├── .gitignore
├── LICENSE
├── README.md
├── datasets
├── __init__.py
├── cityscapes.py
├── kitti.py
└── transforms
│ ├── __init__.py
│ └── transforms.py
├── loss
├── __init__.py
├── boxloss.py
├── focalloss.py
└── multitaskloss.py
├── metrics
├── iou.py
└── mean_ap.py
├── models
├── __init__.py
├── box2pix.py
└── multibox.py
├── prediction
├── __init__.py
└── predictor.py
├── requirements.txt
├── test.py
├── train.py
└── utils
├── __init__.py
├── box_coder.py
├── box_utils.py
├── helper.py
└── nms.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | .hypothesis/
50 | .pytest_cache/
51 |
52 | # Sphinx documentation
53 | docs/_build/
54 |
55 | # Jupyter Notebook
56 | .ipynb_checkpoints
57 |
58 | # pyenv
59 | .python-version
60 |
61 | # mkdocs documentation
62 | /site
63 |
64 |
65 | # Datasets folder
66 | data/
67 |
68 | # PyCharm
69 | .idea/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Michael Kösel
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [WIP] pytorch-box2pix 
2 |
3 | Inofficial PyTorch implementation of [Box2Pix: Single-Shot Instance Segmentation by Assigning Pixels to Object Boxes](https://lmb.informatik.uni-freiburg.de/Publications/2018/UB18) (Uhrig et al., 2018).
4 |
5 | ## TODO:
6 |
7 | This is needed to get the project in a state where it can be trained:
8 |
9 | - [ ] mAP metric
10 |
11 | Instance segmentation can be added later as it's just a post processing step.
12 |
13 | ## Requirements
14 |
15 | - Install PyTorch ([pytorch.org](http://pytorch.org))
16 | - `pip install -r requirements.txt` (Currently requires torchvision master)
17 | - Download the Cityscapes dataset
18 |
19 | ## Usage
20 |
21 | Train model:
22 |
23 | ```bash
24 | python train.py --dataset-dir 'data/cityscapes'
25 | ```
26 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .cityscapes import *
2 | from .kitti import *
3 |
--------------------------------------------------------------------------------
/datasets/cityscapes.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torchvision.datasets as datasets
4 | from PIL import Image
5 |
6 | from utils.box_utils import get_bounding_box
7 |
8 |
9 | class CityscapesDataset(datasets.Cityscapes):
10 |
11 | def __init__(self, root, split='train', joint_transform=None, img_transform=None):
12 | super(CityscapesDataset, self).__init__(root, split, target_type=['instance', 'polygon'])
13 |
14 | self.joint_transform = joint_transform
15 | self.img_transform = img_transform
16 |
17 | def __getitem__(self, index):
18 | image, target = super(CityscapesDataset, self).__getitem__(index)
19 | instance, json = target
20 |
21 | instance = self._convert_id_to_train_id(instance)
22 | boxes, labels = self._create_boxes(json)
23 |
24 | if self.joint_transform:
25 | image, instance, boxes, labels = self.joint_transform(image, instance, boxes, labels)
26 |
27 | if self.img_transform:
28 | image = self.img_transform(image)
29 |
30 | return image, instance, boxes, labels
31 |
32 | def _convert_id_to_train_id(self, instance):
33 | instance = np.array(instance)
34 | instance_copy = instance.copy()
35 |
36 | for cls in self.classes:
37 | instance_copy[instance == cls.id] = cls.train_id
38 | instance = Image.fromarray(instance_copy.astype(np.uint8))
39 |
40 | return instance
41 |
42 | def _create_boxes(self, json):
43 | boxes = []
44 | labels = []
45 | objects = json['objects']
46 | for obj in objects:
47 | polygons = obj['polygon']
48 | cls = self.get_class_from_name(obj['label'])
49 | if cls and cls.has_instances:
50 | boxes.append(get_bounding_box(polygons))
51 | labels.append(cls.id)
52 |
53 | return torch.as_tensor(boxes, dtype=torch.float32), torch.tensor(labels, dtype=torch.int64)
54 |
55 | @staticmethod
56 | def get_class_from_name(name):
57 | for cls in CityscapesDataset.classes:
58 | if cls.name == name:
59 | return cls
60 | return None
61 |
62 | @staticmethod
63 | def get_class_from_id(id):
64 | for cls in CityscapesDataset.classes:
65 | if cls.id == id:
66 | return cls
67 | return None
68 |
69 | @staticmethod
70 | def get_instance_classes():
71 | return [cls for cls in CityscapesDataset.classes if cls.has_instances]
72 |
73 | @staticmethod
74 | def num_instance_classes():
75 | return len(CityscapesDataset.get_instance_classes())
76 |
77 | @staticmethod
78 | def get_colormap():
79 | cmap = torch.zeros([256, 3], dtype=torch.uint8)
80 |
81 | for cls in CityscapesDataset.classes:
82 | if cls.has_instances:
83 | cmap[cls.trainId, :] = torch.tensor(cls.color, dtype=torch.uint8)
84 |
85 | return cmap
86 |
--------------------------------------------------------------------------------
/datasets/kitti.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.utils.data as data
5 | import torchvision
6 | from PIL import Image
7 | import os
8 | import os.path
9 | import errno
10 |
11 | import json
12 | import os
13 |
14 | import torch.utils.data as data
15 | from PIL import Image
16 | from matplotlib.patches import Rectangle
17 |
18 | from datasets.transforms.transforms import Compose, RandomHorizontalFlip, Resize, ToTensor
19 |
20 |
21 | class KITTI(data.Dataset):
22 | """`KITTI `_ Dataset.
23 |
24 | Args:
25 | root (string): Root directory of dataset where directory ``leftImg8bit``
26 | and ``gtFine`` or ``gtCoarse`` are located.
27 | train (bool, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine"
28 | otherwise ``train``, ``train_extra`` or ``val``
29 | target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
30 | or ``color``.
31 | transform (callable, optional): A function/transform that takes in a PIL image
32 | and returns a transformed version. E.g, ``transforms.RandomCrop``
33 | target_transform (callable, optional): A function/transform that takes in the
34 | target and transforms it.
35 | """
36 |
37 | def __init__(self, root, train=True, target_type='instance', joint_transform=None, img_transform=None):
38 | self.root = os.path.expanduser(root)
39 | split = 'training' if train else 'testing'
40 | self.images_dir = os.path.join(self.root, split, 'image_2')
41 | self.targets_dir = os.path.join(self.root, split, target_type)
42 | self.joint_transform = joint_transform
43 | self.img_transform = img_transform
44 | self.target_type = target_type
45 | self.train = train
46 | self.images = []
47 | self.targets = []
48 |
49 | if target_type not in ['instance', 'semantic', 'semantic_rgb']:
50 | raise ValueError('Invalid value for "target_type"! Valid values are: "instance", "semantic"'
51 | ' or "semantic_rgb"')
52 |
53 | if not os.path.isdir(self.images_dir) and not os.path.isdir(self.targets_dir):
54 | raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
55 | ' specified "split" and "mode" are inside the "root" directory')
56 |
57 | for file_name in os.listdir(self.images_dir):
58 | self.images.append(os.path.join(self.images_dir, file_name))
59 | if train:
60 | self.targets.append(os.path.join(self.targets_dir, file_name))
61 |
62 | def __getitem__(self, index):
63 | if self.train:
64 | image = Image.open(self.images[index]).convert('RGB')
65 | target = Image.open(self.targets[index])
66 |
67 | boxes = [(325, 170, 475, 240), (555, 165, 695, 220), (720, 155, 900, 240)]
68 | confs = torch.tensor([1, 1, 1])
69 |
70 | if self.joint_transform:
71 | image, target, boxes = self.joint_transform(image, target, boxes)
72 |
73 | if self.img_transform:
74 | image = self.img_transform(image)
75 |
76 | return image, target, boxes, confs
77 | else:
78 | image = Image.open(self.images[index]).convert('RGB')
79 |
80 | if self.img_transform:
81 | image = self.img_transform(image)
82 |
83 | return image
84 |
85 | def __len__(self):
86 | return len(self.images)
87 |
88 |
89 | if __name__ == '__main__':
90 | import matplotlib.pyplot as plt
91 |
92 | joint_transforms = Compose([
93 | RandomHorizontalFlip(),
94 | #ToTensor()
95 | ])
96 |
97 | dataset = KITTI('../data/kitti', train=True, joint_transform=joint_transforms)
98 | img, inst, bboxes, confs = dataset[10]
99 |
100 | #print('Box size: ', bboxes.size())
101 | #print('Instance size: ', inst.size())
102 | #img = torchvision.transforms.functional.to_pil_image(img)
103 | #plt.imshow(img)
104 |
105 | #inst = torchvision.transforms.functional.to_pil_image(inst)
106 | plt.imshow(inst)
107 | ax = plt.gca()
108 |
109 | for i, box in enumerate(bboxes):
110 | xmin, ymin, xmax, ymax = box
111 | width = xmax - xmin
112 | height = ymax - ymin
113 |
114 | rect = Rectangle((xmin, ymin), width, height, linewidth=1, edgecolor='r', facecolor='none')
115 | ax.add_patch(rect)
116 |
117 | plt.show()
118 |
--------------------------------------------------------------------------------
/datasets/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from .transforms import *
2 |
--------------------------------------------------------------------------------
/datasets/transforms/transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import torch
5 | import torchvision.transforms.functional as F
6 | from PIL import Image
7 | from skimage.filters import gaussian
8 |
9 |
10 | class Compose(object):
11 | def __init__(self, transforms):
12 | self.transforms = transforms
13 |
14 | def __call__(self, img, inst, boxes, labels):
15 | for t in self.transforms:
16 | img, inst, boxes, labels = t(img, inst, boxes, labels)
17 | return img, inst, boxes, labels
18 |
19 |
20 | class ToTensor(object):
21 | def __call__(self, img, inst, boxes, labels):
22 | img = F.to_tensor(img)
23 | inst = F.to_tensor(inst).long()
24 |
25 | return img, inst, boxes, labels
26 |
27 |
28 | class Resize(object):
29 | def __init__(self, new_size, old_size=(1024, 2048)):
30 | self.old_size = old_size
31 | self.new_size = new_size
32 |
33 | self.xscale = self.new_size[1] / self.old_size[1]
34 | self.yscale = self.new_size[0] / self.old_size[0]
35 |
36 | def __call__(self, img, inst, boxes, labels):
37 | img = F.resize(img, self.new_size, interpolation=Image.BILINEAR)
38 | inst = F.resize(inst, self.new_size, interpolation=Image.NEAREST)
39 | boxes = self._resize_boxes(boxes)
40 |
41 | return img, inst, boxes, labels
42 |
43 | def _resize_boxes(self, boxes):
44 | boxes = boxes.clone()
45 | boxes[:, 0] *= self.xscale
46 | boxes[:, 1] *= self.yscale
47 | boxes[:, 2] *= self.xscale
48 | boxes[:, 3] *= self.yscale
49 |
50 | return boxes
51 |
52 |
53 | class RandomHorizontalFlip(object):
54 | def __init__(self, p=0.5):
55 | self.p = p
56 |
57 | def __call__(self, img, inst, boxes, labels):
58 | if random.random() < self.p:
59 | img = F.hflip(img)
60 | inst = F.hflip(inst)
61 | boxes = self._hflip_boxes(img.size[0], boxes)
62 |
63 | return img, inst, boxes, labels
64 |
65 | def _hflip_boxes(self, width, boxes):
66 | boxes = boxes.clone()
67 | box_width = boxes[:, 2] - boxes[:, 0]
68 | boxes[:, 2] = width - boxes[:, 0]
69 | boxes[:, 0] = boxes[:, 2] - box_width
70 |
71 | return boxes
72 |
73 |
74 | class RandomGaussionBlur(object):
75 | def __init__(self, sigma=(0.15, 1.15)):
76 | self.sigma = sigma
77 |
78 | def __call__(self, img, inst, boxes, labels):
79 | sigma = self.sigma[0] + random.random() * self.sigma[1]
80 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True)
81 | blurred_img *= 255
82 | img = Image.fromarray(blurred_img.astype(np.uint8))
83 |
84 | return img, inst, boxes, labels
85 |
86 |
87 | class RandomScale(object):
88 | def __init__(self, scale=1.0):
89 | self.scale = scale
90 |
91 | def __call__(self, img, inst, boxes, labels):
92 | scale = random.uniform(1.0, self.scale)
93 |
94 | img = F.affine(img, 0, (0, 0), scale, 0)
95 | inst = F.affine(inst, 0, (0, 0), scale, 0)
96 |
97 | return img, inst, boxes, labels
98 |
99 |
100 | class ColorJitter(object):
101 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
102 | from torchvision import transforms
103 | self.transform = transforms.ColorJitter(brightness, contrast, saturation, hue)
104 |
105 | def __call__(self, img, inst, boxes, labels):
106 | img = self.transform(img)
107 |
108 | return img, inst, boxes, labels
109 |
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .boxloss import *
2 | from .focalloss import *
3 | from .multitaskloss import *
4 |
--------------------------------------------------------------------------------
/loss/boxloss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from loss.focalloss import FocalLoss
4 |
5 |
6 | class BoxLoss(nn.Module):
7 |
8 | def __init__(self, num_classes=11, gamma=2, reduction='none'):
9 | super(BoxLoss, self).__init__()
10 | self.num_classes = num_classes
11 | self.reduction = reduction
12 | self.focal_loss = FocalLoss(gamma, reduction=reduction)
13 | self.l2_loss = nn.MSELoss(reduction=reduction)
14 |
15 | def forward(self, loc_pred, loc_target, conf_pred, labels_target):
16 | # find only non-background predictions
17 | positives = labels_target > 0
18 | predicted_loc = loc_pred[positives, :].reshape(-1, 4)
19 | groundtruth_loc = loc_target[positives, :].reshape(-1, 4)
20 |
21 | predicted_conf = conf_pred[positives, :].reshape(-1, self.num_classes)
22 | groundtruth_label = labels_target[positives, :] # .reshape(-1, self.num_classes)
23 |
24 | loc_loss = self.l2_loss(predicted_loc, groundtruth_loc)
25 | conf_loss = self.focal_loss(predicted_conf, groundtruth_label)
26 |
27 | num_positives = loc_target.size(0)
28 |
29 | return loc_loss / num_positives, conf_loss / num_positives
30 |
--------------------------------------------------------------------------------
/loss/focalloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FocalLoss(nn.Module):
7 | """Focal Loss for Dense Object Detection
8 |
9 |
10 | .. math::
11 | \text{loss}(p_{t}) = -(1-p_{t})^ \gamma \cdot \log(p_{t})
12 |
13 | Args:
14 | gamma (int, optional): Gamma smoothly adjusts the rate at which easy examples
15 | are down weighted. If gamma is equals 0 it's the same as cross entropy loss. Default: 1
16 | reduction (string, optional): Specifies the reduction to apply to the output:
17 | 'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
18 | 'mean': the sum of the output will be divided by the number of
19 | elements in the output, 'sum': the output will be summed. Default: 'mean'
20 | """
21 |
22 | def __init__(self, gamma=1, reduction='mean'):
23 | super(FocalLoss, self).__init__()
24 |
25 | self.gamma = gamma
26 | self.reduction = reduction
27 |
28 | def forward(self, input, target):
29 | log_pt = -F.cross_entropy(input, target, reduction='none')
30 | pt = torch.exp(log_pt)
31 |
32 | loss = -torch.pow(1 - pt, self.gamma) * log_pt
33 |
34 | if self.reduction == 'mean':
35 | return loss.mean()
36 | elif self.reduction == 'sum':
37 | return loss.sum()
38 | else:
39 | return loss
40 |
--------------------------------------------------------------------------------
/loss/multitaskloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class MultiTaskLoss(nn.Module):
6 | def __init__(self):
7 | super(MultiTaskLoss, self).__init__()
8 |
9 | self.uncert_semantics = nn.Parameter(torch.zeros(1, requires_grad=True))
10 | self.uncert_offsets = nn.Parameter(torch.zeros(1, requires_grad=True))
11 | self.uncert_ssdbox = nn.Parameter(torch.zeros(1, requires_grad=True))
12 | self.uncert_ssdclass = nn.Parameter(torch.zeros(1, requires_grad=True))
13 |
14 | def forward(self, semantics_loss, offsets_loss, box_loss, conf_loss):
15 | loss1 = 0.5 * torch.exp(-self.uncert_semantics) * semantics_loss + self.uncert_semantics
16 | loss2 = torch.exp(-self.uncert_offsets) * offsets_loss + self.uncert_offsets
17 | loss3 = torch.exp(-self.uncert_ssdbox) * box_loss + self.uncert_ssdbox
18 | loss4 = 0.5 * torch.exp(-self.uncert_ssdclass) * conf_loss + self.uncert_ssdclass
19 |
20 | loss = loss1 + loss2 + loss3 + loss4
21 |
22 | return loss
23 |
--------------------------------------------------------------------------------
/metrics/iou.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from ignite.metrics import Metric
3 |
4 |
5 | class IntersectionOverUnion(Metric):
6 | """Computes the intersection over union (IoU) per class.
7 |
8 | based on: https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
9 |
10 | - `update` must receive output of the form `(y_pred, y)`.
11 | """
12 |
13 | def __init__(self, num_classes=10, ignore_index=255, output_transform=lambda x: x):
14 | self.num_classes = num_classes
15 | self.ignore_index = ignore_index
16 | self.confusion_matrix = np.zeros((num_classes, num_classes))
17 |
18 | super(IntersectionOverUnion, self).__init__(output_transform=output_transform)
19 |
20 | def _fast_hist(self, label_true, label_pred):
21 | # mask = (label_true >= 0) & (label_true < self.num_classes)
22 | mask = label_true != self.ignore_index
23 | hist = np.bincount(self.num_classes * label_true[mask].astype(np.int) + label_pred[mask],
24 | minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
25 | return hist
26 |
27 | def reset(self):
28 | self.confusion_matrix = np.zeros((self.num_classes, self.num_classes))
29 |
30 | def update(self, output):
31 | y_pred, y = output
32 |
33 | for label_true, label_pred in zip(y.numpy(), y_pred.numpy()):
34 | self.confusion_matrix += self._fast_hist(label_true.flatten(), label_pred.flatten())
35 |
36 | def compute(self):
37 | hist = self.confusion_matrix
38 | with np.errstate(divide='ignore', invalid='ignore'):
39 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
40 |
41 | return np.nanmean(iu)
42 |
--------------------------------------------------------------------------------
/metrics/mean_ap.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ignite.metrics import Metric
3 |
4 |
5 | class MeanAveragePrecision(Metric):
6 |
7 | def __init__(self, num_classes=20, output_transform=lambda x: x):
8 | super(MeanAveragePrecision, self).__init__(output_transform=output_transform)
9 |
10 | self.num_classes = num_classes
11 |
12 | def reset(self):
13 | self._true_boxes = torch.tensor([], dtype=torch.long)
14 | self._true_labels = torch.tensor([], dtype=torch.long)
15 |
16 | self._det_boxes = torch.tensor([], dtype=torch.float32)
17 | self._det_labels = torch.tensor([], dtype=torch.float32)
18 | self._det_scores = torch.tensor([], dtype=torch.float32)
19 |
20 | def update(self, output):
21 | boxes_preds, labels_preds, scores_preds, boxes, labels = output
22 |
23 | self._true_boxes = torch.cat([self._true_boxes, boxes], dim=0)
24 | self._true_labels = torch.cat([self._true_labels, labels], dim=0)
25 |
26 | self._det_boxes = torch.cat([self._det_boxes, boxes_preds], dim=0)
27 | self._det_labels = torch.cat([self._det_labels, labels_preds], dim=0)
28 | self._det_scores = torch.cat([self._det_scores, scores_preds], dim=0)
29 |
30 | def compute(self):
31 | for c in range(1, self.num_classes):
32 | pass
33 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .box2pix import *
2 | from .multibox import *
3 |
--------------------------------------------------------------------------------
/models/box2pix.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils import model_zoo
4 | from torchvision import models
5 | from torchvision.models.googlenet import BasicConv2d, Inception
6 |
7 | from models.multibox import MultiBox
8 | from utils.helper import get_upsampling_weight
9 |
10 |
11 | def box2pix(num_classes=11, pretrained=False, **kwargs):
12 | if pretrained:
13 | if 'transform_input' not in kwargs:
14 | kwargs['transform_input'] = True
15 | model = Box2Pix(num_classes, **kwargs)
16 | model.load_state_dict(model_zoo.load_url(''))
17 | return model
18 |
19 | return Box2Pix(num_classes, **kwargs)
20 |
21 |
22 | class Box2Pix(nn.Module):
23 | """
24 | Implementation of Box2Pix: Single-Shot Instance Segmentation by Assigning Pixels to Object Boxes
25 |
26 | """
27 |
28 | def __init__(self, num_classes=11, transform_input=False):
29 | super(Box2Pix, self).__init__()
30 | self.transform_input = transform_input
31 |
32 | self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
33 | self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
34 | self.conv2 = BasicConv2d(64, 64, kernel_size=1)
35 | self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
36 |
37 | self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
38 | self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
39 | self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
40 |
41 | self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
42 | self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
43 | self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
44 | self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
45 | self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
46 | self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
47 |
48 | self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
49 | self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
50 | self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
51 |
52 | self.maxpool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
53 | self.inception6a = Inception2(1024, 256, 160, 320, 32, 128, 128)
54 | self.inception6b = Inception2(832, 384, 192, 384, 48, 128, 128)
55 |
56 | self.maxpool6 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
57 | self.inception7a = Inception2(1024, 256, 160, 320, 32, 128, 128)
58 | self.inception7b = Inception2(832, 384, 192, 384, 48, 128, 128)
59 |
60 | self.sem_score3b = nn.Conv2d(480, num_classes, kernel_size=1)
61 | self.sem_score4e = nn.Conv2d(832, num_classes, kernel_size=1)
62 | self.sem_score5b = nn.Conv2d(1024, num_classes, kernel_size=1)
63 | self.sem_score6b = nn.Conv2d(1024, num_classes, kernel_size=1)
64 | self.sem_upscore = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, bias=False)
65 | self.sem_upscore2 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, bias=False)
66 | self.sem_upscore4 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, bias=False)
67 | self.sem_upscore8 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=16, stride=8, bias=False)
68 |
69 | self.offs_score3b = nn.Conv2d(480, 2, kernel_size=1)
70 | self.offs_score4e = nn.Conv2d(832, 2, kernel_size=1)
71 | self.offs_score5b = nn.Conv2d(1024, 2, kernel_size=1)
72 | self.offs_score6b = nn.Conv2d(1024, 2, kernel_size=1)
73 | self.offs_upscore = nn.ConvTranspose2d(2, 2, kernel_size=4, stride=2, bias=False)
74 | self.offs_upscore2 = nn.ConvTranspose2d(2, 2, kernel_size=4, stride=2, bias=False)
75 | self.offs_upscore4 = nn.ConvTranspose2d(2, 2, kernel_size=4, stride=2, bias=False)
76 | self.offs_upscore8 = nn.ConvTranspose2d(2, 2, kernel_size=16, stride=8, bias=False)
77 |
78 | self.multibox = MultiBox(num_classes)
79 | self._initialize_weights(num_classes)
80 |
81 | def _initialize_weights(self, num_classes):
82 | for m in self.modules():
83 | if isinstance(m, nn.Conv2d):
84 | if m.kernel_size[0] == 1 and m.out_channels in [num_classes, 2]:
85 | nn.init.constant_(m.weight, 0)
86 | nn.init.constant_(m.bias, 0)
87 | else:
88 | nn.init.xavier_uniform_(m.weight)
89 | elif isinstance(m, nn.ConvTranspose2d):
90 | upsampling_weight = get_upsampling_weight(m.out_channels, m.kernel_size[0])
91 | with torch.no_grad():
92 | m.weight.copy_(upsampling_weight)
93 | elif isinstance(m, nn.BatchNorm2d):
94 | nn.init.constant_(m.weight, 1)
95 | nn.init.constant_(m.bias, 0)
96 |
97 | def init_from_googlenet(self):
98 | googlenet = models.googlenet(pretrained=True)
99 | self.load_state_dict(googlenet.state_dict(), strict=False)
100 | self.transform_input = True
101 |
102 | def _transform_input(self, x):
103 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
104 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
105 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
106 |
107 | return torch.cat([x_ch0, x_ch1, x_ch2], 1)
108 |
109 | def forward(self, x):
110 | feature_maps = []
111 | size = x.size()
112 |
113 | if self.transform_input:
114 | x = self._transform_input(x)
115 | x = self.conv1(x)
116 | x = self.maxpool1(x)
117 | x = self.conv2(x)
118 | x = self.conv3(x)
119 | x = self.maxpool2(x)
120 | x = self.inception3a(x)
121 | inception3b = self.inception3b(x)
122 | x = self.maxpool3(inception3b)
123 | x = self.inception4a(x)
124 | x = self.inception4b(x)
125 | x = self.inception4c(x)
126 | x = self.inception4d(x)
127 | inception4e = self.inception4e(x)
128 | feature_maps.append(inception4e)
129 |
130 | x = self.maxpool4(inception4e)
131 | x = self.inception5a(x)
132 | inception5b = self.inception5b(x)
133 | feature_maps.append(inception5b)
134 |
135 | x = self.maxpool5(inception5b)
136 | x = self.inception6a(x)
137 | inception6b = self.inception6b(x)
138 | feature_maps.append(inception6b)
139 |
140 | x = self.maxpool6(inception6b)
141 | x = self.inception7a(x)
142 | inception7b = self.inception7b(x)
143 | feature_maps.append(inception7b)
144 |
145 | loc_preds, conf_preds = self.multibox(feature_maps)
146 |
147 | sem_score6b = self.sem_score6b(inception6b)
148 | sem_score5b = self.sem_score5b(inception5b)
149 | semantics = self.sem_upscore(sem_score6b)
150 | semantics = semantics[:, :, 1:1 + sem_score5b.size()[2], 1:1 + sem_score5b.size()[3]]
151 | semantics += sem_score5b
152 | sem_score4e = self.sem_score4e(inception4e)
153 | semantics = self.sem_upscore2(semantics)
154 | semantics = semantics[:, :, 1:1 + sem_score4e.size()[2], 1:1 + sem_score4e.size()[3]]
155 | semantics += sem_score4e
156 | sem_score3b = self.sem_score3b(inception3b)
157 | semantics = self.sem_upscore4(semantics)
158 | semantics = semantics[:, :, 1:1 + sem_score3b.size()[2], 1:1 + sem_score3b.size()[3]]
159 | semantics += sem_score3b
160 | semantics = self.sem_upscore8(semantics)
161 | semantics = semantics[:, :, 4:4 + size[2], 4:4 + size[3]].contiguous()
162 |
163 | offs_score6b = self.offs_score6b(inception6b)
164 | offs_score5b = self.offs_score5b(inception5b)
165 | offsets = self.offs_upscore(offs_score6b)
166 | offsets = offsets[:, :, 1:1 + offs_score5b.size()[2], 1:1 + offs_score5b.size()[3]]
167 | offsets += offs_score5b
168 | offs_score4e = self.offs_score4e(inception4e)
169 | offsets = self.offs_upscore2(offsets)
170 | offsets = offsets[:, :, 1:1 + offs_score4e.size()[2], 1:1 + offs_score4e.size()[3]]
171 | offsets += offs_score4e
172 | offs_score3b = self.offs_score3b(inception3b)
173 | offsets = self.offs_upscore4(offsets)
174 | offsets = offsets[:, :, 1:1 + offs_score3b.size()[2], 1:1 + offs_score3b.size()[3]]
175 | offsets += offs_score3b
176 | offsets = self.offs_upscore8(offsets)
177 | offsets = offsets[:, :, 4:4 + size[2], 4:4 + size[3]].contiguous()
178 |
179 | return loc_preds, conf_preds, semantics, offsets
180 |
181 |
182 | class Inception2(nn.Module):
183 |
184 | def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
185 | super(Inception2, self).__init__()
186 |
187 | self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
188 | self.branch2 = nn.Sequential(
189 | BasicConv2d(in_channels, ch3x3red, kernel_size=1),
190 | BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
191 | )
192 | self.branch3 = nn.Sequential(
193 | BasicConv2d(in_channels, ch5x5red, kernel_size=1),
194 | BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)
195 | )
196 | self.branch4 = nn.Sequential(
197 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
198 | BasicConv2d(in_channels, pool_proj, kernel_size=1)
199 | )
200 |
201 | def forward(self, x):
202 | branch1 = self.branch1(x)
203 | branch2 = self.branch2(x)
204 | branch3 = self.branch3(x)
205 | branch4 = self.branch4(x)
206 |
207 | outputs = [branch1, branch2, branch3, branch4]
208 | return torch.cat(outputs, 1)
209 |
210 |
211 | if __name__ == '__main__':
212 | num_classes, width, height = 20, 1024, 2048
213 |
214 | model = Box2Pix(num_classes) # .to('cuda')
215 | model.init_from_googlenet()
216 | inp = torch.randn(1, 3, height, width) # .to('cuda')
217 |
218 | loc, conf, sem, offs = model(inp)
219 |
220 | assert loc.size(2) == 4
221 | assert conf.size(2) == num_classes
222 | assert sem.size() == torch.Size([1, num_classes, height, width])
223 | assert offs.size() == torch.Size([1, 2, height, width])
224 |
225 | print('Pass size check.')
226 |
--------------------------------------------------------------------------------
/models/multibox.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class MultiBox(nn.Module):
6 |
7 | def __init__(self, num_classes):
8 | super(MultiBox, self).__init__()
9 |
10 | self.num_classes = num_classes
11 | self.loc_layers = nn.ModuleList()
12 | self.conf_layers = nn.ModuleList()
13 |
14 | num_defaults = [16, 16, 20, 21]
15 | in_channels = [832, 1024, 1024, 1024]
16 |
17 | for i in range(len(in_channels)):
18 | self.loc_layers.append(nn.Conv2d(in_channels[i], num_defaults[i] * 4, kernel_size=1))
19 | self.conf_layers.append(nn.Conv2d(in_channels[i], num_defaults[i] * num_classes, kernel_size=1))
20 |
21 | def forward(self, input):
22 | loc_preds = []
23 | conf_preds = []
24 |
25 | for i, layer in enumerate(input):
26 | loc = self.loc_layers[i](layer)
27 | # (N x C x H x W) -> (N x H x W x C)
28 | loc = loc.permute(0, 2, 3, 1).contiguous()
29 | loc = loc.view(loc.size(0), -1, 4)
30 | loc_preds.append(loc)
31 |
32 | conf = self.conf_layers[i](layer)
33 | # (N x C x H x W) -> (N x H x W x C)
34 | conf = conf.permute(0, 2, 3, 1).contiguous()
35 | conf = conf.view(conf.size(0), -1, self.num_classes)
36 | conf_preds.append(conf)
37 |
38 | loc_preds = torch.cat(loc_preds, 1)
39 | conf_preds = torch.cat(conf_preds, 1)
40 |
41 | return loc_preds, conf_preds
42 |
--------------------------------------------------------------------------------
/prediction/__init__.py:
--------------------------------------------------------------------------------
1 | from .predictor import *
2 |
--------------------------------------------------------------------------------
/prediction/predictor.py:
--------------------------------------------------------------------------------
1 | import cv2
2 |
3 | import torch
4 | import torchvision.transforms as transforms
5 |
6 | from models.box2pix import Box2Pix
7 |
8 |
9 | class Predictor(object):
10 |
11 | def __init__(self, show_segmentation=True, show_labels=False, show_boxes=False):
12 | self.show_segmentation = show_segmentation
13 | self.show_labels = show_labels
14 | self.show_boxes = show_boxes
15 |
16 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
17 | self.net = Box2Pix().to(self.device)
18 | self.net.eval()
19 |
20 | self.transform = self.get_transform()
21 |
22 | def get_transform(self):
23 | transform = transforms.Compose([
24 | transforms.ToPILImage(),
25 | transforms.Resize((512, 1024)),
26 | transforms.ToTensor(),
27 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
28 | ])
29 |
30 | return transform
31 |
32 | def run(self, image):
33 | result = image.copy()
34 |
35 | result = self.transform(result)
36 | result = result.unsqueeze(0)
37 | result = result.to(self.device)
38 |
39 | with torch.no_grad():
40 | loc_preds, conf_preds, semantics_pred, offsets_pred = self.net(result)
41 |
42 | if self.show_segmentation:
43 | result = self.add_segmentation_overlay(result, None)
44 |
45 | if self.show_boxes:
46 | result = self.add_boxes_overlay(result, None)
47 |
48 | if self.show_labels:
49 | result = self.add_overlay_classes(result, None)
50 |
51 | return result
52 |
53 | def add_boxes_overlay(self, image, predictions):
54 |
55 | for box in predictions:
56 | top_left, bottom_right = tuple(box[:2].tolist()), tuple(box[2:].tolist())
57 | image = cv2.rectangle(image, top_left, bottom_right, 0, 1)
58 |
59 | return image
60 |
61 | def add_segmentation_overlay(self, image, predictions):
62 | return image
63 |
64 | def add_overlay_classes(self, image, predictions):
65 | scores = [0]
66 | labels = [0]
67 | boxes = predictions
68 |
69 | for box, score, label in zip(boxes, scores, labels):
70 | x, y = box[:2]
71 | cv2.putText(image, 'car', (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
72 |
73 | return image
74 |
75 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | pytorch-ignite
4 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import cv2
4 |
5 | from prediction.predictor import Predictor
6 |
7 |
8 | def run(file_path, show_boxes):
9 | cap = cv2.VideoCapture(file_path)
10 | detector = Predictor(show_boxes)
11 |
12 | while cap.isOpened():
13 | ret, frame = cap.read()
14 | res = detector.run(frame)
15 |
16 | cv2.imshow('frame', res)
17 |
18 | if cv2.waitKey(25) & 0xFF == ord('q'):
19 | break
20 |
21 | cap.release()
22 | cv2.destroyAllWindows()
23 |
24 |
25 | if __name__ == '__main__':
26 | parser = ArgumentParser('Box2Pix with PyTorch')
27 | parser.add_argument('input', help='input video file')
28 | parser.add_argument('--show-boxes', action='store_true',
29 | help='whether or not to also display boxes in the result')
30 |
31 | args = parser.parse_args()
32 |
33 | if args.input is not None:
34 | run(args.input, args.show_boxes)
35 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | from argparse import ArgumentParser
4 | from functools import partial
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 | from ignite.contrib.handlers import ProgressBar
11 | from ignite.contrib.handlers.tensorboard_logger import *
12 | from ignite.engine import Events, Engine
13 | from ignite.handlers import ModelCheckpoint, Timer
14 | from ignite.metrics import RunningAverage, Loss, ConfusionMatrix, mIoU
15 | from ignite.utils import convert_tensor
16 | from torch.utils.data import DataLoader
17 | from torchvision.transforms import Normalize
18 |
19 | import models
20 | from datasets.cityscapes import CityscapesDataset
21 | from datasets.transforms import transforms
22 | from datasets.transforms.transforms import ToTensor
23 | from loss.boxloss import BoxLoss
24 | from loss.multitaskloss import MultiTaskLoss
25 | from metrics.mean_ap import MeanAveragePrecision
26 | from utils import helper
27 | from utils.box_coder import BoxCoder
28 |
29 |
30 | def get_data_loaders(data_dir, batch_size, num_workers):
31 | # new_size = (512, 1024) # (1024, 2048)
32 |
33 | joint_transform = transforms.Compose([
34 | # transforms.Resize(new_size),
35 | transforms.RandomHorizontalFlip(),
36 | transforms.ColorJitter(0.2, 0.2, 0.2),
37 | transforms.RandomGaussionBlur(sigma=(0, 1.2)),
38 | transforms.ToTensor()
39 | ])
40 |
41 | normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
42 |
43 | train_loader = DataLoader(CityscapesDataset(root=data_dir, split='train', joint_transform=joint_transform,
44 | img_transform=normalize),
45 | batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
46 |
47 | val_loader = DataLoader(CityscapesDataset(root=data_dir, split='val', joint_transform=ToTensor(),
48 | img_transform=normalize),
49 | batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
50 |
51 | return train_loader, val_loader
52 |
53 |
54 | def run(args):
55 | train_loader, val_loader = get_data_loaders(args.dir, args.batch_size, args.num_workers)
56 |
57 | if args.seed is not None:
58 | torch.manual_seed(args.seed)
59 |
60 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
61 |
62 | num_classes = CityscapesDataset.num_instance_classes() + 1
63 | model = models.box2pix(num_classes)
64 | model.init_from_googlenet()
65 |
66 | if torch.cuda.device_count() > 1:
67 | print("Using %d GPU(s)" % torch.cuda.device_count())
68 | model = nn.DataParallel(model)
69 |
70 | model = model.to(device)
71 |
72 | semantics_criterion = nn.CrossEntropyLoss(ignore_index=255)
73 | offsets_criterion = nn.MSELoss()
74 | box_criterion = BoxLoss(num_classes, gamma=2)
75 | multitask_criterion = MultiTaskLoss().to(device)
76 |
77 | box_coder = BoxCoder()
78 | optimizer = optim.Adam([{'params': model.parameters()},
79 | {'params': multitask_criterion.parameters()}], lr=args.lr)
80 |
81 | if args.resume:
82 | if os.path.isfile(args.resume):
83 | print("Loading checkpoint '{}'".format(args.resume))
84 | checkpoint = torch.load(args.resume)
85 | args.start_epoch = checkpoint['epoch']
86 | model.load_state_dict(checkpoint['model'])
87 | optimizer.load_state_dict(checkpoint['optimizer'])
88 | multitask_criterion.load_state_dict(checkpoint['multitask'])
89 | print("Loaded checkpoint '{}' (Epoch {})".format(args.resume, checkpoint['epoch']))
90 | else:
91 | print("No checkpoint found at '{}'".format(args.resume))
92 |
93 | def _prepare_batch(batch, non_blocking=True):
94 | x, instance, boxes, labels = batch
95 |
96 | return (convert_tensor(x, device=device, non_blocking=non_blocking),
97 | convert_tensor(instance, device=device, non_blocking=non_blocking),
98 | convert_tensor(boxes, device=device, non_blocking=non_blocking),
99 | convert_tensor(labels, device=device, non_blocking=non_blocking))
100 |
101 | def _update(engine, batch):
102 | model.train()
103 | optimizer.zero_grad()
104 | x, instance, boxes, labels = _prepare_batch(batch)
105 | boxes, labels = box_coder.encode(boxes, labels)
106 |
107 | loc_preds, conf_preds, semantics_pred, offsets_pred = model(x)
108 |
109 | semantics_loss = semantics_criterion(semantics_pred, instance)
110 | offsets_loss = offsets_criterion(offsets_pred, instance)
111 | box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds, labels)
112 |
113 | loss = multitask_criterion(semantics_loss, offsets_loss, box_loss, conf_loss)
114 |
115 | loss.backward()
116 | optimizer.step()
117 |
118 | return {
119 | 'loss': loss.item(),
120 | 'loss_semantics': semantics_loss.item(),
121 | 'loss_offsets': offsets_loss.item(),
122 | 'loss_ssdbox': box_loss.item(),
123 | 'loss_ssdclass': conf_loss.item()
124 | }
125 |
126 | trainer = Engine(_update)
127 |
128 | checkpoint_handler = ModelCheckpoint(args.output_dir, 'checkpoint', save_interval=1, n_saved=10,
129 | require_empty=False, create_dir=True, save_as_state_dict=False)
130 | timer = Timer(average=True)
131 |
132 | # attach running average metrics
133 | train_metrics = ['loss', 'loss_semantics', 'loss_offsets', 'loss_ssdbox', 'loss_ssdclass']
134 | for m in train_metrics:
135 | transform = partial(lambda x, metric: x[metric], metric=m)
136 | RunningAverage(output_transform=transform).attach(trainer, m)
137 |
138 | # attach progress bar
139 | pbar = ProgressBar(persist=True)
140 | pbar.attach(trainer, metric_names=train_metrics)
141 |
142 | checkpoint = {'model': model.state_dict(), 'epoch': trainer.state.epoch, 'optimizer': optimizer.state_dict(),
143 | 'multitask': multitask_criterion.state_dict()}
144 | trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={
145 | 'checkpoint': checkpoint
146 | })
147 |
148 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
149 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
150 |
151 | def _inference(engine, batch):
152 | model.eval()
153 | with torch.no_grad():
154 | x, instance, boxes, labels = _prepare_batch(batch)
155 | loc_preds, conf_preds, semantics, offsets_pred = model(x)
156 | boxes_preds, labels_preds, scores_preds = box_coder.decode(loc_preds, F.softmax(conf_preds, dim=1),
157 | score_thresh=0.01)
158 |
159 | semantics_loss = semantics_criterion(semantics, instance)
160 | offsets_loss = offsets_criterion(offsets_pred, instance)
161 | box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds, labels)
162 |
163 | semantics_pred = semantics.argmax(dim=1)
164 | instances = helper.assign_pix2box(semantics_pred, offsets_pred, boxes_preds, labels_preds)
165 |
166 | return {
167 | 'loss': (semantics_loss, offsets_loss, {'box_loss': box_loss, 'conf_loss': conf_loss}),
168 | 'objects': (boxes_preds, labels_preds, scores_preds, boxes, labels),
169 | 'semantics': semantics_pred,
170 | 'instances': instances
171 | }
172 |
173 | train_evaluator = Engine(_inference)
174 | cm = ConfusionMatrix(num_classes=num_classes, output_transform=lambda x: x['semantics'])
175 | mIoU(cm, ignore_index=0).attach(train_evaluator, 'mIoU')
176 | Loss(multitask_criterion, output_transform=lambda x: x['loss']).attach(train_evaluator, 'loss')
177 | MeanAveragePrecision(num_classes, output_transform=lambda x: x['objects']).attach(train_evaluator, 'mAP')
178 |
179 | evaluator = Engine(_inference)
180 | cm2 = ConfusionMatrix(num_classes=num_classes, output_transform=lambda x: x['semantics'])
181 | mIoU(cm2, ignore_index=0).attach(train_evaluator, 'mIoU')
182 | Loss(multitask_criterion, output_transform=lambda x: x['loss']).attach(evaluator, 'loss')
183 | MeanAveragePrecision(num_classes, output_transform=lambda x: x['objects']).attach(evaluator, 'mAP')
184 |
185 | tb_logger = TensorboardLogger(args.log_dir)
186 | tb_logger.attach(trainer,
187 | log_handler=OutputHandler(tag='training', output_transform=lambda loss: {
188 | 'loss': loss['loss'],
189 | 'loss_semantics': loss['loss_semantics'],
190 | 'loss_offsets': loss['loss_offsets'],
191 | 'loss_ssdbox': loss['loss_ssdbox'],
192 | 'loss_ssdclass': loss['loss_ssdclass']
193 |
194 | }),
195 | event_name=Events.ITERATION_COMPLETED)
196 |
197 | tb_logger.attach(train_evaluator,
198 | log_handler=OutputHandler(tag='training_eval',
199 | metric_names=['loss', 'mAP', 'mIoU'],
200 | output_transform=lambda loss: {
201 | 'loss': loss['loss'],
202 | 'objects': loss['objects'],
203 | 'semantics': loss['semantics']
204 | },
205 | another_engine=trainer),
206 | event_name=Events.EPOCH_COMPLETED)
207 |
208 | tb_logger.attach(evaluator,
209 | log_handler=OutputHandler(tag='validation_eval',
210 | metric_names=['loss', 'mAP', 'mIoU'],
211 | output_transform=lambda loss: {
212 | 'loss': loss['loss'],
213 | 'objects': loss['objects'],
214 | 'semantics': loss['semantics']
215 | },
216 | another_engine=trainer),
217 | event_name=Events.EPOCH_COMPLETED)
218 |
219 | @trainer.on(Events.STARTED)
220 | def initialize(engine):
221 | if args.resume:
222 | engine.state.epoch = args.start_epoch
223 |
224 | @trainer.on(Events.EPOCH_COMPLETED)
225 | def print_times(engine):
226 | pbar.log_message("Epoch [{}/{}] done. Time per batch: {:.3f}[s]".format(engine.state.epoch,
227 | engine.state.max_epochs, timer.value()))
228 | timer.reset()
229 |
230 | @trainer.on(Events.EPOCH_COMPLETED)
231 | def log_training_results(engine):
232 | train_evaluator.run(train_loader)
233 | metrics = train_evaluator.state.metrics
234 | loss = metrics['loss']
235 | mean_ap = metrics['mAP']
236 | iou = metrics['mIoU']
237 |
238 | pbar.log_message('Training results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}'
239 | .format(loss, evaluator.state.epochs, evaluator.state.max_epochs, mean_ap, iou * 100.0))
240 |
241 | @trainer.on(Events.EPOCH_COMPLETED)
242 | def log_validation_results(engine):
243 | evaluator.run(val_loader)
244 | metrics = evaluator.state.metrics
245 | loss = metrics['loss']
246 | mean_ap = metrics['mAP']
247 | iou = metrics['mIoU']
248 |
249 | pbar.log_message('Validation results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}'
250 | .format(loss, evaluator.state.epochs, evaluator.state.max_epochs, mean_ap, iou * 100.0))
251 |
252 | @trainer.on(Events.EXCEPTION_RAISED)
253 | def handle_exception(engine, e):
254 | if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
255 | engine.terminate()
256 | warnings.warn("KeyboardInterrupt caught. Exiting gracefully.")
257 |
258 | checkpoint_handler(engine, {'model_exception': model})
259 | else:
260 | raise e
261 |
262 | @trainer.on(Events.COMPLETED)
263 | def save_final_model(engine):
264 | checkpoint_handler(engine, {'final': model})
265 |
266 | trainer.run(train_loader, max_epochs=args.epochs)
267 | tb_logger.close()
268 |
269 |
270 | if __name__ == '__main__':
271 | parser = ArgumentParser('Box2Pix with PyTorch')
272 | parser.add_argument('--batch_size', type=int, default=8,
273 | help='input batch size for training')
274 | parser.add_argument('--num-workers', type=int, default=8,
275 | help='number of workers')
276 | parser.add_argument('--epochs', type=int, default=200,
277 | help='number of epochs to train')
278 | parser.add_argument('--lr', type=float, default=1e-4,
279 | help='learning rate')
280 | parser.add_argument('--seed', type=int, help='manual seed')
281 | parser.add_argument('--output-dir', default='./checkpoints',
282 | help='directory to save model checkpoints')
283 | parser.add_argument('--resume', type=str, metavar='PATH',
284 | help='path to latest checkpoint (default: none)')
285 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
286 | help='manual epoch number (useful on restarts)')
287 | parser.add_argument('--log-interval', type=int, default=10,
288 | help='how many batches to wait before logging training status')
289 | parser.add_argument("--log-dir", type=str, default="tensorboard_logs",
290 | help="log directory for Tensorboard log output")
291 | parser.add_argument("--dataset-dir", type=str, default="data/cityscapes",
292 | help="location of the dataset")
293 |
294 | run(parser.parse_args())
295 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .box_coder import *
2 | from .box_utils import *
3 | from .helper import *
4 |
--------------------------------------------------------------------------------
/utils/box_coder.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import torch
4 |
5 | from utils import box_utils
6 |
7 | FeatureMapDef = namedtuple('FeatureMapDef', ['width', 'height', 'receptive_size'])
8 |
9 |
10 | class BoxCoder(object):
11 | """
12 | References:
13 | https://github.com/ECer23/ssd.pytorch
14 | https://github.com/kuangliu/torchcv
15 | """
16 |
17 | def __init__(self, img_width=2048, img_height=1024):
18 | self.variances = (0.1, 0.2)
19 |
20 | priors = [
21 | # height, width
22 | (4, 52), (24, 24), (54, 8), (80, 22), (52, 52),
23 | (20, 78), (156, 50), (78, 78), (48, 144), (412, 76),
24 | (104, 150), (74, 404), (644, 166), (358, 448), (70, 686), (68, 948),
25 | (772, 526), (476, 820), (150, 1122), (890, 880), (516, 1130)
26 | ]
27 |
28 | feature_maps = [
29 | FeatureMapDef(128, 64, 427),
30 | FeatureMapDef(64, 32, 715),
31 | FeatureMapDef(32, 16, 1291),
32 | FeatureMapDef(16, 8, 2443)
33 | ]
34 |
35 | boxes = []
36 | for fm in feature_maps:
37 | step_w = fm.width / img_width
38 | step_h = fm.height / img_height
39 | for x in range(fm.width):
40 | for y in range(fm.height):
41 | for p_h, p_w in priors:
42 | cx = (x + 0.5) * step_w
43 | cy = (y + 0.5) * step_h
44 | h = p_h / img_height
45 | w = p_w / img_width
46 |
47 | if fm.receptive_size > (p_h * 2) or fm.receptive_size > (p_w * 2):
48 | boxes.append((cx, cy, h, w))
49 |
50 | self.priors = torch.as_tensor(boxes, dtype=torch.float32).clamp_(0.0, 1.0)
51 |
52 | def encode(self, boxes, labels, change_threshold=0.7):
53 | """Encode target bounding boxes and class labels.
54 | SSD coding rules:
55 | tx = (x - anchor_x) / (variance[0] * anchor_w)
56 | ty = (y - anchor_y) / (variance[0] * anchor_h)
57 | tw = log(w / anchor_w) / variance[1]
58 | th = log(h / anchor_h) / variance[1]
59 |
60 | Args:
61 | boxes: (tensor) bounding boxes of (xmin, ymin, xmax, ymax), sized [#obj, 4].
62 | labels: (tensor) object class labels, sized [#obj,].
63 | change_threshold: (float) the change metric threshold
64 | """
65 |
66 | priors = self.priors
67 | priors = box_utils.center_to_corner_form(priors)
68 |
69 | change = box_utils.d_change(boxes, priors)
70 |
71 | change, max_idx = change.max(0)
72 | max_idx.squeeze_(0)
73 | change.squeeze_(0)
74 |
75 | boxes = boxes[max_idx]
76 | boxes = box_utils.corner_to_center_form(boxes)
77 | priors = box_utils.corner_to_center_form(priors)
78 |
79 | loc_xy = (boxes[:, :2] - priors[:, :2]) / priors[:, 2:] / self.variances[0]
80 | loc_wh = torch.log(boxes[:, 2:] / priors[:, 2:]) / self.variances[1]
81 | loc = torch.cat([loc_xy, loc_wh], 1)
82 |
83 | conf = labels[max_idx] + 1 # background class = 0
84 | conf[change < change_threshold] = 0 # background
85 |
86 | return loc, conf
87 |
88 | def decode(self, loc_preds, conf_preds, score_thresh=0.6, nms_thresh=0.5):
89 | """Decode predicted loc/cls back to real box locations and class labels.
90 | Args:
91 | loc_preds: (tensor) predicted loc, sized [8732,4].
92 | conf_preds: (tensor) predicted conf, sized [8732,21].
93 | score_thresh: (float) threshold for object confidence score.
94 | nms_thresh: (float) threshold for box nms.
95 |
96 | Returns:
97 | boxes: (tensor) bbox locations, sized [#obj,4].
98 | labels: (tensor) class labels, sized [#obj,].
99 | """
100 | cxcy = loc_preds[:, :2] * self.variances[0] * self.priors[:, 2:] + self.priors[:, :2]
101 | wh = torch.exp(loc_preds[:, 2:] * self.variances[1]) * self.priors[:, 2:]
102 | box_preds = torch.cat([cxcy - wh / 2, cxcy + wh / 2], 1)
103 |
104 | boxes = []
105 | labels = []
106 | scores = []
107 | num_classes = conf_preds.size(1)
108 | for i in range(num_classes - 1):
109 | score = conf_preds[:, i + 1] # class i corresponds to (i + 1) column
110 | mask = score > score_thresh
111 | if not mask.any():
112 | continue
113 | box = box_preds[mask.nonzero().squeeze()]
114 | score = score[mask]
115 |
116 | keep = box # torchvision.layers.nms(box, score, nms_thresh)
117 | boxes.append(box[keep])
118 | labels.append(torch.full(box[keep].size()[0], i, dtype=torch.int64))
119 | scores.append(score[keep])
120 |
121 | boxes = torch.cat(boxes, 0)
122 | labels = torch.cat(labels, 0)
123 | scores = torch.cat(scores, 0)
124 |
125 | return boxes, labels, scores
126 |
127 |
128 | if __name__ == '__main__':
129 | coder = BoxCoder()
130 | print(coder.priors[4:])
131 | print(coder.priors.size())
132 |
--------------------------------------------------------------------------------
/utils/box_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | priors = [
4 | # height, width
5 | (4, 52),
6 | (25, 25),
7 | (54, 8),
8 | (80, 22),
9 | (52, 52),
10 | (20, 78),
11 | (156, 50),
12 | (78, 78),
13 | (48, 144),
14 | (412, 76),
15 | (104, 150),
16 | (74, 404),
17 | (645, 166),
18 | (358, 448),
19 | (70, 686),
20 | (68, 948),
21 | (772, 526),
22 | (476, 820),
23 | (150, 1122),
24 | (890, 880),
25 | (518, 1130)
26 | ]
27 |
28 | priors_new = [
29 | # height, width
30 | (4, 52),
31 | (24, 24),
32 | (54, 8),
33 | (80, 22),
34 | (52, 52),
35 | (20, 78),
36 | (156, 50),
37 | (78, 78),
38 | (48, 144),
39 | (412, 76),
40 | (104, 150),
41 | (74, 404),
42 | (644, 166),
43 | (358, 448),
44 | (70, 686),
45 | (68, 948),
46 | (772, 526),
47 | (476, 820),
48 | (150, 1122),
49 | (890, 880),
50 | (516, 1130)
51 | ]
52 |
53 |
54 | def get_bounding_box(polygon):
55 | fpoint = polygon[0]
56 | xmin, ymin, xmax, ymax = fpoint[0], fpoint[1], fpoint[0], fpoint[1]
57 | for point in polygon:
58 | x, y = point[0], point[1]
59 | xmin = min(xmin, x)
60 | ymin = min(ymin, y)
61 | xmax = max(xmax, x)
62 | ymax = max(ymax, y)
63 |
64 | return xmin, ymin, xmax, ymax
65 |
66 |
67 | def d_change(prior, ground_truth):
68 | """Compute a change based metric of two sets of boxes.
69 |
70 | Args:
71 | prior (tensor): Prior boxes, Shape: [num_priors, 4]
72 | ground_truth (tensor): Ground truth bounding boxes, Shape: [num_objects, 4]
73 | """
74 |
75 | xtl = torch.abs(prior[:, 0] - ground_truth[:, 0])
76 | ytl = torch.abs(prior[:, 1] - ground_truth[:, 1])
77 | xbr = torch.abs(prior[:, 2] - ground_truth[:, 2])
78 | ybr = torch.abs(prior[:, 3] - ground_truth[:, 3])
79 |
80 | wgt = ground_truth[:, 2] - ground_truth[:, 0]
81 | hgt = ground_truth[:, 3] - ground_truth[:, 1]
82 |
83 | return torch.sqrt((torch.pow(ytl, 2) / hgt) + (torch.pow(xtl, 2) / wgt)
84 | + (torch.pow(ybr, 2) / hgt) + (torch.pow(xbr, 2) / wgt))
85 |
86 |
87 | def corner_to_center_form(boxes):
88 | """Convert bounding boxes from (xmin, ymin, xmax, ymax) to (cx, cy, width, height)
89 |
90 | Args:
91 | boxes (tensor): Boxes, Shape: [num_priors, 4]
92 | """
93 |
94 | return torch.cat([(boxes[:, 2:] + boxes[:, :2]) / 2,
95 | boxes[:, 2:] - boxes[:, :2]], 1)
96 |
97 |
98 | def center_to_corner_form(boxes):
99 | """Convert bounding boxes from (cx, cy, width, height) to (xmin, ymin, xmax, ymax)
100 |
101 | Args:
102 | boxes (tensor): Boxes, Shape: [num_priors, 4]
103 | """
104 |
105 | return torch.cat([boxes[:, 2:] - (boxes[:, :2] / 2),
106 | boxes[:, 2:] + (boxes[:, :2] / 2)], 1)
107 |
108 |
109 | def nms(boxes, scores, thresh):
110 | x1 = boxes[:, 0]
111 | y1 = boxes[:, 1]
112 | x2 = boxes[:, 2]
113 | y2 = boxes[:, 3]
114 |
115 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
116 | _, indices = scores.scores.sort(0, descending=True)
117 |
118 | keep = []
119 | while indices.size > 0:
120 | i = indices[0]
121 | keep.append(i)
122 |
123 | xx1 = torch.max(x1[i], x1[indices[1:]])
124 | yy1 = torch.max(y1[i], y1[indices[1:]])
125 | xx2 = torch.min(x2[i], x2[indices[1:]])
126 | yy2 = torch.min(y2[i], y2[indices[1:]])
127 |
128 | w = torch.clamp(xx2 - xx1, min=0.0)
129 | h = torch.clamp(yy2 - yy1, min=0.0)
130 | inter = w * h
131 | ovr = inter / (areas[i] + areas[indices[1:]] - inter)
132 |
133 | inds = torch.nonzero(ovr <= thresh).squeeze()
134 | indices = indices[inds + 1]
135 |
136 | return keep
137 |
138 |
139 | if __name__ == '__main__':
140 | """
141 | layers = [
142 | # inception4e
143 | {'size': 427, 'boxes': []},
144 | # inception5b
145 | {'size': 715, 'boxes': []},
146 | # inception6b
147 | {'size': 1291, 'boxes': []},
148 | # inception7b
149 | {'size': 2443, 'boxes': []}
150 | ]
151 |
152 | # calculate the number of associated prior boxes for each layer
153 | for prior in priors_new:
154 | height, width = prior
155 |
156 | for layer in layers:
157 | if layer['size'] > (height * 2) or layer['size'] > (width * 2):
158 | layer['boxes'].append(prior)
159 |
160 | for layer in layers:
161 | print('Number of priors: ', len(layer['boxes']))
162 | """
163 |
--------------------------------------------------------------------------------
/utils/helper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def get_upsampling_weight(channels, kernel_size):
6 | """Make a 2D bilinear kernel suitable for upsampling
7 |
8 | Based on: https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/surgery.py
9 | """
10 | factor = (kernel_size + 1) // 2
11 | if kernel_size % 2 == 1:
12 | center = factor - 1
13 | else:
14 | center = factor - 0.5
15 | og = np.ogrid[:kernel_size, :kernel_size]
16 | filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
17 | filt = torch.from_numpy(filt)
18 | weight = torch.zeros([channels, channels, kernel_size, kernel_size], dtype=torch.float64)
19 | weight[range(channels), range(channels), :, :] = filt
20 |
21 | return weight
22 |
23 |
24 | def assign_pix2box(semantics, offsets, boxes, labels):
25 | return semantics
26 |
--------------------------------------------------------------------------------
/utils/nms.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def non_maximum_suppression(boxes, scores, threshold=0.5):
5 |
6 | if boxes.numel() == 0:
7 | return torch.LongTensor()
8 |
9 | xmin = boxes[:, 0]
10 | ymin = boxes[:, 1]
11 | xmax = boxes[:, 2]
12 | ymax = boxes[:, 3]
13 |
14 | areas = (xmax - xmin) * (ymax - ymin)
15 |
16 | _, indices = scores.sort(0, descending=True)
17 | keep = []
18 | while indices.numel() > 0:
19 | i = indices[0]
20 | keep.append(i)
21 |
22 | if indices.numel() == 1:
23 | break
24 |
25 | return torch.LongTensor(keep)
26 |
--------------------------------------------------------------------------------