├── png
└── cat.jpg
├── results
└── demo
│ ├── cat.jpg
│ └── dog.jpg
├── config
├── FiraMono-Medium.otf
├── tiny-yolo-voc.py
├── yolo-voc.py
├── tiny-yolo.py
└── yolo.py
├── network
├── postprocesscv.py
├── prior.py
├── detect.py
├── postprocess.py
└── yolo.py
├── README.md
├── README.zh.md
├── demo_cam.py
├── demo.py
├── demo_video.py
└── tools
└── yad2t.py
/png/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/YOLO-pytorch/HEAD/png/cat.jpg
--------------------------------------------------------------------------------
/results/demo/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/YOLO-pytorch/HEAD/results/demo/cat.jpg
--------------------------------------------------------------------------------
/results/demo/dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/YOLO-pytorch/HEAD/results/demo/dog.jpg
--------------------------------------------------------------------------------
/config/FiraMono-Medium.otf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/YOLO-pytorch/HEAD/config/FiraMono-Medium.otf
--------------------------------------------------------------------------------
/network/postprocesscv.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 |
5 | # draw proposal boxes using OpenCV
6 | def draw_box_cv(cfg, image, label, box, c):
7 | h, w = image.shape[:2]
8 | thickness = (w + h) // 300
9 | left, top, right, bottom = box
10 | top, left = max(0, np.round(top).astype('int32')), max(0, np.round(left).astype('int32'))
11 | right, bottom = min(w, np.round(right).astype('int32')), min(h, np.round(bottom).astype('int32'))
12 | cv2.rectangle(image, (left, top), (right, bottom), cfg.colors[c], thickness)
13 | cv2.putText(image, label, (left, top - 5), 0, 0.5, cfg.colors[c], 1)
14 |
--------------------------------------------------------------------------------
/network/prior.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from itertools import product
4 |
5 |
6 | # prior box in each position
7 | class PriorBox(object):
8 | def __init__(self, cfg):
9 | super(PriorBox, self).__init__()
10 | self.image_size = cfg.image_size
11 | self.num_priors = cfg.anchor_num
12 | self.feat_size = cfg.feat_size
13 | self.anchors = np.array(cfg.anchors).reshape(-1, 2)
14 |
15 | def forward(self):
16 | mean = []
17 | for i, j in product(range(self.feat_size), repeat=2):
18 | cx = j
19 | cy = i
20 | for k in range(self.num_priors):
21 | w = self.anchors[k, 0]
22 | h = self.anchors[k, 1]
23 | mean += [cx, cy, w, h]
24 |
25 | output = torch.Tensor(mean).view(-1, 4)
26 | return output
27 |
--------------------------------------------------------------------------------
/network/detect.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import Function
2 | from network.postprocess import box_to_corners, filter_box, non_max_suppression
3 |
4 |
5 | # test phase's proposal process
6 | class Detect(Function):
7 | def __init__(self, cfg, eval=False):
8 | self.class_num = cfg.class_num
9 | self.feat_size = cfg.feat_size
10 | if eval:
11 | self.nms_t, self.score_t = cfg.eval_nms_threshold, cfg.eval_score_threshold
12 | else:
13 | self.nms_t, self.score_t = cfg.nms_threshold, cfg.score_threshold
14 |
15 | def forward(self, box_pred, box_conf, box_prob, priors, img_shape, max_boxes=10):
16 | box_pred[..., 0:2] += priors[..., 0:2]
17 | box_pred[..., 2:] *= priors[..., 2:]
18 | boxes = box_to_corners(box_pred) / self.feat_size
19 | boxes, scores, classes = filter_box(boxes, box_conf, box_prob, self.score_t)
20 | if boxes.numel() == 0:
21 | return boxes, scores, classes
22 | boxes = boxes * img_shape.repeat(boxes.size(0), 1)
23 | keep, count = non_max_suppression(boxes, scores, self.nms_t)
24 | boxes = boxes[keep[:count]]
25 | scores = scores[keep[:count]]
26 | classes = classes[keep[:count]]
27 | return boxes, scores, classes
28 |
--------------------------------------------------------------------------------
/config/tiny-yolo-voc.py:
--------------------------------------------------------------------------------
1 | import colorsys
2 | import random
3 |
4 | # create model information
5 | tiny, voc = True, True
6 | num = 7
7 | flag, size_flag = [1] * num, []
8 | pool = [0, 1, 2, 3, 4]
9 |
10 | # anchor and classes information
11 | anchors = [1.08, 1.19, 3.42, 4.41, 6.63, 11.38, 9.42, 5.11, 16.62, 10.52]
12 | classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair',
13 | 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
14 | 'sofa', 'train', 'tvmonitor']
15 | anchor_num = 5
16 | class_num = len(classes)
17 |
18 | # color for draw boxes
19 | hsv_tuples = [(x / class_num, 1., 1.) for x in range(class_num)]
20 | colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
21 | colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
22 | random.seed(10101) # Fixed seed for consistent colors across runs.
23 | random.shuffle(colors) # Shuffle colors to decorrelate adjacent classes.
24 | random.seed(None) # Reset seed to default.
25 |
26 | # input image information
27 | image_size = (416, 416)
28 | feat_size = image_size[0] // 32
29 |
30 | cuda = False
31 |
32 | # demo parameter
33 | eval = False
34 | score_threshold = 0.5
35 | nms_threshold = 0.4
36 | iou_threshold = 0.6
37 |
--------------------------------------------------------------------------------
/config/yolo-voc.py:
--------------------------------------------------------------------------------
1 | import colorsys
2 | import random
3 |
4 | # create model information
5 | tiny, voc, num = False, True, 18
6 | flag = [1, 1, 0] * 3 + [1, 0, 1] * 2 + [0, 1, 0]
7 | size_flag = [3, 6, 9, 11, 14, 16]
8 | pool = [0, 1, 4, 7, 13]
9 |
10 | # anchor and classes information
11 | anchors = [1.3221, 1.73145, 3.19275, 4.00944, 5.05587, 8.09892, 9.47112, 4.84053, 11.2364, 10.0071]
12 | classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair',
13 | 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
14 | 'sofa', 'train', 'tvmonitor']
15 | anchor_num = 5
16 | class_num = len(classes)
17 |
18 | # color for draw boxes
19 | hsv_tuples = [(x / class_num, 1., 1.) for x in range(class_num)]
20 | colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
21 | colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
22 | random.seed(10101) # Fixed seed for consistent colors across runs.
23 | random.shuffle(colors) # Shuffle colors to decorrelate adjacent classes.
24 | random.seed(None) # Reset seed to default.
25 |
26 | # input image information
27 | image_size = (416, 416)
28 | feat_size = image_size[0] // 32
29 |
30 | cuda = False
31 |
32 | # demo parameter
33 | eval = False
34 | score_threshold = 0.5
35 | nms_threshold = 0.4
36 | iou_threshold = 0.6
37 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # YOLO-Pytorch
2 | [中文说明](README.zh.md)
3 |
4 | ## Description
5 |
6 | This is a pytorch version of [YAD2K](https://github.com/allanzelener/YAD2K)。
7 |
8 | Original paper: [YOLO9000: Better, Faster, Stronger](https://arxiv.org/abs/1612.08242)by Joseph Redmond and Ali Farhadi.
9 |
10 |

11 |
12 | ---
13 |
14 | ## Requirements
15 |
16 | - Pytorch 0.3.0
17 | - torchvision
18 | - opencv(Requirement for camera and video)
19 | - python 3
20 |
21 |
22 | ## Usage
23 |
24 | 1. Download Darknet model cfg and weights from the [official YOLO website](http://pjreddie.com/darknet/yolo/).
25 |
26 | ```bash
27 | # for example --- or other version cfg and weights
28 | wget http://pjreddie.com/media/files/yolo.weights
29 | wget https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolo.cfg
30 | ```
31 | Note: you can download other types: like `yolo-voc.cfg`
32 |
33 | 2. Convert the weights to `.pth`
34 |
35 | ```bash
36 | python tools/yad2t.py path-to-yolo-cfg path-to-yolo-weights path-to-output-folder
37 | ```
38 |
39 | Note: default choose
40 |
41 | - copy your `yolo.cfg` and `yolo.weights` to the directory `config`
42 | - the output folder is `model`
43 |
44 | 3. Three demos (picture, camera, video)
45 |
46 | 1. `demo.py`
47 |
48 | ```bash
49 | python demo.py pic-path yolo-type --cuda=True
50 | ```
51 |
52 | Note: default choose
53 |
54 | - picture in folder `results/demo`
55 | - `yolo-type` is `yolo`: three kinds: `[yolo, tiny-yolo-voc, yolo-voc]`
56 |
57 | 2. `demo_cam.py`
58 |
59 | ```bash
60 | python demo_cam.py --trained_model=pth_model_from_1
61 | ```
62 |
63 | 3. `demo_video.py`
64 |
65 | ```bash
66 | python demo_video.py --demo_path=video_path --trained_model=pth_model_from_1
67 | ```
68 |
69 |
70 |
--------------------------------------------------------------------------------
/README.zh.md:
--------------------------------------------------------------------------------
1 | # YOLO-Pytorch
2 |
3 | [English](README.md)
4 |
5 | ## 说明
6 |
7 | 主要将Keras版本的Yolo2[YAD2K](https://github.com/allanzelener/YAD2K)移植到pytorch。
8 |
9 | 原始论文:[YOLO9000: Better, Faster, Stronger](https://arxiv.org/abs/1612.08242)
10 |
11 | 
12 |
13 | ---
14 |
15 | ## 环境要求
16 |
17 | - Pytorch 0.3.0
18 | - torchvision
19 | - OpenCV
20 | - python 3
21 |
22 | ## 使用说明
23 |
24 | 1. 从[官网](http://pjreddie.com/darknet/yolo/)下载已经训练好的模型和说明:
25 |
26 | ```bash
27 | # 还有tiny-yolo-voc 和 yolo-voc等版本
28 | wget http://pjreddie.com/media/files/yolo.weights
29 | wget https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolo.cfg
30 | ```
31 |
32 | 2. 将`.weight`的参数导成`.pth`文件(具体实现见`tools/yad2t.py`函数):
33 |
34 | ```bash
35 | python tools/yad2t.py path-to-yolo-cfg path-to-yolo-weights path-to-output-folder
36 | ```
37 |
38 | 说明:① 默认情况假设你将1中下载的参数和说明放置在`config`这个文件夹中 ② 默认会将导出的`.pth`文件保存到`model`这个文件夹中
39 |
40 | 3. 三个示例程序:
41 |
42 | - `demo.py`:处理单张图片或者一个包含图片的文件夹
43 |
44 | ```bash
45 | python demo.py pic-path yolo-type --cuda=True
46 | ```
47 |
48 | 说明: ① 默认图片地址为`results/demo`中的所有图片 ② 默认采用`yolo`这个模型,你可以选择`yolo-voc`或者`tiny-yolo-voc`(但需同前面两步获得对应的训练好的模型) ③ 是否选择采用gpu模型
49 |
50 | - `demo_cam.py`:摄像头处理(最好是采用gpu加速)
51 |
52 | ```bash
53 | python demo_cam.py --trained_model=pth_model_from_1
54 | ```
55 |
56 | 说明:① 具体其他参数等请看`demo_cam.py`
57 |
58 | - `demo_video.py`:视频处理(可能有bug)
59 |
60 | ```bash
61 | python demo_video.py --demo_path=video_path --trained_model=pth_model_from_1
62 | ```
63 |
64 | 说明:① 目前可以处理`avi`和`mp4`格式,其他类型未验证
65 |
66 | ### 各文件夹说明
67 |
68 | - `config`:预先保存的一些参数
69 | - `network`:包含网络结构的实现,图像后处理等操作
70 | - `tools`:包含参数转换等工具
71 | - `results`:主要保存实验结果
72 |
73 | ## 待办事项
74 |
75 | - [ ] 测试mAP
76 | - [ ] 增加训练过程(这部分待定)
77 |
78 | ## 问题
79 |
80 | 欢迎指出存在的bug,以及pull request~ 谢谢
--------------------------------------------------------------------------------
/config/tiny-yolo.py:
--------------------------------------------------------------------------------
1 | import colorsys
2 | import random
3 |
4 | # create model information
5 | tiny, voc = True, False
6 | num = 7
7 | flag, size_flag = [1] * num, []
8 | pool = [0, 1, 2, 3, 4]
9 |
10 | # anchor and classes information
11 | anchors = [1.3221, 1.73145, 3.19275, 4.00944, 5.05587, 8.09892, 9.47112, 4.84053, 11.2364, 10.0071]
12 | classes = ['person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
13 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
14 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
15 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
16 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
17 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'sofa',
18 | 'pottedplant', 'bed', 'diningtable', 'toilet', 'tvmonitor', 'laptop', 'mouse', 'remote', 'keyboard',
19 | 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
20 | 'teddy bear', 'hair drier', 'toothbrush']
21 | anchor_num = 5
22 | class_num = len(classes)
23 |
24 | # color for draw boxes
25 | hsv_tuples = [(x / class_num, 1., 1.) for x in range(class_num)]
26 | colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
27 | colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
28 | random.seed(10101) # Fixed seed for consistent colors across runs.
29 | random.shuffle(colors) # Shuffle colors to decorrelate adjacent classes.
30 | random.seed(None) # Reset seed to default.
31 |
32 | # input image information
33 | image_size = (416, 416)
34 | feat_size = image_size[0] // 32
35 |
36 | cuda = False
37 |
38 | # demo parameter
39 | eval = False
40 | score_threshold = 0.5
41 | nms_threshold = 0.4
42 | iou_threshold = 0.6
43 |
--------------------------------------------------------------------------------
/config/yolo.py:
--------------------------------------------------------------------------------
1 | import colorsys
2 | import random
3 |
4 | # create model information
5 | tiny, voc, num = False, False, 18
6 | flag = [1, 1, 0] * 3 + [1, 0, 1] * 2 + [0, 1, 0]
7 | size_flag = [3, 6, 9, 11, 14, 16]
8 | pool = [0, 1, 4, 7, 13]
9 |
10 | # anchor and classes information
11 | anchors = [0.57273, 0.677385, 1.87446, 2.06253, 3.33843, 5.47434, 7.88282, 3.52778, 9.77052, 9.16828]
12 | classes = ['person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
13 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
14 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
15 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
16 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
17 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'sofa',
18 | 'pottedplant', 'bed', 'diningtable', 'toilet', 'tvmonitor', 'laptop', 'mouse', 'remote', 'keyboard',
19 | 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
20 | 'teddy bear', 'hair drier', 'toothbrush']
21 | anchor_num = 5
22 | class_num = len(classes)
23 |
24 | # color for draw boxes
25 | hsv_tuples = [(x / class_num, 1., 1.) for x in range(class_num)]
26 | colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
27 | colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
28 | random.seed(10101) # Fixed seed for consistent colors across runs.
29 | random.shuffle(colors) # Shuffle colors to decorrelate adjacent classes.
30 | random.seed(None) # Reset seed to default.
31 |
32 | # input image information
33 | image_size = (608, 608)
34 | feat_size = image_size[0] // 32
35 |
36 | cuda = False
37 |
38 | # demo parameter
39 | eval = False
40 | score_threshold = 0.5
41 | nms_threshold = 0.4
42 | iou_threshold = 0.6
43 |
--------------------------------------------------------------------------------
/demo_cam.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import argparse
5 | import importlib
6 | from torch.autograd import Variable
7 | from network.postprocesscv import draw_box_cv
8 | from network.yolo import yolo
9 |
10 |
11 | def demo_cam(cfg, save_path, save):
12 | net = yolo(cfg)
13 | net.load_state_dict(torch.load(cfg.trained_model))
14 | if cfg.cuda: net = net.cuda()
15 | net.eval()
16 | cam = cv2.VideoCapture(0)
17 | if not cam.isOpened(): raise IOError("check your camera or the opencv library...")
18 | if save:
19 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
20 | out = cv2.VideoWriter(save_path + '/out_camera.avi', fourcc, 20.0, (640, 480))
21 | while cam.isOpened():
22 | ret, image = cam.read()
23 | if ret:
24 | reimg = cv2.resize(image, cfg.image_size)
25 | reimg = torch.from_numpy(reimg.transpose(2, 0, 1)).float().div(255.0).unsqueeze(0)
26 | reimg = Variable(reimg, volatile=True)
27 | if cfg.cuda: reimg = reimg.cuda()
28 | boxes, scores, classes = net(reimg, (image.shape[1], image.shape[0]))
29 | boxes, scores, classes = boxes.data.cpu(), scores.data.cpu(), classes.data.cpu()
30 | for i, c in list(enumerate(classes)):
31 | pred_class, box, score = cfg.classes[c], boxes[i], scores[i]
32 | label = '{} {:.2f}'.format(pred_class, score)
33 | draw_box_cv(cfg, image, label, box, c)
34 | cv2.imshow('camera', image)
35 | if cv2.waitKey(1) & 0xFF == ord('q'): break
36 | if save: out.write(image)
37 | cam.release()
38 | if save: out.release()
39 | cv2.destroyAllWindows()
40 |
41 |
42 | if __name__ == '__main__':
43 | # default path
44 | curdir = os.getcwd()
45 | trained_model = os.path.join(curdir, 'model/yolo.pth')
46 |
47 | # demo parameters
48 | parser = argparse.ArgumentParser()
49 | parser.add_argument('--yolo_type', default='yolo', type=str)
50 | parser.add_argument('--save', default=False, type=bool)
51 | parser.add_argument('--save_path', default='results/demo/camera', type=str)
52 | parser.add_argument('--nms_threshold', default=0.4, type=float)
53 | parser.add_argument('--score_threshold', default=0.5, type=float)
54 | parser.add_argument('--trained_model', default=trained_model, type=str)
55 | parser.add_argument('--cuda', default=True, type=bool)
56 |
57 | config = parser.parse_args()
58 | cfg = importlib.import_module('config.' + config.yolo_type)
59 | cfg.nms_threshold = config.nms_threshold
60 | cfg.score_threshold = config.score_threshold
61 | cfg.trained_model = config.trained_model
62 | cfg.cuda = config.cuda
63 |
64 | if not os.path.exists(config.save_path) and config.save: os.mkdir(config.save_path)
65 | demo_cam(cfg, config.save_path, config.save)
66 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import importlib
5 | from PIL import Image
6 | from torchvision import transforms
7 | from torch.autograd import Variable
8 | from network.postprocess import draw_box
9 | from network.yolo import yolo
10 |
11 |
12 | # TODO: add a dataloader form
13 | def demo(cfg, img_list, save_path):
14 | transform = transforms.Compose([transforms.Resize(cfg.image_size), transforms.ToTensor()])
15 | net = yolo(cfg)
16 | net.load_state_dict(torch.load(cfg.trained_model))
17 | if cfg.cuda: net = net.cuda()
18 | net.eval()
19 | for img in img_list:
20 | image = Image.open(img)
21 | reimg = Variable(transform(image).unsqueeze(0))
22 | if cfg.cuda: reimg = reimg.cuda()
23 | boxes, scores, classes = net(reimg, image.size)
24 | boxes, scores, classes = boxes.data.cpu(), scores.data.cpu(), classes.data.cpu()
25 | print('Find {} boxes for {}.'.format(len(boxes), img.split('/')[-1]))
26 | for i, c in list(enumerate(classes)):
27 | pred_class, box, score = cfg.classes[c], boxes[i], scores[i]
28 | label = '{} {:.2f}'.format(pred_class, score)
29 | draw_box(cfg, image, label, box, c)
30 | image.save(os.path.join(save_path, img.split('/')[-1]), quality=90)
31 |
32 |
33 | if __name__ == '__main__':
34 | # default path
35 | curdir = os.getcwd()
36 | demo_path = os.path.join(curdir, 'results/demo')
37 | trained_model = os.path.join(curdir, 'model/yolo.pth')
38 |
39 | # demo parameters
40 | parser = argparse.ArgumentParser()
41 | parser.add_argument('--demo_path', default=demo_path, type=str)
42 | parser.add_argument('--yolo_type', default='yolo', type=str)
43 | parser.add_argument('--nms_threshold', default=0.4, type=float)
44 | parser.add_argument('--score_threshold', default=0.5, type=float)
45 | parser.add_argument('--trained_model', default=trained_model, type=str)
46 | parser.add_argument('--cuda', default=True, type=str)
47 |
48 | config = parser.parse_args()
49 | cfg = importlib.import_module('config.' + config.yolo_type)
50 | cfg.nms_threshold = config.nms_threshold
51 | cfg.score_threshold = config.score_threshold
52 | cfg.demo_path = config.demo_path
53 | cfg.trained_model = config.trained_model
54 | cfg.cuda = config.cuda
55 |
56 | ext = ['.jpg', '.png']
57 | if os.path.isfile(cfg.demo_path) and os.path.splitext(cfg.demo_path)[-1] in ext:
58 | img_list = [cfg.demo_path]
59 | save_path = os.path.join(os.path.dirname(cfg.demo_path), 'out')
60 | elif os.path.isdir(cfg.demo_path):
61 | imgs = [fname for fname in os.listdir(cfg.demo_path) if os.path.splitext(fname)[-1] in ext]
62 | img_list = [os.path.join(cfg.demo_path, fname) for fname in imgs]
63 | save_path = os.path.join(cfg.demo_path, 'out')
64 | if not os.path.exists(save_path): os.mkdir(save_path)
65 | if not img_list:
66 | raise IOError("illegal demo path ...")
67 | demo(cfg, img_list, save_path)
68 |
--------------------------------------------------------------------------------
/demo_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import argparse
5 | import importlib
6 | from torch.autograd import Variable
7 | from timeit import default_timer as timer
8 | from network.postprocesscv import draw_box_cv
9 | from network.yolo import yolo
10 |
11 |
12 | def demo_video(cfg, file, save_path, save, start_frame=0):
13 | net = yolo(cfg)
14 | net.load_state_dict(torch.load(cfg.trained_model))
15 | if cfg.cuda: net = net.cuda()
16 | net.eval()
17 | video = cv2.VideoCapture(file)
18 | if not video.isOpened():
19 | raise IOError('Could not open video, check your opencv and video')
20 | if save:
21 | fps = video.get(cv2.CAP_PROP_FPS)
22 | size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)))
23 | out = cv2.VideoWriter(os.path.join(save_path + file.split('/')[-1]), cv2.VideoWriter_fourcc(*'XVID'), fps, size)
24 |
25 | if start_frame > 0:
26 | video.set(cv2.CAP_PROP_POS_MSEC, start_frame)
27 | accum_time = 0
28 | curr_fps = 0
29 | fps = "FPS: "
30 | prev_time = timer()
31 |
32 | while True:
33 | retval, orig_image = video.read()
34 | if not retval:
35 | print('Done !!!')
36 | return
37 | reimg = cv2.resize(orig_image, cfg.image_size)
38 | reimg = torch.from_numpy(reimg.transpose(2, 0, 1)).float().div(255.0).unsqueeze(0)
39 | x = Variable(reimg, volatile=True)
40 | if cfg.cuda: x = x.cuda()
41 | boxes, scores, classes = net(x, (orig_image.shape[1], orig_image.shape[0]))
42 | boxes, scores, classes = boxes.data.cpu(), scores.data.cpu(), classes.data.cpu()
43 | for i, c in list(enumerate(classes)):
44 | pred_class, box, score = cfg.classes[c], boxes[i], scores[i]
45 | label = '{} {:.2f}'.format(pred_class, score)
46 | draw_box_cv(cfg, orig_image, label, box, c)
47 | curr_time = timer()
48 | exec_time = curr_time - prev_time
49 | prev_time = curr_time
50 | accum_time += exec_time
51 | curr_fps = curr_fps + 1
52 | if accum_time > 1:
53 | accum_time -= 1
54 | fps = "FPS:" + str(curr_fps)
55 | curr_fps = 0
56 | cv2.rectangle(orig_image, (0, 0), (50, 17), (255, 255, 255), -1)
57 | cv2.putText(orig_image, fps, (3, 10), 0, 0.35, (0, 0, 0), 1)
58 | if save:
59 | out.write(orig_image)
60 | cv2.imshow("yolo result", orig_image)
61 | if cv2.waitKey(1) & 0xFF == ord('q'):
62 | break
63 | video.release()
64 | if save:
65 | out.release()
66 |
67 |
68 | if __name__ == '__main__':
69 | # default path
70 | curdir = os.getcwd()
71 | trained_model = os.path.join(curdir, 'model/yolo.pth')
72 | demo_path = os.path.join(curdir, 'results/demo/video.avi')
73 |
74 | # demo parameters
75 | parser = argparse.ArgumentParser()
76 | parser.add_argument('--yolo_type', default='yolo', type=str)
77 | parser.add_argument('--save', default=False, type=bool)
78 | parser.add_argument('--save_path', default='results/demo/video', type=str)
79 | parser.add_argument('--demo_path', default=demo_path, type=str)
80 | parser.add_argument('--nms_threshold', default=0.4, type=float)
81 | parser.add_argument('--score_threshold', default=0.5, type=float)
82 | parser.add_argument('--trained_model', default=trained_model, type=str)
83 | parser.add_argument('--cuda', default=True, type=bool)
84 |
85 | config = parser.parse_args()
86 | cfg = importlib.import_module('config.' + config.yolo_type)
87 | cfg.nms_threshold = config.nms_threshold
88 | cfg.score_threshold = config.score_threshold
89 | cfg.trained_model = config.trained_model
90 | cfg.cuda = config.cuda
91 |
92 | ext = ['.mp4', '.avi']
93 | if not os.path.splitext(config.demo_path)[-1] in ext:
94 | raise IOError("illegal video form...")
95 | if not os.path.exists(config.save_path) and config.save: os.mkdir(config.save_path)
96 | demo_video(cfg, config.demo_path, config.save_path, config.save)
97 |
--------------------------------------------------------------------------------
/network/postprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from PIL import ImageFont, ImageDraw
4 |
5 |
6 | # (x,y,w,h)--->(x1, y1, x2, y2)
7 | def box_to_corners(box_pred):
8 | box_mins = box_pred[..., 0:2] - (box_pred[..., 2:] / 2.)
9 | box_maxes = box_pred[..., 0:2] + (box_pred[..., 2:] / 2.)
10 | return torch.cat([box_mins[..., 0:1], box_mins[..., 1:2],
11 | box_maxes[..., 0:1], box_maxes[..., 1:2]], 3)
12 |
13 |
14 | # remove the proposal detector which is less than the threshold
15 | def filter_box(boxes, box_conf, box_prob, threshold=.5):
16 | box_scores = box_conf.repeat(1, 1, 1, box_prob.size(3)) * box_prob
17 | box_class_scores, box_classes = torch.max(box_scores, dim=3)
18 | prediction_mask = box_class_scores > threshold
19 | prediction_mask4 = prediction_mask.unsqueeze(3).expand(boxes.size())
20 |
21 | boxes = torch.masked_select(boxes, prediction_mask4).contiguous().view(-1, 4)
22 | scores = torch.masked_select(box_class_scores, prediction_mask)
23 | classes = torch.masked_select(box_classes, prediction_mask)
24 | return boxes, scores, classes
25 |
26 |
27 | # non-max-suppression process
28 | def non_max_suppression(boxes, scores, overlap=0.5, top_k=200):
29 | keep = torch.Tensor(scores.size(0)).fill_(0).long()
30 | if boxes.is_cuda: keep = keep.cuda()
31 | if boxes.numel() == 0:
32 | return keep
33 | x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
34 | area = torch.mul(x2 - x1, y2 - y1)
35 | v, idx = scores.sort(0) # sort in ascending order
36 | # I = I[v >= 0.01]
37 | idx = idx[-top_k:] # indices of the top-k largest vals
38 | xx1, yy1, xx2, yy2 = boxes.new(), boxes.new(), boxes.new(), boxes.new()
39 | w, h = boxes.new(), boxes.new()
40 | count = 0
41 | while idx.numel() > 0:
42 | i = idx[-1] # index of current largest val
43 | # keep.append(i)
44 | keep[count] = i
45 | count += 1
46 | if idx.size(0) == 1: break
47 | idx = idx[:-1] # remove kept element from view
48 | # load bboxes of next highest vals
49 | torch.index_select(x1, 0, idx, out=xx1)
50 | torch.index_select(y1, 0, idx, out=yy1)
51 | torch.index_select(x2, 0, idx, out=xx2)
52 | torch.index_select(y2, 0, idx, out=yy2)
53 | # store element-wise max with next highest score
54 | xx1 = torch.clamp(xx1, min=x1[i])
55 | yy1 = torch.clamp(yy1, min=y1[i])
56 | xx2 = torch.clamp(xx2, max=x2[i])
57 | yy2 = torch.clamp(yy2, max=y2[i])
58 | w.resize_as_(xx2)
59 | h.resize_as_(yy2)
60 | # check sizes of xx1 and xx2.. after each iteration
61 | w, h = torch.clamp(xx2 - xx1, min=0.0), torch.clamp(yy2 - yy1, min=0.0)
62 | inter = w * h
63 | # IoU = i / (area(a) + area(b) - i)
64 | rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
65 | union = (rem_areas - inter) + area[i]
66 | IoU = inter / union # store result in iou
67 | # keep only elements with an IoU <= overlap
68 | idx = idx[IoU.le(overlap)]
69 | return keep, count
70 |
71 |
72 | # draw proposal boxes
73 | def draw_box(cfg, image, label, box, c):
74 | w, h = image.size
75 | font = ImageFont.truetype(font='./config/FiraMono-Medium.otf', size=np.round(3e-2 * h).astype('int32'))
76 | thickness = (w + h) // 300
77 | draw = ImageDraw.Draw(image)
78 | label_size = draw.textsize(label, font)
79 | left, top, right, bottom = box
80 | top, left = max(0, np.round(top).astype('int32')), max(0, np.round(left).astype('int32'))
81 | right, bottom = min(w, np.round(right).astype('int32')), min(h, np.round(bottom).astype('int32'))
82 | print(label, (left, top), (right, bottom))
83 | text_orign = np.array([left, top - label_size[1]]) if top - label_size[1] >= 0 else np.array([left, top + 1])
84 | for i in range(thickness):
85 | draw.rectangle([left + i, top + i, right - i, bottom - i], outline=cfg.colors[c])
86 | draw.rectangle([tuple(text_orign), tuple(text_orign + label_size)], fill=cfg.colors[c])
87 | draw.text(text_orign, label, fill=(0, 0, 0), font=font)
88 | del draw
89 |
--------------------------------------------------------------------------------
/tools/yad2t.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import io
3 | import os
4 | import argparse
5 | import configparser
6 | import importlib
7 | import torch
8 | import numpy as np
9 | from collections import defaultdict
10 |
11 | sys.path.append('..')
12 | from network.yolo import yolo
13 |
14 |
15 | # exchange the [xxx] name to [xxx_num] form
16 | def unique_config_sections(config_file):
17 | """Convert all config sections to have unique names.
18 |
19 | Adds unique suffixes to config sections for compability with configparser.
20 | """
21 | section_counters = defaultdict(int)
22 | output_stream = io.StringIO()
23 | with open(config_file) as fin:
24 | for line in fin:
25 | if line.startswith('['):
26 | section = line.strip().strip('[]')
27 | _section = section + '_' + str(section_counters[section])
28 | section_counters[section] += 1
29 | line = line.replace(section, _section)
30 | output_stream.write(line)
31 | output_stream.seek(0)
32 | return output_stream
33 |
34 |
35 | def weight2pth(config_path, weights_path, output_path):
36 | assert config_path.endswith('.cfg'), '{} is not a .cfg file'.format(config_path)
37 | assert weights_path.endswith('.weights'), '{} is not a .weights file'.format(weights_path)
38 | # weights header
39 | weights_file = open(weights_path, 'rb')
40 | weights_header = np.ndarray(shape=(4,), dtype='int32', buffer=weights_file.read(16))
41 | print('Weights Header: ', weights_header)
42 | # convert config information
43 | unique_config_file = unique_config_sections(config_path)
44 | cfg_parser = configparser.ConfigParser()
45 | cfg_parser.read_file(unique_config_file)
46 | # network information
47 | cfgname = config_path.split('/')[-1].split('.')[0]
48 | cfg = importlib.import_module('config.' + cfgname)
49 | net = yolo(cfg)
50 | net_dict = net.state_dict()
51 | keys = list(net_dict.keys())
52 | key_num, count, prev_filter = 0, 0, 3
53 | print('loading the weights ...')
54 | for section in cfg_parser.sections():
55 | if section.startswith('convolutional'):
56 | filters = int(cfg_parser[section]['filters'])
57 | size = int(cfg_parser[section]['size'])
58 | bn = 'batch_normalize' in cfg_parser[section]
59 | activation = cfg_parser[section]['activation']
60 | # three special case
61 | if section == 'convolutional_20':
62 | prev_filter = 512
63 | elif section == 'convolutional_21':
64 | prev_filter = 1280
65 | elif section == 'convolutional_0':
66 | prev_filter = 3
67 | else:
68 | prev_filter = weights_shape[0]
69 |
70 | weights_shape = (filters, prev_filter, size, size)
71 | weights_size = np.product(weights_shape)
72 | print('conv2d', 'bn' if bn else ' ', activation, weights_shape)
73 | conv_bias = np.ndarray(
74 | shape=(filters,),
75 | dtype='float32',
76 | buffer=weights_file.read(filters * 4))
77 | count += filters
78 | if bn:
79 | bn_weights = np.ndarray(
80 | shape=(3, filters),
81 | dtype='float32',
82 | buffer=weights_file.read(filters * 12))
83 | count += 3 * filters
84 | net_dict[keys[key_num + 1]].copy_(torch.from_numpy(bn_weights[0]))
85 | net_dict[keys[key_num + 2]].copy_(torch.from_numpy(conv_bias))
86 | net_dict[keys[key_num + 3]].copy_(torch.from_numpy(bn_weights[1]))
87 | net_dict[keys[key_num + 4]].copy_(torch.from_numpy(bn_weights[2]))
88 | else:
89 | net_dict[keys[key_num + 1]].copy_(torch.from_numpy(conv_bias))
90 | # 导入卷积层参数
91 | conv_weights = np.ndarray(
92 | shape=weights_shape,
93 | dtype='float32',
94 | buffer=weights_file.read(weights_size * 4))
95 | count += weights_size
96 | net_dict[keys[key_num]].copy_(torch.from_numpy(conv_weights))
97 | key_num = key_num + 5 if bn else key_num + 1
98 | else:
99 | continue
100 | # check the convert
101 | remaining_weights = len(weights_file.read()) // 4
102 | weights_file.close()
103 | print('Read {} of {} from Darknet weights.'.format(count, count + remaining_weights))
104 | if remaining_weights > 0:
105 | print('Warning: {} unused weights'.format(remaining_weights))
106 | # save the net.state_dict
107 | torch.save(net_dict, os.path.join(output_path, cfgname + '.pth'))
108 |
109 |
110 | if __name__ == '__main__':
111 | parser = argparse.ArgumentParser()
112 | curdir = os.path.abspath('..')
113 | cfg_path = os.path.join(curdir, 'config/yolo.cfg')
114 | weight_path = os.path.join(curdir, 'config/yolo.weights')
115 | output_path = os.path.join(curdir, 'model')
116 | # parameters
117 | parser.add_argument('--cfg_path', default=cfg_path, type=str)
118 | parser.add_argument('--weight_path', default=weight_path, type=str)
119 | parser.add_argument('--output_path', default=output_path, type=str)
120 |
121 | config = parser.parse_args()
122 | if not os.path.exists(config.output_path):
123 | os.mkdir(config.output_path)
124 | weight2pth(config.cfg_path, config.weight_path, config.output_path)
125 |
--------------------------------------------------------------------------------
/network/yolo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from network.prior import PriorBox
6 | from network.detect import Detect
7 |
8 |
9 | # module1: conv+bn+leaky_relu
10 | class ConvLayer(nn.Module):
11 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, relu=True, same_padding=False):
12 | super(ConvLayer, self).__init__()
13 | padding = kernel_size // 2 if same_padding else 0
14 |
15 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=False)
16 | self.bn = nn.BatchNorm2d(out_channels)
17 | self.relu = nn.LeakyReLU(0.1, inplace=True) if relu else None
18 |
19 | def forward(self, x):
20 | x = self.conv(x)
21 | x = self.bn(x)
22 | if self.relu is not None:
23 | x = self.relu(x)
24 | return x
25 |
26 |
27 | # reorg layer
28 | class ReorgLayer(nn.Module):
29 | def __init__(self, stride=2):
30 | super(ReorgLayer, self).__init__()
31 | self.stride = stride
32 |
33 | def forward(self, x):
34 | B, C, H, W = x.size()
35 | s = self.stride
36 | x = x.view(B, C, H // s, s, W // s, s).transpose(3, 4).contiguous()
37 | x = x.view(B, C, H // s * W // s, s * s).transpose(2, 3).contiguous()
38 | x = x.view(B, C, s * s, H // s, W // s).transpose(1, 2).contiguous()
39 | return x.view(B, s * s * C, H // s, W // s)
40 |
41 |
42 | # darknet feature detector
43 | class DarknetBone(nn.Module):
44 | def __init__(self, cfg):
45 | super(DarknetBone, self).__init__()
46 | self.cfg = cfg
47 | in_channel, out_channel = 3, 16 if cfg.tiny else 32
48 | flag, pool, size_flag = cfg.flag, cfg.pool, cfg.size_flag
49 | layers1, layers2 = [], []
50 | for i in range(cfg.num):
51 | ksize = 1 if i in size_flag else 3
52 | if i < 13:
53 | layers1.append(ConvLayer(in_channel, out_channel, ksize, same_padding=True))
54 | layers1.append(nn.MaxPool2d(2)) if i in pool else None
55 | layers1 += [nn.ReflectionPad2d([0, 1, 0, 1]), nn.MaxPool2d(2, 1)] if i == 5 and cfg.tiny else []
56 | else:
57 | layers2.append(nn.MaxPool2d(2)) if i in pool else None
58 | layers2.append(ConvLayer(in_channel, out_channel, ksize, same_padding=True))
59 | in_channel, out_channel = out_channel, out_channel * 2 if flag[i] else out_channel // 2
60 | self.main1 = nn.Sequential(*layers1)
61 | self.main2 = nn.Sequential(*layers2)
62 |
63 | def forward(self, x):
64 | xd = self.main1(x)
65 | if self.cfg.tiny:
66 | return xd
67 | else:
68 | x = self.main2(xd)
69 | return x, xd
70 |
71 |
72 | # YOLO
73 | class Yolo(nn.Module):
74 | def __init__(self, cfg):
75 | super(Yolo, self).__init__()
76 | self.cfg = cfg
77 | self.prior = Variable(PriorBox(cfg).forward(), volatile=True)
78 | self.darknet = DarknetBone(cfg)
79 | if cfg.tiny:
80 | out = 1024 if cfg.voc else 512
81 | self.conv = nn.Sequential(
82 | ConvLayer(1024, out, 3, same_padding=True),
83 | nn.Conv2d(out, cfg.anchor_num * (cfg.class_num + 5), 1))
84 | else:
85 | self.conv1 = nn.Sequential(
86 | ConvLayer(1024, 1024, 3, same_padding=True),
87 | ConvLayer(1024, 1024, 3, same_padding=True))
88 | self.conv2 = nn.Sequential(
89 | ConvLayer(512, 64, 1, same_padding=True),
90 | ReorgLayer(2))
91 | self.conv = nn.Sequential(
92 | ConvLayer(1280, 1024, 3, same_padding=True),
93 | nn.Conv2d(1024, cfg.anchor_num * (cfg.class_num + 5), 1))
94 |
95 | def forward(self, x, img_shape=None):
96 | if self.cfg.tiny:
97 | x = self.conv(self.darknet(x))
98 | else:
99 | x1, x2 = self.darknet(x)
100 | x = self.conv(torch.cat([self.conv2(x2), self.conv1(x1)], 1))
101 | # extract each part
102 | b, c, h, w = x.size()
103 | feat = x.permute(0, 2, 3, 1).contiguous().view(b, -1, self.cfg.anchor_num, self.cfg.class_num + 5)
104 | box_xy, box_wh = F.sigmoid(feat[..., 0:2]), feat[..., 2:4].exp()
105 | box_conf, score_pred = F.sigmoid(feat[..., 4:5]), feat[..., 5:].contiguous()
106 | box_prob = F.softmax(score_pred, dim=3)
107 | box_pred = torch.cat([box_xy, box_wh], 3)
108 | # TODO: add training phase
109 | if self.training:
110 | return x
111 | else:
112 | width, height = img_shape
113 | img_shape = Variable(torch.Tensor([[width, height, width, height]]))
114 | if self.cfg.cuda: self.prior, img_shape = self.prior.cuda(), img_shape.cuda()
115 | self.prior = self.prior.view_as(box_pred)
116 | return Detect(self.cfg, self.cfg.eval)(box_pred, box_conf, box_prob, self.prior, img_shape)
117 |
118 |
119 | # interface --- construct different type yolo model
120 | def yolo(cfg):
121 | model = Yolo(cfg)
122 | return model
123 |
124 |
125 | # weight initialize
126 | def weights_init(m):
127 | classname = m.__class__.__name__
128 | if classname.find('Conv2d') != -1:
129 | m.weight.data.normal_(0.0, (m.kernel_size[0] ** 2 * m.out_channels) ** 0.5)
130 | elif classname.find('BatchNorm') != -1:
131 | # Estimated variance, must be around 1
132 | m.weight.data.normal_(1.0, 0.02)
133 | # Estimated mean, must be around 0
134 | m.bias.data.fill_(0)
135 |
--------------------------------------------------------------------------------