├── src
├── __init__.py
├── datasets
│ ├── samplers
│ │ ├── __init__.py
│ │ ├── iteration_based_batch_sampler.py
│ │ └── grouped_batch_sampler.py
│ ├── __init__.py
│ ├── coco_pan.py
│ ├── mapillary_pan.py
│ ├── cityscapes_pan.py
│ ├── cityscapes_sem.py
│ ├── augmentations
│ │ ├── __init__.py
│ │ └── augmentations.py
│ ├── base.py
│ └── gt_producer.py
├── models
│ ├── components
│ │ ├── __init__.py
│ │ ├── resnet.py
│ │ └── FPN.py
│ ├── base_model.py
│ ├── PFPN_pan.py
│ ├── PFPN_d2.py
│ ├── __init__.py
│ └── panoptic_base.py
├── pcv
│ ├── components
│ │ ├── __init__.py
│ │ ├── grid_specs.py
│ │ ├── ballot.py
│ │ └── snake.py
│ ├── inference
│ │ └── __init__.py
│ ├── gaussian_smooth
│ │ ├── __init__.py
│ │ ├── vis.py
│ │ └── prob_tsr.py
│ ├── __init__.py
│ ├── pcv_basic.py
│ ├── pcv.py
│ ├── pcv_boundless.py
│ ├── pcv_smooth.py
│ ├── pcv_igc.py
│ └── pcv_igc_boundless.py
├── optimizer.py
├── scheduler.py
├── utils.py
├── config.py
├── reporters.py
├── metric.py
├── pan_eval.py
├── loss.py
├── pan_vis.py
└── pan_analyzer.py
├── .gitignore
├── LICENSE
├── run.py
└── README.md
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/datasets/samplers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/components/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/pcv/components/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/pcv/inference/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/pcv/gaussian_smooth/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/optimizer.py:
--------------------------------------------------------------------------------
1 | from torch.optim import (
2 | SGD, Adam
3 | )
4 |
--------------------------------------------------------------------------------
/src/pcv/__init__.py:
--------------------------------------------------------------------------------
1 | from ..utils import dynamic_load_py_object
2 |
3 |
4 | def get_pcv(pcv_cfg):
5 | module = dynamic_load_py_object(__name__, pcv_cfg.name)
6 | return module()
7 |
--------------------------------------------------------------------------------
/src/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from panoptic.utils import dynamic_load_py_object
2 |
3 |
4 | def get_dataset_module(dset_name):
5 | dataset = dynamic_load_py_object(
6 | package_name=__name__, module_name=dset_name
7 | )
8 | return dataset
9 |
--------------------------------------------------------------------------------
/src/datasets/coco_pan.py:
--------------------------------------------------------------------------------
1 | from .base import PanopticSeg
2 |
3 |
4 | class COCO_Pan(PanopticSeg):
5 | aspect_grouping = True
6 |
7 | def __init__(self, *args, **kwargs):
8 | super().__init__(name='coco', *args, **kwargs)
9 | self.meta['stuff_pred_thresh'] = 4096
10 |
--------------------------------------------------------------------------------
/src/datasets/mapillary_pan.py:
--------------------------------------------------------------------------------
1 | from .base import PanopticSeg
2 |
3 |
4 | class Mapillary_Pan(PanopticSeg):
5 | aspect_grouping = False
6 |
7 | def __init__(self, *args, **kwargs):
8 | super().__init__(name='mapillary', *args, **kwargs)
9 | self.meta['stuff_pred_thresh'] = 2048
10 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | build/
2 | dist/
3 | *.egg-info/
4 | *.egg
5 | *.py[cod]
6 | __pycache__/
7 | *.so
8 | *~
9 |
10 | # due to using tox and pytest
11 | .tox
12 | .cache
13 |
14 | # notebook checkpoints
15 | .ipynb_checkpoints
16 | .vscode/
17 |
18 | # exp folder
19 | exp/
20 | new_world/
21 | ablations/
22 | notebooks/
23 | test/
24 |
--------------------------------------------------------------------------------
/src/datasets/cityscapes_pan.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .base import PanopticSeg
3 |
4 |
5 | class Cityscapes_Pan(PanopticSeg):
6 | aspect_grouping = False
7 |
8 | def __init__(self, *args, **kwargs):
9 | super().__init__(name='cityscapes', *args, **kwargs)
10 | self.meta['stuff_pred_thresh'] = 2048
11 |
--------------------------------------------------------------------------------
/src/datasets/cityscapes_sem.py:
--------------------------------------------------------------------------------
1 | from panoptic.datasets.base import SemanticSeg
2 | import os.path as osp
3 |
4 |
5 | class Cityscapes_Sem(SemanticSeg):
6 | def __init__(self, split, transforms=None):
7 | super().__init__(name='cityscapes', split=split, transforms=transforms)
8 | self.vanilla_lbl_root = osp.join(
9 | '/share/data/vision-greg/cityscapes/gtFine', split
10 | )
11 |
--------------------------------------------------------------------------------
/src/scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import (
2 | StepLR, MultiStepLR, CosineAnnealingLR, LambdaLR
3 | )
4 |
5 |
6 | def ups_scheduler(optimizer, milestones, gamma, warmup=0):
7 | assert len(milestones) > 0
8 | assert warmup < milestones[0]
9 | from bisect import bisect_right # do not pollute scheduler.py namespace
10 |
11 | def lr_policy(step):
12 | init_multiplier = 1.0
13 | if step < warmup:
14 | return (0.1 + 0.9 * step / warmup) * init_multiplier
15 | exponent = bisect_right(milestones, step)
16 | return init_multiplier * (gamma ** exponent)
17 |
18 | return LambdaLR(optimizer, lr_lambda=lr_policy)
19 |
--------------------------------------------------------------------------------
/src/datasets/augmentations/__init__.py:
--------------------------------------------------------------------------------
1 | from panoptic.utils import dynamic_load_py_object
2 |
3 |
4 | class Compose(object):
5 | def __init__(self, augmentations):
6 | self.augmentations = augmentations
7 |
8 | def __call__(self, img, mask):
9 |
10 | # assert img.size == mask.size
11 | for a in self.augmentations:
12 | img, mask = a(img, mask)
13 |
14 | return img, mask
15 |
16 |
17 | def get_composed_augmentations(aug_list):
18 | if aug_list is None:
19 | print("Using No Augmentations")
20 | return None
21 |
22 | augmentations = []
23 | for aug_meta in aug_list:
24 | name = aug_meta.name
25 | params = aug_meta.params
26 | aug = dynamic_load_py_object(
27 | package_name=__name__,
28 | module_name='augmentations', obj_name=name
29 | )
30 | instance = aug(**params)
31 | augmentations.append(instance)
32 |
33 | return Compose(augmentations)
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Haochen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/pcv/components/grid_specs.py:
--------------------------------------------------------------------------------
1 | specs = [
2 | # (size, num_rounds)
3 | [ # 243, 233 bins
4 | (1, 1), (3, 4), (9, 3), (27, 3)
5 | ],
6 | [ # 243
7 | (3, 4), (9, 3), (27, 3)
8 | ],
9 | [ # size 231, 233 bins
10 | (3, 3), (7, 3), (21, 4),
11 | ],
12 | [ # size 243, 41 bins
13 | (1, 1), (3, 1), (9, 1), (27, 1), (81, 1)
14 | ],
15 | [ # size 225, 121 bins
16 | (3, 2), (15, 2), (25, 3)
17 | ],
18 | # [ # size 231, 233 bins
19 | # (3, 2), (5, 2), (7, 2), (21, 4)
20 | # ],
21 |
22 | [ # size 250, 25 * 25 = 625 bins
23 | (15, 7),
24 | ],
25 | ]
26 |
27 |
28 | igc_specs = [
29 | {
30 | 'base': [(1, 1), (3, 4), (9, 3), (27, 3)],
31 | 'pyramid': { # radius: grid
32 | (9 - 1) // 2: [(1, 1), (3, 4), (9, 3), (27, 3)],
33 | (27 - 1) // 2: [(3, 4), (9, 3), (27, 3)],
34 | (81 - 1) // 2: [(9, 4), (27, 3)],
35 | (243 - 1) // 2: [(27, 4) ]
36 | }
37 | # 'pyramid': { # radius: grid
38 | # (9 - 1) // 2: [(1, 1), (3, 4), (9, 3), (27, 3)],
39 | # (27 - 1) // 2: [(3, 4), (9, 3), (27, 3)],
40 | # (81 - 1) // 2: [(9, 4), (27, 3)],
41 | # (243 - 1) // 2: [(9, 4), (27, 3)]
42 | # }
43 | }
44 | ]
45 |
--------------------------------------------------------------------------------
/src/datasets/samplers/iteration_based_batch_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | '''
3 | HC: this is not used. It creates more problems than it purports to solve.
4 | '''
5 | from torch.utils.data.sampler import BatchSampler
6 |
7 |
8 | class IterationBasedBatchSampler(BatchSampler):
9 | """
10 | Wraps a BatchSampler, resampling from it until
11 | a specified number of iterations have been sampled
12 | """
13 |
14 | def __init__(self, batch_sampler, num_iterations, start_iter=0):
15 | self.batch_sampler = batch_sampler
16 | self.num_iterations = num_iterations
17 | self.start_iter = start_iter
18 |
19 | def __iter__(self):
20 | iteration = self.start_iter
21 | while iteration <= self.num_iterations:
22 | # if the underlying sampler has a set_epoch method, like
23 | # DistributedSampler, used for making each process see
24 | # a different split of the dataset, then set it
25 | if hasattr(self.batch_sampler.sampler, "set_epoch"):
26 | self.batch_sampler.sampler.set_epoch(iteration)
27 | for batch in self.batch_sampler:
28 | iteration += 1
29 | if iteration > self.num_iterations:
30 | break
31 | yield batch
32 |
33 | def __len__(self):
34 | return self.num_iterations
35 |
--------------------------------------------------------------------------------
/src/models/base_model.py:
--------------------------------------------------------------------------------
1 | class BaseModel:
2 | """
3 | A model should encapsulate everything related to its training and inference.
4 | Only the most generic APIs should be exposed for maximal flexibility
5 | """
6 | def __init__(self):
7 | self.cfg = None
8 | self.net = None
9 | self.criteria = None
10 | self.optimizer = None
11 | self.scheduler = None
12 | self.log_writer = None
13 | self.total_train_epoch = -1
14 | self.curr_epoch = 0
15 |
16 | self.is_train = None
17 |
18 | def set_train_eval_state(self, to_train):
19 | assert to_train in (True, False)
20 | if self.is_train is None or self.is_train != to_train:
21 | if to_train:
22 | self.net.train()
23 | else:
24 | self.net.eval()
25 | self.is_train = to_train
26 |
27 | def ingest_train_input(self, input):
28 | raise NotImplementedError()
29 |
30 | def infer(self, input):
31 | raise NotImplementedError()
32 |
33 | def optimize_params(self):
34 | raise NotImplementedError()
35 |
36 | def advance_to_next_epoch(self):
37 | raise NotImplementedError()
38 |
39 | def load_latest_checkpoint_if_available(self, manager):
40 | raise NotImplementedError()
41 |
42 | def write_checkpoint(self, manager):
43 | raise NotImplementedError()
44 |
45 | def add_log_writer(self, log_writer):
46 | self.log_writer = log_writer
47 |
48 | def log_statistics(self, step, level=0):
49 | pass
50 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import datetime
3 | from fabric.utils.timer import Timer
4 |
5 |
6 | def dynamic_load_py_object(
7 | package_name, module_name, obj_name=None
8 | ):
9 | '''Dynamically import an object.
10 | Assumes that the object lives at module_name.py:obj_name
11 | If obj_name is not given, assume it shares the same name as the module.
12 | obj_name is case insensitive e.g. kitti.py/KiTTI is valid
13 | '''
14 | if obj_name is None:
15 | obj_name = module_name
16 | # use relative import syntax .targt_name
17 | target_module = importlib.import_module(
18 | '.{}'.format(module_name), package=package_name
19 | )
20 | target_obj = None
21 | for name, cls in target_module.__dict__.items():
22 | if name.lower() == obj_name.lower():
23 | target_obj = cls
24 |
25 | if target_obj is None:
26 | raise ValueError(
27 | "No object in {}.{}.py whose lower-case name matches {}".format(
28 | package_name, module_name, obj_name)
29 | )
30 |
31 | return target_obj
32 |
33 |
34 | class CompoundTimer():
35 | def __init__(self):
36 | self.data, self.compute = Timer(), Timer()
37 |
38 | def eta(self, curr_step, tot_step):
39 | remaining_steps = tot_step - curr_step
40 | avg = self.data.avg + self.compute.avg
41 | eta_seconds = avg * remaining_steps
42 | return str(datetime.timedelta(seconds=int(eta_seconds)))
43 |
44 | def __repr__(self):
45 | return "data avg {:.2f}, compute avg {:.2f}".format(
46 | self.data.avg, self.compute.avg)
47 |
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 | from pprint import pformat
3 | from fabric.cluster.configurator import Configurator
4 | from fabric.deploy.sow import NodeTracer
5 |
6 |
7 | _ERR_MSG_UNINITED_ = 'global config state not inited'
8 |
9 |
10 | class _CFG():
11 | def __init__(self):
12 | self.manager = None
13 | self.settings = {}
14 | self.inited = False
15 |
16 | def init_state(self, cfg_yaml_path, override_opts):
17 | self.manager = Configurator(cfg_yaml_path)
18 | self.settings = edict(self.manager.config)
19 | if override_opts is not None:
20 | self.merge_from_list(override_opts)
21 | self.inited = True
22 |
23 | def __getattr__(self, name):
24 | assert self.inited, _ERR_MSG_UNINITED_
25 | return getattr(self.settings, name)
26 |
27 | def __str__(self):
28 | assert self.inited, _ERR_MSG_UNINITED_
29 | return pformat(self.settings)
30 |
31 | def merge_from_list(self, cfg_list):
32 | """Merge config (keys, values) in a list (e.g., from command line) into
33 | this cfg. For example, `cfg_list = ['FOO.BAR', 0.5]`.
34 | """
35 | assert len(cfg_list) % 2 == 0, \
36 | "Override list has odd length: {}; it must be a list of pairs".format(
37 | cfg_list
38 | )
39 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
40 | self.trace_and_replace(full_key, v)
41 |
42 | def trace_and_replace(self, key, val):
43 | tracer = NodeTracer(self.settings)
44 | tracer.advance_pointer(key)
45 | tracer.replace([], val)
46 | self.settings = tracer.state
47 |
48 |
49 | _C = _CFG()
50 | cfg = _C
51 |
--------------------------------------------------------------------------------
/src/pcv/pcv_basic.py:
--------------------------------------------------------------------------------
1 | from .pcv import PCV_base
2 | from .components.snake import Snake
3 | from .components.ballot import Ballot
4 | from .components.grid_specs import specs
5 | from .inference.mask_from_vote import MaskFromVote, MFV_CatSeparate
6 |
7 | from .. import cfg
8 |
9 |
10 | class PCV_Basic(PCV_base):
11 | def __init__(self):
12 | spec = specs[cfg.pcv.grid_inx]
13 | self.num_groups = cfg.pcv.num_groups
14 | self.centroid_mode = cfg.pcv.centroid
15 | self.raw_spec = spec
16 | field_diam, grid_spec = Snake.flesh_out_grid_spec(spec)
17 | self.grid_spec = grid_spec
18 | self._vote_mask = Snake.paint_trail_mask(field_diam, grid_spec)
19 | self._ballot_module = None
20 |
21 | @property
22 | def ballot_module(self):
23 | # instantiate on demand to prevent train time data loading workers to
24 | # hold on to GPU memory
25 | if self._ballot_module is None:
26 | self._ballot_module = Ballot(self.raw_spec, self.num_groups).cuda()
27 | return self._ballot_module
28 |
29 | # 1 for bull's eye center, 1 for abstain vote
30 | @property
31 | def num_bins(self):
32 | return len(self.grid_spec)
33 |
34 | @property
35 | def num_votes(self):
36 | return 1 + self.num_bins
37 |
38 | @property
39 | def vote_mask(self):
40 | return self._vote_mask
41 |
42 | @property
43 | def query_mask(self):
44 | return super().query_mask
45 |
46 | def centroid_from_ins_mask(self, ins_mask):
47 | return super().centroid_from_ins_mask(ins_mask)
48 |
49 | def discrete_vote_inx_from_offset(self, offset):
50 | return self._discretize_offset(self.vote_mask, offset)
51 |
52 | def mask_from_sem_vote_tsr(self, dset_meta, sem_pred, vote_pred):
53 | # make the meta data actually required explicit!!
54 | if self.num_groups == 1:
55 | mfv = MaskFromVote(dset_meta, self, sem_pred, vote_pred)
56 | else:
57 | mfv = MFV_CatSeparate(dset_meta, self, sem_pred, vote_pred)
58 | pan_mask, meta = mfv.infer_panoptic_mask()
59 | return pan_mask, meta
60 |
--------------------------------------------------------------------------------
/src/models/components/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torchvision.models import resnet50, resnet101, resnext101_32x8d
4 |
5 | def set_bn_eval(m):
6 | """freeze batch norms, per RT's method"""
7 | classname = m.__class__.__name__
8 | if classname.find('BatchNorm') != -1:
9 | m.eval()
10 | for _p in m.parameters():
11 | _p.requires_grad = False
12 |
13 |
14 | def set_conv_stride_to_1(m):
15 | """
16 | A ResNet Bottleneck has 2 types of modules potentially with stride 2
17 | conv2: 3x3 conv used to reduce the spatial dimension of features
18 | downsample[0]: 1x1 conv used to change shape of original in case the learnt
19 | residual has incompatible channel or spatial dim
20 | It does not make sense to give this 1x1 downsample conv dilation 2.
21 | Hence the if condition testing for conv kernel size
22 | """
23 | classname = m.__class__.__name__
24 | if classname.find('Conv') != -1 and m.stride == (2, 2):
25 | m.stride = (1, 1)
26 | if m.kernel_size == (3, 3):
27 | m.dilation = (2, 2)
28 | m.padding = (2, 2)
29 |
30 |
31 | class ResNetFeatureExtractor(nn.Module):
32 | def __init__(self, resnet):
33 | super().__init__()
34 | layer0 = nn.Sequential(
35 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool
36 | )
37 | self.layers = nn.ModuleList([
38 | layer0,
39 | resnet.layer1, resnet.layer2,
40 | resnet.layer3, resnet.layer4
41 | ])
42 | self.extractor = nn.Sequential(*self.layers)
43 |
44 | def forward(self, input):
45 | return self.extractor(input)
46 |
47 | def train(self, mode=True):
48 | if mode:
49 | super().train(True)
50 | # self.apply(set_bn_eval)
51 | else:
52 | super().train(False)
53 |
54 |
55 | def resnet50_gn(pretrained=True):
56 | model = resnet50(norm_layer=lambda x: nn.GroupNorm(32, x))
57 | model.load_state_dict(torch.load('/share/data/vision-greg/rluo/model/pytorch-resnet/resnet_gn50-pth.pth'))
58 | return model
59 |
60 |
61 | def resnet101_gn(pretrained=True):
62 | model = resnet101(norm_layer=lambda x: nn.GroupNorm(32, x))
63 | model.load_state_dict(torch.load('/share/data/vision-greg/rluo/model/pytorch-resnet/resnet_gn101-pth.pth'))
64 | return model
65 |
66 |
67 | if __name__ == '__main__':
68 | from torchvision.models import resnet18
69 | model = ResNetFeatureExtractor(resnet18(pretrained=False))
70 | input = torch.rand(size=(1, 3, 64, 64))
71 | with torch.no_grad():
72 | out = model(input)
73 | print(out.shape)
74 |
--------------------------------------------------------------------------------
/src/pcv/pcv.py:
--------------------------------------------------------------------------------
1 | """Module for dilation based pixel consensus votin
2 | For now hardcode 3x3 voting kernel and see
3 | """
4 | from abc import ABCMeta, abstractmethod, abstractproperty
5 | import numpy as np
6 | from scipy.ndimage.measurements import center_of_mass
7 | from ..box_and_mask import get_xywh_bbox_from_binary_mask
8 | from .. import cfg
9 |
10 |
11 | class PCV_base(metaclass=ABCMeta):
12 | def __init__(self):
13 | # store the necessary modules
14 | pass
15 |
16 | @abstractproperty
17 | def num_bins(self):
18 | pass
19 |
20 | @abstractproperty
21 | def num_votes(self):
22 | pass
23 |
24 | @abstractproperty
25 | def vote_mask(self):
26 | pass
27 |
28 | @abstractproperty
29 | def query_mask(self):
30 | """
31 | Flipped from inside out
32 | """
33 | diam = len(self.vote_mask)
34 | radius = (diam - 1) // 2
35 | center = (radius, radius)
36 | mask_shape = self.vote_mask.shape
37 | offset_grid = np.indices(mask_shape).transpose(1, 2, 0)[..., ::-1]
38 | offsets = center - offset_grid
39 | allegiance = self.discrete_vote_inx_from_offset(
40 | offsets.reshape(-1, 2)
41 | ).reshape(mask_shape)
42 | return allegiance
43 |
44 | @abstractmethod
45 | def centroid_from_ins_mask(self, ins_mask):
46 | mode = self.centroid_mode
47 | assert mode in ('bbox', 'cmass')
48 | if mode == 'bbox':
49 | bbox = get_xywh_bbox_from_binary_mask(ins_mask)
50 | x, y, w, h = bbox
51 | return [x + w // 2, y + h // 2]
52 | else:
53 | y, x = center_of_mass(ins_mask)
54 | x, y = int(x), int(y)
55 | return [x, y]
56 |
57 | @abstractmethod
58 | def discrete_vote_inx_from_offset(self, offset):
59 | pass
60 |
61 | @staticmethod
62 | def _discretize_offset(vote_mask, offset):
63 | """
64 | Args:
65 | offset: [N, 2] array of offset towards each pixel's own center,
66 | Each row is filled with (x, y) pair, not (y, x)!
67 | """
68 | shape = offset.shape
69 | assert len(shape) == 2 and shape[1] == 2
70 | offset = offset[:, ::-1] # swap to (y, x) for indexing
71 |
72 | diam = len(vote_mask)
73 | radius = (diam - 1) // 2
74 | center = (radius, radius)
75 | coord = offset + center
76 | del offset
77 |
78 | ret = -1 * np.ones(len(coord), dtype=np.int32)
79 | valid_inds = np.where(
80 | (coord[:, 0] >= 0) & (coord[:, 0] < diam)
81 | & (coord[:, 1] >= 0) & (coord[:, 1] < diam)
82 | )[0]
83 | _y_inds, _x_inds = coord[valid_inds].T
84 | vals = vote_mask[_y_inds, _x_inds]
85 | ret[valid_inds] = vals
86 | return ret
87 |
88 | @abstractmethod
89 | def mask_from_sem_vote_tsr(self, dset_meta, sem_pred, vote_pred):
90 | pass
91 |
--------------------------------------------------------------------------------
/src/pcv/pcv_boundless.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .pcv import PCV_base
3 | from .components.snake import Snake
4 | from .components.ballot import Ballot
5 | from .components.grid_specs import specs
6 | from .inference.mask_from_vote import MaskFromVote, MFV_CatSeparate
7 |
8 | from .. import cfg
9 |
10 |
11 | class PCV_Boundless(PCV_base):
12 | def __init__(self):
13 | spec = specs[cfg.pcv.grid_inx]
14 | self.num_groups = cfg.pcv.num_groups
15 | self.centroid_mode = cfg.pcv.centroid
16 | self.raw_spec = spec
17 | field_diam, grid_spec = Snake.flesh_out_grid_spec(spec)
18 | self.grid_spec = grid_spec
19 | self._vote_mask = Snake.paint_trail_mask(field_diam, grid_spec)
20 | self.boundless_vote_mask = \
21 | Snake.paint_bound_ignore_trail_mask(field_diam, grid_spec)
22 | self._ballot_module = None
23 |
24 | @property
25 | def ballot_module(self):
26 | # instantiate on demand to prevent train time data loading workers to
27 | # hold on to GPU memory
28 | if self._ballot_module is None:
29 | self._ballot_module = Ballot(self.raw_spec, self.num_groups).cuda()
30 | return self._ballot_module
31 |
32 | # 1 for bull's eye center, 1 for abstain vote
33 | @property
34 | def num_bins(self):
35 | return len(self.grid_spec)
36 |
37 | @property
38 | def num_votes(self):
39 | return 1 + self.num_bins
40 |
41 | @property
42 | def vote_mask(self):
43 | return self._vote_mask
44 |
45 | @property
46 | def query_mask(self):
47 | """
48 | Flipped from inside out
49 | """
50 | diam = len(self.vote_mask)
51 | radius = (diam - 1) // 2
52 | center = (radius, radius)
53 | mask_shape = self.vote_mask.shape
54 | offset_grid = np.indices(mask_shape).transpose(1, 2, 0)[..., ::-1]
55 | offsets = center - offset_grid
56 | # allegiance = self.discrete_vote_inx_from_offset(
57 | # offsets.reshape(-1, 2)
58 | # ).reshape(mask_shape)
59 | allegiance = self._discretize_offset(
60 | self.vote_mask, offsets.reshape(-1, 2)
61 | ).reshape(mask_shape)
62 | return allegiance
63 |
64 | def centroid_from_ins_mask(self, ins_mask):
65 | return super().centroid_from_ins_mask(ins_mask)
66 |
67 | def discrete_vote_inx_from_offset(self, offset):
68 | # note that when assigning votes, use the boundary ignore vote mask
69 | return self._discretize_offset(self.boundless_vote_mask, offset)
70 |
71 | def mask_from_sem_vote_tsr(self, dset_meta, sem_pred, vote_pred):
72 | # make the meta data actually required explicit!!
73 | if self.num_groups == 1:
74 | mfv = MaskFromVote(dset_meta, self, sem_pred, vote_pred)
75 | else:
76 | mfv = MFV_CatSeparate(dset_meta, self, sem_pred, vote_pred)
77 | pan_mask, meta = mfv.infer_panoptic_mask()
78 | return pan_mask, meta
79 |
--------------------------------------------------------------------------------
/src/models/PFPN_pan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from panoptic.models.panoptic_base import PanopticBase
5 | from panoptic.models.components.resnet import (
6 | ResNetFeatureExtractor, resnet50, resnet101, resnet50_gn, resnet101_gn
7 | )
8 | from panoptic.models.components.FPN import FPN
9 |
10 |
11 | primary_backbones = {
12 | 'resnet50': resnet50,
13 | 'resnet101': resnet101,
14 | 'resnet50_gn': resnet50_gn,
15 | 'resnet101_gn': resnet101_gn
16 | }
17 |
18 |
19 | class PFPN_pan(PanopticBase):
20 | def instantiate_network(self, cfg):
21 | self.net = Net(
22 | num_classes=self.dset_meta['num_classes'], num_votes=self.pcv.num_votes,
23 | **cfg.model.params
24 | )
25 |
26 |
27 | class Net(nn.Module):
28 | def __init__(self, num_classes, num_votes, FPN_C, backbone='resnet50', deep_classifier=False):
29 | super().__init__()
30 | backbone_f = primary_backbones[backbone]
31 | fea_extractor = ResNetFeatureExtractor(backbone_f(pretrained=True))
32 | self.distilled_fpn = FPN(fea_extractor, FPN_C)
33 | if deep_classifier:
34 | self.sem_classifier = nn.Sequential(
35 | nn.Conv2d(FPN_C, 256, kernel_size=3, padding=1),
36 | nn.ReLU(True),
37 | nn.Dropout(0.1),
38 | nn.Conv2d(256, num_classes, kernel_size=1, bias=True)
39 | )
40 | self.vote_classifier = nn.Sequential(
41 | nn.Conv2d(FPN_C, 256, kernel_size=3, padding=1),
42 | nn.ReLU(True),
43 | nn.Dropout(0.1),
44 | nn.Conv2d(256, num_votes, kernel_size=1, bias=True)
45 | )
46 | else:
47 | self.sem_classifier = nn.Conv2d(
48 | FPN_C, num_classes, kernel_size=1, bias=True
49 | )
50 | self.vote_classifier = nn.Conv2d(
51 | FPN_C, num_votes, kernel_size=1, bias=True
52 | )
53 | # self.sem_classifier = nn.Conv2d(
54 | # FPN_C, num_classes, kernel_size=1, bias=True
55 | # )
56 | # self.vote_classifier = nn.Conv2d(
57 | # FPN_C, num_votes, kernel_size=1, bias=True
58 | # )
59 | self.loss_module = None
60 |
61 | def forward(self, *inputs):
62 | """
63 | If gt is supplied, then compute loss
64 | """
65 | x = inputs[0]
66 | x = self.distilled_fpn(x)
67 | sem_pred = self.sem_classifier(x)
68 | vote_pred = self.vote_classifier(x)
69 |
70 | if len(inputs) > 1:
71 | assert self.loss_module is not None
72 | loss = self.loss_module(sem_pred, vote_pred, *inputs[1:])
73 | return loss
74 | else:
75 | return sem_pred, vote_pred
76 |
77 |
78 | if __name__ == '__main__':
79 | model = Net()
80 | input = torch.rand(size=(1, 3, 64, 64))
81 | with torch.no_grad():
82 | out = model(input)
83 | print(out.shape)
84 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import argparse
4 | import random
5 |
6 | import torch
7 | import torch.multiprocessing as mp
8 | import torch.distributed as dist
9 |
10 | import panoptic
11 | from panoptic.entry import Entry
12 | from fabric.utils.mailer import ExceptionEmail
13 | from fabric.utils.logging import setup_logging
14 | from fabric.utils.git import git_version
15 |
16 | logger = setup_logging(__file__)
17 |
18 |
19 | def main():
20 | parser = argparse.ArgumentParser(description="run script")
21 | parser.add_argument('--command', '-c', type=str, default='train')
22 | parser.add_argument('--debug', action='store_true')
23 | parser.add_argument('--dist', action='store_true')
24 | parser.add_argument(
25 | "opts",
26 | help="Modify config options using the command-line",
27 | default=None,
28 | nargs=argparse.REMAINDER,
29 | )
30 | args = parser.parse_args()
31 |
32 | git_hash = git_version(osp.dirname(panoptic.__file__))
33 | logger.info('git repository hash: {}'.format(git_hash))
34 |
35 | global fly
36 | # only decorate when not in interative mode; Bugs are expected there.
37 | if 'INTERACTIVE' not in os.environ:
38 | recipient = 'whc@ttic.edu'
39 | logger.info("decorating with warning email to {}".format(recipient))
40 | email_subject_headline = "{} tripped".format(
41 | # take the immediate dirname as email label
42 | osp.dirname(osp.abspath(__file__)).split('/')[-1]
43 | )
44 | fly = ExceptionEmail(
45 | subject=email_subject_headline, address=recipient
46 | )(fly)
47 |
48 | ngpus = torch.cuda.device_count()
49 | port = random.randint(10000, 20000)
50 | argv = (ngpus, args.command, args.debug, args.opts, port)
51 | if args.dist:
52 | mp.spawn(fly, nprocs=ngpus, args=argv)
53 | else:
54 | fly(None, *argv)
55 |
56 |
57 | def fly(rank, ngpus, command, debug, opts, port):
58 | distributed = rank is not None # and not debug
59 | if distributed: # multiprocess distributed training
60 | dist.init_process_group(
61 | world_size=ngpus, rank=rank,
62 | backend='nccl', init_method=f'tcp://127.0.0.1:{port}',
63 | )
64 | assert command == 'train' # for now only train uses mp distributed
65 | torch.cuda.set_device(rank)
66 |
67 | entry = Entry(
68 | __file__, override_opts=opts, debug=debug,
69 | mp_distributed=distributed, rank=rank, world_size=ngpus
70 | )
71 | if command == 'train':
72 | entry.train()
73 | elif command == 'validate': # for evaluate semantic segmentation mean iou
74 | entry.validate(False)
75 | elif command == 'evaluate':
76 | entry.evaluate()
77 | elif command == 'report':
78 | entry.report()
79 | elif command == 'test':
80 | entry.PQ_test(save_output=True)
81 | elif command == 'make_figures':
82 | entry.make_figures()
83 | else:
84 | raise ValueError("unrecognized command")
85 |
86 |
87 | if __name__ == '__main__':
88 | main()
89 |
--------------------------------------------------------------------------------
/src/models/PFPN_d2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from panoptic.models.panoptic_base import PanopticBase, AbstractNet
5 | from panoptic.models.components.resnet import (
6 | ResNetFeatureExtractor, resnet50, resnet101, resnet50_gn, resnet101_gn
7 | )
8 | from panoptic.models.components.FPN import FPN
9 | import detectron2.config
10 | import detectron2.modeling
11 | from detectron2.modeling.backbone import build_backbone
12 | from detectron2.modeling.meta_arch.semantic_seg import build_sem_seg_head
13 | from detectron2.checkpoint import DetectionCheckpointer
14 |
15 | from panoptic import cfg
16 |
17 | primary_backbones = {
18 | 'resnet50': resnet50,
19 | 'resnet101': resnet101,
20 | 'resnet50_gn': resnet50_gn,
21 | 'resnet101_gn': resnet101_gn
22 | }
23 |
24 |
25 | class PFPN_d2(PanopticBase):
26 | def instantiate_network(self, cfg):
27 | self.net = Net(
28 | num_classes=self.dset_meta['num_classes'], num_votes=self.pcv.num_votes,
29 | **cfg.model.params
30 | )
31 |
32 |
33 | class Net(AbstractNet):
34 | def __init__(self, num_classes, num_votes,
35 | fix_bn=True,
36 | freeze_at=2,
37 | norm='GN',
38 | fpn_norm='',
39 | conv_dims=128,
40 | **kwargs): # FPN_C, backbone='resnet50'):
41 | super().__init__()
42 | d2_cfg = detectron2.config.get_cfg()
43 | d2_cfg.merge_from_file('/home-nfs/whc/glab/detectron2/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_1x.yaml')
44 | if fix_bn:
45 | d2_cfg.MODEL.RESNETS.NORM = "FrozenBN"
46 | else:
47 | d2_cfg.MODEL.RESNETS.NORM = "BN"
48 | d2_cfg.MODEL.BACKBONE.FREEZE_AT = freeze_at
49 |
50 | d2_cfg.MODEL.FPN.NORM = 'BN'
51 | self.backbone = build_backbone(d2_cfg)
52 |
53 | d2_cfg.MODEL.SEM_SEG_HEAD.NORM = norm
54 | d2_cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = conv_dims
55 |
56 | d2_cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = num_classes
57 | self.sem_seg_head = build_sem_seg_head(d2_cfg, self.backbone.output_shape())
58 | self.sem_classifier = self.sem_seg_head.predictor
59 | del self.sem_seg_head.predictor
60 |
61 | d2_cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = num_votes
62 | tmp = build_sem_seg_head(d2_cfg, self.backbone.output_shape())
63 | self.vote_classifier = tmp.predictor
64 | assert cfg.data.dataset.params.caffe_mode == True
65 |
66 | checkpointer = DetectionCheckpointer(self)
67 | checkpointer.load(d2_cfg.MODEL.WEIGHTS)
68 |
69 | def stage1(self, x):
70 | return self.backbone(x)
71 |
72 | def stage2(self, features):
73 | # copy from detectron2/modeling/meta_arch/semantic_seg.py
74 | # why? because segFPNHead doesn't accept training==True and targets==None
75 | for i, f in enumerate(self.sem_seg_head.in_features):
76 | if i == 0:
77 | x = self.sem_seg_head.scale_heads[i](features[f])
78 | else:
79 | x = x + self.sem_seg_head.scale_heads[i](features[f])
80 | # x = self.sem_seg_head.predictor(x)
81 | return x, x
82 | # x = F.interpolate(x, scale_factor=self.common_stride, mode="bilinear", align_corners=False)
83 |
84 |
85 | if __name__ == '__main__':
86 | model = Net(10,10)
87 | model.eval()
88 | input = torch.rand(size=(1, 3, 224, 224))
89 | with torch.no_grad():
90 | out = model(input)
91 | print(out[0].shape, out[1].shape)
92 |
--------------------------------------------------------------------------------
/src/pcv/pcv_smooth.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .pcv import PCV_base
4 | from .components.snake import Snake
5 | from .components.ballot import Ballot
6 | from .components.grid_specs import specs
7 | from .gaussian_smooth.prob_tsr import MakeProbTsr
8 | from .inference.mask_from_vote import MaskFromVote, MFV_CatSeparate
9 |
10 | from .. import cfg
11 |
12 |
13 | def flesh_out_grid(spec):
14 | field_diam, grid_spec = Snake.flesh_out_grid_spec(spec)
15 | vote_mask = Snake.paint_trail_mask(field_diam, grid_spec)
16 | return vote_mask
17 |
18 |
19 | def flesh_out_spec(spec_group):
20 | ret = {'base': None, 'pyramid': dict()}
21 | ret['base'] = flesh_out_grid(spec_group['base'])
22 | for k, v in spec_group['pyramid'].items():
23 | ret['pyramid'][k] = flesh_out_grid(v)
24 | return ret
25 |
26 |
27 | class PCV_Smooth(PCV_base):
28 | def __init__(self, grid_inx):
29 | self.num_groups = cfg.pcv.num_groups
30 | self.spec = specs[cfg.pcv.grid_inx]
31 | self.centroid_mode = cfg.pcv.centroid
32 | field_diam, grid_spec = Snake.flesh_out_grid_spec(self.spec)
33 | self.grid_spec = grid_spec
34 | self._vote_mask = Snake.paint_trail_mask(field_diam, grid_spec)
35 | self._ballot_module = None
36 |
37 | maker = MakeProbTsr(
38 | self.spec, field_diam, grid_spec, self.vote_mask, var=0.05
39 | )
40 | self.spatial_prob_tsr = maker.compute_voting_prob_tsr(normalize=True)
41 |
42 | @property
43 | def ballot_module(self):
44 | # instantiate on demand to prevent train time data loading workers to
45 | # hold on to GPU memory
46 | if self._ballot_module is None:
47 | self._ballot_module = Ballot(self.spec, self.num_groups).cuda()
48 | return self._ballot_module
49 |
50 | # 1 for bull's eye center, 1 for abstain vote
51 | @property
52 | def num_bins(self):
53 | return len(self.grid_spec)
54 |
55 | @property
56 | def num_votes(self):
57 | return 1 + self.num_bins
58 |
59 | @property
60 | def vote_mask(self):
61 | return self._vote_mask
62 |
63 | @property
64 | def query_mask(self):
65 | return super().query_mask
66 |
67 | def centroid_from_ins_mask(self, ins_mask):
68 | return super().centroid_from_ins_mask(ins_mask)
69 |
70 | def discrete_vote_inx_from_offset(self, offset):
71 | return self._discretize_offset(self.vote_mask, offset)
72 |
73 | def smooth_prob_tsr_from_offset(self, offset):
74 | """
75 | Args:
76 | offset: [N, 2] array of offset towards each pixel's own center,
77 | Each row is filled with (x, y) pair, not (y, x)!
78 | Returns:
79 | vote_tsr: [N, num_votes] of float tsr with each entry being a prob
80 | """
81 | shape = offset.shape
82 | assert len(shape) == 2 and shape[1] == 2
83 | offset = offset[:, ::-1] # swap to (y, x) for indexing
84 |
85 | diam = len(self.spatial_prob_tsr)
86 | radius = (diam - 1) // 2
87 | center = (radius, radius)
88 | coord = offset + center
89 |
90 | # [N, num_votes]
91 | tsr = np.zeros((len(offset), self.num_votes), dtype=np.float32)
92 | valid_inds = np.where(
93 | (coord[:, 0] >= 0) & (coord[:, 0] < diam)
94 | & (coord[:, 1] >= 0) & (coord[:, 1] < diam)
95 | )[0]
96 | _y_inds, _x_inds = coord[valid_inds].T
97 | tsr[valid_inds] = self.spatial_prob_tsr[_y_inds, _x_inds]
98 | return tsr
99 |
100 | def mask_from_sem_vote_tsr(self, dset_meta, sem_pred, vote_pred):
101 | # make the meta data actually required explicit!!
102 | if self.num_groups == 1:
103 | mfv = MaskFromVote(dset_meta, self, sem_pred, vote_pred)
104 | else:
105 | mfv = MFV_CatSeparate(dset_meta, self, sem_pred, vote_pred)
106 | pan_mask, meta = mfv.infer_panoptic_mask()
107 | return pan_mask, meta
108 |
--------------------------------------------------------------------------------
/src/pcv/pcv_igc.py:
--------------------------------------------------------------------------------
1 | from bisect import bisect_left
2 | import numpy as np
3 |
4 | from .pcv import PCV_base
5 | from .components.snake import Snake
6 | from .components.ballot import Ballot
7 | from .components.grid_specs import igc_specs
8 | from .inference.mask_from_vote import MaskFromVote, MFV_CatSeparate
9 |
10 | from .. import cfg
11 |
12 |
13 | def flesh_out_grid(spec):
14 | field_diam, grid_spec = Snake.flesh_out_grid_spec(spec)
15 | vote_mask = Snake.paint_trail_mask(field_diam, grid_spec)
16 | return vote_mask
17 |
18 |
19 | def flesh_out_spec(spec_group):
20 | ret = {'base': None, 'pyramid': dict()}
21 | ret['base'] = flesh_out_grid(spec_group['base'])
22 | for k, v in spec_group['pyramid'].items():
23 | ret['pyramid'][k] = flesh_out_grid(v)
24 | return ret
25 |
26 |
27 | class PCV_IGC(PCV_base):
28 | def __init__(self):
29 | # grid inx for now is a dummy
30 | spec_group = igc_specs[cfg.pcv.grid_inx]
31 | self.num_groups = cfg.pcv.num_groups
32 | self.centroid_mode = cfg.pcv.centroid
33 | self.raw_spec = spec_group['base']
34 | _, self.grid_spec = Snake.flesh_out_grid_spec(self.raw_spec)
35 | self.grid_group = flesh_out_spec(spec_group)
36 | self._vote_mask = self.grid_group['base']
37 | self._ballot_module = None
38 | self.coalesce_thresh = list(self.grid_group['pyramid'].keys())
39 |
40 | @property
41 | def ballot_module(self):
42 | # instantiate on demand to prevent train time data loading workers to
43 | # hold on to GPU memory
44 | if self._ballot_module is None:
45 | self._ballot_module = Ballot(self.raw_spec, self.num_groups).cuda()
46 | return self._ballot_module
47 |
48 | # 1 for bull's eye center, 1 for abstain vote
49 | @property
50 | def num_bins(self):
51 | return len(self.grid_spec)
52 |
53 | @property
54 | def num_votes(self):
55 | return 1 + self.num_bins
56 |
57 | @property
58 | def vote_mask(self):
59 | return self._vote_mask
60 |
61 | @property
62 | def query_mask(self):
63 | return super().query_mask
64 |
65 | def centroid_from_ins_mask(self, ins_mask):
66 | return super().centroid_from_ins_mask(ins_mask)
67 |
68 | def discrete_vote_inx_from_offset(self, offset):
69 | return self._discretize_offset(self.vote_mask, offset)
70 |
71 | def tensorized_vote_from_offset(self, offset):
72 | """
73 | Args:
74 | offset: [N, 2] array of offset towards each pixel's own center,
75 | Each row is filled with (x, y) pair, not (y, x)!
76 | Returns:
77 | vote_tsr: [N, num_votes] of bool tsr where 0/1 denotes gt entries.
78 | """
79 | # dispatch to the proper grid
80 | base_mask = self.grid_group['base']
81 | radius = (len(base_mask) - 1) // 2
82 | max_offset = min(radius, np.abs(offset).max())
83 | inx = bisect_left(self.coalesce_thresh, max_offset)
84 | key = self.coalesce_thresh[inx]
85 | vote_mask = self.grid_group['pyramid'][key]
86 |
87 | tsr = np.zeros((len(offset), self.num_votes), dtype=np.bool) # [N, num_votes]
88 | gt_indices = self._discretize_offset(vote_mask, offset) # [N, ]
89 | for inx in np.unique(gt_indices):
90 | if inx == -1:
91 | continue
92 | entries = np.unique(base_mask[vote_mask == inx])
93 | inds = np.where(gt_indices == inx)[0]
94 | tsr[inds.reshape(-1, 1), entries.reshape(1, -1)] = True
95 | return tsr
96 |
97 | def mask_from_sem_vote_tsr(self, dset_meta, sem_pred, vote_pred):
98 | # make the meta data actually required explicit!!
99 | if self.num_groups == 1:
100 | mfv = MaskFromVote(dset_meta, self, sem_pred, vote_pred)
101 | else:
102 | mfv = MFV_CatSeparate(dset_meta, self, sem_pred, vote_pred)
103 | pan_mask, meta = mfv.infer_panoptic_mask()
104 | return pan_mask, meta
105 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pixel Consensus Voting (CVPR 2020)
2 |
3 |
4 |

5 |

6 |
7 |
8 |
9 | The core of our approach, Pixel Consensus Voting, is a framework for instance segmentation based on the Generalized Hough transform. Pixels cast discretized, probabilistic votes for the likely regions that contain instance centroids. At the detected peaks that emerge in the voting heatmap, backprojection is applied to collect pixels and produce instance masks. Unlike a sliding window detector that densely enumerates object proposals, our method detects instances as a result of the consensus among pixel-wise votes. We implement vote aggregation and backprojection using native operators of a convolutional neural network. The discretization of centroid voting reduces the training of instance segmentation to pixel labeling, analogous and complementary to FCN-style semantic segmentation, leading to an efficient and unified architecture that jointly models things and stuff. We demonstrate the effectiveness of our pipeline on COCO and Cityscapes Panoptic Segmentation and obtain competitive results.
10 |
11 | ## Quick Intro
12 | - The codebase contains the essential ingredients of PCV, including various spatial discretization schemes and convolutional backprojection inference. The network backbone is a simple FPN on ResNet.
13 | - Visualzier 1 ([vis.py](src/vis.py)): loads a single image into a dynamic, interacive interface that allows users to click on pixels to inspect model prediction. It is built on matplotlib interactive API and jupyter widgets. Under the hood it's React.
14 | - Visualizer 2 ([pan_vis.py](src/pan_vis.py)): A global inspector that take panoptic segmentation prediction and displays prediction segments against ground truth. Useful to track down which images make the most serious error and how.
15 |
16 | - The core of PCV is contained in [src/pcv](src/pcv). The results reported in the paper uses [src/pcv/pcv_basic](src/pcv/pcv_basic.py). There are also a few modification ideas that didn't work out e.g. "inner grid collapse" ([src/pcv/pcv_igc](src/pcv/pcv_igc.py)), erasing boundary loss [src/pcv/pcv_boundless](src/pcv/pcv_boundless.py), smoothened gt assignment [src/pcv/pcv_smooth](src/pcv/pcv_smooth.py).
17 | - The deconv voting filter weight intializaiton is in [src/pcv/components/ballot.py](src/pcv/components/ballot.py). Different deconv discretization schemes can be found in [src/pcv/components/grid_specs.py](src/pcv/components/grid_specs.py). [src/pcv/components/snake.py](src/pcv/components/snake.py) manages the generation of snake grid on which pcv operates.
18 |
19 | - The backprojection code is in [src/pcv/inference/mask_from_vote.py](src/pcv/inference/mask_from_vote.py). Since this is a non-standard procedure of convolving a filter to do equality comparison, I implemented a simple conv using advanced indexing. See the function [src/pcv/inference/mask_from_vote.py:unroll_img_inds](src/pcv/inference/mask_from_vote.py#L110-L119). For a fun side-project, I am thinking about rewriting the backprojection in Julia and generate GPU code (ptx) directly through LLVM. That way we don't have to deal with CUDA kernels that are hard to maintain.
20 | - The main entry point is [run.py](run.py) and [src/entry.py](src/entry.py)
21 |
22 |
23 | ## Getting Started
24 |
25 | ### Dependencies
26 | - pytorch==1.4.0
27 | - fabric (personal toolkit that needs to be re-factored in)
28 | - pycocotools
29 |
30 | ~~~python
31 | python run.py -c train
32 | python run.py -c evaluate
33 | ~~~
34 | runs the default PCV configuration reported in Table 3 of the paper.
35 |
36 |
37 | ## Bibtex
38 |
39 | ```bibtex
40 | @inproceedings{pcv2020,
41 | title={Pixel consensus voting for panoptic segmentation},
42 | author={Wang, Haochen and Luo, Ruotian and Maire, Michael and Shakhnarovich, Greg},
43 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
44 | pages={9464--9473},
45 | year={2020}
46 | }
47 | ```
48 |
--------------------------------------------------------------------------------
/src/pcv/components/ballot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | from .snake import Snake
6 |
7 |
8 | class Ballot(nn.Module):
9 | def __init__(self, spec, num_groups):
10 | super().__init__()
11 | self.num_groups = num_groups
12 | votes = []
13 | smears = []
14 |
15 | splits = []
16 | acc_size = spec[0][0] # inner most size
17 | for i, (size, num_rounds) in enumerate(spec):
18 | inner_blocks = acc_size // size
19 | total_blocks = inner_blocks + 2 * num_rounds
20 | acc_size += 2 * num_rounds * size
21 |
22 | prepend = (i == 0)
23 | num_in_chnls = (total_blocks ** 2 - inner_blocks ** 2) + int(prepend)
24 | num_in_chnls *= num_groups
25 | num_out_chnls = 1 * num_groups
26 | splits.append(num_in_chnls)
27 | deconv_vote = nn.ConvTranspose2d(
28 | in_channels=num_in_chnls, out_channels=num_out_chnls,
29 | kernel_size=total_blocks, stride=1,
30 | padding=(total_blocks - 1) // 2 * size, dilation=size,
31 | groups=num_groups, bias=False
32 | )
33 | # each group [ num_chnl, 1, 3, 3 ] -> [num_groups * num_chnl, 1, 3, 3]
34 | throw_away = 0 if prepend else inner_blocks
35 | weight = torch.cat([
36 | get_voting_deonv_kernel_weight(
37 | side=total_blocks, throw_away=throw_away
38 | )
39 | for _ in range(num_groups)
40 | ], dim=0)
41 | deconv_vote.weight.data.copy_(weight)
42 |
43 | votes.append(deconv_vote)
44 | smear_ker = nn.AvgPool2d(
45 | # in_channels=num_out_chnls, out_channels=num_out_chnls,
46 | kernel_size=size, stride=1, padding=int( (size - 1) / 2 )
47 | )
48 | # smear_ker = nn.ConvTranspose2d(
49 | # in_channels=num_out_chnls, out_channels=num_out_chnls,
50 | # kernel_size=size, stride=1, padding=int( (size - 1) / 2 ),
51 | # groups=num_groups, bias=False
52 | # )
53 | # smear_ker.weight.data.fill_(1 / (size ** 2 ))
54 | smears.append(smear_ker)
55 |
56 | self.splits = splits
57 | self.votes = nn.ModuleList(votes)
58 | self.smears = nn.ModuleList(smears)
59 |
60 | @torch.no_grad()
61 | def forward(self, x):
62 | num_groups = self.num_groups
63 | splitted = x.split(self.splits, dim=1)
64 | assert len(splitted) == len(self.votes)
65 | output = []
66 | for i in range(len(splitted)):
67 | # if i == (len(splitted) - 1):
68 | # continue
69 | x = splitted[i]
70 | if num_groups > 1: # painful lesson
71 | _, C, H, W = x.shape
72 | x = x.reshape(-1, C // num_groups, num_groups, H, W)
73 | x = x.transpose(1, 2)
74 | x = x.reshape(-1, C, H, W)
75 | x = self.votes[i](x)
76 | x = self.smears[i](x)
77 | output.append(x)
78 | return sum(output)
79 |
80 |
81 | def get_voting_deonv_kernel_weight(side, throw_away, return_tsr=True):
82 | """
83 | The logic is neat; it makes use of the fact that negatives will be indexed
84 | from the last channel, and one simply needs to throw those away
85 | """
86 | assert throw_away <= side
87 | throw_away = throw_away ** 2
88 | rounds = (side - 1) // 2
89 | spatial_inds = Snake.paint_trail_mask(
90 | *Snake.flesh_out_grid_spec([[1, rounds]])
91 | ) - throw_away
92 | weight = np.zeros(shape=(side, side, side ** 2))
93 | dim_0_inds, dim_1_inds = np.ix_(range(side), range(side))
94 | weight[dim_0_inds, dim_1_inds, spatial_inds] = 1
95 | if throw_away > 0:
96 | weight = weight[:, :, :-throw_away]
97 | if not return_tsr:
98 | return weight
99 | else:
100 | weight = np.expand_dims(weight.transpose(2, 0, 1), axis=1)
101 | kernel = torch.as_tensor(weight).float()
102 | return kernel
103 |
--------------------------------------------------------------------------------
/src/pcv/gaussian_smooth/vis.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from ipywidgets import Output
4 | from matplotlib.patches import Circle
5 |
6 | from panoptic.pcv.gaussian_smooth.prob_tsr import MakeProbTsr
7 | from panoptic.pcv.components.snake import Snake
8 |
9 | from panoptic.vis import Visualizer as BaseVisualizer
10 |
11 |
12 | class Plot():
13 | def __init__(self, ax, bin_center_yx, vote_mask, spatial_prob):
14 | self.ax = ax
15 | self.bin_center_yx = bin_center_yx
16 | self.vote_mask = vote_mask
17 | self.spatial_prob = spatial_prob
18 |
19 | # self.pressed_xy = None
20 | self.dot = None
21 | self.texts = None
22 | self.init_artists()
23 | self.render_visual()
24 |
25 | def init_artists(self):
26 | if self.dot is not None:
27 | self.dot.remove()
28 |
29 | if self.texts is not None:
30 | assert isinstance(self.texts, (tuple, list))
31 | for elem in self.texts:
32 | elem.remove()
33 |
34 | self.dot = None
35 | self.texts = None
36 |
37 | def render_visual(self):
38 | self.ax.imshow(self.vote_mask)
39 |
40 | def press_coord(self, x, y, button):
41 | del button # ignoring button for now
42 | # self.pressed_xy = x, y
43 | self.init_artists()
44 | self.render_single_dot(x, y)
45 | self.render_prob_dist(x, y)
46 |
47 | def render_prob_dist(self, x, y):
48 | thresh = 0
49 | dist = self.spatial_prob[y, x]
50 | inds = np.where(dist > thresh)[0]
51 | probs = dist[inds] * 100
52 | # print(probs)
53 | bin_centers = self.bin_center_yx[inds]
54 |
55 | acc = []
56 | for cen, p in zip(bin_centers, probs):
57 | y, x = cen
58 | _a = self.ax.text(
59 | x, y, s='{:.2f}'.format(p), fontsize='small', color='r'
60 | )
61 | acc.append(_a)
62 | self.texts = acc
63 |
64 | def query_coord(self, x, y, button):
65 | pass
66 |
67 | def motion_coord(self, x, y):
68 | self.press_coord(x, y, None)
69 |
70 | def render_single_dot(self, x, y):
71 | cir = Circle((x, y), radius=0.5, color='white')
72 | self.ax.add_patch(cir)
73 | self.dot = cir
74 |
75 |
76 | class Visualizer(BaseVisualizer):
77 | def __init__(self):
78 | spec = [ # 243, 233 bins
79 | (3, 4), (9, 3), (27, 3)
80 | ]
81 | # spec = [
82 | # (3, 3), (7, 3), (21, 4)
83 | # ]
84 | diam, grid_spec = Snake.flesh_out_grid_spec(spec)
85 | vote_mask = Snake.paint_trail_mask(diam, grid_spec)
86 | maker = MakeProbTsr(spec, diam, grid_spec, vote_mask)
87 | spatial_prob = maker.compute_voting_prob_tsr()
88 |
89 | self.vote_mask = vote_mask
90 | self.spatial_prob = spatial_prob
91 |
92 | radius = (diam - 1) // 2
93 | center = np.array((radius, radius))
94 | self.bin_center_yx = grid_spec[:, :2] + center
95 |
96 | self.output_widget = Output()
97 | self.init_state()
98 | self.pressed = False
99 | np.set_printoptions(
100 | formatter={'float': lambda x: "{:.3f}".format(x)}
101 | )
102 |
103 | def vis(self):
104 | fig = plt.figure(figsize=(10, 10), constrained_layout=True)
105 | self.fig = fig
106 | self.canvas = fig.canvas
107 | self.plots = dict()
108 |
109 | key = 'spatial prob dist'
110 | ax = fig.add_subplot(111)
111 | ax.set_title(key)
112 | self.plots[key] = Plot(
113 | ax, self.bin_center_yx, self.vote_mask, self.spatial_prob
114 | )
115 | self.connect()
116 |
117 |
118 | def test():
119 | # spec = [ # 243, 233 bins
120 | # (1, 1), (3, 4), (9, 3), (27, 3)
121 | # ]
122 | spec = [ # 243, 233 bins
123 | (3, 1), (9, 1)
124 | ]
125 | diam, grid_spec = Snake.flesh_out_grid_spec(spec)
126 | vote_mask = Snake.paint_trail_mask(diam, grid_spec)
127 | maker = MakeProbTsr(spec, diam, grid_spec, vote_mask)
128 |
129 |
130 | if __name__ == "__main__":
131 | test()
132 |
--------------------------------------------------------------------------------
/src/models/components/FPN.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def conv_gn_relu(in_C, out_C, kernel_size, use_relu=True):
8 | assert kernel_size in (3, 1)
9 | pad = (kernel_size - 1) // 2
10 | num_groups = out_C // 16 # note this is hardcoded
11 | module = [
12 | nn.Conv2d(in_C, out_C, kernel_size, padding=pad, bias=False),
13 | # nn.GroupNorm(num_groups, num_channels=out_C),
14 | nn.BatchNorm2d(out_C)
15 | ]
16 | if use_relu:
17 | module.append(nn.ReLU(inplace=True))
18 | return nn.Sequential(*module)
19 |
20 |
21 | class FPN(nn.Module):
22 | def __init__(self, resnet_extractor, FPN_distill_C):
23 | super().__init__()
24 | self.strides = [32, 16, 8, 4]
25 | self.dims = [2048, 1024, 512, 256] # assume resnet 50 and above
26 | self.FPN_feature_C = 256
27 | self.FPN_distill_C = FPN_distill_C
28 |
29 | self.backbone = resnet_extractor
30 | self.FPN_create_modules = nn.ModuleList()
31 | self.FPN_distill_modules = nn.ModuleList()
32 |
33 | for in_dim, stride in zip(self.dims, self.strides):
34 | self.FPN_create_modules.append(
35 | nn.ModuleDict({
36 | 'lateral': conv_gn_relu(
37 | in_dim, self.FPN_feature_C,
38 | kernel_size=1, use_relu=False
39 | ),
40 | 'refine': conv_gn_relu(
41 | self.FPN_feature_C, self.FPN_feature_C,
42 | kernel_size=3, use_relu=False
43 | )
44 | })
45 | )
46 |
47 | for stride in self.strides:
48 | self.FPN_distill_modules.append(
49 | self.get_distill_module(
50 | stride / 4, self.FPN_feature_C, self.FPN_distill_C
51 | )
52 | )
53 |
54 | @staticmethod
55 | def get_distill_module(upsample_ratio, in_C, out_C):
56 | levels = math.log(upsample_ratio, 2)
57 | assert levels.is_integer()
58 | levels = int(levels)
59 | if levels == 0:
60 | return conv_gn_relu(in_C, out_C, kernel_size=3)
61 |
62 | module = []
63 | for _ in range(levels):
64 | module.append(
65 | # maybe use ordered dict?
66 | nn.Sequential(
67 | conv_gn_relu(in_C, out_C, kernel_size=3),
68 | nn.Upsample(
69 | scale_factor=2, mode='bilinear', align_corners=False)
70 | )
71 | )
72 | in_C = out_C
73 | return nn.Sequential(*module)
74 |
75 | def collect_resnet_features(self, x):
76 | acc = []
77 | for layer in self.backbone.layers:
78 | x = layer(x)
79 | acc.append(x)
80 | acc = acc[1:] # throw away layer0 output
81 | return acc[::-1] # high level features first
82 |
83 | def create_FPN(self, res_features):
84 | FPN_features = []
85 | prev = None
86 | for tsr, curr_module in zip(res_features, self.FPN_create_modules):
87 | tsr = curr_module['lateral'](tsr)
88 | if prev is not None:
89 | prev = F.interpolate(prev, scale_factor=2, mode='nearest')
90 | tsr = tsr + prev
91 | prev = tsr
92 | refined = curr_module['refine'](tsr)
93 | FPN_features.append(refined)
94 | return FPN_features
95 |
96 | def distill_FPN(self, FPN_features):
97 | acc = []
98 | for tsr, curr_module in zip(FPN_features, self.FPN_distill_modules):
99 | tsr = curr_module(tsr)
100 | acc.append(tsr)
101 | return sum(acc)
102 |
103 | def forward(self, x):
104 | res_features = self.collect_resnet_features(x)
105 | FPN_features = self.create_FPN(res_features)
106 | final = self.distill_FPN(FPN_features)
107 | return final
108 |
109 |
110 | if __name__ == '__main__':
111 | from panoptic.models.components.resnet import ResNetFeatureExtractor
112 | from torchvision.models import resnet50
113 | extractor = ResNetFeatureExtractor(resnet50(pretrained=False))
114 | fpn = FPN(extractor)
115 | input = torch.rand(size=(1, 3, 64, 64))
116 | with torch.no_grad():
117 | output = fpn(input)
118 | print(output.shape)
119 |
--------------------------------------------------------------------------------
/src/pcv/pcv_igc_boundless.py:
--------------------------------------------------------------------------------
1 | from bisect import bisect_left
2 | import numpy as np
3 |
4 | from .pcv import PCV_base
5 | from .components.snake import Snake
6 | from .components.ballot import Ballot
7 | from .components.grid_specs import igc_specs
8 | from .inference.mask_from_vote import MaskFromVote, MFV_CatSeparate
9 |
10 | from .. import cfg
11 |
12 |
13 | def _flesh_out_grid(spec):
14 | '''used only by flesh_out_spec'''
15 | field_diam, grid_spec = Snake.flesh_out_grid_spec(spec)
16 | full_mask = Snake.paint_trail_mask(field_diam, grid_spec)
17 | boundless_mask = Snake.paint_bound_ignore_trail_mask(field_diam, grid_spec)
18 | return {
19 | 'full': full_mask,
20 | 'boundless': boundless_mask
21 | }
22 |
23 |
24 | def flesh_out_spec(spec_group):
25 | ret = {'base': None, 'pyramid': dict()}
26 | ret['base'] = _flesh_out_grid(spec_group['base'])
27 | for k, v in spec_group['pyramid'].items():
28 | ret['pyramid'][k] = _flesh_out_grid(v)
29 | return ret
30 |
31 |
32 | class PCV_IGC_Boundless(PCV_base):
33 | def __init__(self):
34 | # grid inx for now is a dummy
35 | spec_group = igc_specs[cfg.pcv.grid_inx]
36 | self.num_groups = cfg.pcv.num_groups
37 | self.centroid_mode = cfg.pcv.centroid
38 | self.raw_spec = spec_group['base']
39 | _, self.grid_spec = Snake.flesh_out_grid_spec(self.raw_spec)
40 | self.mask_group = flesh_out_spec(spec_group)
41 | self._vote_mask = self.mask_group['base']['full']
42 | self._ballot_module = None
43 | self.coalesce_thresh = list(self.mask_group['pyramid'].keys())
44 |
45 | @property
46 | def ballot_module(self):
47 | # instantiate on demand to prevent train time data loading workers to
48 | # hold on to GPU memory
49 | if self._ballot_module is None:
50 | self._ballot_module = Ballot(self.raw_spec, self.num_groups).cuda()
51 | return self._ballot_module
52 |
53 | # 1 for bull's eye center, 1 for abstain vote
54 | @property
55 | def num_bins(self):
56 | return len(self.grid_spec)
57 |
58 | @property
59 | def num_votes(self):
60 | return 1 + self.num_bins
61 |
62 | @property
63 | def vote_mask(self):
64 | return self._vote_mask
65 |
66 | @property
67 | def query_mask(self):
68 | """
69 | Flipped from inside out
70 | """
71 | diam = len(self.vote_mask)
72 | radius = (diam - 1) // 2
73 | center = (radius, radius)
74 | mask_shape = self.vote_mask.shape
75 | offset_grid = np.indices(mask_shape).transpose(1, 2, 0)[..., ::-1]
76 | offsets = center - offset_grid
77 | # allegiance = self.discrete_vote_inx_from_offset(
78 | # offsets.reshape(-1, 2)
79 | # ).reshape(mask_shape)
80 | allegiance = self._discretize_offset(
81 | self.vote_mask, offsets.reshape(-1, 2)
82 | ).reshape(mask_shape)
83 | return allegiance
84 |
85 | def centroid_from_ins_mask(self, ins_mask):
86 | return super().centroid_from_ins_mask(ins_mask)
87 |
88 | def discrete_vote_inx_from_offset(self, offset):
89 | base_boundless_mask = self.mask_group['base']['boundless']
90 | return self._discretize_offset(base_boundless_mask, offset)
91 |
92 | def tensorized_vote_from_offset(self, offset):
93 | """
94 | Args:
95 | offset: [N, 2] array of offset towards each pixel's own center,
96 | Each row is filled with (x, y) pair, not (y, x)!
97 | Returns:
98 | vote_tsr: [N, num_votes] of bool tsr where 0/1 denotes gt entries.
99 | """
100 | # dispatch to the proper grid
101 | base_mask = self.mask_group['base']['boundless']
102 | radius = (len(base_mask) - 1) // 2
103 | max_offset = min(radius, np.abs(offset).max())
104 | inx = bisect_left(self.coalesce_thresh, max_offset)
105 | key = self.coalesce_thresh[inx]
106 | vote_mask = self.mask_group['pyramid'][key]['boundless']
107 |
108 | tsr = np.zeros((len(offset), self.num_votes), dtype=np.bool) # [N, num_votes]
109 | gt_indices = self._discretize_offset(vote_mask, offset) # [N, ]
110 | for inx in np.unique(gt_indices):
111 | if inx == -1:
112 | continue
113 | entries = np.unique(base_mask[vote_mask == inx])
114 | inds = np.where(gt_indices == inx)[0]
115 | tsr[inds.reshape(-1, 1), entries.reshape(1, -1)] = True
116 | return tsr
117 |
118 | def mask_from_sem_vote_tsr(self, dset_meta, sem_pred, vote_pred):
119 | # make the meta data actually required explicit!!
120 | if self.num_groups == 1:
121 | mfv = MaskFromVote(dset_meta, self, sem_pred, vote_pred)
122 | else:
123 | mfv = MFV_CatSeparate(dset_meta, self, sem_pred, vote_pred)
124 | pan_mask, meta = mfv.infer_panoptic_mask()
125 | return pan_mask, meta
126 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The final assembled model is in assembled.py
3 | The rest of the modules contain helpers and modular components
4 | """
5 | import torch
6 | import torch.nn
7 | from panoptic.utils import dynamic_load_py_object
8 |
9 |
10 | def get_model_module(model_name):
11 | return dynamic_load_py_object(__name__, model_name)
12 |
13 |
14 | def convert_inplace_sync_batchnorm(module, process_group=None):
15 | r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to
16 | `torch.nn.SyncBatchNorm` layer.
17 | Args:
18 | module (nn.Module): containing module
19 | process_group (optional): process group to scope synchronization,
20 | default is the whole world
21 | Returns:
22 | The original module with the converted `torch.nn.SyncBatchNorm` layer
23 | Example::
24 | >>> # Network with nn.BatchNorm layer
25 | >>> module = torch.nn.Sequential(
26 | >>> torch.nn.Linear(20, 100),
27 | >>> torch.nn.BatchNorm1d(100)
28 | >>> ).cuda()
29 | >>> # creating process group (optional)
30 | >>> # process_ids is a list of int identifying rank ids.
31 | >>> process_group = torch.distributed.new_group(process_ids)
32 | >>> sync_bn_module = convert_sync_batchnorm(module, process_group)
33 | """
34 | from inplace_abn import InPlaceABNSync
35 | module_output = module
36 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
37 | module_output = InPlaceABNSync(module.num_features,
38 | module.eps, module.momentum,
39 | module.affine,
40 | activation='identity',
41 | group=process_group)
42 | if module.affine:
43 | module_output.weight.data = module.weight.data.clone().detach()
44 | module_output.bias.data = module.bias.data.clone().detach()
45 | # keep reuqires_grad unchanged
46 | module_output.weight.requires_grad = module.weight.requires_grad
47 | module_output.bias.requires_grad = module.bias.requires_grad
48 | module_output.running_mean = module.running_mean
49 | module_output.running_var = module.running_var
50 | if isinstance(module, torch.nn.ReLU):
51 | module_output = torch.nn.ReLU()
52 | for name, child in module.named_children():
53 | module_output.add_module(name, convert_inplace_sync_batchnorm(child, process_group))
54 | del module
55 | return module_output
56 |
57 | def convert_naive_sync_batchnorm(module, process_group=None):
58 | from detectron2.layers import NaiveSyncBatchNorm
59 | return convert_xbatchnorm(module, NaiveSyncBatchNorm, process_group=None)
60 |
61 | def convert_apex_sync_batchnorm(module, process_group=None):
62 | from apex.parallel import SyncBatchNorm
63 | return convert_xbatchnorm(module, SyncBatchNorm, process_group=None)
64 |
65 |
66 | def convert_xbatchnorm(module, bn_module, process_group=None):
67 | r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to
68 | `torch.nn.SyncBatchNorm` layer.
69 | Args:
70 | module (nn.Module): containing module
71 | process_group (optional): process group to scope synchronization,
72 | default is the whole world
73 | Returns:
74 | The original module with the converted `torch.nn.SyncBatchNorm` layer
75 | Example::
76 | >>> # Network with nn.BatchNorm layer
77 | >>> module = torch.nn.Sequential(
78 | >>> torch.nn.Linear(20, 100),
79 | >>> torch.nn.BatchNorm1d(100)
80 | >>> ).cuda()
81 | >>> # creating process group (optional)
82 | >>> # process_ids is a list of int identifying rank ids.
83 | >>> process_group = torch.distributed.new_group(process_ids)
84 | >>> sync_bn_module = convert_sync_batchnorm(module, process_group)
85 | """
86 | module_output = module
87 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
88 | module_output = bn_module(module.num_features,
89 | module.eps, module.momentum,
90 | module.affine,
91 | module.track_running_stats,
92 | process_group)
93 | if module.affine:
94 | module_output.weight.data = module.weight.data.clone().detach()
95 | module_output.bias.data = module.bias.data.clone().detach()
96 | # keep reuqires_grad unchanged
97 | module_output.weight.requires_grad = module.weight.requires_grad
98 | module_output.bias.requires_grad = module.bias.requires_grad
99 | module_output.running_mean = module.running_mean
100 | module_output.running_var = module.running_var
101 | module_output.num_batches_tracked = module.num_batches_tracked
102 | for name, child in module.named_children():
103 | module_output.add_module(name, convert_xbatchnorm(child, bn_module, process_group))
104 | del module
105 | return module_output
--------------------------------------------------------------------------------
/src/pcv/components/snake.py:
--------------------------------------------------------------------------------
1 | """
2 | This module manages the generation of snake grid on which pcv operates.
3 |
4 | It is intended to be maximally flexible by taking in simple parameters to
5 | achieve various grid configurations for downstream experimentation.
6 | """
7 | import numpy as np
8 |
9 |
10 | class Snake():
11 | def __init__(self):
12 | pass
13 |
14 | @staticmethod
15 | def flesh_out_grid_spec(raw_spec):
16 | """
17 | Args:
18 | raw_spec: [N, 2] each row (size, num_rounds)
19 | """
20 | if not isinstance(raw_spec, np.ndarray):
21 | raw_spec = np.array(raw_spec)
22 | shape = raw_spec.shape
23 | assert len(shape) == 2 and shape[0] > 0 and shape[1] == 2
24 | size = raw_spec[0][0]
25 | trail = [ np.array([0, 0, (size - 1) // 2]).reshape(1, -1), ]
26 | field_diam = size
27 | for size, num_rounds in raw_spec:
28 | for _ in range(num_rounds):
29 | field_diam, _round_trail = Snake.ring_walk(field_diam, size)
30 | trail.append(_round_trail)
31 | trail = np.concatenate(trail, axis=0)
32 | return field_diam, trail
33 |
34 | @staticmethod
35 | def ring_walk(field_diam, body_diam):
36 | assert body_diam > 0 and body_diam % 2 == 1
37 | body_radius = (body_diam - 1) // 2
38 |
39 | assert field_diam > 0 and field_diam % 2 == 1
40 | field_radius = (field_diam - 1) // 2
41 |
42 | assert field_diam % body_diam == 0
43 |
44 | ext_diam = field_diam + 2 * body_diam
45 | ext_radius = field_radius + body_diam
46 | assert ext_diam == ext_radius * 2 + 1
47 |
48 | j = 1 + field_radius + body_radius
49 | # each of the corner coord is the offset from field center
50 | # anticlockwise SE -> NE -> NW -> SW ->
51 | corner_centers = np.array([(+j, +j), (-j, +j), (-j, -j), (+j, -j)])
52 | directs = np.array([(-1, 0), (0, -1), (+1, 0), (0, +1)])
53 | trail = []
54 | num_tiles = ext_diam // body_diam
55 | for corn, dirc in zip(corner_centers, directs):
56 | segment = [corn + i * dirc * body_diam for i in range(num_tiles - 1)]
57 | trail.extend(segment)
58 | trail = np.array(trail) # [N, 2] each row is a center-based offset
59 | sizes = np.array([body_radius] * len(trail)).reshape(-1, 1)
60 | trail = np.concatenate([trail, sizes], axis=1)
61 | return ext_diam, trail # [N, 3]: (offset_y, offset_x, radius)
62 |
63 | @staticmethod
64 | def paint_trail_mask(field_diam, trail, tiling=True):
65 | """
66 | Args:
67 | trail: [N, 3] each row (offset_y, offset_x, radius)
68 | """
69 | assert field_diam > 0 and field_diam % 2 == 1
70 | field_radius = (field_diam - 1) // 2
71 | CEN = np.array((field_radius, field_radius))
72 | trail = trail.copy()
73 | trail[:, :2] += CEN
74 | canvas = -1 * np.ones((field_diam, field_diam), dtype=int)
75 | for i, walk in enumerate(trail):
76 | y, x, r = walk
77 | if tiling:
78 | y, x = y - r, x - r
79 | d = 2 * r + 1
80 | canvas[y: y + d, x: x + d] = i
81 | else:
82 | canvas[y, x] = i
83 | return canvas
84 |
85 | @staticmethod
86 | def paint_bound_ignore_trail_mask(field_diam, trail):
87 | """
88 | Args:
89 | trail: [N, 3] each row (offset_y, offset_x, radius)
90 | """
91 | assert field_diam > 0 and field_diam % 2 == 1
92 | field_radius = (field_diam - 1) // 2
93 | CEN = np.array((field_radius, field_radius))
94 | trail = trail.copy()
95 | trail[:, :2] += CEN
96 | canvas = -1 * np.ones((field_diam, field_diam), dtype=int)
97 | for i, walk in enumerate(trail):
98 | y, x, r = walk
99 | d = 2 * r + 1
100 | boundary_ignore = int(d * 0.12)
101 | # if d > 1: # at least cut 1 unless the grid is of size 1
102 | # boundary_ignore = max(boundary_ignore, 1)
103 | y, x = y - r + boundary_ignore, x - r + boundary_ignore
104 | d = d - 2 * boundary_ignore
105 | canvas[y: y + d, x: x + d] = i
106 | return canvas
107 |
108 | @staticmethod
109 | def vote_channel_splits(raw_spec):
110 | splits = []
111 | acc_size = raw_spec[0][0] # inner most size
112 | for i, (size, num_rounds) in enumerate(raw_spec):
113 | inner_blocks = acc_size // size
114 | total_blocks = inner_blocks + 2 * num_rounds
115 | acc_size += 2 * num_rounds * size
116 | prepend = (i == 0)
117 | num_chnls = (total_blocks ** 2 - inner_blocks ** 2) + int(prepend)
118 | splits.append(num_chnls)
119 | return splits
120 |
121 |
122 | if __name__ == '__main__':
123 | sample_spec = np.array([
124 | # (size, num_rounds)
125 | (1, 1),
126 | (3, 1),
127 | (9, 1)
128 | ])
129 | s = Snake()
130 | diam, trail = s.flesh_out_grid_spec(sample_spec)
131 | mask = s.paint_trail_mask(diam, trail, tiling=True)
132 | print(mask.shape)
133 |
--------------------------------------------------------------------------------
/src/reporters.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import json
4 | import numpy as np
5 | from PIL import Image
6 | import torch.nn.functional as F
7 | from easydict import EasyDict as edict
8 | from .metric import PanMetric as Metric, PQMetric
9 | from panopticapi.utils import rgb2id, id2rgb
10 |
11 | from fabric.io import save_object
12 |
13 | _VALID_ORACLE_MODES = ('full', 'sem', 'vote')
14 |
15 |
16 | class BaseReporter():
17 | def __init__(self, infer_cfg, model, dset, output_root):
18 | self.output_root = output_root
19 |
20 | def process(self, model_outputs, inputs, dset):
21 | pass
22 |
23 | def generate_report(self):
24 | # return some report here
25 | pass
26 |
27 |
28 | class mIoU(BaseReporter):
29 | def __init__(self, infer_cfg, model, dset, output_root):
30 | del infer_cfg # not needed
31 | self.metric = Metric(
32 | model.dset_meta['num_classes'],
33 | model.pcv.num_votes, model.dset_meta['trainId_2_catName']
34 | )
35 | self.output_root = output_root
36 |
37 | def process(self, inx, model_outputs, inputs, dset):
38 | sem_pred, vote_pred = model_outputs
39 | sem_pred = F.interpolate(
40 | sem_pred, scale_factor=4, mode='nearest'
41 | )
42 | sem_pred, vote_pred = sem_pred.argmax(1), vote_pred.argmax(1)
43 | self.metric.update(sem_pred, vote_pred, *inputs[1:3])
44 |
45 | def generate_report(self):
46 | # return some report here
47 | return str(self.metric)
48 |
49 |
50 | class PQ_report(BaseReporter):
51 | def __init__(self, infer_cfg, model, dset, output_root):
52 | self.infer_cfg = edict(infer_cfg.copy())
53 | self.output_root = output_root
54 | self.metric = PQMetric(dset.meta)
55 | self.overall_pred_meta = {
56 | 'images': list(dset.imgs.values()),
57 | 'categories': list(dset.meta['cats'].values()),
58 | 'annotations': []
59 | }
60 | self.model = model
61 | self.oracle_mode = self.infer_cfg['oracle_mode']
62 | if self.oracle_mode is not None:
63 | assert self.oracle_mode in _VALID_ORACLE_MODES
64 |
65 | os.makedirs(osp.dirname(output_root), exist_ok=True)
66 | PRED_OUT_NAME = 'pred'
67 | self.pan_json_fname = osp.join(output_root, '{}.json'.format(PRED_OUT_NAME))
68 | self.pan_mask_dir = osp.join(output_root, PRED_OUT_NAME)
69 |
70 | def process(self, inx, model_outputs, inputs, dset):
71 | from panoptic.entry import gt_tsr_res_reduction # HELPPPPPPPPPPPPPPPPPPPPPPPPPPPPP
72 | sem_pd, vote_pd = model_outputs
73 | sem_pd, vote_pd = F.softmax(sem_pd, dim=1), F.softmax(vote_pd, dim=1)
74 |
75 | model = self.model
76 | oracle_mode = self.oracle_mode
77 | oracle_res = 4
78 | imgMeta, segments_info, _, pan_gt_mask = dset.pan_getitem(
79 | inx, apply_trans=False
80 | )
81 | if dset.transforms is not None:
82 | _, trans_pan_gt = dset.transforms(_, pan_gt_mask)
83 | else:
84 | trans_pan_gt = pan_gt_mask.copy()
85 |
86 | pan_gt_ann = {
87 | 'image_id': imgMeta['id'],
88 | # shameful mogai; can only access image f_name here. alas...
89 | 'file_name': imgMeta['file_name'].split('.')[0] + '.png',
90 | 'segments_info': list(segments_info.values())
91 | }
92 |
93 | if oracle_mode is not None:
94 | sem_ora, vote_ora = gt_tsr_res_reduction(
95 | oracle_res, dset.gt_prod_handle,
96 | dset.meta, model.pcv, trans_pan_gt, segments_info
97 | )
98 | if oracle_mode == 'vote':
99 | pass # if using model sem pd, maintain stuff pred thresh
100 | else: # sem or full oracle, using gt sem pd, do not filter
101 | self.infer_cfg['stuff_pred_thresh'] = -1
102 |
103 | if oracle_mode == 'sem':
104 | sem_pd = sem_ora
105 | elif oracle_mode == 'vote':
106 | vote_pd = vote_ora
107 | else:
108 | sem_pd, vote_pd = sem_ora, vote_ora
109 |
110 | pan_pd_mask, pan_pd_ann = model.stitch_pan_mask(
111 | self.infer_cfg, sem_pd, vote_pd, pan_gt_mask.size
112 | )
113 | pan_pd_ann['image_id'] = pan_gt_ann['image_id']
114 | pan_pd_ann['file_name'] = pan_gt_ann['file_name']
115 | self.overall_pred_meta['annotations'].append(pan_pd_ann)
116 |
117 | self.metric.update(
118 | pan_gt_ann, rgb2id(np.array(pan_gt_mask)), pan_pd_ann, pan_pd_mask
119 | )
120 |
121 | pan_pd_mask = Image.fromarray(id2rgb(pan_pd_mask))
122 | fname = osp.join(self.pan_mask_dir, pan_pd_ann['file_name'])
123 | os.makedirs(osp.dirname(fname), exist_ok=True) # make region subdir
124 | pan_pd_mask.save(fname)
125 |
126 | def generate_report(self):
127 | with open(self.pan_json_fname, 'w') as f:
128 | json.dump(self.overall_pred_meta, f)
129 | save_object(
130 | self.metric.state_dict(), osp.join(self.output_root, 'score.pkl')
131 | )
132 | return self.metric
133 |
134 |
135 | reporter_modules = {
136 | 'mIoU': mIoU,
137 | 'pq': PQ_report
138 | }
139 |
--------------------------------------------------------------------------------
/src/datasets/augmentations/augmentations.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | # import torchvision.transforms as tv_trans
5 | from torchvision.transforms import ColorJitter as tv_color_jitter
6 | import torchvision.transforms.functional as tv_f
7 | from PIL import Image
8 |
9 |
10 | def _pad_to_sizes(mask, tw, th):
11 | w, h = mask.size
12 | right = max(tw - w, 0)
13 | bottom = max(th - h, 0)
14 | mask = tv_f.pad(mask, fill=0, padding=(0, 0, right, bottom))
15 | return mask
16 |
17 |
18 | class RandomCrop():
19 | def __init__(self, size):
20 | if isinstance(size, (tuple, list)):
21 | assert len(size) == 2
22 | self.size = size
23 | else:
24 | self.size = [size, size]
25 |
26 | def __call__(self, img, mask):
27 | assert img.size == mask.size
28 | tw, th = self.size
29 |
30 | img = _pad_to_sizes(img, tw, th)
31 | mask = _pad_to_sizes(mask, tw, th)
32 | w, h = img.size
33 |
34 | x1 = random.randint(0, w - tw)
35 | y1 = random.randint(0, h - th)
36 | square = (x1, y1, x1 + tw, y1 + th)
37 | return img.crop(square), mask.crop(square)
38 |
39 |
40 | class RandomHorizontalFlip():
41 | def __init__(self, p=0.5):
42 | self.p = p
43 |
44 | def __call__(self, img, mask):
45 | if random.random() < self.p:
46 | return (
47 | img.transpose(Image.FLIP_LEFT_RIGHT),
48 | mask.transpose(Image.FLIP_LEFT_RIGHT),
49 | )
50 | return img, mask
51 |
52 |
53 | class Scale():
54 | def __init__(self, ratio=0.5):
55 | self.ratio = ratio
56 |
57 | def __call__(self, img, mask):
58 | w, h = img.size
59 | ratio = self.ratio
60 | target_size = (int(ratio * w), int(ratio * h))
61 | img = img.resize(target_size, Image.BILINEAR)
62 | mask = mask.resize(target_size, Image.NEAREST)
63 | return (img, mask)
64 |
65 |
66 | class Resize():
67 | def __init__(self, size):
68 | self.size = size
69 |
70 | def __call__(self, img, mask):
71 | assert img.size == mask.size
72 | target_size = self.size
73 | img = img.resize(target_size, Image.BILINEAR)
74 | mask = mask.resize(target_size, Image.NEAREST)
75 | return (img, mask)
76 |
77 |
78 | class RandomSized():
79 | def __init__(self, jitter=(0.5, 2)):
80 | self.jitter = jitter
81 |
82 | def __call__(self, img, mask):
83 | assert img.size == mask.size
84 |
85 | scale = random.uniform(self.jitter[0], self.jitter[1])
86 |
87 | w = int(scale * img.size[0])
88 | h = int(scale * img.size[1])
89 |
90 | img, mask = (
91 | img.resize((w, h), Image.BILINEAR),
92 | mask.resize((w, h), Image.NEAREST),
93 | )
94 |
95 | return img, mask
96 |
97 |
98 | class RoundToMultiple():
99 | def __init__(self, stride, method='pad'):
100 | assert method in ('pad', 'resize')
101 | self.stride = stride
102 | self.method = method
103 |
104 | def __call__(self, img, mask):
105 | assert img.size == mask.size
106 | stride = self.stride
107 | if stride > 0:
108 | w, h = img.size
109 | w = int(math.ceil(w / stride) * stride)
110 | h = int(math.ceil(h / stride) * stride)
111 | if self.method == 'pad':
112 | img = _pad_to_sizes(img, w, h)
113 | mask = _pad_to_sizes(mask, w, h)
114 | else:
115 | img = img.resize((w, h), Image.BILINEAR)
116 | mask = mask.resize((w, h), Image.NEAREST)
117 | return img, mask
118 |
119 |
120 | class COCOResize():
121 | '''
122 | adapted from maskrcnn_benchmark/data/transforms/transforms.py
123 | '''
124 | def __init__(self, min_size, max_size, round_to_divisble):
125 | if not isinstance(min_size, (list, tuple)):
126 | min_size = (min_size,)
127 | assert round_to_divisble in ('pad', 'resize')
128 | self.min_size = min_size
129 | self.max_size = max_size
130 | self.round_to_divisble = round_to_divisble
131 |
132 | # modified from torchvision to add support for max size
133 | def get_size(self, image_size):
134 | w, h = image_size
135 | size = random.choice(self.min_size)
136 | max_size = self.max_size
137 | if max_size is not None:
138 | min_original_size = float(min((w, h)))
139 | max_original_size = float(max((w, h)))
140 | if max_original_size / min_original_size * size > max_size:
141 | size = int(round(max_size * min_original_size / max_original_size))
142 |
143 | if (w <= h and w == size) or (h <= w and h == size):
144 | return (h, w)
145 |
146 | if w < h:
147 | ow = size
148 | oh = int(size * h / w)
149 | else:
150 | oh = size
151 | ow = int(size * w / h)
152 |
153 | return (ow, oh)
154 |
155 | def __call__(self, img, mask):
156 | assert img.size == mask.size
157 | size = self.get_size(img.size)
158 | img = img.resize(size)
159 | mask = mask.resize(size)
160 | padder = RoundToMultiple(32, self.round_to_divisble)
161 | img, mask = padder(img, mask)
162 | return img, mask
163 |
164 |
165 | class ColorJitter():
166 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
167 | self.tv_jitter = tv_color_jitter(
168 | brightness, contrast, saturation, hue
169 | )
170 |
171 | def __call__(self, img, mask):
172 | assert img.size == mask.size
173 | img = self.tv_jitter(img)
174 | return img, mask
175 |
--------------------------------------------------------------------------------
/src/datasets/samplers/grouped_batch_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | import itertools
3 | import math
4 | import torch
5 | from torch.utils.data.sampler import BatchSampler
6 | from torch.utils.data.sampler import Sampler
7 |
8 |
9 | class GroupedBatchSampler(BatchSampler):
10 | """
11 | Wraps another sampler to yield a mini-batch of indices.
12 | It enforces that elements from the same group should appear in groups of batch_size.
13 | It also tries to provide mini-batches which follows an ordering which is
14 | as close as possible to the ordering from the original sampler.
15 |
16 | Arguments:
17 | sampler (Sampler): Base sampler.
18 | batch_size (int): Size of mini-batch.
19 | drop_uneven (bool): If ``True``, the sampler will drop the batches whose
20 | size is less than ``batch_size``
21 |
22 | """
23 |
24 | def __init__(self, sampler, group_ids, batch_size, drop_uneven=False):
25 | if not isinstance(sampler, Sampler):
26 | raise ValueError(
27 | "sampler should be an instance of "
28 | "torch.utils.data.Sampler, but got sampler={}".format(sampler)
29 | )
30 | assert drop_uneven is False, 'for now do not allow dropping last batch'
31 | self.sampler = sampler
32 | self.group_ids = torch.as_tensor(group_ids)
33 | assert self.group_ids.dim() == 1
34 | self.batch_size = batch_size
35 | self.drop_uneven = drop_uneven
36 |
37 | self.groups = torch.unique(self.group_ids).sort(0)[0]
38 |
39 | self._can_reuse_batches = False
40 | self._intended_num_batches = int(math.ceil(len(sampler) / batch_size))
41 |
42 | def _prepare_batches(self):
43 | dataset_size = len(self.group_ids)
44 | # get the sampled indices from the sampler
45 | sampled_ids = torch.as_tensor(list(self.sampler))
46 | # potentially not all elements of the dataset were sampled
47 | # by the sampler (e.g., DistributedSampler).
48 | # construct a tensor which contains -1 if the element was
49 | # not sampled, and a non-negative number indicating the
50 | # order where the element was sampled.
51 | # for example. if sampled_ids = [3, 1] and dataset_size = 5,
52 | # the order is [-1, 1, -1, 0, -1]
53 | order = torch.full((dataset_size,), -1, dtype=torch.int64)
54 | order[sampled_ids] = torch.arange(len(sampled_ids))
55 |
56 | # get a mask with the elements that were sampled
57 | mask = order >= 0
58 |
59 | # find the elements that belong to each individual cluster
60 | clusters = [(self.group_ids == i) & mask for i in self.groups]
61 | # get relative order of the elements inside each cluster
62 | # that follows the order from the sampler
63 | relative_order = [order[cluster] for cluster in clusters]
64 | # with the relative order, find the absolute order in the
65 | # sampled space
66 | permutation_ids = [s[s.sort()[1]] for s in relative_order]
67 | # permute each cluster so that they follow the order from
68 | # the sampler
69 | permuted_clusters = [sampled_ids[idx] for idx in permutation_ids]
70 |
71 | # splits each cluster in batch_size, and merge as a list of tensors
72 | splits = [c.split(self.batch_size) for c in permuted_clusters]
73 | merged = tuple(itertools.chain.from_iterable(splits))
74 | merged = merged[:self._intended_num_batches]
75 |
76 | # now each batch internally has the right order, but
77 | # they are grouped by clusters. Find the permutation between
78 | # different batches that brings them as close as possible to
79 | # the order that we have in the sampler. For that, we will consider the
80 | # ordering as coming from the first element of each batch, and sort
81 | # correspondingly
82 | first_element_of_batch = [t[0].item() for t in merged]
83 | # get and inverse mapping from sampled indices and the position where
84 | # they occur (as returned by the sampler)
85 | inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())}
86 | # from the first element in each batch, get a relative ordering
87 | first_index_of_batch = torch.as_tensor(
88 | [inv_sampled_ids_map[s] for s in first_element_of_batch]
89 | )
90 |
91 | # permute the batches so that they approximately follow the order
92 | # from the sampler
93 | permutation_order = first_index_of_batch.sort(0)[1].tolist()
94 | # finally, permute the batches
95 | batches = [merged[i].tolist() for i in permutation_order]
96 |
97 | if self.drop_uneven:
98 | kept = []
99 | for batch in batches:
100 | if len(batch) == self.batch_size:
101 | kept.append(batch)
102 | batches = kept
103 | # print(
104 | # '{} sampled points; size of cluster {}; size of splits {}; '
105 | # 'size of merged: {}, size of batches {}'.format(
106 | # len(sampled_ids),
107 | # [inds.shape[0] for inds in permuted_clusters],
108 | # [len(sp) for sp in splits], len(merged), len(batches)
109 | # )
110 | # )
111 | return batches
112 |
113 | def __iter__(self):
114 | if self._can_reuse_batches:
115 | batches = self._batches
116 | self._can_reuse_batches = False
117 | else:
118 | batches = self._prepare_batches()
119 | self._batches = batches
120 | return iter(batches)
121 |
122 | def __len__(self):
123 | if not hasattr(self, "_batches"):
124 | self._batches = self._prepare_batches()
125 | self._can_reuse_batches = True
126 | return len(self._batches)
127 |
--------------------------------------------------------------------------------
/src/pcv/gaussian_smooth/prob_tsr.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from scipy.stats import multivariate_normal
5 |
6 | from ..components.snake import Snake
7 |
8 | _nine_offsets = [
9 | ( 0, 0),
10 | ( 1, 1),
11 | ( 0, 1),
12 | (-1, 1),
13 | (-1, 0),
14 | (-1, -1),
15 | ( 0, -1),
16 | ( 1, -1),
17 | ( 1, 0),
18 | ]
19 |
20 |
21 | class GaussianField():
22 | def __init__(self, diam, cov=0.05):
23 | assert (diam % 2 == 1), 'diam must be an odd'
24 | self.diam = diam
25 | self.cov = cov # .05 leaves about 95% prob mass within central block
26 | # only consider the 3x3 region
27 | self.increment = 1 / diam
28 | # compute 3 units
29 | self.l, self.r = -1.5, 1.5
30 | self.field_shape = (3 * diam, 3 * diam)
31 | self.unit_area = self.increment ** 2
32 | self.prob_field = self.compute_prob_field()
33 |
34 | def compute_prob_field(self):
35 | cov = self.cov
36 | increment = self.increment
37 | l, r = self.l, self.r
38 | cov_mat = np.array([
39 | [cov, 0],
40 | [0, cov]
41 | ])
42 | rv = multivariate_normal([0, 0], cov_mat)
43 | half_increment = increment / 2
44 | xs, ys = np.mgrid[
45 | l + half_increment: r: increment,
46 | l + half_increment: r: increment
47 | ] # use half increment to make things properly centered
48 | pos = np.dstack((xs, ys))
49 | prob_field = rv.pdf(pos).astype(np.float32)
50 | assert prob_field.shape == self.field_shape
51 | return prob_field
52 |
53 | @torch.no_grad()
54 | def compute_local_mass(self):
55 | kernel_size = self.diam
56 | pad = (kernel_size - 1) // 2
57 | prob_field = self.prob_field
58 |
59 | conv = nn.Conv2d(
60 | in_channels=1, out_channels=1, kernel_size=kernel_size,
61 | padding=pad, bias=False
62 | ) # do not use cuda for now; no point
63 | conv.weight.data.copy_(torch.tensor(1.0))
64 | prob_field = torch.as_tensor(
65 | prob_field, device=conv.weight.device
66 | )[(None,) * 2] # [1, 1, h, w]
67 | local_sum = conv(prob_field).squeeze().cpu().numpy()
68 | local_sum = local_sum * self.unit_area
69 | return local_sum
70 |
71 |
72 | class MakeProbTsr():
73 | '''
74 | make a prob tsr of shape [h, w, num_votes] filled with the corresponding
75 | spatial voting prob
76 | '''
77 | def __init__(self, spec, diam, grid_spec, vote_mask, var=0.05):
78 | # indices grid of shape [2, H, W], where first dim is y, x; swap them
79 | # obtain [H, W, 2] where last channel is (y, x)
80 | self.spec = spec
81 | self.diam = diam
82 | self.vote_mask = vote_mask
83 | self.var = var
84 |
85 | # process grid spec to 0 based indexing and change radius to diam
86 | radius = (diam - 1) // 2
87 | center = np.array((radius, radius))
88 | grid_spec = grid_spec.copy()
89 | grid_spec[:, :2] += center
90 | grid_spec[:, -1] = 1 + 2 * grid_spec[:, -1] # change from r to diam
91 | self.grid_spec = grid_spec
92 |
93 | def compute_voting_prob_tsr(self, normalize=True):
94 | spec = self.spec
95 | diam = self.diam
96 | grid_spec = self.grid_spec
97 | vote_mask = self.vote_mask
98 |
99 | spatial_shape = (diam, diam)
100 | spatial_yx = np.indices(spatial_shape).transpose(1, 2, 0).astype(int)
101 | # [H, W, 2] where each arr[y, x] is the containing grid's center
102 | spatial_cen_yx = np.empty_like(spatial_yx)
103 | # [H, W, 1] where each arr[y, x] is the containing grid's diam
104 | spatial_diam = np.empty(spatial_shape, dtype=int)[..., None]
105 | for i, (y, x, d) in enumerate(grid_spec):
106 | _m = vote_mask == i
107 | spatial_cen_yx[_m] = (y, x)
108 | spatial_diam[_m] = d
109 |
110 | max_vote_bin_diam = spec[-1][0]
111 | spatial_9_inds = self.nine_neighbor_inds(
112 | spatial_diam, spatial_yx, vote_mask,
113 | vote_mask_padding=max_vote_bin_diam
114 | )
115 |
116 | # spatial_9_probs = np.ones_like(spatial_9_inds).astype(float)
117 | spatial_9_probs = self.nine_neighbor_probs(
118 | spatial_diam, spatial_yx, spatial_cen_yx, self.var
119 | )
120 |
121 | # [H, W, num_votes + 1] 1 extra to trash the -1s
122 | spatial_prob = np.zeros((diam, diam, len(grid_spec) + 1))
123 | inds0, inds1, _ = np.ix_(range(diam), range(diam), range(1))
124 | np.add.at(spatial_prob, (inds0, inds1, spatial_9_inds), spatial_9_probs)
125 | spatial_prob[..., -1] = 0 # erase but keep the trash bin -> abstrain bin
126 | spatial_prob = self.erase_inward_prob_dist(spec, vote_mask, spatial_prob)
127 | if normalize:
128 | spatial_prob = spatial_prob / spatial_prob.sum(-1, keepdims=True)
129 | return spatial_prob
130 |
131 | @staticmethod
132 | def erase_inward_prob_dist(spec, vote_mask, spatial_prob):
133 | '''This is a measure of expedience borne of time constraints
134 | I can't help but feel ashamed of the time I have wasted dwelling on the
135 | right move; but the clock is ticking and I have to move on.
136 | '''
137 | splits = Snake.vote_channel_splits(spec)
138 | # ret = np.zeros_like(spatial_prob)
139 | layer_inds = np.cumsum(splits)
140 | for i in range(1, len(layer_inds)):
141 | curr = layer_inds[i]
142 | prev = layer_inds[i-1]
143 | belt_mask = (vote_mask < curr) & (vote_mask >= prev)
144 | spatial_prob[belt_mask, :prev] = 0
145 | return spatial_prob
146 |
147 | @staticmethod
148 | def nine_neighbor_inds(
149 | spatial_diam, spatial_yx, vote_mask, vote_mask_padding,
150 | ):
151 | # [H, W, 1, 1] * [9, 2] -> [H, W, 9, 2]
152 | spatial_9_offsets = spatial_diam[..., None] * np.array(_nine_offsets)
153 | # [H, W, 2] reshapes [H, W, 1, 2] + [H, W, 9, 2] -> [H, W, 9, 2]
154 | spatial_9_loc_yx = np.expand_dims(spatial_yx, 2) + spatial_9_offsets
155 |
156 | padded_vote_mask = np.pad(
157 | vote_mask, vote_mask_padding, mode='constant', constant_values=-1
158 | )
159 | # shift the inds
160 | spatial_9_loc_yx += (vote_mask_padding, vote_mask_padding)
161 | # [H, W, 9] where arr[y, x] contains the 9 inds centered on y, x
162 | spatial_9_inds = padded_vote_mask[
163 | tuple(np.split(spatial_9_loc_yx, 2, axis=-1))
164 | ].squeeze(-1)
165 | return spatial_9_inds
166 |
167 | @staticmethod
168 | def nine_neighbor_probs(spatial_diam, spatial_yx, spatial_cen_yx, var):
169 | spatial_cen_yx_offset = spatial_cen_yx - spatial_yx
170 | del spatial_cen_yx, spatial_yx
171 | single_cell_diam = 81
172 | field_diam = single_cell_diam * 3
173 | gauss = GaussianField(diam=single_cell_diam, cov=var)
174 | prob_local_mass = gauss.compute_local_mass()
175 |
176 | # now read off prob from every pix's 9 neighboring locations
177 | '''
178 | single_cell_diam: scalar; 1/3 of the field size for prob field
179 | prob_local_mass: [3 * single_cell_diam, 3 * single_cell_diam]
180 | spatial_diam: [H, W, 1]; arr[y, x] gives its grid diam
181 | spatial_cen_yx_offset: [H, W, 2] arr[y, x] gives dy, dx to its grid center
182 | '''
183 | assert field_diam == prob_local_mass.shape[0]
184 | assert prob_local_mass.shape[0] == prob_local_mass.shape[1]
185 | norm_spatial_cen_yx_offset = (
186 | spatial_cen_yx_offset * single_cell_diam / spatial_diam
187 | ).astype(np.int) # [H, W, 2]
188 | del spatial_cen_yx_offset, spatial_diam
189 |
190 | spatial_9_offsets = (
191 | single_cell_diam * np.array(_nine_offsets)
192 | ).reshape(1, 1, 9, 2)
193 | field_radius = (field_diam - 1) // 2
194 | center = (field_radius, field_radius)
195 | spatial_yx_loc = center + norm_spatial_cen_yx_offset
196 | # [H, W, 2] reshapes [H, W, 1, 2] + [1, 1, 9, 2] -> [H, W, 9, 2]
197 | spatial_9_loc_yx = np.expand_dims(spatial_yx_loc, axis=2) + spatial_9_offsets
198 |
199 | spatial_9_probs = prob_local_mass[
200 | tuple(np.split(spatial_9_loc_yx, 2, axis=-1))
201 | ].squeeze(-1)
202 | return spatial_9_probs
203 |
--------------------------------------------------------------------------------
/src/metric.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import numpy as np
3 | import torch
4 | from tabulate import tabulate
5 | from panoptic.pan_eval import pq_compute_single_img, PQStat
6 |
7 |
8 | class Metric(object):
9 | def __init__(self, num_classes, trainId_2_catName=None):
10 | """
11 | Args:
12 | window_size: the number of batch of visuals stated here
13 |
14 | All metrics are computed from a confusion matrix that aggregates pixel
15 | count across all images. This evaluation scheme only has a global
16 | view of all the pixels, and disregards image as a unit.
17 | "mean" always refers to averaging across pixel semantic classes.
18 | """
19 | self.num_classes = num_classes
20 | self.trainId_2_catName = trainId_2_catName
21 | self.init_state()
22 |
23 | def init_state(self):
24 | """caller may use it to reset the metric"""
25 | self.scores = dict()
26 | self.confusion_matrix = np.zeros(
27 | shape=(self.num_classes, self.num_classes), dtype=np.int64
28 | )
29 |
30 | def update(self, pred, gt):
31 | """
32 | Args:
33 | pred: [N, H, W] torch tsr or ndarray
34 | gt: [N, H, W] torch tsr or ndarray
35 | """
36 | if len(pred.shape) == 4:
37 | pred = pred.argmax(dim=1)
38 | assert pred.shape == gt.shape
39 | if isinstance(pred, torch.Tensor):
40 | pred, gt = pred.cpu().numpy(), gt.cpu().numpy()
41 | hist = self.fast_hist(pred, gt, self.num_classes)
42 | self.confusion_matrix += hist
43 | self.scores = self.compute_scores(
44 | self.confusion_matrix, self.trainId_2_catName
45 | )
46 |
47 | @staticmethod
48 | def fast_hist(pred, gt, num_classes):
49 | assert pred.shape == gt.shape
50 | valid_mask = (gt >= 0) & (gt < num_classes)
51 | hist = np.bincount(
52 | num_classes * gt[valid_mask] + pred[valid_mask],
53 | minlength=num_classes ** 2
54 | ).reshape(num_classes, num_classes)
55 | return hist
56 |
57 | @staticmethod
58 | def compute_scores(hist, trainId_2_catName):
59 | res = dict()
60 | num_classes = hist.shape[0]
61 |
62 | # per class statistics
63 | gt_freq = hist.sum(axis=1)
64 | pd_freq = hist.sum(axis=0) # pred frequency
65 | intersect = np.diag(hist)
66 | union = gt_freq + pd_freq - intersect
67 | iou = intersect / union
68 |
69 | cls_names = [
70 | trainId_2_catName[inx] for inx in range(num_classes)
71 | ] if trainId_2_catName is not None else range(num_classes)
72 |
73 | details = dict()
74 | for inx, name in enumerate(cls_names):
75 | details[name] = {
76 | 'gt_freq': gt_freq[inx],
77 | 'pd_freq': pd_freq[inx],
78 | 'intersect': intersect[inx],
79 | 'union': union[inx],
80 | 'iou': iou[inx]
81 | }
82 | res['details'] = details
83 |
84 | # aggregate statistics
85 | pix_acc = intersect.sum() / hist.sum()
86 | m_iou = np.nanmean(iou)
87 | freq = gt_freq / gt_freq.sum()
88 | # masking to avoid potential nan in per cls iou
89 | fwm_iou = (freq[freq > 0] * iou[freq > 0]).sum()
90 | del freq
91 | res['pix_acc'] = pix_acc
92 | res['m_iou'] = m_iou
93 | res['fwm_iou'] = fwm_iou
94 |
95 | return res
96 |
97 | def __repr__(self):
98 | return repr(self.scores)
99 |
100 | def __str__(self):
101 | return self.display(self.scores)
102 |
103 | @staticmethod
104 | def display(src_dict):
105 | """Only print out scalar metric like mIoU in a nice tabular form
106 | Detailed per cls info, etc, are withheld for clear presentation
107 | """
108 | to_display = dict()
109 | for k, v in src_dict.items():
110 | if isinstance(v, dict):
111 | continue # ignore those which cannot be tabulated
112 | to_display[k] = [v]
113 | table = tabulate(
114 | to_display,
115 | headers='keys', tablefmt='fancy_grid',
116 | floatfmt=".3f", numalign='decimal'
117 | )
118 | return str(table)
119 |
120 | # these save and load functions are ugly cuz they are tied to the
121 | # infrastructure. Re-write them later.
122 | def save(self, epoch_or_fname, manager):
123 | assert manager is not None
124 | state = self.scores
125 | # now save the acc with manager
126 | if isinstance(epoch_or_fname, int):
127 | epoch = epoch_or_fname
128 | manager.save(epoch, state)
129 | else:
130 | fname = epoch_or_fname
131 | save_path = osp.join(manager.root, fname)
132 | manager.save_f(state, save_path)
133 |
134 | def load(self, state):
135 | """Assume that the state is already read by the caller
136 | Args:
137 | state: dict with fields 'scores' and 'visuals'
138 | """
139 | self.scores = state
140 |
141 |
142 | class PanMetric(Metric):
143 | """
144 | This is a metric that evaluates predictions from both heads as pixel-wise
145 | classification
146 | """
147 | def __init__(self, num_classes, num_votes, trainId_2_catName):
148 | self.num_votes = num_votes
149 | super().__init__(num_classes, trainId_2_catName)
150 |
151 | def init_state(self):
152 | self.scores = dict()
153 | self.sem_confusion = np.zeros(
154 | shape=(self.num_classes, self.num_classes), dtype=np.int64
155 | )
156 | self.vote_confusion = np.zeros(
157 | shape=(self.num_votes, self.num_votes), dtype=np.int64
158 | )
159 |
160 | def update(self, sem_pred, vote_pred, sem_gt, vote_gt):
161 | """
162 | Args:
163 | sem_pred: [N, H, W] torch tsr or ndarray
164 | vote_pred: [N, H, W] torch tsr or ndarray
165 | sem_gt: [N, H, W] torch tsr or ndarray
166 | vote_gt: [N, H, W] torch tsr or ndarray
167 | """
168 | self._update_pair(
169 | sem_pred, sem_gt, 'sem', self.sem_confusion, self.num_classes)
170 | self._update_pair(
171 | vote_pred, vote_gt, 'vote', self.vote_confusion, self.num_votes)
172 |
173 | def _update_pair(self, pred, gt, key, confusion_matrix, num_cats):
174 | if len(pred.shape) == 4:
175 | pred = pred.argmax(dim=1)
176 | if isinstance(pred, torch.Tensor):
177 | pred, gt = pred.cpu().numpy(), gt.cpu().numpy()
178 | hist = self.fast_hist(pred, gt, num_cats)
179 | confusion_matrix += hist
180 | trainId_2_catName = self.trainId_2_catName if key == 'sem' else None
181 | self.scores[key] = self.compute_scores(
182 | confusion_matrix, trainId_2_catName
183 | )
184 |
185 | def __str__(self):
186 | sem_str = self.display(self.scores['sem'])
187 | vote_str = self.display(self.scores['vote'])
188 | combined = "sem: \n{} \nvote: \n{}".format(sem_str, vote_str)
189 | return combined
190 |
191 |
192 | class PQMetric():
193 | def __init__(self, dset_meta):
194 | self.cats = dset_meta['cats']
195 | self.score = PQStat()
196 | self.metrics = [("All", None), ("Things", True), ("Stuff", False)]
197 | self.results = {}
198 |
199 | def update(self, gt_ann, gt, pred_ann, pred):
200 | assert gt.shape == pred.shape
201 | stat = pq_compute_single_img(self.cats, gt_ann, gt, pred_ann, pred)
202 | self.score += stat
203 |
204 | def state_dict(self):
205 | self.aggregate_results()
206 | return self.results
207 |
208 | def aggregate_results(self):
209 | for name, isthing in self.metrics:
210 | self.results[name], per_class_results = self.score.pq_average(
211 | self.cats, isthing=isthing
212 | )
213 | if name == 'All':
214 | self.results['per_class'] = per_class_results
215 |
216 | def __str__(self):
217 | self.aggregate_results()
218 | headers = [ m[0] for m in self.metrics ]
219 | keys = ['pq', 'sq', 'rq']
220 | data = []
221 | for tranche in headers:
222 | row = [100 * self.results[tranche][k] for k in keys]
223 | row = [tranche] + row + [self.results[tranche]['n']]
224 | data.append(row)
225 | table = tabulate(
226 | tabular_data=data,
227 | headers=([''] + keys + ['n_cats']),
228 | floatfmt=".2f", tablefmt='fancy_grid',
229 | )
230 | return table
231 |
--------------------------------------------------------------------------------
/src/models/panoptic_base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import apex
7 | from panoptic.models.base_model import BaseModel
8 | from panoptic import get_loss, get_optimizer, get_scheduler
9 | from ..pcv.inference.mask_from_vote import MaskFromVote, MFV_CatSeparate
10 |
11 | from fabric.utils.logging import setup_logging
12 | logger = setup_logging(__file__)
13 |
14 |
15 | class PanopticBase(BaseModel):
16 | def __init__(self, cfg, pcv, dset_meta):
17 | super().__init__()
18 | self.cfg = cfg
19 | self.pcv = pcv
20 | self.dset_meta = dset_meta
21 |
22 | self.instantiate_network(cfg)
23 | assert self.net is not None, "instantiate network properly!"
24 | self.criteria = get_loss(cfg.loss)
25 | self.net.loss_module = self.criteria # fold crit into net
26 |
27 | # torch.cuda.synchronize()
28 | self.add_optimizer(cfg.optimizer)
29 | assert self.optimizer is not None # confirm the child has done the job
30 | # torch.cuda.synchronize()
31 |
32 | self.curr_epoch = 0
33 | self.total_train_epoch = cfg.scheduler.total_epochs
34 | self.scheduler = get_scheduler(cfg.scheduler)(
35 | optimizer=self.optimizer
36 | )
37 |
38 | # used as temporary place holder for model training inputs
39 | self._latest_loss = None
40 | self.img, self.sem_mask, self.vote_mask, self.weight_mask = \
41 | None, None, None, None
42 |
43 | def instantiate_network(self, cfg):
44 | # child class must overwrite the method
45 | raise NotImplementedError()
46 |
47 | def get_params_lr(self, initial_lr):
48 | return self.net.parameters()
49 |
50 | def add_optimizer(self, optim_cfg):
51 | """New Method for this class. Only retrive the module
52 | Children class should implement the function s.t self.optimizer
53 | is filled. The child class could add different learning rate for
54 | different params, etc
55 | """
56 | optim_handle = get_optimizer(optim_cfg)
57 | # self.optimizer = optim_handle(params=self.net.parameters())
58 | params_lr = self.get_params_lr(optim_cfg.params.lr)
59 | self.optimizer = optim_handle(params=params_lr)
60 |
61 | def ingest_train_input(self, *inputs):
62 | self.inputs = inputs
63 |
64 | def optimize_params(self):
65 | self.set_train_eval_state(True)
66 | self.optimizer.zero_grad()
67 |
68 | loss, sem_loss, vote_loss = self.net(*self.inputs)
69 | loss = loss.mean() # warning here.
70 | if getattr(self.cfg, 'apex', False):
71 | with apex.amp.scale_loss(loss, self.optimizer) as scaled_loss:
72 | scaled_loss.backward()
73 | else:
74 | loss.backward()
75 | self.optimizer.step()
76 | self._latest_loss = {
77 | 'loss': loss.item(),
78 | 'sem_loss': sem_loss.mean().item(),
79 | 'vote_loss': vote_loss.mean().item()
80 | }
81 |
82 | def latest_loss(self, get_numeric=False):
83 | if get_numeric:
84 | return self._latest_loss
85 | else:
86 | return {k: '{:.4f}'.format(v) for k, v in self._latest_loss.items()}
87 |
88 | def curr_lr(self):
89 | lr = [ group['lr'] for group in self.optimizer.param_groups ][0]
90 | return lr
91 |
92 | def advance_to_next_epoch(self):
93 | self.curr_epoch += 1
94 |
95 | def log_statistics(self, step, level=0):
96 | if self.log_writer is not None:
97 | for k, v in self.latest_loss(get_numeric=True).items():
98 | self.log_writer.add_scalar(k, v, step)
99 | if level >= 1:
100 | curr_lr = self.curr_lr()
101 | self.log_writer.add_scalar('lr', curr_lr, step)
102 |
103 | def infer(self, x, *args, **kwargs):
104 | return self._infer(x, *args, **kwargs)
105 |
106 | def _infer(
107 | self, x, upsize_sem=False, take_argmax=False, softmax_normalize=False
108 | ):
109 | self.set_train_eval_state(False)
110 | with torch.no_grad():
111 | sem_pred, vote_pred = self.net(x)
112 |
113 | if upsize_sem:
114 | sem_pred = F.interpolate(
115 | sem_pred, scale_factor=4, mode='nearest'
116 | )
117 |
118 | if take_argmax:
119 | sem_pred, vote_pred = sem_pred.max(1)[1], vote_pred.max(1)[1] #sem_pred.argmax(1), vote_pred.argmax(1)
120 | return sem_pred, vote_pred
121 |
122 | if softmax_normalize:
123 | sem_pred = F.softmax(sem_pred, dim=1)
124 | vote_pred = F.softmax(vote_pred, dim=1)
125 |
126 | return sem_pred, vote_pred
127 |
128 | def _flip_infer(self, x, *args, **kwargs):
129 | # x: 1x3xhxw
130 | new_x = torch.cat([x, x.flip(-1)], 0)
131 | _sem_pred, _vote_pred = self._infer(new_x, *args, **kwargs)
132 |
133 | if not hasattr(self, 'flip_infer_mapping'):
134 | original_mask = torch.from_numpy(self.pcv.vote_mask)
135 | new_mask = original_mask.flip(-1)
136 | mapping = {}
137 | for ii, jj in zip(new_mask.view(-1).tolist(), original_mask.view(-1).tolist()):
138 | if ii in mapping:
139 | assert mapping[ii] == jj
140 | else:
141 | mapping[ii] = jj
142 | mapping[len(mapping)] = len(mapping) #abstain
143 | self.flip_infer_mapping = torch.LongTensor([mapping[_] for _ in range(len(mapping))])
144 |
145 | # sem_pred = (_sem_pred[:1] + _sem_pred[1:].flip(-1)) / 2
146 | sem_pred = F.softmax((_sem_pred[:1].log() + _sem_pred[1:].flip(-1).log()) / 2, dim=1)
147 | # vote_pred = (_vote_pred[:1] + _vote_pred[1:, self.flip_infer_mapping].flip(-1)) / 2
148 | vote_pred = F.softmax((_vote_pred[:1].log() + _vote_pred[1:, self.flip_infer_mapping].flip(-1).log()) / 2, dim=1)
149 | return sem_pred, vote_pred
150 |
151 | def stitch_pan_mask(self, infer_cfg, sem_pred, vote_pred, target_size=None, return_hmap=False):
152 | """
153 | Args:
154 | sem_pred: [1, num_class, H, W] torch gpu tsr
155 | vote_pred: [1, num_bins, H, W] torch gpu tsr
156 | target_size: optionally postprocess the output, resize, filter
157 | stuff predictions on threshold, etc
158 | """
159 | assert self.dset_meta is not None and self.pcv is not None
160 | # make the meta data actually required explicit!!
161 | infer_m = MaskFromVote if infer_cfg.num_groups == 1 else MFV_CatSeparate
162 | mfv = infer_m(infer_cfg, self.dset_meta, self.pcv, sem_pred, vote_pred)
163 | pan_mask, meta = mfv.infer_panoptic_mask()
164 | vote_hmap = mfv.vote_hmap
165 | # pan_mask, meta = self.pcv.mask_from_sem_vote_tsr(
166 | # self.dset_meta, sem_pred, vote_pred
167 | # ) don't go the extra route. Be straightforward
168 |
169 | if target_size is not None:
170 | stuff_filt_thresh = infer_cfg.get(
171 | 'stuff_pred_thresh', self.dset_meta['stuff_pred_thresh']
172 | )
173 | pan_mask, meta = self._post_process_pan_pred(
174 | pan_mask, meta, target_size, stuff_filt_thresh
175 | )
176 | if not return_hmap:
177 | return pan_mask, meta
178 | else:
179 | return pan_mask, meta, vote_hmap
180 |
181 | @staticmethod
182 | def _post_process_pan_pred(pan_mask, pan_meta, target_size, stuff_thresh=-1):
183 | """Assume that pan_mask is np array and target is a PIL Image
184 | adjust annotations as needed when size change erases instances
185 | """
186 | pan_mask = Image.fromarray(pan_mask)
187 | pan_mask_size = pan_mask.size
188 | pan_mask = np.array(pan_mask.resize(target_size, resample=Image.NEAREST))
189 |
190 | # account for lost segments due to downsizing
191 | if pan_mask_size[0] > target_size[0]:
192 | # downsizing is the only case where instances could be erased
193 | remains = np.unique(pan_mask)
194 | segs = pan_meta['segments_info']
195 | acc = []
196 | for seg in segs:
197 | if seg['id'] not in remains:
198 | logger.warn('segment erased due to downsizing')
199 | else:
200 | acc.append(seg)
201 | pan_meta['segments_info'] = acc
202 |
203 | # filter out stuff segments based on area threshold
204 | segs = pan_meta['segments_info']
205 | acc = []
206 | for seg in segs:
207 | if seg['isthing'] == 0:
208 | _mask = (pan_mask == seg['id'])
209 | area = _mask.sum()
210 | if area > stuff_thresh:
211 | acc.append(seg)
212 | else:
213 | pan_mask[_mask] = 0
214 | else:
215 | acc.append(seg)
216 | pan_meta['segments_info'] = acc
217 |
218 | return pan_mask, pan_meta
219 |
220 | def load_latest_checkpoint_if_available(self, manager, direct_ckpt=None):
221 | ckpt = manager.load_latest() if direct_ckpt is None else direct_ckpt
222 | if ckpt:
223 | self.curr_epoch = ckpt['curr_epoch']
224 | self.net.module.load_state_dict(ckpt['model_state'])
225 | self.optimizer.load_state_dict(ckpt['optim_state'])
226 | # scheduler is incremented by global_step in the training loop
227 | self.scheduler = get_scheduler(self.cfg.scheduler)(
228 | optimizer=self.optimizer,
229 | )
230 | logger.info("loaded checkpoint that completes epoch {}".format(
231 | ckpt['curr_epoch']))
232 | self.curr_epoch += 1
233 | else:
234 | logger.info("No checkpoint found")
235 |
236 | def write_checkpoint(self, manager):
237 | ckpt_obj = dict(
238 | curr_epoch=self.curr_epoch,
239 | model_state=self.net.module.state_dict(),
240 | optim_state=self.optimizer.state_dict(),
241 | )
242 | manager.save(self.curr_epoch, ckpt_obj)
243 |
244 |
245 | class AbstractNet(nn.Module):
246 | def forward(self, *inputs):
247 | """
248 | If gt is supplied, then compute loss
249 | """
250 | x = inputs[0]
251 | x = self.stage1(x) # usually are resnet, hrnet, mobilenet etc.
252 | x = self.stage2(x) # like fpn, aspp, ups, etc.
253 | sem_pred = self.sem_classifier(x[0])
254 | vote_pred = self.vote_classifier(x[1])
255 |
256 | if len(inputs) > 1:
257 | assert self.loss_module is not None
258 | loss = self.loss_module(sem_pred, vote_pred, *inputs[1:])
259 | return loss
260 | else:
261 | return sem_pred, vote_pred
262 |
--------------------------------------------------------------------------------
/src/pan_eval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 | from __future__ import unicode_literals
6 | import os
7 | import numpy as np
8 | import json
9 | import time
10 | from collections import defaultdict
11 | import argparse
12 | import multiprocessing
13 |
14 | import PIL.Image as Image
15 |
16 | from panopticapi.utils import get_traceback, rgb2id, id2rgb
17 |
18 | OFFSET = 256 * 256 * 256
19 | VOID = 0
20 |
21 |
22 | class PQStatCat():
23 | def __init__(self):
24 | self.iou = 0.0
25 | self.tp = 0
26 | self.fp = 0
27 | self.fn = 0
28 |
29 | def __iadd__(self, pq_stat_cat):
30 | self.iou += pq_stat_cat.iou
31 | self.tp += pq_stat_cat.tp
32 | self.fp += pq_stat_cat.fp
33 | self.fn += pq_stat_cat.fn
34 | return self
35 |
36 |
37 | class PQStat():
38 | def __init__(self):
39 | self.pq_per_cat = defaultdict(PQStatCat)
40 |
41 | def __getitem__(self, i):
42 | return self.pq_per_cat[i]
43 |
44 | def __iadd__(self, pq_stat):
45 | for label, pq_stat_cat in pq_stat.pq_per_cat.items():
46 | self.pq_per_cat[label] += pq_stat_cat
47 | return self
48 |
49 | def pq_average(self, categories, isthing):
50 | pq, sq, rq, n = 0, 0, 0, 0
51 | per_class_results = {}
52 | for label, label_info in categories.items():
53 | name = label_info['name']
54 | if isthing is not None:
55 | cat_isthing = label_info['isthing'] == 1
56 | if isthing != cat_isthing:
57 | continue
58 | iou = self.pq_per_cat[label].iou
59 | tp = self.pq_per_cat[label].tp
60 | fp = self.pq_per_cat[label].fp
61 | fn = self.pq_per_cat[label].fn
62 | if tp + fp + fn == 0:
63 | per_class_results[label] = {'pq': 0.0, 'sq': 0.0, 'rq': 0.0}
64 | continue
65 | n += 1
66 | pq_class = iou / (tp + 0.5 * fp + 0.5 * fn)
67 | sq_class = iou / tp if tp != 0 else 0
68 | rq_class = tp / (tp + 0.5 * fp + 0.5 * fn)
69 | per_class_results[name] = {
70 | 'pq': pq_class, 'sq': sq_class, 'rq': rq_class
71 | }
72 | pq += pq_class
73 | sq += sq_class
74 | rq += rq_class
75 |
76 | return {'pq': pq / n, 'sq': sq / n, 'rq': rq / n, 'n': n}, per_class_results
77 |
78 |
79 | def pq_compute_single_img(categories, gt_ann, pan_gt, pred_ann, pan_pred):
80 | pq_stat = PQStat()
81 | gt_segms = {el['id']: el for el in gt_ann['segments_info']}
82 | pred_segms = {el['id']: el for el in pred_ann['segments_info']}
83 |
84 | # predicted segments area calculation + prediction sanity checks
85 | pred_labels_set = set(el['id'] for el in pred_ann['segments_info'])
86 | labels, labels_cnt = np.unique(pan_pred, return_counts=True)
87 | for label, label_cnt in zip(labels, labels_cnt):
88 | if label not in pred_segms:
89 | if label == VOID:
90 | continue
91 | raise KeyError('In the image with ID {} segment with ID {} is presented in PNG and not presented in JSON.'.format(gt_ann['image_id'], label))
92 | pred_segms[label]['area'] = int(label_cnt)
93 | pred_labels_set.remove(label)
94 | if pred_segms[label]['category_id'] not in categories:
95 | raise KeyError('In the image with ID {} segment with ID {} has unknown category_id {}.'.format(gt_ann['image_id'], label, pred_segms[label]['category_id']))
96 | if len(pred_labels_set) != 0:
97 | raise KeyError('In the image with ID {} the following segment IDs {} are presented in JSON and not presented in PNG.'.format(gt_ann['image_id'], list(pred_labels_set)))
98 |
99 | # confusion matrix calculation
100 | pan_gt_pred = pan_gt.astype(np.uint64) * OFFSET + pan_pred.astype(np.uint64)
101 | gt_pred_map = {}
102 | labels, labels_cnt = np.unique(pan_gt_pred, return_counts=True)
103 | for label, intersection in zip(labels, labels_cnt):
104 | gt_id = label // OFFSET
105 | pred_id = label % OFFSET
106 | gt_pred_map[(gt_id, pred_id)] = intersection
107 |
108 | # count all matched pairs
109 | gt_matched = set()
110 | pred_matched = set()
111 | for label_tuple, intersection in gt_pred_map.items():
112 | gt_label, pred_label = label_tuple
113 | if gt_label not in gt_segms:
114 | continue
115 | if pred_label not in pred_segms:
116 | continue
117 | if gt_segms[gt_label]['iscrowd'] == 1:
118 | continue
119 | if gt_segms[gt_label]['category_id'] != pred_segms[pred_label]['category_id']:
120 | continue
121 |
122 | union = pred_segms[pred_label]['area'] + gt_segms[gt_label]['area'] - intersection - gt_pred_map.get((VOID, pred_label), 0)
123 | iou = intersection / union
124 | if iou > 0.5:
125 | pq_stat[gt_segms[gt_label]['category_id']].tp += 1
126 | pq_stat[gt_segms[gt_label]['category_id']].iou += iou
127 | gt_matched.add(gt_label)
128 | pred_matched.add(pred_label)
129 |
130 | # count false positives
131 | crowd_labels_dict = {}
132 | for gt_label, gt_info in gt_segms.items():
133 | if gt_label in gt_matched:
134 | continue
135 | # crowd segments are ignored
136 | if gt_info['iscrowd'] == 1:
137 | crowd_labels_dict[gt_info['category_id']] = gt_label
138 | continue
139 | pq_stat[gt_info['category_id']].fn += 1
140 |
141 | # count false positives
142 | for pred_label, pred_info in pred_segms.items():
143 | if pred_label in pred_matched:
144 | continue
145 | # intersection of the segment with VOID
146 | intersection = gt_pred_map.get((VOID, pred_label), 0)
147 | # plus intersection with corresponding CROWD region if it exists
148 | if pred_info['category_id'] in crowd_labels_dict:
149 | intersection += gt_pred_map.get((crowd_labels_dict[pred_info['category_id']], pred_label), 0)
150 | # predicted segment is ignored if more than half of the segment correspond to VOID and CROWD regions
151 | if intersection / pred_info['area'] > 0.5:
152 | continue
153 | pq_stat[pred_info['category_id']].fp += 1
154 | return pq_stat
155 |
156 |
157 | @get_traceback
158 | def pq_compute_single_core(proc_id, annotation_set, gt_folder, pred_folder, categories):
159 | pq_stat = PQStat()
160 | idx = 0
161 | for gt_ann, pred_ann in annotation_set:
162 | if idx % 100 == 0:
163 | print('Core: {}, {} from {} images processed'.format(proc_id, idx, len(annotation_set)))
164 | idx += 1
165 |
166 | pan_gt = np.array(Image.open(os.path.join(gt_folder, gt_ann['file_name'])), dtype=np.uint32)
167 | pan_gt = rgb2id(pan_gt)
168 | pan_pred = np.array(Image.open(os.path.join(pred_folder, pred_ann['file_name'])), dtype=np.uint32)
169 | pan_pred = rgb2id(pan_pred)
170 |
171 | # downsize pred and upsize back; resolve the lost segments
172 | ratio = 8
173 | if ratio > 1:
174 | pan_pred = id2rgb(pan_pred)
175 | pan_pred = Image.fromarray(pan_pred)
176 | h, w = pan_pred.size
177 |
178 | pan_pred = pan_pred\
179 | .resize((h // ratio, w // ratio), Image.NEAREST)\
180 | .resize((h, w), Image.NEAREST)
181 | pan_pred = np.array(pan_pred, dtype=np.uint32)
182 | pan_pred = rgb2id(pan_pred)
183 | acc = []
184 | _p_segs = pred_ann['segments_info']
185 | for el in _p_segs:
186 | iid = el['id']
187 | if iid not in pan_pred:
188 | continue
189 | acc.append(el)
190 | pred_ann['segments_info'] = acc
191 |
192 | _single_img_stat = pq_compute_single_img(
193 | categories, gt_ann, pan_gt, pred_ann, pan_pred)
194 | pq_stat += _single_img_stat
195 |
196 | print('Core: {}, all {} images processed'.format(proc_id, len(annotation_set)))
197 | return pq_stat
198 |
199 |
200 | def pq_compute_multi_core(matched_annotations_list, gt_folder, pred_folder, categories):
201 | cpu_num = multiprocessing.cpu_count()
202 | annotations_split = np.array_split(matched_annotations_list, cpu_num)
203 | print("Number of cores: {}, images per core: {}".format(cpu_num, len(annotations_split[0])))
204 | workers = multiprocessing.Pool(processes=cpu_num)
205 | processes = []
206 | for proc_id, annotation_set in enumerate(annotations_split):
207 | p = workers.apply_async(pq_compute_single_core,
208 | (proc_id, annotation_set, gt_folder, pred_folder, categories))
209 | processes.append(p)
210 | pq_stat = PQStat()
211 | for p in processes:
212 | pq_stat += p.get()
213 | return pq_stat
214 |
215 |
216 | def pq_compute(gt_json_file, pred_json_file, gt_folder=None, pred_folder=None):
217 |
218 | start_time = time.time()
219 | with open(gt_json_file, 'r') as f:
220 | gt_json = json.load(f)
221 | with open(pred_json_file, 'r') as f:
222 | pred_json = json.load(f)
223 |
224 | if gt_folder is None:
225 | gt_folder = gt_json_file.replace('.json', '')
226 | if pred_folder is None:
227 | pred_folder = pred_json_file.replace('.json', '')
228 | categories = {el['id']: el for el in gt_json['categories']}
229 |
230 | print("Evaluation panoptic segmentation metrics:")
231 | print("Ground truth:")
232 | print("\tSegmentation folder: {}".format(gt_folder))
233 | print("\tJSON file: {}".format(gt_json_file))
234 | print("Prediction:")
235 | print("\tSegmentation folder: {}".format(pred_folder))
236 | print("\tJSON file: {}".format(pred_json_file))
237 |
238 | if not os.path.isdir(gt_folder):
239 | raise Exception("Folder {} with ground truth segmentations doesn't exist".format(gt_folder))
240 | if not os.path.isdir(pred_folder):
241 | raise Exception("Folder {} with predicted segmentations doesn't exist".format(pred_folder))
242 |
243 | pred_annotations = {el['image_id']: el for el in pred_json['annotations']}
244 | matched_annotations_list = []
245 | for gt_ann in gt_json['annotations']:
246 | image_id = gt_ann['image_id']
247 | if image_id not in pred_annotations:
248 | raise Exception('no prediction for the image with id: {}'.format(image_id))
249 | matched_annotations_list.append((gt_ann, pred_annotations[image_id]))
250 |
251 | pq_stat = pq_compute_multi_core(matched_annotations_list, gt_folder, pred_folder, categories)
252 |
253 | metrics = [("All", None), ("Things", True), ("Stuff", False)]
254 | results = {}
255 | for name, isthing in metrics:
256 | results[name], per_class_results = pq_stat.pq_average(categories, isthing=isthing)
257 | if name == 'All':
258 | results['per_class'] = per_class_results
259 | print("{:10s}| {:>5s} {:>5s} {:>5s} {:>5s}".format("", "PQ", "SQ", "RQ", "N"))
260 | print("-" * (10 + 7 * 4))
261 |
262 | for name, _isthing in metrics:
263 | print("{:10s}| {:5.1f} {:5.1f} {:5.1f} {:5d}".format(
264 | name,
265 | 100 * results[name]['pq'],
266 | 100 * results[name]['sq'],
267 | 100 * results[name]['rq'],
268 | results[name]['n'])
269 | )
270 |
271 | t_delta = time.time() - start_time
272 | print("Time elapsed: {:0.2f} seconds".format(t_delta))
273 |
274 | return results
275 |
276 |
277 | if __name__ == "__main__":
278 | parser = argparse.ArgumentParser()
279 | parser.add_argument('--gt_json_file', type=str,
280 | help="JSON file with ground truth data")
281 | parser.add_argument('--pred_json_file', type=str,
282 | help="JSON file with predictions data")
283 | parser.add_argument('--gt_folder', type=str, default=None,
284 | help="Folder with ground turth COCO format segmentations. \
285 | Default: X if the corresponding json file is X.json")
286 | parser.add_argument('--pred_folder', type=str, default=None,
287 | help="Folder with prediction COCO format segmentations. \
288 | Default: X if the corresponding json file is X.json")
289 | args = parser.parse_args()
290 | pq_compute(args.gt_json_file, args.pred_json_file, args.gt_folder, args.pred_folder)
291 |
--------------------------------------------------------------------------------
/src/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from . import cfg
6 | from .datasets.base import ignore_index
7 | from detectron2.layers.batch_norm import AllReduce
8 | import torch.distributed as dist
9 |
10 |
11 | class Loss(nn.Module):
12 | def __init__(self, w_vote, w_sem):
13 | super().__init__()
14 | self.w_vote = w_vote
15 | self.w_sem = w_sem
16 |
17 | def forward(self, *args, **kwargs):
18 | '''this 3-step interface is enforced for decomposable loss visulization
19 | see /vis.py
20 | '''
21 | sem_loss_mask, vote_loss_mask = self.per_pix_loss(*args, **kwargs)
22 | if len(args) == 5:
23 | sem_weight_mask = vote_weight_mask = args[-1] # this is risky. replying on the fact that if
24 | else:
25 | sem_weight_mask, vote_weight_mask = args[-2:]
26 | # weight_mask is not needed, then just supplying an ignored parameter
27 | sem_loss_mask, vote_loss_mask = self.normalize(
28 | sem_loss_mask, vote_loss_mask, sem_weight_mask, vote_weight_mask
29 | )
30 | loss, s_loss, v_loss = self.aggregate(sem_loss_mask, vote_loss_mask)
31 | return loss, s_loss, v_loss
32 |
33 | def per_pix_loss(self, *args, **kwargs):
34 | '''compute raw per pixel loss that reflects how good a pixel decision is
35 | do not apply any segment based normalization yet
36 | '''
37 | raise NotImplementedError()
38 |
39 | def normalize(self, sem_loss_mask, vote_loss_mask, weight_mask):
40 | raise NotImplementedError()
41 |
42 | def aggregate(self, sem_loss_mask, vote_loss_mask):
43 | s_loss, v_loss = sem_loss_mask.sum(), vote_loss_mask.sum()
44 | loss = self.w_sem * s_loss + self.w_vote * v_loss
45 | return loss, s_loss, v_loss
46 |
47 |
48 | def loss_normalize(loss, weight_mask, local_batch_size):
49 | '''An abomination of a method. Have to live with it for now but
50 | I will delete it soon enough. This 'total' option makes no sense except to
51 | serve the expedience of making the loss small
52 | '''
53 | norm = cfg.loss.normalize
54 | if norm == 'batch':
55 | loss = (loss * weight_mask) / local_batch_size
56 | elif norm == 'total':
57 | loss = (loss * weight_mask) / (weight_mask.sum() + 1e-10)
58 | elif norm == 'segs_across_batch':
59 | global_seg_cnt = weight_mask.sum()
60 | if torch.distributed.is_initialized():
61 | global_seg_cnt = AllReduce.apply(global_seg_cnt) / dist.get_world_size()
62 | # print('local sum {:.2f} vs global sum {:.2f}'.format(weight_mask.sum(), weight_mask_sum))
63 | loss = (loss * weight_mask) / (global_seg_cnt + 1e-10)
64 | else:
65 | raise ValueError('invalid value for normalization: {}'.format(norm))
66 | return loss
67 |
68 |
69 | class PanLoss(Loss):
70 | def per_pix_loss(self, sem_pred, vote_pred, sem_mask, vote_mask, *args, **kwargs):
71 | sem_loss_mask = F.cross_entropy(
72 | sem_pred, sem_mask, ignore_index=ignore_index, reduction='none'
73 | )
74 | vote_loss_mask = F.cross_entropy(
75 | vote_pred, vote_mask, ignore_index=ignore_index, reduction='none'
76 | )
77 | vote_loss_mask = vote_loss_mask # * 0.1
78 | return sem_loss_mask, vote_loss_mask
79 |
80 | def normalize(self, sem_loss_mask, vote_loss_mask, *args, **kwargs):
81 | sem_loss_mask = sem_loss_mask / sem_loss_mask.numel()
82 | vote_loss_mask = vote_loss_mask / vote_loss_mask.numel()
83 | return sem_loss_mask, vote_loss_mask
84 |
85 |
86 | class MaskedPanLoss(PanLoss):
87 | def normalize(self, sem_loss_mask, vote_loss_mask, sem_weight_mask, vote_weight_mask):
88 | assert len(sem_loss_mask) == len(vote_loss_mask)
89 | assert sem_weight_mask is not None and vote_weight_mask is not None
90 | batch_size = len(sem_loss_mask)
91 | sem_loss_mask = loss_normalize(sem_loss_mask, sem_weight_mask, batch_size)
92 | vote_loss_mask = loss_normalize(vote_loss_mask, vote_weight_mask, batch_size)
93 | return sem_loss_mask, vote_loss_mask
94 |
95 |
96 | class DeeperlabPanLoss(PanLoss):
97 | def normalize(self, sem_loss_mask, vote_loss_mask, sem_weight_mask, vote_weight_mask):
98 | assert len(sem_loss_mask) == len(vote_loss_mask)
99 | assert sem_weight_mask is not None and vote_weight_mask is not None
100 | batch_size = len(sem_loss_mask)
101 | sem_loss_mask = deeperlab_loss_normalize(sem_loss_mask, sem_weight_mask, batch_size)
102 | vote_loss_mask = deeperlab_loss_normalize(vote_loss_mask, vote_weight_mask, batch_size)
103 | return sem_loss_mask, vote_loss_mask
104 |
105 |
106 | def deeperlab_loss_normalize(loss, weight_mask, local_batch_size):
107 | new_weight = (weight_mask > 1/16/16).float()*2+1 # == area < 16x16
108 | flat_loss = loss.reshape(loss.shape[0], -1) # b x h x w -> b x hw
109 | new_weight = new_weight.reshape(new_weight.shape[0], -1)
110 |
111 | topk_loss, topk_inx = flat_loss.topk(int(flat_loss.shape[-1] * 0.15), sorted=False, dim=-1)
112 | topk_weight = new_weight.gather(1, topk_inx)
113 |
114 | loss = (topk_loss * topk_weight).mean()
115 | return loss
116 |
117 | # class TsrCoalesceLoss(Loss):
118 | # def per_pix_loss(
119 | # self, sem_pred, vote_pred, sem_mask, vote_mask, vote_bool_tsr, weight_mask
120 | # ):
121 | # del vote_mask # vote_mask is accepted merely for parameter compatibility
122 | # sem_loss_mask = self.regular_ce_loss(sem_pred, sem_mask, weight_mask)
123 | # vote_loss_mask = self.unsophisticated_loss(vote_pred, vote_bool_tsr, weight_mask)
124 | # return sem_loss_mask, vote_loss_mask
125 |
126 | # @staticmethod
127 | # def regular_ce_loss(pred, lbls, weight_mask):
128 | # return MaskedPanLoss._single_loss(pred, lbls, weight_mask)
129 |
130 | # @staticmethod
131 | # def booltsr_loss(pred, bool_tsr, weight_mask):
132 | # raise ValueError('cannot be back-propagated')
133 | # is_valid = bool_tsr.any(dim=1) # [N, H, W]
134 | # weight_mask = weight_mask[is_valid] # [num_valid, ]
135 | # # pred = pred.permute(0, 2, 3, 1)
136 | # # bool_tsr = bool_tsr.permute(0, 2, 3, 1)
137 | # # pred, bool_tsr = pred[is_valid], bool_tsr[is_valid] # [num_valid, C]
138 |
139 | # bottom = torch.logsumexp(pred, dim=1)
140 | # pred = torch.where(bool_tsr, pred, torch.tensor(float('-inf')).cuda())
141 | # pred = torch.logsumexp(pred, dim=1)
142 | # loss = (bottom - pred)[is_valid] # -1 is implicit here by reversing order
143 | # loss = (loss * weight_mask).sum() / weight_mask.sum()
144 | # return loss
145 |
146 | # @staticmethod
147 | # def unsophisticated_loss(pred, bool_tsr, weight_mask):
148 | # raise ValueError('validity mask changes spatial shape; embarassing')
149 | # is_valid = bool_tsr.any(dim=1) # [N, H, W]
150 | # weight_mask = weight_mask[is_valid]
151 |
152 | # pred = F.softmax(pred, dim=1)
153 | # pred = torch.where(bool_tsr, pred, torch.tensor(0.).cuda())
154 | # loss = torch.log(pred.sum(dim=1)[is_valid])
155 | # loss = loss_normalize(loss, weight_mask, len(pred))
156 | # loss = -1 * loss
157 | # return loss
158 |
159 |
160 | # class MaskedKLPanLoss(PanLoss):
161 | # '''
162 | # cross entropy loss for semantic segmentation and KL-divergence loss for
163 | # voting
164 | # '''
165 | # def per_pix_loss(
166 | # self, sem_pred, vote_pred,
167 | # sem_mask, vote_mask, vote_gt_prob, weight_mask
168 | # ):
169 | # sem_loss_mask = self.sem_loss(sem_pred, sem_mask, weight_mask)
170 | # vote_loss_mask = self.vote_loss(vote_pred, vote_gt_prob, weight_mask)
171 | # return sem_loss_mask, vote_loss_mask
172 |
173 | # @staticmethod
174 | # def sem_loss(sem_pred, sem_mask, weight_mask):
175 | # loss = F.cross_entropy(
176 | # sem_pred, sem_mask, ignore_index=ignore_index, reduction='none'
177 | # )
178 | # loss = loss_normalize(loss, weight_mask, len(sem_pred))
179 | # return loss
180 |
181 | # @staticmethod
182 | # def vote_loss(vote_pred, vote_gt_prob, weight_mask):
183 | # loss = F.kl_div(
184 | # F.log_softmax(vote_pred, dim=1), vote_gt_prob, reduction='none'
185 | # )
186 | # loss = loss.sum(dim=1)
187 | # loss = loss_normalize(loss, weight_mask, len(vote_pred))
188 | # return loss
189 |
190 |
191 | # class _CELoss(nn.Module):
192 | # def __init__(self):
193 | # super().__init__()
194 |
195 | # def forward(self, pred, mask):
196 | # """Assume that ignore index is set at 0
197 | # pred: [N, C, H, W]
198 | # mask: [N, H, W]
199 | # """
200 | # loss = F.cross_entropy(
201 | # pred, mask, ignore_index=ignore_index, reduction='mean'
202 | # )
203 | # return loss
204 |
205 |
206 | class NormalizedFocalPanLoss(nn.Module):
207 | # Adapt from adpatis
208 | def __init__(self, w_vote=0.5, w_sem=0.5, gamma=0, alpha=1):
209 | super().__init__()
210 | # assert w_vote + w_sem == 1
211 | self.w_vote = w_vote
212 | self.w_sem = w_sem
213 | self.gamma = gamma
214 |
215 | def focal_loss(self, input, target):
216 | logpt = - F.cross_entropy(input, target, ignore_index=ignore_index, reduction='none')
217 | pt = torch.exp(logpt)
218 |
219 | beta = (1 - pt) ** self.gamma
220 |
221 | t = target != ignore_index
222 | t_sum = t.float().sum(axis=[-2, -1], keepdims=True)
223 | beta_sum = beta.sum(axis=[-2, -1], keepdims=True)
224 |
225 | eps = 1e-10
226 |
227 | mult = t_sum / (beta_sum + eps)
228 | if True:
229 | mult = mult.detach()
230 | beta = beta * mult
231 |
232 | loss = - beta * logpt.clamp(min=-20) # B x H x W
233 |
234 | # size average
235 | loss = loss.sum(axis=[-2,-1]) / t_sum.squeeze()
236 | loss = loss.sum()
237 |
238 | return loss
239 |
240 | def forward(self, sem_pred, vote_pred, sem_mask, vote_mask, *args, **kwargs):
241 | sem_loss = self.focal_loss(sem_pred, sem_mask)
242 | vote_loss = self.focal_loss(vote_pred, vote_mask)
243 | loss = self.w_vote * vote_loss + self.w_sem * sem_loss
244 | return loss, sem_loss, vote_loss
245 |
246 |
247 | # class MaskedPanLoss(Loss):
248 | # def forward(self, sem_pred, vote_pred, sem_mask, vote_mask, weight_mask, vote_weight_mask=None):
249 | # sem_loss = self._single_loss(sem_pred, sem_mask, weight_mask)
250 | # if vote_weight_mask is None:
251 | # vote_weight_mask = weight_mask
252 | # vote_loss = self._single_loss(vote_pred, vote_mask, vote_weight_mask)
253 | # loss = self.w_vote * vote_loss + self.w_sem * sem_loss
254 | # return loss, sem_loss, vote_loss
255 |
256 | # @staticmethod
257 | # def _single_loss(pred, lbls, weight_mask):
258 | # loss = F.cross_entropy(
259 | # pred, lbls, ignore_index=ignore_index, reduction='none'
260 | # )
261 | # loss = loss_normalize(loss, weight_mask, len(pred))
262 | # return loss
263 |
264 |
265 | # class TsrCoalesceLoss(nn.Module):
266 | # def __init__(self, w_vote=0.5, w_sem=0.5):
267 | # super().__init__()
268 | # self.w_vote = w_vote
269 | # self.w_sem = w_sem
270 |
271 | # def forward(
272 | # self, sem_pred, vote_pred, sem_mask, vote_mask, vote_bool_tsr, weight_mask
273 | # ):
274 | # del vote_mask
275 | # sem_loss = self.regular_ce_loss(sem_pred, sem_mask, weight_mask)
276 | # vote_loss = self.unsophisticated_loss(vote_pred, vote_bool_tsr, weight_mask)
277 | # loss = self.w_vote * vote_loss + self.w_sem * sem_loss
278 | # return loss, sem_loss, vote_loss
279 |
280 | # @staticmethod
281 | # def regular_ce_loss(pred, lbls, weight_mask):
282 | # return MaskedPanLoss._single_loss(pred, lbls, weight_mask)
283 |
284 | # @staticmethod
285 | # def booltsr_loss(pred, bool_tsr, weight_mask):
286 | # raise ValueError('cannot be back-propagated')
287 | # is_valid = bool_tsr.any(dim=1) # [N, H, W]
288 | # weight_mask = weight_mask[is_valid] # [num_valid, ]
289 | # # pred = pred.permute(0, 2, 3, 1)
290 | # # bool_tsr = bool_tsr.permute(0, 2, 3, 1)
291 | # # pred, bool_tsr = pred[is_valid], bool_tsr[is_valid] # [num_valid, C]
292 |
293 | # bottom = torch.logsumexp(pred, dim=1)
294 | # pred = torch.where(bool_tsr, pred, torch.tensor(float('-inf')).cuda())
295 | # pred = torch.logsumexp(pred, dim=1)
296 | # loss = (bottom - pred)[is_valid] # -1 is implicit here by reversing order
297 | # loss = (loss * weight_mask).sum() / weight_mask.sum()
298 | # return loss
299 |
300 | # @staticmethod
301 | # def unsophisticated_loss(pred, bool_tsr, weight_mask):
302 | # is_valid = bool_tsr.any(dim=1) # [N, H, W]
303 | # weight_mask = weight_mask[is_valid]
304 |
305 | # pred = F.softmax(pred, dim=1)
306 | # pred = torch.where(bool_tsr, pred, torch.tensor(0.).cuda())
307 | # loss = torch.log(pred.sum(dim=1)[is_valid])
308 | # loss = loss_normalize(loss, weight_mask, len(pred))
309 | # loss = -1 * loss
310 | # return loss
311 |
--------------------------------------------------------------------------------
/src/pan_vis.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import numpy as np
3 | from PIL import Image
4 | import matplotlib.pyplot as plt
5 |
6 | from IPython.display import display as ipy_display
7 | from ipywidgets import interactive
8 | import ipywidgets as widgets
9 |
10 | from panopticapi.utils import rgb2id
11 | from panoptic.pan_analyzer import (
12 | PanopticEvalAnalyzer,
13 | _SEGMENT_MATCHED, _SEGMENT_UNMATCHED, _SEGMENT_FORGIVEN
14 | )
15 |
16 |
17 | class PanVis():
18 | def __init__(self, img_root, gt_json_meta_fname, pd_json_meta_fname):
19 | """Expect that the pan mask dir to be right beside the meta json file
20 | val.json
21 | val/
22 | Args:
23 | img_root: root dir where images are stored
24 | gt_json_meta_fname: abs fname to gt json
25 | pd_json_meta_fname: ...
26 | """
27 | self.img_root = img_root
28 |
29 | analyzer = PanopticEvalAnalyzer(gt_json_meta_fname, pd_json_meta_fname)
30 | self.gt, self.pd = analyzer.gt, analyzer.pd
31 | self.imgIds = analyzer.imgIds
32 | self.res_dframe, self.overall_table, self.cat_table = analyzer.summarize()
33 | # cached widgets
34 | self.global_walk = None
35 | self.__init_global_state__()
36 |
37 | def evaluate(self):
38 | # this is now a dud for backwards compatibility
39 | pass
40 |
41 | def summarize(self):
42 | print(self.cat_table)
43 | print(self.overall_table)
44 |
45 | def __init_global_state__(self):
46 | self.global_state = {
47 | 'imgId': None, 'segId': None,
48 | 'catId': None, 'tranche': None
49 | }
50 |
51 | # the following are the modules for widgets, from root to leaf
52 | def root_wdgt(self):
53 | """
54 | root widget delegates to either global or image
55 | """
56 | self.summarize()
57 | modes = ['Global', 'Single-Image']
58 |
59 | def logic(mode):
60 | # cache the widget later
61 | if mode == modes[0]:
62 | if self.global_walk is None:
63 | self.global_walk = self.global_walk_specifier()
64 | ipy_display(self.global_walk)
65 | elif mode == modes[1]:
66 | self.image_view = self.single_image_selector()
67 | # if self.image_view is None:
68 | # self.image_view = self.single_image_selector()
69 | # ipy_display(self.image_view)
70 |
71 | UI = interactive(
72 | logic, mode=widgets.ToggleButtons(options=modes, value=modes[0])
73 | )
74 | UI.children[-1].layout.height = '1000px'
75 | ipy_display(UI)
76 |
77 | def global_walk_specifier(self):
78 | tranche_map = self._tranche_filter(self.gt.segs, self.pd.segs)
79 |
80 | def logic(catId, tranche):
81 | if self.global_state['catId'] != catId \
82 | or self.global_state['tranche'] != tranche:
83 | self.__init_global_state__()
84 | self.global_state['catId'] = catId
85 | self.global_state['tranche'] = tranche
86 | seg_list = self._cat_filter_and_merge_tranche_map(
87 | tranche_map, [catId], [tranche]
88 | )
89 | # areas = [ seg['area'] for seg in seg_list ]
90 | # plt.hist(areas)
91 | self.walk_primary(seg_list, is_global=True)
92 | UI = interactive(
93 | logic,
94 | catId=self._category_roulette(self.gt.cats.keys(), multi_select=False),
95 | tranche=widgets.Select(
96 | options=tranche_map.keys(), value=list(tranche_map.keys())[0]
97 | )
98 | )
99 | return UI
100 |
101 | def single_image_selector(self):
102 | imgIds = self.imgIds
103 | inx, txt = self._inx_txt_scroller_pair(
104 | imgIds, default_txt=self.global_state['imgId'])
105 |
106 | def logic(inx):
107 | print("curr image {}/{}".format(inx, len(imgIds)))
108 | imgId = imgIds[inx]
109 | self.single_image_view_specifier(imgId)
110 | UI = interactive(logic, inx=inx)
111 | ipy_display(txt)
112 | ipy_display(UI)
113 |
114 | def single_image_view_specifier(self, imgId):
115 | gt_segs, pd_segs = self.gt.img2seg[imgId], self.pd.img2seg[imgId]
116 | _gt_cats = {seg['category_id'] for seg in gt_segs.values()}
117 | _pd_cats = {seg['category_id'] for seg in pd_segs.values()}
118 | relevant_catIds = _gt_cats | _pd_cats
119 | tranche_map = self._tranche_filter(gt_segs, pd_segs)
120 | modes = ['bulk', 'walk']
121 |
122 | def logic(catIds, tranches, mode):
123 | # only for walk, not for bulk display
124 | seg_list = self._cat_filter_and_merge_tranche_map(
125 | tranche_map, catIds, tranches)
126 | if mode == modes[0]:
127 | self.single_image_bulk_display(seg_list)
128 | elif mode == modes[1]:
129 | self.walk_primary(seg_list)
130 | UI = interactive(
131 | logic,
132 | mode=widgets.ToggleButtons(options=modes, value=modes[0]),
133 | catIds=self._category_roulette(
134 | relevant_catIds, multi_select=True,
135 | default_cid=[self.global_state['catId']]
136 | ),
137 | tranches=widgets.SelectMultiple(
138 | options=tranche_map.keys(),
139 | value=[self.global_state['tranche']]
140 | )
141 | )
142 | ipy_display(UI)
143 |
144 | def single_image_bulk_display(self, segs):
145 | if len(segs) == 0:
146 | return 'no segments in this tranche'
147 | imgId = segs[0]['image_id']
148 | for seg in segs:
149 | assert seg['image_id'] == imgId
150 | segIds = list(map(lambda x: x['sid'], segs))
151 | gt_seg_ids = list(filter(lambda x: x.startswith('gt/'), segIds))
152 | pd_seg_ids = list(filter(lambda x: x.startswith('pd/'), segIds))
153 | self.single_image_plot(imgId, gt_seg_ids, pd_seg_ids)
154 |
155 | def walk_primary(self, segs, is_global=False):
156 | """segs: a list of seg"""
157 | # the watching logic here is quite messy
158 | sids = [seg['sid'] for seg in segs]
159 | if len(sids) == 0:
160 | return 'no available segs'
161 | inx, txt = self._inx_txt_scroller_pair(
162 | sids, default_txt=self.global_state['segId'] if is_global else None
163 | )
164 |
165 | def logic(inx):
166 | seg = segs[inx]
167 | if is_global:
168 | self.global_state['segId'] = seg['sid']
169 | self.global_state['imgId'] = seg['image_id']
170 | print("Primary seg {}/{} matches with {} segments".format(
171 | inx, len(segs), len(seg['matchings'])))
172 | self.walk_matched(seg)
173 | UI = interactive(logic, inx=inx)
174 | print("Primary Segment:")
175 | ipy_display(txt)
176 | ipy_display(UI)
177 |
178 | def walk_matched(self, ref_seg):
179 | """child of walk_primary"""
180 | ref_sid = ref_seg['sid']
181 | # note that matchings is {sid: IoU}
182 | matched_sids = list(ref_seg['matchings'].keys())
183 | matched_ious = list(ref_seg['matchings'].values())
184 | if len(matched_sids) == 0:
185 | matched_sids = (None, )
186 | matched_ious = (0, )
187 |
188 | def segid_to_catname(partition, sid):
189 | if sid is None:
190 | return 'NA'
191 | return self.gt.cats[partition.segs[sid]['category_id']]['name']
192 |
193 | def logic(inx):
194 | match_sid = matched_sids[inx]
195 | if ref_sid.startswith('gt/'):
196 | gt_sid, pd_sid, highlight = ref_sid, match_sid, 1
197 | imgId = self.gt.segs[ref_sid]['image_id']
198 | else:
199 | gt_sid, pd_sid, highlight = match_sid, ref_sid, 2
200 | imgId = self.pd.segs[ref_sid]['image_id']
201 | print('IoU: {:.3f}'.format(matched_ious[inx]))
202 | print('gt: {} vs pd: {}'.format(
203 | segid_to_catname(self.gt, gt_sid),
204 | segid_to_catname(self.pd, pd_sid)
205 | ))
206 | self.single_image_plot(imgId, gt_sid, pd_sid, highlight)
207 |
208 | inx, txt = self._inx_txt_scroller_pair(matched_sids)
209 | UI = interactive(logic, inx=inx)
210 | print("Matched Segment:")
211 | ipy_display(txt)
212 | ipy_display(UI)
213 |
214 | @staticmethod
215 | def _tranche_filter(gt_segs, pd_segs):
216 | """
217 | Args:
218 | gt_segs: {segId: seg}
219 | pd_segs: {segId: seg}
220 | """
221 | def _filter(state, seg_map):
222 | seg_list = [
223 | seg for seg in seg_map.values() if seg['match_state'] == state
224 | ]
225 | seg_list = sorted(seg_list, key=lambda x: x['area'], reverse=True)
226 | return seg_list
227 |
228 | tranche_map = {
229 | 'TP': _filter(_SEGMENT_MATCHED, pd_segs),
230 | 'FN': _filter(_SEGMENT_UNMATCHED, gt_segs),
231 | 'FP': _filter(_SEGMENT_UNMATCHED, pd_segs),
232 | 'GT_FORGIVEN': _filter(_SEGMENT_FORGIVEN, gt_segs),
233 | 'PD_FORGIVEN': _filter(_SEGMENT_FORGIVEN, pd_segs)
234 | }
235 | assert len(gt_segs) == sum(
236 | map(lambda x: len(tranche_map[x]), ['TP', 'FN', 'GT_FORGIVEN'])
237 | )
238 | assert len(pd_segs) == sum(
239 | map(lambda x: len(tranche_map[x]), ['TP', 'FP', 'PD_FORGIVEN'])
240 | )
241 | return tranche_map
242 |
243 | @staticmethod
244 | def _cat_filter_and_merge_tranche_map(tranche_map, catIds, chosen_tranches):
245 | local_tranche_map = {
246 | k: list(filter(lambda seg: seg['category_id'] in catIds, seg_list))
247 | for k, seg_list in tranche_map.items()
248 | }
249 | for k, v in local_tranche_map.items():
250 | print("{}: {}".format(k, len(v)), end='; ')
251 | print('')
252 | seg_list = sum([local_tranche_map[_tr] for _tr in chosen_tranches], [])
253 | return seg_list
254 |
255 | @staticmethod
256 | def _inx_txt_scroller_pair(sids, default_txt=None):
257 | """
258 | Args:
259 | sids: [str, ] segment ids
260 | Note that since a handler is only called if 'value' changes, this mutual
261 | watching would not lead to infinite back-and-forth bouncing.
262 | In addition, bouncing-back is prevented by internal_change flag.
263 | """
264 | assert len(sids) > 0
265 | if default_txt is not None:
266 | default_inx, default_txt = sids.index(default_txt), default_txt
267 | else:
268 | default_inx, default_txt = 0, sids[0]
269 | inx = widgets.BoundedIntText(value=default_inx, min=0, max=len(sids) - 1)
270 | txt = widgets.Text(value=str(default_txt), description='ID')
271 | internal_change = False
272 |
273 | def inx_update_reaction(*args):
274 | nonlocal internal_change
275 | if internal_change:
276 | internal_change = False
277 | return
278 | curr_inx = inx.value
279 | curr_sid = sids[curr_inx]
280 | internal_change = True
281 | txt.value = curr_sid
282 | inx.observe(inx_update_reaction, 'value')
283 |
284 | def txt_update_reaction(*args):
285 | nonlocal internal_change
286 | if internal_change:
287 | internal_change = False
288 | return
289 | curr_sid = txt.value
290 | if curr_sid in sids:
291 | curr_inx = sids.index(curr_sid)
292 | internal_change = True
293 | inx.value = curr_inx
294 | txt.observe(txt_update_reaction, 'value')
295 |
296 | return inx, txt
297 |
298 | def _category_roulette(
299 | self, selected_catIds, multi_select=False, default_cid=None,
300 | ):
301 | """
302 | Things first, Stuff next, each sorted from high to low by PQ
303 | Note that this roulette is multi-selective, and return a tuple of catIds
304 | e.g.
305 | T, 16.60, Person
306 | T, 15.12, Bicycle
307 | S, 32.10, Road
308 | """
309 | catIds = np.array(sorted(self.gt.cats.keys()))
310 | PQ = self.res_dframe.values[:, 0] # (num_cats, )
311 |
312 | # first filter by selection, then sort by PQ from high to low
313 | chosen_mask = np.array(
314 | [ catId in selected_catIds for catId in catIds ], dtype=np.bool)
315 | catIds, PQ = catIds[chosen_mask], PQ[chosen_mask]
316 | order = np.argsort(-PQ) # high to low
317 | catIds, PQ = catIds[order], PQ[order]
318 |
319 | # now do things first followed by stuff
320 | acc = []
321 | isthing = np.array(
322 | [self.gt.cats[id]['isthing'] for id in catIds], dtype=bool)
323 | acc += [
324 | ('T, {:>4.2f}, {}'.format(pq, self.gt.cats[cid]['name']), cid)
325 | for pq, cid in zip(PQ[isthing], catIds[isthing])
326 | ]
327 | acc += [
328 | ('S, {:>4.2f}, {}'.format(pq, self.gt.cats[cid]['name']), cid)
329 | for pq, cid in zip(PQ[~isthing], catIds[~isthing])
330 | ]
331 |
332 | if default_cid is None:
333 | default_cid = acc[0][1]
334 | if multi_select and not isinstance(default_cid, (tuple, list)):
335 | default_cid = [default_cid]
336 | _module = widgets.SelectMultiple if multi_select else widgets.Select
337 | roulette = _module(options=acc, rows=15, value=default_cid)
338 | return roulette
339 |
340 | def single_image_plot(
341 | self, imgId, gt_seg_sid_list, pd_seg_sid_list,
342 | highlight=None, seg_alpha=0.7, seg_cmap='Blues'
343 | ):
344 | # first load image and annotations masks
345 | img = np.array(Image.open(
346 | osp.join(self.img_root, self.gt.imgs[imgId]['file_name'])
347 | ))
348 | gt_rgb = np.array(
349 | Image.open(osp.join(
350 | self.gt.mask_root, self.gt.imgs[imgId]['ann_fname']
351 | )),
352 | dtype=np.uint32
353 | )
354 | gt_mask = rgb2id(gt_rgb)
355 | pd_rgb = np.array(
356 | Image.open(osp.join(
357 | self.pd.mask_root, self.pd.imgs[imgId]['ann_fname']
358 | )),
359 | dtype=np.uint32
360 | )
361 | pd_mask = rgb2id(pd_rgb)
362 |
363 | # now aggregate the segment masks for both pd and gt
364 | def aggregate_seg_mask(sid_list, ref_mask, segs_map):
365 | if sid_list is None:
366 | sid_list = []
367 | if not isinstance(sid_list, (list, tuple)):
368 | sid_list = (sid_list, )
369 | seg_mask = np.zeros(ref_mask.shape, dtype=np.bool)
370 | for sid in sid_list:
371 | seg = segs_map[sid]
372 | assert seg['image_id'] == imgId
373 | seg_mask |= (ref_mask == seg['id'])
374 | return seg_mask
375 |
376 | gt_seg_mask = aggregate_seg_mask(gt_seg_sid_list, gt_mask, self.gt.segs)
377 | pd_seg_mask = aggregate_seg_mask(pd_seg_sid_list, pd_mask, self.pd.segs)
378 |
379 | # plot them together
380 | WHITE = [255, 255, 255]
381 | fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(30, 12))
382 | axes[0].imshow(img)
383 | gt_rgb[gt_seg_mask] = WHITE
384 | axes[1].imshow(gt_rgb)
385 | # axes[1].imshow(gt_seg_mask, alpha=seg_alpha, cmap=seg_cmap)
386 | pd_rgb[pd_seg_mask] = WHITE
387 | axes[2].imshow(pd_rgb)
388 | # axes[2].imshow(pd_seg_mask, alpha=seg_alpha, cmap=seg_cmap)
389 |
390 | if highlight is not None:
391 | axes[highlight].set_title(
392 | 'frame in focus', bbox=dict(facecolor='orange')
393 | )
394 | plt.show()
395 |
--------------------------------------------------------------------------------
/src/datasets/base.py:
--------------------------------------------------------------------------------
1 | import math
2 | import bisect
3 | import copy
4 | import os.path as osp
5 | import json
6 | from functools import partial
7 |
8 | import numpy as np
9 | from PIL import Image
10 | import cv2
11 |
12 | import torch
13 | from torch.utils.data import Dataset, DataLoader
14 | from torchvision.transforms.functional import normalize as tf_norm, to_tensor
15 |
16 | from .. import cfg as global_cfg
17 | from .augmentations import get_composed_augmentations
18 | from .samplers.grouped_batch_sampler import GroupedBatchSampler
19 | from .gt_producer import (
20 | ignore_index, convert_pan_m_to_sem_m, GTGenerator
21 | )
22 | from panopticapi.utils import rgb2id
23 |
24 | from fabric.utils.logging import setup_logging
25 | logger = setup_logging(__file__)
26 |
27 | '''
28 | Expect a data root directory with the following layout (only coco sub-tree is
29 | expanded for simplicity)
30 | .
31 | ├── coco
32 | │ ├── annotations
33 | │ │ ├── train/
34 | │ │ ├── train.json
35 | │ │ ├── val/
36 | │ │ └── val.json
37 | │ └── images
38 | │ ├── train/
39 | │ └── val/
40 | ├── mapillary/
41 | └── cityscapes/
42 | '''
43 |
44 |
45 | '''
46 | A reminder of Panoptic meta json structure: (only the relevant fields are listed)
47 |
48 | info/
49 | licenses/
50 | categories:
51 | - id: 1
52 | name: person
53 | supercategory: person
54 | isthing: 1
55 | - id: 191
56 | name: pavement-merged
57 | supercategory: ground
58 | isthing: 0
59 | images:
60 | - id: 397133
61 | file_name: 000000397133.jpg
62 | height: 427
63 | width: 640
64 | - id: 397134
65 | file_name: 000000397134.jpg
66 | height: 422
67 | width: 650
68 | annotations:
69 | - image_id: 139
70 | file_name: 000000000139.png
71 | segments_info:
72 | - id: 3226956
73 | category_id: 1
74 | iscrowd: 0,
75 | bbox: [413, 158, 53, 138] in xywh form,
76 | area: 2840
77 | - repeat omitted
78 | '''
79 |
80 | imagenet_mean = [0.485, 0.456, 0.406]
81 | imagenet_std = [0.229, 0.224, 0.225]
82 |
83 |
84 | def imagenet_normalize(im):
85 | """This only operates on single image"""
86 | if im.shape[0] == 1:
87 | return im # deal with these gray channel images in coco later. Hell
88 | return tf_norm(im, imagenet_mean, imagenet_std)
89 |
90 |
91 | def caffe_imagenet_normalize(im):
92 | im = im * 255.0
93 | im = tf_norm(im, (102.9801, 115.9465, 122.7717), (1.0, 1.0, 1.0))
94 | return im
95 |
96 |
97 | def check_and_tuplize_splits(splits):
98 | if not isinstance(splits, (tuple, list)):
99 | splits = (splits, )
100 | for split in splits:
101 | assert split in ('train', 'val', 'test')
102 | return splits
103 |
104 |
105 | def mapify_iterable(iter_of_dict, field_name):
106 | """Convert an iterable of dicts into a big dict indexed by chosen field
107 | I can't think of a better name. 'Tis catchy.
108 | """
109 | acc = dict()
110 | for item in iter_of_dict:
111 | acc[item[field_name]] = item
112 | return acc
113 |
114 |
115 | def ttic_find_data_root(dset_name):
116 | '''Find the fastest data root on TTIC slurm cluster'''
117 | default = osp.join('/share/data/vision-greg/panoptic', dset_name)
118 | return default
119 | fast = osp.join('/vscratch/vision/panoptic', dset_name)
120 | return fast if osp.isdir(fast) else default
121 |
122 |
123 | def test_meta_conform(rmeta):
124 | '''the metadata for test set does not conform to panoptic format;
125 | Test ann only has 'images' and 'categories; let's fill in annotations'
126 | '''
127 | images = rmeta['images']
128 | anns = []
129 | for img in images:
130 | _curr_ann = {
131 | 'image_id': img['id'],
132 | 'segments_info': []
133 | # do not fill in file_name
134 | }
135 | anns.append(_curr_ann)
136 | rmeta['annotations'] = anns
137 | return rmeta
138 |
139 |
140 | class PanopticBase(Dataset):
141 | def __init__(self, name, split):
142 | available_dsets = ('coco', 'cityscapes', 'mapillary')
143 | assert name in available_dsets, '{} dset is not available'.format(name)
144 | root = ttic_find_data_root(name)
145 | logger.info('using data root {}'.format(root))
146 | self.root = root
147 | self.name = name
148 | self.split = split
149 | self.img_root = osp.join(root, 'images', split)
150 | self.ann_root = osp.join(root, 'annotations', split)
151 | meta_fname = osp.join(root, 'annotations', '{}.json'.format(split))
152 | with open(meta_fname) as f:
153 | rmeta = json.load(f) # rmeta stands for raw metadata
154 | if self.split.startswith('test'):
155 | rmeta = test_meta_conform(rmeta)
156 |
157 | # store category metadata
158 | self.meta = dict()
159 | self.meta['cats'] = mapify_iterable(rmeta['categories'], 'id')
160 | self.meta['cat_IdToName'] = dict()
161 | self.meta['cat_NameToId'] = dict()
162 | for cat in rmeta['categories']:
163 | id, name = cat['id'], cat['name']
164 | self.meta['cat_IdToName'][id] = name
165 | self.meta['cat_NameToId'][name] = id
166 |
167 | # store image and annotations metadata
168 | self.imgs = mapify_iterable(rmeta['images'], 'id')
169 | self.imgToAnns = mapify_iterable(rmeta['annotations'], 'image_id')
170 | self.imgIds = list(sorted(self.imgs.keys()))
171 |
172 | def confine_to_subset(self, imgIds):
173 | '''confine data loading to a subset of images
174 | This is used for figure making.
175 | '''
176 | # confirm that the supplied imgIds are all valid
177 | for supplied_id in imgIds:
178 | assert supplied_id in self.imgIds
179 | self.imgIds = imgIds
180 |
181 | def test_seek_imgs(self, i, total_splits):
182 | assert isinstance(total_splits, int)
183 | length = len(self.imgIds)
184 | portion_size = int(math.ceil(length * 1.0 / total_splits))
185 | start = i * portion_size
186 | end = min(length, (i + 1) * portion_size)
187 | self.imgIds = self.imgIds[start:end]
188 | acc = { k: self.imgs[k] for k in self.imgIds }
189 | self.imgs = acc
190 |
191 | def __len__(self):
192 | return len(self.imgIds)
193 |
194 | def read_img(self, img_fname):
195 | # there are some gray scale images in coco; convert to RGB
196 | return Image.open(img_fname).convert('RGB')
197 |
198 | def get_meta(self, index):
199 | imgId = self.imgIds[index]
200 | imgMeta = self.imgs[imgId]
201 | anns = self.imgToAnns[imgId]
202 | return imgMeta, anns
203 |
204 | def __getitem__(self, index):
205 | imgMeta, anns = self.get_meta(index)
206 | img_fname = osp.join(self.img_root, imgMeta['file_name'])
207 | img = self.read_img(img_fname)
208 | if self.split.startswith('test'):
209 | mask = Image.fromarray(
210 | np.zeros(np.array(img).shape, dtype=np.uint8)
211 | )
212 | else:
213 | mask = Image.open(osp.join(self.ann_root, anns['file_name']))
214 | segments_info = mapify_iterable(anns['segments_info'], 'id')
215 | return imgMeta, segments_info, img, mask
216 |
217 |
218 | class SemanticSeg(PanopticBase):
219 | def __init__(self, name, split, transforms):
220 | super().__init__(name=name, split=split)
221 | self.transforms = transforms
222 | # produce train cat index id starting at 0
223 | self.meta['catId_2_trainId'] = dict()
224 | self.meta['trainId_2_catId'] = dict()
225 | self.meta['trainId_2_catName'] = dict() # all things grouped into "things"
226 | self.meta['trainId_2_catName'][ignore_index] = 'ignore'
227 | self.prep_trainId()
228 | self.meta['num_classes'] = len(self.meta['trainId_2_catName']) - 1
229 |
230 | def prep_trainId(self):
231 | curr_inx = 0
232 | for catId, cat in self.meta['cats'].items():
233 | self.meta['catId_2_trainId'][catId] = curr_inx
234 | self.meta['trainId_2_catId'][curr_inx] = catId
235 | self.meta['trainId_2_catName'][curr_inx] = cat['name']
236 | curr_inx += 1
237 |
238 | def __getitem__(self, index):
239 | raise ValueError('diabling data loading through this class for now')
240 | _, segments_info, im, mask = super().__getitem__(index)
241 | mask = np.array(mask)
242 | mask = rgb2id(np.array(mask))
243 | mask = convert_pan_m_to_sem_m(
244 | mask, segments_info, self.meta['catId_2_trainId'])
245 | mask = Image.fromarray(mask, mode='I')
246 | if self.transforms is not None:
247 | im, mask = self.transforms(im, mask)
248 | im, mask = to_tensor(im), to_tensor(mask).squeeze(dim=0).long()
249 | im = imagenet_normalize(im)
250 | return im, mask
251 |
252 |
253 | class PanopticSeg(SemanticSeg):
254 | def __init__(
255 | self, name, split, transforms, pcv, gt_producers,
256 | caffe_mode=False, tensorize=True
257 | ):
258 | super().__init__(name=name, split=split, transforms=transforms)
259 | self.pcv = pcv
260 | self.meta['stuff_pred_thresh'] = -1
261 | self.gt_producer_cfgs = gt_producers
262 | self.tensorize = tensorize
263 | self.caffe_mode = caffe_mode
264 | self.gt_prod_handle = partial(GTGenerator, producer_cfgs=gt_producers)
265 |
266 | def read_img(self, img_fname):
267 | if self.caffe_mode:
268 | # cv2 reads imgs in BGR, which is what caffe trained models expect
269 | img = cv2.imread(img_fname) # cv2 auto converts gray to BGR
270 | img = Image.fromarray(img)
271 | else:
272 | img = Image.open(img_fname).convert('RGB')
273 | return img
274 |
275 | def pan_getitem(self, index, apply_trans=True):
276 | # this is now exposed as public API
277 | imgMeta, segments_info, img, mask = PanopticBase.__getitem__(self, index)
278 | if apply_trans and self.transforms is not None:
279 | img, mask = self.transforms(img, mask)
280 | return imgMeta, segments_info, img, mask
281 |
282 | def __getitem__(self, index):
283 | _, segments_info, im, pan_mask = self.pan_getitem(index)
284 | gts = []
285 | if not self.split.startswith('test'):
286 | lo_pan_mask = pan_mask.resize(
287 | np.array(im.size, dtype=np.int) // 4, resample=Image.NEAREST
288 | )
289 | gts = self.gt_prod_handle(
290 | self.meta, self.pcv, lo_pan_mask, segments_info
291 | ).generate_gt()
292 |
293 | if self.split == 'train':
294 | sem_gt = gts[0]
295 | else:
296 | hi_sem_gt = self.gt_prod_handle(
297 | self.meta, self.pcv, pan_mask, segments_info,
298 | ).sem_gt
299 | sem_gt = hi_sem_gt
300 | gts[0] = sem_gt
301 | # else for test/test-dev, do not produce ground truth at all
302 |
303 | if self.tensorize:
304 | im = to_tensor(im)
305 | gts = [ torch.as_tensor(elem) for elem in gts ]
306 | if self.caffe_mode:
307 | im = caffe_imagenet_normalize(im)
308 | else:
309 | im = imagenet_normalize(im)
310 | else:
311 | im = np.array(im)
312 |
313 | gts.insert(0, im)
314 | return tuple(gts)
315 |
316 | @classmethod
317 | def make_loader(
318 | cls, data_cfg, pcv_module, is_train, mp_distributed, world_size,
319 | val_split='val'
320 | ):
321 | if is_train:
322 | split = 'train'
323 | batch_size = data_cfg.train_batch_size
324 | transforms_cfg = data_cfg.train_transforms
325 | else:
326 | split = val_split
327 | batch_size = data_cfg.test_batch_size
328 | transforms_cfg = data_cfg.test_transforms
329 |
330 | num_workers = data_cfg.num_loading_threads
331 | if mp_distributed:
332 | num_workers = int((num_workers + world_size - 1) / world_size)
333 | if is_train:
334 | batch_size = int(batch_size / world_size)
335 | # at test time a model does not need to reduce its batch size
336 |
337 | # 1. dataset
338 | trans = get_composed_augmentations(transforms_cfg)
339 | instance = cls(
340 | split=split, transforms=trans, pcv=pcv_module,
341 | gt_producers=data_cfg.dataset.gt_producers,
342 | **data_cfg.dataset.params,
343 | )
344 | # if split.startswith('test'):
345 | # inx = global_cfg.testing.inx
346 | # total_splits = global_cfg.testing.portions
347 | # instance.test_seek_imgs(inx, total_splits)
348 |
349 | # 2. sampler
350 | sampler = cls.make_sampler(instance, is_train, mp_distributed)
351 |
352 | # 3. batch sampler
353 | batch_sampler = cls.make_batch_sampler(
354 | instance, sampler, batch_size=batch_size,
355 | aspect_grouping=cls.aspect_grouping,
356 | )
357 | del sampler
358 |
359 | # 4. collator
360 | if split.startswith('test'):
361 | collator = None
362 | else:
363 | collator = BatchCollator(data_cfg.dataset.gt_producers)
364 |
365 | # 5. loader
366 | loader = DataLoader(
367 | instance,
368 | num_workers=num_workers,
369 | batch_sampler=batch_sampler,
370 | collate_fn=collator,
371 | # pin_memory=True maskrcnn-benchmark does not pin memory
372 | )
373 | return loader
374 |
375 | @staticmethod
376 | def make_sampler(dataset, is_train, distributed):
377 | if is_train:
378 | if distributed:
379 | # as of pytorch 1.1.0 the distributed sampler always shuffles
380 | sampler = torch.utils.data.distributed.DistributedSampler(dataset)
381 | else:
382 | sampler = torch.utils.data.sampler.RandomSampler(dataset)
383 | else:
384 | sampler = torch.utils.data.sampler.SequentialSampler(dataset)
385 | return sampler
386 |
387 | @staticmethod
388 | def make_batch_sampler(
389 | dataset, sampler, batch_size, aspect_grouping,
390 | ):
391 | if aspect_grouping:
392 | aspect_ratios = _compute_aspect_ratios(dataset)
393 | group_ids = _quantize(aspect_ratios, bins=[1, ])
394 | batch_sampler = GroupedBatchSampler(
395 | sampler, group_ids, batch_size, drop_uneven=False
396 | )
397 | else:
398 | batch_sampler = torch.utils.data.sampler.BatchSampler(
399 | sampler, batch_size, drop_last=False
400 | )
401 | '''
402 | I've decided after much deliberation not to use iteration based training.
403 | Our cluster has a 4-hour time limit before job interrupt.
404 | Under this constraint, I have to resume from where the sampling stopped
405 | at the exact epoch the interrupt occurs and this requires checkpointing
406 | the sampler state.
407 | However, under distributed settings, each process has to checkpoint a
408 | different state since each sees only a portion of the data by virtue of the
409 | distributed sampler. This is bug-prone and brittle.
410 | '''
411 | # if num_iters is not None:
412 | # batch_sampler = samplers.IterationBasedBatchSampler(
413 | # batch_sampler, num_iters, start_iter
414 | # )
415 | return batch_sampler
416 |
417 |
418 | class BatchCollator(object):
419 | def __init__(self, producer_cfg):
420 | fills = [0, ignore_index, ] # always 0 for img, ignore for sem_gt
421 | fills.extend(
422 | [cfg['params']['fill'] for cfg in producer_cfg]
423 | )
424 | self.fills = fills
425 |
426 | def __call__(self, batch):
427 | transposed_batch = list(zip(*batch))
428 | assert len(self.fills) == len(transposed_batch), 'must match in length'
429 | acc = []
430 | for tsr_list, fill in zip(transposed_batch, self.fills):
431 | tsr = self.collate_tensor_list(tsr_list, fill=fill)
432 | acc.append(tsr)
433 | return tuple(acc)
434 |
435 | @staticmethod
436 | def collate_tensor_list(tensors, fill):
437 | """
438 | Pad the Tensors with the specified constant
439 | so that they have the same shape
440 | """
441 | assert isinstance(tensors, (tuple, list))
442 | # largest size along each dimension
443 | max_size = tuple(max(s) for s in zip(*[tsr.shape for tsr in tensors]))
444 | assert len(max_size) == 2 or len(max_size) == 3
445 | batch_shape = (len(tensors),) + max_size
446 | batched_tsr = tensors[0].new(*batch_shape).fill_(fill)
447 | for tsr, pad_tsr in zip(tensors, batched_tsr):
448 | # WARNING only pad the last 2, that is the spatial dimensions
449 | pad_tsr[..., :tsr.shape[-2], :tsr.shape[-1]].copy_(tsr)
450 | return batched_tsr
451 |
452 |
453 | def _quantize(x, bins):
454 | bins = copy.copy(bins)
455 | bins = sorted(bins)
456 | quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
457 | return quantized
458 |
459 |
460 | def _compute_aspect_ratios(dataset):
461 | aspect_ratios = []
462 | for i in range(len(dataset)):
463 | imgMeta, _ = dataset.get_meta(i)
464 | aspect_ratio = float(imgMeta["height"]) / float(imgMeta["width"])
465 | aspect_ratios.append(aspect_ratio)
466 | return aspect_ratios
467 |
--------------------------------------------------------------------------------
/src/datasets/gt_producer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from panopticapi.utils import rgb2id
4 |
5 | from fabric.io import save_object
6 |
7 | ignore_index = -1
8 |
9 |
10 | def convert_pan_m_to_sem_m(pan_mask, segments_info, catId_2_trainId):
11 | """Convert a panoptic mask to semantic segmentation mask
12 | pan_mask: [H, W, 3] panoptic mask
13 | segments_info: dict
14 | catId_2_trainId: dict
15 | """
16 | sem = np.zeros_like(pan_mask, np.int64) # torch requires long tensor
17 | iid_to_catid = {
18 | el['id']: el['category_id'] for el in segments_info.values()
19 | }
20 | for iid in np.unique(pan_mask):
21 | if iid not in iid_to_catid:
22 | assert iid == 0
23 | sem[pan_mask == 0] = ignore_index
24 | continue
25 | cat_id = iid_to_catid[iid]
26 | train_id = catId_2_trainId[cat_id]
27 | sem[pan_mask == iid] = train_id
28 | return sem
29 |
30 |
31 | def tensorize_2d_spatial_assignement(spatial, num_channels):
32 | """
33 | Args:
34 | spatial: [H, W] where -1 encodes invalid region
35 | Here we produce an extra, last channel to be set by inx -1,
36 | only to be throw out later. Neat
37 | Ret:
38 | np array of shape [1, C, H, W]
39 | """
40 | H, W = spatial.shape
41 | num_channels += 1 # add 1 extra channel to be turned on by inx -1
42 | tsr = np.zeros(shape=(num_channels, H, W))
43 | dim_0_inds, dim_1_inds = np.ix_(range(H), range(W))
44 | tsr[spatial, dim_0_inds, dim_1_inds] = 1
45 | tsr = tsr[:-1, :, :] # throw out the extra dimension for -1s
46 | tsr = tsr[np.newaxis, ...]
47 | return tsr
48 |
49 |
50 | class GtProducer():
51 | def __init__(self, pcv, mask_shape, params):
52 | self.pcv = pcv
53 | self.mask_shape = mask_shape
54 | self.params = params
55 | self.interpretable_as_prob_tsr = False
56 |
57 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset):
58 | raise NotImplementedError()
59 |
60 | def produce():
61 | raise NotImplementedError()
62 |
63 | def convert_to_prob_tsr():
64 | raise NotImplementedError()
65 |
66 |
67 | class WeightMask(GtProducer):
68 | def __init__(self, *args, **kwargs):
69 | super().__init__(*args, **kwargs)
70 | self.weight = np.zeros(self.mask_shape, dtype=np.float32)
71 | self.power = self.params.get('power', 0.5)
72 |
73 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset):
74 | w = 1 / (segment_area ** self.power)
75 | if isthing and not iscrowd:
76 | self.weight[segment_mask] = w
77 | elif not isthing:
78 | self.weight[segment_mask] = w
79 | elif iscrowd:
80 | pass
81 | else:
82 | raise ValueError('unreachable')
83 |
84 | def produce(self):
85 | return self.weight
86 |
87 |
88 | class _WeightMask(GtProducer):
89 | def __init__(self, *args, **kwargs):
90 | super().__init__(*args, **kwargs)
91 | self.weight = np.zeros(self.mask_shape, dtype=np.float32)
92 | self.power = self.params.get('power', 0.5)
93 |
94 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset):
95 | w = 1 / (segment_area ** self.power)
96 | if isthing and not iscrowd:
97 | lbls = self.pcv.discrete_vote_inx_from_offset(ins_offset)
98 | # pixs within an instance may be ignored if they sit on grid boundaries
99 | is_valid = lbls > ignore_index
100 | valid_area = is_valid.sum()
101 | # assert valid_area > 0, 'segment area {} vs valid area {}, lbls are {}'.format(
102 | # segment_area, valid_area, lbls
103 | # )
104 | if valid_area == 0:
105 | w = 0
106 | else:
107 | w = 1 / (valid_area ** self.power)
108 | self.weight[segment_mask] = w * is_valid
109 | elif not isthing:
110 | self.weight[segment_mask] = w
111 | elif iscrowd:
112 | pass
113 | else:
114 | raise ValueError('unreachable')
115 |
116 | def produce(self):
117 | return self.weight
118 |
119 |
120 | class ThingWeightMask(GtProducer):
121 | def __init__(self, *args, **kwargs):
122 | super().__init__(*args, **kwargs)
123 | self.weight = np.zeros(self.mask_shape, dtype=np.float32)
124 | self.power = self.params.get('power', 0.5)
125 |
126 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset):
127 | w = 1 / (segment_area ** self.power)
128 | if isthing and not iscrowd:
129 | lbls = self.pcv.discrete_vote_inx_from_offset(ins_offset)
130 | # pixs within an instance may be ignored if they sit on grid boundaries
131 | is_valid = lbls > ignore_index
132 | valid_area = is_valid.sum()
133 | # assert valid_area > 0, 'segment area {} vs valid area {}, lbls are {}'.format(
134 | # segment_area, valid_area, lbls
135 | # )
136 | if valid_area == 0:
137 | w = 0
138 | else:
139 | w = 1 / (valid_area ** self.power)
140 | self.weight[segment_mask] = w * is_valid
141 | elif not isthing:
142 | self.weight[segment_mask] = 0
143 | elif iscrowd:
144 | pass
145 | else:
146 | raise ValueError('unreachable')
147 |
148 | def produce(self):
149 | return self.weight
150 |
151 |
152 | class NoAbstainWeightMask(GtProducer):
153 | def __init__(self, *args, **kwargs):
154 | super().__init__(*args, **kwargs)
155 | self.weight = np.zeros(self.mask_shape, dtype=np.float32)
156 | self.power = self.params.get('power', 0.5)
157 |
158 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset):
159 | w = 1 / (segment_area ** self.power)
160 | if isthing and not iscrowd:
161 | self.weight[segment_mask] = w
162 | elif not isthing: # only difference compared to WeightMask
163 | pass
164 | elif iscrowd:
165 | pass
166 | else:
167 | raise ValueError('unreachable')
168 |
169 | def produce(self):
170 | return self.weight
171 |
172 |
173 | class InsWeightMask(GtProducer):
174 | def __init__(self, *args, **kwargs):
175 | super().__init__(*args, **kwargs)
176 | self.weight = np.zeros(self.mask_shape, dtype=np.float32)
177 | self.stuff_mask = np.zeros(self.mask_shape, dtype=bool)
178 |
179 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset):
180 | if isthing and not iscrowd:
181 | self.weight[segment_mask] = 1 / segment_area
182 | elif not isthing:
183 | self.stuff_mask[segment_mask] = True
184 | elif iscrowd:
185 | pass
186 | else:
187 | raise ValueError('unreachable')
188 |
189 | def produce(self):
190 | stuff_area = self.stuff_mask.sum()
191 | stuff_w = 20 / stuff_area
192 | self.weight[self.stuff_mask] = stuff_w
193 | self.weight = self.weight / 40
194 | return self.weight
195 |
196 |
197 | class PCV_vote_no_abstain_mask(GtProducer):
198 | def __init__(self, *args, **kwargs):
199 | super().__init__(*args, **kwargs)
200 | self.interpretable_as_prob_tsr = True
201 | self.gt_mask = ignore_index * np.ones(self.mask_shape, np.int32)
202 | self.votable_mask = np.zeros(self.mask_shape, dtype=bool)
203 | self.offset_mask = np.empty(
204 | list(self.mask_shape) + [2], dtype=np.int32
205 | )
206 | pcv = self.pcv
207 | assert (pcv.num_votes - pcv.num_bins) == 1
208 | self.ABSTAIN_INX = ignore_index
209 |
210 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset):
211 | if isthing and not iscrowd:
212 | self.votable_mask[segment_mask] = True
213 | self.offset_mask[segment_mask] = ins_offset
214 | elif not isthing:
215 | self.gt_mask[segment_mask] = self.ABSTAIN_INX
216 | elif iscrowd:
217 | pass
218 | else:
219 | raise ValueError('unreachable')
220 |
221 | def produce(self):
222 | votable_mask = self.votable_mask
223 | self.gt_mask[votable_mask] = self.pcv.discrete_vote_inx_from_offset(
224 | self.offset_mask[votable_mask]
225 | )
226 | self.gt_mask = self.gt_mask.astype(np.int64)
227 | return self.gt_mask
228 |
229 | def convert_to_prob_tsr(self):
230 | return tensorize_2d_spatial_assignement(self.gt_mask, self.pcv.num_votes)
231 |
232 |
233 | class PCV_vote_mask(GtProducer):
234 | def __init__(self, *args, **kwargs):
235 | super().__init__(*args, **kwargs)
236 | self.interpretable_as_prob_tsr = True
237 | self.gt_mask = ignore_index * np.ones(self.mask_shape, np.int32)
238 | self.votable_mask = np.zeros(self.mask_shape, dtype=bool)
239 | self.offset_mask = np.empty(
240 | list(self.mask_shape) + [2], dtype=np.int32
241 | )
242 | pcv = self.pcv
243 | assert (pcv.num_votes - pcv.num_bins) == 1
244 | self.ABSTAIN_INX = pcv.num_bins
245 |
246 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset):
247 | if isthing and not iscrowd:
248 | self.votable_mask[segment_mask] = True
249 | self.offset_mask[segment_mask] = ins_offset
250 | elif not isthing:
251 | self.gt_mask[segment_mask] = self.ABSTAIN_INX
252 | elif iscrowd:
253 | pass
254 | else:
255 | raise ValueError('unreachable')
256 |
257 | def produce(self):
258 | votable_mask = self.votable_mask
259 | self.gt_mask[votable_mask] = self.pcv.discrete_vote_inx_from_offset(
260 | self.offset_mask[votable_mask]
261 | )
262 | self.gt_mask = self.gt_mask.astype(np.int64)
263 | return self.gt_mask
264 |
265 | def convert_to_prob_tsr(self):
266 | return tensorize_2d_spatial_assignement(self.gt_mask, self.pcv.num_votes)
267 |
268 |
269 | class PCV_igc_tsr(GtProducer):
270 | '''this code is problematic, but at least it is well exposed now'''
271 | def __init__(self, *args, **kwargs):
272 | super().__init__(*args, **kwargs)
273 | self.interpretable_as_prob_tsr = True
274 | self.tsr = np.zeros(
275 | list(self.mask_shape) + [self.pcv.num_votes], dtype=np.bool
276 | )
277 |
278 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset):
279 | if isthing and not iscrowd:
280 | self.tsr[segment_mask] = self.pcv.tensorized_vote_from_offset(
281 | ins_offset
282 | )
283 | elif not isthing:
284 | self.tsr[segment_mask, -1] = True # abstaining
285 | elif iscrowd:
286 | pass
287 | else:
288 | raise ValueError('unreachable')
289 |
290 | def produce(self):
291 | tsr = torch.as_tensor(
292 | self.tsr.transpose(2, 0, 1) # [C, H, W]
293 | ).contiguous()
294 | self.tsr = tsr
295 | return tsr
296 |
297 | def convert_to_prob_tsr(self):
298 | tsr = self.tsr
299 | vote_tsr = tsr.clone().float()
300 | base = vote_tsr.sum(dim=0, keepdim=True)
301 | base[base == 0] = 1 # any number
302 | vote_tsr = (vote_tsr / base).unsqueeze(dim=0).numpy()
303 | return vote_tsr
304 |
305 |
306 | class PCV_smooth_tsr(GtProducer):
307 | def __init__(self, *args, **kwargs):
308 | super().__init__(*args, **kwargs)
309 | self.interpretable_as_prob_tsr = True
310 | self.votable_mask = np.zeros(self.mask_shape, dtype=bool)
311 | self.abstain_mask = np.zeros(self.mask_shape, dtype=bool)
312 | self.offset_mask = np.empty(
313 | list(self.mask_shape) + [2], dtype=np.int32
314 | )
315 | self.tsr = np.zeros(
316 | list(self.mask_shape) + [self.pcv.num_votes], dtype=np.float32
317 | )
318 | pcv = self.pcv
319 | assert (pcv.num_votes - pcv.num_bins) == 1
320 |
321 | def process(self, isthing, iscrowd, segment_mask, segment_area, ins_offset):
322 | if isthing and not iscrowd:
323 | self.votable_mask[segment_mask] = True
324 | self.offset_mask[segment_mask] = ins_offset
325 | elif not isthing:
326 | self.abstain_mask[segment_mask] = True
327 | elif iscrowd:
328 | pass
329 | else:
330 | raise ValueError('unreachable')
331 |
332 | def produce(self):
333 | self.tsr[self.abstain_mask, -1] = 1.0
334 | self.tsr[self.votable_mask] = self.pcv.smooth_prob_tsr_from_offset(
335 | self.offset_mask[self.votable_mask]
336 | )
337 | self.tsr = torch.as_tensor(
338 | self.tsr.transpose(2, 0, 1) # [C, H, W]
339 | ).contiguous()
340 | return self.tsr
341 |
342 | def convert_to_prob_tsr(self):
343 | tsr = self.tsr
344 | vote_tsr = tsr.clone().unsqueeze(dim=0).numpy()
345 | return vote_tsr
346 |
347 |
348 | _REGISTRY = {
349 | 'weight_mask': WeightMask,
350 | 'thing_weight_mask': ThingWeightMask,
351 | 'vote_mask': PCV_vote_mask,
352 | 'vote_no_abstain_mask': PCV_vote_no_abstain_mask,
353 | 'igc_tsr': PCV_igc_tsr,
354 | 'smth_tsr': PCV_smooth_tsr,
355 | 'ins_weight_mask': InsWeightMask,
356 | 'no_abstain_weight_mask': NoAbstainWeightMask
357 | }
358 |
359 |
360 | class GTGenerator():
361 | def __init__(self, meta, pcv, pan_mask, segments_info, producer_cfgs=()):
362 | assert len(meta['catId_2_trainId']) == len(meta['cats'])
363 | self.pcv = pcv
364 | self.meta = meta
365 | self.pan_mask = rgb2id(np.array(pan_mask))
366 | self.segments_info = segments_info
367 |
368 | am_I_gt_producer = [
369 | int(cfg['params'].get('is_vote_gt', False)) for cfg in producer_cfgs
370 | ]
371 | if sum(am_I_gt_producer) != 1:
372 | raise ValueError('exactly 1 producer should be in charge of vote gt')
373 | self.gt_producer_inx = np.argmax(am_I_gt_producer)
374 | self.producer_cfgs = producer_cfgs
375 | self.producers = None
376 |
377 | self.ins_centroids = None
378 | self._sem_gt = None
379 |
380 | @property
381 | def sem_gt(self):
382 | if self._sem_gt is None:
383 | self._sem_gt = convert_pan_m_to_sem_m(
384 | self.pan_mask, self.segments_info, self.meta['catId_2_trainId']
385 | )
386 | return self._sem_gt
387 |
388 | def generate_gt(self):
389 | """Return a new mask with each pixel containing a discrete numeric label
390 | indicating which spatial bin the pixel should vote for
391 | Args:
392 | mask: [H, W] array filled with segment id
393 | segments_info:
394 | """
395 | pcv = self.pcv
396 | mask = self.pan_mask
397 | segments_info = self.segments_info
398 | category_meta = self.meta['cats']
399 |
400 | self.producers = [
401 | _REGISTRY[cfg['name']](pcv, mask.shape, cfg['params'])
402 | for cfg in self.producer_cfgs
403 | ] # initalize the producers with params
404 |
405 | # indices grid of shape [2, H, W], where first dim is y, x; swap them
406 | # [H, W, 2] where last channel is x, y
407 | spatial_inds = np.indices(
408 | mask.shape).transpose(1, 2, 0)[..., ::-1].astype(np.int32)
409 | ins_centroids = []
410 |
411 | for segment_id, info in segments_info.items():
412 | cat, iscrowd = info['category_id'], info['iscrowd']
413 | isthing = category_meta[cat]['isthing']
414 | segment_mask = (mask == segment_id)
415 | area = segment_mask.sum()
416 | if area == 0:
417 | # cropping or extreme resizing might cause segments to disappear
418 | continue
419 |
420 | ins_offset = None
421 | if isthing and not iscrowd:
422 | ins_center = pcv.centroid_from_ins_mask(segment_mask)
423 | # ins_center = [math.ceil(_x) for _x in ins_center]
424 | ins_offset = (ins_center - spatial_inds[segment_mask]).astype(np.int32)
425 | # ins_center = np.array(ins_center, dtype=np.int32)
426 | # ins_offset = ins_center - spatial_inds[segment_mask]
427 | # ERROR must investigate this!!!
428 | ins_centroids.append(ins_center)
429 |
430 | for actor in self.producers:
431 | actor.process(isthing, iscrowd, segment_mask, area, ins_offset)
432 |
433 | # try:
434 | # for actor in self.producers:
435 | # actor.process(isthing, iscrowd, segment_mask, area, ins_offset)
436 | # except:
437 | # dump = {
438 | # 'pan_mask': self.pan_mask,
439 | # 'segment_id': segment_id,
440 | # 'segments_info': segments_info,
441 | # 'area': area,
442 | # 'ins_offset': ins_offset,
443 | # 'cat': cat,
444 | # 'crowd': iscrowd,
445 | # 'isthing': isthing,
446 | # }
447 | # save_object(dump, './troubleshoot.pkl')
448 | # raise ValueError('terminate here')
449 |
450 | self.ins_centroids = np.array(ins_centroids).reshape(-1, 2)
451 |
452 | training_gts = [self.sem_gt, ]
453 | training_gts.extend(
454 | [ actor.produce() for actor in self.producers ]
455 | )
456 | return training_gts
457 |
458 | def collect_prob_tsr(self):
459 | assert self.producers is not None, 'must create gt first'
460 | gt_producer = self.producers[self.gt_producer_inx]
461 | assert gt_producer.interpretable_as_prob_tsr, \
462 | '{} should be interpretable as prob tsr'.format(gt_producer.__class__)
463 |
464 | sem_prob_tsr = tensorize_2d_spatial_assignement(
465 | self.sem_gt, len(self.meta['catId_2_trainId'])
466 | )
467 | vote_prob_tsr = gt_producer.convert_to_prob_tsr()
468 | return sem_prob_tsr, vote_prob_tsr
469 |
--------------------------------------------------------------------------------
/src/pan_analyzer.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import os.path as osp
3 | from collections import defaultdict
4 | import json
5 |
6 | import numpy as np
7 | import scipy.linalg as LA
8 | from scipy.ndimage import binary_dilation, generate_binary_structure
9 | import pandas as pd
10 | from PIL import Image
11 |
12 | from tabulate import tabulate
13 |
14 | from panopticapi.utils import rgb2id
15 | from panoptic.pan_eval import PQStat, OFFSET, VOID
16 | from panoptic.datasets.base import mapify_iterable
17 |
18 | from fabric.io import load_object, save_object
19 |
20 | _SEGMENT_UNMATCHED = 0
21 | _SEGMENT_MATCHED = 1
22 | _SEGMENT_FORGIVEN = 2
23 |
24 |
25 | def generalized_aspect_ratio(binary_mask):
26 | xs, ys = np.where(binary_mask)
27 | coords = np.array([xs, ys]).T
28 | # mean center the coords
29 | coords = coords - coords.mean(axis=0)
30 | # cov matrix
31 | cov = coords.T @ coords
32 | first, second = LA.eigvalsh(cov)[::-1]
33 | ratio = (first ** 0.5) / (second + 1e-8) ** 0.5
34 | return ratio
35 |
36 |
37 | class Annotation():
38 | '''
39 | Overall Schema for Each Side (Either Gt or Pred)
40 | e.g. pred:
41 | cats: {id: cat}
42 | imgs: {id: image}
43 | segs {sid: seg}
44 | img2seg: {image_id: {sid: seg}}
45 | cat2seg: {cat_id: {sid: seg}}
46 |
47 | cat:
48 | id: 7,
49 | name: road,
50 | supercategory: 'flat',
51 | color: [128, 64, 128],
52 | isthing: 0
53 |
54 | image:
55 | id: 'frankfurt_000000_005898',
56 | file_name: frankfurt/frankfurt_000000_005898_leftImg8bit.png
57 | ann_fname: abcde
58 | width: 2048,
59 | height: 1024,
60 | ---
61 | mask: a cached mask that is loaded
62 |
63 | seg:
64 | sid (seg id): gt/frankfurt_000000_000294/8405120
65 | image_id: frankfurt_000000_000294
66 | id: 8405120
67 | category_id: 7
68 | area: 624611
69 | bbox: [6, 432, 1909, 547]
70 | iscrowd: 0
71 | match_state: one of (UNMATCHED, MATCHED, IGNORED)
72 | matchings: {sid: iou} sorted from high to low
73 | (breakdown_flag): this can be optinally introduced for breakdown analysis
74 | '''
75 | def __init__(self, json_meta_fname, is_gt, state_dict=None):
76 | '''
77 | if state_dict is provided, then just load it and avoid the computation
78 | '''
79 | dirname, fname = osp.split(json_meta_fname)
80 | self.root = dirname
81 | self.mask_root = osp.join(dirname, fname.split('.')[0])
82 |
83 | if state_dict is None:
84 | with open(json_meta_fname) as f:
85 | state_dict = self.process_raw_meta(json.load(f), is_gt)
86 | self.state_dict = state_dict
87 | self.register_state(state_dict)
88 |
89 | def register_state(self, state):
90 | for k, v in state.items():
91 | setattr(self, k, v)
92 |
93 | @staticmethod
94 | def process_raw_meta(raw_meta, is_gt):
95 | state = dict()
96 | state['cats'] = mapify_iterable(raw_meta['categories'], 'id')
97 | state['imgs'] = mapify_iterable(raw_meta['images'], 'id')
98 | state['segs'] = dict()
99 | state['img2seg'] = defaultdict(dict)
100 | state['cat2seg'] = defaultdict(dict)
101 |
102 | sid_prefix = 'gt' if is_gt else 'pd'
103 |
104 | for ann in raw_meta['annotations']:
105 | image_id = ann['image_id']
106 | segments = ann['segments_info']
107 | state['imgs'][image_id]['ann_fname'] = ann['file_name']
108 | for seg in segments:
109 | cat_id = seg['category_id']
110 |
111 | unique_id = '{}/{}/{}'.format(sid_prefix, image_id, seg['id'])
112 | seg['sid'] = unique_id
113 | seg['image_id'] = image_id
114 | seg['match_state'] = _SEGMENT_FORGIVEN
115 | seg['matchings'] = dict()
116 |
117 | state['segs'][unique_id] = seg
118 | state['img2seg'][image_id][unique_id] = seg
119 | state['cat2seg'][cat_id][unique_id] = seg
120 | return state
121 |
122 | def seg_sort_matchings(self):
123 | """sort matchings from high to low IoU"""
124 | for _, seg in self.segs.items():
125 | matchings = seg['matchings']
126 | seg['matchings'] = dict(
127 | sorted(matchings.items(), key=lambda x: x[1], reverse=True)
128 | )
129 |
130 | def match_summarize(self, breakdown_flag=None):
131 | '''
132 | ret: [num_cats, 4] where each row contains
133 | (iou_sum, num_matched, num_unmatched, total_inst)
134 | '''
135 | ret = []
136 | for cat in sorted(self.cats.keys()):
137 | segs = self.cat2seg[cat].values()
138 | iou_sum, num_matched, num_unmatched, total_inst = 0.0, 0.0, 0.0, 0.0
139 | for seg in segs:
140 | if breakdown_flag is not None and seg['breakdown_flag'] != breakdown_flag:
141 | continue # if breakdown is activated, only summarize those required
142 | total_inst += 1
143 | if seg['match_state'] == _SEGMENT_MATCHED:
144 | iou_sum += list(seg['matchings'].values())[0]
145 | num_matched += 1
146 | elif seg['match_state'] == _SEGMENT_UNMATCHED:
147 | num_unmatched += 1
148 | ret.append([iou_sum, num_matched, num_unmatched, total_inst])
149 | ret = np.array(ret)
150 | return ret
151 |
152 | def catId_given_catName(self, catName):
153 | for catId, cat in self.cats.items():
154 | if cat['name'] == catName:
155 | return catId
156 | raise ValueError('what kind of category is this? {}'.format(catName))
157 |
158 | def get_mask_given_seg(self, seg):
159 | return self.get_mask_given_imgid(self, seg['image_id'])
160 |
161 | def get_img_given_imgid(self, image_id, img_root):
162 | img = self.imgs[image_id]
163 | img_fname = img['file_name']
164 | img_fname = osp.join(img_root, img_fname)
165 | img = Image.open(img_fname)
166 | return img
167 |
168 | def get_mask_given_imgid(self, image_id, store_in_cache=True):
169 | img = self.imgs[image_id]
170 | _MASK_KEYNAME = 'mask'
171 | cache_entry = img.get(_MASK_KEYNAME, None)
172 | if cache_entry is not None:
173 | assert isinstance(cache_entry, np.ndarray)
174 | return cache_entry
175 | else:
176 | mask_fname = img['ann_fname']
177 | mask = np.array(
178 | Image.open(osp.join(self.mask_root, mask_fname)),
179 | dtype=np.uint32
180 | )
181 | mask = rgb2id(mask)
182 | if store_in_cache:
183 | img[_MASK_KEYNAME] = mask
184 | return mask
185 |
186 | def compute_seg_shape_oddity(self):
187 | print('start computing shape oddity')
188 | i = 0
189 | for imgId, segs in self.img2seg.items():
190 | i += 1
191 | if (i % 50) == 0:
192 | print(i)
193 | mask = self.get_mask_given_imgid(imgId, store_in_cache=False)
194 | for _, s in segs.items():
195 | seg_id = s['id']
196 | binary_mask = (mask == seg_id)
197 | s['gen_aspect_ratio'] = generalized_aspect_ratio(binary_mask)
198 |
199 | def compute_seg_boundary_stats(self):
200 | print('start computing boundary stats')
201 | i = 0
202 | for imgId, segs in self.img2seg.items():
203 | i += 1
204 | if (i % 50) == 0:
205 | print(i)
206 | mask = self.get_mask_given_imgid(imgId, store_in_cache=False)
207 | for _, s in segs.items():
208 | seg_id = s['id']
209 | binary_mask = (mask == seg_id)
210 | self._per_seg_neighbors_stats(s, binary_mask, mask)
211 |
212 | def _per_seg_neighbors_stats(self, seg_dict, binary_mask, mask):
213 | area = binary_mask.sum()
214 | # struct = generate_binary_structure(2, 2)
215 | dilated = binary_dilation(binary_mask, structure=None, iterations=1)
216 | boundary = dilated ^ binary_mask
217 |
218 | # stats
219 | length = boundary.sum()
220 | ratio = length ** 2 / area
221 | seg_dict['la_ratio'] = ratio
222 |
223 | # get the neighbors
224 | ids, cnts = np.unique(mask[boundary], return_counts=True)
225 |
226 | sid_prefix = '/'.join(
227 | seg_dict['sid'].split('/')[:2] # throw away the last
228 | )
229 | sids = [ '{}/{}'.format(sid_prefix, id) for id in ids ]
230 |
231 | thing_neighbors = {
232 | sid: cnt for sid, id, cnt in zip(sids, ids, cnts)
233 | if id > 0 and self.cats[self.segs[sid]['category_id']]['isthing']
234 | }
235 | seg_dict['thing_neighbors'] = thing_neighbors
236 |
237 |
238 | class PanopticEvalAnalyzer():
239 | def __init__(self, gt_json_meta_fname, pd_json_meta_fname, load_state=True):
240 | # use the pd folder as root directory since a single gt ann can correspond
241 | # to many pd anns.
242 | root = osp.split(pd_json_meta_fname)[0]
243 | self.state_dump_fname = osp.join(root, 'analyze_dump.pkl')
244 |
245 | is_evaluated = False
246 | if osp.isfile(self.state_dump_fname) and load_state:
247 | state = load_object(self.state_dump_fname)
248 | gt_state, pd_state = state['gt'], state['pd']
249 | is_evaluated = True
250 | else:
251 | gt_state, pd_state = None, None
252 |
253 | self.gt = Annotation(gt_json_meta_fname, is_gt=True, state_dict=gt_state)
254 | self.pd = Annotation(pd_json_meta_fname, is_gt=False, state_dict=pd_state)
255 |
256 | # validate that gt and pd json completely match
257 | assert self.gt.imgs.keys() == self.pd.imgs.keys()
258 | assert self.gt.cats == self.pd.cats
259 | self.imgIds = list(sorted(self.gt.imgs.keys()))
260 |
261 | if not is_evaluated:
262 | # evaluate and then save the state
263 | self._evaluate()
264 | self.gt.compute_seg_shape_oddity()
265 | self.pd.compute_seg_shape_oddity()
266 | self.dump_state()
267 |
268 | def _gt_boundary_stats(self):
269 | self.gt.compute_seg_boundary_stats()
270 |
271 | def dump_state(self):
272 | state = {
273 | 'gt': self.gt.state_dict,
274 | 'pd': self.pd.state_dict
275 | }
276 | save_object(state, self.state_dump_fname)
277 |
278 | def _evaluate(self):
279 | stats = PQStat()
280 | cats = self.gt.cats
281 | for i, imgId in enumerate(self.imgIds):
282 | if (i % 50) == 0:
283 | print("progress {} / {}".format(i, len(self.imgIds)))
284 | # if (i > 100):
285 | # break
286 |
287 | gt_ann = {
288 | 'image_id': imgId, 'segments_info': self.gt.img2seg[imgId].values()
289 | }
290 | gt_mask = np.array(
291 | Image.open(osp.join(
292 | self.gt.mask_root, self.gt.imgs[imgId]['ann_fname']
293 | )),
294 | dtype=np.uint32
295 | )
296 | gt_mask = rgb2id(gt_mask)
297 |
298 | pd_ann = {
299 | 'image_id': imgId, 'segments_info': self.pd.img2seg[imgId].values()
300 | }
301 | pd_mask = np.array(
302 | Image.open(osp.join(
303 | self.pd.mask_root, self.pd.imgs[imgId]['ann_fname']
304 | )),
305 | dtype=np.uint32
306 | )
307 | pd_mask = rgb2id(pd_mask)
308 |
309 | _single_stat = self.pq_compute_single_img(
310 | cats, gt_ann, gt_mask, pd_ann, pd_mask
311 | )
312 | stats += _single_stat
313 |
314 | self.gt.seg_sort_matchings()
315 | self.pd.seg_sort_matchings()
316 | return stats
317 |
318 | def summarize(self, flag=None):
319 | per_cat_res, overall_table, cat_table = self._aggregate(
320 | gt_stats=self.gt.match_summarize(flag),
321 | pd_stats=self.pd.match_summarize(flag),
322 | cats=self.gt.cats
323 | )
324 | return per_cat_res, overall_table, cat_table
325 |
326 | @staticmethod
327 | def _aggregate(gt_stats, pd_stats, cats):
328 | '''
329 | Args:
330 | pd/gt_stats: [num_cats, 4] with each row contains
331 | (iou_sum, num_matched, num_unmatched, total_inst)
332 | cats: a dict of {catId: catMetaData}
333 | Returns:
334 | 1. per cat pandas dataframe; easy to programmatically manipulate
335 | 2. str formatted overall result table
336 | 3. str formatted per category result table
337 | '''
338 | # each is of shape [num_cats]
339 | gt_iou, gt_matched, gt_unmatched, gt_tot_inst = gt_stats.T
340 | pd_iou, pd_matched, pd_unmatched, pd_tot_inst = pd_stats.T
341 | assert np.allclose(gt_iou, pd_iou) and (gt_matched == pd_matched).all()
342 |
343 | catIds = list(sorted(cats.keys()))
344 | catNames = [cats[id]['name'] for id in catIds]
345 | isthing = np.array([cats[id]['isthing'] for id in catIds], dtype=bool)
346 |
347 | RQ = gt_matched / (gt_matched + 0.5 * gt_unmatched + 0.5 * pd_unmatched)
348 | SQ = gt_iou / gt_matched
349 | RQ, SQ = np.nan_to_num(RQ), np.nan_to_num(SQ)
350 | PQ = RQ * SQ
351 | results = np.array([PQ, SQ, RQ]) * 100 # [3, num_cats]
352 |
353 | overall_table = tabulate(
354 | headers=['', 'PQ', 'SQ', 'RQ', 'num_cats'],
355 | floatfmt=".2f", tablefmt='fancy_grid',
356 | tabular_data=[
357 | ['all'] + list(map(lambda x: x.mean(), results)) + [len(catIds)],
358 | ['things'] + list(map(lambda x: x[isthing].mean(), results)) + [sum(isthing)],
359 | ['stuff'] + list(map(lambda x: x[~isthing].mean(), results)) + [sum(1 - isthing)],
360 | ]
361 | )
362 |
363 | headers = (
364 | 'PQ', 'SQ', 'RQ',
365 | 'num_matched', 'gt_unmatched', 'pd_unmatched', 'tot_gt_inst',
366 | 'isthing'
367 | )
368 | results = np.array(
369 | list(results) + [gt_matched, gt_unmatched, pd_unmatched, gt_tot_inst, isthing]
370 | )
371 | results = results.T
372 | data_frame = pd.DataFrame(results, columns=headers, index=catNames)
373 | cat_table = tabulate(
374 | data_frame, headers='keys', floatfmt=".2f", tablefmt='fancy_grid'
375 | )
376 | return data_frame, overall_table, cat_table
377 |
378 | @staticmethod
379 | def pq_compute_single_img(cats, gt_ann, gt_mask, pd_ann, pd_mask):
380 | """
381 | This is the original eval function refactored for readability
382 | """
383 | pq_stat = PQStat()
384 | gt_segms = {el['id']: el for el in gt_ann['segments_info']}
385 | pd_segms = {el['id']: el for el in pd_ann['segments_info']}
386 |
387 | # predicted segments area calculation + prediction sanity checks
388 | pd_labels_set = set(el['id'] for el in pd_ann['segments_info'])
389 | labels, labels_cnt = np.unique(pd_mask, return_counts=True)
390 | for label, label_cnt in zip(labels, labels_cnt):
391 | if label not in pd_segms:
392 | if label == VOID:
393 | continue
394 | raise KeyError(
395 | ('In the image with ID {} '
396 | 'segment with ID {} is presented in PNG '
397 | 'and not presented in JSON.').format(gt_ann['image_id'], label)
398 | )
399 | pd_segms[label]['area'] = int(label_cnt)
400 | pd_labels_set.remove(label)
401 | if pd_segms[label]['category_id'] not in cats:
402 | raise KeyError(
403 | ('In the image with ID {} '
404 | 'segment with ID {} has unknown '
405 | 'category_id {}.').format(
406 | gt_ann['image_id'], label, pd_segms[label]['category_id'])
407 | )
408 | if len(pd_labels_set) != 0:
409 | raise KeyError(
410 | ('In the image with ID {} '
411 | 'the following segment IDs {} are presented '
412 | 'in JSON and not presented in PNG.').format(
413 | gt_ann['image_id'], list(pd_labels_set))
414 | )
415 |
416 | # confusion matrix calculation
417 | gt_vs_pd = gt_mask.astype(np.uint64) * OFFSET + pd_mask.astype(np.uint64)
418 | gt_pd_itrsct = {}
419 | labels, labels_cnt = np.unique(gt_vs_pd, return_counts=True)
420 | for label, intersection in zip(labels, labels_cnt):
421 | gt_id, pd_id = label // OFFSET, label % OFFSET
422 | gt_pd_itrsct[(gt_id, pd_id)] = intersection
423 |
424 | # count all matched pairs
425 | gt_matched, pd_matched = set(), set()
426 | for label_tuple, intersection in gt_pd_itrsct.items():
427 | gt_label, pd_label = label_tuple
428 | if gt_label not in gt_segms:
429 | continue
430 | if pd_label not in pd_segms:
431 | continue
432 |
433 | gt_seg, pd_seg = gt_segms[gt_label], pd_segms[pd_label]
434 | union = pd_seg['area'] + gt_seg['area'] \
435 | - intersection - gt_pd_itrsct.get((VOID, pd_label), 0)
436 | iou = intersection / union
437 | if iou > 0.1:
438 | gt_seg['matchings'][pd_seg['sid']] = iou
439 | pd_seg['matchings'][gt_seg['sid']] = iou
440 |
441 | if gt_seg['iscrowd'] == 1:
442 | continue
443 | if gt_seg['category_id'] != pd_seg['category_id']:
444 | continue
445 |
446 | if iou > 0.5:
447 | gt_cat_id = gt_seg['category_id']
448 | pq_stat[gt_cat_id].tp += 1
449 | pq_stat[gt_cat_id].iou += iou
450 | gt_matched.add(gt_label)
451 | pd_matched.add(pd_label)
452 | gt_seg['match_state'] = _SEGMENT_MATCHED
453 | pd_seg['match_state'] = _SEGMENT_MATCHED
454 |
455 | # count false negatives
456 | # HC: assumption each category in image can only have a single crowd segment!
457 | # each img each cat, all crowd segments are merged into 1 segment. well
458 | crowd_cat_segid = {}
459 | for gt_label, gt_info in gt_segms.items():
460 | if gt_label in gt_matched:
461 | continue
462 | # crowd segments are ignored;
463 | if gt_info['iscrowd'] == 1:
464 | crowd_cat_segid[gt_info['category_id']] = gt_label
465 | continue
466 | pq_stat[gt_info['category_id']].fn += 1
467 | gt_info['match_state'] = _SEGMENT_UNMATCHED
468 |
469 | # count false positives
470 | for pd_label, pd_info in pd_segms.items():
471 | if pd_label in pd_matched:
472 | continue
473 | # intersection of the segment with VOID
474 | intersection = gt_pd_itrsct.get((VOID, pd_label), 0)
475 | # plus intersection with corresponding CROWD region if it exists
476 | if pd_info['category_id'] in crowd_cat_segid:
477 | intersection += gt_pd_itrsct.get(
478 | (crowd_cat_segid[pd_info['category_id']], pd_label), 0
479 | )
480 | # predicted segment is ignored if more than half of the segment
481 | # correspond to VOID and CROWD regions
482 | if intersection / pd_info['area'] > 0.5:
483 | continue
484 | pq_stat[pd_info['category_id']].fp += 1
485 | pd_info['match_state'] = _SEGMENT_UNMATCHED
486 | return pq_stat
487 |
488 |
489 | class BreakdownPolicy():
490 | def __init__(self):
491 | self.flags = []
492 |
493 | def breakdown(self, gt_segs, pd_segs):
494 | pass
495 |
496 |
497 | class DummyBreakdown(BreakdownPolicy):
498 | def __init__(self):
499 | self.flags = ['sector1', 'sector2']
500 |
501 | def breakdown(self, gt_segs, pd_segs):
502 | import numpy.random as npr
503 | gt_flags = [ npr.choice(self.flags) for _ in range(len(gt_segs)) ]
504 | for i, seg in enumerate(gt_segs.values()):
505 | seg['breakdown_flag'] = gt_flags[i]
506 |
507 | pd_flags = [ npr.choice(self.flags) for _ in range(len(pd_segs)) ]
508 | for i, seg in enumerate(pd_segs.values()):
509 | seg['breakdown_flag'] = pd_flags[i]
510 |
511 |
512 | class BoxScaleBreakdown(BreakdownPolicy):
513 | def __init__(self):
514 | # self.flags = ['tiny', 'small', 'medium', 'large', 'huge']
515 | # self.scale_thresholds = [ 16 ** 2, 32 ** 2, 64 ** 2, 128 ** 2]
516 |
517 | self.flags = ['small', 'medium', 'large']
518 | self.scale_thresholds = [32 ** 2, 128 ** 2]
519 |
520 | def breakdown(self, gt_segs, pd_segs):
521 | thresh, flags = self.scale_thresholds, self.flags
522 | gt_areas = [
523 | seg['bbox'][-1] * seg['bbox'][-2] for seg in gt_segs.values()
524 | ]
525 | gt_flags = [
526 | flags[bisect.bisect_right(thresh, s_area)] for s_area in gt_areas
527 | ]
528 | # give each gt the flag
529 | for i, g_seg in enumerate(gt_segs.values()):
530 | g_seg['breakdown_flag'] = gt_flags[i]
531 |
532 | for p_seg in pd_segs.values():
533 | matchings = p_seg['matchings']
534 | flag = None
535 | if len(matchings) == 0:
536 | area = p_seg['bbox'][-1] * p_seg['bbox'][-2]
537 | flag = flags[bisect.bisect_right(thresh, area)]
538 | else:
539 | gt_s_sid = list(matchings.keys())[0]
540 | flag = gt_segs[gt_s_sid]['breakdown_flag']
541 | p_seg['breakdown_flag'] = flag
542 |
543 |
544 | class MaskScaleBreakdown(BreakdownPolicy):
545 | def __init__(self):
546 | # self.flags = ['tiny', 'small', 'medium', 'large', 'huge']
547 | # self.scale_thresholds = [ 16 ** 2, 32 ** 2, 64 ** 2, 128 ** 2]
548 |
549 | self.flags = ['small', 'medium', 'large']
550 | self.scale_thresholds = [32 ** 2, 128 ** 2]
551 |
552 | def breakdown(self, gt_segs, pd_segs):
553 | thresh, flags = self.scale_thresholds, self.flags
554 | gt_areas = [ seg['area'] for seg in gt_segs.values() ]
555 | gt_flags = [
556 | flags[bisect.bisect_right(thresh, s_area)] for s_area in gt_areas
557 | ]
558 | # give each gt the flag
559 | for i, g_seg in enumerate(gt_segs.values()):
560 | g_seg['breakdown_flag'] = gt_flags[i]
561 |
562 | for p_seg in pd_segs.values():
563 | matchings = p_seg['matchings']
564 | flag = None
565 | if len(matchings) == 0:
566 | area = p_seg['area']
567 | flag = flags[bisect.bisect_right(thresh, area)]
568 | else:
569 | gt_s_sid = list(matchings.keys())[0]
570 | flag = gt_segs[gt_s_sid]['breakdown_flag']
571 | p_seg['breakdown_flag'] = flag
572 |
573 |
574 | class BoxAspectRatioBreakdown(BreakdownPolicy):
575 | def __init__(self):
576 | self.flags = []
577 |
578 | def breakdown(self, gt_segs, pd_segs):
579 | pass
580 |
581 |
582 | policy_register = {
583 | 'dummy': DummyBreakdown,
584 | 'bbox_scale': BoxScaleBreakdown,
585 | 'mask_scale': MaskScaleBreakdown
586 | }
587 |
588 |
589 | class StatsBreakdown():
590 | def __init__(self, gt_json_meta_fname, pd_json_meta_fname, breakdown_policy):
591 | analyzer = PanopticEvalAnalyzer(gt_json_meta_fname, pd_json_meta_fname)
592 | self.gt_segs = analyzer.gt.segs
593 | self.pd_segs = analyzer.pd.segs
594 | self.analyzer = analyzer
595 | self.policy = policy_register[breakdown_policy]()
596 | self.policy.breakdown(self.gt_segs, self.pd_segs)
597 | self.verify_policy_execution(self.policy, self.gt_segs, self.pd_segs)
598 |
599 | @staticmethod
600 | def verify_policy_execution(policy, gt_segs, pd_segs):
601 | '''make sure that each seg has been given a flag'''
602 | flags = policy.flags
603 | for seg in gt_segs.values():
604 | assert seg['breakdown_flag'] in flags
605 | for seg in pd_segs.values():
606 | assert seg['breakdown_flag'] in flags
607 |
608 | def aggregate(self):
609 | res = []
610 | for flag in self.policy.flags:
611 | dataframe, overall_table, cat_table = self.analyzer.summarize(flag)
612 | res.append(dataframe)
613 | res = pd.concat(res, axis=1)
614 |
615 | # print results in semicolon separated format so that I can transfer to google doc
616 | print(self.policy.flags)
617 | # upper left corner of the table is 'name'
618 | cols = ';'.join(['name'] + list(res.columns))
619 | print(cols)
620 | for catName, row in res.iterrows():
621 | row = [ '{:.2f}'.format(elem) for elem in row.values ]
622 | score_str = ';'.join([catName] + row)
623 | print(score_str)
624 |
625 |
626 | def test():
627 | model = 'pcv'
628 | dset = 'cityscapes'
629 | split = 'val'
630 | gt_json_fname = '/share/data/vision-greg/panoptic/{}/annotations/{}.json'.format(dset, split)
631 | pd_json_fname = '/home-nfs/whc/panout/{}/{}/{}/pred.json'.format(model, dset, split)
632 |
633 | # analyzer = PanopticEvalAnalyzer(gt_json_fname, pd_json_fname)
634 | # _, overall_table, cat_table = analyzer.aggregate(
635 | # gt_stats=analyzer.gt.match_summarize(),
636 | # pd_stats=analyzer.pd.match_summarize(),
637 | # cats=analyzer.gt.cats
638 | # )
639 | # print(cat_table)
640 | # print(overall_table)
641 |
642 | breakdown_stats = StatsBreakdown(gt_json_fname, pd_json_fname, 'mask_scale')
643 | breakdown_stats.aggregate()
644 |
645 |
646 | def draw_failure_cases_spatially():
647 | model = 'pcv'
648 | dset = 'cityscapes'
649 | split = 'val'
650 | gt_json_fname = '/share/data/vision-greg/panoptic/{}/annotations/{}.json'.format(dset, split)
651 | pd_json_fname = '/home-nfs/whc/panout/{}/{}/{}/pred.json'.format(model, dset, split)
652 |
653 | analyzer = PanopticEvalAnalyzer(gt_json_fname, pd_json_fname)
654 | # plot unmatched gt
655 | gt_accumulator = np.zeros((1024, 2048))
656 | # plt.imshow(gt_accumulator)
657 | gt = analyzer.gt
658 | for k, seg in gt.segs.items():
659 | if seg['match_state'] == _SEGMENT_UNMATCHED:
660 | x, y, w, h = seg['bbox']
661 | c = [y + h // 2, x + w // 2]
662 | y, x = c
663 | gt_accumulator[y, x] = gt_accumulator[y, x] + 1
664 | return gt_accumulator
665 |
666 |
667 | if __name__ == "__main__":
668 | draw_failure_cases_spatially()
669 |
--------------------------------------------------------------------------------