├── src ├── __init__.py ├── datasets │ ├── samplers │ │ ├── __init__.py │ │ ├── iteration_based_batch_sampler.py │ │ └── grouped_batch_sampler.py │ ├── __init__.py │ ├── coco_pan.py │ ├── mapillary_pan.py │ ├── cityscapes_pan.py │ ├── cityscapes_sem.py │ ├── augmentations │ │ ├── __init__.py │ │ └── augmentations.py │ ├── base.py │ └── gt_producer.py ├── models │ ├── components │ │ ├── __init__.py │ │ ├── resnet.py │ │ └── FPN.py │ ├── base_model.py │ ├── PFPN_pan.py │ ├── PFPN_d2.py │ ├── __init__.py │ └── panoptic_base.py ├── pcv │ ├── components │ │ ├── __init__.py │ │ ├── grid_specs.py │ │ ├── ballot.py │ │ └── snake.py │ ├── inference │ │ └── __init__.py │ ├── gaussian_smooth │ │ ├── __init__.py │ │ ├── vis.py │ │ └── prob_tsr.py │ ├── __init__.py │ ├── pcv_basic.py │ ├── pcv.py │ ├── pcv_boundless.py │ ├── pcv_smooth.py │ ├── pcv_igc.py │ └── pcv_igc_boundless.py ├── optimizer.py ├── scheduler.py ├── utils.py ├── config.py ├── reporters.py ├── metric.py ├── pan_eval.py ├── loss.py ├── pan_vis.py └── pan_analyzer.py ├── .gitignore ├── LICENSE ├── run.py └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/pcv/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/pcv/inference/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/pcv/gaussian_smooth/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import ( 2 | SGD, Adam 3 | ) 4 | -------------------------------------------------------------------------------- /src/pcv/__init__.py: -------------------------------------------------------------------------------- 1 | from ..utils import dynamic_load_py_object 2 | 3 | 4 | def get_pcv(pcv_cfg): 5 | module = dynamic_load_py_object(__name__, pcv_cfg.name) 6 | return module() 7 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from panoptic.utils import dynamic_load_py_object 2 | 3 | 4 | def get_dataset_module(dset_name): 5 | dataset = dynamic_load_py_object( 6 | package_name=__name__, module_name=dset_name 7 | ) 8 | return dataset 9 | -------------------------------------------------------------------------------- /src/datasets/coco_pan.py: -------------------------------------------------------------------------------- 1 | from .base import PanopticSeg 2 | 3 | 4 | class COCO_Pan(PanopticSeg): 5 | aspect_grouping = True 6 | 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(name='coco', *args, **kwargs) 9 | self.meta['stuff_pred_thresh'] = 4096 10 | -------------------------------------------------------------------------------- /src/datasets/mapillary_pan.py: -------------------------------------------------------------------------------- 1 | from .base import PanopticSeg 2 | 3 | 4 | class Mapillary_Pan(PanopticSeg): 5 | aspect_grouping = False 6 | 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(name='mapillary', *args, **kwargs) 9 | self.meta['stuff_pred_thresh'] = 2048 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | *.egg-info/ 4 | *.egg 5 | *.py[cod] 6 | __pycache__/ 7 | *.so 8 | *~ 9 | 10 | # due to using tox and pytest 11 | .tox 12 | .cache 13 | 14 | # notebook checkpoints 15 | .ipynb_checkpoints 16 | .vscode/ 17 | 18 | # exp folder 19 | exp/ 20 | new_world/ 21 | ablations/ 22 | notebooks/ 23 | test/ 24 | -------------------------------------------------------------------------------- /src/datasets/cityscapes_pan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .base import PanopticSeg 3 | 4 | 5 | class Cityscapes_Pan(PanopticSeg): 6 | aspect_grouping = False 7 | 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(name='cityscapes', *args, **kwargs) 10 | self.meta['stuff_pred_thresh'] = 2048 11 | -------------------------------------------------------------------------------- /src/datasets/cityscapes_sem.py: -------------------------------------------------------------------------------- 1 | from panoptic.datasets.base import SemanticSeg 2 | import os.path as osp 3 | 4 | 5 | class Cityscapes_Sem(SemanticSeg): 6 | def __init__(self, split, transforms=None): 7 | super().__init__(name='cityscapes', split=split, transforms=transforms) 8 | self.vanilla_lbl_root = osp.join( 9 | '/share/data/vision-greg/cityscapes/gtFine', split 10 | ) 11 | -------------------------------------------------------------------------------- /src/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import ( 2 | StepLR, MultiStepLR, CosineAnnealingLR, LambdaLR 3 | ) 4 | 5 | 6 | def ups_scheduler(optimizer, milestones, gamma, warmup=0): 7 | assert len(milestones) > 0 8 | assert warmup < milestones[0] 9 | from bisect import bisect_right # do not pollute scheduler.py namespace 10 | 11 | def lr_policy(step): 12 | init_multiplier = 1.0 13 | if step < warmup: 14 | return (0.1 + 0.9 * step / warmup) * init_multiplier 15 | exponent = bisect_right(milestones, step) 16 | return init_multiplier * (gamma ** exponent) 17 | 18 | return LambdaLR(optimizer, lr_lambda=lr_policy) 19 | -------------------------------------------------------------------------------- /src/datasets/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from panoptic.utils import dynamic_load_py_object 2 | 3 | 4 | class Compose(object): 5 | def __init__(self, augmentations): 6 | self.augmentations = augmentations 7 | 8 | def __call__(self, img, mask): 9 | 10 | # assert img.size == mask.size 11 | for a in self.augmentations: 12 | img, mask = a(img, mask) 13 | 14 | return img, mask 15 | 16 | 17 | def get_composed_augmentations(aug_list): 18 | if aug_list is None: 19 | print("Using No Augmentations") 20 | return None 21 | 22 | augmentations = [] 23 | for aug_meta in aug_list: 24 | name = aug_meta.name 25 | params = aug_meta.params 26 | aug = dynamic_load_py_object( 27 | package_name=__name__, 28 | module_name='augmentations', obj_name=name 29 | ) 30 | instance = aug(**params) 31 | augmentations.append(instance) 32 | 33 | return Compose(augmentations) 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Haochen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/pcv/components/grid_specs.py: -------------------------------------------------------------------------------- 1 | specs = [ 2 | # (size, num_rounds) 3 | [ # 243, 233 bins 4 | (1, 1), (3, 4), (9, 3), (27, 3) 5 | ], 6 | [ # 243 7 | (3, 4), (9, 3), (27, 3) 8 | ], 9 | [ # size 231, 233 bins 10 | (3, 3), (7, 3), (21, 4), 11 | ], 12 | [ # size 243, 41 bins 13 | (1, 1), (3, 1), (9, 1), (27, 1), (81, 1) 14 | ], 15 | [ # size 225, 121 bins 16 | (3, 2), (15, 2), (25, 3) 17 | ], 18 | # [ # size 231, 233 bins 19 | # (3, 2), (5, 2), (7, 2), (21, 4) 20 | # ], 21 | 22 | [ # size 250, 25 * 25 = 625 bins 23 | (15, 7), 24 | ], 25 | ] 26 | 27 | 28 | igc_specs = [ 29 | { 30 | 'base': [(1, 1), (3, 4), (9, 3), (27, 3)], 31 | 'pyramid': { # radius: grid 32 | (9 - 1) // 2: [(1, 1), (3, 4), (9, 3), (27, 3)], 33 | (27 - 1) // 2: [(3, 4), (9, 3), (27, 3)], 34 | (81 - 1) // 2: [(9, 4), (27, 3)], 35 | (243 - 1) // 2: [(27, 4) ] 36 | } 37 | # 'pyramid': { # radius: grid 38 | # (9 - 1) // 2: [(1, 1), (3, 4), (9, 3), (27, 3)], 39 | # (27 - 1) // 2: [(3, 4), (9, 3), (27, 3)], 40 | # (81 - 1) // 2: [(9, 4), (27, 3)], 41 | # (243 - 1) // 2: [(9, 4), (27, 3)] 42 | # } 43 | } 44 | ] 45 | -------------------------------------------------------------------------------- /src/datasets/samplers/iteration_based_batch_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | ''' 3 | HC: this is not used. It creates more problems than it purports to solve. 4 | ''' 5 | from torch.utils.data.sampler import BatchSampler 6 | 7 | 8 | class IterationBasedBatchSampler(BatchSampler): 9 | """ 10 | Wraps a BatchSampler, resampling from it until 11 | a specified number of iterations have been sampled 12 | """ 13 | 14 | def __init__(self, batch_sampler, num_iterations, start_iter=0): 15 | self.batch_sampler = batch_sampler 16 | self.num_iterations = num_iterations 17 | self.start_iter = start_iter 18 | 19 | def __iter__(self): 20 | iteration = self.start_iter 21 | while iteration <= self.num_iterations: 22 | # if the underlying sampler has a set_epoch method, like 23 | # DistributedSampler, used for making each process see 24 | # a different split of the dataset, then set it 25 | if hasattr(self.batch_sampler.sampler, "set_epoch"): 26 | self.batch_sampler.sampler.set_epoch(iteration) 27 | for batch in self.batch_sampler: 28 | iteration += 1 29 | if iteration > self.num_iterations: 30 | break 31 | yield batch 32 | 33 | def __len__(self): 34 | return self.num_iterations 35 | -------------------------------------------------------------------------------- /src/models/base_model.py: -------------------------------------------------------------------------------- 1 | class BaseModel: 2 | """ 3 | A model should encapsulate everything related to its training and inference. 4 | Only the most generic APIs should be exposed for maximal flexibility 5 | """ 6 | def __init__(self): 7 | self.cfg = None 8 | self.net = None 9 | self.criteria = None 10 | self.optimizer = None 11 | self.scheduler = None 12 | self.log_writer = None 13 | self.total_train_epoch = -1 14 | self.curr_epoch = 0 15 | 16 | self.is_train = None 17 | 18 | def set_train_eval_state(self, to_train): 19 | assert to_train in (True, False) 20 | if self.is_train is None or self.is_train != to_train: 21 | if to_train: 22 | self.net.train() 23 | else: 24 | self.net.eval() 25 | self.is_train = to_train 26 | 27 | def ingest_train_input(self, input): 28 | raise NotImplementedError() 29 | 30 | def infer(self, input): 31 | raise NotImplementedError() 32 | 33 | def optimize_params(self): 34 | raise NotImplementedError() 35 | 36 | def advance_to_next_epoch(self): 37 | raise NotImplementedError() 38 | 39 | def load_latest_checkpoint_if_available(self, manager): 40 | raise NotImplementedError() 41 | 42 | def write_checkpoint(self, manager): 43 | raise NotImplementedError() 44 | 45 | def add_log_writer(self, log_writer): 46 | self.log_writer = log_writer 47 | 48 | def log_statistics(self, step, level=0): 49 | pass 50 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import datetime 3 | from fabric.utils.timer import Timer 4 | 5 | 6 | def dynamic_load_py_object( 7 | package_name, module_name, obj_name=None 8 | ): 9 | '''Dynamically import an object. 10 | Assumes that the object lives at module_name.py:obj_name 11 | If obj_name is not given, assume it shares the same name as the module. 12 | obj_name is case insensitive e.g. kitti.py/KiTTI is valid 13 | ''' 14 | if obj_name is None: 15 | obj_name = module_name 16 | # use relative import syntax .targt_name 17 | target_module = importlib.import_module( 18 | '.{}'.format(module_name), package=package_name 19 | ) 20 | target_obj = None 21 | for name, cls in target_module.__dict__.items(): 22 | if name.lower() == obj_name.lower(): 23 | target_obj = cls 24 | 25 | if target_obj is None: 26 | raise ValueError( 27 | "No object in {}.{}.py whose lower-case name matches {}".format( 28 | package_name, module_name, obj_name) 29 | ) 30 | 31 | return target_obj 32 | 33 | 34 | class CompoundTimer(): 35 | def __init__(self): 36 | self.data, self.compute = Timer(), Timer() 37 | 38 | def eta(self, curr_step, tot_step): 39 | remaining_steps = tot_step - curr_step 40 | avg = self.data.avg + self.compute.avg 41 | eta_seconds = avg * remaining_steps 42 | return str(datetime.timedelta(seconds=int(eta_seconds))) 43 | 44 | def __repr__(self): 45 | return "data avg {:.2f}, compute avg {:.2f}".format( 46 | self.data.avg, self.compute.avg) 47 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | from pprint import pformat 3 | from fabric.cluster.configurator import Configurator 4 | from fabric.deploy.sow import NodeTracer 5 | 6 | 7 | _ERR_MSG_UNINITED_ = 'global config state not inited' 8 | 9 | 10 | class _CFG(): 11 | def __init__(self): 12 | self.manager = None 13 | self.settings = {} 14 | self.inited = False 15 | 16 | def init_state(self, cfg_yaml_path, override_opts): 17 | self.manager = Configurator(cfg_yaml_path) 18 | self.settings = edict(self.manager.config) 19 | if override_opts is not None: 20 | self.merge_from_list(override_opts) 21 | self.inited = True 22 | 23 | def __getattr__(self, name): 24 | assert self.inited, _ERR_MSG_UNINITED_ 25 | return getattr(self.settings, name) 26 | 27 | def __str__(self): 28 | assert self.inited, _ERR_MSG_UNINITED_ 29 | return pformat(self.settings) 30 | 31 | def merge_from_list(self, cfg_list): 32 | """Merge config (keys, values) in a list (e.g., from command line) into 33 | this cfg. For example, `cfg_list = ['FOO.BAR', 0.5]`. 34 | """ 35 | assert len(cfg_list) % 2 == 0, \ 36 | "Override list has odd length: {}; it must be a list of pairs".format( 37 | cfg_list 38 | ) 39 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): 40 | self.trace_and_replace(full_key, v) 41 | 42 | def trace_and_replace(self, key, val): 43 | tracer = NodeTracer(self.settings) 44 | tracer.advance_pointer(key) 45 | tracer.replace([], val) 46 | self.settings = tracer.state 47 | 48 | 49 | _C = _CFG() 50 | cfg = _C 51 | -------------------------------------------------------------------------------- /src/pcv/pcv_basic.py: -------------------------------------------------------------------------------- 1 | from .pcv import PCV_base 2 | from .components.snake import Snake 3 | from .components.ballot import Ballot 4 | from .components.grid_specs import specs 5 | from .inference.mask_from_vote import MaskFromVote, MFV_CatSeparate 6 | 7 | from .. import cfg 8 | 9 | 10 | class PCV_Basic(PCV_base): 11 | def __init__(self): 12 | spec = specs[cfg.pcv.grid_inx] 13 | self.num_groups = cfg.pcv.num_groups 14 | self.centroid_mode = cfg.pcv.centroid 15 | self.raw_spec = spec 16 | field_diam, grid_spec = Snake.flesh_out_grid_spec(spec) 17 | self.grid_spec = grid_spec 18 | self._vote_mask = Snake.paint_trail_mask(field_diam, grid_spec) 19 | self._ballot_module = None 20 | 21 | @property 22 | def ballot_module(self): 23 | # instantiate on demand to prevent train time data loading workers to 24 | # hold on to GPU memory 25 | if self._ballot_module is None: 26 | self._ballot_module = Ballot(self.raw_spec, self.num_groups).cuda() 27 | return self._ballot_module 28 | 29 | # 1 for bull's eye center, 1 for abstain vote 30 | @property 31 | def num_bins(self): 32 | return len(self.grid_spec) 33 | 34 | @property 35 | def num_votes(self): 36 | return 1 + self.num_bins 37 | 38 | @property 39 | def vote_mask(self): 40 | return self._vote_mask 41 | 42 | @property 43 | def query_mask(self): 44 | return super().query_mask 45 | 46 | def centroid_from_ins_mask(self, ins_mask): 47 | return super().centroid_from_ins_mask(ins_mask) 48 | 49 | def discrete_vote_inx_from_offset(self, offset): 50 | return self._discretize_offset(self.vote_mask, offset) 51 | 52 | def mask_from_sem_vote_tsr(self, dset_meta, sem_pred, vote_pred): 53 | # make the meta data actually required explicit!! 54 | if self.num_groups == 1: 55 | mfv = MaskFromVote(dset_meta, self, sem_pred, vote_pred) 56 | else: 57 | mfv = MFV_CatSeparate(dset_meta, self, sem_pred, vote_pred) 58 | pan_mask, meta = mfv.infer_panoptic_mask() 59 | return pan_mask, meta 60 | -------------------------------------------------------------------------------- /src/models/components/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision.models import resnet50, resnet101, resnext101_32x8d 4 | 5 | def set_bn_eval(m): 6 | """freeze batch norms, per RT's method""" 7 | classname = m.__class__.__name__ 8 | if classname.find('BatchNorm') != -1: 9 | m.eval() 10 | for _p in m.parameters(): 11 | _p.requires_grad = False 12 | 13 | 14 | def set_conv_stride_to_1(m): 15 | """ 16 | A ResNet Bottleneck has 2 types of modules potentially with stride 2 17 | conv2: 3x3 conv used to reduce the spatial dimension of features 18 | downsample[0]: 1x1 conv used to change shape of original in case the learnt 19 | residual has incompatible channel or spatial dim 20 | It does not make sense to give this 1x1 downsample conv dilation 2. 21 | Hence the if condition testing for conv kernel size 22 | """ 23 | classname = m.__class__.__name__ 24 | if classname.find('Conv') != -1 and m.stride == (2, 2): 25 | m.stride = (1, 1) 26 | if m.kernel_size == (3, 3): 27 | m.dilation = (2, 2) 28 | m.padding = (2, 2) 29 | 30 | 31 | class ResNetFeatureExtractor(nn.Module): 32 | def __init__(self, resnet): 33 | super().__init__() 34 | layer0 = nn.Sequential( 35 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool 36 | ) 37 | self.layers = nn.ModuleList([ 38 | layer0, 39 | resnet.layer1, resnet.layer2, 40 | resnet.layer3, resnet.layer4 41 | ]) 42 | self.extractor = nn.Sequential(*self.layers) 43 | 44 | def forward(self, input): 45 | return self.extractor(input) 46 | 47 | def train(self, mode=True): 48 | if mode: 49 | super().train(True) 50 | # self.apply(set_bn_eval) 51 | else: 52 | super().train(False) 53 | 54 | 55 | def resnet50_gn(pretrained=True): 56 | model = resnet50(norm_layer=lambda x: nn.GroupNorm(32, x)) 57 | model.load_state_dict(torch.load('/share/data/vision-greg/rluo/model/pytorch-resnet/resnet_gn50-pth.pth')) 58 | return model 59 | 60 | 61 | def resnet101_gn(pretrained=True): 62 | model = resnet101(norm_layer=lambda x: nn.GroupNorm(32, x)) 63 | model.load_state_dict(torch.load('/share/data/vision-greg/rluo/model/pytorch-resnet/resnet_gn101-pth.pth')) 64 | return model 65 | 66 | 67 | if __name__ == '__main__': 68 | from torchvision.models import resnet18 69 | model = ResNetFeatureExtractor(resnet18(pretrained=False)) 70 | input = torch.rand(size=(1, 3, 64, 64)) 71 | with torch.no_grad(): 72 | out = model(input) 73 | print(out.shape) 74 | -------------------------------------------------------------------------------- /src/pcv/pcv.py: -------------------------------------------------------------------------------- 1 | """Module for dilation based pixel consensus votin 2 | For now hardcode 3x3 voting kernel and see 3 | """ 4 | from abc import ABCMeta, abstractmethod, abstractproperty 5 | import numpy as np 6 | from scipy.ndimage.measurements import center_of_mass 7 | from ..box_and_mask import get_xywh_bbox_from_binary_mask 8 | from .. import cfg 9 | 10 | 11 | class PCV_base(metaclass=ABCMeta): 12 | def __init__(self): 13 | # store the necessary modules 14 | pass 15 | 16 | @abstractproperty 17 | def num_bins(self): 18 | pass 19 | 20 | @abstractproperty 21 | def num_votes(self): 22 | pass 23 | 24 | @abstractproperty 25 | def vote_mask(self): 26 | pass 27 | 28 | @abstractproperty 29 | def query_mask(self): 30 | """ 31 | Flipped from inside out 32 | """ 33 | diam = len(self.vote_mask) 34 | radius = (diam - 1) // 2 35 | center = (radius, radius) 36 | mask_shape = self.vote_mask.shape 37 | offset_grid = np.indices(mask_shape).transpose(1, 2, 0)[..., ::-1] 38 | offsets = center - offset_grid 39 | allegiance = self.discrete_vote_inx_from_offset( 40 | offsets.reshape(-1, 2) 41 | ).reshape(mask_shape) 42 | return allegiance 43 | 44 | @abstractmethod 45 | def centroid_from_ins_mask(self, ins_mask): 46 | mode = self.centroid_mode 47 | assert mode in ('bbox', 'cmass') 48 | if mode == 'bbox': 49 | bbox = get_xywh_bbox_from_binary_mask(ins_mask) 50 | x, y, w, h = bbox 51 | return [x + w // 2, y + h // 2] 52 | else: 53 | y, x = center_of_mass(ins_mask) 54 | x, y = int(x), int(y) 55 | return [x, y] 56 | 57 | @abstractmethod 58 | def discrete_vote_inx_from_offset(self, offset): 59 | pass 60 | 61 | @staticmethod 62 | def _discretize_offset(vote_mask, offset): 63 | """ 64 | Args: 65 | offset: [N, 2] array of offset towards each pixel's own center, 66 | Each row is filled with (x, y) pair, not (y, x)! 67 | """ 68 | shape = offset.shape 69 | assert len(shape) == 2 and shape[1] == 2 70 | offset = offset[:, ::-1] # swap to (y, x) for indexing 71 | 72 | diam = len(vote_mask) 73 | radius = (diam - 1) // 2 74 | center = (radius, radius) 75 | coord = offset + center 76 | del offset 77 | 78 | ret = -1 * np.ones(len(coord), dtype=np.int32) 79 | valid_inds = np.where( 80 | (coord[:, 0] >= 0) & (coord[:, 0] < diam) 81 | & (coord[:, 1] >= 0) & (coord[:, 1] < diam) 82 | )[0] 83 | _y_inds, _x_inds = coord[valid_inds].T 84 | vals = vote_mask[_y_inds, _x_inds] 85 | ret[valid_inds] = vals 86 | return ret 87 | 88 | @abstractmethod 89 | def mask_from_sem_vote_tsr(self, dset_meta, sem_pred, vote_pred): 90 | pass 91 | -------------------------------------------------------------------------------- /src/pcv/pcv_boundless.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .pcv import PCV_base 3 | from .components.snake import Snake 4 | from .components.ballot import Ballot 5 | from .components.grid_specs import specs 6 | from .inference.mask_from_vote import MaskFromVote, MFV_CatSeparate 7 | 8 | from .. import cfg 9 | 10 | 11 | class PCV_Boundless(PCV_base): 12 | def __init__(self): 13 | spec = specs[cfg.pcv.grid_inx] 14 | self.num_groups = cfg.pcv.num_groups 15 | self.centroid_mode = cfg.pcv.centroid 16 | self.raw_spec = spec 17 | field_diam, grid_spec = Snake.flesh_out_grid_spec(spec) 18 | self.grid_spec = grid_spec 19 | self._vote_mask = Snake.paint_trail_mask(field_diam, grid_spec) 20 | self.boundless_vote_mask = \ 21 | Snake.paint_bound_ignore_trail_mask(field_diam, grid_spec) 22 | self._ballot_module = None 23 | 24 | @property 25 | def ballot_module(self): 26 | # instantiate on demand to prevent train time data loading workers to 27 | # hold on to GPU memory 28 | if self._ballot_module is None: 29 | self._ballot_module = Ballot(self.raw_spec, self.num_groups).cuda() 30 | return self._ballot_module 31 | 32 | # 1 for bull's eye center, 1 for abstain vote 33 | @property 34 | def num_bins(self): 35 | return len(self.grid_spec) 36 | 37 | @property 38 | def num_votes(self): 39 | return 1 + self.num_bins 40 | 41 | @property 42 | def vote_mask(self): 43 | return self._vote_mask 44 | 45 | @property 46 | def query_mask(self): 47 | """ 48 | Flipped from inside out 49 | """ 50 | diam = len(self.vote_mask) 51 | radius = (diam - 1) // 2 52 | center = (radius, radius) 53 | mask_shape = self.vote_mask.shape 54 | offset_grid = np.indices(mask_shape).transpose(1, 2, 0)[..., ::-1] 55 | offsets = center - offset_grid 56 | # allegiance = self.discrete_vote_inx_from_offset( 57 | # offsets.reshape(-1, 2) 58 | # ).reshape(mask_shape) 59 | allegiance = self._discretize_offset( 60 | self.vote_mask, offsets.reshape(-1, 2) 61 | ).reshape(mask_shape) 62 | return allegiance 63 | 64 | def centroid_from_ins_mask(self, ins_mask): 65 | return super().centroid_from_ins_mask(ins_mask) 66 | 67 | def discrete_vote_inx_from_offset(self, offset): 68 | # note that when assigning votes, use the boundary ignore vote mask 69 | return self._discretize_offset(self.boundless_vote_mask, offset) 70 | 71 | def mask_from_sem_vote_tsr(self, dset_meta, sem_pred, vote_pred): 72 | # make the meta data actually required explicit!! 73 | if self.num_groups == 1: 74 | mfv = MaskFromVote(dset_meta, self, sem_pred, vote_pred) 75 | else: 76 | mfv = MFV_CatSeparate(dset_meta, self, sem_pred, vote_pred) 77 | pan_mask, meta = mfv.infer_panoptic_mask() 78 | return pan_mask, meta 79 | -------------------------------------------------------------------------------- /src/models/PFPN_pan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from panoptic.models.panoptic_base import PanopticBase 5 | from panoptic.models.components.resnet import ( 6 | ResNetFeatureExtractor, resnet50, resnet101, resnet50_gn, resnet101_gn 7 | ) 8 | from panoptic.models.components.FPN import FPN 9 | 10 | 11 | primary_backbones = { 12 | 'resnet50': resnet50, 13 | 'resnet101': resnet101, 14 | 'resnet50_gn': resnet50_gn, 15 | 'resnet101_gn': resnet101_gn 16 | } 17 | 18 | 19 | class PFPN_pan(PanopticBase): 20 | def instantiate_network(self, cfg): 21 | self.net = Net( 22 | num_classes=self.dset_meta['num_classes'], num_votes=self.pcv.num_votes, 23 | **cfg.model.params 24 | ) 25 | 26 | 27 | class Net(nn.Module): 28 | def __init__(self, num_classes, num_votes, FPN_C, backbone='resnet50', deep_classifier=False): 29 | super().__init__() 30 | backbone_f = primary_backbones[backbone] 31 | fea_extractor = ResNetFeatureExtractor(backbone_f(pretrained=True)) 32 | self.distilled_fpn = FPN(fea_extractor, FPN_C) 33 | if deep_classifier: 34 | self.sem_classifier = nn.Sequential( 35 | nn.Conv2d(FPN_C, 256, kernel_size=3, padding=1), 36 | nn.ReLU(True), 37 | nn.Dropout(0.1), 38 | nn.Conv2d(256, num_classes, kernel_size=1, bias=True) 39 | ) 40 | self.vote_classifier = nn.Sequential( 41 | nn.Conv2d(FPN_C, 256, kernel_size=3, padding=1), 42 | nn.ReLU(True), 43 | nn.Dropout(0.1), 44 | nn.Conv2d(256, num_votes, kernel_size=1, bias=True) 45 | ) 46 | else: 47 | self.sem_classifier = nn.Conv2d( 48 | FPN_C, num_classes, kernel_size=1, bias=True 49 | ) 50 | self.vote_classifier = nn.Conv2d( 51 | FPN_C, num_votes, kernel_size=1, bias=True 52 | ) 53 | # self.sem_classifier = nn.Conv2d( 54 | # FPN_C, num_classes, kernel_size=1, bias=True 55 | # ) 56 | # self.vote_classifier = nn.Conv2d( 57 | # FPN_C, num_votes, kernel_size=1, bias=True 58 | # ) 59 | self.loss_module = None 60 | 61 | def forward(self, *inputs): 62 | """ 63 | If gt is supplied, then compute loss 64 | """ 65 | x = inputs[0] 66 | x = self.distilled_fpn(x) 67 | sem_pred = self.sem_classifier(x) 68 | vote_pred = self.vote_classifier(x) 69 | 70 | if len(inputs) > 1: 71 | assert self.loss_module is not None 72 | loss = self.loss_module(sem_pred, vote_pred, *inputs[1:]) 73 | return loss 74 | else: 75 | return sem_pred, vote_pred 76 | 77 | 78 | if __name__ == '__main__': 79 | model = Net() 80 | input = torch.rand(size=(1, 3, 64, 64)) 81 | with torch.no_grad(): 82 | out = model(input) 83 | print(out.shape) 84 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import argparse 4 | import random 5 | 6 | import torch 7 | import torch.multiprocessing as mp 8 | import torch.distributed as dist 9 | 10 | import panoptic 11 | from panoptic.entry import Entry 12 | from fabric.utils.mailer import ExceptionEmail 13 | from fabric.utils.logging import setup_logging 14 | from fabric.utils.git import git_version 15 | 16 | logger = setup_logging(__file__) 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser(description="run script") 21 | parser.add_argument('--command', '-c', type=str, default='train') 22 | parser.add_argument('--debug', action='store_true') 23 | parser.add_argument('--dist', action='store_true') 24 | parser.add_argument( 25 | "opts", 26 | help="Modify config options using the command-line", 27 | default=None, 28 | nargs=argparse.REMAINDER, 29 | ) 30 | args = parser.parse_args() 31 | 32 | git_hash = git_version(osp.dirname(panoptic.__file__)) 33 | logger.info('git repository hash: {}'.format(git_hash)) 34 | 35 | global fly 36 | # only decorate when not in interative mode; Bugs are expected there. 37 | if 'INTERACTIVE' not in os.environ: 38 | recipient = 'whc@ttic.edu' 39 | logger.info("decorating with warning email to {}".format(recipient)) 40 | email_subject_headline = "{} tripped".format( 41 | # take the immediate dirname as email label 42 | osp.dirname(osp.abspath(__file__)).split('/')[-1] 43 | ) 44 | fly = ExceptionEmail( 45 | subject=email_subject_headline, address=recipient 46 | )(fly) 47 | 48 | ngpus = torch.cuda.device_count() 49 | port = random.randint(10000, 20000) 50 | argv = (ngpus, args.command, args.debug, args.opts, port) 51 | if args.dist: 52 | mp.spawn(fly, nprocs=ngpus, args=argv) 53 | else: 54 | fly(None, *argv) 55 | 56 | 57 | def fly(rank, ngpus, command, debug, opts, port): 58 | distributed = rank is not None # and not debug 59 | if distributed: # multiprocess distributed training 60 | dist.init_process_group( 61 | world_size=ngpus, rank=rank, 62 | backend='nccl', init_method=f'tcp://127.0.0.1:{port}', 63 | ) 64 | assert command == 'train' # for now only train uses mp distributed 65 | torch.cuda.set_device(rank) 66 | 67 | entry = Entry( 68 | __file__, override_opts=opts, debug=debug, 69 | mp_distributed=distributed, rank=rank, world_size=ngpus 70 | ) 71 | if command == 'train': 72 | entry.train() 73 | elif command == 'validate': # for evaluate semantic segmentation mean iou 74 | entry.validate(False) 75 | elif command == 'evaluate': 76 | entry.evaluate() 77 | elif command == 'report': 78 | entry.report() 79 | elif command == 'test': 80 | entry.PQ_test(save_output=True) 81 | elif command == 'make_figures': 82 | entry.make_figures() 83 | else: 84 | raise ValueError("unrecognized command") 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /src/models/PFPN_d2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from panoptic.models.panoptic_base import PanopticBase, AbstractNet 5 | from panoptic.models.components.resnet import ( 6 | ResNetFeatureExtractor, resnet50, resnet101, resnet50_gn, resnet101_gn 7 | ) 8 | from panoptic.models.components.FPN import FPN 9 | import detectron2.config 10 | import detectron2.modeling 11 | from detectron2.modeling.backbone import build_backbone 12 | from detectron2.modeling.meta_arch.semantic_seg import build_sem_seg_head 13 | from detectron2.checkpoint import DetectionCheckpointer 14 | 15 | from panoptic import cfg 16 | 17 | primary_backbones = { 18 | 'resnet50': resnet50, 19 | 'resnet101': resnet101, 20 | 'resnet50_gn': resnet50_gn, 21 | 'resnet101_gn': resnet101_gn 22 | } 23 | 24 | 25 | class PFPN_d2(PanopticBase): 26 | def instantiate_network(self, cfg): 27 | self.net = Net( 28 | num_classes=self.dset_meta['num_classes'], num_votes=self.pcv.num_votes, 29 | **cfg.model.params 30 | ) 31 | 32 | 33 | class Net(AbstractNet): 34 | def __init__(self, num_classes, num_votes, 35 | fix_bn=True, 36 | freeze_at=2, 37 | norm='GN', 38 | fpn_norm='', 39 | conv_dims=128, 40 | **kwargs): # FPN_C, backbone='resnet50'): 41 | super().__init__() 42 | d2_cfg = detectron2.config.get_cfg() 43 | d2_cfg.merge_from_file('/home-nfs/whc/glab/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_1x.yaml') 44 | if fix_bn: 45 | d2_cfg.MODEL.RESNETS.NORM = "FrozenBN" 46 | else: 47 | d2_cfg.MODEL.RESNETS.NORM = "BN" 48 | d2_cfg.MODEL.BACKBONE.FREEZE_AT = freeze_at 49 | 50 | d2_cfg.MODEL.FPN.NORM = 'BN' 51 | self.backbone = build_backbone(d2_cfg) 52 | 53 | d2_cfg.MODEL.SEM_SEG_HEAD.NORM = norm 54 | d2_cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = conv_dims 55 | 56 | d2_cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = num_classes 57 | self.sem_seg_head = build_sem_seg_head(d2_cfg, self.backbone.output_shape()) 58 | self.sem_classifier = self.sem_seg_head.predictor 59 | del self.sem_seg_head.predictor 60 | 61 | d2_cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = num_votes 62 | tmp = build_sem_seg_head(d2_cfg, self.backbone.output_shape()) 63 | self.vote_classifier = tmp.predictor 64 | assert cfg.data.dataset.params.caffe_mode == True 65 | 66 | checkpointer = DetectionCheckpointer(self) 67 | checkpointer.load(d2_cfg.MODEL.WEIGHTS) 68 | 69 | def stage1(self, x): 70 | return self.backbone(x) 71 | 72 | def stage2(self, features): 73 | # copy from detectron2/modeling/meta_arch/semantic_seg.py 74 | # why? because segFPNHead doesn't accept training==True and targets==None 75 | for i, f in enumerate(self.sem_seg_head.in_features): 76 | if i == 0: 77 | x = self.sem_seg_head.scale_heads[i](features[f]) 78 | else: 79 | x = x + self.sem_seg_head.scale_heads[i](features[f]) 80 | # x = self.sem_seg_head.predictor(x) 81 | return x, x 82 | # x = F.interpolate(x, scale_factor=self.common_stride, mode="bilinear", align_corners=False) 83 | 84 | 85 | if __name__ == '__main__': 86 | model = Net(10,10) 87 | model.eval() 88 | input = torch.rand(size=(1, 3, 224, 224)) 89 | with torch.no_grad(): 90 | out = model(input) 91 | print(out[0].shape, out[1].shape) 92 | -------------------------------------------------------------------------------- /src/pcv/pcv_smooth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .pcv import PCV_base 4 | from .components.snake import Snake 5 | from .components.ballot import Ballot 6 | from .components.grid_specs import specs 7 | from .gaussian_smooth.prob_tsr import MakeProbTsr 8 | from .inference.mask_from_vote import MaskFromVote, MFV_CatSeparate 9 | 10 | from .. import cfg 11 | 12 | 13 | def flesh_out_grid(spec): 14 | field_diam, grid_spec = Snake.flesh_out_grid_spec(spec) 15 | vote_mask = Snake.paint_trail_mask(field_diam, grid_spec) 16 | return vote_mask 17 | 18 | 19 | def flesh_out_spec(spec_group): 20 | ret = {'base': None, 'pyramid': dict()} 21 | ret['base'] = flesh_out_grid(spec_group['base']) 22 | for k, v in spec_group['pyramid'].items(): 23 | ret['pyramid'][k] = flesh_out_grid(v) 24 | return ret 25 | 26 | 27 | class PCV_Smooth(PCV_base): 28 | def __init__(self, grid_inx): 29 | self.num_groups = cfg.pcv.num_groups 30 | self.spec = specs[cfg.pcv.grid_inx] 31 | self.centroid_mode = cfg.pcv.centroid 32 | field_diam, grid_spec = Snake.flesh_out_grid_spec(self.spec) 33 | self.grid_spec = grid_spec 34 | self._vote_mask = Snake.paint_trail_mask(field_diam, grid_spec) 35 | self._ballot_module = None 36 | 37 | maker = MakeProbTsr( 38 | self.spec, field_diam, grid_spec, self.vote_mask, var=0.05 39 | ) 40 | self.spatial_prob_tsr = maker.compute_voting_prob_tsr(normalize=True) 41 | 42 | @property 43 | def ballot_module(self): 44 | # instantiate on demand to prevent train time data loading workers to 45 | # hold on to GPU memory 46 | if self._ballot_module is None: 47 | self._ballot_module = Ballot(self.spec, self.num_groups).cuda() 48 | return self._ballot_module 49 | 50 | # 1 for bull's eye center, 1 for abstain vote 51 | @property 52 | def num_bins(self): 53 | return len(self.grid_spec) 54 | 55 | @property 56 | def num_votes(self): 57 | return 1 + self.num_bins 58 | 59 | @property 60 | def vote_mask(self): 61 | return self._vote_mask 62 | 63 | @property 64 | def query_mask(self): 65 | return super().query_mask 66 | 67 | def centroid_from_ins_mask(self, ins_mask): 68 | return super().centroid_from_ins_mask(ins_mask) 69 | 70 | def discrete_vote_inx_from_offset(self, offset): 71 | return self._discretize_offset(self.vote_mask, offset) 72 | 73 | def smooth_prob_tsr_from_offset(self, offset): 74 | """ 75 | Args: 76 | offset: [N, 2] array of offset towards each pixel's own center, 77 | Each row is filled with (x, y) pair, not (y, x)! 78 | Returns: 79 | vote_tsr: [N, num_votes] of float tsr with each entry being a prob 80 | """ 81 | shape = offset.shape 82 | assert len(shape) == 2 and shape[1] == 2 83 | offset = offset[:, ::-1] # swap to (y, x) for indexing 84 | 85 | diam = len(self.spatial_prob_tsr) 86 | radius = (diam - 1) // 2 87 | center = (radius, radius) 88 | coord = offset + center 89 | 90 | # [N, num_votes] 91 | tsr = np.zeros((len(offset), self.num_votes), dtype=np.float32) 92 | valid_inds = np.where( 93 | (coord[:, 0] >= 0) & (coord[:, 0] < diam) 94 | & (coord[:, 1] >= 0) & (coord[:, 1] < diam) 95 | )[0] 96 | _y_inds, _x_inds = coord[valid_inds].T 97 | tsr[valid_inds] = self.spatial_prob_tsr[_y_inds, _x_inds] 98 | return tsr 99 | 100 | def mask_from_sem_vote_tsr(self, dset_meta, sem_pred, vote_pred): 101 | # make the meta data actually required explicit!! 102 | if self.num_groups == 1: 103 | mfv = MaskFromVote(dset_meta, self, sem_pred, vote_pred) 104 | else: 105 | mfv = MFV_CatSeparate(dset_meta, self, sem_pred, vote_pred) 106 | pan_mask, meta = mfv.infer_panoptic_mask() 107 | return pan_mask, meta 108 | -------------------------------------------------------------------------------- /src/pcv/pcv_igc.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_left 2 | import numpy as np 3 | 4 | from .pcv import PCV_base 5 | from .components.snake import Snake 6 | from .components.ballot import Ballot 7 | from .components.grid_specs import igc_specs 8 | from .inference.mask_from_vote import MaskFromVote, MFV_CatSeparate 9 | 10 | from .. import cfg 11 | 12 | 13 | def flesh_out_grid(spec): 14 | field_diam, grid_spec = Snake.flesh_out_grid_spec(spec) 15 | vote_mask = Snake.paint_trail_mask(field_diam, grid_spec) 16 | return vote_mask 17 | 18 | 19 | def flesh_out_spec(spec_group): 20 | ret = {'base': None, 'pyramid': dict()} 21 | ret['base'] = flesh_out_grid(spec_group['base']) 22 | for k, v in spec_group['pyramid'].items(): 23 | ret['pyramid'][k] = flesh_out_grid(v) 24 | return ret 25 | 26 | 27 | class PCV_IGC(PCV_base): 28 | def __init__(self): 29 | # grid inx for now is a dummy 30 | spec_group = igc_specs[cfg.pcv.grid_inx] 31 | self.num_groups = cfg.pcv.num_groups 32 | self.centroid_mode = cfg.pcv.centroid 33 | self.raw_spec = spec_group['base'] 34 | _, self.grid_spec = Snake.flesh_out_grid_spec(self.raw_spec) 35 | self.grid_group = flesh_out_spec(spec_group) 36 | self._vote_mask = self.grid_group['base'] 37 | self._ballot_module = None 38 | self.coalesce_thresh = list(self.grid_group['pyramid'].keys()) 39 | 40 | @property 41 | def ballot_module(self): 42 | # instantiate on demand to prevent train time data loading workers to 43 | # hold on to GPU memory 44 | if self._ballot_module is None: 45 | self._ballot_module = Ballot(self.raw_spec, self.num_groups).cuda() 46 | return self._ballot_module 47 | 48 | # 1 for bull's eye center, 1 for abstain vote 49 | @property 50 | def num_bins(self): 51 | return len(self.grid_spec) 52 | 53 | @property 54 | def num_votes(self): 55 | return 1 + self.num_bins 56 | 57 | @property 58 | def vote_mask(self): 59 | return self._vote_mask 60 | 61 | @property 62 | def query_mask(self): 63 | return super().query_mask 64 | 65 | def centroid_from_ins_mask(self, ins_mask): 66 | return super().centroid_from_ins_mask(ins_mask) 67 | 68 | def discrete_vote_inx_from_offset(self, offset): 69 | return self._discretize_offset(self.vote_mask, offset) 70 | 71 | def tensorized_vote_from_offset(self, offset): 72 | """ 73 | Args: 74 | offset: [N, 2] array of offset towards each pixel's own center, 75 | Each row is filled with (x, y) pair, not (y, x)! 76 | Returns: 77 | vote_tsr: [N, num_votes] of bool tsr where 0/1 denotes gt entries. 78 | """ 79 | # dispatch to the proper grid 80 | base_mask = self.grid_group['base'] 81 | radius = (len(base_mask) - 1) // 2 82 | max_offset = min(radius, np.abs(offset).max()) 83 | inx = bisect_left(self.coalesce_thresh, max_offset) 84 | key = self.coalesce_thresh[inx] 85 | vote_mask = self.grid_group['pyramid'][key] 86 | 87 | tsr = np.zeros((len(offset), self.num_votes), dtype=np.bool) # [N, num_votes] 88 | gt_indices = self._discretize_offset(vote_mask, offset) # [N, ] 89 | for inx in np.unique(gt_indices): 90 | if inx == -1: 91 | continue 92 | entries = np.unique(base_mask[vote_mask == inx]) 93 | inds = np.where(gt_indices == inx)[0] 94 | tsr[inds.reshape(-1, 1), entries.reshape(1, -1)] = True 95 | return tsr 96 | 97 | def mask_from_sem_vote_tsr(self, dset_meta, sem_pred, vote_pred): 98 | # make the meta data actually required explicit!! 99 | if self.num_groups == 1: 100 | mfv = MaskFromVote(dset_meta, self, sem_pred, vote_pred) 101 | else: 102 | mfv = MFV_CatSeparate(dset_meta, self, sem_pred, vote_pred) 103 | pan_mask, meta = mfv.infer_panoptic_mask() 104 | return pan_mask, meta 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pixel Consensus Voting (CVPR 2020) 2 | 3 |
4 | backproj 5 | panel 6 |
7 | 8 | 9 | The core of our approach, Pixel Consensus Voting, is a framework for instance segmentation based on the Generalized Hough transform. Pixels cast discretized, probabilistic votes for the likely regions that contain instance centroids. At the detected peaks that emerge in the voting heatmap, backprojection is applied to collect pixels and produce instance masks. Unlike a sliding window detector that densely enumerates object proposals, our method detects instances as a result of the consensus among pixel-wise votes. We implement vote aggregation and backprojection using native operators of a convolutional neural network. The discretization of centroid voting reduces the training of instance segmentation to pixel labeling, analogous and complementary to FCN-style semantic segmentation, leading to an efficient and unified architecture that jointly models things and stuff. We demonstrate the effectiveness of our pipeline on COCO and Cityscapes Panoptic Segmentation and obtain competitive results. 10 | 11 | ## Quick Intro 12 | - The codebase contains the essential ingredients of PCV, including various spatial discretization schemes and convolutional backprojection inference. The network backbone is a simple FPN on ResNet. 13 | - Visualzier 1 ([vis.py](src/vis.py)): loads a single image into a dynamic, interacive interface that allows users to click on pixels to inspect model prediction. It is built on matplotlib interactive API and jupyter widgets. Under the hood it's React. 14 | - Visualizer 2 ([pan_vis.py](src/pan_vis.py)): A global inspector that take panoptic segmentation prediction and displays prediction segments against ground truth. Useful to track down which images make the most serious error and how. 15 | 16 | - The core of PCV is contained in [src/pcv](src/pcv). The results reported in the paper uses [src/pcv/pcv_basic](src/pcv/pcv_basic.py). There are also a few modification ideas that didn't work out e.g. "inner grid collapse" ([src/pcv/pcv_igc](src/pcv/pcv_igc.py)), erasing boundary loss [src/pcv/pcv_boundless](src/pcv/pcv_boundless.py), smoothened gt assignment [src/pcv/pcv_smooth](src/pcv/pcv_smooth.py). 17 | - The deconv voting filter weight intializaiton is in [src/pcv/components/ballot.py](src/pcv/components/ballot.py). Different deconv discretization schemes can be found in [src/pcv/components/grid_specs.py](src/pcv/components/grid_specs.py). [src/pcv/components/snake.py](src/pcv/components/snake.py) manages the generation of snake grid on which pcv operates. 18 | 19 | - The backprojection code is in [src/pcv/inference/mask_from_vote.py](src/pcv/inference/mask_from_vote.py). Since this is a non-standard procedure of convolving a filter to do equality comparison, I implemented a simple conv using advanced indexing. See the function [src/pcv/inference/mask_from_vote.py:unroll_img_inds](src/pcv/inference/mask_from_vote.py#L110-L119). For a fun side-project, I am thinking about rewriting the backprojection in Julia and generate GPU code (ptx) directly through LLVM. That way we don't have to deal with CUDA kernels that are hard to maintain. 20 | - The main entry point is [run.py](run.py) and [src/entry.py](src/entry.py) 21 | 22 | 23 | ## Getting Started 24 | 25 | ### Dependencies 26 | - pytorch==1.4.0 27 | - fabric (personal toolkit that needs to be re-factored in) 28 | - pycocotools 29 | 30 | ~~~python 31 | python run.py -c train 32 | python run.py -c evaluate 33 | ~~~ 34 | runs the default PCV configuration reported in Table 3 of the paper. 35 | 36 | 37 | ## Bibtex 38 | 39 | ```bibtex 40 | @inproceedings{pcv2020, 41 | title={Pixel consensus voting for panoptic segmentation}, 42 | author={Wang, Haochen and Luo, Ruotian and Maire, Michael and Shakhnarovich, Greg}, 43 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 44 | pages={9464--9473}, 45 | year={2020} 46 | } 47 | ``` 48 | -------------------------------------------------------------------------------- /src/pcv/components/ballot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .snake import Snake 6 | 7 | 8 | class Ballot(nn.Module): 9 | def __init__(self, spec, num_groups): 10 | super().__init__() 11 | self.num_groups = num_groups 12 | votes = [] 13 | smears = [] 14 | 15 | splits = [] 16 | acc_size = spec[0][0] # inner most size 17 | for i, (size, num_rounds) in enumerate(spec): 18 | inner_blocks = acc_size // size 19 | total_blocks = inner_blocks + 2 * num_rounds 20 | acc_size += 2 * num_rounds * size 21 | 22 | prepend = (i == 0) 23 | num_in_chnls = (total_blocks ** 2 - inner_blocks ** 2) + int(prepend) 24 | num_in_chnls *= num_groups 25 | num_out_chnls = 1 * num_groups 26 | splits.append(num_in_chnls) 27 | deconv_vote = nn.ConvTranspose2d( 28 | in_channels=num_in_chnls, out_channels=num_out_chnls, 29 | kernel_size=total_blocks, stride=1, 30 | padding=(total_blocks - 1) // 2 * size, dilation=size, 31 | groups=num_groups, bias=False 32 | ) 33 | # each group [ num_chnl, 1, 3, 3 ] -> [num_groups * num_chnl, 1, 3, 3] 34 | throw_away = 0 if prepend else inner_blocks 35 | weight = torch.cat([ 36 | get_voting_deonv_kernel_weight( 37 | side=total_blocks, throw_away=throw_away 38 | ) 39 | for _ in range(num_groups) 40 | ], dim=0) 41 | deconv_vote.weight.data.copy_(weight) 42 | 43 | votes.append(deconv_vote) 44 | smear_ker = nn.AvgPool2d( 45 | # in_channels=num_out_chnls, out_channels=num_out_chnls, 46 | kernel_size=size, stride=1, padding=int( (size - 1) / 2 ) 47 | ) 48 | # smear_ker = nn.ConvTranspose2d( 49 | # in_channels=num_out_chnls, out_channels=num_out_chnls, 50 | # kernel_size=size, stride=1, padding=int( (size - 1) / 2 ), 51 | # groups=num_groups, bias=False 52 | # ) 53 | # smear_ker.weight.data.fill_(1 / (size ** 2 )) 54 | smears.append(smear_ker) 55 | 56 | self.splits = splits 57 | self.votes = nn.ModuleList(votes) 58 | self.smears = nn.ModuleList(smears) 59 | 60 | @torch.no_grad() 61 | def forward(self, x): 62 | num_groups = self.num_groups 63 | splitted = x.split(self.splits, dim=1) 64 | assert len(splitted) == len(self.votes) 65 | output = [] 66 | for i in range(len(splitted)): 67 | # if i == (len(splitted) - 1): 68 | # continue 69 | x = splitted[i] 70 | if num_groups > 1: # painful lesson 71 | _, C, H, W = x.shape 72 | x = x.reshape(-1, C // num_groups, num_groups, H, W) 73 | x = x.transpose(1, 2) 74 | x = x.reshape(-1, C, H, W) 75 | x = self.votes[i](x) 76 | x = self.smears[i](x) 77 | output.append(x) 78 | return sum(output) 79 | 80 | 81 | def get_voting_deonv_kernel_weight(side, throw_away, return_tsr=True): 82 | """ 83 | The logic is neat; it makes use of the fact that negatives will be indexed 84 | from the last channel, and one simply needs to throw those away 85 | """ 86 | assert throw_away <= side 87 | throw_away = throw_away ** 2 88 | rounds = (side - 1) // 2 89 | spatial_inds = Snake.paint_trail_mask( 90 | *Snake.flesh_out_grid_spec([[1, rounds]]) 91 | ) - throw_away 92 | weight = np.zeros(shape=(side, side, side ** 2)) 93 | dim_0_inds, dim_1_inds = np.ix_(range(side), range(side)) 94 | weight[dim_0_inds, dim_1_inds, spatial_inds] = 1 95 | if throw_away > 0: 96 | weight = weight[:, :, :-throw_away] 97 | if not return_tsr: 98 | return weight 99 | else: 100 | weight = np.expand_dims(weight.transpose(2, 0, 1), axis=1) 101 | kernel = torch.as_tensor(weight).float() 102 | return kernel 103 | -------------------------------------------------------------------------------- /src/pcv/gaussian_smooth/vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from ipywidgets import Output 4 | from matplotlib.patches import Circle 5 | 6 | from panoptic.pcv.gaussian_smooth.prob_tsr import MakeProbTsr 7 | from panoptic.pcv.components.snake import Snake 8 | 9 | from panoptic.vis import Visualizer as BaseVisualizer 10 | 11 | 12 | class Plot(): 13 | def __init__(self, ax, bin_center_yx, vote_mask, spatial_prob): 14 | self.ax = ax 15 | self.bin_center_yx = bin_center_yx 16 | self.vote_mask = vote_mask 17 | self.spatial_prob = spatial_prob 18 | 19 | # self.pressed_xy = None 20 | self.dot = None 21 | self.texts = None 22 | self.init_artists() 23 | self.render_visual() 24 | 25 | def init_artists(self): 26 | if self.dot is not None: 27 | self.dot.remove() 28 | 29 | if self.texts is not None: 30 | assert isinstance(self.texts, (tuple, list)) 31 | for elem in self.texts: 32 | elem.remove() 33 | 34 | self.dot = None 35 | self.texts = None 36 | 37 | def render_visual(self): 38 | self.ax.imshow(self.vote_mask) 39 | 40 | def press_coord(self, x, y, button): 41 | del button # ignoring button for now 42 | # self.pressed_xy = x, y 43 | self.init_artists() 44 | self.render_single_dot(x, y) 45 | self.render_prob_dist(x, y) 46 | 47 | def render_prob_dist(self, x, y): 48 | thresh = 0 49 | dist = self.spatial_prob[y, x] 50 | inds = np.where(dist > thresh)[0] 51 | probs = dist[inds] * 100 52 | # print(probs) 53 | bin_centers = self.bin_center_yx[inds] 54 | 55 | acc = [] 56 | for cen, p in zip(bin_centers, probs): 57 | y, x = cen 58 | _a = self.ax.text( 59 | x, y, s='{:.2f}'.format(p), fontsize='small', color='r' 60 | ) 61 | acc.append(_a) 62 | self.texts = acc 63 | 64 | def query_coord(self, x, y, button): 65 | pass 66 | 67 | def motion_coord(self, x, y): 68 | self.press_coord(x, y, None) 69 | 70 | def render_single_dot(self, x, y): 71 | cir = Circle((x, y), radius=0.5, color='white') 72 | self.ax.add_patch(cir) 73 | self.dot = cir 74 | 75 | 76 | class Visualizer(BaseVisualizer): 77 | def __init__(self): 78 | spec = [ # 243, 233 bins 79 | (3, 4), (9, 3), (27, 3) 80 | ] 81 | # spec = [ 82 | # (3, 3), (7, 3), (21, 4) 83 | # ] 84 | diam, grid_spec = Snake.flesh_out_grid_spec(spec) 85 | vote_mask = Snake.paint_trail_mask(diam, grid_spec) 86 | maker = MakeProbTsr(spec, diam, grid_spec, vote_mask) 87 | spatial_prob = maker.compute_voting_prob_tsr() 88 | 89 | self.vote_mask = vote_mask 90 | self.spatial_prob = spatial_prob 91 | 92 | radius = (diam - 1) // 2 93 | center = np.array((radius, radius)) 94 | self.bin_center_yx = grid_spec[:, :2] + center 95 | 96 | self.output_widget = Output() 97 | self.init_state() 98 | self.pressed = False 99 | np.set_printoptions( 100 | formatter={'float': lambda x: "{:.3f}".format(x)} 101 | ) 102 | 103 | def vis(self): 104 | fig = plt.figure(figsize=(10, 10), constrained_layout=True) 105 | self.fig = fig 106 | self.canvas = fig.canvas 107 | self.plots = dict() 108 | 109 | key = 'spatial prob dist' 110 | ax = fig.add_subplot(111) 111 | ax.set_title(key) 112 | self.plots[key] = Plot( 113 | ax, self.bin_center_yx, self.vote_mask, self.spatial_prob 114 | ) 115 | self.connect() 116 | 117 | 118 | def test(): 119 | # spec = [ # 243, 233 bins 120 | # (1, 1), (3, 4), (9, 3), (27, 3) 121 | # ] 122 | spec = [ # 243, 233 bins 123 | (3, 1), (9, 1) 124 | ] 125 | diam, grid_spec = Snake.flesh_out_grid_spec(spec) 126 | vote_mask = Snake.paint_trail_mask(diam, grid_spec) 127 | maker = MakeProbTsr(spec, diam, grid_spec, vote_mask) 128 | 129 | 130 | if __name__ == "__main__": 131 | test() 132 | -------------------------------------------------------------------------------- /src/models/components/FPN.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def conv_gn_relu(in_C, out_C, kernel_size, use_relu=True): 8 | assert kernel_size in (3, 1) 9 | pad = (kernel_size - 1) // 2 10 | num_groups = out_C // 16 # note this is hardcoded 11 | module = [ 12 | nn.Conv2d(in_C, out_C, kernel_size, padding=pad, bias=False), 13 | # nn.GroupNorm(num_groups, num_channels=out_C), 14 | nn.BatchNorm2d(out_C) 15 | ] 16 | if use_relu: 17 | module.append(nn.ReLU(inplace=True)) 18 | return nn.Sequential(*module) 19 | 20 | 21 | class FPN(nn.Module): 22 | def __init__(self, resnet_extractor, FPN_distill_C): 23 | super().__init__() 24 | self.strides = [32, 16, 8, 4] 25 | self.dims = [2048, 1024, 512, 256] # assume resnet 50 and above 26 | self.FPN_feature_C = 256 27 | self.FPN_distill_C = FPN_distill_C 28 | 29 | self.backbone = resnet_extractor 30 | self.FPN_create_modules = nn.ModuleList() 31 | self.FPN_distill_modules = nn.ModuleList() 32 | 33 | for in_dim, stride in zip(self.dims, self.strides): 34 | self.FPN_create_modules.append( 35 | nn.ModuleDict({ 36 | 'lateral': conv_gn_relu( 37 | in_dim, self.FPN_feature_C, 38 | kernel_size=1, use_relu=False 39 | ), 40 | 'refine': conv_gn_relu( 41 | self.FPN_feature_C, self.FPN_feature_C, 42 | kernel_size=3, use_relu=False 43 | ) 44 | }) 45 | ) 46 | 47 | for stride in self.strides: 48 | self.FPN_distill_modules.append( 49 | self.get_distill_module( 50 | stride / 4, self.FPN_feature_C, self.FPN_distill_C 51 | ) 52 | ) 53 | 54 | @staticmethod 55 | def get_distill_module(upsample_ratio, in_C, out_C): 56 | levels = math.log(upsample_ratio, 2) 57 | assert levels.is_integer() 58 | levels = int(levels) 59 | if levels == 0: 60 | return conv_gn_relu(in_C, out_C, kernel_size=3) 61 | 62 | module = [] 63 | for _ in range(levels): 64 | module.append( 65 | # maybe use ordered dict? 66 | nn.Sequential( 67 | conv_gn_relu(in_C, out_C, kernel_size=3), 68 | nn.Upsample( 69 | scale_factor=2, mode='bilinear', align_corners=False) 70 | ) 71 | ) 72 | in_C = out_C 73 | return nn.Sequential(*module) 74 | 75 | def collect_resnet_features(self, x): 76 | acc = [] 77 | for layer in self.backbone.layers: 78 | x = layer(x) 79 | acc.append(x) 80 | acc = acc[1:] # throw away layer0 output 81 | return acc[::-1] # high level features first 82 | 83 | def create_FPN(self, res_features): 84 | FPN_features = [] 85 | prev = None 86 | for tsr, curr_module in zip(res_features, self.FPN_create_modules): 87 | tsr = curr_module['lateral'](tsr) 88 | if prev is not None: 89 | prev = F.interpolate(prev, scale_factor=2, mode='nearest') 90 | tsr = tsr + prev 91 | prev = tsr 92 | refined = curr_module['refine'](tsr) 93 | FPN_features.append(refined) 94 | return FPN_features 95 | 96 | def distill_FPN(self, FPN_features): 97 | acc = [] 98 | for tsr, curr_module in zip(FPN_features, self.FPN_distill_modules): 99 | tsr = curr_module(tsr) 100 | acc.append(tsr) 101 | return sum(acc) 102 | 103 | def forward(self, x): 104 | res_features = self.collect_resnet_features(x) 105 | FPN_features = self.create_FPN(res_features) 106 | final = self.distill_FPN(FPN_features) 107 | return final 108 | 109 | 110 | if __name__ == '__main__': 111 | from panoptic.models.components.resnet import ResNetFeatureExtractor 112 | from torchvision.models import resnet50 113 | extractor = ResNetFeatureExtractor(resnet50(pretrained=False)) 114 | fpn = FPN(extractor) 115 | input = torch.rand(size=(1, 3, 64, 64)) 116 | with torch.no_grad(): 117 | output = fpn(input) 118 | print(output.shape) 119 | -------------------------------------------------------------------------------- /src/pcv/pcv_igc_boundless.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_left 2 | import numpy as np 3 | 4 | from .pcv import PCV_base 5 | from .components.snake import Snake 6 | from .components.ballot import Ballot 7 | from .components.grid_specs import igc_specs 8 | from .inference.mask_from_vote import MaskFromVote, MFV_CatSeparate 9 | 10 | from .. import cfg 11 | 12 | 13 | def _flesh_out_grid(spec): 14 | '''used only by flesh_out_spec''' 15 | field_diam, grid_spec = Snake.flesh_out_grid_spec(spec) 16 | full_mask = Snake.paint_trail_mask(field_diam, grid_spec) 17 | boundless_mask = Snake.paint_bound_ignore_trail_mask(field_diam, grid_spec) 18 | return { 19 | 'full': full_mask, 20 | 'boundless': boundless_mask 21 | } 22 | 23 | 24 | def flesh_out_spec(spec_group): 25 | ret = {'base': None, 'pyramid': dict()} 26 | ret['base'] = _flesh_out_grid(spec_group['base']) 27 | for k, v in spec_group['pyramid'].items(): 28 | ret['pyramid'][k] = _flesh_out_grid(v) 29 | return ret 30 | 31 | 32 | class PCV_IGC_Boundless(PCV_base): 33 | def __init__(self): 34 | # grid inx for now is a dummy 35 | spec_group = igc_specs[cfg.pcv.grid_inx] 36 | self.num_groups = cfg.pcv.num_groups 37 | self.centroid_mode = cfg.pcv.centroid 38 | self.raw_spec = spec_group['base'] 39 | _, self.grid_spec = Snake.flesh_out_grid_spec(self.raw_spec) 40 | self.mask_group = flesh_out_spec(spec_group) 41 | self._vote_mask = self.mask_group['base']['full'] 42 | self._ballot_module = None 43 | self.coalesce_thresh = list(self.mask_group['pyramid'].keys()) 44 | 45 | @property 46 | def ballot_module(self): 47 | # instantiate on demand to prevent train time data loading workers to 48 | # hold on to GPU memory 49 | if self._ballot_module is None: 50 | self._ballot_module = Ballot(self.raw_spec, self.num_groups).cuda() 51 | return self._ballot_module 52 | 53 | # 1 for bull's eye center, 1 for abstain vote 54 | @property 55 | def num_bins(self): 56 | return len(self.grid_spec) 57 | 58 | @property 59 | def num_votes(self): 60 | return 1 + self.num_bins 61 | 62 | @property 63 | def vote_mask(self): 64 | return self._vote_mask 65 | 66 | @property 67 | def query_mask(self): 68 | """ 69 | Flipped from inside out 70 | """ 71 | diam = len(self.vote_mask) 72 | radius = (diam - 1) // 2 73 | center = (radius, radius) 74 | mask_shape = self.vote_mask.shape 75 | offset_grid = np.indices(mask_shape).transpose(1, 2, 0)[..., ::-1] 76 | offsets = center - offset_grid 77 | # allegiance = self.discrete_vote_inx_from_offset( 78 | # offsets.reshape(-1, 2) 79 | # ).reshape(mask_shape) 80 | allegiance = self._discretize_offset( 81 | self.vote_mask, offsets.reshape(-1, 2) 82 | ).reshape(mask_shape) 83 | return allegiance 84 | 85 | def centroid_from_ins_mask(self, ins_mask): 86 | return super().centroid_from_ins_mask(ins_mask) 87 | 88 | def discrete_vote_inx_from_offset(self, offset): 89 | base_boundless_mask = self.mask_group['base']['boundless'] 90 | return self._discretize_offset(base_boundless_mask, offset) 91 | 92 | def tensorized_vote_from_offset(self, offset): 93 | """ 94 | Args: 95 | offset: [N, 2] array of offset towards each pixel's own center, 96 | Each row is filled with (x, y) pair, not (y, x)! 97 | Returns: 98 | vote_tsr: [N, num_votes] of bool tsr where 0/1 denotes gt entries. 99 | """ 100 | # dispatch to the proper grid 101 | base_mask = self.mask_group['base']['boundless'] 102 | radius = (len(base_mask) - 1) // 2 103 | max_offset = min(radius, np.abs(offset).max()) 104 | inx = bisect_left(self.coalesce_thresh, max_offset) 105 | key = self.coalesce_thresh[inx] 106 | vote_mask = self.mask_group['pyramid'][key]['boundless'] 107 | 108 | tsr = np.zeros((len(offset), self.num_votes), dtype=np.bool) # [N, num_votes] 109 | gt_indices = self._discretize_offset(vote_mask, offset) # [N, ] 110 | for inx in np.unique(gt_indices): 111 | if inx == -1: 112 | continue 113 | entries = np.unique(base_mask[vote_mask == inx]) 114 | inds = np.where(gt_indices == inx)[0] 115 | tsr[inds.reshape(-1, 1), entries.reshape(1, -1)] = True 116 | return tsr 117 | 118 | def mask_from_sem_vote_tsr(self, dset_meta, sem_pred, vote_pred): 119 | # make the meta data actually required explicit!! 120 | if self.num_groups == 1: 121 | mfv = MaskFromVote(dset_meta, self, sem_pred, vote_pred) 122 | else: 123 | mfv = MFV_CatSeparate(dset_meta, self, sem_pred, vote_pred) 124 | pan_mask, meta = mfv.infer_panoptic_mask() 125 | return pan_mask, meta 126 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The final assembled model is in assembled.py 3 | The rest of the modules contain helpers and modular components 4 | """ 5 | import torch 6 | import torch.nn 7 | from panoptic.utils import dynamic_load_py_object 8 | 9 | 10 | def get_model_module(model_name): 11 | return dynamic_load_py_object(__name__, model_name) 12 | 13 | 14 | def convert_inplace_sync_batchnorm(module, process_group=None): 15 | r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to 16 | `torch.nn.SyncBatchNorm` layer. 17 | Args: 18 | module (nn.Module): containing module 19 | process_group (optional): process group to scope synchronization, 20 | default is the whole world 21 | Returns: 22 | The original module with the converted `torch.nn.SyncBatchNorm` layer 23 | Example:: 24 | >>> # Network with nn.BatchNorm layer 25 | >>> module = torch.nn.Sequential( 26 | >>> torch.nn.Linear(20, 100), 27 | >>> torch.nn.BatchNorm1d(100) 28 | >>> ).cuda() 29 | >>> # creating process group (optional) 30 | >>> # process_ids is a list of int identifying rank ids. 31 | >>> process_group = torch.distributed.new_group(process_ids) 32 | >>> sync_bn_module = convert_sync_batchnorm(module, process_group) 33 | """ 34 | from inplace_abn import InPlaceABNSync 35 | module_output = module 36 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 37 | module_output = InPlaceABNSync(module.num_features, 38 | module.eps, module.momentum, 39 | module.affine, 40 | activation='identity', 41 | group=process_group) 42 | if module.affine: 43 | module_output.weight.data = module.weight.data.clone().detach() 44 | module_output.bias.data = module.bias.data.clone().detach() 45 | # keep reuqires_grad unchanged 46 | module_output.weight.requires_grad = module.weight.requires_grad 47 | module_output.bias.requires_grad = module.bias.requires_grad 48 | module_output.running_mean = module.running_mean 49 | module_output.running_var = module.running_var 50 | if isinstance(module, torch.nn.ReLU): 51 | module_output = torch.nn.ReLU() 52 | for name, child in module.named_children(): 53 | module_output.add_module(name, convert_inplace_sync_batchnorm(child, process_group)) 54 | del module 55 | return module_output 56 | 57 | def convert_naive_sync_batchnorm(module, process_group=None): 58 | from detectron2.layers import NaiveSyncBatchNorm 59 | return convert_xbatchnorm(module, NaiveSyncBatchNorm, process_group=None) 60 | 61 | def convert_apex_sync_batchnorm(module, process_group=None): 62 | from apex.parallel import SyncBatchNorm 63 | return convert_xbatchnorm(module, SyncBatchNorm, process_group=None) 64 | 65 | 66 | def convert_xbatchnorm(module, bn_module, process_group=None): 67 | r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to 68 | `torch.nn.SyncBatchNorm` layer. 69 | Args: 70 | module (nn.Module): containing module 71 | process_group (optional): process group to scope synchronization, 72 | default is the whole world 73 | Returns: 74 | The original module with the converted `torch.nn.SyncBatchNorm` layer 75 | Example:: 76 | >>> # Network with nn.BatchNorm layer 77 | >>> module = torch.nn.Sequential( 78 | >>> torch.nn.Linear(20, 100), 79 | >>> torch.nn.BatchNorm1d(100) 80 | >>> ).cuda() 81 | >>> # creating process group (optional) 82 | >>> # process_ids is a list of int identifying rank ids. 83 | >>> process_group = torch.distributed.new_group(process_ids) 84 | >>> sync_bn_module = convert_sync_batchnorm(module, process_group) 85 | """ 86 | module_output = module 87 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 88 | module_output = bn_module(module.num_features, 89 | module.eps, module.momentum, 90 | module.affine, 91 | module.track_running_stats, 92 | process_group) 93 | if module.affine: 94 | module_output.weight.data = module.weight.data.clone().detach() 95 | module_output.bias.data = module.bias.data.clone().detach() 96 | # keep reuqires_grad unchanged 97 | module_output.weight.requires_grad = module.weight.requires_grad 98 | module_output.bias.requires_grad = module.bias.requires_grad 99 | module_output.running_mean = module.running_mean 100 | module_output.running_var = module.running_var 101 | module_output.num_batches_tracked = module.num_batches_tracked 102 | for name, child in module.named_children(): 103 | module_output.add_module(name, convert_xbatchnorm(child, bn_module, process_group)) 104 | del module 105 | return module_output -------------------------------------------------------------------------------- /src/pcv/components/snake.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module manages the generation of snake grid on which pcv operates. 3 | 4 | It is intended to be maximally flexible by taking in simple parameters to 5 | achieve various grid configurations for downstream experimentation. 6 | """ 7 | import numpy as np 8 | 9 | 10 | class Snake(): 11 | def __init__(self): 12 | pass 13 | 14 | @staticmethod 15 | def flesh_out_grid_spec(raw_spec): 16 | """ 17 | Args: 18 | raw_spec: [N, 2] each row (size, num_rounds) 19 | """ 20 | if not isinstance(raw_spec, np.ndarray): 21 | raw_spec = np.array(raw_spec) 22 | shape = raw_spec.shape 23 | assert len(shape) == 2 and shape[0] > 0 and shape[1] == 2 24 | size = raw_spec[0][0] 25 | trail = [ np.array([0, 0, (size - 1) // 2]).reshape(1, -1), ] 26 | field_diam = size 27 | for size, num_rounds in raw_spec: 28 | for _ in range(num_rounds): 29 | field_diam, _round_trail = Snake.ring_walk(field_diam, size) 30 | trail.append(_round_trail) 31 | trail = np.concatenate(trail, axis=0) 32 | return field_diam, trail 33 | 34 | @staticmethod 35 | def ring_walk(field_diam, body_diam): 36 | assert body_diam > 0 and body_diam % 2 == 1 37 | body_radius = (body_diam - 1) // 2 38 | 39 | assert field_diam > 0 and field_diam % 2 == 1 40 | field_radius = (field_diam - 1) // 2 41 | 42 | assert field_diam % body_diam == 0 43 | 44 | ext_diam = field_diam + 2 * body_diam 45 | ext_radius = field_radius + body_diam 46 | assert ext_diam == ext_radius * 2 + 1 47 | 48 | j = 1 + field_radius + body_radius 49 | # each of the corner coord is the offset from field center 50 | # anticlockwise SE -> NE -> NW -> SW -> 51 | corner_centers = np.array([(+j, +j), (-j, +j), (-j, -j), (+j, -j)]) 52 | directs = np.array([(-1, 0), (0, -1), (+1, 0), (0, +1)]) 53 | trail = [] 54 | num_tiles = ext_diam // body_diam 55 | for corn, dirc in zip(corner_centers, directs): 56 | segment = [corn + i * dirc * body_diam for i in range(num_tiles - 1)] 57 | trail.extend(segment) 58 | trail = np.array(trail) # [N, 2] each row is a center-based offset 59 | sizes = np.array([body_radius] * len(trail)).reshape(-1, 1) 60 | trail = np.concatenate([trail, sizes], axis=1) 61 | return ext_diam, trail # [N, 3]: (offset_y, offset_x, radius) 62 | 63 | @staticmethod 64 | def paint_trail_mask(field_diam, trail, tiling=True): 65 | """ 66 | Args: 67 | trail: [N, 3] each row (offset_y, offset_x, radius) 68 | """ 69 | assert field_diam > 0 and field_diam % 2 == 1 70 | field_radius = (field_diam - 1) // 2 71 | CEN = np.array((field_radius, field_radius)) 72 | trail = trail.copy() 73 | trail[:, :2] += CEN 74 | canvas = -1 * np.ones((field_diam, field_diam), dtype=int) 75 | for i, walk in enumerate(trail): 76 | y, x, r = walk 77 | if tiling: 78 | y, x = y - r, x - r 79 | d = 2 * r + 1 80 | canvas[y: y + d, x: x + d] = i 81 | else: 82 | canvas[y, x] = i 83 | return canvas 84 | 85 | @staticmethod 86 | def paint_bound_ignore_trail_mask(field_diam, trail): 87 | """ 88 | Args: 89 | trail: [N, 3] each row (offset_y, offset_x, radius) 90 | """ 91 | assert field_diam > 0 and field_diam % 2 == 1 92 | field_radius = (field_diam - 1) // 2 93 | CEN = np.array((field_radius, field_radius)) 94 | trail = trail.copy() 95 | trail[:, :2] += CEN 96 | canvas = -1 * np.ones((field_diam, field_diam), dtype=int) 97 | for i, walk in enumerate(trail): 98 | y, x, r = walk 99 | d = 2 * r + 1 100 | boundary_ignore = int(d * 0.12) 101 | # if d > 1: # at least cut 1 unless the grid is of size 1 102 | # boundary_ignore = max(boundary_ignore, 1) 103 | y, x = y - r + boundary_ignore, x - r + boundary_ignore 104 | d = d - 2 * boundary_ignore 105 | canvas[y: y + d, x: x + d] = i 106 | return canvas 107 | 108 | @staticmethod 109 | def vote_channel_splits(raw_spec): 110 | splits = [] 111 | acc_size = raw_spec[0][0] # inner most size 112 | for i, (size, num_rounds) in enumerate(raw_spec): 113 | inner_blocks = acc_size // size 114 | total_blocks = inner_blocks + 2 * num_rounds 115 | acc_size += 2 * num_rounds * size 116 | prepend = (i == 0) 117 | num_chnls = (total_blocks ** 2 - inner_blocks ** 2) + int(prepend) 118 | splits.append(num_chnls) 119 | return splits 120 | 121 | 122 | if __name__ == '__main__': 123 | sample_spec = np.array([ 124 | # (size, num_rounds) 125 | (1, 1), 126 | (3, 1), 127 | (9, 1) 128 | ]) 129 | s = Snake() 130 | diam, trail = s.flesh_out_grid_spec(sample_spec) 131 | mask = s.paint_trail_mask(diam, trail, tiling=True) 132 | print(mask.shape) 133 | -------------------------------------------------------------------------------- /src/reporters.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import json 4 | import numpy as np 5 | from PIL import Image 6 | import torch.nn.functional as F 7 | from easydict import EasyDict as edict 8 | from .metric import PanMetric as Metric, PQMetric 9 | from panopticapi.utils import rgb2id, id2rgb 10 | 11 | from fabric.io import save_object 12 | 13 | _VALID_ORACLE_MODES = ('full', 'sem', 'vote') 14 | 15 | 16 | class BaseReporter(): 17 | def __init__(self, infer_cfg, model, dset, output_root): 18 | self.output_root = output_root 19 | 20 | def process(self, model_outputs, inputs, dset): 21 | pass 22 | 23 | def generate_report(self): 24 | # return some report here 25 | pass 26 | 27 | 28 | class mIoU(BaseReporter): 29 | def __init__(self, infer_cfg, model, dset, output_root): 30 | del infer_cfg # not needed 31 | self.metric = Metric( 32 | model.dset_meta['num_classes'], 33 | model.pcv.num_votes, model.dset_meta['trainId_2_catName'] 34 | ) 35 | self.output_root = output_root 36 | 37 | def process(self, inx, model_outputs, inputs, dset): 38 | sem_pred, vote_pred = model_outputs 39 | sem_pred = F.interpolate( 40 | sem_pred, scale_factor=4, mode='nearest' 41 | ) 42 | sem_pred, vote_pred = sem_pred.argmax(1), vote_pred.argmax(1) 43 | self.metric.update(sem_pred, vote_pred, *inputs[1:3]) 44 | 45 | def generate_report(self): 46 | # return some report here 47 | return str(self.metric) 48 | 49 | 50 | class PQ_report(BaseReporter): 51 | def __init__(self, infer_cfg, model, dset, output_root): 52 | self.infer_cfg = edict(infer_cfg.copy()) 53 | self.output_root = output_root 54 | self.metric = PQMetric(dset.meta) 55 | self.overall_pred_meta = { 56 | 'images': list(dset.imgs.values()), 57 | 'categories': list(dset.meta['cats'].values()), 58 | 'annotations': [] 59 | } 60 | self.model = model 61 | self.oracle_mode = self.infer_cfg['oracle_mode'] 62 | if self.oracle_mode is not None: 63 | assert self.oracle_mode in _VALID_ORACLE_MODES 64 | 65 | os.makedirs(osp.dirname(output_root), exist_ok=True) 66 | PRED_OUT_NAME = 'pred' 67 | self.pan_json_fname = osp.join(output_root, '{}.json'.format(PRED_OUT_NAME)) 68 | self.pan_mask_dir = osp.join(output_root, PRED_OUT_NAME) 69 | 70 | def process(self, inx, model_outputs, inputs, dset): 71 | from panoptic.entry import gt_tsr_res_reduction # HELPPPPPPPPPPPPPPPPPPPPPPPPPPPPP 72 | sem_pd, vote_pd = model_outputs 73 | sem_pd, vote_pd = F.softmax(sem_pd, dim=1), F.softmax(vote_pd, dim=1) 74 | 75 | model = self.model 76 | oracle_mode = self.oracle_mode 77 | oracle_res = 4 78 | imgMeta, segments_info, _, pan_gt_mask = dset.pan_getitem( 79 | inx, apply_trans=False 80 | ) 81 | if dset.transforms is not None: 82 | _, trans_pan_gt = dset.transforms(_, pan_gt_mask) 83 | else: 84 | trans_pan_gt = pan_gt_mask.copy() 85 | 86 | pan_gt_ann = { 87 | 'image_id': imgMeta['id'], 88 | # shameful mogai; can only access image f_name here. alas... 89 | 'file_name': imgMeta['file_name'].split('.')[0] + '.png', 90 | 'segments_info': list(segments_info.values()) 91 | } 92 | 93 | if oracle_mode is not None: 94 | sem_ora, vote_ora = gt_tsr_res_reduction( 95 | oracle_res, dset.gt_prod_handle, 96 | dset.meta, model.pcv, trans_pan_gt, segments_info 97 | ) 98 | if oracle_mode == 'vote': 99 | pass # if using model sem pd, maintain stuff pred thresh 100 | else: # sem or full oracle, using gt sem pd, do not filter 101 | self.infer_cfg['stuff_pred_thresh'] = -1 102 | 103 | if oracle_mode == 'sem': 104 | sem_pd = sem_ora 105 | elif oracle_mode == 'vote': 106 | vote_pd = vote_ora 107 | else: 108 | sem_pd, vote_pd = sem_ora, vote_ora 109 | 110 | pan_pd_mask, pan_pd_ann = model.stitch_pan_mask( 111 | self.infer_cfg, sem_pd, vote_pd, pan_gt_mask.size 112 | ) 113 | pan_pd_ann['image_id'] = pan_gt_ann['image_id'] 114 | pan_pd_ann['file_name'] = pan_gt_ann['file_name'] 115 | self.overall_pred_meta['annotations'].append(pan_pd_ann) 116 | 117 | self.metric.update( 118 | pan_gt_ann, rgb2id(np.array(pan_gt_mask)), pan_pd_ann, pan_pd_mask 119 | ) 120 | 121 | pan_pd_mask = Image.fromarray(id2rgb(pan_pd_mask)) 122 | fname = osp.join(self.pan_mask_dir, pan_pd_ann['file_name']) 123 | os.makedirs(osp.dirname(fname), exist_ok=True) # make region subdir 124 | pan_pd_mask.save(fname) 125 | 126 | def generate_report(self): 127 | with open(self.pan_json_fname, 'w') as f: 128 | json.dump(self.overall_pred_meta, f) 129 | save_object( 130 | self.metric.state_dict(), osp.join(self.output_root, 'score.pkl') 131 | ) 132 | return self.metric 133 | 134 | 135 | reporter_modules = { 136 | 'mIoU': mIoU, 137 | 'pq': PQ_report 138 | } 139 | -------------------------------------------------------------------------------- /src/datasets/augmentations/augmentations.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | # import torchvision.transforms as tv_trans 5 | from torchvision.transforms import ColorJitter as tv_color_jitter 6 | import torchvision.transforms.functional as tv_f 7 | from PIL import Image 8 | 9 | 10 | def _pad_to_sizes(mask, tw, th): 11 | w, h = mask.size 12 | right = max(tw - w, 0) 13 | bottom = max(th - h, 0) 14 | mask = tv_f.pad(mask, fill=0, padding=(0, 0, right, bottom)) 15 | return mask 16 | 17 | 18 | class RandomCrop(): 19 | def __init__(self, size): 20 | if isinstance(size, (tuple, list)): 21 | assert len(size) == 2 22 | self.size = size 23 | else: 24 | self.size = [size, size] 25 | 26 | def __call__(self, img, mask): 27 | assert img.size == mask.size 28 | tw, th = self.size 29 | 30 | img = _pad_to_sizes(img, tw, th) 31 | mask = _pad_to_sizes(mask, tw, th) 32 | w, h = img.size 33 | 34 | x1 = random.randint(0, w - tw) 35 | y1 = random.randint(0, h - th) 36 | square = (x1, y1, x1 + tw, y1 + th) 37 | return img.crop(square), mask.crop(square) 38 | 39 | 40 | class RandomHorizontalFlip(): 41 | def __init__(self, p=0.5): 42 | self.p = p 43 | 44 | def __call__(self, img, mask): 45 | if random.random() < self.p: 46 | return ( 47 | img.transpose(Image.FLIP_LEFT_RIGHT), 48 | mask.transpose(Image.FLIP_LEFT_RIGHT), 49 | ) 50 | return img, mask 51 | 52 | 53 | class Scale(): 54 | def __init__(self, ratio=0.5): 55 | self.ratio = ratio 56 | 57 | def __call__(self, img, mask): 58 | w, h = img.size 59 | ratio = self.ratio 60 | target_size = (int(ratio * w), int(ratio * h)) 61 | img = img.resize(target_size, Image.BILINEAR) 62 | mask = mask.resize(target_size, Image.NEAREST) 63 | return (img, mask) 64 | 65 | 66 | class Resize(): 67 | def __init__(self, size): 68 | self.size = size 69 | 70 | def __call__(self, img, mask): 71 | assert img.size == mask.size 72 | target_size = self.size 73 | img = img.resize(target_size, Image.BILINEAR) 74 | mask = mask.resize(target_size, Image.NEAREST) 75 | return (img, mask) 76 | 77 | 78 | class RandomSized(): 79 | def __init__(self, jitter=(0.5, 2)): 80 | self.jitter = jitter 81 | 82 | def __call__(self, img, mask): 83 | assert img.size == mask.size 84 | 85 | scale = random.uniform(self.jitter[0], self.jitter[1]) 86 | 87 | w = int(scale * img.size[0]) 88 | h = int(scale * img.size[1]) 89 | 90 | img, mask = ( 91 | img.resize((w, h), Image.BILINEAR), 92 | mask.resize((w, h), Image.NEAREST), 93 | ) 94 | 95 | return img, mask 96 | 97 | 98 | class RoundToMultiple(): 99 | def __init__(self, stride, method='pad'): 100 | assert method in ('pad', 'resize') 101 | self.stride = stride 102 | self.method = method 103 | 104 | def __call__(self, img, mask): 105 | assert img.size == mask.size 106 | stride = self.stride 107 | if stride > 0: 108 | w, h = img.size 109 | w = int(math.ceil(w / stride) * stride) 110 | h = int(math.ceil(h / stride) * stride) 111 | if self.method == 'pad': 112 | img = _pad_to_sizes(img, w, h) 113 | mask = _pad_to_sizes(mask, w, h) 114 | else: 115 | img = img.resize((w, h), Image.BILINEAR) 116 | mask = mask.resize((w, h), Image.NEAREST) 117 | return img, mask 118 | 119 | 120 | class COCOResize(): 121 | ''' 122 | adapted from maskrcnn_benchmark/data/transforms/transforms.py 123 | ''' 124 | def __init__(self, min_size, max_size, round_to_divisble): 125 | if not isinstance(min_size, (list, tuple)): 126 | min_size = (min_size,) 127 | assert round_to_divisble in ('pad', 'resize') 128 | self.min_size = min_size 129 | self.max_size = max_size 130 | self.round_to_divisble = round_to_divisble 131 | 132 | # modified from torchvision to add support for max size 133 | def get_size(self, image_size): 134 | w, h = image_size 135 | size = random.choice(self.min_size) 136 | max_size = self.max_size 137 | if max_size is not None: 138 | min_original_size = float(min((w, h))) 139 | max_original_size = float(max((w, h))) 140 | if max_original_size / min_original_size * size > max_size: 141 | size = int(round(max_size * min_original_size / max_original_size)) 142 | 143 | if (w <= h and w == size) or (h <= w and h == size): 144 | return (h, w) 145 | 146 | if w < h: 147 | ow = size 148 | oh = int(size * h / w) 149 | else: 150 | oh = size 151 | ow = int(size * w / h) 152 | 153 | return (ow, oh) 154 | 155 | def __call__(self, img, mask): 156 | assert img.size == mask.size 157 | size = self.get_size(img.size) 158 | img = img.resize(size) 159 | mask = mask.resize(size) 160 | padder = RoundToMultiple(32, self.round_to_divisble) 161 | img, mask = padder(img, mask) 162 | return img, mask 163 | 164 | 165 | class ColorJitter(): 166 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 167 | self.tv_jitter = tv_color_jitter( 168 | brightness, contrast, saturation, hue 169 | ) 170 | 171 | def __call__(self, img, mask): 172 | assert img.size == mask.size 173 | img = self.tv_jitter(img) 174 | return img, mask 175 | -------------------------------------------------------------------------------- /src/datasets/samplers/grouped_batch_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import itertools 3 | import math 4 | import torch 5 | from torch.utils.data.sampler import BatchSampler 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | class GroupedBatchSampler(BatchSampler): 10 | """ 11 | Wraps another sampler to yield a mini-batch of indices. 12 | It enforces that elements from the same group should appear in groups of batch_size. 13 | It also tries to provide mini-batches which follows an ordering which is 14 | as close as possible to the ordering from the original sampler. 15 | 16 | Arguments: 17 | sampler (Sampler): Base sampler. 18 | batch_size (int): Size of mini-batch. 19 | drop_uneven (bool): If ``True``, the sampler will drop the batches whose 20 | size is less than ``batch_size`` 21 | 22 | """ 23 | 24 | def __init__(self, sampler, group_ids, batch_size, drop_uneven=False): 25 | if not isinstance(sampler, Sampler): 26 | raise ValueError( 27 | "sampler should be an instance of " 28 | "torch.utils.data.Sampler, but got sampler={}".format(sampler) 29 | ) 30 | assert drop_uneven is False, 'for now do not allow dropping last batch' 31 | self.sampler = sampler 32 | self.group_ids = torch.as_tensor(group_ids) 33 | assert self.group_ids.dim() == 1 34 | self.batch_size = batch_size 35 | self.drop_uneven = drop_uneven 36 | 37 | self.groups = torch.unique(self.group_ids).sort(0)[0] 38 | 39 | self._can_reuse_batches = False 40 | self._intended_num_batches = int(math.ceil(len(sampler) / batch_size)) 41 | 42 | def _prepare_batches(self): 43 | dataset_size = len(self.group_ids) 44 | # get the sampled indices from the sampler 45 | sampled_ids = torch.as_tensor(list(self.sampler)) 46 | # potentially not all elements of the dataset were sampled 47 | # by the sampler (e.g., DistributedSampler). 48 | # construct a tensor which contains -1 if the element was 49 | # not sampled, and a non-negative number indicating the 50 | # order where the element was sampled. 51 | # for example. if sampled_ids = [3, 1] and dataset_size = 5, 52 | # the order is [-1, 1, -1, 0, -1] 53 | order = torch.full((dataset_size,), -1, dtype=torch.int64) 54 | order[sampled_ids] = torch.arange(len(sampled_ids)) 55 | 56 | # get a mask with the elements that were sampled 57 | mask = order >= 0 58 | 59 | # find the elements that belong to each individual cluster 60 | clusters = [(self.group_ids == i) & mask for i in self.groups] 61 | # get relative order of the elements inside each cluster 62 | # that follows the order from the sampler 63 | relative_order = [order[cluster] for cluster in clusters] 64 | # with the relative order, find the absolute order in the 65 | # sampled space 66 | permutation_ids = [s[s.sort()[1]] for s in relative_order] 67 | # permute each cluster so that they follow the order from 68 | # the sampler 69 | permuted_clusters = [sampled_ids[idx] for idx in permutation_ids] 70 | 71 | # splits each cluster in batch_size, and merge as a list of tensors 72 | splits = [c.split(self.batch_size) for c in permuted_clusters] 73 | merged = tuple(itertools.chain.from_iterable(splits)) 74 | merged = merged[:self._intended_num_batches] 75 | 76 | # now each batch internally has the right order, but 77 | # they are grouped by clusters. Find the permutation between 78 | # different batches that brings them as close as possible to 79 | # the order that we have in the sampler. For that, we will consider the 80 | # ordering as coming from the first element of each batch, and sort 81 | # correspondingly 82 | first_element_of_batch = [t[0].item() for t in merged] 83 | # get and inverse mapping from sampled indices and the position where 84 | # they occur (as returned by the sampler) 85 | inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())} 86 | # from the first element in each batch, get a relative ordering 87 | first_index_of_batch = torch.as_tensor( 88 | [inv_sampled_ids_map[s] for s in first_element_of_batch] 89 | ) 90 | 91 | # permute the batches so that they approximately follow the order 92 | # from the sampler 93 | permutation_order = first_index_of_batch.sort(0)[1].tolist() 94 | # finally, permute the batches 95 | batches = [merged[i].tolist() for i in permutation_order] 96 | 97 | if self.drop_uneven: 98 | kept = [] 99 | for batch in batches: 100 | if len(batch) == self.batch_size: 101 | kept.append(batch) 102 | batches = kept 103 | # print( 104 | # '{} sampled points; size of cluster {}; size of splits {}; ' 105 | # 'size of merged: {}, size of batches {}'.format( 106 | # len(sampled_ids), 107 | # [inds.shape[0] for inds in permuted_clusters], 108 | # [len(sp) for sp in splits], len(merged), len(batches) 109 | # ) 110 | # ) 111 | return batches 112 | 113 | def __iter__(self): 114 | if self._can_reuse_batches: 115 | batches = self._batches 116 | self._can_reuse_batches = False 117 | else: 118 | batches = self._prepare_batches() 119 | self._batches = batches 120 | return iter(batches) 121 | 122 | def __len__(self): 123 | if not hasattr(self, "_batches"): 124 | self._batches = self._prepare_batches() 125 | self._can_reuse_batches = True 126 | return len(self._batches) 127 | -------------------------------------------------------------------------------- /src/pcv/gaussian_smooth/prob_tsr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy.stats import multivariate_normal 5 | 6 | from ..components.snake import Snake 7 | 8 | _nine_offsets = [ 9 | ( 0, 0), 10 | ( 1, 1), 11 | ( 0, 1), 12 | (-1, 1), 13 | (-1, 0), 14 | (-1, -1), 15 | ( 0, -1), 16 | ( 1, -1), 17 | ( 1, 0), 18 | ] 19 | 20 | 21 | class GaussianField(): 22 | def __init__(self, diam, cov=0.05): 23 | assert (diam % 2 == 1), 'diam must be an odd' 24 | self.diam = diam 25 | self.cov = cov # .05 leaves about 95% prob mass within central block 26 | # only consider the 3x3 region 27 | self.increment = 1 / diam 28 | # compute 3 units 29 | self.l, self.r = -1.5, 1.5 30 | self.field_shape = (3 * diam, 3 * diam) 31 | self.unit_area = self.increment ** 2 32 | self.prob_field = self.compute_prob_field() 33 | 34 | def compute_prob_field(self): 35 | cov = self.cov 36 | increment = self.increment 37 | l, r = self.l, self.r 38 | cov_mat = np.array([ 39 | [cov, 0], 40 | [0, cov] 41 | ]) 42 | rv = multivariate_normal([0, 0], cov_mat) 43 | half_increment = increment / 2 44 | xs, ys = np.mgrid[ 45 | l + half_increment: r: increment, 46 | l + half_increment: r: increment 47 | ] # use half increment to make things properly centered 48 | pos = np.dstack((xs, ys)) 49 | prob_field = rv.pdf(pos).astype(np.float32) 50 | assert prob_field.shape == self.field_shape 51 | return prob_field 52 | 53 | @torch.no_grad() 54 | def compute_local_mass(self): 55 | kernel_size = self.diam 56 | pad = (kernel_size - 1) // 2 57 | prob_field = self.prob_field 58 | 59 | conv = nn.Conv2d( 60 | in_channels=1, out_channels=1, kernel_size=kernel_size, 61 | padding=pad, bias=False 62 | ) # do not use cuda for now; no point 63 | conv.weight.data.copy_(torch.tensor(1.0)) 64 | prob_field = torch.as_tensor( 65 | prob_field, device=conv.weight.device 66 | )[(None,) * 2] # [1, 1, h, w] 67 | local_sum = conv(prob_field).squeeze().cpu().numpy() 68 | local_sum = local_sum * self.unit_area 69 | return local_sum 70 | 71 | 72 | class MakeProbTsr(): 73 | ''' 74 | make a prob tsr of shape [h, w, num_votes] filled with the corresponding 75 | spatial voting prob 76 | ''' 77 | def __init__(self, spec, diam, grid_spec, vote_mask, var=0.05): 78 | # indices grid of shape [2, H, W], where first dim is y, x; swap them 79 | # obtain [H, W, 2] where last channel is (y, x) 80 | self.spec = spec 81 | self.diam = diam 82 | self.vote_mask = vote_mask 83 | self.var = var 84 | 85 | # process grid spec to 0 based indexing and change radius to diam 86 | radius = (diam - 1) // 2 87 | center = np.array((radius, radius)) 88 | grid_spec = grid_spec.copy() 89 | grid_spec[:, :2] += center 90 | grid_spec[:, -1] = 1 + 2 * grid_spec[:, -1] # change from r to diam 91 | self.grid_spec = grid_spec 92 | 93 | def compute_voting_prob_tsr(self, normalize=True): 94 | spec = self.spec 95 | diam = self.diam 96 | grid_spec = self.grid_spec 97 | vote_mask = self.vote_mask 98 | 99 | spatial_shape = (diam, diam) 100 | spatial_yx = np.indices(spatial_shape).transpose(1, 2, 0).astype(int) 101 | # [H, W, 2] where each arr[y, x] is the containing grid's center 102 | spatial_cen_yx = np.empty_like(spatial_yx) 103 | # [H, W, 1] where each arr[y, x] is the containing grid's diam 104 | spatial_diam = np.empty(spatial_shape, dtype=int)[..., None] 105 | for i, (y, x, d) in enumerate(grid_spec): 106 | _m = vote_mask == i 107 | spatial_cen_yx[_m] = (y, x) 108 | spatial_diam[_m] = d 109 | 110 | max_vote_bin_diam = spec[-1][0] 111 | spatial_9_inds = self.nine_neighbor_inds( 112 | spatial_diam, spatial_yx, vote_mask, 113 | vote_mask_padding=max_vote_bin_diam 114 | ) 115 | 116 | # spatial_9_probs = np.ones_like(spatial_9_inds).astype(float) 117 | spatial_9_probs = self.nine_neighbor_probs( 118 | spatial_diam, spatial_yx, spatial_cen_yx, self.var 119 | ) 120 | 121 | # [H, W, num_votes + 1] 1 extra to trash the -1s 122 | spatial_prob = np.zeros((diam, diam, len(grid_spec) + 1)) 123 | inds0, inds1, _ = np.ix_(range(diam), range(diam), range(1)) 124 | np.add.at(spatial_prob, (inds0, inds1, spatial_9_inds), spatial_9_probs) 125 | spatial_prob[..., -1] = 0 # erase but keep the trash bin -> abstrain bin 126 | spatial_prob = self.erase_inward_prob_dist(spec, vote_mask, spatial_prob) 127 | if normalize: 128 | spatial_prob = spatial_prob / spatial_prob.sum(-1, keepdims=True) 129 | return spatial_prob 130 | 131 | @staticmethod 132 | def erase_inward_prob_dist(spec, vote_mask, spatial_prob): 133 | '''This is a measure of expedience borne of time constraints 134 | I can't help but feel ashamed of the time I have wasted dwelling on the 135 | right move; but the clock is ticking and I have to move on. 136 | ''' 137 | splits = Snake.vote_channel_splits(spec) 138 | # ret = np.zeros_like(spatial_prob) 139 | layer_inds = np.cumsum(splits) 140 | for i in range(1, len(layer_inds)): 141 | curr = layer_inds[i] 142 | prev = layer_inds[i-1] 143 | belt_mask = (vote_mask < curr) & (vote_mask >= prev) 144 | spatial_prob[belt_mask, :prev] = 0 145 | return spatial_prob 146 | 147 | @staticmethod 148 | def nine_neighbor_inds( 149 | spatial_diam, spatial_yx, vote_mask, vote_mask_padding, 150 | ): 151 | # [H, W, 1, 1] * [9, 2] -> [H, W, 9, 2] 152 | spatial_9_offsets = spatial_diam[..., None] * np.array(_nine_offsets) 153 | # [H, W, 2] reshapes [H, W, 1, 2] + [H, W, 9, 2] -> [H, W, 9, 2] 154 | spatial_9_loc_yx = np.expand_dims(spatial_yx, 2) + spatial_9_offsets 155 | 156 | padded_vote_mask = np.pad( 157 | vote_mask, vote_mask_padding, mode='constant', constant_values=-1 158 | ) 159 | # shift the inds 160 | spatial_9_loc_yx += (vote_mask_padding, vote_mask_padding) 161 | # [H, W, 9] where arr[y, x] contains the 9 inds centered on y, x 162 | spatial_9_inds = padded_vote_mask[ 163 | tuple(np.split(spatial_9_loc_yx, 2, axis=-1)) 164 | ].squeeze(-1) 165 | return spatial_9_inds 166 | 167 | @staticmethod 168 | def nine_neighbor_probs(spatial_diam, spatial_yx, spatial_cen_yx, var): 169 | spatial_cen_yx_offset = spatial_cen_yx - spatial_yx 170 | del spatial_cen_yx, spatial_yx 171 | single_cell_diam = 81 172 | field_diam = single_cell_diam * 3 173 | gauss = GaussianField(diam=single_cell_diam, cov=var) 174 | prob_local_mass = gauss.compute_local_mass() 175 | 176 | # now read off prob from every pix's 9 neighboring locations 177 | ''' 178 | single_cell_diam: scalar; 1/3 of the field size for prob field 179 | prob_local_mass: [3 * single_cell_diam, 3 * single_cell_diam] 180 | spatial_diam: [H, W, 1]; arr[y, x] gives its grid diam 181 | spatial_cen_yx_offset: [H, W, 2] arr[y, x] gives dy, dx to its grid center 182 | ''' 183 | assert field_diam == prob_local_mass.shape[0] 184 | assert prob_local_mass.shape[0] == prob_local_mass.shape[1] 185 | norm_spatial_cen_yx_offset = ( 186 | spatial_cen_yx_offset * single_cell_diam / spatial_diam 187 | ).astype(np.int) # [H, W, 2] 188 | del spatial_cen_yx_offset, spatial_diam 189 | 190 | spatial_9_offsets = ( 191 | single_cell_diam * np.array(_nine_offsets) 192 | ).reshape(1, 1, 9, 2) 193 | field_radius = (field_diam - 1) // 2 194 | center = (field_radius, field_radius) 195 | spatial_yx_loc = center + norm_spatial_cen_yx_offset 196 | # [H, W, 2] reshapes [H, W, 1, 2] + [1, 1, 9, 2] -> [H, W, 9, 2] 197 | spatial_9_loc_yx = np.expand_dims(spatial_yx_loc, axis=2) + spatial_9_offsets 198 | 199 | spatial_9_probs = prob_local_mass[ 200 | tuple(np.split(spatial_9_loc_yx, 2, axis=-1)) 201 | ].squeeze(-1) 202 | return spatial_9_probs 203 | -------------------------------------------------------------------------------- /src/metric.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import torch 4 | from tabulate import tabulate 5 | from panoptic.pan_eval import pq_compute_single_img, PQStat 6 | 7 | 8 | class Metric(object): 9 | def __init__(self, num_classes, trainId_2_catName=None): 10 | """ 11 | Args: 12 | window_size: the number of batch of visuals stated here 13 | 14 | All metrics are computed from a confusion matrix that aggregates pixel 15 | count across all images. This evaluation scheme only has a global 16 | view of all the pixels, and disregards image as a unit. 17 | "mean" always refers to averaging across pixel semantic classes. 18 | """ 19 | self.num_classes = num_classes 20 | self.trainId_2_catName = trainId_2_catName 21 | self.init_state() 22 | 23 | def init_state(self): 24 | """caller may use it to reset the metric""" 25 | self.scores = dict() 26 | self.confusion_matrix = np.zeros( 27 | shape=(self.num_classes, self.num_classes), dtype=np.int64 28 | ) 29 | 30 | def update(self, pred, gt): 31 | """ 32 | Args: 33 | pred: [N, H, W] torch tsr or ndarray 34 | gt: [N, H, W] torch tsr or ndarray 35 | """ 36 | if len(pred.shape) == 4: 37 | pred = pred.argmax(dim=1) 38 | assert pred.shape == gt.shape 39 | if isinstance(pred, torch.Tensor): 40 | pred, gt = pred.cpu().numpy(), gt.cpu().numpy() 41 | hist = self.fast_hist(pred, gt, self.num_classes) 42 | self.confusion_matrix += hist 43 | self.scores = self.compute_scores( 44 | self.confusion_matrix, self.trainId_2_catName 45 | ) 46 | 47 | @staticmethod 48 | def fast_hist(pred, gt, num_classes): 49 | assert pred.shape == gt.shape 50 | valid_mask = (gt >= 0) & (gt < num_classes) 51 | hist = np.bincount( 52 | num_classes * gt[valid_mask] + pred[valid_mask], 53 | minlength=num_classes ** 2 54 | ).reshape(num_classes, num_classes) 55 | return hist 56 | 57 | @staticmethod 58 | def compute_scores(hist, trainId_2_catName): 59 | res = dict() 60 | num_classes = hist.shape[0] 61 | 62 | # per class statistics 63 | gt_freq = hist.sum(axis=1) 64 | pd_freq = hist.sum(axis=0) # pred frequency 65 | intersect = np.diag(hist) 66 | union = gt_freq + pd_freq - intersect 67 | iou = intersect / union 68 | 69 | cls_names = [ 70 | trainId_2_catName[inx] for inx in range(num_classes) 71 | ] if trainId_2_catName is not None else range(num_classes) 72 | 73 | details = dict() 74 | for inx, name in enumerate(cls_names): 75 | details[name] = { 76 | 'gt_freq': gt_freq[inx], 77 | 'pd_freq': pd_freq[inx], 78 | 'intersect': intersect[inx], 79 | 'union': union[inx], 80 | 'iou': iou[inx] 81 | } 82 | res['details'] = details 83 | 84 | # aggregate statistics 85 | pix_acc = intersect.sum() / hist.sum() 86 | m_iou = np.nanmean(iou) 87 | freq = gt_freq / gt_freq.sum() 88 | # masking to avoid potential nan in per cls iou 89 | fwm_iou = (freq[freq > 0] * iou[freq > 0]).sum() 90 | del freq 91 | res['pix_acc'] = pix_acc 92 | res['m_iou'] = m_iou 93 | res['fwm_iou'] = fwm_iou 94 | 95 | return res 96 | 97 | def __repr__(self): 98 | return repr(self.scores) 99 | 100 | def __str__(self): 101 | return self.display(self.scores) 102 | 103 | @staticmethod 104 | def display(src_dict): 105 | """Only print out scalar metric like mIoU in a nice tabular form 106 | Detailed per cls info, etc, are withheld for clear presentation 107 | """ 108 | to_display = dict() 109 | for k, v in src_dict.items(): 110 | if isinstance(v, dict): 111 | continue # ignore those which cannot be tabulated 112 | to_display[k] = [v] 113 | table = tabulate( 114 | to_display, 115 | headers='keys', tablefmt='fancy_grid', 116 | floatfmt=".3f", numalign='decimal' 117 | ) 118 | return str(table) 119 | 120 | # these save and load functions are ugly cuz they are tied to the 121 | # infrastructure. Re-write them later. 122 | def save(self, epoch_or_fname, manager): 123 | assert manager is not None 124 | state = self.scores 125 | # now save the acc with manager 126 | if isinstance(epoch_or_fname, int): 127 | epoch = epoch_or_fname 128 | manager.save(epoch, state) 129 | else: 130 | fname = epoch_or_fname 131 | save_path = osp.join(manager.root, fname) 132 | manager.save_f(state, save_path) 133 | 134 | def load(self, state): 135 | """Assume that the state is already read by the caller 136 | Args: 137 | state: dict with fields 'scores' and 'visuals' 138 | """ 139 | self.scores = state 140 | 141 | 142 | class PanMetric(Metric): 143 | """ 144 | This is a metric that evaluates predictions from both heads as pixel-wise 145 | classification 146 | """ 147 | def __init__(self, num_classes, num_votes, trainId_2_catName): 148 | self.num_votes = num_votes 149 | super().__init__(num_classes, trainId_2_catName) 150 | 151 | def init_state(self): 152 | self.scores = dict() 153 | self.sem_confusion = np.zeros( 154 | shape=(self.num_classes, self.num_classes), dtype=np.int64 155 | ) 156 | self.vote_confusion = np.zeros( 157 | shape=(self.num_votes, self.num_votes), dtype=np.int64 158 | ) 159 | 160 | def update(self, sem_pred, vote_pred, sem_gt, vote_gt): 161 | """ 162 | Args: 163 | sem_pred: [N, H, W] torch tsr or ndarray 164 | vote_pred: [N, H, W] torch tsr or ndarray 165 | sem_gt: [N, H, W] torch tsr or ndarray 166 | vote_gt: [N, H, W] torch tsr or ndarray 167 | """ 168 | self._update_pair( 169 | sem_pred, sem_gt, 'sem', self.sem_confusion, self.num_classes) 170 | self._update_pair( 171 | vote_pred, vote_gt, 'vote', self.vote_confusion, self.num_votes) 172 | 173 | def _update_pair(self, pred, gt, key, confusion_matrix, num_cats): 174 | if len(pred.shape) == 4: 175 | pred = pred.argmax(dim=1) 176 | if isinstance(pred, torch.Tensor): 177 | pred, gt = pred.cpu().numpy(), gt.cpu().numpy() 178 | hist = self.fast_hist(pred, gt, num_cats) 179 | confusion_matrix += hist 180 | trainId_2_catName = self.trainId_2_catName if key == 'sem' else None 181 | self.scores[key] = self.compute_scores( 182 | confusion_matrix, trainId_2_catName 183 | ) 184 | 185 | def __str__(self): 186 | sem_str = self.display(self.scores['sem']) 187 | vote_str = self.display(self.scores['vote']) 188 | combined = "sem: \n{} \nvote: \n{}".format(sem_str, vote_str) 189 | return combined 190 | 191 | 192 | class PQMetric(): 193 | def __init__(self, dset_meta): 194 | self.cats = dset_meta['cats'] 195 | self.score = PQStat() 196 | self.metrics = [("All", None), ("Things", True), ("Stuff", False)] 197 | self.results = {} 198 | 199 | def update(self, gt_ann, gt, pred_ann, pred): 200 | assert gt.shape == pred.shape 201 | stat = pq_compute_single_img(self.cats, gt_ann, gt, pred_ann, pred) 202 | self.score += stat 203 | 204 | def state_dict(self): 205 | self.aggregate_results() 206 | return self.results 207 | 208 | def aggregate_results(self): 209 | for name, isthing in self.metrics: 210 | self.results[name], per_class_results = self.score.pq_average( 211 | self.cats, isthing=isthing 212 | ) 213 | if name == 'All': 214 | self.results['per_class'] = per_class_results 215 | 216 | def __str__(self): 217 | self.aggregate_results() 218 | headers = [ m[0] for m in self.metrics ] 219 | keys = ['pq', 'sq', 'rq'] 220 | data = [] 221 | for tranche in headers: 222 | row = [100 * self.results[tranche][k] for k in keys] 223 | row = [tranche] + row + [self.results[tranche]['n']] 224 | data.append(row) 225 | table = tabulate( 226 | tabular_data=data, 227 | headers=([''] + keys + ['n_cats']), 228 | floatfmt=".2f", tablefmt='fancy_grid', 229 | ) 230 | return table 231 | -------------------------------------------------------------------------------- /src/models/panoptic_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import apex 7 | from panoptic.models.base_model import BaseModel 8 | from panoptic import get_loss, get_optimizer, get_scheduler 9 | from ..pcv.inference.mask_from_vote import MaskFromVote, MFV_CatSeparate 10 | 11 | from fabric.utils.logging import setup_logging 12 | logger = setup_logging(__file__) 13 | 14 | 15 | class PanopticBase(BaseModel): 16 | def __init__(self, cfg, pcv, dset_meta): 17 | super().__init__() 18 | self.cfg = cfg 19 | self.pcv = pcv 20 | self.dset_meta = dset_meta 21 | 22 | self.instantiate_network(cfg) 23 | assert self.net is not None, "instantiate network properly!" 24 | self.criteria = get_loss(cfg.loss) 25 | self.net.loss_module = self.criteria # fold crit into net 26 | 27 | # torch.cuda.synchronize() 28 | self.add_optimizer(cfg.optimizer) 29 | assert self.optimizer is not None # confirm the child has done the job 30 | # torch.cuda.synchronize() 31 | 32 | self.curr_epoch = 0 33 | self.total_train_epoch = cfg.scheduler.total_epochs 34 | self.scheduler = get_scheduler(cfg.scheduler)( 35 | optimizer=self.optimizer 36 | ) 37 | 38 | # used as temporary place holder for model training inputs 39 | self._latest_loss = None 40 | self.img, self.sem_mask, self.vote_mask, self.weight_mask = \ 41 | None, None, None, None 42 | 43 | def instantiate_network(self, cfg): 44 | # child class must overwrite the method 45 | raise NotImplementedError() 46 | 47 | def get_params_lr(self, initial_lr): 48 | return self.net.parameters() 49 | 50 | def add_optimizer(self, optim_cfg): 51 | """New Method for this class. Only retrive the module 52 | Children class should implement the function s.t self.optimizer 53 | is filled. The child class could add different learning rate for 54 | different params, etc 55 | """ 56 | optim_handle = get_optimizer(optim_cfg) 57 | # self.optimizer = optim_handle(params=self.net.parameters()) 58 | params_lr = self.get_params_lr(optim_cfg.params.lr) 59 | self.optimizer = optim_handle(params=params_lr) 60 | 61 | def ingest_train_input(self, *inputs): 62 | self.inputs = inputs 63 | 64 | def optimize_params(self): 65 | self.set_train_eval_state(True) 66 | self.optimizer.zero_grad() 67 | 68 | loss, sem_loss, vote_loss = self.net(*self.inputs) 69 | loss = loss.mean() # warning here. 70 | if getattr(self.cfg, 'apex', False): 71 | with apex.amp.scale_loss(loss, self.optimizer) as scaled_loss: 72 | scaled_loss.backward() 73 | else: 74 | loss.backward() 75 | self.optimizer.step() 76 | self._latest_loss = { 77 | 'loss': loss.item(), 78 | 'sem_loss': sem_loss.mean().item(), 79 | 'vote_loss': vote_loss.mean().item() 80 | } 81 | 82 | def latest_loss(self, get_numeric=False): 83 | if get_numeric: 84 | return self._latest_loss 85 | else: 86 | return {k: '{:.4f}'.format(v) for k, v in self._latest_loss.items()} 87 | 88 | def curr_lr(self): 89 | lr = [ group['lr'] for group in self.optimizer.param_groups ][0] 90 | return lr 91 | 92 | def advance_to_next_epoch(self): 93 | self.curr_epoch += 1 94 | 95 | def log_statistics(self, step, level=0): 96 | if self.log_writer is not None: 97 | for k, v in self.latest_loss(get_numeric=True).items(): 98 | self.log_writer.add_scalar(k, v, step) 99 | if level >= 1: 100 | curr_lr = self.curr_lr() 101 | self.log_writer.add_scalar('lr', curr_lr, step) 102 | 103 | def infer(self, x, *args, **kwargs): 104 | return self._infer(x, *args, **kwargs) 105 | 106 | def _infer( 107 | self, x, upsize_sem=False, take_argmax=False, softmax_normalize=False 108 | ): 109 | self.set_train_eval_state(False) 110 | with torch.no_grad(): 111 | sem_pred, vote_pred = self.net(x) 112 | 113 | if upsize_sem: 114 | sem_pred = F.interpolate( 115 | sem_pred, scale_factor=4, mode='nearest' 116 | ) 117 | 118 | if take_argmax: 119 | sem_pred, vote_pred = sem_pred.max(1)[1], vote_pred.max(1)[1] #sem_pred.argmax(1), vote_pred.argmax(1) 120 | return sem_pred, vote_pred 121 | 122 | if softmax_normalize: 123 | sem_pred = F.softmax(sem_pred, dim=1) 124 | vote_pred = F.softmax(vote_pred, dim=1) 125 | 126 | return sem_pred, vote_pred 127 | 128 | def _flip_infer(self, x, *args, **kwargs): 129 | # x: 1x3xhxw 130 | new_x = torch.cat([x, x.flip(-1)], 0) 131 | _sem_pred, _vote_pred = self._infer(new_x, *args, **kwargs) 132 | 133 | if not hasattr(self, 'flip_infer_mapping'): 134 | original_mask = torch.from_numpy(self.pcv.vote_mask) 135 | new_mask = original_mask.flip(-1) 136 | mapping = {} 137 | for ii, jj in zip(new_mask.view(-1).tolist(), original_mask.view(-1).tolist()): 138 | if ii in mapping: 139 | assert mapping[ii] == jj 140 | else: 141 | mapping[ii] = jj 142 | mapping[len(mapping)] = len(mapping) #abstain 143 | self.flip_infer_mapping = torch.LongTensor([mapping[_] for _ in range(len(mapping))]) 144 | 145 | # sem_pred = (_sem_pred[:1] + _sem_pred[1:].flip(-1)) / 2 146 | sem_pred = F.softmax((_sem_pred[:1].log() + _sem_pred[1:].flip(-1).log()) / 2, dim=1) 147 | # vote_pred = (_vote_pred[:1] + _vote_pred[1:, self.flip_infer_mapping].flip(-1)) / 2 148 | vote_pred = F.softmax((_vote_pred[:1].log() + _vote_pred[1:, self.flip_infer_mapping].flip(-1).log()) / 2, dim=1) 149 | return sem_pred, vote_pred 150 | 151 | def stitch_pan_mask(self, infer_cfg, sem_pred, vote_pred, target_size=None, return_hmap=False): 152 | """ 153 | Args: 154 | sem_pred: [1, num_class, H, W] torch gpu tsr 155 | vote_pred: [1, num_bins, H, W] torch gpu tsr 156 | target_size: optionally postprocess the output, resize, filter 157 | stuff predictions on threshold, etc 158 | """ 159 | assert self.dset_meta is not None and self.pcv is not None 160 | # make the meta data actually required explicit!! 161 | infer_m = MaskFromVote if infer_cfg.num_groups == 1 else MFV_CatSeparate 162 | mfv = infer_m(infer_cfg, self.dset_meta, self.pcv, sem_pred, vote_pred) 163 | pan_mask, meta = mfv.infer_panoptic_mask() 164 | vote_hmap = mfv.vote_hmap 165 | # pan_mask, meta = self.pcv.mask_from_sem_vote_tsr( 166 | # self.dset_meta, sem_pred, vote_pred 167 | # ) don't go the extra route. Be straightforward 168 | 169 | if target_size is not None: 170 | stuff_filt_thresh = infer_cfg.get( 171 | 'stuff_pred_thresh', self.dset_meta['stuff_pred_thresh'] 172 | ) 173 | pan_mask, meta = self._post_process_pan_pred( 174 | pan_mask, meta, target_size, stuff_filt_thresh 175 | ) 176 | if not return_hmap: 177 | return pan_mask, meta 178 | else: 179 | return pan_mask, meta, vote_hmap 180 | 181 | @staticmethod 182 | def _post_process_pan_pred(pan_mask, pan_meta, target_size, stuff_thresh=-1): 183 | """Assume that pan_mask is np array and target is a PIL Image 184 | adjust annotations as needed when size change erases instances 185 | """ 186 | pan_mask = Image.fromarray(pan_mask) 187 | pan_mask_size = pan_mask.size 188 | pan_mask = np.array(pan_mask.resize(target_size, resample=Image.NEAREST)) 189 | 190 | # account for lost segments due to downsizing 191 | if pan_mask_size[0] > target_size[0]: 192 | # downsizing is the only case where instances could be erased 193 | remains = np.unique(pan_mask) 194 | segs = pan_meta['segments_info'] 195 | acc = [] 196 | for seg in segs: 197 | if seg['id'] not in remains: 198 | logger.warn('segment erased due to downsizing') 199 | else: 200 | acc.append(seg) 201 | pan_meta['segments_info'] = acc 202 | 203 | # filter out stuff segments based on area threshold 204 | segs = pan_meta['segments_info'] 205 | acc = [] 206 | for seg in segs: 207 | if seg['isthing'] == 0: 208 | _mask = (pan_mask == seg['id']) 209 | area = _mask.sum() 210 | if area > stuff_thresh: 211 | acc.append(seg) 212 | else: 213 | pan_mask[_mask] = 0 214 | else: 215 | acc.append(seg) 216 | pan_meta['segments_info'] = acc 217 | 218 | return pan_mask, pan_meta 219 | 220 | def load_latest_checkpoint_if_available(self, manager, direct_ckpt=None): 221 | ckpt = manager.load_latest() if direct_ckpt is None else direct_ckpt 222 | if ckpt: 223 | self.curr_epoch = ckpt['curr_epoch'] 224 | self.net.module.load_state_dict(ckpt['model_state']) 225 | self.optimizer.load_state_dict(ckpt['optim_state']) 226 | # scheduler is incremented by global_step in the training loop 227 | self.scheduler = get_scheduler(self.cfg.scheduler)( 228 | optimizer=self.optimizer, 229 | ) 230 | logger.info("loaded checkpoint that completes epoch {}".format( 231 | ckpt['curr_epoch'])) 232 | self.curr_epoch += 1 233 | else: 234 | logger.info("No checkpoint found") 235 | 236 | def write_checkpoint(self, manager): 237 | ckpt_obj = dict( 238 | curr_epoch=self.curr_epoch, 239 | model_state=self.net.module.state_dict(), 240 | optim_state=self.optimizer.state_dict(), 241 | ) 242 | manager.save(self.curr_epoch, ckpt_obj) 243 | 244 | 245 | class AbstractNet(nn.Module): 246 | def forward(self, *inputs): 247 | """ 248 | If gt is supplied, then compute loss 249 | """ 250 | x = inputs[0] 251 | x = self.stage1(x) # usually are resnet, hrnet, mobilenet etc. 252 | x = self.stage2(x) # like fpn, aspp, ups, etc. 253 | sem_pred = self.sem_classifier(x[0]) 254 | vote_pred = self.vote_classifier(x[1]) 255 | 256 | if len(inputs) > 1: 257 | assert self.loss_module is not None 258 | loss = self.loss_module(sem_pred, vote_pred, *inputs[1:]) 259 | return loss 260 | else: 261 | return sem_pred, vote_pred 262 | -------------------------------------------------------------------------------- /src/pan_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | import os 7 | import numpy as np 8 | import json 9 | import time 10 | from collections import defaultdict 11 | import argparse 12 | import multiprocessing 13 | 14 | import PIL.Image as Image 15 | 16 | from panopticapi.utils import get_traceback, rgb2id, id2rgb 17 | 18 | OFFSET = 256 * 256 * 256 19 | VOID = 0 20 | 21 | 22 | class PQStatCat(): 23 | def __init__(self): 24 | self.iou = 0.0 25 | self.tp = 0 26 | self.fp = 0 27 | self.fn = 0 28 | 29 | def __iadd__(self, pq_stat_cat): 30 | self.iou += pq_stat_cat.iou 31 | self.tp += pq_stat_cat.tp 32 | self.fp += pq_stat_cat.fp 33 | self.fn += pq_stat_cat.fn 34 | return self 35 | 36 | 37 | class PQStat(): 38 | def __init__(self): 39 | self.pq_per_cat = defaultdict(PQStatCat) 40 | 41 | def __getitem__(self, i): 42 | return self.pq_per_cat[i] 43 | 44 | def __iadd__(self, pq_stat): 45 | for label, pq_stat_cat in pq_stat.pq_per_cat.items(): 46 | self.pq_per_cat[label] += pq_stat_cat 47 | return self 48 | 49 | def pq_average(self, categories, isthing): 50 | pq, sq, rq, n = 0, 0, 0, 0 51 | per_class_results = {} 52 | for label, label_info in categories.items(): 53 | name = label_info['name'] 54 | if isthing is not None: 55 | cat_isthing = label_info['isthing'] == 1 56 | if isthing != cat_isthing: 57 | continue 58 | iou = self.pq_per_cat[label].iou 59 | tp = self.pq_per_cat[label].tp 60 | fp = self.pq_per_cat[label].fp 61 | fn = self.pq_per_cat[label].fn 62 | if tp + fp + fn == 0: 63 | per_class_results[label] = {'pq': 0.0, 'sq': 0.0, 'rq': 0.0} 64 | continue 65 | n += 1 66 | pq_class = iou / (tp + 0.5 * fp + 0.5 * fn) 67 | sq_class = iou / tp if tp != 0 else 0 68 | rq_class = tp / (tp + 0.5 * fp + 0.5 * fn) 69 | per_class_results[name] = { 70 | 'pq': pq_class, 'sq': sq_class, 'rq': rq_class 71 | } 72 | pq += pq_class 73 | sq += sq_class 74 | rq += rq_class 75 | 76 | return {'pq': pq / n, 'sq': sq / n, 'rq': rq / n, 'n': n}, per_class_results 77 | 78 | 79 | def pq_compute_single_img(categories, gt_ann, pan_gt, pred_ann, pan_pred): 80 | pq_stat = PQStat() 81 | gt_segms = {el['id']: el for el in gt_ann['segments_info']} 82 | pred_segms = {el['id']: el for el in pred_ann['segments_info']} 83 | 84 | # predicted segments area calculation + prediction sanity checks 85 | pred_labels_set = set(el['id'] for el in pred_ann['segments_info']) 86 | labels, labels_cnt = np.unique(pan_pred, return_counts=True) 87 | for label, label_cnt in zip(labels, labels_cnt): 88 | if label not in pred_segms: 89 | if label == VOID: 90 | continue 91 | raise KeyError('In the image with ID {} segment with ID {} is presented in PNG and not presented in JSON.'.format(gt_ann['image_id'], label)) 92 | pred_segms[label]['area'] = int(label_cnt) 93 | pred_labels_set.remove(label) 94 | if pred_segms[label]['category_id'] not in categories: 95 | raise KeyError('In the image with ID {} segment with ID {} has unknown category_id {}.'.format(gt_ann['image_id'], label, pred_segms[label]['category_id'])) 96 | if len(pred_labels_set) != 0: 97 | raise KeyError('In the image with ID {} the following segment IDs {} are presented in JSON and not presented in PNG.'.format(gt_ann['image_id'], list(pred_labels_set))) 98 | 99 | # confusion matrix calculation 100 | pan_gt_pred = pan_gt.astype(np.uint64) * OFFSET + pan_pred.astype(np.uint64) 101 | gt_pred_map = {} 102 | labels, labels_cnt = np.unique(pan_gt_pred, return_counts=True) 103 | for label, intersection in zip(labels, labels_cnt): 104 | gt_id = label // OFFSET 105 | pred_id = label % OFFSET 106 | gt_pred_map[(gt_id, pred_id)] = intersection 107 | 108 | # count all matched pairs 109 | gt_matched = set() 110 | pred_matched = set() 111 | for label_tuple, intersection in gt_pred_map.items(): 112 | gt_label, pred_label = label_tuple 113 | if gt_label not in gt_segms: 114 | continue 115 | if pred_label not in pred_segms: 116 | continue 117 | if gt_segms[gt_label]['iscrowd'] == 1: 118 | continue 119 | if gt_segms[gt_label]['category_id'] != pred_segms[pred_label]['category_id']: 120 | continue 121 | 122 | union = pred_segms[pred_label]['area'] + gt_segms[gt_label]['area'] - intersection - gt_pred_map.get((VOID, pred_label), 0) 123 | iou = intersection / union 124 | if iou > 0.5: 125 | pq_stat[gt_segms[gt_label]['category_id']].tp += 1 126 | pq_stat[gt_segms[gt_label]['category_id']].iou += iou 127 | gt_matched.add(gt_label) 128 | pred_matched.add(pred_label) 129 | 130 | # count false positives 131 | crowd_labels_dict = {} 132 | for gt_label, gt_info in gt_segms.items(): 133 | if gt_label in gt_matched: 134 | continue 135 | # crowd segments are ignored 136 | if gt_info['iscrowd'] == 1: 137 | crowd_labels_dict[gt_info['category_id']] = gt_label 138 | continue 139 | pq_stat[gt_info['category_id']].fn += 1 140 | 141 | # count false positives 142 | for pred_label, pred_info in pred_segms.items(): 143 | if pred_label in pred_matched: 144 | continue 145 | # intersection of the segment with VOID 146 | intersection = gt_pred_map.get((VOID, pred_label), 0) 147 | # plus intersection with corresponding CROWD region if it exists 148 | if pred_info['category_id'] in crowd_labels_dict: 149 | intersection += gt_pred_map.get((crowd_labels_dict[pred_info['category_id']], pred_label), 0) 150 | # predicted segment is ignored if more than half of the segment correspond to VOID and CROWD regions 151 | if intersection / pred_info['area'] > 0.5: 152 | continue 153 | pq_stat[pred_info['category_id']].fp += 1 154 | return pq_stat 155 | 156 | 157 | @get_traceback 158 | def pq_compute_single_core(proc_id, annotation_set, gt_folder, pred_folder, categories): 159 | pq_stat = PQStat() 160 | idx = 0 161 | for gt_ann, pred_ann in annotation_set: 162 | if idx % 100 == 0: 163 | print('Core: {}, {} from {} images processed'.format(proc_id, idx, len(annotation_set))) 164 | idx += 1 165 | 166 | pan_gt = np.array(Image.open(os.path.join(gt_folder, gt_ann['file_name'])), dtype=np.uint32) 167 | pan_gt = rgb2id(pan_gt) 168 | pan_pred = np.array(Image.open(os.path.join(pred_folder, pred_ann['file_name'])), dtype=np.uint32) 169 | pan_pred = rgb2id(pan_pred) 170 | 171 | # downsize pred and upsize back; resolve the lost segments 172 | ratio = 8 173 | if ratio > 1: 174 | pan_pred = id2rgb(pan_pred) 175 | pan_pred = Image.fromarray(pan_pred) 176 | h, w = pan_pred.size 177 | 178 | pan_pred = pan_pred\ 179 | .resize((h // ratio, w // ratio), Image.NEAREST)\ 180 | .resize((h, w), Image.NEAREST) 181 | pan_pred = np.array(pan_pred, dtype=np.uint32) 182 | pan_pred = rgb2id(pan_pred) 183 | acc = [] 184 | _p_segs = pred_ann['segments_info'] 185 | for el in _p_segs: 186 | iid = el['id'] 187 | if iid not in pan_pred: 188 | continue 189 | acc.append(el) 190 | pred_ann['segments_info'] = acc 191 | 192 | _single_img_stat = pq_compute_single_img( 193 | categories, gt_ann, pan_gt, pred_ann, pan_pred) 194 | pq_stat += _single_img_stat 195 | 196 | print('Core: {}, all {} images processed'.format(proc_id, len(annotation_set))) 197 | return pq_stat 198 | 199 | 200 | def pq_compute_multi_core(matched_annotations_list, gt_folder, pred_folder, categories): 201 | cpu_num = multiprocessing.cpu_count() 202 | annotations_split = np.array_split(matched_annotations_list, cpu_num) 203 | print("Number of cores: {}, images per core: {}".format(cpu_num, len(annotations_split[0]))) 204 | workers = multiprocessing.Pool(processes=cpu_num) 205 | processes = [] 206 | for proc_id, annotation_set in enumerate(annotations_split): 207 | p = workers.apply_async(pq_compute_single_core, 208 | (proc_id, annotation_set, gt_folder, pred_folder, categories)) 209 | processes.append(p) 210 | pq_stat = PQStat() 211 | for p in processes: 212 | pq_stat += p.get() 213 | return pq_stat 214 | 215 | 216 | def pq_compute(gt_json_file, pred_json_file, gt_folder=None, pred_folder=None): 217 | 218 | start_time = time.time() 219 | with open(gt_json_file, 'r') as f: 220 | gt_json = json.load(f) 221 | with open(pred_json_file, 'r') as f: 222 | pred_json = json.load(f) 223 | 224 | if gt_folder is None: 225 | gt_folder = gt_json_file.replace('.json', '') 226 | if pred_folder is None: 227 | pred_folder = pred_json_file.replace('.json', '') 228 | categories = {el['id']: el for el in gt_json['categories']} 229 | 230 | print("Evaluation panoptic segmentation metrics:") 231 | print("Ground truth:") 232 | print("\tSegmentation folder: {}".format(gt_folder)) 233 | print("\tJSON file: {}".format(gt_json_file)) 234 | print("Prediction:") 235 | print("\tSegmentation folder: {}".format(pred_folder)) 236 | print("\tJSON file: {}".format(pred_json_file)) 237 | 238 | if not os.path.isdir(gt_folder): 239 | raise Exception("Folder {} with ground truth segmentations doesn't exist".format(gt_folder)) 240 | if not os.path.isdir(pred_folder): 241 | raise Exception("Folder {} with predicted segmentations doesn't exist".format(pred_folder)) 242 | 243 | pred_annotations = {el['image_id']: el for el in pred_json['annotations']} 244 | matched_annotations_list = [] 245 | for gt_ann in gt_json['annotations']: 246 | image_id = gt_ann['image_id'] 247 | if image_id not in pred_annotations: 248 | raise Exception('no prediction for the image with id: {}'.format(image_id)) 249 | matched_annotations_list.append((gt_ann, pred_annotations[image_id])) 250 | 251 | pq_stat = pq_compute_multi_core(matched_annotations_list, gt_folder, pred_folder, categories) 252 | 253 | metrics = [("All", None), ("Things", True), ("Stuff", False)] 254 | results = {} 255 | for name, isthing in metrics: 256 | results[name], per_class_results = pq_stat.pq_average(categories, isthing=isthing) 257 | if name == 'All': 258 | results['per_class'] = per_class_results 259 | print("{:10s}| {:>5s} {:>5s} {:>5s} {:>5s}".format("", "PQ", "SQ", "RQ", "N")) 260 | print("-" * (10 + 7 * 4)) 261 | 262 | for name, _isthing in metrics: 263 | print("{:10s}| {:5.1f} {:5.1f} {:5.1f} {:5d}".format( 264 | name, 265 | 100 * results[name]['pq'], 266 | 100 * results[name]['sq'], 267 | 100 * results[name]['rq'], 268 | results[name]['n']) 269 | ) 270 | 271 | t_delta = time.time() - start_time 272 | print("Time elapsed: {:0.2f} seconds".format(t_delta)) 273 | 274 | return results 275 | 276 | 277 | if __name__ == "__main__": 278 | parser = argparse.ArgumentParser() 279 | parser.add_argument('--gt_json_file', type=str, 280 | help="JSON file with ground truth data") 281 | parser.add_argument('--pred_json_file', type=str, 282 | help="JSON file with predictions data") 283 | parser.add_argument('--gt_folder', type=str, default=None, 284 | help="Folder with ground turth COCO format segmentations. \ 285 | Default: X if the corresponding json file is X.json") 286 | parser.add_argument('--pred_folder', type=str, default=None, 287 | help="Folder with prediction COCO format segmentations. \ 288 | Default: X if the corresponding json file is X.json") 289 | args = parser.parse_args() 290 | pq_compute(args.gt_json_file, args.pred_json_file, args.gt_folder, args.pred_folder) 291 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from . import cfg 6 | from .datasets.base import ignore_index 7 | from detectron2.layers.batch_norm import AllReduce 8 | import torch.distributed as dist 9 | 10 | 11 | class Loss(nn.Module): 12 | def __init__(self, w_vote, w_sem): 13 | super().__init__() 14 | self.w_vote = w_vote 15 | self.w_sem = w_sem 16 | 17 | def forward(self, *args, **kwargs): 18 | '''this 3-step interface is enforced for decomposable loss visulization 19 | see /vis.py 20 | ''' 21 | sem_loss_mask, vote_loss_mask = self.per_pix_loss(*args, **kwargs) 22 | if len(args) == 5: 23 | sem_weight_mask = vote_weight_mask = args[-1] # this is risky. replying on the fact that if 24 | else: 25 | sem_weight_mask, vote_weight_mask = args[-2:] 26 | # weight_mask is not needed, then just supplying an ignored parameter 27 | sem_loss_mask, vote_loss_mask = self.normalize( 28 | sem_loss_mask, vote_loss_mask, sem_weight_mask, vote_weight_mask 29 | ) 30 | loss, s_loss, v_loss = self.aggregate(sem_loss_mask, vote_loss_mask) 31 | return loss, s_loss, v_loss 32 | 33 | def per_pix_loss(self, *args, **kwargs): 34 | '''compute raw per pixel loss that reflects how good a pixel decision is 35 | do not apply any segment based normalization yet 36 | ''' 37 | raise NotImplementedError() 38 | 39 | def normalize(self, sem_loss_mask, vote_loss_mask, weight_mask): 40 | raise NotImplementedError() 41 | 42 | def aggregate(self, sem_loss_mask, vote_loss_mask): 43 | s_loss, v_loss = sem_loss_mask.sum(), vote_loss_mask.sum() 44 | loss = self.w_sem * s_loss + self.w_vote * v_loss 45 | return loss, s_loss, v_loss 46 | 47 | 48 | def loss_normalize(loss, weight_mask, local_batch_size): 49 | '''An abomination of a method. Have to live with it for now but 50 | I will delete it soon enough. This 'total' option makes no sense except to 51 | serve the expedience of making the loss small 52 | ''' 53 | norm = cfg.loss.normalize 54 | if norm == 'batch': 55 | loss = (loss * weight_mask) / local_batch_size 56 | elif norm == 'total': 57 | loss = (loss * weight_mask) / (weight_mask.sum() + 1e-10) 58 | elif norm == 'segs_across_batch': 59 | global_seg_cnt = weight_mask.sum() 60 | if torch.distributed.is_initialized(): 61 | global_seg_cnt = AllReduce.apply(global_seg_cnt) / dist.get_world_size() 62 | # print('local sum {:.2f} vs global sum {:.2f}'.format(weight_mask.sum(), weight_mask_sum)) 63 | loss = (loss * weight_mask) / (global_seg_cnt + 1e-10) 64 | else: 65 | raise ValueError('invalid value for normalization: {}'.format(norm)) 66 | return loss 67 | 68 | 69 | class PanLoss(Loss): 70 | def per_pix_loss(self, sem_pred, vote_pred, sem_mask, vote_mask, *args, **kwargs): 71 | sem_loss_mask = F.cross_entropy( 72 | sem_pred, sem_mask, ignore_index=ignore_index, reduction='none' 73 | ) 74 | vote_loss_mask = F.cross_entropy( 75 | vote_pred, vote_mask, ignore_index=ignore_index, reduction='none' 76 | ) 77 | vote_loss_mask = vote_loss_mask # * 0.1 78 | return sem_loss_mask, vote_loss_mask 79 | 80 | def normalize(self, sem_loss_mask, vote_loss_mask, *args, **kwargs): 81 | sem_loss_mask = sem_loss_mask / sem_loss_mask.numel() 82 | vote_loss_mask = vote_loss_mask / vote_loss_mask.numel() 83 | return sem_loss_mask, vote_loss_mask 84 | 85 | 86 | class MaskedPanLoss(PanLoss): 87 | def normalize(self, sem_loss_mask, vote_loss_mask, sem_weight_mask, vote_weight_mask): 88 | assert len(sem_loss_mask) == len(vote_loss_mask) 89 | assert sem_weight_mask is not None and vote_weight_mask is not None 90 | batch_size = len(sem_loss_mask) 91 | sem_loss_mask = loss_normalize(sem_loss_mask, sem_weight_mask, batch_size) 92 | vote_loss_mask = loss_normalize(vote_loss_mask, vote_weight_mask, batch_size) 93 | return sem_loss_mask, vote_loss_mask 94 | 95 | 96 | class DeeperlabPanLoss(PanLoss): 97 | def normalize(self, sem_loss_mask, vote_loss_mask, sem_weight_mask, vote_weight_mask): 98 | assert len(sem_loss_mask) == len(vote_loss_mask) 99 | assert sem_weight_mask is not None and vote_weight_mask is not None 100 | batch_size = len(sem_loss_mask) 101 | sem_loss_mask = deeperlab_loss_normalize(sem_loss_mask, sem_weight_mask, batch_size) 102 | vote_loss_mask = deeperlab_loss_normalize(vote_loss_mask, vote_weight_mask, batch_size) 103 | return sem_loss_mask, vote_loss_mask 104 | 105 | 106 | def deeperlab_loss_normalize(loss, weight_mask, local_batch_size): 107 | new_weight = (weight_mask > 1/16/16).float()*2+1 # == area < 16x16 108 | flat_loss = loss.reshape(loss.shape[0], -1) # b x h x w -> b x hw 109 | new_weight = new_weight.reshape(new_weight.shape[0], -1) 110 | 111 | topk_loss, topk_inx = flat_loss.topk(int(flat_loss.shape[-1] * 0.15), sorted=False, dim=-1) 112 | topk_weight = new_weight.gather(1, topk_inx) 113 | 114 | loss = (topk_loss * topk_weight).mean() 115 | return loss 116 | 117 | # class TsrCoalesceLoss(Loss): 118 | # def per_pix_loss( 119 | # self, sem_pred, vote_pred, sem_mask, vote_mask, vote_bool_tsr, weight_mask 120 | # ): 121 | # del vote_mask # vote_mask is accepted merely for parameter compatibility 122 | # sem_loss_mask = self.regular_ce_loss(sem_pred, sem_mask, weight_mask) 123 | # vote_loss_mask = self.unsophisticated_loss(vote_pred, vote_bool_tsr, weight_mask) 124 | # return sem_loss_mask, vote_loss_mask 125 | 126 | # @staticmethod 127 | # def regular_ce_loss(pred, lbls, weight_mask): 128 | # return MaskedPanLoss._single_loss(pred, lbls, weight_mask) 129 | 130 | # @staticmethod 131 | # def booltsr_loss(pred, bool_tsr, weight_mask): 132 | # raise ValueError('cannot be back-propagated') 133 | # is_valid = bool_tsr.any(dim=1) # [N, H, W] 134 | # weight_mask = weight_mask[is_valid] # [num_valid, ] 135 | # # pred = pred.permute(0, 2, 3, 1) 136 | # # bool_tsr = bool_tsr.permute(0, 2, 3, 1) 137 | # # pred, bool_tsr = pred[is_valid], bool_tsr[is_valid] # [num_valid, C] 138 | 139 | # bottom = torch.logsumexp(pred, dim=1) 140 | # pred = torch.where(bool_tsr, pred, torch.tensor(float('-inf')).cuda()) 141 | # pred = torch.logsumexp(pred, dim=1) 142 | # loss = (bottom - pred)[is_valid] # -1 is implicit here by reversing order 143 | # loss = (loss * weight_mask).sum() / weight_mask.sum() 144 | # return loss 145 | 146 | # @staticmethod 147 | # def unsophisticated_loss(pred, bool_tsr, weight_mask): 148 | # raise ValueError('validity mask changes spatial shape; embarassing') 149 | # is_valid = bool_tsr.any(dim=1) # [N, H, W] 150 | # weight_mask = weight_mask[is_valid] 151 | 152 | # pred = F.softmax(pred, dim=1) 153 | # pred = torch.where(bool_tsr, pred, torch.tensor(0.).cuda()) 154 | # loss = torch.log(pred.sum(dim=1)[is_valid]) 155 | # loss = loss_normalize(loss, weight_mask, len(pred)) 156 | # loss = -1 * loss 157 | # return loss 158 | 159 | 160 | # class MaskedKLPanLoss(PanLoss): 161 | # ''' 162 | # cross entropy loss for semantic segmentation and KL-divergence loss for 163 | # voting 164 | # ''' 165 | # def per_pix_loss( 166 | # self, sem_pred, vote_pred, 167 | # sem_mask, vote_mask, vote_gt_prob, weight_mask 168 | # ): 169 | # sem_loss_mask = self.sem_loss(sem_pred, sem_mask, weight_mask) 170 | # vote_loss_mask = self.vote_loss(vote_pred, vote_gt_prob, weight_mask) 171 | # return sem_loss_mask, vote_loss_mask 172 | 173 | # @staticmethod 174 | # def sem_loss(sem_pred, sem_mask, weight_mask): 175 | # loss = F.cross_entropy( 176 | # sem_pred, sem_mask, ignore_index=ignore_index, reduction='none' 177 | # ) 178 | # loss = loss_normalize(loss, weight_mask, len(sem_pred)) 179 | # return loss 180 | 181 | # @staticmethod 182 | # def vote_loss(vote_pred, vote_gt_prob, weight_mask): 183 | # loss = F.kl_div( 184 | # F.log_softmax(vote_pred, dim=1), vote_gt_prob, reduction='none' 185 | # ) 186 | # loss = loss.sum(dim=1) 187 | # loss = loss_normalize(loss, weight_mask, len(vote_pred)) 188 | # return loss 189 | 190 | 191 | # class _CELoss(nn.Module): 192 | # def __init__(self): 193 | # super().__init__() 194 | 195 | # def forward(self, pred, mask): 196 | # """Assume that ignore index is set at 0 197 | # pred: [N, C, H, W] 198 | # mask: [N, H, W] 199 | # """ 200 | # loss = F.cross_entropy( 201 | # pred, mask, ignore_index=ignore_index, reduction='mean' 202 | # ) 203 | # return loss 204 | 205 | 206 | class NormalizedFocalPanLoss(nn.Module): 207 | # Adapt from adpatis 208 | def __init__(self, w_vote=0.5, w_sem=0.5, gamma=0, alpha=1): 209 | super().__init__() 210 | # assert w_vote + w_sem == 1 211 | self.w_vote = w_vote 212 | self.w_sem = w_sem 213 | self.gamma = gamma 214 | 215 | def focal_loss(self, input, target): 216 | logpt = - F.cross_entropy(input, target, ignore_index=ignore_index, reduction='none') 217 | pt = torch.exp(logpt) 218 | 219 | beta = (1 - pt) ** self.gamma 220 | 221 | t = target != ignore_index 222 | t_sum = t.float().sum(axis=[-2, -1], keepdims=True) 223 | beta_sum = beta.sum(axis=[-2, -1], keepdims=True) 224 | 225 | eps = 1e-10 226 | 227 | mult = t_sum / (beta_sum + eps) 228 | if True: 229 | mult = mult.detach() 230 | beta = beta * mult 231 | 232 | loss = - beta * logpt.clamp(min=-20) # B x H x W 233 | 234 | # size average 235 | loss = loss.sum(axis=[-2,-1]) / t_sum.squeeze() 236 | loss = loss.sum() 237 | 238 | return loss 239 | 240 | def forward(self, sem_pred, vote_pred, sem_mask, vote_mask, *args, **kwargs): 241 | sem_loss = self.focal_loss(sem_pred, sem_mask) 242 | vote_loss = self.focal_loss(vote_pred, vote_mask) 243 | loss = self.w_vote * vote_loss + self.w_sem * sem_loss 244 | return loss, sem_loss, vote_loss 245 | 246 | 247 | # class MaskedPanLoss(Loss): 248 | # def forward(self, sem_pred, vote_pred, sem_mask, vote_mask, weight_mask, vote_weight_mask=None): 249 | # sem_loss = self._single_loss(sem_pred, sem_mask, weight_mask) 250 | # if vote_weight_mask is None: 251 | # vote_weight_mask = weight_mask 252 | # vote_loss = self._single_loss(vote_pred, vote_mask, vote_weight_mask) 253 | # loss = self.w_vote * vote_loss + self.w_sem * sem_loss 254 | # return loss, sem_loss, vote_loss 255 | 256 | # @staticmethod 257 | # def _single_loss(pred, lbls, weight_mask): 258 | # loss = F.cross_entropy( 259 | # pred, lbls, ignore_index=ignore_index, reduction='none' 260 | # ) 261 | # loss = loss_normalize(loss, weight_mask, len(pred)) 262 | # return loss 263 | 264 | 265 | # class TsrCoalesceLoss(nn.Module): 266 | # def __init__(self, w_vote=0.5, w_sem=0.5): 267 | # super().__init__() 268 | # self.w_vote = w_vote 269 | # self.w_sem = w_sem 270 | 271 | # def forward( 272 | # self, sem_pred, vote_pred, sem_mask, vote_mask, vote_bool_tsr, weight_mask 273 | # ): 274 | # del vote_mask 275 | # sem_loss = self.regular_ce_loss(sem_pred, sem_mask, weight_mask) 276 | # vote_loss = self.unsophisticated_loss(vote_pred, vote_bool_tsr, weight_mask) 277 | # loss = self.w_vote * vote_loss + self.w_sem * sem_loss 278 | # return loss, sem_loss, vote_loss 279 | 280 | # @staticmethod 281 | # def regular_ce_loss(pred, lbls, weight_mask): 282 | # return MaskedPanLoss._single_loss(pred, lbls, weight_mask) 283 | 284 | # @staticmethod 285 | # def booltsr_loss(pred, bool_tsr, weight_mask): 286 | # raise ValueError('cannot be back-propagated') 287 | # is_valid = bool_tsr.any(dim=1) # [N, H, W] 288 | # weight_mask = weight_mask[is_valid] # [num_valid, ] 289 | # # pred = pred.permute(0, 2, 3, 1) 290 | # # bool_tsr = bool_tsr.permute(0, 2, 3, 1) 291 | # # pred, bool_tsr = pred[is_valid], bool_tsr[is_valid] # [num_valid, C] 292 | 293 | # bottom = torch.logsumexp(pred, dim=1) 294 | # pred = torch.where(bool_tsr, pred, torch.tensor(float('-inf')).cuda()) 295 | # pred = torch.logsumexp(pred, dim=1) 296 | # loss = (bottom - pred)[is_valid] # -1 is implicit here by reversing order 297 | # loss = (loss * weight_mask).sum() / weight_mask.sum() 298 | # return loss 299 | 300 | # @staticmethod 301 | # def unsophisticated_loss(pred, bool_tsr, weight_mask): 302 | # is_valid = bool_tsr.any(dim=1) # [N, H, W] 303 | # weight_mask = weight_mask[is_valid] 304 | 305 | # pred = F.softmax(pred, dim=1) 306 | # pred = torch.where(bool_tsr, pred, torch.tensor(0.).cuda()) 307 | # loss = torch.log(pred.sum(dim=1)[is_valid]) 308 | # loss = loss_normalize(loss, weight_mask, len(pred)) 309 | # loss = -1 * loss 310 | # return loss 311 | -------------------------------------------------------------------------------- /src/pan_vis.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | from PIL import Image 4 | import matplotlib.pyplot as plt 5 | 6 | from IPython.display import display as ipy_display 7 | from ipywidgets import interactive 8 | import ipywidgets as widgets 9 | 10 | from panopticapi.utils import rgb2id 11 | from panoptic.pan_analyzer import ( 12 | PanopticEvalAnalyzer, 13 | _SEGMENT_MATCHED, _SEGMENT_UNMATCHED, _SEGMENT_FORGIVEN 14 | ) 15 | 16 | 17 | class PanVis(): 18 | def __init__(self, img_root, gt_json_meta_fname, pd_json_meta_fname): 19 | """Expect that the pan mask dir to be right beside the meta json file 20 | val.json 21 | val/ 22 | Args: 23 | img_root: root dir where images are stored 24 | gt_json_meta_fname: abs fname to gt json 25 | pd_json_meta_fname: ... 26 | """ 27 | self.img_root = img_root 28 | 29 | analyzer = PanopticEvalAnalyzer(gt_json_meta_fname, pd_json_meta_fname) 30 | self.gt, self.pd = analyzer.gt, analyzer.pd 31 | self.imgIds = analyzer.imgIds 32 | self.res_dframe, self.overall_table, self.cat_table = analyzer.summarize() 33 | # cached widgets 34 | self.global_walk = None 35 | self.__init_global_state__() 36 | 37 | def evaluate(self): 38 | # this is now a dud for backwards compatibility 39 | pass 40 | 41 | def summarize(self): 42 | print(self.cat_table) 43 | print(self.overall_table) 44 | 45 | def __init_global_state__(self): 46 | self.global_state = { 47 | 'imgId': None, 'segId': None, 48 | 'catId': None, 'tranche': None 49 | } 50 | 51 | # the following are the modules for widgets, from root to leaf 52 | def root_wdgt(self): 53 | """ 54 | root widget delegates to either global or image 55 | """ 56 | self.summarize() 57 | modes = ['Global', 'Single-Image'] 58 | 59 | def logic(mode): 60 | # cache the widget later 61 | if mode == modes[0]: 62 | if self.global_walk is None: 63 | self.global_walk = self.global_walk_specifier() 64 | ipy_display(self.global_walk) 65 | elif mode == modes[1]: 66 | self.image_view = self.single_image_selector() 67 | # if self.image_view is None: 68 | # self.image_view = self.single_image_selector() 69 | # ipy_display(self.image_view) 70 | 71 | UI = interactive( 72 | logic, mode=widgets.ToggleButtons(options=modes, value=modes[0]) 73 | ) 74 | UI.children[-1].layout.height = '1000px' 75 | ipy_display(UI) 76 | 77 | def global_walk_specifier(self): 78 | tranche_map = self._tranche_filter(self.gt.segs, self.pd.segs) 79 | 80 | def logic(catId, tranche): 81 | if self.global_state['catId'] != catId \ 82 | or self.global_state['tranche'] != tranche: 83 | self.__init_global_state__() 84 | self.global_state['catId'] = catId 85 | self.global_state['tranche'] = tranche 86 | seg_list = self._cat_filter_and_merge_tranche_map( 87 | tranche_map, [catId], [tranche] 88 | ) 89 | # areas = [ seg['area'] for seg in seg_list ] 90 | # plt.hist(areas) 91 | self.walk_primary(seg_list, is_global=True) 92 | UI = interactive( 93 | logic, 94 | catId=self._category_roulette(self.gt.cats.keys(), multi_select=False), 95 | tranche=widgets.Select( 96 | options=tranche_map.keys(), value=list(tranche_map.keys())[0] 97 | ) 98 | ) 99 | return UI 100 | 101 | def single_image_selector(self): 102 | imgIds = self.imgIds 103 | inx, txt = self._inx_txt_scroller_pair( 104 | imgIds, default_txt=self.global_state['imgId']) 105 | 106 | def logic(inx): 107 | print("curr image {}/{}".format(inx, len(imgIds))) 108 | imgId = imgIds[inx] 109 | self.single_image_view_specifier(imgId) 110 | UI = interactive(logic, inx=inx) 111 | ipy_display(txt) 112 | ipy_display(UI) 113 | 114 | def single_image_view_specifier(self, imgId): 115 | gt_segs, pd_segs = self.gt.img2seg[imgId], self.pd.img2seg[imgId] 116 | _gt_cats = {seg['category_id'] for seg in gt_segs.values()} 117 | _pd_cats = {seg['category_id'] for seg in pd_segs.values()} 118 | relevant_catIds = _gt_cats | _pd_cats 119 | tranche_map = self._tranche_filter(gt_segs, pd_segs) 120 | modes = ['bulk', 'walk'] 121 | 122 | def logic(catIds, tranches, mode): 123 | # only for walk, not for bulk display 124 | seg_list = self._cat_filter_and_merge_tranche_map( 125 | tranche_map, catIds, tranches) 126 | if mode == modes[0]: 127 | self.single_image_bulk_display(seg_list) 128 | elif mode == modes[1]: 129 | self.walk_primary(seg_list) 130 | UI = interactive( 131 | logic, 132 | mode=widgets.ToggleButtons(options=modes, value=modes[0]), 133 | catIds=self._category_roulette( 134 | relevant_catIds, multi_select=True, 135 | default_cid=[self.global_state['catId']] 136 | ), 137 | tranches=widgets.SelectMultiple( 138 | options=tranche_map.keys(), 139 | value=[self.global_state['tranche']] 140 | ) 141 | ) 142 | ipy_display(UI) 143 | 144 | def single_image_bulk_display(self, segs): 145 | if len(segs) == 0: 146 | return 'no segments in this tranche' 147 | imgId = segs[0]['image_id'] 148 | for seg in segs: 149 | assert seg['image_id'] == imgId 150 | segIds = list(map(lambda x: x['sid'], segs)) 151 | gt_seg_ids = list(filter(lambda x: x.startswith('gt/'), segIds)) 152 | pd_seg_ids = list(filter(lambda x: x.startswith('pd/'), segIds)) 153 | self.single_image_plot(imgId, gt_seg_ids, pd_seg_ids) 154 | 155 | def walk_primary(self, segs, is_global=False): 156 | """segs: a list of seg""" 157 | # the watching logic here is quite messy 158 | sids = [seg['sid'] for seg in segs] 159 | if len(sids) == 0: 160 | return 'no available segs' 161 | inx, txt = self._inx_txt_scroller_pair( 162 | sids, default_txt=self.global_state['segId'] if is_global else None 163 | ) 164 | 165 | def logic(inx): 166 | seg = segs[inx] 167 | if is_global: 168 | self.global_state['segId'] = seg['sid'] 169 | self.global_state['imgId'] = seg['image_id'] 170 | print("Primary seg {}/{} matches with {} segments".format( 171 | inx, len(segs), len(seg['matchings']))) 172 | self.walk_matched(seg) 173 | UI = interactive(logic, inx=inx) 174 | print("Primary Segment:") 175 | ipy_display(txt) 176 | ipy_display(UI) 177 | 178 | def walk_matched(self, ref_seg): 179 | """child of walk_primary""" 180 | ref_sid = ref_seg['sid'] 181 | # note that matchings is {sid: IoU} 182 | matched_sids = list(ref_seg['matchings'].keys()) 183 | matched_ious = list(ref_seg['matchings'].values()) 184 | if len(matched_sids) == 0: 185 | matched_sids = (None, ) 186 | matched_ious = (0, ) 187 | 188 | def segid_to_catname(partition, sid): 189 | if sid is None: 190 | return 'NA' 191 | return self.gt.cats[partition.segs[sid]['category_id']]['name'] 192 | 193 | def logic(inx): 194 | match_sid = matched_sids[inx] 195 | if ref_sid.startswith('gt/'): 196 | gt_sid, pd_sid, highlight = ref_sid, match_sid, 1 197 | imgId = self.gt.segs[ref_sid]['image_id'] 198 | else: 199 | gt_sid, pd_sid, highlight = match_sid, ref_sid, 2 200 | imgId = self.pd.segs[ref_sid]['image_id'] 201 | print('IoU: {:.3f}'.format(matched_ious[inx])) 202 | print('gt: {} vs pd: {}'.format( 203 | segid_to_catname(self.gt, gt_sid), 204 | segid_to_catname(self.pd, pd_sid) 205 | )) 206 | self.single_image_plot(imgId, gt_sid, pd_sid, highlight) 207 | 208 | inx, txt = self._inx_txt_scroller_pair(matched_sids) 209 | UI = interactive(logic, inx=inx) 210 | print("Matched Segment:") 211 | ipy_display(txt) 212 | ipy_display(UI) 213 | 214 | @staticmethod 215 | def _tranche_filter(gt_segs, pd_segs): 216 | """ 217 | Args: 218 | gt_segs: {segId: seg} 219 | pd_segs: {segId: seg} 220 | """ 221 | def _filter(state, seg_map): 222 | seg_list = [ 223 | seg for seg in seg_map.values() if seg['match_state'] == state 224 | ] 225 | seg_list = sorted(seg_list, key=lambda x: x['area'], reverse=True) 226 | return seg_list 227 | 228 | tranche_map = { 229 | 'TP': _filter(_SEGMENT_MATCHED, pd_segs), 230 | 'FN': _filter(_SEGMENT_UNMATCHED, gt_segs), 231 | 'FP': _filter(_SEGMENT_UNMATCHED, pd_segs), 232 | 'GT_FORGIVEN': _filter(_SEGMENT_FORGIVEN, gt_segs), 233 | 'PD_FORGIVEN': _filter(_SEGMENT_FORGIVEN, pd_segs) 234 | } 235 | assert len(gt_segs) == sum( 236 | map(lambda x: len(tranche_map[x]), ['TP', 'FN', 'GT_FORGIVEN']) 237 | ) 238 | assert len(pd_segs) == sum( 239 | map(lambda x: len(tranche_map[x]), ['TP', 'FP', 'PD_FORGIVEN']) 240 | ) 241 | return tranche_map 242 | 243 | @staticmethod 244 | def _cat_filter_and_merge_tranche_map(tranche_map, catIds, chosen_tranches): 245 | local_tranche_map = { 246 | k: list(filter(lambda seg: seg['category_id'] in catIds, seg_list)) 247 | for k, seg_list in tranche_map.items() 248 | } 249 | for k, v in local_tranche_map.items(): 250 | print("{}: {}".format(k, len(v)), end='; ') 251 | print('') 252 | seg_list = sum([local_tranche_map[_tr] for _tr in chosen_tranches], []) 253 | return seg_list 254 | 255 | @staticmethod 256 | def _inx_txt_scroller_pair(sids, default_txt=None): 257 | """ 258 | Args: 259 | sids: [str, ] segment ids 260 | Note that since a handler is only called if 'value' changes, this mutual 261 | watching would not lead to infinite back-and-forth bouncing. 262 | In addition, bouncing-back is prevented by internal_change flag. 263 | """ 264 | assert len(sids) > 0 265 | if default_txt is not None: 266 | default_inx, default_txt = sids.index(default_txt), default_txt 267 | else: 268 | default_inx, default_txt = 0, sids[0] 269 | inx = widgets.BoundedIntText(value=default_inx, min=0, max=len(sids) - 1) 270 | txt = widgets.Text(value=str(default_txt), description='ID') 271 | internal_change = False 272 | 273 | def inx_update_reaction(*args): 274 | nonlocal internal_change 275 | if internal_change: 276 | internal_change = False 277 | return 278 | curr_inx = inx.value 279 | curr_sid = sids[curr_inx] 280 | internal_change = True 281 | txt.value = curr_sid 282 | inx.observe(inx_update_reaction, 'value') 283 | 284 | def txt_update_reaction(*args): 285 | nonlocal internal_change 286 | if internal_change: 287 | internal_change = False 288 | return 289 | curr_sid = txt.value 290 | if curr_sid in sids: 291 | curr_inx = sids.index(curr_sid) 292 | internal_change = True 293 | inx.value = curr_inx 294 | txt.observe(txt_update_reaction, 'value') 295 | 296 | return inx, txt 297 | 298 | def _category_roulette( 299 | self, selected_catIds, multi_select=False, default_cid=None, 300 | ): 301 | """ 302 | Things first, Stuff next, each sorted from high to low by PQ 303 | Note that this roulette is multi-selective, and return a tuple of catIds 304 | e.g. 305 | T, 16.60, Person 306 | T, 15.12, Bicycle 307 | S, 32.10, Road 308 | """ 309 | catIds = np.array(sorted(self.gt.cats.keys())) 310 | PQ = self.res_dframe.values[:, 0] # (num_cats, ) 311 | 312 | # first filter by selection, then sort by PQ from high to low 313 | chosen_mask = np.array( 314 | [ catId in selected_catIds for catId in catIds ], dtype=np.bool) 315 | catIds, PQ = catIds[chosen_mask], PQ[chosen_mask] 316 | order = np.argsort(-PQ) # high to low 317 | catIds, PQ = catIds[order], PQ[order] 318 | 319 | # now do things first followed by stuff 320 | acc = [] 321 | isthing = np.array( 322 | [self.gt.cats[id]['isthing'] for id in catIds], dtype=bool) 323 | acc += [ 324 | ('T, {:>4.2f}, {}'.format(pq, self.gt.cats[cid]['name']), cid) 325 | for pq, cid in zip(PQ[isthing], catIds[isthing]) 326 | ] 327 | acc += [ 328 | ('S, {:>4.2f}, {}'.format(pq, self.gt.cats[cid]['name']), cid) 329 | for pq, cid in zip(PQ[~isthing], catIds[~isthing]) 330 | ] 331 | 332 | if default_cid is None: 333 | default_cid = acc[0][1] 334 | if multi_select and not isinstance(default_cid, (tuple, list)): 335 | default_cid = [default_cid] 336 | _module = widgets.SelectMultiple if multi_select else widgets.Select 337 | roulette = _module(options=acc, rows=15, value=default_cid) 338 | return roulette 339 | 340 | def single_image_plot( 341 | self, imgId, gt_seg_sid_list, pd_seg_sid_list, 342 | highlight=None, seg_alpha=0.7, seg_cmap='Blues' 343 | ): 344 | # first load image and annotations masks 345 | img = np.array(Image.open( 346 | osp.join(self.img_root, self.gt.imgs[imgId]['file_name']) 347 | )) 348 | gt_rgb = np.array( 349 | Image.open(osp.join( 350 | self.gt.mask_root, self.gt.imgs[imgId]['ann_fname'] 351 | )), 352 | dtype=np.uint32 353 | ) 354 | gt_mask = rgb2id(gt_rgb) 355 | pd_rgb = np.array( 356 | Image.open(osp.join( 357 | self.pd.mask_root, self.pd.imgs[imgId]['ann_fname'] 358 | )), 359 | dtype=np.uint32 360 | ) 361 | pd_mask = rgb2id(pd_rgb) 362 | 363 | # now aggregate the segment masks for both pd and gt 364 | def aggregate_seg_mask(sid_list, ref_mask, segs_map): 365 | if sid_list is None: 366 | sid_list = [] 367 | if not isinstance(sid_list, (list, tuple)): 368 | sid_list = (sid_list, ) 369 | seg_mask = np.zeros(ref_mask.shape, dtype=np.bool) 370 | for sid in sid_list: 371 | seg = segs_map[sid] 372 | assert seg['image_id'] == imgId 373 | seg_mask |= (ref_mask == seg['id']) 374 | return seg_mask 375 | 376 | gt_seg_mask = aggregate_seg_mask(gt_seg_sid_list, gt_mask, self.gt.segs) 377 | pd_seg_mask = aggregate_seg_mask(pd_seg_sid_list, pd_mask, self.pd.segs) 378 | 379 | # plot them together 380 | WHITE = [255, 255, 255] 381 | fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(30, 12)) 382 | axes[0].imshow(img) 383 | gt_rgb[gt_seg_mask] = WHITE 384 | axes[1].imshow(gt_rgb) 385 | # axes[1].imshow(gt_seg_mask, alpha=seg_alpha, cmap=seg_cmap) 386 | pd_rgb[pd_seg_mask] = WHITE 387 | axes[2].imshow(pd_rgb) 388 | # axes[2].imshow(pd_seg_mask, alpha=seg_alpha, cmap=seg_cmap) 389 | 390 | if highlight is not None: 391 | axes[highlight].set_title( 392 | 'frame in focus', bbox=dict(facecolor='orange') 393 | ) 394 | plt.show() 395 | -------------------------------------------------------------------------------- /src/datasets/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | import bisect 3 | import copy 4 | import os.path as osp 5 | import json 6 | from functools import partial 7 | 8 | import numpy as np 9 | from PIL import Image 10 | import cv2 11 | 12 | import torch 13 | from torch.utils.data import Dataset, DataLoader 14 | from torchvision.transforms.functional import normalize as tf_norm, to_tensor 15 | 16 | from .. import cfg as global_cfg 17 | from .augmentations import get_composed_augmentations 18 | from .samplers.grouped_batch_sampler import GroupedBatchSampler 19 | from .gt_producer import ( 20 | ignore_index, convert_pan_m_to_sem_m, GTGenerator 21 | ) 22 | from panopticapi.utils import rgb2id 23 | 24 | from fabric.utils.logging import setup_logging 25 | logger = setup_logging(__file__) 26 | 27 | ''' 28 | Expect a data root directory with the following layout (only coco sub-tree is 29 | expanded for simplicity) 30 | . 31 | ├── coco 32 | │   ├── annotations 33 | │   │   ├── train/ 34 | │   │   ├── train.json 35 | │   │   ├── val/ 36 | │   │   └── val.json 37 | │   └── images 38 | │   ├── train/ 39 | │   └── val/ 40 | ├── mapillary/ 41 | └── cityscapes/ 42 | ''' 43 | 44 | 45 | ''' 46 | A reminder of Panoptic meta json structure: (only the relevant fields are listed) 47 | 48 | info/ 49 | licenses/ 50 | categories: 51 | - id: 1 52 | name: person 53 | supercategory: person 54 | isthing: 1 55 | - id: 191 56 | name: pavement-merged 57 | supercategory: ground 58 | isthing: 0 59 | images: 60 | - id: 397133 61 | file_name: 000000397133.jpg 62 | height: 427 63 | width: 640 64 | - id: 397134 65 | file_name: 000000397134.jpg 66 | height: 422 67 | width: 650 68 | annotations: 69 | - image_id: 139 70 | file_name: 000000000139.png 71 | segments_info: 72 | - id: 3226956 73 | category_id: 1 74 | iscrowd: 0, 75 | bbox: [413, 158, 53, 138] in xywh form, 76 | area: 2840 77 | - repeat omitted 78 | ''' 79 | 80 | imagenet_mean = [0.485, 0.456, 0.406] 81 | imagenet_std = [0.229, 0.224, 0.225] 82 | 83 | 84 | def imagenet_normalize(im): 85 | """This only operates on single image""" 86 | if im.shape[0] == 1: 87 | return im # deal with these gray channel images in coco later. Hell 88 | return tf_norm(im, imagenet_mean, imagenet_std) 89 | 90 | 91 | def caffe_imagenet_normalize(im): 92 | im = im * 255.0 93 | im = tf_norm(im, (102.9801, 115.9465, 122.7717), (1.0, 1.0, 1.0)) 94 | return im 95 | 96 | 97 | def check_and_tuplize_splits(splits): 98 | if not isinstance(splits, (tuple, list)): 99 | splits = (splits, ) 100 | for split in splits: 101 | assert split in ('train', 'val', 'test') 102 | return splits 103 | 104 | 105 | def mapify_iterable(iter_of_dict, field_name): 106 | """Convert an iterable of dicts into a big dict indexed by chosen field 107 | I can't think of a better name. 'Tis catchy. 108 | """ 109 | acc = dict() 110 | for item in iter_of_dict: 111 | acc[item[field_name]] = item 112 | return acc 113 | 114 | 115 | def ttic_find_data_root(dset_name): 116 | '''Find the fastest data root on TTIC slurm cluster''' 117 | default = osp.join('/share/data/vision-greg/panoptic', dset_name) 118 | return default 119 | fast = osp.join('/vscratch/vision/panoptic', dset_name) 120 | return fast if osp.isdir(fast) else default 121 | 122 | 123 | def test_meta_conform(rmeta): 124 | '''the metadata for test set does not conform to panoptic format; 125 | Test ann only has 'images' and 'categories; let's fill in annotations' 126 | ''' 127 | images = rmeta['images'] 128 | anns = [] 129 | for img in images: 130 | _curr_ann = { 131 | 'image_id': img['id'], 132 | 'segments_info': [] 133 | # do not fill in file_name 134 | } 135 | anns.append(_curr_ann) 136 | rmeta['annotations'] = anns 137 | return rmeta 138 | 139 | 140 | class PanopticBase(Dataset): 141 | def __init__(self, name, split): 142 | available_dsets = ('coco', 'cityscapes', 'mapillary') 143 | assert name in available_dsets, '{} dset is not available'.format(name) 144 | root = ttic_find_data_root(name) 145 | logger.info('using data root {}'.format(root)) 146 | self.root = root 147 | self.name = name 148 | self.split = split 149 | self.img_root = osp.join(root, 'images', split) 150 | self.ann_root = osp.join(root, 'annotations', split) 151 | meta_fname = osp.join(root, 'annotations', '{}.json'.format(split)) 152 | with open(meta_fname) as f: 153 | rmeta = json.load(f) # rmeta stands for raw metadata 154 | if self.split.startswith('test'): 155 | rmeta = test_meta_conform(rmeta) 156 | 157 | # store category metadata 158 | self.meta = dict() 159 | self.meta['cats'] = mapify_iterable(rmeta['categories'], 'id') 160 | self.meta['cat_IdToName'] = dict() 161 | self.meta['cat_NameToId'] = dict() 162 | for cat in rmeta['categories']: 163 | id, name = cat['id'], cat['name'] 164 | self.meta['cat_IdToName'][id] = name 165 | self.meta['cat_NameToId'][name] = id 166 | 167 | # store image and annotations metadata 168 | self.imgs = mapify_iterable(rmeta['images'], 'id') 169 | self.imgToAnns = mapify_iterable(rmeta['annotations'], 'image_id') 170 | self.imgIds = list(sorted(self.imgs.keys())) 171 | 172 | def confine_to_subset(self, imgIds): 173 | '''confine data loading to a subset of images 174 | This is used for figure making. 175 | ''' 176 | # confirm that the supplied imgIds are all valid 177 | for supplied_id in imgIds: 178 | assert supplied_id in self.imgIds 179 | self.imgIds = imgIds 180 | 181 | def test_seek_imgs(self, i, total_splits): 182 | assert isinstance(total_splits, int) 183 | length = len(self.imgIds) 184 | portion_size = int(math.ceil(length * 1.0 / total_splits)) 185 | start = i * portion_size 186 | end = min(length, (i + 1) * portion_size) 187 | self.imgIds = self.imgIds[start:end] 188 | acc = { k: self.imgs[k] for k in self.imgIds } 189 | self.imgs = acc 190 | 191 | def __len__(self): 192 | return len(self.imgIds) 193 | 194 | def read_img(self, img_fname): 195 | # there are some gray scale images in coco; convert to RGB 196 | return Image.open(img_fname).convert('RGB') 197 | 198 | def get_meta(self, index): 199 | imgId = self.imgIds[index] 200 | imgMeta = self.imgs[imgId] 201 | anns = self.imgToAnns[imgId] 202 | return imgMeta, anns 203 | 204 | def __getitem__(self, index): 205 | imgMeta, anns = self.get_meta(index) 206 | img_fname = osp.join(self.img_root, imgMeta['file_name']) 207 | img = self.read_img(img_fname) 208 | if self.split.startswith('test'): 209 | mask = Image.fromarray( 210 | np.zeros(np.array(img).shape, dtype=np.uint8) 211 | ) 212 | else: 213 | mask = Image.open(osp.join(self.ann_root, anns['file_name'])) 214 | segments_info = mapify_iterable(anns['segments_info'], 'id') 215 | return imgMeta, segments_info, img, mask 216 | 217 | 218 | class SemanticSeg(PanopticBase): 219 | def __init__(self, name, split, transforms): 220 | super().__init__(name=name, split=split) 221 | self.transforms = transforms 222 | # produce train cat index id starting at 0 223 | self.meta['catId_2_trainId'] = dict() 224 | self.meta['trainId_2_catId'] = dict() 225 | self.meta['trainId_2_catName'] = dict() # all things grouped into "things" 226 | self.meta['trainId_2_catName'][ignore_index] = 'ignore' 227 | self.prep_trainId() 228 | self.meta['num_classes'] = len(self.meta['trainId_2_catName']) - 1 229 | 230 | def prep_trainId(self): 231 | curr_inx = 0 232 | for catId, cat in self.meta['cats'].items(): 233 | self.meta['catId_2_trainId'][catId] = curr_inx 234 | self.meta['trainId_2_catId'][curr_inx] = catId 235 | self.meta['trainId_2_catName'][curr_inx] = cat['name'] 236 | curr_inx += 1 237 | 238 | def __getitem__(self, index): 239 | raise ValueError('diabling data loading through this class for now') 240 | _, segments_info, im, mask = super().__getitem__(index) 241 | mask = np.array(mask) 242 | mask = rgb2id(np.array(mask)) 243 | mask = convert_pan_m_to_sem_m( 244 | mask, segments_info, self.meta['catId_2_trainId']) 245 | mask = Image.fromarray(mask, mode='I') 246 | if self.transforms is not None: 247 | im, mask = self.transforms(im, mask) 248 | im, mask = to_tensor(im), to_tensor(mask).squeeze(dim=0).long() 249 | im = imagenet_normalize(im) 250 | return im, mask 251 | 252 | 253 | class PanopticSeg(SemanticSeg): 254 | def __init__( 255 | self, name, split, transforms, pcv, gt_producers, 256 | caffe_mode=False, tensorize=True 257 | ): 258 | super().__init__(name=name, split=split, transforms=transforms) 259 | self.pcv = pcv 260 | self.meta['stuff_pred_thresh'] = -1 261 | self.gt_producer_cfgs = gt_producers 262 | self.tensorize = tensorize 263 | self.caffe_mode = caffe_mode 264 | self.gt_prod_handle = partial(GTGenerator, producer_cfgs=gt_producers) 265 | 266 | def read_img(self, img_fname): 267 | if self.caffe_mode: 268 | # cv2 reads imgs in BGR, which is what caffe trained models expect 269 | img = cv2.imread(img_fname) # cv2 auto converts gray to BGR 270 | img = Image.fromarray(img) 271 | else: 272 | img = Image.open(img_fname).convert('RGB') 273 | return img 274 | 275 | def pan_getitem(self, index, apply_trans=True): 276 | # this is now exposed as public API 277 | imgMeta, segments_info, img, mask = PanopticBase.__getitem__(self, index) 278 | if apply_trans and self.transforms is not None: 279 | img, mask = self.transforms(img, mask) 280 | return imgMeta, segments_info, img, mask 281 | 282 | def __getitem__(self, index): 283 | _, segments_info, im, pan_mask = self.pan_getitem(index) 284 | gts = [] 285 | if not self.split.startswith('test'): 286 | lo_pan_mask = pan_mask.resize( 287 | np.array(im.size, dtype=np.int) // 4, resample=Image.NEAREST 288 | ) 289 | gts = self.gt_prod_handle( 290 | self.meta, self.pcv, lo_pan_mask, segments_info 291 | ).generate_gt() 292 | 293 | if self.split == 'train': 294 | sem_gt = gts[0] 295 | else: 296 | hi_sem_gt = self.gt_prod_handle( 297 | self.meta, self.pcv, pan_mask, segments_info, 298 | ).sem_gt 299 | sem_gt = hi_sem_gt 300 | gts[0] = sem_gt 301 | # else for test/test-dev, do not produce ground truth at all 302 | 303 | if self.tensorize: 304 | im = to_tensor(im) 305 | gts = [ torch.as_tensor(elem) for elem in gts ] 306 | if self.caffe_mode: 307 | im = caffe_imagenet_normalize(im) 308 | else: 309 | im = imagenet_normalize(im) 310 | else: 311 | im = np.array(im) 312 | 313 | gts.insert(0, im) 314 | return tuple(gts) 315 | 316 | @classmethod 317 | def make_loader( 318 | cls, data_cfg, pcv_module, is_train, mp_distributed, world_size, 319 | val_split='val' 320 | ): 321 | if is_train: 322 | split = 'train' 323 | batch_size = data_cfg.train_batch_size 324 | transforms_cfg = data_cfg.train_transforms 325 | else: 326 | split = val_split 327 | batch_size = data_cfg.test_batch_size 328 | transforms_cfg = data_cfg.test_transforms 329 | 330 | num_workers = data_cfg.num_loading_threads 331 | if mp_distributed: 332 | num_workers = int((num_workers + world_size - 1) / world_size) 333 | if is_train: 334 | batch_size = int(batch_size / world_size) 335 | # at test time a model does not need to reduce its batch size 336 | 337 | # 1. dataset 338 | trans = get_composed_augmentations(transforms_cfg) 339 | instance = cls( 340 | split=split, transforms=trans, pcv=pcv_module, 341 | gt_producers=data_cfg.dataset.gt_producers, 342 | **data_cfg.dataset.params, 343 | ) 344 | # if split.startswith('test'): 345 | # inx = global_cfg.testing.inx 346 | # total_splits = global_cfg.testing.portions 347 | # instance.test_seek_imgs(inx, total_splits) 348 | 349 | # 2. sampler 350 | sampler = cls.make_sampler(instance, is_train, mp_distributed) 351 | 352 | # 3. batch sampler 353 | batch_sampler = cls.make_batch_sampler( 354 | instance, sampler, batch_size=batch_size, 355 | aspect_grouping=cls.aspect_grouping, 356 | ) 357 | del sampler 358 | 359 | # 4. collator 360 | if split.startswith('test'): 361 | collator = None 362 | else: 363 | collator = BatchCollator(data_cfg.dataset.gt_producers) 364 | 365 | # 5. loader 366 | loader = DataLoader( 367 | instance, 368 | num_workers=num_workers, 369 | batch_sampler=batch_sampler, 370 | collate_fn=collator, 371 | # pin_memory=True maskrcnn-benchmark does not pin memory 372 | ) 373 | return loader 374 | 375 | @staticmethod 376 | def make_sampler(dataset, is_train, distributed): 377 | if is_train: 378 | if distributed: 379 | # as of pytorch 1.1.0 the distributed sampler always shuffles 380 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) 381 | else: 382 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 383 | else: 384 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 385 | return sampler 386 | 387 | @staticmethod 388 | def make_batch_sampler( 389 | dataset, sampler, batch_size, aspect_grouping, 390 | ): 391 | if aspect_grouping: 392 | aspect_ratios = _compute_aspect_ratios(dataset) 393 | group_ids = _quantize(aspect_ratios, bins=[1, ]) 394 | batch_sampler = GroupedBatchSampler( 395 | sampler, group_ids, batch_size, drop_uneven=False 396 | ) 397 | else: 398 | batch_sampler = torch.utils.data.sampler.BatchSampler( 399 | sampler, batch_size, drop_last=False 400 | ) 401 | ''' 402 | I've decided after much deliberation not to use iteration based training. 403 | Our cluster has a 4-hour time limit before job interrupt. 404 | Under this constraint, I have to resume from where the sampling stopped 405 | at the exact epoch the interrupt occurs and this requires checkpointing 406 | the sampler state. 407 | However, under distributed settings, each process has to checkpoint a 408 | different state since each sees only a portion of the data by virtue of the 409 | distributed sampler. This is bug-prone and brittle. 410 | ''' 411 | # if num_iters is not None: 412 | # batch_sampler = samplers.IterationBasedBatchSampler( 413 | # batch_sampler, num_iters, start_iter 414 | # ) 415 | return batch_sampler 416 | 417 | 418 | class BatchCollator(object): 419 | def __init__(self, producer_cfg): 420 | fills = [0, ignore_index, ] # always 0 for img, ignore for sem_gt 421 | fills.extend( 422 | [cfg['params']['fill'] for cfg in producer_cfg] 423 | ) 424 | self.fills = fills 425 | 426 | def __call__(self, batch): 427 | transposed_batch = list(zip(*batch)) 428 | assert len(self.fills) == len(transposed_batch), 'must match in length' 429 | acc = [] 430 | for tsr_list, fill in zip(transposed_batch, self.fills): 431 | tsr = self.collate_tensor_list(tsr_list, fill=fill) 432 | acc.append(tsr) 433 | return tuple(acc) 434 | 435 | @staticmethod 436 | def collate_tensor_list(tensors, fill): 437 | """ 438 | Pad the Tensors with the specified constant 439 | so that they have the same shape 440 | """ 441 | assert isinstance(tensors, (tuple, list)) 442 | # largest size along each dimension 443 | max_size = tuple(max(s) for s in zip(*[tsr.shape for tsr in tensors])) 444 | assert len(max_size) == 2 or len(max_size) == 3 445 | batch_shape = (len(tensors),) + max_size 446 | batched_tsr = tensors[0].new(*batch_shape).fill_(fill) 447 | for tsr, pad_tsr in zip(tensors, batched_tsr): 448 | # WARNING only pad the last 2, that is the spatial dimensions 449 | pad_tsr[..., :tsr.shape[-2], :tsr.shape[-1]].copy_(tsr) 450 | return batched_tsr 451 | 452 | 453 | def _quantize(x, bins): 454 | bins = copy.copy(bins) 455 | bins = sorted(bins) 456 | quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) 457 | return quantized 458 | 459 | 460 | def _compute_aspect_ratios(dataset): 461 | aspect_ratios = [] 462 | for i in range(len(dataset)): 463 | imgMeta, _ = dataset.get_meta(i) 464 | aspect_ratio = float(imgMeta["height"]) / float(imgMeta["width"]) 465 | aspect_ratios.append(aspect_ratio) 466 | return aspect_ratios 467 | -------------------------------------------------------------------------------- /src/datasets/gt_producer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from panopticapi.utils import rgb2id 4 | 5 | from fabric.io import save_object 6 | 7 | ignore_index = -1 8 | 9 | 10 | def convert_pan_m_to_sem_m(pan_mask, segments_info, catId_2_trainId): 11 | """Convert a panoptic mask to semantic segmentation mask 12 | pan_mask: [H, W, 3] panoptic mask 13 | segments_info: dict 14 | catId_2_trainId: dict 15 | """ 16 | sem = np.zeros_like(pan_mask, np.int64) # torch requires long tensor 17 | iid_to_catid = { 18 | el['id']: el['category_id'] for el in segments_info.values() 19 | } 20 | for iid in np.unique(pan_mask): 21 | if iid not in iid_to_catid: 22 | assert iid == 0 23 | sem[pan_mask == 0] = ignore_index 24 | continue 25 | cat_id = iid_to_catid[iid] 26 | train_id = catId_2_trainId[cat_id] 27 | sem[pan_mask == iid] = train_id 28 | return sem 29 | 30 | 31 | def tensorize_2d_spatial_assignement(spatial, num_channels): 32 | """ 33 | Args: 34 | spatial: [H, W] where -1 encodes invalid region 35 | Here we produce an extra, last channel to be set by inx -1, 36 | only to be throw out later. Neat 37 | Ret: 38 | np array of shape [1, C, H, W] 39 | """ 40 | H, W = spatial.shape 41 | num_channels += 1 # add 1 extra channel to be turned on by inx -1 42 | tsr = np.zeros(shape=(num_channels, H, W)) 43 | dim_0_inds, dim_1_inds = np.ix_(range(H), range(W)) 44 | tsr[spatial, dim_0_inds, dim_1_inds] = 1 45 | tsr = tsr[:-1, :, :] # throw out the extra dimension for -1s 46 | tsr = tsr[np.newaxis, ...] 47 | return tsr 48 | 49 | 50 | class GtProducer(): 51 | def __init__(self, pcv, mask_shape, params): 52 | self.pcv = pcv 53 | self.mask_shape = mask_shape 54 | self.params = params 55 | self.interpretable_as_prob_tsr = False 56 | 57 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset): 58 | raise NotImplementedError() 59 | 60 | def produce(): 61 | raise NotImplementedError() 62 | 63 | def convert_to_prob_tsr(): 64 | raise NotImplementedError() 65 | 66 | 67 | class WeightMask(GtProducer): 68 | def __init__(self, *args, **kwargs): 69 | super().__init__(*args, **kwargs) 70 | self.weight = np.zeros(self.mask_shape, dtype=np.float32) 71 | self.power = self.params.get('power', 0.5) 72 | 73 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset): 74 | w = 1 / (segment_area ** self.power) 75 | if isthing and not iscrowd: 76 | self.weight[segment_mask] = w 77 | elif not isthing: 78 | self.weight[segment_mask] = w 79 | elif iscrowd: 80 | pass 81 | else: 82 | raise ValueError('unreachable') 83 | 84 | def produce(self): 85 | return self.weight 86 | 87 | 88 | class _WeightMask(GtProducer): 89 | def __init__(self, *args, **kwargs): 90 | super().__init__(*args, **kwargs) 91 | self.weight = np.zeros(self.mask_shape, dtype=np.float32) 92 | self.power = self.params.get('power', 0.5) 93 | 94 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset): 95 | w = 1 / (segment_area ** self.power) 96 | if isthing and not iscrowd: 97 | lbls = self.pcv.discrete_vote_inx_from_offset(ins_offset) 98 | # pixs within an instance may be ignored if they sit on grid boundaries 99 | is_valid = lbls > ignore_index 100 | valid_area = is_valid.sum() 101 | # assert valid_area > 0, 'segment area {} vs valid area {}, lbls are {}'.format( 102 | # segment_area, valid_area, lbls 103 | # ) 104 | if valid_area == 0: 105 | w = 0 106 | else: 107 | w = 1 / (valid_area ** self.power) 108 | self.weight[segment_mask] = w * is_valid 109 | elif not isthing: 110 | self.weight[segment_mask] = w 111 | elif iscrowd: 112 | pass 113 | else: 114 | raise ValueError('unreachable') 115 | 116 | def produce(self): 117 | return self.weight 118 | 119 | 120 | class ThingWeightMask(GtProducer): 121 | def __init__(self, *args, **kwargs): 122 | super().__init__(*args, **kwargs) 123 | self.weight = np.zeros(self.mask_shape, dtype=np.float32) 124 | self.power = self.params.get('power', 0.5) 125 | 126 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset): 127 | w = 1 / (segment_area ** self.power) 128 | if isthing and not iscrowd: 129 | lbls = self.pcv.discrete_vote_inx_from_offset(ins_offset) 130 | # pixs within an instance may be ignored if they sit on grid boundaries 131 | is_valid = lbls > ignore_index 132 | valid_area = is_valid.sum() 133 | # assert valid_area > 0, 'segment area {} vs valid area {}, lbls are {}'.format( 134 | # segment_area, valid_area, lbls 135 | # ) 136 | if valid_area == 0: 137 | w = 0 138 | else: 139 | w = 1 / (valid_area ** self.power) 140 | self.weight[segment_mask] = w * is_valid 141 | elif not isthing: 142 | self.weight[segment_mask] = 0 143 | elif iscrowd: 144 | pass 145 | else: 146 | raise ValueError('unreachable') 147 | 148 | def produce(self): 149 | return self.weight 150 | 151 | 152 | class NoAbstainWeightMask(GtProducer): 153 | def __init__(self, *args, **kwargs): 154 | super().__init__(*args, **kwargs) 155 | self.weight = np.zeros(self.mask_shape, dtype=np.float32) 156 | self.power = self.params.get('power', 0.5) 157 | 158 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset): 159 | w = 1 / (segment_area ** self.power) 160 | if isthing and not iscrowd: 161 | self.weight[segment_mask] = w 162 | elif not isthing: # only difference compared to WeightMask 163 | pass 164 | elif iscrowd: 165 | pass 166 | else: 167 | raise ValueError('unreachable') 168 | 169 | def produce(self): 170 | return self.weight 171 | 172 | 173 | class InsWeightMask(GtProducer): 174 | def __init__(self, *args, **kwargs): 175 | super().__init__(*args, **kwargs) 176 | self.weight = np.zeros(self.mask_shape, dtype=np.float32) 177 | self.stuff_mask = np.zeros(self.mask_shape, dtype=bool) 178 | 179 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset): 180 | if isthing and not iscrowd: 181 | self.weight[segment_mask] = 1 / segment_area 182 | elif not isthing: 183 | self.stuff_mask[segment_mask] = True 184 | elif iscrowd: 185 | pass 186 | else: 187 | raise ValueError('unreachable') 188 | 189 | def produce(self): 190 | stuff_area = self.stuff_mask.sum() 191 | stuff_w = 20 / stuff_area 192 | self.weight[self.stuff_mask] = stuff_w 193 | self.weight = self.weight / 40 194 | return self.weight 195 | 196 | 197 | class PCV_vote_no_abstain_mask(GtProducer): 198 | def __init__(self, *args, **kwargs): 199 | super().__init__(*args, **kwargs) 200 | self.interpretable_as_prob_tsr = True 201 | self.gt_mask = ignore_index * np.ones(self.mask_shape, np.int32) 202 | self.votable_mask = np.zeros(self.mask_shape, dtype=bool) 203 | self.offset_mask = np.empty( 204 | list(self.mask_shape) + [2], dtype=np.int32 205 | ) 206 | pcv = self.pcv 207 | assert (pcv.num_votes - pcv.num_bins) == 1 208 | self.ABSTAIN_INX = ignore_index 209 | 210 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset): 211 | if isthing and not iscrowd: 212 | self.votable_mask[segment_mask] = True 213 | self.offset_mask[segment_mask] = ins_offset 214 | elif not isthing: 215 | self.gt_mask[segment_mask] = self.ABSTAIN_INX 216 | elif iscrowd: 217 | pass 218 | else: 219 | raise ValueError('unreachable') 220 | 221 | def produce(self): 222 | votable_mask = self.votable_mask 223 | self.gt_mask[votable_mask] = self.pcv.discrete_vote_inx_from_offset( 224 | self.offset_mask[votable_mask] 225 | ) 226 | self.gt_mask = self.gt_mask.astype(np.int64) 227 | return self.gt_mask 228 | 229 | def convert_to_prob_tsr(self): 230 | return tensorize_2d_spatial_assignement(self.gt_mask, self.pcv.num_votes) 231 | 232 | 233 | class PCV_vote_mask(GtProducer): 234 | def __init__(self, *args, **kwargs): 235 | super().__init__(*args, **kwargs) 236 | self.interpretable_as_prob_tsr = True 237 | self.gt_mask = ignore_index * np.ones(self.mask_shape, np.int32) 238 | self.votable_mask = np.zeros(self.mask_shape, dtype=bool) 239 | self.offset_mask = np.empty( 240 | list(self.mask_shape) + [2], dtype=np.int32 241 | ) 242 | pcv = self.pcv 243 | assert (pcv.num_votes - pcv.num_bins) == 1 244 | self.ABSTAIN_INX = pcv.num_bins 245 | 246 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset): 247 | if isthing and not iscrowd: 248 | self.votable_mask[segment_mask] = True 249 | self.offset_mask[segment_mask] = ins_offset 250 | elif not isthing: 251 | self.gt_mask[segment_mask] = self.ABSTAIN_INX 252 | elif iscrowd: 253 | pass 254 | else: 255 | raise ValueError('unreachable') 256 | 257 | def produce(self): 258 | votable_mask = self.votable_mask 259 | self.gt_mask[votable_mask] = self.pcv.discrete_vote_inx_from_offset( 260 | self.offset_mask[votable_mask] 261 | ) 262 | self.gt_mask = self.gt_mask.astype(np.int64) 263 | return self.gt_mask 264 | 265 | def convert_to_prob_tsr(self): 266 | return tensorize_2d_spatial_assignement(self.gt_mask, self.pcv.num_votes) 267 | 268 | 269 | class PCV_igc_tsr(GtProducer): 270 | '''this code is problematic, but at least it is well exposed now''' 271 | def __init__(self, *args, **kwargs): 272 | super().__init__(*args, **kwargs) 273 | self.interpretable_as_prob_tsr = True 274 | self.tsr = np.zeros( 275 | list(self.mask_shape) + [self.pcv.num_votes], dtype=np.bool 276 | ) 277 | 278 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset): 279 | if isthing and not iscrowd: 280 | self.tsr[segment_mask] = self.pcv.tensorized_vote_from_offset( 281 | ins_offset 282 | ) 283 | elif not isthing: 284 | self.tsr[segment_mask, -1] = True # abstaining 285 | elif iscrowd: 286 | pass 287 | else: 288 | raise ValueError('unreachable') 289 | 290 | def produce(self): 291 | tsr = torch.as_tensor( 292 | self.tsr.transpose(2, 0, 1) # [C, H, W] 293 | ).contiguous() 294 | self.tsr = tsr 295 | return tsr 296 | 297 | def convert_to_prob_tsr(self): 298 | tsr = self.tsr 299 | vote_tsr = tsr.clone().float() 300 | base = vote_tsr.sum(dim=0, keepdim=True) 301 | base[base == 0] = 1 # any number 302 | vote_tsr = (vote_tsr / base).unsqueeze(dim=0).numpy() 303 | return vote_tsr 304 | 305 | 306 | class PCV_smooth_tsr(GtProducer): 307 | def __init__(self, *args, **kwargs): 308 | super().__init__(*args, **kwargs) 309 | self.interpretable_as_prob_tsr = True 310 | self.votable_mask = np.zeros(self.mask_shape, dtype=bool) 311 | self.abstain_mask = np.zeros(self.mask_shape, dtype=bool) 312 | self.offset_mask = np.empty( 313 | list(self.mask_shape) + [2], dtype=np.int32 314 | ) 315 | self.tsr = np.zeros( 316 | list(self.mask_shape) + [self.pcv.num_votes], dtype=np.float32 317 | ) 318 | pcv = self.pcv 319 | assert (pcv.num_votes - pcv.num_bins) == 1 320 | 321 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset): 322 | if isthing and not iscrowd: 323 | self.votable_mask[segment_mask] = True 324 | self.offset_mask[segment_mask] = ins_offset 325 | elif not isthing: 326 | self.abstain_mask[segment_mask] = True 327 | elif iscrowd: 328 | pass 329 | else: 330 | raise ValueError('unreachable') 331 | 332 | def produce(self): 333 | self.tsr[self.abstain_mask, -1] = 1.0 334 | self.tsr[self.votable_mask] = self.pcv.smooth_prob_tsr_from_offset( 335 | self.offset_mask[self.votable_mask] 336 | ) 337 | self.tsr = torch.as_tensor( 338 | self.tsr.transpose(2, 0, 1) # [C, H, W] 339 | ).contiguous() 340 | return self.tsr 341 | 342 | def convert_to_prob_tsr(self): 343 | tsr = self.tsr 344 | vote_tsr = tsr.clone().unsqueeze(dim=0).numpy() 345 | return vote_tsr 346 | 347 | 348 | _REGISTRY = { 349 | 'weight_mask': WeightMask, 350 | 'thing_weight_mask': ThingWeightMask, 351 | 'vote_mask': PCV_vote_mask, 352 | 'vote_no_abstain_mask': PCV_vote_no_abstain_mask, 353 | 'igc_tsr': PCV_igc_tsr, 354 | 'smth_tsr': PCV_smooth_tsr, 355 | 'ins_weight_mask': InsWeightMask, 356 | 'no_abstain_weight_mask': NoAbstainWeightMask 357 | } 358 | 359 | 360 | class GTGenerator(): 361 | def __init__(self, meta, pcv, pan_mask, segments_info, producer_cfgs=()): 362 | assert len(meta['catId_2_trainId']) == len(meta['cats']) 363 | self.pcv = pcv 364 | self.meta = meta 365 | self.pan_mask = rgb2id(np.array(pan_mask)) 366 | self.segments_info = segments_info 367 | 368 | am_I_gt_producer = [ 369 | int(cfg['params'].get('is_vote_gt', False)) for cfg in producer_cfgs 370 | ] 371 | if sum(am_I_gt_producer) != 1: 372 | raise ValueError('exactly 1 producer should be in charge of vote gt') 373 | self.gt_producer_inx = np.argmax(am_I_gt_producer) 374 | self.producer_cfgs = producer_cfgs 375 | self.producers = None 376 | 377 | self.ins_centroids = None 378 | self._sem_gt = None 379 | 380 | @property 381 | def sem_gt(self): 382 | if self._sem_gt is None: 383 | self._sem_gt = convert_pan_m_to_sem_m( 384 | self.pan_mask, self.segments_info, self.meta['catId_2_trainId'] 385 | ) 386 | return self._sem_gt 387 | 388 | def generate_gt(self): 389 | """Return a new mask with each pixel containing a discrete numeric label 390 | indicating which spatial bin the pixel should vote for 391 | Args: 392 | mask: [H, W] array filled with segment id 393 | segments_info: 394 | """ 395 | pcv = self.pcv 396 | mask = self.pan_mask 397 | segments_info = self.segments_info 398 | category_meta = self.meta['cats'] 399 | 400 | self.producers = [ 401 | _REGISTRY[cfg['name']](pcv, mask.shape, cfg['params']) 402 | for cfg in self.producer_cfgs 403 | ] # initalize the producers with params 404 | 405 | # indices grid of shape [2, H, W], where first dim is y, x; swap them 406 | # [H, W, 2] where last channel is x, y 407 | spatial_inds = np.indices( 408 | mask.shape).transpose(1, 2, 0)[..., ::-1].astype(np.int32) 409 | ins_centroids = [] 410 | 411 | for segment_id, info in segments_info.items(): 412 | cat, iscrowd = info['category_id'], info['iscrowd'] 413 | isthing = category_meta[cat]['isthing'] 414 | segment_mask = (mask == segment_id) 415 | area = segment_mask.sum() 416 | if area == 0: 417 | # cropping or extreme resizing might cause segments to disappear 418 | continue 419 | 420 | ins_offset = None 421 | if isthing and not iscrowd: 422 | ins_center = pcv.centroid_from_ins_mask(segment_mask) 423 | # ins_center = [math.ceil(_x) for _x in ins_center] 424 | ins_offset = (ins_center - spatial_inds[segment_mask]).astype(np.int32) 425 | # ins_center = np.array(ins_center, dtype=np.int32) 426 | # ins_offset = ins_center - spatial_inds[segment_mask] 427 | # ERROR must investigate this!!! 428 | ins_centroids.append(ins_center) 429 | 430 | for actor in self.producers: 431 | actor.process(isthing, iscrowd, segment_mask, area, ins_offset) 432 | 433 | # try: 434 | # for actor in self.producers: 435 | # actor.process(isthing, iscrowd, segment_mask, area, ins_offset) 436 | # except: 437 | # dump = { 438 | # 'pan_mask': self.pan_mask, 439 | # 'segment_id': segment_id, 440 | # 'segments_info': segments_info, 441 | # 'area': area, 442 | # 'ins_offset': ins_offset, 443 | # 'cat': cat, 444 | # 'crowd': iscrowd, 445 | # 'isthing': isthing, 446 | # } 447 | # save_object(dump, './troubleshoot.pkl') 448 | # raise ValueError('terminate here') 449 | 450 | self.ins_centroids = np.array(ins_centroids).reshape(-1, 2) 451 | 452 | training_gts = [self.sem_gt, ] 453 | training_gts.extend( 454 | [ actor.produce() for actor in self.producers ] 455 | ) 456 | return training_gts 457 | 458 | def collect_prob_tsr(self): 459 | assert self.producers is not None, 'must create gt first' 460 | gt_producer = self.producers[self.gt_producer_inx] 461 | assert gt_producer.interpretable_as_prob_tsr, \ 462 | '{} should be interpretable as prob tsr'.format(gt_producer.__class__) 463 | 464 | sem_prob_tsr = tensorize_2d_spatial_assignement( 465 | self.sem_gt, len(self.meta['catId_2_trainId']) 466 | ) 467 | vote_prob_tsr = gt_producer.convert_to_prob_tsr() 468 | return sem_prob_tsr, vote_prob_tsr 469 | -------------------------------------------------------------------------------- /src/pan_analyzer.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import os.path as osp 3 | from collections import defaultdict 4 | import json 5 | 6 | import numpy as np 7 | import scipy.linalg as LA 8 | from scipy.ndimage import binary_dilation, generate_binary_structure 9 | import pandas as pd 10 | from PIL import Image 11 | 12 | from tabulate import tabulate 13 | 14 | from panopticapi.utils import rgb2id 15 | from panoptic.pan_eval import PQStat, OFFSET, VOID 16 | from panoptic.datasets.base import mapify_iterable 17 | 18 | from fabric.io import load_object, save_object 19 | 20 | _SEGMENT_UNMATCHED = 0 21 | _SEGMENT_MATCHED = 1 22 | _SEGMENT_FORGIVEN = 2 23 | 24 | 25 | def generalized_aspect_ratio(binary_mask): 26 | xs, ys = np.where(binary_mask) 27 | coords = np.array([xs, ys]).T 28 | # mean center the coords 29 | coords = coords - coords.mean(axis=0) 30 | # cov matrix 31 | cov = coords.T @ coords 32 | first, second = LA.eigvalsh(cov)[::-1] 33 | ratio = (first ** 0.5) / (second + 1e-8) ** 0.5 34 | return ratio 35 | 36 | 37 | class Annotation(): 38 | ''' 39 | Overall Schema for Each Side (Either Gt or Pred) 40 | e.g. pred: 41 | cats: {id: cat} 42 | imgs: {id: image} 43 | segs {sid: seg} 44 | img2seg: {image_id: {sid: seg}} 45 | cat2seg: {cat_id: {sid: seg}} 46 | 47 | cat: 48 | id: 7, 49 | name: road, 50 | supercategory: 'flat', 51 | color: [128, 64, 128], 52 | isthing: 0 53 | 54 | image: 55 | id: 'frankfurt_000000_005898', 56 | file_name: frankfurt/frankfurt_000000_005898_leftImg8bit.png 57 | ann_fname: abcde 58 | width: 2048, 59 | height: 1024, 60 | --- 61 | mask: a cached mask that is loaded 62 | 63 | seg: 64 | sid (seg id): gt/frankfurt_000000_000294/8405120 65 | image_id: frankfurt_000000_000294 66 | id: 8405120 67 | category_id: 7 68 | area: 624611 69 | bbox: [6, 432, 1909, 547] 70 | iscrowd: 0 71 | match_state: one of (UNMATCHED, MATCHED, IGNORED) 72 | matchings: {sid: iou} sorted from high to low 73 | (breakdown_flag): this can be optinally introduced for breakdown analysis 74 | ''' 75 | def __init__(self, json_meta_fname, is_gt, state_dict=None): 76 | ''' 77 | if state_dict is provided, then just load it and avoid the computation 78 | ''' 79 | dirname, fname = osp.split(json_meta_fname) 80 | self.root = dirname 81 | self.mask_root = osp.join(dirname, fname.split('.')[0]) 82 | 83 | if state_dict is None: 84 | with open(json_meta_fname) as f: 85 | state_dict = self.process_raw_meta(json.load(f), is_gt) 86 | self.state_dict = state_dict 87 | self.register_state(state_dict) 88 | 89 | def register_state(self, state): 90 | for k, v in state.items(): 91 | setattr(self, k, v) 92 | 93 | @staticmethod 94 | def process_raw_meta(raw_meta, is_gt): 95 | state = dict() 96 | state['cats'] = mapify_iterable(raw_meta['categories'], 'id') 97 | state['imgs'] = mapify_iterable(raw_meta['images'], 'id') 98 | state['segs'] = dict() 99 | state['img2seg'] = defaultdict(dict) 100 | state['cat2seg'] = defaultdict(dict) 101 | 102 | sid_prefix = 'gt' if is_gt else 'pd' 103 | 104 | for ann in raw_meta['annotations']: 105 | image_id = ann['image_id'] 106 | segments = ann['segments_info'] 107 | state['imgs'][image_id]['ann_fname'] = ann['file_name'] 108 | for seg in segments: 109 | cat_id = seg['category_id'] 110 | 111 | unique_id = '{}/{}/{}'.format(sid_prefix, image_id, seg['id']) 112 | seg['sid'] = unique_id 113 | seg['image_id'] = image_id 114 | seg['match_state'] = _SEGMENT_FORGIVEN 115 | seg['matchings'] = dict() 116 | 117 | state['segs'][unique_id] = seg 118 | state['img2seg'][image_id][unique_id] = seg 119 | state['cat2seg'][cat_id][unique_id] = seg 120 | return state 121 | 122 | def seg_sort_matchings(self): 123 | """sort matchings from high to low IoU""" 124 | for _, seg in self.segs.items(): 125 | matchings = seg['matchings'] 126 | seg['matchings'] = dict( 127 | sorted(matchings.items(), key=lambda x: x[1], reverse=True) 128 | ) 129 | 130 | def match_summarize(self, breakdown_flag=None): 131 | ''' 132 | ret: [num_cats, 4] where each row contains 133 | (iou_sum, num_matched, num_unmatched, total_inst) 134 | ''' 135 | ret = [] 136 | for cat in sorted(self.cats.keys()): 137 | segs = self.cat2seg[cat].values() 138 | iou_sum, num_matched, num_unmatched, total_inst = 0.0, 0.0, 0.0, 0.0 139 | for seg in segs: 140 | if breakdown_flag is not None and seg['breakdown_flag'] != breakdown_flag: 141 | continue # if breakdown is activated, only summarize those required 142 | total_inst += 1 143 | if seg['match_state'] == _SEGMENT_MATCHED: 144 | iou_sum += list(seg['matchings'].values())[0] 145 | num_matched += 1 146 | elif seg['match_state'] == _SEGMENT_UNMATCHED: 147 | num_unmatched += 1 148 | ret.append([iou_sum, num_matched, num_unmatched, total_inst]) 149 | ret = np.array(ret) 150 | return ret 151 | 152 | def catId_given_catName(self, catName): 153 | for catId, cat in self.cats.items(): 154 | if cat['name'] == catName: 155 | return catId 156 | raise ValueError('what kind of category is this? {}'.format(catName)) 157 | 158 | def get_mask_given_seg(self, seg): 159 | return self.get_mask_given_imgid(self, seg['image_id']) 160 | 161 | def get_img_given_imgid(self, image_id, img_root): 162 | img = self.imgs[image_id] 163 | img_fname = img['file_name'] 164 | img_fname = osp.join(img_root, img_fname) 165 | img = Image.open(img_fname) 166 | return img 167 | 168 | def get_mask_given_imgid(self, image_id, store_in_cache=True): 169 | img = self.imgs[image_id] 170 | _MASK_KEYNAME = 'mask' 171 | cache_entry = img.get(_MASK_KEYNAME, None) 172 | if cache_entry is not None: 173 | assert isinstance(cache_entry, np.ndarray) 174 | return cache_entry 175 | else: 176 | mask_fname = img['ann_fname'] 177 | mask = np.array( 178 | Image.open(osp.join(self.mask_root, mask_fname)), 179 | dtype=np.uint32 180 | ) 181 | mask = rgb2id(mask) 182 | if store_in_cache: 183 | img[_MASK_KEYNAME] = mask 184 | return mask 185 | 186 | def compute_seg_shape_oddity(self): 187 | print('start computing shape oddity') 188 | i = 0 189 | for imgId, segs in self.img2seg.items(): 190 | i += 1 191 | if (i % 50) == 0: 192 | print(i) 193 | mask = self.get_mask_given_imgid(imgId, store_in_cache=False) 194 | for _, s in segs.items(): 195 | seg_id = s['id'] 196 | binary_mask = (mask == seg_id) 197 | s['gen_aspect_ratio'] = generalized_aspect_ratio(binary_mask) 198 | 199 | def compute_seg_boundary_stats(self): 200 | print('start computing boundary stats') 201 | i = 0 202 | for imgId, segs in self.img2seg.items(): 203 | i += 1 204 | if (i % 50) == 0: 205 | print(i) 206 | mask = self.get_mask_given_imgid(imgId, store_in_cache=False) 207 | for _, s in segs.items(): 208 | seg_id = s['id'] 209 | binary_mask = (mask == seg_id) 210 | self._per_seg_neighbors_stats(s, binary_mask, mask) 211 | 212 | def _per_seg_neighbors_stats(self, seg_dict, binary_mask, mask): 213 | area = binary_mask.sum() 214 | # struct = generate_binary_structure(2, 2) 215 | dilated = binary_dilation(binary_mask, structure=None, iterations=1) 216 | boundary = dilated ^ binary_mask 217 | 218 | # stats 219 | length = boundary.sum() 220 | ratio = length ** 2 / area 221 | seg_dict['la_ratio'] = ratio 222 | 223 | # get the neighbors 224 | ids, cnts = np.unique(mask[boundary], return_counts=True) 225 | 226 | sid_prefix = '/'.join( 227 | seg_dict['sid'].split('/')[:2] # throw away the last 228 | ) 229 | sids = [ '{}/{}'.format(sid_prefix, id) for id in ids ] 230 | 231 | thing_neighbors = { 232 | sid: cnt for sid, id, cnt in zip(sids, ids, cnts) 233 | if id > 0 and self.cats[self.segs[sid]['category_id']]['isthing'] 234 | } 235 | seg_dict['thing_neighbors'] = thing_neighbors 236 | 237 | 238 | class PanopticEvalAnalyzer(): 239 | def __init__(self, gt_json_meta_fname, pd_json_meta_fname, load_state=True): 240 | # use the pd folder as root directory since a single gt ann can correspond 241 | # to many pd anns. 242 | root = osp.split(pd_json_meta_fname)[0] 243 | self.state_dump_fname = osp.join(root, 'analyze_dump.pkl') 244 | 245 | is_evaluated = False 246 | if osp.isfile(self.state_dump_fname) and load_state: 247 | state = load_object(self.state_dump_fname) 248 | gt_state, pd_state = state['gt'], state['pd'] 249 | is_evaluated = True 250 | else: 251 | gt_state, pd_state = None, None 252 | 253 | self.gt = Annotation(gt_json_meta_fname, is_gt=True, state_dict=gt_state) 254 | self.pd = Annotation(pd_json_meta_fname, is_gt=False, state_dict=pd_state) 255 | 256 | # validate that gt and pd json completely match 257 | assert self.gt.imgs.keys() == self.pd.imgs.keys() 258 | assert self.gt.cats == self.pd.cats 259 | self.imgIds = list(sorted(self.gt.imgs.keys())) 260 | 261 | if not is_evaluated: 262 | # evaluate and then save the state 263 | self._evaluate() 264 | self.gt.compute_seg_shape_oddity() 265 | self.pd.compute_seg_shape_oddity() 266 | self.dump_state() 267 | 268 | def _gt_boundary_stats(self): 269 | self.gt.compute_seg_boundary_stats() 270 | 271 | def dump_state(self): 272 | state = { 273 | 'gt': self.gt.state_dict, 274 | 'pd': self.pd.state_dict 275 | } 276 | save_object(state, self.state_dump_fname) 277 | 278 | def _evaluate(self): 279 | stats = PQStat() 280 | cats = self.gt.cats 281 | for i, imgId in enumerate(self.imgIds): 282 | if (i % 50) == 0: 283 | print("progress {} / {}".format(i, len(self.imgIds))) 284 | # if (i > 100): 285 | # break 286 | 287 | gt_ann = { 288 | 'image_id': imgId, 'segments_info': self.gt.img2seg[imgId].values() 289 | } 290 | gt_mask = np.array( 291 | Image.open(osp.join( 292 | self.gt.mask_root, self.gt.imgs[imgId]['ann_fname'] 293 | )), 294 | dtype=np.uint32 295 | ) 296 | gt_mask = rgb2id(gt_mask) 297 | 298 | pd_ann = { 299 | 'image_id': imgId, 'segments_info': self.pd.img2seg[imgId].values() 300 | } 301 | pd_mask = np.array( 302 | Image.open(osp.join( 303 | self.pd.mask_root, self.pd.imgs[imgId]['ann_fname'] 304 | )), 305 | dtype=np.uint32 306 | ) 307 | pd_mask = rgb2id(pd_mask) 308 | 309 | _single_stat = self.pq_compute_single_img( 310 | cats, gt_ann, gt_mask, pd_ann, pd_mask 311 | ) 312 | stats += _single_stat 313 | 314 | self.gt.seg_sort_matchings() 315 | self.pd.seg_sort_matchings() 316 | return stats 317 | 318 | def summarize(self, flag=None): 319 | per_cat_res, overall_table, cat_table = self._aggregate( 320 | gt_stats=self.gt.match_summarize(flag), 321 | pd_stats=self.pd.match_summarize(flag), 322 | cats=self.gt.cats 323 | ) 324 | return per_cat_res, overall_table, cat_table 325 | 326 | @staticmethod 327 | def _aggregate(gt_stats, pd_stats, cats): 328 | ''' 329 | Args: 330 | pd/gt_stats: [num_cats, 4] with each row contains 331 | (iou_sum, num_matched, num_unmatched, total_inst) 332 | cats: a dict of {catId: catMetaData} 333 | Returns: 334 | 1. per cat pandas dataframe; easy to programmatically manipulate 335 | 2. str formatted overall result table 336 | 3. str formatted per category result table 337 | ''' 338 | # each is of shape [num_cats] 339 | gt_iou, gt_matched, gt_unmatched, gt_tot_inst = gt_stats.T 340 | pd_iou, pd_matched, pd_unmatched, pd_tot_inst = pd_stats.T 341 | assert np.allclose(gt_iou, pd_iou) and (gt_matched == pd_matched).all() 342 | 343 | catIds = list(sorted(cats.keys())) 344 | catNames = [cats[id]['name'] for id in catIds] 345 | isthing = np.array([cats[id]['isthing'] for id in catIds], dtype=bool) 346 | 347 | RQ = gt_matched / (gt_matched + 0.5 * gt_unmatched + 0.5 * pd_unmatched) 348 | SQ = gt_iou / gt_matched 349 | RQ, SQ = np.nan_to_num(RQ), np.nan_to_num(SQ) 350 | PQ = RQ * SQ 351 | results = np.array([PQ, SQ, RQ]) * 100 # [3, num_cats] 352 | 353 | overall_table = tabulate( 354 | headers=['', 'PQ', 'SQ', 'RQ', 'num_cats'], 355 | floatfmt=".2f", tablefmt='fancy_grid', 356 | tabular_data=[ 357 | ['all'] + list(map(lambda x: x.mean(), results)) + [len(catIds)], 358 | ['things'] + list(map(lambda x: x[isthing].mean(), results)) + [sum(isthing)], 359 | ['stuff'] + list(map(lambda x: x[~isthing].mean(), results)) + [sum(1 - isthing)], 360 | ] 361 | ) 362 | 363 | headers = ( 364 | 'PQ', 'SQ', 'RQ', 365 | 'num_matched', 'gt_unmatched', 'pd_unmatched', 'tot_gt_inst', 366 | 'isthing' 367 | ) 368 | results = np.array( 369 | list(results) + [gt_matched, gt_unmatched, pd_unmatched, gt_tot_inst, isthing] 370 | ) 371 | results = results.T 372 | data_frame = pd.DataFrame(results, columns=headers, index=catNames) 373 | cat_table = tabulate( 374 | data_frame, headers='keys', floatfmt=".2f", tablefmt='fancy_grid' 375 | ) 376 | return data_frame, overall_table, cat_table 377 | 378 | @staticmethod 379 | def pq_compute_single_img(cats, gt_ann, gt_mask, pd_ann, pd_mask): 380 | """ 381 | This is the original eval function refactored for readability 382 | """ 383 | pq_stat = PQStat() 384 | gt_segms = {el['id']: el for el in gt_ann['segments_info']} 385 | pd_segms = {el['id']: el for el in pd_ann['segments_info']} 386 | 387 | # predicted segments area calculation + prediction sanity checks 388 | pd_labels_set = set(el['id'] for el in pd_ann['segments_info']) 389 | labels, labels_cnt = np.unique(pd_mask, return_counts=True) 390 | for label, label_cnt in zip(labels, labels_cnt): 391 | if label not in pd_segms: 392 | if label == VOID: 393 | continue 394 | raise KeyError( 395 | ('In the image with ID {} ' 396 | 'segment with ID {} is presented in PNG ' 397 | 'and not presented in JSON.').format(gt_ann['image_id'], label) 398 | ) 399 | pd_segms[label]['area'] = int(label_cnt) 400 | pd_labels_set.remove(label) 401 | if pd_segms[label]['category_id'] not in cats: 402 | raise KeyError( 403 | ('In the image with ID {} ' 404 | 'segment with ID {} has unknown ' 405 | 'category_id {}.').format( 406 | gt_ann['image_id'], label, pd_segms[label]['category_id']) 407 | ) 408 | if len(pd_labels_set) != 0: 409 | raise KeyError( 410 | ('In the image with ID {} ' 411 | 'the following segment IDs {} are presented ' 412 | 'in JSON and not presented in PNG.').format( 413 | gt_ann['image_id'], list(pd_labels_set)) 414 | ) 415 | 416 | # confusion matrix calculation 417 | gt_vs_pd = gt_mask.astype(np.uint64) * OFFSET + pd_mask.astype(np.uint64) 418 | gt_pd_itrsct = {} 419 | labels, labels_cnt = np.unique(gt_vs_pd, return_counts=True) 420 | for label, intersection in zip(labels, labels_cnt): 421 | gt_id, pd_id = label // OFFSET, label % OFFSET 422 | gt_pd_itrsct[(gt_id, pd_id)] = intersection 423 | 424 | # count all matched pairs 425 | gt_matched, pd_matched = set(), set() 426 | for label_tuple, intersection in gt_pd_itrsct.items(): 427 | gt_label, pd_label = label_tuple 428 | if gt_label not in gt_segms: 429 | continue 430 | if pd_label not in pd_segms: 431 | continue 432 | 433 | gt_seg, pd_seg = gt_segms[gt_label], pd_segms[pd_label] 434 | union = pd_seg['area'] + gt_seg['area'] \ 435 | - intersection - gt_pd_itrsct.get((VOID, pd_label), 0) 436 | iou = intersection / union 437 | if iou > 0.1: 438 | gt_seg['matchings'][pd_seg['sid']] = iou 439 | pd_seg['matchings'][gt_seg['sid']] = iou 440 | 441 | if gt_seg['iscrowd'] == 1: 442 | continue 443 | if gt_seg['category_id'] != pd_seg['category_id']: 444 | continue 445 | 446 | if iou > 0.5: 447 | gt_cat_id = gt_seg['category_id'] 448 | pq_stat[gt_cat_id].tp += 1 449 | pq_stat[gt_cat_id].iou += iou 450 | gt_matched.add(gt_label) 451 | pd_matched.add(pd_label) 452 | gt_seg['match_state'] = _SEGMENT_MATCHED 453 | pd_seg['match_state'] = _SEGMENT_MATCHED 454 | 455 | # count false negatives 456 | # HC: assumption each category in image can only have a single crowd segment! 457 | # each img each cat, all crowd segments are merged into 1 segment. well 458 | crowd_cat_segid = {} 459 | for gt_label, gt_info in gt_segms.items(): 460 | if gt_label in gt_matched: 461 | continue 462 | # crowd segments are ignored; 463 | if gt_info['iscrowd'] == 1: 464 | crowd_cat_segid[gt_info['category_id']] = gt_label 465 | continue 466 | pq_stat[gt_info['category_id']].fn += 1 467 | gt_info['match_state'] = _SEGMENT_UNMATCHED 468 | 469 | # count false positives 470 | for pd_label, pd_info in pd_segms.items(): 471 | if pd_label in pd_matched: 472 | continue 473 | # intersection of the segment with VOID 474 | intersection = gt_pd_itrsct.get((VOID, pd_label), 0) 475 | # plus intersection with corresponding CROWD region if it exists 476 | if pd_info['category_id'] in crowd_cat_segid: 477 | intersection += gt_pd_itrsct.get( 478 | (crowd_cat_segid[pd_info['category_id']], pd_label), 0 479 | ) 480 | # predicted segment is ignored if more than half of the segment 481 | # correspond to VOID and CROWD regions 482 | if intersection / pd_info['area'] > 0.5: 483 | continue 484 | pq_stat[pd_info['category_id']].fp += 1 485 | pd_info['match_state'] = _SEGMENT_UNMATCHED 486 | return pq_stat 487 | 488 | 489 | class BreakdownPolicy(): 490 | def __init__(self): 491 | self.flags = [] 492 | 493 | def breakdown(self, gt_segs, pd_segs): 494 | pass 495 | 496 | 497 | class DummyBreakdown(BreakdownPolicy): 498 | def __init__(self): 499 | self.flags = ['sector1', 'sector2'] 500 | 501 | def breakdown(self, gt_segs, pd_segs): 502 | import numpy.random as npr 503 | gt_flags = [ npr.choice(self.flags) for _ in range(len(gt_segs)) ] 504 | for i, seg in enumerate(gt_segs.values()): 505 | seg['breakdown_flag'] = gt_flags[i] 506 | 507 | pd_flags = [ npr.choice(self.flags) for _ in range(len(pd_segs)) ] 508 | for i, seg in enumerate(pd_segs.values()): 509 | seg['breakdown_flag'] = pd_flags[i] 510 | 511 | 512 | class BoxScaleBreakdown(BreakdownPolicy): 513 | def __init__(self): 514 | # self.flags = ['tiny', 'small', 'medium', 'large', 'huge'] 515 | # self.scale_thresholds = [ 16 ** 2, 32 ** 2, 64 ** 2, 128 ** 2] 516 | 517 | self.flags = ['small', 'medium', 'large'] 518 | self.scale_thresholds = [32 ** 2, 128 ** 2] 519 | 520 | def breakdown(self, gt_segs, pd_segs): 521 | thresh, flags = self.scale_thresholds, self.flags 522 | gt_areas = [ 523 | seg['bbox'][-1] * seg['bbox'][-2] for seg in gt_segs.values() 524 | ] 525 | gt_flags = [ 526 | flags[bisect.bisect_right(thresh, s_area)] for s_area in gt_areas 527 | ] 528 | # give each gt the flag 529 | for i, g_seg in enumerate(gt_segs.values()): 530 | g_seg['breakdown_flag'] = gt_flags[i] 531 | 532 | for p_seg in pd_segs.values(): 533 | matchings = p_seg['matchings'] 534 | flag = None 535 | if len(matchings) == 0: 536 | area = p_seg['bbox'][-1] * p_seg['bbox'][-2] 537 | flag = flags[bisect.bisect_right(thresh, area)] 538 | else: 539 | gt_s_sid = list(matchings.keys())[0] 540 | flag = gt_segs[gt_s_sid]['breakdown_flag'] 541 | p_seg['breakdown_flag'] = flag 542 | 543 | 544 | class MaskScaleBreakdown(BreakdownPolicy): 545 | def __init__(self): 546 | # self.flags = ['tiny', 'small', 'medium', 'large', 'huge'] 547 | # self.scale_thresholds = [ 16 ** 2, 32 ** 2, 64 ** 2, 128 ** 2] 548 | 549 | self.flags = ['small', 'medium', 'large'] 550 | self.scale_thresholds = [32 ** 2, 128 ** 2] 551 | 552 | def breakdown(self, gt_segs, pd_segs): 553 | thresh, flags = self.scale_thresholds, self.flags 554 | gt_areas = [ seg['area'] for seg in gt_segs.values() ] 555 | gt_flags = [ 556 | flags[bisect.bisect_right(thresh, s_area)] for s_area in gt_areas 557 | ] 558 | # give each gt the flag 559 | for i, g_seg in enumerate(gt_segs.values()): 560 | g_seg['breakdown_flag'] = gt_flags[i] 561 | 562 | for p_seg in pd_segs.values(): 563 | matchings = p_seg['matchings'] 564 | flag = None 565 | if len(matchings) == 0: 566 | area = p_seg['area'] 567 | flag = flags[bisect.bisect_right(thresh, area)] 568 | else: 569 | gt_s_sid = list(matchings.keys())[0] 570 | flag = gt_segs[gt_s_sid]['breakdown_flag'] 571 | p_seg['breakdown_flag'] = flag 572 | 573 | 574 | class BoxAspectRatioBreakdown(BreakdownPolicy): 575 | def __init__(self): 576 | self.flags = [] 577 | 578 | def breakdown(self, gt_segs, pd_segs): 579 | pass 580 | 581 | 582 | policy_register = { 583 | 'dummy': DummyBreakdown, 584 | 'bbox_scale': BoxScaleBreakdown, 585 | 'mask_scale': MaskScaleBreakdown 586 | } 587 | 588 | 589 | class StatsBreakdown(): 590 | def __init__(self, gt_json_meta_fname, pd_json_meta_fname, breakdown_policy): 591 | analyzer = PanopticEvalAnalyzer(gt_json_meta_fname, pd_json_meta_fname) 592 | self.gt_segs = analyzer.gt.segs 593 | self.pd_segs = analyzer.pd.segs 594 | self.analyzer = analyzer 595 | self.policy = policy_register[breakdown_policy]() 596 | self.policy.breakdown(self.gt_segs, self.pd_segs) 597 | self.verify_policy_execution(self.policy, self.gt_segs, self.pd_segs) 598 | 599 | @staticmethod 600 | def verify_policy_execution(policy, gt_segs, pd_segs): 601 | '''make sure that each seg has been given a flag''' 602 | flags = policy.flags 603 | for seg in gt_segs.values(): 604 | assert seg['breakdown_flag'] in flags 605 | for seg in pd_segs.values(): 606 | assert seg['breakdown_flag'] in flags 607 | 608 | def aggregate(self): 609 | res = [] 610 | for flag in self.policy.flags: 611 | dataframe, overall_table, cat_table = self.analyzer.summarize(flag) 612 | res.append(dataframe) 613 | res = pd.concat(res, axis=1) 614 | 615 | # print results in semicolon separated format so that I can transfer to google doc 616 | print(self.policy.flags) 617 | # upper left corner of the table is 'name' 618 | cols = ';'.join(['name'] + list(res.columns)) 619 | print(cols) 620 | for catName, row in res.iterrows(): 621 | row = [ '{:.2f}'.format(elem) for elem in row.values ] 622 | score_str = ';'.join([catName] + row) 623 | print(score_str) 624 | 625 | 626 | def test(): 627 | model = 'pcv' 628 | dset = 'cityscapes' 629 | split = 'val' 630 | gt_json_fname = '/share/data/vision-greg/panoptic/{}/annotations/{}.json'.format(dset, split) 631 | pd_json_fname = '/home-nfs/whc/panout/{}/{}/{}/pred.json'.format(model, dset, split) 632 | 633 | # analyzer = PanopticEvalAnalyzer(gt_json_fname, pd_json_fname) 634 | # _, overall_table, cat_table = analyzer.aggregate( 635 | # gt_stats=analyzer.gt.match_summarize(), 636 | # pd_stats=analyzer.pd.match_summarize(), 637 | # cats=analyzer.gt.cats 638 | # ) 639 | # print(cat_table) 640 | # print(overall_table) 641 | 642 | breakdown_stats = StatsBreakdown(gt_json_fname, pd_json_fname, 'mask_scale') 643 | breakdown_stats.aggregate() 644 | 645 | 646 | def draw_failure_cases_spatially(): 647 | model = 'pcv' 648 | dset = 'cityscapes' 649 | split = 'val' 650 | gt_json_fname = '/share/data/vision-greg/panoptic/{}/annotations/{}.json'.format(dset, split) 651 | pd_json_fname = '/home-nfs/whc/panout/{}/{}/{}/pred.json'.format(model, dset, split) 652 | 653 | analyzer = PanopticEvalAnalyzer(gt_json_fname, pd_json_fname) 654 | # plot unmatched gt 655 | gt_accumulator = np.zeros((1024, 2048)) 656 | # plt.imshow(gt_accumulator) 657 | gt = analyzer.gt 658 | for k, seg in gt.segs.items(): 659 | if seg['match_state'] == _SEGMENT_UNMATCHED: 660 | x, y, w, h = seg['bbox'] 661 | c = [y + h // 2, x + w // 2] 662 | y, x = c 663 | gt_accumulator[y, x] = gt_accumulator[y, x] + 1 664 | return gt_accumulator 665 | 666 | 667 | if __name__ == "__main__": 668 | draw_failure_cases_spatially() 669 | --------------------------------------------------------------------------------