├── LICENSE ├── README.md ├── cfg └── voc.py ├── data_utils ├── build.py ├── datasets │ ├── __init__.py │ ├── _utils.py │ └── voc.py ├── evaluate │ └── voc_eval.py └── sampler.py ├── dataset └── voc2007 │ └── __init__.py ├── detect.py ├── dist_comm.py ├── fit_voc_to_yolo.py ├── imgs ├── demo1.png ├── demo2.png └── yolo.PNG ├── structure ├── __init__.py └── bounding_box.py ├── train.py ├── yolo ├── __init__.py ├── darknet.py ├── decoder.py ├── encoder.py ├── loss.py └── yolov1.py └── 中文.md /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 fantastic_levio 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # yolo_1_pytorch 2 | 3 | ###### simplest implementation of yolo v1 via pytorch √ 4 | ##### Language: [中文](中文.md) 5 | ##### paper: [You Only Look Once: Unified, Real-Time Object Detection](https://arxiv.org/pdf/1506.02640.pdf) 6 | ##### CSDN blog: [博客解析](https://muzhan.blog.csdn.net/article/details/82588059) 7 | This repo is a brief implementation of yolo v1. You can easily train the model and visualize the result. 8 | 9 | ![img](https://github.com/leviome/yolo_1_pytorch/blob/master/imgs/yolo.PNG) 10 | |output tensor:|S×S×(B∗5+C)| |S:|num of grids| |B:|num of boxes| |C:|num of classes| 11 | |---|---|---|---|---|---|---|---|---|---|---| 12 | | |7x7x(2*5+20)| | |7| | |2| | |20| 13 | 14 | --- 15 | ``` 16 | git clone https://github.com/leviome/yolo_1_pytorch.git 17 | cd yolo_1_pytorch 18 | ``` 19 | 20 | Environment: 21 | --- 22 | - *Python3* 23 | - *Pytorch>=1.3* 24 | - *cv2* 25 | - *matplotlib* 26 | 27 | Dataset preparation 28 | --- 29 | 1. Download voc2007 dataset: 30 | ``` 31 | wget -c http://pjreddie.com/media/files/VOCtrainval_06-Nov-2007.tar 32 | wget -c http://pjreddie.com/media/files/VOCtest_06-Nov-2007.tar 33 | wget -c http://pjreddie.com/media/files/VOCdevkit_08-Jun-2007.tar 34 | ``` 35 | 2. Extract all tars: 36 | ``` 37 | tar xvf VOCtrainval_06-Nov-2007.tar 38 | tar xvf VOCtest_06-Nov-2007.tar 39 | tar xvf VOCdevkit_08-Jun-2007.tar 40 | ``` 41 | 3. put the data into dataset/voc2007 and make the folder structure look like: 42 | ``` 43 | dataset 44 | ├── voc2007 45 | │   ├── Annotations 46 | │   ├── ImageSets 47 | │   ├── JPEGImages 48 | │   ├── Label 49 | │   ├── SegmentationClass 50 | │   └── SegmentationObject 51 | └── voc2012 52 | ``` 53 | 4. fit voc dataset to yolo model as pytorch dataset format: 54 | ``` 55 | python fit_voc_to_yolo.py 56 | ``` 57 | Train 58 | --- 59 | ``` 60 | python train.py 61 | ``` 62 | Detect single image 63 | --- 64 | ``` 65 | python detect.py 66 | ``` 67 | Demo 68 | --- 69 | ![imgs](https://github.com/leviome/yolo_1_pytorch/blob/master/imgs/demo1.png) 70 | ![imgs](https://github.com/leviome/yolo_1_pytorch/blob/master/imgs/demo2.png) 71 | -------------------------------------------------------------------------------- /cfg/voc.py: -------------------------------------------------------------------------------- 1 | train_cfg = dict() 2 | train_cfg['lr'] = [0.5e-3, 1e-5] 3 | train_cfg['epochs'] = 750 4 | train_cfg['milestone'] = [40, 60] 5 | train_cfg['gamma'] = 0.1 6 | train_cfg['batch_size'] = 16 7 | train_cfg['gpu_id'] = [0, 1] 8 | train_cfg['out_dir'] = 'experiment/VOCNet' 9 | train_cfg['resume'] = False 10 | train_cfg['use_sgd'] = True 11 | train_cfg['device'] = 'cuda' 12 | 13 | train_cfg['dataroot'] = './' 14 | 15 | train_cfg['img_size'] = [448] 16 | 17 | train_cfg['classes'] = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", 18 | "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", 19 | "pottedplant", "sheep", "sofa", "train", "tvmonitor"] 20 | 21 | model_cfg = dict() 22 | model_cfg['model_type'] = 'DarkNet' 23 | model_cfg['class_num'] = len(train_cfg['classes']) 24 | model_cfg['backbone'] = 19 25 | model_cfg['box_num'] = 2 26 | model_cfg['ceil_size'] = 7 27 | model_cfg['pretrained'] = None # 'darknet19_448.conv.23' 28 | model_cfg['l_coord'] = 3 29 | model_cfg['l_obj'] = 3 30 | model_cfg['l_noobj'] = 0.5 31 | model_cfg['conv_mode'] = True 32 | 33 | cfg = dict() 34 | 35 | cfg['train_cfg'] = train_cfg 36 | cfg['model_cfg'] = model_cfg 37 | 38 | -------------------------------------------------------------------------------- /data_utils/build.py: -------------------------------------------------------------------------------- 1 | from data_utils.datasets import * 2 | from torch.utils import data 3 | import torch 4 | from data_utils.sampler import TrainingSampler, InferenceSampler 5 | import cv2 6 | import random 7 | import numpy as np 8 | import math 9 | 10 | 11 | class MutilScaleBatchCollator(object): 12 | def __init__(self, img_size, train): 13 | self.img_size = [a for a in range(min(img_size), max(img_size) + 32, 32)] 14 | # print(self.img_size) 15 | self.train = train 16 | 17 | def normlize(self, img): 18 | 19 | img = np.float32(img) if img.dtype != np.float32 else img.copy() 20 | 21 | return img / 255. 22 | 23 | def process_image(self, meta, sized): 24 | images = [] 25 | 26 | for info in meta: 27 | img = info['img'] 28 | padding_img = img.copy() 29 | info['padding_width'] = padding_img.shape[1] 30 | info['padding_height'] = padding_img.shape[0] 31 | img_size = [sized, sized] 32 | 33 | img_size[0] = math.ceil(img_size[0] / 32) * 32 34 | img_size[1] = math.ceil(img_size[1] / 32) * 32 35 | img_size = (img_size[0], img_size[1]) 36 | 37 | padding_img = cv2.resize(padding_img, img_size) 38 | padding_img = self.normlize(padding_img) 39 | padding_img = torch.from_numpy(padding_img).permute(2, 0, 1).float() 40 | images.append(padding_img) 41 | 42 | return images 43 | 44 | def __call__(self, batch): 45 | meta = list(batch) 46 | if self.train: 47 | sized = random.choice(self.img_size) 48 | else: 49 | sized = sum(self.img_size) / float(len(self.img_size)) 50 | 51 | images = self.process_image(meta, sized) 52 | batch_imgs = torch.cat([a.unsqueeze(0) for a in images]) 53 | 54 | return batch_imgs, meta 55 | 56 | 57 | def make_dist_voc_loader(list_path, train=False, img_size=[(448, 448)], 58 | batch_size=4, num_workers=4): 59 | dataset = VOCDatasets(list_path, train) 60 | collator = MutilScaleBatchCollator(img_size, train) 61 | if train: 62 | sampler = TrainingSampler(len(dataset), shuffle=train) 63 | else: 64 | sampler = InferenceSampler(len(dataset)) 65 | 66 | data_loader = data.DataLoader(dataset=dataset, 67 | batch_size=batch_size, 68 | num_workers=num_workers, 69 | collate_fn=collator, 70 | sampler=sampler, 71 | pin_memory=True 72 | ) 73 | 74 | return data_loader 75 | -------------------------------------------------------------------------------- /data_utils/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from data_utils.datasets.voc import VOCDatasets -------------------------------------------------------------------------------- /data_utils/datasets/_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import math 4 | import random 5 | import torch 6 | from PIL import Image 7 | 8 | 9 | def get_max_overlap(box1, box2): 10 | # get overlap 11 | lt = np.maximum(box1[None, :2], box2[:, :2]) # [N,2] 12 | rb = np.minimum(box1[None, 2:], box2[:, 2:]) # [N,2] 13 | 14 | TO_REMOVE = 1 15 | wh = np.clip(rb - lt, TO_REMOVE, None) # [N,2] 16 | inter = wh[:, 0] * wh[:, 1] # [N] 17 | area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) 18 | area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) 19 | 20 | 21 | iou = inter / (min(area1 , np.min(area2))+1) 22 | return np.max(iou) 23 | 24 | def get_affine_transform(center, 25 | scale, 26 | rot, 27 | output_size, 28 | shift=np.array([0, 0], dtype=np.float32), 29 | inv=0): 30 | def get_3rd_point(a, b): 31 | direct = a - b 32 | return b + np.array([-direct[1], direct[0]], dtype=np.float32) 33 | 34 | def get_dir(src_point, rot_rad): 35 | sn, cs = np.sin(rot_rad), np.cos(rot_rad) 36 | src_result = [0, 0] 37 | src_result[0] = src_point[0] * cs - src_point[1] * sn 38 | src_result[1] = src_point[0] * sn + src_point[1] * cs 39 | return src_result 40 | 41 | if not isinstance(scale, np.ndarray) and not isinstance(scale, list): 42 | scale = np.array([scale, scale], dtype=np.float32) 43 | 44 | scale_tmp = scale 45 | src_w = scale_tmp[0] 46 | dst_w = output_size[0] 47 | dst_h = output_size[1] 48 | 49 | rot_rad = np.pi * rot / 180 50 | src_dir = get_dir([0, src_w * -0.5], rot_rad) 51 | dst_dir = np.array([0, dst_w * -0.5], np.float32) 52 | 53 | src = np.zeros((3, 2), dtype=np.float32) 54 | dst = np.zeros((3, 2), dtype=np.float32) 55 | src[0, :] = center + scale_tmp * shift 56 | src[1, :] = center + src_dir + scale_tmp * shift 57 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 58 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir 59 | 60 | src[2:, :] = get_3rd_point(src[0, :], src[1, :]) 61 | dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) 62 | 63 | if inv: 64 | trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) 65 | else: 66 | trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) 67 | 68 | return trans 69 | 70 | 71 | def get_random_crop_tran(img): 72 | def _get_border(border, size): 73 | i = 1 74 | while size - border // i <= border // i: 75 | i *= 2 76 | return border // i 77 | 78 | c = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32) 79 | height, width = img.shape[0], img.shape[1] 80 | input_h = (height) + 1 81 | input_w = (width) + 1 82 | 83 | s = np.array([input_w, input_h], dtype=np.float32) 84 | s = s * np.random.choice(np.arange(0.8, 1.2, 0.1)) 85 | w_border = _get_border(128, img.shape[1]) 86 | h_border = _get_border(128, img.shape[0]) 87 | c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border) 88 | c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border) 89 | 90 | trans_input = get_affine_transform( 91 | c, s, 0, [img.shape[1], img.shape[0]]) 92 | 93 | return trans_input 94 | 95 | 96 | def random_affine(img, targets, degrees=10, translate=.1, scale=.1, shear=10, border=0): 97 | # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10)) 98 | # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4 99 | 100 | height = img.shape[0] + border * 2 101 | width = img.shape[1] + border * 2 102 | 103 | # Rotation and Scale 104 | R = np.eye(3) 105 | a = random.uniform(-degrees, degrees) 106 | # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations 107 | s = random.uniform(1 - scale, 1 + scale) 108 | R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s) 109 | 110 | # Translation 111 | T = np.eye(3) 112 | T[0, 2] = random.uniform(-translate, translate) * img.shape[0] + border # x translation (pixels) 113 | T[1, 2] = random.uniform(-translate, translate) * img.shape[1] + border # y translation (pixels) 114 | 115 | # Shear 116 | S = np.eye(3) 117 | S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg) 118 | S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg) 119 | 120 | # Combined rotation matrix 121 | M = S @ T @ R # ORDER IS IMPORTANT HERE!! 122 | changed = (border != 0) or (M != np.eye(3)).any() 123 | if changed: 124 | img = cv2.warpAffine(img, M[:2], dsize=(width, height), flags=cv2.INTER_AREA, borderValue=(128, 128, 128)) 125 | targets.warpAffine(M[:2], (width, height)) 126 | 127 | 128 | return img, targets 129 | 130 | 131 | class Grid(object): 132 | def __init__(self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.): 133 | self.use_h = use_h 134 | self.use_w = use_w 135 | self.rotate = rotate 136 | self.offset = offset 137 | self.ratio = ratio 138 | self.mode = mode 139 | self.st_prob = prob 140 | self.prob = prob 141 | 142 | def set_prob(self, epoch, max_epoch): 143 | self.prob = self.st_prob * epoch / max_epoch 144 | 145 | def __call__(self, img, label): 146 | if np.random.rand() > self.prob: 147 | return img, label 148 | h = img.size(1) 149 | w = img.size(2) 150 | self.d1 = 2 151 | self.d2 = min(h, w) 152 | hh = int(1.5 * h) 153 | ww = int(1.5 * w) 154 | d = np.random.randint(self.d1, self.d2) 155 | # d = self.d 156 | # self.l = int(d*self.ratio+0.5) 157 | if self.ratio == 1: 158 | self.l = np.random.randint(1, d) 159 | else: 160 | self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1) 161 | mask = np.ones((hh, ww), np.float32) 162 | st_h = np.random.randint(d) 163 | st_w = np.random.randint(d) 164 | if self.use_h: 165 | for i in range(hh // d): 166 | s = d * i + st_h 167 | t = min(s + self.l, hh) 168 | mask[s:t, :] *= 0 169 | if self.use_w: 170 | for i in range(ww // d): 171 | s = d * i + st_w 172 | t = min(s + self.l, ww) 173 | mask[:, s:t] *= 0 174 | 175 | r = np.random.randint(self.rotate) 176 | mask = Image.fromarray(np.uint8(mask)) 177 | mask = mask.rotate(r) 178 | mask = np.asarray(mask) 179 | # mask = 1*(np.random.randint(0,3,[hh,ww])>0) 180 | mask = mask[(hh - h) // 2:(hh - h) // 2 + h, (ww - w) // 2:(ww - w) // 2 + w] 181 | 182 | mask_cp = mask.copy() 183 | 184 | mask_tensor = torch.from_numpy(mask_cp).float() 185 | if self.mode == 1: 186 | mask_tensor = 1 - mask_tensor 187 | 188 | mask_tensor = mask_tensor.expand_as(img) 189 | if self.offset: 190 | offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float() 191 | offset = (1 - mask_tensor) * offset 192 | img = img * mask_tensor + offset 193 | else: 194 | img = img * mask_tensor 195 | 196 | return img, label 197 | 198 | 199 | -------------------------------------------------------------------------------- /data_utils/datasets/voc.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | from structure.bounding_box import * 3 | import cv2 4 | import random 5 | from data_utils.datasets._utils import get_random_crop_tran, get_max_overlap, random_affine, Grid 6 | import torch 7 | 8 | 9 | class VOCDatasets(data.Dataset): 10 | def __init__(self, list_file, train=False): 11 | 12 | self.train = train 13 | self.label_path = [] 14 | self.image_path = [] 15 | 16 | with open(list_file) as f: 17 | lines = f.readlines() 18 | self.num_samples = len(lines) 19 | for line in lines: 20 | splited = line.strip().split(' ') 21 | self.image_path.append(splited[0]) 22 | self.label_path.append(splited[1]) 23 | 24 | self.grid = Grid(True, True, rotate=1, offset=0, ratio=0.5, mode=1, prob=0.7) 25 | 26 | def _get_label(self, file, size): 27 | tmp = open(file, 'r') 28 | gt = [] 29 | labels = [] 30 | difficult = [] 31 | for f in tmp.readlines(): 32 | a = list(map(float, f.strip().split(','))) 33 | gt.append(a[0:4]) 34 | labels.append(int(a[4])) 35 | difficult.append(0) 36 | tmp.close() 37 | gt_list = BoxList(gt, size) 38 | gt_list.add_field('labels', labels) 39 | gt_list.add_field('difficult', np.asarray(difficult)) 40 | return gt_list 41 | 42 | def _get_img(self, img_file): 43 | img = cv2.imread(img_file)[:, :, ::-1].copy() 44 | return img 45 | 46 | def get_data(self, idx): 47 | file_name = self.image_path[idx] 48 | gt_path = self.label_path[idx] 49 | img = self._get_img(file_name) 50 | gt_list = self._get_label(gt_path, (img.shape[1], img.shape[0])) 51 | 52 | if self.train: 53 | img, gt_list = self._data_aug(img, gt_list) 54 | img = img.copy() 55 | 56 | meta = dict() 57 | meta['fileID'] = gt_path.split('.')[0].split('/')[-1].replace('.txt', '') 58 | meta['img_width'] = img.shape[1] 59 | meta['img_height'] = img.shape[0] 60 | meta['boxlist'] = gt_list.copy() 61 | meta['img'] = img 62 | 63 | return meta 64 | 65 | def _data_aug(self, img, gt_list): 66 | if random.random() > 0.5: 67 | img = cv2.flip(img, 1) 68 | gt_list.flip(1) 69 | 70 | if random.random() > 0.5: 71 | img = torch.from_numpy(img) / 255. 72 | img = img.permute((2, 0, 1)) 73 | 74 | img, label = self.grid(img, gt_list) 75 | img = img.permute((1, 2, 0)) 76 | img = img * 255 77 | img = img.numpy() 78 | img = img.astype(np.uint8) 79 | 80 | if random.random() > 0.2: 81 | img, gt_list = random_affine(img, gt_list, degrees=5, translate=.1, scale=.1, shear=2, border=0) 82 | 83 | if random.random() > 0.2: 84 | matrix = get_random_crop_tran(img) 85 | h, w, _ = img.shape 86 | img = cv2.warpAffine(img, matrix, (w, h)) 87 | gt_list.warpAffine(matrix, (w, h)) 88 | 89 | return img, gt_list 90 | 91 | def __getitem__(self, idx): 92 | meta = self.get_data(idx) 93 | return meta 94 | 95 | def __len__(self): 96 | return self.num_samples 97 | -------------------------------------------------------------------------------- /data_utils/evaluate/voc_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | import numpy as np 4 | from structure.bounding_box import BoxList 5 | from structure.bounding_box import boxlist_iou 6 | 7 | 8 | def voc_evaluation(dataset, predictions, output_folder, box_only, **_): 9 | return do_voc_evaluation( 10 | dataset=dataset, 11 | predictions=predictions, 12 | output_folder=output_folder, 13 | # logger=logger, 14 | ) 15 | 16 | 17 | def do_voc_evaluation(dataset, predictions, output_folder, logger=None): 18 | # TODO need to make the use_07_metric format available 19 | # for the user to choose 20 | pred_boxlists = [] 21 | gt_boxlists = [] 22 | 23 | for image_id, prediction in predictions: 24 | img_info = dataset.get_img_info(image_id) 25 | image_width = img_info["width"] 26 | image_height = img_info["height"] 27 | prediction.resize((image_width, image_height)) 28 | pred_boxlists.append(prediction) 29 | gt_boxlist = dataset.get_groundtruth(image_id) 30 | gt_boxlists.append(gt_boxlist) 31 | 32 | result = eval_detection_voc( 33 | pred_boxlists=pred_boxlists, 34 | gt_boxlists=gt_boxlists, 35 | iou_thresh=0.5, 36 | use_07_metric=True, 37 | ) 38 | result_str = "mAP: {:.4f}\n".format(result["map"]) 39 | 40 | for i, ap in enumerate(result["ap"]): 41 | # TODO no background 42 | # if i == 0: # skip background 43 | # continue 44 | result_str += "{:<16}: {:.4f}\n".format( 45 | dataset.map_class_id_to_class_name(i), ap 46 | ) 47 | print(result_str) 48 | # logger.info(result_str) 49 | if output_folder: 50 | with open(os.path.join(output_folder, "result.txt"), "w") as fid: 51 | fid.write(result_str) 52 | return result 53 | 54 | 55 | def eval_detection_voc(pred_boxlists, gt_boxlists, iou_thresh=0.1, use_07_metric=False): 56 | """Evaluate on voc dataset. 57 | Args: 58 | pred_boxlists(list[BoxList]): pred boxlist, has labels and scores fields. 59 | gt_boxlists(list[BoxList]): ground truth boxlist, has labels field. 60 | iou_thresh: iou thresh 61 | use_07_metric: boolean 62 | Returns: 63 | dict represents the results 64 | """ 65 | assert len(gt_boxlists) == len( 66 | pred_boxlists 67 | ), "Length of gt and pred lists need to be same." 68 | 69 | prec, rec = calc_detection_voc_prec_rec( 70 | pred_boxlists=pred_boxlists, gt_boxlists=gt_boxlists, iou_thresh=iou_thresh 71 | ) 72 | ap = calc_detection_voc_ap(prec, rec, use_07_metric=use_07_metric) 73 | return {"ap": ap, "map": np.nanmean(ap)} 74 | 75 | 76 | def calc_detection_voc_prec_rec(gt_boxlists, pred_boxlists, iou_thresh=0.1): 77 | """Calculate precision and recall based on evaluation code of PASCAL VOC. 78 | This function calculates precision and recall of 79 | predicted bounding boxes obtained from a dataset which has :math:`N` 80 | images. 81 | The code is based on the evaluation code used in PASCAL VOC Challenge. 82 | """ 83 | n_pos = defaultdict(int) 84 | score = defaultdict(list) 85 | match = defaultdict(list) 86 | 87 | for gt_boxlist, pred_boxlist in zip(gt_boxlists, pred_boxlists): 88 | 89 | pred_bbox = pred_boxlist.box 90 | pred_label = pred_boxlist.get_field("labels") 91 | pred_score = pred_boxlist.get_field("scores") 92 | 93 | gt_bbox = gt_boxlist.box 94 | gt_label = gt_boxlist.get_field("labels") 95 | gt_difficult = gt_boxlist.get_field("difficult") 96 | 97 | for l in np.unique(np.concatenate((pred_label, gt_label)).astype(int)): 98 | pred_mask_l = pred_label == l 99 | 100 | pred_bbox_l = pred_bbox[pred_mask_l] 101 | pred_score_l = pred_score[pred_mask_l] 102 | 103 | # sort by score 104 | order = pred_score_l.argsort()[::-1] 105 | pred_bbox_l = pred_bbox_l[order] 106 | pred_score_l = pred_score_l[order] 107 | 108 | gt_mask_l = gt_label == l 109 | gt_bbox_l = gt_bbox[gt_mask_l] 110 | gt_difficult_l = gt_difficult[gt_mask_l] 111 | 112 | n_pos[l] += np.logical_not(gt_difficult_l).sum() 113 | score[l].extend(pred_score_l) 114 | 115 | if len(pred_bbox_l) == 0: 116 | continue 117 | if len(gt_bbox_l) == 0: 118 | match[l].extend((0,) * pred_bbox_l.shape[0]) 119 | continue 120 | 121 | # VOC evaluation follows integer typed bounding boxes. 122 | pred_bbox_l = pred_bbox_l.copy() 123 | pred_bbox_l[:, 2:] += 1 124 | gt_bbox_l = gt_bbox_l.copy() 125 | gt_bbox_l[:, 2:] += 1 126 | assert len(pred_bbox_l.shape) == 2 127 | assert len(gt_bbox_l.shape) == 2 128 | iou = boxlist_iou( 129 | BoxList(pred_bbox_l, gt_boxlist.size), 130 | BoxList(gt_bbox_l, gt_boxlist.size), 131 | ) 132 | gt_index = iou.argmax(axis=1) 133 | # set -1 if there is no matching ground truth 134 | gt_index[iou.max(axis=1) < iou_thresh] = -1 135 | del iou 136 | 137 | selec = np.zeros(gt_bbox_l.shape[0], dtype=bool) 138 | for gt_idx in gt_index: 139 | if gt_idx >= 0: 140 | if gt_difficult_l[gt_idx]: 141 | match[l].append(-1) 142 | else: 143 | if not selec[gt_idx]: 144 | match[l].append(1) 145 | else: 146 | match[l].append(0) 147 | selec[gt_idx] = True 148 | else: 149 | match[l].append(0) 150 | 151 | n_fg_class = max(n_pos.keys()) + 1 152 | prec = [None] * n_fg_class 153 | rec = [None] * n_fg_class 154 | 155 | for l in n_pos.keys(): 156 | score_l = np.array(score[l]) 157 | match_l = np.array(match[l], dtype=np.int8) 158 | 159 | order = score_l.argsort()[::-1] 160 | match_l = match_l[order] 161 | 162 | tp = np.cumsum(match_l == 1) 163 | fp = np.cumsum(match_l == 0) 164 | 165 | # If an element of fp + tp is 0, 166 | # the corresponding element of prec[l] is nan. 167 | prec[l] = tp / (fp + tp) 168 | # If n_pos[l] is 0, rec[l] is None. 169 | if n_pos[l] > 0: 170 | rec[l] = tp / n_pos[l] 171 | 172 | return prec, rec 173 | 174 | 175 | def calc_detection_voc_ap(prec, rec, use_07_metric=False): 176 | """Calculate average precisions based on evaluation code of PASCAL VOC. 177 | This function calculates average precisions 178 | from given precisions and recalls. 179 | The code is based on the evaluation code used in PASCAL VOC Challenge. 180 | Args: 181 | prec (list of numpy.array): A list of arrays. 182 | :obj:`prec[l]` indicates precision for class :math:`l`. 183 | If :obj:`prec[l]` is :obj:`None`, this function returns 184 | :obj:`numpy.nan` for class :math:`l`. 185 | rec (list of numpy.array): A list of arrays. 186 | :obj:`rec[l]` indicates recall for class :math:`l`. 187 | If :obj:`rec[l]` is :obj:`None`, this function returns 188 | :obj:`numpy.nan` for class :math:`l`. 189 | use_07_metric (bool): Whether to use PASCAL VOC 2007 evaluation metric 190 | for calculating average precision. The default value is 191 | :obj:`False`. 192 | Returns: 193 | ~numpy.ndarray: 194 | This function returns an array of average precisions. 195 | The :math:`l`-th value corresponds to the average precision 196 | for class :math:`l`. If :obj:`prec[l]` or :obj:`rec[l]` is 197 | :obj:`None`, the corresponding value is set to :obj:`numpy.nan`. 198 | """ 199 | 200 | n_fg_class = len(prec) 201 | ap = np.empty(n_fg_class) 202 | for l in range(n_fg_class): 203 | if prec[l] is None or rec[l] is None: 204 | ap[l] = np.nan 205 | continue 206 | 207 | if use_07_metric: 208 | # 11 point metric 209 | ap[l] = 0 210 | for t in np.arange(0.0, 1.1, 0.1): 211 | if np.sum(rec[l] >= t) == 0: 212 | p = 0 213 | else: 214 | p = np.max(np.nan_to_num(prec[l])[rec[l] >= t]) 215 | ap[l] += p / 11 216 | else: 217 | # correct AP calculation 218 | # first append sentinel values at the end 219 | mpre = np.concatenate(([0], np.nan_to_num(prec[l]), [0])) 220 | mrec = np.concatenate(([0], rec[l], [1])) 221 | 222 | mpre = np.maximum.accumulate(mpre[::-1])[::-1] 223 | 224 | # to calculate area under PR curve, look for points 225 | # where X axis (recall) changes value 226 | i = np.where(mrec[1:] != mrec[:-1])[0] 227 | 228 | # and sum (\Delta recall) * prec 229 | ap[l] = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 230 | 231 | return ap 232 | -------------------------------------------------------------------------------- /data_utils/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import math 3 | import torch 4 | from torch.utils.data.sampler import Sampler 5 | import dist_comm as comm 6 | 7 | 8 | class TrainingSampler(Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset. 10 | 11 | It is especially useful in conjunction with 12 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 13 | process can pass a DistributedSampler instance as a DataLoader sampler, 14 | and load a subset of the original dataset that is exclusive to it. 15 | 16 | .. note:: 17 | Dataset is assumed to be of constant size. 18 | 19 | Arguments: 20 | dataset: Dataset used for sampling. 21 | num_replicas (optional): Number of processes participating in 22 | distributed training. 23 | rank (optional): Rank of the current process within num_replicas. 24 | shuffle (optional): If true (default), sampler will shuffle the indices 25 | """ 26 | 27 | def __init__(self, size, shuffle=True): 28 | 29 | self.size = size 30 | self.num_replicas = comm.get_world_size() 31 | self.rank = comm.get_rank() 32 | self.epoch = 0 33 | self.num_samples = int(math.ceil(size * 1.0 / self.num_replicas)) 34 | self.total_size = self.num_samples * self.num_replicas 35 | self.shuffle = shuffle 36 | 37 | def __iter__(self): 38 | # deterministically shuffle based on epoch 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch) 41 | if self.shuffle: 42 | indices = torch.randperm(self.size, generator=g).tolist() 43 | else: 44 | indices = list(range(self.size)) 45 | 46 | # add extra samples to make it evenly divisible 47 | indices += indices[:(self.total_size - len(indices))] 48 | assert len(indices) == self.total_size 49 | 50 | # subsample 51 | indices = indices[self.rank:self.total_size:self.num_replicas] 52 | assert len(indices) == self.num_samples 53 | 54 | return iter(indices) 55 | 56 | def __len__(self): 57 | return self.num_samples 58 | 59 | def set_epoch(self, epoch): 60 | self.epoch = epoch 61 | 62 | 63 | class InferenceSampler(Sampler): 64 | """ 65 | Produce indices for inference. 66 | Inference needs to run on the __exact__ set of samples, 67 | therefore when the total number of samples is not divisible by the number of workers, 68 | this sampler produces different number of samples on different workers. 69 | """ 70 | 71 | def __init__(self, size: int): 72 | """ 73 | Args: 74 | size (int): the total number of data of the underlying dataset to sample from 75 | """ 76 | self._size = size 77 | assert size > 0 78 | self._rank = comm.get_rank() 79 | self._world_size = comm.get_world_size() 80 | 81 | shard_size = (self._size - 1) // self._world_size + 1 82 | begin = shard_size * self._rank 83 | end = min(shard_size * (self._rank + 1), self._size) 84 | self._local_indices = range(begin, end) 85 | 86 | def __iter__(self): 87 | yield from self._local_indices 88 | 89 | def __len__(self): 90 | return len(self._local_indices) 91 | -------------------------------------------------------------------------------- /dataset/voc2007/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | from yolo import create_yolov1 2 | from data_utils.build import VOCDatasets 3 | from cfg.voc import cfg 4 | from matplotlib import pyplot as plt 5 | import cv2 6 | import torch 7 | 8 | class_label = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", 9 | "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", 10 | "pottedplant", "sheep", "sofa", "train", "tvmonitor"] 11 | 12 | 13 | def draw_box(img, boxlist, conf=0.1): 14 | try: 15 | box = boxlist.bbox 16 | except: 17 | box = boxlist.box 18 | label = boxlist.get_field('labels') 19 | try: 20 | scores = boxlist.get_field('scores') 21 | except: 22 | scores = [1 for i in range(len(label))] 23 | 24 | for b, l, s in zip(box, label, scores): 25 | if s > conf: 26 | p1 = (int(b[0]), int(b[1])) 27 | p2 = (int(b[2]), int(b[3])) 28 | cv2.rectangle(img, p1, p2, (255, 0, 0), 2) 29 | font = cv2.FONT_HERSHEY_SIMPLEX 30 | img = cv2.putText(img, class_label[int(l)], p1, font, 0.5, (255, 255, 0), 1) 31 | 32 | return img 33 | 34 | 35 | def _run(): 36 | dataset = VOCDatasets('./voc2007_test.txt') 37 | 38 | model_cfg = cfg['model_cfg'] 39 | model = create_yolov1(model_cfg) 40 | checkpoint = torch.load('best_model.pth') 41 | data_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()} 42 | model.load_state_dict(data_dict, strict=True) 43 | model.cuda() 44 | model.eval() 45 | print('load model...') 46 | 47 | idx = 88 48 | # img = dataset[idx]['img'] 49 | # boxlist = dataset[idx]['boxlist'] 50 | # 51 | # img = draw_box(img, boxlist) 52 | # plt.figure(figsize=(10, 10)) 53 | # plt.imshow(img) 54 | # plt.show() 55 | 56 | img = dataset[idx]['img'] 57 | img = cv2.resize(img, (448, 448)) 58 | img = torch.from_numpy(img).float() / 255. 59 | 60 | img = img.unsqueeze(0) 61 | img = img.permute(0, 3, 1, 2) 62 | img = img.cuda() 63 | boxlist = model(img) 64 | 65 | img = dataset[idx]['img'] 66 | 67 | boxlist[0].resize((img.shape[1], img.shape[0])) 68 | img = draw_box(img, boxlist[0], conf=0.1) 69 | plt.figure(figsize=(10, 10)) 70 | plt.imshow(img) 71 | plt.show() 72 | 73 | if __name__ == '__main__': 74 | _run() 75 | -------------------------------------------------------------------------------- /dist_comm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | This file contains primitives for multi-gpu communication. 4 | This is useful when doing distributed training. 5 | """ 6 | 7 | import functools 8 | import logging 9 | import numpy as np 10 | import pickle 11 | import torch 12 | import torch.distributed as dist 13 | 14 | _LOCAL_PROCESS_GROUP = None 15 | """ 16 | A torch process group which only includes processes that on the same machine as the current process. 17 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 18 | """ 19 | 20 | 21 | def get_world_size() -> int: 22 | if not dist.is_available(): 23 | return 1 24 | if not dist.is_initialized(): 25 | return 1 26 | return dist.get_world_size() 27 | 28 | 29 | def get_rank() -> int: 30 | if not dist.is_available(): 31 | return 0 32 | if not dist.is_initialized(): 33 | return 0 34 | return dist.get_rank() 35 | 36 | 37 | def get_local_rank() -> int: 38 | """ 39 | Returns: 40 | The rank of the current process within the local (per-machine) process group. 41 | """ 42 | if not dist.is_available(): 43 | return 0 44 | if not dist.is_initialized(): 45 | return 0 46 | assert _LOCAL_PROCESS_GROUP is not None 47 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 48 | 49 | 50 | def get_local_size() -> int: 51 | """ 52 | Returns: 53 | The size of the per-machine process group, 54 | i.e. the number of processes per machine. 55 | """ 56 | if not dist.is_available(): 57 | return 1 58 | if not dist.is_initialized(): 59 | return 1 60 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 61 | 62 | 63 | def is_main_process() -> bool: 64 | return get_rank() == 0 65 | 66 | 67 | def synchronize(): 68 | """ 69 | Helper function to synchronize (barrier) among all processes when 70 | using distributed training 71 | """ 72 | if not dist.is_available(): 73 | return 74 | if not dist.is_initialized(): 75 | return 76 | world_size = dist.get_world_size() 77 | if world_size == 1: 78 | return 79 | dist.barrier() 80 | 81 | 82 | @functools.lru_cache() 83 | def _get_global_gloo_group(): 84 | """ 85 | Return a process group based on gloo backend, containing all the ranks 86 | The result is cached. 87 | """ 88 | if dist.get_backend() == "nccl": 89 | return dist.new_group(backend="gloo") 90 | else: 91 | return dist.group.WORLD 92 | 93 | 94 | def _serialize_to_tensor(data, group): 95 | backend = dist.get_backend(group) 96 | assert backend in ["gloo", "nccl"] 97 | device = torch.device("cpu" if backend == "gloo" else "cuda") 98 | 99 | buffer = pickle.dumps(data) 100 | if len(buffer) > 1024 ** 3: 101 | logger = logging.getLogger(__name__) 102 | logger.warning( 103 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 104 | get_rank(), len(buffer) / (1024 ** 3), device 105 | ) 106 | ) 107 | storage = torch.ByteStorage.from_buffer(buffer) 108 | tensor = torch.ByteTensor(storage).to(device=device) 109 | return tensor 110 | 111 | 112 | def _pad_to_largest_tensor(tensor, group): 113 | """ 114 | Returns: 115 | list[int]: size of the tensor, on each rank 116 | Tensor: padded tensor that has the max size 117 | """ 118 | world_size = dist.get_world_size(group=group) 119 | assert ( 120 | world_size >= 1 121 | ), "comm.gather/all_gather must be called from ranks within the given group!" 122 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 123 | size_list = [ 124 | torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) 125 | ] 126 | dist.all_gather(size_list, local_size, group=group) 127 | size_list = [int(size.item()) for size in size_list] 128 | 129 | max_size = max(size_list) 130 | 131 | # we pad the tensor because torch all_gather does not support 132 | # gathering tensors of different shapes 133 | if local_size != max_size: 134 | padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) 135 | tensor = torch.cat((tensor, padding), dim=0) 136 | return size_list, tensor 137 | 138 | 139 | def all_gather(data, group=None): 140 | """ 141 | Run all_gather on arbitrary picklable data (not necessarily tensors). 142 | 143 | Args: 144 | data: any picklable object 145 | group: a torch process group. By default, will use a group which 146 | contains all ranks on gloo backend. 147 | 148 | Returns: 149 | list[data]: list of data gathered from each rank 150 | """ 151 | if get_world_size() == 1: 152 | return [data] 153 | if group is None: 154 | group = _get_global_gloo_group() 155 | if dist.get_world_size(group) == 1: 156 | return [data] 157 | 158 | tensor = _serialize_to_tensor(data, group) 159 | 160 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 161 | max_size = max(size_list) 162 | 163 | # receiving Tensor from all ranks 164 | tensor_list = [ 165 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 166 | ] 167 | dist.all_gather(tensor_list, tensor, group=group) 168 | 169 | data_list = [] 170 | for size, tensor in zip(size_list, tensor_list): 171 | buffer = tensor.cpu().numpy().tobytes()[:size] 172 | data_list.append(pickle.loads(buffer)) 173 | 174 | return data_list 175 | 176 | 177 | def gather(data, dst=0, group=None): 178 | """ 179 | Run gather on arbitrary picklable data (not necessarily tensors). 180 | 181 | Args: 182 | data: any picklable object 183 | dst (int): destination rank 184 | group: a torch process group. By default, will use a group which 185 | contains all ranks on gloo backend. 186 | 187 | Returns: 188 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 189 | an empty list. 190 | """ 191 | if get_world_size() == 1: 192 | return [data] 193 | if group is None: 194 | group = _get_global_gloo_group() 195 | if dist.get_world_size(group=group) == 1: 196 | return [data] 197 | rank = dist.get_rank(group=group) 198 | 199 | tensor = _serialize_to_tensor(data, group) 200 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 201 | 202 | # receiving Tensor from all ranks 203 | if rank == dst: 204 | max_size = max(size_list) 205 | tensor_list = [ 206 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 207 | ] 208 | dist.gather(tensor, tensor_list, dst=dst, group=group) 209 | 210 | data_list = [] 211 | for size, tensor in zip(size_list, tensor_list): 212 | buffer = tensor.cpu().numpy().tobytes()[:size] 213 | data_list.append(pickle.loads(buffer)) 214 | return data_list 215 | else: 216 | dist.gather(tensor, [], dst=dst, group=group) 217 | return [] 218 | 219 | 220 | def shared_random_seed(): 221 | """ 222 | Returns: 223 | int: a random number that is the same across all workers. 224 | If workers need a shared RNG, they can use this shared seed to 225 | create one. 226 | 227 | All workers must call this function, otherwise it will deadlock. 228 | """ 229 | ints = np.random.randint(2 ** 31) 230 | all_ints = all_gather(ints) 231 | return all_ints[0] 232 | 233 | 234 | def reduce_dict(input_dict, average=True): 235 | """ 236 | Reduce the values in the dictionary from all processes so that process with rank 237 | 0 has the reduced results. 238 | 239 | Args: 240 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. 241 | average (bool): whether to do average or sum 242 | 243 | Returns: 244 | a dict with the same keys as input_dict, after reduction. 245 | """ 246 | world_size = get_world_size() 247 | if world_size < 2: 248 | return input_dict 249 | with torch.no_grad(): 250 | names = [] 251 | values = [] 252 | # sort the keys so that they are consistent across processes 253 | for k in sorted(input_dict.keys()): 254 | names.append(k) 255 | values.append(input_dict[k]) 256 | values = torch.stack(values, dim=0) 257 | dist.reduce(values, dst=0) 258 | if dist.get_rank() == 0 and average: 259 | # only main process gets accumulated, so only divide by 260 | # world_size in this case 261 | values /= world_size 262 | reduced_dict = {k: v for k, v in zip(names, values)} 263 | return reduced_dict 264 | -------------------------------------------------------------------------------- /fit_voc_to_yolo.py: -------------------------------------------------------------------------------- 1 | import xml.etree.ElementTree as ET 2 | import os 3 | from tqdm import tqdm 4 | import argparse 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--dir_path', type=str, default='dataset/') 10 | args = parser.parse_args() 11 | return args 12 | 13 | 14 | sets = [('2007', 'test'), ('2007', 'train'), ('2007', 'val')] # , ('2012', 'train'), ('2012', 'val')] 15 | 16 | classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", 17 | "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", 18 | "pottedplant", "sheep", "sofa", "train", "tvmonitor"] 19 | 20 | 21 | def convert_xml(file_path, out_file): 22 | out_file = open(out_file, 'w') 23 | tree = ET.parse(file_path) 24 | root = tree.getroot() 25 | size = root.find('size') 26 | w = int(size.find('width').text) 27 | h = int(size.find('height').text) 28 | 29 | for obj in root.iter('object'): 30 | difficult = obj.find('difficult').text 31 | cls = obj.find('name').text 32 | if cls not in classes or int(difficult) == 1: 33 | continue 34 | cls_id = classes.index(cls) 35 | xmlbox = obj.find('bndbox') 36 | 37 | bb = (max(1, float(xmlbox.find('xmin').text)), 38 | max(1, float(xmlbox.find('ymin').text)), 39 | min(w - 1, float(xmlbox.find('xmax').text)), 40 | min(h - 1, float(xmlbox.find('ymax').text))) 41 | 42 | out_file.write(",".join([str(a) for a in bb]) + ',' + str(cls_id) + '\n') 43 | 44 | out_file.close() 45 | 46 | 47 | if __name__ == '__main__': 48 | 49 | args = parse_args() 50 | root_dir = args.dir_path 51 | 52 | for data_ in sets: 53 | 54 | if not os.path.exists(root_dir + 'voc%s/Label/' % (data_[0])): 55 | os.makedirs(root_dir + 'voc%s/Label/' % (data_[0])) 56 | 57 | name_list = open(root_dir + 'voc%s/ImageSets/Main/%s.txt' % (data_[0], data_[1])).read().strip().split() 58 | 59 | print(len(name_list)) 60 | name_list = tqdm(name_list) 61 | data_list = open('voc%s_%s.txt' % (data_[0], data_[1]), 'w') 62 | 63 | file_writer = '' 64 | for i, xml_name in enumerate(name_list): 65 | file_path = root_dir + 'voc%s/Annotations/%s.xml' % (data_[0], xml_name) 66 | label_file = root_dir + 'voc%s/Label/%s.txt' % (data_[0], xml_name) 67 | img_file = root_dir + 'voc%s/JPEGImages/%s.jpg' % (data_[0], xml_name) 68 | convert_xml(file_path, label_file) 69 | 70 | file_writer += img_file + ' ' + label_file + '\n' 71 | 72 | data_list.write(file_writer) 73 | file_writer = '' 74 | 75 | data_list.close() 76 | 77 | os.system('cat voc2007_train.txt voc2007_val.txt >> train.txt ') 78 | -------------------------------------------------------------------------------- /imgs/demo1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leviome/yolo_1_pytorch/e18705cbf22034d97e043e5faa22466c129238d5/imgs/demo1.png -------------------------------------------------------------------------------- /imgs/demo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leviome/yolo_1_pytorch/e18705cbf22034d97e043e5faa22466c129238d5/imgs/demo2.png -------------------------------------------------------------------------------- /imgs/yolo.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leviome/yolo_1_pytorch/e18705cbf22034d97e043e5faa22466c129238d5/imgs/yolo.PNG -------------------------------------------------------------------------------- /structure/__init__.py: -------------------------------------------------------------------------------- 1 | from structure.bounding_box import * -------------------------------------------------------------------------------- /structure/bounding_box.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | FLIP_LR = 1 4 | FLIP_UD = 0 5 | 6 | 7 | def BoxListCat(boxlist): 8 | size = boxlist[0].size 9 | # print(size) 10 | bboxes = [] 11 | labels = [] 12 | scores = [] 13 | for b in boxlist: 14 | assert size == b.size, ("box must in the same scale") 15 | bboxes.append(b.box) 16 | labels.extend(b.get_field('labels')) 17 | scores.extend(b.get_field('scores')) 18 | 19 | box = np.concatenate(bboxes, axis=0) 20 | scores = np.asarray(scores) 21 | bboxes = BoxList(box, size) 22 | bboxes.add_field('scores', scores) 23 | bboxes.add_field('labels', labels) 24 | # print(bboxes.size) 25 | return bboxes 26 | 27 | 28 | class BoxList(object): 29 | 30 | def __init__(self, box, size, mode="xyxy"): 31 | self.box = np.asarray(box) 32 | self.size = size 33 | self.mode = mode 34 | self.extra_fields = {} 35 | assert self.mode == "xyxy" 36 | 37 | def copy(self): 38 | new = BoxList(self.box.copy(), self.size) 39 | for k, v in self.extra_fields.items(): 40 | new.extra_fields[k] = v.copy() 41 | return new 42 | 43 | def add_field(self, field, field_data): 44 | self.extra_fields[field] = field_data 45 | 46 | def get_field(self, field): 47 | return self.extra_fields[field] 48 | 49 | def has_field(self, field): 50 | return field in self.extra_fields 51 | 52 | def fields(self): 53 | return list(self.extra_fields.keys()) 54 | 55 | def _copy_extra_fields(self, bbox): 56 | for k, v in bbox.extra_fields.items(): 57 | self.extra_fields[k] = v 58 | 59 | def resize(self, r_size): 60 | w, h = self.size 61 | n_w, n_h = r_size 62 | 63 | self.box[:, 0] = self.box[:, 0] / w * n_w 64 | self.box[:, 1] = self.box[:, 1] / h * n_h 65 | self.box[:, 2] = self.box[:, 2] / w * n_w 66 | self.box[:, 3] = self.box[:, 3] / h * n_h 67 | 68 | self.size = r_size 69 | return # self.box 70 | 71 | def crop(self, box): 72 | x1, y1, x2, y2 = box 73 | w = x2 - x1 74 | h = y2 - y1 75 | 76 | self.box[:, 0] = np.clip(self.box[:, 0] - x1, 1, w) 77 | self.box[:, 1] = np.clip(self.box[:, 1] - y1, 1, h) 78 | self.box[:, 2] = np.clip(self.box[:, 2] - x1, 1, w) 79 | self.box[:, 3] = np.clip(self.box[:, 3] - y1, 1, h) 80 | self.size = (w, h) 81 | 82 | return 83 | 84 | def flip(self, ops): 85 | w, h = self.size 86 | if ops == FLIP_LR: 87 | bw = self.box[:, 2] - self.box[:, 0] 88 | self.box[:, 0] = np.clip(w - (self.box[:, 0] + bw), 1, w) 89 | self.box[:, 2] = np.clip(w - (self.box[:, 2] - bw), 1, w) 90 | self.box[:, 1] = np.clip(self.box[:, 1], 1, h) 91 | self.box[:, 3] = np.clip(self.box[:, 3], 1, h) 92 | if ops == FLIP_UD: 93 | bh = self.box[:, 3] - self.box[:, 1] 94 | self.box[:, 0] = np.clip(self.box[:, 0], 1, w) 95 | self.box[:, 2] = np.clip(self.box[:, 2], 1, w) 96 | self.box[:, 1] = np.clip(h - (self.box[:, 1] + bh), 1, h) 97 | self.box[:, 3] = np.clip(h - (self.box[:, 3] - bh), 1, h) 98 | else: 99 | ValueError("Only support 0,1") 100 | return 101 | 102 | def warpAffine(self, matrix, size=None): 103 | if size: 104 | self.size = size 105 | p1_list = np.concatenate([self.box[:, :2], np.ones((self.box.shape[0], 1))], axis=1) 106 | p2_list = np.concatenate([self.box[:, 2:4], np.ones((self.box.shape[0], 1))], axis=1) 107 | 108 | box = [] 109 | for p1, p2 in zip(p1_list, p2_list): 110 | new_pt1 = np.dot(matrix, p1.T) 111 | new_pt2 = np.dot(matrix, p2.T) 112 | 113 | new_pt1[0] = min(max(1, new_pt1[0]), self.size[0]) 114 | new_pt1[1] = min(max(1, new_pt1[1]), self.size[1]) 115 | new_pt2[0] = min(max(1, new_pt2[0]), self.size[0]) 116 | new_pt2[1] = min(max(1, new_pt2[1]), self.size[1]) 117 | 118 | box += [[new_pt1[0], new_pt1[1], new_pt2[0], new_pt2[1]]] 119 | 120 | self.box = np.asarray(box) 121 | 122 | def rot90(self, time): 123 | 124 | time = time % 4 125 | for i in range(time): 126 | w, h = self.size 127 | flipped = self.box.copy() 128 | flipped[:, 0] = self.box[:, 1] 129 | flipped[:, 2] = self.box[:, 3] 130 | flipped[:, 1] = w - self.box[:, 2] 131 | flipped[:, 3] = w - self.box[:, 0] 132 | self.box = flipped 133 | self.size = h, w 134 | 135 | w, h = self.size 136 | self.box[:, 0] = np.clip(self.box[:, 0], 1, w) 137 | self.box[:, 1] = np.clip(self.box[:, 1], 1, h) 138 | self.box[:, 2] = np.clip(self.box[:, 2], 1, w) 139 | self.box[:, 3] = np.clip(self.box[:, 3], 1, h) 140 | 141 | return 142 | 143 | def area(self): 144 | box = self.box 145 | assert box.shape[0] > 0, box 146 | assert box.shape[1] == 4, box 147 | if self.mode == "xyxy": 148 | TO_REMOVE = 0 149 | area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE) 150 | else: 151 | raise RuntimeError("Should not be here") 152 | 153 | return area 154 | 155 | def __len__(self): 156 | return self.box.shape[0] 157 | 158 | 159 | # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py 160 | # with slight modifications 161 | def boxlist_iou(boxlist1, boxlist2): 162 | """Compute the intersection over union of two set of boxes. 163 | The box order must be (xmin, ymin, xmax, ymax). 164 | Arguments: 165 | box1: (BoxList) bounding boxes, sized [N,4]. 166 | box2: (BoxList) bounding boxes, sized [M,4]. 167 | Returns: 168 | (tensor) iou, sized [N,M]. 169 | Reference: 170 | https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py 171 | """ 172 | if boxlist1.size != boxlist2.size: 173 | raise RuntimeError( 174 | "boxlists should have same image size, got {}, {}".format(boxlist1, boxlist2)) 175 | 176 | N = len(boxlist1) 177 | M = len(boxlist2) 178 | 179 | area1 = boxlist1.area() 180 | area2 = boxlist2.area() 181 | 182 | box1, box2 = boxlist1.box, boxlist2.box 183 | 184 | lt = np.maximum(box1[:, None, :2], box2[:, :2]) # [N,M,2] 185 | rb = np.minimum(box1[:, None, 2:], box2[:, 2:]) # [N,M,2] 186 | 187 | TO_REMOVE = 1 188 | 189 | wh = np.clip(rb - lt + TO_REMOVE, 0, None) # [N,M,2] 190 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 191 | 192 | iou = inter / (area1[:, None] + area2 - inter) 193 | return iou 194 | 195 | 196 | from torch import nn 197 | 198 | 199 | def pool_nms(heat, kernel=3): 200 | heat = heat.sigmoid() 201 | 202 | pad = (kernel - 1) // 2 203 | hmax = nn.functional.max_pool2d( 204 | heat, (kernel, kernel), stride=1, padding=pad) 205 | keep = (hmax == heat).float() 206 | return heat * keep 207 | 208 | 209 | def pool_nms_no_sigmoid(heat, kernel=3): 210 | pad = (kernel - 1) // 2 211 | hmax = nn.functional.max_pool2d( 212 | heat, (kernel, kernel), stride=1, padding=pad) 213 | keep = (hmax == heat).float() 214 | return heat * keep 215 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # -------------------------------------------------------- 4 | # @Author: levi 5 | # Copyright (c) 2020 6 | # Created by Levi levio123@163.com 7 | # -------------------------------------------------------- 8 | import torch 9 | import os.path as osp 10 | from cfg.voc import cfg 11 | from yolo.yolov1 import create_yolov1 12 | from data_utils.build import make_dist_voc_loader 13 | from torch.cuda import amp 14 | 15 | 16 | def train_step(epochs, model, train_loader, test_loader, optim, classes, device='cuda'): 17 | scaler = amp.GradScaler(enabled=True) 18 | for epoch in range(epochs): 19 | print('epoch =', epoch) 20 | for idx, (img, gt_info) in enumerate(train_loader): 21 | optim.zero_grad() 22 | img = img.to(device) 23 | loss_dict = model(img, gt_info) 24 | loss = sum(l for l in loss_dict.values()) 25 | print(int(loss), end=' ') 26 | scaler.scale(loss).backward() 27 | optim.step() 28 | torch.save(model.state_dict(), 'best_model.pth') 29 | 30 | 31 | def train(): 32 | train_cfg = cfg['train_cfg'] 33 | model_cfg = cfg['model_cfg'] 34 | model_name = model_cfg['model_type'] 35 | epochs = train_cfg['epochs'] 36 | classes = train_cfg['classes'] 37 | lr = train_cfg['lr'] 38 | bs = train_cfg['batch_size'] 39 | device = train_cfg['device'] 40 | out_dir = train_cfg['out_dir'] 41 | resume = train_cfg['resume'] 42 | use_sgd = train_cfg['use_sgd'] 43 | mile = train_cfg['milestone'] 44 | gamma = train_cfg['gamma'] 45 | train_root = train_cfg['dataroot'] 46 | img_size = train_cfg['img_size'] 47 | 48 | model = create_yolov1(model_cfg) 49 | model = model.to(device) 50 | if osp.exists('best_model.pth'): 51 | checkpoint = torch.load('best_model.pth') 52 | data_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()} 53 | model.load_state_dict(data_dict, strict=True) 54 | model.cuda() 55 | 56 | optimizer = torch.optim.Adam(model.parameters(), 57 | lr=lr[0], 58 | weight_decay=5e-5) 59 | 60 | train_loader = make_dist_voc_loader(osp.join(train_root, 'train.txt'), 61 | img_size=img_size, 62 | batch_size=bs, 63 | train=True, 64 | ) 65 | test_loader = make_dist_voc_loader(osp.join(train_root, 'voc2007_test.txt'), 66 | img_size=img_size, 67 | batch_size=16, 68 | train=False, 69 | ) 70 | 71 | train_step(epochs, model, train_loader, test_loader, optimizer, classes, device=device) 72 | 73 | 74 | if __name__ == '__main__': 75 | train() 76 | -------------------------------------------------------------------------------- /yolo/__init__.py: -------------------------------------------------------------------------------- 1 | from yolo.yolov1 import create_yolov1 2 | -------------------------------------------------------------------------------- /yolo/darknet.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | layer_configs = [ 7 | # Unit1 (2) 8 | (32, 3, True), 9 | (64, 3, True), 10 | # Unit2 (3) 11 | (128, 3, False), 12 | (64, 1, False), 13 | (128, 3, True), 14 | # Unit3 (3) 15 | (256, 3, False), 16 | (128, 1, False), 17 | (256, 3, True), 18 | # Unit4 (5) 19 | (512, 3, False), 20 | (256, 1, False), 21 | (512, 3, False), 22 | (256, 1, False), 23 | (512, 3, True), 24 | # Unit5 (5) 25 | (1024, 3, False), 26 | (512, 1, False), 27 | (1024, 3, False), 28 | (512, 1, False), 29 | (1024, 3, False), 30 | ] 31 | 32 | 33 | def load_conv_bn(buf, start, conv_model, bn_model): 34 | num_w = conv_model.weight.numel() 35 | 36 | num_b = bn_model.bias.numel() 37 | bn_model.bias.data.copy_(torch.from_numpy(buf[start:start + num_b])) 38 | start = start + num_b 39 | bn_model.weight.data.copy_(torch.from_numpy(buf[start:start + num_b])) 40 | start = start + num_b 41 | bn_model.running_mean.copy_(torch.from_numpy(buf[start:start + num_b])) 42 | start = start + num_b 43 | bn_model.running_var.copy_(torch.from_numpy(buf[start:start + num_b])) 44 | start = start + num_b 45 | 46 | conv_weight = torch.from_numpy(buf[start:start + num_w]) 47 | conv_model.weight.data.copy_(conv_weight.view_as(conv_model.weight)) 48 | start = start + num_w 49 | 50 | return start 51 | 52 | 53 | class conv_block(nn.Module): 54 | 55 | def __init__(self, inplane, outplane, kernel_size, pool, stride=1): 56 | super(conv_block, self).__init__() 57 | 58 | pad = 1 if kernel_size == 3 else 0 59 | self.conv = nn.Conv2d(inplane, outplane, kernel_size, stride=stride, padding=pad, bias=False) 60 | self.bn = nn.BatchNorm2d(outplane) 61 | self.act = nn.LeakyReLU(0.1) 62 | self.pool = pool # MaxPool2d(2,stride = 2) 63 | 64 | def forward(self, x): 65 | out = self.conv(x) 66 | out = self.bn(out) 67 | out = self.act(out) 68 | 69 | if self.pool: 70 | out = F.max_pool2d(out, kernel_size=2, stride=2) 71 | 72 | return out 73 | 74 | 75 | class darknet_19(nn.Module): 76 | 77 | def __init__(self, cls_num=1000): 78 | super(darknet_19, self).__init__() 79 | self.class_num = cls_num 80 | self.feature = self.make_layers(3, layer_configs) 81 | 82 | def make_layers(self, inplane, cfg): 83 | layers = [] 84 | 85 | for outplane, kernel_size, pool in cfg: 86 | layers.append(conv_block(inplane, outplane, kernel_size, pool)) 87 | inplane = outplane 88 | 89 | return nn.Sequential(*layers) 90 | 91 | def load_weight(self, weight_file): 92 | print("Load pretrained models !") 93 | 94 | fp = open(weight_file, 'rb') 95 | header = np.fromfile(fp, count=4, dtype=np.int32) 96 | header = torch.from_numpy(header) 97 | buf = np.fromfile(fp, dtype=np.float32) 98 | 99 | start = 0 100 | for idx, m in enumerate(self.feature.modules()): 101 | if isinstance(m, nn.Conv2d): 102 | conv = m 103 | if isinstance(m, nn.BatchNorm2d): 104 | bn = m 105 | start = load_conv_bn(buf, start, conv, bn) 106 | 107 | assert start == buf.shape[0] 108 | 109 | def forward(self, x): 110 | 111 | output = self.feature(x) 112 | 113 | return output 114 | -------------------------------------------------------------------------------- /yolo/decoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from structure.bounding_box import BoxList 3 | from torchvision.ops import nms 4 | 5 | 6 | def yolo_decoder(pred, img_size, conf=0.02, nms_threshold=0.5): 7 | ''' 8 | pred_cls = [C*2,S,S] 9 | pred_response = [2,S,S] 10 | pred_bboxes = [4*2,S,S] 11 | ''' 12 | 13 | pred_cls, pred_response, pred_bboxes = pred 14 | class_num, h, w = pred_cls.shape 15 | box_num = pred_response.shape[0] 16 | 17 | pred_cls = pred_cls.view(1, class_num, h, w).permute(2, 3, 0, 1).contiguous() 18 | pred_cls = pred_cls.repeat(1, 1, box_num, 1).view(-1, class_num) 19 | 20 | pred_bboxes = pred_bboxes.view(box_num, 4, h, w).permute(2, 3, 0, 1).contiguous().view(-1, 4) 21 | pred_response = pred_response.view(box_num, 1, h, w).permute(2, 3, 0, 1).contiguous().view(-1, 1) 22 | 23 | # 找最anchor中置信度最高的 24 | pred_mask = (pred_response > conf).view(-1) 25 | pred_bboxes = pred_bboxes[pred_mask] 26 | pred_response = pred_response[pred_mask] 27 | pred_cls = pred_cls[pred_mask] 28 | 29 | bboxes = [] 30 | scores = [] 31 | labels = [] 32 | 33 | for cls in range(class_num): 34 | score = pred_cls[:, cls].float() * pred_response[:, 0] 35 | mask_a = score.gt(conf) 36 | bbox = pred_bboxes[mask_a] 37 | cls_prob = score[mask_a] 38 | if bbox.shape[0] > 0: 39 | bbox[:, 2] = bbox[:, 2] * bbox[:, 2] 40 | bbox[:, 3] = bbox[:, 3] * bbox[:, 3] 41 | 42 | bbox[:, 0] = bbox[:, 0] - bbox[:, 2] / 2 43 | bbox[:, 1] = bbox[:, 1] - bbox[:, 3] / 2 44 | bbox[:, 2] = bbox[:, 0] + bbox[:, 2] 45 | bbox[:, 3] = bbox[:, 1] + bbox[:, 3] 46 | pre_cls_box = bbox.data 47 | pre_cls_score = cls_prob.data.view(-1) 48 | 49 | keep = nms(pre_cls_box, pre_cls_score, nms_threshold) 50 | # keep = [pre_cls_box, pre_cls_score, nms_threshold] 51 | 52 | for conf_keep, loc_keep in zip(pre_cls_score[keep], pre_cls_box[keep]): 53 | # for conf_keep, loc_keep in zip(pre_cls_score, pre_cls_box): 54 | bboxes.append(loc_keep.tolist()) 55 | scores.append(conf_keep.tolist()) 56 | labels.append(cls) 57 | 58 | if len(bboxes) > 0: 59 | scores = np.asarray(scores) 60 | bboxes = np.asarray(bboxes) 61 | labels = np.asarray(labels) 62 | box = BoxList(bboxes, (w, h)) 63 | box.add_field('scores', scores) 64 | box.add_field('labels', labels) 65 | else: 66 | box = BoxList(np.asarray([[0., 0., 1., 1.]]), (w, h)) 67 | box.add_field('scores', np.asarray([0.])) 68 | box.add_field('labels', np.asarray([0.])) 69 | 70 | box.resize(img_size) 71 | 72 | return box 73 | -------------------------------------------------------------------------------- /yolo/encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def yolo_encoder(box_list, ceil_size, box_num, cls_num): 5 | ''' 6 | pred_cls = [C,S,S] 7 | pred_response = [2,S,S] 8 | pred_bboxes = [4*2,S,S] 9 | ''' 10 | w, h = ceil_size 11 | box_list.resize(ceil_size) 12 | labels = box_list.get_field('labels') 13 | 14 | bb_class = np.zeros((cls_num, h, w)) 15 | bb_response = np.zeros((box_num, h, w)) 16 | bb_boxes = np.zeros((box_num * 4, h, w)) 17 | 18 | # TODO avoid loop 19 | for gt, l in zip(box_list.box, labels): 20 | local_x = min(int(round((gt[2] + gt[0]) / 2)), w - 1) 21 | local_y = min(int(round((gt[3] + gt[1]) / 2)), h - 1) 22 | 23 | for j in range(box_num): 24 | bb_response[j, local_y, local_x] = 1 25 | bb_boxes[j * 4 + 0, local_y, local_x] = (gt[2] + gt[0]) / 2 26 | bb_boxes[j * 4 + 1, local_y, local_x] = (gt[3] + gt[1]) / 2 27 | bb_boxes[j * 4 + 2, local_y, local_x] = np.sqrt(max((gt[2] - gt[0]), 0.01)) 28 | bb_boxes[j * 4 + 3, local_y, local_x] = np.sqrt(max((gt[3] - gt[1]), 0.01)) 29 | 30 | bb_class[l, local_y, local_x] = 1 31 | boxes = (bb_class, bb_response, bb_boxes) 32 | return boxes 33 | -------------------------------------------------------------------------------- /yolo/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from yolo.encoder import yolo_encoder 5 | 6 | 7 | class yolov1_loss(nn.Module): 8 | def __init__(self, l_coord, l_obj, l_noobj): 9 | super(yolov1_loss, self).__init__() 10 | self.l_coord = l_coord 11 | self.l_noobj = l_noobj 12 | self.l_obj = l_obj 13 | 14 | def _prepare_target(self, meta, ceil_size, bbox_num, cls_num, device): 15 | target_cls = [] 16 | target_obj = [] 17 | target_box = [] 18 | for target in meta: 19 | t = target['boxlist'] 20 | t.resize(ceil_size) 21 | cls, obj, box = yolo_encoder(t, ceil_size, bbox_num, cls_num) 22 | target_cls.append(torch.from_numpy(cls).unsqueeze(dim=0).float()) 23 | target_obj.append(torch.from_numpy(obj).unsqueeze(dim=0).float()) 24 | target_box.append(torch.from_numpy(box).unsqueeze(dim=0).float()) 25 | target_cls = torch.cat(target_cls).to(device) 26 | target_obj = torch.cat(target_obj).to(device) 27 | target_box = torch.cat(target_box).to(device) 28 | return target_cls, target_obj, target_box 29 | 30 | def offset2box(self, box): 31 | box[:, 0] = box[:, 0] - box[:, 2] / 2 32 | box[:, 1] = box[:, 1] - box[:, 3] / 2 33 | box[:, 2] = box[:, 0] + box[:, 2] 34 | box[:, 3] = box[:, 1] + box[:, 3] 35 | 36 | return box 37 | 38 | def get_kp_torch_batch(self, pred, conf, topk=100): 39 | b, c, h, w = pred.shape 40 | pred = pred.contiguous().view(-1) 41 | pred[pred < conf] = 0 42 | score, topk_idx = torch.topk(pred, k=topk) 43 | 44 | batch = torch.floor_divide(topk_idx, (h * w * c)) 45 | 46 | cls = torch.floor_divide((topk_idx - batch * h * w * c), (h * w)) 47 | 48 | channel = (topk_idx - batch * h * w * c) - (cls * h * w) 49 | 50 | x = channel % w 51 | y = torch.floor_divide(channel, w) 52 | 53 | return x.view(-1), y.view(-1), cls.view(-1), batch.view(-1) 54 | 55 | def compute_iou(self, box1, box2): 56 | '''Compute the intersection over union of two set of boxes, each box is [x1,y1,x2,y2]. 57 | Args: 58 | box1: (tensor) bounding boxes, sized [N,4]. 59 | box2: (tensor) bounding boxes, sized [M,4]. 60 | Return: 61 | (tensor) iou, sized [N,M]. 62 | ''' 63 | 64 | lt = torch.max( 65 | box1[:, :2], # [N,2] -> [N,1,2] -> [N,M,2] 66 | box2[:, :2], # [M,2] -> [1,M,2] -> [N,M,2] 67 | ) 68 | 69 | rb = torch.min( 70 | box1[:, 2:], # [N,2] -> [N,1,2] -> [N,M,2] 71 | box2[:, 2:], # [M,2] -> [1,M,2] -> [N,M,2] 72 | ) 73 | 74 | wh = rb - lt # [N,M,2] 75 | wh[wh < 0] = 0 # clip at 0 76 | inter = wh[:, 0] * wh[:, 1] # [N,M] 77 | 78 | area1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]) # [N,] 79 | area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) # [M,] 80 | 81 | iou = inter / (area1 + area2 - inter + 1e-4) 82 | return iou 83 | 84 | def forward(self, pred, meta): 85 | pred_cls, pred_response, pred_bboxes = pred 86 | device = pred_cls.get_device() 87 | 88 | B_size, cls_num, h, w = pred_cls.shape 89 | bbox_num = pred_response.shape[1] 90 | 91 | ceil_size = (w, h) 92 | label_cls, label_response, label_bboxes = self._prepare_target(meta, ceil_size, bbox_num, cls_num, device) 93 | 94 | device = pred_cls.get_device() 95 | label_cls = label_cls.to(device) 96 | label_response = label_response.to(device) 97 | label_bboxes = label_bboxes.to(device) 98 | 99 | with torch.no_grad(): 100 | tmp_response = label_response.sum(dim=1).unsqueeze(dim=1) 101 | k = (tmp_response > 0.9).sum() 102 | x_list, y_list, c_list, b_list = self.get_kp_torch_batch(tmp_response, conf=0.5, topk=int(k)) 103 | 104 | t_responses = label_response[b_list, :, y_list, x_list] 105 | p_responses = pred_response[b_list, :, y_list, x_list] 106 | 107 | t_boxes = label_bboxes[b_list, :, y_list, x_list] 108 | p_boxes = pred_bboxes[b_list, :, y_list, x_list] 109 | 110 | t_classes = label_cls[b_list, :, y_list, x_list] 111 | p_classes = pred_cls[b_list, :, y_list, x_list] 112 | 113 | loss_pos_cls = F.mse_loss(p_classes, t_classes, reduction='sum') 114 | 115 | t_offset = t_boxes.view(-1, 4) 116 | p_offset = p_boxes.view(-1, 4) 117 | with torch.no_grad(): 118 | t_box = self.offset2box(t_offset.clone().float()).to(device) 119 | p_box = self.offset2box(p_offset.clone().float()).to(device) 120 | iou = self.compute_iou(t_box, p_box).view(-1, bbox_num) 121 | 122 | idx = iou.argmax(dim=1) 123 | idx = idx.unsqueeze(dim=1) 124 | loss_pos_response = F.mse_loss(p_responses.gather(1, idx), iou.gather(1, idx), reduction='sum') 125 | 126 | idx = idx.unsqueeze(dim=1) 127 | p_boxes = p_boxes.view(-1, bbox_num, 4) 128 | t_boxes = t_boxes.view(-1, bbox_num, 4) 129 | off_idx = idx.repeat(1, 1, 4) 130 | loss_pos_offset = F.mse_loss(p_boxes.gather(1, off_idx), t_boxes.gather(1, off_idx), reduction='sum') 131 | 132 | neg_mask = label_response < 1 133 | neg_pred = pred_response[neg_mask] 134 | neg_target = label_response[neg_mask] 135 | 136 | loss_neg_response = F.mse_loss(neg_pred, neg_target, reduction='sum') / B_size * self.l_noobj 137 | loss_pos_response = loss_pos_response / B_size * self.l_obj 138 | loss_pos_offset = loss_pos_offset / B_size * self.l_coord 139 | loss_pos_cls = loss_pos_cls / B_size 140 | 141 | return {'pObj': loss_pos_response, 142 | 'nObj': loss_neg_response, 143 | 'cls': loss_pos_cls, 144 | 'offset': loss_pos_offset} 145 | -------------------------------------------------------------------------------- /yolo/yolov1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | from yolo.decoder import yolo_decoder 5 | from yolo.darknet import darknet_19, conv_block 6 | from yolo.loss import yolov1_loss 7 | 8 | 9 | def create_yolov1(cfg): 10 | cls_num = cfg['class_num'] 11 | box_num = cfg['box_num'] 12 | ceil_size = cfg['ceil_size'] 13 | pretrained = cfg['pretrained'] 14 | l_coord = cfg['l_coord'] 15 | l_noobj = cfg['l_noobj'] 16 | l_obj = cfg['l_obj'] 17 | conv_mode = cfg['conv_mode'] 18 | model = YOLO(cls_num, box_num, ceil_size, pretrained, l_coord, l_obj, l_noobj, conv_mode) 19 | 20 | return model 21 | 22 | 23 | def fill_fc_weights(layers): 24 | for m in layers.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | if m.bias is not None: 27 | nn.init.constant_(m.bias, 0) 28 | 29 | 30 | class YOLO(nn.Module): 31 | 32 | def __init__(self, cls_num, bbox_num=2, scale_size=7, 33 | pretrained=None, 34 | l_coord=5, 35 | l_obj=1, 36 | l_noobj=0.5, 37 | conv_mode=False 38 | ): 39 | super(YOLO, self).__init__() 40 | 41 | self.cls_num = cls_num 42 | self.conv_mode = conv_mode 43 | self.backbone = darknet_19() 44 | if pretrained is not None: 45 | self.backbone.load_weight(pretrained) 46 | 47 | self.loss = yolov1_loss(l_coord, l_obj, l_noobj) 48 | self.scale_size = scale_size 49 | self.bbox_num = bbox_num 50 | self.last_output = (5 * self.bbox_num + self.cls_num) 51 | 52 | self.local_layer = nn.Sequential() 53 | self.local_layer.add_module('block_1', conv_block(1024, 1024, 3, False, 2)) 54 | self.local_layer.add_module('block_2', conv_block(1024, 1024, 3, False, 1)) 55 | self.local_layer.add_module('block_3', conv_block(1024, 1024, 3, False, 1)) 56 | self.local_layer.add_module('block_4', conv_block(1024, 1024, 3, False, 1)) 57 | fill_fc_weights(self.local_layer) 58 | 59 | if not self.conv_mode: 60 | self.reg_layer = nn.Sequential() 61 | self.reg_layer.add_module('local_layer', nn.Linear(1024 * 7 * 7, 4096)) 62 | self.reg_layer.add_module('leaky_local', nn.LeakyReLU(0.1, inplace=True)) 63 | self.reg_layer.add_module('dropout', nn.Dropout(0.5)) 64 | fill_fc_weights(self.reg_layer) 65 | self.cls_pred = nn.Linear(4096, self.cls_num * self.scale_size * self.scale_size) 66 | self.response_pred = nn.Linear(4096, self.bbox_num * self.scale_size * self.scale_size) 67 | self.offset_pred = nn.Linear(4096, self.bbox_num * 4 * self.scale_size * self.scale_size) 68 | else: 69 | self.cls_pred = nn.Sequential( 70 | nn.Conv2d(1024, 256, 3, stride=1, padding=1), 71 | nn.ReLU(), 72 | nn.Conv2d(256, self.cls_num, 1, stride=1, padding=0) 73 | ) 74 | self.response_pred = nn.Sequential( 75 | nn.Conv2d(1024, 256, 3, stride=1, padding=1), 76 | nn.ReLU(), 77 | nn.Conv2d(256, self.bbox_num, 1, stride=1, padding=0) 78 | ) 79 | self.offset_pred = nn.Sequential( 80 | nn.Conv2d(1024, 256, 3, stride=1, padding=1), 81 | nn.ReLU(), 82 | nn.Conv2d(256, self.bbox_num * 4, 1, stride=1, padding=0) 83 | ) 84 | 85 | fill_fc_weights(self.cls_pred) 86 | fill_fc_weights(self.response_pred) 87 | fill_fc_weights(self.offset_pred) 88 | 89 | def gen_anchor(self, ceil): 90 | 91 | w, h = ceil 92 | x = torch.linspace(0, w - 1, w).unsqueeze(dim=0).repeat(h, 1).unsqueeze(dim=0) 93 | y = torch.linspace(0, h - 1, h).unsqueeze(dim=0).repeat(w, 1).unsqueeze(dim=0).permute(0, 2, 1) 94 | anchor_xy = torch.cat((x, y), dim=0).view(-1, 2, h, w) 95 | 96 | return anchor_xy 97 | 98 | def forward(self, x, target=None, conf=0.02, nms_threshold=0.5): 99 | B, c, h, w = x.shape 100 | device = x.get_device() 101 | img_size = (w, h) 102 | output = self.backbone(x) 103 | output = self.local_layer(output) 104 | B, _, ceil_h, ceil_w = output.shape 105 | ceil = (ceil_w, ceil_h) 106 | anchor_xy = self.gen_anchor(ceil) 107 | anchor_xy = anchor_xy.repeat(B, self.bbox_num, 1, 1, 1).to(device) 108 | if self.conv_mode: 109 | pred_cls = self.cls_pred(output) 110 | pred_response = self.response_pred(output) 111 | pred_bbox = self.offset_pred(output).view(B, self.bbox_num, 4, ceil_h, ceil_w) 112 | pred_bbox[:, :, :2, :, :] += anchor_xy 113 | pred_bbox = pred_bbox.view(B, -1, ceil_h, ceil_w) 114 | else: 115 | output = output.view(B, -1) 116 | output = self.reg_layer(output) 117 | pred_cls = self.cls_pred(output).view(B, self.cls_num, ceil_h, ceil_w) 118 | pred_response = self.response_pred(output).view(B, self.bbox_num, ceil_h, ceil_w) 119 | pred_bbox = self.offset_pred(output).view(B, self.bbox_num * 4, ceil_h, ceil_w) 120 | pred_bbox = pred_bbox.view(B, self.bbox_num, 4, ceil_h, ceil_w) 121 | pred_bbox[:, :, :2, :, :] += anchor_xy 122 | pred_bbox = pred_bbox.view(B, -1, ceil_h, ceil_w) 123 | 124 | if target is None: 125 | output = [] 126 | for bs in range(B): 127 | cls = pred_cls[bs, :, :, :] 128 | objness = pred_response[bs, :, :, :] 129 | bbox = pred_bbox[bs, :, :, :] 130 | pred = (cls, objness, bbox) 131 | output.append(yolo_decoder(pred, img_size, conf, nms_threshold)) 132 | return output 133 | else: 134 | pred = (pred_cls, pred_response, pred_bbox) 135 | loss_dict = self.loss(pred, target) 136 | return loss_dict 137 | 138 | 139 | if __name__ == '__main__': 140 | from data_utils.datasets import VOCDatasets 141 | 142 | net = YOLO(20, conv_mode=True).cuda() 143 | 144 | input = torch.zeros(1, 3, 448, 448).cuda() 145 | dataset = VOCDatasets('../train.txt', train=False) 146 | data = net(input, [dataset[334]]) 147 | 148 | print(data) 149 | print(sum([data[d] for d in data.keys()])) 150 | -------------------------------------------------------------------------------- /中文.md: -------------------------------------------------------------------------------- 1 | # 基于pytorch的最简单的yolo v1实现方法 2 | ### 可参考中文博客:https://muzhan.blog.csdn.net/article/details/82588059 3 | --------------------------------------------------------------------------------