├── LICENSE
├── Makefile
├── README.md
├── configs
└── demo_config.yaml
├── docker
└── Dockerfile
├── media
└── figs
│ ├── panoptic-teaser.gif
│ ├── test.png
│ └── tri-logo.png
├── realtime_panoptic
├── __init__.py
├── config
│ ├── __init__.py
│ └── defaults.py
├── data
│ ├── __init__.py
│ └── panoptic_transform.py
├── layers
│ ├── __init__.py
│ └── scale.py
├── models
│ ├── __init__.py
│ ├── backbones.py
│ ├── panoptic_from_dense_box.py
│ └── rt_pano_net.py
└── utils
│ ├── __init__.py
│ ├── bounding_box.py
│ ├── boxlist_ops.py
│ └── visualization.py
└── scripts
└── demo.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Toyota Research Institute - Machine Learning
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # Makefile for launching common tasks
2 |
3 | PYTHON ?= python
4 | DOCKER_OPTS ?= \
5 | -v /dev/shm:/dev/shm \
6 | -v /root/.ssh:/root/.ssh \
7 | -v /var/run/docker.sock:/var/run/docker.sock \
8 | --network=host \
9 | --privileged
10 | PACKAGE_NAME ?= panoptic
11 | WORKSPACE ?= /workspace/$(PACKAGE_NAME)
12 | DOCKER_IMAGE_NAME ?= $(PACKAGE_NAME)
13 | DOCKER_IMAGE ?= $(DOCKER_IMAGE_NAME):latest
14 |
15 | all: clean test
16 |
17 | clean:
18 | find . -name "*.pyc" | xargs rm -f && \
19 | find . -name "__pycache__" | xargs rm -rf
20 |
21 | clean-logs:
22 | find . -name "tensorboardx" | xargs rm -rf && \
23 | find . -name "wandb" | xargs rm -rf
24 |
25 | test:
26 | PYTHONPATH=${PWD}/tests:${PYTHONPATH} python -m unittest discover -s tests
27 |
28 | docker-build:
29 | docker build \
30 | -f docker/Dockerfile \
31 | -t ${DOCKER_IMAGE} .
32 |
33 | docker-run-test-sample:
34 | nvidia-docker run --name panoptic --rm \
35 | -e DISPLAY=${DISPLAY} \
36 | -v ${PWD}:${WORKSPACE} \
37 | -v /tmp/.X11-unix:/tmp/.X11-unix \
38 | -v ~/.torch:/root/.torch \
39 | -p 8888:8888 \
40 | -p 6006:6006 \
41 | -p 5000:5000 \
42 | -it \
43 | -v ${PWD}:${WORKSPACE} \
44 | ${DOCKER_OPTS} \
45 | ${DOCKER_IMAGE} bash -c \
46 | "wget -P /workspace/panoptic/ -c https://tri-ml-public.s3.amazonaws.com/github/realtime_panoptic/models/cvpr_realtime_pano_cityscapes_standalone_no_prefix.pth && \
47 | python scripts/demo.py \
48 | --config-file configs/demo_config.yaml \
49 | --input media/figs/test.png \
50 | --pretrained-weight cvpr_realtime_pano_cityscapes_standalone_no_prefix.pth"
51 |
52 | docker-start:
53 | nvidia-docker run --name panoptic --rm \
54 | -e DISPLAY=${DISPLAY} \
55 | -v ${PWD}:${WORKSPACE} \
56 | -v /tmp/.X11-unix:/tmp/.X11-unix \
57 | -v /data:/data \
58 | -v ~/.torch:/root/.torch \
59 | -p 8888:8888 \
60 | -p 6006:6006 \
61 | -p 5000:5000 \
62 | -d \
63 | -it \
64 | ${DOCKER_OPTS} \
65 | ${DOCKER_IMAGE} && \
66 | nvidia-docker exec -it panoptic bash
67 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Real-Time Panoptic Segmentation from Dense Detections
2 |
3 | Official [PyTorch](https://pytorch.org/) implementation of the CVPR 2020 Oral **Real-Time Panoptic Segmentation from Dense Detections** by the ML Team at [Toyota Research Institute (TRI)](https://www.tri.global/), cf. [References](#references) below.
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | ## Install
14 | ```
15 | git clone https://github.com/TRI-ML/realtime_panoptic.git
16 | cd realtime_panoptic
17 | make docker-build
18 | ```
19 |
20 | To verify your installation, you can also run our simple test run to conduct inference on 1 test image using our Cityscapes pretrained model:
21 | ```
22 | make docker-run-test-sample
23 | ```
24 |
25 | Now you can start a docker container with interactive mode:
26 | ```
27 | make docker-start
28 | ```
29 | ## Demo
30 | We provide demo code to conduct inference on Cityscapes pretrained model.
31 | ```
32 | python scripts/demo.py --config-file --input \
33 | --pretrained-weight
34 | ```
35 | Simple user example using our pretrained model previded in the Models section:
36 | ```
37 | python scripts/demo.py --config-file ./configs/demo_config.yaml --input media/figs/test.png --pretrained-weight cvpr_realtime_pano_cityscapes_standalone_no_prefix.pth
38 | ```
39 |
40 | ## Models
41 |
42 |
43 | ### Cityscapes
44 | | Model | PQ | PQ_th | PQ_st |
45 | | :--- | :---: | :---: | :---: |
46 | | [ResNet-50](https://tri-ml-public.s3.amazonaws.com/github/realtime_panoptic/models/cvpr_realtime_pano_cityscapes_standalone_no_prefix.pth) | 58.8 | 52.1| 63.7 |
47 |
48 | ## License
49 |
50 | The source code is released under the [MIT license](LICENSE.md).
51 |
52 | ## References
53 |
54 | #### Real-Time Panoptic Segmentation from Dense Detections (CVPR 2020 oral)
55 | *Rui Hou\*, Jie Li\*, Arjun Bhargava, Allan Raventos, Vitor Guizilini, Chao Fang, Jerome Lynch, Adrien Gaidon*, [**[paper]**](https://arxiv.org/abs/1912.01202), [**[oral presentation]**](https://www.youtube.com/watch?v=xrxaRU2g2vo), [**[teaser]**](https://www.youtube.com/watch?v=_N4kGJEg-rM)
56 | ```
57 | @InProceedings{real-time-panoptic,
58 | author = {Hou, Rui and Li, Jie and Bhargava, Arjun and Raventos, Allan and Guizilini, Vitor and Fang, Chao and Lynch, Jerome and Gaidon, Adrien},
59 | title = {Real-Time Panoptic Segmentation From Dense Detections},
60 | booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
61 | month = {June},
62 | year = {2020}
63 | }
64 | ```
65 |
--------------------------------------------------------------------------------
/configs/demo_config.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Toyota Research Institute. All rights reserved.
2 | model:
3 | name: 'Cityscape_realtime_panoptic'
4 | backbone: 'R-50-FPN-RETINANET'
5 | panoptic:
6 | num_classes: 19
7 | num_thing_classes: 8
8 | fpn_post_nms_top_n: 100
9 | instance_id_range: (11, 18)
10 | pre_nms_thresh: 0.06
11 |
12 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # ----------------------------------------------- Base Docker Image -----------------------------------------------
2 | FROM nvidia/cuda:10.0-devel-ubuntu18.04
3 |
4 | ENV PYTORCH_VERSION=1.1.0
5 | ENV TORCHVISION_VERSION=0.3.0
6 | ENV CUDNN_VERSION=7.6.0.64-1+cuda10.0
7 | ENV NCCL_VERSION=2.4.7-1+cuda10.0
8 |
9 | # Workaround for deadlock issue. To be removed with next major Pytorch release
10 | ENV NCCL_LL_THRESHOLD=0
11 |
12 | # Python 2.7 or 3.6 is supported by Ubuntu Bionic out of the box
13 | ARG python=3.6
14 | ENV PYTHON_VERSION=${python}
15 | ENV DEBIAN_FRONTEND=noninteractive
16 |
17 | # Set default shell to /bin/bash
18 | SHELL ["/bin/bash", "-cu"]
19 |
20 | RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends \
21 | build-essential \
22 | cmake \
23 | g++-4.8 \
24 | git \
25 | curl \
26 | docker.io \
27 | vim \
28 | wget \
29 | ca-certificates \
30 | libcudnn7=${CUDNN_VERSION} \
31 | libnccl2=${NCCL_VERSION} \
32 | libnccl-dev=${NCCL_VERSION} \
33 | libjpeg-dev \
34 | libpng-dev \
35 | python${PYTHON_VERSION} \
36 | python${PYTHON_VERSION}-dev \
37 | python3-tk \
38 | librdmacm1 \
39 | libibverbs1 \
40 | ibverbs-providers \
41 | libgtk2.0-dev \
42 | unzip \
43 | bzip2 \
44 | htop
45 |
46 |
47 | # Instal Python and pip
48 | RUN if [[ "${PYTHON_VERSION}" == "3.6" ]]; then \
49 | apt-get install -y python${PYTHON_VERSION}-distutils; \
50 | fi
51 |
52 | RUN ln -sf /usr/bin/python${PYTHON_VERSION} /usr/bin/python
53 |
54 | RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
55 | python get-pip.py && \
56 | rm get-pip.py
57 |
58 | # Install PyTorch
59 | RUN pip install future typing numpy awscli
60 | RUN pip install https://download.pytorch.org/whl/cu100/torch-${PYTORCH_VERSION}-cp36-cp36m-linux_x86_64.whl
61 | RUN pip install https://download.pytorch.org/whl/cu100/torchvision-${TORCHVISION_VERSION}-cp36-cp36m-linux_x86_64.whl
62 |
63 |
64 | # Configure environment variables - default working directory is "/workspace"
65 | WORKDIR /workspace
66 | ENV PYTHONPATH="/workspace"
67 |
68 | # Install dependencies
69 | RUN pip install ninja yacs cython matplotlib opencv-python tqdm onnx onnxruntime coloredlogs scipy pycuda
70 | RUN pip uninstall -y pillow
71 | RUN pip install pillow-simd==6.2.2.post1 pycocotools
72 |
73 | # Install apex
74 | WORKDIR /workspace
75 | RUN git clone https://github.com/NVIDIA/apex.git
76 | WORKDIR /workspace/apex
77 | RUN git checkout 82dac9c9419035110d1ccc49b2608681337903ed
78 | RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
79 |
80 | # Copy repo
81 | ENV PYTHONPATH="/workspace/panoptic:$PYTHONPATH"
82 | COPY . /workspace/panoptic
83 | WORKDIR /workspace/panoptic
84 |
--------------------------------------------------------------------------------
/media/figs/panoptic-teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TRI-ML/realtime_panoptic/97ee03657c82b19a628f52764703cb1c693d08a0/media/figs/panoptic-teaser.gif
--------------------------------------------------------------------------------
/media/figs/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TRI-ML/realtime_panoptic/97ee03657c82b19a628f52764703cb1c693d08a0/media/figs/test.png
--------------------------------------------------------------------------------
/media/figs/tri-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TRI-ML/realtime_panoptic/97ee03657c82b19a628f52764703cb1c693d08a0/media/figs/tri-logo.png
--------------------------------------------------------------------------------
/realtime_panoptic/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Toyota Research Institute. All rights reserved.
2 |
--------------------------------------------------------------------------------
/realtime_panoptic/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .defaults import cfg
--------------------------------------------------------------------------------
/realtime_panoptic/config/defaults.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Toyota Research Institute. All rights reserved.
2 |
3 | import os
4 |
5 | from yacs.config import CfgNode as CN
6 | cfg = CN()
7 | ########################################################################################################################
8 | ### MODEL
9 | ########################################################################################################################
10 | cfg.model = CN()
11 | cfg.model.name = '' # Training model
12 | cfg.model.backbone = '' # Backbone
13 | cfg.model.checkpoint_path = '' # Checkpoint path for model saving
14 |
15 | cfg.model.panoptic = CN()
16 | cfg.model.panoptic.num_classes = 19 # number of total classes
17 | cfg.model.panoptic.num_thing_classes = 8 # number of thing classes
18 | cfg.model.panoptic.pre_nms_thresh = 0.05 # objectness threshold before NMS
19 | cfg.model.panoptic.pre_nms_top_n = 1000 # max num of accepted bboxes before NMS
20 | cfg.model.panoptic.nms_thresh = 0.6 # NMS threshold
21 | cfg.model.panoptic.fpn_post_nms_top_n = 100 # Top detection post NMS
22 | # for cityscapes, it is (11,18), for COCO it is (0,79), for Vistas it is (0,36)
23 | cfg.model.panoptic.instance_id_range = (11, 18)
24 |
25 |
26 | ########################################################################################################################
27 | ### INPUT
28 | ########################################################################################################################
29 | cfg.input = CN()
30 | cfg.input.pixel_mean = [102.9801, 115.9465, 122.7717]
31 | cfg.input.pixel_std = [1., 1., 1.]
32 | # Convert image to BGR format, in range 0-255
33 | cfg.input.to_bgr255 = True
34 |
--------------------------------------------------------------------------------
/realtime_panoptic/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TRI-ML/realtime_panoptic/97ee03657c82b19a628f52764703cb1c693d08a0/realtime_panoptic/data/__init__.py
--------------------------------------------------------------------------------
/realtime_panoptic/data/panoptic_transform.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Toyota Research Institute. All rights reserved.
2 | import random
3 |
4 | import torchvision.transforms as torchvision_transforms
5 | from PIL import Image
6 | from torchvision.transforms import functional as F
7 |
8 |
9 | class Compose:
10 | """Compose as set of data transform operations
11 |
12 | Parameters:
13 | -----------
14 | transforms: list
15 | A list of transform operations.
16 |
17 | Returns:
18 | -------
19 | data: Dict
20 | The final output data after the given set of transforms.
21 | """
22 | def __init__(self, transforms):
23 | self.transforms = transforms
24 |
25 | def __call__(self, data):
26 | for t in self.transforms:
27 | data = t(data)
28 | return data
29 |
30 | def __repr__(self):
31 | format_string = self.__class__.__name__ + "("
32 | for t in self.transforms:
33 | format_string += "\n"
34 | format_string += " {0}".format(t)
35 | format_string += "\n)"
36 | return format_string
37 |
38 |
39 | class Resize:
40 | """Resize operation on panoptic data.
41 | The input data will be resized by randomly choose a minimum side length from
42 | a given set with the maximum side capped by a given length.
43 |
44 | Parameters:
45 | ----------
46 | min_size: list or tuple
47 | A list of size to be chosen for image minimum side.
48 |
49 | max_size: int
50 | Maximum side length of the processed image
51 |
52 | """
53 | def __init__(self, min_size, max_size, is_train=True):
54 | if not isinstance(min_size, (list, tuple)):
55 | min_size = (min_size, )
56 | self.min_size = min_size
57 | self.max_size = max_size
58 | self.is_train = is_train
59 |
60 | # modified from torchvision to add support for max size
61 | # NOTE: this method will always try to make the smaller size match ``self.min_size``
62 | # so in the case of Vistas, this will mean evaluating at a significantly lower resolution
63 | # than the original image for some images
64 | def get_size(self, image_size):
65 | w, h = image_size
66 | size = random.choice(self.min_size)
67 | max_size = self.max_size
68 | if max_size is not None:
69 | min_original_size = float(min((w, h)))
70 | max_original_size = float(max((w, h)))
71 | if max_original_size / min_original_size * size > max_size:
72 | size = int(
73 | round(max_size * min_original_size / max_original_size))
74 |
75 | if (w <= h and w == size) or (h <= w and h == size):
76 | return (h, w)
77 |
78 | if w < h:
79 | ow = size
80 | oh = int(size * h / w)
81 | else:
82 | oh = size
83 | ow = int(size * w / h)
84 |
85 | return (oh, ow)
86 |
87 | def __call__(self, data):
88 | size = self.get_size(data["image"].size)
89 | data["image"] = F.resize(data["image"], size)
90 | if self.is_train:
91 | if "segmentation_target" in data:
92 | data["segmentation_target"] = F.resize(
93 | data["segmentation_target"],
94 | size,
95 | interpolation=Image.NEAREST)
96 | if "detection_target" in data:
97 | data["detection_target"] = data["detection_target"].resize(
98 | data["image"].size)
99 | return data
100 |
101 |
102 | class RandomHorizontalFlip:
103 | """Randomly Flip the input data with given probability.
104 |
105 | Parameters:
106 | ----------
107 | prob: float
108 | A probability to flip the data in [0,1].
109 | """
110 | def __init__(self, prob=0.5):
111 | self.prob = prob
112 |
113 | def __call__(self, data):
114 | if random.random() < self.prob:
115 | data["image"] = F.hflip(data["image"])
116 | if "detection_target" in data:
117 | data["detection_target"] = data["detection_target"].transpose(
118 | 0)
119 | if "segmentation_target" in data:
120 | data["segmentation_target"] = F.hflip(
121 | data["segmentation_target"])
122 | return data
123 |
124 |
125 | class ToTensor:
126 | """Convert the input data to Tensor.
127 | """
128 | def __call__(self, data):
129 | data["image"] = F.to_tensor(data["image"])
130 | if "segmentation_target" in data:
131 | data["segmentation_target"] = F.to_tensor(
132 | data["segmentation_target"])
133 | return data
134 |
135 |
136 | class ColorJitter:
137 | """Apply color jittering to input image.
138 | """
139 | def __call__(self, data):
140 | data["image"] = torchvision_transforms.ColorJitter().__call__(
141 | data["image"])
142 | return data
143 |
144 |
145 | class Normalize:
146 | """Normalize the input image with options of RGB/BGR converting.
147 |
148 | Parameters:
149 | ----------
150 | mean: list
151 | Mean value for the 3 image channels.
152 |
153 | std: list
154 | Standard deviation for the 3 image channels.
155 |
156 | to_bgr255: bool
157 | If true, the default image come in with rgb channel and [0,1] scale.
158 | it will be converted into bgr with [0,255] scale.
159 | """
160 | def __init__(self, mean, std, to_bgr255=True):
161 | self.mean = mean
162 | self.std = std
163 | self.to_bgr255 = to_bgr255
164 |
165 | def __call__(self, data):
166 | if self.to_bgr255:
167 | data["image"] = data["image"][[2, 1, 0]] * 255
168 | if "segmentation_target" in data:
169 | data["segmentation_target"] = (
170 | data["segmentation_target"] * 255).long()
171 | data["image"] = F.normalize(
172 | data["image"], mean=self.mean, std=self.std)
173 | return data
174 |
175 |
176 | class RandomCrop:
177 | """Randomly Crop in input panoptic data.
178 |
179 | Parameters:
180 | ----------
181 | crop_size: tuple
182 | Desired crop size of the data.
183 | """
184 | def __init__(self, crop_size):
185 | # A couple of safety checks
186 | assert isinstance(crop_size, tuple)
187 | self.crop_size = crop_size
188 | self.crop = torchvision_transforms.RandomCrop(crop_size)
189 |
190 | def __call__(self, data):
191 | if len(self.crop_size) <= 1:
192 | return data
193 |
194 | # If image size is smaller than crop size,
195 | # resize both image and target to at least crop size.
196 | if self.crop_size[0] > data["image"].size[0] or self.crop_size[
197 | 1] > data["image"].size[1]:
198 | print("Image will be resized before cropping. {},{}".format(
199 | self.crop_size, data["image"].size))
200 | resize_func = Resize(
201 | max(self.crop_size),
202 | round(
203 | max(data["image"].size) * max(self.crop_size) / min(
204 | data["image"].size)))
205 | data = resize_func(data)
206 |
207 | if "detection_target" not in data:
208 |
209 | image_width, image_height = data["image"].size
210 | crop_width, crop_height = self.crop_size
211 | assert image_width >= crop_width and image_height >= crop_height
212 |
213 | left = 0
214 | if image_width > crop_width:
215 | left = random.randint(0, image_width - crop_width)
216 | top = 0
217 | if image_height > crop_height:
218 | top = random.randint(0, image_height - crop_height)
219 |
220 | data["image"] = data["image"].crop((left, top, left + crop_width,
221 | top + crop_height))
222 | if "segmentation_target" in data:
223 | data["segmentation_target"] = data["segmentation_target"].crop(
224 | (left, top, left + crop_width, top + crop_height))
225 |
226 | else:
227 | # We always crop an area that contains at least one instance.
228 | # TODO: We are making an assumption here that data are filtered to only include
229 | # non-empty training samples. So the while loop will not be a dead lock. Need to
230 | # improve the efficiency of this part.
231 | while True:
232 | # continuously try till there's instance inside it.
233 | w, h = data["image"].size
234 | if w <= self.crop_size[0] or h <= self.crop_size[1]:
235 | break
236 | # y, x, h, w
237 | top, left, crop_height, crop_width = self.crop.get_params(
238 | data["image"], (self.crop_size[1], self.crop_size[0]))
239 |
240 | image_cropped = F.crop(data["image"], top, left, crop_height,
241 | crop_width)
242 |
243 | detection_target_cropped = data[
244 | "detection_target"].augmentation_crop(
245 | top, left, crop_height, crop_width)
246 |
247 | if detection_target_cropped is not None:
248 | data["image"] = image_cropped
249 | data["detection_target"] = detection_target_cropped
250 |
251 | # Once ``detection_target`` gets properly cropped, then crop
252 | # ``segmentation_target`` using the same parameters
253 | if "segmentation_target" in data:
254 | data["segmentation_target"] = F.crop(
255 | data["segmentation_target"], top, left,
256 | crop_height, crop_width)
257 | break
258 | return data
--------------------------------------------------------------------------------
/realtime_panoptic/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Toyota Research Institute. All rights reserved.
2 |
--------------------------------------------------------------------------------
/realtime_panoptic/layers/scale.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Toyota Research Institute. All rights reserved.
2 |
3 | # Adapted from FCOS
4 | # https://github.com/tianzhi0549/FCOS/blob/master/fcos_core/layers/scale.py
5 |
6 | import torch
7 | from torch import nn
8 |
9 |
10 | class Scale(nn.Module):
11 | def __init__(self, init_value=1.0):
12 | super(Scale, self).__init__()
13 | """Scale layer with trainable scale factor.
14 |
15 | Parameters
16 | ----------
17 | init_value: float
18 | Initial value of the scale factor.
19 | """
20 | self.scale = nn.Parameter(torch.FloatTensor([init_value]))
21 |
22 | def forward(self, input):
23 | return input * self.scale
24 |
--------------------------------------------------------------------------------
/realtime_panoptic/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Toyota Research Institute. All rights reserved.
2 |
--------------------------------------------------------------------------------
/realtime_panoptic/models/backbones.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Toyota Research Institute. All rights reserved.
2 | from torch import nn
3 | from torchvision.models import resnet
4 | from torchvision.models._utils import IntermediateLayerGetter
5 | from torchvision.ops import misc as misc_nn_ops
6 | from torchvision.ops.feature_pyramid_network import (FeaturePyramidNetwork,
7 | LastLevelP6P7)
8 |
9 | # This class is adapted from "BackboneWithFPN" in torchvision.models.detection.backbone_utils
10 |
11 | class ResNetWithModifiedFPN(nn.Module):
12 | """Adds a p67-FPN on top of a ResNet model with more options.
13 |
14 | We adopt this function from torchvision.models.detection.backbone_utils.
15 | Modification has been added to enable RetinaNet style FPN with P6 P7 as extra blocks.
16 |
17 | Parameters
18 | ----------
19 | backbone_name: string
20 | Resnet architecture supported by torchvision. Possible values are 'ResNet', 'resnet18', 'resnet34', 'resnet50',
21 | 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
22 |
23 | norm_layer: torchvision.ops
24 | It is recommended to use the default value. For details visit:
25 | (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
26 |
27 | pretrained: bool
28 | If True, returns a model with backbone pre-trained on Imagenet. Default: False
29 |
30 | trainable_layers: int
31 | Number of trainable (not frozen) resnet layers starting from final block.
32 | Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
33 |
34 | out_channels: int
35 | number of channels in the FPN.
36 | """
37 |
38 | def __init__(self, backbone_name, pretrained=False, norm_layer=misc_nn_ops.FrozenBatchNorm2d, trainable_layers=3, out_channels = 256):
39 | super().__init__()
40 | # Get ResNet
41 | backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
42 | # select layers that wont be frozen
43 | assert 0 <= trainable_layers <= 5
44 | layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers]
45 | # freeze layers only if pretrained backbone is used
46 | for name, parameter in backbone.named_parameters():
47 | if all([not name.startswith(layer) for layer in layers_to_train]):
48 | parameter.requires_grad_(False)
49 |
50 | return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}
51 |
52 | in_channels_stage2 = backbone.inplanes // 8
53 | self.in_channels_list = [
54 | 0,
55 | in_channels_stage2 * 2,
56 | in_channels_stage2 * 4,
57 | in_channels_stage2 * 8,
58 | ]
59 |
60 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
61 | self.fpn = FeaturePyramidNetwork(
62 | in_channels_list=self.in_channels_list[1:], # nonzero only
63 | out_channels=out_channels,
64 | extra_blocks=LastLevelP6P7(out_channels, out_channels),
65 | )
66 | self.out_channels = out_channels
67 |
68 | def forward(self, x):
69 | x = self.body(x)
70 | keys = list(x.keys())
71 | for idx, key in enumerate(keys):
72 | if self.in_channels_list[idx] == 0:
73 | del x[key]
74 | x = self.fpn(x)
75 | return x
76 |
--------------------------------------------------------------------------------
/realtime_panoptic/models/panoptic_from_dense_box.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Toyota Research Institute. All rights reserved.
2 | import torch
3 | import torch.nn.functional as F
4 | from realtime_panoptic.utils.bounding_box import BoxList
5 | from realtime_panoptic.utils.boxlist_ops import (boxlist_nms, cat_boxlist, remove_small_boxes)
6 |
7 | class PanopticFromDenseBox:
8 | """Performs post-processing on the outputs of the RTPanonet.
9 |
10 | Parameters
11 | ----------
12 | pre_nms_thresh: float
13 | Acceptance class probability threshold for bounding box candidates before NMS.
14 |
15 | pre_nms_top_n: int
16 | Maximum number of accepted bounding box candidates before NMS.
17 |
18 | nms_thresh: float
19 | NMS threshold.
20 |
21 | fpn_post_nms_top_n: int
22 | Maximum number of detected object per image.
23 |
24 | min_size: int
25 | Minimum dimension of accepted detection.
26 |
27 | num_classes: int
28 | Number of total semantic classes (stuff and things).
29 |
30 | mask_thresh: float
31 | Bounding box IoU threshold to determined 'similar bounding box' in mask reconstruction.
32 |
33 | instance_id_range: list of int
34 | [min_id, max_id] defines the range of id in 1:num_classes that corresponding to thing classes.
35 |
36 | is_training: bool
37 | Whether the current process is during training process.
38 | """
39 |
40 | def __init__(
41 | self,
42 | pre_nms_thresh,
43 | pre_nms_top_n,
44 | nms_thresh,
45 | fpn_post_nms_top_n,
46 | min_size,
47 | num_classes,
48 | mask_thresh,
49 | instance_id_range,
50 | is_training
51 | ):
52 | super(PanopticFromDenseBox, self).__init__()
53 | # assign parameters
54 | self.pre_nms_thresh = pre_nms_thresh
55 | self.pre_nms_top_n = pre_nms_top_n
56 | self.nms_thresh = nms_thresh
57 | self.fpn_post_nms_top_n = fpn_post_nms_top_n
58 | self.min_size = min_size
59 | self.num_classes = num_classes
60 | self.mask_thresh = mask_thresh
61 | self.instance_id_range = instance_id_range
62 | self.is_training = is_training
63 |
64 | def process(
65 | self, locations, box_cls, box_regression, centerness, levelness_logits, semantic_logits, image_sizes
66 | ):
67 | """ Reconstruct panoptic segmentation result from raw predictions.
68 |
69 | This function conduct post processing of panoptic head raw prediction, including bounding box
70 | prediction, semantic segmentation and levelness to reconstruct instance segmentation results.
71 |
72 | Parameters
73 | ----------
74 | locations: list of torch.Tensor
75 | Corresponding pixel locations of each FPN predictions.
76 |
77 | box_cls: list of torch.Tensor
78 | Predicted bounding box class from each FPN layers.
79 |
80 | box_regression: list of torch.Tensor
81 | Predicted bounding box offsets from each FPN layers.
82 |
83 | centerness: list of torch.Tensor
84 | Predicted object centerness from each FPN layers.
85 |
86 | levelness_logits:
87 | Global prediction of best source FPN layer for each pixel location.
88 |
89 | semantic_logits:
90 | Global prediction of semantic segmentation.
91 |
92 | image_sizes: list of [int,int]
93 | Image sizes.
94 |
95 | Returns:
96 | --------
97 | boxlists: list of BoxList
98 | reconstructed instances with masks.
99 | """
100 | num_locs_per_level = [len(loc_per_level) for loc_per_level in locations]
101 |
102 | sampled_boxes = []
103 | for i, (l, o, b, c) in enumerate(zip(locations[:-1], box_cls, box_regression, centerness)):
104 | if self.is_training:
105 | layer_boxes = self.forward_for_single_feature_map(l, o, b, c, image_sizes)
106 | for layer_box in layer_boxes:
107 | pred_indices = layer_box.get_field("indices")
108 | pred_indices = pred_indices + sum(num_locs_per_level[:i])
109 | layer_box.add_field("indices", pred_indices)
110 | sampled_boxes.append(layer_boxes)
111 | else:
112 | sampled_boxes.append(self.forward_for_single_feature_map(l, o, b, c, image_sizes))
113 |
114 | # sampled_boxes are a list of bbox_list per level
115 | # the following converts it to per image
116 | boxlists = list(zip(*sampled_boxes))
117 | # per image, concat bbox_list of different levels into one bbox_list
118 | # boxlists is a list of bboxlists of N images
119 | try:
120 | boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
121 | boxlists = self.select_over_all_levels(boxlists)
122 | except Exception as e:
123 | print(e)
124 | for boxlist in boxlists:
125 | for box in boxlist:
126 | print(box, "box shape", box.bbox.shape)
127 |
128 | # Generate bounding box feature map at size of [H/4, W/4] with bounding box prediction as features.
129 | levelness_locations = locations[-1]
130 | _, c_semantic, _, _ = semantic_logits.shape
131 | N, _, h_map, w_map = levelness_logits.shape
132 | bounding_box_feature_map = self.generate_box_feature_map(levelness_locations, box_regression, levelness_logits)
133 |
134 | # process semantic raw prediction
135 | semantic_logits = F.interpolate(semantic_logits, size=(h_map, w_map), mode='bilinear')
136 | semantic_logits = semantic_logits.view(N, c_semantic, h_map, w_map).permute(0, 2, 3, 1)
137 | semantic_logits = semantic_logits.reshape(N, -1, c_semantic)
138 |
139 | # insert semantic prob into mask
140 | semantic_probability = F.softmax(semantic_logits, dim=2)
141 | semantic_probability = semantic_probability[:, :, self.instance_id_range[0]:]
142 | boxlists = self.mask_reconstruction(
143 | boxlists=boxlists,
144 | box_feature_map=bounding_box_feature_map,
145 | semantic_prob=semantic_probability,
146 | box_feature_map_location=levelness_locations,
147 | h_map=h_map,
148 | w_map=w_map
149 | )
150 | # resize instance masks to original image size
151 | if not self.is_training:
152 | for boxlist in boxlists:
153 | masks = boxlist.get_field("mask")
154 | # NOTE: BoxList size is the image size without padding. MASK here is a mask with padding.
155 | # Mask need to be interpolated into padded image size and then crop to unpadded size.
156 | w, h = boxlist.size
157 | if len(masks.shape) == 3 and masks.shape[0] != 0:
158 | masks = F.interpolate(masks.unsqueeze(0), size=(h_map * 4, w_map * 4), mode='bilinear').squeeze()
159 | else:
160 | # handle 0 shape dummy mask.
161 | masks = masks.view([-1, h_map * 4, w_map * 4])
162 | masks = masks >= self.mask_thresh
163 | if len(masks.shape) < 3:
164 | masks = masks.unsqueeze(0)
165 | masks = masks[:, 0:h, 0:w].contiguous()
166 | boxlist.add_field("mask", masks)
167 | return boxlists
168 |
169 | def forward_for_single_feature_map(self, locations, box_cls, box_regression, centerness, image_sizes):
170 | """Recover dense bounding box detection results from raw predictions for each FPN layer.
171 |
172 | Parameters
173 | ----------
174 | locations: torch.Tensor
175 | Corresponding pixel location of FPN feature map with size of (N, H * W, 2).
176 |
177 | box_cls: torch.Tensor
178 | Predicted bounding box class probability with size of (N, C, H, W).
179 |
180 | box_regression: torch.Tensor
181 | Predicted bounding box offset centered at corresponding pixel with size of (N, 4, H, W).
182 |
183 | centerness: torch.Tensor
184 | Predicted centerness of corresponding pixel with size of (N, 1, H, W).
185 |
186 | Note: N is the number of FPN level.
187 |
188 | Returns
189 | -------
190 | results: List of BoxList
191 | A list of dense bounding boxes from each FPN layer.
192 | """
193 |
194 | N, C, H, W = box_cls.shape
195 | # M = H x W is the total number of proposal for this single feature map
196 |
197 | # put in the same format as locations
198 | # from (N, C, H, W) to (N, H, W, C)
199 | box_cls = box_cls.view(N, C, H, W).permute(0, 2, 3, 1)
200 | # from (N, H, W, C) to (N, M, C)
201 | # map class prob to (-1, +1)
202 | box_cls = box_cls.reshape(N, -1, C).sigmoid()
203 | # from (N, 4, H, W) to (N, H, W, 4) to (N, M, 4)
204 | box_regression = box_regression.view(N, 4, H, W).permute(0, 2, 3, 1)
205 | box_regression = box_regression.reshape(N, -1, 4)
206 | # from (N, 4, H, W) to (N, H, W, 1) to (N, M)
207 | # map centerness prob to (-1, +1)
208 | centerness = centerness.view(N, 1, H, W).permute(0, 2, 3, 1)
209 | centerness = centerness.reshape(N, -1).sigmoid()
210 |
211 | # before NMS, per level filter out low cls prob with threshold 0.05
212 | # after this candidate_inds of size (N, M, C) with values corresponding to
213 | # low prob predictions become 0, otherwise 1
214 | candidate_inds = box_cls > self.pre_nms_thresh
215 |
216 | # pre_nms_top_n of size (N, M * C) => (N, 1)
217 | # N -> batch index, 1 -> total number of bbox predictions per image
218 | pre_nms_top_n = candidate_inds.view(N, -1).sum(1)
219 | # total number of proposal before NMS
220 | # if have more than self.pre_nms_top_n (1000) clamp to 1000
221 | pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)
222 |
223 | # multiply the classification scores with centerness scores
224 | # (N, M, C) * (N, M, 1)
225 | box_cls = box_cls * centerness[:, :, None]
226 |
227 | results = []
228 | for i in range(N):
229 | # filer out low score candidates
230 | per_box_cls = box_cls[i] # (M, C)
231 | per_candidate_inds = candidate_inds[i] # (M, C)
232 | # per_box_cls of size P, P < M * C
233 | per_box_cls = per_box_cls[per_candidate_inds]
234 |
235 | # indices of seeds bounding boxes
236 | # 0-dim corresponding to M, location
237 | # 1-dim corresponding to C, class
238 | per_candidate_nonzeros = per_candidate_inds.nonzero()
239 | # Each of the following is of size P < M * C
240 | per_box_loc = per_candidate_nonzeros[:, 0]
241 | per_class = per_candidate_nonzeros[:, 1] + 1
242 |
243 | # per_box_regression of size (M, 4)
244 | per_box_regression = box_regression[i]
245 | # (M, 4) => (P, 4)
246 | # in P, there might be identical bbox prediction in M
247 | per_box_regression = per_box_regression[per_box_loc]
248 | # (M, 2) => (P, 2)
249 | # in P, there might be identical locations in M
250 | per_locations = locations[per_box_loc]
251 |
252 |
253 | # upperbound of the number of predictions for this image
254 | per_pre_nms_top_n = pre_nms_top_n[i]
255 |
256 | # if valid predictions is more than the upperbound
257 | # only select topK
258 | if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
259 | per_box_cls, top_k_indices = \
260 | per_box_cls.topk(per_pre_nms_top_n, sorted=False)
261 | per_class = per_class[top_k_indices]
262 | per_box_regression = per_box_regression[top_k_indices]
263 | per_locations = per_locations[top_k_indices]
264 | if self.is_training:
265 | per_box_loc = per_box_loc[top_k_indices]
266 |
267 | detections = torch.stack([
268 | per_locations[:, 0] - per_box_regression[:, 0],
269 | per_locations[:, 1] - per_box_regression[:, 1],
270 | per_locations[:, 0] + per_box_regression[:, 2],
271 | per_locations[:, 1] + per_box_regression[:, 3],
272 | ],
273 | dim=1)
274 |
275 | h, w = image_sizes[i]
276 |
277 | boxlist = BoxList(detections, (int(w), int(h)), mode="xyxy")
278 | boxlist.add_field("labels", per_class)
279 | boxlist.add_field("scores", per_box_cls)
280 | if self.is_training:
281 | boxlist.add_field("indices", per_box_loc)
282 |
283 | boxlist = boxlist.clip_to_image(remove_empty=False)
284 | boxlist = remove_small_boxes(boxlist, self.min_size)
285 | results.append(boxlist)
286 | return results
287 |
288 | def generate_box_feature_map(self, location, box_regression, levelness_logits):
289 | """Generate bounding box feature aggregating dense bounding box predictions.
290 |
291 | Parameters
292 | ----------
293 | location: torch.Tensor
294 | Pixel location of levelness.
295 |
296 | box_regression: list of torch.Tensor
297 | Bounding box offsets from each FPN.
298 |
299 | levelness_logits: torch.Tenor
300 | Global prediction of best source FPN layer for each pixel location.
301 | Predict at the resolution of (H/4, W/4).
302 |
303 | Returns
304 | -------
305 | bounding_box_feature_map: torch.Tensor
306 | Aggregated bounding box feature map.
307 | """
308 | upscaled_box_reg = []
309 | N, _, h_map, w_map = levelness_logits.shape
310 | downsampled_shape = torch.Size((h_map, w_map))
311 | for box_reg in box_regression:
312 | upscaled_box_reg.append(F.interpolate(box_reg, size=downsampled_shape, mode='bilinear').unsqueeze(1))
313 |
314 | # N_level, 4, h_map, w_map
315 | upscaled_box_reg = torch.cat(upscaled_box_reg, 1)
316 |
317 | max_v, level = torch.max(levelness_logits[:, 1:, :, :], dim=1)
318 |
319 | box_feature_map = torch.gather(
320 | upscaled_box_reg, dim=1, index=level.unsqueeze(1).expand([N, 4, h_map, w_map]).unsqueeze(1)
321 | )
322 |
323 | box_feature_map = box_feature_map.view(N, 4, h_map, w_map).permute(0, 2, 3, 1)
324 | box_feature_map = box_feature_map.reshape(N, -1, 4)
325 | # generate all valid bboxes from feature map
326 | # shape (N, M, 4)
327 | levelness_locations_repeat = location.repeat(N, 1, 1)
328 | bounding_box_feature_map = torch.stack([
329 | levelness_locations_repeat[:, :, 0] - box_feature_map[:, :, 0],
330 | levelness_locations_repeat[:, :, 1] - box_feature_map[:, :, 1],
331 | levelness_locations_repeat[:, :, 0] + box_feature_map[:, :, 2],
332 | levelness_locations_repeat[:, :, 1] + box_feature_map[:, :, 3],
333 | ], dim=2)
334 | return bounding_box_feature_map
335 |
336 | def mask_reconstruction(self, boxlists, box_feature_map, semantic_prob, box_feature_map_location, h_map, w_map):
337 | """Reconstruct instance mask from dense bounding box and semantic smoothing.
338 |
339 | Parameters
340 | ----------
341 | boxlists: List of Boxlist
342 | Object detection result after NMS.
343 |
344 | box_feature_map: torch.Tensor
345 | Aggregated bounding box feature map.
346 |
347 | semantic_prob: torch.Tensor
348 | Prediction semantic probability.
349 |
350 | box_feature_map_location: torch.Tensor
351 | Corresponding pixel location of bounding box feature map.
352 |
353 | h_map: int
354 | Height of bounding box feature map.
355 |
356 | w_map: int
357 | Width of bounding box feature map.
358 | """
359 | for i, (boxlist, per_image_bounding_box_feature_map, per_image_semantic_prob,
360 | box_feature_map_loc) in enumerate(zip(boxlists, box_feature_map, semantic_prob, box_feature_map_location)):
361 |
362 | # decode mask from bbox embedding
363 | if len(boxlist) > 0:
364 | # query_boxes is of shape (P, 4)
365 | # dense_detections is of shape (P', 4)
366 | # P' is larger than P
367 | query_boxes = boxlist.bbox
368 | propose_cls = boxlist.get_field("labels")
369 | # (P, 4) -> (P, 4, 1) -> (P, 4, P) -> (P, P', 4)
370 | propose_bbx = query_boxes.unsqueeze(2).repeat(1, 1,
371 | per_image_bounding_box_feature_map.shape[0]).permute(0, 2, 1)
372 | # (P',4) -> (4, P') -> (1, 4, P') -> (P, 4, P') -> (P, P', 4)
373 | voting_bbx = per_image_bounding_box_feature_map.permute(1, 0).unsqueeze(0).repeat(query_boxes.shape[0], 1,
374 | 1).permute(0, 2, 1)
375 | # implementation based on IOU for bbox_correlation_map
376 | # 0, 1, 2, 3 => left, top, right, bottom
377 | proposal_area = (propose_bbx[:, :, 2] - propose_bbx[:, :, 0]) * \
378 | (propose_bbx[:, :, 3] - propose_bbx[:, :, 1])
379 | voting_area = (voting_bbx[:, :, 2] - voting_bbx[:, :, 0]) * \
380 | (voting_bbx[:, :, 3] - voting_bbx[:, :, 1])
381 | w_intersect = torch.min(voting_bbx[:, :, 2], propose_bbx[:, :, 2]) - \
382 | torch.max(voting_bbx[:, :, 0], propose_bbx[:, :, 0])
383 | h_intersect = torch.min(voting_bbx[:, :, 3], propose_bbx[:, :, 3]) - \
384 | torch.max(voting_bbx[:, :, 1], propose_bbx[:, :, 1])
385 | w_intersect = w_intersect.clamp(min=0.0)
386 | h_intersect = h_intersect.clamp(min=0.0)
387 | w_general = torch.max(voting_bbx[:, :, 2], propose_bbx[:, :, 2]) - \
388 | torch.min(voting_bbx[:, :, 0], propose_bbx[:, :, 0])
389 | h_general = torch.max(voting_bbx[:, :, 3], propose_bbx[:, :, 3]) - \
390 | torch.min(voting_bbx[:, :, 1], propose_bbx[:, :, 1])
391 | # calculate IOU
392 | area_intersect = w_intersect * h_intersect
393 | area_union = proposal_area + voting_area - area_intersect
394 | torch.cuda.synchronize()
395 |
396 | area_general = w_general * h_general + 1e-7
397 | bbox_correlation_map = (area_intersect + 1.0) / (area_union + 1.0) - \
398 | (area_general - area_union) / area_general
399 |
400 | per_image_cls_prob = per_image_semantic_prob[:, propose_cls - 1].permute(1, 0)
401 | # bbox_correlation_map is of size (P or per_pre_nms_top_n, P')
402 | bbox_correlation_map = bbox_correlation_map * per_image_cls_prob
403 | # query_boxes.shape[0] is the number of filtered boxes
404 | masks = bbox_correlation_map.view(query_boxes.shape[0], h_map, w_map)
405 | if len(masks.shape) < 3:
406 | masks = masks.unsqueeze(0)
407 | boxlist.add_field("mask", masks)
408 | else:
409 | dummy_masks = torch.zeros(len(boxlist), h_map,
410 | w_map).float().to(boxlist.bbox.device).to(boxlist.bbox.dtype)
411 | boxlist.add_field("mask", dummy_masks)
412 | return boxlists
413 |
414 | def select_over_all_levels(self, boxlists):
415 | """NMS of bounding box candidates.
416 |
417 | Parameters
418 | ----------
419 | boxlists: list of Boxlist
420 | Pre-NMS bounding boxes.
421 |
422 | Returns
423 | -------
424 | results: list of Boxlist
425 | Final detection result.
426 | """
427 | num_images = len(boxlists)
428 | results = []
429 | for i in range(num_images):
430 | boxlist = boxlists[i]
431 | scores = boxlist.get_field("scores")
432 | labels = boxlist.get_field("labels")
433 | if self.is_training:
434 | indices = boxlist.get_field("indices")
435 | boxes = boxlist.bbox
436 |
437 | result = []
438 | w, h = boxlist.size
439 | # skip the background
440 | if boxes.shape[0] < 1:
441 | results.append(boxlist)
442 | continue
443 | for j in range(1, self.num_classes):
444 | inds = (labels == j).nonzero().view(-1)
445 | if len(inds) > 0:
446 | scores_j = scores[inds]
447 | boxes_j = boxes[inds, :].view(-1, 4)
448 |
449 | boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
450 | boxlist_for_class.add_field("scores", scores_j)
451 |
452 | if self.is_training:
453 | indices_j = indices[inds]
454 | boxlist_for_class.add_field("indices", indices_j)
455 |
456 | boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms_thresh, score_field="scores")
457 | num_labels = len(boxlist_for_class)
458 | boxlist_for_class.add_field(
459 | "labels", torch.full((num_labels, ), j, dtype=torch.int64, device=scores.device)
460 | )
461 | result.append(boxlist_for_class)
462 | result = cat_boxlist(result)
463 |
464 | # global NMS
465 | result = boxlist_nms(result, 0.97, score_field="scores")
466 |
467 | number_of_detections = len(result)
468 |
469 | # Limit to max_per_image detections **over all classes**
470 | if number_of_detections > self.fpn_post_nms_top_n > 0:
471 | cls_scores = result.get_field("scores")
472 | image_thresh, _ = torch.kthvalue(cls_scores.cpu(), number_of_detections - self.fpn_post_nms_top_n + 1)
473 | keep = cls_scores >= image_thresh.item()
474 | keep = torch.nonzero(keep).squeeze(1)
475 | result = result[keep]
476 | results.append(result)
477 | return results
478 |
--------------------------------------------------------------------------------
/realtime_panoptic/models/rt_pano_net.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Toyota Research Institute. All rights reserved.
2 |
3 | # We adapted some FCOS related functions from official repository.
4 | # https://github.com/tianzhi0549/FCOS
5 |
6 | import math
7 | from collections import OrderedDict
8 | import torch
9 | from torch import nn
10 | import torch.nn.functional as F
11 | import apex
12 | from realtime_panoptic.layers.scale import Scale
13 | from realtime_panoptic.utils.bounding_box import BoxList
14 | from realtime_panoptic.models.backbones import ResNetWithModifiedFPN
15 | from realtime_panoptic.models.panoptic_from_dense_box import PanopticFromDenseBox
16 | class RTPanoNet(torch.nn.Module):
17 | """Real-Time Panoptic Network
18 | This module takes the input from a FPN backbone and conducts feature extraction
19 | through a panoptic head, which can be then fed into post processing for final panoptic
20 | results including semantic segmentation and instance segmentation.
21 | NOTE: Currently only the inference functionality is supported.
22 |
23 | Parameters
24 | ----------
25 | backbone: str
26 | backbone type.
27 |
28 | num_classes: int
29 | Number of total classes, including 'things' and 'stuff'.
30 |
31 | things_num_classes: int
32 | number of thing classes
33 |
34 | pre_nms_thresh: float
35 | Acceptance class probability threshold for bounding box candidates before NMS.
36 |
37 | pre_nms_top_n: int
38 | Maximum number of accepted bounding box candidates before NMS.
39 |
40 | nms_thresh: float
41 | NMS threshold.
42 |
43 | fpn_post_nms_top_n: int
44 | Maximum number of detected object per image.
45 |
46 | instance_id_range: list of int
47 | [min_id, max_id] defines the range of id in 1:num_classes that corresponding to thing classes.
48 | """
49 |
50 | def __init__(
51 | self,
52 | backbone,
53 | num_classes,
54 | things_num_classes,
55 | pre_nms_thresh,
56 | pre_nms_top_n,
57 | nms_thresh,
58 | fpn_post_nms_top_n,
59 | instance_id_range
60 | ):
61 | super(RTPanoNet, self).__init__()
62 | # TODO: adapt more backbone.
63 | if backbone == 'R-50-FPN-RETINANET':
64 | self.backbone = ResNetWithModifiedFPN('resnet50')
65 | backbone_out_channels = 256
66 | fpn_strides = [8, 16, 32, 64, 128]
67 | num_fpn_levels = 5
68 | else:
69 | raise NotImplementedError("Backbone type: {} is not supported yet.".format(backbone))
70 | # Global panoptic head that extracts features from each FPN output feature map
71 | self.panoptic_head = PanopticHead(
72 | num_classes,
73 | things_num_classes,
74 | num_fpn_levels,
75 | fpn_strides,
76 | backbone_out_channels
77 | )
78 |
79 | # Parameters
80 | self.fpn_strides = fpn_strides
81 |
82 | # Use dense bounding boxes to reconstruct panoptic segmentation results.
83 | self.panoptic_from_dense_bounding_box = PanopticFromDenseBox(
84 | pre_nms_thresh=pre_nms_thresh,
85 | pre_nms_top_n=pre_nms_top_n,
86 | nms_thresh=nms_thresh,
87 | fpn_post_nms_top_n=fpn_post_nms_top_n,
88 | min_size=0,
89 | num_classes=num_classes,
90 | mask_thresh=0.4,
91 | instance_id_range=instance_id_range,
92 | is_training=False)
93 |
94 | def forward(self, images, detection_targets=None, segmentation_targets=None):
95 | """ Forward function.
96 |
97 | Parameters
98 | ----------
99 | images: torchvision.models.detection.ImageList
100 | Images for which we want to compute the predictions
101 |
102 | detection_targets: list of BoxList
103 | Ground-truth boxes present in the image
104 |
105 | segmentation_targets: List of torch.Tensor
106 | semantic segmentation target for each image in the batch.
107 |
108 | Returns
109 | -------
110 | panoptic_result: Dict
111 | 'instance_segmentation_result': list of BoxList
112 | The predicted boxes (including instance masks), one BoxList per image.
113 | 'semantic_segmentation_result': torch.Tensor
114 | semantic logits interpolated to input data size.
115 | NOTE: this might not be the original input image size due to paddings.
116 | losses: dict of torch.ScalarTensor
117 | the losses for the model during training. During testing, it is an empty dict.
118 | """
119 | features = self.backbone(torch.stack(images.tensors))
120 |
121 | locations = self.compute_locations(list(features.values()))
122 |
123 | semantic_logits, box_cls, box_regression, centerness, levelness_logits = self.panoptic_head(list(features.values()))
124 |
125 | # Get full size semantic logits.
126 | downsampled_level = images.tensors[0].shape[-1] // semantic_logits.shape[-1]
127 | interpolated_semantic_logits_padded = F.interpolate(semantic_logits, scale_factor=downsampled_level, mode='bilinear')
128 | interpolated_semantic_logits = interpolated_semantic_logits_padded[:,:,:images.tensors[0].shape[-2], :images.tensors[0].shape[-1]]
129 | # Calculate levelness locations.
130 | h, w = levelness_logits.size()[-2:]
131 | levelness_location = self.compute_locations_per_level(h, w, self.fpn_strides[0] // 2, levelness_logits.device)
132 | locations.append(levelness_location)
133 |
134 | # Reconstruct mask from dense bounding box and semantic predictions
135 | panoptic_result = OrderedDict()
136 | boxes = self.panoptic_from_dense_bounding_box.process(
137 | locations, box_cls, box_regression, centerness, levelness_logits, semantic_logits, images.image_sizes
138 | )
139 | panoptic_result["instance_segmentation_result"] = boxes
140 | panoptic_result["semantic_segmentation_result"] = interpolated_semantic_logits
141 | return panoptic_result, {}
142 |
143 | def compute_locations(self, features):
144 | """Compute corresponding pixel location for feature maps.
145 |
146 | Parameters
147 | ----------
148 | features: list of torch.Tensor
149 | List of feature maps.
150 |
151 | Returns
152 | -------
153 | locations: list of torch.Tensor
154 | List of pixel location corresponding to the list of features.
155 | """
156 | locations = []
157 | for level, feature in enumerate(features):
158 | h, w = feature.size()[-2:]
159 | locations_per_level = self.compute_locations_per_level(h, w, self.fpn_strides[level], feature.device)
160 | locations.append(locations_per_level)
161 | return locations
162 |
163 | def compute_locations_per_level(self, h, w, stride, device):
164 | """Compute corresponding pixel location for a feature map in pyramid space with certain stride.
165 |
166 | Parameters
167 | ----------
168 | h: int
169 | height of current feature map.
170 |
171 | w: int
172 | width of current feature map.
173 |
174 | stride: int
175 | stride level of current feature map with respect to original input.
176 |
177 | device: torch.device
178 | device to create return tensor.
179 |
180 | Returns
181 | -------
182 | locations: torch.Tensor
183 | pixel location map.
184 | """
185 | shifts_x = torch.arange(0, w * stride, step=stride, dtype=torch.float32, device=device)
186 | shifts_y = torch.arange(0, h * stride, step=stride, dtype=torch.float32, device=device)
187 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
188 | shift_x = shift_x.reshape(-1)
189 | shift_y = shift_y.reshape(-1)
190 | locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2
191 | return locations
192 |
193 |
194 | class PanopticHead(torch.nn.Module):
195 | """Network module of Panoptic Head extracting features from FPN feature maps.
196 |
197 | Parameters
198 | ----------
199 | num_classes: int
200 | Number of total classes, including 'things' and 'stuff'.
201 |
202 | things_num_classes: int
203 | number of thing classes.
204 |
205 | num_fpn_levels: int
206 | Number of FPN levels.
207 |
208 | fpn_strides: list
209 | FPN strides at each FPN scale.
210 |
211 | in_channels: int
212 | Number of channels of the input features (output of FPN)
213 |
214 | norm_reg_targets: bool
215 | If true, train on normalized target.
216 |
217 | centerness_on_reg: bool
218 | If true, regress centerness on box tower of FCOS.
219 |
220 | fcos_num_convs: int
221 | number of convolution modules used in FCOS towers.
222 |
223 | fcos_norm: str
224 | Normalization layer type used in FCOS modules.
225 |
226 | prior_prob: float
227 | Initial probability for focal loss. See `https://arxiv.org/pdf/1708.02002.pdf` for more details.
228 | """
229 | def __init__(
230 | self,
231 | num_classes,
232 | things_num_classes,
233 | num_fpn_levels,
234 | fpn_strides,
235 | in_channels,
236 | norm_reg_targets=False,
237 | centerness_on_reg=True,
238 | fcos_num_convs=4,
239 | fcos_norm='GN',
240 | prior_prob=0.01,
241 | ):
242 | super(PanopticHead, self).__init__()
243 | self.fpn_strides = fpn_strides
244 | self.norm_reg_targets = norm_reg_targets
245 | self.centerness_on_reg = centerness_on_reg
246 |
247 | cls_tower = []
248 | bbox_tower = []
249 |
250 | mid_channels = in_channels // 2
251 |
252 | for i in range(fcos_num_convs):
253 | # Class branch
254 | if i == 0:
255 | cls_tower.append(nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1))
256 | else:
257 | cls_tower.append(nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=1, padding=1))
258 | if fcos_norm == "GN":
259 | cls_tower.append(nn.GroupNorm(mid_channels // 8, mid_channels))
260 | elif fcos_norm == "BN":
261 | cls_tower.append(nn.BatchNorm2d(mid_channels))
262 | elif fcos_norm == "SBN":
263 | cls_tower.append(apex.parallel.SyncBatchNorm(mid_channels))
264 | cls_tower.append(nn.ReLU())
265 |
266 | # Box regression branch
267 | bbox_tower.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
268 |
269 | if fcos_norm == "GN":
270 | bbox_tower.append(nn.GroupNorm(in_channels // 8, in_channels))
271 | elif fcos_norm == "BN":
272 | bbox_tower.append(nn.BatchNorm2d(in_channels))
273 | elif fcos_norm == "SBN":
274 | bbox_tower.append(apex.parallel.SyncBatchNorm(in_channels))
275 | bbox_tower.append(nn.ReLU())
276 |
277 | self.add_module('cls_tower', nn.Sequential(*cls_tower))
278 | self.add_module('bbox_tower', nn.Sequential(*bbox_tower))
279 |
280 | self.cls_logits = nn.Conv2d(mid_channels * 5, num_classes, kernel_size=3, stride=1, padding=1)
281 | self.box_cls_logits = nn.Conv2d(mid_channels, things_num_classes, kernel_size=3, stride=1, padding=1)
282 | self.bbox_pred = nn.Conv2d(in_channels, 4, kernel_size=3, stride=1, padding=1)
283 | self.centerness = nn.Conv2d(in_channels, 1, kernel_size=3, stride=1, padding=1)
284 | self.levelness = nn.Conv2d(in_channels * 5, num_fpn_levels + 1, kernel_size=3, stride=1, padding=1)
285 |
286 | # initialization
287 | to_initialize = [
288 | self.bbox_tower, self.cls_logits, self.cls_tower, self.bbox_pred, self.centerness, self.levelness,
289 | self.box_cls_logits
290 | ]
291 |
292 | for modules in to_initialize:
293 | for l in modules.modules():
294 | if isinstance(l, nn.Conv2d):
295 | torch.nn.init.normal_(l.weight, std=0.01)
296 | torch.nn.init.constant_(l.bias, 0)
297 |
298 | # initialize the bias for focal loss
299 | prior_prob = prior_prob
300 | bias_value = -math.log((1 - prior_prob) / prior_prob)
301 | torch.nn.init.constant_(self.cls_logits.bias, bias_value)
302 |
303 | self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)])
304 |
305 | def forward(self, x):
306 | box_cls = []
307 | logits = []
308 | bbox_reg = []
309 | centerness = []
310 | levelness = []
311 |
312 | downsampled_shape = x[0].shape[2:]
313 | box_feature_map_downsampled_shape = torch.Size((downsampled_shape[0] * 2, downsampled_shape[1] * 2))
314 |
315 | for l, feature in enumerate(x):
316 | # bbox
317 | box_tower = self.bbox_tower(feature)
318 |
319 | # class
320 | cls_tower = self.cls_tower(feature)
321 | box_cls.append(self.box_cls_logits(cls_tower))
322 | logits.append(F.interpolate(cls_tower, size=downsampled_shape, mode='bilinear'))
323 |
324 | # centerness
325 | if self.centerness_on_reg:
326 | centerness.append(self.centerness(box_tower))
327 | else:
328 | centerness.append(self.centerness(cls_tower))
329 |
330 | # bbox regression
331 | bbox_pred = self.scales[l](self.bbox_pred(box_tower))
332 | if self.norm_reg_targets:
333 | bbox_pred = F.relu(bbox_pred)
334 | if self.training:
335 | bbox_reg.append(bbox_pred)
336 | else:
337 | bbox_reg.append(bbox_pred * self.fpn_strides[l])
338 | else:
339 | bbox_reg.append(torch.exp(bbox_pred.clamp(max=math.log(10000))))
340 |
341 | # levelness prediction
342 | levelness.append(F.interpolate(box_tower, size=box_feature_map_downsampled_shape, mode='bilinear'))
343 |
344 | # predict levelness
345 | levelness = torch.cat(levelness, dim=1)
346 | # levelness = torch.stack(levelness, dim=0).sum(dim=0)
347 | levelness_logits = self.levelness(levelness)
348 |
349 | # level attention for semantic segmentation
350 | logits = torch.cat(logits, 1)
351 | semantic_logits = self.cls_logits(logits)
352 |
353 | return semantic_logits, box_cls, bbox_reg, centerness, levelness_logits
354 |
355 |
356 |
--------------------------------------------------------------------------------
/realtime_panoptic/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Toyota Research Institute. All rights reserved.
2 |
--------------------------------------------------------------------------------
/realtime_panoptic/utils/bounding_box.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | # Copyright 2020 Toyota Research Institute. All rights reserved.
3 |
4 | # Adapted from maskrcnn-benchmark
5 | # https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/structures/bounding_box.py
6 |
7 | import math
8 |
9 | import torch
10 |
11 | # transpose
12 | FLIP_LEFT_RIGHT = 0
13 | FLIP_TOP_BOTTOM = 1
14 | ROTATE_90 = 2
15 |
16 |
17 | class BoxList:
18 | """This class represents a set of bounding boxes.
19 | The bounding boxes are represented as a Nx4 Tensor.
20 | In order to uniquely determine the bounding boxes with respect
21 | to an image, we also store the corresponding image dimensions.
22 | They can contain extra information that is specific to each bounding box, such as
23 | labels.
24 | """
25 |
26 | def __init__(self, bbox, image_size, mode="xyxy"):
27 | """Initial function.
28 |
29 | Parameters
30 | ----------
31 | bbox: tensor
32 | Nx4 tensor following bounding box parameterization defined by "mode".
33 |
34 | image_size: list
35 | [W,H] Image size.
36 |
37 | mode: str
38 | Bounding box parameterization. 'xyxy' or 'xyhw'.
39 | """
40 | device = bbox.device if isinstance(
41 | bbox, torch.Tensor) else torch.device("cpu")
42 | bbox = torch.as_tensor(bbox, dtype=torch.float32, device=device)
43 | if bbox.ndimension() != 2:
44 | raise ValueError("bbox should have 2 dimensions, got {}".format(
45 | bbox.ndimension()), bbox)
46 | if bbox.size(-1) != 4:
47 | raise ValueError("last dimension of bbox should have a "
48 | "size of 4, got {}".format(bbox.size(-1)))
49 | if mode not in ("xyxy", "xywh"):
50 | raise ValueError("mode should be 'xyxy' or 'xywh'")
51 |
52 | self.bbox = bbox
53 | self.size = image_size # (image_width, image_height)
54 | self.mode = mode
55 | self.extra_fields = {}
56 |
57 | def add_field(self, field, field_data):
58 | """Add a field to boxlist.
59 | """
60 | self.extra_fields[field] = field_data
61 |
62 | def get_field(self, field):
63 | """Get a field from boxlist.
64 | """
65 | return self.extra_fields[field]
66 |
67 | def has_field(self, field):
68 | """Check if certain field exist in boxlist
69 | """
70 | return field in self.extra_fields
71 |
72 | def fields(self):
73 | """Get all available field names.
74 | """
75 | return list(self.extra_fields.keys())
76 |
77 | def _copy_extra_fields(self, bbox):
78 | """Copy extra fields from given boxlist to current boxlist.
79 | """
80 | for k, v in bbox.extra_fields.items():
81 | self.extra_fields[k] = v
82 |
83 | def convert(self, mode):
84 | """Convert bounding box parameterization mode.
85 | """
86 | if mode not in ("xyxy", "xywh"):
87 | raise ValueError("mode should be 'xyxy' or 'xywh'")
88 | if mode == self.mode:
89 | return self
90 | # we only have two modes, so don't need to check
91 | # self.mode
92 | xmin, ymin, xmax, ymax = self._split_into_xyxy()
93 | if mode == "xyxy":
94 | bbox = torch.cat((xmin, ymin, xmax, ymax), dim=-1)
95 | bbox = BoxList(bbox, self.size, mode=mode)
96 | else:
97 | TO_REMOVE = 1
98 | bbox = torch.cat(
99 | (xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE),
100 | dim=-1)
101 | bbox = BoxList(bbox, self.size, mode=mode)
102 | bbox._copy_extra_fields(self)
103 | return bbox
104 |
105 | def _split_into_xyxy(self):
106 | """split lists of bounding box corners.
107 | """
108 | if self.mode == "xyxy":
109 | xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1)
110 | return xmin, ymin, xmax, ymax
111 | elif self.mode == "xywh":
112 | TO_REMOVE = 1
113 | xmin, ymin, w, h = self.bbox.split(1, dim=-1)
114 | return (
115 | xmin,
116 | ymin,
117 | xmin + (w - TO_REMOVE).clamp(min=0),
118 | ymin + (h - TO_REMOVE).clamp(min=0),
119 | )
120 | else:
121 | raise RuntimeError("Should not be here")
122 |
123 | def resize(self, size, *args, **kwargs):
124 | """Returns a resized copy of this bounding box.
125 |
126 | Parameters
127 | ----------
128 | size: list or tuple
129 | The requested image size in pixels, as a 2-tuple:
130 | (width, height).
131 | """
132 |
133 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
134 | if ratios[0] == ratios[1]:
135 | ratio = ratios[0]
136 | scaled_box = self.bbox * ratio
137 | bbox = BoxList(scaled_box, size, mode=self.mode)
138 | # bbox._copy_extra_fields(self)
139 | for k, v in self.extra_fields.items():
140 | if not isinstance(v, torch.Tensor):
141 | v = v.resize(size, *args, **kwargs)
142 | bbox.add_field(k, v)
143 | return bbox
144 |
145 | ratio_width, ratio_height = ratios
146 | xmin, ymin, xmax, ymax = self._split_into_xyxy()
147 | scaled_xmin = xmin * ratio_width
148 | scaled_xmax = xmax * ratio_width
149 | scaled_ymin = ymin * ratio_height
150 | scaled_ymax = ymax * ratio_height
151 | scaled_box = torch.cat(
152 | (scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1)
153 | bbox = BoxList(scaled_box, size, mode="xyxy")
154 | # bbox._copy_extra_fields(self)
155 | for k, v in self.extra_fields.items():
156 | if not isinstance(v, torch.Tensor):
157 | v = v.resize(size, *args, **kwargs)
158 | bbox.add_field(k, v)
159 |
160 | return bbox.convert(self.mode)
161 |
162 | def transpose(self, method):
163 | """Transpose bounding box (flip or rotate in 90 degree steps)
164 |
165 | Parameters
166 | ----------
167 | method: str
168 | One of:py:attr:`PIL.Image.FLIP_LEFT_RIGHT`,
169 | :py:attr:`PIL.Image.FLIP_TOP_BOTTOM`,:py:attr:`PIL.Image.ROTATE_90`,
170 | :py:attr:`PIL.Image.ROTATE_180`,:py:attr:`PIL.Image.ROTATE_270`,
171 | :py:attr:`PIL.Image.TRANSPOSE` or:py:attr:`PIL.Image.TRANSVERSE`.
172 | """
173 | if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM, ROTATE_90):
174 | raise NotImplementedError(
175 | "Only FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM and ROTATE_90 implemented"
176 | )
177 |
178 | image_width, image_height = self.size
179 | xmin, ymin, xmax, ymax = self._split_into_xyxy()
180 | if method == FLIP_LEFT_RIGHT:
181 | TO_REMOVE = 1
182 | transposed_xmin = image_width - xmax - TO_REMOVE
183 | transposed_xmax = image_width - xmin - TO_REMOVE
184 | transposed_ymin = ymin
185 | transposed_ymax = ymax
186 | elif method == FLIP_TOP_BOTTOM:
187 | transposed_xmin = xmin
188 | transposed_xmax = xmax
189 | transposed_ymin = image_height - ymax
190 | transposed_ymax = image_height - ymin
191 | elif method == ROTATE_90:
192 | transposed_xmin = ymin * image_width / image_height
193 | transposed_xmax = ymax * image_width / image_height
194 | transposed_ymin = (image_width - xmax) * image_height / image_width
195 | transposed_ymax = (image_width - xmin) * image_height / image_width
196 |
197 | transposed_boxes = torch.cat((transposed_xmin, transposed_ymin,
198 | transposed_xmax, transposed_ymax),
199 | dim=-1)
200 | bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
201 | # bbox._copy_extra_fields(self)
202 | for k, v in self.extra_fields.items():
203 | if not isinstance(v, torch.Tensor):
204 | v = v.transpose(method)
205 | bbox.add_field(k, v)
206 | return bbox.convert(self.mode)
207 |
208 | def translate(self, x_offset, y_offset):
209 | """Translate bounding box.
210 |
211 | Parameters
212 | ----------
213 | x_offseflt: float
214 | x offset
215 | y_offset: float
216 | y offset
217 | """
218 | xmin, ymin, xmax, ymax = self._split_into_xyxy()
219 |
220 | translated_xmin = xmin + x_offset
221 | translated_xmax = xmax + x_offset
222 | translated_ymin = ymin + y_offset
223 | translated_ymax = ymax + y_offset
224 |
225 | translated_boxes = torch.cat((translated_xmin, translated_ymin,
226 | translated_xmax, translated_ymax),
227 | dim=-1)
228 | bbox = BoxList(translated_boxes, self.size, mode="xyxy")
229 | for k, v in self.extra_fields.items():
230 | if not isinstance(v, torch.Tensor):
231 | v = v.translate(x_offset, y_offset)
232 | bbox.add_field(k, v)
233 | return bbox.convert(self.mode)
234 |
235 | def crop(self, box):
236 | """Crop a rectangular region from this bounding box.
237 |
238 | Parameters
239 | ----------
240 | box: tuple
241 | The box is a 4-tuple defining the left, upper, right, and lower pixel
242 | coordinate.
243 | """
244 | xmin, ymin, xmax, ymax = self._split_into_xyxy()
245 | w, h = box[2] - box[0], box[3] - box[1]
246 | cropped_xmin = (xmin - box[0]).clamp(min=0, max=w)
247 | cropped_ymin = (ymin - box[1]).clamp(min=0, max=h)
248 | cropped_xmax = (xmax - box[0]).clamp(min=0, max=w)
249 | cropped_ymax = (ymax - box[1]).clamp(min=0, max=h)
250 |
251 | # TODO should I filter empty boxes here?
252 | if False:
253 | is_empty = (cropped_xmin == cropped_xmax) | (
254 | cropped_ymin == cropped_ymax)
255 |
256 | cropped_box = torch.cat(
257 | (cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1)
258 | bbox = BoxList(cropped_box, (w, h), mode="xyxy")
259 | # bbox._copy_extra_fields(self)
260 | for k, v in self.extra_fields.items():
261 | if not isinstance(v, torch.Tensor):
262 | v = v.crop(box)
263 | bbox.add_field(k, v)
264 | return bbox.convert(self.mode)
265 |
266 | def augmentation_crop(self, top, left, crop_height, crop_width):
267 | """Random cropping of the bounding box (bbox).
268 | This function is created for label box to be crop at training time.
269 |
270 | Parameters:
271 | -----------
272 | top: int
273 | Top pixel position of crop area
274 |
275 | left: int
276 | left pixel position of crop area
277 |
278 | crop_height: int
279 | Height of crop area
280 |
281 | crop_width: int
282 | Width of crop area
283 |
284 | Returns:
285 | --------
286 | bbox_cropped: BoxList
287 | A BoxList object with instances after cropping. If no valid instance is left after
288 | cropping, return None.
289 |
290 |
291 | """
292 | # SegmentationMasks object
293 | masks = self.extra_fields["masks"]
294 |
295 | # Conduct mask level cropping and return only the valid ones left.
296 | masks_cropped, keep_ids = masks.augmentation_crop(
297 | top, left, crop_height, crop_width)
298 |
299 | # the return cropped mask should be in "poly" mode
300 | if not keep_ids:
301 | return None
302 | assert masks_cropped.mode == "poly"
303 | bbox_cropped = []
304 | labels = self.extra_fields["labels"]
305 | labels_cropped = [labels[idx] for idx in keep_ids]
306 | labels_cropped = torch.as_tensor(labels_cropped, dtype=torch.long)
307 |
308 | crop_box_xyxy = [
309 | float(left),
310 | float(top),
311 | float(left + crop_width),
312 | float(top + crop_height)
313 | ]
314 | # Crop bounding box.
315 | # Note: this function will not change "masks"
316 | self.extra_fields.pop("masks", None)
317 | new_bbox = self.crop(crop_box_xyxy).convert("xyxy")
318 |
319 | # Further clip the boxes according to the clipped masks.
320 | for mask_id, box_id in enumerate(keep_ids):
321 | x1, y1, x2, y2 = new_bbox.bbox[box_id].numpy()
322 |
323 | # only resize the box on the edge:
324 | if x1 > 0 and y1 > 0 and x2 < crop_width - 1 and y2 < crop_height - 1:
325 | bbox_cropped.append([x1, y1, x2, y2])
326 | else:
327 | # get PolygonInstance for current instance
328 | current_polygon_instance = masks_cropped.instances.polygons[
329 | mask_id]
330 | x_ids = []
331 | y_ids = []
332 | for poly in current_polygon_instance.polygons:
333 | p = poly.clone()
334 | x_ids.extend(p[0::2])
335 | y_ids.extend(p[1::2])
336 | bbox_cropped.append(
337 | [min(x_ids),
338 | min(y_ids),
339 | max(x_ids),
340 | max(y_ids)])
341 | bbox_cropped = BoxList(
342 | bbox_cropped, (crop_width, crop_height), mode="xyxy")
343 | bbox_cropped = bbox_cropped.convert(self.mode)
344 | bbox_cropped.add_field("masks", masks_cropped)
345 | bbox_cropped.add_field("labels", labels_cropped)
346 | return bbox_cropped
347 |
348 | def to(self, device):
349 | """Move object to torch device.
350 | """
351 | bbox = BoxList(self.bbox.to(device), self.size, self.mode)
352 | for k, v in self.extra_fields.items():
353 | if hasattr(v, "to"):
354 | v = v.to(device)
355 | bbox.add_field(k, v)
356 | return bbox
357 |
358 | def __getitem__(self, item):
359 | """Get a sub-list of Boxlist as a new Boxlist
360 | """
361 | item_bbox = self.bbox[item]
362 | if len(item_bbox.shape) < 2:
363 | item_bbox.unsqueeze(0)
364 | bbox = BoxList(item_bbox, self.size, self.mode)
365 | for k, v in self.extra_fields.items():
366 | bbox.add_field(k, v[item])
367 | return bbox
368 |
369 | def __len__(self):
370 | return self.bbox.shape[0]
371 |
372 | def clip_to_image(self, remove_empty=True):
373 | """Clip bounding box coordinates according to the image range.
374 | """
375 | TO_REMOVE = 1
376 | self.bbox[:, 0].clamp_(min=0, max=self.size[0] - TO_REMOVE)
377 | self.bbox[:, 1].clamp_(min=0, max=self.size[1] - TO_REMOVE)
378 | self.bbox[:, 2].clamp_(min=0, max=self.size[0] - TO_REMOVE)
379 | self.bbox[:, 3].clamp_(min=0, max=self.size[1] - TO_REMOVE)
380 | if remove_empty:
381 | box = self.bbox
382 | keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
383 | return self[keep]
384 | return self
385 |
386 | def area(self, idx=None):
387 | """Get bounding box area.
388 | """
389 | box = self.bbox if idx is None else self.bbox[idx].unsqueeze(0)
390 | if self.mode == "xyxy":
391 | TO_REMOVE = 1
392 | area = (box[:, 2] - box[:, 0] + TO_REMOVE) * \
393 | (box[:, 3] - box[:, 1] + TO_REMOVE)
394 | elif self.mode == "xywh":
395 | area = box[:, 2] * box[:, 3]
396 | else:
397 | raise RuntimeError("Should not be here")
398 |
399 | return area
400 |
401 | def copy_with_fields(self, fields, skip_missing=False):
402 | """Provide deep copy of Boxlist with requested fields.
403 | """
404 | bbox = BoxList(self.bbox, self.size, self.mode)
405 | if not isinstance(fields, (list, tuple)):
406 | fields = [fields]
407 | for field in fields:
408 | if self.has_field(field):
409 | bbox.add_field(field, self.get_field(field))
410 | elif not skip_missing:
411 | raise KeyError("Field '{}' not found in {}".format(
412 | field, self))
413 | return bbox
414 |
415 | def __repr__(self):
416 | s = self.__class__.__name__ + "("
417 | s += "num_boxes={}, ".format(len(self))
418 | s += "image_width={}, ".format(self.size[0])
419 | s += "image_height={}, ".format(self.size[1])
420 | s += "mode={})".format(self.mode)
421 | return s
422 |
423 |
424 | if __name__ == "__main__":
425 | bbox = BoxList([[0, 0, 10, 10], [0, 0, 5, 5]], (10, 10))
426 | s_bbox = bbox.resize((5, 5))
427 | print(s_bbox)
428 | print(s_bbox.bbox)
429 |
430 | t_bbox = bbox.transpose(0)
431 | print(t_bbox)
432 | print(t_bbox.bbox)
433 |
--------------------------------------------------------------------------------
/realtime_panoptic/utils/boxlist_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | # Copyright 2020 Toyota Research Institute. All rights reserved.
3 |
4 | # Adapted from maskrcnn-benchmark
5 | # https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/structures/boxlist_ops.py
6 | import torch
7 | from torchvision.ops.boxes import nms as _box_nms
8 |
9 | from .bounding_box import BoxList
10 |
11 |
12 | def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"):
13 | """Performs non-maximum suppression on a boxlist.
14 | The ranking scores are specified in a boxlist field via score_field.
15 |
16 | Parameters
17 | ----------
18 | boxlist : BoxList
19 | Original boxlist
20 |
21 | nms_thresh : float
22 | NMS threshold
23 |
24 | max_proposals : int
25 | If > 0, then only the top max_proposals are kept after non-maximum suppression
26 |
27 | score_field : str
28 | Boxlist field to use during NMS score ranking. Field value needs to be numeric.
29 | """
30 | if nms_thresh <= 0:
31 | return boxlist
32 | mode = boxlist.mode
33 | boxlist = boxlist.convert("xyxy")
34 | boxes = boxlist.bbox
35 | score = boxlist.get_field(score_field)
36 | keep = _box_nms(boxes, score, nms_thresh)
37 | if max_proposals > 0:
38 | keep = keep[:max_proposals]
39 | boxlist = boxlist[keep]
40 | return boxlist.convert(mode)
41 |
42 |
43 | def remove_small_boxes(boxlist, min_size):
44 | """Only keep boxes with both sides >= min_size
45 |
46 | Parameters
47 | ----------
48 | boxlist : Boxlist
49 | Original boxlist
50 |
51 | min_size : int
52 | Max edge dimension of boxes to be kept.
53 | """
54 | # TODO maybe add an API for querying the ws / hs
55 | xywh_boxes = boxlist.convert("xywh").bbox
56 | _, _, ws, hs = xywh_boxes.unbind(dim=1)
57 | keep = ((ws >= min_size) & (hs >= min_size)).nonzero().squeeze(1)
58 | return boxlist[keep]
59 |
60 |
61 | # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
62 | # with slight modifications
63 | def boxlist_iou(boxlist1, boxlist2, optimize_memory=False):
64 | """Compute the intersection over union of two set of boxes.
65 | The box order must be (xmin, ymin, xmax, ymax).
66 |
67 | Parameters
68 | ----------
69 | box1: BoxList
70 | Bounding boxes, sized [N,4].
71 | box2: BoxList
72 | Bounding boxes, sized [M,4].
73 |
74 |
75 | Returns
76 | -------
77 | iou : tensor
78 | IoU of input boxes in matrix form, sized [N,M].
79 |
80 | Reference:
81 | https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py
82 | """
83 | if boxlist1.size != boxlist2.size:
84 | raise RuntimeError("boxlists should have same image size, got {}, {}".format(boxlist1, boxlist2))
85 |
86 | N = len(boxlist1)
87 | M = len(boxlist2)
88 |
89 | area2 = boxlist2.area()
90 |
91 | if not optimize_memory:
92 |
93 | # If not optimizing memory, then following original ``maskrcnn-benchmark`` implementation
94 |
95 | area1 = boxlist1.area()
96 |
97 | box1, box2 = boxlist1.bbox, boxlist2.bbox
98 |
99 | lt = torch.max(box1[:, None, :2], box2[:, :2]) # shape: (N, M, 2)
100 | rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # shape: (N, M, 2)
101 |
102 | TO_REMOVE = 1
103 |
104 | wh = (rb - lt + TO_REMOVE).clamp(min=0) # shape: (N, M, 2)
105 | inter = wh[:, :, 0] * wh[:, :, 1] # shape: (N, M)
106 |
107 | iou = inter / (area1[:, None] + area2 - inter)
108 |
109 | else:
110 |
111 | # If optimizing memory, construct IoU matrix one box1 entry at a time
112 | # (in current usage this means one GT at a time)
113 |
114 | # Entry i of ious will hold the IoU between the ith box in boxlist1 and all boxes
115 | # in boxlist2
116 | ious = []
117 |
118 | box2 = boxlist2.bbox
119 |
120 | for i in range(N):
121 | area1 = boxlist1.area(i)
122 |
123 | box1 = boxlist1.bbox[i].unsqueeze(0)
124 |
125 | lt = torch.max(box1[:, None, :2], box2[:, :2]) # shape: (1, M, 2)
126 | rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # shape: (1, M, 2)
127 |
128 | TO_REMOVE = 1
129 | wh = (rb - lt + TO_REMOVE).clamp(min=0) # shape: (1, M, 2)
130 |
131 | inter = wh[:, :, 0] * wh[:, :, 1] # shape: (1, M)
132 |
133 | iou = inter / (area1 + area2 - inter)
134 |
135 | ious.append(iou)
136 |
137 | iou = torch.cat(ious) # shape: (N, M)
138 |
139 | return iou
140 |
141 |
142 | def cat_boxlist(bboxes):
143 | """Concatenates a list of BoxList into a single BoxList
144 | image sizes needs to be same in this operation.
145 |
146 | Parameters
147 | ----------
148 | bboxes : list[BoxList]
149 | """
150 | assert isinstance(bboxes, (list, tuple))
151 | assert all(isinstance(bbox, BoxList) for bbox in bboxes)
152 |
153 | size = bboxes[0].size
154 | assert all(bbox.size == size for bbox in bboxes)
155 |
156 | mode = bboxes[0].mode
157 | assert all(bbox.mode == mode for bbox in bboxes)
158 |
159 | fields = set(bboxes[0].fields())
160 | assert all(set(bbox.fields()) == fields for bbox in bboxes)
161 |
162 | cat_boxes = BoxList(torch.cat([bbox.bbox for bbox in bboxes], dim=0), size, mode)
163 |
164 | for field in fields:
165 | data = torch.cat([bbox.get_field(field) for bbox in bboxes], dim=0)
166 | cat_boxes.add_field(field, data)
167 |
168 | return cat_boxes
169 |
170 |
171 | def pair_boxlist_iou(boxlist1, boxlist2):
172 | """Compute the intersection over union of two pairs of boxes.
173 | The box order must be (xmin, ymin, xmax, ymax).
174 |
175 | Parameters
176 | ----------
177 | box1 : BoxList
178 | Bounding boxes, sized [N,4].
179 | box2 : BoxList
180 | Bounding boxes, sized [N,4].
181 |
182 | Returns
183 | -------
184 | iou : tensor,
185 | Tensor of iou between the input pair of boxes. sized [N].
186 |
187 | Reference:
188 | https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py
189 | """
190 | if boxlist1.size != boxlist2.size:
191 | raise RuntimeError("boxlists should have same image size, got {}, {}".format(boxlist1, boxlist2))
192 |
193 | assert len(boxlist1) == len(boxlist2), "Two boxlists should have same length"
194 | N = len(boxlist1)
195 |
196 | area2 = boxlist2.area()
197 | area1 = boxlist1.area()
198 |
199 | box1, box2 = boxlist1.bbox, boxlist2.bbox
200 | lt = torch.max(box1[:, :2], box2[:, :2]) # shape: (N, 2)
201 | rb = torch.min(box1[:, 2:], box2[:, 2:]) # shape: (N, 2)
202 | TO_REMOVE = 1
203 | wh = (rb - lt + TO_REMOVE).clamp(min=0) # shape: (N, 2)
204 | inter = wh[:, 0] * wh[:, 1] # shape: (N, 1)
205 | iou = inter / (area1 + area2 - inter)
206 | return iou
207 |
--------------------------------------------------------------------------------
/realtime_panoptic/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 |
7 | DETECTRON_PALETTE = np.array(
8 | [
9 | 0.000, 0.447, 0.741,
10 | 0.850, 0.325, 0.098,
11 | 0.929, 0.694, 0.125,
12 | 0.494, 0.184, 0.556,
13 | 0.466, 0.674, 0.188,
14 | 0.301, 0.745, 0.933,
15 | 0.635, 0.078, 0.184,
16 | 0.300, 0.300, 0.300,
17 | 0.600, 0.600, 0.600,
18 | 1.000, 0.000, 0.000,
19 | 1.000, 0.500, 0.000,
20 | 0.749, 0.749, 0.000,
21 | 0.000, 1.000, 0.000,
22 | 0.000, 0.000, 1.000,
23 | 0.667, 0.000, 1.000,
24 | 0.333, 0.333, 0.000,
25 | 0.333, 0.667, 0.000,
26 | 0.333, 1.000, 0.000,
27 | 0.667, 0.333, 0.000,
28 | 0.667, 0.667, 0.000,
29 | 0.667, 1.000, 0.000,
30 | 1.000, 0.333, 0.000,
31 | 1.000, 0.667, 0.000,
32 | 1.000, 1.000, 0.000,
33 | 0.000, 0.333, 0.500,
34 | 0.000, 0.667, 0.500,
35 | 0.000, 1.000, 0.500,
36 | 0.333, 0.000, 0.500,
37 | 0.333, 0.333, 0.500,
38 | 0.333, 0.667, 0.500,
39 | 0.333, 1.000, 0.500,
40 | 0.667, 0.000, 0.500,
41 | 0.667, 0.333, 0.500,
42 | 0.667, 0.667, 0.500,
43 | 0.667, 1.000, 0.500,
44 | 1.000, 0.000, 0.500,
45 | 1.000, 0.333, 0.500,
46 | 1.000, 0.667, 0.500,
47 | 1.000, 1.000, 0.500,
48 | 0.000, 0.333, 1.000,
49 | 0.000, 0.667, 1.000,
50 | 0.000, 1.000, 1.000,
51 | 0.333, 0.000, 1.000,
52 | 0.333, 0.333, 1.000,
53 | 0.333, 0.667, 1.000,
54 | 0.333, 1.000, 1.000,
55 | 0.667, 0.000, 1.000,
56 | 0.667, 0.333, 1.000,
57 | 0.667, 0.667, 1.000,
58 | 0.667, 1.000, 1.000,
59 | 1.000, 0.000, 1.000,
60 | 1.000, 0.333, 1.000,
61 | 1.000, 0.667, 1.000,
62 | 0.167, 0.000, 0.000,
63 | 0.333, 0.000, 0.000,
64 | 0.500, 0.000, 0.000,
65 | 0.667, 0.000, 0.000,
66 | 0.833, 0.000, 0.000,
67 | 1.000, 0.000, 0.000,
68 | 0.000, 0.167, 0.000,
69 | 0.000, 0.333, 0.000,
70 | 0.000, 0.500, 0.000,
71 | 0.000, 0.667, 0.000,
72 | 0.000, 0.833, 0.000,
73 | 0.000, 1.000, 0.000,
74 | 0.000, 0.000, 0.167,
75 | 0.000, 0.000, 0.333,
76 | 0.000, 0.000, 0.500,
77 | 0.000, 0.000, 0.667,
78 | 0.000, 0.000, 0.833,
79 | 0.000, 0.000, 1.000,
80 | 0.000, 0.000, 0.000,
81 | 0.143, 0.143, 0.143,
82 | 0.286, 0.286, 0.286,
83 | 0.429, 0.429, 0.429,
84 | 0.571, 0.571, 0.571,
85 | 0.714, 0.714, 0.714,
86 | 0.857, 0.857, 0.857,
87 | 1.000, 1.000, 1.000
88 | ]
89 | ).astype(np.float32).reshape(-1, 3) * 255
90 |
91 |
92 | def visualize_segmentation_image(predictions, original_image, colormap, fade_weight=0.5):
93 | """Log a single segmentation result for visualization using a colormap.
94 |
95 | Overlays predicted classes on top of raw RGB image if given.
96 |
97 | Parameters:
98 | -----------
99 | predictions: torch.cuda.LongTensor
100 | Per-pixel predicted class ID's for a single input image
101 | Shape: (H, W)
102 |
103 | original_image: np.array
104 | HxWx3 original image. or None
105 |
106 | colormap: np.array
107 | (N+1)x3 array colormap,where N+1 equals to the number of classes.
108 |
109 | fade_weight: float, default: 0.8
110 | Visualization is fade_weight * original_image + (1 - fade_weight) * predictions
111 |
112 | Returns:
113 | --------
114 | visualized_image: np.array
115 | Semantic semantic visualization color coded by classes.
116 | The visualization will be overlaid on a the RGB image if given.
117 | """
118 |
119 | # ``original_image`` has shape (H, W,3)
120 | if not isinstance(original_image, np.ndarray):
121 | original_image = np.array(original_image)
122 | original_image_height, original_image_width,_ = original_image.shape
123 |
124 | # Grab colormap from dataset for the given number of segmentation classes
125 | # (uses black for the IGNORE class)
126 | # ``colormap`` has shape (num_classes + 1, 3)
127 |
128 | # Color per-pixel predictions using the generated color map
129 | # ``colored_predictions_numpy`` has shape (H, W, 3)
130 | predictions_numpy = predictions.cpu().numpy().astype('uint8')
131 | colored_predictions_numpy = colormap[predictions_numpy.flatten()]
132 | colored_predictions_numpy = colored_predictions_numpy.reshape(original_image_height, original_image_width, 3)
133 |
134 | # Overlay images and predictions
135 | overlaid_predictions = original_image * fade_weight + colored_predictions_numpy * (1 - fade_weight)
136 |
137 | visualized_image = overlaid_predictions.astype('uint8')
138 | return visualized_image
139 |
140 | def random_color(base, max_dist=30):
141 | """Generate random color close to a given base color.
142 |
143 | Parameters:
144 | -----------
145 | base: array_like
146 | Base color for random color generation
147 |
148 | max_dist: int
149 | Max distance from generated color to base color on all RGB axis.
150 |
151 | Returns:
152 | --------
153 | random_color: tuple
154 | 3 channel random color around the given base color.
155 | """
156 | base = np.array(base)
157 | new_color = base + np.random.randint(low=-max_dist, high=max_dist + 1, size=3)
158 | return tuple(np.maximum(0, np.minimum(255, new_color)))
159 |
160 | def draw_mask(im, mask, alpha=0.5, color=None):
161 | """Overlay a mask on top of the image.
162 |
163 | Parameters:
164 | -----------
165 | im: array_like
166 | A 3-channel uint8 image
167 |
168 | mask: array_like
169 | A binary 1-channel image of the same size
170 |
171 | color: bool
172 | If None, will choose automatically
173 |
174 | alpha: float
175 | mask intensity
176 |
177 | Returns:
178 | --------
179 | im: np.array
180 | Image overlaid by masks.
181 |
182 | color: list
183 | Color used for masks.
184 | """
185 | if color is None:
186 | color = DETECTRON_PALETTE[np.random.choice(len(DETECTRON_PALETTE))][::-1]
187 | color = np.asarray(color, dtype=np.int64)
188 | im = np.where(np.repeat((mask > 0)[:, :, None], 3, axis=2), im * (1 - alpha) + color * alpha, im)
189 | im = im.astype('uint8')
190 | return im, color.tolist()
191 |
192 | def visualize_detection_image(predictions, original_image, label_id_to_names, fade_weight=0.8):
193 | """Log a single detection result for visualization.
194 |
195 | Overlays predicted classes on top of raw RGB image.
196 |
197 | Parameters:
198 | -----------
199 | predictions: torch.cuda.LongTensor
200 | Per-pixel predicted class ID's for a single input image
201 | Shape: (H, W)
202 |
203 | original_image: np.array
204 | HxWx3 original image. or None
205 |
206 | label_id_to_names: list
207 | list of class names for instance labels
208 |
209 | fade_weight: float, default: 0.8
210 | Visualization is fade_weight * original_image + (1 - fade_weight) * predictions
211 |
212 | Returns:
213 | -------
214 | visualized_image: np.array
215 | Visualized image with detection results.
216 | """
217 |
218 | # Load raw image using provided dataset and index
219 | # ``images_numpy`` has shape (H, W, 3)
220 | # ``images_numpy`` has shape (H, W,3)
221 | if not isinstance(original_image, np.ndarray):
222 | original_image = np.array(original_image)
223 | original_image_height, original_image_width,_ = original_image.shape
224 |
225 | # overlay_boxes
226 | visualized_image = copy.copy(np.array(original_image))
227 |
228 | labels = predictions.get_field("labels").to("cpu")
229 | boxes = predictions.bbox
230 |
231 | dtype = labels.dtype
232 | palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1]).to(dtype)
233 | colors = labels[:, None] * palette
234 | colors = (colors % 255).numpy().astype("uint8")
235 | masks = None
236 | if predictions.has_field("mask"):
237 | masks = predictions.get_field("mask")
238 | else:
239 | masks = [None] * len(boxes)
240 | # overlay_class_names_and_score
241 | if predictions.has_field("scores"):
242 | scores = predictions.get_field("scores").tolist()
243 | else:
244 | scores = [1.0] * len(boxes)
245 | # predicted label starts from 1 as 0 is reserved for background.
246 | label_names = [label_id_to_names[i-1] for i in labels.tolist()]
247 |
248 | text_template = "{}: {:.2f}"
249 |
250 | for box, color, score, mask, label in zip(boxes, colors, scores, masks, label_names):
251 | if score < 0.5:
252 | continue
253 | box = box.to(torch.int64)
254 | color = random_color(color)
255 | color = tuple(map(int, color))
256 |
257 | if mask is not None:
258 | thresh = (mask > 0.5).cpu().numpy().astype('uint8')
259 | visualized_image, color = draw_mask(visualized_image, thresh)
260 |
261 | x, y = box[:2]
262 | s = text_template.format(label, score)
263 | cv2.putText(visualized_image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1)
264 |
265 | top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
266 | visualized_image = cv2.rectangle(visualized_image, tuple(top_left), tuple(bottom_right), tuple(color), 1)
267 | return visualized_image
--------------------------------------------------------------------------------
/scripts/demo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Toyota Research Institute. All rights reserved.
2 |
3 | # This script provides a demo inference a model trained on Cityscapes dataset.
4 | import warnings
5 | import argparse
6 | import torch
7 | import numpy as np
8 | from PIL import Image
9 | from torchvision.models.detection.image_list import ImageList
10 |
11 | from realtime_panoptic.models.rt_pano_net import RTPanoNet
12 | from realtime_panoptic.config import cfg
13 | import realtime_panoptic.data.panoptic_transform as P
14 | from realtime_panoptic.utils.visualization import visualize_segmentation_image,visualize_detection_image
15 |
16 | cityscapes_colormap = np.array([
17 | [128, 64, 128],
18 | [244, 35, 232],
19 | [ 70, 70, 70],
20 | [102, 102, 156],
21 | [190, 153, 153],
22 | [153, 153, 153],
23 | [250 ,170, 30],
24 | [220, 220, 0],
25 | [107, 142, 35],
26 | [152, 251, 152],
27 | [ 70, 130, 180],
28 | [220, 20, 60],
29 | [255, 0, 0],
30 | [ 0, 0, 142],
31 | [ 0, 0, 70],
32 | [ 0, 60, 100],
33 | [ 0, 80, 100],
34 | [ 0, 0, 230],
35 | [119, 11, 32],
36 | [ 0, 0, 0]])
37 |
38 | cityscapes_instance_label_name = ['person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle']
39 | warnings.filterwarnings("ignore", category=UserWarning)
40 |
41 | def demo():
42 | # Parse the input arguments.
43 | parser = argparse.ArgumentParser(description="Simple demo for real-time-panoptic model")
44 | parser.add_argument("--config-file", metavar="FILE", help="path to config", required=True)
45 | parser.add_argument("--pretrained-weight", metavar="FILE", help="path to pretrained_weight", required=True)
46 | parser.add_argument("--input", metavar="FILE", help="path to jpg/png file", required=True)
47 | parser.add_argument("--device", help="inference device", default='cuda')
48 | args = parser.parse_args()
49 |
50 | # General config object from given config files.
51 | cfg.merge_from_file(args.config_file)
52 |
53 | # Initialize model.
54 | model = RTPanoNet(
55 | backbone=cfg.model.backbone,
56 | num_classes=cfg.model.panoptic.num_classes,
57 | things_num_classes=cfg.model.panoptic.num_thing_classes,
58 | pre_nms_thresh=cfg.model.panoptic.pre_nms_thresh,
59 | pre_nms_top_n=cfg.model.panoptic.pre_nms_top_n,
60 | nms_thresh=cfg.model.panoptic.nms_thresh,
61 | fpn_post_nms_top_n=cfg.model.panoptic.fpn_post_nms_top_n,
62 | instance_id_range=cfg.model.panoptic.instance_id_range)
63 | device = args.device
64 | model.to(device)
65 | model.load_state_dict(torch.load(args.pretrained_weight))
66 |
67 | # Print out mode architecture for sanity checking.
68 | print(model)
69 |
70 | # Prepare for model inference.
71 | model.eval()
72 | input_image = Image.open(args.input)
73 | data = {'image': input_image}
74 | # data pre-processing
75 | normalize_transform = P.Normalize(mean=cfg.input.pixel_mean, std=cfg.input.pixel_std, to_bgr255=cfg.input.to_bgr255)
76 | transform = P.Compose([
77 | P.ToTensor(),
78 | normalize_transform,
79 | ])
80 | data = transform(data)
81 | print("Done with data preparation and model configuration.")
82 | with torch.no_grad():
83 | input_image_list = ImageList([data['image'].to(device)], image_sizes=[input_image.size[::-1]])
84 | panoptic_result, _ = model.forward(input_image_list)
85 | print("Done with model inference.")
86 | print("Process and visualizing the outputs...")
87 | instance_detection = [o.to('cpu') for o in panoptic_result["instance_segmentation_result"]]
88 | semseg_logics = [o.to('cpu') for o in panoptic_result["semantic_segmentation_result"]]
89 | semseg_prob = [torch.argmax(semantic_logit , dim=0) for semantic_logit in semseg_logics]
90 |
91 | seg_vis = visualize_segmentation_image(semseg_prob[0], input_image, cityscapes_colormap)
92 | Image.fromarray(seg_vis.astype('uint8')).save('semantic_segmentation_result.jpg')
93 | print("Saved semantic segmentation visualization in semantic_segmentation_result.jpg")
94 | det_vis = visualize_detection_image(instance_detection[0], input_image, cityscapes_instance_label_name)
95 | Image.fromarray(det_vis.astype('uint8')).save('instance_segmentation_result.jpg')
96 | print("Saved instance segmentation visualization in instance_segmentation_result.jpg")
97 | print("Demo finished.")
98 |
99 | if __name__ == "__main__":
100 | demo()
101 |
--------------------------------------------------------------------------------