├── png └── cat.jpg ├── results └── demo │ ├── cat.jpg │ └── dog.jpg ├── config ├── FiraMono-Medium.otf ├── tiny-yolo-voc.py ├── yolo-voc.py ├── tiny-yolo.py └── yolo.py ├── network ├── postprocesscv.py ├── prior.py ├── detect.py ├── postprocess.py └── yolo.py ├── README.md ├── README.zh.md ├── demo_cam.py ├── demo.py ├── demo_video.py └── tools └── yad2t.py /png/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/YOLO-pytorch/HEAD/png/cat.jpg -------------------------------------------------------------------------------- /results/demo/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/YOLO-pytorch/HEAD/results/demo/cat.jpg -------------------------------------------------------------------------------- /results/demo/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/YOLO-pytorch/HEAD/results/demo/dog.jpg -------------------------------------------------------------------------------- /config/FiraMono-Medium.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/YOLO-pytorch/HEAD/config/FiraMono-Medium.otf -------------------------------------------------------------------------------- /network/postprocesscv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | # draw proposal boxes using OpenCV 6 | def draw_box_cv(cfg, image, label, box, c): 7 | h, w = image.shape[:2] 8 | thickness = (w + h) // 300 9 | left, top, right, bottom = box 10 | top, left = max(0, np.round(top).astype('int32')), max(0, np.round(left).astype('int32')) 11 | right, bottom = min(w, np.round(right).astype('int32')), min(h, np.round(bottom).astype('int32')) 12 | cv2.rectangle(image, (left, top), (right, bottom), cfg.colors[c], thickness) 13 | cv2.putText(image, label, (left, top - 5), 0, 0.5, cfg.colors[c], 1) 14 | -------------------------------------------------------------------------------- /network/prior.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from itertools import product 4 | 5 | 6 | # prior box in each position 7 | class PriorBox(object): 8 | def __init__(self, cfg): 9 | super(PriorBox, self).__init__() 10 | self.image_size = cfg.image_size 11 | self.num_priors = cfg.anchor_num 12 | self.feat_size = cfg.feat_size 13 | self.anchors = np.array(cfg.anchors).reshape(-1, 2) 14 | 15 | def forward(self): 16 | mean = [] 17 | for i, j in product(range(self.feat_size), repeat=2): 18 | cx = j 19 | cy = i 20 | for k in range(self.num_priors): 21 | w = self.anchors[k, 0] 22 | h = self.anchors[k, 1] 23 | mean += [cx, cy, w, h] 24 | 25 | output = torch.Tensor(mean).view(-1, 4) 26 | return output 27 | -------------------------------------------------------------------------------- /network/detect.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | from network.postprocess import box_to_corners, filter_box, non_max_suppression 3 | 4 | 5 | # test phase's proposal process 6 | class Detect(Function): 7 | def __init__(self, cfg, eval=False): 8 | self.class_num = cfg.class_num 9 | self.feat_size = cfg.feat_size 10 | if eval: 11 | self.nms_t, self.score_t = cfg.eval_nms_threshold, cfg.eval_score_threshold 12 | else: 13 | self.nms_t, self.score_t = cfg.nms_threshold, cfg.score_threshold 14 | 15 | def forward(self, box_pred, box_conf, box_prob, priors, img_shape, max_boxes=10): 16 | box_pred[..., 0:2] += priors[..., 0:2] 17 | box_pred[..., 2:] *= priors[..., 2:] 18 | boxes = box_to_corners(box_pred) / self.feat_size 19 | boxes, scores, classes = filter_box(boxes, box_conf, box_prob, self.score_t) 20 | if boxes.numel() == 0: 21 | return boxes, scores, classes 22 | boxes = boxes * img_shape.repeat(boxes.size(0), 1) 23 | keep, count = non_max_suppression(boxes, scores, self.nms_t) 24 | boxes = boxes[keep[:count]] 25 | scores = scores[keep[:count]] 26 | classes = classes[keep[:count]] 27 | return boxes, scores, classes 28 | -------------------------------------------------------------------------------- /config/tiny-yolo-voc.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | import random 3 | 4 | # create model information 5 | tiny, voc = True, True 6 | num = 7 7 | flag, size_flag = [1] * num, [] 8 | pool = [0, 1, 2, 3, 4] 9 | 10 | # anchor and classes information 11 | anchors = [1.08, 1.19, 3.42, 4.41, 6.63, 11.38, 9.42, 5.11, 16.62, 10.52] 12 | classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 13 | 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 14 | 'sofa', 'train', 'tvmonitor'] 15 | anchor_num = 5 16 | class_num = len(classes) 17 | 18 | # color for draw boxes 19 | hsv_tuples = [(x / class_num, 1., 1.) for x in range(class_num)] 20 | colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 21 | colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors)) 22 | random.seed(10101) # Fixed seed for consistent colors across runs. 23 | random.shuffle(colors) # Shuffle colors to decorrelate adjacent classes. 24 | random.seed(None) # Reset seed to default. 25 | 26 | # input image information 27 | image_size = (416, 416) 28 | feat_size = image_size[0] // 32 29 | 30 | cuda = False 31 | 32 | # demo parameter 33 | eval = False 34 | score_threshold = 0.5 35 | nms_threshold = 0.4 36 | iou_threshold = 0.6 37 | -------------------------------------------------------------------------------- /config/yolo-voc.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | import random 3 | 4 | # create model information 5 | tiny, voc, num = False, True, 18 6 | flag = [1, 1, 0] * 3 + [1, 0, 1] * 2 + [0, 1, 0] 7 | size_flag = [3, 6, 9, 11, 14, 16] 8 | pool = [0, 1, 4, 7, 13] 9 | 10 | # anchor and classes information 11 | anchors = [1.3221, 1.73145, 3.19275, 4.00944, 5.05587, 8.09892, 9.47112, 4.84053, 11.2364, 10.0071] 12 | classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 13 | 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 14 | 'sofa', 'train', 'tvmonitor'] 15 | anchor_num = 5 16 | class_num = len(classes) 17 | 18 | # color for draw boxes 19 | hsv_tuples = [(x / class_num, 1., 1.) for x in range(class_num)] 20 | colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 21 | colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors)) 22 | random.seed(10101) # Fixed seed for consistent colors across runs. 23 | random.shuffle(colors) # Shuffle colors to decorrelate adjacent classes. 24 | random.seed(None) # Reset seed to default. 25 | 26 | # input image information 27 | image_size = (416, 416) 28 | feat_size = image_size[0] // 32 29 | 30 | cuda = False 31 | 32 | # demo parameter 33 | eval = False 34 | score_threshold = 0.5 35 | nms_threshold = 0.4 36 | iou_threshold = 0.6 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YOLO-Pytorch 2 | [中文说明](README.zh.md) 3 | 4 | ## Description 5 | 6 | This is a pytorch version of [YAD2K](https://github.com/allanzelener/YAD2K)。 7 | 8 | Original paper: [YOLO9000: Better, Faster, Stronger](https://arxiv.org/abs/1612.08242)by Joseph Redmond and Ali Farhadi. 9 | 10 |

11 | 12 | --- 13 | 14 | ## Requirements 15 | 16 | - Pytorch 0.3.0 17 | - torchvision 18 | - opencv(Requirement for camera and video) 19 | - python 3 20 | 21 | 22 | ## Usage 23 | 24 | 1. Download Darknet model cfg and weights from the [official YOLO website](http://pjreddie.com/darknet/yolo/). 25 | 26 | ```bash 27 | # for example --- or other version cfg and weights 28 | wget http://pjreddie.com/media/files/yolo.weights 29 | wget https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolo.cfg 30 | ``` 31 | Note: you can download other types: like `yolo-voc.cfg` 32 | 33 | 2. Convert the weights to `.pth` 34 | 35 | ```bash 36 | python tools/yad2t.py path-to-yolo-cfg path-to-yolo-weights path-to-output-folder 37 | ``` 38 | 39 | Note: default choose 40 | 41 | - copy your `yolo.cfg` and `yolo.weights` to the directory `config` 42 | - the output folder is `model` 43 | 44 | 3. Three demos (picture, camera, video) 45 | 46 | 1. `demo.py` 47 | 48 | ```bash 49 | python demo.py pic-path yolo-type --cuda=True 50 | ``` 51 | 52 | Note: default choose 53 | 54 | - picture in folder `results/demo` 55 | - `yolo-type` is `yolo`: three kinds: `[yolo, tiny-yolo-voc, yolo-voc]` 56 | 57 | 2. `demo_cam.py` 58 | 59 | ```bash 60 | python demo_cam.py --trained_model=pth_model_from_1 61 | ``` 62 | 63 | 3. `demo_video.py` 64 | 65 | ```bash 66 | python demo_video.py --demo_path=video_path --trained_model=pth_model_from_1 67 | ``` 68 | 69 | 70 | -------------------------------------------------------------------------------- /README.zh.md: -------------------------------------------------------------------------------- 1 | # YOLO-Pytorch 2 | 3 | [English](README.md) 4 | 5 | ## 说明 6 | 7 | 主要将Keras版本的Yolo2[YAD2K](https://github.com/allanzelener/YAD2K)移植到pytorch。 8 | 9 | 原始论文:[YOLO9000: Better, Faster, Stronger](https://arxiv.org/abs/1612.08242) 10 | 11 |

12 | 13 | --- 14 | 15 | ## 环境要求 16 | 17 | - Pytorch 0.3.0 18 | - torchvision 19 | - OpenCV 20 | - python 3 21 | 22 | ## 使用说明 23 | 24 | 1. 从[官网](http://pjreddie.com/darknet/yolo/)下载已经训练好的模型和说明: 25 | 26 | ```bash 27 | # 还有tiny-yolo-voc 和 yolo-voc等版本 28 | wget http://pjreddie.com/media/files/yolo.weights 29 | wget https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolo.cfg 30 | ``` 31 | 32 | 2. 将`.weight`的参数导成`.pth`文件(具体实现见`tools/yad2t.py`函数): 33 | 34 | ```bash 35 | python tools/yad2t.py path-to-yolo-cfg path-to-yolo-weights path-to-output-folder 36 | ``` 37 | 38 | 说明:① 默认情况假设你将1中下载的参数和说明放置在`config`这个文件夹中 ② 默认会将导出的`.pth`文件保存到`model`这个文件夹中 39 | 40 | 3. 三个示例程序: 41 | 42 | - `demo.py`:处理单张图片或者一个包含图片的文件夹 43 | 44 | ```bash 45 | python demo.py pic-path yolo-type --cuda=True 46 | ``` 47 | 48 | 说明: ① 默认图片地址为`results/demo`中的所有图片 ② 默认采用`yolo`这个模型,你可以选择`yolo-voc`或者`tiny-yolo-voc`(但需同前面两步获得对应的训练好的模型) ③ 是否选择采用gpu模型 49 | 50 | - `demo_cam.py`:摄像头处理(最好是采用gpu加速) 51 | 52 | ```bash 53 | python demo_cam.py --trained_model=pth_model_from_1 54 | ``` 55 | 56 | 说明:① 具体其他参数等请看`demo_cam.py` 57 | 58 | - `demo_video.py`:视频处理(可能有bug) 59 | 60 | ```bash 61 | python demo_video.py --demo_path=video_path --trained_model=pth_model_from_1 62 | ``` 63 | 64 | 说明:① 目前可以处理`avi`和`mp4`格式,其他类型未验证 65 | 66 | ### 各文件夹说明 67 | 68 | - `config`:预先保存的一些参数 69 | - `network`:包含网络结构的实现,图像后处理等操作 70 | - `tools`:包含参数转换等工具 71 | - `results`:主要保存实验结果 72 | 73 | ## 待办事项 74 | 75 | - [ ] 测试mAP 76 | - [ ] 增加训练过程(这部分待定) 77 | 78 | ## 问题 79 | 80 | 欢迎指出存在的bug,以及pull request~ 谢谢 -------------------------------------------------------------------------------- /config/tiny-yolo.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | import random 3 | 4 | # create model information 5 | tiny, voc = True, False 6 | num = 7 7 | flag, size_flag = [1] * num, [] 8 | pool = [0, 1, 2, 3, 4] 9 | 10 | # anchor and classes information 11 | anchors = [1.3221, 1.73145, 3.19275, 4.00944, 5.05587, 8.09892, 9.47112, 4.84053, 11.2364, 10.0071] 12 | classes = ['person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 13 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 14 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 15 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 16 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 17 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'sofa', 18 | 'pottedplant', 'bed', 'diningtable', 'toilet', 'tvmonitor', 'laptop', 'mouse', 'remote', 'keyboard', 19 | 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 20 | 'teddy bear', 'hair drier', 'toothbrush'] 21 | anchor_num = 5 22 | class_num = len(classes) 23 | 24 | # color for draw boxes 25 | hsv_tuples = [(x / class_num, 1., 1.) for x in range(class_num)] 26 | colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 27 | colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors)) 28 | random.seed(10101) # Fixed seed for consistent colors across runs. 29 | random.shuffle(colors) # Shuffle colors to decorrelate adjacent classes. 30 | random.seed(None) # Reset seed to default. 31 | 32 | # input image information 33 | image_size = (416, 416) 34 | feat_size = image_size[0] // 32 35 | 36 | cuda = False 37 | 38 | # demo parameter 39 | eval = False 40 | score_threshold = 0.5 41 | nms_threshold = 0.4 42 | iou_threshold = 0.6 43 | -------------------------------------------------------------------------------- /config/yolo.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | import random 3 | 4 | # create model information 5 | tiny, voc, num = False, False, 18 6 | flag = [1, 1, 0] * 3 + [1, 0, 1] * 2 + [0, 1, 0] 7 | size_flag = [3, 6, 9, 11, 14, 16] 8 | pool = [0, 1, 4, 7, 13] 9 | 10 | # anchor and classes information 11 | anchors = [0.57273, 0.677385, 1.87446, 2.06253, 3.33843, 5.47434, 7.88282, 3.52778, 9.77052, 9.16828] 12 | classes = ['person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 13 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 14 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 15 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 16 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 17 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'sofa', 18 | 'pottedplant', 'bed', 'diningtable', 'toilet', 'tvmonitor', 'laptop', 'mouse', 'remote', 'keyboard', 19 | 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 20 | 'teddy bear', 'hair drier', 'toothbrush'] 21 | anchor_num = 5 22 | class_num = len(classes) 23 | 24 | # color for draw boxes 25 | hsv_tuples = [(x / class_num, 1., 1.) for x in range(class_num)] 26 | colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 27 | colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors)) 28 | random.seed(10101) # Fixed seed for consistent colors across runs. 29 | random.shuffle(colors) # Shuffle colors to decorrelate adjacent classes. 30 | random.seed(None) # Reset seed to default. 31 | 32 | # input image information 33 | image_size = (608, 608) 34 | feat_size = image_size[0] // 32 35 | 36 | cuda = False 37 | 38 | # demo parameter 39 | eval = False 40 | score_threshold = 0.5 41 | nms_threshold = 0.4 42 | iou_threshold = 0.6 43 | -------------------------------------------------------------------------------- /demo_cam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import argparse 5 | import importlib 6 | from torch.autograd import Variable 7 | from network.postprocesscv import draw_box_cv 8 | from network.yolo import yolo 9 | 10 | 11 | def demo_cam(cfg, save_path, save): 12 | net = yolo(cfg) 13 | net.load_state_dict(torch.load(cfg.trained_model)) 14 | if cfg.cuda: net = net.cuda() 15 | net.eval() 16 | cam = cv2.VideoCapture(0) 17 | if not cam.isOpened(): raise IOError("check your camera or the opencv library...") 18 | if save: 19 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 20 | out = cv2.VideoWriter(save_path + '/out_camera.avi', fourcc, 20.0, (640, 480)) 21 | while cam.isOpened(): 22 | ret, image = cam.read() 23 | if ret: 24 | reimg = cv2.resize(image, cfg.image_size) 25 | reimg = torch.from_numpy(reimg.transpose(2, 0, 1)).float().div(255.0).unsqueeze(0) 26 | reimg = Variable(reimg, volatile=True) 27 | if cfg.cuda: reimg = reimg.cuda() 28 | boxes, scores, classes = net(reimg, (image.shape[1], image.shape[0])) 29 | boxes, scores, classes = boxes.data.cpu(), scores.data.cpu(), classes.data.cpu() 30 | for i, c in list(enumerate(classes)): 31 | pred_class, box, score = cfg.classes[c], boxes[i], scores[i] 32 | label = '{} {:.2f}'.format(pred_class, score) 33 | draw_box_cv(cfg, image, label, box, c) 34 | cv2.imshow('camera', image) 35 | if cv2.waitKey(1) & 0xFF == ord('q'): break 36 | if save: out.write(image) 37 | cam.release() 38 | if save: out.release() 39 | cv2.destroyAllWindows() 40 | 41 | 42 | if __name__ == '__main__': 43 | # default path 44 | curdir = os.getcwd() 45 | trained_model = os.path.join(curdir, 'model/yolo.pth') 46 | 47 | # demo parameters 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--yolo_type', default='yolo', type=str) 50 | parser.add_argument('--save', default=False, type=bool) 51 | parser.add_argument('--save_path', default='results/demo/camera', type=str) 52 | parser.add_argument('--nms_threshold', default=0.4, type=float) 53 | parser.add_argument('--score_threshold', default=0.5, type=float) 54 | parser.add_argument('--trained_model', default=trained_model, type=str) 55 | parser.add_argument('--cuda', default=True, type=bool) 56 | 57 | config = parser.parse_args() 58 | cfg = importlib.import_module('config.' + config.yolo_type) 59 | cfg.nms_threshold = config.nms_threshold 60 | cfg.score_threshold = config.score_threshold 61 | cfg.trained_model = config.trained_model 62 | cfg.cuda = config.cuda 63 | 64 | if not os.path.exists(config.save_path) and config.save: os.mkdir(config.save_path) 65 | demo_cam(cfg, config.save_path, config.save) 66 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import importlib 5 | from PIL import Image 6 | from torchvision import transforms 7 | from torch.autograd import Variable 8 | from network.postprocess import draw_box 9 | from network.yolo import yolo 10 | 11 | 12 | # TODO: add a dataloader form 13 | def demo(cfg, img_list, save_path): 14 | transform = transforms.Compose([transforms.Resize(cfg.image_size), transforms.ToTensor()]) 15 | net = yolo(cfg) 16 | net.load_state_dict(torch.load(cfg.trained_model)) 17 | if cfg.cuda: net = net.cuda() 18 | net.eval() 19 | for img in img_list: 20 | image = Image.open(img) 21 | reimg = Variable(transform(image).unsqueeze(0)) 22 | if cfg.cuda: reimg = reimg.cuda() 23 | boxes, scores, classes = net(reimg, image.size) 24 | boxes, scores, classes = boxes.data.cpu(), scores.data.cpu(), classes.data.cpu() 25 | print('Find {} boxes for {}.'.format(len(boxes), img.split('/')[-1])) 26 | for i, c in list(enumerate(classes)): 27 | pred_class, box, score = cfg.classes[c], boxes[i], scores[i] 28 | label = '{} {:.2f}'.format(pred_class, score) 29 | draw_box(cfg, image, label, box, c) 30 | image.save(os.path.join(save_path, img.split('/')[-1]), quality=90) 31 | 32 | 33 | if __name__ == '__main__': 34 | # default path 35 | curdir = os.getcwd() 36 | demo_path = os.path.join(curdir, 'results/demo') 37 | trained_model = os.path.join(curdir, 'model/yolo.pth') 38 | 39 | # demo parameters 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('--demo_path', default=demo_path, type=str) 42 | parser.add_argument('--yolo_type', default='yolo', type=str) 43 | parser.add_argument('--nms_threshold', default=0.4, type=float) 44 | parser.add_argument('--score_threshold', default=0.5, type=float) 45 | parser.add_argument('--trained_model', default=trained_model, type=str) 46 | parser.add_argument('--cuda', default=True, type=str) 47 | 48 | config = parser.parse_args() 49 | cfg = importlib.import_module('config.' + config.yolo_type) 50 | cfg.nms_threshold = config.nms_threshold 51 | cfg.score_threshold = config.score_threshold 52 | cfg.demo_path = config.demo_path 53 | cfg.trained_model = config.trained_model 54 | cfg.cuda = config.cuda 55 | 56 | ext = ['.jpg', '.png'] 57 | if os.path.isfile(cfg.demo_path) and os.path.splitext(cfg.demo_path)[-1] in ext: 58 | img_list = [cfg.demo_path] 59 | save_path = os.path.join(os.path.dirname(cfg.demo_path), 'out') 60 | elif os.path.isdir(cfg.demo_path): 61 | imgs = [fname for fname in os.listdir(cfg.demo_path) if os.path.splitext(fname)[-1] in ext] 62 | img_list = [os.path.join(cfg.demo_path, fname) for fname in imgs] 63 | save_path = os.path.join(cfg.demo_path, 'out') 64 | if not os.path.exists(save_path): os.mkdir(save_path) 65 | if not img_list: 66 | raise IOError("illegal demo path ...") 67 | demo(cfg, img_list, save_path) 68 | -------------------------------------------------------------------------------- /demo_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import argparse 5 | import importlib 6 | from torch.autograd import Variable 7 | from timeit import default_timer as timer 8 | from network.postprocesscv import draw_box_cv 9 | from network.yolo import yolo 10 | 11 | 12 | def demo_video(cfg, file, save_path, save, start_frame=0): 13 | net = yolo(cfg) 14 | net.load_state_dict(torch.load(cfg.trained_model)) 15 | if cfg.cuda: net = net.cuda() 16 | net.eval() 17 | video = cv2.VideoCapture(file) 18 | if not video.isOpened(): 19 | raise IOError('Could not open video, check your opencv and video') 20 | if save: 21 | fps = video.get(cv2.CAP_PROP_FPS) 22 | size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))) 23 | out = cv2.VideoWriter(os.path.join(save_path + file.split('/')[-1]), cv2.VideoWriter_fourcc(*'XVID'), fps, size) 24 | 25 | if start_frame > 0: 26 | video.set(cv2.CAP_PROP_POS_MSEC, start_frame) 27 | accum_time = 0 28 | curr_fps = 0 29 | fps = "FPS: " 30 | prev_time = timer() 31 | 32 | while True: 33 | retval, orig_image = video.read() 34 | if not retval: 35 | print('Done !!!') 36 | return 37 | reimg = cv2.resize(orig_image, cfg.image_size) 38 | reimg = torch.from_numpy(reimg.transpose(2, 0, 1)).float().div(255.0).unsqueeze(0) 39 | x = Variable(reimg, volatile=True) 40 | if cfg.cuda: x = x.cuda() 41 | boxes, scores, classes = net(x, (orig_image.shape[1], orig_image.shape[0])) 42 | boxes, scores, classes = boxes.data.cpu(), scores.data.cpu(), classes.data.cpu() 43 | for i, c in list(enumerate(classes)): 44 | pred_class, box, score = cfg.classes[c], boxes[i], scores[i] 45 | label = '{} {:.2f}'.format(pred_class, score) 46 | draw_box_cv(cfg, orig_image, label, box, c) 47 | curr_time = timer() 48 | exec_time = curr_time - prev_time 49 | prev_time = curr_time 50 | accum_time += exec_time 51 | curr_fps = curr_fps + 1 52 | if accum_time > 1: 53 | accum_time -= 1 54 | fps = "FPS:" + str(curr_fps) 55 | curr_fps = 0 56 | cv2.rectangle(orig_image, (0, 0), (50, 17), (255, 255, 255), -1) 57 | cv2.putText(orig_image, fps, (3, 10), 0, 0.35, (0, 0, 0), 1) 58 | if save: 59 | out.write(orig_image) 60 | cv2.imshow("yolo result", orig_image) 61 | if cv2.waitKey(1) & 0xFF == ord('q'): 62 | break 63 | video.release() 64 | if save: 65 | out.release() 66 | 67 | 68 | if __name__ == '__main__': 69 | # default path 70 | curdir = os.getcwd() 71 | trained_model = os.path.join(curdir, 'model/yolo.pth') 72 | demo_path = os.path.join(curdir, 'results/demo/video.avi') 73 | 74 | # demo parameters 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument('--yolo_type', default='yolo', type=str) 77 | parser.add_argument('--save', default=False, type=bool) 78 | parser.add_argument('--save_path', default='results/demo/video', type=str) 79 | parser.add_argument('--demo_path', default=demo_path, type=str) 80 | parser.add_argument('--nms_threshold', default=0.4, type=float) 81 | parser.add_argument('--score_threshold', default=0.5, type=float) 82 | parser.add_argument('--trained_model', default=trained_model, type=str) 83 | parser.add_argument('--cuda', default=True, type=bool) 84 | 85 | config = parser.parse_args() 86 | cfg = importlib.import_module('config.' + config.yolo_type) 87 | cfg.nms_threshold = config.nms_threshold 88 | cfg.score_threshold = config.score_threshold 89 | cfg.trained_model = config.trained_model 90 | cfg.cuda = config.cuda 91 | 92 | ext = ['.mp4', '.avi'] 93 | if not os.path.splitext(config.demo_path)[-1] in ext: 94 | raise IOError("illegal video form...") 95 | if not os.path.exists(config.save_path) and config.save: os.mkdir(config.save_path) 96 | demo_video(cfg, config.demo_path, config.save_path, config.save) 97 | -------------------------------------------------------------------------------- /network/postprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import ImageFont, ImageDraw 4 | 5 | 6 | # (x,y,w,h)--->(x1, y1, x2, y2) 7 | def box_to_corners(box_pred): 8 | box_mins = box_pred[..., 0:2] - (box_pred[..., 2:] / 2.) 9 | box_maxes = box_pred[..., 0:2] + (box_pred[..., 2:] / 2.) 10 | return torch.cat([box_mins[..., 0:1], box_mins[..., 1:2], 11 | box_maxes[..., 0:1], box_maxes[..., 1:2]], 3) 12 | 13 | 14 | # remove the proposal detector which is less than the threshold 15 | def filter_box(boxes, box_conf, box_prob, threshold=.5): 16 | box_scores = box_conf.repeat(1, 1, 1, box_prob.size(3)) * box_prob 17 | box_class_scores, box_classes = torch.max(box_scores, dim=3) 18 | prediction_mask = box_class_scores > threshold 19 | prediction_mask4 = prediction_mask.unsqueeze(3).expand(boxes.size()) 20 | 21 | boxes = torch.masked_select(boxes, prediction_mask4).contiguous().view(-1, 4) 22 | scores = torch.masked_select(box_class_scores, prediction_mask) 23 | classes = torch.masked_select(box_classes, prediction_mask) 24 | return boxes, scores, classes 25 | 26 | 27 | # non-max-suppression process 28 | def non_max_suppression(boxes, scores, overlap=0.5, top_k=200): 29 | keep = torch.Tensor(scores.size(0)).fill_(0).long() 30 | if boxes.is_cuda: keep = keep.cuda() 31 | if boxes.numel() == 0: 32 | return keep 33 | x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] 34 | area = torch.mul(x2 - x1, y2 - y1) 35 | v, idx = scores.sort(0) # sort in ascending order 36 | # I = I[v >= 0.01] 37 | idx = idx[-top_k:] # indices of the top-k largest vals 38 | xx1, yy1, xx2, yy2 = boxes.new(), boxes.new(), boxes.new(), boxes.new() 39 | w, h = boxes.new(), boxes.new() 40 | count = 0 41 | while idx.numel() > 0: 42 | i = idx[-1] # index of current largest val 43 | # keep.append(i) 44 | keep[count] = i 45 | count += 1 46 | if idx.size(0) == 1: break 47 | idx = idx[:-1] # remove kept element from view 48 | # load bboxes of next highest vals 49 | torch.index_select(x1, 0, idx, out=xx1) 50 | torch.index_select(y1, 0, idx, out=yy1) 51 | torch.index_select(x2, 0, idx, out=xx2) 52 | torch.index_select(y2, 0, idx, out=yy2) 53 | # store element-wise max with next highest score 54 | xx1 = torch.clamp(xx1, min=x1[i]) 55 | yy1 = torch.clamp(yy1, min=y1[i]) 56 | xx2 = torch.clamp(xx2, max=x2[i]) 57 | yy2 = torch.clamp(yy2, max=y2[i]) 58 | w.resize_as_(xx2) 59 | h.resize_as_(yy2) 60 | # check sizes of xx1 and xx2.. after each iteration 61 | w, h = torch.clamp(xx2 - xx1, min=0.0), torch.clamp(yy2 - yy1, min=0.0) 62 | inter = w * h 63 | # IoU = i / (area(a) + area(b) - i) 64 | rem_areas = torch.index_select(area, 0, idx) # load remaining areas) 65 | union = (rem_areas - inter) + area[i] 66 | IoU = inter / union # store result in iou 67 | # keep only elements with an IoU <= overlap 68 | idx = idx[IoU.le(overlap)] 69 | return keep, count 70 | 71 | 72 | # draw proposal boxes 73 | def draw_box(cfg, image, label, box, c): 74 | w, h = image.size 75 | font = ImageFont.truetype(font='./config/FiraMono-Medium.otf', size=np.round(3e-2 * h).astype('int32')) 76 | thickness = (w + h) // 300 77 | draw = ImageDraw.Draw(image) 78 | label_size = draw.textsize(label, font) 79 | left, top, right, bottom = box 80 | top, left = max(0, np.round(top).astype('int32')), max(0, np.round(left).astype('int32')) 81 | right, bottom = min(w, np.round(right).astype('int32')), min(h, np.round(bottom).astype('int32')) 82 | print(label, (left, top), (right, bottom)) 83 | text_orign = np.array([left, top - label_size[1]]) if top - label_size[1] >= 0 else np.array([left, top + 1]) 84 | for i in range(thickness): 85 | draw.rectangle([left + i, top + i, right - i, bottom - i], outline=cfg.colors[c]) 86 | draw.rectangle([tuple(text_orign), tuple(text_orign + label_size)], fill=cfg.colors[c]) 87 | draw.text(text_orign, label, fill=(0, 0, 0), font=font) 88 | del draw 89 | -------------------------------------------------------------------------------- /tools/yad2t.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import io 3 | import os 4 | import argparse 5 | import configparser 6 | import importlib 7 | import torch 8 | import numpy as np 9 | from collections import defaultdict 10 | 11 | sys.path.append('..') 12 | from network.yolo import yolo 13 | 14 | 15 | # exchange the [xxx] name to [xxx_num] form 16 | def unique_config_sections(config_file): 17 | """Convert all config sections to have unique names. 18 | 19 | Adds unique suffixes to config sections for compability with configparser. 20 | """ 21 | section_counters = defaultdict(int) 22 | output_stream = io.StringIO() 23 | with open(config_file) as fin: 24 | for line in fin: 25 | if line.startswith('['): 26 | section = line.strip().strip('[]') 27 | _section = section + '_' + str(section_counters[section]) 28 | section_counters[section] += 1 29 | line = line.replace(section, _section) 30 | output_stream.write(line) 31 | output_stream.seek(0) 32 | return output_stream 33 | 34 | 35 | def weight2pth(config_path, weights_path, output_path): 36 | assert config_path.endswith('.cfg'), '{} is not a .cfg file'.format(config_path) 37 | assert weights_path.endswith('.weights'), '{} is not a .weights file'.format(weights_path) 38 | # weights header 39 | weights_file = open(weights_path, 'rb') 40 | weights_header = np.ndarray(shape=(4,), dtype='int32', buffer=weights_file.read(16)) 41 | print('Weights Header: ', weights_header) 42 | # convert config information 43 | unique_config_file = unique_config_sections(config_path) 44 | cfg_parser = configparser.ConfigParser() 45 | cfg_parser.read_file(unique_config_file) 46 | # network information 47 | cfgname = config_path.split('/')[-1].split('.')[0] 48 | cfg = importlib.import_module('config.' + cfgname) 49 | net = yolo(cfg) 50 | net_dict = net.state_dict() 51 | keys = list(net_dict.keys()) 52 | key_num, count, prev_filter = 0, 0, 3 53 | print('loading the weights ...') 54 | for section in cfg_parser.sections(): 55 | if section.startswith('convolutional'): 56 | filters = int(cfg_parser[section]['filters']) 57 | size = int(cfg_parser[section]['size']) 58 | bn = 'batch_normalize' in cfg_parser[section] 59 | activation = cfg_parser[section]['activation'] 60 | # three special case 61 | if section == 'convolutional_20': 62 | prev_filter = 512 63 | elif section == 'convolutional_21': 64 | prev_filter = 1280 65 | elif section == 'convolutional_0': 66 | prev_filter = 3 67 | else: 68 | prev_filter = weights_shape[0] 69 | 70 | weights_shape = (filters, prev_filter, size, size) 71 | weights_size = np.product(weights_shape) 72 | print('conv2d', 'bn' if bn else ' ', activation, weights_shape) 73 | conv_bias = np.ndarray( 74 | shape=(filters,), 75 | dtype='float32', 76 | buffer=weights_file.read(filters * 4)) 77 | count += filters 78 | if bn: 79 | bn_weights = np.ndarray( 80 | shape=(3, filters), 81 | dtype='float32', 82 | buffer=weights_file.read(filters * 12)) 83 | count += 3 * filters 84 | net_dict[keys[key_num + 1]].copy_(torch.from_numpy(bn_weights[0])) 85 | net_dict[keys[key_num + 2]].copy_(torch.from_numpy(conv_bias)) 86 | net_dict[keys[key_num + 3]].copy_(torch.from_numpy(bn_weights[1])) 87 | net_dict[keys[key_num + 4]].copy_(torch.from_numpy(bn_weights[2])) 88 | else: 89 | net_dict[keys[key_num + 1]].copy_(torch.from_numpy(conv_bias)) 90 | # 导入卷积层参数 91 | conv_weights = np.ndarray( 92 | shape=weights_shape, 93 | dtype='float32', 94 | buffer=weights_file.read(weights_size * 4)) 95 | count += weights_size 96 | net_dict[keys[key_num]].copy_(torch.from_numpy(conv_weights)) 97 | key_num = key_num + 5 if bn else key_num + 1 98 | else: 99 | continue 100 | # check the convert 101 | remaining_weights = len(weights_file.read()) // 4 102 | weights_file.close() 103 | print('Read {} of {} from Darknet weights.'.format(count, count + remaining_weights)) 104 | if remaining_weights > 0: 105 | print('Warning: {} unused weights'.format(remaining_weights)) 106 | # save the net.state_dict 107 | torch.save(net_dict, os.path.join(output_path, cfgname + '.pth')) 108 | 109 | 110 | if __name__ == '__main__': 111 | parser = argparse.ArgumentParser() 112 | curdir = os.path.abspath('..') 113 | cfg_path = os.path.join(curdir, 'config/yolo.cfg') 114 | weight_path = os.path.join(curdir, 'config/yolo.weights') 115 | output_path = os.path.join(curdir, 'model') 116 | # parameters 117 | parser.add_argument('--cfg_path', default=cfg_path, type=str) 118 | parser.add_argument('--weight_path', default=weight_path, type=str) 119 | parser.add_argument('--output_path', default=output_path, type=str) 120 | 121 | config = parser.parse_args() 122 | if not os.path.exists(config.output_path): 123 | os.mkdir(config.output_path) 124 | weight2pth(config.cfg_path, config.weight_path, config.output_path) 125 | -------------------------------------------------------------------------------- /network/yolo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from network.prior import PriorBox 6 | from network.detect import Detect 7 | 8 | 9 | # module1: conv+bn+leaky_relu 10 | class ConvLayer(nn.Module): 11 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, relu=True, same_padding=False): 12 | super(ConvLayer, self).__init__() 13 | padding = kernel_size // 2 if same_padding else 0 14 | 15 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=False) 16 | self.bn = nn.BatchNorm2d(out_channels) 17 | self.relu = nn.LeakyReLU(0.1, inplace=True) if relu else None 18 | 19 | def forward(self, x): 20 | x = self.conv(x) 21 | x = self.bn(x) 22 | if self.relu is not None: 23 | x = self.relu(x) 24 | return x 25 | 26 | 27 | # reorg layer 28 | class ReorgLayer(nn.Module): 29 | def __init__(self, stride=2): 30 | super(ReorgLayer, self).__init__() 31 | self.stride = stride 32 | 33 | def forward(self, x): 34 | B, C, H, W = x.size() 35 | s = self.stride 36 | x = x.view(B, C, H // s, s, W // s, s).transpose(3, 4).contiguous() 37 | x = x.view(B, C, H // s * W // s, s * s).transpose(2, 3).contiguous() 38 | x = x.view(B, C, s * s, H // s, W // s).transpose(1, 2).contiguous() 39 | return x.view(B, s * s * C, H // s, W // s) 40 | 41 | 42 | # darknet feature detector 43 | class DarknetBone(nn.Module): 44 | def __init__(self, cfg): 45 | super(DarknetBone, self).__init__() 46 | self.cfg = cfg 47 | in_channel, out_channel = 3, 16 if cfg.tiny else 32 48 | flag, pool, size_flag = cfg.flag, cfg.pool, cfg.size_flag 49 | layers1, layers2 = [], [] 50 | for i in range(cfg.num): 51 | ksize = 1 if i in size_flag else 3 52 | if i < 13: 53 | layers1.append(ConvLayer(in_channel, out_channel, ksize, same_padding=True)) 54 | layers1.append(nn.MaxPool2d(2)) if i in pool else None 55 | layers1 += [nn.ReflectionPad2d([0, 1, 0, 1]), nn.MaxPool2d(2, 1)] if i == 5 and cfg.tiny else [] 56 | else: 57 | layers2.append(nn.MaxPool2d(2)) if i in pool else None 58 | layers2.append(ConvLayer(in_channel, out_channel, ksize, same_padding=True)) 59 | in_channel, out_channel = out_channel, out_channel * 2 if flag[i] else out_channel // 2 60 | self.main1 = nn.Sequential(*layers1) 61 | self.main2 = nn.Sequential(*layers2) 62 | 63 | def forward(self, x): 64 | xd = self.main1(x) 65 | if self.cfg.tiny: 66 | return xd 67 | else: 68 | x = self.main2(xd) 69 | return x, xd 70 | 71 | 72 | # YOLO 73 | class Yolo(nn.Module): 74 | def __init__(self, cfg): 75 | super(Yolo, self).__init__() 76 | self.cfg = cfg 77 | self.prior = Variable(PriorBox(cfg).forward(), volatile=True) 78 | self.darknet = DarknetBone(cfg) 79 | if cfg.tiny: 80 | out = 1024 if cfg.voc else 512 81 | self.conv = nn.Sequential( 82 | ConvLayer(1024, out, 3, same_padding=True), 83 | nn.Conv2d(out, cfg.anchor_num * (cfg.class_num + 5), 1)) 84 | else: 85 | self.conv1 = nn.Sequential( 86 | ConvLayer(1024, 1024, 3, same_padding=True), 87 | ConvLayer(1024, 1024, 3, same_padding=True)) 88 | self.conv2 = nn.Sequential( 89 | ConvLayer(512, 64, 1, same_padding=True), 90 | ReorgLayer(2)) 91 | self.conv = nn.Sequential( 92 | ConvLayer(1280, 1024, 3, same_padding=True), 93 | nn.Conv2d(1024, cfg.anchor_num * (cfg.class_num + 5), 1)) 94 | 95 | def forward(self, x, img_shape=None): 96 | if self.cfg.tiny: 97 | x = self.conv(self.darknet(x)) 98 | else: 99 | x1, x2 = self.darknet(x) 100 | x = self.conv(torch.cat([self.conv2(x2), self.conv1(x1)], 1)) 101 | # extract each part 102 | b, c, h, w = x.size() 103 | feat = x.permute(0, 2, 3, 1).contiguous().view(b, -1, self.cfg.anchor_num, self.cfg.class_num + 5) 104 | box_xy, box_wh = F.sigmoid(feat[..., 0:2]), feat[..., 2:4].exp() 105 | box_conf, score_pred = F.sigmoid(feat[..., 4:5]), feat[..., 5:].contiguous() 106 | box_prob = F.softmax(score_pred, dim=3) 107 | box_pred = torch.cat([box_xy, box_wh], 3) 108 | # TODO: add training phase 109 | if self.training: 110 | return x 111 | else: 112 | width, height = img_shape 113 | img_shape = Variable(torch.Tensor([[width, height, width, height]])) 114 | if self.cfg.cuda: self.prior, img_shape = self.prior.cuda(), img_shape.cuda() 115 | self.prior = self.prior.view_as(box_pred) 116 | return Detect(self.cfg, self.cfg.eval)(box_pred, box_conf, box_prob, self.prior, img_shape) 117 | 118 | 119 | # interface --- construct different type yolo model 120 | def yolo(cfg): 121 | model = Yolo(cfg) 122 | return model 123 | 124 | 125 | # weight initialize 126 | def weights_init(m): 127 | classname = m.__class__.__name__ 128 | if classname.find('Conv2d') != -1: 129 | m.weight.data.normal_(0.0, (m.kernel_size[0] ** 2 * m.out_channels) ** 0.5) 130 | elif classname.find('BatchNorm') != -1: 131 | # Estimated variance, must be around 1 132 | m.weight.data.normal_(1.0, 0.02) 133 | # Estimated mean, must be around 0 134 | m.bias.data.fill_(0) 135 | --------------------------------------------------------------------------------