├── .gitignore ├── LICENSE ├── README.md ├── datasets ├── __init__.py ├── cityscapes.py ├── kitti.py └── transforms │ ├── __init__.py │ └── transforms.py ├── loss ├── __init__.py ├── boxloss.py ├── focalloss.py └── multitaskloss.py ├── metrics ├── iou.py └── mean_ap.py ├── models ├── __init__.py ├── box2pix.py └── multibox.py ├── prediction ├── __init__.py └── predictor.py ├── requirements.txt ├── test.py ├── train.py └── utils ├── __init__.py ├── box_coder.py ├── box_utils.py ├── helper.py └── nms.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Sphinx documentation 53 | docs/_build/ 54 | 55 | # Jupyter Notebook 56 | .ipynb_checkpoints 57 | 58 | # pyenv 59 | .python-version 60 | 61 | # mkdocs documentation 62 | /site 63 | 64 | 65 | # Datasets folder 66 | data/ 67 | 68 | # PyCharm 69 | .idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Michael Kösel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [WIP] pytorch-box2pix ![alt text](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat) 2 | 3 | Inofficial PyTorch implementation of [Box2Pix: Single-Shot Instance Segmentation by Assigning Pixels to Object Boxes](https://lmb.informatik.uni-freiburg.de/Publications/2018/UB18) (Uhrig et al., 2018). 4 | 5 | ## TODO: 6 | 7 | This is needed to get the project in a state where it can be trained: 8 | 9 | - [ ] mAP metric 10 | 11 | Instance segmentation can be added later as it's just a post processing step. 12 | 13 | ## Requirements 14 | 15 | - Install PyTorch ([pytorch.org](http://pytorch.org)) 16 | - `pip install -r requirements.txt` (Currently requires torchvision master) 17 | - Download the Cityscapes dataset 18 | 19 | ## Usage 20 | 21 | Train model: 22 | 23 | ```bash 24 | python train.py --dataset-dir 'data/cityscapes' 25 | ``` 26 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cityscapes import * 2 | from .kitti import * 3 | -------------------------------------------------------------------------------- /datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.datasets as datasets 4 | from PIL import Image 5 | 6 | from utils.box_utils import get_bounding_box 7 | 8 | 9 | class CityscapesDataset(datasets.Cityscapes): 10 | 11 | def __init__(self, root, split='train', joint_transform=None, img_transform=None): 12 | super(CityscapesDataset, self).__init__(root, split, target_type=['instance', 'polygon']) 13 | 14 | self.joint_transform = joint_transform 15 | self.img_transform = img_transform 16 | 17 | def __getitem__(self, index): 18 | image, target = super(CityscapesDataset, self).__getitem__(index) 19 | instance, json = target 20 | 21 | instance = self._convert_id_to_train_id(instance) 22 | boxes, labels = self._create_boxes(json) 23 | 24 | if self.joint_transform: 25 | image, instance, boxes, labels = self.joint_transform(image, instance, boxes, labels) 26 | 27 | if self.img_transform: 28 | image = self.img_transform(image) 29 | 30 | return image, instance, boxes, labels 31 | 32 | def _convert_id_to_train_id(self, instance): 33 | instance = np.array(instance) 34 | instance_copy = instance.copy() 35 | 36 | for cls in self.classes: 37 | instance_copy[instance == cls.id] = cls.train_id 38 | instance = Image.fromarray(instance_copy.astype(np.uint8)) 39 | 40 | return instance 41 | 42 | def _create_boxes(self, json): 43 | boxes = [] 44 | labels = [] 45 | objects = json['objects'] 46 | for obj in objects: 47 | polygons = obj['polygon'] 48 | cls = self.get_class_from_name(obj['label']) 49 | if cls and cls.has_instances: 50 | boxes.append(get_bounding_box(polygons)) 51 | labels.append(cls.id) 52 | 53 | return torch.as_tensor(boxes, dtype=torch.float32), torch.tensor(labels, dtype=torch.int64) 54 | 55 | @staticmethod 56 | def get_class_from_name(name): 57 | for cls in CityscapesDataset.classes: 58 | if cls.name == name: 59 | return cls 60 | return None 61 | 62 | @staticmethod 63 | def get_class_from_id(id): 64 | for cls in CityscapesDataset.classes: 65 | if cls.id == id: 66 | return cls 67 | return None 68 | 69 | @staticmethod 70 | def get_instance_classes(): 71 | return [cls for cls in CityscapesDataset.classes if cls.has_instances] 72 | 73 | @staticmethod 74 | def num_instance_classes(): 75 | return len(CityscapesDataset.get_instance_classes()) 76 | 77 | @staticmethod 78 | def get_colormap(): 79 | cmap = torch.zeros([256, 3], dtype=torch.uint8) 80 | 81 | for cls in CityscapesDataset.classes: 82 | if cls.has_instances: 83 | cmap[cls.trainId, :] = torch.tensor(cls.color, dtype=torch.uint8) 84 | 85 | return cmap 86 | -------------------------------------------------------------------------------- /datasets/kitti.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.utils.data as data 5 | import torchvision 6 | from PIL import Image 7 | import os 8 | import os.path 9 | import errno 10 | 11 | import json 12 | import os 13 | 14 | import torch.utils.data as data 15 | from PIL import Image 16 | from matplotlib.patches import Rectangle 17 | 18 | from datasets.transforms.transforms import Compose, RandomHorizontalFlip, Resize, ToTensor 19 | 20 | 21 | class KITTI(data.Dataset): 22 | """`KITTI `_ Dataset. 23 | 24 | Args: 25 | root (string): Root directory of dataset where directory ``leftImg8bit`` 26 | and ``gtFine`` or ``gtCoarse`` are located. 27 | train (bool, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine" 28 | otherwise ``train``, ``train_extra`` or ``val`` 29 | target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon`` 30 | or ``color``. 31 | transform (callable, optional): A function/transform that takes in a PIL image 32 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 33 | target_transform (callable, optional): A function/transform that takes in the 34 | target and transforms it. 35 | """ 36 | 37 | def __init__(self, root, train=True, target_type='instance', joint_transform=None, img_transform=None): 38 | self.root = os.path.expanduser(root) 39 | split = 'training' if train else 'testing' 40 | self.images_dir = os.path.join(self.root, split, 'image_2') 41 | self.targets_dir = os.path.join(self.root, split, target_type) 42 | self.joint_transform = joint_transform 43 | self.img_transform = img_transform 44 | self.target_type = target_type 45 | self.train = train 46 | self.images = [] 47 | self.targets = [] 48 | 49 | if target_type not in ['instance', 'semantic', 'semantic_rgb']: 50 | raise ValueError('Invalid value for "target_type"! Valid values are: "instance", "semantic"' 51 | ' or "semantic_rgb"') 52 | 53 | if not os.path.isdir(self.images_dir) and not os.path.isdir(self.targets_dir): 54 | raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' 55 | ' specified "split" and "mode" are inside the "root" directory') 56 | 57 | for file_name in os.listdir(self.images_dir): 58 | self.images.append(os.path.join(self.images_dir, file_name)) 59 | if train: 60 | self.targets.append(os.path.join(self.targets_dir, file_name)) 61 | 62 | def __getitem__(self, index): 63 | if self.train: 64 | image = Image.open(self.images[index]).convert('RGB') 65 | target = Image.open(self.targets[index]) 66 | 67 | boxes = [(325, 170, 475, 240), (555, 165, 695, 220), (720, 155, 900, 240)] 68 | confs = torch.tensor([1, 1, 1]) 69 | 70 | if self.joint_transform: 71 | image, target, boxes = self.joint_transform(image, target, boxes) 72 | 73 | if self.img_transform: 74 | image = self.img_transform(image) 75 | 76 | return image, target, boxes, confs 77 | else: 78 | image = Image.open(self.images[index]).convert('RGB') 79 | 80 | if self.img_transform: 81 | image = self.img_transform(image) 82 | 83 | return image 84 | 85 | def __len__(self): 86 | return len(self.images) 87 | 88 | 89 | if __name__ == '__main__': 90 | import matplotlib.pyplot as plt 91 | 92 | joint_transforms = Compose([ 93 | RandomHorizontalFlip(), 94 | #ToTensor() 95 | ]) 96 | 97 | dataset = KITTI('../data/kitti', train=True, joint_transform=joint_transforms) 98 | img, inst, bboxes, confs = dataset[10] 99 | 100 | #print('Box size: ', bboxes.size()) 101 | #print('Instance size: ', inst.size()) 102 | #img = torchvision.transforms.functional.to_pil_image(img) 103 | #plt.imshow(img) 104 | 105 | #inst = torchvision.transforms.functional.to_pil_image(inst) 106 | plt.imshow(inst) 107 | ax = plt.gca() 108 | 109 | for i, box in enumerate(bboxes): 110 | xmin, ymin, xmax, ymax = box 111 | width = xmax - xmin 112 | height = ymax - ymin 113 | 114 | rect = Rectangle((xmin, ymin), width, height, linewidth=1, edgecolor='r', facecolor='none') 115 | ax.add_patch(rect) 116 | 117 | plt.show() 118 | -------------------------------------------------------------------------------- /datasets/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | -------------------------------------------------------------------------------- /datasets/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision.transforms.functional as F 6 | from PIL import Image 7 | from skimage.filters import gaussian 8 | 9 | 10 | class Compose(object): 11 | def __init__(self, transforms): 12 | self.transforms = transforms 13 | 14 | def __call__(self, img, inst, boxes, labels): 15 | for t in self.transforms: 16 | img, inst, boxes, labels = t(img, inst, boxes, labels) 17 | return img, inst, boxes, labels 18 | 19 | 20 | class ToTensor(object): 21 | def __call__(self, img, inst, boxes, labels): 22 | img = F.to_tensor(img) 23 | inst = F.to_tensor(inst).long() 24 | 25 | return img, inst, boxes, labels 26 | 27 | 28 | class Resize(object): 29 | def __init__(self, new_size, old_size=(1024, 2048)): 30 | self.old_size = old_size 31 | self.new_size = new_size 32 | 33 | self.xscale = self.new_size[1] / self.old_size[1] 34 | self.yscale = self.new_size[0] / self.old_size[0] 35 | 36 | def __call__(self, img, inst, boxes, labels): 37 | img = F.resize(img, self.new_size, interpolation=Image.BILINEAR) 38 | inst = F.resize(inst, self.new_size, interpolation=Image.NEAREST) 39 | boxes = self._resize_boxes(boxes) 40 | 41 | return img, inst, boxes, labels 42 | 43 | def _resize_boxes(self, boxes): 44 | boxes = boxes.clone() 45 | boxes[:, 0] *= self.xscale 46 | boxes[:, 1] *= self.yscale 47 | boxes[:, 2] *= self.xscale 48 | boxes[:, 3] *= self.yscale 49 | 50 | return boxes 51 | 52 | 53 | class RandomHorizontalFlip(object): 54 | def __init__(self, p=0.5): 55 | self.p = p 56 | 57 | def __call__(self, img, inst, boxes, labels): 58 | if random.random() < self.p: 59 | img = F.hflip(img) 60 | inst = F.hflip(inst) 61 | boxes = self._hflip_boxes(img.size[0], boxes) 62 | 63 | return img, inst, boxes, labels 64 | 65 | def _hflip_boxes(self, width, boxes): 66 | boxes = boxes.clone() 67 | box_width = boxes[:, 2] - boxes[:, 0] 68 | boxes[:, 2] = width - boxes[:, 0] 69 | boxes[:, 0] = boxes[:, 2] - box_width 70 | 71 | return boxes 72 | 73 | 74 | class RandomGaussionBlur(object): 75 | def __init__(self, sigma=(0.15, 1.15)): 76 | self.sigma = sigma 77 | 78 | def __call__(self, img, inst, boxes, labels): 79 | sigma = self.sigma[0] + random.random() * self.sigma[1] 80 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True) 81 | blurred_img *= 255 82 | img = Image.fromarray(blurred_img.astype(np.uint8)) 83 | 84 | return img, inst, boxes, labels 85 | 86 | 87 | class RandomScale(object): 88 | def __init__(self, scale=1.0): 89 | self.scale = scale 90 | 91 | def __call__(self, img, inst, boxes, labels): 92 | scale = random.uniform(1.0, self.scale) 93 | 94 | img = F.affine(img, 0, (0, 0), scale, 0) 95 | inst = F.affine(inst, 0, (0, 0), scale, 0) 96 | 97 | return img, inst, boxes, labels 98 | 99 | 100 | class ColorJitter(object): 101 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 102 | from torchvision import transforms 103 | self.transform = transforms.ColorJitter(brightness, contrast, saturation, hue) 104 | 105 | def __call__(self, img, inst, boxes, labels): 106 | img = self.transform(img) 107 | 108 | return img, inst, boxes, labels 109 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .boxloss import * 2 | from .focalloss import * 3 | from .multitaskloss import * 4 | -------------------------------------------------------------------------------- /loss/boxloss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from loss.focalloss import FocalLoss 4 | 5 | 6 | class BoxLoss(nn.Module): 7 | 8 | def __init__(self, num_classes=11, gamma=2, reduction='none'): 9 | super(BoxLoss, self).__init__() 10 | self.num_classes = num_classes 11 | self.reduction = reduction 12 | self.focal_loss = FocalLoss(gamma, reduction=reduction) 13 | self.l2_loss = nn.MSELoss(reduction=reduction) 14 | 15 | def forward(self, loc_pred, loc_target, conf_pred, labels_target): 16 | # find only non-background predictions 17 | positives = labels_target > 0 18 | predicted_loc = loc_pred[positives, :].reshape(-1, 4) 19 | groundtruth_loc = loc_target[positives, :].reshape(-1, 4) 20 | 21 | predicted_conf = conf_pred[positives, :].reshape(-1, self.num_classes) 22 | groundtruth_label = labels_target[positives, :] # .reshape(-1, self.num_classes) 23 | 24 | loc_loss = self.l2_loss(predicted_loc, groundtruth_loc) 25 | conf_loss = self.focal_loss(predicted_conf, groundtruth_label) 26 | 27 | num_positives = loc_target.size(0) 28 | 29 | return loc_loss / num_positives, conf_loss / num_positives 30 | -------------------------------------------------------------------------------- /loss/focalloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FocalLoss(nn.Module): 7 | """Focal Loss for Dense Object Detection 8 | 9 | 10 | .. math:: 11 | \text{loss}(p_{t}) = -(1-p_{t})^ \gamma \cdot \log(p_{t}) 12 | 13 | Args: 14 | gamma (int, optional): Gamma smoothly adjusts the rate at which easy examples 15 | are down weighted. If gamma is equals 0 it's the same as cross entropy loss. Default: 1 16 | reduction (string, optional): Specifies the reduction to apply to the output: 17 | 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 18 | 'mean': the sum of the output will be divided by the number of 19 | elements in the output, 'sum': the output will be summed. Default: 'mean' 20 | """ 21 | 22 | def __init__(self, gamma=1, reduction='mean'): 23 | super(FocalLoss, self).__init__() 24 | 25 | self.gamma = gamma 26 | self.reduction = reduction 27 | 28 | def forward(self, input, target): 29 | log_pt = -F.cross_entropy(input, target, reduction='none') 30 | pt = torch.exp(log_pt) 31 | 32 | loss = -torch.pow(1 - pt, self.gamma) * log_pt 33 | 34 | if self.reduction == 'mean': 35 | return loss.mean() 36 | elif self.reduction == 'sum': 37 | return loss.sum() 38 | else: 39 | return loss 40 | -------------------------------------------------------------------------------- /loss/multitaskloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MultiTaskLoss(nn.Module): 6 | def __init__(self): 7 | super(MultiTaskLoss, self).__init__() 8 | 9 | self.uncert_semantics = nn.Parameter(torch.zeros(1, requires_grad=True)) 10 | self.uncert_offsets = nn.Parameter(torch.zeros(1, requires_grad=True)) 11 | self.uncert_ssdbox = nn.Parameter(torch.zeros(1, requires_grad=True)) 12 | self.uncert_ssdclass = nn.Parameter(torch.zeros(1, requires_grad=True)) 13 | 14 | def forward(self, semantics_loss, offsets_loss, box_loss, conf_loss): 15 | loss1 = 0.5 * torch.exp(-self.uncert_semantics) * semantics_loss + self.uncert_semantics 16 | loss2 = torch.exp(-self.uncert_offsets) * offsets_loss + self.uncert_offsets 17 | loss3 = torch.exp(-self.uncert_ssdbox) * box_loss + self.uncert_ssdbox 18 | loss4 = 0.5 * torch.exp(-self.uncert_ssdclass) * conf_loss + self.uncert_ssdclass 19 | 20 | loss = loss1 + loss2 + loss3 + loss4 21 | 22 | return loss 23 | -------------------------------------------------------------------------------- /metrics/iou.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ignite.metrics import Metric 3 | 4 | 5 | class IntersectionOverUnion(Metric): 6 | """Computes the intersection over union (IoU) per class. 7 | 8 | based on: https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 9 | 10 | - `update` must receive output of the form `(y_pred, y)`. 11 | """ 12 | 13 | def __init__(self, num_classes=10, ignore_index=255, output_transform=lambda x: x): 14 | self.num_classes = num_classes 15 | self.ignore_index = ignore_index 16 | self.confusion_matrix = np.zeros((num_classes, num_classes)) 17 | 18 | super(IntersectionOverUnion, self).__init__(output_transform=output_transform) 19 | 20 | def _fast_hist(self, label_true, label_pred): 21 | # mask = (label_true >= 0) & (label_true < self.num_classes) 22 | mask = label_true != self.ignore_index 23 | hist = np.bincount(self.num_classes * label_true[mask].astype(np.int) + label_pred[mask], 24 | minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes) 25 | return hist 26 | 27 | def reset(self): 28 | self.confusion_matrix = np.zeros((self.num_classes, self.num_classes)) 29 | 30 | def update(self, output): 31 | y_pred, y = output 32 | 33 | for label_true, label_pred in zip(y.numpy(), y_pred.numpy()): 34 | self.confusion_matrix += self._fast_hist(label_true.flatten(), label_pred.flatten()) 35 | 36 | def compute(self): 37 | hist = self.confusion_matrix 38 | with np.errstate(divide='ignore', invalid='ignore'): 39 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 40 | 41 | return np.nanmean(iu) 42 | -------------------------------------------------------------------------------- /metrics/mean_ap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ignite.metrics import Metric 3 | 4 | 5 | class MeanAveragePrecision(Metric): 6 | 7 | def __init__(self, num_classes=20, output_transform=lambda x: x): 8 | super(MeanAveragePrecision, self).__init__(output_transform=output_transform) 9 | 10 | self.num_classes = num_classes 11 | 12 | def reset(self): 13 | self._true_boxes = torch.tensor([], dtype=torch.long) 14 | self._true_labels = torch.tensor([], dtype=torch.long) 15 | 16 | self._det_boxes = torch.tensor([], dtype=torch.float32) 17 | self._det_labels = torch.tensor([], dtype=torch.float32) 18 | self._det_scores = torch.tensor([], dtype=torch.float32) 19 | 20 | def update(self, output): 21 | boxes_preds, labels_preds, scores_preds, boxes, labels = output 22 | 23 | self._true_boxes = torch.cat([self._true_boxes, boxes], dim=0) 24 | self._true_labels = torch.cat([self._true_labels, labels], dim=0) 25 | 26 | self._det_boxes = torch.cat([self._det_boxes, boxes_preds], dim=0) 27 | self._det_labels = torch.cat([self._det_labels, labels_preds], dim=0) 28 | self._det_scores = torch.cat([self._det_scores, scores_preds], dim=0) 29 | 30 | def compute(self): 31 | for c in range(1, self.num_classes): 32 | pass 33 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .box2pix import * 2 | from .multibox import * 3 | -------------------------------------------------------------------------------- /models/box2pix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils import model_zoo 4 | from torchvision import models 5 | from torchvision.models.googlenet import BasicConv2d, Inception 6 | 7 | from models.multibox import MultiBox 8 | from utils.helper import get_upsampling_weight 9 | 10 | 11 | def box2pix(num_classes=11, pretrained=False, **kwargs): 12 | if pretrained: 13 | if 'transform_input' not in kwargs: 14 | kwargs['transform_input'] = True 15 | model = Box2Pix(num_classes, **kwargs) 16 | model.load_state_dict(model_zoo.load_url('')) 17 | return model 18 | 19 | return Box2Pix(num_classes, **kwargs) 20 | 21 | 22 | class Box2Pix(nn.Module): 23 | """ 24 | Implementation of Box2Pix: Single-Shot Instance Segmentation by Assigning Pixels to Object Boxes 25 | 26 | """ 27 | 28 | def __init__(self, num_classes=11, transform_input=False): 29 | super(Box2Pix, self).__init__() 30 | self.transform_input = transform_input 31 | 32 | self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3) 33 | self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 34 | self.conv2 = BasicConv2d(64, 64, kernel_size=1) 35 | self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) 36 | 37 | self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 38 | self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) 39 | self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) 40 | 41 | self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 42 | self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) 43 | self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) 44 | self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) 45 | self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) 46 | self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) 47 | 48 | self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 49 | self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) 50 | self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) 51 | 52 | self.maxpool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 53 | self.inception6a = Inception2(1024, 256, 160, 320, 32, 128, 128) 54 | self.inception6b = Inception2(832, 384, 192, 384, 48, 128, 128) 55 | 56 | self.maxpool6 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 57 | self.inception7a = Inception2(1024, 256, 160, 320, 32, 128, 128) 58 | self.inception7b = Inception2(832, 384, 192, 384, 48, 128, 128) 59 | 60 | self.sem_score3b = nn.Conv2d(480, num_classes, kernel_size=1) 61 | self.sem_score4e = nn.Conv2d(832, num_classes, kernel_size=1) 62 | self.sem_score5b = nn.Conv2d(1024, num_classes, kernel_size=1) 63 | self.sem_score6b = nn.Conv2d(1024, num_classes, kernel_size=1) 64 | self.sem_upscore = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, bias=False) 65 | self.sem_upscore2 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, bias=False) 66 | self.sem_upscore4 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, bias=False) 67 | self.sem_upscore8 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=16, stride=8, bias=False) 68 | 69 | self.offs_score3b = nn.Conv2d(480, 2, kernel_size=1) 70 | self.offs_score4e = nn.Conv2d(832, 2, kernel_size=1) 71 | self.offs_score5b = nn.Conv2d(1024, 2, kernel_size=1) 72 | self.offs_score6b = nn.Conv2d(1024, 2, kernel_size=1) 73 | self.offs_upscore = nn.ConvTranspose2d(2, 2, kernel_size=4, stride=2, bias=False) 74 | self.offs_upscore2 = nn.ConvTranspose2d(2, 2, kernel_size=4, stride=2, bias=False) 75 | self.offs_upscore4 = nn.ConvTranspose2d(2, 2, kernel_size=4, stride=2, bias=False) 76 | self.offs_upscore8 = nn.ConvTranspose2d(2, 2, kernel_size=16, stride=8, bias=False) 77 | 78 | self.multibox = MultiBox(num_classes) 79 | self._initialize_weights(num_classes) 80 | 81 | def _initialize_weights(self, num_classes): 82 | for m in self.modules(): 83 | if isinstance(m, nn.Conv2d): 84 | if m.kernel_size[0] == 1 and m.out_channels in [num_classes, 2]: 85 | nn.init.constant_(m.weight, 0) 86 | nn.init.constant_(m.bias, 0) 87 | else: 88 | nn.init.xavier_uniform_(m.weight) 89 | elif isinstance(m, nn.ConvTranspose2d): 90 | upsampling_weight = get_upsampling_weight(m.out_channels, m.kernel_size[0]) 91 | with torch.no_grad(): 92 | m.weight.copy_(upsampling_weight) 93 | elif isinstance(m, nn.BatchNorm2d): 94 | nn.init.constant_(m.weight, 1) 95 | nn.init.constant_(m.bias, 0) 96 | 97 | def init_from_googlenet(self): 98 | googlenet = models.googlenet(pretrained=True) 99 | self.load_state_dict(googlenet.state_dict(), strict=False) 100 | self.transform_input = True 101 | 102 | def _transform_input(self, x): 103 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 104 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 105 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 106 | 107 | return torch.cat([x_ch0, x_ch1, x_ch2], 1) 108 | 109 | def forward(self, x): 110 | feature_maps = [] 111 | size = x.size() 112 | 113 | if self.transform_input: 114 | x = self._transform_input(x) 115 | x = self.conv1(x) 116 | x = self.maxpool1(x) 117 | x = self.conv2(x) 118 | x = self.conv3(x) 119 | x = self.maxpool2(x) 120 | x = self.inception3a(x) 121 | inception3b = self.inception3b(x) 122 | x = self.maxpool3(inception3b) 123 | x = self.inception4a(x) 124 | x = self.inception4b(x) 125 | x = self.inception4c(x) 126 | x = self.inception4d(x) 127 | inception4e = self.inception4e(x) 128 | feature_maps.append(inception4e) 129 | 130 | x = self.maxpool4(inception4e) 131 | x = self.inception5a(x) 132 | inception5b = self.inception5b(x) 133 | feature_maps.append(inception5b) 134 | 135 | x = self.maxpool5(inception5b) 136 | x = self.inception6a(x) 137 | inception6b = self.inception6b(x) 138 | feature_maps.append(inception6b) 139 | 140 | x = self.maxpool6(inception6b) 141 | x = self.inception7a(x) 142 | inception7b = self.inception7b(x) 143 | feature_maps.append(inception7b) 144 | 145 | loc_preds, conf_preds = self.multibox(feature_maps) 146 | 147 | sem_score6b = self.sem_score6b(inception6b) 148 | sem_score5b = self.sem_score5b(inception5b) 149 | semantics = self.sem_upscore(sem_score6b) 150 | semantics = semantics[:, :, 1:1 + sem_score5b.size()[2], 1:1 + sem_score5b.size()[3]] 151 | semantics += sem_score5b 152 | sem_score4e = self.sem_score4e(inception4e) 153 | semantics = self.sem_upscore2(semantics) 154 | semantics = semantics[:, :, 1:1 + sem_score4e.size()[2], 1:1 + sem_score4e.size()[3]] 155 | semantics += sem_score4e 156 | sem_score3b = self.sem_score3b(inception3b) 157 | semantics = self.sem_upscore4(semantics) 158 | semantics = semantics[:, :, 1:1 + sem_score3b.size()[2], 1:1 + sem_score3b.size()[3]] 159 | semantics += sem_score3b 160 | semantics = self.sem_upscore8(semantics) 161 | semantics = semantics[:, :, 4:4 + size[2], 4:4 + size[3]].contiguous() 162 | 163 | offs_score6b = self.offs_score6b(inception6b) 164 | offs_score5b = self.offs_score5b(inception5b) 165 | offsets = self.offs_upscore(offs_score6b) 166 | offsets = offsets[:, :, 1:1 + offs_score5b.size()[2], 1:1 + offs_score5b.size()[3]] 167 | offsets += offs_score5b 168 | offs_score4e = self.offs_score4e(inception4e) 169 | offsets = self.offs_upscore2(offsets) 170 | offsets = offsets[:, :, 1:1 + offs_score4e.size()[2], 1:1 + offs_score4e.size()[3]] 171 | offsets += offs_score4e 172 | offs_score3b = self.offs_score3b(inception3b) 173 | offsets = self.offs_upscore4(offsets) 174 | offsets = offsets[:, :, 1:1 + offs_score3b.size()[2], 1:1 + offs_score3b.size()[3]] 175 | offsets += offs_score3b 176 | offsets = self.offs_upscore8(offsets) 177 | offsets = offsets[:, :, 4:4 + size[2], 4:4 + size[3]].contiguous() 178 | 179 | return loc_preds, conf_preds, semantics, offsets 180 | 181 | 182 | class Inception2(nn.Module): 183 | 184 | def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): 185 | super(Inception2, self).__init__() 186 | 187 | self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) 188 | self.branch2 = nn.Sequential( 189 | BasicConv2d(in_channels, ch3x3red, kernel_size=1), 190 | BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) 191 | ) 192 | self.branch3 = nn.Sequential( 193 | BasicConv2d(in_channels, ch5x5red, kernel_size=1), 194 | BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2) 195 | ) 196 | self.branch4 = nn.Sequential( 197 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), 198 | BasicConv2d(in_channels, pool_proj, kernel_size=1) 199 | ) 200 | 201 | def forward(self, x): 202 | branch1 = self.branch1(x) 203 | branch2 = self.branch2(x) 204 | branch3 = self.branch3(x) 205 | branch4 = self.branch4(x) 206 | 207 | outputs = [branch1, branch2, branch3, branch4] 208 | return torch.cat(outputs, 1) 209 | 210 | 211 | if __name__ == '__main__': 212 | num_classes, width, height = 20, 1024, 2048 213 | 214 | model = Box2Pix(num_classes) # .to('cuda') 215 | model.init_from_googlenet() 216 | inp = torch.randn(1, 3, height, width) # .to('cuda') 217 | 218 | loc, conf, sem, offs = model(inp) 219 | 220 | assert loc.size(2) == 4 221 | assert conf.size(2) == num_classes 222 | assert sem.size() == torch.Size([1, num_classes, height, width]) 223 | assert offs.size() == torch.Size([1, 2, height, width]) 224 | 225 | print('Pass size check.') 226 | -------------------------------------------------------------------------------- /models/multibox.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MultiBox(nn.Module): 6 | 7 | def __init__(self, num_classes): 8 | super(MultiBox, self).__init__() 9 | 10 | self.num_classes = num_classes 11 | self.loc_layers = nn.ModuleList() 12 | self.conf_layers = nn.ModuleList() 13 | 14 | num_defaults = [16, 16, 20, 21] 15 | in_channels = [832, 1024, 1024, 1024] 16 | 17 | for i in range(len(in_channels)): 18 | self.loc_layers.append(nn.Conv2d(in_channels[i], num_defaults[i] * 4, kernel_size=1)) 19 | self.conf_layers.append(nn.Conv2d(in_channels[i], num_defaults[i] * num_classes, kernel_size=1)) 20 | 21 | def forward(self, input): 22 | loc_preds = [] 23 | conf_preds = [] 24 | 25 | for i, layer in enumerate(input): 26 | loc = self.loc_layers[i](layer) 27 | # (N x C x H x W) -> (N x H x W x C) 28 | loc = loc.permute(0, 2, 3, 1).contiguous() 29 | loc = loc.view(loc.size(0), -1, 4) 30 | loc_preds.append(loc) 31 | 32 | conf = self.conf_layers[i](layer) 33 | # (N x C x H x W) -> (N x H x W x C) 34 | conf = conf.permute(0, 2, 3, 1).contiguous() 35 | conf = conf.view(conf.size(0), -1, self.num_classes) 36 | conf_preds.append(conf) 37 | 38 | loc_preds = torch.cat(loc_preds, 1) 39 | conf_preds = torch.cat(conf_preds, 1) 40 | 41 | return loc_preds, conf_preds 42 | -------------------------------------------------------------------------------- /prediction/__init__.py: -------------------------------------------------------------------------------- 1 | from .predictor import * 2 | -------------------------------------------------------------------------------- /prediction/predictor.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | import torch 4 | import torchvision.transforms as transforms 5 | 6 | from models.box2pix import Box2Pix 7 | 8 | 9 | class Predictor(object): 10 | 11 | def __init__(self, show_segmentation=True, show_labels=False, show_boxes=False): 12 | self.show_segmentation = show_segmentation 13 | self.show_labels = show_labels 14 | self.show_boxes = show_boxes 15 | 16 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 17 | self.net = Box2Pix().to(self.device) 18 | self.net.eval() 19 | 20 | self.transform = self.get_transform() 21 | 22 | def get_transform(self): 23 | transform = transforms.Compose([ 24 | transforms.ToPILImage(), 25 | transforms.Resize((512, 1024)), 26 | transforms.ToTensor(), 27 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 28 | ]) 29 | 30 | return transform 31 | 32 | def run(self, image): 33 | result = image.copy() 34 | 35 | result = self.transform(result) 36 | result = result.unsqueeze(0) 37 | result = result.to(self.device) 38 | 39 | with torch.no_grad(): 40 | loc_preds, conf_preds, semantics_pred, offsets_pred = self.net(result) 41 | 42 | if self.show_segmentation: 43 | result = self.add_segmentation_overlay(result, None) 44 | 45 | if self.show_boxes: 46 | result = self.add_boxes_overlay(result, None) 47 | 48 | if self.show_labels: 49 | result = self.add_overlay_classes(result, None) 50 | 51 | return result 52 | 53 | def add_boxes_overlay(self, image, predictions): 54 | 55 | for box in predictions: 56 | top_left, bottom_right = tuple(box[:2].tolist()), tuple(box[2:].tolist()) 57 | image = cv2.rectangle(image, top_left, bottom_right, 0, 1) 58 | 59 | return image 60 | 61 | def add_segmentation_overlay(self, image, predictions): 62 | return image 63 | 64 | def add_overlay_classes(self, image, predictions): 65 | scores = [0] 66 | labels = [0] 67 | boxes = predictions 68 | 69 | for box, score, label in zip(boxes, scores, labels): 70 | x, y = box[:2] 71 | cv2.putText(image, 'car', (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) 72 | 73 | return image 74 | 75 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | pytorch-ignite 4 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import cv2 4 | 5 | from prediction.predictor import Predictor 6 | 7 | 8 | def run(file_path, show_boxes): 9 | cap = cv2.VideoCapture(file_path) 10 | detector = Predictor(show_boxes) 11 | 12 | while cap.isOpened(): 13 | ret, frame = cap.read() 14 | res = detector.run(frame) 15 | 16 | cv2.imshow('frame', res) 17 | 18 | if cv2.waitKey(25) & 0xFF == ord('q'): 19 | break 20 | 21 | cap.release() 22 | cv2.destroyAllWindows() 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = ArgumentParser('Box2Pix with PyTorch') 27 | parser.add_argument('input', help='input video file') 28 | parser.add_argument('--show-boxes', action='store_true', 29 | help='whether or not to also display boxes in the result') 30 | 31 | args = parser.parse_args() 32 | 33 | if args.input is not None: 34 | run(args.input, args.show_boxes) 35 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from argparse import ArgumentParser 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from ignite.contrib.handlers import ProgressBar 11 | from ignite.contrib.handlers.tensorboard_logger import * 12 | from ignite.engine import Events, Engine 13 | from ignite.handlers import ModelCheckpoint, Timer 14 | from ignite.metrics import RunningAverage, Loss, ConfusionMatrix, mIoU 15 | from ignite.utils import convert_tensor 16 | from torch.utils.data import DataLoader 17 | from torchvision.transforms import Normalize 18 | 19 | import models 20 | from datasets.cityscapes import CityscapesDataset 21 | from datasets.transforms import transforms 22 | from datasets.transforms.transforms import ToTensor 23 | from loss.boxloss import BoxLoss 24 | from loss.multitaskloss import MultiTaskLoss 25 | from metrics.mean_ap import MeanAveragePrecision 26 | from utils import helper 27 | from utils.box_coder import BoxCoder 28 | 29 | 30 | def get_data_loaders(data_dir, batch_size, num_workers): 31 | # new_size = (512, 1024) # (1024, 2048) 32 | 33 | joint_transform = transforms.Compose([ 34 | # transforms.Resize(new_size), 35 | transforms.RandomHorizontalFlip(), 36 | transforms.ColorJitter(0.2, 0.2, 0.2), 37 | transforms.RandomGaussionBlur(sigma=(0, 1.2)), 38 | transforms.ToTensor() 39 | ]) 40 | 41 | normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 42 | 43 | train_loader = DataLoader(CityscapesDataset(root=data_dir, split='train', joint_transform=joint_transform, 44 | img_transform=normalize), 45 | batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) 46 | 47 | val_loader = DataLoader(CityscapesDataset(root=data_dir, split='val', joint_transform=ToTensor(), 48 | img_transform=normalize), 49 | batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) 50 | 51 | return train_loader, val_loader 52 | 53 | 54 | def run(args): 55 | train_loader, val_loader = get_data_loaders(args.dir, args.batch_size, args.num_workers) 56 | 57 | if args.seed is not None: 58 | torch.manual_seed(args.seed) 59 | 60 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 61 | 62 | num_classes = CityscapesDataset.num_instance_classes() + 1 63 | model = models.box2pix(num_classes) 64 | model.init_from_googlenet() 65 | 66 | if torch.cuda.device_count() > 1: 67 | print("Using %d GPU(s)" % torch.cuda.device_count()) 68 | model = nn.DataParallel(model) 69 | 70 | model = model.to(device) 71 | 72 | semantics_criterion = nn.CrossEntropyLoss(ignore_index=255) 73 | offsets_criterion = nn.MSELoss() 74 | box_criterion = BoxLoss(num_classes, gamma=2) 75 | multitask_criterion = MultiTaskLoss().to(device) 76 | 77 | box_coder = BoxCoder() 78 | optimizer = optim.Adam([{'params': model.parameters()}, 79 | {'params': multitask_criterion.parameters()}], lr=args.lr) 80 | 81 | if args.resume: 82 | if os.path.isfile(args.resume): 83 | print("Loading checkpoint '{}'".format(args.resume)) 84 | checkpoint = torch.load(args.resume) 85 | args.start_epoch = checkpoint['epoch'] 86 | model.load_state_dict(checkpoint['model']) 87 | optimizer.load_state_dict(checkpoint['optimizer']) 88 | multitask_criterion.load_state_dict(checkpoint['multitask']) 89 | print("Loaded checkpoint '{}' (Epoch {})".format(args.resume, checkpoint['epoch'])) 90 | else: 91 | print("No checkpoint found at '{}'".format(args.resume)) 92 | 93 | def _prepare_batch(batch, non_blocking=True): 94 | x, instance, boxes, labels = batch 95 | 96 | return (convert_tensor(x, device=device, non_blocking=non_blocking), 97 | convert_tensor(instance, device=device, non_blocking=non_blocking), 98 | convert_tensor(boxes, device=device, non_blocking=non_blocking), 99 | convert_tensor(labels, device=device, non_blocking=non_blocking)) 100 | 101 | def _update(engine, batch): 102 | model.train() 103 | optimizer.zero_grad() 104 | x, instance, boxes, labels = _prepare_batch(batch) 105 | boxes, labels = box_coder.encode(boxes, labels) 106 | 107 | loc_preds, conf_preds, semantics_pred, offsets_pred = model(x) 108 | 109 | semantics_loss = semantics_criterion(semantics_pred, instance) 110 | offsets_loss = offsets_criterion(offsets_pred, instance) 111 | box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds, labels) 112 | 113 | loss = multitask_criterion(semantics_loss, offsets_loss, box_loss, conf_loss) 114 | 115 | loss.backward() 116 | optimizer.step() 117 | 118 | return { 119 | 'loss': loss.item(), 120 | 'loss_semantics': semantics_loss.item(), 121 | 'loss_offsets': offsets_loss.item(), 122 | 'loss_ssdbox': box_loss.item(), 123 | 'loss_ssdclass': conf_loss.item() 124 | } 125 | 126 | trainer = Engine(_update) 127 | 128 | checkpoint_handler = ModelCheckpoint(args.output_dir, 'checkpoint', save_interval=1, n_saved=10, 129 | require_empty=False, create_dir=True, save_as_state_dict=False) 130 | timer = Timer(average=True) 131 | 132 | # attach running average metrics 133 | train_metrics = ['loss', 'loss_semantics', 'loss_offsets', 'loss_ssdbox', 'loss_ssdclass'] 134 | for m in train_metrics: 135 | transform = partial(lambda x, metric: x[metric], metric=m) 136 | RunningAverage(output_transform=transform).attach(trainer, m) 137 | 138 | # attach progress bar 139 | pbar = ProgressBar(persist=True) 140 | pbar.attach(trainer, metric_names=train_metrics) 141 | 142 | checkpoint = {'model': model.state_dict(), 'epoch': trainer.state.epoch, 'optimizer': optimizer.state_dict(), 143 | 'multitask': multitask_criterion.state_dict()} 144 | trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 145 | 'checkpoint': checkpoint 146 | }) 147 | 148 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, 149 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) 150 | 151 | def _inference(engine, batch): 152 | model.eval() 153 | with torch.no_grad(): 154 | x, instance, boxes, labels = _prepare_batch(batch) 155 | loc_preds, conf_preds, semantics, offsets_pred = model(x) 156 | boxes_preds, labels_preds, scores_preds = box_coder.decode(loc_preds, F.softmax(conf_preds, dim=1), 157 | score_thresh=0.01) 158 | 159 | semantics_loss = semantics_criterion(semantics, instance) 160 | offsets_loss = offsets_criterion(offsets_pred, instance) 161 | box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds, labels) 162 | 163 | semantics_pred = semantics.argmax(dim=1) 164 | instances = helper.assign_pix2box(semantics_pred, offsets_pred, boxes_preds, labels_preds) 165 | 166 | return { 167 | 'loss': (semantics_loss, offsets_loss, {'box_loss': box_loss, 'conf_loss': conf_loss}), 168 | 'objects': (boxes_preds, labels_preds, scores_preds, boxes, labels), 169 | 'semantics': semantics_pred, 170 | 'instances': instances 171 | } 172 | 173 | train_evaluator = Engine(_inference) 174 | cm = ConfusionMatrix(num_classes=num_classes, output_transform=lambda x: x['semantics']) 175 | mIoU(cm, ignore_index=0).attach(train_evaluator, 'mIoU') 176 | Loss(multitask_criterion, output_transform=lambda x: x['loss']).attach(train_evaluator, 'loss') 177 | MeanAveragePrecision(num_classes, output_transform=lambda x: x['objects']).attach(train_evaluator, 'mAP') 178 | 179 | evaluator = Engine(_inference) 180 | cm2 = ConfusionMatrix(num_classes=num_classes, output_transform=lambda x: x['semantics']) 181 | mIoU(cm2, ignore_index=0).attach(train_evaluator, 'mIoU') 182 | Loss(multitask_criterion, output_transform=lambda x: x['loss']).attach(evaluator, 'loss') 183 | MeanAveragePrecision(num_classes, output_transform=lambda x: x['objects']).attach(evaluator, 'mAP') 184 | 185 | tb_logger = TensorboardLogger(args.log_dir) 186 | tb_logger.attach(trainer, 187 | log_handler=OutputHandler(tag='training', output_transform=lambda loss: { 188 | 'loss': loss['loss'], 189 | 'loss_semantics': loss['loss_semantics'], 190 | 'loss_offsets': loss['loss_offsets'], 191 | 'loss_ssdbox': loss['loss_ssdbox'], 192 | 'loss_ssdclass': loss['loss_ssdclass'] 193 | 194 | }), 195 | event_name=Events.ITERATION_COMPLETED) 196 | 197 | tb_logger.attach(train_evaluator, 198 | log_handler=OutputHandler(tag='training_eval', 199 | metric_names=['loss', 'mAP', 'mIoU'], 200 | output_transform=lambda loss: { 201 | 'loss': loss['loss'], 202 | 'objects': loss['objects'], 203 | 'semantics': loss['semantics'] 204 | }, 205 | another_engine=trainer), 206 | event_name=Events.EPOCH_COMPLETED) 207 | 208 | tb_logger.attach(evaluator, 209 | log_handler=OutputHandler(tag='validation_eval', 210 | metric_names=['loss', 'mAP', 'mIoU'], 211 | output_transform=lambda loss: { 212 | 'loss': loss['loss'], 213 | 'objects': loss['objects'], 214 | 'semantics': loss['semantics'] 215 | }, 216 | another_engine=trainer), 217 | event_name=Events.EPOCH_COMPLETED) 218 | 219 | @trainer.on(Events.STARTED) 220 | def initialize(engine): 221 | if args.resume: 222 | engine.state.epoch = args.start_epoch 223 | 224 | @trainer.on(Events.EPOCH_COMPLETED) 225 | def print_times(engine): 226 | pbar.log_message("Epoch [{}/{}] done. Time per batch: {:.3f}[s]".format(engine.state.epoch, 227 | engine.state.max_epochs, timer.value())) 228 | timer.reset() 229 | 230 | @trainer.on(Events.EPOCH_COMPLETED) 231 | def log_training_results(engine): 232 | train_evaluator.run(train_loader) 233 | metrics = train_evaluator.state.metrics 234 | loss = metrics['loss'] 235 | mean_ap = metrics['mAP'] 236 | iou = metrics['mIoU'] 237 | 238 | pbar.log_message('Training results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}' 239 | .format(loss, evaluator.state.epochs, evaluator.state.max_epochs, mean_ap, iou * 100.0)) 240 | 241 | @trainer.on(Events.EPOCH_COMPLETED) 242 | def log_validation_results(engine): 243 | evaluator.run(val_loader) 244 | metrics = evaluator.state.metrics 245 | loss = metrics['loss'] 246 | mean_ap = metrics['mAP'] 247 | iou = metrics['mIoU'] 248 | 249 | pbar.log_message('Validation results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}' 250 | .format(loss, evaluator.state.epochs, evaluator.state.max_epochs, mean_ap, iou * 100.0)) 251 | 252 | @trainer.on(Events.EXCEPTION_RAISED) 253 | def handle_exception(engine, e): 254 | if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): 255 | engine.terminate() 256 | warnings.warn("KeyboardInterrupt caught. Exiting gracefully.") 257 | 258 | checkpoint_handler(engine, {'model_exception': model}) 259 | else: 260 | raise e 261 | 262 | @trainer.on(Events.COMPLETED) 263 | def save_final_model(engine): 264 | checkpoint_handler(engine, {'final': model}) 265 | 266 | trainer.run(train_loader, max_epochs=args.epochs) 267 | tb_logger.close() 268 | 269 | 270 | if __name__ == '__main__': 271 | parser = ArgumentParser('Box2Pix with PyTorch') 272 | parser.add_argument('--batch_size', type=int, default=8, 273 | help='input batch size for training') 274 | parser.add_argument('--num-workers', type=int, default=8, 275 | help='number of workers') 276 | parser.add_argument('--epochs', type=int, default=200, 277 | help='number of epochs to train') 278 | parser.add_argument('--lr', type=float, default=1e-4, 279 | help='learning rate') 280 | parser.add_argument('--seed', type=int, help='manual seed') 281 | parser.add_argument('--output-dir', default='./checkpoints', 282 | help='directory to save model checkpoints') 283 | parser.add_argument('--resume', type=str, metavar='PATH', 284 | help='path to latest checkpoint (default: none)') 285 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 286 | help='manual epoch number (useful on restarts)') 287 | parser.add_argument('--log-interval', type=int, default=10, 288 | help='how many batches to wait before logging training status') 289 | parser.add_argument("--log-dir", type=str, default="tensorboard_logs", 290 | help="log directory for Tensorboard log output") 291 | parser.add_argument("--dataset-dir", type=str, default="data/cityscapes", 292 | help="location of the dataset") 293 | 294 | run(parser.parse_args()) 295 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .box_coder import * 2 | from .box_utils import * 3 | from .helper import * 4 | -------------------------------------------------------------------------------- /utils/box_coder.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | 5 | from utils import box_utils 6 | 7 | FeatureMapDef = namedtuple('FeatureMapDef', ['width', 'height', 'receptive_size']) 8 | 9 | 10 | class BoxCoder(object): 11 | """ 12 | References: 13 | https://github.com/ECer23/ssd.pytorch 14 | https://github.com/kuangliu/torchcv 15 | """ 16 | 17 | def __init__(self, img_width=2048, img_height=1024): 18 | self.variances = (0.1, 0.2) 19 | 20 | priors = [ 21 | # height, width 22 | (4, 52), (24, 24), (54, 8), (80, 22), (52, 52), 23 | (20, 78), (156, 50), (78, 78), (48, 144), (412, 76), 24 | (104, 150), (74, 404), (644, 166), (358, 448), (70, 686), (68, 948), 25 | (772, 526), (476, 820), (150, 1122), (890, 880), (516, 1130) 26 | ] 27 | 28 | feature_maps = [ 29 | FeatureMapDef(128, 64, 427), 30 | FeatureMapDef(64, 32, 715), 31 | FeatureMapDef(32, 16, 1291), 32 | FeatureMapDef(16, 8, 2443) 33 | ] 34 | 35 | boxes = [] 36 | for fm in feature_maps: 37 | step_w = fm.width / img_width 38 | step_h = fm.height / img_height 39 | for x in range(fm.width): 40 | for y in range(fm.height): 41 | for p_h, p_w in priors: 42 | cx = (x + 0.5) * step_w 43 | cy = (y + 0.5) * step_h 44 | h = p_h / img_height 45 | w = p_w / img_width 46 | 47 | if fm.receptive_size > (p_h * 2) or fm.receptive_size > (p_w * 2): 48 | boxes.append((cx, cy, h, w)) 49 | 50 | self.priors = torch.as_tensor(boxes, dtype=torch.float32).clamp_(0.0, 1.0) 51 | 52 | def encode(self, boxes, labels, change_threshold=0.7): 53 | """Encode target bounding boxes and class labels. 54 | SSD coding rules: 55 | tx = (x - anchor_x) / (variance[0] * anchor_w) 56 | ty = (y - anchor_y) / (variance[0] * anchor_h) 57 | tw = log(w / anchor_w) / variance[1] 58 | th = log(h / anchor_h) / variance[1] 59 | 60 | Args: 61 | boxes: (tensor) bounding boxes of (xmin, ymin, xmax, ymax), sized [#obj, 4]. 62 | labels: (tensor) object class labels, sized [#obj,]. 63 | change_threshold: (float) the change metric threshold 64 | """ 65 | 66 | priors = self.priors 67 | priors = box_utils.center_to_corner_form(priors) 68 | 69 | change = box_utils.d_change(boxes, priors) 70 | 71 | change, max_idx = change.max(0) 72 | max_idx.squeeze_(0) 73 | change.squeeze_(0) 74 | 75 | boxes = boxes[max_idx] 76 | boxes = box_utils.corner_to_center_form(boxes) 77 | priors = box_utils.corner_to_center_form(priors) 78 | 79 | loc_xy = (boxes[:, :2] - priors[:, :2]) / priors[:, 2:] / self.variances[0] 80 | loc_wh = torch.log(boxes[:, 2:] / priors[:, 2:]) / self.variances[1] 81 | loc = torch.cat([loc_xy, loc_wh], 1) 82 | 83 | conf = labels[max_idx] + 1 # background class = 0 84 | conf[change < change_threshold] = 0 # background 85 | 86 | return loc, conf 87 | 88 | def decode(self, loc_preds, conf_preds, score_thresh=0.6, nms_thresh=0.5): 89 | """Decode predicted loc/cls back to real box locations and class labels. 90 | Args: 91 | loc_preds: (tensor) predicted loc, sized [8732,4]. 92 | conf_preds: (tensor) predicted conf, sized [8732,21]. 93 | score_thresh: (float) threshold for object confidence score. 94 | nms_thresh: (float) threshold for box nms. 95 | 96 | Returns: 97 | boxes: (tensor) bbox locations, sized [#obj,4]. 98 | labels: (tensor) class labels, sized [#obj,]. 99 | """ 100 | cxcy = loc_preds[:, :2] * self.variances[0] * self.priors[:, 2:] + self.priors[:, :2] 101 | wh = torch.exp(loc_preds[:, 2:] * self.variances[1]) * self.priors[:, 2:] 102 | box_preds = torch.cat([cxcy - wh / 2, cxcy + wh / 2], 1) 103 | 104 | boxes = [] 105 | labels = [] 106 | scores = [] 107 | num_classes = conf_preds.size(1) 108 | for i in range(num_classes - 1): 109 | score = conf_preds[:, i + 1] # class i corresponds to (i + 1) column 110 | mask = score > score_thresh 111 | if not mask.any(): 112 | continue 113 | box = box_preds[mask.nonzero().squeeze()] 114 | score = score[mask] 115 | 116 | keep = box # torchvision.layers.nms(box, score, nms_thresh) 117 | boxes.append(box[keep]) 118 | labels.append(torch.full(box[keep].size()[0], i, dtype=torch.int64)) 119 | scores.append(score[keep]) 120 | 121 | boxes = torch.cat(boxes, 0) 122 | labels = torch.cat(labels, 0) 123 | scores = torch.cat(scores, 0) 124 | 125 | return boxes, labels, scores 126 | 127 | 128 | if __name__ == '__main__': 129 | coder = BoxCoder() 130 | print(coder.priors[4:]) 131 | print(coder.priors.size()) 132 | -------------------------------------------------------------------------------- /utils/box_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | priors = [ 4 | # height, width 5 | (4, 52), 6 | (25, 25), 7 | (54, 8), 8 | (80, 22), 9 | (52, 52), 10 | (20, 78), 11 | (156, 50), 12 | (78, 78), 13 | (48, 144), 14 | (412, 76), 15 | (104, 150), 16 | (74, 404), 17 | (645, 166), 18 | (358, 448), 19 | (70, 686), 20 | (68, 948), 21 | (772, 526), 22 | (476, 820), 23 | (150, 1122), 24 | (890, 880), 25 | (518, 1130) 26 | ] 27 | 28 | priors_new = [ 29 | # height, width 30 | (4, 52), 31 | (24, 24), 32 | (54, 8), 33 | (80, 22), 34 | (52, 52), 35 | (20, 78), 36 | (156, 50), 37 | (78, 78), 38 | (48, 144), 39 | (412, 76), 40 | (104, 150), 41 | (74, 404), 42 | (644, 166), 43 | (358, 448), 44 | (70, 686), 45 | (68, 948), 46 | (772, 526), 47 | (476, 820), 48 | (150, 1122), 49 | (890, 880), 50 | (516, 1130) 51 | ] 52 | 53 | 54 | def get_bounding_box(polygon): 55 | fpoint = polygon[0] 56 | xmin, ymin, xmax, ymax = fpoint[0], fpoint[1], fpoint[0], fpoint[1] 57 | for point in polygon: 58 | x, y = point[0], point[1] 59 | xmin = min(xmin, x) 60 | ymin = min(ymin, y) 61 | xmax = max(xmax, x) 62 | ymax = max(ymax, y) 63 | 64 | return xmin, ymin, xmax, ymax 65 | 66 | 67 | def d_change(prior, ground_truth): 68 | """Compute a change based metric of two sets of boxes. 69 | 70 | Args: 71 | prior (tensor): Prior boxes, Shape: [num_priors, 4] 72 | ground_truth (tensor): Ground truth bounding boxes, Shape: [num_objects, 4] 73 | """ 74 | 75 | xtl = torch.abs(prior[:, 0] - ground_truth[:, 0]) 76 | ytl = torch.abs(prior[:, 1] - ground_truth[:, 1]) 77 | xbr = torch.abs(prior[:, 2] - ground_truth[:, 2]) 78 | ybr = torch.abs(prior[:, 3] - ground_truth[:, 3]) 79 | 80 | wgt = ground_truth[:, 2] - ground_truth[:, 0] 81 | hgt = ground_truth[:, 3] - ground_truth[:, 1] 82 | 83 | return torch.sqrt((torch.pow(ytl, 2) / hgt) + (torch.pow(xtl, 2) / wgt) 84 | + (torch.pow(ybr, 2) / hgt) + (torch.pow(xbr, 2) / wgt)) 85 | 86 | 87 | def corner_to_center_form(boxes): 88 | """Convert bounding boxes from (xmin, ymin, xmax, ymax) to (cx, cy, width, height) 89 | 90 | Args: 91 | boxes (tensor): Boxes, Shape: [num_priors, 4] 92 | """ 93 | 94 | return torch.cat([(boxes[:, 2:] + boxes[:, :2]) / 2, 95 | boxes[:, 2:] - boxes[:, :2]], 1) 96 | 97 | 98 | def center_to_corner_form(boxes): 99 | """Convert bounding boxes from (cx, cy, width, height) to (xmin, ymin, xmax, ymax) 100 | 101 | Args: 102 | boxes (tensor): Boxes, Shape: [num_priors, 4] 103 | """ 104 | 105 | return torch.cat([boxes[:, 2:] - (boxes[:, :2] / 2), 106 | boxes[:, 2:] + (boxes[:, :2] / 2)], 1) 107 | 108 | 109 | def nms(boxes, scores, thresh): 110 | x1 = boxes[:, 0] 111 | y1 = boxes[:, 1] 112 | x2 = boxes[:, 2] 113 | y2 = boxes[:, 3] 114 | 115 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 116 | _, indices = scores.scores.sort(0, descending=True) 117 | 118 | keep = [] 119 | while indices.size > 0: 120 | i = indices[0] 121 | keep.append(i) 122 | 123 | xx1 = torch.max(x1[i], x1[indices[1:]]) 124 | yy1 = torch.max(y1[i], y1[indices[1:]]) 125 | xx2 = torch.min(x2[i], x2[indices[1:]]) 126 | yy2 = torch.min(y2[i], y2[indices[1:]]) 127 | 128 | w = torch.clamp(xx2 - xx1, min=0.0) 129 | h = torch.clamp(yy2 - yy1, min=0.0) 130 | inter = w * h 131 | ovr = inter / (areas[i] + areas[indices[1:]] - inter) 132 | 133 | inds = torch.nonzero(ovr <= thresh).squeeze() 134 | indices = indices[inds + 1] 135 | 136 | return keep 137 | 138 | 139 | if __name__ == '__main__': 140 | """ 141 | layers = [ 142 | # inception4e 143 | {'size': 427, 'boxes': []}, 144 | # inception5b 145 | {'size': 715, 'boxes': []}, 146 | # inception6b 147 | {'size': 1291, 'boxes': []}, 148 | # inception7b 149 | {'size': 2443, 'boxes': []} 150 | ] 151 | 152 | # calculate the number of associated prior boxes for each layer 153 | for prior in priors_new: 154 | height, width = prior 155 | 156 | for layer in layers: 157 | if layer['size'] > (height * 2) or layer['size'] > (width * 2): 158 | layer['boxes'].append(prior) 159 | 160 | for layer in layers: 161 | print('Number of priors: ', len(layer['boxes'])) 162 | """ 163 | -------------------------------------------------------------------------------- /utils/helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def get_upsampling_weight(channels, kernel_size): 6 | """Make a 2D bilinear kernel suitable for upsampling 7 | 8 | Based on: https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/surgery.py 9 | """ 10 | factor = (kernel_size + 1) // 2 11 | if kernel_size % 2 == 1: 12 | center = factor - 1 13 | else: 14 | center = factor - 0.5 15 | og = np.ogrid[:kernel_size, :kernel_size] 16 | filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) 17 | filt = torch.from_numpy(filt) 18 | weight = torch.zeros([channels, channels, kernel_size, kernel_size], dtype=torch.float64) 19 | weight[range(channels), range(channels), :, :] = filt 20 | 21 | return weight 22 | 23 | 24 | def assign_pix2box(semantics, offsets, boxes, labels): 25 | return semantics 26 | -------------------------------------------------------------------------------- /utils/nms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def non_maximum_suppression(boxes, scores, threshold=0.5): 5 | 6 | if boxes.numel() == 0: 7 | return torch.LongTensor() 8 | 9 | xmin = boxes[:, 0] 10 | ymin = boxes[:, 1] 11 | xmax = boxes[:, 2] 12 | ymax = boxes[:, 3] 13 | 14 | areas = (xmax - xmin) * (ymax - ymin) 15 | 16 | _, indices = scores.sort(0, descending=True) 17 | keep = [] 18 | while indices.numel() > 0: 19 | i = indices[0] 20 | keep.append(i) 21 | 22 | if indices.numel() == 1: 23 | break 24 | 25 | return torch.LongTensor(keep) 26 | --------------------------------------------------------------------------------