├── .gitignore ├── README.md ├── config └── voc.yaml ├── dataset ├── __init__.py └── voc.py ├── model ├── __init__.py └── ssd.py ├── requirements.txt └── tools ├── __init__.py ├── infer.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore all image files 2 | *.jpg 3 | *.png 4 | *.jpeg 5 | 6 | # Ignore pycharm and system files 7 | .DS_Store 8 | *.idea 9 | __pycache__ 10 | *.zip 11 | 12 | # Ignore dataset files 13 | *.csv 14 | *.json 15 | 16 | # Ignore checkpoints 17 | *.pth 18 | 19 | # Ignore pickle files 20 | *.pkl -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | SSD Implementation in Pytorch 2 | ======== 3 | 4 | This repository implements SSD, with training, inference and mAP evaluation in PyTorch. 5 | Most of the code is just parts of pytorch ssd implementation and all I have done is gotten rid of abstractions and commented the code. 6 | 7 | The repo provides code to train on voc dataset. Specifically I trained on trainval images of VOC 2007 dataset and for testing, I use VOC2007 test set. 8 | 9 | ## SSD Explanation and Implementation Video 10 | 11 | SSD Explanation and Implementation 12 | 13 | 14 | 15 | ## Result by training SSD on VOC 2007 dataset 16 | One should be able to get **71-72% mAP** by training on VOC 2007 trainval images(**68% reported in paper**). 17 | 18 | Adding 2012 trainval we should be able to get **>77% mAP** 19 | 20 | 21 | 22 | 23 |
24 | 25 | Here's an evaluation result that I got after training 100 epochs. 26 | ``` 27 | Class Wise Average Precisions 28 | AP for class aeroplane = 0.7552 29 | AP for class bicycle = 0.8384 30 | AP for class bird = 0.7025 31 | AP for class boat = 0.6543 32 | AP for class bottle = 0.3411 33 | AP for class bus = 0.8355 34 | AP for class car = 0.8611 35 | AP for class cat = 0.8682 36 | AP for class chair = 0.4798 37 | AP for class cow = 0.7453 38 | AP for class diningtable = 0.7092 39 | AP for class dog = 0.8582 40 | AP for class horse = 0.8506 41 | AP for class motorbike = 0.8259 42 | AP for class person = 0.7721 43 | AP for class pottedplant = 0.3939 44 | AP for class sheep = 0.7300 45 | AP for class sofa = 0.7626 46 | AP for class train = 0.8615 47 | AP for class tvmonitor = 0.7260 48 | Mean Average Precision : 0.7286 49 | ``` 50 | 51 | 52 | ## Data preparation 53 | For setting up the VOC 2007 dataset: 54 | * Create a data directory inside SSD-Pytorch 55 | * Download VOC 2007 train/val data from http://host.robots.ox.ac.uk/pascal/VOC/voc2007 and copy the `VOC2007` directory inside `data` directory 56 | * Download VOC 2007 test data from http://host.robots.ox.ac.uk/pascal/VOC/voc2007 and copy the `VOC2007` directory and name it as `VOC2007-test` directory inside `data` 57 | * If you want to use 2012 trainval images as well, then download VOC 2012 train/val data from http://host.robots.ox.ac.uk/pascal/VOC/voc2007 and copy the `VOC2012` directory inside `data` 58 | * Ensure to place all the directories inside the data folder of repo according to below structure 59 | ``` 60 | SSD-Pytorch 61 | -> data 62 | -> VOC2007 63 | -> JPEGImages 64 | -> Annotations 65 | -> ImageSets 66 | -> VOC2007-test 67 | -> JPEGImages 68 | -> Annotations 69 | -> VOC2012 (if needed) 70 | -> JPEGImages 71 | -> Annotations 72 | -> ImageSets 73 | -> tools 74 | -> train.py 75 | -> infer.py 76 | -> config 77 | -> voc.yaml 78 | -> model 79 | -> ssd.py 80 | -> dataset 81 | -> voc.py 82 | ``` 83 | 84 | ## For training on your own dataset 85 | 86 | * Update the path for `train_im_sets`, `test_im_sets` in config 87 | * If you want to train on 2007+2012 trainval then have `train_im_sets` as `['data/VOC2007', 'data/VOC2012'] ` 88 | * Modify dataset file `dataset/voc.py` to load images and annotations accordingly specifically `load_images_and_anns` method 89 | * Update the class list of your dataset in the dataset file. 90 | * Dataset class should return the following: 91 | ``` 92 | im_tensor(C x H x W) , 93 | target{ 94 | 'bboxes': Number of Gts x 4 (this is in x1y1x2y2 format normalized from 0-1) 95 | 'labels': Number of Gts, 96 | 'difficult': Number of Gts, 97 | } 98 | file_path 99 | ``` 100 | 101 | 102 | ## For modifications 103 | * In case you have GPU which does not support large batch size, you can use a smaller batch size like 2 and then have `acc_steps` in config set as 4(to mimic 8 batch size training). 104 | * For using a different backbone you would have to change the following: 105 | * Change the backbone, extra conv layers and creation of feature maps in initialization of SSD model 106 | * Ensure the `out_channels` is correctly set as the channels in all feature maps to be used for prediction [here](https://github.com/explainingai-code/SSD-PyTorch/blob/main/model/ssd.py#L316) 107 | * In the forward method call the backbone and extra conv layers and ensure `outputs` is correctly set as list of feature maps [here](https://github.com/explainingai-code/SSD-PyTorch/blob/main/model/ssd.py#L472) 108 | 109 | # Quickstart 110 | * Create a new conda environment with python 3.10 then run below commands 111 | * ```git clone https://github.com/explainingai-code/SSD-PyTorch.git``` 112 | * ```cd SSD-PyTorch``` 113 | * ```pip install -r requirements.txt``` 114 | * For training/inference use the below commands passing the desired configuration file as the config argument in case you want to play with it. 115 | * ```python -m tools.train``` for training SSD on VOC dataset 116 | * ```python -m tools.infer --evaluate False --infer_samples True``` for generating inference predictions 117 | * ```python -m tools.infer --evaluate True --infer_samples False``` for evaluating on test dataset 118 | 119 | ## Configuration 120 | * ```config/voc.yaml``` - Allows you to play with different components of SSD on voc dataset 121 | 122 | 123 | ## Output 124 | Outputs will be saved according to the configuration present in yaml files. 125 | 126 | For every run a folder of `task_name` key in config will be created 127 | 128 | During training of SSD the following output will be saved 129 | * Latest Model checkpoint in ```task_name``` directory 130 | 131 | During inference the following output will be saved 132 | * Sample prediction outputs for images in ```task_name/samples``` 133 | 134 | ## Citations 135 | ``` 136 | @article{DBLP:journals/corr/LiuAESR15, 137 | author = {Wei Liu and 138 | Dragomir Anguelov and 139 | Dumitru Erhan and 140 | Christian Szegedy and 141 | Scott E. Reed and 142 | Cheng{-}Yang Fu and 143 | Alexander C. Berg}, 144 | title = {{SSD:} Single Shot MultiBox Detector}, 145 | journal = {CoRR}, 146 | volume = {abs/1512.02325}, 147 | year = {2015}, 148 | url = {http://arxiv.org/abs/1512.02325}, 149 | eprinttype = {arXiv}, 150 | eprint = {1512.02325}, 151 | timestamp = {Wed, 12 Feb 2020 08:32:49 +0100}, 152 | biburl = {https://dblp.org/rec/journals/corr/LiuAESR15.bib}, 153 | bibsource = {dblp computer science bibliography, https://dblp.org} 154 | } 155 | ``` 156 | -------------------------------------------------------------------------------- /config/voc.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | train_im_sets: ['data/VOC2007'] 3 | test_im_sets: ['data/VOC2007-test'] 4 | num_classes : 21 5 | im_size : 300 6 | 7 | model_params: 8 | im_channels : 3 9 | aspect_ratios : [ 10 | [ 1., 2., 0.5 ], 11 | [ 1., 2., 3., 0.5, .333 ], 12 | [ 1., 2., 3., 0.5, .333 ], 13 | [ 1., 2., 3., 0.5, .333 ], 14 | [ 1., 2., 0.5 ], 15 | [ 1., 2., 0.5 ] 16 | ] 17 | scales : [0.1, 0.2, 0.375, 0.55, 0.725, 0.9] 18 | iou_threshold : 0.5 19 | low_score_threshold : 0.01 20 | neg_pos_ratio : 3 21 | pre_nms_topK : 400 22 | detections_per_img : 200 23 | nms_threshold : 0.45 24 | 25 | train_params: 26 | task_name: 'voc' 27 | seed: 1111 28 | acc_steps: 1 29 | num_epochs: 100 30 | batch_size: 8 31 | lr_steps: [ 40, 50, 60, 70, 80, 90 ] 32 | lr: 0.001 33 | log_steps : 100 34 | infer_conf_threshold : 0.5 35 | ckpt_name: 'ssd_voc2007.pth' 36 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/SSD-PyTorch/41b309063138a9d32a0031cfda513f197631d50a/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/voc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.transforms.v2 4 | from torch.utils.data.dataset import Dataset 5 | import xml.etree.ElementTree as ET 6 | from torchvision import tv_tensors 7 | from torchvision.io import read_image 8 | 9 | 10 | def load_images_and_anns(im_sets, label2idx, ann_fname, split): 11 | r""" 12 | Method to get the xml files and for each file 13 | get all the objects and their ground truth detection 14 | information for the dataset 15 | :param im_sets: Sets of images to consider 16 | :param label2idx: Class Name to index mapping for dataset 17 | :param ann_fname: txt file containing image names{trainval.txt/test.txt} 18 | :param split: train/test 19 | :return: 20 | """ 21 | im_infos = [] 22 | ims = [] 23 | 24 | for im_set in im_sets: 25 | im_names = [] 26 | # Fetch all image names in txt file for this imageset 27 | for line in open(os.path.join( 28 | im_set, 'ImageSets', 'Main', '{}.txt'.format(ann_fname))): 29 | im_names.append(line.strip()) 30 | 31 | # Set annotation and image path 32 | ann_dir = os.path.join(im_set, 'Annotations') 33 | im_dir = os.path.join(im_set, 'JPEGImages') 34 | for im_name in im_names: 35 | ann_file = os.path.join(ann_dir, '{}.xml'.format(im_name)) 36 | im_info = {} 37 | ann_info = ET.parse(ann_file) 38 | root = ann_info.getroot() 39 | size = root.find('size') 40 | width = int(size.find('width').text) 41 | height = int(size.find('height').text) 42 | im_info['img_id'] = os.path.basename(ann_file).split('.xml')[0] 43 | im_info['filename'] = os.path.join( 44 | im_dir, '{}.jpg'.format(im_info['img_id']) 45 | ) 46 | im_info['width'] = width 47 | im_info['height'] = height 48 | detections = [] 49 | 50 | for obj in ann_info.findall('object'): 51 | det = {} 52 | label = label2idx[obj.find('name').text] 53 | difficult = int(obj.find('difficult').text) 54 | bbox_info = obj.find('bndbox') 55 | bbox = [ 56 | int(bbox_info.find('xmin').text) - 1, 57 | int(bbox_info.find('ymin').text) - 1, 58 | int(bbox_info.find('xmax').text) - 1, 59 | int(bbox_info.find('ymax').text) - 1 60 | ] 61 | det['label'] = label 62 | det['bbox'] = bbox 63 | det['difficult'] = difficult 64 | # At test time eval does the job of ignoring difficult 65 | detections.append(det) 66 | 67 | im_info['detections'] = detections 68 | im_infos.append(im_info) 69 | print('Total {} images found'.format(len(im_infos))) 70 | return im_infos 71 | 72 | 73 | class VOCDataset(Dataset): 74 | def __init__(self, split, im_sets, im_size=300): 75 | self.split = split 76 | 77 | # Imagesets for this dataset instance (VOC2007/VOC2007+VOC2012/VOC2007-test) 78 | self.im_sets = im_sets 79 | self.fname = 'trainval' if self.split == 'train' else 'test' 80 | self.im_size = im_size 81 | self.im_mean = [123.0, 117.0, 104.0] 82 | self.imagenet_mean = [0.485, 0.456, 0.406] 83 | self.imagenet_std = [0.229, 0.224, 0.225] 84 | 85 | # Train and test transformations 86 | self.transforms = { 87 | 'train': torchvision.transforms.v2.Compose([ 88 | torchvision.transforms.v2.RandomPhotometricDistort(), 89 | torchvision.transforms.v2.RandomZoomOut(fill=self.im_mean), 90 | torchvision.transforms.v2.RandomIoUCrop(), 91 | torchvision.transforms.v2.RandomHorizontalFlip(p=0.5), 92 | torchvision.transforms.v2.Resize(size=(self.im_size, self.im_size)), 93 | torchvision.transforms.v2.SanitizeBoundingBoxes( 94 | labels_getter=lambda transform_input: 95 | (transform_input[1]["labels"], transform_input[1]["difficult"])), 96 | torchvision.transforms.v2.ToPureTensor(), 97 | torchvision.transforms.v2.ToDtype(torch.float32, scale=True), 98 | torchvision.transforms.v2.Normalize(mean=self.imagenet_mean, 99 | std=self.imagenet_std) 100 | 101 | ]), 102 | 'test': torchvision.transforms.v2.Compose([ 103 | torchvision.transforms.v2.Resize(size=(self.im_size, self.im_size)), 104 | torchvision.transforms.v2.ToPureTensor(), 105 | torchvision.transforms.v2.ToDtype(torch.float32, scale=True), 106 | torchvision.transforms.v2.Normalize(mean=self.imagenet_mean, 107 | std=self.imagenet_std) 108 | ]), 109 | } 110 | 111 | classes = [ 112 | 'person', 'bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 113 | 'aeroplane', 'bicycle', 'boat', 'bus', 'car', 'motorbike', 'train', 114 | 'bottle', 'chair', 'diningtable', 'pottedplant', 'sofa', 'tvmonitor' 115 | ] 116 | classes = sorted(classes) 117 | # We need to add background class as well with 0 index 118 | classes = ['background'] + classes 119 | 120 | self.label2idx = {classes[idx]: idx for idx in range(len(classes))} 121 | self.idx2label = {idx: classes[idx] for idx in range(len(classes))} 122 | print(self.idx2label) 123 | self.images_info = load_images_and_anns(self.im_sets, 124 | self.label2idx, 125 | self.fname, 126 | self.split) 127 | 128 | def __len__(self): 129 | return len(self.images_info) 130 | 131 | def __getitem__(self, index): 132 | im_info = self.images_info[index] 133 | im = read_image(im_info['filename']) 134 | 135 | # Get annotations for this image 136 | targets = {} 137 | targets['bboxes'] = tv_tensors.BoundingBoxes( 138 | [detection['bbox'] for detection in im_info['detections']], 139 | format='XYXY', canvas_size=im.shape[-2:]) 140 | targets['labels'] = torch.as_tensor( 141 | [detection['label'] for detection in im_info['detections']]) 142 | targets['difficult'] = torch.as_tensor( 143 | [detection['difficult']for detection in im_info['detections']]) 144 | 145 | # Transform the image and targets 146 | transformed_info = self.transforms[self.split](im, targets) 147 | im_tensor, targets = transformed_info 148 | 149 | h, w = im_tensor.shape[-2:] 150 | wh_tensor = torch.as_tensor([[w, h, w, h]]).expand_as(targets['bboxes']) 151 | targets['bboxes'] = targets['bboxes'] / wh_tensor 152 | return im_tensor, targets, im_info['filename'] 153 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/SSD-PyTorch/41b309063138a9d32a0031cfda513f197631d50a/model/__init__.py -------------------------------------------------------------------------------- /model/ssd.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | import torchvision 5 | 6 | 7 | def get_iou(boxes1, boxes2): 8 | r""" 9 | IOU between two sets of boxes 10 | :param boxes1: (Tensor of shape N x 4) 11 | :param boxes2: (Tensor of shape M x 4) 12 | :return: IOU matrix of shape N x M 13 | """ 14 | 15 | # Area of boxes (x2-x1)*(y2-y1) 16 | area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) # (N,) 17 | area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) # (M,) 18 | 19 | # Get top left x1,y1 coordinate 20 | x_left = torch.max(boxes1[:, None, 0], boxes2[:, 0]) # (N, M) 21 | y_top = torch.max(boxes1[:, None, 1], boxes2[:, 1]) # (N, M) 22 | 23 | # Get bottom right x2,y2 coordinate 24 | x_right = torch.min(boxes1[:, None, 2], boxes2[:, 2]) # (N, M) 25 | y_bottom = torch.min(boxes1[:, None, 3], boxes2[:, 3]) # (N, M) 26 | 27 | intersection_area = ((x_right - x_left).clamp(min=0) * 28 | (y_bottom - y_top).clamp(min=0)) # (N, M) 29 | union = area1[:, None] + area2 - intersection_area # (N, M) 30 | iou = intersection_area / union # (N, M) 31 | return iou 32 | 33 | 34 | def boxes_to_transformation_targets(ground_truth_boxes, 35 | default_boxes, 36 | weights=(10., 10., 5., 5.)): 37 | r""" 38 | Method to compute targets for each default_boxes. 39 | Assumes boxes are in x1y1x2y2 format. 40 | We first convert boxes to cx,cy,w,h format and then 41 | compute targets based on following formulation 42 | target_dx = (gt_cx - default_boxes_cx) / default_boxes_w 43 | target_dy = (gt_cy - default_boxes_cy) / default_boxes_h 44 | target_dw = log(gt_w / default_boxes_w) 45 | target_dh = log(gt_h / default_boxes_h) 46 | :param ground_truth_boxes: (Tensor of shape N x 4) 47 | :param default_boxes: (Tensor of shape N x 4) 48 | :param weights: Tuple[float] -> (wx, wy, ww, wh) 49 | :return: regression_targets: (Tensor of shape N x 4) 50 | """ 51 | # # Get center_x,center_y,w,h from x1,y1,x2,y2 for default_boxes 52 | widths = default_boxes[:, 2] - default_boxes[:, 0] 53 | heights = default_boxes[:, 3] - default_boxes[:, 1] 54 | center_x = default_boxes[:, 0] + 0.5 * widths 55 | center_y = default_boxes[:, 1] + 0.5 * heights 56 | 57 | # # Get center_x,center_y,w,h from x1,y1,x2,y2 for gt boxes 58 | gt_widths = (ground_truth_boxes[:, 2] - ground_truth_boxes[:, 0]) 59 | gt_heights = ground_truth_boxes[:, 3] - ground_truth_boxes[:, 1] 60 | gt_center_x = ground_truth_boxes[:, 0] + 0.5 * gt_widths 61 | gt_center_y = ground_truth_boxes[:, 1] + 0.5 * gt_heights 62 | 63 | # Use formulation to compute all targets 64 | targets_dx = weights[0] * (gt_center_x - center_x) / widths 65 | targets_dy = weights[1] * (gt_center_y - center_y) / heights 66 | targets_dw = weights[2] * torch.log(gt_widths / widths) 67 | targets_dh = weights[3] * torch.log(gt_heights / heights) 68 | regression_targets = torch.stack((targets_dx, 69 | targets_dy, 70 | targets_dw, 71 | targets_dh), dim=1) 72 | return regression_targets 73 | 74 | 75 | def apply_regression_pred_to_default_boxes(box_transform_pred, 76 | default_boxes, 77 | weights=(10., 10., 5., 5.)): 78 | r""" 79 | Method to transform default_boxes based on transformation parameter 80 | prediction. 81 | Assumes boxes are in x1y1x2y2 format 82 | :param box_transform_pred: (Tensor of shape N x 4) 83 | :param default_boxes: (Tensor of shape N x 4) 84 | :param weights: Tuple[float] -> (wx, wy, ww, wh) 85 | :return: pred_boxes: (Tensor of shape N x 4) 86 | """ 87 | 88 | # Get cx, cy, w, h from x1,y1,x2,y2 89 | w = default_boxes[:, 2] - default_boxes[:, 0] 90 | h = default_boxes[:, 3] - default_boxes[:, 1] 91 | center_x = default_boxes[:, 0] + 0.5 * w 92 | center_y = default_boxes[:, 1] + 0.5 * h 93 | 94 | dx = box_transform_pred[..., 0] / weights[0] 95 | dy = box_transform_pred[..., 1] / weights[1] 96 | dw = box_transform_pred[..., 2] / weights[2] 97 | dh = box_transform_pred[..., 3] / weights[3] 98 | # dh -> (num_default_boxes) 99 | 100 | pred_center_x = dx * w + center_x 101 | pred_center_y = dy * h + center_y 102 | pred_w = torch.exp(dw) * w 103 | pred_h = torch.exp(dh) * h 104 | # pred_center_x -> (num_default_boxes, 4) 105 | 106 | pred_box_x1 = pred_center_x - 0.5 * pred_w 107 | pred_box_y1 = pred_center_y - 0.5 * pred_h 108 | pred_box_x2 = pred_center_x + 0.5 * pred_w 109 | pred_box_y2 = pred_center_y + 0.5 * pred_h 110 | 111 | pred_boxes = torch.stack(( 112 | pred_box_x1, 113 | pred_box_y1, 114 | pred_box_x2, 115 | pred_box_y2), 116 | dim=-1) 117 | return pred_boxes 118 | 119 | 120 | def generate_default_boxes(feat, aspect_ratios, scales): 121 | r""" 122 | Method to generate default_boxes for all feature maps the image 123 | :param feat: List[(Tensor of shape B x C x Feat_H x Feat x W)] 124 | :param aspect_ratios: List[List[float]] aspect ratios for each feature map 125 | :param scales: List[float] scales for each feature map 126 | :return: default_boxes : List[(Tensor of shape N x 4)] default_boxes over all 127 | feature maps aggregated for each batch image 128 | """ 129 | 130 | # List to store default boxes for all feature maps 131 | default_boxes = [] 132 | for k in range(len(feat)): 133 | # We first add the aspect ratio 1 and scale (sqrt(scale[k])*sqrt(scale[k+1]) 134 | s_prime_k = math.sqrt(scales[k] * scales[k + 1]) 135 | wh_pairs = [[s_prime_k, s_prime_k]] 136 | 137 | # Adding all possible w,h pairs according to 138 | # aspect ratio of the feature map k 139 | for ar in aspect_ratios[k]: 140 | sq_ar = math.sqrt(ar) 141 | w = scales[k] * sq_ar 142 | h = scales[k] / sq_ar 143 | 144 | wh_pairs.extend([[w, h]]) 145 | 146 | feat_h, feat_w = feat[k].shape[-2:] 147 | 148 | # These shifts will be the centre of each of the default boxes 149 | shifts_x = ((torch.arange(0, feat_w) + 0.5) / feat_w).to(torch.float32) 150 | shifts_y = ((torch.arange(0, feat_h) + 0.5) / feat_h).to(torch.float32) 151 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij") 152 | shift_x = shift_x.reshape(-1) 153 | shift_y = shift_y.reshape(-1) 154 | 155 | # Duplicate these shifts for as 156 | # many boxes(aspect ratios) 157 | # per position we have 158 | shifts = torch.stack((shift_x, shift_y) * len(wh_pairs), dim=-1).reshape(-1, 2) 159 | # shifts for first feature map will be (5776 x 2) 160 | 161 | wh_pairs = torch.as_tensor(wh_pairs) 162 | 163 | # Repeat the wh pairs for all positions in feature map 164 | wh_pairs = wh_pairs.repeat((feat_h * feat_w), 1) 165 | # wh_pairs for first feature map will be (5776 x 2) 166 | 167 | # Concat the shifts(cx cy) and wh values for all positions 168 | default_box = torch.cat((shifts, wh_pairs), dim=1) 169 | # default box for feat_1 -> (5776, 4) 170 | # default box for feat_2 -> (2166, 4) 171 | # default box for feat_3 -> (600, 4) 172 | # default box for feat_4 -> (150, 4) 173 | # default box for feat_5 -> (36, 4) 174 | # default box for feat_6 -> (4, 4) 175 | 176 | default_boxes.append(default_box) 177 | default_boxes = torch.cat(default_boxes, dim=0) 178 | # default_boxes -> (8732, 4) 179 | 180 | # We now duplicate these default boxes 181 | # for all images in the batch 182 | # and also convert cx,cy,w,h format of 183 | # default boxes to x1,y1,x2,y2 184 | dboxes = [] 185 | for _ in range(feat[0].size(0)): 186 | dboxes_in_image = default_boxes 187 | # x1 = cx - 0.5 * width 188 | # y1 = cy - 0.5 * height 189 | # x2 = cx + 0.5 * width 190 | # y2 = cy + 0.5 * height 191 | dboxes_in_image = torch.cat( 192 | [ 193 | (dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:]), 194 | (dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]), 195 | ], 196 | -1, 197 | ) 198 | dboxes.append(dboxes_in_image.to(feat[0].device)) 199 | return dboxes 200 | 201 | 202 | class SSD(nn.Module): 203 | r""" 204 | Main Class for SSD. Does the following steps 205 | to generate detections/losses. 206 | During initialization 207 | 1. Load VGG Imagenet pretrained model 208 | 2. Extract Backbone from VGG and add extra conv layers 209 | 3. Add class prediction and bbox transformation prediction layers 210 | 4. Initialize all conv2d layers 211 | 212 | During Forward Pass 213 | 1. Get conv4_3 output 214 | 2. Normalize and scale conv4_3 output (feat_output_1) 215 | 3. Pass the unscaled conv4_3 to conv5_3 layers and conv layers 216 | replacing fc6 and fc7 of vgg (feat_output_2) 217 | 4. Pass the conv_fc7 output to extra conv layers (feat_output_3-6) 218 | 5. Get the classification and regression predictions for all 6 feature maps 219 | 6. Generate default_boxes for all these feature maps(8732 x 4) 220 | 7a. If in training assign targets for these default_boxes and 221 | compute localization and classification losses 222 | 7b. If in inference mode, then do all pre-nms filtering, nms 223 | and then post nms filtering and return the detected boxes, 224 | their labels and their scores 225 | """ 226 | def __init__(self, config, num_classes=21): 227 | super().__init__() 228 | self.aspect_ratios = config['aspect_ratios'] 229 | 230 | self.scales = config['scales'] 231 | self.scales.append(1.0) 232 | 233 | self.num_classes = num_classes 234 | self.iou_threshold = config['iou_threshold'] 235 | self.low_score_threshold = config['low_score_threshold'] 236 | self.neg_pos_ratio = config['neg_pos_ratio'] 237 | self.pre_nms_topK = config['pre_nms_topK'] 238 | self.nms_threshold = config['nms_threshold'] 239 | self.detections_per_img = config['detections_per_img'] 240 | 241 | # Load imagenet pretrained vgg network 242 | backbone = torchvision.models.vgg16( 243 | weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1 244 | ) 245 | 246 | # Get all max pool indexes to determine different stages 247 | max_pool_pos = [idx for idx, layer in enumerate(list(backbone.features)) 248 | if isinstance(layer, nn.MaxPool2d)] 249 | max_pool_stage_3_pos = max_pool_pos[-3] # for vgg16 this would be 16 250 | max_pool_stage_4_pos = max_pool_pos[-2] # for vgg16 this would be 23 251 | 252 | backbone.features[max_pool_stage_3_pos].ceil_mode = True 253 | # otherwise vgg conv4_3 output will be 37x37 254 | self.features = nn.Sequential(*backbone.features[:max_pool_stage_4_pos]) 255 | self.scale_weight = nn.Parameter(torch.ones(512) * 20) 256 | 257 | ################################### 258 | # Conv5_3 + Conv for fc6 and fc 7 # 259 | ################################### 260 | # Conv modules replacing fc6 and fc7 261 | # Ideally we would copy the weights 262 | # but here we are just adding new layers 263 | # and not copying fc6 and fc7 weights by 264 | # subsampling 265 | fcs = nn.Sequential( 266 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 267 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, 268 | padding=6, dilation=6), 269 | nn.ReLU(inplace=True), 270 | nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1), 271 | nn.ReLU(inplace=True), 272 | ) 273 | self.conv5_3_fc = nn.Sequential( 274 | *backbone.features[max_pool_stage_4_pos:-1], 275 | fcs, 276 | ) 277 | 278 | ########################## 279 | # Additional Conv Layers # 280 | ########################## 281 | # Modules to take from 19x19 to 10x10 282 | self.conv8_2 = nn.Sequential( 283 | nn.Conv2d(1024, 256, kernel_size=1), 284 | nn.ReLU(inplace=True), 285 | nn.Conv2d(256, 512, kernel_size=3, padding=1, 286 | stride=2), 287 | nn.ReLU(inplace=True) 288 | ) 289 | 290 | # Modules to take from 10x10 to 5x5 291 | self.conv9_2 = nn.Sequential( 292 | nn.Conv2d(512, 128, kernel_size=1), 293 | nn.ReLU(inplace=True), 294 | nn.Conv2d(128, 256, kernel_size=3, padding=1, 295 | stride=2), 296 | nn.ReLU(inplace=True) 297 | ) 298 | 299 | # Modules to take from 5x5 to 3x3 300 | self.conv10_2 = nn.Sequential( 301 | nn.Conv2d(256, 128, kernel_size=1), 302 | nn.ReLU(inplace=True), 303 | nn.Conv2d(128, 256, kernel_size=3), 304 | nn.ReLU(inplace=True) 305 | ) 306 | 307 | # Modules to take from 3x3 to 1x1 308 | self.conv11_2 = nn.Sequential( 309 | nn.Conv2d(256, 128, kernel_size=1), 310 | nn.ReLU(inplace=True), 311 | nn.Conv2d(128, 256, kernel_size=3), 312 | nn.ReLU(inplace=True) 313 | ) 314 | 315 | # Must match conv4_3, fcs, conv8_2, conv9_2, conv10_2, conv11_2 316 | out_channels = [512, 1024, 512, 256, 256, 256] 317 | 318 | ##################### 319 | # Prediction Layers # 320 | ##################### 321 | self.cls_heads = nn.ModuleList() 322 | for channels, aspect_ratio in zip(out_channels, self.aspect_ratios): 323 | # extra 1 is added for scale of sqrt(sk*sk+1) 324 | self.cls_heads.append(nn.Conv2d(channels, 325 | self.num_classes * (len(aspect_ratio)+1), 326 | kernel_size=3, 327 | padding=1)) 328 | 329 | self.bbox_reg_heads = nn.ModuleList() 330 | for channels, aspect_ratio in zip(out_channels, self.aspect_ratios): 331 | # extra 1 is added for scale of sqrt(sk*sk+1) 332 | self.bbox_reg_heads.append(nn.Conv2d(channels, 4 * (len(aspect_ratio)+1), 333 | kernel_size=3, 334 | padding=1)) 335 | 336 | ############################# 337 | # Conv Layer Initialization # 338 | ############################# 339 | for layer in fcs.modules(): 340 | if isinstance(layer, nn.Conv2d): 341 | torch.nn.init.xavier_uniform_(layer.weight) 342 | if layer.bias is not None: 343 | torch.nn.init.constant_(layer.bias, 0.0) 344 | 345 | for conv_module in [self.conv8_2, self.conv9_2, self.conv10_2, self.conv11_2]: 346 | for layer in conv_module.modules(): 347 | if isinstance(layer, nn.Conv2d): 348 | torch.nn.init.xavier_uniform_(layer.weight) 349 | if layer.bias is not None: 350 | torch.nn.init.constant_(layer.bias, 0.0) 351 | 352 | for module in self.cls_heads: 353 | torch.nn.init.xavier_uniform_(module.weight) 354 | if module.bias is not None: 355 | torch.nn.init.constant_(module.bias, 0.0) 356 | for module in self.bbox_reg_heads: 357 | torch.nn.init.xavier_uniform_(module.weight) 358 | if module.bias is not None: 359 | torch.nn.init.constant_(module.bias, 0.0) 360 | 361 | def compute_loss( 362 | self, 363 | targets, 364 | cls_logits, 365 | bbox_regression, 366 | default_boxes, 367 | matched_idxs, 368 | ): 369 | # Counting all the foreground default_boxes for computing N in loss equation 370 | num_foreground = 0 371 | # BBox losses for all batch images(for foreground default_boxes) 372 | bbox_loss = [] 373 | # classification targets for all batch images(for ALL default_boxes) 374 | cls_targets = [] 375 | for ( 376 | targets_per_image, 377 | bbox_regression_per_image, 378 | cls_logits_per_image, 379 | default_boxes_per_image, 380 | matched_idxs_per_image, 381 | ) in zip(targets, bbox_regression, cls_logits, default_boxes, matched_idxs): 382 | # Foreground default_boxes -> matched_idx >=0 383 | # Background default_boxes -> matched_idx = -1 384 | fg_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] 385 | foreground_matched_idxs_per_image = matched_idxs_per_image[ 386 | fg_idxs_per_image 387 | ] 388 | num_foreground += foreground_matched_idxs_per_image.numel() 389 | 390 | # Get foreground default_boxes and their transformation predictions 391 | matched_gt_boxes_per_image = targets_per_image["boxes"][ 392 | foreground_matched_idxs_per_image 393 | ] 394 | bbox_regression_per_image = bbox_regression_per_image[fg_idxs_per_image, :] 395 | default_boxes_per_image = default_boxes_per_image[fg_idxs_per_image, :] 396 | target_regression = boxes_to_transformation_targets( 397 | matched_gt_boxes_per_image, 398 | default_boxes_per_image) 399 | 400 | bbox_loss.append( 401 | torch.nn.functional.smooth_l1_loss(bbox_regression_per_image, 402 | target_regression, 403 | reduction='sum') 404 | ) 405 | 406 | # Get classification target for ALL default_boxes 407 | # For all default_boxes set it as 0 first 408 | # Then set foreground default_boxes target as label 409 | # of assigned gt box 410 | gt_classes_target = torch.zeros( 411 | (cls_logits_per_image.size(0),), 412 | dtype=targets_per_image["labels"].dtype, 413 | device=targets_per_image["labels"].device, 414 | ) 415 | gt_classes_target[fg_idxs_per_image] = targets_per_image["labels"][ 416 | foreground_matched_idxs_per_image 417 | ] 418 | cls_targets.append(gt_classes_target) 419 | 420 | # Aggregated bbox loss and classification targets 421 | # for all batch images 422 | bbox_loss = torch.stack(bbox_loss) 423 | cls_targets = torch.stack(cls_targets) # (B, 8732) 424 | 425 | # Calculate classification loss for ALL default_boxes 426 | num_classes = cls_logits.size(-1) 427 | cls_loss = torch.nn.functional.cross_entropy(cls_logits.view(-1, num_classes), 428 | cls_targets.view(-1), 429 | reduction="none").view( 430 | cls_targets.size() 431 | ) 432 | 433 | # Hard Negative Mining 434 | foreground_idxs = cls_targets > 0 435 | # We will sample total of 3 x (number of fg default_boxes) 436 | # background default_boxes 437 | num_negative = self.neg_pos_ratio * foreground_idxs.sum(1, keepdim=True) 438 | 439 | # As of now cls_loss is for ALL default_boxes 440 | negative_loss = cls_loss.clone() 441 | # We want to ensure that after sorting based on loss value, 442 | # foreground default_boxes are never picked when choosing topK 443 | # highest loss indexes 444 | negative_loss[foreground_idxs] = -float("inf") 445 | values, idx = negative_loss.sort(1, descending=True) 446 | # Fetch those indexes which have in topK(K=num_negative) losses 447 | background_idxs = idx.sort(1)[1] < num_negative 448 | N = max(1, num_foreground) 449 | return { 450 | "bbox_regression": bbox_loss.sum() / N, 451 | "classification": (cls_loss[foreground_idxs].sum() + 452 | cls_loss[background_idxs].sum()) / N, 453 | } 454 | 455 | def forward(self, x, targets=None): 456 | # Call everything till conv4_3 layers first 457 | conv_4_3_out = self.features(x) 458 | 459 | # Scale conv4_3 output using learnt norm scale 460 | conv_4_3_out_scaled = (self.scale_weight.view(1, -1, 1, 1) * 461 | torch.nn.functional.normalize(conv_4_3_out)) 462 | 463 | # Call conv5_3 with non_scaled conv_3 and also 464 | # Call additional conv layers 465 | conv_5_3_fc_out = self.conv5_3_fc(conv_4_3_out) 466 | conv8_2_out = self.conv8_2(conv_5_3_fc_out) 467 | conv9_2_out = self.conv9_2(conv8_2_out) 468 | conv10_2_out = self.conv10_2(conv9_2_out) 469 | conv11_2_out = self.conv11_2(conv10_2_out) 470 | 471 | # Feature maps for predictions 472 | outputs = [ 473 | conv_4_3_out_scaled, # 38 x 38 474 | conv_5_3_fc_out, # 19 x 19 475 | conv8_2_out, # 10 x 10 476 | conv9_2_out, # 5 x 5 477 | conv10_2_out, # 3 x 3 478 | conv11_2_out, # 1 x 1 479 | ] 480 | 481 | # Classification and bbox regression for all feature maps 482 | cls_logits = [] 483 | bbox_reg_deltas = [] 484 | for i, features in enumerate(outputs): 485 | cls_feat_i = self.cls_heads[i](features) 486 | bbox_reg_feat_i = self.bbox_reg_heads[i](features) 487 | 488 | # Cls output from (B, A * num_classes, H, W) to (B, HWA, num_classes). 489 | N, _, H, W = cls_feat_i.shape 490 | cls_feat_i = cls_feat_i.view(N, -1, self.num_classes, H, W) 491 | # (B, A, num_classes, H, W) 492 | cls_feat_i = cls_feat_i.permute(0, 3, 4, 1, 2) # (B, H, W, A, num_classes) 493 | cls_feat_i = cls_feat_i.reshape(N, -1, self.num_classes) 494 | # (B, HWA, num_classes) 495 | cls_logits.append(cls_feat_i) 496 | 497 | # Permute bbox reg output from (B, A * 4, H, W) to (B, HWA, 4). 498 | N, _, H, W = bbox_reg_feat_i.shape 499 | bbox_reg_feat_i = bbox_reg_feat_i.view(N, -1, 4, H, W) # (B, A, 4, H, W) 500 | bbox_reg_feat_i = bbox_reg_feat_i.permute(0, 3, 4, 1, 2) # (B, H, W, A, 4) 501 | bbox_reg_feat_i = bbox_reg_feat_i.reshape(N, -1, 4) # Size=(B, HWA, 4) 502 | bbox_reg_deltas.append(bbox_reg_feat_i) 503 | 504 | # Concat cls logits and bbox regression predictions for all feature maps 505 | cls_logits = torch.cat(cls_logits, dim=1) # (B, 8732, num_classes) 506 | bbox_reg_deltas = torch.cat(bbox_reg_deltas, dim=1) # (B, 8732, 4) 507 | 508 | # Generate default_boxes for all feature maps 509 | default_boxes = generate_default_boxes(outputs, self.aspect_ratios, self.scales) 510 | # default_boxes -> List[Tensor of shape 8732 x 4] 511 | # len(default_boxes) = Batch size 512 | 513 | losses = {} 514 | detections = [] 515 | if self.training: 516 | # List to hold for each image, which default box 517 | # is assigned to with gt box if any 518 | # or unassigned(background) 519 | matched_idxs = [] 520 | for default_boxes_per_image, targets_per_image in zip(default_boxes, 521 | targets): 522 | if targets_per_image["boxes"].numel() == 0: 523 | matched_idxs.append( 524 | torch.full( 525 | (default_boxes_per_image.size(0),), -1, 526 | dtype=torch.int64, 527 | device=default_boxes_per_image.device 528 | ) 529 | ) 530 | continue 531 | iou_matrix = get_iou(targets_per_image["boxes"], 532 | default_boxes_per_image) 533 | # For each default box find best ground truth box 534 | matched_vals, matches = iou_matrix.max(dim=0) 535 | # matches -> [8732] 536 | 537 | # Update index of match for all default_boxes which 538 | # have maximum iou with a gt box < low threshold 539 | # as -1 540 | # This allows selecting foreground boxes as match index >= 0 541 | below_low_threshold = matched_vals < self.iou_threshold 542 | matches[below_low_threshold] = -1 543 | 544 | # We want to also assign the best default box for every gt 545 | # as foreground 546 | # So first find the best default box for every gt 547 | _, highest_quality_pred_foreach_gt = iou_matrix.max(dim=1) 548 | # Update the best matching gt index for these best default_boxes 549 | # as 0, 1, 2, ...., len(gt)-1 550 | matches[highest_quality_pred_foreach_gt] = torch.arange( 551 | highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, 552 | device=highest_quality_pred_foreach_gt.device 553 | ) 554 | matched_idxs.append(matches) 555 | losses = self.compute_loss(targets, cls_logits, bbox_reg_deltas, 556 | default_boxes, matched_idxs) 557 | else: 558 | # For test time we do the following: 559 | # 1. Convert default_boxes to boxes using predicted bbox regression deltas 560 | # 2. Low score filtering 561 | # 3. Pre-NMS TopK filtering 562 | # 4. NMS 563 | # 5. Post NMS TopK Filtering 564 | cls_scores = torch.nn.functional.softmax(cls_logits, dim=-1) 565 | num_classes = cls_scores.size(-1) 566 | 567 | for bbox_deltas_i, cls_scores_i, default_boxes_i in zip(bbox_reg_deltas, 568 | cls_scores, 569 | default_boxes): 570 | boxes = apply_regression_pred_to_default_boxes(bbox_deltas_i, 571 | default_boxes_i) 572 | # Ensure all values are between 0-1 573 | boxes.clamp_(min=0., max=1.) 574 | 575 | pred_boxes = [] 576 | pred_scores = [] 577 | pred_labels = [] 578 | # Class wise filtering 579 | for label in range(1, num_classes): 580 | score = cls_scores_i[:, label] 581 | 582 | # Remove low scoring boxes of this class 583 | keep_idxs = score > self.low_score_threshold 584 | score = score[keep_idxs] 585 | box = boxes[keep_idxs] 586 | 587 | # keep only topk scoring predictions of this class 588 | score, top_k_idxs = score.topk(min(self.pre_nms_topK, len(score))) 589 | box = box[top_k_idxs] 590 | 591 | pred_boxes.append(box) 592 | pred_scores.append(score) 593 | pred_labels.append(torch.full_like(score, fill_value=label, 594 | dtype=torch.int64, 595 | device=cls_scores.device)) 596 | 597 | pred_boxes = torch.cat(pred_boxes, dim=0) 598 | pred_scores = torch.cat(pred_scores, dim=0) 599 | pred_labels = torch.cat(pred_labels, dim=0) 600 | 601 | # Class wise NMS 602 | keep_mask = torch.zeros_like(pred_scores, dtype=torch.bool) 603 | for class_id in torch.unique(pred_labels): 604 | curr_indices = torch.where(pred_labels == class_id)[0] 605 | curr_keep_idxs = torch.ops.torchvision.nms(pred_boxes[curr_indices], 606 | pred_scores[curr_indices], 607 | self.nms_threshold) 608 | keep_mask[curr_indices[curr_keep_idxs]] = True 609 | keep_indices = torch.where(keep_mask)[0] 610 | post_nms_keep_indices = keep_indices[pred_scores[keep_indices].sort( 611 | descending=True)[1]] 612 | keep = post_nms_keep_indices[:self.detections_per_img] 613 | pred_boxes, pred_scores, pred_labels = (pred_boxes[keep], 614 | pred_scores[keep], 615 | pred_labels[keep]) 616 | 617 | detections.append( 618 | { 619 | "boxes": pred_boxes, 620 | "scores": pred_scores, 621 | "labels": pred_labels, 622 | } 623 | ) 624 | return losses, detections 625 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.8.0 2 | numpy==2.0.1 3 | opencv_python==4.10.0.84 4 | Pillow==10.4.0 5 | PyYAML==6.0.1 6 | torch==2.3.1 7 | torchvision==0.18.1 8 | tqdm==4.66.4 -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explainingai-code/SSD-PyTorch/41b309063138a9d32a0031cfda513f197631d50a/tools/__init__.py -------------------------------------------------------------------------------- /tools/infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | import yaml 5 | import random 6 | from tqdm import tqdm 7 | from model.ssd import SSD 8 | import numpy as np 9 | import cv2 10 | from dataset.voc import VOCDataset 11 | from torch.utils.data.dataloader import DataLoader 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | if torch.backends.mps.is_available(): 15 | device = torch.device('mps') 16 | print('Using mps') 17 | 18 | 19 | def get_iou(det, gt): 20 | det_x1, det_y1, det_x2, det_y2 = det 21 | gt_x1, gt_y1, gt_x2, gt_y2 = gt 22 | 23 | x_left = max(det_x1, gt_x1) 24 | y_top = max(det_y1, gt_y1) 25 | x_right = min(det_x2, gt_x2) 26 | y_bottom = min(det_y2, gt_y2) 27 | 28 | if x_right < x_left or y_bottom < y_top: 29 | return 0.0 30 | 31 | area_intersection = (x_right - x_left) * (y_bottom - y_top) 32 | det_area = (det_x2 - det_x1) * (det_y2 - det_y1) 33 | gt_area = (gt_x2 - gt_x1) * (gt_y2 - gt_y1) 34 | area_union = float(det_area + gt_area - area_intersection + 1E-6) 35 | iou = area_intersection / area_union 36 | return iou 37 | 38 | 39 | def compute_map(det_boxes, gt_boxes, iou_threshold=0.5, method='area', difficult=None): 40 | # det_boxes = [ 41 | # { 42 | # 'person' : [[x1, y1, x2, y2, score], ...], 43 | # 'car' : [[x1, y1, x2, y2, score], ...] 44 | # } 45 | # {det_boxes_img_2}, 46 | # ... 47 | # {det_boxes_img_N}, 48 | # ] 49 | # 50 | # gt_boxes = [ 51 | # { 52 | # 'person' : [[x1, y1, x2, y2], ...], 53 | # 'car' : [[x1, y1, x2, y2], ...] 54 | # }, 55 | # {gt_boxes_img_2}, 56 | # ... 57 | # {gt_boxes_img_N}, 58 | # ] 59 | 60 | gt_labels = {cls_key for im_gt in gt_boxes for cls_key in im_gt.keys()} 61 | gt_labels = sorted(gt_labels) 62 | 63 | all_aps = {} 64 | # average precisions for ALL classes 65 | aps = [] 66 | for idx, label in enumerate(gt_labels): 67 | # Get detection predictions of this class 68 | cls_dets = [ 69 | [im_idx, im_dets_label] for im_idx, im_dets in enumerate(det_boxes) 70 | if label in im_dets for im_dets_label in im_dets[label] 71 | ] 72 | 73 | # cls_dets = [ 74 | # (0, [x1_0, y1_0, x2_0, y2_0, score_0]), 75 | # ... 76 | # (0, [x1_M, y1_M, x2_M, y2_M, score_M]), 77 | # (1, [x1_0, y1_0, x2_0, y2_0, score_0]), 78 | # ... 79 | # (1, [x1_N, y1_N, x2_N, y2_N, score_N]), 80 | # ... 81 | # ] 82 | 83 | # Sort them by confidence score 84 | cls_dets = sorted(cls_dets, key=lambda k: -k[1][-1]) 85 | 86 | # For tracking which gt boxes of this class have already been matched 87 | gt_matched = [[False for _ in im_gts[label]] for im_gts in gt_boxes] 88 | # Number of gt boxes for this class for recall calculation 89 | num_gts = sum([len(im_gts[label]) for im_gts in gt_boxes]) 90 | num_difficults = sum([sum(difficults_label[label]) for difficults_label in difficult]) 91 | 92 | tp = [0] * len(cls_dets) 93 | fp = [0] * len(cls_dets) 94 | 95 | # For each prediction 96 | for det_idx, (im_idx, det_pred) in enumerate(cls_dets): 97 | # Get gt boxes for this image and this label 98 | im_gts = gt_boxes[im_idx][label] 99 | im_gt_difficults = difficult[im_idx][label] 100 | 101 | max_iou_found = -1 102 | max_iou_gt_idx = -1 103 | 104 | # Get best matching gt box 105 | for gt_box_idx, gt_box in enumerate(im_gts): 106 | gt_box_iou = get_iou(det_pred[:-1], gt_box) 107 | if gt_box_iou > max_iou_found: 108 | max_iou_found = gt_box_iou 109 | max_iou_gt_idx = gt_box_idx 110 | # TP only if iou >= threshold and this gt has not yet been matched 111 | if max_iou_found >= iou_threshold: 112 | if not im_gt_difficults[max_iou_gt_idx]: 113 | if not gt_matched[im_idx][max_iou_gt_idx]: 114 | # If tp then we set this gt box as matched 115 | gt_matched[im_idx][max_iou_gt_idx] = True 116 | tp[det_idx] = 1 117 | else: 118 | fp[det_idx] = 1 119 | else: 120 | fp[det_idx] = 1 121 | 122 | # Cumulative tp and fp 123 | tp = np.cumsum(tp) 124 | fp = np.cumsum(fp) 125 | 126 | eps = np.finfo(np.float32).eps 127 | # recalls = tp / np.maximum(num_gts, eps) 128 | recalls = tp / np.maximum(num_gts - num_difficults, eps) 129 | precisions = tp / np.maximum((tp + fp), eps) 130 | 131 | if method == 'area': 132 | recalls = np.concatenate(([0.0], recalls, [1.0])) 133 | precisions = np.concatenate(([0.0], precisions, [0.0])) 134 | 135 | # Replace precision values with recall r with maximum precision value 136 | # of any recall value >= r 137 | # This computes the precision envelope 138 | for i in range(precisions.size - 1, 0, -1): 139 | precisions[i - 1] = np.maximum(precisions[i - 1], precisions[i]) 140 | # For computing area, get points where recall changes value 141 | i = np.where(recalls[1:] != recalls[:-1])[0] 142 | # Add the rectangular areas to get ap 143 | ap = np.sum((recalls[i + 1] - recalls[i]) * precisions[i + 1]) 144 | elif method == 'interp': 145 | ap = 0.0 146 | for interp_pt in np.arange(0, 1 + 1E-3, 0.1): 147 | # Get precision values for recall values >= interp_pt 148 | prec_interp_pt = precisions[recalls >= interp_pt] 149 | 150 | # Get max of those precision values 151 | prec_interp_pt= prec_interp_pt.max() if prec_interp_pt.size>0.0 else 0.0 152 | ap += prec_interp_pt 153 | ap = ap / 11.0 154 | else: 155 | raise ValueError('Method can only be area or interp') 156 | if num_gts > 0: 157 | aps.append(ap) 158 | all_aps[label] = ap 159 | else: 160 | all_aps[label] = np.nan 161 | # compute mAP at provided iou threshold 162 | mean_ap = sum(aps) / len(aps) 163 | return mean_ap, all_aps 164 | 165 | 166 | def load_model_and_dataset(args): 167 | # Read the config file # 168 | with open(args.config_path, 'r') as file: 169 | try: 170 | config = yaml.safe_load(file) 171 | except yaml.YAMLError as exc: 172 | print(exc) 173 | print(config) 174 | ######################## 175 | 176 | dataset_config = config['dataset_params'] 177 | model_config = config['model_params'] 178 | train_config = config['train_params'] 179 | 180 | voc = VOCDataset('test', 181 | im_sets=dataset_config['test_im_sets']) 182 | test_dataset = DataLoader(voc, batch_size=1, shuffle=False) 183 | 184 | model = SSD(config=model_config, 185 | num_classes=dataset_config['num_classes']) 186 | model.to(device=torch.device(device)) 187 | model.eval() 188 | 189 | assert os.path.exists(os.path.join(train_config['task_name'], 190 | train_config['ckpt_name'])), \ 191 | "No checkpoint exists at {}".format(os.path.join(train_config['task_name'], 192 | train_config['ckpt_name'])) 193 | model.load_state_dict(torch.load(os.path.join(train_config['task_name'], 194 | train_config['ckpt_name']), 195 | map_location=device)) 196 | return model, voc, test_dataset, config 197 | 198 | 199 | def infer(args): 200 | if not os.path.exists('samples'): 201 | os.mkdir('samples') 202 | 203 | model, voc, test_dataset, config = load_model_and_dataset(args) 204 | conf_threshold = config['train_params']['infer_conf_threshold'] 205 | model.low_score_threshold = conf_threshold 206 | 207 | num_samples = 5 208 | for i in tqdm(range(num_samples)): 209 | dataset_idx = random.randint(0, len(voc)) 210 | im_tensor, target, fname = voc[dataset_idx] 211 | _, ssd_detections = model(im_tensor.unsqueeze(0).to(device), [target]) 212 | 213 | gt_im = cv2.imread(fname) 214 | h, w = gt_im.shape[:2] 215 | gt_im_copy = gt_im.copy() 216 | # Saving images with ground truth boxes 217 | for idx, box in enumerate(target['bboxes']): 218 | x1, y1, x2, y2 = box.detach().cpu().numpy() 219 | x1, y1, x2, y2 = int(w*x1), int(h*y1), int(w*x2), int(h*y2) 220 | cv2.rectangle(gt_im, (x1, y1), (x2, y2), thickness=2, color=[0, 255, 0]) 221 | cv2.rectangle(gt_im_copy, (x1, y1), (x2, y2), thickness=2, color=[0, 255, 0]) 222 | text = voc.idx2label[target['labels'][idx].detach().cpu().item()] 223 | text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_PLAIN, 1, 1) 224 | text_w, text_h = text_size 225 | cv2.rectangle(gt_im_copy, (x1, y1), (x1 + 10 + text_w, y1 + 10 + text_h), [255, 255, 255], -1) 226 | cv2.putText(gt_im, text=voc.idx2label[target['labels'][idx].detach().cpu().item()], 227 | org=(x1 + 5, y1 + 15), 228 | thickness=1, 229 | fontScale=1, 230 | color=[0, 0, 0], 231 | fontFace=cv2.FONT_HERSHEY_PLAIN) 232 | cv2.putText(gt_im_copy, text=text, 233 | org=(x1 + 5, y1 + 15), 234 | thickness=1, 235 | fontScale=1, 236 | color=[0, 0, 0], 237 | fontFace=cv2.FONT_HERSHEY_PLAIN) 238 | cv2.addWeighted(gt_im_copy, 0.7, gt_im, 0.3, 0, gt_im) 239 | cv2.imwrite('samples/output_ssd_gt_{}.png'.format(i), gt_im) 240 | 241 | # Getting predictions from trained model 242 | boxes = ssd_detections[0]['boxes'] 243 | labels = ssd_detections[0]['labels'] 244 | scores = ssd_detections[0]['scores'] 245 | im = cv2.imread(fname) 246 | im_copy = im.copy() 247 | 248 | # Saving images with predicted boxes 249 | for idx, box in enumerate(boxes): 250 | x1, y1, x2, y2 = box.detach().cpu().numpy() 251 | x1, y1, x2, y2 = int(w * x1), int(h * y1), int(w * x2), int(h * y2) 252 | cv2.rectangle(im, (x1, y1), (x2, y2), thickness=2, color=[0, 0, 255]) 253 | cv2.rectangle(im_copy, (x1, y1), (x2, y2), thickness=2, color=[0, 0, 255]) 254 | text = '{} : {:.2f}'.format(voc.idx2label[labels[idx].detach().cpu().item()], 255 | scores[idx].detach().cpu().item()) 256 | text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_PLAIN, 1, 1) 257 | text_w, text_h = text_size 258 | cv2.rectangle(im_copy, (x1, y1), (x1 + 10 + text_w, y1 + 10 + text_h), [255, 255, 255], -1) 259 | cv2.putText(im, text=text, 260 | org=(x1 + 5, y1 + 15), 261 | thickness=1, 262 | fontScale=1, 263 | color=[0, 0, 0], 264 | fontFace=cv2.FONT_HERSHEY_PLAIN) 265 | cv2.putText(im_copy, text=text, 266 | org=(x1 + 5, y1 + 15), 267 | thickness=1, 268 | fontScale=1, 269 | color=[0, 0, 0], 270 | fontFace=cv2.FONT_HERSHEY_PLAIN) 271 | cv2.addWeighted(im_copy, 0.7, im, 0.3, 0, im) 272 | cv2.imwrite('samples/output_ssd_{}.jpg'.format(i), im) 273 | 274 | print('Done Detecting...') 275 | 276 | 277 | def evaluate_map(args): 278 | model, voc, test_dataset, config = load_model_and_dataset(args) 279 | 280 | gts = [] 281 | preds = [] 282 | difficults = [] 283 | for im_tensor, target, fname in tqdm(test_dataset): 284 | im_tensor = im_tensor.float().to(device) 285 | target_bboxes = target['bboxes'].float()[0].to(device) 286 | target_labels = target['labels'].long()[0].to(device) 287 | difficult = target['difficult'].long()[0].to(device) 288 | _, ssd_detections = model(im_tensor) 289 | 290 | boxes = ssd_detections[0]['boxes'] 291 | labels = ssd_detections[0]['labels'] 292 | scores = ssd_detections[0]['scores'] 293 | 294 | pred_boxes = {} 295 | gt_boxes = {} 296 | difficult_boxes = {} 297 | 298 | for label_name in voc.label2idx: 299 | pred_boxes[label_name] = [] 300 | gt_boxes[label_name] = [] 301 | difficult_boxes[label_name] = [] 302 | 303 | for idx, box in enumerate(boxes): 304 | x1, y1, x2, y2 = box.detach().cpu().numpy() 305 | label = labels[idx].detach().cpu().item() 306 | score = scores[idx].detach().cpu().item() 307 | label_name = voc.idx2label[label] 308 | pred_boxes[label_name].append([x1, y1, x2, y2, score]) 309 | for idx, box in enumerate(target_bboxes): 310 | x1, y1, x2, y2 = box.detach().cpu().numpy() 311 | label = target_labels[idx].detach().cpu().item() 312 | label_name = voc.idx2label[label] 313 | gt_boxes[label_name].append([x1, y1, x2, y2]) 314 | difficult_boxes[label_name].append(difficult[idx].detach().cpu().item()) 315 | 316 | gts.append(gt_boxes) 317 | preds.append(pred_boxes) 318 | difficults.append(difficult_boxes) 319 | mean_ap, all_aps = compute_map(preds, gts, method='area', difficult=difficults) 320 | print('Class Wise Average Precisions') 321 | for idx in range(len(voc.idx2label)): 322 | print('AP for class {} = {:.4f}'.format(voc.idx2label[idx], 323 | all_aps[voc.idx2label[idx]])) 324 | print('Mean Average Precision : {:.4f}'.format(mean_ap)) 325 | 326 | 327 | if __name__ == '__main__': 328 | parser = argparse.ArgumentParser(description='Arguments for ssd inference') 329 | parser.add_argument('--config', dest='config_path', 330 | default='config/voc.yaml', type=str) 331 | parser.add_argument('--evaluate', dest='evaluate', 332 | default=False, type=bool) 333 | parser.add_argument('--infer_samples', dest='infer_samples', 334 | default=True, type=bool) 335 | args = parser.parse_args() 336 | 337 | with torch.no_grad(): 338 | if args.infer_samples: 339 | infer(args) 340 | else: 341 | print('Not Inferring for samples as `infer_samples` argument is False') 342 | 343 | if args.evaluate: 344 | evaluate_map(args) 345 | else: 346 | print('Not Evaluating as `evaluate` argument is False') 347 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | import numpy as np 5 | import yaml 6 | import random 7 | from tqdm import tqdm 8 | from model.ssd import SSD 9 | import torchvision 10 | from dataset.voc import VOCDataset 11 | from torch.utils.data.dataloader import DataLoader 12 | from torch.optim.lr_scheduler import MultiStepLR 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | if torch.backends.mps.is_available(): 17 | device = torch.device('mps') 18 | print('Using mps') 19 | 20 | 21 | def collate_function(data): 22 | return tuple(zip(*data)) 23 | 24 | 25 | def train(args): 26 | # Read the config file # 27 | with open(args.config_path, 'r') as file: 28 | try: 29 | config = yaml.safe_load(file) 30 | except yaml.YAMLError as exc: 31 | print(exc) 32 | print(config) 33 | ######################### 34 | 35 | dataset_config = config['dataset_params'] 36 | train_config = config['train_params'] 37 | 38 | seed = train_config['seed'] 39 | torch.manual_seed(seed) 40 | np.random.seed(seed) 41 | random.seed(seed) 42 | if device == 'cuda': 43 | torch.cuda.manual_seed_all(seed) 44 | 45 | voc = VOCDataset('train', 46 | im_sets=dataset_config['train_im_sets'], 47 | im_size=dataset_config['im_size']) 48 | train_dataset = DataLoader(voc, 49 | batch_size=train_config['batch_size'], 50 | shuffle=True, 51 | collate_fn=collate_function) 52 | 53 | # Instantiate model and load checkpoint if present 54 | model = SSD(config=config['model_params'], 55 | num_classes=dataset_config['num_classes']) 56 | model.to(device) 57 | model.train() 58 | if os.path.exists(os.path.join(train_config['task_name'], 59 | train_config['ckpt_name'])): 60 | print('Loading checkpoint as one exists') 61 | model.load_state_dict(torch.load( 62 | os.path.join(train_config['task_name'], 63 | train_config['ckpt_name']), 64 | map_location=device)) 65 | 66 | if not os.path.exists(train_config['task_name']): 67 | os.mkdir(train_config['task_name']) 68 | 69 | optimizer = torch.optim.SGD(lr=train_config['lr'], 70 | params=model.parameters(), 71 | weight_decay=5E-4, momentum=0.9) 72 | lr_scheduler = MultiStepLR(optimizer, milestones=train_config['lr_steps'], gamma=0.5) 73 | acc_steps = train_config['acc_steps'] 74 | num_epochs = train_config['num_epochs'] 75 | steps = 0 76 | for i in range(num_epochs): 77 | ssd_classification_losses = [] 78 | ssd_localization_losses = [] 79 | for idx, (ims, targets, _) in enumerate(tqdm(train_dataset)): 80 | for target in targets: 81 | target['boxes'] = target['bboxes'].float().to(device) 82 | del target['bboxes'] 83 | target['labels'] = target['labels'].long().to(device) 84 | images = torch.stack([im.float().to(device) for im in ims], dim=0) 85 | batch_losses, _ = model(images, targets) 86 | loss = batch_losses['classification'] 87 | loss += batch_losses['bbox_regression'] 88 | 89 | ssd_classification_losses.append(batch_losses['classification'].item()) 90 | ssd_localization_losses.append(batch_losses['bbox_regression'].item()) 91 | loss = loss / acc_steps 92 | loss.backward() 93 | 94 | if (idx + 1) % acc_steps == 0: 95 | optimizer.step() 96 | optimizer.zero_grad() 97 | if steps % train_config['log_steps'] == 0: 98 | loss_output = '' 99 | loss_output += 'SSD Classification Loss : {:.4f}'.format(np.mean(ssd_classification_losses)) 100 | loss_output += ' | SSD Localization Loss : {:.4f}'.format(np.mean(ssd_localization_losses)) 101 | print(loss_output) 102 | if torch.isnan(loss): 103 | print('Loss is becoming nan. Exiting') 104 | exit(0) 105 | steps += 1 106 | optimizer.step() 107 | optimizer.zero_grad() 108 | lr_scheduler.step() 109 | print('Finished epoch {}'.format(i+1)) 110 | loss_output = '' 111 | loss_output += 'SSD Classification Loss : {:.4f}'.format(np.mean(ssd_classification_losses)) 112 | loss_output += ' | SSD Localization Loss : {:.4f}'.format(np.mean(ssd_localization_losses)) 113 | print(loss_output) 114 | torch.save(model.state_dict(), os.path.join(train_config['task_name'], 115 | train_config['ckpt_name'])) 116 | print('Done Training...') 117 | 118 | 119 | if __name__ == '__main__': 120 | parser = argparse.ArgumentParser(description='Arguments for ssd training') 121 | parser.add_argument('--config', dest='config_path', 122 | default='config/voc.yaml', type=str) 123 | args = parser.parse_args() 124 | train(args) 125 | --------------------------------------------------------------------------------