├── LICENSE ├── Makefile ├── README.md ├── configs └── demo_config.yaml ├── docker └── Dockerfile ├── media └── figs │ ├── panoptic-teaser.gif │ ├── test.png │ └── tri-logo.png ├── realtime_panoptic ├── __init__.py ├── config │ ├── __init__.py │ └── defaults.py ├── data │ ├── __init__.py │ └── panoptic_transform.py ├── layers │ ├── __init__.py │ └── scale.py ├── models │ ├── __init__.py │ ├── backbones.py │ ├── panoptic_from_dense_box.py │ └── rt_pano_net.py └── utils │ ├── __init__.py │ ├── bounding_box.py │ ├── boxlist_ops.py │ └── visualization.py └── scripts └── demo.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Toyota Research Institute - Machine Learning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for launching common tasks 2 | 3 | PYTHON ?= python 4 | DOCKER_OPTS ?= \ 5 | -v /dev/shm:/dev/shm \ 6 | -v /root/.ssh:/root/.ssh \ 7 | -v /var/run/docker.sock:/var/run/docker.sock \ 8 | --network=host \ 9 | --privileged 10 | PACKAGE_NAME ?= panoptic 11 | WORKSPACE ?= /workspace/$(PACKAGE_NAME) 12 | DOCKER_IMAGE_NAME ?= $(PACKAGE_NAME) 13 | DOCKER_IMAGE ?= $(DOCKER_IMAGE_NAME):latest 14 | 15 | all: clean test 16 | 17 | clean: 18 | find . -name "*.pyc" | xargs rm -f && \ 19 | find . -name "__pycache__" | xargs rm -rf 20 | 21 | clean-logs: 22 | find . -name "tensorboardx" | xargs rm -rf && \ 23 | find . -name "wandb" | xargs rm -rf 24 | 25 | test: 26 | PYTHONPATH=${PWD}/tests:${PYTHONPATH} python -m unittest discover -s tests 27 | 28 | docker-build: 29 | docker build \ 30 | -f docker/Dockerfile \ 31 | -t ${DOCKER_IMAGE} . 32 | 33 | docker-run-test-sample: 34 | nvidia-docker run --name panoptic --rm \ 35 | -e DISPLAY=${DISPLAY} \ 36 | -v ${PWD}:${WORKSPACE} \ 37 | -v /tmp/.X11-unix:/tmp/.X11-unix \ 38 | -v ~/.torch:/root/.torch \ 39 | -p 8888:8888 \ 40 | -p 6006:6006 \ 41 | -p 5000:5000 \ 42 | -it \ 43 | -v ${PWD}:${WORKSPACE} \ 44 | ${DOCKER_OPTS} \ 45 | ${DOCKER_IMAGE} bash -c \ 46 | "wget -P /workspace/panoptic/ -c https://tri-ml-public.s3.amazonaws.com/github/realtime_panoptic/models/cvpr_realtime_pano_cityscapes_standalone_no_prefix.pth && \ 47 | python scripts/demo.py \ 48 | --config-file configs/demo_config.yaml \ 49 | --input media/figs/test.png \ 50 | --pretrained-weight cvpr_realtime_pano_cityscapes_standalone_no_prefix.pth" 51 | 52 | docker-start: 53 | nvidia-docker run --name panoptic --rm \ 54 | -e DISPLAY=${DISPLAY} \ 55 | -v ${PWD}:${WORKSPACE} \ 56 | -v /tmp/.X11-unix:/tmp/.X11-unix \ 57 | -v /data:/data \ 58 | -v ~/.torch:/root/.torch \ 59 | -p 8888:8888 \ 60 | -p 6006:6006 \ 61 | -p 5000:5000 \ 62 | -d \ 63 | -it \ 64 | ${DOCKER_OPTS} \ 65 | ${DOCKER_IMAGE} && \ 66 | nvidia-docker exec -it panoptic bash 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Real-Time Panoptic Segmentation from Dense Detections 2 | 3 | Official [PyTorch](https://pytorch.org/) implementation of the CVPR 2020 Oral **Real-Time Panoptic Segmentation from Dense Detections** by the ML Team at [Toyota Research Institute (TRI)](https://www.tri.global/), cf. [References](#references) below. 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | ## Install 14 | ``` 15 | git clone https://github.com/TRI-ML/realtime_panoptic.git 16 | cd realtime_panoptic 17 | make docker-build 18 | ``` 19 | 20 | To verify your installation, you can also run our simple test run to conduct inference on 1 test image using our Cityscapes pretrained model: 21 | ``` 22 | make docker-run-test-sample 23 | ``` 24 | 25 | Now you can start a docker container with interactive mode: 26 | ``` 27 | make docker-start 28 | ``` 29 | ## Demo 30 | We provide demo code to conduct inference on Cityscapes pretrained model. 31 | ``` 32 | python scripts/demo.py --config-file --input \ 33 | --pretrained-weight 34 | ``` 35 | Simple user example using our pretrained model previded in the Models section: 36 | ``` 37 | python scripts/demo.py --config-file ./configs/demo_config.yaml --input media/figs/test.png --pretrained-weight cvpr_realtime_pano_cityscapes_standalone_no_prefix.pth 38 | ``` 39 | 40 | ## Models 41 | 42 | 43 | ### Cityscapes 44 | | Model | PQ | PQ_th | PQ_st | 45 | | :--- | :---: | :---: | :---: | 46 | | [ResNet-50](https://tri-ml-public.s3.amazonaws.com/github/realtime_panoptic/models/cvpr_realtime_pano_cityscapes_standalone_no_prefix.pth) | 58.8 | 52.1| 63.7 | 47 | 48 | ## License 49 | 50 | The source code is released under the [MIT license](LICENSE.md). 51 | 52 | ## References 53 | 54 | #### Real-Time Panoptic Segmentation from Dense Detections (CVPR 2020 oral) 55 | *Rui Hou\*, Jie Li\*, Arjun Bhargava, Allan Raventos, Vitor Guizilini, Chao Fang, Jerome Lynch, Adrien Gaidon*, [**[paper]**](https://arxiv.org/abs/1912.01202), [**[oral presentation]**](https://www.youtube.com/watch?v=xrxaRU2g2vo), [**[teaser]**](https://www.youtube.com/watch?v=_N4kGJEg-rM) 56 | ``` 57 | @InProceedings{real-time-panoptic, 58 | author = {Hou, Rui and Li, Jie and Bhargava, Arjun and Raventos, Allan and Guizilini, Vitor and Fang, Chao and Lynch, Jerome and Gaidon, Adrien}, 59 | title = {Real-Time Panoptic Segmentation From Dense Detections}, 60 | booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 61 | month = {June}, 62 | year = {2020} 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /configs/demo_config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Toyota Research Institute. All rights reserved. 2 | model: 3 | name: 'Cityscape_realtime_panoptic' 4 | backbone: 'R-50-FPN-RETINANET' 5 | panoptic: 6 | num_classes: 19 7 | num_thing_classes: 8 8 | fpn_post_nms_top_n: 100 9 | instance_id_range: (11, 18) 10 | pre_nms_thresh: 0.06 11 | 12 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------- Base Docker Image ----------------------------------------------- 2 | FROM nvidia/cuda:10.0-devel-ubuntu18.04 3 | 4 | ENV PYTORCH_VERSION=1.1.0 5 | ENV TORCHVISION_VERSION=0.3.0 6 | ENV CUDNN_VERSION=7.6.0.64-1+cuda10.0 7 | ENV NCCL_VERSION=2.4.7-1+cuda10.0 8 | 9 | # Workaround for deadlock issue. To be removed with next major Pytorch release 10 | ENV NCCL_LL_THRESHOLD=0 11 | 12 | # Python 2.7 or 3.6 is supported by Ubuntu Bionic out of the box 13 | ARG python=3.6 14 | ENV PYTHON_VERSION=${python} 15 | ENV DEBIAN_FRONTEND=noninteractive 16 | 17 | # Set default shell to /bin/bash 18 | SHELL ["/bin/bash", "-cu"] 19 | 20 | RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends \ 21 | build-essential \ 22 | cmake \ 23 | g++-4.8 \ 24 | git \ 25 | curl \ 26 | docker.io \ 27 | vim \ 28 | wget \ 29 | ca-certificates \ 30 | libcudnn7=${CUDNN_VERSION} \ 31 | libnccl2=${NCCL_VERSION} \ 32 | libnccl-dev=${NCCL_VERSION} \ 33 | libjpeg-dev \ 34 | libpng-dev \ 35 | python${PYTHON_VERSION} \ 36 | python${PYTHON_VERSION}-dev \ 37 | python3-tk \ 38 | librdmacm1 \ 39 | libibverbs1 \ 40 | ibverbs-providers \ 41 | libgtk2.0-dev \ 42 | unzip \ 43 | bzip2 \ 44 | htop 45 | 46 | 47 | # Instal Python and pip 48 | RUN if [[ "${PYTHON_VERSION}" == "3.6" ]]; then \ 49 | apt-get install -y python${PYTHON_VERSION}-distutils; \ 50 | fi 51 | 52 | RUN ln -sf /usr/bin/python${PYTHON_VERSION} /usr/bin/python 53 | 54 | RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ 55 | python get-pip.py && \ 56 | rm get-pip.py 57 | 58 | # Install PyTorch 59 | RUN pip install future typing numpy awscli 60 | RUN pip install https://download.pytorch.org/whl/cu100/torch-${PYTORCH_VERSION}-cp36-cp36m-linux_x86_64.whl 61 | RUN pip install https://download.pytorch.org/whl/cu100/torchvision-${TORCHVISION_VERSION}-cp36-cp36m-linux_x86_64.whl 62 | 63 | 64 | # Configure environment variables - default working directory is "/workspace" 65 | WORKDIR /workspace 66 | ENV PYTHONPATH="/workspace" 67 | 68 | # Install dependencies 69 | RUN pip install ninja yacs cython matplotlib opencv-python tqdm onnx onnxruntime coloredlogs scipy pycuda 70 | RUN pip uninstall -y pillow 71 | RUN pip install pillow-simd==6.2.2.post1 pycocotools 72 | 73 | # Install apex 74 | WORKDIR /workspace 75 | RUN git clone https://github.com/NVIDIA/apex.git 76 | WORKDIR /workspace/apex 77 | RUN git checkout 82dac9c9419035110d1ccc49b2608681337903ed 78 | RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . 79 | 80 | # Copy repo 81 | ENV PYTHONPATH="/workspace/panoptic:$PYTHONPATH" 82 | COPY . /workspace/panoptic 83 | WORKDIR /workspace/panoptic 84 | -------------------------------------------------------------------------------- /media/figs/panoptic-teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/realtime_panoptic/97ee03657c82b19a628f52764703cb1c693d08a0/media/figs/panoptic-teaser.gif -------------------------------------------------------------------------------- /media/figs/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/realtime_panoptic/97ee03657c82b19a628f52764703cb1c693d08a0/media/figs/test.png -------------------------------------------------------------------------------- /media/figs/tri-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/realtime_panoptic/97ee03657c82b19a628f52764703cb1c693d08a0/media/figs/tri-logo.png -------------------------------------------------------------------------------- /realtime_panoptic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Toyota Research Institute. All rights reserved. 2 | -------------------------------------------------------------------------------- /realtime_panoptic/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import cfg -------------------------------------------------------------------------------- /realtime_panoptic/config/defaults.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Toyota Research Institute. All rights reserved. 2 | 3 | import os 4 | 5 | from yacs.config import CfgNode as CN 6 | cfg = CN() 7 | ######################################################################################################################## 8 | ### MODEL 9 | ######################################################################################################################## 10 | cfg.model = CN() 11 | cfg.model.name = '' # Training model 12 | cfg.model.backbone = '' # Backbone 13 | cfg.model.checkpoint_path = '' # Checkpoint path for model saving 14 | 15 | cfg.model.panoptic = CN() 16 | cfg.model.panoptic.num_classes = 19 # number of total classes 17 | cfg.model.panoptic.num_thing_classes = 8 # number of thing classes 18 | cfg.model.panoptic.pre_nms_thresh = 0.05 # objectness threshold before NMS 19 | cfg.model.panoptic.pre_nms_top_n = 1000 # max num of accepted bboxes before NMS 20 | cfg.model.panoptic.nms_thresh = 0.6 # NMS threshold 21 | cfg.model.panoptic.fpn_post_nms_top_n = 100 # Top detection post NMS 22 | # for cityscapes, it is (11,18), for COCO it is (0,79), for Vistas it is (0,36) 23 | cfg.model.panoptic.instance_id_range = (11, 18) 24 | 25 | 26 | ######################################################################################################################## 27 | ### INPUT 28 | ######################################################################################################################## 29 | cfg.input = CN() 30 | cfg.input.pixel_mean = [102.9801, 115.9465, 122.7717] 31 | cfg.input.pixel_std = [1., 1., 1.] 32 | # Convert image to BGR format, in range 0-255 33 | cfg.input.to_bgr255 = True 34 | -------------------------------------------------------------------------------- /realtime_panoptic/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/realtime_panoptic/97ee03657c82b19a628f52764703cb1c693d08a0/realtime_panoptic/data/__init__.py -------------------------------------------------------------------------------- /realtime_panoptic/data/panoptic_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Toyota Research Institute. All rights reserved. 2 | import random 3 | 4 | import torchvision.transforms as torchvision_transforms 5 | from PIL import Image 6 | from torchvision.transforms import functional as F 7 | 8 | 9 | class Compose: 10 | """Compose as set of data transform operations 11 | 12 | Parameters: 13 | ----------- 14 | transforms: list 15 | A list of transform operations. 16 | 17 | Returns: 18 | ------- 19 | data: Dict 20 | The final output data after the given set of transforms. 21 | """ 22 | def __init__(self, transforms): 23 | self.transforms = transforms 24 | 25 | def __call__(self, data): 26 | for t in self.transforms: 27 | data = t(data) 28 | return data 29 | 30 | def __repr__(self): 31 | format_string = self.__class__.__name__ + "(" 32 | for t in self.transforms: 33 | format_string += "\n" 34 | format_string += " {0}".format(t) 35 | format_string += "\n)" 36 | return format_string 37 | 38 | 39 | class Resize: 40 | """Resize operation on panoptic data. 41 | The input data will be resized by randomly choose a minimum side length from 42 | a given set with the maximum side capped by a given length. 43 | 44 | Parameters: 45 | ---------- 46 | min_size: list or tuple 47 | A list of size to be chosen for image minimum side. 48 | 49 | max_size: int 50 | Maximum side length of the processed image 51 | 52 | """ 53 | def __init__(self, min_size, max_size, is_train=True): 54 | if not isinstance(min_size, (list, tuple)): 55 | min_size = (min_size, ) 56 | self.min_size = min_size 57 | self.max_size = max_size 58 | self.is_train = is_train 59 | 60 | # modified from torchvision to add support for max size 61 | # NOTE: this method will always try to make the smaller size match ``self.min_size`` 62 | # so in the case of Vistas, this will mean evaluating at a significantly lower resolution 63 | # than the original image for some images 64 | def get_size(self, image_size): 65 | w, h = image_size 66 | size = random.choice(self.min_size) 67 | max_size = self.max_size 68 | if max_size is not None: 69 | min_original_size = float(min((w, h))) 70 | max_original_size = float(max((w, h))) 71 | if max_original_size / min_original_size * size > max_size: 72 | size = int( 73 | round(max_size * min_original_size / max_original_size)) 74 | 75 | if (w <= h and w == size) or (h <= w and h == size): 76 | return (h, w) 77 | 78 | if w < h: 79 | ow = size 80 | oh = int(size * h / w) 81 | else: 82 | oh = size 83 | ow = int(size * w / h) 84 | 85 | return (oh, ow) 86 | 87 | def __call__(self, data): 88 | size = self.get_size(data["image"].size) 89 | data["image"] = F.resize(data["image"], size) 90 | if self.is_train: 91 | if "segmentation_target" in data: 92 | data["segmentation_target"] = F.resize( 93 | data["segmentation_target"], 94 | size, 95 | interpolation=Image.NEAREST) 96 | if "detection_target" in data: 97 | data["detection_target"] = data["detection_target"].resize( 98 | data["image"].size) 99 | return data 100 | 101 | 102 | class RandomHorizontalFlip: 103 | """Randomly Flip the input data with given probability. 104 | 105 | Parameters: 106 | ---------- 107 | prob: float 108 | A probability to flip the data in [0,1]. 109 | """ 110 | def __init__(self, prob=0.5): 111 | self.prob = prob 112 | 113 | def __call__(self, data): 114 | if random.random() < self.prob: 115 | data["image"] = F.hflip(data["image"]) 116 | if "detection_target" in data: 117 | data["detection_target"] = data["detection_target"].transpose( 118 | 0) 119 | if "segmentation_target" in data: 120 | data["segmentation_target"] = F.hflip( 121 | data["segmentation_target"]) 122 | return data 123 | 124 | 125 | class ToTensor: 126 | """Convert the input data to Tensor. 127 | """ 128 | def __call__(self, data): 129 | data["image"] = F.to_tensor(data["image"]) 130 | if "segmentation_target" in data: 131 | data["segmentation_target"] = F.to_tensor( 132 | data["segmentation_target"]) 133 | return data 134 | 135 | 136 | class ColorJitter: 137 | """Apply color jittering to input image. 138 | """ 139 | def __call__(self, data): 140 | data["image"] = torchvision_transforms.ColorJitter().__call__( 141 | data["image"]) 142 | return data 143 | 144 | 145 | class Normalize: 146 | """Normalize the input image with options of RGB/BGR converting. 147 | 148 | Parameters: 149 | ---------- 150 | mean: list 151 | Mean value for the 3 image channels. 152 | 153 | std: list 154 | Standard deviation for the 3 image channels. 155 | 156 | to_bgr255: bool 157 | If true, the default image come in with rgb channel and [0,1] scale. 158 | it will be converted into bgr with [0,255] scale. 159 | """ 160 | def __init__(self, mean, std, to_bgr255=True): 161 | self.mean = mean 162 | self.std = std 163 | self.to_bgr255 = to_bgr255 164 | 165 | def __call__(self, data): 166 | if self.to_bgr255: 167 | data["image"] = data["image"][[2, 1, 0]] * 255 168 | if "segmentation_target" in data: 169 | data["segmentation_target"] = ( 170 | data["segmentation_target"] * 255).long() 171 | data["image"] = F.normalize( 172 | data["image"], mean=self.mean, std=self.std) 173 | return data 174 | 175 | 176 | class RandomCrop: 177 | """Randomly Crop in input panoptic data. 178 | 179 | Parameters: 180 | ---------- 181 | crop_size: tuple 182 | Desired crop size of the data. 183 | """ 184 | def __init__(self, crop_size): 185 | # A couple of safety checks 186 | assert isinstance(crop_size, tuple) 187 | self.crop_size = crop_size 188 | self.crop = torchvision_transforms.RandomCrop(crop_size) 189 | 190 | def __call__(self, data): 191 | if len(self.crop_size) <= 1: 192 | return data 193 | 194 | # If image size is smaller than crop size, 195 | # resize both image and target to at least crop size. 196 | if self.crop_size[0] > data["image"].size[0] or self.crop_size[ 197 | 1] > data["image"].size[1]: 198 | print("Image will be resized before cropping. {},{}".format( 199 | self.crop_size, data["image"].size)) 200 | resize_func = Resize( 201 | max(self.crop_size), 202 | round( 203 | max(data["image"].size) * max(self.crop_size) / min( 204 | data["image"].size))) 205 | data = resize_func(data) 206 | 207 | if "detection_target" not in data: 208 | 209 | image_width, image_height = data["image"].size 210 | crop_width, crop_height = self.crop_size 211 | assert image_width >= crop_width and image_height >= crop_height 212 | 213 | left = 0 214 | if image_width > crop_width: 215 | left = random.randint(0, image_width - crop_width) 216 | top = 0 217 | if image_height > crop_height: 218 | top = random.randint(0, image_height - crop_height) 219 | 220 | data["image"] = data["image"].crop((left, top, left + crop_width, 221 | top + crop_height)) 222 | if "segmentation_target" in data: 223 | data["segmentation_target"] = data["segmentation_target"].crop( 224 | (left, top, left + crop_width, top + crop_height)) 225 | 226 | else: 227 | # We always crop an area that contains at least one instance. 228 | # TODO: We are making an assumption here that data are filtered to only include 229 | # non-empty training samples. So the while loop will not be a dead lock. Need to 230 | # improve the efficiency of this part. 231 | while True: 232 | # continuously try till there's instance inside it. 233 | w, h = data["image"].size 234 | if w <= self.crop_size[0] or h <= self.crop_size[1]: 235 | break 236 | # y, x, h, w 237 | top, left, crop_height, crop_width = self.crop.get_params( 238 | data["image"], (self.crop_size[1], self.crop_size[0])) 239 | 240 | image_cropped = F.crop(data["image"], top, left, crop_height, 241 | crop_width) 242 | 243 | detection_target_cropped = data[ 244 | "detection_target"].augmentation_crop( 245 | top, left, crop_height, crop_width) 246 | 247 | if detection_target_cropped is not None: 248 | data["image"] = image_cropped 249 | data["detection_target"] = detection_target_cropped 250 | 251 | # Once ``detection_target`` gets properly cropped, then crop 252 | # ``segmentation_target`` using the same parameters 253 | if "segmentation_target" in data: 254 | data["segmentation_target"] = F.crop( 255 | data["segmentation_target"], top, left, 256 | crop_height, crop_width) 257 | break 258 | return data -------------------------------------------------------------------------------- /realtime_panoptic/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Toyota Research Institute. All rights reserved. 2 | -------------------------------------------------------------------------------- /realtime_panoptic/layers/scale.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Toyota Research Institute. All rights reserved. 2 | 3 | # Adapted from FCOS 4 | # https://github.com/tianzhi0549/FCOS/blob/master/fcos_core/layers/scale.py 5 | 6 | import torch 7 | from torch import nn 8 | 9 | 10 | class Scale(nn.Module): 11 | def __init__(self, init_value=1.0): 12 | super(Scale, self).__init__() 13 | """Scale layer with trainable scale factor. 14 | 15 | Parameters 16 | ---------- 17 | init_value: float 18 | Initial value of the scale factor. 19 | """ 20 | self.scale = nn.Parameter(torch.FloatTensor([init_value])) 21 | 22 | def forward(self, input): 23 | return input * self.scale 24 | -------------------------------------------------------------------------------- /realtime_panoptic/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Toyota Research Institute. All rights reserved. 2 | -------------------------------------------------------------------------------- /realtime_panoptic/models/backbones.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Toyota Research Institute. All rights reserved. 2 | from torch import nn 3 | from torchvision.models import resnet 4 | from torchvision.models._utils import IntermediateLayerGetter 5 | from torchvision.ops import misc as misc_nn_ops 6 | from torchvision.ops.feature_pyramid_network import (FeaturePyramidNetwork, 7 | LastLevelP6P7) 8 | 9 | # This class is adapted from "BackboneWithFPN" in torchvision.models.detection.backbone_utils 10 | 11 | class ResNetWithModifiedFPN(nn.Module): 12 | """Adds a p67-FPN on top of a ResNet model with more options. 13 | 14 | We adopt this function from torchvision.models.detection.backbone_utils. 15 | Modification has been added to enable RetinaNet style FPN with P6 P7 as extra blocks. 16 | 17 | Parameters 18 | ---------- 19 | backbone_name: string 20 | Resnet architecture supported by torchvision. Possible values are 'ResNet', 'resnet18', 'resnet34', 'resnet50', 21 | 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2' 22 | 23 | norm_layer: torchvision.ops 24 | It is recommended to use the default value. For details visit: 25 | (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267) 26 | 27 | pretrained: bool 28 | If True, returns a model with backbone pre-trained on Imagenet. Default: False 29 | 30 | trainable_layers: int 31 | Number of trainable (not frozen) resnet layers starting from final block. 32 | Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. 33 | 34 | out_channels: int 35 | number of channels in the FPN. 36 | """ 37 | 38 | def __init__(self, backbone_name, pretrained=False, norm_layer=misc_nn_ops.FrozenBatchNorm2d, trainable_layers=3, out_channels = 256): 39 | super().__init__() 40 | # Get ResNet 41 | backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) 42 | # select layers that wont be frozen 43 | assert 0 <= trainable_layers <= 5 44 | layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers] 45 | # freeze layers only if pretrained backbone is used 46 | for name, parameter in backbone.named_parameters(): 47 | if all([not name.startswith(layer) for layer in layers_to_train]): 48 | parameter.requires_grad_(False) 49 | 50 | return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'} 51 | 52 | in_channels_stage2 = backbone.inplanes // 8 53 | self.in_channels_list = [ 54 | 0, 55 | in_channels_stage2 * 2, 56 | in_channels_stage2 * 4, 57 | in_channels_stage2 * 8, 58 | ] 59 | 60 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 61 | self.fpn = FeaturePyramidNetwork( 62 | in_channels_list=self.in_channels_list[1:], # nonzero only 63 | out_channels=out_channels, 64 | extra_blocks=LastLevelP6P7(out_channels, out_channels), 65 | ) 66 | self.out_channels = out_channels 67 | 68 | def forward(self, x): 69 | x = self.body(x) 70 | keys = list(x.keys()) 71 | for idx, key in enumerate(keys): 72 | if self.in_channels_list[idx] == 0: 73 | del x[key] 74 | x = self.fpn(x) 75 | return x 76 | -------------------------------------------------------------------------------- /realtime_panoptic/models/panoptic_from_dense_box.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Toyota Research Institute. All rights reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | from realtime_panoptic.utils.bounding_box import BoxList 5 | from realtime_panoptic.utils.boxlist_ops import (boxlist_nms, cat_boxlist, remove_small_boxes) 6 | 7 | class PanopticFromDenseBox: 8 | """Performs post-processing on the outputs of the RTPanonet. 9 | 10 | Parameters 11 | ---------- 12 | pre_nms_thresh: float 13 | Acceptance class probability threshold for bounding box candidates before NMS. 14 | 15 | pre_nms_top_n: int 16 | Maximum number of accepted bounding box candidates before NMS. 17 | 18 | nms_thresh: float 19 | NMS threshold. 20 | 21 | fpn_post_nms_top_n: int 22 | Maximum number of detected object per image. 23 | 24 | min_size: int 25 | Minimum dimension of accepted detection. 26 | 27 | num_classes: int 28 | Number of total semantic classes (stuff and things). 29 | 30 | mask_thresh: float 31 | Bounding box IoU threshold to determined 'similar bounding box' in mask reconstruction. 32 | 33 | instance_id_range: list of int 34 | [min_id, max_id] defines the range of id in 1:num_classes that corresponding to thing classes. 35 | 36 | is_training: bool 37 | Whether the current process is during training process. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | pre_nms_thresh, 43 | pre_nms_top_n, 44 | nms_thresh, 45 | fpn_post_nms_top_n, 46 | min_size, 47 | num_classes, 48 | mask_thresh, 49 | instance_id_range, 50 | is_training 51 | ): 52 | super(PanopticFromDenseBox, self).__init__() 53 | # assign parameters 54 | self.pre_nms_thresh = pre_nms_thresh 55 | self.pre_nms_top_n = pre_nms_top_n 56 | self.nms_thresh = nms_thresh 57 | self.fpn_post_nms_top_n = fpn_post_nms_top_n 58 | self.min_size = min_size 59 | self.num_classes = num_classes 60 | self.mask_thresh = mask_thresh 61 | self.instance_id_range = instance_id_range 62 | self.is_training = is_training 63 | 64 | def process( 65 | self, locations, box_cls, box_regression, centerness, levelness_logits, semantic_logits, image_sizes 66 | ): 67 | """ Reconstruct panoptic segmentation result from raw predictions. 68 | 69 | This function conduct post processing of panoptic head raw prediction, including bounding box 70 | prediction, semantic segmentation and levelness to reconstruct instance segmentation results. 71 | 72 | Parameters 73 | ---------- 74 | locations: list of torch.Tensor 75 | Corresponding pixel locations of each FPN predictions. 76 | 77 | box_cls: list of torch.Tensor 78 | Predicted bounding box class from each FPN layers. 79 | 80 | box_regression: list of torch.Tensor 81 | Predicted bounding box offsets from each FPN layers. 82 | 83 | centerness: list of torch.Tensor 84 | Predicted object centerness from each FPN layers. 85 | 86 | levelness_logits: 87 | Global prediction of best source FPN layer for each pixel location. 88 | 89 | semantic_logits: 90 | Global prediction of semantic segmentation. 91 | 92 | image_sizes: list of [int,int] 93 | Image sizes. 94 | 95 | Returns: 96 | -------- 97 | boxlists: list of BoxList 98 | reconstructed instances with masks. 99 | """ 100 | num_locs_per_level = [len(loc_per_level) for loc_per_level in locations] 101 | 102 | sampled_boxes = [] 103 | for i, (l, o, b, c) in enumerate(zip(locations[:-1], box_cls, box_regression, centerness)): 104 | if self.is_training: 105 | layer_boxes = self.forward_for_single_feature_map(l, o, b, c, image_sizes) 106 | for layer_box in layer_boxes: 107 | pred_indices = layer_box.get_field("indices") 108 | pred_indices = pred_indices + sum(num_locs_per_level[:i]) 109 | layer_box.add_field("indices", pred_indices) 110 | sampled_boxes.append(layer_boxes) 111 | else: 112 | sampled_boxes.append(self.forward_for_single_feature_map(l, o, b, c, image_sizes)) 113 | 114 | # sampled_boxes are a list of bbox_list per level 115 | # the following converts it to per image 116 | boxlists = list(zip(*sampled_boxes)) 117 | # per image, concat bbox_list of different levels into one bbox_list 118 | # boxlists is a list of bboxlists of N images 119 | try: 120 | boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] 121 | boxlists = self.select_over_all_levels(boxlists) 122 | except Exception as e: 123 | print(e) 124 | for boxlist in boxlists: 125 | for box in boxlist: 126 | print(box, "box shape", box.bbox.shape) 127 | 128 | # Generate bounding box feature map at size of [H/4, W/4] with bounding box prediction as features. 129 | levelness_locations = locations[-1] 130 | _, c_semantic, _, _ = semantic_logits.shape 131 | N, _, h_map, w_map = levelness_logits.shape 132 | bounding_box_feature_map = self.generate_box_feature_map(levelness_locations, box_regression, levelness_logits) 133 | 134 | # process semantic raw prediction 135 | semantic_logits = F.interpolate(semantic_logits, size=(h_map, w_map), mode='bilinear') 136 | semantic_logits = semantic_logits.view(N, c_semantic, h_map, w_map).permute(0, 2, 3, 1) 137 | semantic_logits = semantic_logits.reshape(N, -1, c_semantic) 138 | 139 | # insert semantic prob into mask 140 | semantic_probability = F.softmax(semantic_logits, dim=2) 141 | semantic_probability = semantic_probability[:, :, self.instance_id_range[0]:] 142 | boxlists = self.mask_reconstruction( 143 | boxlists=boxlists, 144 | box_feature_map=bounding_box_feature_map, 145 | semantic_prob=semantic_probability, 146 | box_feature_map_location=levelness_locations, 147 | h_map=h_map, 148 | w_map=w_map 149 | ) 150 | # resize instance masks to original image size 151 | if not self.is_training: 152 | for boxlist in boxlists: 153 | masks = boxlist.get_field("mask") 154 | # NOTE: BoxList size is the image size without padding. MASK here is a mask with padding. 155 | # Mask need to be interpolated into padded image size and then crop to unpadded size. 156 | w, h = boxlist.size 157 | if len(masks.shape) == 3 and masks.shape[0] != 0: 158 | masks = F.interpolate(masks.unsqueeze(0), size=(h_map * 4, w_map * 4), mode='bilinear').squeeze() 159 | else: 160 | # handle 0 shape dummy mask. 161 | masks = masks.view([-1, h_map * 4, w_map * 4]) 162 | masks = masks >= self.mask_thresh 163 | if len(masks.shape) < 3: 164 | masks = masks.unsqueeze(0) 165 | masks = masks[:, 0:h, 0:w].contiguous() 166 | boxlist.add_field("mask", masks) 167 | return boxlists 168 | 169 | def forward_for_single_feature_map(self, locations, box_cls, box_regression, centerness, image_sizes): 170 | """Recover dense bounding box detection results from raw predictions for each FPN layer. 171 | 172 | Parameters 173 | ---------- 174 | locations: torch.Tensor 175 | Corresponding pixel location of FPN feature map with size of (N, H * W, 2). 176 | 177 | box_cls: torch.Tensor 178 | Predicted bounding box class probability with size of (N, C, H, W). 179 | 180 | box_regression: torch.Tensor 181 | Predicted bounding box offset centered at corresponding pixel with size of (N, 4, H, W). 182 | 183 | centerness: torch.Tensor 184 | Predicted centerness of corresponding pixel with size of (N, 1, H, W). 185 | 186 | Note: N is the number of FPN level. 187 | 188 | Returns 189 | ------- 190 | results: List of BoxList 191 | A list of dense bounding boxes from each FPN layer. 192 | """ 193 | 194 | N, C, H, W = box_cls.shape 195 | # M = H x W is the total number of proposal for this single feature map 196 | 197 | # put in the same format as locations 198 | # from (N, C, H, W) to (N, H, W, C) 199 | box_cls = box_cls.view(N, C, H, W).permute(0, 2, 3, 1) 200 | # from (N, H, W, C) to (N, M, C) 201 | # map class prob to (-1, +1) 202 | box_cls = box_cls.reshape(N, -1, C).sigmoid() 203 | # from (N, 4, H, W) to (N, H, W, 4) to (N, M, 4) 204 | box_regression = box_regression.view(N, 4, H, W).permute(0, 2, 3, 1) 205 | box_regression = box_regression.reshape(N, -1, 4) 206 | # from (N, 4, H, W) to (N, H, W, 1) to (N, M) 207 | # map centerness prob to (-1, +1) 208 | centerness = centerness.view(N, 1, H, W).permute(0, 2, 3, 1) 209 | centerness = centerness.reshape(N, -1).sigmoid() 210 | 211 | # before NMS, per level filter out low cls prob with threshold 0.05 212 | # after this candidate_inds of size (N, M, C) with values corresponding to 213 | # low prob predictions become 0, otherwise 1 214 | candidate_inds = box_cls > self.pre_nms_thresh 215 | 216 | # pre_nms_top_n of size (N, M * C) => (N, 1) 217 | # N -> batch index, 1 -> total number of bbox predictions per image 218 | pre_nms_top_n = candidate_inds.view(N, -1).sum(1) 219 | # total number of proposal before NMS 220 | # if have more than self.pre_nms_top_n (1000) clamp to 1000 221 | pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n) 222 | 223 | # multiply the classification scores with centerness scores 224 | # (N, M, C) * (N, M, 1) 225 | box_cls = box_cls * centerness[:, :, None] 226 | 227 | results = [] 228 | for i in range(N): 229 | # filer out low score candidates 230 | per_box_cls = box_cls[i] # (M, C) 231 | per_candidate_inds = candidate_inds[i] # (M, C) 232 | # per_box_cls of size P, P < M * C 233 | per_box_cls = per_box_cls[per_candidate_inds] 234 | 235 | # indices of seeds bounding boxes 236 | # 0-dim corresponding to M, location 237 | # 1-dim corresponding to C, class 238 | per_candidate_nonzeros = per_candidate_inds.nonzero() 239 | # Each of the following is of size P < M * C 240 | per_box_loc = per_candidate_nonzeros[:, 0] 241 | per_class = per_candidate_nonzeros[:, 1] + 1 242 | 243 | # per_box_regression of size (M, 4) 244 | per_box_regression = box_regression[i] 245 | # (M, 4) => (P, 4) 246 | # in P, there might be identical bbox prediction in M 247 | per_box_regression = per_box_regression[per_box_loc] 248 | # (M, 2) => (P, 2) 249 | # in P, there might be identical locations in M 250 | per_locations = locations[per_box_loc] 251 | 252 | 253 | # upperbound of the number of predictions for this image 254 | per_pre_nms_top_n = pre_nms_top_n[i] 255 | 256 | # if valid predictions is more than the upperbound 257 | # only select topK 258 | if per_candidate_inds.sum().item() > per_pre_nms_top_n.item(): 259 | per_box_cls, top_k_indices = \ 260 | per_box_cls.topk(per_pre_nms_top_n, sorted=False) 261 | per_class = per_class[top_k_indices] 262 | per_box_regression = per_box_regression[top_k_indices] 263 | per_locations = per_locations[top_k_indices] 264 | if self.is_training: 265 | per_box_loc = per_box_loc[top_k_indices] 266 | 267 | detections = torch.stack([ 268 | per_locations[:, 0] - per_box_regression[:, 0], 269 | per_locations[:, 1] - per_box_regression[:, 1], 270 | per_locations[:, 0] + per_box_regression[:, 2], 271 | per_locations[:, 1] + per_box_regression[:, 3], 272 | ], 273 | dim=1) 274 | 275 | h, w = image_sizes[i] 276 | 277 | boxlist = BoxList(detections, (int(w), int(h)), mode="xyxy") 278 | boxlist.add_field("labels", per_class) 279 | boxlist.add_field("scores", per_box_cls) 280 | if self.is_training: 281 | boxlist.add_field("indices", per_box_loc) 282 | 283 | boxlist = boxlist.clip_to_image(remove_empty=False) 284 | boxlist = remove_small_boxes(boxlist, self.min_size) 285 | results.append(boxlist) 286 | return results 287 | 288 | def generate_box_feature_map(self, location, box_regression, levelness_logits): 289 | """Generate bounding box feature aggregating dense bounding box predictions. 290 | 291 | Parameters 292 | ---------- 293 | location: torch.Tensor 294 | Pixel location of levelness. 295 | 296 | box_regression: list of torch.Tensor 297 | Bounding box offsets from each FPN. 298 | 299 | levelness_logits: torch.Tenor 300 | Global prediction of best source FPN layer for each pixel location. 301 | Predict at the resolution of (H/4, W/4). 302 | 303 | Returns 304 | ------- 305 | bounding_box_feature_map: torch.Tensor 306 | Aggregated bounding box feature map. 307 | """ 308 | upscaled_box_reg = [] 309 | N, _, h_map, w_map = levelness_logits.shape 310 | downsampled_shape = torch.Size((h_map, w_map)) 311 | for box_reg in box_regression: 312 | upscaled_box_reg.append(F.interpolate(box_reg, size=downsampled_shape, mode='bilinear').unsqueeze(1)) 313 | 314 | # N_level, 4, h_map, w_map 315 | upscaled_box_reg = torch.cat(upscaled_box_reg, 1) 316 | 317 | max_v, level = torch.max(levelness_logits[:, 1:, :, :], dim=1) 318 | 319 | box_feature_map = torch.gather( 320 | upscaled_box_reg, dim=1, index=level.unsqueeze(1).expand([N, 4, h_map, w_map]).unsqueeze(1) 321 | ) 322 | 323 | box_feature_map = box_feature_map.view(N, 4, h_map, w_map).permute(0, 2, 3, 1) 324 | box_feature_map = box_feature_map.reshape(N, -1, 4) 325 | # generate all valid bboxes from feature map 326 | # shape (N, M, 4) 327 | levelness_locations_repeat = location.repeat(N, 1, 1) 328 | bounding_box_feature_map = torch.stack([ 329 | levelness_locations_repeat[:, :, 0] - box_feature_map[:, :, 0], 330 | levelness_locations_repeat[:, :, 1] - box_feature_map[:, :, 1], 331 | levelness_locations_repeat[:, :, 0] + box_feature_map[:, :, 2], 332 | levelness_locations_repeat[:, :, 1] + box_feature_map[:, :, 3], 333 | ], dim=2) 334 | return bounding_box_feature_map 335 | 336 | def mask_reconstruction(self, boxlists, box_feature_map, semantic_prob, box_feature_map_location, h_map, w_map): 337 | """Reconstruct instance mask from dense bounding box and semantic smoothing. 338 | 339 | Parameters 340 | ---------- 341 | boxlists: List of Boxlist 342 | Object detection result after NMS. 343 | 344 | box_feature_map: torch.Tensor 345 | Aggregated bounding box feature map. 346 | 347 | semantic_prob: torch.Tensor 348 | Prediction semantic probability. 349 | 350 | box_feature_map_location: torch.Tensor 351 | Corresponding pixel location of bounding box feature map. 352 | 353 | h_map: int 354 | Height of bounding box feature map. 355 | 356 | w_map: int 357 | Width of bounding box feature map. 358 | """ 359 | for i, (boxlist, per_image_bounding_box_feature_map, per_image_semantic_prob, 360 | box_feature_map_loc) in enumerate(zip(boxlists, box_feature_map, semantic_prob, box_feature_map_location)): 361 | 362 | # decode mask from bbox embedding 363 | if len(boxlist) > 0: 364 | # query_boxes is of shape (P, 4) 365 | # dense_detections is of shape (P', 4) 366 | # P' is larger than P 367 | query_boxes = boxlist.bbox 368 | propose_cls = boxlist.get_field("labels") 369 | # (P, 4) -> (P, 4, 1) -> (P, 4, P) -> (P, P', 4) 370 | propose_bbx = query_boxes.unsqueeze(2).repeat(1, 1, 371 | per_image_bounding_box_feature_map.shape[0]).permute(0, 2, 1) 372 | # (P',4) -> (4, P') -> (1, 4, P') -> (P, 4, P') -> (P, P', 4) 373 | voting_bbx = per_image_bounding_box_feature_map.permute(1, 0).unsqueeze(0).repeat(query_boxes.shape[0], 1, 374 | 1).permute(0, 2, 1) 375 | # implementation based on IOU for bbox_correlation_map 376 | # 0, 1, 2, 3 => left, top, right, bottom 377 | proposal_area = (propose_bbx[:, :, 2] - propose_bbx[:, :, 0]) * \ 378 | (propose_bbx[:, :, 3] - propose_bbx[:, :, 1]) 379 | voting_area = (voting_bbx[:, :, 2] - voting_bbx[:, :, 0]) * \ 380 | (voting_bbx[:, :, 3] - voting_bbx[:, :, 1]) 381 | w_intersect = torch.min(voting_bbx[:, :, 2], propose_bbx[:, :, 2]) - \ 382 | torch.max(voting_bbx[:, :, 0], propose_bbx[:, :, 0]) 383 | h_intersect = torch.min(voting_bbx[:, :, 3], propose_bbx[:, :, 3]) - \ 384 | torch.max(voting_bbx[:, :, 1], propose_bbx[:, :, 1]) 385 | w_intersect = w_intersect.clamp(min=0.0) 386 | h_intersect = h_intersect.clamp(min=0.0) 387 | w_general = torch.max(voting_bbx[:, :, 2], propose_bbx[:, :, 2]) - \ 388 | torch.min(voting_bbx[:, :, 0], propose_bbx[:, :, 0]) 389 | h_general = torch.max(voting_bbx[:, :, 3], propose_bbx[:, :, 3]) - \ 390 | torch.min(voting_bbx[:, :, 1], propose_bbx[:, :, 1]) 391 | # calculate IOU 392 | area_intersect = w_intersect * h_intersect 393 | area_union = proposal_area + voting_area - area_intersect 394 | torch.cuda.synchronize() 395 | 396 | area_general = w_general * h_general + 1e-7 397 | bbox_correlation_map = (area_intersect + 1.0) / (area_union + 1.0) - \ 398 | (area_general - area_union) / area_general 399 | 400 | per_image_cls_prob = per_image_semantic_prob[:, propose_cls - 1].permute(1, 0) 401 | # bbox_correlation_map is of size (P or per_pre_nms_top_n, P') 402 | bbox_correlation_map = bbox_correlation_map * per_image_cls_prob 403 | # query_boxes.shape[0] is the number of filtered boxes 404 | masks = bbox_correlation_map.view(query_boxes.shape[0], h_map, w_map) 405 | if len(masks.shape) < 3: 406 | masks = masks.unsqueeze(0) 407 | boxlist.add_field("mask", masks) 408 | else: 409 | dummy_masks = torch.zeros(len(boxlist), h_map, 410 | w_map).float().to(boxlist.bbox.device).to(boxlist.bbox.dtype) 411 | boxlist.add_field("mask", dummy_masks) 412 | return boxlists 413 | 414 | def select_over_all_levels(self, boxlists): 415 | """NMS of bounding box candidates. 416 | 417 | Parameters 418 | ---------- 419 | boxlists: list of Boxlist 420 | Pre-NMS bounding boxes. 421 | 422 | Returns 423 | ------- 424 | results: list of Boxlist 425 | Final detection result. 426 | """ 427 | num_images = len(boxlists) 428 | results = [] 429 | for i in range(num_images): 430 | boxlist = boxlists[i] 431 | scores = boxlist.get_field("scores") 432 | labels = boxlist.get_field("labels") 433 | if self.is_training: 434 | indices = boxlist.get_field("indices") 435 | boxes = boxlist.bbox 436 | 437 | result = [] 438 | w, h = boxlist.size 439 | # skip the background 440 | if boxes.shape[0] < 1: 441 | results.append(boxlist) 442 | continue 443 | for j in range(1, self.num_classes): 444 | inds = (labels == j).nonzero().view(-1) 445 | if len(inds) > 0: 446 | scores_j = scores[inds] 447 | boxes_j = boxes[inds, :].view(-1, 4) 448 | 449 | boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") 450 | boxlist_for_class.add_field("scores", scores_j) 451 | 452 | if self.is_training: 453 | indices_j = indices[inds] 454 | boxlist_for_class.add_field("indices", indices_j) 455 | 456 | boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms_thresh, score_field="scores") 457 | num_labels = len(boxlist_for_class) 458 | boxlist_for_class.add_field( 459 | "labels", torch.full((num_labels, ), j, dtype=torch.int64, device=scores.device) 460 | ) 461 | result.append(boxlist_for_class) 462 | result = cat_boxlist(result) 463 | 464 | # global NMS 465 | result = boxlist_nms(result, 0.97, score_field="scores") 466 | 467 | number_of_detections = len(result) 468 | 469 | # Limit to max_per_image detections **over all classes** 470 | if number_of_detections > self.fpn_post_nms_top_n > 0: 471 | cls_scores = result.get_field("scores") 472 | image_thresh, _ = torch.kthvalue(cls_scores.cpu(), number_of_detections - self.fpn_post_nms_top_n + 1) 473 | keep = cls_scores >= image_thresh.item() 474 | keep = torch.nonzero(keep).squeeze(1) 475 | result = result[keep] 476 | results.append(result) 477 | return results 478 | -------------------------------------------------------------------------------- /realtime_panoptic/models/rt_pano_net.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Toyota Research Institute. All rights reserved. 2 | 3 | # We adapted some FCOS related functions from official repository. 4 | # https://github.com/tianzhi0549/FCOS 5 | 6 | import math 7 | from collections import OrderedDict 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | import apex 12 | from realtime_panoptic.layers.scale import Scale 13 | from realtime_panoptic.utils.bounding_box import BoxList 14 | from realtime_panoptic.models.backbones import ResNetWithModifiedFPN 15 | from realtime_panoptic.models.panoptic_from_dense_box import PanopticFromDenseBox 16 | class RTPanoNet(torch.nn.Module): 17 | """Real-Time Panoptic Network 18 | This module takes the input from a FPN backbone and conducts feature extraction 19 | through a panoptic head, which can be then fed into post processing for final panoptic 20 | results including semantic segmentation and instance segmentation. 21 | NOTE: Currently only the inference functionality is supported. 22 | 23 | Parameters 24 | ---------- 25 | backbone: str 26 | backbone type. 27 | 28 | num_classes: int 29 | Number of total classes, including 'things' and 'stuff'. 30 | 31 | things_num_classes: int 32 | number of thing classes 33 | 34 | pre_nms_thresh: float 35 | Acceptance class probability threshold for bounding box candidates before NMS. 36 | 37 | pre_nms_top_n: int 38 | Maximum number of accepted bounding box candidates before NMS. 39 | 40 | nms_thresh: float 41 | NMS threshold. 42 | 43 | fpn_post_nms_top_n: int 44 | Maximum number of detected object per image. 45 | 46 | instance_id_range: list of int 47 | [min_id, max_id] defines the range of id in 1:num_classes that corresponding to thing classes. 48 | """ 49 | 50 | def __init__( 51 | self, 52 | backbone, 53 | num_classes, 54 | things_num_classes, 55 | pre_nms_thresh, 56 | pre_nms_top_n, 57 | nms_thresh, 58 | fpn_post_nms_top_n, 59 | instance_id_range 60 | ): 61 | super(RTPanoNet, self).__init__() 62 | # TODO: adapt more backbone. 63 | if backbone == 'R-50-FPN-RETINANET': 64 | self.backbone = ResNetWithModifiedFPN('resnet50') 65 | backbone_out_channels = 256 66 | fpn_strides = [8, 16, 32, 64, 128] 67 | num_fpn_levels = 5 68 | else: 69 | raise NotImplementedError("Backbone type: {} is not supported yet.".format(backbone)) 70 | # Global panoptic head that extracts features from each FPN output feature map 71 | self.panoptic_head = PanopticHead( 72 | num_classes, 73 | things_num_classes, 74 | num_fpn_levels, 75 | fpn_strides, 76 | backbone_out_channels 77 | ) 78 | 79 | # Parameters 80 | self.fpn_strides = fpn_strides 81 | 82 | # Use dense bounding boxes to reconstruct panoptic segmentation results. 83 | self.panoptic_from_dense_bounding_box = PanopticFromDenseBox( 84 | pre_nms_thresh=pre_nms_thresh, 85 | pre_nms_top_n=pre_nms_top_n, 86 | nms_thresh=nms_thresh, 87 | fpn_post_nms_top_n=fpn_post_nms_top_n, 88 | min_size=0, 89 | num_classes=num_classes, 90 | mask_thresh=0.4, 91 | instance_id_range=instance_id_range, 92 | is_training=False) 93 | 94 | def forward(self, images, detection_targets=None, segmentation_targets=None): 95 | """ Forward function. 96 | 97 | Parameters 98 | ---------- 99 | images: torchvision.models.detection.ImageList 100 | Images for which we want to compute the predictions 101 | 102 | detection_targets: list of BoxList 103 | Ground-truth boxes present in the image 104 | 105 | segmentation_targets: List of torch.Tensor 106 | semantic segmentation target for each image in the batch. 107 | 108 | Returns 109 | ------- 110 | panoptic_result: Dict 111 | 'instance_segmentation_result': list of BoxList 112 | The predicted boxes (including instance masks), one BoxList per image. 113 | 'semantic_segmentation_result': torch.Tensor 114 | semantic logits interpolated to input data size. 115 | NOTE: this might not be the original input image size due to paddings. 116 | losses: dict of torch.ScalarTensor 117 | the losses for the model during training. During testing, it is an empty dict. 118 | """ 119 | features = self.backbone(torch.stack(images.tensors)) 120 | 121 | locations = self.compute_locations(list(features.values())) 122 | 123 | semantic_logits, box_cls, box_regression, centerness, levelness_logits = self.panoptic_head(list(features.values())) 124 | 125 | # Get full size semantic logits. 126 | downsampled_level = images.tensors[0].shape[-1] // semantic_logits.shape[-1] 127 | interpolated_semantic_logits_padded = F.interpolate(semantic_logits, scale_factor=downsampled_level, mode='bilinear') 128 | interpolated_semantic_logits = interpolated_semantic_logits_padded[:,:,:images.tensors[0].shape[-2], :images.tensors[0].shape[-1]] 129 | # Calculate levelness locations. 130 | h, w = levelness_logits.size()[-2:] 131 | levelness_location = self.compute_locations_per_level(h, w, self.fpn_strides[0] // 2, levelness_logits.device) 132 | locations.append(levelness_location) 133 | 134 | # Reconstruct mask from dense bounding box and semantic predictions 135 | panoptic_result = OrderedDict() 136 | boxes = self.panoptic_from_dense_bounding_box.process( 137 | locations, box_cls, box_regression, centerness, levelness_logits, semantic_logits, images.image_sizes 138 | ) 139 | panoptic_result["instance_segmentation_result"] = boxes 140 | panoptic_result["semantic_segmentation_result"] = interpolated_semantic_logits 141 | return panoptic_result, {} 142 | 143 | def compute_locations(self, features): 144 | """Compute corresponding pixel location for feature maps. 145 | 146 | Parameters 147 | ---------- 148 | features: list of torch.Tensor 149 | List of feature maps. 150 | 151 | Returns 152 | ------- 153 | locations: list of torch.Tensor 154 | List of pixel location corresponding to the list of features. 155 | """ 156 | locations = [] 157 | for level, feature in enumerate(features): 158 | h, w = feature.size()[-2:] 159 | locations_per_level = self.compute_locations_per_level(h, w, self.fpn_strides[level], feature.device) 160 | locations.append(locations_per_level) 161 | return locations 162 | 163 | def compute_locations_per_level(self, h, w, stride, device): 164 | """Compute corresponding pixel location for a feature map in pyramid space with certain stride. 165 | 166 | Parameters 167 | ---------- 168 | h: int 169 | height of current feature map. 170 | 171 | w: int 172 | width of current feature map. 173 | 174 | stride: int 175 | stride level of current feature map with respect to original input. 176 | 177 | device: torch.device 178 | device to create return tensor. 179 | 180 | Returns 181 | ------- 182 | locations: torch.Tensor 183 | pixel location map. 184 | """ 185 | shifts_x = torch.arange(0, w * stride, step=stride, dtype=torch.float32, device=device) 186 | shifts_y = torch.arange(0, h * stride, step=stride, dtype=torch.float32, device=device) 187 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) 188 | shift_x = shift_x.reshape(-1) 189 | shift_y = shift_y.reshape(-1) 190 | locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 191 | return locations 192 | 193 | 194 | class PanopticHead(torch.nn.Module): 195 | """Network module of Panoptic Head extracting features from FPN feature maps. 196 | 197 | Parameters 198 | ---------- 199 | num_classes: int 200 | Number of total classes, including 'things' and 'stuff'. 201 | 202 | things_num_classes: int 203 | number of thing classes. 204 | 205 | num_fpn_levels: int 206 | Number of FPN levels. 207 | 208 | fpn_strides: list 209 | FPN strides at each FPN scale. 210 | 211 | in_channels: int 212 | Number of channels of the input features (output of FPN) 213 | 214 | norm_reg_targets: bool 215 | If true, train on normalized target. 216 | 217 | centerness_on_reg: bool 218 | If true, regress centerness on box tower of FCOS. 219 | 220 | fcos_num_convs: int 221 | number of convolution modules used in FCOS towers. 222 | 223 | fcos_norm: str 224 | Normalization layer type used in FCOS modules. 225 | 226 | prior_prob: float 227 | Initial probability for focal loss. See `https://arxiv.org/pdf/1708.02002.pdf` for more details. 228 | """ 229 | def __init__( 230 | self, 231 | num_classes, 232 | things_num_classes, 233 | num_fpn_levels, 234 | fpn_strides, 235 | in_channels, 236 | norm_reg_targets=False, 237 | centerness_on_reg=True, 238 | fcos_num_convs=4, 239 | fcos_norm='GN', 240 | prior_prob=0.01, 241 | ): 242 | super(PanopticHead, self).__init__() 243 | self.fpn_strides = fpn_strides 244 | self.norm_reg_targets = norm_reg_targets 245 | self.centerness_on_reg = centerness_on_reg 246 | 247 | cls_tower = [] 248 | bbox_tower = [] 249 | 250 | mid_channels = in_channels // 2 251 | 252 | for i in range(fcos_num_convs): 253 | # Class branch 254 | if i == 0: 255 | cls_tower.append(nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)) 256 | else: 257 | cls_tower.append(nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=1, padding=1)) 258 | if fcos_norm == "GN": 259 | cls_tower.append(nn.GroupNorm(mid_channels // 8, mid_channels)) 260 | elif fcos_norm == "BN": 261 | cls_tower.append(nn.BatchNorm2d(mid_channels)) 262 | elif fcos_norm == "SBN": 263 | cls_tower.append(apex.parallel.SyncBatchNorm(mid_channels)) 264 | cls_tower.append(nn.ReLU()) 265 | 266 | # Box regression branch 267 | bbox_tower.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) 268 | 269 | if fcos_norm == "GN": 270 | bbox_tower.append(nn.GroupNorm(in_channels // 8, in_channels)) 271 | elif fcos_norm == "BN": 272 | bbox_tower.append(nn.BatchNorm2d(in_channels)) 273 | elif fcos_norm == "SBN": 274 | bbox_tower.append(apex.parallel.SyncBatchNorm(in_channels)) 275 | bbox_tower.append(nn.ReLU()) 276 | 277 | self.add_module('cls_tower', nn.Sequential(*cls_tower)) 278 | self.add_module('bbox_tower', nn.Sequential(*bbox_tower)) 279 | 280 | self.cls_logits = nn.Conv2d(mid_channels * 5, num_classes, kernel_size=3, stride=1, padding=1) 281 | self.box_cls_logits = nn.Conv2d(mid_channels, things_num_classes, kernel_size=3, stride=1, padding=1) 282 | self.bbox_pred = nn.Conv2d(in_channels, 4, kernel_size=3, stride=1, padding=1) 283 | self.centerness = nn.Conv2d(in_channels, 1, kernel_size=3, stride=1, padding=1) 284 | self.levelness = nn.Conv2d(in_channels * 5, num_fpn_levels + 1, kernel_size=3, stride=1, padding=1) 285 | 286 | # initialization 287 | to_initialize = [ 288 | self.bbox_tower, self.cls_logits, self.cls_tower, self.bbox_pred, self.centerness, self.levelness, 289 | self.box_cls_logits 290 | ] 291 | 292 | for modules in to_initialize: 293 | for l in modules.modules(): 294 | if isinstance(l, nn.Conv2d): 295 | torch.nn.init.normal_(l.weight, std=0.01) 296 | torch.nn.init.constant_(l.bias, 0) 297 | 298 | # initialize the bias for focal loss 299 | prior_prob = prior_prob 300 | bias_value = -math.log((1 - prior_prob) / prior_prob) 301 | torch.nn.init.constant_(self.cls_logits.bias, bias_value) 302 | 303 | self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)]) 304 | 305 | def forward(self, x): 306 | box_cls = [] 307 | logits = [] 308 | bbox_reg = [] 309 | centerness = [] 310 | levelness = [] 311 | 312 | downsampled_shape = x[0].shape[2:] 313 | box_feature_map_downsampled_shape = torch.Size((downsampled_shape[0] * 2, downsampled_shape[1] * 2)) 314 | 315 | for l, feature in enumerate(x): 316 | # bbox 317 | box_tower = self.bbox_tower(feature) 318 | 319 | # class 320 | cls_tower = self.cls_tower(feature) 321 | box_cls.append(self.box_cls_logits(cls_tower)) 322 | logits.append(F.interpolate(cls_tower, size=downsampled_shape, mode='bilinear')) 323 | 324 | # centerness 325 | if self.centerness_on_reg: 326 | centerness.append(self.centerness(box_tower)) 327 | else: 328 | centerness.append(self.centerness(cls_tower)) 329 | 330 | # bbox regression 331 | bbox_pred = self.scales[l](self.bbox_pred(box_tower)) 332 | if self.norm_reg_targets: 333 | bbox_pred = F.relu(bbox_pred) 334 | if self.training: 335 | bbox_reg.append(bbox_pred) 336 | else: 337 | bbox_reg.append(bbox_pred * self.fpn_strides[l]) 338 | else: 339 | bbox_reg.append(torch.exp(bbox_pred.clamp(max=math.log(10000)))) 340 | 341 | # levelness prediction 342 | levelness.append(F.interpolate(box_tower, size=box_feature_map_downsampled_shape, mode='bilinear')) 343 | 344 | # predict levelness 345 | levelness = torch.cat(levelness, dim=1) 346 | # levelness = torch.stack(levelness, dim=0).sum(dim=0) 347 | levelness_logits = self.levelness(levelness) 348 | 349 | # level attention for semantic segmentation 350 | logits = torch.cat(logits, 1) 351 | semantic_logits = self.cls_logits(logits) 352 | 353 | return semantic_logits, box_cls, bbox_reg, centerness, levelness_logits 354 | 355 | 356 | -------------------------------------------------------------------------------- /realtime_panoptic/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Toyota Research Institute. All rights reserved. 2 | -------------------------------------------------------------------------------- /realtime_panoptic/utils/bounding_box.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # Copyright 2020 Toyota Research Institute. All rights reserved. 3 | 4 | # Adapted from maskrcnn-benchmark 5 | # https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/structures/bounding_box.py 6 | 7 | import math 8 | 9 | import torch 10 | 11 | # transpose 12 | FLIP_LEFT_RIGHT = 0 13 | FLIP_TOP_BOTTOM = 1 14 | ROTATE_90 = 2 15 | 16 | 17 | class BoxList: 18 | """This class represents a set of bounding boxes. 19 | The bounding boxes are represented as a Nx4 Tensor. 20 | In order to uniquely determine the bounding boxes with respect 21 | to an image, we also store the corresponding image dimensions. 22 | They can contain extra information that is specific to each bounding box, such as 23 | labels. 24 | """ 25 | 26 | def __init__(self, bbox, image_size, mode="xyxy"): 27 | """Initial function. 28 | 29 | Parameters 30 | ---------- 31 | bbox: tensor 32 | Nx4 tensor following bounding box parameterization defined by "mode". 33 | 34 | image_size: list 35 | [W,H] Image size. 36 | 37 | mode: str 38 | Bounding box parameterization. 'xyxy' or 'xyhw'. 39 | """ 40 | device = bbox.device if isinstance( 41 | bbox, torch.Tensor) else torch.device("cpu") 42 | bbox = torch.as_tensor(bbox, dtype=torch.float32, device=device) 43 | if bbox.ndimension() != 2: 44 | raise ValueError("bbox should have 2 dimensions, got {}".format( 45 | bbox.ndimension()), bbox) 46 | if bbox.size(-1) != 4: 47 | raise ValueError("last dimension of bbox should have a " 48 | "size of 4, got {}".format(bbox.size(-1))) 49 | if mode not in ("xyxy", "xywh"): 50 | raise ValueError("mode should be 'xyxy' or 'xywh'") 51 | 52 | self.bbox = bbox 53 | self.size = image_size # (image_width, image_height) 54 | self.mode = mode 55 | self.extra_fields = {} 56 | 57 | def add_field(self, field, field_data): 58 | """Add a field to boxlist. 59 | """ 60 | self.extra_fields[field] = field_data 61 | 62 | def get_field(self, field): 63 | """Get a field from boxlist. 64 | """ 65 | return self.extra_fields[field] 66 | 67 | def has_field(self, field): 68 | """Check if certain field exist in boxlist 69 | """ 70 | return field in self.extra_fields 71 | 72 | def fields(self): 73 | """Get all available field names. 74 | """ 75 | return list(self.extra_fields.keys()) 76 | 77 | def _copy_extra_fields(self, bbox): 78 | """Copy extra fields from given boxlist to current boxlist. 79 | """ 80 | for k, v in bbox.extra_fields.items(): 81 | self.extra_fields[k] = v 82 | 83 | def convert(self, mode): 84 | """Convert bounding box parameterization mode. 85 | """ 86 | if mode not in ("xyxy", "xywh"): 87 | raise ValueError("mode should be 'xyxy' or 'xywh'") 88 | if mode == self.mode: 89 | return self 90 | # we only have two modes, so don't need to check 91 | # self.mode 92 | xmin, ymin, xmax, ymax = self._split_into_xyxy() 93 | if mode == "xyxy": 94 | bbox = torch.cat((xmin, ymin, xmax, ymax), dim=-1) 95 | bbox = BoxList(bbox, self.size, mode=mode) 96 | else: 97 | TO_REMOVE = 1 98 | bbox = torch.cat( 99 | (xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), 100 | dim=-1) 101 | bbox = BoxList(bbox, self.size, mode=mode) 102 | bbox._copy_extra_fields(self) 103 | return bbox 104 | 105 | def _split_into_xyxy(self): 106 | """split lists of bounding box corners. 107 | """ 108 | if self.mode == "xyxy": 109 | xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1) 110 | return xmin, ymin, xmax, ymax 111 | elif self.mode == "xywh": 112 | TO_REMOVE = 1 113 | xmin, ymin, w, h = self.bbox.split(1, dim=-1) 114 | return ( 115 | xmin, 116 | ymin, 117 | xmin + (w - TO_REMOVE).clamp(min=0), 118 | ymin + (h - TO_REMOVE).clamp(min=0), 119 | ) 120 | else: 121 | raise RuntimeError("Should not be here") 122 | 123 | def resize(self, size, *args, **kwargs): 124 | """Returns a resized copy of this bounding box. 125 | 126 | Parameters 127 | ---------- 128 | size: list or tuple 129 | The requested image size in pixels, as a 2-tuple: 130 | (width, height). 131 | """ 132 | 133 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) 134 | if ratios[0] == ratios[1]: 135 | ratio = ratios[0] 136 | scaled_box = self.bbox * ratio 137 | bbox = BoxList(scaled_box, size, mode=self.mode) 138 | # bbox._copy_extra_fields(self) 139 | for k, v in self.extra_fields.items(): 140 | if not isinstance(v, torch.Tensor): 141 | v = v.resize(size, *args, **kwargs) 142 | bbox.add_field(k, v) 143 | return bbox 144 | 145 | ratio_width, ratio_height = ratios 146 | xmin, ymin, xmax, ymax = self._split_into_xyxy() 147 | scaled_xmin = xmin * ratio_width 148 | scaled_xmax = xmax * ratio_width 149 | scaled_ymin = ymin * ratio_height 150 | scaled_ymax = ymax * ratio_height 151 | scaled_box = torch.cat( 152 | (scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1) 153 | bbox = BoxList(scaled_box, size, mode="xyxy") 154 | # bbox._copy_extra_fields(self) 155 | for k, v in self.extra_fields.items(): 156 | if not isinstance(v, torch.Tensor): 157 | v = v.resize(size, *args, **kwargs) 158 | bbox.add_field(k, v) 159 | 160 | return bbox.convert(self.mode) 161 | 162 | def transpose(self, method): 163 | """Transpose bounding box (flip or rotate in 90 degree steps) 164 | 165 | Parameters 166 | ---------- 167 | method: str 168 | One of:py:attr:`PIL.Image.FLIP_LEFT_RIGHT`, 169 | :py:attr:`PIL.Image.FLIP_TOP_BOTTOM`,:py:attr:`PIL.Image.ROTATE_90`, 170 | :py:attr:`PIL.Image.ROTATE_180`,:py:attr:`PIL.Image.ROTATE_270`, 171 | :py:attr:`PIL.Image.TRANSPOSE` or:py:attr:`PIL.Image.TRANSVERSE`. 172 | """ 173 | if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM, ROTATE_90): 174 | raise NotImplementedError( 175 | "Only FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM and ROTATE_90 implemented" 176 | ) 177 | 178 | image_width, image_height = self.size 179 | xmin, ymin, xmax, ymax = self._split_into_xyxy() 180 | if method == FLIP_LEFT_RIGHT: 181 | TO_REMOVE = 1 182 | transposed_xmin = image_width - xmax - TO_REMOVE 183 | transposed_xmax = image_width - xmin - TO_REMOVE 184 | transposed_ymin = ymin 185 | transposed_ymax = ymax 186 | elif method == FLIP_TOP_BOTTOM: 187 | transposed_xmin = xmin 188 | transposed_xmax = xmax 189 | transposed_ymin = image_height - ymax 190 | transposed_ymax = image_height - ymin 191 | elif method == ROTATE_90: 192 | transposed_xmin = ymin * image_width / image_height 193 | transposed_xmax = ymax * image_width / image_height 194 | transposed_ymin = (image_width - xmax) * image_height / image_width 195 | transposed_ymax = (image_width - xmin) * image_height / image_width 196 | 197 | transposed_boxes = torch.cat((transposed_xmin, transposed_ymin, 198 | transposed_xmax, transposed_ymax), 199 | dim=-1) 200 | bbox = BoxList(transposed_boxes, self.size, mode="xyxy") 201 | # bbox._copy_extra_fields(self) 202 | for k, v in self.extra_fields.items(): 203 | if not isinstance(v, torch.Tensor): 204 | v = v.transpose(method) 205 | bbox.add_field(k, v) 206 | return bbox.convert(self.mode) 207 | 208 | def translate(self, x_offset, y_offset): 209 | """Translate bounding box. 210 | 211 | Parameters 212 | ---------- 213 | x_offseflt: float 214 | x offset 215 | y_offset: float 216 | y offset 217 | """ 218 | xmin, ymin, xmax, ymax = self._split_into_xyxy() 219 | 220 | translated_xmin = xmin + x_offset 221 | translated_xmax = xmax + x_offset 222 | translated_ymin = ymin + y_offset 223 | translated_ymax = ymax + y_offset 224 | 225 | translated_boxes = torch.cat((translated_xmin, translated_ymin, 226 | translated_xmax, translated_ymax), 227 | dim=-1) 228 | bbox = BoxList(translated_boxes, self.size, mode="xyxy") 229 | for k, v in self.extra_fields.items(): 230 | if not isinstance(v, torch.Tensor): 231 | v = v.translate(x_offset, y_offset) 232 | bbox.add_field(k, v) 233 | return bbox.convert(self.mode) 234 | 235 | def crop(self, box): 236 | """Crop a rectangular region from this bounding box. 237 | 238 | Parameters 239 | ---------- 240 | box: tuple 241 | The box is a 4-tuple defining the left, upper, right, and lower pixel 242 | coordinate. 243 | """ 244 | xmin, ymin, xmax, ymax = self._split_into_xyxy() 245 | w, h = box[2] - box[0], box[3] - box[1] 246 | cropped_xmin = (xmin - box[0]).clamp(min=0, max=w) 247 | cropped_ymin = (ymin - box[1]).clamp(min=0, max=h) 248 | cropped_xmax = (xmax - box[0]).clamp(min=0, max=w) 249 | cropped_ymax = (ymax - box[1]).clamp(min=0, max=h) 250 | 251 | # TODO should I filter empty boxes here? 252 | if False: 253 | is_empty = (cropped_xmin == cropped_xmax) | ( 254 | cropped_ymin == cropped_ymax) 255 | 256 | cropped_box = torch.cat( 257 | (cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1) 258 | bbox = BoxList(cropped_box, (w, h), mode="xyxy") 259 | # bbox._copy_extra_fields(self) 260 | for k, v in self.extra_fields.items(): 261 | if not isinstance(v, torch.Tensor): 262 | v = v.crop(box) 263 | bbox.add_field(k, v) 264 | return bbox.convert(self.mode) 265 | 266 | def augmentation_crop(self, top, left, crop_height, crop_width): 267 | """Random cropping of the bounding box (bbox). 268 | This function is created for label box to be crop at training time. 269 | 270 | Parameters: 271 | ----------- 272 | top: int 273 | Top pixel position of crop area 274 | 275 | left: int 276 | left pixel position of crop area 277 | 278 | crop_height: int 279 | Height of crop area 280 | 281 | crop_width: int 282 | Width of crop area 283 | 284 | Returns: 285 | -------- 286 | bbox_cropped: BoxList 287 | A BoxList object with instances after cropping. If no valid instance is left after 288 | cropping, return None. 289 | 290 | 291 | """ 292 | # SegmentationMasks object 293 | masks = self.extra_fields["masks"] 294 | 295 | # Conduct mask level cropping and return only the valid ones left. 296 | masks_cropped, keep_ids = masks.augmentation_crop( 297 | top, left, crop_height, crop_width) 298 | 299 | # the return cropped mask should be in "poly" mode 300 | if not keep_ids: 301 | return None 302 | assert masks_cropped.mode == "poly" 303 | bbox_cropped = [] 304 | labels = self.extra_fields["labels"] 305 | labels_cropped = [labels[idx] for idx in keep_ids] 306 | labels_cropped = torch.as_tensor(labels_cropped, dtype=torch.long) 307 | 308 | crop_box_xyxy = [ 309 | float(left), 310 | float(top), 311 | float(left + crop_width), 312 | float(top + crop_height) 313 | ] 314 | # Crop bounding box. 315 | # Note: this function will not change "masks" 316 | self.extra_fields.pop("masks", None) 317 | new_bbox = self.crop(crop_box_xyxy).convert("xyxy") 318 | 319 | # Further clip the boxes according to the clipped masks. 320 | for mask_id, box_id in enumerate(keep_ids): 321 | x1, y1, x2, y2 = new_bbox.bbox[box_id].numpy() 322 | 323 | # only resize the box on the edge: 324 | if x1 > 0 and y1 > 0 and x2 < crop_width - 1 and y2 < crop_height - 1: 325 | bbox_cropped.append([x1, y1, x2, y2]) 326 | else: 327 | # get PolygonInstance for current instance 328 | current_polygon_instance = masks_cropped.instances.polygons[ 329 | mask_id] 330 | x_ids = [] 331 | y_ids = [] 332 | for poly in current_polygon_instance.polygons: 333 | p = poly.clone() 334 | x_ids.extend(p[0::2]) 335 | y_ids.extend(p[1::2]) 336 | bbox_cropped.append( 337 | [min(x_ids), 338 | min(y_ids), 339 | max(x_ids), 340 | max(y_ids)]) 341 | bbox_cropped = BoxList( 342 | bbox_cropped, (crop_width, crop_height), mode="xyxy") 343 | bbox_cropped = bbox_cropped.convert(self.mode) 344 | bbox_cropped.add_field("masks", masks_cropped) 345 | bbox_cropped.add_field("labels", labels_cropped) 346 | return bbox_cropped 347 | 348 | def to(self, device): 349 | """Move object to torch device. 350 | """ 351 | bbox = BoxList(self.bbox.to(device), self.size, self.mode) 352 | for k, v in self.extra_fields.items(): 353 | if hasattr(v, "to"): 354 | v = v.to(device) 355 | bbox.add_field(k, v) 356 | return bbox 357 | 358 | def __getitem__(self, item): 359 | """Get a sub-list of Boxlist as a new Boxlist 360 | """ 361 | item_bbox = self.bbox[item] 362 | if len(item_bbox.shape) < 2: 363 | item_bbox.unsqueeze(0) 364 | bbox = BoxList(item_bbox, self.size, self.mode) 365 | for k, v in self.extra_fields.items(): 366 | bbox.add_field(k, v[item]) 367 | return bbox 368 | 369 | def __len__(self): 370 | return self.bbox.shape[0] 371 | 372 | def clip_to_image(self, remove_empty=True): 373 | """Clip bounding box coordinates according to the image range. 374 | """ 375 | TO_REMOVE = 1 376 | self.bbox[:, 0].clamp_(min=0, max=self.size[0] - TO_REMOVE) 377 | self.bbox[:, 1].clamp_(min=0, max=self.size[1] - TO_REMOVE) 378 | self.bbox[:, 2].clamp_(min=0, max=self.size[0] - TO_REMOVE) 379 | self.bbox[:, 3].clamp_(min=0, max=self.size[1] - TO_REMOVE) 380 | if remove_empty: 381 | box = self.bbox 382 | keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0]) 383 | return self[keep] 384 | return self 385 | 386 | def area(self, idx=None): 387 | """Get bounding box area. 388 | """ 389 | box = self.bbox if idx is None else self.bbox[idx].unsqueeze(0) 390 | if self.mode == "xyxy": 391 | TO_REMOVE = 1 392 | area = (box[:, 2] - box[:, 0] + TO_REMOVE) * \ 393 | (box[:, 3] - box[:, 1] + TO_REMOVE) 394 | elif self.mode == "xywh": 395 | area = box[:, 2] * box[:, 3] 396 | else: 397 | raise RuntimeError("Should not be here") 398 | 399 | return area 400 | 401 | def copy_with_fields(self, fields, skip_missing=False): 402 | """Provide deep copy of Boxlist with requested fields. 403 | """ 404 | bbox = BoxList(self.bbox, self.size, self.mode) 405 | if not isinstance(fields, (list, tuple)): 406 | fields = [fields] 407 | for field in fields: 408 | if self.has_field(field): 409 | bbox.add_field(field, self.get_field(field)) 410 | elif not skip_missing: 411 | raise KeyError("Field '{}' not found in {}".format( 412 | field, self)) 413 | return bbox 414 | 415 | def __repr__(self): 416 | s = self.__class__.__name__ + "(" 417 | s += "num_boxes={}, ".format(len(self)) 418 | s += "image_width={}, ".format(self.size[0]) 419 | s += "image_height={}, ".format(self.size[1]) 420 | s += "mode={})".format(self.mode) 421 | return s 422 | 423 | 424 | if __name__ == "__main__": 425 | bbox = BoxList([[0, 0, 10, 10], [0, 0, 5, 5]], (10, 10)) 426 | s_bbox = bbox.resize((5, 5)) 427 | print(s_bbox) 428 | print(s_bbox.bbox) 429 | 430 | t_bbox = bbox.transpose(0) 431 | print(t_bbox) 432 | print(t_bbox.bbox) 433 | -------------------------------------------------------------------------------- /realtime_panoptic/utils/boxlist_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # Copyright 2020 Toyota Research Institute. All rights reserved. 3 | 4 | # Adapted from maskrcnn-benchmark 5 | # https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/structures/boxlist_ops.py 6 | import torch 7 | from torchvision.ops.boxes import nms as _box_nms 8 | 9 | from .bounding_box import BoxList 10 | 11 | 12 | def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"): 13 | """Performs non-maximum suppression on a boxlist. 14 | The ranking scores are specified in a boxlist field via score_field. 15 | 16 | Parameters 17 | ---------- 18 | boxlist : BoxList 19 | Original boxlist 20 | 21 | nms_thresh : float 22 | NMS threshold 23 | 24 | max_proposals : int 25 | If > 0, then only the top max_proposals are kept after non-maximum suppression 26 | 27 | score_field : str 28 | Boxlist field to use during NMS score ranking. Field value needs to be numeric. 29 | """ 30 | if nms_thresh <= 0: 31 | return boxlist 32 | mode = boxlist.mode 33 | boxlist = boxlist.convert("xyxy") 34 | boxes = boxlist.bbox 35 | score = boxlist.get_field(score_field) 36 | keep = _box_nms(boxes, score, nms_thresh) 37 | if max_proposals > 0: 38 | keep = keep[:max_proposals] 39 | boxlist = boxlist[keep] 40 | return boxlist.convert(mode) 41 | 42 | 43 | def remove_small_boxes(boxlist, min_size): 44 | """Only keep boxes with both sides >= min_size 45 | 46 | Parameters 47 | ---------- 48 | boxlist : Boxlist 49 | Original boxlist 50 | 51 | min_size : int 52 | Max edge dimension of boxes to be kept. 53 | """ 54 | # TODO maybe add an API for querying the ws / hs 55 | xywh_boxes = boxlist.convert("xywh").bbox 56 | _, _, ws, hs = xywh_boxes.unbind(dim=1) 57 | keep = ((ws >= min_size) & (hs >= min_size)).nonzero().squeeze(1) 58 | return boxlist[keep] 59 | 60 | 61 | # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py 62 | # with slight modifications 63 | def boxlist_iou(boxlist1, boxlist2, optimize_memory=False): 64 | """Compute the intersection over union of two set of boxes. 65 | The box order must be (xmin, ymin, xmax, ymax). 66 | 67 | Parameters 68 | ---------- 69 | box1: BoxList 70 | Bounding boxes, sized [N,4]. 71 | box2: BoxList 72 | Bounding boxes, sized [M,4]. 73 | 74 | 75 | Returns 76 | ------- 77 | iou : tensor 78 | IoU of input boxes in matrix form, sized [N,M]. 79 | 80 | Reference: 81 | https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py 82 | """ 83 | if boxlist1.size != boxlist2.size: 84 | raise RuntimeError("boxlists should have same image size, got {}, {}".format(boxlist1, boxlist2)) 85 | 86 | N = len(boxlist1) 87 | M = len(boxlist2) 88 | 89 | area2 = boxlist2.area() 90 | 91 | if not optimize_memory: 92 | 93 | # If not optimizing memory, then following original ``maskrcnn-benchmark`` implementation 94 | 95 | area1 = boxlist1.area() 96 | 97 | box1, box2 = boxlist1.bbox, boxlist2.bbox 98 | 99 | lt = torch.max(box1[:, None, :2], box2[:, :2]) # shape: (N, M, 2) 100 | rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # shape: (N, M, 2) 101 | 102 | TO_REMOVE = 1 103 | 104 | wh = (rb - lt + TO_REMOVE).clamp(min=0) # shape: (N, M, 2) 105 | inter = wh[:, :, 0] * wh[:, :, 1] # shape: (N, M) 106 | 107 | iou = inter / (area1[:, None] + area2 - inter) 108 | 109 | else: 110 | 111 | # If optimizing memory, construct IoU matrix one box1 entry at a time 112 | # (in current usage this means one GT at a time) 113 | 114 | # Entry i of ious will hold the IoU between the ith box in boxlist1 and all boxes 115 | # in boxlist2 116 | ious = [] 117 | 118 | box2 = boxlist2.bbox 119 | 120 | for i in range(N): 121 | area1 = boxlist1.area(i) 122 | 123 | box1 = boxlist1.bbox[i].unsqueeze(0) 124 | 125 | lt = torch.max(box1[:, None, :2], box2[:, :2]) # shape: (1, M, 2) 126 | rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # shape: (1, M, 2) 127 | 128 | TO_REMOVE = 1 129 | wh = (rb - lt + TO_REMOVE).clamp(min=0) # shape: (1, M, 2) 130 | 131 | inter = wh[:, :, 0] * wh[:, :, 1] # shape: (1, M) 132 | 133 | iou = inter / (area1 + area2 - inter) 134 | 135 | ious.append(iou) 136 | 137 | iou = torch.cat(ious) # shape: (N, M) 138 | 139 | return iou 140 | 141 | 142 | def cat_boxlist(bboxes): 143 | """Concatenates a list of BoxList into a single BoxList 144 | image sizes needs to be same in this operation. 145 | 146 | Parameters 147 | ---------- 148 | bboxes : list[BoxList] 149 | """ 150 | assert isinstance(bboxes, (list, tuple)) 151 | assert all(isinstance(bbox, BoxList) for bbox in bboxes) 152 | 153 | size = bboxes[0].size 154 | assert all(bbox.size == size for bbox in bboxes) 155 | 156 | mode = bboxes[0].mode 157 | assert all(bbox.mode == mode for bbox in bboxes) 158 | 159 | fields = set(bboxes[0].fields()) 160 | assert all(set(bbox.fields()) == fields for bbox in bboxes) 161 | 162 | cat_boxes = BoxList(torch.cat([bbox.bbox for bbox in bboxes], dim=0), size, mode) 163 | 164 | for field in fields: 165 | data = torch.cat([bbox.get_field(field) for bbox in bboxes], dim=0) 166 | cat_boxes.add_field(field, data) 167 | 168 | return cat_boxes 169 | 170 | 171 | def pair_boxlist_iou(boxlist1, boxlist2): 172 | """Compute the intersection over union of two pairs of boxes. 173 | The box order must be (xmin, ymin, xmax, ymax). 174 | 175 | Parameters 176 | ---------- 177 | box1 : BoxList 178 | Bounding boxes, sized [N,4]. 179 | box2 : BoxList 180 | Bounding boxes, sized [N,4]. 181 | 182 | Returns 183 | ------- 184 | iou : tensor, 185 | Tensor of iou between the input pair of boxes. sized [N]. 186 | 187 | Reference: 188 | https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py 189 | """ 190 | if boxlist1.size != boxlist2.size: 191 | raise RuntimeError("boxlists should have same image size, got {}, {}".format(boxlist1, boxlist2)) 192 | 193 | assert len(boxlist1) == len(boxlist2), "Two boxlists should have same length" 194 | N = len(boxlist1) 195 | 196 | area2 = boxlist2.area() 197 | area1 = boxlist1.area() 198 | 199 | box1, box2 = boxlist1.bbox, boxlist2.bbox 200 | lt = torch.max(box1[:, :2], box2[:, :2]) # shape: (N, 2) 201 | rb = torch.min(box1[:, 2:], box2[:, 2:]) # shape: (N, 2) 202 | TO_REMOVE = 1 203 | wh = (rb - lt + TO_REMOVE).clamp(min=0) # shape: (N, 2) 204 | inter = wh[:, 0] * wh[:, 1] # shape: (N, 1) 205 | iou = inter / (area1 + area2 - inter) 206 | return iou 207 | -------------------------------------------------------------------------------- /realtime_panoptic/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | DETECTRON_PALETTE = np.array( 8 | [ 9 | 0.000, 0.447, 0.741, 10 | 0.850, 0.325, 0.098, 11 | 0.929, 0.694, 0.125, 12 | 0.494, 0.184, 0.556, 13 | 0.466, 0.674, 0.188, 14 | 0.301, 0.745, 0.933, 15 | 0.635, 0.078, 0.184, 16 | 0.300, 0.300, 0.300, 17 | 0.600, 0.600, 0.600, 18 | 1.000, 0.000, 0.000, 19 | 1.000, 0.500, 0.000, 20 | 0.749, 0.749, 0.000, 21 | 0.000, 1.000, 0.000, 22 | 0.000, 0.000, 1.000, 23 | 0.667, 0.000, 1.000, 24 | 0.333, 0.333, 0.000, 25 | 0.333, 0.667, 0.000, 26 | 0.333, 1.000, 0.000, 27 | 0.667, 0.333, 0.000, 28 | 0.667, 0.667, 0.000, 29 | 0.667, 1.000, 0.000, 30 | 1.000, 0.333, 0.000, 31 | 1.000, 0.667, 0.000, 32 | 1.000, 1.000, 0.000, 33 | 0.000, 0.333, 0.500, 34 | 0.000, 0.667, 0.500, 35 | 0.000, 1.000, 0.500, 36 | 0.333, 0.000, 0.500, 37 | 0.333, 0.333, 0.500, 38 | 0.333, 0.667, 0.500, 39 | 0.333, 1.000, 0.500, 40 | 0.667, 0.000, 0.500, 41 | 0.667, 0.333, 0.500, 42 | 0.667, 0.667, 0.500, 43 | 0.667, 1.000, 0.500, 44 | 1.000, 0.000, 0.500, 45 | 1.000, 0.333, 0.500, 46 | 1.000, 0.667, 0.500, 47 | 1.000, 1.000, 0.500, 48 | 0.000, 0.333, 1.000, 49 | 0.000, 0.667, 1.000, 50 | 0.000, 1.000, 1.000, 51 | 0.333, 0.000, 1.000, 52 | 0.333, 0.333, 1.000, 53 | 0.333, 0.667, 1.000, 54 | 0.333, 1.000, 1.000, 55 | 0.667, 0.000, 1.000, 56 | 0.667, 0.333, 1.000, 57 | 0.667, 0.667, 1.000, 58 | 0.667, 1.000, 1.000, 59 | 1.000, 0.000, 1.000, 60 | 1.000, 0.333, 1.000, 61 | 1.000, 0.667, 1.000, 62 | 0.167, 0.000, 0.000, 63 | 0.333, 0.000, 0.000, 64 | 0.500, 0.000, 0.000, 65 | 0.667, 0.000, 0.000, 66 | 0.833, 0.000, 0.000, 67 | 1.000, 0.000, 0.000, 68 | 0.000, 0.167, 0.000, 69 | 0.000, 0.333, 0.000, 70 | 0.000, 0.500, 0.000, 71 | 0.000, 0.667, 0.000, 72 | 0.000, 0.833, 0.000, 73 | 0.000, 1.000, 0.000, 74 | 0.000, 0.000, 0.167, 75 | 0.000, 0.000, 0.333, 76 | 0.000, 0.000, 0.500, 77 | 0.000, 0.000, 0.667, 78 | 0.000, 0.000, 0.833, 79 | 0.000, 0.000, 1.000, 80 | 0.000, 0.000, 0.000, 81 | 0.143, 0.143, 0.143, 82 | 0.286, 0.286, 0.286, 83 | 0.429, 0.429, 0.429, 84 | 0.571, 0.571, 0.571, 85 | 0.714, 0.714, 0.714, 86 | 0.857, 0.857, 0.857, 87 | 1.000, 1.000, 1.000 88 | ] 89 | ).astype(np.float32).reshape(-1, 3) * 255 90 | 91 | 92 | def visualize_segmentation_image(predictions, original_image, colormap, fade_weight=0.5): 93 | """Log a single segmentation result for visualization using a colormap. 94 | 95 | Overlays predicted classes on top of raw RGB image if given. 96 | 97 | Parameters: 98 | ----------- 99 | predictions: torch.cuda.LongTensor 100 | Per-pixel predicted class ID's for a single input image 101 | Shape: (H, W) 102 | 103 | original_image: np.array 104 | HxWx3 original image. or None 105 | 106 | colormap: np.array 107 | (N+1)x3 array colormap,where N+1 equals to the number of classes. 108 | 109 | fade_weight: float, default: 0.8 110 | Visualization is fade_weight * original_image + (1 - fade_weight) * predictions 111 | 112 | Returns: 113 | -------- 114 | visualized_image: np.array 115 | Semantic semantic visualization color coded by classes. 116 | The visualization will be overlaid on a the RGB image if given. 117 | """ 118 | 119 | # ``original_image`` has shape (H, W,3) 120 | if not isinstance(original_image, np.ndarray): 121 | original_image = np.array(original_image) 122 | original_image_height, original_image_width,_ = original_image.shape 123 | 124 | # Grab colormap from dataset for the given number of segmentation classes 125 | # (uses black for the IGNORE class) 126 | # ``colormap`` has shape (num_classes + 1, 3) 127 | 128 | # Color per-pixel predictions using the generated color map 129 | # ``colored_predictions_numpy`` has shape (H, W, 3) 130 | predictions_numpy = predictions.cpu().numpy().astype('uint8') 131 | colored_predictions_numpy = colormap[predictions_numpy.flatten()] 132 | colored_predictions_numpy = colored_predictions_numpy.reshape(original_image_height, original_image_width, 3) 133 | 134 | # Overlay images and predictions 135 | overlaid_predictions = original_image * fade_weight + colored_predictions_numpy * (1 - fade_weight) 136 | 137 | visualized_image = overlaid_predictions.astype('uint8') 138 | return visualized_image 139 | 140 | def random_color(base, max_dist=30): 141 | """Generate random color close to a given base color. 142 | 143 | Parameters: 144 | ----------- 145 | base: array_like 146 | Base color for random color generation 147 | 148 | max_dist: int 149 | Max distance from generated color to base color on all RGB axis. 150 | 151 | Returns: 152 | -------- 153 | random_color: tuple 154 | 3 channel random color around the given base color. 155 | """ 156 | base = np.array(base) 157 | new_color = base + np.random.randint(low=-max_dist, high=max_dist + 1, size=3) 158 | return tuple(np.maximum(0, np.minimum(255, new_color))) 159 | 160 | def draw_mask(im, mask, alpha=0.5, color=None): 161 | """Overlay a mask on top of the image. 162 | 163 | Parameters: 164 | ----------- 165 | im: array_like 166 | A 3-channel uint8 image 167 | 168 | mask: array_like 169 | A binary 1-channel image of the same size 170 | 171 | color: bool 172 | If None, will choose automatically 173 | 174 | alpha: float 175 | mask intensity 176 | 177 | Returns: 178 | -------- 179 | im: np.array 180 | Image overlaid by masks. 181 | 182 | color: list 183 | Color used for masks. 184 | """ 185 | if color is None: 186 | color = DETECTRON_PALETTE[np.random.choice(len(DETECTRON_PALETTE))][::-1] 187 | color = np.asarray(color, dtype=np.int64) 188 | im = np.where(np.repeat((mask > 0)[:, :, None], 3, axis=2), im * (1 - alpha) + color * alpha, im) 189 | im = im.astype('uint8') 190 | return im, color.tolist() 191 | 192 | def visualize_detection_image(predictions, original_image, label_id_to_names, fade_weight=0.8): 193 | """Log a single detection result for visualization. 194 | 195 | Overlays predicted classes on top of raw RGB image. 196 | 197 | Parameters: 198 | ----------- 199 | predictions: torch.cuda.LongTensor 200 | Per-pixel predicted class ID's for a single input image 201 | Shape: (H, W) 202 | 203 | original_image: np.array 204 | HxWx3 original image. or None 205 | 206 | label_id_to_names: list 207 | list of class names for instance labels 208 | 209 | fade_weight: float, default: 0.8 210 | Visualization is fade_weight * original_image + (1 - fade_weight) * predictions 211 | 212 | Returns: 213 | ------- 214 | visualized_image: np.array 215 | Visualized image with detection results. 216 | """ 217 | 218 | # Load raw image using provided dataset and index 219 | # ``images_numpy`` has shape (H, W, 3) 220 | # ``images_numpy`` has shape (H, W,3) 221 | if not isinstance(original_image, np.ndarray): 222 | original_image = np.array(original_image) 223 | original_image_height, original_image_width,_ = original_image.shape 224 | 225 | # overlay_boxes 226 | visualized_image = copy.copy(np.array(original_image)) 227 | 228 | labels = predictions.get_field("labels").to("cpu") 229 | boxes = predictions.bbox 230 | 231 | dtype = labels.dtype 232 | palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1]).to(dtype) 233 | colors = labels[:, None] * palette 234 | colors = (colors % 255).numpy().astype("uint8") 235 | masks = None 236 | if predictions.has_field("mask"): 237 | masks = predictions.get_field("mask") 238 | else: 239 | masks = [None] * len(boxes) 240 | # overlay_class_names_and_score 241 | if predictions.has_field("scores"): 242 | scores = predictions.get_field("scores").tolist() 243 | else: 244 | scores = [1.0] * len(boxes) 245 | # predicted label starts from 1 as 0 is reserved for background. 246 | label_names = [label_id_to_names[i-1] for i in labels.tolist()] 247 | 248 | text_template = "{}: {:.2f}" 249 | 250 | for box, color, score, mask, label in zip(boxes, colors, scores, masks, label_names): 251 | if score < 0.5: 252 | continue 253 | box = box.to(torch.int64) 254 | color = random_color(color) 255 | color = tuple(map(int, color)) 256 | 257 | if mask is not None: 258 | thresh = (mask > 0.5).cpu().numpy().astype('uint8') 259 | visualized_image, color = draw_mask(visualized_image, thresh) 260 | 261 | x, y = box[:2] 262 | s = text_template.format(label, score) 263 | cv2.putText(visualized_image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1) 264 | 265 | top_left, bottom_right = box[:2].tolist(), box[2:].tolist() 266 | visualized_image = cv2.rectangle(visualized_image, tuple(top_left), tuple(bottom_right), tuple(color), 1) 267 | return visualized_image -------------------------------------------------------------------------------- /scripts/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Toyota Research Institute. All rights reserved. 2 | 3 | # This script provides a demo inference a model trained on Cityscapes dataset. 4 | import warnings 5 | import argparse 6 | import torch 7 | import numpy as np 8 | from PIL import Image 9 | from torchvision.models.detection.image_list import ImageList 10 | 11 | from realtime_panoptic.models.rt_pano_net import RTPanoNet 12 | from realtime_panoptic.config import cfg 13 | import realtime_panoptic.data.panoptic_transform as P 14 | from realtime_panoptic.utils.visualization import visualize_segmentation_image,visualize_detection_image 15 | 16 | cityscapes_colormap = np.array([ 17 | [128, 64, 128], 18 | [244, 35, 232], 19 | [ 70, 70, 70], 20 | [102, 102, 156], 21 | [190, 153, 153], 22 | [153, 153, 153], 23 | [250 ,170, 30], 24 | [220, 220, 0], 25 | [107, 142, 35], 26 | [152, 251, 152], 27 | [ 70, 130, 180], 28 | [220, 20, 60], 29 | [255, 0, 0], 30 | [ 0, 0, 142], 31 | [ 0, 0, 70], 32 | [ 0, 60, 100], 33 | [ 0, 80, 100], 34 | [ 0, 0, 230], 35 | [119, 11, 32], 36 | [ 0, 0, 0]]) 37 | 38 | cityscapes_instance_label_name = ['person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] 39 | warnings.filterwarnings("ignore", category=UserWarning) 40 | 41 | def demo(): 42 | # Parse the input arguments. 43 | parser = argparse.ArgumentParser(description="Simple demo for real-time-panoptic model") 44 | parser.add_argument("--config-file", metavar="FILE", help="path to config", required=True) 45 | parser.add_argument("--pretrained-weight", metavar="FILE", help="path to pretrained_weight", required=True) 46 | parser.add_argument("--input", metavar="FILE", help="path to jpg/png file", required=True) 47 | parser.add_argument("--device", help="inference device", default='cuda') 48 | args = parser.parse_args() 49 | 50 | # General config object from given config files. 51 | cfg.merge_from_file(args.config_file) 52 | 53 | # Initialize model. 54 | model = RTPanoNet( 55 | backbone=cfg.model.backbone, 56 | num_classes=cfg.model.panoptic.num_classes, 57 | things_num_classes=cfg.model.panoptic.num_thing_classes, 58 | pre_nms_thresh=cfg.model.panoptic.pre_nms_thresh, 59 | pre_nms_top_n=cfg.model.panoptic.pre_nms_top_n, 60 | nms_thresh=cfg.model.panoptic.nms_thresh, 61 | fpn_post_nms_top_n=cfg.model.panoptic.fpn_post_nms_top_n, 62 | instance_id_range=cfg.model.panoptic.instance_id_range) 63 | device = args.device 64 | model.to(device) 65 | model.load_state_dict(torch.load(args.pretrained_weight)) 66 | 67 | # Print out mode architecture for sanity checking. 68 | print(model) 69 | 70 | # Prepare for model inference. 71 | model.eval() 72 | input_image = Image.open(args.input) 73 | data = {'image': input_image} 74 | # data pre-processing 75 | normalize_transform = P.Normalize(mean=cfg.input.pixel_mean, std=cfg.input.pixel_std, to_bgr255=cfg.input.to_bgr255) 76 | transform = P.Compose([ 77 | P.ToTensor(), 78 | normalize_transform, 79 | ]) 80 | data = transform(data) 81 | print("Done with data preparation and model configuration.") 82 | with torch.no_grad(): 83 | input_image_list = ImageList([data['image'].to(device)], image_sizes=[input_image.size[::-1]]) 84 | panoptic_result, _ = model.forward(input_image_list) 85 | print("Done with model inference.") 86 | print("Process and visualizing the outputs...") 87 | instance_detection = [o.to('cpu') for o in panoptic_result["instance_segmentation_result"]] 88 | semseg_logics = [o.to('cpu') for o in panoptic_result["semantic_segmentation_result"]] 89 | semseg_prob = [torch.argmax(semantic_logit , dim=0) for semantic_logit in semseg_logics] 90 | 91 | seg_vis = visualize_segmentation_image(semseg_prob[0], input_image, cityscapes_colormap) 92 | Image.fromarray(seg_vis.astype('uint8')).save('semantic_segmentation_result.jpg') 93 | print("Saved semantic segmentation visualization in semantic_segmentation_result.jpg") 94 | det_vis = visualize_detection_image(instance_detection[0], input_image, cityscapes_instance_label_name) 95 | Image.fromarray(det_vis.astype('uint8')).save('instance_segmentation_result.jpg') 96 | print("Saved instance segmentation visualization in instance_segmentation_result.jpg") 97 | print("Demo finished.") 98 | 99 | if __name__ == "__main__": 100 | demo() 101 | --------------------------------------------------------------------------------