├── .gitignore ├── README.md ├── Yolov1_demo ├── test1.png ├── test10.png ├── test11.png ├── test12.png ├── test13.png ├── test14.png ├── test15.png ├── test16.png ├── test17.png ├── test18.png ├── test19.png ├── test2.png ├── test20.png ├── test21.png ├── test3.png ├── test4.png ├── test5.png ├── test6.png ├── test7.png ├── test8.png └── test9.png └── src ├── __init__.py ├── augmentations.py ├── config.py ├── dataset.py ├── detect.py ├── lr_scheduler.py ├── model.py ├── predict.py ├── test.py ├── train.py ├── utils.py └── yolov1loss.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/ 3 | __pycache__/ 4 | *.pkl 5 | test_img/ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YOLOv1-Pytorch implement 2 | 3 | 4 | 5 | ## Detials 6 | 7 | You can see more details about Yolo by clicking the link below: 8 | 9 | 10 | 11 | 12 | 13 | ## Pretrained Model 14 | 15 | You can download pretrained ResNet-Backbone YOLOv1 weight here: 16 | 17 | https://pan.baidu.com/s/1YnPqOepzAbr9T_z4Ux-Ocg 18 | 19 | 20 | 21 | ## Demo 22 | 23 | ![](./Yolov1_demo/test1.png) 24 | 25 | ![](./Yolov1_demo/test2.png) 26 | 27 | ![](./Yolov1_demo/test3.png) 28 | 29 | ![](./Yolov1_demo/test4.png) 30 | 31 | ![](./Yolov1_demo/test37.png) 32 | 33 | ![](./Yolov1_demo/test6.png) 34 | 35 | ![](./Yolov1_demo/test10.png) 36 | 37 | ![](./Yolov1_demo/test7.png) 38 | 39 | 40 | 41 | ## Features 42 | 43 | * auto-save and load mechanism, default dir is './model' 44 | * base-net can be chose from pre-trained ResNet18、ResNet50 and ResNet101 45 | 46 | 47 | 48 | ## Setup 49 | 50 | the data set dir should be like this: 51 | 52 | ```` 53 | base_dir 54 | VOC2007 55 | Annotations 56 | ImageSets 57 | JPEGImages 58 | SegmentationClass 59 | SegmentationObject 60 | VOC2012 61 | Annotations 62 | ImageSets 63 | JPEGImages 64 | SegmentationClass 65 | SegmentationObject 66 | ```` 67 | 68 | you can setup the base_dir by this: 69 | 70 | ``` 71 | python train.py --voc_data_set_root /base_dir 72 | ``` 73 | 74 | 75 | 76 | ## usage 77 | 78 | * you can train model by this: 79 | 80 | * ``` 81 | python train.py --voc_data_set_root /media/charles/750GB/VOC0712trainval --num_workers 12 --batch_size 16 --backbone resnet50 --save_step 500 82 | ``` 83 | 84 | * assign the dir of pic in predict.py and run to see the demo 85 | 86 | * model weights will be uploaded afterwards 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /Yolov1_demo/test1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test1.png -------------------------------------------------------------------------------- /Yolov1_demo/test10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test10.png -------------------------------------------------------------------------------- /Yolov1_demo/test11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test11.png -------------------------------------------------------------------------------- /Yolov1_demo/test12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test12.png -------------------------------------------------------------------------------- /Yolov1_demo/test13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test13.png -------------------------------------------------------------------------------- /Yolov1_demo/test14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test14.png -------------------------------------------------------------------------------- /Yolov1_demo/test15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test15.png -------------------------------------------------------------------------------- /Yolov1_demo/test16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test16.png -------------------------------------------------------------------------------- /Yolov1_demo/test17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test17.png -------------------------------------------------------------------------------- /Yolov1_demo/test18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test18.png -------------------------------------------------------------------------------- /Yolov1_demo/test19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test19.png -------------------------------------------------------------------------------- /Yolov1_demo/test2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test2.png -------------------------------------------------------------------------------- /Yolov1_demo/test20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test20.png -------------------------------------------------------------------------------- /Yolov1_demo/test21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test21.png -------------------------------------------------------------------------------- /Yolov1_demo/test3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test3.png -------------------------------------------------------------------------------- /Yolov1_demo/test4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test4.png -------------------------------------------------------------------------------- /Yolov1_demo/test5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test5.png -------------------------------------------------------------------------------- /Yolov1_demo/test6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test6.png -------------------------------------------------------------------------------- /Yolov1_demo/test7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test7.png -------------------------------------------------------------------------------- /Yolov1_demo/test8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test8.png -------------------------------------------------------------------------------- /Yolov1_demo/test9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/Yolov1_demo/test9.png -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/src/__init__.py -------------------------------------------------------------------------------- /src/augmentations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | import cv2 4 | import numpy as np 5 | import types 6 | from numpy import random 7 | from PIL import Image 8 | 9 | TEST_MODE = False 10 | 11 | 12 | def intersect(box_a, box_b): 13 | max_xy = np.minimum(box_a[:, 2:], box_b[2:]) 14 | min_xy = np.maximum(box_a[:, :2], box_b[:2]) 15 | inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) 16 | return inter[:, 0] * inter[:, 1] 17 | 18 | 19 | def jaccard_numpy(box_a, box_b): 20 | r""" 21 | Compute the jaccard overlap of two sets of boxes. The jaccard overlap 22 | is simply the intersection over union of two boxes. 23 | E.g.: 24 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 25 | Args: 26 | box_a: Multiple bounding boxes, Shape: [num_boxes,4] 27 | box_b: Single bounding box, Shape: [4] 28 | Return: 29 | jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]] 30 | """ 31 | inter = intersect(box_a, box_b) 32 | area_a = ((box_a[:, 2] - box_a[:, 0]) * 33 | (box_a[:, 3] - box_a[:, 1])) # [A,B] 34 | area_b = ((box_b[2] - box_b[0]) * 35 | (box_b[3] - box_b[1])) # [A,B] 36 | union = area_a + area_b - inter 37 | return inter / union # [A,B] 38 | 39 | 40 | # 和transform.Compose 一个作用 41 | class Compose(object): 42 | r""" 43 | Composes several augmentations together. 44 | Args: 45 | transforms (List[Transform]): list of transforms to compose. 46 | Example: 47 | >>> augmentations.Compose([ 48 | >>> transforms.CenterCrop(10), 49 | >>> transforms.ToTensor(), 50 | >>> ]) 51 | """ 52 | 53 | def __init__(self, transforms): 54 | self.transforms = transforms 55 | 56 | def __call__(self, img, boxes=None, labels=None): 57 | for t in self.transforms: 58 | img, boxes, labels = t(img, boxes, labels) 59 | return img, boxes, labels 60 | 61 | 62 | class Lambda(object): 63 | """ 64 | Applies a lambda as a transform. 65 | """ 66 | 67 | def __init__(self, lambd): 68 | assert isinstance(lambd, types.LambdaType) 69 | self.lambd = lambd 70 | 71 | def __call__(self, img, boxes=None, labels=None): 72 | return self.lambd(img, boxes, labels) 73 | 74 | 75 | # 将图片转化为np.float32的格式 76 | class ConvertFromInts(object): 77 | def __call__(self, image, boxes=None, labels=None): 78 | return image.astype(np.float32), boxes, labels 79 | 80 | 81 | # 对图片减去其三通道的均值 82 | class SubtractMeans(object): 83 | def __init__(self, mean): 84 | self.mean = np.array(mean, dtype=np.float32) 85 | 86 | def __call__(self, image, boxes=None, labels=None): 87 | image = image.astype(np.float32) 88 | image -= self.mean 89 | return image.astype(np.float32), boxes, labels 90 | 91 | 92 | # 将坐标从相对坐标[0~1]转化为绝对的像素点坐标 93 | class ToAbsoluteCoords(object): 94 | def __call__(self, image, boxes=None, labels=None): 95 | height, width, channels = image.shape 96 | boxes[:, 0] *= width 97 | boxes[:, 2] *= width 98 | boxes[:, 1] *= height 99 | boxes[:, 3] *= height 100 | 101 | return image, boxes, labels 102 | 103 | 104 | # 将坐标绝对的像素点坐标转化为相对坐标[0~1] 105 | class ToPercentCoords(object): 106 | def __call__(self, image, boxes=None, labels=None): 107 | height, width, channels = image.shape 108 | boxes[:, 0] /= width 109 | boxes[:, 2] /= width 110 | boxes[:, 1] /= height 111 | boxes[:, 3] /= height 112 | 113 | return image, boxes, labels 114 | 115 | 116 | # 将图片转化为对应尺寸大小 117 | class AdaptiveResize(object): 118 | def __init__(self, min_size=600, max_size=1000): 119 | self.min_size = min_size 120 | self.max_size = max_size 121 | 122 | def __call__(self, image, boxes=None, labels=None): 123 | # resize according to the rules: 124 | # 1. scale shorter side to IMAGE_MIN_SIDE 125 | # 2. after scaling, if longer side > IMAGE_MAX_SIDE, scale longer side to IMAGE_MAX_SIDE 126 | h, w, _ = image.shape 127 | # print('h:{} | w:{}'.format(h, w)) 128 | img_min_size = min(h, w) 129 | img_max_size = max(h, w) 130 | ratio = float(self.min_size) / float(img_min_size) 131 | if img_max_size * ratio > self.max_size: 132 | ratio *= (float(self.max_size) / float(img_max_size * ratio)) 133 | new_h, new_w = h * ratio, w * ratio 134 | # print('new_h:{} | new_w:{} | ratio:{}'.format(new_h, new_w, ratio)) 135 | # resize(src, (w, h)) 136 | image = cv2.resize(image, (int(new_w), int(new_h))) 137 | if TEST_MODE: 138 | Image.fromarray(image.astype(np.uint8)[..., (2, 1, 0)]).show() 139 | print(f'image.shape:{image.shape}') 140 | return image, boxes, labels 141 | 142 | 143 | class Resize(object): 144 | def __init__(self, size): 145 | self.size = size 146 | 147 | def __call__(self, image, boxes=None, labels=None): 148 | image = cv2.resize(image, (int(self.size), int(self.size))) 149 | return image, boxes, labels 150 | 151 | 152 | class RandomSaturation(object): 153 | def __init__(self, lower=0.5, upper=1.5): 154 | self.lower = lower 155 | self.upper = upper 156 | assert self.upper >= self.lower, "contrast upper must be >= lower." 157 | assert self.lower >= 0, "contrast lower must be non-negative." 158 | 159 | def __call__(self, image, boxes=None, labels=None): 160 | # 有1/2的几率对HSV中的S通道进行缩放操作 161 | if random.randint(2): 162 | # 生成均一分布的一个数 \in(self.lower, self.upper) 163 | image[:, :, 1] *= random.uniform(self.lower, self.upper) 164 | 165 | return image, boxes, labels 166 | 167 | 168 | class RandomHue(object): 169 | def __init__(self, delta=18.0): 170 | assert delta >= 0.0 and delta <= 360.0 171 | self.delta = delta 172 | 173 | def __call__(self, image, boxes=None, labels=None): 174 | # 有1/2的几率对HSV中的H通道进行偏移操作 175 | if random.randint(2): 176 | image[:, :, 0] += random.uniform(-self.delta, self.delta) 177 | image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0 178 | image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 179 | return image, boxes, labels 180 | 181 | 182 | class RandomLightingNoise(object): 183 | def __init__(self): 184 | self.perms = ((0, 1, 2), (0, 2, 1), 185 | (1, 0, 2), (1, 2, 0), 186 | (2, 0, 1), (2, 1, 0)) 187 | 188 | def __call__(self, image, boxes=None, labels=None): 189 | # 有1/2的几率对BGR?通道进行shuffle 190 | if random.randint(2): 191 | swap = self.perms[random.randint(len(self.perms))] 192 | shuffle = SwapChannels(swap) # shuffle channels 193 | image = shuffle(image) 194 | return image, boxes, labels 195 | 196 | 197 | # 将图片在BGR和HSV通道之间转换 198 | class ConvertColor(object): 199 | def __init__(self, current='BGR', transform='HSV'): 200 | self.transform = transform 201 | self.current = current 202 | 203 | def __call__(self, image, boxes=None, labels=None): 204 | if self.current == 'BGR' and self.transform == 'HSV': 205 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 206 | elif self.current == 'HSV' and self.transform == 'BGR': 207 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 208 | else: 209 | raise NotImplementedError 210 | return image, boxes, labels 211 | 212 | 213 | class RandomContrast(object): 214 | def __init__(self, lower=0.5, upper=1.5): 215 | self.lower = lower 216 | self.upper = upper 217 | assert self.upper >= self.lower, "contrast upper must be >= lower." 218 | assert self.lower >= 0, "contrast lower must be non-negative." 219 | 220 | # expects float image 221 | def __call__(self, image, boxes=None, labels=None): 222 | # 有1/2几率对RGB三个通道的颜色数值进行缩放 223 | if random.randint(2): 224 | alpha = random.uniform(self.lower, self.upper) 225 | image *= alpha 226 | return image, boxes, labels 227 | 228 | 229 | class RandomBrightness(object): 230 | def __init__(self, delta=32): 231 | assert delta >= 0.0 232 | assert delta <= 255.0 233 | self.delta = delta 234 | 235 | def __call__(self, image, boxes=None, labels=None): 236 | # 有1/2几率对RGB三个通道的颜色数值进行增减 237 | if random.randint(2): 238 | delta = random.uniform(-self.delta, self.delta) 239 | image += delta 240 | return image, boxes, labels 241 | 242 | 243 | # 将tensor图片转化为cv的图片 244 | class ToCV2Image(object): 245 | def __call__(self, tensor, boxes=None, labels=None): 246 | return tensor.cpu().numpy().astype(np.float32).transpose((1, 2, 0)), boxes, labels 247 | 248 | 249 | # 将cv的图片转化为tensor图片 250 | class ToTensor(object): 251 | def __call__(self, cvimage, boxes=None, labels=None): 252 | return torch.from_numpy(cvimage.astype(np.float32)).permute(2, 0, 1), boxes, labels 253 | 254 | 255 | # 随机裁剪 256 | class RandomSampleCrop(object): 257 | """Crop 258 | Arguments: 259 | img (Image): the image being input during training 260 | boxes (Tensor): the original bounding boxes in pt form 261 | labels (Tensor): the class labels for each bbox 262 | mode (float tuple): the min and max jaccard overlaps 263 | Return: 264 | (img, boxes, classes) 265 | img (Image): the cropped image 266 | boxes (Tensor): the adjusted bounding boxes in pt form? 267 | labels (Tensor): the class labels for each bbox 268 | """ 269 | 270 | def __init__(self): 271 | self.sample_options = ( 272 | # using entire original input image 273 | None, 274 | # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9 275 | (0.1, None), 276 | (0.3, None), 277 | (0.7, None), 278 | (0.9, None), 279 | # randomly sample a patch 280 | (None, None), 281 | ) 282 | 283 | def __call__(self, image, boxes=None, labels=None): 284 | height, width, _ = image.shape 285 | while True: 286 | # randomly choose a mode 287 | mode = random.choice(self.sample_options) 288 | if mode is None: 289 | return image, boxes, labels 290 | 291 | min_iou, max_iou = mode 292 | if min_iou is None: 293 | min_iou = float('-inf') 294 | if max_iou is None: 295 | max_iou = float('inf') 296 | 297 | # max trails (50) 298 | for _ in range(50): 299 | current_image = image 300 | # 将w, h, 在30%~100%之间缩放 301 | w = random.uniform(0.3 * width, width) 302 | h = random.uniform(0.3 * height, height) 303 | 304 | # aspect ratio constraint b/t .5 & 2 305 | if h / w < 0.5 or h / w > 2: 306 | continue 307 | 308 | # 随机选择一个左上角 309 | left = random.uniform(width - w) 310 | top = random.uniform(height - h) 311 | 312 | # convert to integer rect x1,y1,x2,y2 313 | rect = np.array([int(left), int(top), int(left + w), int(top + h)]) 314 | 315 | # calculate IoU (jaccard overlap) b/t the cropped and gt boxes 316 | overlap = jaccard_numpy(boxes, rect) 317 | 318 | # is min and max overlap constraint satisfied? if not try again 319 | # 检查物体与裁剪区域的overlap情况 320 | if overlap.min() < min_iou and max_iou < overlap.max(): 321 | continue 322 | 323 | # cut the crop from the image 324 | current_image = current_image[rect[1]:rect[3], rect[0]:rect[2], :] 325 | 326 | # keep overlap with gt box IF center in sampled patch 327 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0 328 | 329 | # mask in all gt boxes that above and to the left of centers 330 | m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) 331 | 332 | # mask in all gt boxes that under and to the right of centers 333 | m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) 334 | 335 | # mask in that both m1 and m2 are true 336 | mask = m1 * m2 337 | 338 | # have any valid boxes? try again if not 339 | if not mask.any(): 340 | continue 341 | 342 | # take only matching gt boxes 343 | current_boxes = boxes[mask, :].copy() 344 | 345 | # take only matching gt labels 346 | current_labels = labels[mask] 347 | 348 | # should we use the box left and top corner or the crop's 349 | current_boxes[:, :2] = np.maximum(current_boxes[:, :2], rect[:2]) 350 | # adjust to crop (by substracting crop's left,top) 351 | current_boxes[:, :2] -= rect[:2] 352 | 353 | current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], rect[2:]) 354 | # adjust to crop (by substracting crop's left,top) 355 | current_boxes[:, 2:] -= rect[:2] 356 | 357 | return current_image, current_boxes, current_labels 358 | 359 | 360 | # 随机扩大,用数据集三通道均值来垫底 361 | class Expand(object): 362 | def __init__(self, mean): 363 | self.mean = mean 364 | 365 | def __call__(self, image, boxes, labels): 366 | if random.randint(2): 367 | return image, boxes, labels 368 | 369 | height, width, depth = image.shape 370 | # ratio取值范围是(1,2,3) 371 | ratio = random.uniform(1, 4) 372 | # 选择左上角坐标,用于放入原始图片 373 | left = random.uniform(0, width * ratio - width) 374 | top = random.uniform(0, height * ratio - height) 375 | 376 | expand_image = np.zeros( 377 | (int(height * ratio), int(width * ratio), depth), 378 | dtype=image.dtype) 379 | # 和均值一致的颜色 380 | expand_image[:, :, :] = self.mean 381 | expand_image[int(top):int(top + height), int(left):int(left + width)] = image 382 | image = expand_image 383 | 384 | # 转换bounding boxes的坐标 385 | boxes = boxes.copy() 386 | boxes[:, :2] += (int(left), int(top)) 387 | boxes[:, 2:] += (int(left), int(top)) 388 | 389 | return image, boxes, labels 390 | 391 | 392 | class RandomMirror(object): 393 | def __call__(self, image, boxes, classes): 394 | _, width, _ = image.shape 395 | if random.randint(2): 396 | image = image[:, ::-1] 397 | boxes = boxes.copy() 398 | boxes[:, 0::2] = width - boxes[:, 2::-2] 399 | return image, boxes, classes 400 | 401 | 402 | class SwapChannels(object): 403 | """ 404 | Transforms a tensorized image by swapping the channels in the order 405 | specified in the swap tuple. 406 | Args: 407 | swaps (int triple): final order of channels 408 | eg: (2, 1, 0) 409 | """ 410 | 411 | def __init__(self, swaps): 412 | self.swaps = swaps 413 | 414 | def __call__(self, image): 415 | """ 416 | Args: 417 | image (Tensor): image tensor to be transformed 418 | Return: 419 | a tensor with channels swapped according to swap 420 | """ 421 | # if torch.is_tensor(image): 422 | # image = image.data.cpu().numpy() 423 | # else: 424 | # image = np.array(image) 425 | image = image[:, :, self.swaps] 426 | return image 427 | 428 | 429 | class PhotometricDistort(object): 430 | def __init__(self): 431 | self.pd = [ 432 | RandomContrast(), 433 | ConvertColor(transform='HSV'), 434 | RandomSaturation(), 435 | RandomHue(), 436 | ConvertColor(current='HSV', transform='BGR'), 437 | RandomContrast() 438 | ] 439 | self.rand_brightness = RandomBrightness() 440 | self.rand_light_noise = RandomLightingNoise() 441 | 442 | def __call__(self, image, boxes, labels): 443 | im = image.copy() 444 | im, boxes, labels = self.rand_brightness(im, boxes, labels) 445 | if random.randint(2): 446 | distort = Compose(self.pd[:-1]) 447 | else: 448 | distort = Compose(self.pd[1:]) 449 | im, boxes, labels = distort(im, boxes, labels) 450 | return self.rand_light_noise(im, boxes, labels) 451 | 452 | 453 | class SSDAugmentation(object): 454 | def __init__(self, size=300, mean=(104, 117, 123), expand_and_crop=False): 455 | self.mean = mean 456 | self.augment = Compose([ 457 | # 将图片转化为np.float32 458 | ConvertFromInts(), 459 | # 得到框的绝对坐标 460 | ToAbsoluteCoords(), 461 | PhotometricDistort(), 462 | Expand(self.mean), 463 | RandomSampleCrop(), 464 | RandomMirror(), 465 | # 相对坐标 466 | ToPercentCoords(), 467 | Resize(size=size), 468 | SubtractMeans(self.mean) 469 | ]) if expand_and_crop else Compose([ 470 | # 将图片转化为np.float32 471 | ConvertFromInts(), 472 | # 得到框的绝对坐标 473 | ToAbsoluteCoords(), 474 | PhotometricDistort(), 475 | # Expand(self.mean), 476 | # RandomSampleCrop(), 477 | RandomMirror(), 478 | # 相对坐标 479 | ToPercentCoords(), 480 | Resize(size=size), 481 | SubtractMeans(self.mean) 482 | ]) 483 | 484 | # 输入的是 cv_img 485 | def __call__(self, img, boxes, labels): 486 | return self.augment(img, boxes, labels) 487 | 488 | 489 | # 对图片减去其三通道的均值 490 | class Divide255(object): 491 | def __init__(self): 492 | self.max_uint = np.array(255.0, dtype=np.float32) 493 | 494 | def __call__(self, image, boxes=None, labels=None): 495 | image = image.astype(np.float32) 496 | image /= self.max_uint 497 | return image.astype(np.float32), boxes, labels 498 | 499 | 500 | # 随机扩大,用数据集三通道均值来垫底 501 | class ExpandTo1000x1000(object): 502 | def __init__(self, mean): 503 | self.mean = mean 504 | 505 | def __call__(self, image, boxes, labels): 506 | height, width, depth = image.shape 507 | # ratio取值范围是(1,2,3) 508 | # 选择左上角坐标,用于放入原始图片 509 | left = random.uniform(0, 1000 - width) 510 | top = random.uniform(0, 1000 - height) 511 | 512 | expand_image = np.zeros((1000, 1000, depth), dtype=image.dtype) 513 | # 和均值一致的颜色 514 | # print(expand_image[int(top):int(top + height), int(left):int(left + width)].shape) 515 | # print(image.shape) 516 | # print(left, top) 517 | expand_image[:, :, :] = self.mean 518 | expand_image[int(top):int(top + height), int(left):int(left + width)] = image 519 | image = expand_image 520 | 521 | # 转换bounding boxes的坐标 522 | boxes = boxes.copy() 523 | boxes[:, :2] += (int(left), int(top)) 524 | boxes[:, 2:] += (int(left), int(top)) 525 | 526 | return image, boxes, labels 527 | 528 | 529 | class FasterRCNNAugmentation(object): 530 | def __init__(self, min_size=600, max_size=1000, mean=(104, 117, 123), percent_coord=False): 531 | self.mean = mean 532 | self.min_size = min_size 533 | self.max_size = max_size 534 | self.augment = Compose([ 535 | # 将图片转化为np.float32 536 | ConvertFromInts(), 537 | AdaptiveResize(min_size=self.min_size, max_size=self.max_size), 538 | # 得到框的绝对坐标 539 | ToAbsoluteCoords(), 540 | PhotometricDistort(), 541 | ExpandTo1000x1000(self.mean), 542 | RandomMirror(), 543 | # 相对坐标 544 | ToPercentCoords(), 545 | # SubtractMeans(self.mean) 546 | Divide255() 547 | ]) if percent_coord else Compose([ 548 | # 将图片转化为np.float32 549 | ConvertFromInts(), 550 | AdaptiveResize(min_size=self.min_size, max_size=self.max_size), 551 | # 得到框的绝对坐标 552 | ToAbsoluteCoords(), 553 | PhotometricDistort(), 554 | # ExpandTo1000x1000(self.mean), 555 | RandomMirror(), 556 | # 相对坐标 557 | # ToPercentCoords(), 558 | # SubtractMeans(self.mean) 559 | Divide255() 560 | ]) 561 | 562 | # 输入的是 cv_img 563 | def __call__(self, img, boxes, labels): 564 | return self.augment(img, boxes, labels) 565 | 566 | 567 | class Yolov1Augmentation(object): 568 | def __init__(self, size=448, mean=(104, 117, 123), percent_coord=False): 569 | self.augment = Compose([ 570 | # 将图片转化为np.float32 571 | ConvertFromInts(), 572 | # 得到框的绝对坐标 573 | ToAbsoluteCoords(), 574 | PhotometricDistort(), 575 | Expand(mean), 576 | RandomMirror(), 577 | # 相对坐标 578 | ToPercentCoords(), 579 | Divide255(), 580 | Resize(size=size), 581 | ]) if percent_coord else Compose([ 582 | # 将图片转化为np.float32 583 | ConvertFromInts(), 584 | # 得到框的绝对坐标 585 | ToAbsoluteCoords(), 586 | PhotometricDistort(), 587 | Expand(mean), 588 | RandomMirror(), 589 | Divide255(), 590 | Resize(size=size), 591 | ]) 592 | 593 | def __call__(self, img, boxes, labels): 594 | return self.augment(img, boxes, labels) 595 | 596 | 597 | class Yolov1TestAugmentation(object): 598 | def __init__(self, size=448, percent_coord=False): 599 | self.augment = Compose([ 600 | # 将图片转化为np.float32 601 | ConvertFromInts(), 602 | # 得到框的绝对坐标 603 | ToAbsoluteCoords(), 604 | # 相对坐标 605 | ToPercentCoords(), 606 | Divide255(), 607 | Resize(size=size), 608 | ]) if percent_coord else Compose([ 609 | # 将图片转化为np.float32 610 | ConvertFromInts(), 611 | # 得到框的绝对坐标 612 | ToAbsoluteCoords(), 613 | Divide255(), 614 | Resize(size=size), 615 | ]) 616 | 617 | def __call__(self, img, boxes, labels): 618 | return self.augment(img, boxes, labels) 619 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Tuple, List 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | CHANNEL_MEANS = (104, 117, 123) 10 | # IMAGE_MIN_SIDE: float = 600.0 11 | # IMAGE_MAX_SIDE: float = 1000.0 12 | 13 | # ANCHOR_RATIOS: List[Tuple[int, int]] = [(1, 2), (1, 1), (2, 1)] 14 | # ANCHOR_SIZES: List[int] = [128, 256, 512] 15 | 16 | # RPN_PRE_NMS_TOP_N: int = 12000 17 | # RPN_POST_NMS_TOP_N: int = 2000 18 | 19 | # ANCHOR_SMOOTH_L1_LOSS_BETA: float = 1.0 20 | # PROPOSAL_SMOOTH_L1_LOSS_BETA: float = 1.0 21 | 22 | LEARNING_RATE: float = 0.001 23 | MOMENTUM: float = 0.9 24 | WEIGHT_DECAY: float = 0.0005 25 | STEP_LR_SIZES: List[int] = [200000, 400000] 26 | STEP_LR_GAMMA: float = 0.1 27 | WARM_UP_FACTOR: float = 0.1 28 | WARM_UP_NUM_ITERS: int = 1000 29 | 30 | NUM_STEPS_TO_SAVE: int = 100 31 | NUM_STEPS_TO_SNAPSHOT: int = 10000 32 | NUM_STEPS_TO_FINISH: int = 600000 33 | 34 | 35 | YOLOv1_PIC_SIZE = 448 36 | VOC_DATA_SET_ROOT = '' 37 | MODEL_SAVE_DIR = '../model' 38 | GRID_NUM = 7 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | """VOC Dataset Classes 2 | 3 | Original author: Francisco Massa 4 | https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py 5 | 6 | Updated by: Ellis Brown, Max deGroot 7 | """ 8 | import os.path as osp 9 | import sys 10 | import torch 11 | import torch.utils.data as data 12 | import cv2 13 | import numpy as np 14 | from config import GRID_NUM 15 | 16 | # sys.version_info(major=3, minor=7, micro=1, releaselevel='final', serial=0) 17 | # 根据Python版本导入模块 18 | if sys.version_info[0] == 2: 19 | import xml.etree.cElementTree as ET 20 | else: 21 | import xml.etree.ElementTree as ET 22 | 23 | VOC_CLASSES = ( # always index 0 24 | 'aeroplane', 'bicycle', 'bird', 'boat', 25 | 'bottle', 'bus', 'car', 'cat', 'chair', 26 | 'cow', 'diningtable', 'dog', 'horse', 27 | 'motorbike', 'person', 'pottedplant', 28 | 'sheep', 'sofa', 'train', 'tvmonitor' 29 | ) 30 | 31 | 32 | class VOCAnnotationTransform(object): 33 | # 将VOC的标注转换为 (x,y,w,h,class), class为上面VOC_CLASSES的序号 34 | """Transforms a VOC annotation into a Tensor of bbox coords and label index 35 | Initilized with a dictionary lookup of classnames to indexes 36 | 37 | Arguments: 38 | class_to_ind (dict, optional): dictionary lookup of classnames -> indexes 39 | (default: alphabetic indexing of VOC's 20 classes) 40 | keep_difficult (bool, optional): keep difficult instances or not 41 | (default: False) 42 | height (int): height 43 | width (int): width 44 | """ 45 | 46 | def __init__(self, class_to_ind=None, keep_difficult=False): 47 | self.class_to_ind = class_to_ind or dict( 48 | # 将物体名称与0~class数量绑定 49 | zip(VOC_CLASSES, range(len(VOC_CLASSES)))) 50 | self.keep_difficult = keep_difficult 51 | 52 | # 可调用对象 53 | def __call__(self, target, width, height): 54 | """Arguments: 55 | target (annotation) : the target annotation to be made usable 56 | will be an ET.Element 57 | Returns: 58 | a list containing lists of bounding boxes [bbox coords, class name] 59 | """ 60 | res = [] 61 | # 对ET.Element 里面名字为'object'的对象进行遍历 62 | # 具体用法:https://www.cnblogs.com/ifantastic/archive/2013/04/12/3017110.html 63 | for obj in target.iter('object'): 64 | # difficult VOC文档里的含义,标为 1 表示难以辨认 65 | # ‘difficult’: an object marked as ‘difficult’ indicates that the object is considered 66 | # difficult to recognize, for example an object which is clearly visible but unidentifiable 67 | # without substantial use of context. Objects marked as difficult are currently ignored 68 | # in the evaluation of the challenge. 69 | difficult = int(obj.find('difficult').text) == 1 70 | # 检测目标为难以检测而且self.keep_difficult标记为1才继续进行操作 71 | if not self.keep_difficult and difficult: 72 | continue 73 | 74 | # 用法解释: 75 | # str = "00000003210Runoob01230000000"; 76 | # print str.strip( '0' ); # 去除首尾字符 0 77 | name = obj.find('name').text.lower().strip() 78 | # 数据格式: 79 | # 80 | # 174 81 | # 101 82 | # 349 83 | # 351 84 | # 85 | bbox = obj.find('bndbox') 86 | pts = ['xmin', 'ymin', 'xmax', 'ymax'] 87 | bndbox = [] 88 | # 得到一组0~1.0范围的值 89 | for i, pt in enumerate(pts): 90 | # bbox 数值为像素点的位置,从1开始取所以要减去1? 91 | cur_pt = int(bbox.find(pt).text) - 1 92 | # scale height or width 93 | cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height 94 | bndbox.append(cur_pt) 95 | label_idx = self.class_to_ind[name] 96 | # 查找name类别对应的标号 97 | bndbox.append(label_idx) 98 | res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind] 99 | # img_id = target.find('filename').text[:-4] 100 | # res: tensor[ ,5] i.e. [xmin, ymin, xmax, ymax, label_ind], ... ] 101 | return res 102 | 103 | 104 | class VOCDetection(data.Dataset): 105 | """VOC Detection Dataset Object 106 | 107 | input is image, target is annotation 108 | 109 | Arguments: 110 | root (string): filepath to VOCdevkit folder. 111 | image_set (string): imageset to use (eg. 'train', 'val', 'test') 112 | transform (callable, optional): transformation to perform on the 113 | input image 114 | target_transform (callable, optional): transformation to perform on the 115 | target `annotation` 116 | (eg: take in caption string, return tensor of word indices) 117 | dataset_name (string, optional): which dataset to load 118 | (default: 'VOC2007') 119 | """ 120 | 121 | def __init__(self, root, # /VOCdevkit ? 122 | image_sets=(('2007', 'trainval'), ('2012', 'trainval')), 123 | transform=None, 124 | target_transform=VOCAnnotationTransform(), 125 | dataset_name='VOC0712'): 126 | self.root = root 127 | self.image_set = image_sets 128 | self.transform = transform 129 | self.target_transform = target_transform 130 | self.name = dataset_name 131 | # 标记文本的位置 132 | self._annopath = osp.join('%s', 'Annotations', '%s.xml') 133 | # 图片的位置 134 | self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg') 135 | self.ids = list() 136 | for (year, name) in image_sets: 137 | # ./root/VOC2007 138 | rootpath = osp.join(self.root, 'VOC' + year) 139 | # /root/VOC2007/ImageSets/Main/trainval.txt 140 | for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')): 141 | # (./root/VOC2007, Image_ID) 142 | self.ids.append((rootpath, line.strip())) 143 | 144 | def __getitem__(self, index): 145 | # im为图片,gt=get_target 146 | im, gt, h, w = self.pull_item(index) 147 | # i.e. tensor[c,h,w],[[xmin, ymin, xmax, ymax, label_idx], ... ] 148 | return im, gt 149 | 150 | def __len__(self): 151 | return len(self.ids) 152 | 153 | def pull_item(self, index): 154 | # img_id=(./VOCdevkit/VOC2007, Image_ID) 155 | img_id = self.ids[index] 156 | # '%s/Annotations/%s.xml'.format((./VOCdevkit/VOC2007, Img_ID)) 157 | # ===>./root/VOC2007/Annotations/Image_ID.xml' 158 | # target 为解析后的.xml 文件根节点。 159 | target = ET.parse(self._annopath % img_id).getroot() 160 | # ===>./root/VOC2007/Annotations/Image_ID.jpg' 161 | img = cv2.imread(self._imgpath % img_id) 162 | # 得到图片的宽高 163 | height, width, channels = img.shape 164 | 165 | # 对标注格式进行转换,默认为上文的VOCAnnotationTransform() 166 | # 输入一个ET.parse().getroot()的element,得到[[xmin, ymin, xmax, ymax, label_ind], ... ] 167 | if self.target_transform is not None: 168 | target = self.target_transform(target, width, height) 169 | 170 | if self.transform is not None: 171 | # 将list转化为np.ndarray 172 | target = np.array(target) 173 | # img为cv图片 174 | # boxes=[xmin, ymin, xmax, ymax]\in[0,1], 175 | # abels=类名对应的序号,i.e.[idx] 176 | img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) 177 | # to rgb:[h,w,c], 其中c 为 BGR 178 | # i.e. img = img.transpose(2, 0, 1) 179 | img = img[:, :, (2, 1, 0)] 180 | 181 | # hstack,在最低的维度进行连接,这不还原成了上面的target? 182 | # [[xmin, ymin, xmax, ymax, label_idx], ... ] 183 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 184 | # tensor[c,h,w], np.array[[xmin, ymin, xmax, ymax, label_ind], ... ] 185 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width 186 | # return torch.from_numpy(img), target, height, width 187 | 188 | # 返回原始的PIL图片 189 | def pull_image(self, index): 190 | '''Returns the original image object at index in PIL form 191 | 192 | Note: not using self.__getitem__(), as any transformations passed in 193 | could mess up this functionality. 194 | 195 | Argument: 196 | index (int): index of img to show 197 | Return: 198 | PIL img 199 | ''' 200 | img_id = self.ids[index] 201 | # cv.IMREAD_COLOR = 1 : 将图像转为彩色读取 202 | return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR) 203 | 204 | def pull_anno(self, index): 205 | '''Returns the original annotation of image at index 206 | 207 | Note: not using self.__getitem__(), as any transformations passed in 208 | could mess up this functionality. 209 | 210 | Argument: 211 | index (int): index of img to get annotation of 212 | Return: 213 | list: [img_id, [(label, bbox coords),...]] 214 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 215 | ''' 216 | img_id = self.ids[index] 217 | anno = ET.parse(self._annopath % img_id).getroot() 218 | gt = self.target_transform(anno, 1, 1) 219 | return img_id[1], gt 220 | 221 | def pull_tensor(self, index): 222 | '''Returns the original image at an index in tensor form 223 | 224 | Note: not using self.__getitem__(), as any transformations passed in 225 | could mess up this functionality. 226 | 227 | Argument: 228 | index (int): index of img to show 229 | Return: 230 | tensorized version of img, squeezed 231 | ''' 232 | return torch.Tensor(self.pull_image(index)).unsqueeze_(0) 233 | 234 | 235 | # 如何把多个sample打包成batch的函数 236 | def detection_collate(batch): 237 | r"""Custom collate fn for dealing with batches of images that have a different 238 | number of associated object annotations (bounding boxes). 239 | 240 | Arguments: 241 | batch: (tuple) A tuple of tensor images and lists of annotations 242 | 243 | Return: 244 | imgs: tensor [batch_size, 3, 448, 448] 245 | boxes: list of tensor:[, 4] for (x1, y1, x2, y2) 246 | labels: list of LongTensor:[,1] 247 | gt_outs: the ground truth outputs of model 248 | """ 249 | imgs = [] 250 | boxes, labels, gt_outs = [], [], [] 251 | for sample in batch: 252 | # sample[0]:[3,h,w], sample[1]:[, 5] 253 | imgs.append(sample[0]) 254 | # print(sample[1].shape, sample[1]) 255 | box = torch.FloatTensor([i[:4] for i in sample[1]]) 256 | label = torch.LongTensor([i[4] for i in sample[1]]) 257 | boxes.append(box) 258 | labels.append(label) 259 | gt_outs.append(yolov1_data_encoder(box, label)) 260 | # print(f'boxes:{boxes}\nlabels:{labels}') 261 | 262 | return torch.stack(imgs, 0), boxes, labels, torch.stack(gt_outs, 0) 263 | # return imgs, targets 264 | # return torch.stack(imgs, 0), targets 265 | 266 | 267 | def yolov1_data_encoder(boxes, labels, grid_num=GRID_NUM): 268 | """ 269 | boxes (tensor) [[x1,y1,x2,y2],[]] 270 | labels (tensor) [...] 271 | return SxSx30 272 | 30: B1[:4], Obj1[4], B2[5:9], Obj[9], C[9:] 273 | """ 274 | target = torch.zeros((grid_num, grid_num, 30)) 275 | cell_size = 1. / grid_num 276 | # (w,h) 277 | wh = boxes[:, 2:] - boxes[:, :2] 278 | # center(x,y) 279 | cxcy = (boxes[:, 2:] + boxes[:, :2]) / 2 280 | for i in range(cxcy.size()[0]): 281 | cxcy_sample = cxcy[i] 282 | # 计算属于格子的第几行第几列 283 | ij = (cxcy_sample / cell_size).ceil() - 1 284 | # B1、B2、C 标记为1 285 | target[int(ij[1]), int(ij[0]), 4] = 1 286 | target[int(ij[1]), int(ij[0]), 9] = 1 287 | # int(labels[i]) + 10 288 | target[int(ij[1]), int(ij[0]), int(labels[i]) + 10] = 1 289 | # 匹配到的网格的左上角相对坐标 290 | xy = ij * cell_size 291 | # 真框相对于格子坐上角的偏移量 292 | delta_xy = (cxcy_sample - xy) / cell_size 293 | target[int(ij[1]), int(ij[0]), 2:4] = wh[i] 294 | target[int(ij[1]), int(ij[0]), :2] = delta_xy 295 | target[int(ij[1]), int(ij[0]), 7:9] = wh[i] 296 | target[int(ij[1]), int(ij[0]), 5:7] = delta_xy 297 | return target 298 | 299 | 300 | # batch:(imgs:list[tensor img \in(0,1)], targets:list[tensor:[object_num, 5]]) 301 | def get_voc_data_set(args, percent_coord=False, test=False, year=None): 302 | if not test: 303 | image_sets = (('2007', 'trainval'), ('2012', 'trainval')) 304 | else: 305 | if year is None: 306 | image_sets = (('2007test', 'test'), ('2012test', 'test')) 307 | elif year == '2007': 308 | image_sets = (('2007test', 'test'),) 309 | elif year == '2012': 310 | image_sets = (('2012test', 'test'),) 311 | from augmentations import Yolov1Augmentation 312 | dataset = VOCDetection(root=args.voc_data_set_root, 313 | image_sets=image_sets, 314 | # transform=Yolov1Augmentation(size=YOLOv1_PIC_SIZE, percent_coord=percent_coord)) 315 | transform=Yolov1Augmentation(size=448, percent_coord=percent_coord)) 316 | return data.DataLoader(dataset, 317 | args.batch_size, 318 | num_workers=args.num_workers, 319 | shuffle=True, 320 | collate_fn=detection_collate, 321 | pin_memory=False) 322 | 323 | 324 | if __name__ == '__main__': 325 | from predict import draw_box 326 | 327 | # global TEST_MODE 328 | from train import config_parser 329 | 330 | args = config_parser() 331 | data_set = get_voc_data_set(args, percent_coord=True, test=True, year='2012') 332 | for _, (imgs, gt_boxes, gt_labels, gt_outs) in enumerate(data_set): 333 | # print(f'img:{imgs}') 334 | # print(f'targets:{targets}') 335 | print(f'gt_encode:{gt_outs}') 336 | for gt_out in gt_outs: 337 | print(f'gt_out:{gt_out.shape}') 338 | print(gt_out.nonzero().transpose(1, 0)) 339 | gt_out_nonzero_split = torch.split(gt_out.nonzero().transpose(1, 0), dim=0, split_size_or_sections=1) 340 | print(f'gt_out_nonzero_split:{gt_out_nonzero_split}') 341 | print(f'gt_out:{gt_out[gt_out_nonzero_split]}') 342 | for img, gt_box, gt_label in zip(imgs, gt_boxes, gt_labels): 343 | gt_box_np = gt_box.cpu().numpy() 344 | gt_label_np = gt_label.cpu().numpy() 345 | print(f'gt_label_np:{gt_label_np}') 346 | print(f'gt_box_np{gt_box_np.shape},gt_label_np:{gt_label_np.shape}') 347 | img_np = (img * 255.0).cpu().numpy().astype(np.uint8) 348 | # print(f'img_np:{img_np}') 349 | img_np = img_np.transpose(1, 2, 0) # [..., (2, 1, 0)] 350 | # img_np = cv2.cvtColor((img * 255.0).cpu().numpy(), cv2.COLOR_RGB2BGR) 351 | # print(img_np.shape) 352 | draw_box(img_np, 353 | gt_box_np, 354 | gt_label_np, 355 | relative_coord=True) 356 | -------------------------------------------------------------------------------- /src/detect.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/src/detect.py -------------------------------------------------------------------------------- /src/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from torch.optim import Optimizer 4 | from torch.optim.lr_scheduler import MultiStepLR 5 | 6 | 7 | class WarmUpMultiStepLR(MultiStepLR): 8 | def __init__(self, optimizer: Optimizer, 9 | milestones: List[int], 10 | gamma: float = 0.1, 11 | warm_up_factor: float = 0.1, 12 | warm_up_iters: int = 500, 13 | last_epoch: int = -1): 14 | self.factor = warm_up_factor 15 | self.warm_up_iters = warm_up_iters 16 | super().__init__(optimizer, milestones, gamma, last_epoch) 17 | 18 | def get_lr(self) -> List[float]: 19 | if self.last_epoch < self.warm_up_iters: 20 | alpha = self.last_epoch / self.warm_up_iters 21 | factor = (1 - self.factor) * alpha + self.factor 22 | else: 23 | factor = 1 24 | 25 | return [lr * factor for lr in super().get_lr()] 26 | 27 | 28 | if __name__ == '__main__': 29 | last_epoch = 2 30 | for iter in range(1, 1000): 31 | factor = 0.1 32 | alpha = iter / 1000 33 | factor = (1 - factor) * alpha + factor 34 | print(f'factor:{factor}') 35 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Type 2 | from enum import Enum 3 | import torchvision 4 | import torch 5 | from torch import nn 6 | from typing import Tuple, List, Optional, Union 7 | from torch import nn 8 | from torchvision import models as Models 9 | from os import path as osp 10 | import os 11 | from config import * 12 | 13 | 14 | def get_backbone(model_name: str): 15 | r""" 16 | get pre-trained base-network for yolo-v1, 17 | children[:5] do not require grad 18 | :param model_name: name of model 19 | :return: pre-layer of pre-trained model without FC 20 | """ 21 | model_dict = { 22 | 'resnet18': Models.resnet18(True), 23 | 'resnet50': Models.resnet50(True), 24 | 'resnet101': Models.resnet101(True) 25 | } 26 | ''' 27 | list(resnet18.children()) consists of following modules 28 | [0] = Conv2d, [1] = BatchNorm2d, [2] = ReLU, 29 | [3] = MaxPool2d, [4] = Sequential(Bottleneck...), 30 | [5] = Sequential(Bottleneck...), 31 | [6] = Sequential(Bottleneck...), 32 | [7] = Sequential(Bottleneck...), 33 | [8] = AvgPool2d, [9] = Linear 34 | ''' 35 | # when input shape is [, 3, 448, 448], output shape is: 36 | feature_maps_shape = { 37 | 'resnet18': (512, 14, 14), 38 | 'resnet50': (2048, 14, 14), 39 | 'resnet101': (2048, 14, 14) 40 | } 41 | features = list(model_dict.get(model_name).children())[:-2] 42 | for parameters in [feature.parameters() for i, feature in enumerate(features) if i <= 4]: 43 | for parameter in parameters: 44 | parameter.requires_grad = False 45 | return nn.Sequential(*features), feature_maps_shape.get(model_name) 46 | 47 | 48 | class Yolov1(nn.Module): 49 | def __init__(self, backbone_name: str, grid_num=GRID_NUM, model_save_dir=MODEL_SAVE_DIR): 50 | def get_tuple_multiplied(input_tuple: tuple): 51 | res = 1.0 52 | for i in input_tuple: 53 | res *= i 54 | return int(res) 55 | 56 | super(Yolov1, self).__init__() 57 | self.model_save_dir = model_save_dir 58 | self.grid_num = grid_num 59 | # self.backbone_name = backbone_name 60 | self.backbone, feature_maps_shape = get_backbone(backbone_name) 61 | self.model_save_name = '{}_{}'.format(self.__class__.__name__, backbone_name) 62 | last_conv3x3_out_channel = 1024 63 | self.last_conv3x3 = nn.Sequential( 64 | nn.Conv2d(in_channels=feature_maps_shape[0], out_channels=last_conv3x3_out_channel, 65 | kernel_size=3, stride=2, padding=1), 66 | nn.ReLU(True), 67 | nn.BatchNorm2d(last_conv3x3_out_channel) 68 | ) 69 | self.cls = nn.Sequential( 70 | nn.Linear(get_tuple_multiplied((last_conv3x3_out_channel, self.grid_num, self.grid_num)), 4096), 71 | nn.ReLU(True), 72 | nn.Dropout(), 73 | # nn.Linear(4096, 4096), 74 | # nn.ReLU(True), 75 | # nn.Dropout(), 76 | nn.Linear(4096, int(self.grid_num * self.grid_num * 30)), 77 | ) 78 | 79 | def forward(self, x): 80 | x = self.backbone(x) 81 | x = self.last_conv3x3(x) 82 | x = x.view(x.size(0), -1) 83 | x = self.cls(x) 84 | x = torch.sigmoid(x) # 归一化到0-1 85 | x = x.view(-1, self.grid_num, self.grid_num, 30) 86 | return x 87 | 88 | def save_model(self, step=None, optimizer=None, lr_scheduler=None): 89 | self.save_safely(self.state_dict(), self.model_save_dir, self.model_save_name + '.pkl') 90 | print('*** model weights saved successfully at {}!'.format( 91 | osp.join(self.model_save_dir, self.model_save_name + '.pkl'))) 92 | if optimizer and lr_scheduler and step is not None: 93 | temp = { 94 | 'step': step, 95 | 'lr_scheduler': lr_scheduler.state_dict(), 96 | 'optimizer': optimizer.state_dict(), 97 | } 98 | self.save_safely(temp, self.model_save_dir, self.model_save_name + '_para.pkl') 99 | print('*** auxiliary part saved successfully at {}!'.format( 100 | osp.join(self.model_save_dir, self.model_save_name + '.pkl'))) 101 | 102 | def load_model(self, optimizer=None, lr_scheduler=None): 103 | try: 104 | saved_model = torch.load(osp.join(self.model_save_dir, self.model_save_name + '.pkl'), 105 | map_location='cpu') 106 | self.load_state_dict(saved_model) 107 | print('*** loading model weight successfully!') 108 | except Exception: 109 | print('*** loading model weight fail!') 110 | 111 | if optimizer and lr_scheduler is not None: 112 | try: 113 | temp = torch.load(osp.join(self.model_save_dir, self.model_save_name + '_para.pkl'), map_location='cpu') 114 | lr_scheduler.load_state_dict(temp['lr_scheduler']) 115 | step = temp['step'] 116 | print('*** loading optimizer&lr_scheduler&step successfully!') 117 | return step 118 | except Exception: 119 | print('*** loading optimizer&lr_scheduler&step fail!') 120 | return 0 121 | 122 | @staticmethod 123 | def save_safely(file, dir_path, file_name): 124 | r""" 125 | save the file safely, if detect the file name conflict, 126 | save the new file first and remove the old file 127 | """ 128 | if not osp.exists(dir_path): 129 | os.mkdir(dir_path) 130 | print('*** dir not exist, created one') 131 | save_path = osp.join(dir_path, file_name) 132 | if osp.exists(save_path): 133 | temp_name = save_path + '.temp' 134 | torch.save(file, temp_name) 135 | os.remove(save_path) 136 | os.rename(temp_name, save_path) 137 | print('*** find the file conflict while saving, saved safely') 138 | else: 139 | torch.save(file, save_path) 140 | 141 | 142 | if __name__ == '__main__': 143 | from torch import optim 144 | from lr_scheduler import WarmUpMultiStepLR 145 | 146 | x = torch.rand(2, 3, 448, 448) 147 | for name in ['resnet18', 'resnet50', 'resnet101']: 148 | model, _ = get_backbone(name) 149 | step = 0 150 | yolo_model = Yolov1(backbone_name=name) 151 | optimizer = optim.SGD(yolo_model.parameters(), 152 | lr=LEARNING_RATE, 153 | momentum=MOMENTUM, 154 | weight_decay=WEIGHT_DECAY) 155 | scheduler = WarmUpMultiStepLR(optimizer, 156 | milestones=STEP_LR_SIZES, 157 | gamma=STEP_LR_GAMMA, 158 | factor=WARM_UP_FACTOR, 159 | num_iters=WARM_UP_NUM_ITERS) 160 | yolo_model.save_model(optimizer=optimizer, lr_scheduler=scheduler, step=step) 161 | yolo_model.load_model(optimizer=optimizer, lr_scheduler=scheduler) 162 | print(yolo_model.model_save_name) 163 | # y1 = model(x) 164 | # print(f'y1.shape:{y1.shape}') 165 | y2 = yolo_model(x) 166 | print(f'y2.shape:{y2.shape}') 167 | del yolo_model 168 | pass 169 | -------------------------------------------------------------------------------- /src/predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from config import GRID_NUM, DEVICE 3 | from dataset import VOC_CLASSES 4 | from model import Yolov1 5 | from matplotlib import pyplot as plt 6 | from torchvision import transforms 7 | import PIL 8 | from PIL import Image 9 | import numpy as np 10 | from os import path as osp 11 | import os 12 | from numpy.random import shuffle 13 | 14 | 15 | def draw_box(img_np, boxes_np, tags_np, scores_np=None, relative_coord=False, save_path=None): 16 | if scores_np is None: 17 | scores_np = [1.0 for i in tags_np] 18 | # img = cv2.cvtColor(img_cv, cv2.COLOR_RGB2GRAY) 19 | h, w, _ = img_np.shape 20 | if relative_coord and len(boxes_np) > 0: 21 | boxes_np = np.array([ 22 | boxes_np[:, 0] * w, 23 | boxes_np[:, 1] * h, 24 | boxes_np[:, 2] * w, 25 | boxes_np[:, 3] * h, 26 | ]).T 27 | plt.figure(figsize=(10, 10)) 28 | colors = plt.cm.hsv(np.linspace(0, 1, 21)).tolist() 29 | currentAxis = plt.gca() 30 | for box, tag, score in zip(boxes_np, tags_np, scores_np): 31 | from dataset import VOC_CLASSES as LABLES 32 | tag = int(tag) 33 | label_name = LABLES[tag] 34 | display_txt = '%s: %.2f' % (label_name, score) 35 | coords = (box[0], box[1]), box[2] - box[0] + 1, box[3] - box[1] + 1 36 | color = colors[tag] 37 | currentAxis.add_patch(plt.Rectangle(*coords, fill=False, edgecolor=color, linewidth=2)) 38 | currentAxis.text(box[0], box[1], display_txt, bbox={'facecolor': color, 'alpha': 0.5}) 39 | plt.imshow(img_np) 40 | if save_path is not None: 41 | # fig, ax = plt.subplots() 42 | fig = plt.gcf() 43 | fig.savefig(save_path) 44 | plt.cla() 45 | plt.clf() 46 | plt.close('all') 47 | else: 48 | plt.show() 49 | 50 | 51 | def decoder(pred, obj_thres=0.1): 52 | r""" 53 | :param pred: the output of the yolov1 model, should be tensor of [1, grid_num, grid_num, 30] 54 | :param obj_thres: the threshold of objectness 55 | :return: list of [c, [boxes, labels]], boxes is [:4], labels is [4] 56 | """ 57 | pred = pred.cpu() 58 | assert pred.shape[0] == 1 59 | # i for W, j for H 60 | res = [[] for i in range(len(VOC_CLASSES))] 61 | # print(res) 62 | for h in range(GRID_NUM): 63 | for w in range(GRID_NUM): 64 | better_box = pred[0, h, w, :5] if pred[0, h, w, 4] > pred[0, h, w, 9] else pred[0, h, w, 5:10] 65 | if better_box[4] < obj_thres: 66 | continue 67 | better_box_xyxy = torch.FloatTensor(better_box.size()) 68 | # print(f'grid(cx,cy), (w,h), obj:{better_box}') 69 | better_box_xyxy[:2] = better_box[:2] / float(GRID_NUM) - 0.5 * better_box[2:4] 70 | better_box_xyxy[2:4] = better_box[:2] / float(GRID_NUM) + 0.5 * better_box[2:4] 71 | better_box_xyxy[0:4:2] += (w / float(GRID_NUM)) 72 | better_box_xyxy[1:4:2] += (h / float(GRID_NUM)) 73 | better_box_xyxy = better_box_xyxy.clamp(max=1.0, min=0.0) 74 | score, cls = pred[0, h, w, 10:].max(dim=0) 75 | # print(f'pre_cls_shape:{pred[0, w, h, 10:].shape}') 76 | from dataset import VOC_CLASSES as LABELS 77 | # print(f'score:{score}\tcls:{cls}\ttag:{LABELS[cls]}') 78 | better_box_xyxy[4] = score * better_box[4] 79 | res[cls].append(better_box_xyxy) 80 | # print(res) 81 | for i in range(len(VOC_CLASSES)): 82 | if len(res[i]) > 0: 83 | # res[i] = [box.unsqueeze(0) for box in res[i]] 84 | res[i] = torch.stack(res[i], 0) 85 | else: 86 | res[i] = torch.tensor([]) 87 | # print(res) 88 | return res 89 | 90 | 91 | def _nms(boxes, scores, overlap=0.5, top_k=None): 92 | r""" 93 | Apply non-maximum suppression at test time to avoid detecting too many 94 | overlapping bounding boxes for a given object. 95 | Args: 96 | boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. 97 | scores: (tensor) The class predscores for the img, Shape:[num_priors]. 98 | overlap: (float) The overlap thresh for suppressing unnecessary boxes. 99 | top_k: (int) The Maximum number of box preds to consider. 100 | Return: 101 | The indices of the kept boxes with respect to num_priors. 102 | """ 103 | # boxes = boxes.detach() 104 | # keep shape [num_prior] type: Long 105 | keep = scores.new(scores.size(0)).zero_().long() 106 | # print('keep.shape:{}'.format(keep.shape)) 107 | # tensor.numel()用于计算tensor里面包含元素的总数,i.e. shape[0]*shape[1]... 108 | if boxes.numel() == 0: 109 | return keep 110 | x1 = boxes[:, 0] 111 | y1 = boxes[:, 1] 112 | x2 = boxes[:, 2] 113 | y2 = boxes[:, 3] 114 | # print('x1:{}\ny1:{}\nx2:{}\ny2:{}'.format(x1, y1, x2, y2)) 115 | # area shape[prior_num], 代表每个prior框的面积 116 | area = torch.mul(x2 - x1, y2 - y1) 117 | v, idx = scores.sort(0) # sort in ascending order 118 | # print(f'idx:{idx}') 119 | # I = I[v >= 0.01] 120 | if top_k is not None: 121 | # indices of the top-k largest vals 122 | idx = idx[-top_k:] 123 | # keep = torch.Tensor() 124 | count = 0 125 | # Returns the total number of elements in the input tensor. 126 | while idx.numel() > 0: 127 | i = idx[-1] # index of current largest val 128 | # keep.append(i) 129 | keep[count] = i 130 | count += 1 131 | if idx.size(0) == 1: 132 | break 133 | idx = idx[:-1] # remove kept element from view 134 | # load bboxes of next highest vals 135 | # torch.index_select(input, dim, index, out=None) 136 | # 将input里面dim维度上序号为idx的元素放到out里面去 137 | # >>> x 138 | # tensor([[1, 2, 3], 139 | # [3, 4, 5]]) 140 | # >>> z=torch.index_select(x,0,torch.tensor([1,0])) 141 | # >>> z 142 | # tensor([[3, 4, 5], 143 | # [1, 2, 3]]) 144 | xx1 = x1[idx] 145 | # torch.index_select(x1, 0, idx, out=xx1) 146 | yy1 = y1[idx] 147 | # torch.index_select(y1, 0, idx, out=yy1) 148 | xx2 = x2[idx] 149 | # torch.index_select(x2, 0, idx, out=xx2) 150 | yy2 = y2[idx] 151 | # torch.index_select(y2, 0, idx, out=yy2) 152 | 153 | # store element-wise max with next highest score 154 | # 将除置信度最高的prior框外的所有框进行clip以计算inter大小 155 | # print(f'xx1.shape:{xx1.shape}') 156 | xx1 = torch.clamp(xx1, min=float(x1[i])) 157 | yy1 = torch.clamp(yy1, min=float(y1[i])) 158 | xx2 = torch.clamp(xx2, max=float(x2[i])) 159 | yy2 = torch.clamp(yy2, max=float(y2[i])) 160 | # w.resize_as_(xx2) 161 | # h.resize_as_(yy2) 162 | w = xx2 - xx1 163 | h = yy2 - yy1 164 | # check sizes of xx1 and xx2.. after each iteration 165 | w = torch.clamp(w, min=0.0) 166 | h = torch.clamp(h, min=0.0) 167 | inter = w * h 168 | # IoU = i / (area(a) + area(b) - i) 169 | rem_areas = torch.index_select(area, 0, idx) # load remaining areas) 170 | union = (rem_areas - inter) + area[i] 171 | IoU = inter / union # store result in iou 172 | # keep only elements with an IoU <= overlap 173 | # torch.le===>less and equal to 174 | idx = idx[IoU.le(overlap)] 175 | # print(keep, count) 176 | # keep 包含置信度从大到小的prior框的indices,count表示数量 177 | # print('keep.shape:{},count:{}'.format(keep.shape, count)) 178 | return keep, count 179 | 180 | 181 | def img_to_tensor_batch(img_path, size=(448, 448)): 182 | img = Image.open(img_path) 183 | img_resize = img.resize(size, PIL.Image.BILINEAR) 184 | img_tensor = transforms.ToTensor()(img_resize).unsqueeze(0) 185 | # print(f'img_tensor:{img_tensor.shape}') 186 | # print(f'img_tensor:{img_tensor}') 187 | return img_tensor, img 188 | 189 | 190 | def predict_one_img(img_path, model): 191 | # model = Yolov1(backbone_name=backbone_name) 192 | # model.load_model() 193 | img_tensor, img = img_to_tensor_batch(img_path) 194 | boxes, tags, scores = predict(img_tensor, model) 195 | img = np.array(img) 196 | draw_box(img_np=img, boxes_np=boxes, scores_np=scores, tags_np=tags, relative_coord=True) 197 | 198 | 199 | def predict(img_tensor, model): 200 | model.eval() 201 | img_tensor, model = img_tensor.to(DEVICE), model.to(DEVICE) 202 | with torch.no_grad(): 203 | out = model(img_tensor) 204 | # out:list[tensor[, 5]] 205 | out = decoder(out, obj_thres=0.3) 206 | boxes, tags, scores = [], [], [] 207 | for cls, pred_target in enumerate(out): 208 | if pred_target.shape[0] > 0: 209 | # print(pred_target.shape) 210 | b = pred_target[:, :4] 211 | p = pred_target[:, 4] 212 | # print(b, p) 213 | keep_idx, count = _nms(b, p, overlap=0.5) 214 | # keep:[, 5] 215 | keep = pred_target[keep_idx] 216 | for box in keep[..., :4]: boxes.append(box) 217 | for tag in range(count): tags.append(torch.LongTensor([cls])) 218 | for score in keep[..., 4]: scores.append(score) 219 | # print(f'*** boxes:{boxes}\ntags:{tags}\nscores:{scores}') 220 | if len(boxes) > 0: 221 | boxes = torch.stack(boxes, 0).numpy() # .squeeze(dim=0) 222 | tags = torch.stack(tags, 0).numpy() # .squeeze(dim=0) 223 | scores = torch.stack(scores, 0).numpy() # .squeeze(dim=0) 224 | # print(f'*** boxes:{boxes}\ntags:{tags}\nscores:{scores}') 225 | else: 226 | boxes = torch.FloatTensor([]) 227 | tags = torch.LongTensor([]) # .squeeze(dim=0) 228 | scores = torch.FloatTensor([]) # .squeeze(dim=0) 229 | # img, boxes, tags, scores = np.array(img), np.array(boxes), np.array(tags), np.array(scores) 230 | return boxes, tags, scores 231 | 232 | 233 | if __name__ == '__main__': 234 | # test: 235 | # fake_pred = torch.rand(1, GRID_NUM, GRID_NUM, 30) 236 | # decoder(fake_pred) 237 | CONTINUE = False # continue from breakpoint 238 | model = Yolov1(backbone_name='resnet50') 239 | model.load_model() 240 | # predict_one_img('../test_img/000001.jpg', model) 241 | # test_img_dir = '../test_img' 242 | test_img_dir = '/Users/chenlinwei/Dataset/VOC0712/VOC2012test/JPEGImages' 243 | for root, dirs, files in os.walk(test_img_dir, topdown=True): 244 | if test_img_dir == root: 245 | print(root, dirs, files) 246 | files = [i for i in files if any([j in i for j in ['jpg', 'png', 'jpeg', 'gif', 'tiff']])] 247 | shuffle(files) 248 | if CONTINUE: 249 | with open(osp.join(test_img_dir, 'tested.txt'), 'a') as _: 250 | pass 251 | with open(osp.join(test_img_dir, 'tested.txt'), 'r') as txt: 252 | txt = txt.readlines() 253 | txt = [i.strip() for i in txt] 254 | print(txt) 255 | files = [i for i in files if i not in txt] 256 | for file in files: 257 | file_path = os.path.join(root, file) 258 | print(f'*** testing:{file_path}') 259 | predict_one_img(file_path, model) 260 | with open(osp.join(test_img_dir, 'tested.txt'), 'a') as txt: 261 | txt.write(file + '\n') 262 | else: 263 | for file in files: 264 | file_path = os.path.join(root, file) 265 | print(f'*** testing:{file_path}') 266 | predict_one_img(file_path, model) 267 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import torch 5 | import torch.nn as nn 6 | import torch.backends.cudnn as cudnn 7 | from torch.autograd import Variable 8 | import numpy as np 9 | import cv2 10 | from matplotlib import pyplot as plt 11 | from torchvision import transforms 12 | from PIL import Image 13 | from dataset import detection_collate, VOCDetection 14 | from torch.utils import data 15 | 16 | # TODO:cal the mAP 17 | def get_test_data_set(args, percent_coord=False, year=None): 18 | if year == '2007': 19 | image_sets = (('2007test', 'test'),) 20 | elif year == '2012': 21 | image_sets = (('2012test', 'test'),) 22 | else: 23 | image_sets = (('2007test', 'test'), ('2012test', 'test')) 24 | from augmentations import Yolov1TestAugmentation 25 | dataset = VOCDetection(root=args.voc_data_set_root, 26 | image_sets=image_sets, 27 | # transform=Yolov1Augmentation(size=YOLOv1_PIC_SIZE, percent_coord=percent_coord)) 28 | transform=Yolov1TestAugmentation(size=448, percent_coord=percent_coord)) 29 | return data.DataLoader(dataset, 30 | args.batch_size, 31 | num_workers=args.num_workers, 32 | shuffle=True, 33 | collate_fn=detection_collate, 34 | pin_memory=False) 35 | 36 | 37 | def data_set_test(): 38 | from predict import draw_box 39 | 40 | # global TEST_MODE 41 | from train import config_parser 42 | 43 | args = config_parser() 44 | data_set = get_test_data_set(args, percent_coord=True, year='2007') 45 | for _, (imgs, gt_boxes, gt_labels, gt_outs) in enumerate(data_set): 46 | # print(f'img:{imgs}') 47 | # print(f'targets:{targets}') 48 | print(f'gt_encode:{gt_outs}') 49 | for gt_out in gt_outs: 50 | print(f'gt_out:{gt_out.shape}') 51 | print(gt_out.nonzero().transpose(1, 0)) 52 | gt_out_nonzero_split = torch.split(gt_out.nonzero().transpose(1, 0), dim=0, split_size_or_sections=1) 53 | print(f'gt_out_nonzero_split:{gt_out_nonzero_split}') 54 | print(f'gt_out:{gt_out[gt_out_nonzero_split]}') 55 | for img, gt_box, gt_label in zip(imgs, gt_boxes, gt_labels): 56 | gt_box_np = gt_box.cpu().numpy() 57 | gt_label_np = gt_label.cpu().numpy() 58 | print(f'gt_label_np:{gt_label_np}') 59 | print(f'gt_box_np{gt_box_np.shape},gt_label_np:{gt_label_np.shape}') 60 | img_np = (img * 255.0).cpu().numpy().astype(np.uint8) 61 | # print(f'img_np:{img_np}') 62 | img_np = img_np.transpose(1, 2, 0) # [..., (2, 1, 0)] 63 | # img_np = cv2.cvtColor((img * 255.0).cpu().numpy(), cv2.COLOR_RGB2BGR) 64 | # print(img_np.shape) 65 | draw_box(img_np, 66 | gt_box_np, 67 | gt_label_np, 68 | relative_coord=True) 69 | 70 | 71 | if __name__ == '__main__': 72 | data_set_test() 73 | pass 74 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | from dataset import get_voc_data_set, yolov1_data_encoder 3 | from config import DEVICE 4 | from model import Yolov1 5 | from config import * 6 | from lr_scheduler import WarmUpMultiStepLR 7 | import time 8 | from yolov1loss import Yolov1Loss 9 | 10 | 11 | # from dataset import 12 | 13 | def config_parser(): 14 | parser = argparse.ArgumentParser(description='Single Shot MultiBox Detector Training With Pytorch') 15 | # train_set = parser.add_mutually_exclusive_group() 16 | # 训练集与基础网络设定 17 | parser.add_argument('--voc_data_set_root', default='/Users/chenlinwei/Dataset/VOC0712', 18 | help='data_set root directory path') 19 | parser.add_argument('--batch_size', default=2, type=int, 20 | help='Batch size for training') 21 | parser.add_argument('--num_workers', default=0, type=int, 22 | help='Number of workers used in dataloading') 23 | # 文件保存路径 24 | # parser.add_argument('--save_folder', default='./saved_model/', 25 | # help='Directory for saving checkpoint models') 26 | parser.add_argument('--save_step', default=100, type=int, 27 | help='Directory for saving checkpoint models') 28 | # 恢复训练 29 | parser.add_argument('--backbone', default='resnet18', choices=['resnet18', 'resnet50', 'resnet101'], 30 | help='pre-trained base model name.') 31 | # 优化器参数设置 32 | # parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, 33 | # help='initial learning rate') 34 | # parser.add_argument('--momentum', default=0.9, type=float, 35 | # help='Momentum value for optim') 36 | # parser.add_argument('--weight_decay', default=5e-4, type=float, 37 | # help='Weight decay for SGD') 38 | # parser.add_argument('--gamma', default=0.1, type=float, 39 | # help='Gamma update for SGD') 40 | parser.add_argument('--cuda', default=True, type=bool, 41 | help='use cuda or not') 42 | 43 | args = parser.parse_args() 44 | return args 45 | 46 | 47 | def train(args): 48 | model = Yolov1(backbone_name=args.backbone) 49 | optimizer = optim.SGD(model.parameters(), 50 | lr=LEARNING_RATE, 51 | momentum=MOMENTUM, 52 | weight_decay=WEIGHT_DECAY) 53 | scheduler = WarmUpMultiStepLR(optimizer, 54 | milestones=STEP_LR_SIZES, 55 | gamma=STEP_LR_GAMMA, 56 | warm_up_factor=WARM_UP_FACTOR, 57 | warm_up_iters=WARM_UP_NUM_ITERS) 58 | step = model.load_model(optimizer=optimizer, lr_scheduler=scheduler) 59 | model.to(DEVICE) 60 | model.train() 61 | criterion = Yolov1Loss() 62 | while step < NUM_STEPS_TO_FINISH: 63 | data_set = get_voc_data_set(args, percent_coord=True) 64 | t1 = time.perf_counter() 65 | for _, (imgs, gt_boxes, gt_labels, gt_outs) in enumerate(data_set): 66 | step += 1 67 | scheduler.step() 68 | imgs = imgs.to(DEVICE) 69 | gt_outs = gt_outs.to(DEVICE) 70 | model_outs = model(imgs) 71 | loss = criterion(model_outs, gt_outs) 72 | optimizer.zero_grad() 73 | loss.backward() 74 | optimizer.step() 75 | t2 = time.perf_counter() 76 | print('step:{} | loss:{:.8f} | time:{:.4f}'.format(step, loss.item(), t2 - t1)) 77 | t1 = time.perf_counter() 78 | if step != 0 and step % args.save_step == 0: 79 | model.save_model(step, optimizer, scheduler) 80 | 81 | 82 | if __name__ == '__main__': 83 | args = config_parser() 84 | train(args) 85 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linwei-Chen/YOLOv1-Pytorch/aa2a6c152087e9ab5a9f15aae6b0a95a07ce9ad9/src/utils.py -------------------------------------------------------------------------------- /src/yolov1loss.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from config import DEVICE, GRID_NUM 6 | 7 | 8 | class Yolov1Loss(nn.Module): 9 | def __init__(self, s=GRID_NUM, b=2, l_coord=5, l_noobj=0.5): 10 | super(Yolov1Loss, self).__init__() 11 | self.S = float(s) 12 | self.B = int(b) 13 | self.l_coord = l_coord 14 | self.l_noobj = l_noobj 15 | 16 | def compute_iou(self, box1, box2): 17 | r"""Compute the intersection over union of two set of boxes, each box is [x1,y1,x2,y2]. 18 | Args: 19 | box1: (tensor) bounding boxes, sized [N,4]. 20 | box2: (tensor) bounding boxes, sized [M,4]. 21 | Return: 22 | (tensor) iou, sized [N,M]. 23 | """ 24 | N = box1.size(0) 25 | M = box2.size(0) 26 | r''' 27 | torch.max(input, other, out=None) → Tensor 28 | Each element of the tensor input is compared with the corresponding element 29 | of the tensor other and an element-wise maximum is taken. 30 | ''' 31 | # left top 32 | lt = torch.max( 33 | box1[:, :2].unsqueeze(1).expand(N, M, 2), # [N,2] -> [N,1,2] -> [N,M,2] 34 | box2[:, :2].unsqueeze(0).expand(N, M, 2), # [M,2] -> [1,M,2] -> [N,M,2] 35 | ) 36 | # right bottom 37 | rb = torch.min( 38 | box1[:, 2:].unsqueeze(1).expand(N, M, 2), # [N,2] -> [N,1,2] -> [N,M,2] 39 | box2[:, 2:].unsqueeze(0).expand(N, M, 2), # [M,2] -> [1,M,2] -> [N,M,2] 40 | ) 41 | 42 | wh = rb - lt # [N,M,2] 43 | wh[wh < 0] = 0 # clip at 0 44 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 45 | 46 | area1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]) # [N,] 47 | area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) # [M,] 48 | area1 = area1.unsqueeze(1).expand_as(inter) # [N,] -> [N,1] -> [N,M] 49 | area2 = area2.unsqueeze(0).expand_as(inter) # [M,] -> [1,M] -> [N,M] 50 | 51 | iou = inter / (area1 + area2 - inter) 52 | return iou 53 | 54 | def forward(self, pred_tensor, target_tensor): 55 | r""" 56 | pred_tensor: (tensor) size(batchsize,S,S,Bx5+20=30) [x,y,w,h,c] 57 | target_tensor: (tensor) size(batchsize,S,S,30) 58 | """ 59 | N = pred_tensor.size()[0] 60 | # contain obj 61 | coo_mask = target_tensor[:, :, :, 4] > 0 62 | # no obj 63 | noo_mask = target_tensor[:, :, :, 4] == 0 64 | coo_mask = coo_mask.unsqueeze(-1).expand_as(target_tensor) 65 | noo_mask = noo_mask.unsqueeze(-1).expand_as(target_tensor) 66 | # coo_pred:tensor[, 30] 67 | coo_pred = pred_tensor[coo_mask].view(-1, 30) 68 | # box[x1,y1,w1,h1,c1], [x2,y2,w2,h2,c2] 69 | box_pred = coo_pred[:, :10].contiguous().view(-1, 5) 70 | # class[...] 71 | class_pred = coo_pred[:, 10:] 72 | 73 | coo_target = target_tensor[coo_mask].view(-1, 30) 74 | box_target = coo_target[:, :10].contiguous().view(-1, 5) 75 | class_target = coo_target[:, 10:] 76 | 77 | # compute not contain obj loss 78 | noo_pred = pred_tensor[noo_mask].view(-1, 30) 79 | noo_target = target_tensor[noo_mask].view(-1, 30) 80 | # noo pred只需要计算 Obj1、2 的损失 size[,2] 81 | noo_pred_mask = torch.ByteTensor(noo_pred.size()).to(DEVICE) 82 | noo_pred_mask.zero_() 83 | noo_pred_mask[:, 4] = 1 84 | noo_pred_mask[:, 9] = 1 85 | noo_pred_c = noo_pred[noo_pred_mask] 86 | noo_target_c = noo_target[noo_pred_mask] 87 | nooobj_loss = F.mse_loss(noo_pred_c, noo_target_c, reduction='sum') 88 | 89 | # compute contain obj loss 90 | coo_response_mask = torch.ByteTensor(box_target.size()).to(DEVICE) 91 | coo_response_mask.zero_() 92 | coo_not_response_mask = torch.ByteTensor(box_target.size()).to(DEVICE) 93 | coo_not_response_mask.zero_() 94 | box_target_iou = torch.zeros(box_target.size()).to(DEVICE) 95 | # 从两个框中二选一 96 | for i in range(0, box_target.size()[0], 2): # choose the best iou box 97 | box1 = box_pred[i:i + 2] 98 | box1_xyxy = torch.FloatTensor(box1.size()) 99 | # (x,y,w,h) 100 | box1_xyxy[:, :2] = box1[:, :2] / self.S - 0.5 * box1[:, 2:4] 101 | box1_xyxy[:, 2:4] = box1[:, :2] / self.S + 0.5 * box1[:, 2:4] 102 | box2 = box_target[i].view(-1, 5) 103 | box2_xyxy = torch.FloatTensor(box2.size()) 104 | box2_xyxy[:, :2] = box2[:, :2] / self.S - 0.5 * box2[:, 2:4] 105 | box2_xyxy[:, 2:4] = box2[:, :2] / self.S + 0.5 * box2[:, 2:4] 106 | # iou(pred_box[2,], target_box[2,]) 107 | iou = self.compute_iou(box1_xyxy[:, :4], box2_xyxy[:, :4]) 108 | # target匹配到的box 109 | max_iou, max_index = iou.max(0) 110 | # print(f'max_iou:{max_iou}, max_index:{max_index}') 111 | max_index = max_index.to(DEVICE) 112 | 113 | coo_response_mask[i + max_index] = 1 114 | coo_not_response_mask[i + 1 - max_index] = 1 115 | ##### 116 | # we want the confidence score to equal the 117 | # intersection over union (IOU) between the predicted box 118 | # and the ground truth 119 | ##### 120 | box_target_iou[i + max_index, torch.LongTensor([4]).to(DEVICE)] = max_iou.to(DEVICE) 121 | 122 | box_target_iou = box_target_iou.to(DEVICE) 123 | # 1.response loss 124 | box_pred_response = box_pred[coo_response_mask].view(-1, 5) 125 | box_target_response_iou = box_target_iou[coo_response_mask].view(-1, 5) 126 | box_target_response = box_target[coo_response_mask].view(-1, 5) 127 | contain_loss = F.mse_loss(box_pred_response[:, 4], box_target_response_iou[:, 4], reduction='sum') 128 | loc_loss = F.mse_loss(box_pred_response[:, :2], box_target_response[:, :2], reduction='sum') + F.mse_loss( 129 | torch.sqrt(box_pred_response[:, 2:4]), torch.sqrt(box_target_response[:, 2:4]), reduction='sum') 130 | # 2.not response loss 131 | box_pred_not_response = box_pred[coo_not_response_mask].view(-1, 5) 132 | box_target_not_response = box_target[coo_not_response_mask].view(-1, 5) 133 | box_target_not_response[:, 4] = 0 134 | # not_contain_loss = F.mse_loss(box_pred_response[:,4],box_target_response[:,4],size_average=False) 135 | 136 | # I believe this bug is simply a typo 137 | not_contain_loss = F.mse_loss(box_pred_not_response[:, 4], box_target_not_response[:, 4], reduction='sum') 138 | 139 | # 3.class loss 140 | class_loss = F.mse_loss(class_pred, class_target, reduction='sum') 141 | 142 | return (self.l_coord * loc_loss + 2 * contain_loss + 143 | not_contain_loss + self.l_noobj * nooobj_loss + class_loss) / N 144 | 145 | 146 | if __name__ == '__main__': 147 | x1 = torch.rand(1, 7, 7, 30) 148 | x2 = torch.rand(1, 7, 7, 30) 149 | # torch.ByteTensor() 150 | x = Yolov1Loss()(x1, x2) 151 | print(x.item()) 152 | --------------------------------------------------------------------------------