├── README.md
├── __pycache__
├── assin_priors.cpython-37.pyc
├── basket_utils.cpython-37.pyc
├── dataset.cpython-37.pyc
├── lossfn.cpython-37.pyc
├── model.cpython-37.pyc
├── nms.cpython-37.pyc
├── post_processer.cpython-37.pyc
├── priors.cpython-37.pyc
├── python_nms.cpython-37.pyc
└── viz.cpython-37.pyc
├── a.jpg
├── augment
├── __init__.py
└── augmentor.py
├── augmentor.py
├── basket_utils.py
├── config.py
├── dataset.py
├── datasets
├── images
│ ├── 009814.jpg
│ ├── 045346.jpg
│ ├── 059925.jpg
│ ├── 102954.jpg
│ ├── 121349.jpg
│ ├── 147664.jpg
│ ├── 167843.jpg
│ ├── 201586.jpg
│ ├── 213702.jpg
│ ├── 216026.jpg
│ ├── 250583.jpg
│ ├── 257559.jpg
│ ├── 260655.jpg
│ ├── 273406.jpg
│ ├── 275153.jpg
│ ├── 286036.jpg
│ ├── 308634.jpg
│ ├── 325869.jpg
│ ├── 346044.jpg
│ ├── 447131.jpg
│ ├── 492718.jpg
│ ├── 542925.jpg
│ ├── 567307.jpg
│ ├── 574125.jpg
│ ├── 605272.jpg
│ ├── 630439.jpg
│ ├── 645593.jpg
│ ├── 721074.jpg
│ ├── 721205.jpg
│ ├── 771259.jpg
│ ├── 785985.jpg
│ ├── 795439.jpg
│ ├── 826272.jpg
│ ├── 845712.jpg
│ ├── 847258.jpg
│ ├── 858721.jpg
│ ├── 875555.jpg
│ ├── 884419.jpg
│ ├── 896922.jpg
│ ├── 898712.jpg
│ ├── 916700.jpg
│ ├── 918904.jpg
│ ├── 945782.jpg
│ ├── 959677.jpg
│ ├── 960859.jpg
│ ├── 968216.jpg
│ ├── 974557.jpg
│ ├── 976494.jpg
│ └── 998724.jpg
└── labels
│ ├── 009814.xml
│ ├── 045346.xml
│ ├── 059925.xml
│ ├── 102954.xml
│ ├── 121349.xml
│ ├── 147664.xml
│ ├── 167843.xml
│ ├── 201586.xml
│ ├── 213702.xml
│ ├── 216026.xml
│ ├── 250583.xml
│ ├── 257559.xml
│ ├── 260655.xml
│ ├── 273406.xml
│ ├── 275153.xml
│ ├── 286036.xml
│ ├── 308634.xml
│ ├── 325869.xml
│ ├── 346044.xml
│ ├── 447131.xml
│ ├── 492718.xml
│ ├── 542925.xml
│ ├── 567307.xml
│ ├── 574125.xml
│ ├── 605272.xml
│ ├── 630439.xml
│ ├── 645593.xml
│ ├── 721074.xml
│ ├── 721205.xml
│ ├── 771259.xml
│ ├── 785985.xml
│ ├── 795439.xml
│ ├── 826272.xml
│ ├── 845712.xml
│ ├── 847258.xml
│ ├── 858721.xml
│ ├── 875555.xml
│ ├── 884419.xml
│ ├── 896922.xml
│ ├── 898712.xml
│ ├── 916700.xml
│ ├── 918904.xml
│ ├── 945782.xml
│ ├── 959677.xml
│ ├── 960859.xml
│ ├── 968216.xml
│ ├── 974557.xml
│ ├── 976494.xml
│ └── 998724.xml
├── demo.png
├── lossfn.py
├── model.py
├── nms.py
├── post_processer.py
├── predict.py
├── priors.py
├── python_nms.py
├── rf_erf_visualize.png
├── split_pic.py
├── train.py
├── visualize_demo.py
└── viz.py
/README.md:
--------------------------------------------------------------------------------
1 | # BasketNet
2 |
3 | This is a demo of a LFFD model for a Basketball Recognition Competition.
4 |
5 | The demo is based on SSD and LFFD.
6 | ## Install
7 | ```[cmd]
8 | git clone https://github.com/aoru45/LFFD-Pytorch.git
9 | ```
10 |
11 | ## Usage
12 | Download or make your own dataset and modify the dataset.py file.
13 | ```[cmd]
14 | python train.py
15 | ```
16 |
17 | ## rf and erf visualize:
18 | 
19 |
20 | The visulaization code is avaliable here: https://github.com/aoru45/LFFD-Pytorch/blob/master/visualize_demo.py
21 |
22 | ## Network Structure
23 |
24 | 
25 |
26 |
27 | ## demo result
28 |
29 | 
30 |
31 | ## Reference
32 |
33 | SSD:https://arxiv.org/abs/1512.02325
34 |
35 | LFFD:https://arxiv.org/pdf/1904.10633.pdf
36 |
37 |
38 |
--------------------------------------------------------------------------------
/__pycache__/assin_priors.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/__pycache__/assin_priors.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/basket_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/__pycache__/basket_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/__pycache__/dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/lossfn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/__pycache__/lossfn.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/nms.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/__pycache__/nms.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/post_processer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/__pycache__/post_processer.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/priors.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/__pycache__/priors.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/python_nms.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/__pycache__/python_nms.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/viz.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/__pycache__/viz.cpython-37.pyc
--------------------------------------------------------------------------------
/a.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/a.jpg
--------------------------------------------------------------------------------
/augment/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/augment/augmentor.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | This module provides many types of image augmentation. One can choose appropriate augmentation for
4 | detection, segmentation and classification.
5 | """
6 | import cv2
7 | import numpy
8 | import random
9 |
10 |
11 | class Augmentor(object):
12 | """
13 | All augmentation operations are static methods of this class.
14 | """
15 |
16 | def __init__(self):
17 | pass
18 |
19 | @staticmethod
20 | def histogram_equalisation(image):
21 | """
22 | do histogram equlisation for grayscale image
23 | :param image: input image with single channel 8bits
24 | :return: processed image
25 | """
26 | if image.ndim != 2:
27 | print('Input image is not grayscale!')
28 | return None
29 | if image.dtype != numpy.uint8:
30 | print('Input image is not uint8!')
31 | return None
32 |
33 | result = cv2.equalizeHist(image)
34 | return result
35 |
36 | @staticmethod
37 | def grayscale(image):
38 | """
39 | convert BGR image to grayscale image
40 | :param image: input image with BGR channels
41 | :return:
42 | """
43 | if image.ndim != 3:
44 | return None
45 | if image.dtype != numpy.uint8:
46 | print('Input image is not uint8!')
47 | return None
48 |
49 | result = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
50 | return result
51 |
52 | @staticmethod
53 | def inversion(image):
54 | """
55 | invert the image (255-)
56 | :param image: input image with BGR or grayscale
57 | :return:
58 | """
59 | if image.dtype != numpy.uint8:
60 | print('Input image is not uint8!')
61 | return None
62 |
63 | result = 255 - image
64 | return result
65 |
66 | @staticmethod
67 | def binarization(image, block_size=5, C=10):
68 | """
69 | convert input image to binary image
70 | cv2.adaptiveThreshold is used, for detailed information, refer to opencv docs
71 | :param image:
72 | :return:
73 | """
74 | if image.ndim == 3:
75 | image_grayscale = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
76 | else:
77 | image_grayscale = image
78 |
79 | binary_image = cv2.adaptiveThreshold(image_grayscale, 255, cv2.ADAPTIVE_THRESH_MEAN_C,
80 | cv2.THRESH_BINARY, block_size, C)
81 | return binary_image
82 |
83 | @staticmethod
84 | def brightness(image, min_factor=0.5, max_factor=1.5):
85 | '''
86 | adjust the image brightness
87 | :param image:
88 | :param min_factor:
89 | :param max_factor:
90 | :return:
91 | '''
92 | if image.dtype != numpy.uint8:
93 | print('Input image is not uint8!')
94 | return None
95 |
96 | factor = numpy.random.uniform(min_factor, max_factor)
97 | result = image * factor
98 | if factor > 1:
99 | result[result > 255] = 255
100 | result = result.astype(numpy.uint8)
101 | return result
102 |
103 | @staticmethod
104 | def saturation(image, min_factor=0.5, max_factor=1.5):
105 | '''
106 | adjust the image saturation
107 | :param image:
108 | :param min_factor:
109 | :param max_factor:
110 | :return:
111 | '''
112 | if image.dtype != numpy.uint8:
113 | print('Input image is not uint8!')
114 | return None
115 |
116 | image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
117 | factor = numpy.random.uniform(min_factor, max_factor)
118 |
119 | result = numpy.zeros(image.shape, dtype=numpy.float32)
120 | result[:, :, 0] = image[:, :, 0] * factor + image_gray * (1 - factor)
121 | result[:, :, 1] = image[:, :, 1] * factor + image_gray * (1 - factor)
122 | result[:, :, 2] = image[:, :, 2] * factor + image_gray * (1 - factor)
123 | result[result > 255] = 255
124 | result[result < 0] = 0
125 | result = result.astype(numpy.uint8)
126 | return result
127 |
128 | @staticmethod
129 | def contrast(image, min_factor=0.5, max_factor=1.5):
130 | '''
131 | adjust the image contrast
132 | :param image:
133 | :param min_factor:
134 | :param max_factor:
135 | :return:
136 | '''
137 | if image.dtype != numpy.uint8:
138 | print('Input image is not uint8!')
139 | return None
140 |
141 | image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
142 | gray_mean = numpy.mean(image_gray)
143 | temp = numpy.ones((image.shape[0], image.shape[1]), dtype=numpy.float32) * gray_mean
144 | factor = numpy.random.uniform(min_factor, max_factor)
145 |
146 | result = numpy.zeros(image.shape, dtype=numpy.float32)
147 | result[:, :, 0] = image[:, :, 0] * factor + temp * (1 - factor)
148 | result[:, :, 1] = image[:, :, 1] * factor + temp * (1 - factor)
149 | result[:, :, 2] = image[:, :, 2] * factor + temp * (1 - factor)
150 |
151 | result[result > 255] = 255
152 | result[result < 0] = 0
153 | result = result.astype(numpy.uint8)
154 |
155 | return result
156 |
157 | @staticmethod
158 | def blur(image, mode='random', kernel_size=3, sigma=1):
159 | """
160 |
161 | :param image:
162 | :param mode: options 'normalized' 'gaussian' 'median'
163 | :param kernel_size:
164 | :param sigma: used for gaussian blur
165 | :return:
166 | """
167 | if image.dtype != numpy.uint8:
168 | print('Input image is not uint8!')
169 | return None
170 |
171 | if mode == 'random':
172 | mode = random.choice(['normalized', 'gaussian', 'median'])
173 |
174 | if mode == 'normalized':
175 | result = cv2.blur(image, (kernel_size, kernel_size))
176 | elif mode == 'gaussian':
177 | result = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigmaX=sigma, sigmaY=sigma)
178 | elif mode == 'median':
179 | result = cv2.medianBlur(image, kernel_size)
180 | else:
181 | print('Blur mode is not supported: %s.' % mode)
182 | result = image
183 | return result
184 |
185 | @staticmethod
186 | def rotation(image, degree=10, mode='crop', scale=1):
187 | """
188 |
189 | :param image:
190 | :param degree:
191 | :param mode: 'crop'-keep original size, 'fill'-keep full image
192 | :param scale:
193 | :return:
194 | """
195 | if image.dtype != numpy.uint8:
196 | print('Input image is not uint8!')
197 | return None
198 |
199 | h, w = image.shape[:2]
200 | center_x, center_y = w / 2, h / 2
201 | M = cv2.getRotationMatrix2D((center_x, center_y), degree, scale)
202 |
203 | if mode == 'crop':
204 | new_w, new_h = w, h
205 | else:
206 | cos = numpy.abs(M[0, 0])
207 | sin = numpy.abs(M[0, 1])
208 | new_w = int(h * sin + w * cos)
209 | new_h = int(h * cos + w * sin)
210 | M[0, 2] += (new_w / 2) - center_x
211 | M[1, 2] += (new_h / 2) - center_y
212 |
213 | result = cv2.warpAffine(image, M, (new_w, new_h))
214 | return result
215 |
216 | @staticmethod
217 | def flip(image, orientation='h'):
218 | '''
219 |
220 | :param image:
221 | :param orientation:
222 | :return:
223 | '''
224 | if image.dtype != numpy.uint8:
225 | print('Input image is not uint8!')
226 | return None
227 |
228 | if orientation == 'h':
229 | return cv2.flip(image, 1)
230 | elif orientation == 'v':
231 | return cv2.flip(image, 0)
232 | else:
233 | print('Unsupported orientation: %s.' % orientation)
234 | return image
235 |
236 | @staticmethod
237 | def resize(image, size_in_pixel=None, size_in_scale=None):
238 | """
239 |
240 | :param image:
241 | :param size_in_pixel: tuple (width, height)
242 | :param size_in_scale: tuple (width_scale, height_scale)
243 | :return:
244 | """
245 | if image.dtype != numpy.uint8:
246 | print('Input image is not uint8!')
247 | return None
248 |
249 | if size_in_pixel is not None:
250 | return cv2.resize(image, size_in_pixel)
251 | elif size_in_scale is not None:
252 | return cv2.resize(image, (0, 0), fx=size_in_scale[0], fy=size_in_scale[1])
253 | else:
254 | print('size_in_pixel and size_in_scale are both None.')
255 | return image
256 |
257 | @staticmethod
258 | def crop(image, x, y, width, height):
259 | """
260 |
261 | :param image:
262 | :param x: crop area top-left x coordinate
263 | :param y: crop area top-left y coordinate
264 | :param width: crop area width
265 | :param height: crop area height
266 | :return:
267 | """
268 | if image.dtype != numpy.uint8:
269 | print('Input image is not uint8!')
270 | return None
271 |
272 | if image.ndim == 3:
273 | return image[y:y + height, x:x + width, :]
274 | else:
275 | return image[y:y + height, x:x + width]
276 |
277 | @staticmethod
278 | def random_crop(image, width, height):
279 | """
280 |
281 | :param image:
282 | :param width: crop area width
283 | :param height: crop area height
284 | :return:
285 | """
286 | if image.dtype != numpy.uint8:
287 | print('Input image is not uint8!')
288 | return False, image
289 |
290 | w_interval = image.shape[1] - width
291 | h_interval = image.shape[0] - height
292 |
293 | if image.ndim == 3:
294 | result = numpy.zeros((height, width, 3), dtype=numpy.uint8)
295 | else:
296 | result = numpy.zeros((height, width), dtype=numpy.uint8)
297 |
298 | if w_interval >= 0 and h_interval >= 0:
299 | crop_x, crop_y = random.randint(0, w_interval), random.randint(0, h_interval)
300 | if image.ndim == 3:
301 | result = image[crop_y:crop_y + height, crop_x:crop_x + width, :]
302 | else:
303 | result = image[crop_y:crop_y + height, crop_x:crop_x + width]
304 | elif w_interval < 0 and h_interval >= 0:
305 | put_x = -w_interval / 2
306 | crop_y = random.randint(0, h_interval)
307 | if image.ndim == 3:
308 | result[:, put_x:put_x + image.shape[1], :] = image[crop_y:crop_y + height, :, :]
309 | else:
310 | result[:, put_x:put_x + image.shape[1]] = image[crop_y:crop_y + height, :]
311 | elif w_interval >= 0 and h_interval < 0:
312 | crop_x = random.randint(0, w_interval)
313 | put_y = -h_interval / 2
314 | if image.ndim == 3:
315 | result[put_y:put_y + image.shape[0], :, :] = image[:, crop_x:crop_x + width, :]
316 | else:
317 | result[put_y:put_y + image.shape[0], :] = image[:, crop_x:crop_x + width]
318 | else:
319 | put_x, put_y = -w_interval / 2, -h_interval / 2
320 | if image.ndim == 3:
321 | result[put_y:put_y + image.shape[0], put_x:put_x + image.shape[1], :] = image[:, :, :]
322 | else:
323 | result[put_y:put_y + image.shape[0], put_x:put_x + image.shape[1]] = image[:, :]
324 |
325 | return result
326 |
327 |
--------------------------------------------------------------------------------
/augmentor.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Descripttion: This is Aoru Xue's demo,which is only for reference
3 | @version:
4 | @Author: Aoru Xue
5 | @Date: 2019-09-17 00:45:53
6 | @LastEditors: Aoru Xue
7 | @LastEditTime: 2019-10-02 18:30:47
8 | '''
9 | import cv2 as cv
10 | import numpy as np
11 | from PIL import Image
12 | from albumentations import (
13 |
14 | HorizontalFlip,
15 | Resize,
16 | RandomSizedBBoxSafeCrop,
17 | Compose,
18 | RandomSunFlare,
19 | RandomShadow,
20 | RandomBrightness,
21 | RandomContrast,
22 | RandomCrop,
23 | GaussianBlur,
24 | CenterCrop,
25 | SmallestMaxSize,
26 | PadIfNeeded,
27 | LongestMaxSize
28 | )
29 | from torchvision.transforms import ToTensor
30 | class BasketAug(object):
31 | def __init__(self,to_tensor = True):
32 | augs = [
33 | LongestMaxSize(640),
34 | PadIfNeeded(640,640,cv.BORDER_CONSTANT,value = 0),
35 | #RandomSizedBBoxSafeCrop(height = 300,width = 300),
36 | RandomBrightness(p=0.5),
37 | RandomContrast(p=0.5),
38 |
39 | #RandomSunFlare(p=0.5, flare_roi=(0, 0, 1, 0.5), angle_lower=0.5,src_radius= 150),
40 | RandomShadow(p=0.5, num_shadows_lower=1, num_shadows_upper=1,
41 | shadow_dimension=5, shadow_roi=(0, 0.5, 1, 1)),
42 | HorizontalFlip(p=0.5),
43 | GaussianBlur(p=0.5),
44 | ]
45 | self.transform = Compose(augs,
46 | bbox_params={"format" : "albumentations","min_area": 0,"min_visibility": 0.2,'label_fields': ['category_id']}
47 | )
48 |
49 | def __call__(self,cv_img, boxes=None, labels=None):
50 | auged = self.transform(image = cv_img,bboxes = boxes, category_id = labels)
51 | return ToTensor()(auged["image"]),auged["bboxes"],auged["category_id"]
52 | if __name__ == "__main__":
53 | basket_aug = BasketAug()
54 | img = cv.imread("./a.jpg")
55 | boxes = [[0.2,0.2,0.5,0.5],[0,0,0.1,0.1]]
56 | labels = [0,1]
57 | img,boxes,labels = basket_aug(img,boxes,labels)
58 |
59 | for box in boxes:
60 | xmin,ymin,xmax,ymax = box
61 | xmin*=640
62 | ymin*=640
63 | xmax*=640
64 | ymax*=640
65 | cv.rectangle(img,(int(xmin),int(ymin)),(int(xmax),int(ymax)),(255,0,0),2)
66 | cv.imshow("img",img)
67 | cv.waitKey(0)
68 |
--------------------------------------------------------------------------------
/basket_utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Descripttion: This is Aoru Xue's demo,which is only for reference
3 | @version:
4 | @Author: Aoru Xue
5 | @Date: 2019-09-02 01:32:25
6 | @LastEditors: Aoru Xue
7 | @LastEditTime: 2019-10-02 18:37:41
8 | '''
9 | import torch
10 | import math
11 | from config import *
12 |
13 | def convert_locations_to_boxes(locations, priors, variance = 2):
14 |
15 | # priors can have one dimension less.
16 | if priors.dim() + 1 == locations.dim():
17 | priors = priors.unsqueeze(0)
18 | return torch.cat([
19 | priors[..., :2] - locations[..., :2] * priors[..., 2:],
20 | priors[..., :2] - locations[..., 2:] * priors[..., 2:]
21 | ], dim=locations.dim() - 1)
22 |
23 |
24 | def convert_boxes_to_locations(corner_form_boxes, priors,variance = 2):
25 |
26 | if priors.dim() + 1 == corner_form_boxes.dim():
27 | priors = priors.unsqueeze(0)
28 | return torch.cat([
29 | (priors[..., :2] - corner_form_boxes[..., :2]) / priors[..., 2:],
30 | (priors[..., :2] - corner_form_boxes[..., 2:]) / priors[..., 2:]
31 | ], dim=corner_form_boxes.dim() - 1)
32 |
33 | def area_of(left_top, right_bottom):
34 | hw = torch.clamp(right_bottom - left_top, min=0.0)
35 | return hw[..., 0] * hw[..., 1]
36 |
37 |
38 | def iou_of(boxes0, boxes1, eps=1e-5):
39 | overlap_left_top = torch.max(boxes0[..., :2], boxes1[..., :2]) # 每个predbox与gt左上角坐标比一下,取最大的,因为是两个坐标之间的比较
40 | overlap_right_bottom = torch.min(boxes0[..., 2:], boxes1[..., 2:])
41 |
42 | overlap_area = area_of(overlap_left_top, overlap_right_bottom)
43 | area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
44 | area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
45 | return overlap_area / (area0 + area1 - overlap_area + eps)
46 |
47 | def in_center(priors,gt):
48 | # 1,n,4 m,1,4
49 | s1 = (priors[...,0] > gt[...,0]) & (priors[...,1] > gt[...,1]) & (priors[...,0] < gt[...,2]) & (priors[...,1] < gt[...,3])
50 | return s1
51 | '''
52 | def assign_priors(gt_boxes, gt_labels, priors,
53 | iou_threshold):
54 |
55 | #print(gt_boxes,corner_form_priors)
56 | # size: num_priors x num_targets
57 | # prior 包含 gt
58 | s1 = in_center(priors.unsqueeze(1),gt_boxes.unsqueeze(0)) # 直接匹配在中间的
59 |
60 | #s1[torch.sum(s1,dim = 1) > 1] = False # 同时有多个匹配的
61 | not_ignored = torch.ones(s1.size(0),dtype = torch.uint8)
62 | not_ignored[torch.sum(s1,dim = 1) > 1] = 0
63 | best_target_per_prior, best_target_per_prior_index = s1.max(1)
64 |
65 |
66 | labels = gt_labels[best_target_per_prior_index] # (num_priors,1)
67 | labels[best_target_per_prior == False] = 0 # the backgournd id,没有匹配的给背景
68 |
69 | boxes = gt_boxes[best_target_per_prior_index] #num_priors,4
70 | t = 0
71 | center_form_gt_boxes = corner_form_to_center_form(boxes)
72 | for f,scale in zip(feature_maps,scales):
73 | d = torch.min(center_form_gt_boxes[t:t+f*f, 2],center_form_gt_boxes[t:t+f*f, 3])
74 | condition = (d < scale[0]/image_size) | (d > scale[1]/image_size)
75 | labels[t:t+f*f][condition] = 0
76 |
77 | left_gray_scale = [0.9 *scale[0], scale[0]]
78 | right_gray_scale = [scale[1], 1.1 * scale[1]]
79 |
80 | not_ignored[t:t+f*f][(d > left_gray_scale[0]) & (d < left_gray_scale[1])] = 0
81 | not_ignored[t:t+f*f][(d > right_gray_scale[0]) & (d < right_gray_scale[1])] = 0
82 |
83 | t += f*f
84 | return boxes, labels, not_ignored
85 | '''
86 | def assign_priors(gt_boxes, gt_labels, priors,
87 | iou_threshold):
88 |
89 | #print(gt_boxes,corner_form_priors)
90 | # size: num_priors x num_targets
91 | # prior 包含 gt
92 | s1 = in_center(priors.unsqueeze(1),gt_boxes.unsqueeze(0)) # 直接匹配在中间的
93 | corner_form_priors = center_form_to_corner_form(priors)
94 | ious = iou_of(corner_form_priors.unsqueeze(1),gt_boxes.unsqueeze(0))
95 | #s1[torch.sum(s1,dim = 1) > 1] = False # 同时有多个匹配的
96 | not_ignored = torch.ones(s1.size(0),dtype = torch.uint8)
97 | not_ignored[torch.sum(s1,dim = 1) > 1] = 0
98 | best_target_per_prior, best_target_per_prior_index = (ious*s1.float()).max(1)
99 | best_prior_per_target, best_prior_per_target_index = ious.max(0)
100 | for target_index, prior_index in enumerate(best_prior_per_target_index):
101 | best_target_per_prior_index[prior_index] = target_index
102 | best_target_per_prior.index_fill_(0, best_prior_per_target_index, 2)
103 | labels = gt_labels[best_target_per_prior_index] # (num_priors,1)
104 | labels[best_target_per_prior < iou_threshold] = 0 # the backgournd id,没有匹配的给背景
105 |
106 | boxes = gt_boxes[best_target_per_prior_index] #num_priors,4
107 | t = 0
108 | center_form_gt_boxes = corner_form_to_center_form(boxes)
109 | for f,scale in zip(feature_maps,scales):
110 | d = torch.min(center_form_gt_boxes[t:t+f*f, 2],center_form_gt_boxes[t:t+f*f, 3])
111 | #condition = (d < scale[0]/image_size) | (d > scale[1]/image_size)
112 | #labels[t:t+f*f][condition] = 0
113 |
114 | left_gray_scale = [0.9 *scale[0], scale[0]]
115 | right_gray_scale = [scale[1], 1.1 * scale[1]]
116 |
117 | not_ignored[t:t+f*f][(d > left_gray_scale[0]) & (d < left_gray_scale[1])] = 0
118 | not_ignored[t:t+f*f][(d > right_gray_scale[0]) & (d < right_gray_scale[1])] = 0
119 |
120 | t += f*f
121 | return boxes, labels, not_ignored
122 |
123 | def hard_negative_mining(loss, labels, neg_pos_ratio):
124 | pos_mask = labels > 0
125 | num_pos = pos_mask.long().sum(dim=1, keepdim=True)
126 | num_neg = num_pos * neg_pos_ratio
127 |
128 | loss[pos_mask] = -math.inf
129 | _, indexes = loss.sort(dim=1, descending=True)
130 | _, orders = indexes.sort(dim=1)
131 | neg_mask = orders < num_neg
132 | return pos_mask | neg_mask
133 |
134 |
135 | def center_form_to_corner_form(locations):
136 | return torch.cat([locations[..., :2] - locations[..., 2:] / 2,
137 | locations[..., :2] + locations[..., 2:] / 2], locations.dim() - 1)
138 |
139 |
140 | def corner_form_to_center_form(boxes):
141 | return torch.cat([
142 | (boxes[..., :2] + boxes[..., 2:]) / 2,
143 | boxes[..., 2:] - boxes[..., :2]
144 | ], boxes.dim() - 1)
145 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | num_classes = 3
2 | num_epochs = 500
3 | image_size = 640
4 | strides = [4,4,8,8,16,32,32,32]
5 | feature_maps = [159,159,79,79,39,19,19,19]
6 | scale_factors = [15, 20, 40, 70, 110, 250, 400, 560]
7 | scales = [(10,15),(15,20),(20,40),(40,70),(70,110),(110,250),(250,400),(400,560)]
8 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Descripttion: This is Aoru Xue's demo,which is only for reference
3 | @version:
4 | @Author: Aoru Xue
5 | @Date: 2019-09-10 12:51:20
6 | @LastEditors: Aoru Xue
7 | @LastEditTime: 2019-10-02 18:38:22
8 | '''
9 | import torch
10 | import torch.nn as nn
11 | from torch.utils.data import Dataset
12 | import glob
13 | from PIL import Image
14 | import xml.etree.ElementTree as ET
15 | import numpy as np
16 | from priors import Priors
17 | from basket_utils import *
18 | np.set_printoptions(threshold=np.inf)
19 |
20 |
21 | class BasketDataset(Dataset):
22 | def __init__(self,img_path,transform = None):
23 | self.img_paths = glob.glob(img_path + "/images/*.jpg")
24 | self.labels = [label.replace(".jpg",".xml").replace("images","labels") for label in self.img_paths]
25 | self.class_names = ("__background__","basketball","volleyball")
26 | prior = Priors()
27 | self.priors = prior() # center form
28 | self.imgW,self.imgH = 640,640
29 | self.transform = transform
30 | def __getitem__(self,idx):
31 | img = Image.open(self.img_paths[idx]).convert("RGB")
32 | label_file = self.labels[idx]
33 | gt_bboxes,gt_classes = self._get_annotation(idx)
34 |
35 | if self.transform:
36 | img,gt_bboxes,gt_classes = self.transform(np.array(img),gt_bboxes,gt_classes)
37 | gt_bboxes = torch.tensor(gt_bboxes)
38 | gt_classes = torch.LongTensor(gt_classes)
39 |
40 | gt_bboxes,gt_classes,ignored = assign_priors(gt_bboxes,gt_classes,self.priors,0.5)
41 | locations = convert_boxes_to_locations(gt_bboxes, self.priors,2)
42 |
43 | return [img,locations,gt_classes,ignored]
44 | def _get_annotation(self,idx):
45 | annotation_file = self.labels[idx]
46 | objects = ET.parse(annotation_file).findall("object")
47 | size = ET.parse(annotation_file).find("size")
48 | boxes = []
49 | labels = []
50 | #is_difficult = []
51 | for obj in objects:
52 | class_name = obj.find('name').text.lower().strip()
53 | bbox = obj.find('bndbox')
54 | x1 = float(bbox.find('xmin').text) - 1
55 | y1 = float(bbox.find('ymin').text) - 1
56 | x2 = float(bbox.find('xmax').text) - 1
57 | y2 = float(bbox.find('ymax').text) - 1
58 | imgW = float(size.find('width').text)
59 | imgH = float(size.find('height').text)
60 | boxes.append([x1/imgW,y1/imgH,x2/imgW,y2/imgH])
61 | labels.append(self.class_names.index(class_name))
62 | return boxes,labels
63 | def __len__(self):
64 | return len(self.img_paths)
65 | if __name__ == '__main__':
66 | import random
67 | from augmentor import BasketAug
68 | transform = BasketAug(to_tensor = False)
69 | import cv2 as cv
70 | datset = BasketDataset("../BasketNet_circle/datasets",transform = transform)
71 | img,gt_loc,gt_labels,ignored = datset[random.choice(range(len(datset)))]
72 | #print(ignored.size())
73 | cv_img = np.array(img)
74 | #h,w,_ = cv_img.shape
75 | cv_img = cv.cvtColor(cv_img,cv.COLOR_RGB2BGR)
76 | priors = datset.priors
77 | idx = (gt_labels > 0) & ignored.bool()
78 | loc = convert_locations_to_boxes(gt_loc,datset.priors,2)
79 | loc = loc[idx]
80 | priors = priors[idx]
81 | label = gt_labels[idx]
82 | print(loc.size())
83 | for i in range(priors.size(0)):
84 |
85 | x,y,r = priors[i,:]
86 | #print(x,y,r)
87 | x = x.item() * 640
88 | y = y.item() * 640
89 | r = r.item() * 640
90 | print(x,y,r)
91 | cv.circle(cv_img,(int(x),int(y)),int(r),(255,0,0),2)
92 | cv.imshow("cv",cv_img)
93 | cv.waitKey(0)
94 |
--------------------------------------------------------------------------------
/datasets/images/009814.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/009814.jpg
--------------------------------------------------------------------------------
/datasets/images/045346.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/045346.jpg
--------------------------------------------------------------------------------
/datasets/images/059925.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/059925.jpg
--------------------------------------------------------------------------------
/datasets/images/102954.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/102954.jpg
--------------------------------------------------------------------------------
/datasets/images/121349.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/121349.jpg
--------------------------------------------------------------------------------
/datasets/images/147664.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/147664.jpg
--------------------------------------------------------------------------------
/datasets/images/167843.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/167843.jpg
--------------------------------------------------------------------------------
/datasets/images/201586.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/201586.jpg
--------------------------------------------------------------------------------
/datasets/images/213702.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/213702.jpg
--------------------------------------------------------------------------------
/datasets/images/216026.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/216026.jpg
--------------------------------------------------------------------------------
/datasets/images/250583.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/250583.jpg
--------------------------------------------------------------------------------
/datasets/images/257559.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/257559.jpg
--------------------------------------------------------------------------------
/datasets/images/260655.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/260655.jpg
--------------------------------------------------------------------------------
/datasets/images/273406.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/273406.jpg
--------------------------------------------------------------------------------
/datasets/images/275153.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/275153.jpg
--------------------------------------------------------------------------------
/datasets/images/286036.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/286036.jpg
--------------------------------------------------------------------------------
/datasets/images/308634.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/308634.jpg
--------------------------------------------------------------------------------
/datasets/images/325869.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/325869.jpg
--------------------------------------------------------------------------------
/datasets/images/346044.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/346044.jpg
--------------------------------------------------------------------------------
/datasets/images/447131.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/447131.jpg
--------------------------------------------------------------------------------
/datasets/images/492718.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/492718.jpg
--------------------------------------------------------------------------------
/datasets/images/542925.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/542925.jpg
--------------------------------------------------------------------------------
/datasets/images/567307.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/567307.jpg
--------------------------------------------------------------------------------
/datasets/images/574125.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/574125.jpg
--------------------------------------------------------------------------------
/datasets/images/605272.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/605272.jpg
--------------------------------------------------------------------------------
/datasets/images/630439.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/630439.jpg
--------------------------------------------------------------------------------
/datasets/images/645593.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/645593.jpg
--------------------------------------------------------------------------------
/datasets/images/721074.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/721074.jpg
--------------------------------------------------------------------------------
/datasets/images/721205.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/721205.jpg
--------------------------------------------------------------------------------
/datasets/images/771259.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/771259.jpg
--------------------------------------------------------------------------------
/datasets/images/785985.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/785985.jpg
--------------------------------------------------------------------------------
/datasets/images/795439.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/795439.jpg
--------------------------------------------------------------------------------
/datasets/images/826272.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/826272.jpg
--------------------------------------------------------------------------------
/datasets/images/845712.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/845712.jpg
--------------------------------------------------------------------------------
/datasets/images/847258.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/847258.jpg
--------------------------------------------------------------------------------
/datasets/images/858721.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/858721.jpg
--------------------------------------------------------------------------------
/datasets/images/875555.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/875555.jpg
--------------------------------------------------------------------------------
/datasets/images/884419.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/884419.jpg
--------------------------------------------------------------------------------
/datasets/images/896922.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/896922.jpg
--------------------------------------------------------------------------------
/datasets/images/898712.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/898712.jpg
--------------------------------------------------------------------------------
/datasets/images/916700.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/916700.jpg
--------------------------------------------------------------------------------
/datasets/images/918904.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/918904.jpg
--------------------------------------------------------------------------------
/datasets/images/945782.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/945782.jpg
--------------------------------------------------------------------------------
/datasets/images/959677.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/959677.jpg
--------------------------------------------------------------------------------
/datasets/images/960859.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/960859.jpg
--------------------------------------------------------------------------------
/datasets/images/968216.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/968216.jpg
--------------------------------------------------------------------------------
/datasets/images/974557.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/974557.jpg
--------------------------------------------------------------------------------
/datasets/images/976494.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/976494.jpg
--------------------------------------------------------------------------------
/datasets/images/998724.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/datasets/images/998724.jpg
--------------------------------------------------------------------------------
/datasets/labels/009814.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 009814.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/009814.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/045346.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 045346.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/045346.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
50 |
51 |
--------------------------------------------------------------------------------
/datasets/labels/059925.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 059925.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/059925.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
27 |
--------------------------------------------------------------------------------
/datasets/labels/102954.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 102954.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/102954.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
50 |
51 |
--------------------------------------------------------------------------------
/datasets/labels/121349.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 121349.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/121349.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
50 |
51 |
--------------------------------------------------------------------------------
/datasets/labels/147664.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 147664.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/147664.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/167843.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 167843.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/167843.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/201586.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 201586.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/201586.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
50 |
51 |
--------------------------------------------------------------------------------
/datasets/labels/213702.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 213702.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/213702.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
50 |
51 |
--------------------------------------------------------------------------------
/datasets/labels/216026.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 216026.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/216026.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/250583.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 250583.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/250583.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/257559.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 257559.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/257559.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/260655.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 260655.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/260655.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/273406.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 273406.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/273406.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
27 |
--------------------------------------------------------------------------------
/datasets/labels/275153.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 275153.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/275153.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
50 |
51 |
--------------------------------------------------------------------------------
/datasets/labels/286036.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 286036.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/286036.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
27 |
--------------------------------------------------------------------------------
/datasets/labels/308634.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 308634.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/308634.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
50 |
51 |
--------------------------------------------------------------------------------
/datasets/labels/325869.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 325869.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/325869.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
50 |
51 |
--------------------------------------------------------------------------------
/datasets/labels/346044.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 346044.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/346044.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/447131.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 447131.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/447131.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
50 |
51 |
--------------------------------------------------------------------------------
/datasets/labels/492718.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 492718.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/492718.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/542925.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 542925.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/542925.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/567307.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 567307.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/567307.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/574125.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 574125.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/574125.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/605272.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 605272.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/605272.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/630439.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 630439.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/630439.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/645593.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 645593.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/645593.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/721074.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 721074.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/721074.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
27 |
--------------------------------------------------------------------------------
/datasets/labels/721205.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 721205.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/721205.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/771259.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 771259.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/771259.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/785985.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 785985.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/785985.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
27 |
--------------------------------------------------------------------------------
/datasets/labels/795439.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 795439.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/795439.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
50 |
51 |
--------------------------------------------------------------------------------
/datasets/labels/826272.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 826272.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/826272.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
50 |
51 |
--------------------------------------------------------------------------------
/datasets/labels/845712.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 845712.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/845712.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
27 |
--------------------------------------------------------------------------------
/datasets/labels/847258.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 847258.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/847258.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/858721.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 858721.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/858721.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/875555.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 875555.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/875555.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/884419.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 884419.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/884419.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/896922.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 896922.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/896922.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
27 |
--------------------------------------------------------------------------------
/datasets/labels/898712.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 898712.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/898712.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/916700.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 916700.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/916700.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
27 |
--------------------------------------------------------------------------------
/datasets/labels/918904.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 918904.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/918904.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/945782.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 945782.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/945782.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/959677.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 959677.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/959677.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/960859.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 960859.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/960859.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/968216.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 968216.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/968216.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
27 |
--------------------------------------------------------------------------------
/datasets/labels/974557.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 974557.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/974557.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
38 |
39 |
--------------------------------------------------------------------------------
/datasets/labels/976494.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 976494.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/976494.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
27 |
--------------------------------------------------------------------------------
/datasets/labels/998724.xml:
--------------------------------------------------------------------------------
1 |
2 | images
3 | 998724.jpg
4 | /media/xueaoru/其他/ML/my-SSD/dataset/images/998724.jpg
5 |
6 | Unknown
7 |
8 |
9 | 512
10 | 512
11 | 3
12 |
13 | 0
14 |
26 |
27 |
--------------------------------------------------------------------------------
/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/demo.png
--------------------------------------------------------------------------------
/lossfn.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Descripttion: This is Aoru Xue's demo,which is only for reference
3 | @version:
4 | @Author: Aoru Xue
5 | @Date: 2019-09-10 00:03:42
6 | @LastEditors: Aoru Xue
7 | @LastEditTime: 2019-10-01 23:59:09
8 | '''
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import math
13 | import basket_utils
14 |
15 | class BasketLoss(nn.Module):
16 | def __init__(self):
17 | super(BasketLoss, self).__init__()
18 | self.neg_pos_ratio = 10
19 | def forward(self,scores, predicted_locations, labels, gt_locations,not_ignored):
20 | num_classes = scores.size(2)
21 | with torch.no_grad():
22 | loss = -F.log_softmax(scores, dim=2)[:, :, 0]
23 | mask = basket_utils.hard_negative_mining(loss, labels, self.neg_pos_ratio) & not_ignored
24 |
25 | confidence = scores[mask, :]
26 | classification_loss = F.cross_entropy(confidence.view(-1, num_classes), labels[mask], reduction='sum')
27 |
28 | pos_mask = labels > 0
29 | predicted_locations = predicted_locations[pos_mask, :].view(-1, 4)
30 | gt_locations = gt_locations[pos_mask, :].view(-1, 4)
31 | mse_loss = F.mse_loss(predicted_locations, gt_locations, reduction='sum')
32 | #smooth_l1_loss = F.smooth_l1_loss(predicted_locations, gt_locations, reduction='sum')
33 | num_pos = gt_locations.size(0)
34 | return mse_loss / num_pos, classification_loss / num_pos
35 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Descripttion: This is Aoru Xue's demo,which is only for reference
3 | @version:
4 | @Author: Aoru Xue
5 | @Date: 2019-09-09 23:13:46
6 | @LastEditors: Aoru Xue
7 | @LastEditTime: 2019-10-02 18:39:48
8 | '''
9 |
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from priors import Priors
15 | from basket_utils import *
16 | class SeparableConv2d(nn.Module):
17 | def __init__(self,in_channels,out_channels,kernel_size=3,stride=1,padding=1):
18 | super(SeparableConv2d, self).__init__()
19 | self.conv = nn.Sequential(
20 | nn.Conv2d(in_channels, in_channels, kernel_size = kernel_size, stride=stride, groups=in_channels,padding=padding),
21 | nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride=(1, 1))
22 | )
23 | def forward(self,x):
24 | return self.conv(x)
25 | class ResBlock(nn.Module):
26 | def __init__(self,channels):
27 | super(ResBlock, self).__init__()
28 | self.conv2dRelu = nn.Sequential(
29 | SeparableConv2d(channels,channels,kernel_size=3,stride=1,padding=1),
30 | nn.ReLU6(channels),
31 | SeparableConv2d(channels,channels,kernel_size=3,stride=1,padding=1),
32 | nn.ReLU6(channels)
33 | )
34 | self.relu = nn.ReLU6(channels)
35 | def forward(self,x):
36 | return self.relu(x + self.conv2dRelu(x))
37 | class BasketLossBranch(nn.Module):
38 | def __init__(self,in_channels,out_channels=64,num_classes=3):
39 | super(BasketLossBranch, self).__init__()
40 | self.conv1x1relu = nn.Sequential(
41 | nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1,padding=0),
42 | nn.ReLU6(out_channels)
43 | )
44 | self.score =nn.Sequential(
45 | nn.Conv2d(out_channels,out_channels,kernel_size=1,stride=1,padding=0),
46 | nn.ReLU6(out_channels),
47 | nn.Conv2d(out_channels,num_classes,kernel_size=1,stride=1)
48 | )
49 | self.locations =nn.Sequential(
50 | nn.Conv2d(out_channels,out_channels,kernel_size=1,stride=1,padding=0),
51 | nn.ReLU6(out_channels),
52 | nn.Conv2d(out_channels,4,kernel_size=1,stride=1)
53 | )
54 | def forward(self,x):
55 | score = self.score(self.conv1x1relu(x))
56 | locations = self.locations(self.conv1x1relu(x))
57 | return score,locations
58 | class BasketNet(nn.Module):
59 | def __init__(self,num_classes = 3):
60 | super(BasketNet, self).__init__()
61 | self.num_classes = num_classes
62 | self.priors = None
63 | self.c1 = nn.Sequential(
64 | SeparableConv2d(3,64,kernel_size=3,stride=2,padding=0),
65 | nn.ReLU6(64)
66 | )
67 | self.c2 = nn.Sequential(
68 | SeparableConv2d(64,64,kernel_size=3,stride=2,padding=0),
69 | nn.ReLU6(64)
70 | )
71 | self.tinypart1 = nn.Sequential(
72 | ResBlock(64),
73 | ResBlock(64),
74 | ResBlock(64)
75 | )
76 | self.tinypart2 = ResBlock(64)
77 | self.c11 = nn.Sequential(
78 | nn.Conv2d(64,64,kernel_size=3,stride=2,padding=0),
79 | nn.ReLU6(64)
80 | )
81 | self.smallpart1 = ResBlock(64)
82 | self.smallpart2 = ResBlock(64)
83 | self.c16 = nn.Sequential(
84 | nn.Conv2d(64,128,kernel_size=3,stride=2,padding=0),
85 | nn.ReLU6(128)
86 | )
87 | self.mediumpart = ResBlock(128)
88 | self.c19 = nn.Sequential(
89 | nn.Conv2d(128,128,kernel_size=3,stride=2,padding=0),
90 | nn.ReLU6(128)
91 | )
92 | self.largepart1 = ResBlock(128)
93 | self.largepart2 = ResBlock(128)
94 | self.largepart3 = ResBlock(128)
95 |
96 |
97 | self.lossbranch1 = BasketLossBranch(64,num_classes = self.num_classes)
98 | self.lossbranch2 = BasketLossBranch(64,num_classes = self.num_classes)
99 | self.lossbranch3 = BasketLossBranch(64,num_classes = self.num_classes)
100 | self.lossbranch4 = BasketLossBranch(64,num_classes = self.num_classes)
101 | self.lossbranch5 = BasketLossBranch(128,num_classes = self.num_classes)
102 | self.lossbranch6 = BasketLossBranch(128,num_classes = self.num_classes)
103 | self.lossbranch7 = BasketLossBranch(128,num_classes = self.num_classes)
104 | self.lossbranch8 = BasketLossBranch(128,num_classes = self.num_classes)
105 | def forward(self, x):
106 | c1 = self.c1(x)
107 | c2 = self.c2(c1)
108 |
109 | c8 = self.tinypart1(c2)
110 | c10 = self.tinypart2(c8)
111 |
112 | c11 = self.c11(c10)
113 | c13 = self.smallpart1(c11)
114 | c15 = self.smallpart2(c13)
115 |
116 | c16 = self.c16(c15)
117 | c18 = self.mediumpart(c16)
118 |
119 | c19 = self.c19(c18)
120 | c21 = self.largepart1(c19)
121 | c23 = self.largepart2(c21)
122 | c25 = self.largepart3(c23)
123 |
124 | score1,loc1 = self.lossbranch1(c8)
125 | score2,loc2 = self.lossbranch2(c10)
126 | score3,loc3 = self.lossbranch3(c13)
127 | score4,loc4 = self.lossbranch4(c15)
128 | score5,loc5 = self.lossbranch5(c18)
129 | #print(loc1.size(),loc2.size(),loc3.size(),loc4.size(),loc5.size())
130 | score6,loc6 = self.lossbranch6(c21)
131 | score7,loc7 = self.lossbranch7(c23)
132 | score8,loc8 = self.lossbranch8(c25)
133 |
134 | cls = torch.cat([score1.permute(0, 2, 3, 1).contiguous().view(score1.size(0),-1, self.num_classes),
135 | score2.permute(0, 2, 3, 1).contiguous().view(score2.size(0),-1, self.num_classes),
136 | score3.permute(0, 2, 3, 1).contiguous().view(score3.size(0), -1, self.num_classes),
137 | score4.permute(0, 2, 3, 1).contiguous().view(score4.size(0), -1, self.num_classes),
138 | score5.permute(0, 2, 3, 1).contiguous().view(score5.size(0), -1, self.num_classes),
139 | score6.permute(0, 2, 3, 1).contiguous().view(score6.size(0), -1, self.num_classes),
140 | score7.permute(0, 2, 3, 1).contiguous().view(score7.size(0), -1, self.num_classes),
141 | score8.permute(0, 2, 3, 1).contiguous().view(score8.size(0), -1, self.num_classes)], dim=1)
142 | loc = torch.cat([loc1.permute(0, 2, 3, 1).contiguous().view(loc1.size(0), -1,4),
143 | loc2.permute(0, 2, 3, 1).contiguous().view(loc1.size(0), -1,4),
144 | loc3.permute(0, 2, 3, 1).contiguous().view(loc1.size(0), -1,4),
145 | loc4.permute(0, 2, 3, 1).contiguous().view(loc1.size(0), -1,4),
146 | loc5.permute(0, 2, 3, 1).contiguous().view(loc1.size(0), -1,4),
147 | loc6.permute(0, 2, 3, 1).contiguous().view(loc1.size(0), -1,4),
148 | loc7.permute(0, 2, 3, 1).contiguous().view(loc1.size(0), -1,4),
149 | loc8.permute(0, 2, 3, 1).contiguous().view(loc1.size(0), -1,4)], dim=1)
150 |
151 | if not self.training:
152 | if self.priors is None:
153 | self.priors = Priors()() # center form
154 | #self.priors = self.priors.cuda()
155 | boxes = convert_locations_to_boxes(
156 | loc, self.priors, 2
157 | )# corner_form
158 | cls = F.softmax(cls, dim=2)
159 | return cls, boxes
160 | else:
161 | #print(confidences.size(),locations.size())
162 | return (cls,loc) # (2,1111,3) (2,1111,4)
163 |
164 |
165 | from torchsummary import summary
166 | if __name__ == '__main__':
167 | model = BasketNet()
168 | summary(model,(3,640,640),device = "cpu")
169 |
--------------------------------------------------------------------------------
/nms.py:
--------------------------------------------------------------------------------
1 | try:
2 | import torch_extension
3 |
4 | _nms = torch_extension.nms
5 | except ImportError:
6 | from python_nms import python_nms
7 |
8 | _nms = python_nms
9 |
10 |
11 | def boxes_nms(boxes, scores, nms_thresh, max_count=-1):
12 | """ Performs non-maximum suppression, run on GPU or CPU according to
13 | boxes's device.
14 | Args:
15 | boxes(Tensor): `xyxy` mode boxes, use absolute coordinates(not support relative coordinates),
16 | shape is (n, 4)
17 | scores(Tensor): scores, shape is (n, )
18 | nms_thresh(float): thresh
19 | max_count (int): if > 0, then only the top max_proposals are kept after non-maximum suppression
20 | Returns:
21 | indices kept.
22 | """
23 | keep = _nms(boxes, scores, nms_thresh)
24 | if max_count > 0:
25 | keep = keep[:max_count]
26 | return keep
27 |
--------------------------------------------------------------------------------
/post_processer.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Descripttion: This is Aoru Xue's demo,which is only for reference
3 | @version:
4 | @Author: Aoru Xue
5 | @Date: 2019-09-13 21:06:46
6 | @LastEditors: Aoru Xue
7 | @LastEditTime: 2019-10-03 00:23:31
8 | '''
9 | import torch
10 |
11 | from nms import boxes_nms
12 |
13 |
14 | class PostProcessor:
15 | def __init__(self,
16 | iou_threshold = 0.10,
17 | score_threshold = 0.8,
18 | image_size = 640,
19 | max_per_class=200,
20 | max_per_image=-1):
21 | self.confidence_threshold = score_threshold
22 | self.iou_threshold = iou_threshold
23 | self.width = image_size
24 | self.height = image_size
25 | self.max_per_class = max_per_class
26 | self.max_per_image = max_per_image
27 |
28 | def __call__(self, confidences, locations, width=None, height=None, batch_ids=None):
29 | if width is None:
30 | width = self.width
31 | if height is None:
32 | height = self.height
33 |
34 | batch_size = confidences.size(0)
35 | if batch_ids is None:
36 | batch_ids = torch.arange(batch_size, device=confidences.device)
37 | else:
38 | batch_ids = torch.tensor(batch_ids, device=confidences.device)
39 |
40 | locations = locations[batch_ids]
41 | confidences = confidences[batch_ids]
42 |
43 | results = []
44 | for decoded_boxes, scores in zip(locations, confidences):
45 | # per batch
46 | filtered_boxes = []
47 | filtered_labels = []
48 | filtered_probs = []
49 | for class_index in range(1, scores.size(1)):
50 | probs = scores[:, class_index]
51 | mask = probs > self.confidence_threshold
52 | probs = probs[mask]
53 | if probs.size(0) == 0:
54 | continue
55 | boxes = decoded_boxes[mask, :] # x1,y1,x2,y2
56 | #ratio = (width + height) / 2.
57 | boxes[:, 0] *= width
58 | boxes[:, 2] *= width
59 | boxes[:, 1] *= height
60 | boxes[:, 3] *= height
61 |
62 | keep = boxes_nms(boxes, probs, self.iou_threshold, self.max_per_class)
63 |
64 | boxes = boxes[keep, :]
65 | labels = torch.tensor([class_index] * keep.size(0))
66 | probs = probs[keep]
67 |
68 | filtered_boxes.append(boxes)
69 | filtered_labels.append(labels)
70 | filtered_probs.append(probs)
71 |
72 | # no object detected
73 | if len(filtered_boxes) == 0:
74 | filtered_boxes = torch.empty(0, 4)
75 | filtered_labels = torch.empty(0)
76 | filtered_probs = torch.empty(0)
77 | else: # cat all result
78 | filtered_boxes = torch.cat(filtered_boxes, 0)
79 | filtered_labels = torch.cat(filtered_labels, 0)
80 | filtered_probs = torch.cat(filtered_probs, 0)
81 | if 0 < self.max_per_image < filtered_probs.size(0):
82 | keep = torch.argsort(filtered_probs, dim=0, descending=True)[:self.max_per_image]
83 | filtered_boxes = filtered_boxes[keep, :]
84 | filtered_labels = filtered_labels[keep]
85 | filtered_probs = filtered_probs[keep]
86 | results.append((filtered_boxes, filtered_labels, filtered_probs))
87 | return results
88 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Descripttion: This is Aoru Xue's demo,which is only for reference
3 | @version:
4 | @Author: Aoru Xue
5 | @Date: 2019-09-02 21:08:56
6 | @LastEditors: Aoru Xue
7 | @LastEditTime: 2019-10-03 00:22:04
8 | '''
9 | import torch
10 | import torchvision
11 | from model import BasketNet
12 | from torchvision import transforms
13 | #from transforms import *
14 | from PIL import Image
15 | import numpy as np
16 | from viz import draw_bounding_boxes,draw_circles
17 | from post_processer import PostProcessor
18 | from priors import Priors
19 |
20 | prior = Priors()
21 | center_form_priors = prior()
22 |
23 | post_process = PostProcessor()
24 | color_map= {
25 | 1: (0,0,255),
26 | 2: (0,255,255)
27 | }
28 | transform = transforms.Compose([
29 | transforms.Resize((640,640)),
30 | transforms.ToTensor(),
31 | #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
32 | ]
33 | )
34 |
35 |
36 | def pic_test():
37 | img = Image.open("./test2.jpg").convert('RGB')
38 | image = np.array(img,dtype = np.float32)
39 | height, width, _ = image.shape
40 | img = transform(img)
41 | img = img.unsqueeze(0)
42 | #img = img.cuda()
43 | net = BasketNet()
44 |
45 | net.load_state_dict(torch.load("./ckpt/1748.pth",map_location="cpu"))
46 | #net.cuda()
47 | print("network load...")
48 | net.eval()
49 | with torch.no_grad():
50 | pred_confidence,boxes = net(img)
51 |
52 |
53 | output = post_process(pred_confidence,boxes, width=width, height=height)[0]
54 | #print(output)
55 | boxes, labels, scores = [o.to("cpu").numpy() for o in output]
56 |
57 |
58 | drawn_image = draw_bounding_boxes(image, boxes, labels, scores, ("__background__","basketball","volleyball")).astype(np.uint8)
59 |
60 | Image.fromarray(drawn_image).save("./a.jpg")
61 |
62 | def cap_test():
63 | import cv2 as cv
64 | cap = cv.VideoCapture("../test.mp4")
65 | net = BasketNet()
66 |
67 | net.load_state_dict(torch.load("./ckpt/1748.pth",map_location="cpu"))
68 | #net.cuda()
69 | net.eval()
70 | while True:
71 | ret,frame = cap.read()
72 | if not ret:
73 | break
74 |
75 | height,width,_ = frame.shape
76 | center_width = width//2
77 | frame = frame[:,center_width-height//2:center_width + height//2]
78 | height,width,_ = frame.shape
79 | cv_img = cv.cvtColor(frame,cv.COLOR_BGR2RGB)
80 | img = Image.fromarray(cv_img)
81 | img = transform(img)
82 | img = img.unsqueeze(0)
83 | #img = img.cuda()
84 |
85 | with torch.no_grad():
86 | pred_confidence,boxes = net(img)
87 |
88 | output = post_process(pred_confidence,boxes, width=width, height=height)[0]
89 | boxes, labels, scores = [o.to("cpu").numpy() for o in output]
90 | drawn_image = draw_bounding_boxes(frame, boxes, labels, scores, ("__background__","basketball","volleyball")).astype(np.uint8)
91 |
92 | cv.imshow("img",drawn_image)
93 | key = cv.waitKey(1)
94 | if key == ord("q"):
95 | break
96 | cv.destroyAllWindows()
97 | cap.release()
98 | if __name__ == "__main__":
99 | pic_test()
100 |
--------------------------------------------------------------------------------
/priors.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Descripttion: This is Aoru Xue's demo,which is only for reference
3 | @version:
4 | @Author: Aoru Xue
5 | @Date: 2019-09-10 01:02:36
6 | @LastEditors: Aoru Xue
7 | @LastEditTime: 2019-10-02 18:36:33
8 | '''
9 | from itertools import product
10 |
11 | import torch
12 | from math import sqrt
13 | from config import *
14 |
15 | class Priors:
16 | def __init__(self,clip = True):
17 | self.image_size = image_size
18 | self.strides = strides
19 | self.feature_maps = feature_maps
20 | self.clip = clip
21 | def __call__(self):
22 | priors = []
23 | for k, f in enumerate(self.feature_maps):
24 | # 513/4 = 128.25
25 | # 126/128.25 = 0.98245
26 | # 126.5/128.25 = 0.98635
27 | scale = self.image_size / self.strides[k]
28 | for i, j in product(range(f), repeat=2):
29 |
30 | cx = (j + 0.5) / scale
31 | cy = (i + 0.5) / scale
32 | scale_factor = scale_factors[k] * 0.5 / self.image_size
33 | #r = self.sizes[k]
34 | #r = r/self.image_size
35 | #h = w = self.sizes[k] / self.image_size
36 | priors.append([cx, cy, scale_factor])
37 |
38 | priors = torch.tensor(priors) #(num_priors,4)
39 | if self.clip:
40 | priors.clamp_(max=1, min=0)
41 | return priors
42 | def get_priors():
43 | return Priors()
44 | if __name__ == "__main__":
45 | priors = Priors()
46 | print(priors().size())
47 |
--------------------------------------------------------------------------------
/python_nms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def python_nms(boxes, scores, nms_thresh, max_count=-1):
6 | """ Performs non-maximum suppression using numpy
7 | Args:
8 | boxes(Tensor): `xyxy` mode boxes, use absolute coordinates(not support relative coordinates),
9 | shape is (n, 4)
10 | scores(Tensor): scores, shape is (n, )
11 | nms_thresh(float): thresh
12 | max_count (int): if > 0, then only the top max_proposals are kept after non-maximum suppression
13 | Returns:
14 | indices kept.
15 | """
16 | if boxes.numel() == 0:
17 | return torch.empty((0,), dtype=torch.long)
18 | # Use numpy to run nms. Running nms in PyTorch code on CPU is really slow.
19 | origin_device = boxes.device
20 | cpu_device = torch.device('cpu')
21 | boxes = boxes.to(cpu_device).numpy()
22 | scores = scores.to(cpu_device).numpy()
23 |
24 | x1 = boxes[:, 0]
25 | y1 = boxes[:, 1]
26 | x2 = boxes[:, 2]
27 | y2 = boxes[:, 3]
28 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
29 | order = np.argsort(scores)[::-1]
30 | num_detections = boxes.shape[0]
31 | suppressed = np.zeros((num_detections,), dtype=np.bool)
32 | for _i in range(num_detections):
33 | i = order[_i]
34 | if suppressed[i]:
35 | continue
36 | ix1 = x1[i]
37 | iy1 = y1[i]
38 | ix2 = x2[i]
39 | iy2 = y2[i]
40 | iarea = areas[i]
41 |
42 | for _j in range(_i + 1, num_detections):
43 | j = order[_j]
44 | if suppressed[j]:
45 | continue
46 |
47 | xx1 = max(ix1, x1[j])
48 | yy1 = max(iy1, y1[j])
49 | xx2 = min(ix2, x2[j])
50 | yy2 = min(iy2, y2[j])
51 | w = max(0, xx2 - xx1 + 1)
52 | h = max(0, yy2 - yy1 + 1)
53 |
54 | inter = w * h
55 | ovr = inter / (iarea + areas[j] - inter)
56 | if ovr >= nms_thresh:
57 | suppressed[j] = True
58 | keep = np.nonzero(suppressed == 0)[0]
59 | if max_count > 0:
60 | keep = keep[:max_count]
61 | keep = torch.from_numpy(keep).to(origin_device)
62 | return keep
63 |
--------------------------------------------------------------------------------
/rf_erf_visualize.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aoru45/LFFD-Pytorch/b76596b7af32e2fddad78275ccd90c09ac8e20bb/rf_erf_visualize.png
--------------------------------------------------------------------------------
/split_pic.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Descripttion: This is Aoru Xue's demo,which is only for reference
3 | @version:
4 | @Author: Aoru Xue
5 | @Date: 2019-09-10 18:22:23
6 | @LastEditors: Aoru Xue
7 | @LastEditTime: 2019-09-10 18:29:11
8 | '''
9 | import cv2 as cv
10 | import random
11 |
12 | if __name__ == "__main__":
13 | video_path = "/home/xueaoru/视频/VID_20180815_142126.mp4"
14 | cap = cv.VideoCapture(video_path)
15 | idx = 0
16 | while True:
17 | ret,frame = cap.read()
18 | if not ret:
19 | break
20 | idx +=1
21 | if idx % 30 ==0:
22 | cv.imwrite("./images/{:0>6d}.jpg".format(random.randint(0,999999)),frame)
23 | cv.destroyAllWindows()
24 |
25 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Descripttion: This is Aoru Xue's demo,which is only for reference
3 | @version:
4 | @Author: Aoru Xue
5 | @Date: 2019-09-10 13:56:50
6 | @LastEditors: Aoru Xue
7 | @LastEditTime: 2019-10-02 20:14:08
8 | '''
9 | import torch
10 | from torch.utils.data import DataLoader
11 | import torch.nn as nn
12 | import torch.optim as optim
13 | from model import BasketNet
14 | from torchvision import transforms
15 | from dataset import BasketDataset
16 | from lossfn import BasketLoss
17 | from tqdm import tqdm
18 | from augmentor import BasketAug
19 | from config import *
20 |
21 | if __name__ == '__main__':
22 | transform = BasketAug()
23 | #transform = transforms.Compose([
24 | # transforms.Resize(512),
25 | # transforms.ToTensor(),
26 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
27 | #])
28 | dataset = BasketDataset("../BasketNet_circle/datasets",transform = transform)
29 | dataloader = DataLoader(dataset,batch_size = 16,shuffle = True,num_workers = 6,pin_memory=True)
30 |
31 | net = BasketNet(num_classes = num_classes)
32 | net.cuda()
33 | #net.load_state_dict(torch.load("./ckpt/3090.pth"))
34 | optimizer = optim.Adam(net.parameters(),lr = 1e-3)
35 | #optimizer = optim.SGD(net.parameters(), lr=1e-1,momentum=0.9,weight_decay=0.)
36 | #scheduler = optim.lr_scheduler.StepLR(optimizer,100,0.1)
37 | loss_fn = BasketLoss()
38 | num_batches = len(dataloader)
39 | min_loss = float("inf")
40 | for epoch in range(num_epochs):
41 | epoch_loss_cls = 0.
42 | epoch_loss_reg = 0.
43 |
44 | for img,gt_pos,gt_labels,not_ignored in tqdm(dataloader):
45 | img = img.cuda()
46 | gt_pos =gt_pos.cuda()
47 | not_ignored = not_ignored.cuda()
48 | gt_labels =gt_labels.cuda()
49 | cls,loc = net(img)
50 | reg_loss,cls_loss = loss_fn(cls,loc,gt_labels,gt_pos,not_ignored)
51 | epoch_loss_cls += cls_loss.item()
52 | epoch_loss_reg += reg_loss.item()
53 | loss = (reg_loss + cls_loss)
54 | optimizer.zero_grad()
55 | loss.backward()
56 | optimizer.step()
57 | #scheduler.step()
58 | print("cls_loss:{}---reg_loss:{}".format(epoch_loss_cls/num_batches,epoch_loss_reg/num_batches))
59 | if (epoch_loss_cls/num_batches + epoch_loss_reg/num_batches) < min_loss:
60 | min_loss = epoch_loss_cls/num_batches + epoch_loss_reg/num_batches
61 | torch.save(net.state_dict(),"./ckpt/{}.pth".format(int(epoch_loss_cls/num_batches * 1000 + epoch_loss_reg/num_batches * 1000)))
62 |
--------------------------------------------------------------------------------
/visualize_demo.py:
--------------------------------------------------------------------------------
1 | import cv2 as cv
2 | import torch
3 | import torch.nn as nn
4 | import torchvision
5 | from torchvision.models import ResNet
6 | import numpy as np
7 | class ResBlock(nn.Module):
8 | def __init__(self,channels):
9 | super(ResBlock, self).__init__()
10 | self.conv2dRelu = nn.Sequential(
11 | nn.Conv2d(channels,channels,kernel_size=3,stride=1,padding=1),
12 | nn.ReLU(channels),
13 | nn.Conv2d(channels,channels,kernel_size=3,stride=1,padding=1),
14 | nn.ReLU(channels)
15 | )
16 | self.relu = nn.ReLU(channels)
17 | def forward(self,x):
18 | return self.relu(x + self.conv2dRelu(x))
19 | class TinyNet(nn.Module):
20 | def __init__(self):
21 | super(TinyNet,self).__init__()
22 | self.c1 = nn.Sequential(
23 | nn.Conv2d(3,64,kernel_size=3,stride=2,padding=0),
24 | nn.ReLU(64)
25 | )
26 | self.c2 = nn.Sequential(
27 | nn.Conv2d(64,64,kernel_size=3,stride=2,padding=0),
28 | nn.ReLU(64)
29 | )
30 | self.tinypart1 = nn.Sequential(
31 | ResBlock(64),
32 | ResBlock(64),
33 | ResBlock(64)
34 | )
35 | def forward(self,x):
36 | c1 = self.c1(x)
37 | c2 = self.c2(c1)
38 | c8 = self.tinypart1(c2)
39 | return c8
40 | if __name__ == "__main__":
41 | model = TinyNet()
42 | for module in model.modules():
43 | try:
44 | nn.init.constant_(module.weight, 0.05)
45 | nn.init.zeros_(module.bias)
46 | nn.init.zeros_(module.running_mean)
47 | nn.init.ones_(module.running_var)
48 | except Exception as e:
49 | pass
50 | if type(module) is nn.BatchNorm2d:
51 | module.eval()
52 | x = torch.ones(1,3,640,640,requires_grad= True)
53 | pred = model(x)
54 | grad = torch.zeros_like(pred, requires_grad= True)
55 | grad[0, 0, 64, 64] = 1
56 | pred.backward(gradient = grad)
57 | grad_input = x.grad[0,0,...].data.numpy()
58 | grad_input = grad_input / np.max(grad_input)
59 | # 有效感受野 0.75 - 0.85
60 | #grad_input = np.where(grad_input>0.85,1,0)
61 | grad_input = np.where(grad_input>0.75,1,0)
62 | # 注释掉即为感受野
63 | grad_input = (grad_input * 255).astype(np.uint8)
64 | kernel = np.ones((5, 5), np.uint8)
65 | grad_input = cv.dilate(grad_input, kernel=kernel)
66 |
67 | contours, _ = cv.findContours(grad_input, mode=cv.RETR_EXTERNAL, method=cv.CHAIN_APPROX_SIMPLE)
68 | rect = cv.boundingRect(contours[0])
69 | print(rect[-2:])
70 | cv.imshow( "a",grad_input)
71 | cv.waitKey(0)
--------------------------------------------------------------------------------
/viz.py:
--------------------------------------------------------------------------------
1 | '''
2 | @Descripttion: This is Aoru Xue's demo,which is only for reference
3 | @version:
4 | @Author: Aoru Xue
5 | @Date: 2019-09-13 21:06:56
6 | @LastEditors: Aoru Xue
7 | @LastEditTime: 2019-10-02 17:00:40
8 | '''
9 | import numpy as np
10 | from six.moves import range
11 | import PIL.Image as Image
12 | import PIL.ImageDraw as ImageDraw
13 | import PIL.ImageFont as ImageFont
14 |
15 | STANDARD_COLORS = [
16 | 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque',
17 | 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite',
18 | 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan',
19 | 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange',
20 | 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet',
21 | 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite',
22 | 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod',
23 | 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki',
24 | 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue',
25 | 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey',
26 | 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue',
27 | 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime',
28 | 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid',
29 | 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen',
30 | 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin',
31 | 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed',
32 | 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed',
33 | 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple',
34 | 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown',
35 | 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue',
36 | 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow',
37 | 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White',
38 | 'WhiteSmoke', 'Yellow', 'YellowGreen'
39 | ]
40 |
41 | NUM_COLORS = len(STANDARD_COLORS)
42 |
43 | try:
44 | FONT = ImageFont.truetype('arial.ttf', 24)
45 | except IOError:
46 | FONT = ImageFont.load_default()
47 |
48 |
49 | def _draw_single_box(image, xmin, ymin, xmax, ymax, color='black', display_str=None, font=None, thickness=2):
50 | draw = ImageDraw.Draw(image)
51 | left, right, top, bottom = xmin, xmax, ymin, ymax
52 | draw.line([(left, top), (left, bottom), (right, bottom), (right, top), (left, top)], width=thickness, fill=color)
53 | if display_str is not None:
54 | text_bottom = bottom
55 | # Reverse list and print from bottom to top.
56 | text_width, text_height = font.getsize(display_str)
57 | margin = np.ceil(0.05 * text_height)
58 | draw.rectangle(
59 | [(left, text_bottom - text_height - 2 * margin), (left + text_width, text_bottom)], fill=color)
60 | draw.text((left + margin, text_bottom - text_height - margin),
61 | display_str,
62 | fill='black',
63 | font=font)
64 |
65 | return image
66 |
67 | def _draw_single_circle(image, xmin, ymin, xmax, ymax, color='black', display_str=None, font=None, thickness=2):
68 | draw = ImageDraw.Draw(image)
69 | left, right, top, bottom = xmin, xmax, ymin, ymax
70 | draw.ellipse([(left, top), (right, bottom)], width=thickness)
71 | if display_str is not None:
72 | text_bottom = bottom
73 | text_width, text_height = font.getsize(display_str)
74 | margin = np.ceil(0.05 * text_height)
75 | draw.text((left + margin, text_bottom - text_height - margin),
76 | display_str,
77 | fill='black',
78 | font=font)
79 | # text_bottom = bottom
80 | # Reverse list and print from bottom to top.
81 | # text_width, text_height = font.getsize(display_str)
82 | # margin = np.ceil(0.05 * text_height)
83 | #draw.rectangle(
84 | # [(left, text_bottom - text_height - 2 * margin), (left + text_width, text_bottom)], fill=color)
85 | #draw.text((left + margin, text_bottom - text_height - margin),
86 | # display_str,
87 | # fill='black',
88 | # font=font)
89 |
90 | return image
91 | def draw_circles(image, circles, labels=None, probs=None, class_name_map=None):
92 | num_circles = circles.shape[0]
93 | gt_boxes_new = circles.copy()
94 | draw_image = Image.fromarray(np.uint8(image))
95 | for i in range(num_circles):
96 | display_str = None
97 | this_class = 0
98 | if labels is not None:
99 | this_class = labels[i]
100 | class_name = class_name_map[this_class] if class_name_map is not None else str(this_class)
101 | class_name = class_name.decode('utf-8') if isinstance(class_name, bytes) else class_name
102 | if probs is not None:
103 | prob = probs[i]
104 | display_str = '{}:{:.2f}'.format(class_name, prob)
105 | else:
106 | display_str = class_name
107 | draw_image = _draw_single_circle(image=draw_image,
108 | xmin=gt_boxes_new[i, 0] - gt_boxes_new[i, 2],
109 | ymin=gt_boxes_new[i, 1] - gt_boxes_new[i, 2],
110 | xmax=gt_boxes_new[i, 0] + gt_boxes_new[i, 2],
111 | ymax=gt_boxes_new[i, 1] + gt_boxes_new[i, 2],
112 | color=STANDARD_COLORS[this_class % NUM_COLORS],
113 | display_str=display_str,
114 | font=FONT)
115 |
116 | image = np.array(draw_image, dtype=np.float32)
117 | return image
118 |
119 |
120 | def draw_bounding_boxes(image, boxes, labels=None, probs=None, class_name_map=None):
121 | """Draw bboxes(labels, probs) on image
122 | Args:
123 | image: numpy array image, shape should be (height, width, channel)
124 | boxes: bboxes, shape should be (N, 4), and each row is (xmin, ymin, xmax, ymax)
125 | labels: labels, shape: (N, )
126 | probs: label scores, shape: (N, ), can be False/True or None
127 | class_name_map: list or dict, map class id to class name for visualization.
128 | can be False/True or None
129 | Returns:
130 | An image with information drawn on it.
131 | """
132 | num_boxes = boxes.shape[0]
133 | gt_boxes_new = boxes.copy()
134 | draw_image = Image.fromarray(np.uint8(image))
135 | for i in range(num_boxes):
136 | display_str = None
137 | this_class = 0
138 | if labels is not None:
139 | this_class = labels[i]
140 | class_name = class_name_map[this_class] if class_name_map is not None else str(this_class)
141 | class_name = class_name.decode('utf-8') if isinstance(class_name, bytes) else class_name
142 | if probs is not None:
143 | prob = probs[i]
144 | display_str = '{}:{:.2f}'.format(class_name, prob)
145 | else:
146 | display_str = class_name
147 | draw_image = _draw_single_box(image=draw_image,
148 | xmin=gt_boxes_new[i, 0],
149 | ymin=gt_boxes_new[i, 1],
150 | xmax=gt_boxes_new[i, 2],
151 | ymax=gt_boxes_new[i, 3],
152 | color=STANDARD_COLORS[this_class % NUM_COLORS],
153 | display_str=display_str,
154 | font=FONT)
155 |
156 | image = np.array(draw_image, dtype=np.float32)
157 | return image
158 |
--------------------------------------------------------------------------------