├── .idea
├── Repulsion_Loss.iml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── README.md
├── anchors.py
├── csv_eval.py
├── dataloader.py
├── img
├── 1.png
└── 2.png
├── lib
├── __init__.py
├── __pycache__
│ └── __init__.cpython-36.pyc
├── build.sh
└── nms
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ └── pth_nms.cpython-36.pyc
│ ├── _ext
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-36.pyc
│ └── nms
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ └── __init__.cpython-36.pyc
│ │ └── _nms.so
│ ├── build.py
│ ├── pth_nms.py
│ └── src
│ ├── cuda
│ ├── nms_kernel.cu
│ ├── nms_kernel.cu.o
│ └── nms_kernel.h
│ ├── nms.c
│ ├── nms.h
│ ├── nms_cuda.c
│ └── nms_cuda.h
├── losses.py
├── model.py
├── train.py
└── utils.py
/.idea/Repulsion_Loss.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
51 |
52 |
53 |
54 |
55 | true
56 | DEFINITION_ORDER
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 | 1540996969921
120 |
121 |
122 | 1540996969921
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Repulsion_Loss
2 |
3 | Pytorch implementation of Repulsion Loss as described in [Repulsion Loss: Detecting Pedestrians in a Crowd](https://arxiv.org/abs/1711.07752). The baseline is RetinaNet followed by this [repo](https://github.com/yhenon/pytorch-retinanet).
4 |
5 | ## Requirements
6 |
7 | - Python3
8 | - Pytorch0.4
9 | - torchvision
10 | - tensorboardX
11 |
12 | ## Installation
13 |
14 | Install packages.
15 |
16 | ```
17 | sudo apt-get install tk-dev python-tk
18 | pip install cffi
19 | pip install cython
20 | pip install pandas
21 | pip install tensorboardX
22 | ```
23 |
24 | Build NMS.
25 |
26 | ```
27 | cd Repulsion_Loss/lib
28 | sh build.sh
29 | ```
30 |
31 | Create folders.
32 |
33 | ```
34 | cd Repulsion_Loss
35 | mkdir ckpt mAP_txt summary weight
36 | ```
37 |
38 | ## Datasets
39 | This repo is built for human detection. The popular annotation format for human detection(or pedestrian detection) includes bounding boxes of both human and ignore regions such as [Citypersons](https://arxiv.org/pdf/1702.05693.pdf) and [Crowdhuman](https://arxiv.org/pdf/1805.00123.pdf). You should write them in CSV or TXT files.
40 |
41 | ### Annotations format
42 | Three examples are as follows:
43 |
44 | ```
45 | $image_path/img_1.jpg x1 y1 x2 y2 person
46 | $image_path/img_1.jpg x1 y1 x2 y2 ignore
47 | $image_path/img_2.jpg . . . . .
48 | ```
49 |
50 | Images with more than one bounding box should use one row per box. Labels that we often use are 'person' or 'ignore'. When an image does not contain any bounding box, set them '.'.
51 |
52 | ### Label encoding file
53 | A TXT file (classes.txt) is needed to map label to ID. Each line means one label name and its ID. One example is as follows:
54 |
55 | ```
56 | person 0
57 | ```
58 |
59 | ## Pretrained Model
60 |
61 | We use resnet18, 34, 50, 101, 152 as the backbone. You should download them and put them to `/weight`.
62 |
63 | - resnet18: [https://download.pytorch.org/models/resnet18-5c106cde.pth](https://download.pytorch.org/models/resnet18-5c106cde.pth)
64 | - resnet34: [https://download.pytorch.org/models/resnet34-333f7ec4.pth](https://download.pytorch.org/models/resnet34-333f7ec4.pth)
65 | - resnet50: [https://download.pytorch.org/models/resnet50-19c8e357.pth](https://download.pytorch.org/models/resnet50-19c8e357.pth)
66 | - resnet101: [https://download.pytorch.org/models/resnet101-5d3b4d8f.pth](https://download.pytorch.org/models/resnet101-5d3b4d8f.pth)
67 | - resnet152: [https://download.pytorch.org/models/resnet152-b121ed2d.pth](https://download.pytorch.org/models/resnet152-b121ed2d.pth)
68 |
69 | ## Training
70 |
71 | ```
72 | python train.py --csv_train <$path/train.txt> --csv_val <$path/val.txt> --csv_classes <$path/classes.txt> --depth <50> --pretrained resnet50-19c8e357.pth --model_name
73 | ```
74 |
75 | ## Visualization Result
76 | Baseline
77 |
78 | 
79 |
80 | Add repulsion loss
81 |
82 | 
83 |
84 | ## Reference
85 |
86 | - [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)
87 | - [Repulsion Loss: Detecting Pedestrians in a Crowd](https://arxiv.org/abs/1711.07752)
--------------------------------------------------------------------------------
/anchors.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class Anchors(nn.Module):
7 | def __init__(self, pyramid_levels=None, strides=None, sizes=None, ratios=None, scales=None):
8 | super(Anchors, self).__init__()
9 |
10 | if pyramid_levels is None:
11 | self.pyramid_levels = [3, 4, 5, 6, 7]
12 | if strides is None:
13 | self.strides = [2 ** x for x in self.pyramid_levels]
14 | if sizes is None:
15 | self.sizes = [2 ** (x + 2) for x in self.pyramid_levels]
16 | if ratios is None:
17 | #self.ratios = np.array([1., 1.5, 2., 2.5, 3.])
18 | self.ratios = np.array([0.5, 1., 2.])
19 | if scales is None:
20 | self.scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])
21 |
22 | def forward(self, image):
23 |
24 | image_shape = image.shape[2:]
25 | image_shape = np.array(image_shape)
26 | image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in self.pyramid_levels]
27 |
28 | # compute anchors over all pyramid levels
29 | all_anchors = np.zeros((0, 4)).astype(np.float32)
30 |
31 | for idx, p in enumerate(self.pyramid_levels):
32 | anchors = generate_anchors(base_size=self.sizes[idx], ratios=self.ratios, scales=self.scales)
33 | shifted_anchors = shift(image_shapes[idx], self.strides[idx], anchors)
34 | all_anchors = np.append(all_anchors, shifted_anchors, axis=0)
35 |
36 | all_anchors = np.expand_dims(all_anchors, axis=0)
37 |
38 | return torch.from_numpy(all_anchors.astype(np.float32)).cuda()
39 |
40 | def generate_anchors(base_size=16, ratios=None, scales=None):
41 | """
42 | Generate anchor (reference) windows by enumerating aspect ratios X
43 | scales w.r.t. a reference window.
44 | """
45 |
46 | if ratios is None:
47 | ratios = np.array([0.5, 1, 2])
48 |
49 | if scales is None:
50 | scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])
51 |
52 | num_anchors = len(ratios) * len(scales)
53 |
54 | # initialize output anchors
55 | anchors = np.zeros((num_anchors, 4))
56 |
57 | # scale base_size
58 | anchors[:, 2:] = base_size * np.tile(scales, (2, len(ratios))).T
59 |
60 | # compute areas of anchors
61 | areas = anchors[:, 2] * anchors[:, 3]
62 |
63 | # correct for ratios
64 | anchors[:, 2] = np.sqrt(areas / np.repeat(ratios, len(scales)))
65 | anchors[:, 3] = anchors[:, 2] * np.repeat(ratios, len(scales))
66 |
67 | # transform from (x_ctr, y_ctr, w, h) -> (x1, y1, x2, y2)
68 | anchors[:, 0::2] -= np.tile(anchors[:, 2] * 0.5, (2, 1)).T
69 | anchors[:, 1::2] -= np.tile(anchors[:, 3] * 0.5, (2, 1)).T
70 |
71 | return anchors
72 |
73 | def compute_shape(image_shape, pyramid_levels):
74 | """Compute shapes based on pyramid levels.
75 |
76 | :param image_shape:
77 | :param pyramid_levels:
78 | :return:
79 | """
80 | image_shape = np.array(image_shape[:2])
81 | image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in pyramid_levels]
82 | return image_shapes
83 |
84 |
85 | def anchors_for_shape(
86 | image_shape,
87 | pyramid_levels=None,
88 | ratios=None,
89 | scales=None,
90 | strides=None,
91 | sizes=None,
92 | shapes_callback=None,
93 | ):
94 |
95 | image_shapes = compute_shape(image_shape, pyramid_levels)
96 |
97 | # compute anchors over all pyramid levels
98 | all_anchors = np.zeros((0, 4))
99 | for idx, p in enumerate(pyramid_levels):
100 | anchors = generate_anchors(base_size=sizes[idx], ratios=ratios, scales=scales)
101 | shifted_anchors = shift(image_shapes[idx], strides[idx], anchors)
102 | all_anchors = np.append(all_anchors, shifted_anchors, axis=0)
103 |
104 | return all_anchors
105 |
106 |
107 | def shift(shape, stride, anchors):
108 | shift_x = (np.arange(0, shape[1]) + 0.5) * stride
109 | shift_y = (np.arange(0, shape[0]) + 0.5) * stride
110 |
111 | shift_x, shift_y = np.meshgrid(shift_x, shift_y)
112 |
113 | shifts = np.vstack((
114 | shift_x.ravel(), shift_y.ravel(),
115 | shift_x.ravel(), shift_y.ravel()
116 | )).transpose()
117 |
118 | # add A anchors (1, A, 4) to
119 | # cell K shifts (K, 1, 4) to get
120 | # shift anchors (K, A, 4)
121 | # reshape to (K*A, 4) shifted anchors
122 | A = anchors.shape[0]
123 | K = shifts.shape[0]
124 | all_anchors = (anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2)))
125 | all_anchors = all_anchors.reshape((K * A, 4))
126 |
127 | return all_anchors
128 |
129 |
--------------------------------------------------------------------------------
/csv_eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import numpy as np
4 | import json
5 | import os
6 |
7 | import torch
8 |
9 |
10 |
11 | def compute_overlap(a, b):
12 | """
13 | Parameters
14 | ----------
15 | a: (N, 4) ndarray of float
16 | b: (K, 4) ndarray of float
17 | Returns
18 | -------
19 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes
20 | """
21 | area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])
22 |
23 | iw = np.minimum(np.expand_dims(a[:, 2], axis=1), b[:, 2]) - np.maximum(np.expand_dims(a[:, 0], 1), b[:, 0])
24 | ih = np.minimum(np.expand_dims(a[:, 3], axis=1), b[:, 3]) - np.maximum(np.expand_dims(a[:, 1], 1), b[:, 1])
25 |
26 | iw = np.maximum(iw, 0)
27 | ih = np.maximum(ih, 0)
28 |
29 | ua = np.expand_dims((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), axis=1) + area - iw * ih
30 |
31 | ua = np.maximum(ua, np.finfo(float).eps)
32 |
33 | intersection = iw * ih
34 |
35 | return intersection / ua
36 |
37 |
38 | def _compute_ap(recall, precision):
39 | """ Compute the average precision, given the recall and precision curves.
40 | Code originally from https://github.com/rbgirshick/py-faster-rcnn.
41 | # Arguments
42 | recall: The recall curve (list).
43 | precision: The precision curve (list).
44 | # Returns
45 | The average precision as computed in py-faster-rcnn.
46 | """
47 | # correct AP calculation
48 | # first append sentinel values at the end
49 | mrec = np.concatenate(([0.], recall, [1.]))
50 | mpre = np.concatenate(([0.], precision, [0.]))
51 |
52 | # compute the precision envelope
53 | for i in range(mpre.size - 1, 0, -1):
54 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
55 |
56 | # to calculate area under PR curve, look for points
57 | # where X axis (recall) changes value
58 | i = np.where(mrec[1:] != mrec[:-1])[0]
59 |
60 | # and sum (\Delta recall) * prec
61 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
62 | return ap
63 |
64 |
65 | def _get_detections(dataset, retinanet, score_threshold=0.05, max_detections=100, save_path=None):
66 | """ Get the detections from the retinanet using the generator.
67 | The result is a list of lists such that the size is:
68 | all_detections[num_images][num_classes] = detections[num_detections, 4 + num_classes]
69 | # Arguments
70 | dataset : The generator used to run images through the retinanet.
71 | retinanet : The retinanet to run on the images.
72 | score_threshold : The score confidence threshold to use.
73 | max_detections : The maximum number of detections to use per image.
74 | save_path : The path to save the images with visualized detections to.
75 | # Returns
76 | A list of lists containing the detections for each image in the generator.
77 | """
78 | all_detections = [[None for i in range(dataset.num_classes())] for j in range(len(dataset))]
79 |
80 | retinanet.eval()
81 |
82 | with torch.no_grad():
83 |
84 | for index in range(len(dataset)):
85 | data = dataset[index]
86 | scale = data['scale']
87 |
88 | # run network
89 | scores, labels, boxes = retinanet(data['img'].permute(2, 0, 1).cuda().float().unsqueeze(dim=0))
90 | if isinstance(scores, torch.Tensor):
91 | scores = scores.cpu().numpy()
92 | labels = labels.cpu().numpy()
93 | boxes = boxes.cpu().numpy()
94 |
95 | # correct boxes for image scale
96 | boxes /= scale
97 |
98 | # select indices which have a score above the threshold
99 | indices = np.where(scores > score_threshold)[0]
100 |
101 | # select those scores
102 | scores = scores[indices]
103 |
104 | # find the order with which to sort the scores
105 | scores_sort = np.argsort(-scores)[:max_detections]
106 |
107 | # select detections
108 | image_boxes = boxes[indices[scores_sort], :]
109 | image_scores = scores[scores_sort]
110 | image_labels = labels[indices[scores_sort]]
111 | image_detections = np.concatenate([image_boxes, np.expand_dims(image_scores, axis=1), np.expand_dims(image_labels, axis=1)], axis=1)
112 |
113 | # copy detections to all_detections
114 | for label in range(dataset.num_classes()):
115 | all_detections[index][label] = image_detections[image_detections[:, -1] == label, :-1]
116 | else:
117 | # copy detections to all_detections
118 | for label in range(dataset.num_classes()):
119 | all_detections[index][label] = np.zeros((0, 5))
120 |
121 | print('{}/{}'.format(index + 1, len(dataset)), end='\r')
122 |
123 | return all_detections
124 |
125 |
126 | def _get_annotations(generator):
127 | """ Get the ground truth annotations from the generator.
128 | The result is a list of lists such that the size is:
129 | all_detections[num_images][num_classes] = annotations[num_detections, 5]
130 | # Arguments
131 | generator : The generator used to retrieve ground truth annotations.
132 | # Returns
133 | A list of lists containing the annotations for each image in the generator.
134 | """
135 | all_annotations = [[None for i in range(generator.num_classes())] for j in range(len(generator))]
136 |
137 | for i in range(len(generator)):
138 | # load the annotations
139 | annotations, _ = generator.load_annotations(i)
140 |
141 | # copy detections to all_annotations
142 | for label in range(generator.num_classes()):
143 | all_annotations[i][label] = annotations[annotations[:, 4] == label, :4].copy()
144 |
145 | print('{}/{}'.format(i + 1, len(generator)), end='\r')
146 |
147 | return all_annotations
148 |
149 |
150 | def evaluate(
151 | generator,
152 | retinanet,
153 | iou_threshold=0.5,
154 | score_threshold=0.05,
155 | max_detections=100,
156 | save_path=None
157 | ):
158 | """ Evaluate a given dataset using a given retinanet.
159 | # Arguments
160 | generator : The generator that represents the dataset to evaluate.
161 | retinanet : The retinanet to evaluate.
162 | iou_threshold : The threshold used to consider when a detection is positive or negative.
163 | score_threshold : The score confidence threshold to use for detections.
164 | max_detections : The maximum number of detections to use per image.
165 | save_path : The path to save images with visualized detections to.
166 | # Returns
167 | A dict mapping class names to mAP scores.
168 | """
169 |
170 |
171 |
172 | # gather all detections and annotations
173 |
174 | all_detections = _get_detections(generator, retinanet, score_threshold=score_threshold, max_detections=max_detections, save_path=save_path)
175 | all_annotations = _get_annotations(generator)
176 |
177 | average_precisions = {}
178 |
179 | for label in range(generator.num_classes()):
180 | false_positives = np.zeros((0,))
181 | true_positives = np.zeros((0,))
182 | scores = np.zeros((0,))
183 | num_annotations = 0.0
184 |
185 | for i in range(len(generator)):
186 | detections = all_detections[i][label]
187 | annotations = all_annotations[i][label]
188 | num_annotations += annotations.shape[0]
189 | detected_annotations = []
190 |
191 | for d in detections:
192 | scores = np.append(scores, d[4])
193 |
194 | if annotations.shape[0] == 0:
195 | false_positives = np.append(false_positives, 1)
196 | true_positives = np.append(true_positives, 0)
197 | continue
198 |
199 | overlaps = compute_overlap(np.expand_dims(d, axis=0), annotations)
200 | assigned_annotation = np.argmax(overlaps, axis=1)
201 | max_overlap = overlaps[0, assigned_annotation]
202 |
203 | if max_overlap >= iou_threshold and assigned_annotation not in detected_annotations:
204 | false_positives = np.append(false_positives, 0)
205 | true_positives = np.append(true_positives, 1)
206 | detected_annotations.append(assigned_annotation)
207 | else:
208 | false_positives = np.append(false_positives, 1)
209 | true_positives = np.append(true_positives, 0)
210 |
211 | # no annotations -> AP for this class is 0 (is this correct?)
212 | if num_annotations == 0:
213 | average_precisions[label] = 0, 0
214 | continue
215 |
216 | # sort by score
217 | indices = np.argsort(-scores)
218 | false_positives = false_positives[indices]
219 | true_positives = true_positives[indices]
220 |
221 | # compute false positives and true positives
222 | false_positives = np.cumsum(false_positives)
223 | true_positives = np.cumsum(true_positives)
224 |
225 | # compute recall and precision
226 | recall = true_positives / num_annotations
227 | precision = true_positives / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)
228 |
229 | # compute average precision
230 | average_precision = _compute_ap(recall, precision)
231 | average_precisions[label] = average_precision, num_annotations
232 |
233 | print('\nmAP:')
234 | for label in range(generator.num_classes()):
235 | label_name = generator.label_to_name(label)
236 | print('{}: {}'.format(label_name, average_precisions[label][0]))
237 |
238 | return average_precisions
239 |
240 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import sys
3 | import os
4 | import torch
5 | import pandas as pd
6 | import numpy as np
7 | import random
8 | import csv
9 | import time
10 | import cv2
11 |
12 | from torch.utils.data import Dataset, DataLoader
13 | from torchvision import transforms, utils
14 | from torch.utils.data.sampler import Sampler
15 |
16 |
17 | import skimage.io
18 | import skimage.transform
19 | import skimage.color
20 | import skimage
21 |
22 | from PIL import Image, ImageEnhance, ImageFilter
23 |
24 |
25 | class CSVDataset(Dataset):
26 | """CSV dataset."""
27 |
28 | def __init__(self, train_file, class_list, transform=None):
29 | """
30 | Args:
31 | train_file (string): CSV file with training annotations
32 | annotations (string): CSV file with class list
33 | test_file (string, optional): CSV file with testing annotations
34 | """
35 | self.train_file = train_file
36 | self.class_list = class_list
37 | self.transform = transform
38 |
39 | # parse the provided class file
40 | try:
41 | with self._open_for_csv(self.class_list) as file:
42 | self.classes = self.load_classes(csv.reader(file, delimiter=' '))
43 | except ValueError as e:
44 | raise(ValueError('invalid CSV class file: {}: {}'.format(self.class_list, e)), None)
45 |
46 | self.labels = {}
47 | for key, value in self.classes.items():
48 | self.labels[value] = key
49 |
50 | # csv with img_path, x1, y1, x2, y2, class_name
51 | try:
52 | with self._open_for_csv(self.train_file) as file:
53 | self.image_data = self._read_annotations(csv.reader(file, delimiter=' '), self.classes)
54 | except ValueError as e:
55 | raise(ValueError('invalid CSV annotations file: {}: {}'.format(self.train_file, e)), None)
56 | self.image_names = list(self.image_data.keys())
57 |
58 | #import pdb
59 | #pdb.set_trace()
60 |
61 | def _parse(self, value, function, fmt):
62 | """
63 | Parse a string into a value, and format a nice ValueError if it fails.
64 | Returns `function(value)`.
65 | Any `ValueError` raised is catched and a new `ValueError` is raised
66 | with message `fmt.format(e)`, where `e` is the caught `ValueError`.
67 | """
68 | try:
69 | return function(value)
70 | except ValueError as e:
71 | raise(ValueError(fmt.format(e)), None)
72 |
73 | def _open_for_csv(self, path):
74 | """
75 | Open a file with flags suitable for csv.reader.
76 | This is different for python2 it means with mode 'rb',
77 | for python3 this means 'r' with "universal newlines".
78 | """
79 | if sys.version_info[0] < 3:
80 | return open(path, 'rb')
81 | else:
82 | return open(path, 'r', newline='')
83 |
84 |
85 | def load_classes(self, csv_reader):
86 | result = {}
87 |
88 | for line, row in enumerate(csv_reader):
89 | line += 1
90 |
91 | try:
92 | class_name, class_id = row
93 | except ValueError:
94 | raise(ValueError('line {}: format should be \'class_name,class_id\''.format(line)), None)
95 | class_id = self._parse(class_id, int, 'line {}: malformed class ID: {{}}'.format(line))
96 |
97 | if class_name in result:
98 | raise ValueError('line {}: duplicate class name: \'{}\''.format(line, class_name))
99 | result[class_name] = class_id
100 | return result
101 |
102 |
103 | def __len__(self):
104 | return len(self.image_names)
105 |
106 | def __getitem__(self, idx):
107 |
108 | img = self.load_image(idx)
109 | annot, igno = self.load_annotations(idx)
110 | sample = {'img': img, 'annot': annot, 'ignore': igno, 'scale': 1}
111 | if self.transform:
112 | sample = self.transform(sample)
113 |
114 | return sample
115 |
116 | def load_image(self, image_index):
117 |
118 | #img = skimage.io.imread(self.image_names[image_index])
119 | img = cv2.imread(self.image_names[image_index])
120 | b,g,r = cv2.split(img)
121 | img = cv2.merge([r,g,b])
122 |
123 |
124 | if len(img.shape) == 2:
125 | img = skimage.color.gray2rgb(img)
126 |
127 | return img.astype(np.float32)/255.0
128 |
129 | def load_annotations(self, image_index):
130 | # get ground truth annotations
131 | annotation_list = self.image_data[self.image_names[image_index]]
132 | annotations = np.zeros((0, 5))
133 | ignores = np.zeros((0, 5))
134 |
135 | # some images appear to miss annotations (like image with id 257034)
136 | if len(annotation_list) == 0:
137 | return annotations, ignores
138 |
139 | # parse annotations
140 | for idx, a in enumerate(annotation_list):
141 | # some annotations have basically no width / height, skip them
142 | x1 = a['x1']
143 | x2 = a['x2']
144 | y1 = a['y1']
145 | y2 = a['y2']
146 |
147 | if (x2-x1) < 1 or (y2-y1) < 1:
148 | continue
149 |
150 | annotation = np.zeros((1, 5))
151 |
152 | annotation[0, 0] = x1
153 | annotation[0, 1] = y1
154 | annotation[0, 2] = x2
155 | annotation[0, 3] = y2
156 | if a['class'] == 'ignore':
157 | annotation[0, 4] = -2
158 | ignores = np.append(ignores, annotation, axis=0)
159 | else:
160 | annotation[0, 4] = self.name_to_label(a['class'])
161 | annotations = np.append(annotations, annotation, axis=0)
162 |
163 | return annotations, ignores
164 |
165 | def _read_annotations(self, csv_reader, classes):
166 | result = {}
167 | for line, row in enumerate(csv_reader):
168 | line += 1
169 |
170 | try:
171 | img_file, x1, y1, x2, y2, class_name = row[:6]
172 | except ValueError:
173 | raise(ValueError('line {}: format should be \'img_file,x1,y1,x2,y2,class_name\' or \'img_file,,,,,\''.format(line)), None)
174 |
175 | if img_file not in result:
176 | result[img_file] = []
177 |
178 | # If a row contains only an image path, it's an image without annotations.
179 | if (x1, y1, x2, y2, class_name) == ('.', '.', '.', '.', '.'):
180 | continue
181 |
182 | x1 = self._parse(x1, int, 'line {}: malformed x1: {{}}'.format(line))
183 | y1 = self._parse(y1, int, 'line {}: malformed y1: {{}}'.format(line))
184 | x2 = self._parse(x2, int, 'line {}: malformed x2: {{}}'.format(line))
185 | y2 = self._parse(y2, int, 'line {}: malformed y2: {{}}'.format(line))
186 |
187 | if class_name !='ignore':
188 | # Check that the bounding box is valid.
189 | if x2 <= x1:
190 | raise ValueError('line {}: x2 ({}) must be higher than x1 ({})'.format(line, x2, x1))
191 | if y2 <= y1:
192 | raise ValueError('line {}: y2 ({}) must be higher than y1 ({})'.format(line, y2, y1))
193 |
194 | # check if the current class name is correctly present
195 | if class_name not in classes:
196 | raise ValueError('line {}: unknown class name: \'{}\' (classes: {})'.format(line, class_name, classes))
197 |
198 | result[img_file].append({'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'class': class_name})
199 | return result
200 |
201 | def name_to_label(self, name):
202 | return self.classes[name]
203 |
204 | def label_to_name(self, label):
205 | return self.labels[label]
206 |
207 | def num_classes(self):
208 | return max(self.classes.values()) + 1
209 |
210 | def image_aspect_ratio(self, image_index):
211 | image = Image.open(self.image_names[image_index])
212 | return float(image.width) / float(image.height)
213 |
214 |
215 | def collater(data):
216 |
217 | imgs = [s['img'] for s in data]
218 | annots = [s['annot'] for s in data]
219 | igno = [s['ignore'] for s in data]
220 | scales = [s['scale'] for s in data]
221 |
222 | widths = [int(s.shape[0]) for s in imgs]
223 | heights = [int(s.shape[1]) for s in imgs]
224 | batch_size = len(imgs)
225 |
226 | max_width = np.array(widths).max()
227 | max_height = np.array(heights).max()
228 |
229 | padded_imgs = torch.zeros(batch_size, max_width, max_height, 3)
230 |
231 | for i in range(batch_size):
232 | img = imgs[i]
233 | padded_imgs[i, :int(img.shape[0]), :int(img.shape[1]), :] = img
234 |
235 | max_num_annots = max(annot.shape[0] for annot in annots)
236 | annot_padded = torch.ones((len(annots), max_num_annots, 5)) * -1
237 | #print(annot_padded.shape)
238 | if max_num_annots > 0:
239 | for idx, annot in enumerate(annots):
240 | #print(annot.shape)
241 | if annot.shape[0] > 0:
242 | annot_padded[idx, :annot.shape[0], :] = annot
243 | else:
244 | annot_padded = torch.ones((len(annots), 1, 5)) * -1
245 |
246 | max_num_igno = max(ig.shape[0] for ig in igno)
247 | ig_padded = torch.ones((len(igno), max_num_igno, 5)) * -1
248 | if max_num_igno > 0:
249 | for idx, ig in enumerate(igno):
250 | if ig.shape[0] > 0:
251 | ig_padded[idx, :ig.shape[0], :] = ig
252 | else:
253 | ig_padded = torch.ones((len(igno), 1, 5)) * -1
254 |
255 | padded_imgs = padded_imgs.permute(0, 3, 1, 2)
256 |
257 | return {'img': padded_imgs, 'annot': annot_padded, 'ignore': ig_padded, 'scale': scales}
258 |
259 | class Resizer(object):
260 | """Convert ndarrays in sample to Tensors."""
261 |
262 | def __call__(self, sample, min_side=800, max_side=1400):
263 |
264 | image, annots, igno, scale = sample['img'], sample['annot'], sample['ignore'], sample['scale']
265 |
266 | rows, cols, cns = image.shape
267 |
268 | #scale = min_side / rows
269 |
270 |
271 | smallest_side = min(rows, cols)
272 |
273 | # rescale the image so the smallest side is min_side
274 | scale = min_side / smallest_side
275 |
276 | # check if the largest side is now greater than max_side, which can happen
277 | # when images have a large aspect ratio
278 | largest_side = max(rows, cols)
279 |
280 | if largest_side * scale > max_side:
281 | scale = max_side / largest_side
282 |
283 |
284 | # resize the image with the computed scale
285 |
286 | image = cv2.resize(image, (int(round((cols*scale))), int(round((rows*scale)))))
287 | #image = skimage.transform.resize(image, (int(round(rows*scale)), int(round((cols*scale)))))
288 |
289 | rows, cols, cns = image.shape
290 |
291 | pad_w = 32 - rows%32
292 | pad_h = 32 - cols%32
293 |
294 | new_image = np.zeros((rows + pad_w, cols + pad_h, cns)).astype(np.float32)
295 | new_image[:rows, :cols, :] = image.astype(np.float32)
296 |
297 | annots[:, :4] *= scale
298 | igno[:, :4] *= scale
299 |
300 | return {'img': new_image, 'annot': annots, 'ignore': igno, 'scale': scale}
301 |
302 |
303 | class Augmenter(object):
304 | """Convert ndarrays in sample to Tensors."""
305 |
306 | def __call__(self, sample, flip_x=0.5):
307 |
308 | if np.random.rand() < flip_x:
309 | image, annots, igno, scales = sample['img'], sample['annot'], sample['ignore'], sample['scale']
310 | image = image[:, ::-1, :]
311 |
312 | rows, cols, channels = image.shape
313 |
314 | x1 = annots[:, 0].copy()
315 | x2 = annots[:, 2].copy()
316 |
317 | x_tmp = x1.copy()
318 |
319 | annots[:, 0] = cols - x2
320 | annots[:, 2] = cols - x_tmp
321 |
322 | y1 = igno[:, 0].copy()
323 | y2 = igno[:, 2].copy()
324 | y_tmp = y1.copy()
325 | igno[:, 0] = cols - y2
326 | igno[:, 2] = cols - y_tmp
327 |
328 | sample = {'img': image, 'annot': annots, 'ignore': igno, 'scale': scales}
329 |
330 | return sample
331 |
332 |
333 | class Random_crop(object):
334 |
335 | def __call__(self, sample):
336 |
337 | image, annots, igno, scales = sample['img'], sample['annot'], sample['ignore'], sample['scale']
338 |
339 | if not annots.shape[0]:
340 | return {'img': image, 'annot': annots, 'ignore': igno, 'scale': scales}
341 | if random.choice([0, 1]):
342 | return {'img': image, 'annot': annots, 'ignore': igno, 'scale': scales}
343 | else:
344 | rows, cols, cns = image.shape
345 | flag = 0
346 | while True:
347 | flag += 1
348 | if flag > 10:
349 | return {'img': image, 'annot': annots, 'ignore': igno, 'scale': scales}
350 |
351 | crop_ratio = random.uniform(0.5, 1)
352 | rows_zero = int(rows * random.uniform(0, 1 - crop_ratio))
353 | cols_zero = int(cols * random.uniform(0, 1 - crop_ratio))
354 | crop_rows = int(rows * crop_ratio)
355 | crop_cols = int(cols * crop_ratio)
356 | '''
357 | new_image = image[rows_zero:rows_zero+crop_rows, cols_zero:cols_zero+crop_cols, :]
358 | new_image = cv2.resize(new_image, (cols, rows))
359 | #new_image = skimage.transform.resize(new_image, (rows, cols))
360 |
361 | new_annots = np.zeros((0, 5))
362 | for i in range(annots.shape[0]):
363 | x1 = max(annots[i, 0] - cols_zero, 0)
364 | y1 = max(annots[i, 1] - rows_zero, 0)
365 | x2 = min(annots[i, 2] - cols_zero, crop_cols)
366 | y2 = min(annots[i, 3] - rows_zero, crop_rows)
367 | label = annots[i, 4]
368 | if x1 + 10 < x2 and y1 + 10 < y2:
369 | x1 /= crop_ratio
370 | y1 /= crop_ratio
371 | x2 /= crop_ratio
372 | y2 /= crop_ratio
373 | new_annots = np.append(new_annots, np.array([[x1, y1, x2, y2, label]]), axis=0)
374 |
375 | if not new_annots.shape[0]:
376 | continue
377 |
378 | new_igno = np.zeros((0, 5))
379 | for i in range(igno.shape[0]):
380 | x1 = max(igno[i, 0] - cols_zero, 0)
381 | y1 = max(igno[i, 1] - rows_zero, 0)
382 | x2 = min(igno[i, 2] - cols_zero, crop_cols)
383 | y2 = min(igno[i, 3] - rows_zero, crop_rows)
384 | label = igno[i, 4]
385 | if x1 + 10 < x2 and y1 + 5 < y2:
386 | x1 /= crop_ratio
387 | y1 /= crop_ratio
388 | x2 /= crop_ratio
389 | y2 /= crop_ratio
390 | new_igno = np.append(new_igno, np.array([[x1, y1, x2, y2, label]]), axis=0)
391 | '''
392 | new_image = np.zeros((rows , cols , cns))
393 | new_image[rows_zero:rows_zero+crop_rows, cols_zero:cols_zero+crop_cols, :] = image[rows_zero:rows_zero+crop_rows, cols_zero:cols_zero+crop_cols, :]
394 |
395 | new_annots = np.zeros((0, 5))
396 | for i in range(annots.shape[0]):
397 | x1 = max(cols_zero, annots[i, 0])
398 | y1 = max(rows_zero, annots[i, 1])
399 | x2 = min(cols_zero+crop_cols, annots[i, 2])
400 | y2 = min(rows_zero+crop_rows, annots[i, 3])
401 | label = annots[i, 4]
402 | if x1+10 < x2 and y1+10 < y2:
403 | new_annots = np.append(new_annots, np.array([[x1,y1,x2,y2,label]]), axis=0)
404 |
405 | if not new_annots.shape[0]:
406 | continue
407 |
408 | new_igno = np.zeros((0, 5))
409 | for i in range(igno.shape[0]):
410 | x1 = max(cols_zero, igno[i, 0])
411 | y1 = max(rows_zero, igno[i, 1])
412 | x2 = min(cols_zero + crop_cols, igno[i, 2])
413 | y2 = min(rows_zero + crop_rows, igno[i, 3])
414 | label = igno[i, 4]
415 | if x1+10 < x2 and y1+5 < y2:
416 | new_igno = np.append(new_igno, np.array([[x1, y1, x2, y2,label]]), axis=0)
417 |
418 | return {'img': new_image, 'annot': new_annots, 'ignore': new_igno, 'scale': scales}
419 |
420 |
421 | class Color(object):
422 |
423 | def __call__(self, sample):
424 | image, annots, igno, scales = sample['img'], sample['annot'], sample['ignore'], sample['scale']
425 | image = Image.fromarray(image)
426 |
427 | ratio = [0.5, 0.8, 1.2, 1.5]
428 |
429 | if random.choice([0, 1]):
430 | enh_bri = ImageEnhance.Brightness(image)
431 | brightness = random.choice(ratio)
432 | image = enh_bri.enhance(brightness)
433 | if random.choice([0, 1]):
434 | enh_col = ImageEnhance.Color(image)
435 | color = random.choice(ratio)
436 | image = enh_col.enhance(color)
437 | if random.choice([0, 1]):
438 | enh_con = ImageEnhance.Contrast(image)
439 | contrast = random.choice(ratio)
440 | image = enh_con.enhance(contrast)
441 | if random.choice([0, 1]):
442 | enh_sha = ImageEnhance.Sharpness(image)
443 | sharpness = random.choice(ratio)
444 | image = enh_sha.enhance(sharpness)
445 | if random.choice([0, 1]):
446 | image = image.filter(ImageFilter.BLUR)
447 |
448 | image = np.asarray(image)
449 | return {'img': image, 'annot': annots, 'ignore': igno, 'scale': scales}
450 |
451 |
452 | class Normalizer(object):
453 |
454 | def __init__(self):
455 | self.mean = np.array([[[0.485, 0.456, 0.406]]])
456 | self.std = np.array([[[0.229, 0.224, 0.225]]])
457 |
458 | def __call__(self, sample):
459 |
460 | image, annots, igno, scales = sample['img'], sample['annot'], sample['ignore'], sample['scale']
461 |
462 | image = (image.astype(np.float32)-self.mean)/self.std
463 |
464 | sample = {'img': torch.from_numpy(image), 'annot': torch.from_numpy(annots), 'ignore':torch.from_numpy(igno), 'scale': scales}
465 | return sample
466 |
467 | class UnNormalizer(object):
468 | def __init__(self, mean=None, std=None):
469 | if mean == None:
470 | self.mean = [0.485, 0.456, 0.406]
471 | else:
472 | self.mean = mean
473 | if std == None:
474 | self.std = [0.229, 0.224, 0.225]
475 | else:
476 | self.std = std
477 |
478 | def __call__(self, tensor):
479 | """
480 | Args:
481 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
482 | Returns:
483 | Tensor: Normalized image.
484 | """
485 | for t, m, s in zip(tensor, self.mean, self.std):
486 | t.mul_(s).add_(m)
487 | return tensor
488 |
489 |
490 | class AspectRatioBasedSampler(Sampler):
491 |
492 | def __init__(self, data_source, batch_size, drop_last):
493 | self.data_source = data_source
494 | self.batch_size = batch_size
495 | self.drop_last = drop_last
496 | self.groups = self.group_images()
497 |
498 | def __iter__(self):
499 | random.shuffle(self.groups)
500 | for group in self.groups:
501 | yield group
502 |
503 | def __len__(self):
504 | if self.drop_last:
505 | return len(self.sampler) // self.batch_size
506 | else:
507 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size
508 |
509 | def group_images(self):
510 | # determine the order of the images
511 | order = list(range(len(self.data_source)))
512 | order.sort(key=lambda x: self.data_source.image_aspect_ratio(x))
513 |
514 | # divide into groups, one group = one batch
515 | return [[order[x % len(order)] for x in range(i, i + self.batch_size)] for i in range(0, len(order), self.batch_size)]
516 |
517 |
518 |
--------------------------------------------------------------------------------
/img/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainofmine/Repulsion_Loss/eb55686b175f2e69c2647472c325c172b2df6fcd/img/1.png
--------------------------------------------------------------------------------
/img/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainofmine/Repulsion_Loss/eb55686b175f2e69c2647472c325c172b2df6fcd/img/2.png
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainofmine/Repulsion_Loss/eb55686b175f2e69c2647472c325c172b2df6fcd/lib/__init__.py
--------------------------------------------------------------------------------
/lib/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainofmine/Repulsion_Loss/eb55686b175f2e69c2647472c325c172b2df6fcd/lib/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/build.sh:
--------------------------------------------------------------------------------
1 | CUDA_ARCH="-gencode arch=compute_30,code=sm_30 \
2 | -gencode arch=compute_35,code=sm_35 \
3 | -gencode arch=compute_50,code=sm_50 \
4 | -gencode arch=compute_52,code=sm_52 \
5 | -gencode arch=compute_60,code=sm_60 \
6 | -gencode arch=compute_61,code=sm_61"
7 |
8 |
9 | # Build NMS
10 | cd nms/src/cuda
11 | echo "Compiling nms kernels by nvcc..."
12 | /usr/local/cuda/bin/nvcc -c -o nms_kernel.cu.o nms_kernel.cu -x cu -Xcompiler -fPIC $CUDA_ARCH
13 | cd ../../
14 | python build.py
15 | cd ../
16 |
--------------------------------------------------------------------------------
/lib/nms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainofmine/Repulsion_Loss/eb55686b175f2e69c2647472c325c172b2df6fcd/lib/nms/__init__.py
--------------------------------------------------------------------------------
/lib/nms/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainofmine/Repulsion_Loss/eb55686b175f2e69c2647472c325c172b2df6fcd/lib/nms/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/nms/__pycache__/pth_nms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainofmine/Repulsion_Loss/eb55686b175f2e69c2647472c325c172b2df6fcd/lib/nms/__pycache__/pth_nms.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/nms/_ext/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainofmine/Repulsion_Loss/eb55686b175f2e69c2647472c325c172b2df6fcd/lib/nms/_ext/__init__.py
--------------------------------------------------------------------------------
/lib/nms/_ext/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainofmine/Repulsion_Loss/eb55686b175f2e69c2647472c325c172b2df6fcd/lib/nms/_ext/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/nms/_ext/nms/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils.ffi import _wrap_function
3 | from ._nms import lib as _lib, ffi as _ffi
4 |
5 | __all__ = []
6 | def _import_symbols(locals):
7 | for symbol in dir(_lib):
8 | fn = getattr(_lib, symbol)
9 | if callable(fn):
10 | locals[symbol] = _wrap_function(fn, _ffi)
11 | else:
12 | locals[symbol] = fn
13 | __all__.append(symbol)
14 |
15 | _import_symbols(locals())
16 |
--------------------------------------------------------------------------------
/lib/nms/_ext/nms/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainofmine/Repulsion_Loss/eb55686b175f2e69c2647472c325c172b2df6fcd/lib/nms/_ext/nms/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/lib/nms/_ext/nms/_nms.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainofmine/Repulsion_Loss/eb55686b175f2e69c2647472c325c172b2df6fcd/lib/nms/_ext/nms/_nms.so
--------------------------------------------------------------------------------
/lib/nms/build.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.ffi import create_extension
4 |
5 |
6 | sources = ['src/nms.c']
7 | headers = ['src/nms.h']
8 | defines = []
9 | with_cuda = False
10 |
11 | if torch.cuda.is_available():
12 | print('Including CUDA code.')
13 | sources += ['src/nms_cuda.c']
14 | headers += ['src/nms_cuda.h']
15 | defines += [('WITH_CUDA', None)]
16 | with_cuda = True
17 |
18 | this_file = os.path.dirname(os.path.realpath(__file__))
19 | print(this_file)
20 | extra_objects = ['src/cuda/nms_kernel.cu.o']
21 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
22 |
23 | ffi = create_extension(
24 | '_ext.nms',
25 | headers=headers,
26 | sources=sources,
27 | define_macros=defines,
28 | relative_to=__file__,
29 | with_cuda=with_cuda,
30 | extra_objects=extra_objects,
31 | extra_compile_args=['-std=c99']
32 | )
33 |
34 | if __name__ == '__main__':
35 | ffi.build()
36 |
--------------------------------------------------------------------------------
/lib/nms/pth_nms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ._ext import nms
3 | import numpy as np
4 |
5 | def pth_nms(dets, thresh):
6 | """
7 | dets has to be a tensor
8 | """
9 | if not dets.is_cuda:
10 | x1 = dets[:, 0]
11 | y1 = dets[:, 1]
12 | x2 = dets[:, 2]
13 | y2 = dets[:, 3]
14 | scores = dets[:, 4]
15 |
16 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
17 | order = scores.sort(0, descending=True)[1]
18 | # order = torch.from_numpy(np.ascontiguousarray(scores.numpy().argsort()[::-1])).long()
19 |
20 | keep = torch.LongTensor(dets.size(0))
21 | num_out = torch.LongTensor(1)
22 | nms.cpu_nms(keep, num_out, dets, order, areas, thresh)
23 |
24 | return keep[:num_out[0]]
25 | else:
26 | x1 = dets[:, 0]
27 | y1 = dets[:, 1]
28 | x2 = dets[:, 2]
29 | y2 = dets[:, 3]
30 | scores = dets[:, 4]
31 |
32 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
33 | order = scores.sort(0, descending=True)[1]
34 | # order = torch.from_numpy(np.ascontiguousarray(scores.cpu().numpy().argsort()[::-1])).long().cuda()
35 |
36 | dets = dets[order].contiguous()
37 |
38 | keep = torch.LongTensor(dets.size(0))
39 | num_out = torch.LongTensor(1)
40 | # keep = torch.cuda.LongTensor(dets.size(0))
41 | # num_out = torch.cuda.LongTensor(1)
42 | nms.gpu_nms(keep, num_out, dets, thresh)
43 |
44 | return order[keep[:num_out[0]].cuda()].contiguous()
45 | # return order[keep[:num_out[0]]].contiguous()
46 |
47 |
--------------------------------------------------------------------------------
/lib/nms/src/cuda/nms_kernel.cu:
--------------------------------------------------------------------------------
1 | // ------------------------------------------------------------------
2 | // Faster R-CNN
3 | // Copyright (c) 2015 Microsoft
4 | // Licensed under The MIT License [see fast-rcnn/LICENSE for details]
5 | // Written by Shaoqing Ren
6 | // ------------------------------------------------------------------
7 | #ifdef __cplusplus
8 | extern "C" {
9 | #endif
10 |
11 | #include
12 | #include
13 | #include
14 | #include "nms_kernel.h"
15 |
16 | __device__ inline float devIoU(float const * const a, float const * const b) {
17 | float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]);
18 | float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]);
19 | float width = fmaxf(right - left + 1, 0.f), height = fmaxf(bottom - top + 1, 0.f);
20 | float interS = width * height;
21 | float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
22 | float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
23 | return interS / (Sa + Sb - interS);
24 | }
25 |
26 | __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
27 | const float *dev_boxes, unsigned long long *dev_mask) {
28 | const int row_start = blockIdx.y;
29 | const int col_start = blockIdx.x;
30 |
31 | // if (row_start > col_start) return;
32 |
33 | const int row_size =
34 | fminf(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
35 | const int col_size =
36 | fminf(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
37 |
38 | __shared__ float block_boxes[threadsPerBlock * 5];
39 | if (threadIdx.x < col_size) {
40 | block_boxes[threadIdx.x * 5 + 0] =
41 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
42 | block_boxes[threadIdx.x * 5 + 1] =
43 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
44 | block_boxes[threadIdx.x * 5 + 2] =
45 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
46 | block_boxes[threadIdx.x * 5 + 3] =
47 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
48 | block_boxes[threadIdx.x * 5 + 4] =
49 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
50 | }
51 | __syncthreads();
52 |
53 | if (threadIdx.x < row_size) {
54 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
55 | const float *cur_box = dev_boxes + cur_box_idx * 5;
56 | int i = 0;
57 | unsigned long long t = 0;
58 | int start = 0;
59 | if (row_start == col_start) {
60 | start = threadIdx.x + 1;
61 | }
62 | for (i = start; i < col_size; i++) {
63 | if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
64 | t |= 1ULL << i;
65 | }
66 | }
67 | const int col_blocks = DIVUP(n_boxes, threadsPerBlock);
68 | dev_mask[cur_box_idx * col_blocks + col_start] = t;
69 | }
70 | }
71 |
72 |
73 | void _nms(int boxes_num, float * boxes_dev,
74 | unsigned long long * mask_dev, float nms_overlap_thresh) {
75 |
76 | dim3 blocks(DIVUP(boxes_num, threadsPerBlock),
77 | DIVUP(boxes_num, threadsPerBlock));
78 | dim3 threads(threadsPerBlock);
79 | nms_kernel<<>>(boxes_num,
80 | nms_overlap_thresh,
81 | boxes_dev,
82 | mask_dev);
83 | }
84 |
85 | #ifdef __cplusplus
86 | }
87 | #endif
88 |
--------------------------------------------------------------------------------
/lib/nms/src/cuda/nms_kernel.cu.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainofmine/Repulsion_Loss/eb55686b175f2e69c2647472c325c172b2df6fcd/lib/nms/src/cuda/nms_kernel.cu.o
--------------------------------------------------------------------------------
/lib/nms/src/cuda/nms_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _NMS_KERNEL
2 | #define _NMS_KERNEL
3 |
4 | #ifdef __cplusplus
5 | extern "C" {
6 | #endif
7 |
8 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
9 | int const threadsPerBlock = sizeof(unsigned long long) * 8;
10 |
11 | void _nms(int boxes_num, float * boxes_dev,
12 | unsigned long long * mask_dev, float nms_overlap_thresh);
13 |
14 | #ifdef __cplusplus
15 | }
16 | #endif
17 |
18 | #endif
19 |
20 |
--------------------------------------------------------------------------------
/lib/nms/src/nms.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | int cpu_nms(THLongTensor * keep_out, THLongTensor * num_out, THFloatTensor * boxes, THLongTensor * order, THFloatTensor * areas, float nms_overlap_thresh) {
5 | // boxes has to be sorted
6 | THArgCheck(THLongTensor_isContiguous(keep_out), 0, "keep_out must be contiguous");
7 | THArgCheck(THLongTensor_isContiguous(boxes), 2, "boxes must be contiguous");
8 | THArgCheck(THLongTensor_isContiguous(order), 3, "order must be contiguous");
9 | THArgCheck(THLongTensor_isContiguous(areas), 4, "areas must be contiguous");
10 | // Number of ROIs
11 | long boxes_num = THFloatTensor_size(boxes, 0);
12 | long boxes_dim = THFloatTensor_size(boxes, 1);
13 |
14 | long * keep_out_flat = THLongTensor_data(keep_out);
15 | float * boxes_flat = THFloatTensor_data(boxes);
16 | long * order_flat = THLongTensor_data(order);
17 | float * areas_flat = THFloatTensor_data(areas);
18 |
19 | THByteTensor* suppressed = THByteTensor_newWithSize1d(boxes_num);
20 | THByteTensor_fill(suppressed, 0);
21 | unsigned char * suppressed_flat = THByteTensor_data(suppressed);
22 |
23 | // nominal indices
24 | int i, j;
25 | // sorted indices
26 | int _i, _j;
27 | // temp variables for box i's (the box currently under consideration)
28 | float ix1, iy1, ix2, iy2, iarea;
29 | // variables for computing overlap with box j (lower scoring box)
30 | float xx1, yy1, xx2, yy2;
31 | float w, h;
32 | float inter, ovr;
33 |
34 | long num_to_keep = 0;
35 | for (_i=0; _i < boxes_num; ++_i) {
36 | i = order_flat[_i];
37 | if (suppressed_flat[i] == 1) {
38 | continue;
39 | }
40 | keep_out_flat[num_to_keep++] = i;
41 | ix1 = boxes_flat[i * boxes_dim];
42 | iy1 = boxes_flat[i * boxes_dim + 1];
43 | ix2 = boxes_flat[i * boxes_dim + 2];
44 | iy2 = boxes_flat[i * boxes_dim + 3];
45 | iarea = areas_flat[i];
46 | for (_j = _i + 1; _j < boxes_num; ++_j) {
47 | j = order_flat[_j];
48 | if (suppressed_flat[j] == 1) {
49 | continue;
50 | }
51 | xx1 = fmaxf(ix1, boxes_flat[j * boxes_dim]);
52 | yy1 = fmaxf(iy1, boxes_flat[j * boxes_dim + 1]);
53 | xx2 = fminf(ix2, boxes_flat[j * boxes_dim + 2]);
54 | yy2 = fminf(iy2, boxes_flat[j * boxes_dim + 3]);
55 | w = fmaxf(0.0, xx2 - xx1 + 1);
56 | h = fmaxf(0.0, yy2 - yy1 + 1);
57 | inter = w * h;
58 | ovr = inter / (iarea + areas_flat[j] - inter);
59 | if (ovr >= nms_overlap_thresh) {
60 | suppressed_flat[j] = 1;
61 | }
62 | }
63 | }
64 |
65 | long *num_out_flat = THLongTensor_data(num_out);
66 | *num_out_flat = num_to_keep;
67 | THByteTensor_free(suppressed);
68 | return 1;
69 | }
--------------------------------------------------------------------------------
/lib/nms/src/nms.h:
--------------------------------------------------------------------------------
1 | int cpu_nms(THLongTensor * keep_out, THLongTensor * num_out, THFloatTensor * boxes, THLongTensor * order, THFloatTensor * areas, float nms_overlap_thresh);
--------------------------------------------------------------------------------
/lib/nms/src/nms_cuda.c:
--------------------------------------------------------------------------------
1 | // ------------------------------------------------------------------
2 | // Faster R-CNN
3 | // Copyright (c) 2015 Microsoft
4 | // Licensed under The MIT License [see fast-rcnn/LICENSE for details]
5 | // Written by Shaoqing Ren
6 | // ------------------------------------------------------------------
7 | #include
8 | #include
9 | #include
10 | #include
11 |
12 | #include "cuda/nms_kernel.h"
13 |
14 |
15 | extern THCState *state;
16 |
17 | int gpu_nms(THLongTensor * keep, THLongTensor* num_out, THCudaTensor * boxes, float nms_overlap_thresh) {
18 | // boxes has to be sorted
19 | THArgCheck(THLongTensor_isContiguous(keep), 0, "boxes must be contiguous");
20 | THArgCheck(THCudaTensor_isContiguous(state, boxes), 2, "boxes must be contiguous");
21 | // Number of ROIs
22 | int boxes_num = THCudaTensor_size(state, boxes, 0);
23 | int boxes_dim = THCudaTensor_size(state, boxes, 1);
24 |
25 | float* boxes_flat = THCudaTensor_data(state, boxes);
26 |
27 | const int col_blocks = DIVUP(boxes_num, threadsPerBlock);
28 | THCudaLongTensor * mask = THCudaLongTensor_newWithSize2d(state, boxes_num, col_blocks);
29 | unsigned long long* mask_flat = THCudaLongTensor_data(state, mask);
30 |
31 | _nms(boxes_num, boxes_flat, mask_flat, nms_overlap_thresh);
32 |
33 | THLongTensor * mask_cpu = THLongTensor_newWithSize2d(boxes_num, col_blocks);
34 | THLongTensor_copyCuda(state, mask_cpu, mask);
35 | THCudaLongTensor_free(state, mask);
36 |
37 | unsigned long long * mask_cpu_flat = THLongTensor_data(mask_cpu);
38 |
39 | THLongTensor * remv_cpu = THLongTensor_newWithSize1d(col_blocks);
40 | unsigned long long* remv_cpu_flat = THLongTensor_data(remv_cpu);
41 | THLongTensor_fill(remv_cpu, 0);
42 |
43 | long * keep_flat = THLongTensor_data(keep);
44 | long num_to_keep = 0;
45 |
46 | int i, j;
47 | for (i = 0; i < boxes_num; i++) {
48 | int nblock = i / threadsPerBlock;
49 | int inblock = i % threadsPerBlock;
50 |
51 | if (!(remv_cpu_flat[nblock] & (1ULL << inblock))) {
52 | keep_flat[num_to_keep++] = i;
53 | unsigned long long *p = &mask_cpu_flat[0] + i * col_blocks;
54 | for (j = nblock; j < col_blocks; j++) {
55 | remv_cpu_flat[j] |= p[j];
56 | }
57 | }
58 | }
59 |
60 | long * num_out_flat = THLongTensor_data(num_out);
61 | * num_out_flat = num_to_keep;
62 |
63 | THLongTensor_free(mask_cpu);
64 | THLongTensor_free(remv_cpu);
65 |
66 | return 1;
67 | }
68 |
--------------------------------------------------------------------------------
/lib/nms/src/nms_cuda.h:
--------------------------------------------------------------------------------
1 | int gpu_nms(THLongTensor * keep_out, THLongTensor* num_out, THCudaTensor * boxes, float nms_overlap_thresh);
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import random
5 |
6 |
7 | def calc_iou(a, b):
8 | area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])
9 |
10 | iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0])
11 | ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1])
12 |
13 | iw = torch.clamp(iw, min=0)
14 | ih = torch.clamp(ih, min=0)
15 |
16 | ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih
17 |
18 | ua = torch.clamp(ua, min=1e-8)
19 |
20 | intersection = iw * ih
21 |
22 | IoU = intersection / ua
23 |
24 | return IoU
25 |
26 |
27 | def IoG(box_a, box_b):
28 |
29 | inter_xmin = torch.max(box_a[:, 0], box_b[:, 0])
30 | inter_ymin = torch.max(box_a[:, 1], box_b[:, 1])
31 | inter_xmax = torch.min(box_a[:, 2], box_b[:, 2])
32 | inter_ymax = torch.min(box_a[:, 3], box_b[:, 3])
33 | Iw = torch.clamp(inter_xmax - inter_xmin, min=0)
34 | Ih = torch.clamp(inter_ymax - inter_ymin, min=0)
35 | I = Iw * Ih
36 | G = (box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])
37 | return I / G
38 |
39 |
40 | def smooth_ln(x, smooth):
41 | return torch.where(
42 | torch.le(x, smooth),
43 | -torch.log(1 - x),
44 | ((x - smooth) / (1 - smooth)) - np.log(1 - smooth)
45 | )
46 |
47 |
48 | class FocalLoss(nn.Module):
49 | # def __init__(self):
50 |
51 | def forward(self, classifications, regressions, anchors, annotations, ignores):
52 | alpha = 0.25
53 | gamma = 2.0
54 | batch_size = classifications.shape[0]
55 | classification_losses = []
56 | regression_losses = []
57 | RepGT_losses = []
58 | RepBox_losses = []
59 |
60 | anchor = anchors[0, :, :]
61 |
62 | anchor_widths = anchor[:, 2] - anchor[:, 0]
63 | anchor_heights = anchor[:, 3] - anchor[:, 1]
64 | anchor_ctr_x = anchor[:, 0] + 0.5 * anchor_widths
65 | anchor_ctr_y = anchor[:, 1] + 0.5 * anchor_heights
66 |
67 | for j in range(batch_size):
68 |
69 | classification = classifications[j, :, :]
70 | regression = regressions[j, :, :]
71 |
72 | bbox_annotation = annotations[j, :, :]
73 | bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
74 |
75 | if bbox_annotation.shape[0] == 0:
76 | regression_losses.append(torch.tensor(0).float().cuda())
77 | classification_losses.append(torch.tensor(0).float().cuda())
78 | RepGT_losses.append(torch.tensor(0).float().cuda())
79 |
80 | continue
81 |
82 | classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
83 |
84 | ignore = ignores[j, :, :]
85 | ignore = ignore[ignore[:, 4] != -1]
86 | if ignore.shape[0] > 0:
87 | iou_igno = calc_iou(anchor, ignore)
88 | iou_igno_max, iou_igno_argmax = torch.max(iou_igno, dim=1)
89 | index_igno = torch.lt(iou_igno_max, 0.5)
90 | anchor_keep = anchor[index_igno, :]
91 | classification = classification[index_igno, :]
92 | regression = regression[index_igno, :]
93 | anchor_widths_keep = anchor_widths[index_igno]
94 | anchor_heights_keep = anchor_heights[index_igno]
95 | anchor_ctr_x_keep = anchor_ctr_x[index_igno]
96 | anchor_ctr_y_keep = anchor_ctr_y[index_igno]
97 | else:
98 | anchor_keep = anchor
99 | anchor_widths_keep = anchor_widths
100 | anchor_heights_keep = anchor_heights
101 | anchor_ctr_x_keep = anchor_ctr_x
102 | anchor_ctr_y_keep = anchor_ctr_y
103 |
104 | IoU = calc_iou(anchor_keep, bbox_annotation[:, :4]) # num_anchors x num_annotations
105 |
106 | IoU_max, IoU_argmax = torch.max(IoU, dim=1) # num_anchors x 1
107 |
108 | # compute the loss for classification
109 | targets = torch.ones(classification.shape) * -1
110 | targets = targets.cuda()
111 |
112 | targets[torch.lt(IoU_max, 0.4), :] = 0
113 |
114 | positive_indices = torch.ge(IoU_max, 0.5)
115 |
116 | num_positive_anchors = positive_indices.sum()
117 |
118 | assigned_annotations = bbox_annotation[IoU_argmax, :] #which gt the anchor matches
119 |
120 | targets[positive_indices, :] = 0
121 | targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1
122 |
123 | alpha_factor = torch.ones(targets.shape).cuda() * alpha
124 |
125 | alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
126 | focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
127 | focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
128 |
129 | bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))
130 |
131 | # cls_loss = focal_weight * torch.pow(bce, gamma)
132 | cls_loss = focal_weight * bce
133 |
134 | cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())
135 |
136 | classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.float(), min=1.0))
137 |
138 | # compute the loss for regression
139 |
140 | if positive_indices.sum() > 0:
141 | assigned_annotations = assigned_annotations[positive_indices, :] #num_pos_anchors * 5
142 |
143 | anchor_widths_pi = anchor_widths_keep[positive_indices]
144 | anchor_heights_pi = anchor_heights_keep[positive_indices]
145 | anchor_ctr_x_pi = anchor_ctr_x_keep[positive_indices]
146 | anchor_ctr_y_pi = anchor_ctr_y_keep[positive_indices]
147 |
148 | gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0]
149 | gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
150 | gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths
151 | gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights
152 |
153 | # clip widths to 1
154 | gt_widths = torch.clamp(gt_widths, min=1)
155 | gt_heights = torch.clamp(gt_heights, min=1)
156 |
157 | targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
158 | targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
159 | targets_dw = torch.log(gt_widths / anchor_widths_pi)
160 | targets_dh = torch.log(gt_heights / anchor_heights_pi)
161 |
162 | targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh))
163 | targets = targets.t()
164 |
165 | targets = targets / torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda()
166 |
167 | negative_indices = 1 - positive_indices
168 |
169 | regression_diff = torch.abs(targets - regression[positive_indices, :])
170 |
171 | regression_loss = torch.where(
172 | torch.le(regression_diff, 1.0 / 9.0),
173 | 0.5 * 9.0 * torch.pow(regression_diff, 2),
174 | regression_diff - 0.5 / 9.0
175 | )
176 | regression_loss = regression_loss.mean()
177 | regression_losses.append(regression_loss)
178 |
179 |
180 | # predict regression to boxes that are positive
181 | if bbox_annotation.shape[0] == 1:
182 | RepGT_losses.append(torch.tensor(0).float().cuda())
183 | # RepBox_losses.append(torch.tensor(0).float().cuda())
184 | else:
185 | regression_pos = regression[positive_indices, :]
186 | regression_pos_dx = regression_pos[:, 0]
187 | regression_pos_dy = regression_pos[:, 1]
188 | regression_pos_dw = regression_pos[:, 2]
189 | regression_pos_dh = regression_pos[:, 3]
190 | predict_w = torch.exp(regression_pos_dw) * anchor_widths_pi
191 | predict_h = torch.exp(regression_pos_dh) * anchor_heights_pi
192 | predict_x = regression_pos_dx * anchor_widths_pi + anchor_ctr_x_pi
193 | predict_y = regression_pos_dy * anchor_heights_pi + anchor_ctr_y_pi
194 | predict_xmin = predict_x - 0.5 * predict_w
195 | predict_ymin = predict_y - 0.5 * predict_h
196 | predict_xmax = predict_x + 0.5 * predict_w
197 | predict_ymax = predict_y + 0.5 * predict_h
198 | predict_boxes = torch.stack((predict_xmin, predict_ymin, predict_xmax, predict_ymax)).t()
199 |
200 | # add RepGT_losses
201 | IoU_pos = IoU[positive_indices, :]
202 | IoU_max_keep, IoU_argmax_keep = torch.max(IoU_pos, dim=1, keepdim=True) # num_anchors x 1
203 | for idx in range(IoU_argmax_keep.shape[0]):
204 | IoU_pos[idx, IoU_argmax_keep[idx]] = -1
205 | IoU_sec, IoU_argsec = torch.max(IoU_pos, dim=1)
206 |
207 | assigned_annotations_sec = bbox_annotation[IoU_argsec, :] # which gt the anchor iou second num_anchors * 5
208 |
209 | IoG_to_minimize = IoG(assigned_annotations_sec, predict_boxes)
210 | RepGT_loss = smooth_ln(IoG_to_minimize, 0.5)
211 | RepGT_loss = RepGT_loss.mean()
212 | RepGT_losses.append(RepGT_loss)
213 |
214 | # add PepBox losses
215 | IoU_argmax_pos = IoU_argmax[positive_indices].float()
216 | IoU_argmax_pos = IoU_argmax_pos.unsqueeze(0).t()
217 | predict_boxes = torch.cat([predict_boxes, IoU_argmax_pos], dim=1)
218 | predict_boxes_np = predict_boxes.detach().cpu().numpy()
219 | num_gt = bbox_annotation.shape[0]
220 | predict_boxes_sampled = []
221 | for id in range(num_gt):
222 | index = np.where(predict_boxes_np[:, 4]==id)[0]
223 | if index.shape[0]:
224 | idx = random.choice(range(index.shape[0]))
225 | predict_boxes_sampled.append(predict_boxes[index[idx], :4])
226 | predict_boxes_sampled = torch.stack(predict_boxes_sampled)
227 | iou_repbox = calc_iou(predict_boxes_sampled, predict_boxes_sampled)
228 | mask = torch.lt(iou_repbox, 1.).float()
229 | iou_repbox = iou_repbox * mask
230 | RepBox_loss = smooth_ln(iou_repbox, 0.5)
231 | RepBox_loss = RepBox_loss.sum() / torch.clamp(torch.sum(torch.gt(iou_repbox, 0)).float(), min=1.0)
232 | RepBox_losses.append(RepBox_loss)
233 |
234 | else:
235 | regression_losses.append(torch.tensor(0).float().cuda())
236 | RepGT_losses.append(torch.tensor(0).float().cuda())
237 | RepBox_losses.append(torch.tensor(0).float().cuda())
238 |
239 | return torch.stack(classification_losses).mean(dim=0, keepdim=True), \
240 | torch.stack(regression_losses).mean(dim=0, keepdim=True), \
241 | torch.stack(RepGT_losses).mean(dim=0, keepdim=True), \
242 | torch.stack(RepBox_losses).mean(dim=0, keepdim=True)
243 |
244 |
245 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import math
4 | import time
5 | import os
6 | import numpy as np
7 | import cv2
8 | import matplotlib.pyplot as plt
9 | import torch.utils.model_zoo as model_zoo
10 | from torch.nn import init
11 | from utils import BasicBlock, Bottleneck, BBoxTransform, ClipBoxes
12 | from anchors import Anchors
13 | import losses
14 | from lib.nms.pth_nms import pth_nms
15 | from dataloader import UnNormalizer
16 | unnormalize = UnNormalizer()
17 |
18 |
19 | def nms(dets, thresh):
20 | "Dispatch to either CPU or GPU NMS implementations.\
21 | Accept dets as tensor"""
22 | return pth_nms(dets, thresh)
23 |
24 |
25 |
26 | class PyramidFeatures(nn.Module):
27 | def __init__(self, C3_size, C4_size, C5_size, feature_size=256):
28 | super(PyramidFeatures, self).__init__()
29 |
30 | # upsample C5 to get P5 from the FPN paper
31 | self.P5_1 = nn.Conv2d(C5_size, feature_size, kernel_size=1, stride=1, padding=0)
32 | self.P5_upsampled = nn.Upsample(scale_factor=2, mode='nearest')
33 | self.P5_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
34 |
35 | # add P5 elementwise to C4
36 | self.P4_1 = nn.Conv2d(C4_size, feature_size, kernel_size=1, stride=1, padding=0)
37 | self.P4_upsampled = nn.Upsample(scale_factor=2, mode='nearest')
38 | self.P4_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
39 |
40 | # add P4 elementwise to C3
41 | self.P3_1 = nn.Conv2d(C3_size, feature_size, kernel_size=1, stride=1, padding=0)
42 | self.P3_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
43 |
44 | # "P6 is obtained via a 3x3 stride-2 conv on C5"
45 | self.P6 = nn.Conv2d(C5_size, feature_size, kernel_size=3, stride=2, padding=1)
46 |
47 | # "P7 is computed by applying ReLU followed by a 3x3 stride-2 conv on P6"
48 | self.P7_1 = nn.ReLU()
49 | self.P7_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=2, padding=1)
50 |
51 | def forward(self, inputs):
52 | C3, C4, C5 = inputs
53 |
54 | P5_x = self.P5_1(C5)
55 | P5_upsampled_x = self.P5_upsampled(P5_x)
56 | P5_x = self.P5_2(P5_x)
57 |
58 | P4_x = self.P4_1(C4)
59 | P4_x = P5_upsampled_x + P4_x
60 | P4_upsampled_x = self.P4_upsampled(P4_x)
61 | P4_x = self.P4_2(P4_x)
62 |
63 | P3_x = self.P3_1(C3)
64 | P3_x = P3_x + P4_upsampled_x
65 | P3_x = self.P3_2(P3_x)
66 |
67 | P6_x = self.P6(C5)
68 |
69 | P7_x = self.P7_1(P6_x)
70 | P7_x = self.P7_2(P7_x)
71 |
72 | return [P3_x, P4_x, P5_x, P6_x, P7_x]
73 |
74 |
75 | class RegressionModel(nn.Module):
76 | def __init__(self, num_features_in, num_anchors=9, feature_size=256):
77 | super(RegressionModel, self).__init__()
78 |
79 | self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
80 | self.act1 = nn.ReLU()
81 |
82 | self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
83 | self.act2 = nn.ReLU()
84 |
85 | self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
86 | self.act3 = nn.ReLU()
87 |
88 | self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
89 | self.act4 = nn.ReLU()
90 |
91 | self.output = nn.Conv2d(feature_size, num_anchors * 4, kernel_size=3, padding=1)
92 |
93 | def forward(self, x):
94 | out = self.conv1(x)
95 | out = self.act1(out)
96 |
97 | out = self.conv2(out)
98 | out = self.act2(out)
99 |
100 | out = self.conv3(out)
101 | out = self.act3(out)
102 |
103 | out = self.conv4(out)
104 | out = self.act4(out)
105 |
106 | out = self.output(out)
107 |
108 | # out is B x C x W x H, with C = 4*num_anchors
109 | out = out.permute(0, 2, 3, 1)
110 |
111 | return out.contiguous().view(out.shape[0], -1, 4)
112 |
113 |
114 | class ClassificationModel(nn.Module):
115 | def __init__(self, num_features_in, num_anchors=9, num_classes=80, prior=0.01, feature_size=256):
116 | super(ClassificationModel, self).__init__()
117 |
118 | self.num_classes = num_classes
119 | self.num_anchors = num_anchors
120 |
121 | self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
122 | self.act1 = nn.ReLU()
123 |
124 | self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
125 | self.act2 = nn.ReLU()
126 |
127 | self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
128 | self.act3 = nn.ReLU()
129 |
130 | self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
131 | self.act4 = nn.ReLU()
132 |
133 | self.output = nn.Conv2d(feature_size, num_anchors * num_classes, kernel_size=3, padding=1)
134 | self.output_act = nn.Sigmoid()
135 |
136 | def forward(self, x):
137 | out = self.conv1(x)
138 | out = self.act1(out)
139 |
140 | out = self.conv2(out)
141 | out = self.act2(out)
142 |
143 | out = self.conv3(out)
144 | out = self.act3(out)
145 |
146 | out = self.conv4(out)
147 | out = self.act4(out)
148 |
149 | out = self.output(out)
150 | out = self.output_act(out)
151 |
152 | # out is B x C x W x H, with C = n_classes + n_anchors
153 | out1 = out.permute(0, 2, 3, 1)
154 |
155 | batch_size, width, height, channels = out1.shape
156 |
157 | out2 = out1.view(batch_size, width, height, self.num_anchors, self.num_classes)
158 |
159 | return out2.contiguous().view(x.shape[0], -1, self.num_classes)
160 |
161 |
162 | class ResNet(nn.Module):
163 |
164 | def __init__(self, num_classes, block, layers):
165 | self.inplanes = 64
166 | super(ResNet, self).__init__()
167 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
168 | self.bn1 = nn.BatchNorm2d(64)
169 | self.relu = nn.ReLU(inplace=True)
170 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
171 | self.layer1 = self._make_layer(block, 64, layers[0])
172 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
173 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
174 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
175 |
176 | if block == BasicBlock:
177 | fpn_sizes = [self.layer2[layers[1] - 1].conv2.out_channels, self.layer3[layers[2] - 1].conv2.out_channels,
178 | self.layer4[layers[3] - 1].conv2.out_channels]
179 | elif block == Bottleneck:
180 | fpn_sizes = [self.layer2[layers[1] - 1].conv3.out_channels, self.layer3[layers[2] - 1].conv3.out_channels,
181 | self.layer4[layers[3] - 1].conv3.out_channels]
182 |
183 | self.fpn = PyramidFeatures(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2])
184 |
185 | self.regressionModel = RegressionModel(256)
186 | self.classificationModel = ClassificationModel(256, num_classes=num_classes)
187 |
188 | self.anchors = Anchors()
189 |
190 | self.regressBoxes = BBoxTransform()
191 |
192 | self.clipBoxes = ClipBoxes()
193 |
194 | self.focalLoss = losses.FocalLoss()
195 |
196 | for m in self.modules():
197 | if isinstance(m, nn.Conv2d):
198 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
199 | m.weight.data.normal_(0, math.sqrt(2. / n))
200 | # init.xavier_normal(m.weight)
201 | elif isinstance(m, nn.BatchNorm2d):
202 | m.weight.data.fill_(1)
203 | m.bias.data.zero_()
204 |
205 | prior = 0.01
206 |
207 | self.classificationModel.output.weight.data.fill_(0)
208 | self.classificationModel.output.bias.data.fill_(-math.log((1.0 - prior) / prior))
209 |
210 | self.regressionModel.output.weight.data.fill_(0)
211 | self.regressionModel.output.bias.data.fill_(0)
212 |
213 | self.freeze_bn()
214 |
215 | def _make_layer(self, block, planes, blocks, stride=1):
216 | downsample = None
217 | if stride != 1 or self.inplanes != planes * block.expansion:
218 | downsample = nn.Sequential(
219 | nn.Conv2d(self.inplanes, planes * block.expansion,
220 | kernel_size=1, stride=stride, bias=False),
221 | nn.BatchNorm2d(planes * block.expansion),
222 | )
223 |
224 | layers = []
225 | layers.append(block(self.inplanes, planes, stride, downsample))
226 | self.inplanes = planes * block.expansion
227 | for i in range(1, blocks):
228 | layers.append(block(self.inplanes, planes))
229 |
230 | return nn.Sequential(*layers)
231 |
232 | def freeze_bn(self):
233 | '''Freeze BatchNorm layers.'''
234 | for layer in self.modules():
235 | if isinstance(layer, nn.BatchNorm2d):
236 | layer.eval()
237 |
238 | def forward(self, inputs):
239 |
240 | if self.training:
241 | img_batch, annotations, ignores = inputs
242 | else:
243 | img_batch = inputs
244 |
245 | x = self.conv1(img_batch)
246 | x = self.bn1(x)
247 | x = self.relu(x)
248 | x = self.maxpool(x)
249 |
250 | x1 = self.layer1(x)
251 | x2 = self.layer2(x1)
252 | x3 = self.layer3(x2)
253 | x4 = self.layer4(x3)
254 |
255 | features = self.fpn([x2, x3, x4])
256 |
257 | regression = torch.cat([self.regressionModel(feature) for feature in features], dim=1)
258 |
259 | classification = torch.cat([self.classificationModel(feature) for feature in features], dim=1)
260 |
261 | anchors = self.anchors(img_batch)
262 |
263 | if self.training:
264 | return self.focalLoss(classification, regression, anchors, annotations, ignores)
265 | else:
266 |
267 | transformed_anchors = self.regressBoxes(anchors, regression)
268 | transformed_anchors = self.clipBoxes(transformed_anchors, img_batch)
269 |
270 | scores = torch.max(classification, dim=2, keepdim=True)[0]
271 | scores_over_thresh = (scores > 0.05)[0, :, 0]
272 |
273 | if scores_over_thresh.sum() == 0:
274 | # no boxes to NMS, just return
275 | # return [torch.zeros(0), torch.zeros(0), torch.zeros(0, 4)]
276 | return [None, None, None]
277 |
278 | classification = classification[:, scores_over_thresh, :]
279 | transformed_anchors = transformed_anchors[:, scores_over_thresh, :]
280 | scores = scores[:, scores_over_thresh, :]
281 |
282 | anchors_nms_idx = nms(torch.cat([transformed_anchors, scores], dim=2)[0, :, :], 0.5)
283 | nms_scores, nms_class = classification[0, anchors_nms_idx, :].max(dim=1)
284 | return [nms_scores, nms_class, transformed_anchors[0, anchors_nms_idx, :]]
285 |
286 |
287 | def resnet18(num_classes, **kwargs):
288 | """Constructs a ResNet-18 model.
289 | Args:
290 | pretrained (bool): If True, returns a model pre-trained on ImageNet
291 | """
292 | model = ResNet(num_classes, BasicBlock, [2, 2, 2, 2], **kwargs)
293 | return model
294 |
295 |
296 | def resnet34(num_classes, **kwargs):
297 | """Constructs a ResNet-34 model.
298 | Args:
299 | pretrained (bool): If True, returns a model pre-trained on ImageNet
300 | """
301 | model = ResNet(num_classes, BasicBlock, [3, 4, 6, 3], **kwargs)
302 | return model
303 |
304 |
305 | def resnet50(num_classes, **kwargs):
306 | """Constructs a ResNet-50 model.
307 | Args:
308 | pretrained (bool): If True, returns a model pre-trained on ImageNet
309 | """
310 | model = ResNet(num_classes, Bottleneck, [3, 4, 6, 3], **kwargs)
311 | return model
312 |
313 |
314 | def resnet101(num_classes, **kwargs):
315 | """Constructs a ResNet-101 model.
316 | Args:
317 | pretrained (bool): If True, returns a model pre-trained on ImageNet
318 | """
319 | model = ResNet(num_classes, Bottleneck, [3, 4, 23, 3], **kwargs)
320 | return model
321 |
322 |
323 | def resnet152(num_classes, **kwargs):
324 | """Constructs a ResNet-152 model.
325 | Args:
326 | pretrained (bool): If True, returns a model pre-trained on ImageNet
327 | """
328 | model = ResNet(num_classes, Bottleneck, [3, 8, 36, 3], **kwargs)
329 | return model
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import copy
4 | import argparse
5 | import pdb
6 | import collections
7 | import sys
8 |
9 | import numpy as np
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.optim as optim
14 | from torch.optim import lr_scheduler
15 | from torch.autograd import Variable
16 | from torchvision import datasets, models, transforms
17 | import torchvision
18 | from tensorboardX import SummaryWriter
19 |
20 | import model
21 | from anchors import Anchors
22 | import losses
23 | from dataloader import CSVDataset, collater, Resizer, AspectRatioBasedSampler, Augmenter, UnNormalizer, Normalizer
24 | from torch.utils.data import Dataset, DataLoader
25 |
26 | import csv_eval
27 | import cv2
28 | assert torch.__version__.split('.')[1] == '4'
29 |
30 | print('CUDA available: {}'.format(torch.cuda.is_available()))
31 |
32 | ckpt = False
33 | def main(args=None):
34 |
35 | parser = argparse.ArgumentParser(description='Simple training script for training a RetinaNet network.')
36 |
37 | parser.add_argument('--csv_train', help='Path to file containing training annotations (see readme)')
38 | parser.add_argument('--csv_classes', help='Path to file containing class list (see readme)')
39 | parser.add_argument('--csv_val', help='Path to file containing validation annotations (optional, see readme)')
40 |
41 | parser.add_argument('--depth', help='Resnet depth, must be one of 18, 34, 50, 101, 152', type=int, default=50)
42 | parser.add_argument('--epochs', help='Number of epochs', type=int, default=50)
43 |
44 | parser.add_argument('--model_name', help='name of the model to save')
45 | parser.add_argument('--pretrained', help='pretrained model name')
46 |
47 | parser = parser.parse_args(args)
48 |
49 | # Create the data loaders
50 | dataset_train = CSVDataset(train_file=parser.csv_train, class_list=parser.csv_classes, transform=transforms.Compose([Resizer(), Augmenter(), Normalizer()]))
51 |
52 | if parser.csv_val is None:
53 | dataset_val = None
54 | print('No validation annotations provided.')
55 | else:
56 | dataset_val = CSVDataset(train_file=parser.csv_val, class_list=parser.csv_classes, transform=transforms.Compose([Resizer(), Normalizer()]))
57 |
58 | sampler = AspectRatioBasedSampler(dataset_train, batch_size=2, drop_last=False)
59 | dataloader_train = DataLoader(dataset_train, num_workers=16, collate_fn=collater, batch_sampler=sampler)
60 | #dataloader_train = DataLoader(dataset_train, num_workers=16, collate_fn=collater, batch_size=8, shuffle=True)
61 |
62 | if dataset_val is not None:
63 | sampler_val = AspectRatioBasedSampler(dataset_val, batch_size=2, drop_last=False)
64 | dataloader_val = DataLoader(dataset_val, num_workers=16, collate_fn=collater, batch_sampler=sampler_val)
65 | #dataloader_val = DataLoader(dataset_train, num_workers=16, collate_fn=collater, batch_size=8, shuffle=True)
66 |
67 | # Create the model_pose_level_attention
68 | if parser.depth == 18:
69 | retinanet = model.resnet18(num_classes=dataset_train.num_classes())
70 | elif parser.depth == 34:
71 | retinanet = model.resnet34(num_classes=dataset_train.num_classes())
72 | elif parser.depth == 50:
73 | retinanet = model.resnet50(num_classes=dataset_train.num_classes())
74 | elif parser.depth == 101:
75 | retinanet = model.resnet101(num_classes=dataset_train.num_classes())
76 | elif parser.depth == 152:
77 | retinanet = model.resnet152(num_classes=dataset_train.num_classes())
78 | else:
79 | raise ValueError('Unsupported model depth, must be one of 18, 34, 50, 101, 152')
80 |
81 | if ckpt:
82 | retinanet = torch.load('')
83 | print('load ckpt')
84 | else:
85 | retinanet_dict = retinanet.state_dict()
86 | pretrained_dict = torch.load('./weight/' + parser.pretrained)
87 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in retinanet_dict}
88 | retinanet_dict.update(pretrained_dict)
89 | retinanet.load_state_dict(retinanet_dict)
90 | print('load pretrained backbone')
91 |
92 | print(retinanet)
93 | retinanet = torch.nn.DataParallel(retinanet, device_ids=[0])
94 | retinanet.cuda()
95 |
96 | retinanet.training = True
97 |
98 | optimizer = optim.Adam(retinanet.parameters(), lr=1e-5)
99 | #optimizer = optim.SGD(retinanet.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
100 |
101 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
102 | #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
103 |
104 | loss_hist = collections.deque(maxlen=500)
105 |
106 | retinanet.train()
107 | retinanet.module.freeze_bn()
108 |
109 | print('Num training images: {}'.format(len(dataset_train)))
110 | f_map = open('./mAP_txt/' + parser.model_name + '.txt', 'a')
111 | writer = SummaryWriter(log_dir='./summary')
112 | iters = 0
113 | for epoch_num in range(0, parser.epochs):
114 |
115 | retinanet.train()
116 | retinanet.module.freeze_bn()
117 |
118 | epoch_loss = []
119 | #scheduler.step()
120 |
121 | for iter_num, data in enumerate(dataloader_train):
122 |
123 | iters += 1
124 |
125 | optimizer.zero_grad()
126 |
127 | classification_loss, regression_loss, repgt_loss, repbox_loss = retinanet([data['img'].cuda().float(), data['annot'], data['ignore']])
128 |
129 | classification_loss = classification_loss.mean()
130 | regression_loss = regression_loss.mean()
131 | repgt_loss = repgt_loss.mean()
132 | repbox_loss = repbox_loss.mean()
133 |
134 | loss = classification_loss + regression_loss + 0.5 * repgt_loss + 0.5 * repbox_loss
135 |
136 | if bool(loss == 0):
137 | continue
138 |
139 | loss.backward()
140 |
141 | torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.1)
142 |
143 | optimizer.step()
144 |
145 | loss_hist.append(float(loss))
146 |
147 | epoch_loss.append(float(loss))
148 |
149 | print('Epoch: {} | Iteration: {} | Classification loss: {:1.5f} | Regression loss: {:1.5f} | RepGT loss {:1.5f} | RepBox loss {:1.5f} | Running loss: {:1.5f}'.format(epoch_num, iter_num, float(classification_loss), float(regression_loss), float(repgt_loss), float(repbox_loss), np.mean(loss_hist)))
150 |
151 | writer.add_scalar('classification_loss', classification_loss, iters)
152 | writer.add_scalar('regression_loss', regression_loss, iters)
153 | writer.add_scalar('loss', loss, iters)
154 |
155 | del classification_loss
156 | del regression_loss
157 |
158 |
159 | if parser.csv_val is not None:
160 |
161 | print('Evaluating dataset')
162 |
163 | mAP = csv_eval.evaluate(dataset_val, retinanet)
164 | f_map.write('mAP:{}, epoch:{}'.format(mAP[0][0], epoch_num))
165 | f_map.write('\n')
166 |
167 | scheduler.step(np.mean(epoch_loss))
168 |
169 | torch.save(retinanet.module, './ckpt/' + parser.model_name + '_{}.pt'.format(epoch_num))
170 |
171 | retinanet.eval()
172 |
173 | writer.export_scalars_to_json("./summary/' + parser.pretrained + 'all_scalars.json")
174 | f_map.close()
175 | writer.close()
176 |
177 | if __name__ == '__main__':
178 | main()
179 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | def conv3x3(in_planes, out_planes, stride=1):
7 | """3x3 convolution with padding"""
8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
9 | padding=1, bias=False)
10 |
11 | class BasicBlock(nn.Module):
12 | expansion = 1
13 |
14 | def __init__(self, inplanes, planes, stride=1, downsample=None):
15 | super(BasicBlock, self).__init__()
16 | self.conv1 = conv3x3(inplanes, planes, stride)
17 | self.bn1 = nn.BatchNorm2d(planes)
18 | self.relu = nn.ReLU(inplace=True)
19 | self.conv2 = conv3x3(planes, planes)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 | self.downsample = downsample
22 | self.stride = stride
23 |
24 | def forward(self, x):
25 | residual = x
26 |
27 | out = self.conv1(x)
28 | out = self.bn1(out)
29 | out = self.relu(out)
30 |
31 | out = self.conv2(out)
32 | out = self.bn2(out)
33 |
34 | if self.downsample is not None:
35 | residual = self.downsample(x)
36 |
37 | out += residual
38 | out = self.relu(out)
39 |
40 | return out
41 |
42 |
43 | class Bottleneck(nn.Module):
44 | expansion = 4
45 |
46 | def __init__(self, inplanes, planes, stride=1, downsample=None):
47 | super(Bottleneck, self).__init__()
48 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
49 | self.bn1 = nn.BatchNorm2d(planes)
50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
51 | padding=1, bias=False)
52 | self.bn2 = nn.BatchNorm2d(planes)
53 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
54 | self.bn3 = nn.BatchNorm2d(planes * 4)
55 | self.relu = nn.ReLU(inplace=True)
56 | self.downsample = downsample
57 | self.stride = stride
58 |
59 | def forward(self, x):
60 | residual = x
61 |
62 | out = self.conv1(x)
63 | out = self.bn1(out)
64 | out = self.relu(out)
65 |
66 | out = self.conv2(out)
67 | out = self.bn2(out)
68 | out = self.relu(out)
69 |
70 | out = self.conv3(out)
71 | out = self.bn3(out)
72 |
73 | if self.downsample is not None:
74 | residual = self.downsample(x)
75 |
76 | out += residual
77 | out = self.relu(out)
78 |
79 | return out
80 |
81 |
82 | class SELayer(nn.Module):
83 | def __init__(self, channel, reduction=16):
84 | super(SELayer, self).__init__()
85 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
86 | self.fc = nn.Sequential(
87 | nn.Linear(channel, channel // reduction),
88 | nn.ReLU(inplace=True),
89 | nn.Linear(channel // reduction, channel),
90 | nn.Sigmoid()
91 | )
92 |
93 | def forward(self, x):
94 | b, c, _, _ = x.size()
95 | y = self.avg_pool(x).view(b, c)
96 | y = self.fc(y).view(b, c, 1, 1)
97 | return x * y
98 |
99 | class BottleneckSE(nn.Module):
100 | expansion = 4
101 |
102 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
103 | super(BottleneckSE, self).__init__()
104 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
105 | self.bn1 = nn.BatchNorm2d(planes)
106 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
107 | padding=1, bias=False)
108 | self.bn2 = nn.BatchNorm2d(planes)
109 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
110 | self.bn3 = nn.BatchNorm2d(planes * 4)
111 | self.relu = nn.ReLU(inplace=True)
112 | self.se = SELayer(planes * 4, reduction)
113 | self.downsample = downsample
114 | self.stride = stride
115 |
116 | def forward(self, x):
117 | residual = x
118 |
119 | out = self.conv1(x)
120 | out = self.bn1(out)
121 | out = self.relu(out)
122 |
123 | out = self.conv2(out)
124 | out = self.bn2(out)
125 | out = self.relu(out)
126 |
127 | out = self.conv3(out)
128 | out = self.bn3(out)
129 | out = self.se(out)
130 |
131 | if self.downsample is not None:
132 | residual = self.downsample(x)
133 |
134 | out += residual
135 | out = self.relu(out)
136 |
137 | return out
138 |
139 |
140 | class CBAM_Module(nn.Module):
141 |
142 | def __init__(self, channels, reduction):
143 | super(CBAM_Module, self).__init__()
144 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
145 | self.max_pool = nn.AdaptiveMaxPool2d(1)
146 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,
147 | padding=0)
148 | self.relu = nn.ReLU(inplace=True)
149 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,
150 | padding=0)
151 | self.sigmoid_channel = nn.Sigmoid()
152 | self.conv_after_concat = nn.Conv2d(2, 1, kernel_size = 7, stride=1, padding = 3)
153 | self.sigmoid_spatial = nn.Sigmoid()
154 |
155 | def forward(self, x):
156 | module_input = x
157 | avg = self.avg_pool(x)
158 | mx = self.max_pool(x)
159 | avg = self.fc1(avg)
160 | mx = self.fc1(mx)
161 | avg = self.relu(avg)
162 | mx = self.relu(mx)
163 | avg = self.fc2(avg)
164 | mx = self.fc2(mx)
165 | x = avg + mx
166 | x = self.sigmoid_channel(x)
167 | x = module_input * x
168 | module_input = x
169 | avg = torch.mean(x, 1, True)
170 | mx, _ = torch.max(x, 1, True)
171 | x = torch.cat((avg, mx), 1)
172 | x = self.conv_after_concat(x)
173 | x = self.sigmoid_spatial(x)
174 | x = module_input * x
175 | return x
176 |
177 | class BottleneckCBAM(nn.Module):
178 | expansion = 4
179 |
180 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
181 | super(BottleneckCBAM, self).__init__()
182 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
183 | self.bn1 = nn.BatchNorm2d(planes)
184 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
185 | padding=1, bias=False)
186 | self.bn2 = nn.BatchNorm2d(planes)
187 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
188 | self.bn3 = nn.BatchNorm2d(planes * 4)
189 | self.relu = nn.ReLU(inplace=True)
190 | self.se = CBAM_Module(planes * 4, reduction)
191 | self.downsample = downsample
192 | self.stride = stride
193 |
194 | def forward(self, x):
195 | residual = x
196 |
197 | out = self.conv1(x)
198 | out = self.bn1(out)
199 | out = self.relu(out)
200 |
201 | out = self.conv2(out)
202 | out = self.bn2(out)
203 | out = self.relu(out)
204 |
205 | out = self.conv3(out)
206 | out = self.bn3(out)
207 | out = self.se(out)
208 |
209 | if self.downsample is not None:
210 | residual = self.downsample(x)
211 |
212 | out += residual
213 | out = self.relu(out)
214 |
215 | return out
216 |
217 |
218 | class BBoxTransform(nn.Module):
219 |
220 | def __init__(self, mean=None, std=None):
221 | super(BBoxTransform, self).__init__()
222 | if mean is None:
223 | self.mean = torch.from_numpy(np.array([0, 0, 0, 0]).astype(np.float32)).cuda()
224 | else:
225 | self.mean = mean
226 | if std is None:
227 | self.std = torch.from_numpy(np.array([0.1, 0.1, 0.2, 0.2]).astype(np.float32)).cuda()
228 | else:
229 | self.std = std
230 |
231 | def forward(self, boxes, deltas):
232 |
233 | widths = boxes[:, :, 2] - boxes[:, :, 0]
234 | heights = boxes[:, :, 3] - boxes[:, :, 1]
235 | ctr_x = boxes[:, :, 0] + 0.5 * widths
236 | ctr_y = boxes[:, :, 1] + 0.5 * heights
237 |
238 | dx = deltas[:, :, 0] * self.std[0] + self.mean[0]
239 | dy = deltas[:, :, 1] * self.std[1] + self.mean[1]
240 | dw = deltas[:, :, 2] * self.std[2] + self.mean[2]
241 | dh = deltas[:, :, 3] * self.std[3] + self.mean[3]
242 |
243 | pred_ctr_x = ctr_x + dx * widths
244 | pred_ctr_y = ctr_y + dy * heights
245 | pred_w = torch.exp(dw) * widths
246 | pred_h = torch.exp(dh) * heights
247 |
248 | pred_boxes_x1 = pred_ctr_x - 0.5 * pred_w
249 | pred_boxes_y1 = pred_ctr_y - 0.5 * pred_h
250 | pred_boxes_x2 = pred_ctr_x + 0.5 * pred_w
251 | pred_boxes_y2 = pred_ctr_y + 0.5 * pred_h
252 |
253 | pred_boxes = torch.stack([pred_boxes_x1, pred_boxes_y1, pred_boxes_x2, pred_boxes_y2], dim=2)
254 |
255 | return pred_boxes
256 |
257 |
258 | class ClipBoxes(nn.Module):
259 |
260 | def __init__(self, width=None, height=None):
261 | super(ClipBoxes, self).__init__()
262 |
263 | def forward(self, boxes, img):
264 |
265 | batch_size, num_channels, height, width = img.shape
266 |
267 | boxes[:, :, 0] = torch.clamp(boxes[:, :, 0], min=0)
268 | boxes[:, :, 1] = torch.clamp(boxes[:, :, 1], min=0)
269 |
270 | boxes[:, :, 2] = torch.clamp(boxes[:, :, 2], max=width)
271 | boxes[:, :, 3] = torch.clamp(boxes[:, :, 3], max=height)
272 |
273 | return boxes
274 |
--------------------------------------------------------------------------------
| |