├── .gitignore ├── README.md ├── asserts ├── demo.png └── retinanet.png ├── config.json ├── dataGen ├── __init__.py ├── compute_overlap.cpython-36m-x86_64-linux-gnu.so ├── compute_overlap.pyx ├── data_loader.py ├── setup.py ├── targetBuild.py └── utils.py ├── dataset ├── demo.ipynb ├── fold_data ├── trnbox_1.pkl ├── trnbox_2.pkl ├── trnbox_3.pkl ├── trnbox_4.pkl ├── trnbox_5.pkl ├── trnbox_6.pkl ├── trnfps_1.pkl ├── trnfps_2.pkl ├── trnfps_3.pkl ├── trnfps_4.pkl ├── trnfps_5.pkl ├── trnfps_6.pkl ├── trnlabel_1.pkl ├── trnlabel_2.pkl ├── trnlabel_3.pkl ├── trnlabel_4.pkl ├── trnlabel_5.pkl ├── trnlabel_6.pkl ├── valbox_1.pkl ├── valbox_2.pkl ├── valbox_3.pkl ├── valbox_4.pkl ├── valbox_5.pkl ├── valbox_6.pkl ├── valfps_1.pkl ├── valfps_2.pkl ├── valfps_3.pkl ├── valfps_4.pkl ├── valfps_5.pkl ├── valfps_6.pkl ├── vallabel_1.pkl ├── vallabel_2.pkl ├── vallabel_3.pkl ├── vallabel_4.pkl ├── vallabel_5.pkl └── vallabel_6.pkl ├── models ├── __init__.py ├── backbone.py ├── fpn.py ├── losses.py └── misc.py ├── prepare_data.ipynb ├── pretrainedmodels ├── .gitignore └── download_here.txt ├── requirements.txt ├── setup.sh ├── tests ├── __init__.py ├── test_boxinv.py ├── test_dataloader.py ├── test_filter.py ├── test_fpn.py ├── test_losses.py └── test_retinanet.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .ipynb_checkpoints 3 | checkpoints 4 | .vscode 5 | dataGen/__pycache__ 6 | models/__pycache__ 7 | tests/__pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## An Implentation of RetinaNet in Pytorch 2 | ![retinanet](asserts/retinanet.png) 3 | 4 | #### optional backbone: 5 | - [se_resnext50_32x4d](http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth) 6 | - [se_resnext101_32x4d](http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth) 7 | 8 | #### usage: 9 | `setup.sh`: compile cython code
10 | `config.json`: config file
11 | `train.py`: main train script
12 | `dataGen/data_loader.py`: Subclass **torch.utils.data.Dataset** or modify **class RsnaDataset** for your owen dataset.
13 | `prepare_data.ipynb`: data processing script for **RSNA Pneumonia Detection Challenge**
14 | `demo.ipynb`: sample code showing how to predict with **model.predict** method 15 | 16 | --- 17 | #### Application 18 | train this model with dataset of [RSNA Pneumonia Detection Challenge](https://www.kaggle.com/c/rsna-pneumonia-detection-challenge) 19 | ![demo image](asserts/demo.png) 20 | 21 | ### credits: 22 | 1. [Cadene pretrained-models.pytorch](https://github.com/Cadene/pretrained-models.pytorch) 23 | 2. [fizyr/keras-retinanet](https://github.com/fizyr/keras-retinanet) 24 | 3. [Squeeze-and-Excitation Networks](https://arxiv.org/pdf/1709.01507.pdf) 25 | 4. [Focal Loss for Dense Object Detection](https://arxiv.org/pdf/1708.02002.pdf) -------------------------------------------------------------------------------- /asserts/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/asserts/demo.png -------------------------------------------------------------------------------- /asserts/retinanet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/asserts/retinanet.png -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained_imagenet": true, 3 | "backbone": "se_resnext50_32x4d", 4 | "image_shape": [512, 512], 5 | "pyramid_levels_default": [3, 4, 5, 6, 7], 6 | "anchor_sizes_default": [32, 64, 128, 256, 512], 7 | "anchor_strides_default": [8, 16, 32, 64, 128], 8 | "anchor_ratios_default": [0.5, 1, 2], 9 | "anchor_scales_default": [1.0, 1.2599, 1.5874], 10 | 11 | "num_classes": 1, 12 | "mean_bbox_transform": [0.0, 0.0, 0.0, 0.0], 13 | "std_bbox_transform": [0.2, 0.2, 0.2, 0.2], 14 | "dicom_train": "dataset/stage_2_train_images", 15 | "RandomRotate": true, 16 | "RandomHorizontalFlip": true, 17 | "batch_size": 8, 18 | 19 | "loss_ratio_FL2L1": 1.0, 20 | "focal_alpha": 0.75, 21 | "l1_sigma": 3, 22 | 23 | "use_cuda": true, 24 | "num_workers": 4, 25 | "rsna_mean": 0.49, 26 | "rsna_std": 0.23 27 | } -------------------------------------------------------------------------------- /dataGen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/dataGen/__init__.py -------------------------------------------------------------------------------- /dataGen/compute_overlap.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/dataGen/compute_overlap.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /dataGen/compute_overlap.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Sergey Karayev 6 | # -------------------------------------------------------- 7 | 8 | cimport cython 9 | import numpy as np 10 | cimport numpy as np 11 | 12 | 13 | def compute_overlap( 14 | np.ndarray[double, ndim=2] boxes, 15 | np.ndarray[double, ndim=2] query_boxes 16 | ): 17 | """ 18 | Args 19 | a: (N, 4) ndarray of float 20 | b: (K, 4) ndarray of float 21 | 22 | Returns 23 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes 24 | """ 25 | cdef unsigned int N = boxes.shape[0] 26 | cdef unsigned int K = query_boxes.shape[0] 27 | cdef np.ndarray[double, ndim=2] overlaps = np.zeros((N, K), dtype=np.float64) 28 | cdef double iw, ih, box_area 29 | cdef double ua 30 | cdef unsigned int k, n 31 | for k in range(K): 32 | box_area = ( 33 | (query_boxes[k, 2] - query_boxes[k, 0] + 1) * 34 | (query_boxes[k, 3] - query_boxes[k, 1] + 1) 35 | ) 36 | for n in range(N): 37 | iw = ( 38 | min(boxes[n, 2], query_boxes[k, 2]) - 39 | max(boxes[n, 0], query_boxes[k, 0]) + 1 40 | ) 41 | if iw > 0: 42 | ih = ( 43 | min(boxes[n, 3], query_boxes[k, 3]) - 44 | max(boxes[n, 1], query_boxes[k, 1]) + 1 45 | ) 46 | if ih > 0: 47 | ua = np.float64( 48 | (boxes[n, 2] - boxes[n, 0] + 1) * 49 | (boxes[n, 3] - boxes[n, 1] + 1) + 50 | box_area - iw * ih 51 | ) 52 | overlaps[n, k] = iw * ih / ua 53 | return overlaps 54 | -------------------------------------------------------------------------------- /dataGen/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import pydicom 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | from torchvision import transforms 9 | from albumentations import HorizontalFlip, Rotate, Resize, Compose 10 | from models.backbone import pretrained_settings 11 | from .targetBuild import anchor_targets_bbox, anchors_for_shape 12 | 13 | config_path = os.path.join(os.path.dirname(__file__), '..', 'config.json') 14 | with open(config_path, 'r') as f: 15 | config = json.load(f) 16 | 17 | 18 | def load_dicom(img_id): 19 | image_path = os.path.join(os.path.dirname(__file__), '..', config['dicom_train'], img_id+'.dcm') 20 | ds = pydicom.read_file(image_path) 21 | image = ds.pixel_array 22 | # If grayscale. Convert to RGB for consistency. 23 | if len(image.shape) != 3 or image.shape[2] != 3: 24 | image = np.stack((image,) * 3, -1) 25 | return image 26 | 27 | 28 | def get_aug(aug, min_area=0., min_visibility=0.): 29 | return Compose(aug, bbox_params={'format': 'pascal_voc', 'min_area': min_area, 'min_visibility': min_visibility, 'label_fields': ['category_id']}) 30 | 31 | 32 | class RsnaDataset(Dataset): 33 | """ 34 | A standard PyTorch definition of Dataset which defines the functions __len__ and __getitem__. 35 | """ 36 | 37 | def __init__(self, filenames, gt_bboxes, gt_catids, aug=None): 38 | """ 39 | Store the filenames of the jpgs to use. Specifies transforms to apply on images. 40 | 41 | Args: 42 | data_dir: (string) directory containing the dataset 43 | transform: (torchvision.transforms) transformation to apply on image 44 | """ 45 | self.filenames = filenames 46 | self.gt_bboxes = gt_bboxes 47 | self.gt_catids = gt_catids 48 | self.aug = get_aug(aug) if aug is not None else get_aug([]) 49 | model_name = config['backbone'] 50 | # img_mean = pretrained_settings[model_name]['imagenet']['mean'] 51 | # img_std = pretrained_settings[model_name]['imagenet']['std'] 52 | self.img_tfs = transforms.Compose([transforms.ToTensor(), transforms.Normalize( 53 | mean=[config['rsna_mean']]*3, std=[config['rsna_std']]*3)]) 54 | self.anchors = anchors_for_shape(config['image_shape']) 55 | 56 | def __len__(self): 57 | # return size of dataset 58 | return len(self.filenames) 59 | 60 | def __getitem__(self, idx): 61 | """ 62 | Fetch index idx image and labels from dataset. Perform transforms on image. 63 | 64 | Args: 65 | idx: (int) index in [0, 1, ..., size_of_dataset-1] 66 | 67 | Returns: 68 | image: (Tensor) transformed image 69 | label: (int) corresponding label of image 70 | """ 71 | fps = self.filenames[idx] 72 | image = load_dicom(fps) 73 | bboxes = self.gt_bboxes[fps] 74 | category_id = self.gt_catids[fps] 75 | augmented = self.aug(image=image, bboxes=bboxes, category_id=category_id) 76 | if not augmented['bboxes']: 77 | gt_annos = np.zeros((0, 5), dtype=np.float32) 78 | else: 79 | gt_annos = np.empty((len(augmented['bboxes']), 5), dtype=np.float32) 80 | gt_annos[:, :4] = augmented['bboxes'] 81 | gt_annos[:, 4] = augmented['category_id'] 82 | img = self.img_tfs(augmented['image']) 83 | labels, regression = anchor_targets_bbox(self.anchors, gt_annos) 84 | labels, regression = torch.tensor(labels), torch.tensor(regression) 85 | return img, labels, regression 86 | 87 | trn_aug = [] 88 | if config['image_shape']: 89 | trn_aug.append(Resize(*config['image_shape'], p=1.0)) 90 | val_aug = [Resize(*config['image_shape'], p=1.0)] 91 | else: 92 | val_aug = [] 93 | 94 | if config['RandomRotate']: 95 | trn_aug.append(Rotate(limit=15, p=0.5)) 96 | 97 | if config['RandomHorizontalFlip']: 98 | trn_aug.append(HorizontalFlip(p=0.5)) 99 | 100 | 101 | def fetch_trn_loader(kfold, trnfps=None, bboxdict=None, labeldict=None, aug=None): 102 | if trnfps is None: 103 | trnfps_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'trnfps_{}.pkl'.format(kfold)) 104 | with open(trnfps_path, 'rb') as f: 105 | trnfps = pickle.load(f) 106 | 107 | if bboxdict is None: 108 | bboxdict_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'trnbox_{}.pkl'.format(kfold)) 109 | with open(bboxdict_path, 'rb') as f: 110 | bboxdict = pickle.load(f) 111 | 112 | if labeldict is None: 113 | labeldict_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'trnlabel_{}.pkl'.format(kfold)) 114 | with open(labeldict_path, 'rb') as f: 115 | labeldict = pickle.load(f) 116 | 117 | if aug is None: 118 | aug = trn_aug 119 | 120 | dataset = RsnaDataset(trnfps, bboxdict, labeldict, aug) 121 | return DataLoader(dataset, 122 | batch_size=config['batch_size'], 123 | shuffle=True, 124 | num_workers=config['num_workers'], 125 | pin_memory=config['use_cuda']) 126 | 127 | 128 | def fetch_val_loader(kfold, valfps=None, bboxdict=None, labeldict=None, aug=None): 129 | if valfps is None: 130 | valfps_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'valfps_{}.pkl'.format(kfold)) 131 | with open(valfps_path, 'rb') as f: 132 | valfps = pickle.load(f) 133 | 134 | if bboxdict is None: 135 | bboxdict_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'valbox_{}.pkl'.format(kfold)) 136 | with open(bboxdict_path, 'rb') as f: 137 | bboxdict = pickle.load(f) 138 | 139 | if labeldict is None: 140 | labeldict_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'vallabel_{}.pkl'.format(kfold)) 141 | with open(labeldict_path, 'rb') as f: 142 | labeldict = pickle.load(f) 143 | 144 | if aug is None: 145 | aug = val_aug 146 | 147 | dataset = RsnaDataset(valfps, bboxdict, labeldict, aug) 148 | return DataLoader(dataset, 149 | batch_size=config['batch_size'], 150 | shuffle=False, 151 | num_workers=config['num_workers'], 152 | pin_memory=config['use_cuda']) 153 | 154 | -------------------------------------------------------------------------------- /dataGen/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from setuptools.extension import Extension 3 | import numpy as np 4 | from Cython.Build import cythonize 5 | 6 | extensions = [ 7 | Extension( 8 | 'dataGen.compute_overlap', 9 | ['dataGen/compute_overlap.pyx'], 10 | include_dirs=[np.get_include()] 11 | ), 12 | ] 13 | 14 | setuptools.setup( 15 | name='dataGen', 16 | packages=setuptools.find_packages(), 17 | # same with `ext_modules=extensions`, 18 | ext_modules=cythonize(extensions), 19 | setup_requires=["cython>=0.28"] 20 | ) 21 | -------------------------------------------------------------------------------- /dataGen/targetBuild.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from .compute_overlap import compute_overlap 5 | 6 | config_path = os.path.join(os.path.dirname(__file__), '..', 'config.json') 7 | with open(config_path, 'r') as f: 8 | config = json.load(f) 9 | 10 | 11 | class AnchorParameters: 12 | """ The parameteres that define how anchors are generated. 13 | 14 | Args 15 | sizes : List of sizes to use. Each size corresponds to one feature level. 16 | strides : List of strides to use. Each stride correspond to one feature level. 17 | ratios : List of ratios to use per location in a feature map. 18 | scales : List of scales to use per location in a feature map. 19 | """ 20 | 21 | def __init__(self, sizes, strides, ratios, scales): 22 | self.sizes = sizes 23 | self.strides = strides 24 | self.ratios = ratios 25 | self.scales = scales 26 | 27 | def num_anchors(self): 28 | return len(self.ratios) * len(self.scales) 29 | 30 | 31 | """ 32 | The default anchor parameters. 33 | """ 34 | AnchorParameters.default = AnchorParameters( 35 | sizes=config['anchor_sizes_default'], 36 | strides=config['anchor_strides_default'], 37 | ratios=np.array(config['anchor_ratios_default'], np.float32), 38 | scales=np.array(config['anchor_scales_default'], np.float32), 39 | ) 40 | 41 | 42 | def generate_anchors(base_size=16, ratios=None, scales=None): 43 | """ 44 | Generate anchor (reference) windows by enumerating aspect ratios X 45 | scales w.r.t. a reference window. 46 | """ 47 | 48 | if ratios is None: 49 | ratios = AnchorParameters.default.ratios 50 | 51 | if scales is None: 52 | scales = AnchorParameters.default.scales 53 | 54 | num_anchors = len(ratios) * len(scales) 55 | 56 | # initialize output anchors 57 | anchors = np.zeros((num_anchors, 4), dtype=np.float32) 58 | 59 | # scale base_size 60 | anchors[:, 2:] = base_size * np.tile(scales, (2, len(ratios))).T 61 | 62 | # compute areas of anchors 63 | areas = anchors[:, 2] * anchors[:, 3] 64 | 65 | # correct for ratios 66 | anchors[:, 2] = np.sqrt(areas / np.repeat(ratios, len(scales))) 67 | anchors[:, 3] = anchors[:, 2] * np.repeat(ratios, len(scales)) 68 | 69 | # transform from (x_ctr, y_ctr, w, h) -> (x1, y1, x2, y2) 70 | anchors[:, 0::2] -= np.tile(anchors[:, 2] * 0.5, (2, 1)).T 71 | anchors[:, 1::2] -= np.tile(anchors[:, 3] * 0.5, (2, 1)).T 72 | 73 | return anchors 74 | 75 | 76 | def guess_shapes(image_shape, pyramid_levels=None): 77 | """Guess shapes based on pyramid levels. 78 | 79 | Args 80 | image_shape: The shape of the image. 81 | pyramid_levels: A list of what pyramid levels are used. 82 | 83 | Returns 84 | A list of image shapes at each pyramid level. 85 | """ 86 | if pyramid_levels is None: 87 | pyramid_levels = config['pyramid_levels_default'] 88 | image_shape = np.array(image_shape) 89 | image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in pyramid_levels] 90 | return image_shapes 91 | 92 | 93 | def shift(shape, stride, anchors): 94 | """ Produce shifted anchors based on shape of the map and stride size. 95 | 96 | Args 97 | shape : Shape to shift the anchors over. 98 | stride : Stride to shift the anchors with over the shape. 99 | anchors: The anchors to apply at each location. 100 | """ 101 | 102 | # create a grid starting from half stride from the top left corner 103 | shift_x = (np.arange(0, shape[1]) + 0.5) * stride 104 | shift_y = (np.arange(0, shape[0]) + 0.5) * stride 105 | 106 | shift_x, shift_y = np.meshgrid(shift_x, shift_y) 107 | 108 | shifts = np.vstack(( 109 | shift_x.ravel(), shift_y.ravel(), 110 | shift_x.ravel(), shift_y.ravel() 111 | )).transpose() 112 | 113 | # add A anchors (1, A, 4) to 114 | # cell K shifts (K, 1, 4) to get 115 | # shift anchors (K, A, 4) 116 | # reshape to (K*A, 4) shifted anchors 117 | A = anchors.shape[0] 118 | K = shifts.shape[0] 119 | all_anchors = (anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2))) 120 | all_anchors = all_anchors.reshape((K * A, 4)) 121 | 122 | return all_anchors 123 | 124 | 125 | def anchors_for_shape( 126 | image_shape, 127 | pyramid_levels=None, 128 | anchor_params=None, 129 | ): 130 | """ Generators anchors for a given shape. 131 | 132 | Args 133 | image_shape: The shape of the image. 134 | pyramid_levels: List of ints representing which pyramids to use (defaults to [3, 4, 5, 6, 7]). 135 | anchor_params: Struct containing anchor parameters. If None, default values are used. 136 | 137 | Returns 138 | np.array of shape (N, 4) containing the (x1, y1, x2, y2) coordinates for the anchors. 139 | """ 140 | 141 | if pyramid_levels is None: 142 | pyramid_levels = config['pyramid_levels_default'] 143 | 144 | if anchor_params is None: 145 | anchor_params = AnchorParameters.default 146 | 147 | image_shapes = guess_shapes(image_shape, pyramid_levels) 148 | 149 | # compute anchors over all pyramid levels 150 | all_anchors = np.zeros((0, 4)) 151 | for idx, p in enumerate(pyramid_levels): 152 | anchors = generate_anchors( 153 | base_size=anchor_params.sizes[idx], 154 | ratios=anchor_params.ratios, 155 | scales=anchor_params.scales 156 | ) 157 | shifted_anchors = shift(image_shapes[idx], anchor_params.strides[idx], anchors) 158 | all_anchors = np.append(all_anchors, shifted_anchors, axis=0) 159 | 160 | return all_anchors 161 | 162 | 163 | def compute_gt_annotations( 164 | anchors, 165 | annotations, 166 | negative_overlap=0.4, 167 | positive_overlap=0.5 168 | ): 169 | """ Obtain indices of gt annotations with the greatest overlap. 170 | 171 | Args 172 | anchors: np.array of annotations of shape (N, 4) for (x1, y1, x2, y2). 173 | annotations: np.array of shape (N, 5) for (x1, y1, x2, y2, label). 174 | negative_overlap: IoU overlap for negative anchors (all anchors with overlap < negative_overlap are negative). 175 | positive_overlap: IoU overlap or positive anchors (all anchors with overlap > positive_overlap are positive). 176 | 177 | Returns 178 | positive_indices: indices of positive anchors 179 | ignore_indices: indices of ignored anchors 180 | argmax_overlaps_inds: ordered overlaps indices 181 | """ 182 | 183 | overlaps = compute_overlap(anchors.astype(np.float64), annotations.astype(np.float64)) 184 | argmax_overlaps_inds = np.argmax(overlaps, axis=1) 185 | max_overlaps = overlaps[np.arange(overlaps.shape[0]), argmax_overlaps_inds] 186 | 187 | # assign "dont care" labels 188 | positive_indices = max_overlaps >= positive_overlap 189 | ignore_indices = (max_overlaps > negative_overlap) & ~positive_indices 190 | 191 | return positive_indices, ignore_indices, argmax_overlaps_inds 192 | 193 | 194 | def bbox_transform(anchors, gt_boxes, mean=None, std=None): 195 | """Compute bounding-box regression targets for an image.""" 196 | 197 | if mean is None: 198 | mean = np.array(config['mean_bbox_transform'], dtype=np.float32) 199 | if std is None: 200 | std = np.array(config['std_bbox_transform'], dtype=np.float32) 201 | 202 | if isinstance(mean, (list, tuple)): 203 | mean = np.array(mean) 204 | elif not isinstance(mean, np.ndarray): 205 | raise ValueError('Expected mean to be a np.ndarray, list or tuple. Received: {}'.format(type(mean))) 206 | 207 | if isinstance(std, (list, tuple)): 208 | std = np.array(std) 209 | elif not isinstance(std, np.ndarray): 210 | raise ValueError('Expected std to be a np.ndarray, list or tuple. Received: {}'.format(type(std))) 211 | 212 | anchor_widths = anchors[:, 2] - anchors[:, 0] 213 | anchor_heights = anchors[:, 3] - anchors[:, 1] 214 | 215 | targets_dx1 = (gt_boxes[:, 0] - anchors[:, 0]) / anchor_widths 216 | targets_dy1 = (gt_boxes[:, 1] - anchors[:, 1]) / anchor_heights 217 | targets_dx2 = (gt_boxes[:, 2] - anchors[:, 2]) / anchor_widths 218 | targets_dy2 = (gt_boxes[:, 3] - anchors[:, 3]) / anchor_heights 219 | 220 | targets = np.stack((targets_dx1, targets_dy1, targets_dx2, targets_dy2)) 221 | targets = targets.T 222 | 223 | targets = (targets - mean) / std 224 | 225 | return targets 226 | 227 | 228 | def anchor_targets_bbox( 229 | anchors, 230 | annotations, 231 | num_classes=None, 232 | negative_overlap=0.4, 233 | positive_overlap=0.5 234 | ): 235 | """ Generate anchor targets for bbox detection. 236 | 237 | Args 238 | anchors: np.array of annotations of shape (N, 4) for (x1, y1, x2, y2). 239 | annotations: annotations (np.array of shape (N, 5) for (x1, y1, x2, y2, label)). 240 | num_classes: Number of classes to predict. 241 | negative_overlap: IoU overlap for negative anchors (all anchors with overlap < negative_overlap are negative). 242 | positive_overlap: IoU overlap or positive anchors (all anchors with overlap > positive_overlap are positive). 243 | 244 | Returns 245 | labels: that contains labels & anchor states (np.array of shape (N, num_classes + 1), 246 | where N is the number of anchors for an image and 247 | the last column defines the anchor state (-1 for ignore, 0 for bg, 1 for fg). 248 | regression: that contains bounding-box regression targets for an image & anchor states (np.array of shape (N, 4 + 1), 249 | where N is the number of anchors for an image, the first 4 columns define regression targets for (x1, y1, x2, y2) 250 | and the last column defines anchor states (-1 for ignore, 0 for bg, 1 for fg). 251 | """ 252 | if num_classes is None: 253 | num_classes = config['num_classes'] 254 | regression = np.zeros((anchors.shape[0], 4 + 1), dtype=np.float32) 255 | labels = np.zeros((anchors.shape[0], num_classes + 1), dtype=np.float32) 256 | 257 | # compute labels and regression targets 258 | if annotations.shape[0]: 259 | # obtain indices of gt annotations with the greatest overlap 260 | positive_indices, ignore_indices, argmax_overlaps_inds = compute_gt_annotations(anchors, annotations, 261 | negative_overlap, 262 | positive_overlap) 263 | labels[ignore_indices, -1] = -1 264 | labels[positive_indices, -1] = 1 265 | 266 | regression[ignore_indices, -1] = -1 267 | regression[positive_indices, -1] = 1 268 | 269 | # compute box regression targets 270 | annotations = annotations[argmax_overlaps_inds] 271 | 272 | # compute target class labels 273 | labels[positive_indices, annotations[positive_indices, 4].astype(int)] = 1 274 | 275 | regression[:, :-1] = bbox_transform(anchors, annotations) 276 | 277 | return labels, regression 278 | -------------------------------------------------------------------------------- /dataGen/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def iou(box1, box2): 4 | """ 5 | From Yicheng Chen's "Mean Average Precision Metric" 6 | https://www.kaggle.com/chenyc15/mean-average-precision-metric 7 | 8 | helper function to calculate IoU 9 | """ 10 | x11, y11, x12, y12 = box1 11 | x21, y21, x22, y22 = box2 12 | w1, h1 = x12-x11, y12-y11 13 | w2, h2 = x22-x21, y22-y21 14 | 15 | area1, area2 = w1 * h1, w2 * h2 16 | xi1, yi1, xi2, yi2 = max([x11, x21]), max( 17 | [y11, y21]), min([x12, x22]), min([y12, y22]) 18 | 19 | if xi2 <= xi1 or yi2 <= yi1: 20 | return 0 21 | else: 22 | intersect = (xi2-xi1) * (yi2-yi1) 23 | union = area1 + area2 - intersect 24 | return intersect / union 25 | 26 | 27 | def nms(boxes, scores, overlapThresh): 28 | """ 29 | adapted from non-maximum suppression by Adrian Rosebrock 30 | https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/ 31 | """ 32 | 33 | # if there are no boxes, return an empty list 34 | if len(boxes) == 0: 35 | return np.array([]).reshape(0, 4), np.array([]) 36 | if boxes.dtype.kind == "i": 37 | boxes = boxes.astype("float") 38 | 39 | pick = [] 40 | x1 = boxes[:, 0] 41 | y1 = boxes[:, 1] 42 | x2 = boxes[:, 2] 43 | y2 = boxes[:, 3] 44 | 45 | # compute the area of the bounding boxes 46 | area = (x2 - x1 + 1) * (y2 - y1 + 1) 47 | 48 | # sort the bounding boxes by scores in ascending order 49 | idxs = np.argsort(scores) 50 | 51 | # keep looping while indexes still remain in the indexes list 52 | while len(idxs) > 0: 53 | # grab the last index in the indexes list and add the 54 | # index value to the list of picked indexes 55 | last = len(idxs) - 1 56 | i = idxs[last] 57 | pick.append(i) 58 | 59 | # find the largest (x, y) coordinates for the start of 60 | # the bounding box and the smallest (x, y) coordinates 61 | # for the end of the bounding box 62 | xx1 = np.maximum(x1[i], x1[idxs[:last]]) 63 | yy1 = np.maximum(y1[i], y1[idxs[:last]]) 64 | xx2 = np.minimum(x2[i], x2[idxs[:last]]) 65 | yy2 = np.minimum(y2[i], y2[idxs[:last]]) 66 | 67 | # compute the width and height of the bounding box 68 | w = np.maximum(0, xx2 - xx1 + 1) 69 | h = np.maximum(0, yy2 - yy1 + 1) 70 | 71 | # compute the ratio of overlap 72 | overlap = (w * h) / area[idxs[:last]] 73 | 74 | # delete all indexes from the index list that have 75 | idxs = np.delete(idxs, np.concatenate(([last], 76 | np.where(overlap > overlapThresh)[0]))) 77 | 78 | # return only the bounding boxes that were picked using the 79 | # integer data type 80 | return boxes[pick], scores[pick] 81 | -------------------------------------------------------------------------------- /dataset: -------------------------------------------------------------------------------- 1 | /home/raytroop/sand/rsna-pneumonia-detection-challenge_stg1 -------------------------------------------------------------------------------- /fold_data/trnbox_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnbox_1.pkl -------------------------------------------------------------------------------- /fold_data/trnbox_2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnbox_2.pkl -------------------------------------------------------------------------------- /fold_data/trnbox_3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnbox_3.pkl -------------------------------------------------------------------------------- /fold_data/trnbox_4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnbox_4.pkl -------------------------------------------------------------------------------- /fold_data/trnbox_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnbox_5.pkl -------------------------------------------------------------------------------- /fold_data/trnbox_6.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnbox_6.pkl -------------------------------------------------------------------------------- /fold_data/trnfps_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnfps_1.pkl -------------------------------------------------------------------------------- /fold_data/trnfps_2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnfps_2.pkl -------------------------------------------------------------------------------- /fold_data/trnfps_3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnfps_3.pkl -------------------------------------------------------------------------------- /fold_data/trnfps_4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnfps_4.pkl -------------------------------------------------------------------------------- /fold_data/trnfps_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnfps_5.pkl -------------------------------------------------------------------------------- /fold_data/trnfps_6.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnfps_6.pkl -------------------------------------------------------------------------------- /fold_data/trnlabel_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnlabel_1.pkl -------------------------------------------------------------------------------- /fold_data/trnlabel_2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnlabel_2.pkl -------------------------------------------------------------------------------- /fold_data/trnlabel_3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnlabel_3.pkl -------------------------------------------------------------------------------- /fold_data/trnlabel_4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnlabel_4.pkl -------------------------------------------------------------------------------- /fold_data/trnlabel_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnlabel_5.pkl -------------------------------------------------------------------------------- /fold_data/trnlabel_6.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnlabel_6.pkl -------------------------------------------------------------------------------- /fold_data/valbox_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valbox_1.pkl -------------------------------------------------------------------------------- /fold_data/valbox_2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valbox_2.pkl -------------------------------------------------------------------------------- /fold_data/valbox_3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valbox_3.pkl -------------------------------------------------------------------------------- /fold_data/valbox_4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valbox_4.pkl -------------------------------------------------------------------------------- /fold_data/valbox_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valbox_5.pkl -------------------------------------------------------------------------------- /fold_data/valbox_6.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valbox_6.pkl -------------------------------------------------------------------------------- /fold_data/valfps_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valfps_1.pkl -------------------------------------------------------------------------------- /fold_data/valfps_2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valfps_2.pkl -------------------------------------------------------------------------------- /fold_data/valfps_3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valfps_3.pkl -------------------------------------------------------------------------------- /fold_data/valfps_4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valfps_4.pkl -------------------------------------------------------------------------------- /fold_data/valfps_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valfps_5.pkl -------------------------------------------------------------------------------- /fold_data/valfps_6.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valfps_6.pkl -------------------------------------------------------------------------------- /fold_data/vallabel_1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/vallabel_1.pkl -------------------------------------------------------------------------------- /fold_data/vallabel_2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/vallabel_2.pkl -------------------------------------------------------------------------------- /fold_data/vallabel_3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/vallabel_3.pkl -------------------------------------------------------------------------------- /fold_data/vallabel_4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/vallabel_4.pkl -------------------------------------------------------------------------------- /fold_data/vallabel_5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/vallabel_5.pkl -------------------------------------------------------------------------------- /fold_data/vallabel_6.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/vallabel_6.pkl -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/models/__init__.py -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResNet code gently borrowed from 3 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 4 | """ 5 | import os 6 | from collections import OrderedDict 7 | import math 8 | 9 | import torch.nn as nn 10 | import torch 11 | 12 | 13 | pretrained_settings = { 14 | 'se_resnext50_32x4d': { 15 | 'imagenet': { 16 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', 17 | 'input_space': 'RGB', 18 | 'input_size': [3, 224, 224], 19 | 'input_range': [0, 1], 20 | 'mean': [0.485, 0.456, 0.406], 21 | 'std': [0.229, 0.224, 0.225], 22 | 'num_classes': 1000 23 | } 24 | }, 25 | 'se_resnext101_32x4d': { 26 | 'imagenet': { 27 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', 28 | 'input_space': 'RGB', 29 | 'input_size': [3, 224, 224], 30 | 'input_range': [0, 1], 31 | 'mean': [0.485, 0.456, 0.406], 32 | 'std': [0.229, 0.224, 0.225], 33 | 'num_classes': 1000 34 | } 35 | }, 36 | } 37 | 38 | 39 | class SEModule(nn.Module): 40 | 41 | def __init__(self, channels, reduction): 42 | super(SEModule, self).__init__() 43 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 44 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, 45 | padding=0) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, 48 | padding=0) 49 | self.sigmoid = nn.Sigmoid() 50 | 51 | def forward(self, x): 52 | module_input = x 53 | x = self.avg_pool(x) 54 | x = self.fc1(x) 55 | x = self.relu(x) 56 | x = self.fc2(x) 57 | x = self.sigmoid(x) 58 | return module_input * x 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | """ 63 | Base class for bottlenecks that implements `forward()` method. 64 | """ 65 | def forward(self, x): 66 | residual = x 67 | 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv3(out) 77 | out = self.bn3(out) 78 | 79 | if self.downsample is not None: 80 | residual = self.downsample(x) 81 | 82 | out = self.se_module(out) + residual 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | 88 | class SEResNeXtBottleneck(Bottleneck): 89 | """ 90 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module. 91 | """ 92 | expansion = 4 93 | 94 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 95 | downsample=None, base_width=4): 96 | super(SEResNeXtBottleneck, self).__init__() 97 | width = math.floor(planes * (base_width / 64)) * groups 98 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, 99 | stride=1) 100 | self.bn1 = nn.BatchNorm2d(width) 101 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, 102 | padding=1, groups=groups, bias=False) 103 | self.bn2 = nn.BatchNorm2d(width) 104 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) 105 | self.bn3 = nn.BatchNorm2d(planes * 4) 106 | self.relu = nn.ReLU(inplace=True) 107 | self.se_module = SEModule(planes * 4, reduction=reduction) 108 | self.downsample = downsample 109 | self.stride = stride 110 | 111 | 112 | class BackBone(nn.Module): 113 | 114 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2, 115 | inplanes=128, input_3x3=True, downsample_kernel_size=3, 116 | downsample_padding=1): 117 | """ 118 | Parameters 119 | ---------- 120 | block (nn.Module): Bottleneck class. 121 | - For SE-ResNeXt models: SEResNeXtBottleneck 122 | layers (list of ints): Number of residual blocks for 4 layers of the 123 | network (layer1...layer4). 124 | groups (int): Number of groups for the 3x3 convolution in each 125 | bottleneck block. 126 | - For BackBone154: 64 127 | - For SE-ResNet models: 1 128 | - For SE-ResNeXt models: 32 129 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules. 130 | - For all models: 16 131 | dropout_p (float or None): Drop probability for the Dropout layer. 132 | If `None` the Dropout layer is not used. 133 | - For BackBone154: 0.2 134 | - For SE-ResNet models: None 135 | - For SE-ResNeXt models: None 136 | inplanes (int): Number of input channels for layer1. 137 | - For BackBone154: 128 138 | - For SE-ResNet models: 64 139 | - For SE-ResNeXt models: 64 140 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of 141 | a single 7x7 convolution in layer0. 142 | - For BackBone154: True 143 | - For SE-ResNet models: False 144 | - For SE-ResNeXt models: False 145 | downsample_kernel_size (int): Kernel size for downsampling convolutions 146 | in layer2, layer3 and layer4. 147 | - For BackBone154: 3 148 | - For SE-ResNet models: 1 149 | - For SE-ResNeXt models: 1 150 | downsample_padding (int): Padding for downsampling convolutions in 151 | layer2, layer3 and layer4. 152 | - For BackBone154: 1 153 | - For SE-ResNet models: 0 154 | - For SE-ResNeXt models: 0 155 | num_classes (int): Number of outputs in `last_linear` layer. 156 | - For all models: 1000 157 | """ 158 | super(BackBone, self).__init__() 159 | self.inplanes = inplanes 160 | if input_3x3: 161 | layer0_modules = [ 162 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1, 163 | bias=False)), 164 | ('bn1', nn.BatchNorm2d(64)), 165 | ('relu1', nn.ReLU(inplace=True)), 166 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, 167 | bias=False)), 168 | ('bn2', nn.BatchNorm2d(64)), 169 | ('relu2', nn.ReLU(inplace=True)), 170 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, 171 | bias=False)), 172 | ('bn3', nn.BatchNorm2d(inplanes)), 173 | ('relu3', nn.ReLU(inplace=True)), 174 | ] 175 | else: 176 | layer0_modules = [ 177 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2, 178 | padding=3, bias=False)), 179 | ('bn1', nn.BatchNorm2d(inplanes)), 180 | ('relu1', nn.ReLU(inplace=True)), 181 | ] 182 | # To preserve compatibility with Caffe weights `ceil_mode=True` 183 | # is used instead of `padding=1`. 184 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2, 185 | ceil_mode=True))) 186 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) 187 | self.layer1 = self._make_layer( 188 | block, 189 | planes=64, 190 | blocks=layers[0], 191 | groups=groups, 192 | reduction=reduction, 193 | downsample_kernel_size=1, 194 | downsample_padding=0 195 | ) 196 | self.layer2 = self._make_layer( 197 | block, 198 | planes=128, 199 | blocks=layers[1], 200 | stride=2, 201 | groups=groups, 202 | reduction=reduction, 203 | downsample_kernel_size=downsample_kernel_size, 204 | downsample_padding=downsample_padding 205 | ) 206 | self.layer3 = self._make_layer( 207 | block, 208 | planes=256, 209 | blocks=layers[2], 210 | stride=2, 211 | groups=groups, 212 | reduction=reduction, 213 | downsample_kernel_size=downsample_kernel_size, 214 | downsample_padding=downsample_padding 215 | ) 216 | self.layer4 = self._make_layer( 217 | block, 218 | planes=512, 219 | blocks=layers[3], 220 | stride=2, 221 | groups=groups, 222 | reduction=reduction, 223 | downsample_kernel_size=downsample_kernel_size, 224 | downsample_padding=downsample_padding 225 | ) 226 | 227 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, 228 | downsample_kernel_size=1, downsample_padding=0): 229 | downsample = None 230 | if stride != 1 or self.inplanes != planes * block.expansion: 231 | downsample = nn.Sequential( 232 | nn.Conv2d(self.inplanes, planes * block.expansion, 233 | kernel_size=downsample_kernel_size, stride=stride, 234 | padding=downsample_padding, bias=False), 235 | nn.BatchNorm2d(planes * block.expansion), 236 | ) 237 | 238 | layers = [] 239 | layers.append(block(self.inplanes, planes, groups, reduction, stride, 240 | downsample)) 241 | self.inplanes = planes * block.expansion 242 | for i in range(1, blocks): 243 | layers.append(block(self.inplanes, planes, groups, reduction)) 244 | 245 | return nn.Sequential(*layers) 246 | 247 | def forward(self, x): 248 | x = self.layer0(x) 249 | C2 = self.layer1(x) 250 | C3 = self.layer2(C2) 251 | C4 = self.layer3(C3) 252 | C5 = self.layer4(C4) 253 | return [C3, C4, C5] 254 | 255 | 256 | def initialize_pretrained_model(model, settings): 257 | weights_path = os.path.join(os.path.join(os.path.dirname(__file__), '..', 258 | 'pretrainedmodels', os.path.basename(settings['url']))) 259 | state_dict = torch.load(weights_path) 260 | # filter out last layer 261 | for layer in ['last_linear.weight', 'last_linear.bias']: 262 | del state_dict[layer] 263 | state_dict_cur = model.state_dict() 264 | state_dict_cur.update(state_dict_cur) 265 | 266 | model.load_state_dict(state_dict_cur) 267 | model.input_space = settings['input_space'] 268 | model.input_size = settings['input_size'] 269 | model.input_range = settings['input_range'] 270 | model.mean = settings['mean'] 271 | model.std = settings['std'] 272 | 273 | 274 | def se_resnext50_32x4d(pretrained_imagenet): 275 | model = BackBone(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, 276 | dropout_p=None, inplanes=64, input_3x3=False, 277 | downsample_kernel_size=1, downsample_padding=0, 278 | ) 279 | if pretrained_imagenet: 280 | settings = pretrained_settings['se_resnext50_32x4d']['imagenet'] 281 | initialize_pretrained_model(model, settings) 282 | return model 283 | 284 | 285 | def se_resnext101_32x4d(pretrained_imagenet): 286 | model = BackBone(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, 287 | dropout_p=None, inplanes=64, input_3x3=False, 288 | downsample_kernel_size=1, downsample_padding=0, 289 | ) 290 | if pretrained_imagenet: 291 | settings = pretrained_settings['se_resnext101_32x4d']['imagenet'] 292 | initialize_pretrained_model(model, settings) 293 | return model 294 | -------------------------------------------------------------------------------- /models/fpn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from . import backbone 8 | from .misc import filter_detections, bbox_transform_inv, clip_boxes, build_anchors 9 | from dataGen.targetBuild import anchors_for_shape 10 | 11 | config_path = os.path.join(os.path.dirname(__file__), '..', 'config.json') 12 | with open(config_path, 'r') as f: 13 | config = json.load(f) 14 | 15 | 16 | class top_down(nn.Module): 17 | """ Creates the FPN layers on top of the backbone features. 18 | 19 | Args 20 | C3 : Feature stage C3 from the backbone. 21 | C4 : Feature stage C4 from the backbone. 22 | C5 : Feature stage C5 from the backbone. 23 | feature_size : The feature size to use for the resulting feature levels. 24 | 25 | Returns 26 | A list of feature levels [P3, P4, P5, P6, P7]. 27 | """ 28 | 29 | def __init__(self, feature_size=256): 30 | super(top_down, self).__init__() 31 | self.C5_reduced = nn.Conv2d(2048, feature_size, kernel_size=1, stride=1) 32 | self.P5_conv = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1) 33 | self.C4_reduced = nn.Conv2d(1024, feature_size, kernel_size=1, stride=1) 34 | self.P4_conv = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1) 35 | self.C3_reduced = nn.Conv2d(512, feature_size, kernel_size=1, stride=1) 36 | self.P3_conv = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1) 37 | self.P6_conv = nn.Conv2d(2048, feature_size, kernel_size=3, stride=2, padding=1) 38 | self.P7_conv = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=2, padding=1) 39 | 40 | def forward(self, x): 41 | C3, C4, C5 = x 42 | 43 | # upsample C5 to get P5 from the FPN paper 44 | P5 = self.C5_reduced(C5) 45 | P5_upsampled = F.interpolate(P5, scale_factor=2, mode='nearest') 46 | P5 = self.P5_conv(P5) 47 | 48 | # add P5 elementwise to C4 49 | P4 = self.C4_reduced(C4) 50 | P4 = P5_upsampled + P4 51 | P4_upsampled = F.interpolate(P4, scale_factor=2, mode='nearest') 52 | P4 = self.P4_conv(P4) 53 | 54 | # add P4 elementwise to C3 55 | P3 = self.C3_reduced(C3) 56 | P3 = P4_upsampled + P3 57 | P3 = self.P3_conv(P3) 58 | 59 | # "P6 is obtained via a 3x3 stride-2 conv on C5" 60 | P6 = self.P6_conv(C5) 61 | 62 | # "P7 is computed by applying ReLU followed by a 3x3 stride-2 conv on P6" 63 | P7 = F.relu(P6) 64 | P7 = self.P7_conv(P7) 65 | 66 | return [P3, P4, P5, P6, P7] 67 | 68 | 69 | class classification_subnet(nn.Module): 70 | """ the default classification submodel.""" 71 | options = { 72 | 'kernel_size': 3, 73 | 'stride': 1, 74 | 'padding': 1, 75 | } 76 | 77 | def __init__(self, num_classes=1, num_anchors=9, pyramid_feature_size=256, prior_probability=0.01): 78 | """ 79 | Args 80 | num_classes : Number of classes to predict a score for at each feature level. 81 | num_anchors : Number of anchors to predict classification scores for at each feature level. 82 | pyramid_feature_size : The number of filters to expect from the feature pyramid levels. 83 | prior_probability : Prior probability for training stability in early training 84 | """ 85 | super().__init__() 86 | convs = [] 87 | for i in range(4): 88 | conv = nn.Conv2d(pyramid_feature_size, pyramid_feature_size, **classification_subnet.options) 89 | nn.init.normal_(conv.weight, mean=0.0, std=0.01) 90 | nn.init.zeros_(conv.bias) 91 | convs.append(conv) 92 | convs.append(nn.ReLU()) 93 | self.feats = nn.Sequential(*convs) 94 | self.num_classes = num_classes 95 | head = nn.Conv2d(pyramid_feature_size, out_channels=num_classes * num_anchors, **classification_subnet.options) 96 | nn.init.normal_(head.weight, mean=0.0, std=0.01) 97 | nn.init.constant_(head.bias, val=-math.log((1 - prior_probability) / prior_probability)) 98 | self.head = head 99 | 100 | def forward(self, x): 101 | outputs = self.feats(x) 102 | outputs = self.head(outputs) 103 | 104 | # reshape output and apply sigmoid 105 | outputs = outputs.permute(0, 2, 3, 1).contiguous() 106 | outputs = outputs.view(outputs.shape[0], -1, self.num_classes) 107 | outputs = torch.sigmoid(outputs) 108 | return outputs 109 | 110 | 111 | class regression_subnet(nn.Module): 112 | """ Creates the default regression submodel.""" 113 | options = { 114 | 'kernel_size': 3, 115 | 'stride': 1, 116 | 'padding': 1, 117 | } 118 | 119 | def __init__(self, num_values=4, num_anchors=9, pyramid_feature_size=256): 120 | """ 121 | Args 122 | num_values : Number of values to regress. 123 | num_anchors : Number of anchors to regress for each feature level. 124 | pyramid_feature_size : The number of filters to expect from the feature pyramid levels. 125 | """ 126 | super().__init__() 127 | self.num_values = num_values 128 | convs = [] 129 | for i in range(4): 130 | conv = nn.Conv2d(pyramid_feature_size, pyramid_feature_size, **regression_subnet.options) 131 | nn.init.normal_(conv.weight, mean=0.0, std=0.01) 132 | nn.init.zeros_(conv.bias) 133 | convs.append(conv) 134 | convs.append(nn.ReLU()) 135 | self.feats = nn.Sequential(*convs) 136 | 137 | head = nn.Conv2d(pyramid_feature_size, out_channels=num_anchors * num_values, **regression_subnet.options) 138 | nn.init.normal_(head.weight, mean=0.0, std=0.01) 139 | nn.init.zeros_(head.bias) 140 | self.head = head 141 | 142 | def forward(self, x): 143 | outputs = self.feats(x) 144 | outputs = self.head(outputs) 145 | 146 | # reshape 147 | outputs = outputs.permute(0, 2, 3, 1).contiguous() 148 | outputs = outputs.view(outputs.shape[0], -1, self.num_values) 149 | return outputs 150 | 151 | 152 | class retinanet(nn.Module): 153 | """ Construct a RetinaNet model on top of a backbone, without bbox prediction transform""" 154 | 155 | def __init__(self, backbone_name=None, num_classes=None, num_anchors=None, pretrained_imagenet=None): 156 | """ 157 | Args 158 | backbone_name : backbone name, `se_resnext50_32x4d` or `se_resnext101_32x4d` 159 | num_classes : Number of classes to classify. 160 | num_anchors : Number of base anchors. 161 | """ 162 | super().__init__() 163 | 164 | if backbone_name is None: 165 | backbone_name = config['backbone'] 166 | assert backbone_name in ['se_resnext50_32x4d', 'se_resnext101_32x4d'], \ 167 | "`se_resnext50_32x4d` or `se_resnext101_32x4d`" 168 | bottom_up = getattr(backbone, backbone_name) 169 | 170 | if pretrained_imagenet is None: 171 | pretrained_imagenet=config['pretrained_imagenet'] 172 | self.bottom_up = bottom_up(pretrained_imagenet) 173 | 174 | self.top_down = top_down(feature_size=256) 175 | 176 | if num_classes is None: 177 | num_classes = config['num_classes'] 178 | if num_anchors is None: 179 | num_anchors = len(config['anchor_ratios_default']) * len(config['anchor_scales_default']) 180 | self.classification_subnet = classification_subnet(num_classes, num_anchors, 256, 0.01) 181 | self.regression_subnet = regression_subnet(4, num_anchors, 256) 182 | 183 | def forward(self, images): 184 | """ 185 | Args: 186 | images: Tensor of (B, 3, H, W), where B is the batch size; H, w is image height, width 187 | """ 188 | C3, C4, C5 = self.bottom_up(images) 189 | P3, P4, P5, P6, P7 = self.top_down((C3, C4, C5)) 190 | classification_output = [] 191 | regression_output = [] 192 | for P in [P3, P4, P5, P6, P7]: 193 | classification_output.append(self.classification_subnet(P)) 194 | regression_output.append(self.regression_subnet(P)) 195 | 196 | classification = torch.cat(classification_output, dim=1) 197 | regression = torch.cat(regression_output, dim=1) 198 | 199 | return classification, regression, [P3, P4, P5, P6, P7] 200 | 201 | def predict(self, images): 202 | """ 203 | Args: 204 | images: Tensor of (B, 3, H, W), where B is the batch size; H, W is image height, width 205 | 206 | Returns: 207 | list of [bboxes, labels, scores] per image 208 | """ 209 | # C3, C4, C5 = self.bottom_up(images) 210 | # P3, P4, P5, P6, P7 = self.top_down((C3, C4, C5)) 211 | # classification_output = [] 212 | # regression_output = [] 213 | # for P in [P3, P4, P5, P6, P7]: 214 | # classification_output.append(self.classification_subnet(P)) 215 | # regression_output.append(self.regression_subnet(P)) 216 | 217 | # classification = torch.cat(classification_output, dim=1) 218 | # regression = torch.cat(regression_output, dim=1) 219 | classification, regression, [P3, P4, P5, P6, P7] = self.__call__(images) 220 | print(classification.max().item()) 221 | anchors = build_anchors(features=[P3, P4, P5, P6, P7]) 222 | bboxes = bbox_transform_inv(anchors, regression) 223 | del anchors 224 | bboxes = clip_boxes(images, bboxes) 225 | 226 | return filter_detections(bboxes, classification) 227 | 228 | def train_extractor(self, active=True): 229 | if active: 230 | for p in self.bottom_up.parameters(): 231 | p.requires_grad = True 232 | self.bottom_up.train() 233 | else: 234 | for p in self.bottom_up.parameters(): 235 | p.requires_grad = False 236 | self.bottom_up.eval() 237 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | config_path = os.path.join(os.path.dirname(__file__), '..', 'config.json') 7 | with open(config_path, 'r') as f: 8 | config = json.load(f) 9 | # torch.device object used throughout this script 10 | device = torch.device("cuda" if config['use_cuda'] else "cpu") 11 | 12 | def focal_loss(alpha=0.25, gamma=2.0): 13 | """ Create a functor for computing the focal loss. 14 | 15 | Args 16 | alpha: Scale the focal weight with alpha. 17 | gamma: Take the power of the focal weight with gamma. 18 | 19 | Returns 20 | A functor that computes the focal loss using the alpha and gamma. 21 | """ 22 | 23 | def _focal(y_true, y_pred): 24 | """ Compute the focal loss given the target tensor and the predicted tensor. 25 | 26 | As defined in https://arxiv.org/abs/1708.02002 27 | 28 | Args 29 | y_true: Tensor of target data from the generator with shape (B, N, num_classes+1). 30 | y_pred: Tensor of predicted data from the network with shape (B, N, num_classes). 31 | 32 | Returns 33 | The focal loss of y_pred w.r.t. y_true. 34 | """ 35 | labels = y_true[:, :, :-1] 36 | anchor_state = y_true[:, :, -1] # -1 for ignore, 0 for background, 1 for object 37 | classification = y_pred 38 | 39 | # filter out "ignore" anchors 40 | indices = anchor_state != -1 41 | labels = labels[indices] 42 | classification = classification[indices] 43 | 44 | # compute the focal loss 45 | alpha_factor = torch.where(labels == 1, torch.tensor(alpha, dtype=torch.float32, device=device), 46 | torch.tensor(1 - alpha, dtype=torch.float32, device=device)) 47 | focal_weight = torch.where(labels == 1, 1 - classification, classification) 48 | focal_weight = alpha_factor * focal_weight ** gamma 49 | 50 | cls_loss = focal_weight * F.binary_cross_entropy(classification, labels, reduction='none') 51 | # compute the normalizer: the number of positive anchors 52 | normalizer = torch.sum(anchor_state == 1) 53 | normalizer = normalizer.type(torch.float32) 54 | normalizer = torch.max(normalizer, torch.tensor(1.0, device=device)) 55 | 56 | return torch.sum(cls_loss) / normalizer 57 | 58 | return _focal 59 | 60 | 61 | def smooth_l1_loss(sigma=3.0): 62 | """ Create a smooth L1 loss functor. 63 | 64 | Args 65 | sigma: This argument defines the point where the loss changes from L2 to L1. 66 | 67 | Returns 68 | A functor for computing the smooth L1 loss given target data and predicted data. 69 | """ 70 | sigma_squared = sigma ** 2 71 | 72 | def _smooth_l1(y_true, y_pred): 73 | """ Compute the smooth L1 loss of y_pred w.r.t. y_true. 74 | 75 | Args 76 | y_true: Tensor from the generator of shape (B, N, 5). The last value for each box is the state of the anchor (ignore, negative, positive). 77 | y_pred: Tensor from the network of shape (B, N, 4). 78 | 79 | Returns 80 | The smooth L1 loss of y_pred w.r.t. y_true. 81 | """ 82 | # separate target and state 83 | regression = y_pred 84 | regression_target = y_true[:, :, :4] 85 | anchor_state = y_true[:, :, 4] 86 | 87 | # filter out "ignore" anchors and "bg" 88 | indices = anchor_state == 1 89 | regression = regression[indices] 90 | regression_target = regression_target[indices] 91 | 92 | # compute smooth L1 loss 93 | # f(x) = 0.5 * (sigma * x)^2 if |x| < 1 / sigma / sigma 94 | # |x| - 0.5 / sigma / sigma otherwise 95 | regression_diff = regression - regression_target 96 | regression_diff = torch.abs(regression_diff) 97 | regression_loss = torch.where( 98 | torch.lt(regression_diff, 1.0 / sigma_squared), 99 | 0.5 * sigma_squared * torch.pow(regression_diff, 2), 100 | regression_diff - 0.5 / sigma_squared 101 | ) 102 | 103 | # compute the normalizer: the number of positive anchors 104 | normalizer = torch.max(torch.tensor(1, device=device), torch.sum(indices)) 105 | normalizer = normalizer.type(torch.float32) 106 | return torch.sum(regression_loss) / normalizer 107 | 108 | return _smooth_l1 109 | -------------------------------------------------------------------------------- /models/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import torch 5 | from dataGen.targetBuild import generate_anchors, AnchorParameters 6 | 7 | config_path = os.path.join(os.path.dirname(__file__), '..', 'config.json') 8 | with open(config_path, 'r') as f: 9 | config = json.load(f) 10 | # torch.device object used throughout this script 11 | device = torch.device("cuda" if config['use_cuda'] else "cpu") 12 | 13 | 14 | def shift(shape, stride, anchors): 15 | """ Produce shifted anchors based on shape of the map and stride size. 16 | 17 | Args 18 | shape : Shape to shift the anchors over (H, W). 19 | stride : Stride to shift the anchors with over the shape. 20 | anchors: The anchors to apply at each location, np.ndarry. 21 | """ 22 | H, W = shape 23 | shift_x = (torch.arange(W, dtype=torch.float32, device=device) + torch.tensor(0.5, dtype=torch.float32, device=device)) * stride 24 | shift_y = (torch.arange(H, dtype=torch.float32, device=device) + torch.tensor(0.5, dtype=torch.float32, device=device)) * stride 25 | shift_y, shift_x = torch.meshgrid([shift_y, shift_x]) 26 | shift_x = shift_x.contiguous().view(-1) 27 | shift_y = shift_y.contiguous().view(-1) 28 | shifts = torch.stack([shift_x, shift_y, shift_x, shift_y]) 29 | shifts = shifts.t() 30 | anchors = torch.tensor(anchors, dtype=torch.float32, device=device) 31 | 32 | shifted_anchors = torch.unsqueeze(anchors, 0) + torch.unsqueeze(shifts, 1) 33 | shifted_anchors = shifted_anchors.view(-1, 4) 34 | 35 | return shifted_anchors 36 | 37 | 38 | def build_anchors(features, anchor_params=None): 39 | # if no anchor parameters are passed, use default values 40 | if anchor_params is None: 41 | anchor_params = AnchorParameters.default 42 | ratios = anchor_params.ratios 43 | scales = anchor_params.scales 44 | sizes = anchor_params.sizes 45 | strides = anchor_params.strides 46 | anchors_bag = [] 47 | for feature, size, stride in zip(features, sizes, strides): 48 | shape = feature.shape[-2:] 49 | anchors = generate_anchors(size, ratios, scales) 50 | anchors_shift = shift(shape, stride, anchors) 51 | anchors_bag.append(anchors_shift) 52 | 53 | return torch.cat(anchors_bag, dim=0) 54 | 55 | 56 | def bbox_transform_inv(anchors, regression, mean=None, std=None): 57 | """ Applies deltas (usually regression results) to boxes (usually anchors). 58 | 59 | Before applying the deltas to the boxes, the normalization that was previously applied (in the generator) has to be removed. 60 | The mean and std are the mean and std as applied in the generator. They are unnormalized in this function and then applied to the boxes. 61 | 62 | Args 63 | anchors : Tensor of shape (N, 4), N the number of boxes and 4 values for (x1, y1, x2, y2). 64 | regression: Tensor of (B, N, 4), where B is the batch size, N the number of boxes. 65 | These deltas (d_x1, d_y1, d_x2, d_y2) are a factor of the width/height. 66 | mean : The mean value used when computing deltas (defaults to [0, 0, 0, 0]). 67 | std : The standard deviation used when computing deltas (defaults to [0.2, 0.2, 0.2, 0.2]). 68 | 69 | Returns 70 | A Tensor of the same shape as boxes, but with deltas applied to each box. 71 | The mean and std are used during training to normalize the regression values (networks love normalization). 72 | """ 73 | 74 | if mean is None: 75 | mean = config['mean_bbox_transform'] 76 | if std is None: 77 | std = config['std_bbox_transform'] 78 | 79 | anchors = torch.unsqueeze(anchors, dim=0) # (1, N, 4) 80 | width = anchors[:, :, 2] - anchors[:, :, 0] 81 | height = anchors[:, :, 3] - anchors[:, :, 1] 82 | 83 | x1 = anchors[:, :, 0] + (regression[:, :, 0] * std[0] + mean[0]) * width 84 | y1 = anchors[:, :, 1] + (regression[:, :, 1] * std[1] + mean[1]) * height 85 | x2 = anchors[:, :, 2] + (regression[:, :, 2] * std[2] + mean[2]) * width 86 | y2 = anchors[:, :, 3] + (regression[:, :, 3] * std[3] + mean[3]) * height 87 | 88 | pred_boxes = torch.stack([x1, y1, x2, y2], dim=2) 89 | 90 | return pred_boxes 91 | 92 | 93 | def clip_boxes(images, boxes): 94 | shape = images.shape 95 | height = shape[-2] 96 | width = shape[-1] 97 | 98 | x1 = torch.clamp(boxes[:, :, 0], 0.0, width) 99 | y1 = torch.clamp(boxes[:, :, 1], 0.0, height) 100 | x2 = torch.clamp(boxes[:, :, 2], 0.0, width) 101 | y2 = torch.clamp(boxes[:, :, 3], 0.0, height) 102 | boxes_x1y1x2y2 = torch.stack([x1, y1, x2, y2], dim=2) 103 | boxes_x1y1x2y2 = boxes_x1y1x2y2.type(torch.int64) 104 | return boxes_x1y1x2y2 105 | 106 | 107 | def bbox_iou(box1, box2): 108 | """ 109 | Returns the IoU of two bounding boxes 110 | """ 111 | # Get the coordinates of bounding boxes 112 | b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] 113 | b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] 114 | 115 | # get the corrdinates of the intersection rectangle 116 | inter_rect_x1 = torch.max(b1_x1, b2_x1) 117 | inter_rect_y1 = torch.max(b1_y1, b2_y1) 118 | inter_rect_x2 = torch.min(b1_x2, b2_x2) 119 | inter_rect_y2 = torch.min(b1_y2, b2_y2) 120 | # Intersection area 121 | inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * torch.clamp( 122 | inter_rect_y2 - inter_rect_y1 + 1, min=0 123 | ) 124 | # Union Area 125 | b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1) 126 | b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1) 127 | 128 | iou = inter_area / (b1_area + b2_area - inter_area + 1e-16) 129 | 130 | return iou 131 | 132 | 133 | def non_max_suppression(boxes, scores, max_output_size=300, iou_threshold=0.5): 134 | """ 135 | Removes detections with lower object confidence score than 'conf_thres' and performs 136 | Non-Maximum Suppression to further filter detections. 137 | Returns detections with shape: 138 | (x1, y1, x2, y2, object_conf, class_score, class_pred) 139 | """ 140 | 141 | # Sort the detections by maximum objectness confidence 142 | _, conf_sort_index = torch.sort(scores, descending=True) 143 | boxes = boxes[conf_sort_index] 144 | # Perform non-maximum suppression 145 | max_indexes = [] 146 | count = 0 147 | while boxes.shape[0] > 0: 148 | # Get detection with highest confidence and save as max detection 149 | max_detections = boxes[0].unsqueeze(0) # expand 1 dim 150 | max_indexes.append(conf_sort_index[0]) 151 | # Stop if we're at the last detection 152 | if boxes.shape[0] == 1: 153 | break 154 | # Get the IOUs for all boxes with lower confidence 155 | ious = bbox_iou(max_detections, boxes[1:]) 156 | # Remove detections with IoU >= NMS threshold 157 | boxes = boxes[1:][ious < iou_threshold] 158 | conf_sort_index = conf_sort_index[1:][ious < iou_threshold] 159 | # break when get enough bboxes 160 | count += 1 161 | if count >= max_output_size: 162 | break 163 | 164 | # max_detections = torch.cat(max_detections).data 165 | max_indexes = torch.stack(max_indexes).data 166 | return max_indexes 167 | 168 | 169 | def filter_detections( 170 | boxes, 171 | classification, 172 | class_specific_filter=True, 173 | nms=True, 174 | score_threshold=0.01, 175 | max_detections=300, 176 | nms_threshold=0.5 177 | ): 178 | """ Filter detections using the boxes and classification values. 179 | 180 | Args 181 | boxes : Tensor of shape (B, num_boxes, 4) containing the boxes in (x1, y1, x2, y2) format. 182 | classification : Tensor of shape (B, num_boxes, num_classes) containing the classification scores. 183 | class_specific_filter : Whether to perform filtering per class, or take the best scoring class and filter those. 184 | nms : Flag to enable/disable non maximum suppression. 185 | score_threshold : Threshold used to prefilter the boxes with. 186 | max_detections : Maximum number of detections to keep. 187 | nms_threshold : Threshold for the IoU value to determine when a box should be suppressed. 188 | 189 | Returns 190 | A list of [boxes, scores, labels, other[0], other[1], ...]. 191 | boxes is shaped (max_detections, 4) and contains the (x1, y1, x2, y2) of the non-suppressed boxes. 192 | scores is shaped (max_detections,) and contains the scores of the predicted class. 193 | labels is shaped (max_detections,) and contains the predicted label. 194 | other[i] is shaped (max_detections, ...) and contains the filtered other[i] data. 195 | In case there are less than max_detections detections, the tensors are padded with -1's. 196 | """ 197 | 198 | def _filter_detections(boxes, scores, labels): 199 | # threshold based on score 200 | indices = torch.gt(scores, score_threshold).nonzero() 201 | if indices.shape[0] == 0: 202 | return torch.tensor([], dtype=torch.int64, device=device) 203 | indices = indices[:, 0] 204 | 205 | if nms: 206 | filtered_boxes = torch.index_select(boxes, 0, indices) 207 | filtered_scores = torch.index_select(scores, 0, indices) 208 | 209 | # perform NMS 210 | nms_indices = non_max_suppression(filtered_boxes, filtered_scores, max_output_size=max_detections, 211 | iou_threshold=nms_threshold) 212 | 213 | # filter indices based on NMS 214 | indices = torch.index_select(indices, 0, nms_indices) 215 | 216 | # add indices to list of all indices 217 | labels = torch.index_select(labels, 0, indices) 218 | indices = torch.stack([indices, labels], dim=1) 219 | 220 | return indices 221 | 222 | results = [] 223 | for box_cur, classification_cur in zip(boxes, classification): 224 | if class_specific_filter: 225 | all_indices = [] 226 | # perform per class filtering 227 | for c in range(int(classification_cur.shape[1])): 228 | scores = classification_cur[:, c] 229 | labels = torch.full_like(scores, c, dtype=torch.int64) 230 | all_indices.append(_filter_detections(box_cur, scores, labels)) 231 | 232 | # concatenate indices to single tensor 233 | indices = torch.cat(all_indices, dim=0) 234 | else: 235 | scores, labels = torch.max(classification_cur, dim=1) 236 | indices = _filter_detections(box_cur, scores, labels) 237 | 238 | if indices.shape[0] == 0: 239 | results.append({'bboxes':np.zeros((0, 4)), 'scores': np.full((0, ), -1, dtype=np.float32),'category_id': np.full((0, ), -1, dtype=np.int64)}) 240 | continue 241 | # select top k 242 | scores = classification_cur[indices[:, 0], indices[:, 1]] 243 | labels = indices[:, 1] 244 | indices = indices[:, 0] 245 | 246 | scores, top_indices = torch.topk(scores, k=min(max_detections, scores.shape[0])) 247 | # filter input using the final set of indices 248 | indices = indices[top_indices] 249 | box_cur = box_cur[indices] 250 | labels = labels[top_indices] 251 | results.append({'bboxes':box_cur.cpu().detach().numpy(),'scores': scores.cpu().detach().numpy(), 'category_id': labels.cpu().detach().numpy()}) 252 | 253 | return results 254 | -------------------------------------------------------------------------------- /pretrainedmodels/.gitignore: -------------------------------------------------------------------------------- 1 | se_resnext50_32x4d-a260b3a4.pth 2 | se_resnext101_32x4d-3b2fe3d8.pth -------------------------------------------------------------------------------- /pretrainedmodels/download_here.txt: -------------------------------------------------------------------------------- 1 | https://github.com/Cadene/pretrained-models.pytorch 2 | 3 | pretrained_settings = { 4 | 'se_resnext50_32x4d': { 5 | 'imagenet': { 6 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', 7 | 'input_space': 'RGB', 8 | 'input_size': [3, 224, 224], 9 | 'input_range': [0, 1], 10 | 'mean': [0.485, 0.456, 0.406], 11 | 'std': [0.229, 0.224, 0.225], 12 | 'num_classes': 1000 13 | } 14 | }, 15 | 'se_resnext101_32x4d': { 16 | 'imagenet': { 17 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', 18 | 'input_space': 'RGB', 19 | 'input_size': [3, 224, 224], 20 | 'input_range': [0, 1], 21 | 'mean': [0.485, 0.456, 0.406], 22 | 'std': [0.229, 0.224, 0.225], 23 | 'num_classes': 1000 24 | } 25 | }, 26 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision==0.2.1 2 | setuptools==39.1.0 3 | torch==0.4.1 4 | Cython==0.29 5 | pydicom==1.2.0 6 | opencv_python==3.4.3.18 7 | matplotlib==2.2.3 8 | albumentations==0.1.7 9 | numpy==1.15.2 10 | tqdm==4.26.0 11 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python dataGen/setup.py build_ext --inplace 4 | rm -rf build dataGen/compute_overlap.c .eggs 5 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_boxinv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import pydicom 5 | import torch 6 | import numpy as np 7 | from albumentations import Resize, Compose 8 | from dataGen.data_loader import RsnaDataset, fetch_val_loader 9 | from models.misc import build_anchors, bbox_transform_inv 10 | 11 | config_path = os.path.join(os.path.dirname(__file__), '..', 'config.json') 12 | with open(config_path, 'r') as f: 13 | config = json.load(f) 14 | 15 | def test_boxtfsinv(): 16 | def load_dicom(img_id): 17 | image_path = os.path.join(os.path.dirname(__file__), '..', 'dataset/stage_2_train_images', img_id+'.dcm') 18 | ds = pydicom.read_file(image_path) 19 | image = ds.pixel_array 20 | # If grayscale. Convert to RGB for consistency. 21 | if len(image.shape) != 3 or image.shape[2] != 3: 22 | image = np.stack((image,) * 3, -1) 23 | return image 24 | 25 | kfold = 1 26 | valfps_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'valfps_{}.pkl'.format(kfold)) 27 | with open(valfps_path, 'rb') as f: 28 | valfps = pickle.load(f) 29 | 30 | bboxdict_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'valbox_{}.pkl'.format(kfold)) 31 | with open(bboxdict_path, 'rb') as f: 32 | bboxdict = pickle.load(f) 33 | 34 | labeldict_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'vallabel_{}.pkl'.format(kfold)) 35 | with open(labeldict_path, 'rb') as f: 36 | labeldict = pickle.load(f) 37 | sample = None 38 | for i, nm in enumerate(valfps): 39 | if len(labeldict[nm]) > 1: 40 | sample = nm 41 | idx = i 42 | break 43 | # 00436515-870c-4b36-a041-de91049b9ab4 44 | img = load_dicom(sample) 45 | # [[264, 152, 476, 530], [562, 152, 817, 604]] 46 | bboxes = bboxdict[sample] 47 | # [0, 0] 48 | labels = labeldict[sample] 49 | assert img.shape == (1024, 1024, 3) 50 | assert len(bboxes) > 0 51 | assert len(bboxes[0]) == 4 52 | assert len(labels) == len(bboxes) 53 | 54 | val_aug = [Resize(*config['image_shape'], p=1.0)] 55 | dt = RsnaDataset(valfps, bboxdict, labeldict, aug=val_aug) 56 | sample = dt[idx] 57 | assert len(sample) == 3 58 | assert sample[0].shape == (3, 224, 224) 59 | # when `config['image_shape']` == (224, 224) 60 | length = (28*28+14*14+7*7+4*4+2*2)*9 61 | assert sample[1].shape == (length, 2) 62 | # assert sample[1][:, 0].sum().item() == 2 63 | pos_label = (sample[1][:, 1] == 1).sum().item() 64 | ignore_label = (sample[1][:, 1] == -1).sum().item() 65 | neg_label = (sample[1][:, 1] == 0).sum().item() 66 | assert length == pos_label + ignore_label + neg_label 67 | 68 | pos_reg = (sample[2][..., -1] == 1).sum().item() 69 | ignore_reg = (sample[2][..., -1] == -1).sum().item() 70 | neg_reg = (sample[2][..., -1] == 0).sum().item() 71 | assert length == pos_reg + ignore_reg + neg_reg 72 | 73 | assert pos_label == pos_reg == 64 74 | assert ignore_label == ignore_reg == 109 75 | assert neg_label == neg_reg == 9268 76 | 77 | regression = sample[2] 78 | assert regression.shape == (length, 5) 79 | features = [] 80 | for sz in [28, 14, 7, 4, 2]: 81 | features.append(np.empty(shape=(config['batch_size'], 256, sz, sz))) 82 | anchors = build_anchors(features) 83 | assert anchors.shape == (length, 4) 84 | regression = regression[None, :, :4] 85 | assert regression.shape == (1, length, 4) 86 | bboxes_pred = bbox_transform_inv(anchors, regression)[0] 87 | 88 | # resize bbox 89 | bboxes = np.array(bboxes, dtype=np.float32) * 224 / 1024 90 | bboxes_pred = bboxes_pred[sample[1][:, -1]==1].numpy() 91 | assert bboxes_pred.shape == (pos_label, 4) 92 | assert bboxes.shape == (2, 4) 93 | assert (bboxes_pred[:, 0] == bboxes[0, 0]).sum() > 0 94 | assert (bboxes_pred[:, 1] == bboxes[0, 1]).sum() > 0 95 | assert (bboxes_pred[:, 2] == bboxes[0, 2]).sum() > 0 96 | assert (bboxes_pred[:, 3] == bboxes[0, 3]).sum() > 0 97 | 98 | assert (bboxes_pred[:, 0] == bboxes[1, 0]).sum() > 0 99 | assert (bboxes_pred[:, 1] == bboxes[1, 1]).sum() > 0 100 | assert (bboxes_pred[:, 2] == bboxes[1, 2]).sum() > 0 101 | assert (bboxes_pred[:, 3] == bboxes[1, 3]).sum() > 0 102 | -------------------------------------------------------------------------------- /tests/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | import torch 5 | import numpy as np 6 | from dataGen.data_loader import load_dicom, fetch_trn_loader, fetch_val_loader 7 | 8 | config_path = os.path.join(os.path.dirname(__file__), '..', 'config.json') 9 | with open(config_path, 'r') as f: 10 | config = json.load(f) 11 | num_anchors = len(config['anchor_ratios_default']) * len(config['anchor_scales_default']) 12 | length = (28*28+14*14+7*7+4*4+2*2)*num_anchors 13 | 14 | def test_loaddicm(): 15 | ids = os.listdir(os.path.join(os.path.dirname(__file__), '..', config['dicom_train'])) 16 | assert len(ids) > 0 17 | imgid = random.sample(ids, 1)[0] 18 | img = load_dicom(os.path.splitext(imgid)[0]) 19 | assert img.shape == (1024, 1024, 3) 20 | 21 | 22 | def test_dataloader(): 23 | for i in range(1, 6): 24 | dl = iter(fetch_trn_loader(i)) 25 | img_batch, labels_batch, regression_batch = next(dl) 26 | assert img_batch.shape == (config['batch_size'], 3, *config['image_shape']) 27 | assert labels_batch.shape == (config['batch_size'], length, config['num_classes']+1) 28 | assert regression_batch.shape == (config['batch_size'], length, 4+1) 29 | 30 | for i in range(1, 6): 31 | dl = iter(fetch_val_loader(i)) 32 | img_batch, labels_batch, regression_batch = next(dl) 33 | assert img_batch.shape == (config['batch_size'], 3, *config['image_shape']) 34 | assert labels_batch.shape == (config['batch_size'], length, config['num_classes']+1) 35 | assert regression_batch.shape == (config['batch_size'], length, 4+1) 36 | -------------------------------------------------------------------------------- /tests/test_filter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import json 4 | import torch 5 | import numpy as np 6 | from models.misc import filter_detections 7 | 8 | 9 | def test_simple(): 10 | # create simple FilterDetections layer 11 | 12 | # create simple input 13 | boxes = np.array([[ 14 | [0, 0, 10, 10], 15 | [0, 0, 10, 10], # this will be suppressed 16 | ]], dtype=np.float) 17 | boxes = torch.tensor(boxes) 18 | 19 | classification = np.array([[ 20 | [0, 0.9], # this will be suppressed 21 | [0, 1], 22 | ]], dtype=np.float) 23 | classification = torch.tensor(classification) 24 | 25 | # compute output 26 | results = filter_detections(boxes, classification) 27 | actual_boxes = results[0][0] 28 | actual_scores = results[0][1] 29 | actual_labels = results[0][2] 30 | 31 | # define expected output 32 | expected_boxes = np.array([[0, 0, 10, 10]], dtype=np.float) 33 | 34 | expected_scores = np.array([1], dtype=np.float) 35 | 36 | expected_labels = np.array([1], dtype=np.float) 37 | 38 | # assert actual and expected are equal 39 | np.testing.assert_array_equal(actual_boxes, expected_boxes) 40 | np.testing.assert_array_equal(actual_scores, expected_scores) 41 | np.testing.assert_array_equal(actual_labels, expected_labels) 42 | 43 | 44 | def test_mini_batch(): 45 | # create simple FilterDetections layer 46 | 47 | # create input with batch_size=2 48 | boxes = np.array([ 49 | [ 50 | [0, 0, 10, 10], # this will be suppressed 51 | [0, 0, 10, 10], 52 | ], 53 | [ 54 | [100, 100, 150, 150], 55 | [100, 100, 150, 150], # this will be suppressed 56 | ], 57 | ], dtype=np.float) 58 | boxes = torch.tensor(boxes) 59 | 60 | classification = np.array([ 61 | [ 62 | [0, 0.9], # this will be suppressed 63 | [0, 1], 64 | ], 65 | [ 66 | [1, 0], 67 | [0.9, 0], # this will be suppressed 68 | ], 69 | ], dtype=np.float) 70 | classification = torch.tensor(classification) 71 | 72 | # compute output 73 | results = filter_detections(boxes, classification) 74 | 75 | 76 | # define expected output 77 | expected_boxes0 = np.array([[0, 0, 10, 10]], dtype=np.float) 78 | expected_boxes1 = np.array([[100, 100, 150, 150]], dtype=np.float) 79 | 80 | expected_scores0 = np.array([1], dtype=np.float) 81 | expected_scores1 = np.array([1], dtype=np.float) 82 | 83 | expected_labels0 = np.array([1], dtype=np.float) 84 | expected_labels1 = np.array([0], dtype=np.float) 85 | 86 | # assert actual and expected are equal 87 | np.testing.assert_array_equal(results[0][0], expected_boxes0) 88 | np.testing.assert_array_equal(results[0][1], expected_scores0) 89 | np.testing.assert_array_equal(results[0][2], expected_labels0) 90 | 91 | # assert actual and expected are equal 92 | np.testing.assert_array_equal(results[1][0], expected_boxes1) 93 | np.testing.assert_array_equal(results[1][1], expected_scores1) 94 | np.testing.assert_array_equal(results[1][2], expected_labels1) 95 | -------------------------------------------------------------------------------- /tests/test_fpn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from models.fpn import top_down, classification_subnet, regression_subnet 6 | 7 | def test_top_down(): 8 | model = top_down() 9 | C3 = torch.randn(4, 512, 28, 28) 10 | C4 = torch.randn(4, 1024, 14, 14) 11 | C5 = torch.randn(4, 2048, 7, 7) 12 | 13 | P3, P4, P5, P6, P7 = model((C3, C4, C5)) 14 | assert P3.shape == (4, 256, 28, 28) 15 | assert P4.shape == (4, 256, 14, 14) 16 | assert P5.shape == (4, 256, 7, 7) 17 | assert P6.shape == (4, 256, 4, 4) 18 | assert P7.shape == (4, 256, 2, 2) 19 | 20 | # ---------------------------------- 21 | C3 = torch.randn(4, 512, 64, 64) 22 | C4 = torch.randn(4, 1024, 32, 32) 23 | C5 = torch.randn(4, 2048, 16, 16) 24 | 25 | P3, P4, P5, P6, P7 = model((C3, C4, C5)) 26 | assert P3.shape == (4, 256, 64, 64) 27 | assert P4.shape == (4, 256, 32, 32) 28 | assert P5.shape == (4, 256, 16, 16) 29 | assert P6.shape == (4, 256, 8, 8) 30 | assert P7.shape == (4, 256, 4, 4) 31 | 32 | 33 | def test_classification_subnet(): 34 | model = classification_subnet() 35 | P3 = torch.randn(4, 256, 28, 28) 36 | feat3 = model(P3) 37 | P4 = torch.randn(4, 256, 14, 14) 38 | feat4 = model(P4) 39 | P5 = torch.randn(4, 256, 7, 7) 40 | feat5 = model(P5) 41 | P6 = torch.randn(4, 256, 4, 4) 42 | feat6 = model(P6) 43 | P7 = torch.randn(4, 256, 2, 2) 44 | feat7 = model(P7) 45 | assert feat3.shape == (4, 9*28*28, 1) 46 | assert feat4.shape == (4, 9*14*14, 1) 47 | assert feat5.shape == (4, 9*7*7, 1) 48 | assert feat6.shape == (4, 9*4*4, 1) 49 | assert feat7.shape == (4, 9*2*2, 1) 50 | 51 | assert len(list(model.children())) == 2 52 | 53 | assert isinstance(model.head, nn.Conv2d) 54 | assert model.head.weight.shape == (9, 256, 3, 3) 55 | np.testing.assert_almost_equal(model.head.weight.mean().item(), 0, decimal=2) 56 | np.testing.assert_almost_equal(model.head.weight.std().item(), 0.01, decimal=2) 57 | prior_probability=0.01 58 | np.testing.assert_almost_equal(model.head.bias.data.numpy(), -math.log((1 - prior_probability) / prior_probability)) 59 | 60 | 61 | def test_regression_subnet(): 62 | model = regression_subnet() 63 | P3 = torch.randn(4, 256, 28, 28) 64 | feat3 = model(P3) 65 | P4 = torch.randn(4, 256, 14, 14) 66 | feat4 = model(P4) 67 | P5 = torch.randn(4, 256, 7, 7) 68 | feat5 = model(P5) 69 | P6 = torch.randn(4, 256, 4, 4) 70 | feat6 = model(P6) 71 | P7 = torch.randn(4, 256, 2, 2) 72 | feat7 = model(P7) 73 | assert feat3.shape == (4, 9*28*28, 4) 74 | assert feat4.shape == (4, 9*14*14, 4) 75 | assert feat5.shape == (4, 9*7*7, 4) 76 | assert feat6.shape == (4, 9*4*4, 4) 77 | assert feat7.shape == (4, 9*2*2, 4) 78 | -------------------------------------------------------------------------------- /tests/test_losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from models.losses import focal_loss, smooth_l1_loss 4 | 5 | def test_focal(): 6 | focal = focal_loss() 7 | y_pred = torch.rand(8, 3, 1) 8 | y_true = torch.rand(8, 3, 2) 9 | y_true[..., -1] = torch.tensor([1, -1, 0]) 10 | assert (y_true[..., -1] == torch.tensor([1, -1, 0], dtype=torch.float32)).all() 11 | assert focal(y_true, y_pred) 12 | 13 | 14 | def test_smooth_l1(): 15 | smooth_l1 = smooth_l1_loss() 16 | y_pred = torch.rand(8, 3, 4) 17 | y_true = torch.rand(8, 3, 5) 18 | y_true[..., -1] = torch.tensor([1, -1, 0]) 19 | assert (y_true[..., -1] == torch.tensor([1, -1, 0], dtype=torch.float32)).all() 20 | assert smooth_l1(y_true, y_pred) 21 | -------------------------------------------------------------------------------- /tests/test_retinanet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from models.fpn import retinanet 7 | from dataGen.data_loader import load_dicom, fetch_trn_loader, fetch_val_loader 8 | from models.losses import focal_loss, smooth_l1_loss 9 | 10 | 11 | config_path = os.path.join(os.path.dirname(__file__), '..', 'config.json') 12 | with open(config_path, 'r') as f: 13 | config = json.load(f) 14 | num_anchors = len(config['anchor_ratios_default']) * len(config['anchor_scales_default']) 15 | length = (28*28+14*14+7*7+4*4+2*2)*num_anchors 16 | 17 | def test_retinanet(): 18 | model = retinanet() 19 | images = torch.randn(8, 3, 224, 224) 20 | classification, regression = model(images) 21 | assert classification.shape == (8, length, 1) 22 | assert regression.shape == (8, length, 4) 23 | 24 | results = model.predict(images) 25 | assert len(results) == 8 26 | assert len(results[0]) == 3 27 | for res in results: 28 | assert res[0].shape[0] == res[1].shape[0] == res[2].shape[0] 29 | assert res[0].shape[1] == 4 30 | assert len(res[1].shape) == 1 31 | assert len(res[2].shape) == 1 32 | assert res[0].dtype == np.float32 33 | assert res[1].dtype == np.float32 34 | assert res[2].dtype == np.int64 35 | 36 | 37 | def test_merged(): 38 | model = retinanet() 39 | model = model.cuda() 40 | for k in range(1, 6): 41 | dl = iter(fetch_trn_loader(k)) 42 | for i in range(100): 43 | img_batch, labels_batch, regression_batch = next(dl) 44 | img_batch = img_batch.cuda() 45 | labels_batch = labels_batch.cuda() 46 | regression_batch = regression_batch.cuda() 47 | 48 | classification, regression = model(img_batch) 49 | assert classification.shape == (config['batch_size'], length, config['num_classes']) 50 | assert labels_batch.shape == (config['batch_size'], length, config['num_classes']+1) 51 | assert regression.shape == (config['batch_size'], length, 4) 52 | assert regression_batch.shape == (config['batch_size'], length, 4+1) 53 | 54 | focal = focal_loss() 55 | smooth_l1 = smooth_l1_loss() 56 | assert focal(labels_batch, classification).shape == torch.Size([]) 57 | assert smooth_l1(regression_batch, regression).shape == torch.Size([]) 58 | 59 | results = model.predict(img_batch) 60 | 61 | for k in range(1, 6): 62 | dl = iter(fetch_val_loader(k)) 63 | for i in range(100): 64 | img_batch, labels_batch, regression_batch = next(dl) 65 | img_batch = img_batch.cuda() 66 | labels_batch = labels_batch.cuda() 67 | regression_batch = regression_batch.cuda() 68 | 69 | classification, regression = model(img_batch) 70 | assert classification.shape == (config['batch_size'], length, config['num_classes']) 71 | assert labels_batch.shape == (config['batch_size'], length, config['num_classes']+1) 72 | assert regression.shape == (config['batch_size'], length, 4) 73 | assert regression_batch.shape == (config['batch_size'], length, 4+1) 74 | 75 | focal = focal_loss() 76 | smooth_l1 = smooth_l1_loss() 77 | assert focal(labels_batch, classification).shape == torch.Size([]) 78 | assert smooth_l1(regression_batch, regression).shape == torch.Size([]) 79 | 80 | results = model.predict(img_batch) 81 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Train the model""" 2 | import sys 3 | import json 4 | import argparse 5 | import logging 6 | import os 7 | import numpy as np 8 | import torch 9 | from torch import nn, optim 10 | from torch.nn.utils import clip_grad_norm_ 11 | from tqdm import tqdm 12 | from models.losses import focal_loss, smooth_l1_loss 13 | from models.fpn import retinanet 14 | from dataGen.data_loader import fetch_trn_loader, fetch_val_loader 15 | import utils 16 | 17 | with open('config.json', 'r') as f: 18 | config = json.load(f) 19 | # torch.device object used throughout this script 20 | device = torch.device("cuda" if config['use_cuda'] else "cpu") 21 | 22 | def parse_args(args=None): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--fold', required=True, choices=[1, 2, 3, 4, 5], type=int, help="Directory containing the dataset") 25 | parser.add_argument('--learning_rate', default=1e-5, type=float, help="learning rate of optimizer") 26 | parser.add_argument('--num_epochs', default=30, type=int, help="total epochs to train") 27 | parser.add_argument('--frozen_epochs', default=20, type=int, help="the first epoches to fix parameter of backbone") 28 | parser.add_argument('--save_dir', default='checkpoints', type=str, help="Directory containing params.json") 29 | parser.add_argument('--checkpoint2load', default=None, type=str, help="checkpoint to load") # 'best' or 'train' 30 | parser.add_argument('--optim_restore', default=True, type=bool, help="whether to restore optimizer parameter") 31 | return parser.parse_args(args) 32 | 33 | 34 | def train(model, optimizer, loss_fn, dataloader, params, epoch): 35 | """Train the model on `num_steps` batches 36 | 37 | Args: 38 | model: (torch.nn.Module) the neural network 39 | optimizer: (torch.optim) optimizer for parameters of model 40 | loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch 41 | dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data 42 | metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch 43 | params: (Params) hyperparameters 44 | num_steps: (int) number of batches to train on, each of size params.batch_size 45 | """ 46 | loss_TOTAL = utils.RunningAverage() 47 | loss_FL = utils.RunningAverage() 48 | loss_L1 = utils.RunningAverage() 49 | 50 | model.train() 51 | if epoch < params.frozen_epochs: 52 | model.train_extractor(False) 53 | else: 54 | model.train_extractor(True) 55 | 56 | with tqdm(total=len(dataloader)) as t: 57 | for i, (img_batch, labels_batch, regression_batch) in enumerate(dataloader): 58 | img_batch, labels_batch, regression_batch = img_batch.to(device), labels_batch.to(device), regression_batch.to(device) 59 | classification_pred, regression_pred, _ = model(img_batch) 60 | loss_cls = loss_fn['focal'](labels_batch, classification_pred) * config['loss_ratio_FL2L1'] 61 | loss_reg = loss_fn['smooth_l1'](regression_batch, regression_pred) 62 | loss_all = loss_cls + loss_reg 63 | 64 | loss_cls_detach = loss_cls.detach().item() 65 | loss_reg_detach = loss_reg.detach().item() 66 | loss_all_detach = loss_all.detach().item() 67 | # clear previous gradients, compute gradients of all variables wrt loss 68 | optimizer.zero_grad() 69 | loss_all.backward() 70 | # The norm is computed over all gradients together 71 | clip_grad_norm_(model.parameters(), 0.5) 72 | # performs updates using calculated gradients 73 | optimizer.step() 74 | 75 | # update the average loss 76 | loss_TOTAL.update(loss_all_detach) 77 | loss_FL.update(loss_cls_detach) 78 | loss_L1.update(loss_reg_detach) 79 | 80 | del img_batch, labels_batch, regression_batch 81 | 82 | t.set_postfix(total_loss='{:05.3f}'.format(loss_all_detach), FL_loss='{:05.3f}'.format( 83 | loss_cls_detach), L1_loss='{:05.3f}'.format(loss_reg_detach)) 84 | t.update() 85 | logging.info("total_loss:{:05.3f} FL_loss:{:05.3f} L1_loss:{:05.3f}".format(loss_TOTAL(), loss_FL(), loss_L1())) 86 | del loss_TOTAL, loss_FL, loss_L1 87 | 88 | 89 | def evaluate(model, loss_fn, val_dataloader, params, epoch): 90 | # set model to evaluation mode 91 | model.eval() 92 | 93 | loss_TOTAL = utils.RunningAverage() 94 | loss_FL = utils.RunningAverage() 95 | loss_L1 = utils.RunningAverage() 96 | 97 | with torch.no_grad(): 98 | for i, (img_batch, labels_batch, regression_batch) in enumerate(val_dataloader): 99 | img_batch, labels_batch, regression_batch = img_batch.to(device), labels_batch.to(device), regression_batch.to(device) 100 | classification_pred, regression_pred, _ = model(img_batch) 101 | loss_cls = loss_fn['focal'](labels_batch, classification_pred) * config['loss_ratio_FL2L1'] 102 | loss_reg = loss_fn['smooth_l1'](regression_batch, regression_pred) 103 | loss_all = loss_cls + loss_reg 104 | 105 | loss_cls_detach = loss_cls.detach().item() 106 | loss_reg_detach = loss_reg.detach().item() 107 | loss_all_detach = loss_all.detach().item() 108 | 109 | # update the average loss 110 | loss_TOTAL.update(loss_all_detach) 111 | loss_FL.update(loss_cls_detach) 112 | loss_L1.update(loss_reg_detach) 113 | 114 | del img_batch, labels_batch, regression_batch 115 | 116 | logging.info("total_loss:{:05.3f} FL_loss:{:05.3f} L1_loss:{:05.3f}".format(loss_TOTAL(), loss_FL(), loss_L1())) 117 | res = loss_TOTAL() 118 | del loss_TOTAL, loss_FL, loss_L1 119 | return res 120 | 121 | def train_and_evaluate(model, train_dataloader, val_dataloader, optimizer, loss_fn, params, 122 | scheduler=None): 123 | """Train the model and evaluate every epoch. 124 | 125 | Args: 126 | model: (torch.nn.Module) the neural network 127 | train_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data 128 | val_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches validation data 129 | optimizer: (torch.optim) optimizer for parameters of model 130 | loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch 131 | metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch 132 | params: (Params) hyperparameters 133 | """ 134 | init_epoch = 0 135 | best_val_loss = float('inf') 136 | 137 | # reload weights from restore_file if specified 138 | if params.checkpoint2load is not None: 139 | checkpoint = utils.load_checkpoint(params.checkpoint2load, model, optimizer if params.optim_restore else None) 140 | if 'epoch' in checkpoint: 141 | init_epoch = checkpoint['epoch'] 142 | if 'best_val_loss' in checkpoint: 143 | best_val_loss = checkpoint['best_val_loss'] 144 | 145 | for epoch in range(init_epoch, params.num_epochs): 146 | # Run one epoch 147 | logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs)) 148 | if scheduler is not None: 149 | scheduler.step() 150 | # compute number of batches in one epoch (one full pass over the training set) 151 | train(model, optimizer, loss_fn, train_dataloader, params, epoch) 152 | 153 | logging.info("validating ... ") 154 | # Evaluate for one epoch on validation set 155 | val_loss = evaluate(model, loss_fn, val_dataloader, params, epoch) 156 | 157 | is_best = val_loss <= best_val_loss 158 | if is_best: 159 | best_val_loss = val_loss 160 | # Save weights 161 | utils.save_checkpoint({'epoch': epoch + 1, 162 | 'state_dict': model.state_dict(), 163 | 'optim_dict': optimizer.state_dict(), 164 | 'best_val_loss': best_val_loss}, 165 | is_best=is_best, 166 | checkpoint=params.save_dir) 167 | 168 | 169 | if __name__ == '__main__': 170 | # Load the parameters from json file 171 | args = parse_args() 172 | 173 | # Set the random seed for reproducible experiments 174 | torch.manual_seed(42) 175 | if config['use_cuda']: 176 | torch.cuda.manual_seed(42) 177 | 178 | args.save_dir = os.path.join(args.save_dir, config['backbone'], f'fold_{args.fold}') 179 | # Set the logger 180 | if not os.path.isdir(args.save_dir): 181 | os.makedirs(args.save_dir) 182 | 183 | # save config in file 184 | with open(os.path.join(args.save_dir, 'config.json'), 'w') as f: 185 | config.update(vars(args)) 186 | json.dump(config, f, indent=4) 187 | 188 | utils.set_logger(os.path.join(args.save_dir, 'train.log')) 189 | logging.info(' '.join(sys.argv[:])) 190 | logging.info(args.save_dir) 191 | 192 | # Create the input data pipeline 193 | logging.info("Loading the datasets...") 194 | # fetch dataloaders 195 | train_dl = fetch_trn_loader(args.fold) 196 | val_dl = fetch_val_loader(args.fold) 197 | 198 | # Define the model and optimizer 199 | Net = retinanet(config['backbone']) 200 | model = Net.to(device) 201 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) 202 | # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) 203 | # fetch loss function and metrics 204 | loss_fn = {'focal': focal_loss(alpha=config['focal_alpha']), 'smooth_l1': smooth_l1_loss(sigma=config['l1_sigma'])} 205 | 206 | # Train the model 207 | logging.info("Starting training for {} epoch(s)".format(args.num_epochs)) 208 | train_and_evaluate(model, train_dl, val_dl, optimizer, loss_fn, args, scheduler=None) 209 | logging.info('Done') 210 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import shutil 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import cv2 9 | 10 | class RunningAverage(): 11 | """A simple class that maintains the running average of a quantity 12 | 13 | Example: 14 | ``` 15 | loss_avg = RunningAverage() 16 | loss_avg.update(2) 17 | loss_avg.update(4) 18 | loss_avg() = 3 19 | ``` 20 | """ 21 | def __init__(self): 22 | self.steps = 0 23 | self.total = 0 24 | 25 | def update(self, val): 26 | self.total += val 27 | self.steps += 1 28 | 29 | def __call__(self): 30 | return self.total/float(self.steps) 31 | 32 | 33 | def set_logger(log_path): 34 | """Set the logger to log info in terminal and file `log_path`. 35 | 36 | In general, it is useful to have a logger so that every output to the terminal is saved 37 | in a permanent file. Here we save it to `model_dir/train.log`. 38 | 39 | Example: 40 | ``` 41 | logging.info("Starting training...") 42 | ``` 43 | 44 | Args: 45 | log_path: (string) where to log 46 | """ 47 | logger = logging.getLogger() 48 | logger.setLevel(logging.INFO) 49 | 50 | if not logger.handlers: 51 | # Logging to a file 52 | file_handler = logging.FileHandler(log_path) 53 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 54 | logger.addHandler(file_handler) 55 | 56 | # Logging to console 57 | stream_handler = logging.StreamHandler() 58 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 59 | logger.addHandler(stream_handler) 60 | 61 | 62 | def save_dict_to_json(d, json_path): 63 | """Saves dict of floats in json file 64 | 65 | Args: 66 | d: (dict) of float-castable values (np.float, int, float, etc.) 67 | json_path: (string) path to json file 68 | """ 69 | with open(json_path, 'w') as f: 70 | json.dump(d, f, indent=4) 71 | 72 | 73 | def save_checkpoint(state, is_best, checkpoint): 74 | """Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves 75 | checkpoint + 'best.pth.tar' 76 | 77 | Args: 78 | state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict 79 | is_best: (bool) True if it is the best model seen till now 80 | checkpoint: (string) folder where parameters are to be saved 81 | """ 82 | filepath = os.path.join(checkpoint, f"epoch{state['epoch']}.pth.tar") 83 | if not os.path.exists(checkpoint): 84 | print("Checkpoint Directory does not exist! Making directory {}".format(checkpoint)) 85 | os.mkdir(checkpoint) 86 | else: 87 | print("Checkpoint Directory exists! ") 88 | torch.save(state, filepath) 89 | if is_best: 90 | shutil.copyfile(filepath, os.path.join(checkpoint, 'best.pth.tar')) 91 | 92 | 93 | def load_checkpoint(checkpoint, model, optimizer=None): 94 | """Loads model parameters (state_dict) from file_path. If optimizer is provided, loads state_dict of 95 | optimizer assuming it is present in checkpoint. 96 | 97 | Args: 98 | checkpoint: (string) filename which needs to be loaded 99 | model: (torch.nn.Module) model for which the parameters are loaded 100 | optimizer: (torch.optim) optional: resume optimizer from checkpoint 101 | """ 102 | if not os.path.exists(checkpoint): 103 | raise("File doesn't exist {}".format(checkpoint)) 104 | checkpoint = torch.load(checkpoint) 105 | model.load_state_dict(checkpoint['state_dict']) 106 | 107 | if optimizer: 108 | optimizer.load_state_dict(checkpoint['optim_dict']) 109 | 110 | return checkpoint 111 | 112 | 113 | # https://github.com/albu/albumentations/blob/master/notebooks/example_bboxes.ipynb 114 | # Functions to visualize bounding boxes and class labels on an image. 115 | # Based on https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/vis.py 116 | BOX_COLOR = {0:(255, 0, 0), 1:(0, 255, 0)} 117 | TEXT_COLOR = (255, 255, 255) 118 | 119 | # Available formats are: coco, pascal_voc. 120 | # The coco format of a bounding box looks like [x_min, y_min, width, height], e.g. [97, 12, 150, 200]. 121 | # The pascal_voc format of a bounding box looks like [x_min, y_min, x_max, y_max], e.g. [97, 12, 247, 212]. 122 | 123 | def visualize_bbox(img, bbox, class_id, class_idx_to_name, color=BOX_COLOR, thickness=2, pascal=True): 124 | if pascal: 125 | x_min, x_max, y_min, y_max = bbox 126 | else: 127 | x_min, y_min, w, h = bbox 128 | x_min, x_max, y_min, y_max = int(x_min), int(x_min + w), int(y_min), int(y_min + h) 129 | 130 | boxcolor = BOX_COLOR[class_id] 131 | cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=boxcolor, thickness=thickness) 132 | class_name = class_idx_to_name[class_id] 133 | ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1) 134 | cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), boxcolor, -1) 135 | cv2.putText(img, class_name, (x_min, y_min - int(0.3 * text_height)), cv2.FONT_HERSHEY_SIMPLEX, 0.35,TEXT_COLOR, lineType=cv2.LINE_AA) 136 | return img 137 | 138 | 139 | def visualize(annotations, category_id_to_name): 140 | img = annotations['image'].copy() 141 | for idx, bbox in enumerate(annotations['bboxes']): 142 | img = visualize_bbox(img, bbox, annotations['category_id'][idx], category_id_to_name) 143 | plt.figure(figsize=(12, 12)) 144 | plt.imshow(img) 145 | --------------------------------------------------------------------------------