├── .gitignore ├── README.md ├── analysis ├── plots.py └── scrape_logfiles.py ├── base ├── __init__.py ├── base_data_loader.py ├── base_model.py └── base_trainer.py ├── config.json ├── constants.py ├── data_loader ├── data_loaders.py └── datasets.py ├── logger ├── __init__.py ├── logger.py ├── logger_config.json └── visualization.py ├── model ├── cooc_layers.py ├── loss.py ├── metric.py └── model.py ├── parse_config.py ├── test.py ├── test_image.py ├── train.py ├── trainer ├── __init__.py └── trainer.py └── utils ├── __init__.py ├── joint_transforms.py ├── model_utils.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore data directory 2 | data/ 3 | **/*.csv 4 | 5 | 6 | # Byte-compiled / optimized / DLL files 7 | **/__pycache__/ 8 | **/*.py[cod] 9 | **/*$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # dotenv 88 | .env 89 | 90 | # virtualenv 91 | .venv 92 | venv/ 93 | ENV/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | # input data, saved log, checkpoints 109 | data/ 110 | input/ 111 | saved/ 112 | datasets/ 113 | 114 | # editor, os cache directory 115 | .vscode/ 116 | .idea/ 117 | __MACOSX/ 118 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # small-object-detection 2 | 3 | CS231n project, Spring 2019 4 | 5 | ## Overview 6 | Experiments with different models for object detection on the Pascal VOC 2007 dataset. 7 | 8 | See model/ directory for models: SSD300 and SSCoD. See for Faster R-CNN+GAN 9 | 10 | The implementation of the novel spatial co-occurrence layer is in model/cooc_layers.py. This is an extension of the convolutional co-occurrence layer from Shih et al. 2017 (CVPR) for generating local spatial cross-correlation signals. 11 | 12 | 13 | ### Requirements 14 | PyTorch >=1.0 with torchvision, image processing libraries (PIL, cv2). Python 3.7 Anaconda distribution. 15 | 16 | ## Instructions 17 | Set hyperparameters in config.json. 18 | 19 | Run 20 | ``` 21 | python train.py -c config.json 22 | ``` 23 | 24 | Model checkpoints are automatically saved. Resume training with 25 | ``` 26 | python train.py -r saved/models/path-to-checkpoint.pth 27 | ``` 28 | 29 | After training, a selected model can be used for testing with e.g. 30 | ``` 31 | python test.py -r saved/models/path-to-checkpoint.pth 32 | ``` 33 | 34 | Remaining scripts in root directory are self-explanatory, e.g. producing images with bounding boxes. 35 | 36 | 37 | ## Acknowledgements 38 | The following repos were essential to our work: 39 | 40 | 41 | 42 | The basic project backbone (loggers, model saving, base classes) was adapted from the first repo. I wrote the data loading and preprocessing and all the model components, and rewrote the trainer module. Some helper functions were closely adapted from the tutorial in the second link, which also guided me in writing the models and loss modules. The torchvision transforms are set up to take only images but not corresponding bounding boxes so these were subclassed and updated. The SSCoD model was written from scratch with help from useful PyTorch forum posts. 43 | 44 | 45 | -------------------------------------------------------------------------------- /analysis/plots.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | 5 | df = pd.read_csv('../log.csv') 6 | 7 | x = df.epoch 8 | train = df.train 9 | val = df.val 10 | 11 | fig, ax = plt.subplots(ncols=3) 12 | 13 | ax[0].plot(x, train, label='train') 14 | ax[0].plot(x, val, label='val') 15 | ax[0].set_xlabel('Epoch') 16 | ax[0].set_ylabel('Loss') 17 | ax[0].set_title('SSD300') 18 | ax[0].legend() 19 | plt.show() 20 | -------------------------------------------------------------------------------- /analysis/scrape_logfiles.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | 4 | LOG = '../saved/log/VOC_SSD/0601_052820/info.log' 5 | 6 | def load_log(logfile): 7 | epochs = [] 8 | train_loss = [] 9 | val_loss = [] 10 | 11 | with open(LOG, 'r') as f: 12 | for l in f: 13 | if ' epoch ' in l: 14 | tmp = l.split(':')[-1] 15 | tmp = "".join(tmp.split()) 16 | epochs.append(int(tmp)) 17 | elif ' loss ' in l: 18 | tmp = l.split(':')[-1] 19 | tmp = "".join(tmp.split()) 20 | train_loss.append(float(tmp)) 21 | elif ' val_loss ' in l: 22 | tmp = l.split(':')[-1] 23 | tmp = "".join(tmp.split()) 24 | val_loss.append(float(tmp)) 25 | 26 | cols = {'epoch': epochs, 'train': train_loss, 'val': val_loss} 27 | loss = pd.DataFrame(cols) 28 | 29 | return loss 30 | 31 | 32 | if __name__ == '__main__': 33 | 34 | loss = load_log(LOG) 35 | print(loss.head()) 36 | loss.to_csv('./log.csv', index=False) 37 | -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_data_loader import * 2 | from .base_model import * 3 | from .base_trainer import * 4 | -------------------------------------------------------------------------------- /base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | 6 | 7 | class BaseDataLoader(DataLoader): 8 | """ 9 | Base class for all data loaders 10 | """ 11 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): 12 | self.validation_split = validation_split 13 | self.shuffle = shuffle 14 | 15 | self.batch_idx = 0 16 | self.n_samples = len(dataset) 17 | 18 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) 19 | 20 | self.init_kwargs = { 21 | 'dataset': dataset, 22 | 'batch_size': batch_size, 23 | 'shuffle': self.shuffle, 24 | 'collate_fn': collate_fn, 25 | 'num_workers': num_workers 26 | } 27 | super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs) 28 | 29 | def _split_sampler(self, split): 30 | if split == 0.0: 31 | return None, None 32 | 33 | idx_full = np.arange(self.n_samples) 34 | 35 | np.random.seed(0) 36 | np.random.shuffle(idx_full) 37 | 38 | if isinstance(split, int): 39 | assert split > 0 40 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." 41 | len_valid = split 42 | else: 43 | len_valid = int(self.n_samples * split) 44 | 45 | valid_idx = idx_full[0:len_valid] 46 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 47 | 48 | train_sampler = SubsetRandomSampler(train_idx) 49 | valid_sampler = SubsetRandomSampler(valid_idx) 50 | 51 | # turn off shuffle option which is mutually exclusive with sampler 52 | self.shuffle = False 53 | self.n_samples = len(train_idx) 54 | 55 | return train_sampler, valid_sampler 56 | 57 | def split_validation(self): 58 | if self.valid_sampler is None: 59 | return None 60 | else: 61 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 62 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from abc import abstractmethod 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base class for all models 9 | """ 10 | @abstractmethod 11 | def forward(self, *input): 12 | """ 13 | Forward pass logic 14 | 15 | :return: Model output 16 | """ 17 | raise NotImplementedError 18 | 19 | def __str__(self): 20 | """ 21 | Model prints with number of trainable parameters 22 | """ 23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 24 | params = sum([np.prod(p.size()) for p in model_parameters]) 25 | return super(BaseModel, self).__str__() + '\nTrainable parameters: {}'.format(params) 26 | -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod 3 | from numpy import inf 4 | from logger import WriterTensorboardX 5 | 6 | from constants import DEVICE 7 | 8 | 9 | class BaseTrainer: 10 | """ 11 | Base class for all trainers 12 | """ 13 | def __init__(self, model, loss, metrics, optimizer, config): 14 | self.config = config 15 | self.logger = config.get_logger('trainer', config['trainer']['verbosity']) 16 | 17 | if torch.cuda.is_available(): 18 | print('CUDA device available.') 19 | print('Using device: {0}'.format(DEVICE)) 20 | 21 | self.model = model.to(DEVICE) 22 | self.loss = loss 23 | self.metrics = metrics 24 | self.optimizer = optimizer 25 | 26 | cfg_trainer = config['trainer'] 27 | self.epochs = cfg_trainer['epochs'] 28 | self.save_period = cfg_trainer['save_period'] 29 | self.monitor = cfg_trainer.get('monitor', 'off') 30 | 31 | # configuration to monitor model performance and save best 32 | if self.monitor == 'off': 33 | self.mnt_mode = 'off' 34 | self.mnt_best = 0 35 | else: 36 | self.mnt_mode, self.mnt_metric = self.monitor.split() 37 | assert self.mnt_mode in ['min', 'max'] 38 | 39 | self.mnt_best = inf if self.mnt_mode == 'min' else -inf 40 | self.early_stop = cfg_trainer.get('early_stop', inf) 41 | 42 | self.start_epoch = 1 43 | 44 | self.checkpoint_dir = config.save_dir 45 | # setup visualization writer instance 46 | self.writer = WriterTensorboardX(config.log_dir, self.logger, cfg_trainer['tensorboardX']) 47 | 48 | if config.resume is not None: 49 | self._resume_checkpoint(config.resume) 50 | 51 | @abstractmethod 52 | def _train_epoch(self, epoch): 53 | """ 54 | Training logic for an epoch 55 | 56 | :param epoch: Current epoch number 57 | """ 58 | raise NotImplementedError 59 | 60 | def train(self): 61 | """ 62 | Full training logic 63 | """ 64 | for epoch in range(self.start_epoch, self.epochs + 1): 65 | result = self._train_epoch(epoch) 66 | 67 | # save logged informations into log dict 68 | log = {'epoch': epoch} 69 | for key, value in result.items(): 70 | if key == 'metrics': 71 | log.update({mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)}) 72 | elif key == 'val_metrics': 73 | log.update({'val_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)}) 74 | else: 75 | log[key] = value 76 | 77 | # print logged informations to the screen 78 | for key, value in log.items(): 79 | self.logger.info(' {:15s}: {}'.format(str(key), value)) 80 | 81 | # evaluate model performance according to configured metric, save best checkpoint as model_best 82 | best = False 83 | if self.mnt_mode != 'off': 84 | not_improved_count = 0 85 | 86 | try: 87 | # check whether model performance improved or not, according to specified metric(mnt_metric) 88 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ 89 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) 90 | except KeyError: 91 | self.logger.warning("Warning: Metric '{}' is not found. " 92 | "Model performance monitoring is disabled.".format(self.mnt_metric)) 93 | self.mnt_mode = 'off' 94 | improved = False 95 | not_improved_count = 0 96 | 97 | if improved: 98 | self.mnt_best = log[self.mnt_metric] 99 | not_improved_count = 0 100 | best = True 101 | else: 102 | not_improved_count += 1 103 | 104 | if not_improved_count > self.early_stop: 105 | self.logger.info("Validation performance didn\'t improve for {} epochs. " 106 | "Training stops.".format(self.early_stop)) 107 | break 108 | 109 | if epoch % self.save_period == 0: 110 | self._save_checkpoint(epoch, save_best=best) 111 | 112 | def _save_checkpoint(self, epoch, save_best=False): 113 | """ 114 | Saving checkpoints 115 | 116 | :param epoch: current epoch number 117 | :param log: logging information of the epoch 118 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth' 119 | """ 120 | arch = type(self.model).__name__ 121 | state = { 122 | 'arch': arch, 123 | 'epoch': epoch, 124 | 'state_dict': self.model.state_dict(), 125 | 'optimizer': self.optimizer.state_dict(), 126 | 'monitor_best': self.mnt_best, 127 | 'config': self.config 128 | } 129 | filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) 130 | torch.save(state, filename) 131 | self.logger.info("Saving checkpoint: {} ...".format(filename)) 132 | if save_best: 133 | best_path = str(self.checkpoint_dir / 'model_best.pth') 134 | torch.save(state, best_path) 135 | self.logger.info("Saving current best: model_best.pth ...") 136 | 137 | def _resume_checkpoint(self, resume_path): 138 | """ 139 | Resume from saved checkpoints 140 | 141 | :param resume_path: Checkpoint path to be resumed 142 | """ 143 | resume_path = str(resume_path) 144 | self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 145 | checkpoint = torch.load(resume_path) 146 | self.start_epoch = checkpoint['epoch'] + 1 147 | self.mnt_best = checkpoint['monitor_best'] 148 | 149 | # load architecture params from checkpoint. 150 | if checkpoint['config']['arch'] != self.config['arch']: 151 | self.logger.warning("Warning: Architecture configuration given in config file is different from that of " 152 | "checkpoint. This may yield an exception while state_dict is being loaded.") 153 | self.model.load_state_dict(checkpoint['state_dict']) 154 | 155 | # load optimizer state from checkpoint only when optimizer type is not changed. 156 | if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: 157 | self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " 158 | "Optimizer parameters not being resumed.") 159 | else: 160 | self.optimizer.load_state_dict(checkpoint['optimizer']) 161 | 162 | self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) 163 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VOC_SSD", 3 | 4 | "arch": { 5 | "type": "SSCoD", 6 | "args": {"n_classes": 21} 7 | }, 8 | "data_loader": { 9 | "type": "VOCDataLoader", 10 | "args":{ 11 | "data_dir": "./data/", 12 | "image_size": 300, 13 | "batch_size": 8, 14 | "shuffle": true, 15 | "validation_split": 0.0, 16 | "num_workers": 1, 17 | "augment": true 18 | } 19 | }, 20 | "optimizer": { 21 | "type": "SGD", 22 | "args":{ 23 | "lr": 1e-3, 24 | "momentum": 0.9, 25 | "weight_decay": 5e-4 26 | } 27 | }, 28 | "loss": "MultiBoxLoss", 29 | "metrics": ["meanAP"], 30 | "lr_scheduler": { 31 | "type": "StepLR", 32 | "args": { 33 | "step_size": 50, 34 | "gamma": 0.1 35 | } 36 | }, 37 | "trainer": { 38 | "epochs": 200, 39 | 40 | "save_dir": "saved/", 41 | "save_period": 1, 42 | "verbosity": 2, 43 | 44 | "monitor": "min val_loss", 45 | "early_stop": 10, 46 | 47 | "tensorboardX": true 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 4 | 5 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 6 | IMAGENET_STD = [0.229, 0.224, 0.225] 7 | 8 | # Pascal VOC constants 9 | VOC_CLASS_NAMES = ( 10 | '__background__', 11 | 'aeroplane', 12 | 'bicycle', 13 | 'bird', 14 | 'boat', 15 | 'bottle', 16 | 'bus', 17 | 'car', 18 | 'cat', 19 | 'chair', 20 | 'cow', 21 | 'diningtable', 22 | 'dog', 23 | 'horse', 24 | 'motorbike', 25 | 'person', 26 | 'pottedplant', 27 | 'sheep', 28 | 'sofa', 29 | 'train', 30 | 'tvmonitor', 31 | ) 32 | VOC_ENCODING = {cl: id for id, cl in enumerate(VOC_CLASS_NAMES)} 33 | VOC_DECODING = {id: cl for id, cl in enumerate(VOC_CLASS_NAMES)} 34 | VOC_NUM_CLASSES = 21 35 | 36 | VOC_TRAIN_PARAMS = { 37 | "year": "2007", 38 | "image_set": "train" 39 | } 40 | 41 | VOC_VALID_PARAMS = { 42 | "year": "2007", 43 | "image_set": "val" 44 | } 45 | 46 | VOC_TEST_PARAMS = { 47 | "year": "2007", 48 | "image_set": "test" 49 | } 50 | 51 | # Next dataset constants 52 | -------------------------------------------------------------------------------- /data_loader/data_loaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from constants import * 5 | from base import BaseDataLoader 6 | from utils import joint_transforms as t 7 | from data_loader.datasets import ModVOCDetection 8 | 9 | 10 | class JointTransformer(object): 11 | """ Custom transformer to jointly operate on an image and its 12 | corresponding bounding boxes and labels. 13 | 14 | The important steps are: 15 | Convert bounding box pixels to percent coordinates, 16 | Resize to specific image size (i.e. 300x300), 17 | Mean subtract, 18 | Convert image to tensor 19 | """ 20 | def __init__(self, image_size, mode, augment): 21 | self.image_size = image_size 22 | self.mode = mode 23 | self.augment = augment 24 | 25 | def __call__(self, image, boxes, labels): 26 | if self.augment and self.mode == 'train': 27 | transformer = t.Compose([ 28 | t.ConvertFromPIL(), 29 | t.PhotometricDistort(), 30 | t.Expand(IMAGENET_MEAN), 31 | t.RandomSampleCrop(), 32 | t.RandomMirror(), 33 | t.ToPercentCoords(), 34 | t.Resize(self.image_size), 35 | t.Normalize(IMAGENET_MEAN, IMAGENET_STD), 36 | t.ToTensor() 37 | ]) 38 | else: 39 | transformer = t.Compose([ 40 | t.ConvertFromPIL(), 41 | t.ToPercentCoords(), 42 | t.Resize(self.image_size), 43 | t.Normalize(IMAGENET_MEAN, IMAGENET_STD), 44 | t.ToTensor() 45 | ]) 46 | return transformer(image, boxes, labels) 47 | 48 | 49 | def collate_fn(batch): 50 | """ Collate objects together in a batch. 51 | 52 | Used by the dataloader for generating batches of data. This is a 53 | simple modification of the default collate_fn. 54 | 55 | Inputs: 56 | batch: an iterable of N sets from __getitem__() 57 | 58 | Return: 59 | a tensor of images, list of varying-size tensors of bounding boxes, 60 | and list of vary-size tensors of encoded labels. 61 | """ 62 | images = [] 63 | boxes_list = [] 64 | labels_list = [] 65 | difficulties = [] 66 | 67 | for item in batch: 68 | images.append(item[0]) 69 | boxes_list.append(item[1]) 70 | labels_list.append(item[2]) 71 | difficulties.append(item[3]) 72 | 73 | images = torch.stack(images, dim=0) 74 | 75 | return images, boxes_list, labels_list, difficulties 76 | 77 | 78 | class VOCDataLoader(BaseDataLoader): 79 | """ 80 | Load Pascal VOC using BaseDataLoader 81 | """ 82 | def __init__(self, data_dir, image_size, batch_size, 83 | shuffle=True, validation_split=0.0, collate_fn=collate_fn, 84 | num_workers=1, augment=False, mode='train'): 85 | 86 | assert mode in ('train', 'valid', 'test') 87 | if mode == 'train': 88 | voc_params = VOC_TRAIN_PARAMS 89 | elif mode == 'valid': 90 | voc_params = VOC_VALID_PARAMS 91 | elif mode == 'test': 92 | voc_params = VOC_TEST_PARAMS 93 | 94 | self.data_dir = data_dir 95 | self.dataset = ModVOCDetection( 96 | self.data_dir, 97 | year=voc_params['year'], 98 | image_set=voc_params['image_set'], 99 | download=False, 100 | joint_transform=JointTransformer(image_size, 101 | mode, 102 | augment) 103 | ) 104 | 105 | super(VOCDataLoader, self).__init__(self.dataset, batch_size, 106 | shuffle, validation_split, 107 | num_workers, collate_fn) 108 | -------------------------------------------------------------------------------- /data_loader/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets 3 | import numpy as np 4 | from PIL import Image 5 | import xml.etree.ElementTree as ET 6 | 7 | from constants import VOC_ENCODING 8 | 9 | 10 | class ModVOCDetection(datasets.VOCDetection): 11 | """ Inherits PyTorch implementation of VOCDetection Dataset with 12 | necessary modifications to support joint transformations. 13 | """ 14 | def __init__(self, root, year='2007', image_set='train', 15 | download=False, joint_transform=None): 16 | 17 | self.joint_transform = joint_transform 18 | super(ModVOCDetection, self).__init__(root, 19 | year=year, 20 | image_set=image_set, 21 | download=download) 22 | 23 | def __getitem__(self, index): 24 | """ Return a tuple consisting of the image, its bounding boxes, 25 | and the labels of the bounding boxes. 26 | 27 | Args: 28 | index (int): Index 29 | Returns: 30 | tuple: (image, boxes, labels) 31 | """ 32 | img = Image.open(self.images[index]).convert('RGB') 33 | target = self.parse_voc_xml( 34 | ET.parse(self.annotations[index]).getroot()) 35 | 36 | boxes, labels, difficulties = parse_annotation_dict(target) 37 | 38 | if self.joint_transform is not None: 39 | img, boxes, labels = self.joint_transform(img, boxes, labels) 40 | 41 | return img, boxes, labels, torch.LongTensor(difficulties) 42 | 43 | 44 | def parse_annotation_dict(annot): 45 | """ Parse the annotation dictionary for a single image/label set. 46 | 47 | Annotations are stored in a nested dictionary. We require three 48 | elements: 49 | 1. labels of any bounding boxes 50 | 2. corner coordinates of any bounding boxes 51 | 3. "difficulty" of the detected object, used later for metrics 52 | """ 53 | 54 | objects = annot['annotation']['object'] 55 | 56 | labels = [] 57 | boxes = [] 58 | difficulties = [] 59 | if isinstance(objects, list): 60 | for o in objects: 61 | labels.append(VOC_ENCODING[o['name']]) 62 | bbox = o['bndbox'] 63 | boxes.append([int(bbox['xmin']) - 1, int(bbox['ymin']) - 1, 64 | int(bbox['xmax']) - 1, int(bbox['ymax']) - 1]) 65 | difficulties.append(o['difficult']) 66 | 67 | elif isinstance(objects, dict): 68 | labels.append(VOC_ENCODING[objects['name']]) 69 | bbox = objects['bndbox'] 70 | boxes.append([int(bbox['xmin']) - 1, int(bbox['ymin']) - 1, 71 | int(bbox['xmax']) - 1, int(bbox['ymax']) - 1]) 72 | difficulties.append(objects['difficult']) 73 | 74 | boxes = np.array(boxes, dtype=np.float32) 75 | labels = np.array(labels) 76 | difficulties = np.array(difficulties).astype(int) 77 | 78 | return boxes, labels, difficulties 79 | -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .visualization import * -------------------------------------------------------------------------------- /logger/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from pathlib import Path 4 | from utils import read_json 5 | 6 | 7 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO): 8 | """ 9 | Setup logging configuration 10 | """ 11 | log_config = Path(log_config) 12 | if log_config.is_file(): 13 | config = read_json(log_config) 14 | # modify logging paths based on run config 15 | for _, handler in config['handlers'].items(): 16 | if 'filename' in handler: 17 | handler['filename'] = str(save_dir / handler['filename']) 18 | 19 | logging.config.dictConfig(config) 20 | else: 21 | print("Warning: logging configuration file is not found in {}.".format(log_config)) 22 | logging.basicConfig(level=default_level) 23 | -------------------------------------------------------------------------------- /logger/logger_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "version": 1, 4 | "disable_existing_loggers": false, 5 | "formatters": { 6 | "simple": {"format": "%(message)s"}, 7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"} 8 | }, 9 | "handlers": { 10 | "console": { 11 | "class": "logging.StreamHandler", 12 | "level": "DEBUG", 13 | "formatter": "simple", 14 | "stream": "ext://sys.stdout" 15 | }, 16 | "info_file_handler": { 17 | "class": "logging.handlers.RotatingFileHandler", 18 | "level": "INFO", 19 | "formatter": "datetime", 20 | "filename": "info.log", 21 | "maxBytes": 10485760, 22 | "backupCount": 20, "encoding": "utf8" 23 | } 24 | }, 25 | "root": { 26 | "level": "INFO", 27 | "handlers": [ 28 | "console", 29 | "info_file_handler" 30 | ] 31 | } 32 | } -------------------------------------------------------------------------------- /logger/visualization.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from utils import Timer 3 | 4 | 5 | class WriterTensorboardX(): 6 | def __init__(self, log_dir, logger, enable): 7 | self.writer = None 8 | if enable: 9 | log_dir = str(log_dir) 10 | try: 11 | self.writer = importlib.import_module('tensorboardX').SummaryWriter(log_dir) 12 | except ImportError: 13 | message = "Warning: TensorboardX visualization is configured to use, but currently not installed on " \ 14 | "this machine. Please install the package by 'pip install tensorboardx' command or turn " \ 15 | "off the option in the 'config.json' file." 16 | logger.warning(message) 17 | self.step = 0 18 | self.mode = '' 19 | 20 | self.tb_writer_ftns = [ 21 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 22 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' 23 | ] 24 | self.tag_mode_exceptions = ['add_histogram', 'add_embedding'] 25 | self.timer = Timer() 26 | 27 | def set_step(self, step, mode='train'): 28 | self.mode = mode 29 | self.step = step 30 | if step == 0: 31 | self.timer.reset() 32 | else: 33 | duration = self.timer.check() 34 | self.add_scalar('steps_per_sec', 1 / duration) 35 | 36 | def __getattr__(self, name): 37 | """ 38 | If visualization is configured to use: 39 | return add_data() methods of tensorboard with additional information (step, tag) added. 40 | Otherwise: 41 | return a blank function handle that does nothing 42 | """ 43 | if name in self.tb_writer_ftns: 44 | add_data = getattr(self.writer, name, None) 45 | 46 | def wrapper(tag, data, *args, **kwargs): 47 | if add_data is not None: 48 | # add mode(train/valid) tag 49 | if name not in self.tag_mode_exceptions: 50 | tag = '{}/{}'.format(tag, self.mode) 51 | add_data(tag, data, self.step, *args, **kwargs) 52 | return wrapper 53 | else: 54 | # default action for returning methods defined in this class, set_step() for instance. 55 | try: 56 | attr = object.__getattr__(name) 57 | except AttributeError: 58 | raise AttributeError("type object 'WriterTensorboardX' has no attribute '{}'".format(name)) 59 | return attr 60 | -------------------------------------------------------------------------------- /model/cooc_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import numbers 6 | 7 | from utils import * 8 | from constants import * 9 | 10 | 11 | 12 | class CoocLayer(nn.Module): 13 | """ Co-occurrence layer as proposed in Shih et al. (CVPR 2017) 14 | """ 15 | def __init__(self, in_channels, out_channels=32): 16 | super().__init__() 17 | 18 | # dimensionality reduction 19 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 20 | 21 | # gaussian filtering for noise removal 22 | self.gaussian = GaussianSmoothing(channels=out_channels, kernel_size=5, sigma=1) 23 | 24 | def forward(self, x): 25 | """ input is the image activations following a convolutional layer. 26 | dimensions: N x C x H x W 27 | 28 | The co-occurrence layer computes a vector of length C ** 2 29 | """ 30 | 31 | x = F.relu(self.conv1(x)) 32 | 33 | x = F.pad(x, (2, 2, 2, 2), mode='reflect') 34 | x = self.gaussian(x) 35 | 36 | N, C, H, W = x.size() 37 | 38 | # list of length H*W of (N, C, H, W) tensors containing each offset 39 | x_offsets = [self.roll(self.roll(x, i, 2), j, 3) for i in range(H) for j in range(W)] 40 | x_offsets = torch.cat(x_offsets, 1).to(DEVICE) # (N, C*H*W, H, W) 41 | x_offsets = torch.view(N, C * H * W, H * W).permute(0, 2, 1) # (N, H*W, C*H*W) 42 | 43 | x_base = x.view(N, C, H * W) # (N, C, H*W) 44 | corrs = torch.bmm(x_base, x_offsets) # (N, C, C*H*W) 45 | corrs = corrs.view(N, C * C, H * W).permute(0, 2, 1) 46 | c_ij, best_offset = torch.max(corrs, 1) # (N, C*C) 47 | 48 | return c_ij 49 | 50 | @staticmethod 51 | def roll(tensor, shift, axis): 52 | """ https://discuss.pytorch.org/t/implementation-of-function-like-numpy-roll/964/6 """ 53 | if shift == 0: 54 | return tensor 55 | 56 | if axis < 0: 57 | axis += tensor.dim() 58 | 59 | dim_size = tensor.size(axis) 60 | after_start = dim_size - shift 61 | if shift < 0: 62 | after_start = -shift 63 | shift = dim_size - abs(shift) 64 | 65 | before = tensor.narrow(axis, 0, dim_size - shift) 66 | after = tensor.narrow(axis, after_start, shift) 67 | return torch.cat([after, before], axis) 68 | 69 | 70 | class SpatialCoocLayer(CoocLayer): 71 | """ Novel adaptation of co-occurrence layer for spatially localized 72 | co-occurrences. Instead of computing global co-occurrences over the 73 | entire activations, this returns a spatial activation map of local 74 | co-occurrences using pooling operations. 75 | """ 76 | def __init__(self, in_channels, out_channels=32, local_kernel=5): 77 | super().__init__(in_channels, out_channels) 78 | 79 | same_pad = (local_kernel - 1) // 2 80 | self.avgpool = nn.AvgPool2d(kernel_size=local_kernel, stride=1, 81 | padding=same_pad) 82 | 83 | def forward(self, x): 84 | 85 | x = F.relu(self.conv1(x)) 86 | 87 | x = F.pad(x, (2, 2, 2, 2), mode='reflect') 88 | x = self.gaussian(x) 89 | 90 | N, C, H, W = x.size() 91 | 92 | # list of length H*W of (N, C, H, W) tensors containing each offset 93 | x_offsets = [self.roll(self.roll(x, i, 2), j, 3) for i in range(H) for j in range(W)] 94 | x_offsets = torch.cat(x_offsets, 1).to(DEVICE) # (N, C*H*W, H, W) 95 | 96 | all_channels = [] 97 | for c in range(C): 98 | x_channel = x[:, c, :, :].unsqueeze(1) # (N, C*H*W, H, W) 99 | channel_pairs = torch.mul(x_channel, x_offsets) # (N, C*H*W, H, W) 100 | channel_pairs = channel_pairs.view(N * C, H * W, H, W) 101 | 102 | local_corrs = self.avgpool(channel_pairs) # (N*C, H*W, H, W) 103 | 104 | max_corrs_over_offsets, _ = torch.max(local_corrs, dim=1) # (N*C, 1, H, W) 105 | all_channels.append(max_corrs_over_offsets.view(N, C, H, W)) 106 | 107 | c_ij = torch.cat(all_channels, dim=1) # (N, C*C, H, W) 108 | return c_ij 109 | 110 | 111 | class GaussianSmoothing(nn.Module): 112 | """ 113 | (From https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/10) 114 | Apply gaussian smoothing on a 115 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 116 | in the input using a depthwise convolution. 117 | Arguments: 118 | channels (int, sequence): Number of channels of the input tensors. Output will 119 | have this number of channels as well. 120 | kernel_size (int, sequence): Size of the gaussian kernel. 121 | sigma (float, sequence): Standard deviation of the gaussian kernel. 122 | dim (int, optional): The number of dimensions of the data. 123 | Default value is 2 (spatial). 124 | """ 125 | def __init__(self, channels, kernel_size, sigma, dim=2): 126 | super(GaussianSmoothing, self).__init__() 127 | if isinstance(kernel_size, numbers.Number): 128 | kernel_size = [kernel_size] * dim 129 | if isinstance(sigma, numbers.Number): 130 | sigma = [sigma] * dim 131 | 132 | # The gaussian kernel is the product of the 133 | # gaussian function of each dimension. 134 | kernel = 1 135 | meshgrids = torch.meshgrid( 136 | [ 137 | torch.arange(size, dtype=torch.float32) 138 | for size in kernel_size 139 | ] 140 | ) 141 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 142 | mean = (size - 1) / 2 143 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ 144 | torch.exp(-((mgrid - mean) / std) ** 2 / 2) 145 | 146 | # Make sure sum of values in gaussian kernel equals 1. 147 | kernel = kernel / torch.sum(kernel) 148 | 149 | # Reshape to depthwise convolutional weight 150 | kernel = kernel.view(1, 1, *kernel.size()) 151 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 152 | 153 | self.register_buffer('weight', kernel) 154 | self.groups = channels 155 | 156 | if dim == 1: 157 | self.conv = F.conv1d 158 | elif dim == 2: 159 | self.conv = F.conv2d 160 | elif dim == 3: 161 | self.conv = F.conv3d 162 | else: 163 | raise RuntimeError( 164 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 165 | ) 166 | 167 | def forward(self, input): 168 | """ 169 | Apply gaussian filter to input. 170 | Arguments: 171 | input (torch.Tensor): Input to apply gaussian filter on. 172 | Returns: 173 | filtered (torch.Tensor): Filtered output. 174 | """ 175 | return self.conv(input, weight=self.weight, groups=self.groups) 176 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torchvision.models import vgg16_bn 5 | from base import BaseModel 6 | from math import sqrt 7 | from itertools import product as product 8 | import torchvision 9 | 10 | from utils import cxcy_to_gcxgcy, cxcy_to_xy, xy_to_cxcy, gcxgcy_to_cxcy 11 | from utils import get_default_boxes, find_jaccard_overlap 12 | from constants import DEVICE 13 | 14 | 15 | class SSDLoss(nn.Module): 16 | 17 | def __init__(self, threshold, neg_pos_ratio, alpha, device): 18 | super().__init__() 19 | self.default_cxcy = get_default_boxes().to(device) 20 | self.default_xy = cxcy_to_xy(self.default_cxcy) 21 | 22 | self.threshold = threshold 23 | self.hard_neg_scale = neg_pos_ratio 24 | self.alpha = alpha 25 | self.device = device 26 | 27 | self.smooth_l1 = nn.SmoothL1Loss(reduction='none') 28 | self.cross_entropy = nn.CrossEntropyLoss(reduction='none') 29 | 30 | 31 | def forward(self, output_boxes, output_scores, true_boxes, true_labels): 32 | 33 | batch_size = output_boxes.size(0) 34 | n_classes = output_boxes.size(2) 35 | n_priors = self.default_cxcy.size(0) 36 | 37 | gt_locs = torch.Tensor(batch_size, n_priors, 4).to(self.device) 38 | gt_class = torch.LongTensor(batch_size, n_priors).to(self.device) 39 | 40 | for im in range(batch_size): 41 | n_objects = true_boxes[im].size(0) 42 | 43 | # compute IoU for each ground truth box with default boxes 44 | # (n_objects, 8732) 45 | overlaps = find_jaccard_overlap(true_boxes[im], self.default_xy) 46 | 47 | # find highest-overlap object for each default, and then highest- 48 | # overlap default for each object 49 | overlap_per_default, object_per_default = overlaps.max(dim=0) 50 | overlap_per_object, default_per_object = overlaps.max(dim=1) 51 | 52 | # assign object to default box with highest overlap 53 | object_per_default[default_per_object] = torch.LongTensor(range(n_objects)).to(self.device) 54 | 55 | # give these default boxes an overlap of 1 (ensure positive) 56 | overlap_per_default[default_per_object] = 1. 57 | 58 | # assign labels to the default boxes according to the best overlap 59 | default_labels = true_labels[im][object_per_default] 60 | default_labels[overlap_per_default < self.threshold] = 0 61 | 62 | gt_class[im] = default_labels 63 | gt_locs[im] = cxcy_to_gcxgcy(xy_to_cxcy(true_boxes[im][object_per_default]), self.default_cxcy) 64 | 65 | positive_defaults = (gt_class > 0) 66 | 67 | # localization loss 68 | L_loc = self.smooth_l1(output_boxes[positive_defaults], true_boxes[positive_defaults]) 69 | 70 | # confidence loss 71 | n_positives = positive_defaults.sum(dim=1) # (N) 72 | n_hard_negatives = self.hard_neg_scale * n_positives 73 | 74 | conf_all = output_scores.view(-1, n_classes) 75 | L_conf_all = self.cross_entropy(conf_all, gt_class.view(-1)) 76 | L_conf_all = L_conf_all.view(batch_size, n_priors) # (N, 8732) 77 | 78 | # We already know which priors are positive 79 | L_conf_pos = L_conf_all[positive_defaults] # (sum(n_positives)) 80 | 81 | # Next, find which priors are hard-negative 82 | # To do this, sort ONLY negative priors in each image in order of decreasing loss and take top n_hard_negatives 83 | L_conf_neg = L_con_all.clone() # (N, 8732) 84 | L_conf_neg[positive_defaults] = 0. # (N, 8732), positive priors are ignored (never in top n_hard_negatives) 85 | L_conf_neg, _ = L_conf_neg.sort(dim=1, descending=True) # (N, 8732), sorted by decreasing hardness 86 | hardness_ranks = torch.LongTensor(range(n_priors)).unsqueeze(0).expand_as(L_conf_neg).to(device) # (N, 8732) 87 | hard_negatives = hardness_ranks < n_hard_negatives.unsqueeze(1) 88 | L_conf_hard_neg = L_conf_neg[hard_negatives] 89 | 90 | # As in the paper, averaged over positive priors only, although computed over both positive and hard-negative priors 91 | L_conf = (L_conf_hard_neg.sum() + L_conf_pos.sum()) / n_positives.sum().float() # (), scalar 92 | 93 | loss = L_conf + self.alpha * L_loc 94 | return loss 95 | 96 | 97 | class MultiBoxLoss(nn.Module): 98 | """ 99 | https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection/blob/master/model.py 100 | 101 | The MultiBox loss, a loss function for object detection. 102 | This is a combination of: 103 | (1) a localization loss for the predicted locations of the boxes, and 104 | (2) a confidence loss for the predicted class scores. 105 | """ 106 | 107 | def __init__(self, threshold=0.5, neg_pos_ratio=3, alpha=1., device=DEVICE): 108 | super(MultiBoxLoss, self).__init__() 109 | self.priors_cxcy = get_default_boxes().to(device) 110 | self.priors_xy = cxcy_to_xy(self.priors_cxcy) 111 | self.threshold = threshold 112 | self.neg_pos_ratio = neg_pos_ratio 113 | self.alpha = alpha 114 | self.device = device 115 | 116 | self.smooth_l1 = nn.L1Loss() 117 | self.cross_entropy = nn.CrossEntropyLoss(reduce=False) 118 | 119 | def forward(self, predicted_locs, predicted_scores, boxes, labels): 120 | """ 121 | Forward propagation. 122 | :param predicted_locs: predicted locations/boxes w.r.t the 8732 prior boxes, a tensor of dimensions (N, 8732, 4) 123 | :param predicted_scores: class scores for each of the encoded locations/boxes, a tensor of dimensions (N, 8732, n_classes) 124 | :param boxes: true object bounding boxes in boundary coordinates, a list of N tensors 125 | :param labels: true object labels, a list of N tensors 126 | :return: multibox loss, a scalar 127 | """ 128 | batch_size = predicted_locs.size(0) 129 | n_priors = self.priors_cxcy.size(0) 130 | n_classes = predicted_scores.size(2) 131 | 132 | assert n_priors == predicted_locs.size(1) == predicted_scores.size(1) 133 | 134 | true_locs = torch.zeros((batch_size, n_priors, 4), dtype=torch.float).to(self.device) # (N, 8732, 4) 135 | true_classes = torch.zeros((batch_size, n_priors), dtype=torch.long).to(self.device) # (N, 8732) 136 | 137 | # For each image 138 | for i in range(batch_size): 139 | n_objects = boxes[i].size(0) 140 | 141 | overlap = find_jaccard_overlap(boxes[i], 142 | self.priors_xy) # (n_objects, 8732) 143 | 144 | # For each prior, find the object that has the maximum overlap 145 | overlap_for_each_prior, object_for_each_prior = overlap.max(dim=0) # (8732) 146 | 147 | # We don't want a situation where an object is not represented in our positive (non-background) priors - 148 | # 1. An object might not be the best object for all priors, and is therefore not in object_for_each_prior. 149 | # 2. All priors with the object may be assigned as background based on the threshold (0.5). 150 | 151 | # To remedy this - 152 | # First, find the prior that has the maximum overlap for each object. 153 | _, prior_for_each_object = overlap.max(dim=1) # (N_o) 154 | 155 | # Then, assign each object to the corresponding maximum-overlap-prior. (This fixes 1.) 156 | object_for_each_prior[prior_for_each_object] = torch.LongTensor(range(n_objects)).to(self.device) 157 | 158 | # To ensure these priors qualify, artificially give them an overlap of greater than 0.5. (This fixes 2.) 159 | overlap_for_each_prior[prior_for_each_object] = 1. 160 | 161 | # Labels for each prior 162 | label_for_each_prior = labels[i][object_for_each_prior] # (8732) 163 | # Set priors whose overlaps with objects are less than the threshold to be background (no object) 164 | label_for_each_prior[overlap_for_each_prior < self.threshold] = 0 # (8732) 165 | 166 | # Store 167 | true_classes[i] = label_for_each_prior 168 | 169 | # Encode center-size object coordinates into the form we regressed predicted boxes to 170 | true_locs[i] = cxcy_to_gcxgcy(xy_to_cxcy(boxes[i][object_for_each_prior]), self.priors_cxcy) # (8732, 4) 171 | 172 | # Identify priors that are positive (object/non-background) 173 | positive_priors = true_classes != 0 # (N, 8732) 174 | 175 | # LOCALIZATION LOSS 176 | 177 | # Localization loss is computed only over positive (non-background) priors 178 | loc_loss = self.smooth_l1(predicted_locs[positive_priors], true_locs[positive_priors]) # (), scalar 179 | 180 | # Note: indexing with a torch.uint8 (byte) tensor flattens the tensor when indexing is across multiple dimensions (N & 8732) 181 | # So, if predicted_locs has the shape (N, 8732, 4), predicted_locs[positive_priors] will have (total positives, 4) 182 | 183 | # CONFIDENCE LOSS 184 | 185 | # Confidence loss is computed over positive priors and the most difficult (hardest) negative priors in each image 186 | # That is, FOR EACH IMAGE, 187 | # we will take the hardest (neg_pos_ratio * n_positives) negative priors, i.e where there is maximum loss 188 | # This is called Hard Negative Mining - it concentrates on hardest negatives in each image, and also minimizes pos/neg imbalance 189 | 190 | # Number of positive and hard-negative priors per image 191 | n_positives = positive_priors.sum(dim=1) # (N) 192 | n_hard_negatives = self.neg_pos_ratio * n_positives # (N) 193 | 194 | # First, find the loss for all priors 195 | conf_loss_all = self.cross_entropy(predicted_scores.view(-1, n_classes), true_classes.view(-1)) # (N * 8732) 196 | conf_loss_all = conf_loss_all.view(batch_size, n_priors) # (N, 8732) 197 | 198 | # We already know which priors are positive 199 | conf_loss_pos = conf_loss_all[positive_priors] # (sum(n_positives)) 200 | 201 | # Next, find which priors are hard-negative 202 | # To do this, sort ONLY negative priors in each image in order of decreasing loss and take top n_hard_negatives 203 | conf_loss_neg = conf_loss_all.clone() # (N, 8732) 204 | conf_loss_neg[positive_priors] = 0. # (N, 8732), positive priors are ignored (never in top n_hard_negatives) 205 | conf_loss_neg, _ = conf_loss_neg.sort(dim=1, descending=True) # (N, 8732), sorted by decreasing hardness 206 | hardness_ranks = torch.LongTensor(range(n_priors)).unsqueeze(0).expand_as(conf_loss_neg).to(self.device) # (N, 8732) 207 | hard_negatives = hardness_ranks < n_hard_negatives.unsqueeze(1) # (N, 8732) 208 | conf_loss_hard_neg = conf_loss_neg[hard_negatives] # (sum(n_hard_negatives)) 209 | 210 | # As in the paper, averaged over positive priors only, although computed over both positive and hard-negative priors 211 | conf_loss = (conf_loss_hard_neg.sum() + conf_loss_pos.sum()) / n_positives.sum().float() # (), scalar 212 | 213 | # TOTAL LOSS 214 | 215 | return conf_loss + self.alpha * loc_loss 216 | -------------------------------------------------------------------------------- /model/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from constants import VOC_ENCODING, VOC_ENCODING, VOC_DECODING 3 | from utils import find_jaccard_overlap 4 | 5 | from constants import DEVICE 6 | 7 | 8 | def meanAP(det_boxes, det_labels, det_scores, true_boxes, true_labels, true_difficulties): 9 | """ 10 | Calculate the Mean Average Precision (mAP) of detected objects. 11 | See https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173 for an explanation 12 | :param det_boxes: list of tensors, one tensor for each image containing detected objects' bounding boxes 13 | :param det_labels: list of tensors, one tensor for each image containing detected objects' labels 14 | :param det_scores: list of tensors, one tensor for each image containing detected objects' labels' scores 15 | :param true_boxes: list of tensors, one tensor for each image containing actual objects' bounding boxes 16 | :param true_labels: list of tensors, one tensor for each image containing actual objects' labels 17 | :param true_difficulties: list of tensors, one tensor for each image containing actual objects' difficulty (0 or 1) 18 | :return: list of average precisions for all classes, mean average precision (mAP) 19 | """ 20 | assert len(det_boxes) == len(det_labels) == len(det_scores) == len(true_boxes) == len( 21 | true_labels) # these are all lists of tensors of the same length, i.e. number of images 22 | n_classes = len(VOC_ENCODING) 23 | 24 | # Store all (true) objects in a single continuous tensor while keeping track of the image it is from 25 | true_images = list() 26 | for i in range(len(true_labels)): 27 | true_images.extend([i] * true_labels[i].size(0)) 28 | true_images = torch.LongTensor(true_images).to(DEVICE) # (n_objects), n_objects is the total no. of objects across all images 29 | true_boxes = torch.cat(true_boxes, dim=0) # (n_objects, 4) 30 | true_labels = torch.cat(true_labels, dim=0) # (n_objects) 31 | true_difficulties = torch.cat(true_difficulties, dim=0) # (n_objects) 32 | 33 | assert true_images.size(0) == true_boxes.size(0) == true_labels.size(0) 34 | 35 | # Store all detections in a single continuous tensor while keeping track of the image it is from 36 | det_images = list() 37 | for i in range(len(det_labels)): 38 | det_images.extend([i] * det_labels[i].size(0)) 39 | det_images = torch.LongTensor(det_images).to(DEVICE) # (n_detections) 40 | det_boxes = torch.cat(det_boxes, dim=0) # (n_detections, 4) 41 | det_labels = torch.cat(det_labels, dim=0) # (n_detections) 42 | det_scores = torch.cat(det_scores, dim=0) # (n_detections) 43 | 44 | assert det_images.size(0) == det_boxes.size(0) == det_labels.size(0) == det_scores.size(0) 45 | 46 | # Calculate APs for each class (except background) 47 | average_precisions = torch.zeros((n_classes - 1), dtype=torch.float) # (n_classes - 1) 48 | for c in range(1, n_classes): 49 | # Extract only objects with this class 50 | true_class_images = true_images[true_labels == c] # (n_class_objects) 51 | true_class_boxes = true_boxes[true_labels == c] # (n_class_objects, 4) 52 | true_class_difficulties = true_difficulties[true_labels == c] # (n_class_objects) 53 | n_easy_class_objects = (1 - true_class_difficulties).sum().item() # ignore difficult objects 54 | 55 | # Keep track of which true objects with this class have already been 'detected' 56 | # So far, none 57 | true_class_boxes_detected = torch.zeros((true_class_difficulties.size(0)), dtype=torch.uint8).to(DEVICE) # (n_class_objects) 58 | 59 | # Extract only detections with this class 60 | det_class_images = det_images[det_labels == c] # (n_class_detections) 61 | det_class_boxes = det_boxes[det_labels == c] # (n_class_detections, 4) 62 | det_class_scores = det_scores[det_labels == c] # (n_class_detections) 63 | n_class_detections = det_class_boxes.size(0) 64 | if n_class_detections == 0: 65 | continue 66 | 67 | # Sort detections in decreasing order of confidence/scores 68 | det_class_scores, sort_ind = torch.sort(det_class_scores, dim=0, descending=True) # (n_class_detections) 69 | det_class_images = det_class_images[sort_ind] # (n_class_detections) 70 | det_class_boxes = det_class_boxes[sort_ind] # (n_class_detections, 4) 71 | 72 | # In the order of decreasing scores, check if true or false positive 73 | true_positives = torch.zeros((n_class_detections), dtype=torch.float).to(DEVICE) # (n_class_detections) 74 | false_positives = torch.zeros((n_class_detections), dtype=torch.float).to(DEVICE) # (n_class_detections) 75 | for d in range(n_class_detections): 76 | this_detection_box = det_class_boxes[d].unsqueeze(0) # (1, 4) 77 | this_image = det_class_images[d] # (), scalar 78 | 79 | # Find objects in the same image with this class, their difficulties, and whether they have been detected before 80 | object_boxes = true_class_boxes[true_class_images == this_image] # (n_class_objects_in_img) 81 | object_difficulties = true_class_difficulties[true_class_images == this_image] # (n_class_objects_in_img) 82 | # If no such object in this image, then the detection is a false positive 83 | if object_boxes.size(0) == 0: 84 | false_positives[d] = 1 85 | continue 86 | 87 | # Find maximum overlap of this detection with objects in this image of this class 88 | overlaps = find_jaccard_overlap(this_detection_box, object_boxes) # (1, n_class_objects_in_img) 89 | max_overlap, ind = torch.max(overlaps.squeeze(0), dim=0) # (), () - scalars 90 | 91 | # 'ind' is the index of the object in these image-level tensors 'object_boxes', 'object_difficulties' 92 | # In the original class-level tensors 'true_class_boxes', etc., 'ind' corresponds to object with index... 93 | original_ind = torch.LongTensor(range(true_class_boxes.size(0)))[true_class_images == this_image][ind] 94 | # We need 'original_ind' to update 'true_class_boxes_detected' 95 | 96 | # If the maximum overlap is greater than the threshold of 0.5, it's a match 97 | if max_overlap.item() > 0.5: 98 | # If the object it matched with is 'difficult', ignore it 99 | if object_difficulties[ind] == 0: 100 | # If this object has already not been detected, it's a true positive 101 | if true_class_boxes_detected[original_ind] == 0: 102 | true_positives[d] = 1 103 | true_class_boxes_detected[original_ind] = 1 # this object has now been detected/accounted for 104 | # Otherwise, it's a false positive (since this object is already accounted for) 105 | else: 106 | false_positives[d] = 1 107 | # Otherwise, the detection occurs in a different location than the actual object, and is a false positive 108 | else: 109 | false_positives[d] = 1 110 | 111 | # Compute cumulative precision and recall at each detection in the order of decreasing scores 112 | cumul_true_positives = torch.cumsum(true_positives, dim=0) # (n_class_detections) 113 | cumul_false_positives = torch.cumsum(false_positives, dim=0) # (n_class_detections) 114 | cumul_precision = cumul_true_positives / ( 115 | cumul_true_positives + cumul_false_positives + 1e-10) # (n_class_detections) 116 | cumul_recall = cumul_true_positives / n_easy_class_objects # (n_class_detections) 117 | 118 | # Find the mean of the maximum of the precisions corresponding to recalls above the threshold 't' 119 | recall_thresholds = torch.arange(start=0, end=1.1, step=.1).tolist() # (11) 120 | precisions = torch.zeros((len(recall_thresholds)), dtype=torch.float).to(DEVICE) # (11) 121 | for i, t in enumerate(recall_thresholds): 122 | recalls_above_t = cumul_recall >= t 123 | if recalls_above_t.any(): 124 | precisions[i] = cumul_precision[recalls_above_t].max() 125 | else: 126 | precisions[i] = 0. 127 | average_precisions[c - 1] = precisions.mean() # c is in [1, n_classes - 1] 128 | 129 | # Calculate Mean Average Precision (mAP) 130 | mean_average_precision = average_precisions.mean().item() 131 | 132 | # Keep class-wise average precisions in a dictionary 133 | average_precisions = {VOC_DECODING[c + 1]: v for c, v in enumerate(average_precisions.tolist())} 134 | 135 | return average_precisions, mean_average_precision 136 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models import vgg16 5 | from math import sqrt 6 | 7 | from utils import * 8 | from constants import * 9 | from .cooc_layers import SpatialCoocLayer 10 | 11 | 12 | class VGG16(nn.Module): 13 | """ Modified VGG16 base for generating features from image 14 | """ 15 | def __init__(self): 16 | super(VGG16, self).__init__() 17 | 18 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) 19 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 20 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 21 | 22 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 23 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 24 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 25 | 26 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 27 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 28 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 29 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) 30 | 31 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 32 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 33 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 34 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 35 | 36 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 37 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 38 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 39 | self.pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 40 | 41 | # Replace the VGG16 FC layers with additional conv2d (see Fig. 2) 42 | self.conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6) 43 | self.conv7 = nn.Conv2d(1024, 1024, kernel_size=1) 44 | 45 | # Load pretrained layers 46 | self.load_pretrained_layers() 47 | 48 | def forward(self, x): 49 | 50 | out = F.relu(self.conv1_1(x)) # (N, 64, 300, 300) 51 | out = F.relu(self.conv1_2(out)) # (N, 64, 300, 300) 52 | out = self.pool1(out) # (N, 64, 150, 150) 53 | 54 | out = F.relu(self.conv2_1(out)) # (N, 128, 150, 150) 55 | out = F.relu(self.conv2_2(out)) # (N, 128, 150, 150) 56 | out = self.pool2(out) # (N, 128, 75, 75) 57 | 58 | out = F.relu(self.conv3_1(out)) # (N, 256, 75, 75) 59 | out = F.relu(self.conv3_2(out)) # (N, 256, 75, 75) 60 | out = F.relu(self.conv3_3(out)) # (N, 256, 75, 75) 61 | out = self.pool3(out) # (N, 256, 38, 38) (note ceil_mode=True) 62 | 63 | out = F.relu(self.conv4_1(out)) # (N, 512, 38, 38) 64 | out = F.relu(self.conv4_2(out)) # (N, 512, 38, 38) 65 | out = F.relu(self.conv4_3(out)) # (N, 512, 38, 38) 66 | conv4_out = out # (N, 512, 38, 38) 67 | out = self.pool4(out) # (N, 512, 19, 19) 68 | 69 | out = F.relu(self.conv5_1(out)) # (N, 512, 19, 19) 70 | out = F.relu(self.conv5_2(out)) # (N, 512, 19, 19) 71 | out = F.relu(self.conv5_3(out)) # (N, 512, 19, 19) 72 | out = self.pool5(out) # (N, 512, 19, 19) 73 | 74 | out = F.relu(self.conv6(out)) # (N, 1024, 19, 19) 75 | 76 | conv7_out = F.relu(self.conv7(out)) # (N, 1024, 19, 19) 77 | 78 | return conv4_out, conv7_out 79 | 80 | def load_pretrained_layers(self): 81 | """ 82 | (This function as defined in https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection)) 83 | 84 | As in the paper, we use a VGG-16 pretrained on the ImageNet task as the base network. 85 | There's one available in PyTorch, see https://pytorch.org/docs/stable/torchvision/models.html#torchvision.models.vgg16 86 | We copy these parameters into our network. It's straightforward for conv1 to conv5. 87 | However, the original VGG-16 does not contain the conv6 and con7 layers. 88 | Therefore, we convert fc6 and fc7 into convolutional layers, and subsample by decimation. See 'decimate' in utils.py. 89 | """ 90 | # Current state of base 91 | state_dict = self.state_dict() 92 | param_names = list(state_dict.keys()) 93 | 94 | # Pretrained VGG base 95 | pretrained_state_dict = vgg16(pretrained=True).state_dict() 96 | pretrained_param_names = list(pretrained_state_dict.keys()) 97 | 98 | # Transfer conv. parameters from pretrained model to current model 99 | for i, param in enumerate(param_names[:-4]): # excluding conv6 and conv7 parameters 100 | state_dict[param] = pretrained_state_dict[pretrained_param_names[i]] 101 | 102 | # Convert fc6, fc7 to convolutional layers, and subsample (by decimation) to sizes of conv6 and conv7 103 | # fc6 104 | conv_fc6_weight = pretrained_state_dict['classifier.0.weight'].view(4096, 512, 7, 7) # (4096, 512, 7, 7) 105 | conv_fc6_bias = pretrained_state_dict['classifier.0.bias'] # (4096) 106 | state_dict['conv6.weight'] = decimate(conv_fc6_weight, m=[4, None, 3, 3]) # (1024, 512, 3, 3) 107 | state_dict['conv6.bias'] = decimate(conv_fc6_bias, m=[4]) # (1024) 108 | # fc7 109 | conv_fc7_weight = pretrained_state_dict['classifier.3.weight'].view(4096, 4096, 1, 1) # (4096, 4096, 1, 1) 110 | conv_fc7_bias = pretrained_state_dict['classifier.3.bias'] # (4096) 111 | state_dict['conv7.weight'] = decimate(conv_fc7_weight, m=[4, 4, None, None]) # (1024, 1024, 1, 1) 112 | state_dict['conv7.bias'] = decimate(conv_fc7_bias, m=[4]) # (1024) 113 | 114 | self.load_state_dict(state_dict) 115 | 116 | print("\nLoaded base model.\n") 117 | 118 | 119 | class ExtraLayers(nn.Module): 120 | """ Additional convolutions after VGG16 for feature scaling. """ 121 | def __init__(self): 122 | super(ExtraLayers, self).__init__() 123 | 124 | self.conv8_1 = nn.Conv2d(1024, 256, kernel_size=1, padding=0) 125 | self.conv8_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) 126 | 127 | self.conv9_1 = nn.Conv2d(512, 128, kernel_size=1, padding=0) 128 | self.conv9_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) 129 | 130 | self.conv10_1 = nn.Conv2d(256, 128, kernel_size=1, padding=0) 131 | self.conv10_2 = nn.Conv2d(128, 256, kernel_size=3, padding=0) 132 | 133 | self.conv11_1 = nn.Conv2d(256, 128, kernel_size=1, padding=0) 134 | self.conv11_2 = nn.Conv2d(128, 256, kernel_size=3, padding=0) 135 | 136 | for layer in self.children(): 137 | if isinstance(layer, nn.Conv2d): 138 | nn.init.xavier_uniform_(layer.weight) 139 | nn.init.constant_(layer.bias, 0.) 140 | 141 | def forward(self, conv7_out): 142 | 143 | out = F.relu(self.conv8_1(conv7_out)) # (N, 256, 19, 19) 144 | out = F.relu(self.conv8_2(out)) # (N, 512, 10, 10) 145 | conv8_out = out # (N, 512, 10, 10) 146 | 147 | out = F.relu(self.conv9_1(out)) # (N, 128, 10, 10) 148 | out = F.relu(self.conv9_2(out)) # (N, 256, 5, 5) 149 | conv9_out = out # (N, 256, 5, 5) 150 | 151 | out = F.relu(self.conv10_1(out)) # (N, 128, 5, 5) 152 | out = F.relu(self.conv10_2(out)) # (N, 256, 3, 3) 153 | conv10_out = out # (N, 256, 3, 3) 154 | 155 | out = F.relu(self.conv11_1(out)) # (N, 128, 3, 3) 156 | conv11_out = F.relu(self.conv11_2(out)) # (N, 256, 1, 1) 157 | 158 | return conv8_out, conv9_out, conv10_out, conv11_out 159 | 160 | 161 | class Classifiers(nn.Module): 162 | 163 | def __init__(self, n_classes, n_boxes): 164 | super(Classifiers, self).__init__() 165 | 166 | self.n_classes = n_classes 167 | self.n_boxes = n_boxes 168 | assert len(self.n_boxes) == 6 169 | 170 | self.box4 = nn.Conv2d(512, n_boxes[0] * 4, kernel_size=3, padding=1) 171 | self.box7 = nn.Conv2d(1024, n_boxes[1] * 4, kernel_size=3, padding=1) 172 | self.box8 = nn.Conv2d(512, n_boxes[2] * 4, kernel_size=3, padding=1) 173 | self.box9 = nn.Conv2d(256, n_boxes[3] * 4, kernel_size=3, padding=1) 174 | self.box10 = nn.Conv2d(256, n_boxes[4] * 4, kernel_size=3, padding=1) 175 | self.box11 = nn.Conv2d(256, n_boxes[5] * 4, kernel_size=3, padding=1) 176 | 177 | self.class4 = nn.Conv2d(512, n_boxes[0] * n_classes, kernel_size=3, padding=1) 178 | self.class7 = nn.Conv2d(1024, n_boxes[1] * n_classes, kernel_size=3, padding=1) 179 | self.class8 = nn.Conv2d(512, n_boxes[2] * n_classes, kernel_size=3, padding=1) 180 | self.class9 = nn.Conv2d(256, n_boxes[3] * n_classes, kernel_size=3, padding=1) 181 | self.class10 = nn.Conv2d(256, n_boxes[4] * n_classes, kernel_size=3, padding=1) 182 | self.class11 = nn.Conv2d(256, n_boxes[5] * n_classes, kernel_size=3, padding=1) 183 | 184 | for layer in self.children(): 185 | if isinstance(layer, nn.Conv2d): 186 | nn.init.xavier_uniform_(layer.weight) 187 | nn.init.constant_(layer.bias, 0.) 188 | 189 | def forward(self, inputs): 190 | conv4_out, conv7_out, conv8_out, conv9_out, conv10_out, conv11_out = inputs 191 | 192 | N = conv4_out.size(0) 193 | 194 | # swap dimensions around and reassign in natural layout (contiguous) 195 | box4 = self.box4(conv4_out).permute(0, 2, 3, 1).contiguous() 196 | box7 = self.box7(conv7_out).permute(0, 2, 3, 1).contiguous() 197 | box8 = self.box8(conv8_out).permute(0, 2, 3, 1).contiguous() 198 | box9 = self.box9(conv9_out).permute(0, 2, 3, 1).contiguous() 199 | box10 = self.box10(conv10_out).permute(0, 2, 3, 1).contiguous() 200 | box11 = self.box11(conv11_out).permute(0, 2, 3, 1).contiguous() 201 | 202 | class4 = self.class4(conv4_out).permute(0, 2, 3, 1).contiguous() 203 | class7 = self.class7(conv7_out).permute(0, 2, 3, 1).contiguous() 204 | class8 = self.class8(conv8_out).permute(0, 2, 3, 1).contiguous() 205 | class9 = self.class9(conv9_out).permute(0, 2, 3, 1).contiguous() 206 | class10 = self.class10(conv10_out).permute(0, 2, 3, 1).contiguous() 207 | class11 = self.class11(conv11_out).permute(0, 2, 3, 1).contiguous() 208 | 209 | # reshape to match expected bounding box and class score shapes 210 | box4 = box4.view(N, -1, 4) 211 | box7 = box7.view(N, -1, 4) 212 | box8 = box8.view(N, -1, 4) 213 | box9 = box9.view(N, -1, 4) 214 | box10 = box10.view(N, -1, 4) 215 | box11 = box11.view(N, -1, 4) 216 | 217 | class4 = class4.view(N, -1, self.n_classes) 218 | class7 = class7.view(N, -1, self.n_classes) 219 | class8 = class8.view(N, -1, self.n_classes) 220 | class9 = class9.view(N, -1, self.n_classes) 221 | class10 = class10.view(N, -1, self.n_classes) 222 | class11 = class11.view(N, -1, self.n_classes) 223 | 224 | boxes = torch.cat([box4, box7, box8, box9, box10, box11], dim=1) 225 | classes = torch.cat([class4, class7, class8, class9, class10, class11], 226 | dim=1) 227 | 228 | return boxes, classes 229 | 230 | 231 | class SSD300(nn.Module): 232 | """ The full SSD300 network. 233 | 234 | The full process involves: 235 | 1) run VGG16 base on the image and extract layer 4 & 7 features. 236 | 2) run ExtraLayers to systematically downscale image while pulling 237 | features from each scaling. 238 | 3) run Classifiers to compute boxes and classes for each 239 | feature set. 240 | """ 241 | 242 | def __init__(self, n_classes=VOC_NUM_CLASSES, n_boxes=(4, 6, 6, 6, 4, 4)): 243 | super(SSD300, self).__init__() 244 | 245 | self.n_classes = n_classes 246 | self.n_boxes = n_boxes 247 | 248 | self.base = VGG16() 249 | self.extra = ExtraLayers() 250 | self.classifiers = Classifiers(n_classes, n_boxes) 251 | 252 | # L2 norm scaler for conv4_out. Updated thru backprop 253 | self.rescale_factors = nn.Parameter(torch.FloatTensor(1, 512, 1, 1)) # there are 512 channels in conv4_3_feats 254 | nn.init.constant_(self.rescale_factors, 20) 255 | 256 | # default boxes 257 | self.priors = get_default_boxes() 258 | 259 | def forward(self, image): 260 | """ Forward propagation. 261 | 262 | Input: images forming tensor of dimensions (N, 3, 300, 300) 263 | 264 | Returns: 8732 locations and class scores for each image. 265 | """ 266 | # Run VGG16 267 | conv4_out, conv7_out = self.base(image) # (N, 512, 38, 38), (N, 1024, 19, 19) 268 | conv4_out = self.L2Norm(conv4_out) 269 | 270 | # Run ExtraLayers 271 | conv8_out, conv9_out, conv10_out, conv11_out = self.extra(conv7_out) # (N, 512, 10, 10), (N, 256, 5, 5), (N, 256, 3, 3), (N, 256, 1, 1) 272 | 273 | # setup prediction inputs 274 | features = (conv4_out, conv7_out, conv8_out, conv9_out, 275 | conv10_out, conv11_out) 276 | 277 | # Run Classifiers 278 | output_boxes, output_scores = self.classifiers(features) 279 | 280 | return output_boxes, output_scores 281 | 282 | def L2Norm(self, out, eps=1e-10): 283 | """ Rescale the outputs of conv4. The rescaling factor is a parameter 284 | that gets updated through backprop. 285 | """ 286 | norm = out.pow(2).sum(dim=1, keepdim=True).sqrt() + eps 287 | out = torch.div(out, norm) 288 | return out * self.rescale_factors 289 | 290 | def detect_objects(self, predicted_locs, predicted_scores, min_score, max_overlap, top_k): 291 | """ 292 | (This function as defined in https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection)) 293 | 294 | Decipher the 8732 locations and class scores (output of ths SSD300) to detect objects. 295 | For each class, perform Non-Maximum Suppression (NMS) on boxes that are above a minimum threshold. 296 | :param predicted_locs: predicted locations/boxes w.r.t the 8732 prior boxes, a tensor of dimensions (N, 8732, 4) 297 | :param predicted_scores: class scores for each of the encoded locations/boxes, a tensor of dimensions (N, 8732, n_classes) 298 | :param min_score: minimum threshold for a box to be considered a match for a certain class 299 | :param max_overlap: maximum overlap two boxes can have so that the one with the lower score is not suppressed via NMS 300 | :param top_k: if there are a lot of resulting detection across all classes, keep only the top 'k' 301 | :return: detections (boxes, labels, and scores), lists of length batch_size 302 | """ 303 | batch_size = predicted_locs.size(0) 304 | n_priors = self.priors.size(0) 305 | predicted_scores = F.softmax(predicted_scores, dim=2) # (N, 8732, n_classes) 306 | 307 | # Lists to store final predicted boxes, labels, and scores for all images 308 | all_images_boxes = list() 309 | all_images_labels = list() 310 | all_images_scores = list() 311 | 312 | assert n_priors == predicted_locs.size(1) == predicted_scores.size(1) 313 | 314 | for i in range(batch_size): 315 | # Decode object coordinates from the form we regressed predicted boxes to 316 | decoded_locs = cxcy_to_xy( 317 | gcxgcy_to_cxcy(predicted_locs[i], self.priors)) # (8732, 4), these are fractional pt. coordinates 318 | 319 | # Lists to store boxes and scores for this image 320 | image_boxes = list() 321 | image_labels = list() 322 | image_scores = list() 323 | 324 | max_scores, best_label = predicted_scores[i].max(dim=1) # (8732) 325 | 326 | # Check for each class 327 | for c in range(1, self.n_classes): 328 | # Keep only predicted boxes and scores where scores for this class are above the minimum score 329 | class_scores = predicted_scores[i][:, c] # (8732) 330 | score_above_min_score = class_scores > min_score # torch.uint8 (byte) tensor, for indexing 331 | n_above_min_score = score_above_min_score.sum().item() 332 | if n_above_min_score == 0: 333 | continue 334 | class_scores = class_scores[score_above_min_score] # (n_qualified), n_min_score <= 8732 335 | class_decoded_locs = decoded_locs[score_above_min_score] # (n_qualified, 4) 336 | 337 | # Sort predicted boxes and scores by scores 338 | class_scores, sort_ind = class_scores.sort(dim=0, descending=True) # (n_qualified), (n_min_score) 339 | class_decoded_locs = class_decoded_locs[sort_ind] # (n_min_score, 4) 340 | 341 | # Find the overlap between predicted boxes 342 | overlap = find_jaccard_overlap(class_decoded_locs, class_decoded_locs) # (n_qualified, n_min_score) 343 | 344 | # Non-Maximum Suppression (NMS) 345 | 346 | # A torch.uint8 (byte) tensor to keep track of which predicted boxes to suppress 347 | # 1 implies suppress, 0 implies don't suppress 348 | suppress = torch.zeros((n_above_min_score), dtype=torch.uint8).to(DEVICE) # (n_qualified) 349 | 350 | # Consider each box in order of decreasing scores 351 | for box in range(class_decoded_locs.size(0)): 352 | # If this box is already marked for suppression 353 | if suppress[box] == 1: 354 | continue 355 | 356 | # Suppress boxes whose overlaps (with this box) are greater than maximum overlap 357 | # Find such boxes and update suppress indices 358 | suppress = torch.max(suppress, overlap[box] > max_overlap) 359 | # The max operation retains previously suppressed boxes, like an 'OR' operation 360 | 361 | # Don't suppress this box, even though it has an overlap of 1 with itself 362 | suppress[box] = 0 363 | 364 | # Store only unsuppressed boxes for this class 365 | image_boxes.append(class_decoded_locs[1 - suppress]) 366 | image_labels.append(torch.LongTensor((1 - suppress).sum().item() * [c]).to(DEVICE)) 367 | image_scores.append(class_scores[1 - suppress]) 368 | 369 | # If no object in any class is found, store a placeholder for 'background' 370 | if len(image_boxes) == 0: 371 | image_boxes.append(torch.FloatTensor([[0., 0., 1., 1.]]).to(DEVICE)) 372 | image_labels.append(torch.LongTensor([0]).to(DEVICE)) 373 | image_scores.append(torch.FloatTensor([0.]).to(DEVICE)) 374 | 375 | # Concatenate into single tensors 376 | image_boxes = torch.cat(image_boxes, dim=0) # (n_objects, 4) 377 | image_labels = torch.cat(image_labels, dim=0) # (n_objects) 378 | image_scores = torch.cat(image_scores, dim=0) # (n_objects) 379 | n_objects = image_scores.size(0) 380 | 381 | # Keep only the top k objects 382 | if n_objects > top_k: 383 | image_scores, sort_ind = image_scores.sort(dim=0, descending=True) 384 | image_scores = image_scores[:top_k] # (top_k) 385 | image_boxes = image_boxes[sort_ind][:top_k] # (top_k, 4) 386 | image_labels = image_labels[sort_ind][:top_k] # (top_k) 387 | 388 | # Append to lists that store predicted boxes and scores for all images 389 | all_images_boxes.append(image_boxes) 390 | all_images_labels.append(image_labels) 391 | all_images_scores.append(image_scores) 392 | 393 | return all_images_boxes, all_images_labels, all_images_scores # lists of length batch_size 394 | 395 | 396 | class CoClassifiers(Classifiers): 397 | """ Modifed classifier convolutions for SSCoD due to extra features """ 398 | def __init__(self, n_classes, n_boxes): 399 | super(Classifiers, self).__init__() 400 | 401 | self.n_classes = n_classes 402 | self.n_boxes = n_boxes 403 | assert len(self.n_boxes) == 6 404 | 405 | self.box4 = nn.Conv2d(512+64, n_boxes[0] * 4, kernel_size=3, padding=1) 406 | self.box7 = nn.Conv2d(1024+64, n_boxes[1] * 4, kernel_size=3, padding=1) 407 | self.box8 = nn.Conv2d(512, n_boxes[2] * 4, kernel_size=3, padding=1) 408 | self.box9 = nn.Conv2d(256, n_boxes[3] * 4, kernel_size=3, padding=1) 409 | self.box10 = nn.Conv2d(256, n_boxes[4] * 4, kernel_size=3, padding=1) 410 | self.box11 = nn.Conv2d(256, n_boxes[5] * 4, kernel_size=3, padding=1) 411 | 412 | self.class4 = nn.Conv2d(512+64, n_boxes[0] * n_classes, kernel_size=3, padding=1) 413 | self.class7 = nn.Conv2d(1024+64, n_boxes[1] * n_classes, kernel_size=3, padding=1) 414 | self.class8 = nn.Conv2d(512, n_boxes[2] * n_classes, kernel_size=3, padding=1) 415 | self.class9 = nn.Conv2d(256, n_boxes[3] * n_classes, kernel_size=3, padding=1) 416 | self.class10 = nn.Conv2d(256, n_boxes[4] * n_classes, kernel_size=3, padding=1) 417 | self.class11 = nn.Conv2d(256, n_boxes[5] * n_classes, kernel_size=3, padding=1) 418 | 419 | for layer in self.children(): 420 | if isinstance(layer, nn.Conv2d): 421 | nn.init.xavier_uniform_(layer.weight) 422 | nn.init.constant_(layer.bias, 0.) 423 | 424 | 425 | class SSCoD(SSD300): 426 | def __init__(self, n_classes=VOC_NUM_CLASSES, n_boxes=(4, 6, 6, 6, 4, 4)): 427 | super(SSCoD, self).__init__(n_classes, n_boxes) 428 | 429 | self.classifiers = CoClassifiers(n_classes, n_boxes) 430 | 431 | self.spatialcooc4 = SpatialCoocLayer(in_channels=512, out_channels=8, local_kernel=5) 432 | self.spatialcooc7 = SpatialCoocLayer(in_channels=1024, out_channels=8, local_kernel=5) 433 | self.spatialcooc8 = SpatialCoocLayer(in_channels=512, out_channels=8, local_kernel=5) 434 | 435 | def forward(self, image): 436 | """ Forward propagation. 437 | 438 | Input: images forming tensor of dimensions (N, 3, 300, 300) 439 | 440 | Returns: 8732 locations and class scores for each image. 441 | """ 442 | # Run VGG16 443 | conv4_out, conv7_out = self.base(image) # (N, 512, 38, 38), (N, 1024, 19, 19) 444 | conv4_out = self.L2Norm(conv4_out) 445 | 446 | # Run ExtraLayers 447 | conv8_out, conv9_out, conv10_out, conv11_out = self.extra(conv7_out) # (N, 512, 10, 10), (N, 256, 5, 5), (N, 256, 3, 3), (N, 256, 1, 1) 448 | 449 | # run spatial co-occurrence layers and stack with activations 450 | spatial_corr4 = self.spatialcooc4(conv4_out) 451 | spatial_corr7 = self.spatialcooc7(conv7_out) 452 | spatial_corr8 = self.spatialcooc8(conv8_out) 453 | conv4_out = torch.cat([conv4_out, spatial_corr4], dim=1) 454 | conv7_out = torch.cat([conv7_out, spatial_corr7], dim=1) 455 | conv8_out = torch.cat([conv8_out, spatial_corr8], dim=1) 456 | 457 | # setup prediction inputs 458 | features = (conv4_out, conv7_out, conv8_out, conv9_out, 459 | conv10_out, conv11_out) 460 | 461 | # Run Classifiers 462 | output_boxes, output_scores = self.classifiers(features) 463 | 464 | return output_boxes, output_scores 465 | -------------------------------------------------------------------------------- /parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | from functools import reduce 5 | from operator import getitem 6 | from datetime import datetime 7 | from logger import setup_logging 8 | from utils import read_json, write_json 9 | 10 | 11 | class ConfigParser: 12 | def __init__(self, args, options='', timestamp=True): 13 | # parse default and custom cli options 14 | for opt in options: 15 | args.add_argument(*opt.flags, default=None, type=opt.type) 16 | args = args.parse_args() 17 | 18 | if args.device: 19 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 20 | if args.resume: 21 | self.resume = Path(args.resume) 22 | self.cfg_fname = self.resume.parent / 'config.json' 23 | else: 24 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 25 | assert args.config is not None, msg_no_cfg 26 | self.resume = None 27 | self.cfg_fname = Path(args.config) 28 | 29 | # load config file and apply custom cli options 30 | config = read_json(self.cfg_fname) 31 | self.__config = _update_config(config, options, args) 32 | 33 | # set save_dir where trained model and log will be saved. 34 | save_dir = Path(self.config['trainer']['save_dir']) 35 | timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else '' 36 | 37 | exper_name = self.config['name'] 38 | self.__save_dir = save_dir / 'models' / exper_name / timestamp 39 | self.__log_dir = save_dir / 'log' / exper_name / timestamp 40 | 41 | self.save_dir.mkdir(parents=True, exist_ok=True) 42 | self.log_dir.mkdir(parents=True, exist_ok=True) 43 | 44 | # save updated config file to the checkpoint dir 45 | write_json(self.config, self.save_dir / 'config.json') 46 | 47 | # configure logging module 48 | setup_logging(self.log_dir) 49 | self.log_levels = { 50 | 0: logging.WARNING, 51 | 1: logging.INFO, 52 | 2: logging.DEBUG 53 | } 54 | 55 | def initialize(self, name, module, *args, **kwargs): 56 | """ 57 | finds a function handle with the name given as 'type' in config, and returns the 58 | instance initialized with corresponding keyword args given as 'args'. 59 | """ 60 | module_cfg = self[name] 61 | return getattr(module, module_cfg['type'])(*args, **module_cfg['args'], **kwargs) 62 | 63 | def __getitem__(self, name): 64 | return self.config[name] 65 | 66 | def get_logger(self, name, verbosity=2): 67 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys()) 68 | assert verbosity in self.log_levels, msg_verbosity 69 | logger = logging.getLogger(name) 70 | logger.setLevel(self.log_levels[verbosity]) 71 | return logger 72 | 73 | # setting read-only attributes 74 | @property 75 | def config(self): 76 | return self.__config 77 | 78 | @property 79 | def save_dir(self): 80 | return self.__save_dir 81 | 82 | @property 83 | def log_dir(self): 84 | return self.__log_dir 85 | 86 | # helper functions used to update config dict with custom cli options 87 | def _update_config(config, options, args): 88 | for opt in options: 89 | value = getattr(args, _get_opt_name(opt.flags)) 90 | if value is not None: 91 | _set_by_path(config, opt.target, value) 92 | return config 93 | 94 | def _get_opt_name(flags): 95 | for flg in flags: 96 | if flg.startswith('--'): 97 | return flg.replace('--', '') 98 | return flags[0].replace('--', '') 99 | 100 | def _set_by_path(tree, keys, value): 101 | """Set a value in a nested object in tree by sequence of keys.""" 102 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 103 | 104 | def _get_by_path(tree, keys): 105 | """Access a nested object in tree by sequence of keys.""" 106 | return reduce(getitem, keys, tree) 107 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from tqdm import tqdm 4 | import data_loader.data_loaders as module_data 5 | import model.loss as module_loss 6 | import model.metric as module_metric 7 | import model.model as module_arch 8 | from parse_config import ConfigParser 9 | 10 | 11 | def main(config, resume): 12 | logger = config.get_logger('test') 13 | 14 | data_loader = config.initialize('data_loader', module_data, mode='test') 15 | 16 | # build model architecture 17 | model = config.initialize('arch', module_arch) 18 | logger.info(model) 19 | 20 | # get function handles of loss and metrics 21 | loss_fn = getattr(module_loss, config['loss']) 22 | metric_fns = [getattr(module_metric, met) for met in config['metrics']] 23 | 24 | logger.info('Loading checkpoint: {} ...'.format(resume)) 25 | checkpoint = torch.load(resume) 26 | state_dict = checkpoint['state_dict'] 27 | 28 | model.load_state_dict(state_dict) 29 | 30 | # prepare model for testing 31 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 32 | model = model.to(device) 33 | model.eval() 34 | 35 | total_loss = 0.0 36 | total_metrics = torch.zeros(len(metric_fns)) 37 | 38 | with torch.no_grad(): 39 | for batch_idx, (data, boxes, labels, _) in enumerate(test_data_loader): 40 | 41 | data = data.to(DEVICE) 42 | boxes = [b.to(DEVICE) for b in boxes] 43 | labels = [l.to(DEVICE) for l in labels] 44 | 45 | output_boxes, output_scores = self.model(data) 46 | 47 | mbloss = loss_fn(threshold=0.5, neg_pos_ratio=3, 48 | alpha=1., device=device) 49 | loss = mbloss(output_boxes, output_scores, boxes, labels) 50 | 51 | # computing loss, metrics on test set 52 | batch_size = data.shape[0] 53 | total_loss += loss.item() * batch_size 54 | 55 | n_samples = len(data_loader.sampler) 56 | log = {'loss': total_loss / n_samples} 57 | log.update({ 58 | met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns) 59 | }) 60 | logger.info(log) 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser(description='Evaluator') 65 | 66 | parser.add_argument('-r', '--resume', default=None, type=str, 67 | help='path to latest checkpoint (default: None)') 68 | parser.add_argument('-d', '--device', default=None, type=str, 69 | help='indices of GPUs to enable (default: all)') 70 | 71 | args = parser.parse_args() 72 | config = ConfigParser(args) 73 | main(config, args.resume) 74 | -------------------------------------------------------------------------------- /test_image.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from torchvision import transforms 3 | from PIL import Image, ImageDraw, ImageFont 4 | from constants import * 5 | from utils import * 6 | from model.model import * 7 | 8 | 9 | def detect(original_image, model, min_score, max_overlap, top_k, suppress=None): 10 | """ Visualize model results on an original VOC image. 11 | 12 | Inputs: 13 | original_image: image, a PIL Image 14 | min_score: minimum threshold for a detected box to be considered a match for a certain class 15 | max_overlap: maximum overlap two boxes can have so that the one with the lower score is not suppressed via Non-Maximum Suppression (NMS) 16 | top_k: if there are a lot of resulting detection across all classes, keep only the top 'k' 17 | suppress: classes that you know for sure cannot be in the image or you do not want in the image, a list 18 | """ 19 | # Transforms 20 | resize = transforms.Resize((300, 300)) 21 | to_tensor = transforms.ToTensor() 22 | normalize = transforms.Normalize(mean=IMAGENET_MEAN, 23 | std=IMAGENET_STD) 24 | image = normalize(to_tensor(resize(original_image))) 25 | 26 | # Move to default device 27 | image = image.to(DEVICE) 28 | 29 | # Forward prop. 30 | predicted_locs, predicted_scores = model(image.unsqueeze(0)) 31 | 32 | # Detect objects in SSD output 33 | det_boxes, det_labels, det_scores = model.detect_objects(predicted_locs, predicted_scores, min_score=min_score, 34 | max_overlap=max_overlap, top_k=top_k) 35 | 36 | # Move detections to the CPU 37 | det_boxes = det_boxes[0].to('cpu') 38 | 39 | # Transform to original image dimensions 40 | original_dims = torch.FloatTensor([original_image.width, original_image.height, original_image.width, original_image.height]).unsqueeze(0) 41 | det_boxes = det_boxes * original_dims 42 | 43 | # Decode class integer labels 44 | det_labels = [VOC_DECODING[l] for l in det_labels[0].to('cpu').tolist()] 45 | 46 | # If no objects found, the detected labels will be set to ['0.'], i.e. ['background'] in SSD300.detect_objects() in model.py 47 | if det_labels == ['__background__']: 48 | # Just return original image 49 | return original_image 50 | 51 | distinct_colors = ['#e6194b', '#3cb44b', '#ffe119', '#0082c8', '#f58231', '#911eb4', '#46f0f0', '#f032e6', 52 | '#d2f53c', '#fabebe', '#008080', '#000080', '#aa6e28', '#fffac8', '#800000', '#aaffc3', '#808000', 53 | '#ffd8b1', '#e6beff', '#808080', '#FFFFFF'] 54 | label_color_map = {k: distinct_colors[i] for i, k in enumerate(VOC_ENCODING.keys())} 55 | 56 | # Annotate 57 | annotated_image = original_image 58 | draw = ImageDraw.Draw(annotated_image) 59 | font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 15) 60 | 61 | # Suppress specific classes, if needed 62 | for i in range(det_boxes.size(0)): 63 | if suppress is not None: 64 | if det_labels[i] in suppress: 65 | continue 66 | 67 | # Boxes 68 | box_location = det_boxes[i].tolist() 69 | draw.rectangle(xy=box_location, outline=label_color_map[det_labels[i]]) 70 | draw.rectangle(xy=[l + 1. for l in box_location], outline=label_color_map[ 71 | det_labels[i]]) 72 | 73 | # Text 74 | text_size = font.getsize(det_labels[i].upper()) 75 | text_location = [box_location[0] + 2., box_location[1] - text_size[1]] 76 | textbox_location = [box_location[0], box_location[1] - text_size[1], box_location[0] + text_size[0] + 4., 77 | box_location[1]] 78 | draw.rectangle(xy=textbox_location, fill=label_color_map[det_labels[i]]) 79 | draw.text(xy=text_location, text=det_labels[i].upper(), fill='white', 80 | font=font) 81 | del draw 82 | 83 | return annotated_image 84 | 85 | 86 | def main(args): 87 | 88 | image = args.image 89 | image = './data/VOCdevkit/VOC2007/JPEGImages/000012.jpg' 90 | raw_image = Image.open(image, mode='r') 91 | raw_image = raw_image.convert('RGB') 92 | 93 | # Load model checkpoint 94 | checkpoint = args.model 95 | checkpoint = './saved/models/VOC_SSD/checkpoint-epoch98.pth' 96 | checkpoint = torch.load(checkpoint) 97 | state_dict = checkpoint['state_dict'] 98 | 99 | mod = SSD300(n_classes=21) 100 | mod.load_state_dict(state_dict) 101 | mod = mod.to(DEVICE) 102 | mod.eval() 103 | 104 | detect(raw_image, mod, min_score=0.2, max_overlap=0.5, top_k=200).show() 105 | 106 | 107 | if __name__ == '__main__': 108 | parser = argparse.ArgumentParser(description='Draw bbox') 109 | parser.add_argument('-m', '--model', default=None, type=str, 110 | help='path to checkpoint to use') 111 | parser.add_argument('-i', '--image', default=None, type=str, 112 | help='path to image to detect') 113 | args = parser.parse_args() 114 | main(args) 115 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import torch 4 | import data_loader.data_loaders as module_data 5 | import model.loss as module_loss 6 | import model.metric as module_metric 7 | import model.model as module_arch 8 | from parse_config import ConfigParser 9 | from trainer import Trainer 10 | 11 | 12 | def main(config): 13 | logger = config.get_logger('train') 14 | 15 | # setup data_loader instances 16 | data_loader = config.initialize('data_loader', module_data, mode='train') 17 | valid_data_loader = config.initialize('data_loader', module_data, mode='valid') 18 | 19 | # build model architecture, then print to console 20 | model = config.initialize('arch', module_arch) 21 | logger.info(model) 22 | 23 | # get function handles of loss and metrics 24 | loss = getattr(module_loss, config['loss']) 25 | metrics = [getattr(module_metric, met) for met in config['metrics']] 26 | 27 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler 28 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 29 | optimizer = config.initialize('optimizer', torch.optim, trainable_params) 30 | 31 | lr_scheduler = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer) 32 | 33 | trainer = Trainer(model, loss, metrics, optimizer, 34 | config=config, 35 | data_loader=data_loader, 36 | valid_data_loader=valid_data_loader, 37 | lr_scheduler=lr_scheduler) 38 | 39 | trainer.train() 40 | 41 | 42 | if __name__ == '__main__': 43 | args = argparse.ArgumentParser(description='Trainer') 44 | args.add_argument('-c', '--config', default=None, type=str, 45 | help='config file path (default: None)') 46 | args.add_argument('-r', '--resume', default=None, type=str, 47 | help='path to latest checkpoint (default: None)') 48 | args.add_argument('-d', '--device', default=None, type=str, 49 | help='indices of GPUs to enable (default: all)') 50 | 51 | # custom cli options to modify configuration from default values given in json file. 52 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') 53 | options = [ 54 | CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')), 55 | CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size')) 56 | ] 57 | config = ConfigParser(args, options) 58 | main(config) 59 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import * 2 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.utils import make_grid 4 | from base import BaseTrainer 5 | from model.metric import meanAP 6 | 7 | from constants import DEVICE 8 | 9 | class Trainer(BaseTrainer): 10 | """ 11 | Trainer class 12 | 13 | Note: 14 | Inherited from BaseTrainer. 15 | """ 16 | def __init__(self, model, loss, metrics, optimizer, config, 17 | data_loader, valid_data_loader=None, lr_scheduler=None): 18 | super(Trainer, self).__init__(model, loss, metrics, optimizer, config) 19 | self.config = config 20 | self.data_loader = data_loader 21 | self.valid_data_loader = valid_data_loader 22 | self.do_validation = self.valid_data_loader is not None 23 | self.lr_scheduler = lr_scheduler 24 | self.log_step = int(np.sqrt(data_loader.batch_size)) 25 | self.multiboxloss = loss(threshold=0.5, neg_pos_ratio=3, 26 | alpha=1., device=DEVICE) 27 | 28 | def _train_epoch(self, epoch): 29 | """ 30 | Training logic for an epoch 31 | 32 | :param epoch: Current training epoch. 33 | :return: A log that contains all information you want to save. 34 | 35 | Note: 36 | If you have additional information to record, for example: 37 | > additional_log = {"x": x, "y": y} 38 | merge it with log before return. i.e. 39 | > log = {**log, **additional_log} 40 | > return log 41 | 42 | The metrics in log must have the key 'metrics'. 43 | """ 44 | self.model.train() 45 | 46 | total_loss = 0 47 | for batch_idx, (data, boxes, labels, _) in enumerate(self.data_loader): 48 | 49 | data = data.to(DEVICE) 50 | boxes = [b.to(DEVICE) for b in boxes] 51 | labels = [l.to(DEVICE) for l in labels] 52 | 53 | self.optimizer.zero_grad() 54 | output_boxes, output_scores = self.model(data) 55 | 56 | loss = self.multiboxloss(output_boxes, output_scores, boxes, labels) 57 | loss.backward() 58 | 59 | self.optimizer.step() 60 | 61 | self.writer.set_step((epoch - 1) * len(self.data_loader) + batch_idx) 62 | self.writer.add_scalar('loss', loss.item()) 63 | total_loss += loss.item() 64 | 65 | if batch_idx % self.log_step == 0: 66 | self.logger.debug('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format( 67 | epoch, 68 | batch_idx * self.data_loader.batch_size, 69 | self.data_loader.n_samples, 70 | 100.0 * batch_idx / len(self.data_loader), 71 | loss.item())) 72 | self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) 73 | 74 | log = { 75 | 'loss': total_loss / len(self.data_loader), 76 | } 77 | 78 | if self.do_validation: 79 | val_log = self._valid_epoch(epoch) 80 | # metric_log = self._valid_metric(epoch) 81 | log = {**log, **val_log} 82 | # log = {**log, **val_log, **metric_log} 83 | 84 | if self.lr_scheduler is not None: 85 | self.lr_scheduler.step() 86 | 87 | return log 88 | 89 | def _valid_epoch(self, epoch): 90 | """ 91 | Validate after training an epoch 92 | 93 | :return: A log that contains information about validation 94 | 95 | """ 96 | self.model.eval() 97 | total_val_loss = 0 98 | with torch.no_grad(): 99 | for batch_idx, (data, boxes, labels, _) in enumerate(self.valid_data_loader): 100 | 101 | data = data.to(DEVICE) 102 | boxes = [b.to(DEVICE) for b in boxes] 103 | labels = [l.to(DEVICE) for l in labels] 104 | 105 | output_boxes, output_scores = self.model(data) 106 | 107 | loss = self.multiboxloss(output_boxes, output_scores, boxes, labels) 108 | 109 | self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') 110 | self.writer.add_scalar('loss', loss.item()) 111 | total_val_loss += loss.item() 112 | self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) 113 | 114 | if batch_idx % (10 * self.log_step) == 0: 115 | self.logger.debug('Val Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format( 116 | epoch, 117 | batch_idx * self.data_loader.batch_size, 118 | self.data_loader.n_samples, 119 | 100.0 * batch_idx / len(self.data_loader), 120 | loss.item())) 121 | self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) 122 | 123 | 124 | # add histogram of model parameters to the tensorboard 125 | for name, p in self.model.named_parameters(): 126 | self.writer.add_histogram(name, p, bins='auto') 127 | 128 | return { 129 | 'val_loss': total_val_loss / len(self.valid_data_loader), 130 | } 131 | 132 | def _valid_metric(self, epoch): 133 | """ Compute mAP metric over validation set after certain number of 134 | epochs. Because the metric is computed over the whole dataset, 135 | I separated it from the validation loss method to lower 136 | training time. 137 | 138 | Return: A log with information about metrics 139 | """ 140 | print('Computing metrics:') 141 | 142 | self.model.eval() 143 | 144 | # must compute mAP over entire dataset 145 | all_boxes = list() 146 | all_labels = list() 147 | all_scores = list() 148 | all_true_boxes = list() 149 | all_true_labels = list() 150 | all_difficulties = list() 151 | 152 | with torch.no_grad(): 153 | for batch_idx, (data, boxes, labels, difficulties) in enumerate(self.valid_data_loader): 154 | 155 | 156 | data = data.to(DEVICE) 157 | boxes = [b.to(DEVICE) for b in boxes] 158 | labels = [l.to(DEVICE) for l in labels] 159 | difficulties = [d.to(DEVICE) for d in difficulties] 160 | 161 | output_boxes, output_scores = self.model(data) 162 | 163 | batch_boxes, batch_labels, batch_scores = self.model.detect_objects(output_boxes, output_scores, 164 | min_score=0.01, max_overlap=0.45, 165 | top_k=200) 166 | 167 | all_boxes.extend(batch_boxes) 168 | all_labels.extend(batch_labels) 169 | all_scores.extend(batch_scores) 170 | all_true_boxes.extend(boxes) 171 | all_true_labels.extend(labels) 172 | all_difficulties.extend(difficulties) 173 | 174 | if batch_idx % (10 * self.log_step) == 0: 175 | self.logger.debug('Val Epoch: {} [{}/{} ({:.0f}%)] Append'.format( 176 | epoch, 177 | batch_idx * self.data_loader.batch_size, 178 | self.data_loader.n_samples, 179 | 100.0 * batch_idx / len(self.data_loader))) 180 | 181 | # Calculate mAP 182 | class_APs, mAP = meanAP(all_boxes, all_labels, all_scores, all_true_boxes, all_true_labels, all_difficulties) 183 | 184 | return { 185 | 'val_mAP': mAP, 186 | 'val_class_AP': class_APs 187 | } 188 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | from .joint_transforms import * 3 | from .model_utils import * 4 | -------------------------------------------------------------------------------- /utils/joint_transforms.py: -------------------------------------------------------------------------------- 1 | """ Modified from https://github.com/amdegroot/ssd.pytorch for implementing 2 | joint transforms (i.e. co-operations on both images and corresponding 3 | bounding boxes and labels). Modifications are made to suit the format 4 | of our input data. 5 | """ 6 | 7 | import torch 8 | from torchvision import transforms 9 | import cv2 10 | import numpy as np 11 | import types 12 | from numpy import random 13 | 14 | 15 | def intersect(box_a, box_b): 16 | max_xy = np.minimum(box_a[:, 2:], box_b[2:]) 17 | min_xy = np.maximum(box_a[:, :2], box_b[:2]) 18 | inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) 19 | return inter[:, 0] * inter[:, 1] 20 | 21 | 22 | def jaccard_numpy(box_a, box_b): 23 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 24 | is simply the intersection over union of two boxes. 25 | E.g.: 26 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 27 | Args: 28 | box_a: Multiple bounding boxes, Shape: [num_boxes,4] 29 | box_b: Single bounding box, Shape: [4] 30 | Return: 31 | jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]] 32 | """ 33 | inter = intersect(box_a, box_b) 34 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 35 | (box_a[:, 3]-box_a[:, 1])) # [A,B] 36 | area_b = ((box_b[2]-box_b[0]) * 37 | (box_b[3]-box_b[1])) # [A,B] 38 | 39 | union = area_a + area_b - inter 40 | return inter / union # [A,B] 41 | 42 | 43 | class Compose(object): 44 | """ Composes several augmentations together. 45 | 46 | Args: 47 | transforms (List[Transform]): list of transforms to compose. 48 | 49 | Example: 50 | >>> augmentations.Compose([ 51 | >>> transforms.CenterCrop(10), 52 | >>> transforms.ToTensor(), 53 | >>> ]) 54 | """ 55 | 56 | def __init__(self, transforms): 57 | self.transforms = transforms 58 | 59 | def __call__(self, img, boxes=None, labels=None): 60 | for t in self.transforms: 61 | img, boxes, labels = t(img, boxes, labels) 62 | return img, boxes, labels 63 | 64 | 65 | class Lambda(object): 66 | """Applies a lambda as a transform.""" 67 | def __init__(self, lambd): 68 | assert isinstance(lambd, types.LambdaType) 69 | self.lambd = lambd 70 | 71 | def __call__(self, img, boxes=None, labels=None): 72 | return self.lambd(img, boxes, labels) 73 | 74 | 75 | class ConvertFromPIL(object): 76 | """ PIL image to numpy array """ 77 | def __call__(self, image, boxes=None, labels=None): 78 | return np.array(image).astype(np.float32), boxes, labels 79 | 80 | 81 | class SubtractMeans(object): 82 | def __init__(self, mean): 83 | self.mean = np.array(mean, dtype=np.float32) 84 | 85 | def __call__(self, image, boxes=None, labels=None): 86 | image = image.astype(np.float32) 87 | image -= self.mean 88 | return image, boxes, labels 89 | 90 | 91 | class Normalize(object): 92 | def __init__(self, mean, std): 93 | self.mean = np.array(mean, dtype=np.float32) 94 | self.std = np.array(std, dtype=np.float32) 95 | 96 | def __call__(self, image, boxes=None, labels=None): 97 | image = image.astype(np.float32) 98 | image /= 256. 99 | image -= self.mean 100 | image /= self.std 101 | return image, boxes, labels 102 | 103 | 104 | class ToAbsoluteCoords(object): 105 | """ Change bbox coordinates from percent to pixels """ 106 | def __call__(self, image, boxes=None, labels=None): 107 | height, width, channels = image.shape 108 | boxes[:, 0] *= width 109 | boxes[:, 2] *= width 110 | boxes[:, 1] *= height 111 | boxes[:, 3] *= height 112 | 113 | return image, boxes, labels 114 | 115 | 116 | class ToPercentCoords(object): 117 | """ Change bbox coordinates from pixels to percents """ 118 | def __call__(self, image, boxes=None, labels=None): 119 | height, width, channels = image.shape 120 | 121 | boxes[:, 0] /= width 122 | boxes[:, 2] /= width 123 | boxes[:, 1] /= height 124 | boxes[:, 3] /= height 125 | 126 | return image, boxes, labels 127 | 128 | 129 | class Resize(object): 130 | def __init__(self, size=300): 131 | self.size = size 132 | 133 | def __call__(self, image, boxes=None, labels=None): 134 | image = cv2.resize(image, (self.size, 135 | self.size)) 136 | return image, boxes, labels 137 | 138 | 139 | class RandomSaturation(object): 140 | """ Randomly apply saturation to image in numpy array form """ 141 | def __init__(self, lower=0.5, upper=1.5): 142 | self.lower = lower 143 | self.upper = upper 144 | assert self.upper >= self.lower, "contrast upper must be >= lower." 145 | assert self.lower >= 0, "contrast lower must be non-negative." 146 | 147 | def __call__(self, image, boxes=None, labels=None): 148 | if random.randint(2): 149 | image[:, :, 1] *= random.uniform(self.lower, self.upper) 150 | 151 | return image, boxes, labels 152 | 153 | 154 | class RandomHue(object): 155 | """ Randomly apply hue change to image in numpy array form """ 156 | def __init__(self, delta=18.0): 157 | assert delta >= 0.0 and delta <= 360.0 158 | self.delta = delta 159 | 160 | def __call__(self, image, boxes=None, labels=None): 161 | if random.randint(2): 162 | image[:, :, 0] += random.uniform(-self.delta, self.delta) 163 | image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0 164 | image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 165 | return image, boxes, labels 166 | 167 | 168 | class RandomLightingNoise(object): 169 | def __init__(self): 170 | self.perms = ((0, 1, 2), (0, 2, 1), 171 | (1, 0, 2), (1, 2, 0), 172 | (2, 0, 1), (2, 1, 0)) 173 | 174 | def __call__(self, image, boxes=None, labels=None): 175 | if random.randint(2): 176 | swap = self.perms[random.randint(len(self.perms))] 177 | shuffle = SwapChannels(swap) # shuffle channels 178 | image = shuffle(image) 179 | return image, boxes, labels 180 | 181 | 182 | class ConvertColor(object): 183 | def __init__(self, current='RGB', transform='HSV'): 184 | self.transform = transform 185 | self.current = current 186 | 187 | def __call__(self, image, boxes=None, labels=None): 188 | if self.current == 'RGB' and self.transform == 'HSV': 189 | image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) 190 | elif self.current == 'HSV' and self.transform == 'RGB': 191 | image = cv2.cvtColor(image, cv2.COLOR_HSV2RGB) 192 | else: 193 | raise NotImplementedError 194 | return image, boxes, labels 195 | 196 | 197 | class RandomContrast(object): 198 | def __init__(self, lower=0.5, upper=1.5): 199 | self.lower = lower 200 | self.upper = upper 201 | assert self.upper >= self.lower, "contrast upper must be >= lower." 202 | assert self.lower >= 0, "contrast lower must be non-negative." 203 | 204 | # expects float image 205 | def __call__(self, image, boxes=None, labels=None): 206 | if random.randint(2): 207 | alpha = random.uniform(self.lower, self.upper) 208 | image *= alpha 209 | return image, boxes, labels 210 | 211 | 212 | class RandomBrightness(object): 213 | def __init__(self, delta=32): 214 | assert delta >= 0.0 215 | assert delta <= 255.0 216 | self.delta = delta 217 | 218 | def __call__(self, image, boxes=None, labels=None): 219 | if random.randint(2): 220 | delta = random.uniform(-self.delta, self.delta) 221 | image += delta 222 | return image, boxes, labels 223 | 224 | 225 | class ToCV2Image(object): 226 | def __call__(self, tensor, boxes=None, labels=None): 227 | return tensor.cpu().numpy().astype(np.float32).transpose((1, 2, 0)), boxes, labels 228 | 229 | 230 | class ToTensor(object): 231 | """ Convert numpy arrays to tensor. Permute to put channels in front """ 232 | def __call__(self, image, boxes=None, labels=None): 233 | image = torch.from_numpy(image.astype(np.float32)).permute(2, 0, 1) 234 | boxes = torch.FloatTensor(boxes) 235 | labels = torch.LongTensor(labels) 236 | return image, boxes, labels 237 | 238 | 239 | class RandomSampleCrop(object): 240 | """Crop 241 | Arguments: 242 | img (Image): the image being input during training 243 | boxes (Tensor): the original bounding boxes in pt form 244 | labels (Tensor): the class labels for each bbox 245 | mode (float tuple): the min and max jaccard overlaps 246 | Return: 247 | (img, boxes, classes) 248 | img (Image): the cropped image 249 | boxes (Tensor): the adjusted bounding boxes in pt form 250 | labels (Tensor): the class labels for each bbox 251 | """ 252 | def __init__(self): 253 | self.sample_options = ( 254 | # using entire original input image 255 | None, 256 | # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9 257 | (0.1, None), 258 | (0.3, None), 259 | (0.7, None), 260 | (0.9, None), 261 | # randomly sample a patch 262 | (None, None), 263 | ) 264 | 265 | def __call__(self, image, boxes=None, labels=None): 266 | height, width, _ = image.shape 267 | while True: 268 | # randomly choose a mode 269 | mode = random.choice(self.sample_options) 270 | if mode is None: 271 | return image, boxes, labels 272 | 273 | min_iou, max_iou = mode 274 | if min_iou is None: 275 | min_iou = float('-inf') 276 | if max_iou is None: 277 | max_iou = float('inf') 278 | 279 | # max trails (50) 280 | for _ in range(50): 281 | current_image = image 282 | 283 | w = random.uniform(0.3 * width, width) 284 | h = random.uniform(0.3 * height, height) 285 | 286 | # aspect ratio constraint b/t .5 & 2 287 | if h / w < 0.5 or h / w > 2: 288 | continue 289 | 290 | left = random.uniform(width - w) 291 | top = random.uniform(height - h) 292 | 293 | # convert to integer rect x1,y1,x2,y2 294 | rect = np.array([int(left), int(top), int(left+w), int(top+h)]) 295 | 296 | # calculate IoU (jaccard overlap) b/t the cropped and gt boxes 297 | overlap = jaccard_numpy(boxes, rect) 298 | 299 | # is min and max overlap constraint satisfied? if not try again 300 | if overlap.min() < min_iou and max_iou < overlap.max(): 301 | continue 302 | 303 | # cut the crop from the image 304 | current_image = current_image[rect[1]:rect[3], rect[0]:rect[2], 305 | :] 306 | 307 | # keep overlap with gt box IF center in sampled patch 308 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0 309 | 310 | # mask in all gt boxes that above and to the left of centers 311 | m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) 312 | 313 | # mask in all gt boxes that under and to the right of centers 314 | m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) 315 | 316 | # mask in that both m1 and m2 are true 317 | mask = m1 * m2 318 | 319 | # have any valid boxes? try again if not 320 | if not mask.any(): 321 | continue 322 | 323 | # take only matching gt boxes 324 | current_boxes = boxes[mask, :].copy() 325 | 326 | # take only matching gt labels 327 | current_labels = labels[mask] 328 | 329 | # should we use the box left and top corner or the crop's 330 | current_boxes[:, :2] = np.maximum(current_boxes[:, :2], 331 | rect[:2]) 332 | # adjust to crop (by substracting crop's left,top) 333 | current_boxes[:, :2] -= rect[:2] 334 | 335 | current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], 336 | rect[2:]) 337 | # adjust to crop (by substracting crop's left,top) 338 | current_boxes[:, 2:] -= rect[:2] 339 | 340 | return current_image, current_boxes, current_labels 341 | 342 | 343 | class Expand(object): 344 | """ Randomly expand image size and place original image within. 345 | Note: bounding boxes must be in absolute coordinates 346 | """ 347 | def __init__(self, mean): 348 | self.mean = mean 349 | 350 | def __call__(self, image, boxes, labels): 351 | if random.randint(2): 352 | return image, boxes, labels 353 | 354 | height, width, depth = image.shape 355 | ratio = random.uniform(1, 4) 356 | left = random.uniform(0, width*ratio - width) 357 | top = random.uniform(0, height*ratio - height) 358 | 359 | expand_image = np.zeros( 360 | (int(height*ratio), int(width*ratio), depth), 361 | dtype=image.dtype) 362 | expand_image[:, :, :] = self.mean 363 | expand_image[int(top):int(top + height), 364 | int(left):int(left + width)] = image 365 | image = expand_image 366 | 367 | boxes = boxes.copy() 368 | boxes[:, :2] += (int(left), int(top)) 369 | boxes[:, 2:] += (int(left), int(top)) 370 | 371 | return image, boxes, labels 372 | 373 | 374 | class RandomMirror(object): 375 | def __call__(self, image, boxes, classes): 376 | _, width, _ = image.shape 377 | if random.randint(2): 378 | image = image[:, ::-1] 379 | boxes = boxes.copy() 380 | boxes[:, 0::2] = width - boxes[:, 2::-2] 381 | return image, boxes, classes 382 | 383 | 384 | class SwapChannels(object): 385 | """Transforms a tensorized image by swapping the channels in the order 386 | specified in the swap tuple. 387 | Args: 388 | swaps (int triple): final order of channels 389 | eg: (2, 1, 0) 390 | """ 391 | 392 | def __init__(self, swaps): 393 | self.swaps = swaps 394 | 395 | def __call__(self, image): 396 | """ 397 | Args: 398 | image (Tensor): image tensor to be transformed 399 | Return: 400 | a tensor with channels swapped according to swap 401 | """ 402 | image = image[:, :, self.swaps] 403 | return image 404 | 405 | 406 | class PhotometricDistort(object): 407 | def __init__(self): 408 | self.pd = [ 409 | RandomContrast(), 410 | ConvertColor(transform='HSV'), 411 | RandomSaturation(), 412 | RandomHue(), 413 | ConvertColor(current='HSV', transform='RGB'), 414 | RandomContrast() 415 | ] 416 | self.rand_brightness = RandomBrightness() 417 | self.rand_light_noise = RandomLightingNoise() 418 | 419 | def __call__(self, image, boxes, labels): 420 | im = image.copy() 421 | im, boxes, labels = self.rand_brightness(im, boxes, labels) 422 | if random.randint(2): 423 | distort = Compose(self.pd[:-1]) 424 | else: 425 | distort = Compose(self.pd[1:]) 426 | im, boxes, labels = distort(im, boxes, labels) 427 | return self.rand_light_noise(im, boxes, labels) 428 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | """ helper utilities adapted from 2 | https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection 3 | """ 4 | 5 | import torch 6 | from math import sqrt 7 | 8 | from constants import DEVICE 9 | 10 | 11 | def decimate(tensor, m): 12 | """ 13 | Decimate a tensor by a factor 'm', i.e. downsample by keeping every 'm'th value. 14 | This is used when we convert FC layers to equivalent Convolutional layers, BUT of a smaller size. 15 | :param tensor: tensor to be decimated 16 | :param m: list of decimation factors for each dimension of the tensor; None if not to be decimated along a dimension 17 | :return: decimated tensor 18 | """ 19 | assert tensor.dim() == len(m) 20 | for d in range(tensor.dim()): 21 | if m[d] is not None: 22 | tensor = tensor.index_select( 23 | dim=d, 24 | index=torch.arange(start=0, end=tensor.size(d), step=m[d]).long()) 25 | 26 | return tensor 27 | 28 | 29 | def get_default_boxes(): 30 | """ Generate the 8732 default boxes following description in paper. 31 | These boxes are in centered format (Cx,Cy,W,H). 32 | 33 | Returns: tensor of dimensions (8732, 4) 34 | """ 35 | layers = ('conv4', 'conv7', 'conv8', 'conv9', 'conv10', 'conv11') 36 | conv_dims = (38, 19, 10, 5, 3, 1) 37 | s_min, s_max = 0.2, 0.9 38 | m = len(layers) 39 | 40 | assert m == len(conv_dims) 41 | 42 | # modified scaling to match paper (different from github implementations) 43 | box_scales = [s_min + (s_max - s_min) / (m - 1) * k for k in range(m)] 44 | 45 | aspect_ratios = ([1., 2., 0.5], 46 | [1., 2., 3., 0.5, .333], 47 | [1., 2., 3., 0.5, .333], 48 | [1., 2., 3., 0.5, .333], 49 | [1., 2., 0.5], 50 | [1., 2., 0.5]) 51 | 52 | default_boxes = [] 53 | for idx, n_px in enumerate(conv_dims): 54 | for i in range(n_px): 55 | for j in range(n_px): 56 | cx = (j + 0.5) / n_px 57 | cy = (i + 0.5) / n_px 58 | 59 | for ratio in aspect_ratios[idx]: 60 | default_boxes.append([ 61 | cx, 62 | cy, 63 | box_scales[idx] * sqrt(ratio), 64 | box_scales[idx] / sqrt(ratio)]) 65 | 66 | if ratio == 1.: 67 | # squares get size pairwise consecutive geometric means 68 | try: 69 | square_scale = sqrt(box_scales[idx] * box_scales[idx + 1]) 70 | # edge case of last feature map 71 | except IndexError: 72 | additional_scale = 1. 73 | default_boxes.append([cx, cy, square_scale, square_scale]) 74 | 75 | default_boxes = torch.FloatTensor(default_boxes).to(DEVICE) 76 | default_boxes.clamp_(0, 1) 77 | 78 | return default_boxes 79 | 80 | 81 | def xy_to_cxcy(xy): 82 | """ 83 | Convert bounding boxes from boundary coordinates (x_min, y_min, x_max, y_max) to center-size coordinates (c_x, c_y, w, h). 84 | :param xy: bounding boxes in boundary coordinates, a tensor of size (n_boxes, 4) 85 | :return: bounding boxes in center-size coordinates, a tensor of size (n_boxes, 4) 86 | """ 87 | return torch.cat([(xy[:, 2:] + xy[:, :2]) / 2, # c_x, c_y 88 | xy[:, 2:] - xy[:, :2]], 1) # w, h 89 | 90 | 91 | def cxcy_to_xy(cxcy): 92 | """ 93 | Convert bounding boxes from center-size coordinates (c_x, c_y, w, h) to boundary coordinates (x_min, y_min, x_max, y_max). 94 | :param cxcy: bounding boxes in center-size coordinates, a tensor of size (n_boxes, 4) 95 | :return: bounding boxes in boundary coordinates, a tensor of size (n_boxes, 4) 96 | """ 97 | return torch.cat([cxcy[:, :2] - (cxcy[:, 2:] / 2), # x_min, y_min 98 | cxcy[:, :2] + (cxcy[:, 2:] / 2)], 1) # x_max, y_max 99 | 100 | 101 | def cxcy_to_gcxgcy(cxcy, priors_cxcy): 102 | """ 103 | Encode bounding boxes (that are in center-size form) w.r.t. the corresponding prior boxes (that are in center-size form). 104 | For the center coordinates, find the offset with respect to the prior box, and scale by the size of the prior box. 105 | For the size coordinates, scale by the size of the prior box, and convert to the log-space. 106 | In the model, we are predicting bounding box coordinates in this encoded form. 107 | :param cxcy: bounding boxes in center-size coordinates, a tensor of size (n_priors, 4) 108 | :param priors_cxcy: prior boxes with respect to which the encoding must be performed, a tensor of size (n_priors, 4) 109 | :return: encoded bounding boxes, a tensor of size (n_priors, 4) 110 | """ 111 | 112 | # The 10 and 5 below are referred to as 'variances' in the original Caffe repo, completely empirical 113 | # They are for some sort of numerical conditioning, for 'scaling the localization gradient' 114 | # See https://github.com/weiliu89/caffe/issues/155 115 | return torch.cat([(cxcy[:, :2] - priors_cxcy[:, :2]) / (priors_cxcy[:, 2:] / 10), # g_c_x, g_c_y 116 | torch.log(cxcy[:, 2:] / priors_cxcy[:, 2:]) * 5], 1) # g_w, g_h 117 | 118 | 119 | def gcxgcy_to_cxcy(gcxgcy, priors_cxcy): 120 | """ 121 | Decode bounding box coordinates predicted by the model, since they are encoded in the form mentioned above. 122 | They are decoded into center-size coordinates. 123 | This is the inverse of the function above. 124 | :param gcxgcy: encoded bounding boxes, i.e. output of the model, a tensor of size (n_priors, 4) 125 | :param priors_cxcy: prior boxes with respect to which the encoding is defined, a tensor of size (n_priors, 4) 126 | :return: decoded bounding boxes in center-size form, a tensor of size (n_priors, 4) 127 | """ 128 | 129 | return torch.cat([gcxgcy[:, :2] * priors_cxcy[:, 2:] / 10 + priors_cxcy[:, :2], # c_x, c_y 130 | torch.exp(gcxgcy[:, 2:] / 5) * priors_cxcy[:, 2:]], 1) # w, h 131 | 132 | 133 | def find_intersection(set_1, set_2): 134 | """ 135 | Find the intersection of every box combination between two sets of boxes that are in boundary coordinates. 136 | :param set_1: set 1, a tensor of dimensions (n1, 4) 137 | :param set_2: set 2, a tensor of dimensions (n2, 4) 138 | :return: intersection of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2) 139 | """ 140 | 141 | # PyTorch auto-broadcasts singleton dimensions 142 | lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), set_2[:, :2].unsqueeze(0)) # (n1, n2, 2) 143 | upper_bounds = torch.min(set_1[:, 2:].unsqueeze(1), set_2[:, 2:].unsqueeze(0)) # (n1, n2, 2) 144 | intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0) # (n1, n2, 2) 145 | return intersection_dims[:, :, 0] * intersection_dims[:, :, 1] # (n1, n2) 146 | 147 | 148 | def find_jaccard_overlap(set_1, set_2): 149 | """ 150 | Find the Jaccard Overlap (IoU) of every box combination between two sets of boxes that are in boundary coordinates. 151 | :param set_1: set 1, a tensor of dimensions (n1, 4) 152 | :param set_2: set 2, a tensor of dimensions (n2, 4) 153 | :return: Jaccard Overlap of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2) 154 | """ 155 | 156 | # Find intersections 157 | intersection = find_intersection(set_1, set_2) # (n1, n2) 158 | 159 | # Find areas of each box in both sets 160 | areas_set_1 = (set_1[:, 2] - set_1[:, 0]) * (set_1[:, 3] - set_1[:, 1]) # (n1) 161 | areas_set_2 = (set_2[:, 2] - set_2[:, 0]) * (set_2[:, 3] - set_2[:, 1]) # (n2) 162 | 163 | # Find the union 164 | # PyTorch auto-broadcasts singleton dimensions 165 | union = areas_set_1.unsqueeze(1) + areas_set_2.unsqueeze(0) - intersection # (n1, n2) 166 | 167 | return intersection / union # (n1, n2) 168 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from datetime import datetime 4 | from collections import OrderedDict 5 | 6 | 7 | def ensure_dir(dirname): 8 | dirname = Path(dirname) 9 | if not dirname.is_dir(): 10 | dirname.mkdir(parents=True, exist_ok=False) 11 | 12 | def read_json(fname): 13 | with fname.open('rt') as handle: 14 | return json.load(handle, object_hook=OrderedDict) 15 | 16 | def write_json(content, fname): 17 | with fname.open('wt') as handle: 18 | json.dump(content, handle, indent=4, sort_keys=False) 19 | 20 | class Timer: 21 | def __init__(self): 22 | self.cache = datetime.now() 23 | 24 | def check(self): 25 | now = datetime.now() 26 | duration = now - self.cache 27 | self.cache = now 28 | return duration.total_seconds() 29 | 30 | def reset(self): 31 | self.cache = datetime.now() 32 | --------------------------------------------------------------------------------