├── .gitignore
├── README.md
├── asserts
├── demo.png
└── retinanet.png
├── config.json
├── dataGen
├── __init__.py
├── compute_overlap.cpython-36m-x86_64-linux-gnu.so
├── compute_overlap.pyx
├── data_loader.py
├── setup.py
├── targetBuild.py
└── utils.py
├── dataset
├── demo.ipynb
├── fold_data
├── trnbox_1.pkl
├── trnbox_2.pkl
├── trnbox_3.pkl
├── trnbox_4.pkl
├── trnbox_5.pkl
├── trnbox_6.pkl
├── trnfps_1.pkl
├── trnfps_2.pkl
├── trnfps_3.pkl
├── trnfps_4.pkl
├── trnfps_5.pkl
├── trnfps_6.pkl
├── trnlabel_1.pkl
├── trnlabel_2.pkl
├── trnlabel_3.pkl
├── trnlabel_4.pkl
├── trnlabel_5.pkl
├── trnlabel_6.pkl
├── valbox_1.pkl
├── valbox_2.pkl
├── valbox_3.pkl
├── valbox_4.pkl
├── valbox_5.pkl
├── valbox_6.pkl
├── valfps_1.pkl
├── valfps_2.pkl
├── valfps_3.pkl
├── valfps_4.pkl
├── valfps_5.pkl
├── valfps_6.pkl
├── vallabel_1.pkl
├── vallabel_2.pkl
├── vallabel_3.pkl
├── vallabel_4.pkl
├── vallabel_5.pkl
└── vallabel_6.pkl
├── models
├── __init__.py
├── backbone.py
├── fpn.py
├── losses.py
└── misc.py
├── prepare_data.ipynb
├── pretrainedmodels
├── .gitignore
└── download_here.txt
├── requirements.txt
├── setup.sh
├── tests
├── __init__.py
├── test_boxinv.py
├── test_dataloader.py
├── test_filter.py
├── test_fpn.py
├── test_losses.py
└── test_retinanet.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .ipynb_checkpoints
3 | checkpoints
4 | .vscode
5 | dataGen/__pycache__
6 | models/__pycache__
7 | tests/__pycache__
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## An Implentation of RetinaNet in Pytorch
2 | 
3 |
4 | #### optional backbone:
5 | - [se_resnext50_32x4d](http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth)
6 | - [se_resnext101_32x4d](http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth)
7 |
8 | #### usage:
9 | `setup.sh`: compile cython code
10 | `config.json`: config file
11 | `train.py`: main train script
12 | `dataGen/data_loader.py`: Subclass **torch.utils.data.Dataset** or modify **class RsnaDataset** for your owen dataset.
13 | `prepare_data.ipynb`: data processing script for **RSNA Pneumonia Detection Challenge**
14 | `demo.ipynb`: sample code showing how to predict with **model.predict** method
15 |
16 | ---
17 | #### Application
18 | train this model with dataset of [RSNA Pneumonia Detection Challenge](https://www.kaggle.com/c/rsna-pneumonia-detection-challenge)
19 | 
20 |
21 | ### credits:
22 | 1. [Cadene pretrained-models.pytorch](https://github.com/Cadene/pretrained-models.pytorch)
23 | 2. [fizyr/keras-retinanet](https://github.com/fizyr/keras-retinanet)
24 | 3. [Squeeze-and-Excitation Networks](https://arxiv.org/pdf/1709.01507.pdf)
25 | 4. [Focal Loss for Dense Object Detection](https://arxiv.org/pdf/1708.02002.pdf)
--------------------------------------------------------------------------------
/asserts/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/asserts/demo.png
--------------------------------------------------------------------------------
/asserts/retinanet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/asserts/retinanet.png
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "pretrained_imagenet": true,
3 | "backbone": "se_resnext50_32x4d",
4 | "image_shape": [512, 512],
5 | "pyramid_levels_default": [3, 4, 5, 6, 7],
6 | "anchor_sizes_default": [32, 64, 128, 256, 512],
7 | "anchor_strides_default": [8, 16, 32, 64, 128],
8 | "anchor_ratios_default": [0.5, 1, 2],
9 | "anchor_scales_default": [1.0, 1.2599, 1.5874],
10 |
11 | "num_classes": 1,
12 | "mean_bbox_transform": [0.0, 0.0, 0.0, 0.0],
13 | "std_bbox_transform": [0.2, 0.2, 0.2, 0.2],
14 | "dicom_train": "dataset/stage_2_train_images",
15 | "RandomRotate": true,
16 | "RandomHorizontalFlip": true,
17 | "batch_size": 8,
18 |
19 | "loss_ratio_FL2L1": 1.0,
20 | "focal_alpha": 0.75,
21 | "l1_sigma": 3,
22 |
23 | "use_cuda": true,
24 | "num_workers": 4,
25 | "rsna_mean": 0.49,
26 | "rsna_std": 0.23
27 | }
--------------------------------------------------------------------------------
/dataGen/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/dataGen/__init__.py
--------------------------------------------------------------------------------
/dataGen/compute_overlap.cpython-36m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/dataGen/compute_overlap.cpython-36m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/dataGen/compute_overlap.pyx:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Fast R-CNN
3 | # Copyright (c) 2015 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Sergey Karayev
6 | # --------------------------------------------------------
7 |
8 | cimport cython
9 | import numpy as np
10 | cimport numpy as np
11 |
12 |
13 | def compute_overlap(
14 | np.ndarray[double, ndim=2] boxes,
15 | np.ndarray[double, ndim=2] query_boxes
16 | ):
17 | """
18 | Args
19 | a: (N, 4) ndarray of float
20 | b: (K, 4) ndarray of float
21 |
22 | Returns
23 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes
24 | """
25 | cdef unsigned int N = boxes.shape[0]
26 | cdef unsigned int K = query_boxes.shape[0]
27 | cdef np.ndarray[double, ndim=2] overlaps = np.zeros((N, K), dtype=np.float64)
28 | cdef double iw, ih, box_area
29 | cdef double ua
30 | cdef unsigned int k, n
31 | for k in range(K):
32 | box_area = (
33 | (query_boxes[k, 2] - query_boxes[k, 0] + 1) *
34 | (query_boxes[k, 3] - query_boxes[k, 1] + 1)
35 | )
36 | for n in range(N):
37 | iw = (
38 | min(boxes[n, 2], query_boxes[k, 2]) -
39 | max(boxes[n, 0], query_boxes[k, 0]) + 1
40 | )
41 | if iw > 0:
42 | ih = (
43 | min(boxes[n, 3], query_boxes[k, 3]) -
44 | max(boxes[n, 1], query_boxes[k, 1]) + 1
45 | )
46 | if ih > 0:
47 | ua = np.float64(
48 | (boxes[n, 2] - boxes[n, 0] + 1) *
49 | (boxes[n, 3] - boxes[n, 1] + 1) +
50 | box_area - iw * ih
51 | )
52 | overlaps[n, k] = iw * ih / ua
53 | return overlaps
54 |
--------------------------------------------------------------------------------
/dataGen/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import pydicom
5 | import numpy as np
6 | import torch
7 | from torch.utils.data import Dataset, DataLoader
8 | from torchvision import transforms
9 | from albumentations import HorizontalFlip, Rotate, Resize, Compose
10 | from models.backbone import pretrained_settings
11 | from .targetBuild import anchor_targets_bbox, anchors_for_shape
12 |
13 | config_path = os.path.join(os.path.dirname(__file__), '..', 'config.json')
14 | with open(config_path, 'r') as f:
15 | config = json.load(f)
16 |
17 |
18 | def load_dicom(img_id):
19 | image_path = os.path.join(os.path.dirname(__file__), '..', config['dicom_train'], img_id+'.dcm')
20 | ds = pydicom.read_file(image_path)
21 | image = ds.pixel_array
22 | # If grayscale. Convert to RGB for consistency.
23 | if len(image.shape) != 3 or image.shape[2] != 3:
24 | image = np.stack((image,) * 3, -1)
25 | return image
26 |
27 |
28 | def get_aug(aug, min_area=0., min_visibility=0.):
29 | return Compose(aug, bbox_params={'format': 'pascal_voc', 'min_area': min_area, 'min_visibility': min_visibility, 'label_fields': ['category_id']})
30 |
31 |
32 | class RsnaDataset(Dataset):
33 | """
34 | A standard PyTorch definition of Dataset which defines the functions __len__ and __getitem__.
35 | """
36 |
37 | def __init__(self, filenames, gt_bboxes, gt_catids, aug=None):
38 | """
39 | Store the filenames of the jpgs to use. Specifies transforms to apply on images.
40 |
41 | Args:
42 | data_dir: (string) directory containing the dataset
43 | transform: (torchvision.transforms) transformation to apply on image
44 | """
45 | self.filenames = filenames
46 | self.gt_bboxes = gt_bboxes
47 | self.gt_catids = gt_catids
48 | self.aug = get_aug(aug) if aug is not None else get_aug([])
49 | model_name = config['backbone']
50 | # img_mean = pretrained_settings[model_name]['imagenet']['mean']
51 | # img_std = pretrained_settings[model_name]['imagenet']['std']
52 | self.img_tfs = transforms.Compose([transforms.ToTensor(), transforms.Normalize(
53 | mean=[config['rsna_mean']]*3, std=[config['rsna_std']]*3)])
54 | self.anchors = anchors_for_shape(config['image_shape'])
55 |
56 | def __len__(self):
57 | # return size of dataset
58 | return len(self.filenames)
59 |
60 | def __getitem__(self, idx):
61 | """
62 | Fetch index idx image and labels from dataset. Perform transforms on image.
63 |
64 | Args:
65 | idx: (int) index in [0, 1, ..., size_of_dataset-1]
66 |
67 | Returns:
68 | image: (Tensor) transformed image
69 | label: (int) corresponding label of image
70 | """
71 | fps = self.filenames[idx]
72 | image = load_dicom(fps)
73 | bboxes = self.gt_bboxes[fps]
74 | category_id = self.gt_catids[fps]
75 | augmented = self.aug(image=image, bboxes=bboxes, category_id=category_id)
76 | if not augmented['bboxes']:
77 | gt_annos = np.zeros((0, 5), dtype=np.float32)
78 | else:
79 | gt_annos = np.empty((len(augmented['bboxes']), 5), dtype=np.float32)
80 | gt_annos[:, :4] = augmented['bboxes']
81 | gt_annos[:, 4] = augmented['category_id']
82 | img = self.img_tfs(augmented['image'])
83 | labels, regression = anchor_targets_bbox(self.anchors, gt_annos)
84 | labels, regression = torch.tensor(labels), torch.tensor(regression)
85 | return img, labels, regression
86 |
87 | trn_aug = []
88 | if config['image_shape']:
89 | trn_aug.append(Resize(*config['image_shape'], p=1.0))
90 | val_aug = [Resize(*config['image_shape'], p=1.0)]
91 | else:
92 | val_aug = []
93 |
94 | if config['RandomRotate']:
95 | trn_aug.append(Rotate(limit=15, p=0.5))
96 |
97 | if config['RandomHorizontalFlip']:
98 | trn_aug.append(HorizontalFlip(p=0.5))
99 |
100 |
101 | def fetch_trn_loader(kfold, trnfps=None, bboxdict=None, labeldict=None, aug=None):
102 | if trnfps is None:
103 | trnfps_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'trnfps_{}.pkl'.format(kfold))
104 | with open(trnfps_path, 'rb') as f:
105 | trnfps = pickle.load(f)
106 |
107 | if bboxdict is None:
108 | bboxdict_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'trnbox_{}.pkl'.format(kfold))
109 | with open(bboxdict_path, 'rb') as f:
110 | bboxdict = pickle.load(f)
111 |
112 | if labeldict is None:
113 | labeldict_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'trnlabel_{}.pkl'.format(kfold))
114 | with open(labeldict_path, 'rb') as f:
115 | labeldict = pickle.load(f)
116 |
117 | if aug is None:
118 | aug = trn_aug
119 |
120 | dataset = RsnaDataset(trnfps, bboxdict, labeldict, aug)
121 | return DataLoader(dataset,
122 | batch_size=config['batch_size'],
123 | shuffle=True,
124 | num_workers=config['num_workers'],
125 | pin_memory=config['use_cuda'])
126 |
127 |
128 | def fetch_val_loader(kfold, valfps=None, bboxdict=None, labeldict=None, aug=None):
129 | if valfps is None:
130 | valfps_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'valfps_{}.pkl'.format(kfold))
131 | with open(valfps_path, 'rb') as f:
132 | valfps = pickle.load(f)
133 |
134 | if bboxdict is None:
135 | bboxdict_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'valbox_{}.pkl'.format(kfold))
136 | with open(bboxdict_path, 'rb') as f:
137 | bboxdict = pickle.load(f)
138 |
139 | if labeldict is None:
140 | labeldict_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'vallabel_{}.pkl'.format(kfold))
141 | with open(labeldict_path, 'rb') as f:
142 | labeldict = pickle.load(f)
143 |
144 | if aug is None:
145 | aug = val_aug
146 |
147 | dataset = RsnaDataset(valfps, bboxdict, labeldict, aug)
148 | return DataLoader(dataset,
149 | batch_size=config['batch_size'],
150 | shuffle=False,
151 | num_workers=config['num_workers'],
152 | pin_memory=config['use_cuda'])
153 |
154 |
--------------------------------------------------------------------------------
/dataGen/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 | from setuptools.extension import Extension
3 | import numpy as np
4 | from Cython.Build import cythonize
5 |
6 | extensions = [
7 | Extension(
8 | 'dataGen.compute_overlap',
9 | ['dataGen/compute_overlap.pyx'],
10 | include_dirs=[np.get_include()]
11 | ),
12 | ]
13 |
14 | setuptools.setup(
15 | name='dataGen',
16 | packages=setuptools.find_packages(),
17 | # same with `ext_modules=extensions`,
18 | ext_modules=cythonize(extensions),
19 | setup_requires=["cython>=0.28"]
20 | )
21 |
--------------------------------------------------------------------------------
/dataGen/targetBuild.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | from .compute_overlap import compute_overlap
5 |
6 | config_path = os.path.join(os.path.dirname(__file__), '..', 'config.json')
7 | with open(config_path, 'r') as f:
8 | config = json.load(f)
9 |
10 |
11 | class AnchorParameters:
12 | """ The parameteres that define how anchors are generated.
13 |
14 | Args
15 | sizes : List of sizes to use. Each size corresponds to one feature level.
16 | strides : List of strides to use. Each stride correspond to one feature level.
17 | ratios : List of ratios to use per location in a feature map.
18 | scales : List of scales to use per location in a feature map.
19 | """
20 |
21 | def __init__(self, sizes, strides, ratios, scales):
22 | self.sizes = sizes
23 | self.strides = strides
24 | self.ratios = ratios
25 | self.scales = scales
26 |
27 | def num_anchors(self):
28 | return len(self.ratios) * len(self.scales)
29 |
30 |
31 | """
32 | The default anchor parameters.
33 | """
34 | AnchorParameters.default = AnchorParameters(
35 | sizes=config['anchor_sizes_default'],
36 | strides=config['anchor_strides_default'],
37 | ratios=np.array(config['anchor_ratios_default'], np.float32),
38 | scales=np.array(config['anchor_scales_default'], np.float32),
39 | )
40 |
41 |
42 | def generate_anchors(base_size=16, ratios=None, scales=None):
43 | """
44 | Generate anchor (reference) windows by enumerating aspect ratios X
45 | scales w.r.t. a reference window.
46 | """
47 |
48 | if ratios is None:
49 | ratios = AnchorParameters.default.ratios
50 |
51 | if scales is None:
52 | scales = AnchorParameters.default.scales
53 |
54 | num_anchors = len(ratios) * len(scales)
55 |
56 | # initialize output anchors
57 | anchors = np.zeros((num_anchors, 4), dtype=np.float32)
58 |
59 | # scale base_size
60 | anchors[:, 2:] = base_size * np.tile(scales, (2, len(ratios))).T
61 |
62 | # compute areas of anchors
63 | areas = anchors[:, 2] * anchors[:, 3]
64 |
65 | # correct for ratios
66 | anchors[:, 2] = np.sqrt(areas / np.repeat(ratios, len(scales)))
67 | anchors[:, 3] = anchors[:, 2] * np.repeat(ratios, len(scales))
68 |
69 | # transform from (x_ctr, y_ctr, w, h) -> (x1, y1, x2, y2)
70 | anchors[:, 0::2] -= np.tile(anchors[:, 2] * 0.5, (2, 1)).T
71 | anchors[:, 1::2] -= np.tile(anchors[:, 3] * 0.5, (2, 1)).T
72 |
73 | return anchors
74 |
75 |
76 | def guess_shapes(image_shape, pyramid_levels=None):
77 | """Guess shapes based on pyramid levels.
78 |
79 | Args
80 | image_shape: The shape of the image.
81 | pyramid_levels: A list of what pyramid levels are used.
82 |
83 | Returns
84 | A list of image shapes at each pyramid level.
85 | """
86 | if pyramid_levels is None:
87 | pyramid_levels = config['pyramid_levels_default']
88 | image_shape = np.array(image_shape)
89 | image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in pyramid_levels]
90 | return image_shapes
91 |
92 |
93 | def shift(shape, stride, anchors):
94 | """ Produce shifted anchors based on shape of the map and stride size.
95 |
96 | Args
97 | shape : Shape to shift the anchors over.
98 | stride : Stride to shift the anchors with over the shape.
99 | anchors: The anchors to apply at each location.
100 | """
101 |
102 | # create a grid starting from half stride from the top left corner
103 | shift_x = (np.arange(0, shape[1]) + 0.5) * stride
104 | shift_y = (np.arange(0, shape[0]) + 0.5) * stride
105 |
106 | shift_x, shift_y = np.meshgrid(shift_x, shift_y)
107 |
108 | shifts = np.vstack((
109 | shift_x.ravel(), shift_y.ravel(),
110 | shift_x.ravel(), shift_y.ravel()
111 | )).transpose()
112 |
113 | # add A anchors (1, A, 4) to
114 | # cell K shifts (K, 1, 4) to get
115 | # shift anchors (K, A, 4)
116 | # reshape to (K*A, 4) shifted anchors
117 | A = anchors.shape[0]
118 | K = shifts.shape[0]
119 | all_anchors = (anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2)))
120 | all_anchors = all_anchors.reshape((K * A, 4))
121 |
122 | return all_anchors
123 |
124 |
125 | def anchors_for_shape(
126 | image_shape,
127 | pyramid_levels=None,
128 | anchor_params=None,
129 | ):
130 | """ Generators anchors for a given shape.
131 |
132 | Args
133 | image_shape: The shape of the image.
134 | pyramid_levels: List of ints representing which pyramids to use (defaults to [3, 4, 5, 6, 7]).
135 | anchor_params: Struct containing anchor parameters. If None, default values are used.
136 |
137 | Returns
138 | np.array of shape (N, 4) containing the (x1, y1, x2, y2) coordinates for the anchors.
139 | """
140 |
141 | if pyramid_levels is None:
142 | pyramid_levels = config['pyramid_levels_default']
143 |
144 | if anchor_params is None:
145 | anchor_params = AnchorParameters.default
146 |
147 | image_shapes = guess_shapes(image_shape, pyramid_levels)
148 |
149 | # compute anchors over all pyramid levels
150 | all_anchors = np.zeros((0, 4))
151 | for idx, p in enumerate(pyramid_levels):
152 | anchors = generate_anchors(
153 | base_size=anchor_params.sizes[idx],
154 | ratios=anchor_params.ratios,
155 | scales=anchor_params.scales
156 | )
157 | shifted_anchors = shift(image_shapes[idx], anchor_params.strides[idx], anchors)
158 | all_anchors = np.append(all_anchors, shifted_anchors, axis=0)
159 |
160 | return all_anchors
161 |
162 |
163 | def compute_gt_annotations(
164 | anchors,
165 | annotations,
166 | negative_overlap=0.4,
167 | positive_overlap=0.5
168 | ):
169 | """ Obtain indices of gt annotations with the greatest overlap.
170 |
171 | Args
172 | anchors: np.array of annotations of shape (N, 4) for (x1, y1, x2, y2).
173 | annotations: np.array of shape (N, 5) for (x1, y1, x2, y2, label).
174 | negative_overlap: IoU overlap for negative anchors (all anchors with overlap < negative_overlap are negative).
175 | positive_overlap: IoU overlap or positive anchors (all anchors with overlap > positive_overlap are positive).
176 |
177 | Returns
178 | positive_indices: indices of positive anchors
179 | ignore_indices: indices of ignored anchors
180 | argmax_overlaps_inds: ordered overlaps indices
181 | """
182 |
183 | overlaps = compute_overlap(anchors.astype(np.float64), annotations.astype(np.float64))
184 | argmax_overlaps_inds = np.argmax(overlaps, axis=1)
185 | max_overlaps = overlaps[np.arange(overlaps.shape[0]), argmax_overlaps_inds]
186 |
187 | # assign "dont care" labels
188 | positive_indices = max_overlaps >= positive_overlap
189 | ignore_indices = (max_overlaps > negative_overlap) & ~positive_indices
190 |
191 | return positive_indices, ignore_indices, argmax_overlaps_inds
192 |
193 |
194 | def bbox_transform(anchors, gt_boxes, mean=None, std=None):
195 | """Compute bounding-box regression targets for an image."""
196 |
197 | if mean is None:
198 | mean = np.array(config['mean_bbox_transform'], dtype=np.float32)
199 | if std is None:
200 | std = np.array(config['std_bbox_transform'], dtype=np.float32)
201 |
202 | if isinstance(mean, (list, tuple)):
203 | mean = np.array(mean)
204 | elif not isinstance(mean, np.ndarray):
205 | raise ValueError('Expected mean to be a np.ndarray, list or tuple. Received: {}'.format(type(mean)))
206 |
207 | if isinstance(std, (list, tuple)):
208 | std = np.array(std)
209 | elif not isinstance(std, np.ndarray):
210 | raise ValueError('Expected std to be a np.ndarray, list or tuple. Received: {}'.format(type(std)))
211 |
212 | anchor_widths = anchors[:, 2] - anchors[:, 0]
213 | anchor_heights = anchors[:, 3] - anchors[:, 1]
214 |
215 | targets_dx1 = (gt_boxes[:, 0] - anchors[:, 0]) / anchor_widths
216 | targets_dy1 = (gt_boxes[:, 1] - anchors[:, 1]) / anchor_heights
217 | targets_dx2 = (gt_boxes[:, 2] - anchors[:, 2]) / anchor_widths
218 | targets_dy2 = (gt_boxes[:, 3] - anchors[:, 3]) / anchor_heights
219 |
220 | targets = np.stack((targets_dx1, targets_dy1, targets_dx2, targets_dy2))
221 | targets = targets.T
222 |
223 | targets = (targets - mean) / std
224 |
225 | return targets
226 |
227 |
228 | def anchor_targets_bbox(
229 | anchors,
230 | annotations,
231 | num_classes=None,
232 | negative_overlap=0.4,
233 | positive_overlap=0.5
234 | ):
235 | """ Generate anchor targets for bbox detection.
236 |
237 | Args
238 | anchors: np.array of annotations of shape (N, 4) for (x1, y1, x2, y2).
239 | annotations: annotations (np.array of shape (N, 5) for (x1, y1, x2, y2, label)).
240 | num_classes: Number of classes to predict.
241 | negative_overlap: IoU overlap for negative anchors (all anchors with overlap < negative_overlap are negative).
242 | positive_overlap: IoU overlap or positive anchors (all anchors with overlap > positive_overlap are positive).
243 |
244 | Returns
245 | labels: that contains labels & anchor states (np.array of shape (N, num_classes + 1),
246 | where N is the number of anchors for an image and
247 | the last column defines the anchor state (-1 for ignore, 0 for bg, 1 for fg).
248 | regression: that contains bounding-box regression targets for an image & anchor states (np.array of shape (N, 4 + 1),
249 | where N is the number of anchors for an image, the first 4 columns define regression targets for (x1, y1, x2, y2)
250 | and the last column defines anchor states (-1 for ignore, 0 for bg, 1 for fg).
251 | """
252 | if num_classes is None:
253 | num_classes = config['num_classes']
254 | regression = np.zeros((anchors.shape[0], 4 + 1), dtype=np.float32)
255 | labels = np.zeros((anchors.shape[0], num_classes + 1), dtype=np.float32)
256 |
257 | # compute labels and regression targets
258 | if annotations.shape[0]:
259 | # obtain indices of gt annotations with the greatest overlap
260 | positive_indices, ignore_indices, argmax_overlaps_inds = compute_gt_annotations(anchors, annotations,
261 | negative_overlap,
262 | positive_overlap)
263 | labels[ignore_indices, -1] = -1
264 | labels[positive_indices, -1] = 1
265 |
266 | regression[ignore_indices, -1] = -1
267 | regression[positive_indices, -1] = 1
268 |
269 | # compute box regression targets
270 | annotations = annotations[argmax_overlaps_inds]
271 |
272 | # compute target class labels
273 | labels[positive_indices, annotations[positive_indices, 4].astype(int)] = 1
274 |
275 | regression[:, :-1] = bbox_transform(anchors, annotations)
276 |
277 | return labels, regression
278 |
--------------------------------------------------------------------------------
/dataGen/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def iou(box1, box2):
4 | """
5 | From Yicheng Chen's "Mean Average Precision Metric"
6 | https://www.kaggle.com/chenyc15/mean-average-precision-metric
7 |
8 | helper function to calculate IoU
9 | """
10 | x11, y11, x12, y12 = box1
11 | x21, y21, x22, y22 = box2
12 | w1, h1 = x12-x11, y12-y11
13 | w2, h2 = x22-x21, y22-y21
14 |
15 | area1, area2 = w1 * h1, w2 * h2
16 | xi1, yi1, xi2, yi2 = max([x11, x21]), max(
17 | [y11, y21]), min([x12, x22]), min([y12, y22])
18 |
19 | if xi2 <= xi1 or yi2 <= yi1:
20 | return 0
21 | else:
22 | intersect = (xi2-xi1) * (yi2-yi1)
23 | union = area1 + area2 - intersect
24 | return intersect / union
25 |
26 |
27 | def nms(boxes, scores, overlapThresh):
28 | """
29 | adapted from non-maximum suppression by Adrian Rosebrock
30 | https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/
31 | """
32 |
33 | # if there are no boxes, return an empty list
34 | if len(boxes) == 0:
35 | return np.array([]).reshape(0, 4), np.array([])
36 | if boxes.dtype.kind == "i":
37 | boxes = boxes.astype("float")
38 |
39 | pick = []
40 | x1 = boxes[:, 0]
41 | y1 = boxes[:, 1]
42 | x2 = boxes[:, 2]
43 | y2 = boxes[:, 3]
44 |
45 | # compute the area of the bounding boxes
46 | area = (x2 - x1 + 1) * (y2 - y1 + 1)
47 |
48 | # sort the bounding boxes by scores in ascending order
49 | idxs = np.argsort(scores)
50 |
51 | # keep looping while indexes still remain in the indexes list
52 | while len(idxs) > 0:
53 | # grab the last index in the indexes list and add the
54 | # index value to the list of picked indexes
55 | last = len(idxs) - 1
56 | i = idxs[last]
57 | pick.append(i)
58 |
59 | # find the largest (x, y) coordinates for the start of
60 | # the bounding box and the smallest (x, y) coordinates
61 | # for the end of the bounding box
62 | xx1 = np.maximum(x1[i], x1[idxs[:last]])
63 | yy1 = np.maximum(y1[i], y1[idxs[:last]])
64 | xx2 = np.minimum(x2[i], x2[idxs[:last]])
65 | yy2 = np.minimum(y2[i], y2[idxs[:last]])
66 |
67 | # compute the width and height of the bounding box
68 | w = np.maximum(0, xx2 - xx1 + 1)
69 | h = np.maximum(0, yy2 - yy1 + 1)
70 |
71 | # compute the ratio of overlap
72 | overlap = (w * h) / area[idxs[:last]]
73 |
74 | # delete all indexes from the index list that have
75 | idxs = np.delete(idxs, np.concatenate(([last],
76 | np.where(overlap > overlapThresh)[0])))
77 |
78 | # return only the bounding boxes that were picked using the
79 | # integer data type
80 | return boxes[pick], scores[pick]
81 |
--------------------------------------------------------------------------------
/dataset:
--------------------------------------------------------------------------------
1 | /home/raytroop/sand/rsna-pneumonia-detection-challenge_stg1
--------------------------------------------------------------------------------
/fold_data/trnbox_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnbox_1.pkl
--------------------------------------------------------------------------------
/fold_data/trnbox_2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnbox_2.pkl
--------------------------------------------------------------------------------
/fold_data/trnbox_3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnbox_3.pkl
--------------------------------------------------------------------------------
/fold_data/trnbox_4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnbox_4.pkl
--------------------------------------------------------------------------------
/fold_data/trnbox_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnbox_5.pkl
--------------------------------------------------------------------------------
/fold_data/trnbox_6.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnbox_6.pkl
--------------------------------------------------------------------------------
/fold_data/trnfps_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnfps_1.pkl
--------------------------------------------------------------------------------
/fold_data/trnfps_2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnfps_2.pkl
--------------------------------------------------------------------------------
/fold_data/trnfps_3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnfps_3.pkl
--------------------------------------------------------------------------------
/fold_data/trnfps_4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnfps_4.pkl
--------------------------------------------------------------------------------
/fold_data/trnfps_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnfps_5.pkl
--------------------------------------------------------------------------------
/fold_data/trnfps_6.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnfps_6.pkl
--------------------------------------------------------------------------------
/fold_data/trnlabel_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnlabel_1.pkl
--------------------------------------------------------------------------------
/fold_data/trnlabel_2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnlabel_2.pkl
--------------------------------------------------------------------------------
/fold_data/trnlabel_3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnlabel_3.pkl
--------------------------------------------------------------------------------
/fold_data/trnlabel_4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnlabel_4.pkl
--------------------------------------------------------------------------------
/fold_data/trnlabel_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnlabel_5.pkl
--------------------------------------------------------------------------------
/fold_data/trnlabel_6.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/trnlabel_6.pkl
--------------------------------------------------------------------------------
/fold_data/valbox_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valbox_1.pkl
--------------------------------------------------------------------------------
/fold_data/valbox_2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valbox_2.pkl
--------------------------------------------------------------------------------
/fold_data/valbox_3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valbox_3.pkl
--------------------------------------------------------------------------------
/fold_data/valbox_4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valbox_4.pkl
--------------------------------------------------------------------------------
/fold_data/valbox_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valbox_5.pkl
--------------------------------------------------------------------------------
/fold_data/valbox_6.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valbox_6.pkl
--------------------------------------------------------------------------------
/fold_data/valfps_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valfps_1.pkl
--------------------------------------------------------------------------------
/fold_data/valfps_2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valfps_2.pkl
--------------------------------------------------------------------------------
/fold_data/valfps_3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valfps_3.pkl
--------------------------------------------------------------------------------
/fold_data/valfps_4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valfps_4.pkl
--------------------------------------------------------------------------------
/fold_data/valfps_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valfps_5.pkl
--------------------------------------------------------------------------------
/fold_data/valfps_6.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/valfps_6.pkl
--------------------------------------------------------------------------------
/fold_data/vallabel_1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/vallabel_1.pkl
--------------------------------------------------------------------------------
/fold_data/vallabel_2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/vallabel_2.pkl
--------------------------------------------------------------------------------
/fold_data/vallabel_3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/vallabel_3.pkl
--------------------------------------------------------------------------------
/fold_data/vallabel_4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/vallabel_4.pkl
--------------------------------------------------------------------------------
/fold_data/vallabel_5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/vallabel_5.pkl
--------------------------------------------------------------------------------
/fold_data/vallabel_6.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/fold_data/vallabel_6.pkl
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/models/__init__.py
--------------------------------------------------------------------------------
/models/backbone.py:
--------------------------------------------------------------------------------
1 | """
2 | ResNet code gently borrowed from
3 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
4 | """
5 | import os
6 | from collections import OrderedDict
7 | import math
8 |
9 | import torch.nn as nn
10 | import torch
11 |
12 |
13 | pretrained_settings = {
14 | 'se_resnext50_32x4d': {
15 | 'imagenet': {
16 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
17 | 'input_space': 'RGB',
18 | 'input_size': [3, 224, 224],
19 | 'input_range': [0, 1],
20 | 'mean': [0.485, 0.456, 0.406],
21 | 'std': [0.229, 0.224, 0.225],
22 | 'num_classes': 1000
23 | }
24 | },
25 | 'se_resnext101_32x4d': {
26 | 'imagenet': {
27 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
28 | 'input_space': 'RGB',
29 | 'input_size': [3, 224, 224],
30 | 'input_range': [0, 1],
31 | 'mean': [0.485, 0.456, 0.406],
32 | 'std': [0.229, 0.224, 0.225],
33 | 'num_classes': 1000
34 | }
35 | },
36 | }
37 |
38 |
39 | class SEModule(nn.Module):
40 |
41 | def __init__(self, channels, reduction):
42 | super(SEModule, self).__init__()
43 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
44 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,
45 | padding=0)
46 | self.relu = nn.ReLU(inplace=True)
47 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,
48 | padding=0)
49 | self.sigmoid = nn.Sigmoid()
50 |
51 | def forward(self, x):
52 | module_input = x
53 | x = self.avg_pool(x)
54 | x = self.fc1(x)
55 | x = self.relu(x)
56 | x = self.fc2(x)
57 | x = self.sigmoid(x)
58 | return module_input * x
59 |
60 |
61 | class Bottleneck(nn.Module):
62 | """
63 | Base class for bottlenecks that implements `forward()` method.
64 | """
65 | def forward(self, x):
66 | residual = x
67 |
68 | out = self.conv1(x)
69 | out = self.bn1(out)
70 | out = self.relu(out)
71 |
72 | out = self.conv2(out)
73 | out = self.bn2(out)
74 | out = self.relu(out)
75 |
76 | out = self.conv3(out)
77 | out = self.bn3(out)
78 |
79 | if self.downsample is not None:
80 | residual = self.downsample(x)
81 |
82 | out = self.se_module(out) + residual
83 | out = self.relu(out)
84 |
85 | return out
86 |
87 |
88 | class SEResNeXtBottleneck(Bottleneck):
89 | """
90 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
91 | """
92 | expansion = 4
93 |
94 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
95 | downsample=None, base_width=4):
96 | super(SEResNeXtBottleneck, self).__init__()
97 | width = math.floor(planes * (base_width / 64)) * groups
98 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False,
99 | stride=1)
100 | self.bn1 = nn.BatchNorm2d(width)
101 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
102 | padding=1, groups=groups, bias=False)
103 | self.bn2 = nn.BatchNorm2d(width)
104 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
105 | self.bn3 = nn.BatchNorm2d(planes * 4)
106 | self.relu = nn.ReLU(inplace=True)
107 | self.se_module = SEModule(planes * 4, reduction=reduction)
108 | self.downsample = downsample
109 | self.stride = stride
110 |
111 |
112 | class BackBone(nn.Module):
113 |
114 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
115 | inplanes=128, input_3x3=True, downsample_kernel_size=3,
116 | downsample_padding=1):
117 | """
118 | Parameters
119 | ----------
120 | block (nn.Module): Bottleneck class.
121 | - For SE-ResNeXt models: SEResNeXtBottleneck
122 | layers (list of ints): Number of residual blocks for 4 layers of the
123 | network (layer1...layer4).
124 | groups (int): Number of groups for the 3x3 convolution in each
125 | bottleneck block.
126 | - For BackBone154: 64
127 | - For SE-ResNet models: 1
128 | - For SE-ResNeXt models: 32
129 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
130 | - For all models: 16
131 | dropout_p (float or None): Drop probability for the Dropout layer.
132 | If `None` the Dropout layer is not used.
133 | - For BackBone154: 0.2
134 | - For SE-ResNet models: None
135 | - For SE-ResNeXt models: None
136 | inplanes (int): Number of input channels for layer1.
137 | - For BackBone154: 128
138 | - For SE-ResNet models: 64
139 | - For SE-ResNeXt models: 64
140 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
141 | a single 7x7 convolution in layer0.
142 | - For BackBone154: True
143 | - For SE-ResNet models: False
144 | - For SE-ResNeXt models: False
145 | downsample_kernel_size (int): Kernel size for downsampling convolutions
146 | in layer2, layer3 and layer4.
147 | - For BackBone154: 3
148 | - For SE-ResNet models: 1
149 | - For SE-ResNeXt models: 1
150 | downsample_padding (int): Padding for downsampling convolutions in
151 | layer2, layer3 and layer4.
152 | - For BackBone154: 1
153 | - For SE-ResNet models: 0
154 | - For SE-ResNeXt models: 0
155 | num_classes (int): Number of outputs in `last_linear` layer.
156 | - For all models: 1000
157 | """
158 | super(BackBone, self).__init__()
159 | self.inplanes = inplanes
160 | if input_3x3:
161 | layer0_modules = [
162 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1,
163 | bias=False)),
164 | ('bn1', nn.BatchNorm2d(64)),
165 | ('relu1', nn.ReLU(inplace=True)),
166 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1,
167 | bias=False)),
168 | ('bn2', nn.BatchNorm2d(64)),
169 | ('relu2', nn.ReLU(inplace=True)),
170 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1,
171 | bias=False)),
172 | ('bn3', nn.BatchNorm2d(inplanes)),
173 | ('relu3', nn.ReLU(inplace=True)),
174 | ]
175 | else:
176 | layer0_modules = [
177 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2,
178 | padding=3, bias=False)),
179 | ('bn1', nn.BatchNorm2d(inplanes)),
180 | ('relu1', nn.ReLU(inplace=True)),
181 | ]
182 | # To preserve compatibility with Caffe weights `ceil_mode=True`
183 | # is used instead of `padding=1`.
184 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2,
185 | ceil_mode=True)))
186 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
187 | self.layer1 = self._make_layer(
188 | block,
189 | planes=64,
190 | blocks=layers[0],
191 | groups=groups,
192 | reduction=reduction,
193 | downsample_kernel_size=1,
194 | downsample_padding=0
195 | )
196 | self.layer2 = self._make_layer(
197 | block,
198 | planes=128,
199 | blocks=layers[1],
200 | stride=2,
201 | groups=groups,
202 | reduction=reduction,
203 | downsample_kernel_size=downsample_kernel_size,
204 | downsample_padding=downsample_padding
205 | )
206 | self.layer3 = self._make_layer(
207 | block,
208 | planes=256,
209 | blocks=layers[2],
210 | stride=2,
211 | groups=groups,
212 | reduction=reduction,
213 | downsample_kernel_size=downsample_kernel_size,
214 | downsample_padding=downsample_padding
215 | )
216 | self.layer4 = self._make_layer(
217 | block,
218 | planes=512,
219 | blocks=layers[3],
220 | stride=2,
221 | groups=groups,
222 | reduction=reduction,
223 | downsample_kernel_size=downsample_kernel_size,
224 | downsample_padding=downsample_padding
225 | )
226 |
227 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
228 | downsample_kernel_size=1, downsample_padding=0):
229 | downsample = None
230 | if stride != 1 or self.inplanes != planes * block.expansion:
231 | downsample = nn.Sequential(
232 | nn.Conv2d(self.inplanes, planes * block.expansion,
233 | kernel_size=downsample_kernel_size, stride=stride,
234 | padding=downsample_padding, bias=False),
235 | nn.BatchNorm2d(planes * block.expansion),
236 | )
237 |
238 | layers = []
239 | layers.append(block(self.inplanes, planes, groups, reduction, stride,
240 | downsample))
241 | self.inplanes = planes * block.expansion
242 | for i in range(1, blocks):
243 | layers.append(block(self.inplanes, planes, groups, reduction))
244 |
245 | return nn.Sequential(*layers)
246 |
247 | def forward(self, x):
248 | x = self.layer0(x)
249 | C2 = self.layer1(x)
250 | C3 = self.layer2(C2)
251 | C4 = self.layer3(C3)
252 | C5 = self.layer4(C4)
253 | return [C3, C4, C5]
254 |
255 |
256 | def initialize_pretrained_model(model, settings):
257 | weights_path = os.path.join(os.path.join(os.path.dirname(__file__), '..',
258 | 'pretrainedmodels', os.path.basename(settings['url'])))
259 | state_dict = torch.load(weights_path)
260 | # filter out last layer
261 | for layer in ['last_linear.weight', 'last_linear.bias']:
262 | del state_dict[layer]
263 | state_dict_cur = model.state_dict()
264 | state_dict_cur.update(state_dict_cur)
265 |
266 | model.load_state_dict(state_dict_cur)
267 | model.input_space = settings['input_space']
268 | model.input_size = settings['input_size']
269 | model.input_range = settings['input_range']
270 | model.mean = settings['mean']
271 | model.std = settings['std']
272 |
273 |
274 | def se_resnext50_32x4d(pretrained_imagenet):
275 | model = BackBone(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
276 | dropout_p=None, inplanes=64, input_3x3=False,
277 | downsample_kernel_size=1, downsample_padding=0,
278 | )
279 | if pretrained_imagenet:
280 | settings = pretrained_settings['se_resnext50_32x4d']['imagenet']
281 | initialize_pretrained_model(model, settings)
282 | return model
283 |
284 |
285 | def se_resnext101_32x4d(pretrained_imagenet):
286 | model = BackBone(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
287 | dropout_p=None, inplanes=64, input_3x3=False,
288 | downsample_kernel_size=1, downsample_padding=0,
289 | )
290 | if pretrained_imagenet:
291 | settings = pretrained_settings['se_resnext101_32x4d']['imagenet']
292 | initialize_pretrained_model(model, settings)
293 | return model
294 |
--------------------------------------------------------------------------------
/models/fpn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from . import backbone
8 | from .misc import filter_detections, bbox_transform_inv, clip_boxes, build_anchors
9 | from dataGen.targetBuild import anchors_for_shape
10 |
11 | config_path = os.path.join(os.path.dirname(__file__), '..', 'config.json')
12 | with open(config_path, 'r') as f:
13 | config = json.load(f)
14 |
15 |
16 | class top_down(nn.Module):
17 | """ Creates the FPN layers on top of the backbone features.
18 |
19 | Args
20 | C3 : Feature stage C3 from the backbone.
21 | C4 : Feature stage C4 from the backbone.
22 | C5 : Feature stage C5 from the backbone.
23 | feature_size : The feature size to use for the resulting feature levels.
24 |
25 | Returns
26 | A list of feature levels [P3, P4, P5, P6, P7].
27 | """
28 |
29 | def __init__(self, feature_size=256):
30 | super(top_down, self).__init__()
31 | self.C5_reduced = nn.Conv2d(2048, feature_size, kernel_size=1, stride=1)
32 | self.P5_conv = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
33 | self.C4_reduced = nn.Conv2d(1024, feature_size, kernel_size=1, stride=1)
34 | self.P4_conv = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
35 | self.C3_reduced = nn.Conv2d(512, feature_size, kernel_size=1, stride=1)
36 | self.P3_conv = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
37 | self.P6_conv = nn.Conv2d(2048, feature_size, kernel_size=3, stride=2, padding=1)
38 | self.P7_conv = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=2, padding=1)
39 |
40 | def forward(self, x):
41 | C3, C4, C5 = x
42 |
43 | # upsample C5 to get P5 from the FPN paper
44 | P5 = self.C5_reduced(C5)
45 | P5_upsampled = F.interpolate(P5, scale_factor=2, mode='nearest')
46 | P5 = self.P5_conv(P5)
47 |
48 | # add P5 elementwise to C4
49 | P4 = self.C4_reduced(C4)
50 | P4 = P5_upsampled + P4
51 | P4_upsampled = F.interpolate(P4, scale_factor=2, mode='nearest')
52 | P4 = self.P4_conv(P4)
53 |
54 | # add P4 elementwise to C3
55 | P3 = self.C3_reduced(C3)
56 | P3 = P4_upsampled + P3
57 | P3 = self.P3_conv(P3)
58 |
59 | # "P6 is obtained via a 3x3 stride-2 conv on C5"
60 | P6 = self.P6_conv(C5)
61 |
62 | # "P7 is computed by applying ReLU followed by a 3x3 stride-2 conv on P6"
63 | P7 = F.relu(P6)
64 | P7 = self.P7_conv(P7)
65 |
66 | return [P3, P4, P5, P6, P7]
67 |
68 |
69 | class classification_subnet(nn.Module):
70 | """ the default classification submodel."""
71 | options = {
72 | 'kernel_size': 3,
73 | 'stride': 1,
74 | 'padding': 1,
75 | }
76 |
77 | def __init__(self, num_classes=1, num_anchors=9, pyramid_feature_size=256, prior_probability=0.01):
78 | """
79 | Args
80 | num_classes : Number of classes to predict a score for at each feature level.
81 | num_anchors : Number of anchors to predict classification scores for at each feature level.
82 | pyramid_feature_size : The number of filters to expect from the feature pyramid levels.
83 | prior_probability : Prior probability for training stability in early training
84 | """
85 | super().__init__()
86 | convs = []
87 | for i in range(4):
88 | conv = nn.Conv2d(pyramid_feature_size, pyramid_feature_size, **classification_subnet.options)
89 | nn.init.normal_(conv.weight, mean=0.0, std=0.01)
90 | nn.init.zeros_(conv.bias)
91 | convs.append(conv)
92 | convs.append(nn.ReLU())
93 | self.feats = nn.Sequential(*convs)
94 | self.num_classes = num_classes
95 | head = nn.Conv2d(pyramid_feature_size, out_channels=num_classes * num_anchors, **classification_subnet.options)
96 | nn.init.normal_(head.weight, mean=0.0, std=0.01)
97 | nn.init.constant_(head.bias, val=-math.log((1 - prior_probability) / prior_probability))
98 | self.head = head
99 |
100 | def forward(self, x):
101 | outputs = self.feats(x)
102 | outputs = self.head(outputs)
103 |
104 | # reshape output and apply sigmoid
105 | outputs = outputs.permute(0, 2, 3, 1).contiguous()
106 | outputs = outputs.view(outputs.shape[0], -1, self.num_classes)
107 | outputs = torch.sigmoid(outputs)
108 | return outputs
109 |
110 |
111 | class regression_subnet(nn.Module):
112 | """ Creates the default regression submodel."""
113 | options = {
114 | 'kernel_size': 3,
115 | 'stride': 1,
116 | 'padding': 1,
117 | }
118 |
119 | def __init__(self, num_values=4, num_anchors=9, pyramid_feature_size=256):
120 | """
121 | Args
122 | num_values : Number of values to regress.
123 | num_anchors : Number of anchors to regress for each feature level.
124 | pyramid_feature_size : The number of filters to expect from the feature pyramid levels.
125 | """
126 | super().__init__()
127 | self.num_values = num_values
128 | convs = []
129 | for i in range(4):
130 | conv = nn.Conv2d(pyramid_feature_size, pyramid_feature_size, **regression_subnet.options)
131 | nn.init.normal_(conv.weight, mean=0.0, std=0.01)
132 | nn.init.zeros_(conv.bias)
133 | convs.append(conv)
134 | convs.append(nn.ReLU())
135 | self.feats = nn.Sequential(*convs)
136 |
137 | head = nn.Conv2d(pyramid_feature_size, out_channels=num_anchors * num_values, **regression_subnet.options)
138 | nn.init.normal_(head.weight, mean=0.0, std=0.01)
139 | nn.init.zeros_(head.bias)
140 | self.head = head
141 |
142 | def forward(self, x):
143 | outputs = self.feats(x)
144 | outputs = self.head(outputs)
145 |
146 | # reshape
147 | outputs = outputs.permute(0, 2, 3, 1).contiguous()
148 | outputs = outputs.view(outputs.shape[0], -1, self.num_values)
149 | return outputs
150 |
151 |
152 | class retinanet(nn.Module):
153 | """ Construct a RetinaNet model on top of a backbone, without bbox prediction transform"""
154 |
155 | def __init__(self, backbone_name=None, num_classes=None, num_anchors=None, pretrained_imagenet=None):
156 | """
157 | Args
158 | backbone_name : backbone name, `se_resnext50_32x4d` or `se_resnext101_32x4d`
159 | num_classes : Number of classes to classify.
160 | num_anchors : Number of base anchors.
161 | """
162 | super().__init__()
163 |
164 | if backbone_name is None:
165 | backbone_name = config['backbone']
166 | assert backbone_name in ['se_resnext50_32x4d', 'se_resnext101_32x4d'], \
167 | "`se_resnext50_32x4d` or `se_resnext101_32x4d`"
168 | bottom_up = getattr(backbone, backbone_name)
169 |
170 | if pretrained_imagenet is None:
171 | pretrained_imagenet=config['pretrained_imagenet']
172 | self.bottom_up = bottom_up(pretrained_imagenet)
173 |
174 | self.top_down = top_down(feature_size=256)
175 |
176 | if num_classes is None:
177 | num_classes = config['num_classes']
178 | if num_anchors is None:
179 | num_anchors = len(config['anchor_ratios_default']) * len(config['anchor_scales_default'])
180 | self.classification_subnet = classification_subnet(num_classes, num_anchors, 256, 0.01)
181 | self.regression_subnet = regression_subnet(4, num_anchors, 256)
182 |
183 | def forward(self, images):
184 | """
185 | Args:
186 | images: Tensor of (B, 3, H, W), where B is the batch size; H, w is image height, width
187 | """
188 | C3, C4, C5 = self.bottom_up(images)
189 | P3, P4, P5, P6, P7 = self.top_down((C3, C4, C5))
190 | classification_output = []
191 | regression_output = []
192 | for P in [P3, P4, P5, P6, P7]:
193 | classification_output.append(self.classification_subnet(P))
194 | regression_output.append(self.regression_subnet(P))
195 |
196 | classification = torch.cat(classification_output, dim=1)
197 | regression = torch.cat(regression_output, dim=1)
198 |
199 | return classification, regression, [P3, P4, P5, P6, P7]
200 |
201 | def predict(self, images):
202 | """
203 | Args:
204 | images: Tensor of (B, 3, H, W), where B is the batch size; H, W is image height, width
205 |
206 | Returns:
207 | list of [bboxes, labels, scores] per image
208 | """
209 | # C3, C4, C5 = self.bottom_up(images)
210 | # P3, P4, P5, P6, P7 = self.top_down((C3, C4, C5))
211 | # classification_output = []
212 | # regression_output = []
213 | # for P in [P3, P4, P5, P6, P7]:
214 | # classification_output.append(self.classification_subnet(P))
215 | # regression_output.append(self.regression_subnet(P))
216 |
217 | # classification = torch.cat(classification_output, dim=1)
218 | # regression = torch.cat(regression_output, dim=1)
219 | classification, regression, [P3, P4, P5, P6, P7] = self.__call__(images)
220 | print(classification.max().item())
221 | anchors = build_anchors(features=[P3, P4, P5, P6, P7])
222 | bboxes = bbox_transform_inv(anchors, regression)
223 | del anchors
224 | bboxes = clip_boxes(images, bboxes)
225 |
226 | return filter_detections(bboxes, classification)
227 |
228 | def train_extractor(self, active=True):
229 | if active:
230 | for p in self.bottom_up.parameters():
231 | p.requires_grad = True
232 | self.bottom_up.train()
233 | else:
234 | for p in self.bottom_up.parameters():
235 | p.requires_grad = False
236 | self.bottom_up.eval()
237 |
--------------------------------------------------------------------------------
/models/losses.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | from torch.nn import functional as F
5 |
6 | config_path = os.path.join(os.path.dirname(__file__), '..', 'config.json')
7 | with open(config_path, 'r') as f:
8 | config = json.load(f)
9 | # torch.device object used throughout this script
10 | device = torch.device("cuda" if config['use_cuda'] else "cpu")
11 |
12 | def focal_loss(alpha=0.25, gamma=2.0):
13 | """ Create a functor for computing the focal loss.
14 |
15 | Args
16 | alpha: Scale the focal weight with alpha.
17 | gamma: Take the power of the focal weight with gamma.
18 |
19 | Returns
20 | A functor that computes the focal loss using the alpha and gamma.
21 | """
22 |
23 | def _focal(y_true, y_pred):
24 | """ Compute the focal loss given the target tensor and the predicted tensor.
25 |
26 | As defined in https://arxiv.org/abs/1708.02002
27 |
28 | Args
29 | y_true: Tensor of target data from the generator with shape (B, N, num_classes+1).
30 | y_pred: Tensor of predicted data from the network with shape (B, N, num_classes).
31 |
32 | Returns
33 | The focal loss of y_pred w.r.t. y_true.
34 | """
35 | labels = y_true[:, :, :-1]
36 | anchor_state = y_true[:, :, -1] # -1 for ignore, 0 for background, 1 for object
37 | classification = y_pred
38 |
39 | # filter out "ignore" anchors
40 | indices = anchor_state != -1
41 | labels = labels[indices]
42 | classification = classification[indices]
43 |
44 | # compute the focal loss
45 | alpha_factor = torch.where(labels == 1, torch.tensor(alpha, dtype=torch.float32, device=device),
46 | torch.tensor(1 - alpha, dtype=torch.float32, device=device))
47 | focal_weight = torch.where(labels == 1, 1 - classification, classification)
48 | focal_weight = alpha_factor * focal_weight ** gamma
49 |
50 | cls_loss = focal_weight * F.binary_cross_entropy(classification, labels, reduction='none')
51 | # compute the normalizer: the number of positive anchors
52 | normalizer = torch.sum(anchor_state == 1)
53 | normalizer = normalizer.type(torch.float32)
54 | normalizer = torch.max(normalizer, torch.tensor(1.0, device=device))
55 |
56 | return torch.sum(cls_loss) / normalizer
57 |
58 | return _focal
59 |
60 |
61 | def smooth_l1_loss(sigma=3.0):
62 | """ Create a smooth L1 loss functor.
63 |
64 | Args
65 | sigma: This argument defines the point where the loss changes from L2 to L1.
66 |
67 | Returns
68 | A functor for computing the smooth L1 loss given target data and predicted data.
69 | """
70 | sigma_squared = sigma ** 2
71 |
72 | def _smooth_l1(y_true, y_pred):
73 | """ Compute the smooth L1 loss of y_pred w.r.t. y_true.
74 |
75 | Args
76 | y_true: Tensor from the generator of shape (B, N, 5). The last value for each box is the state of the anchor (ignore, negative, positive).
77 | y_pred: Tensor from the network of shape (B, N, 4).
78 |
79 | Returns
80 | The smooth L1 loss of y_pred w.r.t. y_true.
81 | """
82 | # separate target and state
83 | regression = y_pred
84 | regression_target = y_true[:, :, :4]
85 | anchor_state = y_true[:, :, 4]
86 |
87 | # filter out "ignore" anchors and "bg"
88 | indices = anchor_state == 1
89 | regression = regression[indices]
90 | regression_target = regression_target[indices]
91 |
92 | # compute smooth L1 loss
93 | # f(x) = 0.5 * (sigma * x)^2 if |x| < 1 / sigma / sigma
94 | # |x| - 0.5 / sigma / sigma otherwise
95 | regression_diff = regression - regression_target
96 | regression_diff = torch.abs(regression_diff)
97 | regression_loss = torch.where(
98 | torch.lt(regression_diff, 1.0 / sigma_squared),
99 | 0.5 * sigma_squared * torch.pow(regression_diff, 2),
100 | regression_diff - 0.5 / sigma_squared
101 | )
102 |
103 | # compute the normalizer: the number of positive anchors
104 | normalizer = torch.max(torch.tensor(1, device=device), torch.sum(indices))
105 | normalizer = normalizer.type(torch.float32)
106 | return torch.sum(regression_loss) / normalizer
107 |
108 | return _smooth_l1
109 |
--------------------------------------------------------------------------------
/models/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | import torch
5 | from dataGen.targetBuild import generate_anchors, AnchorParameters
6 |
7 | config_path = os.path.join(os.path.dirname(__file__), '..', 'config.json')
8 | with open(config_path, 'r') as f:
9 | config = json.load(f)
10 | # torch.device object used throughout this script
11 | device = torch.device("cuda" if config['use_cuda'] else "cpu")
12 |
13 |
14 | def shift(shape, stride, anchors):
15 | """ Produce shifted anchors based on shape of the map and stride size.
16 |
17 | Args
18 | shape : Shape to shift the anchors over (H, W).
19 | stride : Stride to shift the anchors with over the shape.
20 | anchors: The anchors to apply at each location, np.ndarry.
21 | """
22 | H, W = shape
23 | shift_x = (torch.arange(W, dtype=torch.float32, device=device) + torch.tensor(0.5, dtype=torch.float32, device=device)) * stride
24 | shift_y = (torch.arange(H, dtype=torch.float32, device=device) + torch.tensor(0.5, dtype=torch.float32, device=device)) * stride
25 | shift_y, shift_x = torch.meshgrid([shift_y, shift_x])
26 | shift_x = shift_x.contiguous().view(-1)
27 | shift_y = shift_y.contiguous().view(-1)
28 | shifts = torch.stack([shift_x, shift_y, shift_x, shift_y])
29 | shifts = shifts.t()
30 | anchors = torch.tensor(anchors, dtype=torch.float32, device=device)
31 |
32 | shifted_anchors = torch.unsqueeze(anchors, 0) + torch.unsqueeze(shifts, 1)
33 | shifted_anchors = shifted_anchors.view(-1, 4)
34 |
35 | return shifted_anchors
36 |
37 |
38 | def build_anchors(features, anchor_params=None):
39 | # if no anchor parameters are passed, use default values
40 | if anchor_params is None:
41 | anchor_params = AnchorParameters.default
42 | ratios = anchor_params.ratios
43 | scales = anchor_params.scales
44 | sizes = anchor_params.sizes
45 | strides = anchor_params.strides
46 | anchors_bag = []
47 | for feature, size, stride in zip(features, sizes, strides):
48 | shape = feature.shape[-2:]
49 | anchors = generate_anchors(size, ratios, scales)
50 | anchors_shift = shift(shape, stride, anchors)
51 | anchors_bag.append(anchors_shift)
52 |
53 | return torch.cat(anchors_bag, dim=0)
54 |
55 |
56 | def bbox_transform_inv(anchors, regression, mean=None, std=None):
57 | """ Applies deltas (usually regression results) to boxes (usually anchors).
58 |
59 | Before applying the deltas to the boxes, the normalization that was previously applied (in the generator) has to be removed.
60 | The mean and std are the mean and std as applied in the generator. They are unnormalized in this function and then applied to the boxes.
61 |
62 | Args
63 | anchors : Tensor of shape (N, 4), N the number of boxes and 4 values for (x1, y1, x2, y2).
64 | regression: Tensor of (B, N, 4), where B is the batch size, N the number of boxes.
65 | These deltas (d_x1, d_y1, d_x2, d_y2) are a factor of the width/height.
66 | mean : The mean value used when computing deltas (defaults to [0, 0, 0, 0]).
67 | std : The standard deviation used when computing deltas (defaults to [0.2, 0.2, 0.2, 0.2]).
68 |
69 | Returns
70 | A Tensor of the same shape as boxes, but with deltas applied to each box.
71 | The mean and std are used during training to normalize the regression values (networks love normalization).
72 | """
73 |
74 | if mean is None:
75 | mean = config['mean_bbox_transform']
76 | if std is None:
77 | std = config['std_bbox_transform']
78 |
79 | anchors = torch.unsqueeze(anchors, dim=0) # (1, N, 4)
80 | width = anchors[:, :, 2] - anchors[:, :, 0]
81 | height = anchors[:, :, 3] - anchors[:, :, 1]
82 |
83 | x1 = anchors[:, :, 0] + (regression[:, :, 0] * std[0] + mean[0]) * width
84 | y1 = anchors[:, :, 1] + (regression[:, :, 1] * std[1] + mean[1]) * height
85 | x2 = anchors[:, :, 2] + (regression[:, :, 2] * std[2] + mean[2]) * width
86 | y2 = anchors[:, :, 3] + (regression[:, :, 3] * std[3] + mean[3]) * height
87 |
88 | pred_boxes = torch.stack([x1, y1, x2, y2], dim=2)
89 |
90 | return pred_boxes
91 |
92 |
93 | def clip_boxes(images, boxes):
94 | shape = images.shape
95 | height = shape[-2]
96 | width = shape[-1]
97 |
98 | x1 = torch.clamp(boxes[:, :, 0], 0.0, width)
99 | y1 = torch.clamp(boxes[:, :, 1], 0.0, height)
100 | x2 = torch.clamp(boxes[:, :, 2], 0.0, width)
101 | y2 = torch.clamp(boxes[:, :, 3], 0.0, height)
102 | boxes_x1y1x2y2 = torch.stack([x1, y1, x2, y2], dim=2)
103 | boxes_x1y1x2y2 = boxes_x1y1x2y2.type(torch.int64)
104 | return boxes_x1y1x2y2
105 |
106 |
107 | def bbox_iou(box1, box2):
108 | """
109 | Returns the IoU of two bounding boxes
110 | """
111 | # Get the coordinates of bounding boxes
112 | b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
113 | b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
114 |
115 | # get the corrdinates of the intersection rectangle
116 | inter_rect_x1 = torch.max(b1_x1, b2_x1)
117 | inter_rect_y1 = torch.max(b1_y1, b2_y1)
118 | inter_rect_x2 = torch.min(b1_x2, b2_x2)
119 | inter_rect_y2 = torch.min(b1_y2, b2_y2)
120 | # Intersection area
121 | inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * torch.clamp(
122 | inter_rect_y2 - inter_rect_y1 + 1, min=0
123 | )
124 | # Union Area
125 | b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
126 | b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
127 |
128 | iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)
129 |
130 | return iou
131 |
132 |
133 | def non_max_suppression(boxes, scores, max_output_size=300, iou_threshold=0.5):
134 | """
135 | Removes detections with lower object confidence score than 'conf_thres' and performs
136 | Non-Maximum Suppression to further filter detections.
137 | Returns detections with shape:
138 | (x1, y1, x2, y2, object_conf, class_score, class_pred)
139 | """
140 |
141 | # Sort the detections by maximum objectness confidence
142 | _, conf_sort_index = torch.sort(scores, descending=True)
143 | boxes = boxes[conf_sort_index]
144 | # Perform non-maximum suppression
145 | max_indexes = []
146 | count = 0
147 | while boxes.shape[0] > 0:
148 | # Get detection with highest confidence and save as max detection
149 | max_detections = boxes[0].unsqueeze(0) # expand 1 dim
150 | max_indexes.append(conf_sort_index[0])
151 | # Stop if we're at the last detection
152 | if boxes.shape[0] == 1:
153 | break
154 | # Get the IOUs for all boxes with lower confidence
155 | ious = bbox_iou(max_detections, boxes[1:])
156 | # Remove detections with IoU >= NMS threshold
157 | boxes = boxes[1:][ious < iou_threshold]
158 | conf_sort_index = conf_sort_index[1:][ious < iou_threshold]
159 | # break when get enough bboxes
160 | count += 1
161 | if count >= max_output_size:
162 | break
163 |
164 | # max_detections = torch.cat(max_detections).data
165 | max_indexes = torch.stack(max_indexes).data
166 | return max_indexes
167 |
168 |
169 | def filter_detections(
170 | boxes,
171 | classification,
172 | class_specific_filter=True,
173 | nms=True,
174 | score_threshold=0.01,
175 | max_detections=300,
176 | nms_threshold=0.5
177 | ):
178 | """ Filter detections using the boxes and classification values.
179 |
180 | Args
181 | boxes : Tensor of shape (B, num_boxes, 4) containing the boxes in (x1, y1, x2, y2) format.
182 | classification : Tensor of shape (B, num_boxes, num_classes) containing the classification scores.
183 | class_specific_filter : Whether to perform filtering per class, or take the best scoring class and filter those.
184 | nms : Flag to enable/disable non maximum suppression.
185 | score_threshold : Threshold used to prefilter the boxes with.
186 | max_detections : Maximum number of detections to keep.
187 | nms_threshold : Threshold for the IoU value to determine when a box should be suppressed.
188 |
189 | Returns
190 | A list of [boxes, scores, labels, other[0], other[1], ...].
191 | boxes is shaped (max_detections, 4) and contains the (x1, y1, x2, y2) of the non-suppressed boxes.
192 | scores is shaped (max_detections,) and contains the scores of the predicted class.
193 | labels is shaped (max_detections,) and contains the predicted label.
194 | other[i] is shaped (max_detections, ...) and contains the filtered other[i] data.
195 | In case there are less than max_detections detections, the tensors are padded with -1's.
196 | """
197 |
198 | def _filter_detections(boxes, scores, labels):
199 | # threshold based on score
200 | indices = torch.gt(scores, score_threshold).nonzero()
201 | if indices.shape[0] == 0:
202 | return torch.tensor([], dtype=torch.int64, device=device)
203 | indices = indices[:, 0]
204 |
205 | if nms:
206 | filtered_boxes = torch.index_select(boxes, 0, indices)
207 | filtered_scores = torch.index_select(scores, 0, indices)
208 |
209 | # perform NMS
210 | nms_indices = non_max_suppression(filtered_boxes, filtered_scores, max_output_size=max_detections,
211 | iou_threshold=nms_threshold)
212 |
213 | # filter indices based on NMS
214 | indices = torch.index_select(indices, 0, nms_indices)
215 |
216 | # add indices to list of all indices
217 | labels = torch.index_select(labels, 0, indices)
218 | indices = torch.stack([indices, labels], dim=1)
219 |
220 | return indices
221 |
222 | results = []
223 | for box_cur, classification_cur in zip(boxes, classification):
224 | if class_specific_filter:
225 | all_indices = []
226 | # perform per class filtering
227 | for c in range(int(classification_cur.shape[1])):
228 | scores = classification_cur[:, c]
229 | labels = torch.full_like(scores, c, dtype=torch.int64)
230 | all_indices.append(_filter_detections(box_cur, scores, labels))
231 |
232 | # concatenate indices to single tensor
233 | indices = torch.cat(all_indices, dim=0)
234 | else:
235 | scores, labels = torch.max(classification_cur, dim=1)
236 | indices = _filter_detections(box_cur, scores, labels)
237 |
238 | if indices.shape[0] == 0:
239 | results.append({'bboxes':np.zeros((0, 4)), 'scores': np.full((0, ), -1, dtype=np.float32),'category_id': np.full((0, ), -1, dtype=np.int64)})
240 | continue
241 | # select top k
242 | scores = classification_cur[indices[:, 0], indices[:, 1]]
243 | labels = indices[:, 1]
244 | indices = indices[:, 0]
245 |
246 | scores, top_indices = torch.topk(scores, k=min(max_detections, scores.shape[0]))
247 | # filter input using the final set of indices
248 | indices = indices[top_indices]
249 | box_cur = box_cur[indices]
250 | labels = labels[top_indices]
251 | results.append({'bboxes':box_cur.cpu().detach().numpy(),'scores': scores.cpu().detach().numpy(), 'category_id': labels.cpu().detach().numpy()})
252 |
253 | return results
254 |
--------------------------------------------------------------------------------
/pretrainedmodels/.gitignore:
--------------------------------------------------------------------------------
1 | se_resnext50_32x4d-a260b3a4.pth
2 | se_resnext101_32x4d-3b2fe3d8.pth
--------------------------------------------------------------------------------
/pretrainedmodels/download_here.txt:
--------------------------------------------------------------------------------
1 | https://github.com/Cadene/pretrained-models.pytorch
2 |
3 | pretrained_settings = {
4 | 'se_resnext50_32x4d': {
5 | 'imagenet': {
6 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
7 | 'input_space': 'RGB',
8 | 'input_size': [3, 224, 224],
9 | 'input_range': [0, 1],
10 | 'mean': [0.485, 0.456, 0.406],
11 | 'std': [0.229, 0.224, 0.225],
12 | 'num_classes': 1000
13 | }
14 | },
15 | 'se_resnext101_32x4d': {
16 | 'imagenet': {
17 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
18 | 'input_space': 'RGB',
19 | 'input_size': [3, 224, 224],
20 | 'input_range': [0, 1],
21 | 'mean': [0.485, 0.456, 0.406],
22 | 'std': [0.229, 0.224, 0.225],
23 | 'num_classes': 1000
24 | }
25 | },
26 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torchvision==0.2.1
2 | setuptools==39.1.0
3 | torch==0.4.1
4 | Cython==0.29
5 | pydicom==1.2.0
6 | opencv_python==3.4.3.18
7 | matplotlib==2.2.3
8 | albumentations==0.1.7
9 | numpy==1.15.2
10 | tqdm==4.26.0
11 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python dataGen/setup.py build_ext --inplace
4 | rm -rf build dataGen/compute_overlap.c .eggs
5 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raytroop/RetinaNet_SE-ResNeXt/cab6103629e0fbe212397b6c202d4ca968c6cf52/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_boxinv.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import pydicom
5 | import torch
6 | import numpy as np
7 | from albumentations import Resize, Compose
8 | from dataGen.data_loader import RsnaDataset, fetch_val_loader
9 | from models.misc import build_anchors, bbox_transform_inv
10 |
11 | config_path = os.path.join(os.path.dirname(__file__), '..', 'config.json')
12 | with open(config_path, 'r') as f:
13 | config = json.load(f)
14 |
15 | def test_boxtfsinv():
16 | def load_dicom(img_id):
17 | image_path = os.path.join(os.path.dirname(__file__), '..', 'dataset/stage_2_train_images', img_id+'.dcm')
18 | ds = pydicom.read_file(image_path)
19 | image = ds.pixel_array
20 | # If grayscale. Convert to RGB for consistency.
21 | if len(image.shape) != 3 or image.shape[2] != 3:
22 | image = np.stack((image,) * 3, -1)
23 | return image
24 |
25 | kfold = 1
26 | valfps_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'valfps_{}.pkl'.format(kfold))
27 | with open(valfps_path, 'rb') as f:
28 | valfps = pickle.load(f)
29 |
30 | bboxdict_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'valbox_{}.pkl'.format(kfold))
31 | with open(bboxdict_path, 'rb') as f:
32 | bboxdict = pickle.load(f)
33 |
34 | labeldict_path = os.path.join(os.path.dirname(__file__), '..', 'fold_data', 'vallabel_{}.pkl'.format(kfold))
35 | with open(labeldict_path, 'rb') as f:
36 | labeldict = pickle.load(f)
37 | sample = None
38 | for i, nm in enumerate(valfps):
39 | if len(labeldict[nm]) > 1:
40 | sample = nm
41 | idx = i
42 | break
43 | # 00436515-870c-4b36-a041-de91049b9ab4
44 | img = load_dicom(sample)
45 | # [[264, 152, 476, 530], [562, 152, 817, 604]]
46 | bboxes = bboxdict[sample]
47 | # [0, 0]
48 | labels = labeldict[sample]
49 | assert img.shape == (1024, 1024, 3)
50 | assert len(bboxes) > 0
51 | assert len(bboxes[0]) == 4
52 | assert len(labels) == len(bboxes)
53 |
54 | val_aug = [Resize(*config['image_shape'], p=1.0)]
55 | dt = RsnaDataset(valfps, bboxdict, labeldict, aug=val_aug)
56 | sample = dt[idx]
57 | assert len(sample) == 3
58 | assert sample[0].shape == (3, 224, 224)
59 | # when `config['image_shape']` == (224, 224)
60 | length = (28*28+14*14+7*7+4*4+2*2)*9
61 | assert sample[1].shape == (length, 2)
62 | # assert sample[1][:, 0].sum().item() == 2
63 | pos_label = (sample[1][:, 1] == 1).sum().item()
64 | ignore_label = (sample[1][:, 1] == -1).sum().item()
65 | neg_label = (sample[1][:, 1] == 0).sum().item()
66 | assert length == pos_label + ignore_label + neg_label
67 |
68 | pos_reg = (sample[2][..., -1] == 1).sum().item()
69 | ignore_reg = (sample[2][..., -1] == -1).sum().item()
70 | neg_reg = (sample[2][..., -1] == 0).sum().item()
71 | assert length == pos_reg + ignore_reg + neg_reg
72 |
73 | assert pos_label == pos_reg == 64
74 | assert ignore_label == ignore_reg == 109
75 | assert neg_label == neg_reg == 9268
76 |
77 | regression = sample[2]
78 | assert regression.shape == (length, 5)
79 | features = []
80 | for sz in [28, 14, 7, 4, 2]:
81 | features.append(np.empty(shape=(config['batch_size'], 256, sz, sz)))
82 | anchors = build_anchors(features)
83 | assert anchors.shape == (length, 4)
84 | regression = regression[None, :, :4]
85 | assert regression.shape == (1, length, 4)
86 | bboxes_pred = bbox_transform_inv(anchors, regression)[0]
87 |
88 | # resize bbox
89 | bboxes = np.array(bboxes, dtype=np.float32) * 224 / 1024
90 | bboxes_pred = bboxes_pred[sample[1][:, -1]==1].numpy()
91 | assert bboxes_pred.shape == (pos_label, 4)
92 | assert bboxes.shape == (2, 4)
93 | assert (bboxes_pred[:, 0] == bboxes[0, 0]).sum() > 0
94 | assert (bboxes_pred[:, 1] == bboxes[0, 1]).sum() > 0
95 | assert (bboxes_pred[:, 2] == bboxes[0, 2]).sum() > 0
96 | assert (bboxes_pred[:, 3] == bboxes[0, 3]).sum() > 0
97 |
98 | assert (bboxes_pred[:, 0] == bboxes[1, 0]).sum() > 0
99 | assert (bboxes_pred[:, 1] == bboxes[1, 1]).sum() > 0
100 | assert (bboxes_pred[:, 2] == bboxes[1, 2]).sum() > 0
101 | assert (bboxes_pred[:, 3] == bboxes[1, 3]).sum() > 0
102 |
--------------------------------------------------------------------------------
/tests/test_dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import json
4 | import torch
5 | import numpy as np
6 | from dataGen.data_loader import load_dicom, fetch_trn_loader, fetch_val_loader
7 |
8 | config_path = os.path.join(os.path.dirname(__file__), '..', 'config.json')
9 | with open(config_path, 'r') as f:
10 | config = json.load(f)
11 | num_anchors = len(config['anchor_ratios_default']) * len(config['anchor_scales_default'])
12 | length = (28*28+14*14+7*7+4*4+2*2)*num_anchors
13 |
14 | def test_loaddicm():
15 | ids = os.listdir(os.path.join(os.path.dirname(__file__), '..', config['dicom_train']))
16 | assert len(ids) > 0
17 | imgid = random.sample(ids, 1)[0]
18 | img = load_dicom(os.path.splitext(imgid)[0])
19 | assert img.shape == (1024, 1024, 3)
20 |
21 |
22 | def test_dataloader():
23 | for i in range(1, 6):
24 | dl = iter(fetch_trn_loader(i))
25 | img_batch, labels_batch, regression_batch = next(dl)
26 | assert img_batch.shape == (config['batch_size'], 3, *config['image_shape'])
27 | assert labels_batch.shape == (config['batch_size'], length, config['num_classes']+1)
28 | assert regression_batch.shape == (config['batch_size'], length, 4+1)
29 |
30 | for i in range(1, 6):
31 | dl = iter(fetch_val_loader(i))
32 | img_batch, labels_batch, regression_batch = next(dl)
33 | assert img_batch.shape == (config['batch_size'], 3, *config['image_shape'])
34 | assert labels_batch.shape == (config['batch_size'], length, config['num_classes']+1)
35 | assert regression_batch.shape == (config['batch_size'], length, 4+1)
36 |
--------------------------------------------------------------------------------
/tests/test_filter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import json
4 | import torch
5 | import numpy as np
6 | from models.misc import filter_detections
7 |
8 |
9 | def test_simple():
10 | # create simple FilterDetections layer
11 |
12 | # create simple input
13 | boxes = np.array([[
14 | [0, 0, 10, 10],
15 | [0, 0, 10, 10], # this will be suppressed
16 | ]], dtype=np.float)
17 | boxes = torch.tensor(boxes)
18 |
19 | classification = np.array([[
20 | [0, 0.9], # this will be suppressed
21 | [0, 1],
22 | ]], dtype=np.float)
23 | classification = torch.tensor(classification)
24 |
25 | # compute output
26 | results = filter_detections(boxes, classification)
27 | actual_boxes = results[0][0]
28 | actual_scores = results[0][1]
29 | actual_labels = results[0][2]
30 |
31 | # define expected output
32 | expected_boxes = np.array([[0, 0, 10, 10]], dtype=np.float)
33 |
34 | expected_scores = np.array([1], dtype=np.float)
35 |
36 | expected_labels = np.array([1], dtype=np.float)
37 |
38 | # assert actual and expected are equal
39 | np.testing.assert_array_equal(actual_boxes, expected_boxes)
40 | np.testing.assert_array_equal(actual_scores, expected_scores)
41 | np.testing.assert_array_equal(actual_labels, expected_labels)
42 |
43 |
44 | def test_mini_batch():
45 | # create simple FilterDetections layer
46 |
47 | # create input with batch_size=2
48 | boxes = np.array([
49 | [
50 | [0, 0, 10, 10], # this will be suppressed
51 | [0, 0, 10, 10],
52 | ],
53 | [
54 | [100, 100, 150, 150],
55 | [100, 100, 150, 150], # this will be suppressed
56 | ],
57 | ], dtype=np.float)
58 | boxes = torch.tensor(boxes)
59 |
60 | classification = np.array([
61 | [
62 | [0, 0.9], # this will be suppressed
63 | [0, 1],
64 | ],
65 | [
66 | [1, 0],
67 | [0.9, 0], # this will be suppressed
68 | ],
69 | ], dtype=np.float)
70 | classification = torch.tensor(classification)
71 |
72 | # compute output
73 | results = filter_detections(boxes, classification)
74 |
75 |
76 | # define expected output
77 | expected_boxes0 = np.array([[0, 0, 10, 10]], dtype=np.float)
78 | expected_boxes1 = np.array([[100, 100, 150, 150]], dtype=np.float)
79 |
80 | expected_scores0 = np.array([1], dtype=np.float)
81 | expected_scores1 = np.array([1], dtype=np.float)
82 |
83 | expected_labels0 = np.array([1], dtype=np.float)
84 | expected_labels1 = np.array([0], dtype=np.float)
85 |
86 | # assert actual and expected are equal
87 | np.testing.assert_array_equal(results[0][0], expected_boxes0)
88 | np.testing.assert_array_equal(results[0][1], expected_scores0)
89 | np.testing.assert_array_equal(results[0][2], expected_labels0)
90 |
91 | # assert actual and expected are equal
92 | np.testing.assert_array_equal(results[1][0], expected_boxes1)
93 | np.testing.assert_array_equal(results[1][1], expected_scores1)
94 | np.testing.assert_array_equal(results[1][2], expected_labels1)
95 |
--------------------------------------------------------------------------------
/tests/test_fpn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from models.fpn import top_down, classification_subnet, regression_subnet
6 |
7 | def test_top_down():
8 | model = top_down()
9 | C3 = torch.randn(4, 512, 28, 28)
10 | C4 = torch.randn(4, 1024, 14, 14)
11 | C5 = torch.randn(4, 2048, 7, 7)
12 |
13 | P3, P4, P5, P6, P7 = model((C3, C4, C5))
14 | assert P3.shape == (4, 256, 28, 28)
15 | assert P4.shape == (4, 256, 14, 14)
16 | assert P5.shape == (4, 256, 7, 7)
17 | assert P6.shape == (4, 256, 4, 4)
18 | assert P7.shape == (4, 256, 2, 2)
19 |
20 | # ----------------------------------
21 | C3 = torch.randn(4, 512, 64, 64)
22 | C4 = torch.randn(4, 1024, 32, 32)
23 | C5 = torch.randn(4, 2048, 16, 16)
24 |
25 | P3, P4, P5, P6, P7 = model((C3, C4, C5))
26 | assert P3.shape == (4, 256, 64, 64)
27 | assert P4.shape == (4, 256, 32, 32)
28 | assert P5.shape == (4, 256, 16, 16)
29 | assert P6.shape == (4, 256, 8, 8)
30 | assert P7.shape == (4, 256, 4, 4)
31 |
32 |
33 | def test_classification_subnet():
34 | model = classification_subnet()
35 | P3 = torch.randn(4, 256, 28, 28)
36 | feat3 = model(P3)
37 | P4 = torch.randn(4, 256, 14, 14)
38 | feat4 = model(P4)
39 | P5 = torch.randn(4, 256, 7, 7)
40 | feat5 = model(P5)
41 | P6 = torch.randn(4, 256, 4, 4)
42 | feat6 = model(P6)
43 | P7 = torch.randn(4, 256, 2, 2)
44 | feat7 = model(P7)
45 | assert feat3.shape == (4, 9*28*28, 1)
46 | assert feat4.shape == (4, 9*14*14, 1)
47 | assert feat5.shape == (4, 9*7*7, 1)
48 | assert feat6.shape == (4, 9*4*4, 1)
49 | assert feat7.shape == (4, 9*2*2, 1)
50 |
51 | assert len(list(model.children())) == 2
52 |
53 | assert isinstance(model.head, nn.Conv2d)
54 | assert model.head.weight.shape == (9, 256, 3, 3)
55 | np.testing.assert_almost_equal(model.head.weight.mean().item(), 0, decimal=2)
56 | np.testing.assert_almost_equal(model.head.weight.std().item(), 0.01, decimal=2)
57 | prior_probability=0.01
58 | np.testing.assert_almost_equal(model.head.bias.data.numpy(), -math.log((1 - prior_probability) / prior_probability))
59 |
60 |
61 | def test_regression_subnet():
62 | model = regression_subnet()
63 | P3 = torch.randn(4, 256, 28, 28)
64 | feat3 = model(P3)
65 | P4 = torch.randn(4, 256, 14, 14)
66 | feat4 = model(P4)
67 | P5 = torch.randn(4, 256, 7, 7)
68 | feat5 = model(P5)
69 | P6 = torch.randn(4, 256, 4, 4)
70 | feat6 = model(P6)
71 | P7 = torch.randn(4, 256, 2, 2)
72 | feat7 = model(P7)
73 | assert feat3.shape == (4, 9*28*28, 4)
74 | assert feat4.shape == (4, 9*14*14, 4)
75 | assert feat5.shape == (4, 9*7*7, 4)
76 | assert feat6.shape == (4, 9*4*4, 4)
77 | assert feat7.shape == (4, 9*2*2, 4)
78 |
--------------------------------------------------------------------------------
/tests/test_losses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from models.losses import focal_loss, smooth_l1_loss
4 |
5 | def test_focal():
6 | focal = focal_loss()
7 | y_pred = torch.rand(8, 3, 1)
8 | y_true = torch.rand(8, 3, 2)
9 | y_true[..., -1] = torch.tensor([1, -1, 0])
10 | assert (y_true[..., -1] == torch.tensor([1, -1, 0], dtype=torch.float32)).all()
11 | assert focal(y_true, y_pred)
12 |
13 |
14 | def test_smooth_l1():
15 | smooth_l1 = smooth_l1_loss()
16 | y_pred = torch.rand(8, 3, 4)
17 | y_true = torch.rand(8, 3, 5)
18 | y_true[..., -1] = torch.tensor([1, -1, 0])
19 | assert (y_true[..., -1] == torch.tensor([1, -1, 0], dtype=torch.float32)).all()
20 | assert smooth_l1(y_true, y_pred)
21 |
--------------------------------------------------------------------------------
/tests/test_retinanet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from models.fpn import retinanet
7 | from dataGen.data_loader import load_dicom, fetch_trn_loader, fetch_val_loader
8 | from models.losses import focal_loss, smooth_l1_loss
9 |
10 |
11 | config_path = os.path.join(os.path.dirname(__file__), '..', 'config.json')
12 | with open(config_path, 'r') as f:
13 | config = json.load(f)
14 | num_anchors = len(config['anchor_ratios_default']) * len(config['anchor_scales_default'])
15 | length = (28*28+14*14+7*7+4*4+2*2)*num_anchors
16 |
17 | def test_retinanet():
18 | model = retinanet()
19 | images = torch.randn(8, 3, 224, 224)
20 | classification, regression = model(images)
21 | assert classification.shape == (8, length, 1)
22 | assert regression.shape == (8, length, 4)
23 |
24 | results = model.predict(images)
25 | assert len(results) == 8
26 | assert len(results[0]) == 3
27 | for res in results:
28 | assert res[0].shape[0] == res[1].shape[0] == res[2].shape[0]
29 | assert res[0].shape[1] == 4
30 | assert len(res[1].shape) == 1
31 | assert len(res[2].shape) == 1
32 | assert res[0].dtype == np.float32
33 | assert res[1].dtype == np.float32
34 | assert res[2].dtype == np.int64
35 |
36 |
37 | def test_merged():
38 | model = retinanet()
39 | model = model.cuda()
40 | for k in range(1, 6):
41 | dl = iter(fetch_trn_loader(k))
42 | for i in range(100):
43 | img_batch, labels_batch, regression_batch = next(dl)
44 | img_batch = img_batch.cuda()
45 | labels_batch = labels_batch.cuda()
46 | regression_batch = regression_batch.cuda()
47 |
48 | classification, regression = model(img_batch)
49 | assert classification.shape == (config['batch_size'], length, config['num_classes'])
50 | assert labels_batch.shape == (config['batch_size'], length, config['num_classes']+1)
51 | assert regression.shape == (config['batch_size'], length, 4)
52 | assert regression_batch.shape == (config['batch_size'], length, 4+1)
53 |
54 | focal = focal_loss()
55 | smooth_l1 = smooth_l1_loss()
56 | assert focal(labels_batch, classification).shape == torch.Size([])
57 | assert smooth_l1(regression_batch, regression).shape == torch.Size([])
58 |
59 | results = model.predict(img_batch)
60 |
61 | for k in range(1, 6):
62 | dl = iter(fetch_val_loader(k))
63 | for i in range(100):
64 | img_batch, labels_batch, regression_batch = next(dl)
65 | img_batch = img_batch.cuda()
66 | labels_batch = labels_batch.cuda()
67 | regression_batch = regression_batch.cuda()
68 |
69 | classification, regression = model(img_batch)
70 | assert classification.shape == (config['batch_size'], length, config['num_classes'])
71 | assert labels_batch.shape == (config['batch_size'], length, config['num_classes']+1)
72 | assert regression.shape == (config['batch_size'], length, 4)
73 | assert regression_batch.shape == (config['batch_size'], length, 4+1)
74 |
75 | focal = focal_loss()
76 | smooth_l1 = smooth_l1_loss()
77 | assert focal(labels_batch, classification).shape == torch.Size([])
78 | assert smooth_l1(regression_batch, regression).shape == torch.Size([])
79 |
80 | results = model.predict(img_batch)
81 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """Train the model"""
2 | import sys
3 | import json
4 | import argparse
5 | import logging
6 | import os
7 | import numpy as np
8 | import torch
9 | from torch import nn, optim
10 | from torch.nn.utils import clip_grad_norm_
11 | from tqdm import tqdm
12 | from models.losses import focal_loss, smooth_l1_loss
13 | from models.fpn import retinanet
14 | from dataGen.data_loader import fetch_trn_loader, fetch_val_loader
15 | import utils
16 |
17 | with open('config.json', 'r') as f:
18 | config = json.load(f)
19 | # torch.device object used throughout this script
20 | device = torch.device("cuda" if config['use_cuda'] else "cpu")
21 |
22 | def parse_args(args=None):
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('--fold', required=True, choices=[1, 2, 3, 4, 5], type=int, help="Directory containing the dataset")
25 | parser.add_argument('--learning_rate', default=1e-5, type=float, help="learning rate of optimizer")
26 | parser.add_argument('--num_epochs', default=30, type=int, help="total epochs to train")
27 | parser.add_argument('--frozen_epochs', default=20, type=int, help="the first epoches to fix parameter of backbone")
28 | parser.add_argument('--save_dir', default='checkpoints', type=str, help="Directory containing params.json")
29 | parser.add_argument('--checkpoint2load', default=None, type=str, help="checkpoint to load") # 'best' or 'train'
30 | parser.add_argument('--optim_restore', default=True, type=bool, help="whether to restore optimizer parameter")
31 | return parser.parse_args(args)
32 |
33 |
34 | def train(model, optimizer, loss_fn, dataloader, params, epoch):
35 | """Train the model on `num_steps` batches
36 |
37 | Args:
38 | model: (torch.nn.Module) the neural network
39 | optimizer: (torch.optim) optimizer for parameters of model
40 | loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
41 | dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data
42 | metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
43 | params: (Params) hyperparameters
44 | num_steps: (int) number of batches to train on, each of size params.batch_size
45 | """
46 | loss_TOTAL = utils.RunningAverage()
47 | loss_FL = utils.RunningAverage()
48 | loss_L1 = utils.RunningAverage()
49 |
50 | model.train()
51 | if epoch < params.frozen_epochs:
52 | model.train_extractor(False)
53 | else:
54 | model.train_extractor(True)
55 |
56 | with tqdm(total=len(dataloader)) as t:
57 | for i, (img_batch, labels_batch, regression_batch) in enumerate(dataloader):
58 | img_batch, labels_batch, regression_batch = img_batch.to(device), labels_batch.to(device), regression_batch.to(device)
59 | classification_pred, regression_pred, _ = model(img_batch)
60 | loss_cls = loss_fn['focal'](labels_batch, classification_pred) * config['loss_ratio_FL2L1']
61 | loss_reg = loss_fn['smooth_l1'](regression_batch, regression_pred)
62 | loss_all = loss_cls + loss_reg
63 |
64 | loss_cls_detach = loss_cls.detach().item()
65 | loss_reg_detach = loss_reg.detach().item()
66 | loss_all_detach = loss_all.detach().item()
67 | # clear previous gradients, compute gradients of all variables wrt loss
68 | optimizer.zero_grad()
69 | loss_all.backward()
70 | # The norm is computed over all gradients together
71 | clip_grad_norm_(model.parameters(), 0.5)
72 | # performs updates using calculated gradients
73 | optimizer.step()
74 |
75 | # update the average loss
76 | loss_TOTAL.update(loss_all_detach)
77 | loss_FL.update(loss_cls_detach)
78 | loss_L1.update(loss_reg_detach)
79 |
80 | del img_batch, labels_batch, regression_batch
81 |
82 | t.set_postfix(total_loss='{:05.3f}'.format(loss_all_detach), FL_loss='{:05.3f}'.format(
83 | loss_cls_detach), L1_loss='{:05.3f}'.format(loss_reg_detach))
84 | t.update()
85 | logging.info("total_loss:{:05.3f} FL_loss:{:05.3f} L1_loss:{:05.3f}".format(loss_TOTAL(), loss_FL(), loss_L1()))
86 | del loss_TOTAL, loss_FL, loss_L1
87 |
88 |
89 | def evaluate(model, loss_fn, val_dataloader, params, epoch):
90 | # set model to evaluation mode
91 | model.eval()
92 |
93 | loss_TOTAL = utils.RunningAverage()
94 | loss_FL = utils.RunningAverage()
95 | loss_L1 = utils.RunningAverage()
96 |
97 | with torch.no_grad():
98 | for i, (img_batch, labels_batch, regression_batch) in enumerate(val_dataloader):
99 | img_batch, labels_batch, regression_batch = img_batch.to(device), labels_batch.to(device), regression_batch.to(device)
100 | classification_pred, regression_pred, _ = model(img_batch)
101 | loss_cls = loss_fn['focal'](labels_batch, classification_pred) * config['loss_ratio_FL2L1']
102 | loss_reg = loss_fn['smooth_l1'](regression_batch, regression_pred)
103 | loss_all = loss_cls + loss_reg
104 |
105 | loss_cls_detach = loss_cls.detach().item()
106 | loss_reg_detach = loss_reg.detach().item()
107 | loss_all_detach = loss_all.detach().item()
108 |
109 | # update the average loss
110 | loss_TOTAL.update(loss_all_detach)
111 | loss_FL.update(loss_cls_detach)
112 | loss_L1.update(loss_reg_detach)
113 |
114 | del img_batch, labels_batch, regression_batch
115 |
116 | logging.info("total_loss:{:05.3f} FL_loss:{:05.3f} L1_loss:{:05.3f}".format(loss_TOTAL(), loss_FL(), loss_L1()))
117 | res = loss_TOTAL()
118 | del loss_TOTAL, loss_FL, loss_L1
119 | return res
120 |
121 | def train_and_evaluate(model, train_dataloader, val_dataloader, optimizer, loss_fn, params,
122 | scheduler=None):
123 | """Train the model and evaluate every epoch.
124 |
125 | Args:
126 | model: (torch.nn.Module) the neural network
127 | train_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data
128 | val_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches validation data
129 | optimizer: (torch.optim) optimizer for parameters of model
130 | loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
131 | metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
132 | params: (Params) hyperparameters
133 | """
134 | init_epoch = 0
135 | best_val_loss = float('inf')
136 |
137 | # reload weights from restore_file if specified
138 | if params.checkpoint2load is not None:
139 | checkpoint = utils.load_checkpoint(params.checkpoint2load, model, optimizer if params.optim_restore else None)
140 | if 'epoch' in checkpoint:
141 | init_epoch = checkpoint['epoch']
142 | if 'best_val_loss' in checkpoint:
143 | best_val_loss = checkpoint['best_val_loss']
144 |
145 | for epoch in range(init_epoch, params.num_epochs):
146 | # Run one epoch
147 | logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))
148 | if scheduler is not None:
149 | scheduler.step()
150 | # compute number of batches in one epoch (one full pass over the training set)
151 | train(model, optimizer, loss_fn, train_dataloader, params, epoch)
152 |
153 | logging.info("validating ... ")
154 | # Evaluate for one epoch on validation set
155 | val_loss = evaluate(model, loss_fn, val_dataloader, params, epoch)
156 |
157 | is_best = val_loss <= best_val_loss
158 | if is_best:
159 | best_val_loss = val_loss
160 | # Save weights
161 | utils.save_checkpoint({'epoch': epoch + 1,
162 | 'state_dict': model.state_dict(),
163 | 'optim_dict': optimizer.state_dict(),
164 | 'best_val_loss': best_val_loss},
165 | is_best=is_best,
166 | checkpoint=params.save_dir)
167 |
168 |
169 | if __name__ == '__main__':
170 | # Load the parameters from json file
171 | args = parse_args()
172 |
173 | # Set the random seed for reproducible experiments
174 | torch.manual_seed(42)
175 | if config['use_cuda']:
176 | torch.cuda.manual_seed(42)
177 |
178 | args.save_dir = os.path.join(args.save_dir, config['backbone'], f'fold_{args.fold}')
179 | # Set the logger
180 | if not os.path.isdir(args.save_dir):
181 | os.makedirs(args.save_dir)
182 |
183 | # save config in file
184 | with open(os.path.join(args.save_dir, 'config.json'), 'w') as f:
185 | config.update(vars(args))
186 | json.dump(config, f, indent=4)
187 |
188 | utils.set_logger(os.path.join(args.save_dir, 'train.log'))
189 | logging.info(' '.join(sys.argv[:]))
190 | logging.info(args.save_dir)
191 |
192 | # Create the input data pipeline
193 | logging.info("Loading the datasets...")
194 | # fetch dataloaders
195 | train_dl = fetch_trn_loader(args.fold)
196 | val_dl = fetch_val_loader(args.fold)
197 |
198 | # Define the model and optimizer
199 | Net = retinanet(config['backbone'])
200 | model = Net.to(device)
201 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
202 | # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
203 | # fetch loss function and metrics
204 | loss_fn = {'focal': focal_loss(alpha=config['focal_alpha']), 'smooth_l1': smooth_l1_loss(sigma=config['l1_sigma'])}
205 |
206 | # Train the model
207 | logging.info("Starting training for {} epoch(s)".format(args.num_epochs))
208 | train_and_evaluate(model, train_dl, val_dl, optimizer, loss_fn, args, scheduler=None)
209 | logging.info('Done')
210 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import shutil
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 | import cv2
9 |
10 | class RunningAverage():
11 | """A simple class that maintains the running average of a quantity
12 |
13 | Example:
14 | ```
15 | loss_avg = RunningAverage()
16 | loss_avg.update(2)
17 | loss_avg.update(4)
18 | loss_avg() = 3
19 | ```
20 | """
21 | def __init__(self):
22 | self.steps = 0
23 | self.total = 0
24 |
25 | def update(self, val):
26 | self.total += val
27 | self.steps += 1
28 |
29 | def __call__(self):
30 | return self.total/float(self.steps)
31 |
32 |
33 | def set_logger(log_path):
34 | """Set the logger to log info in terminal and file `log_path`.
35 |
36 | In general, it is useful to have a logger so that every output to the terminal is saved
37 | in a permanent file. Here we save it to `model_dir/train.log`.
38 |
39 | Example:
40 | ```
41 | logging.info("Starting training...")
42 | ```
43 |
44 | Args:
45 | log_path: (string) where to log
46 | """
47 | logger = logging.getLogger()
48 | logger.setLevel(logging.INFO)
49 |
50 | if not logger.handlers:
51 | # Logging to a file
52 | file_handler = logging.FileHandler(log_path)
53 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
54 | logger.addHandler(file_handler)
55 |
56 | # Logging to console
57 | stream_handler = logging.StreamHandler()
58 | stream_handler.setFormatter(logging.Formatter('%(message)s'))
59 | logger.addHandler(stream_handler)
60 |
61 |
62 | def save_dict_to_json(d, json_path):
63 | """Saves dict of floats in json file
64 |
65 | Args:
66 | d: (dict) of float-castable values (np.float, int, float, etc.)
67 | json_path: (string) path to json file
68 | """
69 | with open(json_path, 'w') as f:
70 | json.dump(d, f, indent=4)
71 |
72 |
73 | def save_checkpoint(state, is_best, checkpoint):
74 | """Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves
75 | checkpoint + 'best.pth.tar'
76 |
77 | Args:
78 | state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict
79 | is_best: (bool) True if it is the best model seen till now
80 | checkpoint: (string) folder where parameters are to be saved
81 | """
82 | filepath = os.path.join(checkpoint, f"epoch{state['epoch']}.pth.tar")
83 | if not os.path.exists(checkpoint):
84 | print("Checkpoint Directory does not exist! Making directory {}".format(checkpoint))
85 | os.mkdir(checkpoint)
86 | else:
87 | print("Checkpoint Directory exists! ")
88 | torch.save(state, filepath)
89 | if is_best:
90 | shutil.copyfile(filepath, os.path.join(checkpoint, 'best.pth.tar'))
91 |
92 |
93 | def load_checkpoint(checkpoint, model, optimizer=None):
94 | """Loads model parameters (state_dict) from file_path. If optimizer is provided, loads state_dict of
95 | optimizer assuming it is present in checkpoint.
96 |
97 | Args:
98 | checkpoint: (string) filename which needs to be loaded
99 | model: (torch.nn.Module) model for which the parameters are loaded
100 | optimizer: (torch.optim) optional: resume optimizer from checkpoint
101 | """
102 | if not os.path.exists(checkpoint):
103 | raise("File doesn't exist {}".format(checkpoint))
104 | checkpoint = torch.load(checkpoint)
105 | model.load_state_dict(checkpoint['state_dict'])
106 |
107 | if optimizer:
108 | optimizer.load_state_dict(checkpoint['optim_dict'])
109 |
110 | return checkpoint
111 |
112 |
113 | # https://github.com/albu/albumentations/blob/master/notebooks/example_bboxes.ipynb
114 | # Functions to visualize bounding boxes and class labels on an image.
115 | # Based on https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/vis.py
116 | BOX_COLOR = {0:(255, 0, 0), 1:(0, 255, 0)}
117 | TEXT_COLOR = (255, 255, 255)
118 |
119 | # Available formats are: coco, pascal_voc.
120 | # The coco format of a bounding box looks like [x_min, y_min, width, height], e.g. [97, 12, 150, 200].
121 | # The pascal_voc format of a bounding box looks like [x_min, y_min, x_max, y_max], e.g. [97, 12, 247, 212].
122 |
123 | def visualize_bbox(img, bbox, class_id, class_idx_to_name, color=BOX_COLOR, thickness=2, pascal=True):
124 | if pascal:
125 | x_min, x_max, y_min, y_max = bbox
126 | else:
127 | x_min, y_min, w, h = bbox
128 | x_min, x_max, y_min, y_max = int(x_min), int(x_min + w), int(y_min), int(y_min + h)
129 |
130 | boxcolor = BOX_COLOR[class_id]
131 | cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=boxcolor, thickness=thickness)
132 | class_name = class_idx_to_name[class_id]
133 | ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)
134 | cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), boxcolor, -1)
135 | cv2.putText(img, class_name, (x_min, y_min - int(0.3 * text_height)), cv2.FONT_HERSHEY_SIMPLEX, 0.35,TEXT_COLOR, lineType=cv2.LINE_AA)
136 | return img
137 |
138 |
139 | def visualize(annotations, category_id_to_name):
140 | img = annotations['image'].copy()
141 | for idx, bbox in enumerate(annotations['bboxes']):
142 | img = visualize_bbox(img, bbox, annotations['category_id'][idx], category_id_to_name)
143 | plt.figure(figsize=(12, 12))
144 | plt.imshow(img)
145 |
--------------------------------------------------------------------------------