├── .gitignore
├── README.md
├── config
└── voc.yaml
├── dataset
├── __init__.py
└── voc.py
├── model
├── __init__.py
└── ssd.py
├── requirements.txt
└── tools
├── __init__.py
├── infer.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore all image files
2 | *.jpg
3 | *.png
4 | *.jpeg
5 |
6 | # Ignore pycharm and system files
7 | .DS_Store
8 | *.idea
9 | __pycache__
10 | *.zip
11 |
12 | # Ignore dataset files
13 | *.csv
14 | *.json
15 |
16 | # Ignore checkpoints
17 | *.pth
18 |
19 | # Ignore pickle files
20 | *.pkl
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | SSD Implementation in Pytorch
2 | ========
3 |
4 | This repository implements SSD, with training, inference and mAP evaluation in PyTorch.
5 | Most of the code is just parts of pytorch ssd implementation and all I have done is gotten rid of abstractions and commented the code.
6 |
7 | The repo provides code to train on voc dataset. Specifically I trained on trainval images of VOC 2007 dataset and for testing, I use VOC2007 test set.
8 |
9 | ## SSD Explanation and Implementation Video
10 |
11 |
12 |
13 |
14 |
15 | ## Result by training SSD on VOC 2007 dataset
16 | One should be able to get **71-72% mAP** by training on VOC 2007 trainval images(**68% reported in paper**).
17 |
18 | Adding 2012 trainval we should be able to get **>77% mAP**
19 |
20 |
21 |
22 |
23 |
24 |
25 | Here's an evaluation result that I got after training 100 epochs.
26 | ```
27 | Class Wise Average Precisions
28 | AP for class aeroplane = 0.7552
29 | AP for class bicycle = 0.8384
30 | AP for class bird = 0.7025
31 | AP for class boat = 0.6543
32 | AP for class bottle = 0.3411
33 | AP for class bus = 0.8355
34 | AP for class car = 0.8611
35 | AP for class cat = 0.8682
36 | AP for class chair = 0.4798
37 | AP for class cow = 0.7453
38 | AP for class diningtable = 0.7092
39 | AP for class dog = 0.8582
40 | AP for class horse = 0.8506
41 | AP for class motorbike = 0.8259
42 | AP for class person = 0.7721
43 | AP for class pottedplant = 0.3939
44 | AP for class sheep = 0.7300
45 | AP for class sofa = 0.7626
46 | AP for class train = 0.8615
47 | AP for class tvmonitor = 0.7260
48 | Mean Average Precision : 0.7286
49 | ```
50 |
51 |
52 | ## Data preparation
53 | For setting up the VOC 2007 dataset:
54 | * Create a data directory inside SSD-Pytorch
55 | * Download VOC 2007 train/val data from http://host.robots.ox.ac.uk/pascal/VOC/voc2007 and copy the `VOC2007` directory inside `data` directory
56 | * Download VOC 2007 test data from http://host.robots.ox.ac.uk/pascal/VOC/voc2007 and copy the `VOC2007` directory and name it as `VOC2007-test` directory inside `data`
57 | * If you want to use 2012 trainval images as well, then download VOC 2012 train/val data from http://host.robots.ox.ac.uk/pascal/VOC/voc2007 and copy the `VOC2012` directory inside `data`
58 | * Ensure to place all the directories inside the data folder of repo according to below structure
59 | ```
60 | SSD-Pytorch
61 | -> data
62 | -> VOC2007
63 | -> JPEGImages
64 | -> Annotations
65 | -> ImageSets
66 | -> VOC2007-test
67 | -> JPEGImages
68 | -> Annotations
69 | -> VOC2012 (if needed)
70 | -> JPEGImages
71 | -> Annotations
72 | -> ImageSets
73 | -> tools
74 | -> train.py
75 | -> infer.py
76 | -> config
77 | -> voc.yaml
78 | -> model
79 | -> ssd.py
80 | -> dataset
81 | -> voc.py
82 | ```
83 |
84 | ## For training on your own dataset
85 |
86 | * Update the path for `train_im_sets`, `test_im_sets` in config
87 | * If you want to train on 2007+2012 trainval then have `train_im_sets` as `['data/VOC2007', 'data/VOC2012'] `
88 | * Modify dataset file `dataset/voc.py` to load images and annotations accordingly specifically `load_images_and_anns` method
89 | * Update the class list of your dataset in the dataset file.
90 | * Dataset class should return the following:
91 | ```
92 | im_tensor(C x H x W) ,
93 | target{
94 | 'bboxes': Number of Gts x 4 (this is in x1y1x2y2 format normalized from 0-1)
95 | 'labels': Number of Gts,
96 | 'difficult': Number of Gts,
97 | }
98 | file_path
99 | ```
100 |
101 |
102 | ## For modifications
103 | * In case you have GPU which does not support large batch size, you can use a smaller batch size like 2 and then have `acc_steps` in config set as 4(to mimic 8 batch size training).
104 | * For using a different backbone you would have to change the following:
105 | * Change the backbone, extra conv layers and creation of feature maps in initialization of SSD model
106 | * Ensure the `out_channels` is correctly set as the channels in all feature maps to be used for prediction [here](https://github.com/explainingai-code/SSD-PyTorch/blob/main/model/ssd.py#L316)
107 | * In the forward method call the backbone and extra conv layers and ensure `outputs` is correctly set as list of feature maps [here](https://github.com/explainingai-code/SSD-PyTorch/blob/main/model/ssd.py#L472)
108 |
109 | # Quickstart
110 | * Create a new conda environment with python 3.10 then run below commands
111 | * ```git clone https://github.com/explainingai-code/SSD-PyTorch.git```
112 | * ```cd SSD-PyTorch```
113 | * ```pip install -r requirements.txt```
114 | * For training/inference use the below commands passing the desired configuration file as the config argument in case you want to play with it.
115 | * ```python -m tools.train``` for training SSD on VOC dataset
116 | * ```python -m tools.infer --evaluate False --infer_samples True``` for generating inference predictions
117 | * ```python -m tools.infer --evaluate True --infer_samples False``` for evaluating on test dataset
118 |
119 | ## Configuration
120 | * ```config/voc.yaml``` - Allows you to play with different components of SSD on voc dataset
121 |
122 |
123 | ## Output
124 | Outputs will be saved according to the configuration present in yaml files.
125 |
126 | For every run a folder of `task_name` key in config will be created
127 |
128 | During training of SSD the following output will be saved
129 | * Latest Model checkpoint in ```task_name``` directory
130 |
131 | During inference the following output will be saved
132 | * Sample prediction outputs for images in ```task_name/samples```
133 |
134 | ## Citations
135 | ```
136 | @article{DBLP:journals/corr/LiuAESR15,
137 | author = {Wei Liu and
138 | Dragomir Anguelov and
139 | Dumitru Erhan and
140 | Christian Szegedy and
141 | Scott E. Reed and
142 | Cheng{-}Yang Fu and
143 | Alexander C. Berg},
144 | title = {{SSD:} Single Shot MultiBox Detector},
145 | journal = {CoRR},
146 | volume = {abs/1512.02325},
147 | year = {2015},
148 | url = {http://arxiv.org/abs/1512.02325},
149 | eprinttype = {arXiv},
150 | eprint = {1512.02325},
151 | timestamp = {Wed, 12 Feb 2020 08:32:49 +0100},
152 | biburl = {https://dblp.org/rec/journals/corr/LiuAESR15.bib},
153 | bibsource = {dblp computer science bibliography, https://dblp.org}
154 | }
155 | ```
156 |
--------------------------------------------------------------------------------
/config/voc.yaml:
--------------------------------------------------------------------------------
1 | dataset_params:
2 | train_im_sets: ['data/VOC2007']
3 | test_im_sets: ['data/VOC2007-test']
4 | num_classes : 21
5 | im_size : 300
6 |
7 | model_params:
8 | im_channels : 3
9 | aspect_ratios : [
10 | [ 1., 2., 0.5 ],
11 | [ 1., 2., 3., 0.5, .333 ],
12 | [ 1., 2., 3., 0.5, .333 ],
13 | [ 1., 2., 3., 0.5, .333 ],
14 | [ 1., 2., 0.5 ],
15 | [ 1., 2., 0.5 ]
16 | ]
17 | scales : [0.1, 0.2, 0.375, 0.55, 0.725, 0.9]
18 | iou_threshold : 0.5
19 | low_score_threshold : 0.01
20 | neg_pos_ratio : 3
21 | pre_nms_topK : 400
22 | detections_per_img : 200
23 | nms_threshold : 0.45
24 |
25 | train_params:
26 | task_name: 'voc'
27 | seed: 1111
28 | acc_steps: 1
29 | num_epochs: 100
30 | batch_size: 8
31 | lr_steps: [ 40, 50, 60, 70, 80, 90 ]
32 | lr: 0.001
33 | log_steps : 100
34 | infer_conf_threshold : 0.5
35 | ckpt_name: 'ssd_voc2007.pth'
36 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/SSD-PyTorch/41b309063138a9d32a0031cfda513f197631d50a/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/voc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.transforms.v2
4 | from torch.utils.data.dataset import Dataset
5 | import xml.etree.ElementTree as ET
6 | from torchvision import tv_tensors
7 | from torchvision.io import read_image
8 |
9 |
10 | def load_images_and_anns(im_sets, label2idx, ann_fname, split):
11 | r"""
12 | Method to get the xml files and for each file
13 | get all the objects and their ground truth detection
14 | information for the dataset
15 | :param im_sets: Sets of images to consider
16 | :param label2idx: Class Name to index mapping for dataset
17 | :param ann_fname: txt file containing image names{trainval.txt/test.txt}
18 | :param split: train/test
19 | :return:
20 | """
21 | im_infos = []
22 | ims = []
23 |
24 | for im_set in im_sets:
25 | im_names = []
26 | # Fetch all image names in txt file for this imageset
27 | for line in open(os.path.join(
28 | im_set, 'ImageSets', 'Main', '{}.txt'.format(ann_fname))):
29 | im_names.append(line.strip())
30 |
31 | # Set annotation and image path
32 | ann_dir = os.path.join(im_set, 'Annotations')
33 | im_dir = os.path.join(im_set, 'JPEGImages')
34 | for im_name in im_names:
35 | ann_file = os.path.join(ann_dir, '{}.xml'.format(im_name))
36 | im_info = {}
37 | ann_info = ET.parse(ann_file)
38 | root = ann_info.getroot()
39 | size = root.find('size')
40 | width = int(size.find('width').text)
41 | height = int(size.find('height').text)
42 | im_info['img_id'] = os.path.basename(ann_file).split('.xml')[0]
43 | im_info['filename'] = os.path.join(
44 | im_dir, '{}.jpg'.format(im_info['img_id'])
45 | )
46 | im_info['width'] = width
47 | im_info['height'] = height
48 | detections = []
49 |
50 | for obj in ann_info.findall('object'):
51 | det = {}
52 | label = label2idx[obj.find('name').text]
53 | difficult = int(obj.find('difficult').text)
54 | bbox_info = obj.find('bndbox')
55 | bbox = [
56 | int(bbox_info.find('xmin').text) - 1,
57 | int(bbox_info.find('ymin').text) - 1,
58 | int(bbox_info.find('xmax').text) - 1,
59 | int(bbox_info.find('ymax').text) - 1
60 | ]
61 | det['label'] = label
62 | det['bbox'] = bbox
63 | det['difficult'] = difficult
64 | # At test time eval does the job of ignoring difficult
65 | detections.append(det)
66 |
67 | im_info['detections'] = detections
68 | im_infos.append(im_info)
69 | print('Total {} images found'.format(len(im_infos)))
70 | return im_infos
71 |
72 |
73 | class VOCDataset(Dataset):
74 | def __init__(self, split, im_sets, im_size=300):
75 | self.split = split
76 |
77 | # Imagesets for this dataset instance (VOC2007/VOC2007+VOC2012/VOC2007-test)
78 | self.im_sets = im_sets
79 | self.fname = 'trainval' if self.split == 'train' else 'test'
80 | self.im_size = im_size
81 | self.im_mean = [123.0, 117.0, 104.0]
82 | self.imagenet_mean = [0.485, 0.456, 0.406]
83 | self.imagenet_std = [0.229, 0.224, 0.225]
84 |
85 | # Train and test transformations
86 | self.transforms = {
87 | 'train': torchvision.transforms.v2.Compose([
88 | torchvision.transforms.v2.RandomPhotometricDistort(),
89 | torchvision.transforms.v2.RandomZoomOut(fill=self.im_mean),
90 | torchvision.transforms.v2.RandomIoUCrop(),
91 | torchvision.transforms.v2.RandomHorizontalFlip(p=0.5),
92 | torchvision.transforms.v2.Resize(size=(self.im_size, self.im_size)),
93 | torchvision.transforms.v2.SanitizeBoundingBoxes(
94 | labels_getter=lambda transform_input:
95 | (transform_input[1]["labels"], transform_input[1]["difficult"])),
96 | torchvision.transforms.v2.ToPureTensor(),
97 | torchvision.transforms.v2.ToDtype(torch.float32, scale=True),
98 | torchvision.transforms.v2.Normalize(mean=self.imagenet_mean,
99 | std=self.imagenet_std)
100 |
101 | ]),
102 | 'test': torchvision.transforms.v2.Compose([
103 | torchvision.transforms.v2.Resize(size=(self.im_size, self.im_size)),
104 | torchvision.transforms.v2.ToPureTensor(),
105 | torchvision.transforms.v2.ToDtype(torch.float32, scale=True),
106 | torchvision.transforms.v2.Normalize(mean=self.imagenet_mean,
107 | std=self.imagenet_std)
108 | ]),
109 | }
110 |
111 | classes = [
112 | 'person', 'bird', 'cat', 'cow', 'dog', 'horse', 'sheep',
113 | 'aeroplane', 'bicycle', 'boat', 'bus', 'car', 'motorbike', 'train',
114 | 'bottle', 'chair', 'diningtable', 'pottedplant', 'sofa', 'tvmonitor'
115 | ]
116 | classes = sorted(classes)
117 | # We need to add background class as well with 0 index
118 | classes = ['background'] + classes
119 |
120 | self.label2idx = {classes[idx]: idx for idx in range(len(classes))}
121 | self.idx2label = {idx: classes[idx] for idx in range(len(classes))}
122 | print(self.idx2label)
123 | self.images_info = load_images_and_anns(self.im_sets,
124 | self.label2idx,
125 | self.fname,
126 | self.split)
127 |
128 | def __len__(self):
129 | return len(self.images_info)
130 |
131 | def __getitem__(self, index):
132 | im_info = self.images_info[index]
133 | im = read_image(im_info['filename'])
134 |
135 | # Get annotations for this image
136 | targets = {}
137 | targets['bboxes'] = tv_tensors.BoundingBoxes(
138 | [detection['bbox'] for detection in im_info['detections']],
139 | format='XYXY', canvas_size=im.shape[-2:])
140 | targets['labels'] = torch.as_tensor(
141 | [detection['label'] for detection in im_info['detections']])
142 | targets['difficult'] = torch.as_tensor(
143 | [detection['difficult']for detection in im_info['detections']])
144 |
145 | # Transform the image and targets
146 | transformed_info = self.transforms[self.split](im, targets)
147 | im_tensor, targets = transformed_info
148 |
149 | h, w = im_tensor.shape[-2:]
150 | wh_tensor = torch.as_tensor([[w, h, w, h]]).expand_as(targets['bboxes'])
151 | targets['bboxes'] = targets['bboxes'] / wh_tensor
152 | return im_tensor, targets, im_info['filename']
153 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/SSD-PyTorch/41b309063138a9d32a0031cfda513f197631d50a/model/__init__.py
--------------------------------------------------------------------------------
/model/ssd.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import math
4 | import torchvision
5 |
6 |
7 | def get_iou(boxes1, boxes2):
8 | r"""
9 | IOU between two sets of boxes
10 | :param boxes1: (Tensor of shape N x 4)
11 | :param boxes2: (Tensor of shape M x 4)
12 | :return: IOU matrix of shape N x M
13 | """
14 |
15 | # Area of boxes (x2-x1)*(y2-y1)
16 | area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) # (N,)
17 | area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) # (M,)
18 |
19 | # Get top left x1,y1 coordinate
20 | x_left = torch.max(boxes1[:, None, 0], boxes2[:, 0]) # (N, M)
21 | y_top = torch.max(boxes1[:, None, 1], boxes2[:, 1]) # (N, M)
22 |
23 | # Get bottom right x2,y2 coordinate
24 | x_right = torch.min(boxes1[:, None, 2], boxes2[:, 2]) # (N, M)
25 | y_bottom = torch.min(boxes1[:, None, 3], boxes2[:, 3]) # (N, M)
26 |
27 | intersection_area = ((x_right - x_left).clamp(min=0) *
28 | (y_bottom - y_top).clamp(min=0)) # (N, M)
29 | union = area1[:, None] + area2 - intersection_area # (N, M)
30 | iou = intersection_area / union # (N, M)
31 | return iou
32 |
33 |
34 | def boxes_to_transformation_targets(ground_truth_boxes,
35 | default_boxes,
36 | weights=(10., 10., 5., 5.)):
37 | r"""
38 | Method to compute targets for each default_boxes.
39 | Assumes boxes are in x1y1x2y2 format.
40 | We first convert boxes to cx,cy,w,h format and then
41 | compute targets based on following formulation
42 | target_dx = (gt_cx - default_boxes_cx) / default_boxes_w
43 | target_dy = (gt_cy - default_boxes_cy) / default_boxes_h
44 | target_dw = log(gt_w / default_boxes_w)
45 | target_dh = log(gt_h / default_boxes_h)
46 | :param ground_truth_boxes: (Tensor of shape N x 4)
47 | :param default_boxes: (Tensor of shape N x 4)
48 | :param weights: Tuple[float] -> (wx, wy, ww, wh)
49 | :return: regression_targets: (Tensor of shape N x 4)
50 | """
51 | # # Get center_x,center_y,w,h from x1,y1,x2,y2 for default_boxes
52 | widths = default_boxes[:, 2] - default_boxes[:, 0]
53 | heights = default_boxes[:, 3] - default_boxes[:, 1]
54 | center_x = default_boxes[:, 0] + 0.5 * widths
55 | center_y = default_boxes[:, 1] + 0.5 * heights
56 |
57 | # # Get center_x,center_y,w,h from x1,y1,x2,y2 for gt boxes
58 | gt_widths = (ground_truth_boxes[:, 2] - ground_truth_boxes[:, 0])
59 | gt_heights = ground_truth_boxes[:, 3] - ground_truth_boxes[:, 1]
60 | gt_center_x = ground_truth_boxes[:, 0] + 0.5 * gt_widths
61 | gt_center_y = ground_truth_boxes[:, 1] + 0.5 * gt_heights
62 |
63 | # Use formulation to compute all targets
64 | targets_dx = weights[0] * (gt_center_x - center_x) / widths
65 | targets_dy = weights[1] * (gt_center_y - center_y) / heights
66 | targets_dw = weights[2] * torch.log(gt_widths / widths)
67 | targets_dh = weights[3] * torch.log(gt_heights / heights)
68 | regression_targets = torch.stack((targets_dx,
69 | targets_dy,
70 | targets_dw,
71 | targets_dh), dim=1)
72 | return regression_targets
73 |
74 |
75 | def apply_regression_pred_to_default_boxes(box_transform_pred,
76 | default_boxes,
77 | weights=(10., 10., 5., 5.)):
78 | r"""
79 | Method to transform default_boxes based on transformation parameter
80 | prediction.
81 | Assumes boxes are in x1y1x2y2 format
82 | :param box_transform_pred: (Tensor of shape N x 4)
83 | :param default_boxes: (Tensor of shape N x 4)
84 | :param weights: Tuple[float] -> (wx, wy, ww, wh)
85 | :return: pred_boxes: (Tensor of shape N x 4)
86 | """
87 |
88 | # Get cx, cy, w, h from x1,y1,x2,y2
89 | w = default_boxes[:, 2] - default_boxes[:, 0]
90 | h = default_boxes[:, 3] - default_boxes[:, 1]
91 | center_x = default_boxes[:, 0] + 0.5 * w
92 | center_y = default_boxes[:, 1] + 0.5 * h
93 |
94 | dx = box_transform_pred[..., 0] / weights[0]
95 | dy = box_transform_pred[..., 1] / weights[1]
96 | dw = box_transform_pred[..., 2] / weights[2]
97 | dh = box_transform_pred[..., 3] / weights[3]
98 | # dh -> (num_default_boxes)
99 |
100 | pred_center_x = dx * w + center_x
101 | pred_center_y = dy * h + center_y
102 | pred_w = torch.exp(dw) * w
103 | pred_h = torch.exp(dh) * h
104 | # pred_center_x -> (num_default_boxes, 4)
105 |
106 | pred_box_x1 = pred_center_x - 0.5 * pred_w
107 | pred_box_y1 = pred_center_y - 0.5 * pred_h
108 | pred_box_x2 = pred_center_x + 0.5 * pred_w
109 | pred_box_y2 = pred_center_y + 0.5 * pred_h
110 |
111 | pred_boxes = torch.stack((
112 | pred_box_x1,
113 | pred_box_y1,
114 | pred_box_x2,
115 | pred_box_y2),
116 | dim=-1)
117 | return pred_boxes
118 |
119 |
120 | def generate_default_boxes(feat, aspect_ratios, scales):
121 | r"""
122 | Method to generate default_boxes for all feature maps the image
123 | :param feat: List[(Tensor of shape B x C x Feat_H x Feat x W)]
124 | :param aspect_ratios: List[List[float]] aspect ratios for each feature map
125 | :param scales: List[float] scales for each feature map
126 | :return: default_boxes : List[(Tensor of shape N x 4)] default_boxes over all
127 | feature maps aggregated for each batch image
128 | """
129 |
130 | # List to store default boxes for all feature maps
131 | default_boxes = []
132 | for k in range(len(feat)):
133 | # We first add the aspect ratio 1 and scale (sqrt(scale[k])*sqrt(scale[k+1])
134 | s_prime_k = math.sqrt(scales[k] * scales[k + 1])
135 | wh_pairs = [[s_prime_k, s_prime_k]]
136 |
137 | # Adding all possible w,h pairs according to
138 | # aspect ratio of the feature map k
139 | for ar in aspect_ratios[k]:
140 | sq_ar = math.sqrt(ar)
141 | w = scales[k] * sq_ar
142 | h = scales[k] / sq_ar
143 |
144 | wh_pairs.extend([[w, h]])
145 |
146 | feat_h, feat_w = feat[k].shape[-2:]
147 |
148 | # These shifts will be the centre of each of the default boxes
149 | shifts_x = ((torch.arange(0, feat_w) + 0.5) / feat_w).to(torch.float32)
150 | shifts_y = ((torch.arange(0, feat_h) + 0.5) / feat_h).to(torch.float32)
151 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
152 | shift_x = shift_x.reshape(-1)
153 | shift_y = shift_y.reshape(-1)
154 |
155 | # Duplicate these shifts for as
156 | # many boxes(aspect ratios)
157 | # per position we have
158 | shifts = torch.stack((shift_x, shift_y) * len(wh_pairs), dim=-1).reshape(-1, 2)
159 | # shifts for first feature map will be (5776 x 2)
160 |
161 | wh_pairs = torch.as_tensor(wh_pairs)
162 |
163 | # Repeat the wh pairs for all positions in feature map
164 | wh_pairs = wh_pairs.repeat((feat_h * feat_w), 1)
165 | # wh_pairs for first feature map will be (5776 x 2)
166 |
167 | # Concat the shifts(cx cy) and wh values for all positions
168 | default_box = torch.cat((shifts, wh_pairs), dim=1)
169 | # default box for feat_1 -> (5776, 4)
170 | # default box for feat_2 -> (2166, 4)
171 | # default box for feat_3 -> (600, 4)
172 | # default box for feat_4 -> (150, 4)
173 | # default box for feat_5 -> (36, 4)
174 | # default box for feat_6 -> (4, 4)
175 |
176 | default_boxes.append(default_box)
177 | default_boxes = torch.cat(default_boxes, dim=0)
178 | # default_boxes -> (8732, 4)
179 |
180 | # We now duplicate these default boxes
181 | # for all images in the batch
182 | # and also convert cx,cy,w,h format of
183 | # default boxes to x1,y1,x2,y2
184 | dboxes = []
185 | for _ in range(feat[0].size(0)):
186 | dboxes_in_image = default_boxes
187 | # x1 = cx - 0.5 * width
188 | # y1 = cy - 0.5 * height
189 | # x2 = cx + 0.5 * width
190 | # y2 = cy + 0.5 * height
191 | dboxes_in_image = torch.cat(
192 | [
193 | (dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:]),
194 | (dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]),
195 | ],
196 | -1,
197 | )
198 | dboxes.append(dboxes_in_image.to(feat[0].device))
199 | return dboxes
200 |
201 |
202 | class SSD(nn.Module):
203 | r"""
204 | Main Class for SSD. Does the following steps
205 | to generate detections/losses.
206 | During initialization
207 | 1. Load VGG Imagenet pretrained model
208 | 2. Extract Backbone from VGG and add extra conv layers
209 | 3. Add class prediction and bbox transformation prediction layers
210 | 4. Initialize all conv2d layers
211 |
212 | During Forward Pass
213 | 1. Get conv4_3 output
214 | 2. Normalize and scale conv4_3 output (feat_output_1)
215 | 3. Pass the unscaled conv4_3 to conv5_3 layers and conv layers
216 | replacing fc6 and fc7 of vgg (feat_output_2)
217 | 4. Pass the conv_fc7 output to extra conv layers (feat_output_3-6)
218 | 5. Get the classification and regression predictions for all 6 feature maps
219 | 6. Generate default_boxes for all these feature maps(8732 x 4)
220 | 7a. If in training assign targets for these default_boxes and
221 | compute localization and classification losses
222 | 7b. If in inference mode, then do all pre-nms filtering, nms
223 | and then post nms filtering and return the detected boxes,
224 | their labels and their scores
225 | """
226 | def __init__(self, config, num_classes=21):
227 | super().__init__()
228 | self.aspect_ratios = config['aspect_ratios']
229 |
230 | self.scales = config['scales']
231 | self.scales.append(1.0)
232 |
233 | self.num_classes = num_classes
234 | self.iou_threshold = config['iou_threshold']
235 | self.low_score_threshold = config['low_score_threshold']
236 | self.neg_pos_ratio = config['neg_pos_ratio']
237 | self.pre_nms_topK = config['pre_nms_topK']
238 | self.nms_threshold = config['nms_threshold']
239 | self.detections_per_img = config['detections_per_img']
240 |
241 | # Load imagenet pretrained vgg network
242 | backbone = torchvision.models.vgg16(
243 | weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1
244 | )
245 |
246 | # Get all max pool indexes to determine different stages
247 | max_pool_pos = [idx for idx, layer in enumerate(list(backbone.features))
248 | if isinstance(layer, nn.MaxPool2d)]
249 | max_pool_stage_3_pos = max_pool_pos[-3] # for vgg16 this would be 16
250 | max_pool_stage_4_pos = max_pool_pos[-2] # for vgg16 this would be 23
251 |
252 | backbone.features[max_pool_stage_3_pos].ceil_mode = True
253 | # otherwise vgg conv4_3 output will be 37x37
254 | self.features = nn.Sequential(*backbone.features[:max_pool_stage_4_pos])
255 | self.scale_weight = nn.Parameter(torch.ones(512) * 20)
256 |
257 | ###################################
258 | # Conv5_3 + Conv for fc6 and fc 7 #
259 | ###################################
260 | # Conv modules replacing fc6 and fc7
261 | # Ideally we would copy the weights
262 | # but here we are just adding new layers
263 | # and not copying fc6 and fc7 weights by
264 | # subsampling
265 | fcs = nn.Sequential(
266 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
267 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3,
268 | padding=6, dilation=6),
269 | nn.ReLU(inplace=True),
270 | nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1),
271 | nn.ReLU(inplace=True),
272 | )
273 | self.conv5_3_fc = nn.Sequential(
274 | *backbone.features[max_pool_stage_4_pos:-1],
275 | fcs,
276 | )
277 |
278 | ##########################
279 | # Additional Conv Layers #
280 | ##########################
281 | # Modules to take from 19x19 to 10x10
282 | self.conv8_2 = nn.Sequential(
283 | nn.Conv2d(1024, 256, kernel_size=1),
284 | nn.ReLU(inplace=True),
285 | nn.Conv2d(256, 512, kernel_size=3, padding=1,
286 | stride=2),
287 | nn.ReLU(inplace=True)
288 | )
289 |
290 | # Modules to take from 10x10 to 5x5
291 | self.conv9_2 = nn.Sequential(
292 | nn.Conv2d(512, 128, kernel_size=1),
293 | nn.ReLU(inplace=True),
294 | nn.Conv2d(128, 256, kernel_size=3, padding=1,
295 | stride=2),
296 | nn.ReLU(inplace=True)
297 | )
298 |
299 | # Modules to take from 5x5 to 3x3
300 | self.conv10_2 = nn.Sequential(
301 | nn.Conv2d(256, 128, kernel_size=1),
302 | nn.ReLU(inplace=True),
303 | nn.Conv2d(128, 256, kernel_size=3),
304 | nn.ReLU(inplace=True)
305 | )
306 |
307 | # Modules to take from 3x3 to 1x1
308 | self.conv11_2 = nn.Sequential(
309 | nn.Conv2d(256, 128, kernel_size=1),
310 | nn.ReLU(inplace=True),
311 | nn.Conv2d(128, 256, kernel_size=3),
312 | nn.ReLU(inplace=True)
313 | )
314 |
315 | # Must match conv4_3, fcs, conv8_2, conv9_2, conv10_2, conv11_2
316 | out_channels = [512, 1024, 512, 256, 256, 256]
317 |
318 | #####################
319 | # Prediction Layers #
320 | #####################
321 | self.cls_heads = nn.ModuleList()
322 | for channels, aspect_ratio in zip(out_channels, self.aspect_ratios):
323 | # extra 1 is added for scale of sqrt(sk*sk+1)
324 | self.cls_heads.append(nn.Conv2d(channels,
325 | self.num_classes * (len(aspect_ratio)+1),
326 | kernel_size=3,
327 | padding=1))
328 |
329 | self.bbox_reg_heads = nn.ModuleList()
330 | for channels, aspect_ratio in zip(out_channels, self.aspect_ratios):
331 | # extra 1 is added for scale of sqrt(sk*sk+1)
332 | self.bbox_reg_heads.append(nn.Conv2d(channels, 4 * (len(aspect_ratio)+1),
333 | kernel_size=3,
334 | padding=1))
335 |
336 | #############################
337 | # Conv Layer Initialization #
338 | #############################
339 | for layer in fcs.modules():
340 | if isinstance(layer, nn.Conv2d):
341 | torch.nn.init.xavier_uniform_(layer.weight)
342 | if layer.bias is not None:
343 | torch.nn.init.constant_(layer.bias, 0.0)
344 |
345 | for conv_module in [self.conv8_2, self.conv9_2, self.conv10_2, self.conv11_2]:
346 | for layer in conv_module.modules():
347 | if isinstance(layer, nn.Conv2d):
348 | torch.nn.init.xavier_uniform_(layer.weight)
349 | if layer.bias is not None:
350 | torch.nn.init.constant_(layer.bias, 0.0)
351 |
352 | for module in self.cls_heads:
353 | torch.nn.init.xavier_uniform_(module.weight)
354 | if module.bias is not None:
355 | torch.nn.init.constant_(module.bias, 0.0)
356 | for module in self.bbox_reg_heads:
357 | torch.nn.init.xavier_uniform_(module.weight)
358 | if module.bias is not None:
359 | torch.nn.init.constant_(module.bias, 0.0)
360 |
361 | def compute_loss(
362 | self,
363 | targets,
364 | cls_logits,
365 | bbox_regression,
366 | default_boxes,
367 | matched_idxs,
368 | ):
369 | # Counting all the foreground default_boxes for computing N in loss equation
370 | num_foreground = 0
371 | # BBox losses for all batch images(for foreground default_boxes)
372 | bbox_loss = []
373 | # classification targets for all batch images(for ALL default_boxes)
374 | cls_targets = []
375 | for (
376 | targets_per_image,
377 | bbox_regression_per_image,
378 | cls_logits_per_image,
379 | default_boxes_per_image,
380 | matched_idxs_per_image,
381 | ) in zip(targets, bbox_regression, cls_logits, default_boxes, matched_idxs):
382 | # Foreground default_boxes -> matched_idx >=0
383 | # Background default_boxes -> matched_idx = -1
384 | fg_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
385 | foreground_matched_idxs_per_image = matched_idxs_per_image[
386 | fg_idxs_per_image
387 | ]
388 | num_foreground += foreground_matched_idxs_per_image.numel()
389 |
390 | # Get foreground default_boxes and their transformation predictions
391 | matched_gt_boxes_per_image = targets_per_image["boxes"][
392 | foreground_matched_idxs_per_image
393 | ]
394 | bbox_regression_per_image = bbox_regression_per_image[fg_idxs_per_image, :]
395 | default_boxes_per_image = default_boxes_per_image[fg_idxs_per_image, :]
396 | target_regression = boxes_to_transformation_targets(
397 | matched_gt_boxes_per_image,
398 | default_boxes_per_image)
399 |
400 | bbox_loss.append(
401 | torch.nn.functional.smooth_l1_loss(bbox_regression_per_image,
402 | target_regression,
403 | reduction='sum')
404 | )
405 |
406 | # Get classification target for ALL default_boxes
407 | # For all default_boxes set it as 0 first
408 | # Then set foreground default_boxes target as label
409 | # of assigned gt box
410 | gt_classes_target = torch.zeros(
411 | (cls_logits_per_image.size(0),),
412 | dtype=targets_per_image["labels"].dtype,
413 | device=targets_per_image["labels"].device,
414 | )
415 | gt_classes_target[fg_idxs_per_image] = targets_per_image["labels"][
416 | foreground_matched_idxs_per_image
417 | ]
418 | cls_targets.append(gt_classes_target)
419 |
420 | # Aggregated bbox loss and classification targets
421 | # for all batch images
422 | bbox_loss = torch.stack(bbox_loss)
423 | cls_targets = torch.stack(cls_targets) # (B, 8732)
424 |
425 | # Calculate classification loss for ALL default_boxes
426 | num_classes = cls_logits.size(-1)
427 | cls_loss = torch.nn.functional.cross_entropy(cls_logits.view(-1, num_classes),
428 | cls_targets.view(-1),
429 | reduction="none").view(
430 | cls_targets.size()
431 | )
432 |
433 | # Hard Negative Mining
434 | foreground_idxs = cls_targets > 0
435 | # We will sample total of 3 x (number of fg default_boxes)
436 | # background default_boxes
437 | num_negative = self.neg_pos_ratio * foreground_idxs.sum(1, keepdim=True)
438 |
439 | # As of now cls_loss is for ALL default_boxes
440 | negative_loss = cls_loss.clone()
441 | # We want to ensure that after sorting based on loss value,
442 | # foreground default_boxes are never picked when choosing topK
443 | # highest loss indexes
444 | negative_loss[foreground_idxs] = -float("inf")
445 | values, idx = negative_loss.sort(1, descending=True)
446 | # Fetch those indexes which have in topK(K=num_negative) losses
447 | background_idxs = idx.sort(1)[1] < num_negative
448 | N = max(1, num_foreground)
449 | return {
450 | "bbox_regression": bbox_loss.sum() / N,
451 | "classification": (cls_loss[foreground_idxs].sum() +
452 | cls_loss[background_idxs].sum()) / N,
453 | }
454 |
455 | def forward(self, x, targets=None):
456 | # Call everything till conv4_3 layers first
457 | conv_4_3_out = self.features(x)
458 |
459 | # Scale conv4_3 output using learnt norm scale
460 | conv_4_3_out_scaled = (self.scale_weight.view(1, -1, 1, 1) *
461 | torch.nn.functional.normalize(conv_4_3_out))
462 |
463 | # Call conv5_3 with non_scaled conv_3 and also
464 | # Call additional conv layers
465 | conv_5_3_fc_out = self.conv5_3_fc(conv_4_3_out)
466 | conv8_2_out = self.conv8_2(conv_5_3_fc_out)
467 | conv9_2_out = self.conv9_2(conv8_2_out)
468 | conv10_2_out = self.conv10_2(conv9_2_out)
469 | conv11_2_out = self.conv11_2(conv10_2_out)
470 |
471 | # Feature maps for predictions
472 | outputs = [
473 | conv_4_3_out_scaled, # 38 x 38
474 | conv_5_3_fc_out, # 19 x 19
475 | conv8_2_out, # 10 x 10
476 | conv9_2_out, # 5 x 5
477 | conv10_2_out, # 3 x 3
478 | conv11_2_out, # 1 x 1
479 | ]
480 |
481 | # Classification and bbox regression for all feature maps
482 | cls_logits = []
483 | bbox_reg_deltas = []
484 | for i, features in enumerate(outputs):
485 | cls_feat_i = self.cls_heads[i](features)
486 | bbox_reg_feat_i = self.bbox_reg_heads[i](features)
487 |
488 | # Cls output from (B, A * num_classes, H, W) to (B, HWA, num_classes).
489 | N, _, H, W = cls_feat_i.shape
490 | cls_feat_i = cls_feat_i.view(N, -1, self.num_classes, H, W)
491 | # (B, A, num_classes, H, W)
492 | cls_feat_i = cls_feat_i.permute(0, 3, 4, 1, 2) # (B, H, W, A, num_classes)
493 | cls_feat_i = cls_feat_i.reshape(N, -1, self.num_classes)
494 | # (B, HWA, num_classes)
495 | cls_logits.append(cls_feat_i)
496 |
497 | # Permute bbox reg output from (B, A * 4, H, W) to (B, HWA, 4).
498 | N, _, H, W = bbox_reg_feat_i.shape
499 | bbox_reg_feat_i = bbox_reg_feat_i.view(N, -1, 4, H, W) # (B, A, 4, H, W)
500 | bbox_reg_feat_i = bbox_reg_feat_i.permute(0, 3, 4, 1, 2) # (B, H, W, A, 4)
501 | bbox_reg_feat_i = bbox_reg_feat_i.reshape(N, -1, 4) # Size=(B, HWA, 4)
502 | bbox_reg_deltas.append(bbox_reg_feat_i)
503 |
504 | # Concat cls logits and bbox regression predictions for all feature maps
505 | cls_logits = torch.cat(cls_logits, dim=1) # (B, 8732, num_classes)
506 | bbox_reg_deltas = torch.cat(bbox_reg_deltas, dim=1) # (B, 8732, 4)
507 |
508 | # Generate default_boxes for all feature maps
509 | default_boxes = generate_default_boxes(outputs, self.aspect_ratios, self.scales)
510 | # default_boxes -> List[Tensor of shape 8732 x 4]
511 | # len(default_boxes) = Batch size
512 |
513 | losses = {}
514 | detections = []
515 | if self.training:
516 | # List to hold for each image, which default box
517 | # is assigned to with gt box if any
518 | # or unassigned(background)
519 | matched_idxs = []
520 | for default_boxes_per_image, targets_per_image in zip(default_boxes,
521 | targets):
522 | if targets_per_image["boxes"].numel() == 0:
523 | matched_idxs.append(
524 | torch.full(
525 | (default_boxes_per_image.size(0),), -1,
526 | dtype=torch.int64,
527 | device=default_boxes_per_image.device
528 | )
529 | )
530 | continue
531 | iou_matrix = get_iou(targets_per_image["boxes"],
532 | default_boxes_per_image)
533 | # For each default box find best ground truth box
534 | matched_vals, matches = iou_matrix.max(dim=0)
535 | # matches -> [8732]
536 |
537 | # Update index of match for all default_boxes which
538 | # have maximum iou with a gt box < low threshold
539 | # as -1
540 | # This allows selecting foreground boxes as match index >= 0
541 | below_low_threshold = matched_vals < self.iou_threshold
542 | matches[below_low_threshold] = -1
543 |
544 | # We want to also assign the best default box for every gt
545 | # as foreground
546 | # So first find the best default box for every gt
547 | _, highest_quality_pred_foreach_gt = iou_matrix.max(dim=1)
548 | # Update the best matching gt index for these best default_boxes
549 | # as 0, 1, 2, ...., len(gt)-1
550 | matches[highest_quality_pred_foreach_gt] = torch.arange(
551 | highest_quality_pred_foreach_gt.size(0), dtype=torch.int64,
552 | device=highest_quality_pred_foreach_gt.device
553 | )
554 | matched_idxs.append(matches)
555 | losses = self.compute_loss(targets, cls_logits, bbox_reg_deltas,
556 | default_boxes, matched_idxs)
557 | else:
558 | # For test time we do the following:
559 | # 1. Convert default_boxes to boxes using predicted bbox regression deltas
560 | # 2. Low score filtering
561 | # 3. Pre-NMS TopK filtering
562 | # 4. NMS
563 | # 5. Post NMS TopK Filtering
564 | cls_scores = torch.nn.functional.softmax(cls_logits, dim=-1)
565 | num_classes = cls_scores.size(-1)
566 |
567 | for bbox_deltas_i, cls_scores_i, default_boxes_i in zip(bbox_reg_deltas,
568 | cls_scores,
569 | default_boxes):
570 | boxes = apply_regression_pred_to_default_boxes(bbox_deltas_i,
571 | default_boxes_i)
572 | # Ensure all values are between 0-1
573 | boxes.clamp_(min=0., max=1.)
574 |
575 | pred_boxes = []
576 | pred_scores = []
577 | pred_labels = []
578 | # Class wise filtering
579 | for label in range(1, num_classes):
580 | score = cls_scores_i[:, label]
581 |
582 | # Remove low scoring boxes of this class
583 | keep_idxs = score > self.low_score_threshold
584 | score = score[keep_idxs]
585 | box = boxes[keep_idxs]
586 |
587 | # keep only topk scoring predictions of this class
588 | score, top_k_idxs = score.topk(min(self.pre_nms_topK, len(score)))
589 | box = box[top_k_idxs]
590 |
591 | pred_boxes.append(box)
592 | pred_scores.append(score)
593 | pred_labels.append(torch.full_like(score, fill_value=label,
594 | dtype=torch.int64,
595 | device=cls_scores.device))
596 |
597 | pred_boxes = torch.cat(pred_boxes, dim=0)
598 | pred_scores = torch.cat(pred_scores, dim=0)
599 | pred_labels = torch.cat(pred_labels, dim=0)
600 |
601 | # Class wise NMS
602 | keep_mask = torch.zeros_like(pred_scores, dtype=torch.bool)
603 | for class_id in torch.unique(pred_labels):
604 | curr_indices = torch.where(pred_labels == class_id)[0]
605 | curr_keep_idxs = torch.ops.torchvision.nms(pred_boxes[curr_indices],
606 | pred_scores[curr_indices],
607 | self.nms_threshold)
608 | keep_mask[curr_indices[curr_keep_idxs]] = True
609 | keep_indices = torch.where(keep_mask)[0]
610 | post_nms_keep_indices = keep_indices[pred_scores[keep_indices].sort(
611 | descending=True)[1]]
612 | keep = post_nms_keep_indices[:self.detections_per_img]
613 | pred_boxes, pred_scores, pred_labels = (pred_boxes[keep],
614 | pred_scores[keep],
615 | pred_labels[keep])
616 |
617 | detections.append(
618 | {
619 | "boxes": pred_boxes,
620 | "scores": pred_scores,
621 | "labels": pred_labels,
622 | }
623 | )
624 | return losses, detections
625 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.8.0
2 | numpy==2.0.1
3 | opencv_python==4.10.0.84
4 | Pillow==10.4.0
5 | PyYAML==6.0.1
6 | torch==2.3.1
7 | torchvision==0.18.1
8 | tqdm==4.66.4
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explainingai-code/SSD-PyTorch/41b309063138a9d32a0031cfda513f197631d50a/tools/__init__.py
--------------------------------------------------------------------------------
/tools/infer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import os
4 | import yaml
5 | import random
6 | from tqdm import tqdm
7 | from model.ssd import SSD
8 | import numpy as np
9 | import cv2
10 | from dataset.voc import VOCDataset
11 | from torch.utils.data.dataloader import DataLoader
12 |
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 | if torch.backends.mps.is_available():
15 | device = torch.device('mps')
16 | print('Using mps')
17 |
18 |
19 | def get_iou(det, gt):
20 | det_x1, det_y1, det_x2, det_y2 = det
21 | gt_x1, gt_y1, gt_x2, gt_y2 = gt
22 |
23 | x_left = max(det_x1, gt_x1)
24 | y_top = max(det_y1, gt_y1)
25 | x_right = min(det_x2, gt_x2)
26 | y_bottom = min(det_y2, gt_y2)
27 |
28 | if x_right < x_left or y_bottom < y_top:
29 | return 0.0
30 |
31 | area_intersection = (x_right - x_left) * (y_bottom - y_top)
32 | det_area = (det_x2 - det_x1) * (det_y2 - det_y1)
33 | gt_area = (gt_x2 - gt_x1) * (gt_y2 - gt_y1)
34 | area_union = float(det_area + gt_area - area_intersection + 1E-6)
35 | iou = area_intersection / area_union
36 | return iou
37 |
38 |
39 | def compute_map(det_boxes, gt_boxes, iou_threshold=0.5, method='area', difficult=None):
40 | # det_boxes = [
41 | # {
42 | # 'person' : [[x1, y1, x2, y2, score], ...],
43 | # 'car' : [[x1, y1, x2, y2, score], ...]
44 | # }
45 | # {det_boxes_img_2},
46 | # ...
47 | # {det_boxes_img_N},
48 | # ]
49 | #
50 | # gt_boxes = [
51 | # {
52 | # 'person' : [[x1, y1, x2, y2], ...],
53 | # 'car' : [[x1, y1, x2, y2], ...]
54 | # },
55 | # {gt_boxes_img_2},
56 | # ...
57 | # {gt_boxes_img_N},
58 | # ]
59 |
60 | gt_labels = {cls_key for im_gt in gt_boxes for cls_key in im_gt.keys()}
61 | gt_labels = sorted(gt_labels)
62 |
63 | all_aps = {}
64 | # average precisions for ALL classes
65 | aps = []
66 | for idx, label in enumerate(gt_labels):
67 | # Get detection predictions of this class
68 | cls_dets = [
69 | [im_idx, im_dets_label] for im_idx, im_dets in enumerate(det_boxes)
70 | if label in im_dets for im_dets_label in im_dets[label]
71 | ]
72 |
73 | # cls_dets = [
74 | # (0, [x1_0, y1_0, x2_0, y2_0, score_0]),
75 | # ...
76 | # (0, [x1_M, y1_M, x2_M, y2_M, score_M]),
77 | # (1, [x1_0, y1_0, x2_0, y2_0, score_0]),
78 | # ...
79 | # (1, [x1_N, y1_N, x2_N, y2_N, score_N]),
80 | # ...
81 | # ]
82 |
83 | # Sort them by confidence score
84 | cls_dets = sorted(cls_dets, key=lambda k: -k[1][-1])
85 |
86 | # For tracking which gt boxes of this class have already been matched
87 | gt_matched = [[False for _ in im_gts[label]] for im_gts in gt_boxes]
88 | # Number of gt boxes for this class for recall calculation
89 | num_gts = sum([len(im_gts[label]) for im_gts in gt_boxes])
90 | num_difficults = sum([sum(difficults_label[label]) for difficults_label in difficult])
91 |
92 | tp = [0] * len(cls_dets)
93 | fp = [0] * len(cls_dets)
94 |
95 | # For each prediction
96 | for det_idx, (im_idx, det_pred) in enumerate(cls_dets):
97 | # Get gt boxes for this image and this label
98 | im_gts = gt_boxes[im_idx][label]
99 | im_gt_difficults = difficult[im_idx][label]
100 |
101 | max_iou_found = -1
102 | max_iou_gt_idx = -1
103 |
104 | # Get best matching gt box
105 | for gt_box_idx, gt_box in enumerate(im_gts):
106 | gt_box_iou = get_iou(det_pred[:-1], gt_box)
107 | if gt_box_iou > max_iou_found:
108 | max_iou_found = gt_box_iou
109 | max_iou_gt_idx = gt_box_idx
110 | # TP only if iou >= threshold and this gt has not yet been matched
111 | if max_iou_found >= iou_threshold:
112 | if not im_gt_difficults[max_iou_gt_idx]:
113 | if not gt_matched[im_idx][max_iou_gt_idx]:
114 | # If tp then we set this gt box as matched
115 | gt_matched[im_idx][max_iou_gt_idx] = True
116 | tp[det_idx] = 1
117 | else:
118 | fp[det_idx] = 1
119 | else:
120 | fp[det_idx] = 1
121 |
122 | # Cumulative tp and fp
123 | tp = np.cumsum(tp)
124 | fp = np.cumsum(fp)
125 |
126 | eps = np.finfo(np.float32).eps
127 | # recalls = tp / np.maximum(num_gts, eps)
128 | recalls = tp / np.maximum(num_gts - num_difficults, eps)
129 | precisions = tp / np.maximum((tp + fp), eps)
130 |
131 | if method == 'area':
132 | recalls = np.concatenate(([0.0], recalls, [1.0]))
133 | precisions = np.concatenate(([0.0], precisions, [0.0]))
134 |
135 | # Replace precision values with recall r with maximum precision value
136 | # of any recall value >= r
137 | # This computes the precision envelope
138 | for i in range(precisions.size - 1, 0, -1):
139 | precisions[i - 1] = np.maximum(precisions[i - 1], precisions[i])
140 | # For computing area, get points where recall changes value
141 | i = np.where(recalls[1:] != recalls[:-1])[0]
142 | # Add the rectangular areas to get ap
143 | ap = np.sum((recalls[i + 1] - recalls[i]) * precisions[i + 1])
144 | elif method == 'interp':
145 | ap = 0.0
146 | for interp_pt in np.arange(0, 1 + 1E-3, 0.1):
147 | # Get precision values for recall values >= interp_pt
148 | prec_interp_pt = precisions[recalls >= interp_pt]
149 |
150 | # Get max of those precision values
151 | prec_interp_pt= prec_interp_pt.max() if prec_interp_pt.size>0.0 else 0.0
152 | ap += prec_interp_pt
153 | ap = ap / 11.0
154 | else:
155 | raise ValueError('Method can only be area or interp')
156 | if num_gts > 0:
157 | aps.append(ap)
158 | all_aps[label] = ap
159 | else:
160 | all_aps[label] = np.nan
161 | # compute mAP at provided iou threshold
162 | mean_ap = sum(aps) / len(aps)
163 | return mean_ap, all_aps
164 |
165 |
166 | def load_model_and_dataset(args):
167 | # Read the config file #
168 | with open(args.config_path, 'r') as file:
169 | try:
170 | config = yaml.safe_load(file)
171 | except yaml.YAMLError as exc:
172 | print(exc)
173 | print(config)
174 | ########################
175 |
176 | dataset_config = config['dataset_params']
177 | model_config = config['model_params']
178 | train_config = config['train_params']
179 |
180 | voc = VOCDataset('test',
181 | im_sets=dataset_config['test_im_sets'])
182 | test_dataset = DataLoader(voc, batch_size=1, shuffle=False)
183 |
184 | model = SSD(config=model_config,
185 | num_classes=dataset_config['num_classes'])
186 | model.to(device=torch.device(device))
187 | model.eval()
188 |
189 | assert os.path.exists(os.path.join(train_config['task_name'],
190 | train_config['ckpt_name'])), \
191 | "No checkpoint exists at {}".format(os.path.join(train_config['task_name'],
192 | train_config['ckpt_name']))
193 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
194 | train_config['ckpt_name']),
195 | map_location=device))
196 | return model, voc, test_dataset, config
197 |
198 |
199 | def infer(args):
200 | if not os.path.exists('samples'):
201 | os.mkdir('samples')
202 |
203 | model, voc, test_dataset, config = load_model_and_dataset(args)
204 | conf_threshold = config['train_params']['infer_conf_threshold']
205 | model.low_score_threshold = conf_threshold
206 |
207 | num_samples = 5
208 | for i in tqdm(range(num_samples)):
209 | dataset_idx = random.randint(0, len(voc))
210 | im_tensor, target, fname = voc[dataset_idx]
211 | _, ssd_detections = model(im_tensor.unsqueeze(0).to(device), [target])
212 |
213 | gt_im = cv2.imread(fname)
214 | h, w = gt_im.shape[:2]
215 | gt_im_copy = gt_im.copy()
216 | # Saving images with ground truth boxes
217 | for idx, box in enumerate(target['bboxes']):
218 | x1, y1, x2, y2 = box.detach().cpu().numpy()
219 | x1, y1, x2, y2 = int(w*x1), int(h*y1), int(w*x2), int(h*y2)
220 | cv2.rectangle(gt_im, (x1, y1), (x2, y2), thickness=2, color=[0, 255, 0])
221 | cv2.rectangle(gt_im_copy, (x1, y1), (x2, y2), thickness=2, color=[0, 255, 0])
222 | text = voc.idx2label[target['labels'][idx].detach().cpu().item()]
223 | text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_PLAIN, 1, 1)
224 | text_w, text_h = text_size
225 | cv2.rectangle(gt_im_copy, (x1, y1), (x1 + 10 + text_w, y1 + 10 + text_h), [255, 255, 255], -1)
226 | cv2.putText(gt_im, text=voc.idx2label[target['labels'][idx].detach().cpu().item()],
227 | org=(x1 + 5, y1 + 15),
228 | thickness=1,
229 | fontScale=1,
230 | color=[0, 0, 0],
231 | fontFace=cv2.FONT_HERSHEY_PLAIN)
232 | cv2.putText(gt_im_copy, text=text,
233 | org=(x1 + 5, y1 + 15),
234 | thickness=1,
235 | fontScale=1,
236 | color=[0, 0, 0],
237 | fontFace=cv2.FONT_HERSHEY_PLAIN)
238 | cv2.addWeighted(gt_im_copy, 0.7, gt_im, 0.3, 0, gt_im)
239 | cv2.imwrite('samples/output_ssd_gt_{}.png'.format(i), gt_im)
240 |
241 | # Getting predictions from trained model
242 | boxes = ssd_detections[0]['boxes']
243 | labels = ssd_detections[0]['labels']
244 | scores = ssd_detections[0]['scores']
245 | im = cv2.imread(fname)
246 | im_copy = im.copy()
247 |
248 | # Saving images with predicted boxes
249 | for idx, box in enumerate(boxes):
250 | x1, y1, x2, y2 = box.detach().cpu().numpy()
251 | x1, y1, x2, y2 = int(w * x1), int(h * y1), int(w * x2), int(h * y2)
252 | cv2.rectangle(im, (x1, y1), (x2, y2), thickness=2, color=[0, 0, 255])
253 | cv2.rectangle(im_copy, (x1, y1), (x2, y2), thickness=2, color=[0, 0, 255])
254 | text = '{} : {:.2f}'.format(voc.idx2label[labels[idx].detach().cpu().item()],
255 | scores[idx].detach().cpu().item())
256 | text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_PLAIN, 1, 1)
257 | text_w, text_h = text_size
258 | cv2.rectangle(im_copy, (x1, y1), (x1 + 10 + text_w, y1 + 10 + text_h), [255, 255, 255], -1)
259 | cv2.putText(im, text=text,
260 | org=(x1 + 5, y1 + 15),
261 | thickness=1,
262 | fontScale=1,
263 | color=[0, 0, 0],
264 | fontFace=cv2.FONT_HERSHEY_PLAIN)
265 | cv2.putText(im_copy, text=text,
266 | org=(x1 + 5, y1 + 15),
267 | thickness=1,
268 | fontScale=1,
269 | color=[0, 0, 0],
270 | fontFace=cv2.FONT_HERSHEY_PLAIN)
271 | cv2.addWeighted(im_copy, 0.7, im, 0.3, 0, im)
272 | cv2.imwrite('samples/output_ssd_{}.jpg'.format(i), im)
273 |
274 | print('Done Detecting...')
275 |
276 |
277 | def evaluate_map(args):
278 | model, voc, test_dataset, config = load_model_and_dataset(args)
279 |
280 | gts = []
281 | preds = []
282 | difficults = []
283 | for im_tensor, target, fname in tqdm(test_dataset):
284 | im_tensor = im_tensor.float().to(device)
285 | target_bboxes = target['bboxes'].float()[0].to(device)
286 | target_labels = target['labels'].long()[0].to(device)
287 | difficult = target['difficult'].long()[0].to(device)
288 | _, ssd_detections = model(im_tensor)
289 |
290 | boxes = ssd_detections[0]['boxes']
291 | labels = ssd_detections[0]['labels']
292 | scores = ssd_detections[0]['scores']
293 |
294 | pred_boxes = {}
295 | gt_boxes = {}
296 | difficult_boxes = {}
297 |
298 | for label_name in voc.label2idx:
299 | pred_boxes[label_name] = []
300 | gt_boxes[label_name] = []
301 | difficult_boxes[label_name] = []
302 |
303 | for idx, box in enumerate(boxes):
304 | x1, y1, x2, y2 = box.detach().cpu().numpy()
305 | label = labels[idx].detach().cpu().item()
306 | score = scores[idx].detach().cpu().item()
307 | label_name = voc.idx2label[label]
308 | pred_boxes[label_name].append([x1, y1, x2, y2, score])
309 | for idx, box in enumerate(target_bboxes):
310 | x1, y1, x2, y2 = box.detach().cpu().numpy()
311 | label = target_labels[idx].detach().cpu().item()
312 | label_name = voc.idx2label[label]
313 | gt_boxes[label_name].append([x1, y1, x2, y2])
314 | difficult_boxes[label_name].append(difficult[idx].detach().cpu().item())
315 |
316 | gts.append(gt_boxes)
317 | preds.append(pred_boxes)
318 | difficults.append(difficult_boxes)
319 | mean_ap, all_aps = compute_map(preds, gts, method='area', difficult=difficults)
320 | print('Class Wise Average Precisions')
321 | for idx in range(len(voc.idx2label)):
322 | print('AP for class {} = {:.4f}'.format(voc.idx2label[idx],
323 | all_aps[voc.idx2label[idx]]))
324 | print('Mean Average Precision : {:.4f}'.format(mean_ap))
325 |
326 |
327 | if __name__ == '__main__':
328 | parser = argparse.ArgumentParser(description='Arguments for ssd inference')
329 | parser.add_argument('--config', dest='config_path',
330 | default='config/voc.yaml', type=str)
331 | parser.add_argument('--evaluate', dest='evaluate',
332 | default=False, type=bool)
333 | parser.add_argument('--infer_samples', dest='infer_samples',
334 | default=True, type=bool)
335 | args = parser.parse_args()
336 |
337 | with torch.no_grad():
338 | if args.infer_samples:
339 | infer(args)
340 | else:
341 | print('Not Inferring for samples as `infer_samples` argument is False')
342 |
343 | if args.evaluate:
344 | evaluate_map(args)
345 | else:
346 | print('Not Evaluating as `evaluate` argument is False')
347 |
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import os
4 | import numpy as np
5 | import yaml
6 | import random
7 | from tqdm import tqdm
8 | from model.ssd import SSD
9 | import torchvision
10 | from dataset.voc import VOCDataset
11 | from torch.utils.data.dataloader import DataLoader
12 | from torch.optim.lr_scheduler import MultiStepLR
13 |
14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 |
16 | if torch.backends.mps.is_available():
17 | device = torch.device('mps')
18 | print('Using mps')
19 |
20 |
21 | def collate_function(data):
22 | return tuple(zip(*data))
23 |
24 |
25 | def train(args):
26 | # Read the config file #
27 | with open(args.config_path, 'r') as file:
28 | try:
29 | config = yaml.safe_load(file)
30 | except yaml.YAMLError as exc:
31 | print(exc)
32 | print(config)
33 | #########################
34 |
35 | dataset_config = config['dataset_params']
36 | train_config = config['train_params']
37 |
38 | seed = train_config['seed']
39 | torch.manual_seed(seed)
40 | np.random.seed(seed)
41 | random.seed(seed)
42 | if device == 'cuda':
43 | torch.cuda.manual_seed_all(seed)
44 |
45 | voc = VOCDataset('train',
46 | im_sets=dataset_config['train_im_sets'],
47 | im_size=dataset_config['im_size'])
48 | train_dataset = DataLoader(voc,
49 | batch_size=train_config['batch_size'],
50 | shuffle=True,
51 | collate_fn=collate_function)
52 |
53 | # Instantiate model and load checkpoint if present
54 | model = SSD(config=config['model_params'],
55 | num_classes=dataset_config['num_classes'])
56 | model.to(device)
57 | model.train()
58 | if os.path.exists(os.path.join(train_config['task_name'],
59 | train_config['ckpt_name'])):
60 | print('Loading checkpoint as one exists')
61 | model.load_state_dict(torch.load(
62 | os.path.join(train_config['task_name'],
63 | train_config['ckpt_name']),
64 | map_location=device))
65 |
66 | if not os.path.exists(train_config['task_name']):
67 | os.mkdir(train_config['task_name'])
68 |
69 | optimizer = torch.optim.SGD(lr=train_config['lr'],
70 | params=model.parameters(),
71 | weight_decay=5E-4, momentum=0.9)
72 | lr_scheduler = MultiStepLR(optimizer, milestones=train_config['lr_steps'], gamma=0.5)
73 | acc_steps = train_config['acc_steps']
74 | num_epochs = train_config['num_epochs']
75 | steps = 0
76 | for i in range(num_epochs):
77 | ssd_classification_losses = []
78 | ssd_localization_losses = []
79 | for idx, (ims, targets, _) in enumerate(tqdm(train_dataset)):
80 | for target in targets:
81 | target['boxes'] = target['bboxes'].float().to(device)
82 | del target['bboxes']
83 | target['labels'] = target['labels'].long().to(device)
84 | images = torch.stack([im.float().to(device) for im in ims], dim=0)
85 | batch_losses, _ = model(images, targets)
86 | loss = batch_losses['classification']
87 | loss += batch_losses['bbox_regression']
88 |
89 | ssd_classification_losses.append(batch_losses['classification'].item())
90 | ssd_localization_losses.append(batch_losses['bbox_regression'].item())
91 | loss = loss / acc_steps
92 | loss.backward()
93 |
94 | if (idx + 1) % acc_steps == 0:
95 | optimizer.step()
96 | optimizer.zero_grad()
97 | if steps % train_config['log_steps'] == 0:
98 | loss_output = ''
99 | loss_output += 'SSD Classification Loss : {:.4f}'.format(np.mean(ssd_classification_losses))
100 | loss_output += ' | SSD Localization Loss : {:.4f}'.format(np.mean(ssd_localization_losses))
101 | print(loss_output)
102 | if torch.isnan(loss):
103 | print('Loss is becoming nan. Exiting')
104 | exit(0)
105 | steps += 1
106 | optimizer.step()
107 | optimizer.zero_grad()
108 | lr_scheduler.step()
109 | print('Finished epoch {}'.format(i+1))
110 | loss_output = ''
111 | loss_output += 'SSD Classification Loss : {:.4f}'.format(np.mean(ssd_classification_losses))
112 | loss_output += ' | SSD Localization Loss : {:.4f}'.format(np.mean(ssd_localization_losses))
113 | print(loss_output)
114 | torch.save(model.state_dict(), os.path.join(train_config['task_name'],
115 | train_config['ckpt_name']))
116 | print('Done Training...')
117 |
118 |
119 | if __name__ == '__main__':
120 | parser = argparse.ArgumentParser(description='Arguments for ssd training')
121 | parser.add_argument('--config', dest='config_path',
122 | default='config/voc.yaml', type=str)
123 | args = parser.parse_args()
124 | train(args)
125 |
--------------------------------------------------------------------------------