├── .gitignore
├── .idea
├── ProbEn.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── CenterNet
├── LICENSE
├── centernet.py
├── change_weight.py
├── get_FNPI.py
├── get_map.py
├── model_data
│ ├── people_classes.txt
│ └── simhei.ttf
├── nets
│ ├── CenterNet_yolov7.py
│ ├── centernet.py
│ ├── centernet_training.py
│ ├── hourglass.py
│ └── resnet50.py
├── predict.py
├── summary.py
├── train.py
├── utils
│ ├── __init__.py
│ ├── callbacks.py
│ ├── dataloader.py
│ ├── utils.py
│ ├── utils_bbox.py
│ ├── utils_fit.py
│ └── utils_map.py
├── vision_for_centernet.py
└── voc_annotation.py
├── ProbEn.py
├── ProbEn_time_test.py
├── README.md
├── detector_test.py
├── get_fusion_FNPI.py
├── get_fusion_FNPI_onlyYolo.py
├── get_fusion_map.py
├── get_fusion_map_onlyYolo.py
├── img
├── 1.jpg
├── 2.jpg
└── street.jpg
├── img_out
├── 1.png
├── 2.png
└── street.png
├── predict_with_probEn.py
├── predict_with_probEn_onlyYolo.py
└── yolov7
├── get_FNPI.py
├── get_map.py
├── kmeans_for_anchors.py
├── model_data
├── coco_classes.txt
├── simhei.ttf
├── voc_classes.txt
└── yolo_anchors.txt
├── nets
├── SRModule.py
├── SR_Decoder.py
├── SR_Encoder.py
├── __init__.py
├── backbone.py
├── sync_batchnorm
│ ├── __init__.py
│ ├── batchnorm.py
│ ├── comm.py
│ ├── replicate.py
│ └── unittest.py
├── yolo().py
├── yolo.py
└── yolo_training.py
├── predict.py
├── predict_RGB.py
├── predict_T.py
├── summary.py
├── train.py
├── utils
├── __init__.py
├── attentions.py
├── callbacks.py
├── dataloader.py
├── utils.py
├── utils_FNPI.py
├── utils_bbox.py
├── utils_fit.py
└── utils_map.py
├── utils_coco
├── coco_annotation.py
└── get_map_coco.py
├── voc_annotation.py
├── yolo.py
├── yolo_RGB.py
└── yolo_T.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /yolov7/VOCdevkit/
2 | /CenterNet/2007_val.txt
3 | /CenterNet/2007_train.txt
4 |
--------------------------------------------------------------------------------
/.idea/ProbEn.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
20 |
21 |
22 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/CenterNet/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Bubbliiiing
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/CenterNet/change_weight.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def Combine_weights(ir_weights,rgb_weights,alpha,save_path):
6 | # lambda_ = np.random.beta(alpha, alpha)
7 | lambda_=alpha
8 | weights_1 = torch.load(ir_weights)
9 | # print(weights_1['head.reg_head.3.bias'])
10 | weights_2 = torch.load(rgb_weights)
11 | for key,value in weights_2.items():
12 | if key in weights_1:
13 | weights_1[key]=torch.tensor(((1-lambda_)*weights_1[key].cpu().numpy()+lambda_ * weights_2[key].cpu().numpy()),device ='cuda:0')
14 | # print(weights_1['head.reg_head.3.bias'])
15 | torch.save(weights_1,r"D:\xiangmushiyan\centernet-pytorch-main\new_weights.pth")
16 | return save_path
17 | save_path=r'D:\xiangmushiyan\centernet-pytorch-main'
18 | ir_weights=r"D:\xiangmushiyan\centernet-pytorch-main\logs\loss_2022_10_17_16_36_01-yolov7-ir\best_epoch_weights.pth"
19 | rgb_weights=r"D:\xiangmushiyan\centernet-pytorch-main\logs\loss_2022_10_18_21_51_40-yolov7-rgb\best_epoch_weights.pth"
20 | alpha=0.99
21 |
22 | Combine_weights(ir_weights,rgb_weights,alpha,save_path)
--------------------------------------------------------------------------------
/CenterNet/get_FNPI.py:
--------------------------------------------------------------------------------
1 | import os
2 | import xml.etree.ElementTree as ET
3 |
4 | from PIL import Image
5 | from tqdm import tqdm
6 |
7 | from utils.utils import get_classes
8 | from yolov7.utils.utils_FNPI import get_FNPI
9 | from centernet import CenterNet
10 |
11 | if __name__ == "__main__":
12 | #------------------------------------------------------------------------------------------------------------------#
13 | # map_mode用于指定该文件运行时计算的内容
14 | # map_mode为0代表整个FNPI计算流程,包括获得预测结果、获得真实框、计算FNPI。
15 | # map_mode为1代表仅仅获得预测结果。
16 | # map_mode为2代表仅仅获得真实框。
17 | # map_mode为3代表仅仅计算FNPI。
18 | #-------------------------------------------------------------------------------------------------------------------#
19 | map_mode = 0
20 | #--------------------------------------------------------------------------------------#
21 | # 此处的classes_path用于指定需要测量FNPI的类别
22 | # 一般情况下与训练和预测所用的classes_path一致即可
23 | #--------------------------------------------------------------------------------------#
24 | # classes_path = 'model_data/people_classes_KAIST.txt'
25 | classes_path = r'E:\pythonProject\object-detection\ProbEn-master\CenterNet\model_data\people_classes_voc4.txt'
26 | #--------------------------------------------------------------------------------------#
27 | # FNPI_IOU作为判定预测框与真实框相匹配(即真实框所对应的目标被检测成功)的条件
28 | # 只有大于FNPI_IOU值才算检测成功
29 | #--------------------------------------------------------------------------------------#
30 | FNPI_IOU = 0.5
31 | #--------------------------------------------------------------------------------------#
32 | # confidence的设置与计算map时的设置情况不一样。
33 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,因此,map计算时的confidence的值应当设置的尽量小进而获得全部可能的预测框。
34 | # 而计算FNPI设置的置信度confidence应该与预测时的置信度一致,只有得分大于置信度的预测框会被保留下来
35 | #--------------------------------------------------------------------------------------#
36 | confidence = 0.5
37 | #--------------------------------------------------------------------------------------#
38 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。
39 | # 该值也应该与预测时设置的nms_iou一致。
40 | #--------------------------------------------------------------------------------------#
41 | nms_iou = 0.3
42 | #-------------------------------------------------------#
43 | # 指向VOC数据集所在的文件夹
44 | # 默认指向根目录下的VOC数据集
45 | #-------------------------------------------------------#
46 | VOCdevkit_path = r'E:\pythonProject\object-detection\yolov7-pytorch-master\VOCdevkit'
47 | # VOCdevkit_path = r'D:\KAIST数据集\重新标注的kaist'
48 | #-------------------------------------------------------#
49 | # 结果输出的文件夹,默认为map_out
50 | #-------------------------------------------------------#
51 | FNPI_out_path = 'FNPI_out/FNPI_out_voc4'
52 |
53 | image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split()
54 | # image_ids = open(os.path.join(VOCdevkit_path, "kaist_wash_picture_test/test.txt")).read().strip().split()
55 |
56 | if not os.path.exists(FNPI_out_path):
57 | os.makedirs(FNPI_out_path)
58 | if not os.path.exists(os.path.join(FNPI_out_path, 'ground-truth')):
59 | os.makedirs(os.path.join(FNPI_out_path, 'ground-truth'))
60 | if not os.path.exists(os.path.join(FNPI_out_path, 'detection-results')):
61 | os.makedirs(os.path.join(FNPI_out_path, 'detection-results'))
62 |
63 | class_names, _ = get_classes(classes_path)
64 |
65 | if map_mode == 0 or map_mode == 1:
66 | print("Load model.")
67 | centernet = CenterNet(confidence = confidence, nms_iou = nms_iou)
68 | print("Load model done.")
69 |
70 | print("Get predict result.")
71 | for image_id in tqdm(image_ids):
72 | image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg")
73 | # image_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/lwir/"+image_id+".jpg")
74 | # image_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/visible/"+image_id+".jpg")
75 | image = Image.open(image_path)
76 | centernet.get_map_txt(image_id, image, class_names, FNPI_out_path)
77 | print("Get predict result done.")
78 |
79 | if map_mode == 0 or map_mode == 2:
80 | print("Get ground truth result.")
81 | for image_id in tqdm(image_ids):
82 | with open(os.path.join(FNPI_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
83 | root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/"+image_id+".xml")).getroot()
84 | # root = ET.parse(os.path.join(VOCdevkit_path, "kaist_wash_annotation_test/"+image_id+".xml")).getroot()
85 | for obj in root.findall('object'):
86 | difficult_flag = False
87 | if obj.find('difficult')!=None:
88 | difficult = obj.find('difficult').text
89 | if int(difficult)==1:
90 | difficult_flag = True
91 | obj_name = obj.find('name').text
92 | if obj_name not in class_names:
93 | continue
94 | bndbox = obj.find('bndbox')
95 | left = bndbox.find('xmin').text
96 | top = bndbox.find('ymin').text
97 | right = bndbox.find('xmax').text
98 | bottom = bndbox.find('ymax').text
99 |
100 | if difficult_flag:
101 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
102 | else:
103 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
104 | print("Get ground truth result done.")
105 |
106 | if map_mode == 0 or map_mode == 3:
107 | print("Get map.")
108 | get_FNPI(FNPI_IOU, True, path = FNPI_out_path)
109 | print("Get map done.")
--------------------------------------------------------------------------------
/CenterNet/get_map.py:
--------------------------------------------------------------------------------
1 | import os
2 | import xml.etree.ElementTree as ET
3 |
4 | from PIL import Image
5 | from tqdm import tqdm
6 |
7 | from centernet import CenterNet
8 | from utils.utils import get_classes
9 | from utils.utils_map import get_coco_map, get_map
10 |
11 | if __name__ == "__main__":
12 | '''
13 | Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。
14 | 默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。
15 |
16 | 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值
17 | 因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框,
18 | '''
19 | #------------------------------------------------------------------------------------------------------------------#
20 | # map_mode用于指定该文件运行时计算的内容
21 | # map_mode为0代表整个map计算流程,包括获得预测结果、获得真实框、计算VOC_map。
22 | # map_mode为1代表仅仅获得预测结果。
23 | # map_mode为2代表仅仅获得真实框。
24 | # map_mode为3代表仅仅计算VOC_map。
25 | # map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行
26 | #-------------------------------------------------------------------------------------------------------------------#
27 | map_mode = 0
28 | #--------------------------------------------------------------------------------------#
29 | # 此处的classes_path用于指定需要测量VOC_map的类别
30 | # 一般情况下与训练和预测所用的classes_path一致即可
31 | #--------------------------------------------------------------------------------------#
32 | classes_path = 'model_data/people_classes.txt'
33 | #--------------------------------------------------------------------------------------#
34 | # MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。
35 | # 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。
36 | #
37 | # 当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。
38 | # 因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低,
39 | #--------------------------------------------------------------------------------------#
40 | MINOVERLAP = 0.5
41 | #--------------------------------------------------------------------------------------#
42 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP
43 | # 因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。
44 | #
45 | # 该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。
46 | # 想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。
47 | #--------------------------------------------------------------------------------------#
48 | confidence = 0.02
49 | #--------------------------------------------------------------------------------------#
50 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。
51 | #
52 | # 该值一般不调整。
53 | #--------------------------------------------------------------------------------------#
54 | nms_iou = 0.5
55 | #---------------------------------------------------------------------------------------------------------------#
56 | # Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。
57 | #
58 | # 默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。
59 | # 因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。
60 | # 这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。
61 | #---------------------------------------------------------------------------------------------------------------#
62 | score_threhold = 0.5
63 | #-------------------------------------------------------#
64 | # map_vis用于指定是否开启VOC_map计算的可视化
65 | #-------------------------------------------------------#
66 | map_vis = False
67 | #-------------------------------------------------------#
68 | # 指向VOC数据集所在的文件夹
69 | # 默认指向根目录下的VOC数据集
70 | #-------------------------------------------------------#
71 | VOCdevkit_path = 'VOCdevkit'
72 | #-------------------------------------------------------#
73 | # 结果输出的文件夹,默认为map_out
74 | #-------------------------------------------------------#
75 | map_out_path = 'map_out'
76 |
77 | image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split()
78 |
79 | if not os.path.exists(map_out_path):
80 | os.makedirs(map_out_path)
81 | if not os.path.exists(os.path.join(map_out_path, 'ground-truth')):
82 | os.makedirs(os.path.join(map_out_path, 'ground-truth'))
83 | if not os.path.exists(os.path.join(map_out_path, 'detection-results')):
84 | os.makedirs(os.path.join(map_out_path, 'detection-results'))
85 | if not os.path.exists(os.path.join(map_out_path, 'images-optional')):
86 | os.makedirs(os.path.join(map_out_path, 'images-optional'))
87 |
88 | class_names, _ = get_classes(classes_path)
89 |
90 | if map_mode == 0 or map_mode == 1:
91 | print("Load model.")
92 | centernet = CenterNet(confidence = confidence, nms_iou = nms_iou)
93 | print("Load model done.")
94 |
95 | print("Get predict result.")
96 | for image_id in tqdm(image_ids):
97 | image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg")
98 | image = Image.open(image_path)
99 | if map_vis:
100 | image.save(os.path.join(map_out_path, "images-optional/" + image_id + ".jpg"))
101 | centernet.get_map_txt(image_id, image, class_names, map_out_path)
102 | print("Get predict result done.")
103 |
104 | if map_mode == 0 or map_mode == 2:
105 | print("Get ground truth result.")
106 | for image_id in tqdm(image_ids):
107 | with open(os.path.join(map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
108 | root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/"+image_id+".xml")).getroot()
109 | for obj in root.findall('object'):
110 | difficult_flag = False
111 | if obj.find('difficult')!=None:
112 | difficult = obj.find('difficult').text
113 | if int(difficult)==1:
114 | difficult_flag = True
115 | obj_name = obj.find('name').text
116 | if obj_name not in class_names:
117 | continue
118 | bndbox = obj.find('bndbox')
119 | left = bndbox.find('xmin').text
120 | top = bndbox.find('ymin').text
121 | right = bndbox.find('xmax').text
122 | bottom = bndbox.find('ymax').text
123 |
124 | if difficult_flag:
125 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
126 | else:
127 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
128 | print("Get ground truth result done.")
129 |
130 | if map_mode == 0 or map_mode == 3:
131 | print("Get map.")
132 | get_map(MINOVERLAP, True, score_threhold = score_threhold, path = map_out_path)
133 | print("Get map done.")
134 |
135 | if map_mode == 4:
136 | print("Get map.")
137 | get_coco_map(class_names = class_names, path = map_out_path)
138 | print("Get map done.")
139 |
--------------------------------------------------------------------------------
/CenterNet/model_data/people_classes.txt:
--------------------------------------------------------------------------------
1 | dog
2 | person
3 | cat
4 | car
--------------------------------------------------------------------------------
/CenterNet/model_data/simhei.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/ProbDet/0e203e0dc827ad34f8c5eb87c953f16703b9a5d1/CenterNet/model_data/simhei.ttf
--------------------------------------------------------------------------------
/CenterNet/nets/CenterNet_yolov7.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def autopad(k, p=None):
6 | if p is None:
7 | p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
8 | return p
9 |
10 | class SiLU(nn.Module):
11 | @staticmethod
12 | def forward(x):
13 | return x * torch.sigmoid(x)
14 |
15 | class Conv(nn.Module):
16 | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=SiLU()): # ch_in, ch_out, kernel, stride, padding, groups
17 | super(Conv, self).__init__()
18 | self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
19 | self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
20 | self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
21 |
22 | def forward(self, x):
23 | return self.act(self.bn(self.conv(x)))
24 |
25 | def fuseforward(self, x):
26 | return self.act(self.conv(x))
27 |
28 | class Block(nn.Module):
29 | def __init__(self, c1, c2, c3, n=4, e=1, ids=[0]):
30 | super(Block, self).__init__()
31 | c_ = int(c2 * e)
32 |
33 | self.ids = ids
34 | self.cv1 = Conv(c1, c_, 1, 1)
35 | self.cv2 = Conv(c1, c_, 1, 1)
36 | self.cv3 = nn.ModuleList(
37 | [Conv(c_ if i ==0 else c2, c2, 3, 1) for i in range(n)]
38 | )
39 | self.cv4 = Conv(c_ * 2 + c2 * (len(ids) - 2), c3, 1, 1)
40 |
41 | def forward(self, x):
42 | x_1 = self.cv1(x)
43 | x_2 = self.cv2(x)
44 |
45 | x_all = [x_1, x_2]
46 | for i in range(len(self.cv3)):
47 | x_2 = self.cv3[i](x_2)
48 | x_all.append(x_2)
49 |
50 | out = self.cv4(torch.cat([x_all[id] for id in self.ids], 1))
51 | return out
52 |
53 | class MP(nn.Module):
54 | def __init__(self, k=2):
55 | super(MP, self).__init__()
56 | self.m = nn.MaxPool2d(kernel_size=k, stride=k)
57 |
58 | def forward(self, x):
59 | return self.m(x)
60 |
61 | class Transition(nn.Module):
62 | def __init__(self, c1, c2):
63 | super(Transition, self).__init__()
64 | self.cv1 = Conv(c1, c2, 1, 1)
65 | self.cv2 = Conv(c1, c2, 1, 1)
66 | self.cv3 = Conv(c2, c2, 3, 2)
67 |
68 | self.mp = MP()
69 |
70 | def forward(self, x):
71 | x_1 = self.mp(x)
72 | x_1 = self.cv1(x_1)
73 |
74 | x_2 = self.cv2(x)
75 | x_2 = self.cv3(x_2)
76 |
77 | return torch.cat([x_2, x_1], 1)
78 |
79 | class yolo_Backbone(nn.Module):
80 | def __init__(self, transition_channels, block_channels, n, phi, pretrained=False):
81 | super().__init__()
82 | #-----------------------------------------------#
83 | # 输入图片是640, 640, 3
84 | #-----------------------------------------------#
85 | ids = {
86 | 'l' : [-1, -3, -5, -6],
87 | 'x' : [-1, -3, -5, -7, -8],
88 | }[phi]
89 | self.stem = nn.Sequential(
90 | Conv(3, transition_channels, 3, 1),
91 | Conv(transition_channels, transition_channels * 2, 3, 2),
92 | Conv(transition_channels * 2, transition_channels * 2, 3, 1),
93 | )
94 | self.dark2 = nn.Sequential(
95 | Conv(transition_channels * 2, transition_channels * 4, 3, 2),
96 | Block(transition_channels * 4, block_channels * 2, transition_channels * 8, n=n, ids=ids),
97 | )
98 | self.dark3 = nn.Sequential(
99 | Transition(transition_channels * 8, transition_channels * 4),
100 | Block(transition_channels * 8, block_channels * 4, transition_channels * 16, n=n, ids=ids),
101 | )
102 | self.dark4 = nn.Sequential(
103 | Transition(transition_channels * 16, transition_channels * 8),
104 | Block(transition_channels * 16, block_channels * 8, transition_channels * 32, n=n, ids=ids),
105 | )
106 | self.dark5 = nn.Sequential(
107 | Transition(transition_channels * 32, transition_channels * 16),
108 | Block(transition_channels * 32, block_channels * 8, transition_channels * 64, n=n, ids=ids),
109 | )
110 |
111 | if pretrained:
112 | url = {
113 | "l" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_backbone_weights.pth',
114 | "x" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_x_backbone_weights.pth',
115 | }[phi]
116 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", model_dir="./model_data")
117 | self.load_state_dict(checkpoint, strict=False)
118 | print("Load weights from " + url.split('/')[-1])
119 |
120 | def forward(self, x):
121 | x = self.stem(x)
122 | x = self.dark2(x)
123 | #-----------------------------------------------#
124 | # dark3的输出为80, 80, 256,是一个有效特征层
125 | #-----------------------------------------------#
126 | x = self.dark3(x)
127 | feat1 = x
128 | #-----------------------------------------------#
129 | # dark4的输出为40, 40, 512,是一个有效特征层
130 | #-----------------------------------------------#
131 | x = self.dark4(x)
132 | feat2 = x
133 | #-----------------------------------------------#
134 | # dark5的输出为20, 20, 1024,是一个有效特征层
135 | #-----------------------------------------------#
136 | x = self.dark5(x)
137 | feat3 = x
138 | return feat3
139 |
--------------------------------------------------------------------------------
/CenterNet/nets/centernet.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch.nn as nn
4 | from torch import nn
5 |
6 | from CenterNet.nets.hourglass import *
7 | from CenterNet.nets.resnet50 import resnet50, resnet50_Decoder, resnet50_Head,yolov7
8 |
9 |
10 | class CenterNet_Resnet50(nn.Module):
11 | def __init__(self, num_classes = 20, pretrained = False):
12 | super(CenterNet_Resnet50, self).__init__()
13 | self.pretrained = pretrained
14 | # 512,512,3 -> 16,16,2048
15 | self.backbone = resnet50(pretrained = pretrained)
16 | # 16,16,2048 -> 128,128,64
17 | self.decoder = resnet50_Decoder(2048)
18 | #-----------------------------------------------------------------#
19 | # 对获取到的特征进行上采样,进行分类预测和回归预测
20 | # 128, 128, 64 -> 128, 128, 64 -> 128, 128, num_classes
21 | # -> 128, 128, 64 -> 128, 128, 2
22 | # -> 128, 128, 64 -> 128, 128, 2
23 | #-----------------------------------------------------------------#
24 | self.head = resnet50_Head(channel=64, num_classes=num_classes)
25 |
26 | self._init_weights()
27 |
28 | def freeze_backbone(self):
29 | for param in self.backbone.parameters():
30 | param.requires_grad = False
31 |
32 | def unfreeze_backbone(self):
33 | for param in self.backbone.parameters():
34 | param.requires_grad = True
35 |
36 | def _init_weights(self):
37 | if not self.pretrained:
38 | for m in self.modules():
39 | if isinstance(m, nn.Conv2d):
40 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
41 | m.weight.data.normal_(0, math.sqrt(2. / n))
42 | elif isinstance(m, nn.BatchNorm2d):
43 | m.weight.data.fill_(1)
44 | m.bias.data.zero_()
45 |
46 | self.head.cls_head[-1].weight.data.fill_(0)
47 | self.head.cls_head[-1].bias.data.fill_(-2.19)
48 |
49 | def forward(self, x):
50 | feat = self.backbone(x)
51 | return self.head(self.decoder(feat))
52 |
53 | class CenterNet_HourglassNet(nn.Module):
54 | def __init__(self, heads, pretrained=False, num_stacks=2, n=5, cnv_dim=256, dims=[256, 256, 384, 384, 384, 512], modules = [2, 2, 2, 2, 2, 4]):
55 | super(CenterNet_HourglassNet, self).__init__()
56 | if pretrained:
57 | raise ValueError("HourglassNet has no pretrained model")
58 |
59 | self.nstack = num_stacks
60 | self.heads = heads
61 |
62 | curr_dim = dims[0]
63 |
64 | self.pre = nn.Sequential(
65 | conv2d(7, 3, 128, stride=2),
66 | residual(3, 128, 256, stride=2)
67 | )
68 |
69 | self.kps = nn.ModuleList([
70 | kp_module(
71 | n, dims, modules
72 | ) for _ in range(num_stacks)
73 | ])
74 |
75 | self.cnvs = nn.ModuleList([
76 | conv2d(3, curr_dim, cnv_dim) for _ in range(num_stacks)
77 | ])
78 |
79 | self.inters = nn.ModuleList([
80 | residual(3, curr_dim, curr_dim) for _ in range(num_stacks - 1)
81 | ])
82 |
83 | self.inters_ = nn.ModuleList([
84 | nn.Sequential(
85 | nn.Conv2d(curr_dim, curr_dim, (1, 1), bias=False),
86 | nn.BatchNorm2d(curr_dim)
87 | ) for _ in range(num_stacks - 1)
88 | ])
89 |
90 | self.cnvs_ = nn.ModuleList([
91 | nn.Sequential(
92 | nn.Conv2d(cnv_dim, curr_dim, (1, 1), bias=False),
93 | nn.BatchNorm2d(curr_dim)
94 | ) for _ in range(num_stacks - 1)
95 | ])
96 |
97 | for head in heads.keys():
98 | if 'hm' in head:
99 | module = nn.ModuleList([
100 | nn.Sequential(
101 | conv2d(3, cnv_dim, curr_dim, with_bn=False),
102 | nn.Conv2d(curr_dim, heads[head], (1, 1))
103 | ) for _ in range(num_stacks)
104 | ])
105 | self.__setattr__(head, module)
106 | for heat in self.__getattr__(head):
107 | heat[-1].weight.data.fill_(0)
108 | heat[-1].bias.data.fill_(-2.19)
109 | else:
110 | module = nn.ModuleList([
111 | nn.Sequential(
112 | conv2d(3, cnv_dim, curr_dim, with_bn=False),
113 | nn.Conv2d(curr_dim, heads[head], (1, 1))
114 | ) for _ in range(num_stacks)
115 | ])
116 | self.__setattr__(head, module)
117 |
118 |
119 | self.relu = nn.ReLU(inplace=True)
120 |
121 | def freeze_backbone(self):
122 | freeze_list = [self.pre, self.kps]
123 | for module in freeze_list:
124 | for param in module.parameters():
125 | param.requires_grad = False
126 |
127 | def unfreeze_backbone(self):
128 | freeze_list = [self.pre, self.kps]
129 | for module in freeze_list:
130 | for param in module.parameters():
131 | param.requires_grad = True
132 |
133 | def forward(self, image):
134 | # print('image shape', image.shape)
135 | inter = self.pre(image)
136 | outs = []
137 |
138 | for ind in range(self.nstack):
139 | kp = self.kps[ind](inter)
140 | cnv = self.cnvs[ind](kp)
141 |
142 | if ind < self.nstack - 1:
143 | inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv)
144 | inter = self.relu(inter)
145 | inter = self.inters[ind](inter)
146 |
147 | out = {}
148 | for head in self.heads:
149 | out[head] = self.__getattr__(head)[ind](cnv)
150 | outs.append(out)
151 | return outs
152 | def autopad(k, p=None):
153 | if p is None:
154 | p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
155 | return p
156 |
157 | class SiLU(nn.Module):
158 | @staticmethod
159 | def forward(x):
160 | return x * torch.sigmoid(x)
161 | class Conv(nn.Module):
162 | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=SiLU()): # ch_in, ch_out, kernel, stride, padding, groups
163 | super(Conv, self).__init__()
164 | self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
165 | self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
166 | self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
167 |
168 | def forward(self, x):
169 | return self.act(self.bn(self.conv(x)))
170 |
171 | def fuseforward(self, x):
172 | return self.act(self.conv(x))
173 |
174 |
175 | class Block(nn.Module):
176 | def __init__(self, c1, c2, c3, n=4, e=1, ids=[0]):
177 | super(Block, self).__init__()
178 | c_ = int(c2 * e)
179 |
180 | self.ids = ids
181 | self.cv1 = Conv(c1, c_, 1, 1)
182 | self.cv2 = Conv(c1, c_, 1, 1)
183 | self.cv3 = nn.ModuleList(
184 | [Conv(c_ if i == 0 else c2, c2, 3, 1) for i in range(n)]
185 | )
186 | self.cv4 = Conv(c_ * 2 + c2 * (len(ids) - 2), c3, 1, 1)
187 |
188 | def forward(self, x):
189 | x_1 = self.cv1(x)
190 | x_2 = self.cv2(x)
191 |
192 | x_all = [x_1, x_2]
193 | for i in range(len(self.cv3)):
194 | x_2 = self.cv3[i](x_2)
195 | x_all.append(x_2)
196 |
197 | out = self.cv4(torch.cat([x_all[id] for id in self.ids], 1))
198 | return out
199 |
200 |
201 | class MP(nn.Module):
202 | def __init__(self, k=2):
203 | super(MP, self).__init__()
204 | self.m = nn.MaxPool2d(kernel_size=k, stride=k)
205 |
206 | def forward(self, x):
207 | return self.m(x)
208 |
209 |
210 | class Transition(nn.Module):
211 | def __init__(self, c1, c2):
212 | super(Transition, self).__init__()
213 | self.cv1 = Conv(c1, c2, 1, 1)
214 | self.cv2 = Conv(c1, c2, 1, 1)
215 | self.cv3 = Conv(c2, c2, 3, 2)
216 |
217 | self.mp = MP()
218 |
219 | def forward(self, x):
220 | x_1 = self.mp(x)
221 | x_1 = self.cv1(x_1)
222 |
223 | x_2 = self.cv2(x)
224 | x_2 = self.cv3(x_2)
225 |
226 | return torch.cat([x_2, x_1], 1)
227 |
228 |
229 | class CenterNet_yolov7(nn.Module):
230 | def __init__(self, num_classes=20, pretrained=False):
231 | super(CenterNet_yolov7, self).__init__()
232 | self.pretrained = pretrained
233 | # 512,512,3 -> 16,16,2048
234 | self.backbone = yolov7(pretrained=pretrained)
235 | # 16,16,2048 -> 128,128,64
236 | self.decoder = resnet50_Decoder(2048)
237 | # -----------------------------------------------------------------#
238 | # 对获取到的特征进行上采样,进行分类预测和回归预测
239 | # 128, 128, 64 -> 128, 128, 64 -> 128, 128, num_classes
240 | # -> 128, 128, 64 -> 128, 128, 2
241 | # -> 128, 128, 64 -> 128, 128, 2
242 | # -----------------------------------------------------------------#
243 | self.head = resnet50_Head(channel=64, num_classes=num_classes)
244 |
245 | self._init_weights()
246 |
247 | def freeze_backbone(self):
248 | for param in self.backbone.parameters():
249 | param.requires_grad = False
250 |
251 | def unfreeze_backbone(self):
252 | for param in self.backbone.parameters():
253 | param.requires_grad = True
254 |
255 | def _init_weights(self):
256 | if not self.pretrained:
257 | for m in self.modules():
258 | if isinstance(m, nn.Conv2d):
259 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
260 | m.weight.data.normal_(0, math.sqrt(2. / n))
261 | elif isinstance(m, nn.BatchNorm2d):
262 | m.weight.data.fill_(1)
263 | m.bias.data.zero_()
264 |
265 | self.head.cls_head[-1].weight.data.fill_(0)
266 | self.head.cls_head[-1].bias.data.fill_(-2.19)
267 |
268 | def forward(self, x):
269 | feat = self.backbone(x)
270 | return self.head(self.decoder(feat))
271 |
--------------------------------------------------------------------------------
/CenterNet/nets/centernet_training.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 |
8 | def focal_loss(pred, target):
9 | pred = pred.permute(0, 2, 3, 1)
10 |
11 | #-------------------------------------------------------------------------#
12 | # 找到每张图片的正样本和负样本
13 | # 一个真实框对应一个正样本
14 | # 除去正样本的特征点,其余为负样本
15 | #-------------------------------------------------------------------------#
16 | pos_inds = target.eq(1).float()
17 | neg_inds = target.lt(1).float()
18 | #-------------------------------------------------------------------------#
19 | # 正样本特征点附近的负样本的权值更小一些
20 | #-------------------------------------------------------------------------#
21 | neg_weights = torch.pow(1 - target, 4)
22 |
23 | pred = torch.clamp(pred, 1e-6, 1 - 1e-6)
24 | #-------------------------------------------------------------------------#
25 | # 计算focal loss。难分类样本权重大,易分类样本权重小。
26 | #-------------------------------------------------------------------------#
27 | pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
28 | neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
29 |
30 | #-------------------------------------------------------------------------#
31 | # 进行损失的归一化
32 | #-------------------------------------------------------------------------#
33 | num_pos = pos_inds.float().sum()
34 | pos_loss = pos_loss.sum()
35 | neg_loss = neg_loss.sum()
36 |
37 | if num_pos == 0:
38 | loss = -neg_loss
39 | else:
40 | loss = -(pos_loss + neg_loss) / num_pos
41 | return loss
42 |
43 | def reg_l1_loss(pred, target, mask):
44 | #--------------------------------#
45 | # 计算l1_loss
46 | #--------------------------------#
47 | pred = pred.permute(0,2,3,1)
48 | expand_mask = torch.unsqueeze(mask,-1).repeat(1,1,1,2)
49 |
50 | loss = F.l1_loss(pred * expand_mask, target * expand_mask, reduction='sum')
51 | loss = loss / (mask.sum() + 1e-4)
52 | return loss
53 |
54 | def weights_init(net, init_type='normal', init_gain=0.02):
55 | def init_func(m):
56 | classname = m.__class__.__name__
57 | if hasattr(m, 'weight') and classname.find('Conv') != -1:
58 | if init_type == 'normal':
59 | torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
60 | elif init_type == 'xavier':
61 | torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
62 | elif init_type == 'kaiming':
63 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
64 | elif init_type == 'orthogonal':
65 | torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
66 | else:
67 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
68 | elif classname.find('BatchNorm2d') != -1:
69 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
70 | torch.nn.init.constant_(m.bias.data, 0.0)
71 | print('initialize network with %s type' % init_type)
72 | net.apply(init_func)
73 |
74 | def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
75 | def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
76 | if iters <= warmup_total_iters:
77 | # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
78 | lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
79 | elif iters >= total_iters - no_aug_iter:
80 | lr = min_lr
81 | else:
82 | lr = min_lr + 0.5 * (lr - min_lr) * (
83 | 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
84 | )
85 | return lr
86 |
87 | def step_lr(lr, decay_rate, step_size, iters):
88 | if step_size < 1:
89 | raise ValueError("step_size must above 1.")
90 | n = iters // step_size
91 | out_lr = lr * decay_rate ** n
92 | return out_lr
93 |
94 | if lr_decay_type == "cos":
95 | warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
96 | warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
97 | no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
98 | func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
99 | else:
100 | decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
101 | step_size = total_iters / step_num
102 | func = partial(step_lr, lr, decay_rate, step_size)
103 |
104 | return func
105 |
106 | def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
107 | lr = lr_scheduler_func(epoch)
108 | for param_group in optimizer.param_groups:
109 | param_group['lr'] = lr
110 |
--------------------------------------------------------------------------------
/CenterNet/nets/hourglass.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | #-------------------------#
7 | # 卷积+标准化+激活函数
8 | #-------------------------#
9 | class conv2d(nn.Module):
10 | def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True):
11 | super(conv2d, self).__init__()
12 |
13 | pad = (k - 1) // 2
14 | self.conv = nn.Conv2d(inp_dim, out_dim, (k, k), padding=(pad, pad), stride=(stride, stride), bias=not with_bn)
15 | self.bn = nn.BatchNorm2d(out_dim) if with_bn else nn.Sequential()
16 | self.relu = nn.ReLU(inplace=True)
17 |
18 | def forward(self, x):
19 | conv = self.conv(x)
20 | bn = self.bn(conv)
21 | relu = self.relu(bn)
22 | return relu
23 |
24 | #-------------------------#
25 | # 残差结构
26 | #-------------------------#
27 | class residual(nn.Module):
28 | def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True):
29 | super(residual, self).__init__()
30 |
31 | self.conv1 = nn.Conv2d(inp_dim, out_dim, (3, 3), padding=(1, 1), stride=(stride, stride), bias=False)
32 | self.bn1 = nn.BatchNorm2d(out_dim)
33 | self.relu1 = nn.ReLU(inplace=True)
34 |
35 | self.conv2 = nn.Conv2d(out_dim, out_dim, (3, 3), padding=(1, 1), bias=False)
36 | self.bn2 = nn.BatchNorm2d(out_dim)
37 |
38 | self.skip = nn.Sequential(
39 | nn.Conv2d(inp_dim, out_dim, (1, 1), stride=(stride, stride), bias=False),
40 | nn.BatchNorm2d(out_dim)
41 | ) if stride != 1 or inp_dim != out_dim else nn.Sequential()
42 | self.relu = nn.ReLU(inplace=True)
43 |
44 | def forward(self, x):
45 | conv1 = self.conv1(x)
46 | bn1 = self.bn1(conv1)
47 | relu1 = self.relu1(bn1)
48 |
49 | conv2 = self.conv2(relu1)
50 | bn2 = self.bn2(conv2)
51 |
52 | skip = self.skip(x)
53 | return self.relu(bn2 + skip)
54 |
55 | def make_layer(k, inp_dim, out_dim, modules, **kwargs):
56 | layers = [residual(k, inp_dim, out_dim, **kwargs)]
57 | for _ in range(modules - 1):
58 | layers.append(residual(k, out_dim, out_dim, **kwargs))
59 | return nn.Sequential(*layers)
60 |
61 | def make_hg_layer(k, inp_dim, out_dim, modules, **kwargs):
62 | layers = [residual(k, inp_dim, out_dim, stride=2)]
63 | for _ in range(modules - 1):
64 | layers += [residual(k, out_dim, out_dim)]
65 | return nn.Sequential(*layers)
66 |
67 | def make_layer_revr(k, inp_dim, out_dim, modules, **kwargs):
68 | layers = []
69 | for _ in range(modules - 1):
70 | layers.append(residual(k, inp_dim, inp_dim, **kwargs))
71 | layers.append(residual(k, inp_dim, out_dim, **kwargs))
72 | return nn.Sequential(*layers)
73 |
74 |
75 | class kp_module(nn.Module):
76 | def __init__(self, n, dims, modules, **kwargs):
77 | super(kp_module, self).__init__()
78 | self.n = n
79 |
80 | curr_mod = modules[0]
81 | next_mod = modules[1]
82 |
83 | curr_dim = dims[0]
84 | next_dim = dims[1]
85 |
86 | # 将输入进来的特征层进行两次残差卷积,便于和后面的层进行融合
87 | self.up1 = make_layer(
88 | 3, curr_dim, curr_dim, curr_mod, **kwargs
89 | )
90 |
91 | # 进行下采样
92 | self.low1 = make_hg_layer(
93 | 3, curr_dim, next_dim, curr_mod, **kwargs
94 | )
95 |
96 | # 构建U形结构的下一层
97 | if self.n > 1 :
98 | self.low2 = kp_module(
99 | n - 1, dims[1:], modules[1:], **kwargs
100 | )
101 | else:
102 | self.low2 = make_layer(
103 | 3, next_dim, next_dim, next_mod, **kwargs
104 | )
105 |
106 | # 将U形结构下一层反馈上来的层进行残差卷积
107 | self.low3 = make_layer_revr(
108 | 3, next_dim, curr_dim, curr_mod, **kwargs
109 | )
110 | # 将U形结构下一层反馈上来的层进行上采样
111 | self.up2 = nn.Upsample(scale_factor=2)
112 |
113 | def forward(self, x):
114 | up1 = self.up1(x)
115 | low1 = self.low1(x)
116 | low2 = self.low2(low1)
117 | low3 = self.low3(low2)
118 | up2 = self.up2(low3)
119 | outputs = up1 + up2
120 | return outputs
121 |
122 |
--------------------------------------------------------------------------------
/CenterNet/nets/resnet50.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import math
4 | import torch.nn as nn
5 | from torch.hub import load_state_dict_from_url
6 | from CenterNet.nets.CenterNet_yolov7 import yolo_Backbone
7 | model_urls = {
8 | 'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth',
9 | 'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth',
10 | 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth',
11 | 'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth',
12 | 'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth',
13 | }
14 |
15 | class Bottleneck(nn.Module):
16 | expansion = 4
17 |
18 | def __init__(self, inplanes, planes, stride=1, downsample=None):
19 | super(Bottleneck, self).__init__()
20 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change
21 | self.bn1 = nn.BatchNorm2d(planes)
22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change
23 | padding=1, bias=False)
24 | self.bn2 = nn.BatchNorm2d(planes)
25 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
26 | self.bn3 = nn.BatchNorm2d(planes * 4)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.downsample = downsample
29 | self.stride = stride
30 |
31 | def forward(self, x):
32 | residual = x
33 |
34 | out = self.conv1(x)
35 | out = self.bn1(out)
36 | out = self.relu(out)
37 |
38 | out = self.conv2(out)
39 | out = self.bn2(out)
40 | out = self.relu(out)
41 |
42 | out = self.conv3(out)
43 | out = self.bn3(out)
44 |
45 | if self.downsample is not None:
46 | residual = self.downsample(x)
47 |
48 | out += residual
49 | out = self.relu(out)
50 |
51 | return out
52 |
53 | #-----------------------------------------------------------------#
54 | # 使用Renset50作为主干特征提取网络,最终会获得一个
55 | # 16x16x2048的有效特征层
56 | #-----------------------------------------------------------------#
57 | class ResNet(nn.Module):
58 | def __init__(self, block, layers, num_classes=1000):
59 | self.inplanes = 64
60 | super(ResNet, self).__init__()
61 | # 512,512,3 -> 256,256,64
62 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)
63 | self.bn1 = nn.BatchNorm2d(64)
64 | self.relu = nn.ReLU(inplace=True)
65 | # 256x256x64 -> 128x128x64
66 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
67 |
68 | # 128x128x64 -> 128x128x256
69 | self.layer1 = self._make_layer(block, 64, layers[0])
70 |
71 | # 128x128x256 -> 64x64x512
72 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
73 |
74 | # 64x64x512 -> 32x32x1024
75 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
76 |
77 | # 32x32x1024 -> 16x16x2048
78 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
79 |
80 | self.avgpool = nn.AvgPool2d(7)
81 | self.fc = nn.Linear(512 * block.expansion, num_classes)
82 |
83 | for m in self.modules():
84 | if isinstance(m, nn.Conv2d):
85 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
86 | m.weight.data.normal_(0, math.sqrt(2. / n))
87 | elif isinstance(m, nn.BatchNorm2d):
88 | m.weight.data.fill_(1)
89 | m.bias.data.zero_()
90 |
91 | def _make_layer(self, block, planes, blocks, stride=1):
92 | downsample = None
93 | if stride != 1 or self.inplanes != planes * block.expansion:
94 | downsample = nn.Sequential(
95 | nn.Conv2d(self.inplanes, planes * block.expansion,
96 | kernel_size=1, stride=stride, bias=False),
97 | nn.BatchNorm2d(planes * block.expansion),
98 | )
99 |
100 | layers = []
101 | layers.append(block(self.inplanes, planes, stride, downsample))
102 | self.inplanes = planes * block.expansion
103 | for i in range(1, blocks):
104 | layers.append(block(self.inplanes, planes))
105 |
106 | return nn.Sequential(*layers)
107 |
108 | def forward(self, x):
109 | x = self.conv1(x)
110 | x = self.bn1(x)
111 | x = self.relu(x)
112 | x = self.maxpool(x)
113 |
114 | x = self.layer1(x)
115 | x = self.layer2(x)
116 | x = self.layer3(x)
117 | x = self.layer4(x)
118 |
119 | x = self.avgpool(x)
120 | x = x.view(x.size(0), -1)
121 | x = self.fc(x)
122 |
123 | return x
124 |
125 | def resnet50(pretrained = True):
126 | model = ResNet(Bottleneck, [3, 4, 6, 3])
127 | if pretrained:
128 | state_dict = load_state_dict_from_url(model_urls['resnet50'], model_dir = 'model_data/')
129 | model.load_state_dict(state_dict)
130 | #----------------------------------------------------------#
131 | # 获取特征提取部分
132 | #----------------------------------------------------------#
133 | features = list([model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2, model.layer3, model.layer4])
134 | features = nn.Sequential(*features)
135 | return features
136 | def yolov7(pretrained = True):
137 | phi = 'l'
138 | transition_channels = {'l': 32, 'x': 40}[phi]
139 | block_channels = 32
140 | n = {'l': 4, 'x': 6}[phi]
141 | model = yolo_Backbone(transition_channels, block_channels, n, phi, pretrained=False)
142 | if pretrained:
143 | state_dict = load_state_dict_from_url(model_urls['yolov7'], model_dir = 'model_data/')
144 | model.load_state_dict(state_dict)
145 | #----------------------------------------------------------#
146 | # 获取特征提取部分
147 | #----------------------------------------------------------#
148 | features = list([model.stem, model.dark2, model.dark3, model.dark4, model.dark5])
149 | features = nn.Sequential(*features)
150 | return features
151 | class resnet50_Decoder(nn.Module):
152 | def __init__(self, inplanes, bn_momentum=0.1):
153 | super(resnet50_Decoder, self).__init__()
154 | self.bn_momentum = bn_momentum
155 | self.inplanes = inplanes
156 | self.deconv_with_bias = False
157 |
158 | #----------------------------------------------------------#
159 | # 16,16,2048 -> 32,32,256 -> 64,64,128 -> 128,128,64
160 | # 利用ConvTranspose2d进行上采样。
161 | # 每次特征层的宽高变为原来的两倍。
162 | #----------------------------------------------------------#
163 | self.deconv_layers = self._make_deconv_layer(
164 | num_layers=3,
165 | num_filters=[256, 128, 64],
166 | num_kernels=[4, 4, 4],
167 | )
168 |
169 | def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
170 | layers = []
171 | for i in range(num_layers):
172 | kernel = num_kernels[i]
173 | planes = num_filters[i]
174 |
175 | layers.append(
176 | nn.ConvTranspose2d(
177 | in_channels=self.inplanes,
178 | out_channels=planes,
179 | kernel_size=kernel,
180 | stride=2,
181 | padding=1,
182 | output_padding=0,
183 | bias=self.deconv_with_bias))
184 | layers.append(nn.BatchNorm2d(planes, momentum=self.bn_momentum))
185 | layers.append(nn.ReLU(inplace=True))
186 | self.inplanes = planes
187 | return nn.Sequential(*layers)
188 |
189 | def forward(self, x):
190 | return self.deconv_layers(x)
191 |
192 |
193 | class resnet50_Head(nn.Module):
194 | def __init__(self, num_classes=80, channel=64, bn_momentum=0.1):
195 | super(resnet50_Head, self).__init__()
196 | #-----------------------------------------------------------------#
197 | # 对获取到的特征进行上采样,进行分类预测和回归预测
198 | # 128, 128, 64 -> 128, 128, 64 -> 128, 128, num_classes
199 | # -> 128, 128, 64 -> 128, 128, 2
200 | # -> 128, 128, 64 -> 128, 128, 2
201 | #-----------------------------------------------------------------#
202 | # 热力图预测部分
203 | self.cls_head = nn.Sequential(
204 | nn.Conv2d(64, channel,
205 | kernel_size=3, padding=1, bias=False),
206 | nn.BatchNorm2d(64, momentum=bn_momentum),
207 | nn.ReLU(inplace=True),
208 | nn.Conv2d(channel, num_classes,
209 | kernel_size=1, stride=1, padding=0))
210 | # 宽高预测的部分
211 | self.wh_head = nn.Sequential(
212 | nn.Conv2d(64, channel,
213 | kernel_size=3, padding=1, bias=False),
214 | nn.BatchNorm2d(64, momentum=bn_momentum),
215 | nn.ReLU(inplace=True),
216 | nn.Conv2d(channel, 2,
217 | kernel_size=1, stride=1, padding=0))
218 |
219 | # 中心点预测的部分
220 | self.reg_head = nn.Sequential(
221 | nn.Conv2d(64, channel,
222 | kernel_size=3, padding=1, bias=False),
223 | nn.BatchNorm2d(64, momentum=bn_momentum),
224 | nn.ReLU(inplace=True),
225 | nn.Conv2d(channel, 2,
226 | kernel_size=1, stride=1, padding=0))
227 |
228 | def forward(self, x):
229 | hm = self.cls_head(x).sigmoid_()
230 | wh = self.wh_head(x)
231 | offset = self.reg_head(x)
232 | return hm, wh, offset
233 |
234 |
--------------------------------------------------------------------------------
/CenterNet/predict.py:
--------------------------------------------------------------------------------
1 | #-----------------------------------------------------------------------#
2 | # predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能
3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。
4 | #-----------------------------------------------------------------------#
5 | import time
6 |
7 | import cv2
8 | import numpy as np
9 | from PIL import Image
10 |
11 | from centernet import CenterNet
12 |
13 | if __name__ == "__main__":
14 | centernet = CenterNet()
15 | #----------------------------------------------------------------------------------------------------------#
16 | # mode用于指定测试的模式:
17 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
18 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
19 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
20 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
21 | # 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。
22 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。
23 | #----------------------------------------------------------------------------------------------------------#
24 | mode = "predict"
25 | #-------------------------------------------------------------------------#
26 | # crop 指定了是否在单张图片预测后对目标进行截取
27 | # count 指定了是否进行目标的计数
28 | # crop、count仅在mode='predict'时有效
29 | #-------------------------------------------------------------------------#
30 | crop = False
31 | count = False
32 | #----------------------------------------------------------------------------------------------------------#
33 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
34 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
35 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存
36 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
37 | # video_fps 用于保存的视频的fps
38 | #
39 | # video_path、video_save_path和video_fps仅在mode='video'时有效
40 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
41 | #----------------------------------------------------------------------------------------------------------#
42 | video_path = 0
43 | video_save_path = ""
44 | video_fps = 25.0
45 | #----------------------------------------------------------------------------------------------------------#
46 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
47 | # fps_image_path 用于指定测试的fps图片
48 | #
49 | # test_interval和fps_image_path仅在mode='fps'有效
50 | #----------------------------------------------------------------------------------------------------------#
51 | test_interval = 100
52 | fps_image_path = r"D:\yolov7-pytorch-master\img\000001.jpg"
53 | #-------------------------------------------------------------------------#
54 | # dir_origin_path 指定了用于检测的图片的文件夹路径
55 | # dir_save_path 指定了检测完图片的保存路径
56 | #
57 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
58 | #-------------------------------------------------------------------------#
59 | dir_origin_path = r"D:\yolov7-pytorch-master\img\000001.jpg"
60 | dir_save_path = "img_out/"
61 | #-------------------------------------------------------------------------#
62 | # heatmap_save_path 热力图的保存路径,默认保存在model_data下
63 | #
64 | # heatmap_save_path仅在mode='heatmap'有效
65 | #-------------------------------------------------------------------------#
66 | heatmap_save_path = "model_data/heatmap_vision.png"
67 | #-------------------------------------------------------------------------#
68 | # simplify 使用Simplify onnx
69 | # onnx_save_path 指定了onnx的保存路径
70 | #-------------------------------------------------------------------------#
71 | simplify = True
72 | onnx_save_path = "model_data/models.onnx"
73 |
74 | if mode == "predict":
75 | '''
76 | 1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。
77 | 2、如果想要获得预测框的坐标,可以进入centernet.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
78 | 3、如果想要利用预测框截取下目标,可以进入centernet.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
79 | 在原图上利用矩阵的方式进行截取。
80 | 4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入centernet.detect_image函数,在绘图部分对predicted_class进行判断,
81 | 比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
82 | '''
83 | while True:
84 | img = input('Input image filename:')
85 | try:
86 | image = Image.open(img)
87 | except:
88 | print('Open Error! Try again!')
89 | continue
90 | else:
91 | r_image = centernet.detect_image(image, crop = crop, count=count)
92 | r_image.show()
93 |
94 | elif mode == "video":
95 | capture = cv2.VideoCapture(video_path)
96 | if video_save_path!="":
97 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
98 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
99 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
100 |
101 | ref, frame = capture.read()
102 | if not ref:
103 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
104 |
105 | fps = 0.0
106 | while(True):
107 | t1 = time.time()
108 | # 读取某一帧
109 | ref, frame = capture.read()
110 | if not ref:
111 | break
112 | # 格式转变,BGRtoRGB
113 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
114 | # 转变成Image
115 | frame = Image.fromarray(np.uint8(frame))
116 | # 进行检测
117 | frame = np.array(centernet.detect_image(frame))
118 | # RGBtoBGR满足opencv显示格式
119 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
120 |
121 | fps = ( fps + (1./(time.time()-t1)) ) / 2
122 | print("fps= %.2f"%(fps))
123 | frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
124 |
125 | cv2.imshow("video",frame)
126 | c= cv2.waitKey(1) & 0xff
127 | if video_save_path!="":
128 | out.write(frame)
129 |
130 | if c==27:
131 | capture.release()
132 | break
133 |
134 | print("Video Detection Done!")
135 | capture.release()
136 | if video_save_path!="":
137 | print("Save processed video to the path :" + video_save_path)
138 | out.release()
139 | cv2.destroyAllWindows()
140 |
141 | elif mode == "fps":
142 | img = Image.open(fps_image_path)
143 | tact_time = centernet.get_FPS(img, test_interval)
144 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')
145 |
146 | elif mode == "dir_predict":
147 | import os
148 |
149 | from tqdm import tqdm
150 |
151 | img_names = os.listdir(dir_origin_path)
152 | for img_name in tqdm(img_names):
153 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
154 | image_path = os.path.join(dir_origin_path, img_name)
155 | image = Image.open(image_path)
156 | r_image = centernet.detect_image(image)
157 | if not os.path.exists(dir_save_path):
158 | os.makedirs(dir_save_path)
159 | r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
160 |
161 | elif mode == "heatmap":
162 | while True:
163 | img = input('Input image filename:')
164 | try:
165 | image = Image.open(img)
166 | except:
167 | print('Open Error! Try again!')
168 | continue
169 | else:
170 | centernet.detect_heatmap(image, heatmap_save_path)
171 |
172 | elif mode == "export_onnx":
173 | centernet.convert_to_onnx(simplify, onnx_save_path)
174 |
175 | else:
176 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.")
177 |
--------------------------------------------------------------------------------
/CenterNet/summary.py:
--------------------------------------------------------------------------------
1 | #--------------------------------------------#
2 | # 该部分代码用于看网络参数
3 | #--------------------------------------------#
4 | import torch
5 | from thop import clever_format, profile
6 | from torchsummary import summary
7 |
8 | from nets.centernet import CenterNet_HourglassNet, CenterNet_Resnet50
9 |
10 | if __name__ == "__main__":
11 | input_shape = [512, 512]
12 | num_classes = 20
13 |
14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15 | model = CenterNet_Resnet50().to(device)
16 | summary(model, (3, input_shape[0], input_shape[1]))
17 |
18 | dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)
19 | flops, params = profile(model.to(device), (dummy_input, ), verbose=False)
20 | #--------------------------------------------------------#
21 | # flops * 2是因为profile没有将卷积作为两个operations
22 | # 有些论文将卷积算乘法、加法两个operations。此时乘2
23 | # 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2
24 | # 本代码选择乘2,参考YOLOX。
25 | #--------------------------------------------------------#
26 | flops = flops * 2
27 | flops, params = clever_format([flops, params], "%.3f")
28 | print('Total GFLOPS: %s' % (flops))
29 | print('Total params: %s' % (params))
30 |
--------------------------------------------------------------------------------
/CenterNet/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #
--------------------------------------------------------------------------------
/CenterNet/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 | #---------------------------------------------------------#
5 | # 将图像转换成RGB图像,防止灰度图在预测时报错。
6 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
7 | #---------------------------------------------------------#
8 | def cvtColor(image):
9 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
10 | return image
11 | else:
12 | image = image.convert('RGB')
13 | return image
14 |
15 | #---------------------------------------------------#
16 | # 对输入图像进行resize
17 | #---------------------------------------------------#
18 | def resize_image(image, size, letterbox_image):
19 | iw, ih = image.size
20 | w, h = size
21 | if letterbox_image:
22 | scale = min(w/iw, h/ih)
23 | nw = int(iw*scale)
24 | nh = int(ih*scale)
25 |
26 | image = image.resize((nw,nh), Image.BICUBIC)
27 | new_image = Image.new('RGB', size, (128,128,128))
28 | new_image.paste(image, ((w-nw)//2, (h-nh)//2))
29 | else:
30 | new_image = image.resize((w, h), Image.BICUBIC)
31 | return new_image
32 |
33 | #---------------------------------------------------#
34 | # 获得类
35 | #---------------------------------------------------#
36 | def get_classes(classes_path):
37 | with open(classes_path, encoding='utf-8') as f:
38 | class_names = f.readlines()
39 | class_names = [c.strip() for c in class_names]
40 | return class_names, len(class_names)
41 |
42 | def get_lr(optimizer):
43 | for param_group in optimizer.param_groups:
44 | return param_group['lr']
45 |
46 | def preprocess_input(image):
47 | image = np.array(image,dtype = np.float32)[:, :,::-1]
48 | mean = [0.40789655, 0.44719303, 0.47026116]
49 | std = [0.2886383, 0.27408165, 0.27809834]
50 | return (image / 255. - mean) / std
51 |
52 | def show_config(**kwargs):
53 | print('Configurations:')
54 | print('-' * 70)
55 | print('|%25s | %40s|' % ('keys', 'values'))
56 | print('-' * 70)
57 | for key, value in kwargs.items():
58 | print('|%25s | %40s|' % (str(key), str(value)))
59 | print('-' * 70)
60 |
61 | def download_weights(backbone, model_dir="./model_data"):
62 | import os
63 | from torch.hub import load_state_dict_from_url
64 |
65 | if backbone == "hourglass":
66 | raise ValueError("HourglassNet has no pretrained model")
67 |
68 | download_urls = {
69 | 'resnet50' : 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth',
70 | }
71 | url = download_urls[backbone]
72 |
73 | if not os.path.exists(model_dir):
74 | os.makedirs(model_dir)
75 | load_state_dict_from_url(url, model_dir)
--------------------------------------------------------------------------------
/CenterNet/utils/utils_fit.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from nets.centernet_training import focal_loss, reg_l1_loss
5 | from tqdm import tqdm
6 |
7 | from utils.utils import get_lr
8 |
9 |
10 | def fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, backbone, save_period, save_dir, local_rank=0):
11 | total_r_loss = 0
12 | total_c_loss = 0
13 | total_loss = 0
14 | val_loss = 0
15 |
16 | if local_rank == 0:
17 | print('Start Train')
18 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
19 | model_train.train()
20 | for iteration, batch in enumerate(gen):
21 | if iteration >= epoch_step:
22 | break
23 | with torch.no_grad():
24 | if cuda:
25 | batch = [ann.cuda(local_rank) for ann in batch]
26 | batch_images, batch_hms, batch_whs, batch_regs, batch_reg_masks = batch
27 |
28 | #----------------------#
29 | # 清零梯度
30 | #----------------------#
31 | optimizer.zero_grad()
32 | if not fp16:
33 | if backbone=="resnet50":
34 | hm, wh, offset = model_train(batch_images)
35 | c_loss = focal_loss(hm, batch_hms)
36 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks)
37 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks)
38 |
39 | loss = c_loss + wh_loss + off_loss
40 |
41 | total_loss += loss.item()
42 | total_c_loss += c_loss.item()
43 | total_r_loss += wh_loss.item() + off_loss.item()
44 | elif backbone == "yolov7_brackbone":
45 | hm, wh, offset = model_train(batch_images)
46 | c_loss = focal_loss(hm, batch_hms)
47 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks)
48 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks)
49 |
50 | loss = c_loss + wh_loss + off_loss
51 |
52 | total_loss += loss.item()
53 | total_c_loss += c_loss.item()
54 | total_r_loss += wh_loss.item() + off_loss.item()
55 |
56 | elif backbone=="hourglass":
57 | outputs = model_train(batch_images)
58 | loss = 0
59 | c_loss_all = 0
60 | r_loss_all = 0
61 | index = 0
62 | for output in outputs:
63 | hm, wh, offset = output["hm"].sigmoid(), output["wh"], output["reg"]
64 | c_loss = focal_loss(hm, batch_hms)
65 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks)
66 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks)
67 |
68 | loss += c_loss + wh_loss + off_loss
69 |
70 | c_loss_all += c_loss
71 | r_loss_all += wh_loss + off_loss
72 | index += 1
73 | total_loss += loss.item() / index
74 | total_c_loss += c_loss_all.item() / index
75 | total_r_loss += r_loss_all.item() / index
76 | loss.backward()
77 | optimizer.step()
78 | else:
79 | from torch.cuda.amp import autocast
80 | with autocast():
81 | if backbone=="resnet50":
82 | hm, wh, offset = model_train(batch_images)
83 | c_loss = focal_loss(hm, batch_hms)
84 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks)
85 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks)
86 |
87 | loss = c_loss + wh_loss + off_loss
88 |
89 | total_loss += loss.item()
90 | total_c_loss += c_loss.item()
91 | total_r_loss += wh_loss.item() + off_loss.item()
92 | if backbone == "yolov7_brackbone":
93 | hm, wh, offset = model_train(batch_images)
94 | c_loss = focal_loss(hm, batch_hms)
95 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks)
96 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks)
97 |
98 | loss = c_loss + wh_loss + off_loss
99 |
100 | total_loss += loss.item()
101 | total_c_loss += c_loss.item()
102 | total_r_loss += wh_loss.item() + off_loss.item()
103 |
104 | elif backbone=="hourglass":
105 | outputs = model_train(batch_images)
106 | loss = 0
107 | c_loss_all = 0
108 | r_loss_all = 0
109 | index = 0
110 | for output in outputs:
111 | hm, wh, offset = output["hm"].sigmoid(), output["wh"], output["reg"]
112 | c_loss = focal_loss(hm, batch_hms)
113 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks)
114 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks)
115 |
116 | loss += c_loss + wh_loss + off_loss
117 |
118 | c_loss_all += c_loss
119 | r_loss_all += wh_loss + off_loss
120 | index += 1
121 | total_loss += loss.item() / index
122 | total_c_loss += c_loss_all.item() / index
123 | total_r_loss += r_loss_all.item() / index
124 |
125 | #----------------------#
126 | # 反向传播
127 | #----------------------#
128 | scaler.scale(loss).backward()
129 | scaler.step(optimizer)
130 | scaler.update()
131 |
132 | if local_rank == 0:
133 | pbar.set_postfix(**{'total_r_loss' : total_r_loss / (iteration + 1),
134 | 'total_c_loss' : total_c_loss / (iteration + 1),
135 | 'lr' : get_lr(optimizer)})
136 | pbar.update(1)
137 |
138 | if local_rank == 0:
139 | pbar.close()
140 | print('Finish Train')
141 | print('Start Validation')
142 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
143 |
144 | model_train.eval()
145 | for iteration, batch in enumerate(gen_val):
146 | if iteration >= epoch_step_val:
147 | break
148 |
149 | with torch.no_grad():
150 | if cuda:
151 | batch = [ann.cuda(local_rank) for ann in batch]
152 | batch_images, batch_hms, batch_whs, batch_regs, batch_reg_masks = batch
153 |
154 | if backbone=="resnet50":
155 | hm, wh, offset = model_train(batch_images)
156 | c_loss = focal_loss(hm, batch_hms)
157 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks)
158 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks)
159 |
160 | loss = c_loss + wh_loss + off_loss
161 |
162 | val_loss += loss.item()
163 | elif backbone=="yolov7_brackbone":
164 | hm, wh, offset = model_train(batch_images)
165 | c_loss = focal_loss(hm, batch_hms)
166 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks)
167 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks)
168 |
169 | loss = c_loss + wh_loss + off_loss
170 |
171 | val_loss += loss.item()
172 | elif backbone=="hourglass":
173 | outputs = model_train(batch_images)
174 | index = 0
175 | loss = 0
176 | for output in outputs:
177 | hm, wh, offset = output["hm"].sigmoid(), output["wh"], output["reg"]
178 | c_loss = focal_loss(hm, batch_hms)
179 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks)
180 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks)
181 |
182 | loss += c_loss + wh_loss + off_loss
183 | index += 1
184 | val_loss += loss.item() / index
185 |
186 | if local_rank == 0:
187 | pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)})
188 | pbar.update(1)
189 |
190 | if local_rank == 0:
191 | pbar.close()
192 | print('Finish Validation')
193 | loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)
194 | eval_callback.on_epoch_end(epoch + 1, model_train)
195 | print('Epoch:'+ str(epoch+1) + '/' + str(Epoch))
196 | print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val))
197 |
198 | #-----------------------------------------------#
199 | # 保存权值
200 | #-----------------------------------------------#
201 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
202 | torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f-val_loss%.3f.pth' % (epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)))
203 |
204 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
205 | print('Save best model to best_epoch_weights.pth')
206 | torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth"))
207 |
208 | torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth"))
--------------------------------------------------------------------------------
/CenterNet/vision_for_centernet.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 |
4 | if __name__ == "__main__":
5 | height, width, feat_stride = 128,128,1
6 |
7 | fig = plt.figure()
8 | ax = fig.add_subplot(121)
9 | plt.ylim(-10,17)
10 | plt.xlim(-10,17)
11 |
12 | shift_x = np.arange(0, width * feat_stride, feat_stride)
13 | shift_y = np.arange(0, height * feat_stride, feat_stride)
14 | shift_x, shift_y = np.meshgrid(shift_x, shift_y)
15 | boxes = np.stack([shift_x,shift_y,shift_x,shift_y],axis=-1).reshape([-1,4]).astype(np.float32)
16 | plt.scatter(boxes[3:,0],boxes[3:,1])
17 | plt.scatter(boxes[0:3,0],boxes[0:3,1],c="r")
18 | ax.invert_yaxis()
19 |
20 | ax = fig.add_subplot(122)
21 | plt.ylim(-10,17)
22 | plt.xlim(-10,17)
23 |
24 | shift_x = np.arange(0, width * feat_stride, feat_stride)
25 | shift_y = np.arange(0, height * feat_stride, feat_stride)
26 | shift_x, shift_y = np.meshgrid(shift_x, shift_y)
27 | boxes = np.stack([shift_x,shift_y,shift_x,shift_y],axis=-1).reshape([-1,4]).astype(np.float32)
28 | plt.scatter(shift_x,shift_y)
29 |
30 | heatmap = np.random.uniform(0,1,[128,128,80]).reshape([-1,80])
31 | reg = np.random.uniform(0,1,[128,128,2]).reshape([-1,2])
32 | wh = np.random.uniform(5,20,[128,128,2]).reshape([-1,2])
33 |
34 | boxes[:,:2] = boxes[:,:2] + reg
35 | boxes[:,2:] = boxes[:,2:] + reg
36 | plt.scatter(boxes[0:3,0],boxes[0:3,1])
37 | boxes[:,:2] = boxes[:,:2] - wh/2
38 | boxes[:,2:] = boxes[:,2:] + wh/2
39 | for i in [0,1,2]:
40 | rect = plt.Rectangle([boxes[i, 0],boxes[i, 1]], wh[i,0], wh[i,1], color="r",fill=False)
41 | ax.add_patch(rect)
42 |
43 | ax.invert_yaxis()
44 |
45 | plt.show()
46 |
--------------------------------------------------------------------------------
/CenterNet/voc_annotation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import xml.etree.ElementTree as ET
4 |
5 | import numpy as np
6 |
7 | from utils.utils import get_classes
8 |
9 | #--------------------------------------------------------------------------------------------------------------------------------#
10 | # annotation_mode用于指定该文件运行时计算的内容
11 | # annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
12 | # annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
13 | # annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
14 | #--------------------------------------------------------------------------------------------------------------------------------#
15 | annotation_mode = 0
16 | #-------------------------------------------------------------------#
17 | # 必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息
18 | # 与训练和预测所用的classes_path一致即可
19 | # 如果生成的2007_train.txt里面没有目标信息
20 | # 那么就是因为classes没有设定正确
21 | # 仅在annotation_mode为0和2的时候有效
22 | #-------------------------------------------------------------------#
23 | classes_path =r'model_data\people_classes.txt'
24 | #--------------------------------------------------------------------------------------------------------------------------------#
25 | # trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1
26 | # train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1
27 | # 仅在annotation_mode为0和1的时候有效
28 | #--------------------------------------------------------------------------------------------------------------------------------#
29 | trainval_percent = 0.8
30 | train_percent = 0.8
31 | #-------------------------------------------------------#
32 | # 指向VOC数据集所在的文件夹
33 | # 默认指向根目录下的VOC数据集
34 | #-------------------------------------------------------#
35 | VOCdevkit_path = 'VOCdevkit'
36 |
37 | VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')]
38 | classes, _ = get_classes(classes_path)
39 |
40 | #-------------------------------------------------------#
41 | # 统计目标数量
42 | #-------------------------------------------------------#
43 | photo_nums = np.zeros(len(VOCdevkit_sets))
44 | nums = np.zeros(len(classes))
45 | def convert_annotation(year, image_id, list_file):
46 | in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8')
47 | tree=ET.parse(in_file)
48 | root = tree.getroot()
49 |
50 | for obj in root.iter('object'):
51 | difficult = 0
52 | if obj.find('difficult')!=None:
53 | difficult = obj.find('difficult').text
54 | cls = obj.find('name').text
55 | if cls not in classes or int(difficult)==1:
56 | continue
57 | cls_id = classes.index(cls)
58 | xmlbox = obj.find('bndbox')
59 | b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
60 | list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
61 |
62 | nums[classes.index(cls)] = nums[classes.index(cls)] + 1
63 |
64 | if __name__ == "__main__":
65 | random.seed(0)
66 | if " " in os.path.abspath(VOCdevkit_path):
67 | raise ValueError("数据集存放的文件夹路径与图片名称中不可以存在空格,否则会影响正常的模型训练,请注意修改。")
68 |
69 | if annotation_mode == 0 or annotation_mode == 1:
70 | print("Generate txt in ImageSets.")
71 | xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
72 | saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
73 | temp_xml = os.listdir(xmlfilepath)
74 | total_xml = []
75 | for xml in temp_xml:
76 | if xml.endswith(".xml"):
77 | total_xml.append(xml)
78 |
79 | num = len(total_xml)
80 | list = range(num)
81 | tv = int(num*trainval_percent)
82 | tr = int(tv*train_percent)
83 | trainval= random.sample(list,tv)
84 | train = random.sample(trainval,tr)
85 |
86 | print("train and val size",tv)
87 | print("train size",tr)
88 | ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')
89 | ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')
90 | ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')
91 | fval = open(os.path.join(saveBasePath,'val.txt'), 'w')
92 |
93 | for i in list:
94 | name=total_xml[i][:-4]+'\n'
95 | if i in trainval:
96 | ftrainval.write(name)
97 | if i in train:
98 | ftrain.write(name)
99 | else:
100 | fval.write(name)
101 | else:
102 | ftest.write(name)
103 |
104 | ftrainval.close()
105 | ftrain.close()
106 | fval.close()
107 | ftest.close()
108 | print("Generate txt in ImageSets done.")
109 |
110 | if annotation_mode == 0 or annotation_mode == 2:
111 | print("Generate 2007_train.txt and 2007_val.txt for train.")
112 | type_index = 0
113 | for year, image_set in VOCdevkit_sets:
114 | image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split()
115 | list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8')
116 | for image_id in image_ids:
117 | list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id))
118 |
119 | convert_annotation(year, image_id, list_file)
120 | list_file.write('\n')
121 | photo_nums[type_index] = len(image_ids)
122 | type_index += 1
123 | list_file.close()
124 | print("Generate 2007_train.txt and 2007_val.txt for train done.")
125 |
126 | def printTable(List1, List2):
127 | for i in range(len(List1[0])):
128 | print("|", end=' ')
129 | for j in range(len(List1)):
130 | print(List1[j][i].rjust(int(List2[j])), end=' ')
131 | print("|", end=' ')
132 | print()
133 |
134 | str_nums = [str(int(x)) for x in nums]
135 | tableData = [
136 | classes, str_nums
137 | ]
138 | colWidths = [0]*len(tableData)
139 | len1 = 0
140 | for i in range(len(tableData)):
141 | for j in range(len(tableData[i])):
142 | if len(tableData[i][j]) > colWidths[i]:
143 | colWidths[i] = len(tableData[i][j])
144 | printTable(tableData, colWidths)
145 |
146 | if photo_nums[0] <= 500:
147 | print("训练集数量小于500,属于较小的数据量,请注意设置较大的训练世代(Epoch)以满足足够的梯度下降次数(Step)。")
148 |
149 | if np.sum(nums) == 0:
150 | print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
151 | print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
152 | print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
153 | print("(重要的事情说三遍)。")
154 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ProbDet:基于概率决策融合的多模态目标检测
2 |
3 | ---
4 |
5 | ## 目录
6 | 1. [仓库更新 Top News](#仓库更新)
7 | 2. [训练步骤 How2train](#训练步骤)
8 | 3. [预测步骤 How2predict](#预测步骤)
9 | 4. [评估步骤 How2eval](#评估步骤)
10 | 5. [融合步骤 How2fusion](#融合步骤)
11 | 6. [参考资料 Reference](#Reference)
12 |
13 | ## Top News
14 | **`2023-01-10`**:**仓库创建,支持yolov7、CenterNet两种目标检测器**
15 | **`2023-01-14`**:**添加基于尺度不变特征和注意力机制的改进yolov7**
16 |
17 | ## 所需环境
18 | torch>=1.2.0
19 | 为了使用amp混合精度,推荐使用torch1.7.1以上的版本。
20 |
21 |
22 | ## 训练步骤(检测器)
23 | ### a、训练VOC07+12数据集
24 | 1. 数据集的准备
25 | **本文使用VOC格式进行训练,训练前需要下载好VOC07+12的数据集,解压后放在根目录**
26 |
27 | 2. 数据集的处理
28 | 修改voc_annotation.py里面的annotation_mode=2,运行voc_annotation.py生成根目录下的2007_train.txt和2007_val.txt。
29 |
30 | 3. 开始网络训练
31 | train.py的默认参数用于训练VOC数据集,直接运行train.py即可开始训练。
32 |
33 | 4. 训练结果预测
34 | 训练结果预测需要用到两个文件,分别是yolo.py和predict.py。我们首先需要去yolo.py里面修改model_path以及classes_path,这两个参数必须要修改。
35 | **model_path指向训练好的权值文件,在logs文件夹里。
36 | classes_path指向检测类别所对应的txt。**
37 | 完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。
38 |
39 | ### b、训练自己的数据集
40 | 1. 数据集的准备
41 | **本文使用VOC格式进行训练,训练前需要自己制作好数据集,**
42 | 训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中。
43 | 训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。
44 |
45 | 2. 数据集的处理
46 | 在完成数据集的摆放之后,我们需要利用voc_annotation.py获得训练用的2007_train.txt和2007_val.txt。
47 | 修改voc_annotation.py里面的参数。第一次训练可以仅修改classes_path,classes_path用于指向检测类别所对应的txt。
48 | 训练自己的数据集时,可以自己建立一个cls_classes.txt,里面写自己所需要区分的类别。
49 | model_data/cls_classes.txt文件内容为:
50 | ```python
51 | cat
52 | dog
53 | ...
54 | ```
55 | 修改voc_annotation.py中的classes_path,使其对应cls_classes.txt,并运行voc_annotation.py。
56 |
57 | 3. 开始网络训练
58 | **训练的参数较多,均在train.py中,大家可以在下载库后仔细看注释,其中最重要的部分依然是train.py里的classes_path。**
59 | **classes_path用于指向检测类别所对应的txt,这个txt和voc_annotation.py里面的txt一样!训练自己的数据集必须要修改!**
60 | 修改完classes_path后就可以运行train.py开始训练了,在训练多个epoch后,权值会生成在logs文件夹中。
61 |
62 | 4. 训练结果预测
63 | 训练结果预测需要用到两个文件,分别是yolo.py和predict.py。在yolo.py里面修改model_path以及classes_path。
64 | **model_path指向训练好的权值文件,在logs文件夹里。
65 | classes_path指向检测类别所对应的txt。**
66 | 完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。
67 |
68 | ## 预测步骤(检测器)
69 | ### a、使用预训练权重
70 | 1. 下载完库后解压,在百度网盘下载权值,放入model_data,运行predict.py,输入
71 | ```python
72 | img/street.jpg
73 | ```
74 | 2. 在predict.py里面进行设置可以进行fps测试和video视频检测。
75 | ### b、使用自己训练的权重
76 | 1. 按照训练步骤训练。
77 | 2. 在yolo.py文件里面,在如下部分修改model_path和classes_path使其对应训练好的文件;**model_path对应logs文件夹下面的权值文件,classes_path是model_path对应分的类**。
78 | ```python
79 | _defaults = {
80 | #--------------------------------------------------------------------------#
81 | # 使用自己训练好的模型进行预测一定要修改model_path和classes_path!
82 | # model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
83 | #
84 | # 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
85 | # 验证集损失较低不代表mAP较高,仅代表该权值在验证集上泛化性能较好。
86 | # 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
87 | #--------------------------------------------------------------------------#
88 | "model_path" : 'model_data/yolov7_weights.pth',
89 | "classes_path" : 'model_data/coco_classes.txt',
90 | #---------------------------------------------------------------------#
91 | # anchors_path代表先验框对应的txt文件,一般不修改。
92 | # anchors_mask用于帮助代码找到对应的先验框,一般不修改。
93 | #---------------------------------------------------------------------#
94 | "anchors_path" : 'model_data/yolo_anchors.txt',
95 | "anchors_mask" : [[6, 7, 8], [3, 4, 5], [0, 1, 2]],
96 | #---------------------------------------------------------------------#
97 | # 输入图片的大小,必须为32的倍数。
98 | #---------------------------------------------------------------------#
99 | "input_shape" : [640, 640],
100 | #------------------------------------------------------#
101 | # 所使用到的yolov7的版本,本仓库一共提供两个:
102 | # l : 对应yolov7
103 | # x : 对应yolov7_x
104 | #------------------------------------------------------#
105 | "phi" : 'l',
106 | #---------------------------------------------------------------------#
107 | # 只有得分大于置信度的预测框会被保留下来
108 | #---------------------------------------------------------------------#
109 | "confidence" : 0.5,
110 | #---------------------------------------------------------------------#
111 | # 非极大抑制所用到的nms_iou大小
112 | #---------------------------------------------------------------------#
113 | "nms_iou" : 0.3,
114 | #---------------------------------------------------------------------#
115 | # 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
116 | # 在多次测试后,发现关闭letterbox_image直接resize的效果更好
117 | #---------------------------------------------------------------------#
118 | "letterbox_image" : True,
119 | #-------------------------------#
120 | # 是否使用Cuda
121 | # 没有GPU可以设置成False
122 | #-------------------------------#
123 | "cuda" : True,
124 | }
125 | ```
126 | 3. 运行predict.py,输入
127 | ```python
128 | img/street.jpg
129 | ```
130 | 4. 在predict.py里面进行设置可以进行fps测试和video视频检测。
131 |
132 | ## 评估步骤 (检测器)
133 | ### a、评估VOC07+12的测试集
134 | 1. 本文使用VOC格式进行评估。VOC07+12已经划分好了测试集,无需利用voc_annotation.py生成ImageSets文件夹下的txt。
135 | 2. 在yolo.py里面修改model_path以及classes_path。**model_path指向训练好的权值文件,在logs文件夹里。classes_path指向检测类别所对应的txt。**
136 | 3. 运行get_map.py即可获得评估结果,评估结果会保存在map_out文件夹中。
137 |
138 | ### b、评估自己的数据集
139 | 1. 本文使用VOC格式进行评估。
140 | 2. 如果在训练前已经运行过voc_annotation.py文件,代码会自动将数据集划分成训练集、验证集和测试集。如果想要修改测试集的比例,可以修改voc_annotation.py文件下的trainval_percent。trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1。train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1。
141 | 3. 利用voc_annotation.py划分测试集后,前往get_map.py文件修改classes_path,classes_path用于指向检测类别所对应的txt,这个txt和训练时的txt一样。评估自己的数据集必须要修改。
142 | 4. 在yolo.py里面修改model_path以及classes_path。**model_path指向训练好的权值文件,在logs文件夹里。classes_path指向检测类别所对应的txt。**
143 | 5. 运行get_map.py即可获得评估结果,评估结果会保存在map_out文件夹中。
144 |
145 | ## 融合步骤(ProbEn)
146 | 1. 参考第4节检测器预测步骤,修改两种检测器配置
147 | 2. 选择predict_with_probEn.py中的模式:图片、视频、目录
148 | 3. 更改对应的路径,详情见注释
149 | 4. 运行predict_with_probEn.py
150 |
151 | ## Reference
152 | https://github.com/WongKinYiu/yolov7
153 |
154 | https://github.com/bubbliiiing/yolov7-pytorch
155 |
156 | https://github.com/bubbliiiing/centernet-pytorch
157 |
158 | https://github.com/Jamie725/Multimodal-Object-Detection-via-Probabilistic-Ensembling
159 |
--------------------------------------------------------------------------------
/detector_test.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import cv2
4 | import numpy as np
5 | from PIL import Image
6 |
7 | from yolov7.yolo import YOLO
8 | from CenterNet.centernet import CenterNet
9 |
10 | if __name__ == '__main__':
11 | centernet = CenterNet()
12 | yolo = YOLO()
13 | crop = False
14 | count = False
15 |
16 | # img = input('Input image filename:')
17 | img = 'img/1.jpg'
18 | try:
19 | image1 = Image.open(img)
20 | image2 = Image.open(img)
21 | except:
22 | print('Open Error! Try again!')
23 | else:
24 | print("-----------------")
25 | print("yolov7:")
26 | #r_image = yolo.detect_image(image1, crop=crop, count=count)
27 | #r_image.show()
28 | # tact_time = yolo.get_FPS(image1, test_interval=100)
29 | # print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1')
30 | dets_yolo, scores_yolo = yolo.detect_image_dets(image1)
31 | t1 = time.time()
32 | for _ in range(100):
33 | dets_yolo, scores_yolo = yolo.detect_image_dets(image1)
34 | t2 = time.time()
35 | tact_time = (t2 - t1) / 100
36 | print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1')
37 | # print(dets_yolo)
38 | # print(scores_yolo)
39 |
40 | print("-----------------")
41 | print("centernet:")
42 | # r_image2 = centernet.detect_image(image2, crop = crop, count=count)
43 | # r_image2.show()
44 | # tact_time = centernet.get_FPS(image2, test_interval=100)
45 | # print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1')
46 | dets_centernet, scores_centernet = centernet.detect_image_dets(image2)
47 | t1 = time.time()
48 | for _ in range(100):
49 | dets_centernet, scores_centernet = centernet.detect_image_dets(image2)
50 | t2 = time.time()
51 | tact_time = (t2 - t1) / 100
52 | print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1')
53 | # print(dets_centernet)
54 | #print(scores_centernet)
--------------------------------------------------------------------------------
/get_fusion_FNPI.py:
--------------------------------------------------------------------------------
1 | import os
2 | import xml.etree.ElementTree as ET
3 |
4 | from PIL import Image
5 | from tqdm import tqdm
6 | import numpy as np
7 |
8 | from yolov7.utils.utils import get_classes
9 | from yolov7.utils.utils_FNPI import get_FNPI
10 | from CenterNet.centernet import CenterNet
11 | from ProbEn import ProbEn
12 | from yolov7.yolo import YOLO
13 |
14 | if __name__ == "__main__":
15 | #------------------------------------------------------------------------------------------------------------------#
16 | # map_mode用于指定该文件运行时计算的内容
17 | # map_mode为0代表整个FNPI计算流程,包括获得预测结果、获得真实框、计算FNPI。
18 | # map_mode为1代表仅仅获得预测结果。
19 | # map_mode为2代表仅仅获得真实框。
20 | # map_mode为3代表仅仅计算FNPI。
21 | #-------------------------------------------------------------------------------------------------------------------#
22 | map_mode = 0
23 | #--------------------------------------------------------------------------------------#
24 | # 此处的classes_path用于指定需要测量FNPI的类别
25 | # 一般情况下与训练和预测所用的classes_path一致即可
26 | #--------------------------------------------------------------------------------------#
27 | classes_path = r'E:\pythonProject\object-detection\ProbEn-master\yolov7\model_data\voc_classes.txt'
28 | #--------------------------------------------------------------------------------------#
29 | # FNPI_IOU作为判定预测框与真实框相匹配(即真实框所对应的目标被检测成功)的条件
30 | # 只有大于FNPI_IOU值才算检测成功
31 | #--------------------------------------------------------------------------------------#
32 | FNPI_IOU = 0.5
33 | #--------------------------------------------------------------------------------------#
34 | # confidence的设置与计算map时的设置情况不一样。
35 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,因此,map计算时的confidence的值应当设置的尽量小进而获得全部可能的预测框。
36 | # 而计算FNPI设置的置信度confidence应该与预测时的置信度一致,只有得分大于置信度的预测框会被保留下来
37 | #--------------------------------------------------------------------------------------#
38 | confidence_yolo = 0.5
39 | confidence_centernet = 0.5
40 | #--------------------------------------------------------------------------------------#
41 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。
42 | # 该值也应该与预测时设置的nms_iou一致。
43 | #--------------------------------------------------------------------------------------#
44 | nms_iou = 0.3
45 | #-------------------------------------------------------#
46 | # 指向VOC数据集所在的文件夹
47 | # 默认指向根目录下的VOC数据集
48 | #-------------------------------------------------------#
49 | VOCdevkit_path = r'E:\pythonProject\object-detection\yolov7-pytorch-master\VOCdevkit'
50 | #-------------------------------------------------------#
51 | # 结果输出的文件夹,默认为map_out
52 | #-------------------------------------------------------#
53 | FNPI_out_path = 'FNPI_out/FNPI_out_ProbEn_VOC4'
54 |
55 | image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split()
56 |
57 | if not os.path.exists(FNPI_out_path):
58 | os.makedirs(FNPI_out_path)
59 | if not os.path.exists(os.path.join(FNPI_out_path, 'ground-truth')):
60 | os.makedirs(os.path.join(FNPI_out_path, 'ground-truth'))
61 | if not os.path.exists(os.path.join(FNPI_out_path, 'detection-results')):
62 | os.makedirs(os.path.join(FNPI_out_path, 'detection-results'))
63 |
64 | class_names, _ = get_classes(classes_path)
65 |
66 | if map_mode == 0 or map_mode == 1:
67 | print("Load model.")
68 | yolo = YOLO(confidence=confidence_yolo, nms_iou=nms_iou)
69 | centernet = CenterNet(confidence=confidence_centernet, nms_iou=nms_iou)
70 | proben = ProbEn()
71 | print("Load model done.")
72 |
73 | print("Get predict result.")
74 | for image_id in tqdm(image_ids):
75 | # image_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/visible/" + image_id + ".jpg")
76 | image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg")
77 | image = Image.open(image_path)
78 |
79 | dets_yolo, scores_yolo = yolo.detect_image_dets(image)
80 | dets_yolo = np.asarray(dets_yolo)
81 | scores_yolo = np.asarray(scores_yolo)
82 |
83 | dets_centernet, scores_centernet = centernet.detect_image_dets(image)
84 | dets_centernet = np.asarray(dets_centernet)
85 | scores_centernet = np.asarray(scores_centernet)
86 |
87 | proben.get_map_txt(image_id, class_names, FNPI_out_path, dets_yolo, scores_yolo, dets_centernet, scores_centernet)
88 | print("Get predict result done.")
89 |
90 | if map_mode == 0 or map_mode == 2:
91 | print("Get ground truth result.")
92 | for image_id in tqdm(image_ids):
93 | with open(os.path.join(FNPI_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
94 | # root = ET.parse(os.path.join(VOCdevkit_path, "kaist_wash_annotation_test/"+image_id+".xml")).getroot()
95 | root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/"+image_id+".xml")).getroot()
96 | for obj in root.findall('object'):
97 | difficult_flag = False
98 | if obj.find('difficult')!=None:
99 | difficult = obj.find('difficult').text
100 | if int(difficult)==1:
101 | difficult_flag = True
102 | obj_name = obj.find('name').text
103 | if obj_name not in class_names:
104 | continue
105 | bndbox = obj.find('bndbox')
106 | left = bndbox.find('xmin').text
107 | top = bndbox.find('ymin').text
108 | right = bndbox.find('xmax').text
109 | bottom = bndbox.find('ymax').text
110 |
111 | if difficult_flag:
112 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
113 | else:
114 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
115 | print("Get ground truth result done.")
116 |
117 | if map_mode == 0 or map_mode == 3:
118 | print("Get map.")
119 | get_FNPI(FNPI_IOU, True, path = FNPI_out_path)
120 | print("Get map done.")
--------------------------------------------------------------------------------
/get_fusion_FNPI_onlyYolo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import xml.etree.ElementTree as ET
3 |
4 | from PIL import Image
5 | from tqdm import tqdm
6 | import numpy as np
7 |
8 | from yolov7.utils.utils import get_classes
9 | from yolov7.utils.utils_FNPI import get_FNPI
10 | from ProbEn import ProbEn
11 | from yolov7.yolo_RGB import YOLO_RGB
12 | from yolov7.yolo_T import YOLO_T
13 |
14 | if __name__ == "__main__":
15 | #------------------------------------------------------------------------------------------------------------------#
16 | # map_mode用于指定该文件运行时计算的内容
17 | # map_mode为0代表整个FNPI计算流程,包括获得预测结果、获得真实框、计算FNPI。
18 | # map_mode为1代表仅仅获得预测结果。
19 | # map_mode为2代表仅仅获得真实框。
20 | # map_mode为3代表仅仅计算FNPI。
21 | #-------------------------------------------------------------------------------------------------------------------#
22 | map_mode = 3
23 | #--------------------------------------------------------------------------------------#
24 | # 此处的classes_path用于指定需要测量FNPI的类别
25 | # 一般情况下与训练和预测所用的classes_path一致即可
26 | #--------------------------------------------------------------------------------------#
27 | classes_path = r'E:\pythonProject\object-detection\ProbEn-master\yolov7\model_data\people_classes_KAIST.txt'
28 | #--------------------------------------------------------------------------------------#
29 | # FNPI_IOU作为判定预测框与真实框相匹配(即真实框所对应的目标被检测成功)的条件
30 | # 只有大于FNPI_IOU值才算检测成功
31 | #--------------------------------------------------------------------------------------#
32 | FNPI_IOU = 0.5
33 | #--------------------------------------------------------------------------------------#
34 | # confidence的设置与计算map时的设置情况不一样。
35 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,因此,map计算时的confidence的值应当设置的尽量小进而获得全部可能的预测框。
36 | # 而计算FNPI设置的置信度confidence应该与预测时的置信度一致,只有得分大于置信度的预测框会被保留下来
37 | #--------------------------------------------------------------------------------------#
38 | confidence_RGB = 0.5
39 | confidence_T = 0.5
40 | #--------------------------------------------------------------------------------------#
41 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。
42 | # 该值也应该与预测时设置的nms_iou一致。
43 | #--------------------------------------------------------------------------------------#
44 | nms_iou = 0.3
45 | #-------------------------------------------------------#
46 | # 指向VOC数据集所在的文件夹
47 | # 默认指向根目录下的VOC数据集
48 | #-------------------------------------------------------#
49 | VOCdevkit_path = r'D:\KAIST数据集\重新标注的kaist'
50 | #-------------------------------------------------------#
51 | # 结果输出的文件夹,默认为map_out
52 | #-------------------------------------------------------#
53 | FNPI_out_path = 'FNPI_out/FNPI_out_ProbEn_YOLO'
54 |
55 | image_ids = open(os.path.join(VOCdevkit_path, "kaist_wash_picture_test/test.txt")).read().strip().split()
56 |
57 | if not os.path.exists(FNPI_out_path):
58 | os.makedirs(FNPI_out_path)
59 | if not os.path.exists(os.path.join(FNPI_out_path, 'ground-truth')):
60 | os.makedirs(os.path.join(FNPI_out_path, 'ground-truth'))
61 | if not os.path.exists(os.path.join(FNPI_out_path, 'detection-results')):
62 | os.makedirs(os.path.join(FNPI_out_path, 'detection-results'))
63 |
64 | class_names, _ = get_classes(classes_path)
65 |
66 | if map_mode == 0 or map_mode == 1:
67 | print("Load model.")
68 | yolo_rgb = YOLO_RGB(confidence=confidence_RGB, nms_iou=nms_iou)
69 | yolo_T = YOLO_T(confidence=confidence_T, nms_iou=nms_iou)
70 | proben = ProbEn()
71 | print("Load model done.")
72 |
73 | print("Get predict result.")
74 | for image_id in tqdm(image_ids):
75 | image_RGB_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/visible/" + image_id + ".jpg")
76 | image_T_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/lwir/" + image_id + ".jpg")
77 |
78 | image_rgb = Image.open(image_RGB_path)
79 | image_T = Image.open(image_T_path)
80 |
81 | dets_rgb, scores_rgb = yolo_rgb.detect_image_dets(image_rgb)
82 | dets_rgb = np.asarray(dets_rgb)
83 | scores_rgb = np.asarray(scores_rgb)
84 |
85 | dets_T, scores_T = yolo_T.detect_image_dets(image_T)
86 | dets_T = np.asarray(dets_T)
87 | scores_T = np.asarray(scores_T)
88 |
89 | proben.get_map_txt(image_id, class_names, FNPI_out_path, dets_rgb, scores_rgb, dets_T, scores_T)
90 | print("Get predict result done.")
91 |
92 | if map_mode == 0 or map_mode == 2:
93 | print("Get ground truth result.")
94 | for image_id in tqdm(image_ids):
95 | with open(os.path.join(FNPI_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
96 | root = ET.parse(os.path.join(VOCdevkit_path, "kaist_wash_annotation_test/"+image_id+".xml")).getroot()
97 | for obj in root.findall('object'):
98 | difficult_flag = False
99 | if obj.find('difficult')!=None:
100 | difficult = obj.find('difficult').text
101 | if int(difficult)==1:
102 | difficult_flag = True
103 | obj_name = obj.find('name').text
104 | if obj_name not in class_names:
105 | continue
106 | bndbox = obj.find('bndbox')
107 | left = bndbox.find('xmin').text
108 | top = bndbox.find('ymin').text
109 | right = bndbox.find('xmax').text
110 | bottom = bndbox.find('ymax').text
111 |
112 | if difficult_flag:
113 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
114 | else:
115 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
116 | print("Get ground truth result done.")
117 |
118 | if map_mode == 0 or map_mode == 3:
119 | print("Get map.")
120 | get_FNPI(FNPI_IOU, True, path = FNPI_out_path)
121 | print("Get map done.")
--------------------------------------------------------------------------------
/get_fusion_map.py:
--------------------------------------------------------------------------------
1 | import os
2 | import xml.etree.ElementTree as ET
3 |
4 | from PIL import Image
5 | from tqdm import tqdm
6 | import numpy as np
7 | from yolov7.utils.utils import get_classes
8 | from yolov7.utils.utils_map import get_coco_map, get_map
9 | from CenterNet.centernet import CenterNet
10 | from ProbEn import ProbEn
11 | from yolov7.yolo import YOLO
12 |
13 | if __name__ == "__main__":
14 | '''
15 | Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。
16 | 默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。
17 |
18 | 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值
19 | 因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框,
20 | '''
21 | # ------------------------------------------------------------------------------------------------------------------#
22 | # map_mode用于指定该文件运行时计算的内容
23 | # map_mode为0代表整个map计算流程,包括获得预测结果、获得真实框、计算VOC_map。
24 | # map_mode为1代表仅仅获得预测结果。
25 | # map_mode为2代表仅仅获得真实框。
26 | # map_mode为3代表仅仅计算VOC_map。
27 | # map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行
28 | # -------------------------------------------------------------------------------------------------------------------#
29 | map_mode = 0
30 | # --------------------------------------------------------------------------------------#
31 | # 此处的classes_path用于指定需要测量VOC_map的类别
32 | # 一般情况下与训练和预测所用的classes_path一致即可
33 | # --------------------------------------------------------------------------------------#
34 | classes_path = r'D:\Deep_Learning_folds\ProbEn\yolov7\model_data\voc_classes.txt'
35 | # --------------------------------------------------------------------------------------#
36 | # MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。
37 | # 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。
38 | #
39 | # 当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。
40 | # 因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低,
41 | # --------------------------------------------------------------------------------------#
42 | MINOVERLAP = 0.5
43 | # --------------------------------------------------------------------------------------#
44 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP
45 | # 因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。
46 | #
47 | # 该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。
48 | # 想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。
49 | # --------------------------------------------------------------------------------------#
50 | confidence_yolo = 0.001
51 | confidence_centernet = 0.04
52 | # --------------------------------------------------------------------------------------#
53 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。
54 | #
55 | # 该值一般不调整。
56 | # --------------------------------------------------------------------------------------#
57 | nms_iou = 0.5
58 | # ---------------------------------------------------------------------------------------------------------------#
59 | # Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。
60 | #
61 | # 默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。
62 | # 因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。
63 | # 这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。
64 | # ---------------------------------------------------------------------------------------------------------------#
65 | score_threhold = 0.5
66 | # -------------------------------------------------------#
67 | # map_vis用于指定是否开启VOC_map计算的可视化
68 | # -------------------------------------------------------#
69 | map_vis = False
70 | # -------------------------------------------------------#
71 | # 指向VOC数据集所在的文件夹
72 | # 默认指向根目录下的VOC数据集
73 | # -------------------------------------------------------#
74 | VOCdevkit_path = r'D:\Deep_Learning_folds\ProbEn\yolov7\VOCdevkit'
75 | # -------------------------------------------------------#
76 | # 结果输出的文件夹,默认为map_out
77 | # -------------------------------------------------------#
78 | map_out_path = 'map_out'
79 |
80 | image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split()
81 |
82 | if not os.path.exists(map_out_path):
83 | os.makedirs(map_out_path)
84 | if not os.path.exists(os.path.join(map_out_path, 'ground-truth')):
85 | os.makedirs(os.path.join(map_out_path, 'ground-truth'))
86 | if not os.path.exists(os.path.join(map_out_path, 'detection-results')):
87 | os.makedirs(os.path.join(map_out_path, 'detection-results'))
88 | if not os.path.exists(os.path.join(map_out_path, 'images-optional')):
89 | os.makedirs(os.path.join(map_out_path, 'images-optional'))
90 |
91 | class_names, _ = get_classes(classes_path)
92 |
93 | if map_mode == 0 or map_mode == 1:
94 | print("Load model.")
95 | yolo = YOLO(confidence=confidence_yolo, nms_iou=nms_iou)
96 | centernet = CenterNet(confidence=confidence_centernet, nms_iou=nms_iou)
97 | proben = ProbEn()
98 | print("Load model done.")
99 |
100 | print("Get predict result.")
101 | for image_id in tqdm(image_ids):
102 | image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/" + image_id + ".jpg")
103 | image = Image.open(image_path)
104 | if map_vis:
105 | image.save(os.path.join(map_out_path, "images-optional/" + image_id + ".jpg"))
106 |
107 | dets_yolo, scores_yolo = yolo.detect_image_dets(image)
108 | dets_yolo = np.asarray(dets_yolo)
109 | scores_yolo = np.asarray(scores_yolo)
110 |
111 | dets_centernet, scores_centernet = centernet.detect_image_dets(image)
112 | dets_centernet = np.asarray(dets_centernet)
113 | scores_centernet = np.asarray(scores_centernet)
114 |
115 | proben.get_map_txt(image_id, class_names, map_out_path, dets_yolo, scores_yolo, dets_centernet, scores_centernet)
116 | print("Get predict result done.")
117 |
118 | if map_mode == 0 or map_mode == 2:
119 | print("Get ground truth result.")
120 | for image_id in tqdm(image_ids):
121 | with open(os.path.join(map_out_path, "ground-truth/" + image_id + ".txt"), "w") as new_f:
122 | root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/" + image_id + ".xml")).getroot()
123 | for obj in root.findall('object'):
124 | difficult_flag = False
125 | if obj.find('difficult') != None:
126 | difficult = obj.find('difficult').text
127 | if int(difficult) == 1:
128 | difficult_flag = True
129 | obj_name = obj.find('name').text
130 | if obj_name not in class_names:
131 | continue
132 | bndbox = obj.find('bndbox')
133 | left = bndbox.find('xmin').text
134 | top = bndbox.find('ymin').text
135 | right = bndbox.find('xmax').text
136 | bottom = bndbox.find('ymax').text
137 |
138 | if difficult_flag:
139 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
140 | else:
141 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
142 | print("Get ground truth result done.")
143 |
144 | if map_mode == 0 or map_mode == 3:
145 | print("Get map.")
146 | get_map(MINOVERLAP, True, score_threhold=score_threhold, path=map_out_path)
147 | print("Get map done.")
148 |
149 | if map_mode == 4:
150 | print("Get map.")
151 | get_coco_map(class_names=class_names, path=map_out_path)
152 | print("Get map done.")
153 |
--------------------------------------------------------------------------------
/get_fusion_map_onlyYolo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import xml.etree.ElementTree as ET
3 |
4 | from PIL import Image
5 | from tqdm import tqdm
6 | import numpy as np
7 | from yolov7.utils.utils import get_classes
8 | from yolov7.utils.utils_map import get_coco_map, get_map
9 |
10 | from ProbEn import ProbEn
11 | from yolov7.yolo_RGB import YOLO_RGB
12 | from yolov7.yolo_T import YOLO_T
13 |
14 | if __name__ == "__main__":
15 | '''
16 | Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。
17 | 默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。
18 |
19 | 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值
20 | 因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框,
21 | '''
22 | # ------------------------------------------------------------------------------------------------------------------#
23 | # map_mode用于指定该文件运行时计算的内容
24 | # map_mode为0代表整个map计算流程,包括获得预测结果、获得真实框、计算VOC_map。
25 | # map_mode为1代表仅仅获得预测结果。
26 | # map_mode为2代表仅仅获得真实框。
27 | # map_mode为3代表仅仅计算VOC_map。
28 | # map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行
29 | # -------------------------------------------------------------------------------------------------------------------#
30 | map_mode = 0
31 | # --------------------------------------------------------------------------------------#
32 | # 此处的classes_path用于指定需要测量VOC_map的类别
33 | # 一般情况下与训练和预测所用的classes_path一致即可
34 | # --------------------------------------------------------------------------------------#
35 | classes_path = r'E:\pythonProject\object-detection\ProbEn-master\yolov7\model_data\people_classes_KAIST.txt'
36 | # --------------------------------------------------------------------------------------#
37 | # MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。
38 | # 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。
39 | #
40 | # 当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。
41 | # 因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低,
42 | # --------------------------------------------------------------------------------------#
43 | MINOVERLAP = 0.5
44 | # --------------------------------------------------------------------------------------#
45 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP
46 | # 因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。
47 | #
48 | # 该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。
49 | # 想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。
50 | # --------------------------------------------------------------------------------------#
51 | confidence_RGB = 0.001
52 | confidence_T = 0.001
53 | # --------------------------------------------------------------------------------------#
54 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。
55 | #
56 | # 该值一般不调整。
57 | # --------------------------------------------------------------------------------------#
58 | nms_iou = 0.5
59 | # ---------------------------------------------------------------------------------------------------------------#
60 | # Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。
61 | #
62 | # 默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。
63 | # 因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。
64 | # 这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。
65 | # ---------------------------------------------------------------------------------------------------------------#
66 | score_threhold = 0.5
67 | # -------------------------------------------------------#
68 | # map_vis用于指定是否开启VOC_map计算的可视化
69 | # -------------------------------------------------------#
70 | map_vis = False
71 | # -------------------------------------------------------#
72 | # 指向VOC数据集所在的文件夹
73 | # 默认指向根目录下的VOC数据集
74 | # -------------------------------------------------------#
75 | VOCdevkit_path = r'D:\KAIST数据集\重新标注的kaist'
76 | # -------------------------------------------------------#
77 | # 结果输出的文件夹,默认为map_out
78 | # -------------------------------------------------------#
79 | map_out_path = 'map_out_ProbEn_YOLO1'
80 |
81 | image_ids = open(os.path.join(VOCdevkit_path, "kaist_wash_picture_test/test.txt")).read().strip().split()
82 |
83 | if not os.path.exists(map_out_path):
84 | os.makedirs(map_out_path)
85 | if not os.path.exists(os.path.join(map_out_path, 'ground-truth')):
86 | os.makedirs(os.path.join(map_out_path, 'ground-truth'))
87 | if not os.path.exists(os.path.join(map_out_path, 'detection-results')):
88 | os.makedirs(os.path.join(map_out_path, 'detection-results'))
89 | if not os.path.exists(os.path.join(map_out_path, 'images-optional')):
90 | os.makedirs(os.path.join(map_out_path, 'images-optional'))
91 |
92 | class_names, _ = get_classes(classes_path)
93 |
94 | if map_mode == 0 or map_mode == 1:
95 | print("Load model.")
96 | yolo_rgb = YOLO_RGB(confidence=confidence_RGB, nms_iou=nms_iou)
97 | yolo_T = YOLO_T(confidence=confidence_T, nms_iou=nms_iou)
98 | proben = ProbEn()
99 | print("Load model done.")
100 |
101 | print("Get predict result.")
102 | for image_id in tqdm(image_ids):
103 | image_RGB_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/visible/" + image_id + ".jpg")
104 | image_T_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/lwir/" + image_id + ".jpg")
105 |
106 | image_rgb = Image.open(image_RGB_path)
107 | image_T = Image.open(image_T_path)
108 |
109 | if map_vis:
110 | image_rgb.save(os.path.join(map_out_path, "images-optional/rgb/" + image_id + ".jpg"))
111 | image_T.save(os.path.join(map_out_path, "images-optional/T/" + image_id + ".jpg"))
112 |
113 | dets_rgb, scores_rgb = yolo_rgb.detect_image_dets(image_rgb)
114 | dets_rgb = np.asarray(dets_rgb)
115 | scores_rgb = np.asarray(scores_rgb)
116 |
117 | dets_T, scores_T = yolo_T.detect_image_dets(image_T)
118 | dets_T = np.asarray(dets_T)
119 | scores_T = np.asarray(scores_T)
120 |
121 | proben.get_map_txt(image_id, class_names, map_out_path, dets_rgb, scores_rgb, dets_T, scores_T)
122 | print("Get predict result done.")
123 |
124 | if map_mode == 0 or map_mode == 2:
125 | print("Get ground truth result.")
126 | for image_id in tqdm(image_ids):
127 | with open(os.path.join(map_out_path, "ground-truth/" + image_id + ".txt"), "w") as new_f:
128 | root = ET.parse(os.path.join(VOCdevkit_path, "kaist_wash_annotation_test/" + image_id + ".xml")).getroot()
129 | for obj in root.findall('object'):
130 | difficult_flag = False
131 | if obj.find('difficult') != None:
132 | difficult = obj.find('difficult').text
133 | if int(difficult) == 1:
134 | difficult_flag = True
135 | obj_name = obj.find('name').text
136 | if obj_name not in class_names:
137 | continue
138 | bndbox = obj.find('bndbox')
139 | left = bndbox.find('xmin').text
140 | top = bndbox.find('ymin').text
141 | right = bndbox.find('xmax').text
142 | bottom = bndbox.find('ymax').text
143 |
144 | if difficult_flag:
145 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
146 | else:
147 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
148 | print("Get ground truth result done.")
149 |
150 | if map_mode == 0 or map_mode == 3:
151 | print("Get map.")
152 | get_map(MINOVERLAP, True, score_threhold=score_threhold, path=map_out_path)
153 | print("Get map done.")
154 |
155 | if map_mode == 4:
156 | print("Get map.")
157 | get_coco_map(class_names=class_names, path=map_out_path)
158 | print("Get map done.")
159 |
--------------------------------------------------------------------------------
/img/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/ProbDet/0e203e0dc827ad34f8c5eb87c953f16703b9a5d1/img/1.jpg
--------------------------------------------------------------------------------
/img/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/ProbDet/0e203e0dc827ad34f8c5eb87c953f16703b9a5d1/img/2.jpg
--------------------------------------------------------------------------------
/img/street.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/ProbDet/0e203e0dc827ad34f8c5eb87c953f16703b9a5d1/img/street.jpg
--------------------------------------------------------------------------------
/img_out/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/ProbDet/0e203e0dc827ad34f8c5eb87c953f16703b9a5d1/img_out/1.png
--------------------------------------------------------------------------------
/img_out/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/ProbDet/0e203e0dc827ad34f8c5eb87c953f16703b9a5d1/img_out/2.png
--------------------------------------------------------------------------------
/img_out/street.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/ProbDet/0e203e0dc827ad34f8c5eb87c953f16703b9a5d1/img_out/street.png
--------------------------------------------------------------------------------
/yolov7/get_FNPI.py:
--------------------------------------------------------------------------------
1 | import os
2 | import xml.etree.ElementTree as ET
3 |
4 | from PIL import Image
5 | from tqdm import tqdm
6 |
7 | from utils.utils import get_classes
8 | from utils.utils_FNPI import get_FNPI
9 | from yolo import YOLO
10 |
11 | if __name__ == "__main__":
12 | #------------------------------------------------------------------------------------------------------------------#
13 | # map_mode用于指定该文件运行时计算的内容
14 | # map_mode为0代表整个FNPI计算流程,包括获得预测结果、获得真实框、计算FNPI。
15 | # map_mode为1代表仅仅获得预测结果。
16 | # map_mode为2代表仅仅获得真实框。
17 | # map_mode为3代表仅仅计算FNPI。
18 | #-------------------------------------------------------------------------------------------------------------------#
19 | map_mode = 0
20 | #--------------------------------------------------------------------------------------#
21 | # 此处的classes_path用于指定需要测量FNPI的类别
22 | # 一般情况下与训练和预测所用的classes_path一致即可
23 | #--------------------------------------------------------------------------------------#
24 | classes_path = 'model_data/people_classes_KAIST.txt'
25 | # classes_path = 'model_data/voc_classes.txt'
26 | #--------------------------------------------------------------------------------------#
27 | # FNPI_IOU作为判定预测框与真实框相匹配(即真实框所对应的目标被检测成功)的条件
28 | # 只有大于FNPI_IOU值才算检测成功
29 | #--------------------------------------------------------------------------------------#
30 | FNPI_IOU = 0.5
31 | #--------------------------------------------------------------------------------------#
32 | # confidence的设置与计算map时的设置情况不一样。
33 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,因此,map计算时的confidence的值应当设置的尽量小进而获得全部可能的预测框。
34 | # 而计算FNPI设置的置信度confidence应该与预测时的置信度一致,只有得分大于置信度的预测框会被保留下来
35 | #--------------------------------------------------------------------------------------#
36 | confidence = 0.5
37 | #--------------------------------------------------------------------------------------#
38 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。
39 | # 该值也应该与预测时设置的nms_iou一致。
40 | #--------------------------------------------------------------------------------------#
41 | nms_iou = 0.3
42 | #-------------------------------------------------------#
43 | # 指向VOC数据集所在的文件夹
44 | # 默认指向根目录下的VOC数据集
45 | #-------------------------------------------------------#
46 | # VOCdevkit_path = r'E:\pythonProject\object-detection\yolov7-pytorch-master\VOCdevkit'
47 | VOCdevkit_path = r'D:\KAIST数据集\重新标注的kaist'
48 | #-------------------------------------------------------#
49 | # 结果输出的文件夹,默认为map_out
50 | #-------------------------------------------------------#
51 | FNPI_out_path = 'FNPI_out/FNPI_out_KAIST_RGB'
52 |
53 | # image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split()
54 | image_ids = open(os.path.join(VOCdevkit_path, "kaist_wash_picture_test/test.txt")).read().strip().split()
55 |
56 | if not os.path.exists(FNPI_out_path):
57 | os.makedirs(FNPI_out_path)
58 | if not os.path.exists(os.path.join(FNPI_out_path, 'ground-truth')):
59 | os.makedirs(os.path.join(FNPI_out_path, 'ground-truth'))
60 | if not os.path.exists(os.path.join(FNPI_out_path, 'detection-results')):
61 | os.makedirs(os.path.join(FNPI_out_path, 'detection-results'))
62 |
63 | class_names, _ = get_classes(classes_path)
64 |
65 | if map_mode == 0 or map_mode == 1:
66 | print("Load model.")
67 | yolo = YOLO(confidence = confidence, nms_iou = nms_iou)
68 | print("Load model done.")
69 |
70 | print("Get predict result.")
71 | for image_id in tqdm(image_ids):
72 | # image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg")
73 | # image_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/lwir/"+image_id+".jpg")
74 | image_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/visible/"+image_id+".jpg")
75 | image = Image.open(image_path)
76 | yolo.get_map_txt(image_id, image, class_names, FNPI_out_path)
77 | print("Get predict result done.")
78 |
79 | if map_mode == 0 or map_mode == 2:
80 | print("Get ground truth result.")
81 | for image_id in tqdm(image_ids):
82 | with open(os.path.join(FNPI_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
83 | # root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/"+image_id+".xml")).getroot()
84 | root = ET.parse(os.path.join(VOCdevkit_path, "kaist_wash_annotation_test/"+image_id+".xml")).getroot()
85 | for obj in root.findall('object'):
86 | difficult_flag = False
87 | if obj.find('difficult')!=None:
88 | difficult = obj.find('difficult').text
89 | if int(difficult)==1:
90 | difficult_flag = True
91 | obj_name = obj.find('name').text
92 | if obj_name not in class_names:
93 | continue
94 | bndbox = obj.find('bndbox')
95 | left = bndbox.find('xmin').text
96 | top = bndbox.find('ymin').text
97 | right = bndbox.find('xmax').text
98 | bottom = bndbox.find('ymax').text
99 |
100 | if difficult_flag:
101 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
102 | else:
103 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
104 | print("Get ground truth result done.")
105 |
106 | if map_mode == 0 or map_mode == 3:
107 | print("Get map.")
108 | get_FNPI(FNPI_IOU, True, path = FNPI_out_path)
109 | print("Get map done.")
--------------------------------------------------------------------------------
/yolov7/get_map.py:
--------------------------------------------------------------------------------
1 | import os
2 | import xml.etree.ElementTree as ET
3 |
4 | from PIL import Image
5 | from tqdm import tqdm
6 |
7 | from utils.utils import get_classes
8 | from utils.utils_map import get_coco_map, get_map
9 | from yolo import YOLO
10 |
11 | if __name__ == "__main__":
12 | '''
13 | Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。
14 | 默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。
15 |
16 | 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值
17 | 因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框,
18 | '''
19 | #------------------------------------------------------------------------------------------------------------------#
20 | # map_mode用于指定该文件运行时计算的内容
21 | # map_mode为0代表整个map计算流程,包括获得预测结果、获得真实框、计算VOC_map。
22 | # map_mode为1代表仅仅获得预测结果。
23 | # map_mode为2代表仅仅获得真实框。
24 | # map_mode为3代表仅仅计算VOC_map。
25 | # map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行
26 | #-------------------------------------------------------------------------------------------------------------------#
27 | map_mode = 0
28 | #--------------------------------------------------------------------------------------#
29 | # 此处的classes_path用于指定需要测量VOC_map的类别
30 | # 一般情况下与训练和预测所用的classes_path一致即可
31 | #--------------------------------------------------------------------------------------#
32 | classes_path = 'model_data/voc_classes.txt'
33 | #--------------------------------------------------------------------------------------#
34 | # MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。
35 | # 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。
36 | #
37 | # 当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。
38 | # 因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低,
39 | #--------------------------------------------------------------------------------------#
40 | MINOVERLAP = 0.5
41 | #--------------------------------------------------------------------------------------#
42 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP
43 | # 因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。
44 | #
45 | # 该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。
46 | # 想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。
47 | #--------------------------------------------------------------------------------------#
48 | confidence = 0.001
49 | #--------------------------------------------------------------------------------------#
50 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。
51 | #
52 | # 该值一般不调整。
53 | #--------------------------------------------------------------------------------------#
54 | nms_iou = 0.5
55 | #---------------------------------------------------------------------------------------------------------------#
56 | # Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。
57 | #
58 | # 默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。
59 | # 因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。
60 | # 这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。
61 | #---------------------------------------------------------------------------------------------------------------#
62 | score_threhold = 0.5
63 | #-------------------------------------------------------#
64 | # map_vis用于指定是否开启VOC_map计算的可视化
65 | #-------------------------------------------------------#
66 | map_vis = False
67 | #-------------------------------------------------------#
68 | # 指向VOC数据集所在的文件夹
69 | # 默认指向根目录下的VOC数据集
70 | #-------------------------------------------------------#
71 | VOCdevkit_path = 'VOCdevkit'
72 | #-------------------------------------------------------#
73 | # 结果输出的文件夹,默认为map_out
74 | #-------------------------------------------------------#
75 | map_out_path = 'map_out'
76 |
77 | image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split()
78 |
79 | if not os.path.exists(map_out_path):
80 | os.makedirs(map_out_path)
81 | if not os.path.exists(os.path.join(map_out_path, 'ground-truth')):
82 | os.makedirs(os.path.join(map_out_path, 'ground-truth'))
83 | if not os.path.exists(os.path.join(map_out_path, 'detection-results')):
84 | os.makedirs(os.path.join(map_out_path, 'detection-results'))
85 | if not os.path.exists(os.path.join(map_out_path, 'images-optional')):
86 | os.makedirs(os.path.join(map_out_path, 'images-optional'))
87 |
88 | class_names, _ = get_classes(classes_path)
89 |
90 | if map_mode == 0 or map_mode == 1:
91 | print("Load model.")
92 | yolo = YOLO(confidence = confidence, nms_iou = nms_iou)
93 | print("Load model done.")
94 |
95 | print("Get predict result.")
96 | for image_id in tqdm(image_ids):
97 | image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg")
98 | image = Image.open(image_path)
99 | if map_vis:
100 | image.save(os.path.join(map_out_path, "images-optional/" + image_id + ".jpg"))
101 | yolo.get_map_txt(image_id, image, class_names, map_out_path)
102 | print("Get predict result done.")
103 |
104 | if map_mode == 0 or map_mode == 2:
105 | print("Get ground truth result.")
106 | for image_id in tqdm(image_ids):
107 | with open(os.path.join(map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
108 | root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/"+image_id+".xml")).getroot()
109 | for obj in root.findall('object'):
110 | difficult_flag = False
111 | if obj.find('difficult')!=None:
112 | difficult = obj.find('difficult').text
113 | if int(difficult)==1:
114 | difficult_flag = True
115 | obj_name = obj.find('name').text
116 | if obj_name not in class_names:
117 | continue
118 | bndbox = obj.find('bndbox')
119 | left = bndbox.find('xmin').text
120 | top = bndbox.find('ymin').text
121 | right = bndbox.find('xmax').text
122 | bottom = bndbox.find('ymax').text
123 |
124 | if difficult_flag:
125 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
126 | else:
127 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
128 | print("Get ground truth result done.")
129 |
130 | if map_mode == 0 or map_mode == 3:
131 | print("Get map.")
132 | get_map(MINOVERLAP, True, score_threhold = score_threhold, path = map_out_path)
133 | print("Get map done.")
134 |
135 | if map_mode == 4:
136 | print("Get map.")
137 | get_coco_map(class_names = class_names, path = map_out_path)
138 | print("Get map done.")
139 |
--------------------------------------------------------------------------------
/yolov7/kmeans_for_anchors.py:
--------------------------------------------------------------------------------
1 | #-------------------------------------------------------------------------------------------------------#
2 | # kmeans虽然会对数据集中的框进行聚类,但是很多数据集由于框的大小相近,聚类出来的9个框相差不大,
3 | # 这样的框反而不利于模型的训练。因为不同的特征层适合不同大小的先验框,shape越小的特征层适合越大的先验框
4 | # 原始网络的先验框已经按大中小比例分配好了,不进行聚类也会有非常好的效果。
5 | #-------------------------------------------------------------------------------------------------------#
6 | import glob
7 | import xml.etree.ElementTree as ET
8 |
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | from tqdm import tqdm
12 |
13 |
14 | def cas_ratio(box,cluster):
15 | ratios_of_box_cluster = box / cluster
16 | ratios_of_cluster_box = cluster / box
17 | ratios = np.concatenate([ratios_of_box_cluster, ratios_of_cluster_box], axis = -1)
18 |
19 | return np.max(ratios, -1)
20 |
21 | def avg_ratio(box,cluster):
22 | return np.mean([np.min(cas_ratio(box[i],cluster)) for i in range(box.shape[0])])
23 |
24 | def kmeans(box,k):
25 | #-------------------------------------------------------------#
26 | # 取出一共有多少框
27 | #-------------------------------------------------------------#
28 | row = box.shape[0]
29 |
30 | #-------------------------------------------------------------#
31 | # 每个框各个点的位置
32 | #-------------------------------------------------------------#
33 | distance = np.empty((row,k))
34 |
35 | #-------------------------------------------------------------#
36 | # 最后的聚类位置
37 | #-------------------------------------------------------------#
38 | last_clu = np.zeros((row,))
39 |
40 | np.random.seed()
41 |
42 | #-------------------------------------------------------------#
43 | # 随机选5个当聚类中心
44 | #-------------------------------------------------------------#
45 | cluster = box[np.random.choice(row,k,replace = False)]
46 |
47 | iter = 0
48 | while True:
49 | #-------------------------------------------------------------#
50 | # 计算当前框和先验框的宽高比例
51 | #-------------------------------------------------------------#
52 | for i in range(row):
53 | distance[i] = cas_ratio(box[i],cluster)
54 |
55 | #-------------------------------------------------------------#
56 | # 取出最小点
57 | #-------------------------------------------------------------#
58 | near = np.argmin(distance,axis=1)
59 |
60 | if (last_clu == near).all():
61 | break
62 |
63 | #-------------------------------------------------------------#
64 | # 求每一个类的中位点
65 | #-------------------------------------------------------------#
66 | for j in range(k):
67 | cluster[j] = np.median(
68 | box[near == j],axis=0)
69 |
70 | last_clu = near
71 | if iter % 5 == 0:
72 | print('iter: {:d}. avg_ratio:{:.2f}'.format(iter, avg_ratio(box,cluster)))
73 | iter += 1
74 |
75 | return cluster, near
76 |
77 | def load_data(path):
78 | data = []
79 | #-------------------------------------------------------------#
80 | # 对于每一个xml都寻找box
81 | #-------------------------------------------------------------#
82 | for xml_file in tqdm(glob.glob('{}/*xml'.format(path))):
83 | tree = ET.parse(xml_file)
84 | height = int(tree.findtext('./size/height'))
85 | width = int(tree.findtext('./size/width'))
86 | if height<=0 or width<=0:
87 | continue
88 |
89 | #-------------------------------------------------------------#
90 | # 对于每一个目标都获得它的宽高
91 | #-------------------------------------------------------------#
92 | for obj in tree.iter('object'):
93 | xmin = int(float(obj.findtext('bndbox/xmin'))) / width
94 | ymin = int(float(obj.findtext('bndbox/ymin'))) / height
95 | xmax = int(float(obj.findtext('bndbox/xmax'))) / width
96 | ymax = int(float(obj.findtext('bndbox/ymax'))) / height
97 |
98 | xmin = np.float64(xmin)
99 | ymin = np.float64(ymin)
100 | xmax = np.float64(xmax)
101 | ymax = np.float64(ymax)
102 | # 得到宽高
103 | data.append([xmax-xmin, ymax-ymin])
104 | return np.array(data)
105 |
106 | if __name__ == '__main__':
107 | np.random.seed(0)
108 | #-------------------------------------------------------------#
109 | # 运行该程序会计算'./VOCdevkit/VOC2007/Annotations'的xml
110 | # 会生成yolo_anchors.txt
111 | #-------------------------------------------------------------#
112 | input_shape = [640, 640]
113 | anchors_num = 9
114 | #-------------------------------------------------------------#
115 | # 载入数据集,可以使用VOC的xml
116 | #-------------------------------------------------------------#
117 | path = 'VOCdevkit/VOC2007/Annotations'
118 |
119 | #-------------------------------------------------------------#
120 | # 载入所有的xml
121 | # 存储格式为转化为比例后的width,height
122 | #-------------------------------------------------------------#
123 | print('Load xmls.')
124 | data = load_data(path)
125 | print('Load xmls done.')
126 |
127 | #-------------------------------------------------------------#
128 | # 使用k聚类算法
129 | #-------------------------------------------------------------#
130 | print('K-means boxes.')
131 | cluster, near = kmeans(data, anchors_num)
132 | print('K-means boxes done.')
133 | data = data * np.array([input_shape[1], input_shape[0]])
134 | cluster = cluster * np.array([input_shape[1], input_shape[0]])
135 |
136 | #-------------------------------------------------------------#
137 | # 绘图
138 | #-------------------------------------------------------------#
139 | for j in range(anchors_num):
140 | plt.scatter(data[near == j][:,0], data[near == j][:,1])
141 | plt.scatter(cluster[j][0], cluster[j][1], marker='x', c='black')
142 | plt.savefig("kmeans_for_anchors.jpg")
143 | plt.show()
144 | print('Save kmeans_for_anchors.jpg in root dir.')
145 |
146 | cluster = cluster[np.argsort(cluster[:, 0] * cluster[:, 1])]
147 | print('avg_ratio:{:.2f}'.format(avg_ratio(data, cluster)))
148 | print(cluster)
149 |
150 | f = open("yolo_anchors.txt", 'w')
151 | row = np.shape(cluster)[0]
152 | for i in range(row):
153 | if i == 0:
154 | x_y = "%d,%d" % (cluster[i][0], cluster[i][1])
155 | else:
156 | x_y = ", %d,%d" % (cluster[i][0], cluster[i][1])
157 | f.write(x_y)
158 | f.close()
159 |
--------------------------------------------------------------------------------
/yolov7/model_data/coco_classes.txt:
--------------------------------------------------------------------------------
1 | person
2 | bicycle
3 | car
4 | motorbike
5 | aeroplane
6 | bus
7 | train
8 | truck
9 | boat
10 | traffic light
11 | fire hydrant
12 | stop sign
13 | parking meter
14 | bench
15 | bird
16 | cat
17 | dog
18 | horse
19 | sheep
20 | cow
21 | elephant
22 | bear
23 | zebra
24 | giraffe
25 | backpack
26 | umbrella
27 | handbag
28 | tie
29 | suitcase
30 | frisbee
31 | skis
32 | snowboard
33 | sports ball
34 | kite
35 | baseball bat
36 | baseball glove
37 | skateboard
38 | surfboard
39 | tennis racket
40 | bottle
41 | wine glass
42 | cup
43 | fork
44 | knife
45 | spoon
46 | bowl
47 | banana
48 | apple
49 | sandwich
50 | orange
51 | broccoli
52 | carrot
53 | hot dog
54 | pizza
55 | donut
56 | cake
57 | chair
58 | sofa
59 | pottedplant
60 | bed
61 | diningtable
62 | toilet
63 | tvmonitor
64 | laptop
65 | mouse
66 | remote
67 | keyboard
68 | cell phone
69 | microwave
70 | oven
71 | toaster
72 | sink
73 | refrigerator
74 | book
75 | clock
76 | vase
77 | scissors
78 | teddy bear
79 | hair drier
80 | toothbrush
81 |
--------------------------------------------------------------------------------
/yolov7/model_data/simhei.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1e12Leon/ProbDet/0e203e0dc827ad34f8c5eb87c953f16703b9a5d1/yolov7/model_data/simhei.ttf
--------------------------------------------------------------------------------
/yolov7/model_data/voc_classes.txt:
--------------------------------------------------------------------------------
1 | dog
2 | person
3 | cat
4 | car
--------------------------------------------------------------------------------
/yolov7/model_data/yolo_anchors.txt:
--------------------------------------------------------------------------------
1 | 23,44, 64,64, 37,117, 75,228, 140,146, 154,380, 307,273, 291,498, 536,541
--------------------------------------------------------------------------------
/yolov7/nets/SRModule.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from nets.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5 | from nets.SR_Decoder import Decoder
6 | from nets.SR_Encoder import EDSR
7 |
8 |
9 | class EDSRConv(torch.nn.Module):
10 | def __init__(self, in_ch, out_ch):
11 | super(EDSRConv, self).__init__()
12 | self.conv = torch.nn.Sequential(
13 | torch.nn.Conv2d(in_ch, out_ch, 3, padding=1),
14 | torch.nn.ReLU(inplace=True),
15 | torch.nn.Conv2d(out_ch, out_ch, 3, padding=1),
16 | )
17 |
18 | self.residual_upsampler = torch.nn.Sequential(
19 | torch.nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False),
20 | )
21 |
22 | def forward(self, input):
23 | return self.conv(input)+self.residual_upsampler(input)
24 |
25 |
26 | class DeepLab(nn.Module):
27 | def __init__(self, ch, c1=128, c2=512, factor=2):
28 | super(DeepLab, self).__init__()
29 | self.sr_decoder = Decoder(c1, c2)
30 | self.edsr = EDSR(num_channels=ch, input_channel=64, factor=8)
31 | self.factor = factor
32 |
33 | def forward(self, low_level_feat,x):
34 | x_sr = self.sr_decoder(x, low_level_feat, self.factor)
35 | x_sr_up = self.edsr(x_sr)
36 | return x_sr_up
37 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/yolov7/nets/SR_Decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from nets.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5 |
6 |
7 | class Decoder(nn.Module):
8 | def __init__(self, c1, c2):
9 | super(Decoder, self).__init__()
10 |
11 | self.conv1 = nn.Conv2d(c1, c1 // 2, 1, bias=False)
12 | self.conv2 = nn.Conv2d(c2, c2 // 2, 1, bias=False)
13 | # self.bn1 = BatchNorm(48)
14 | self.relu = nn.ReLU()
15 | # self.pixel_shuffle = nn.PixelShuffle(4)
16 | # self.attention = AttentionModel(48+c2)
17 | self.last_conv = nn.Sequential(nn.Conv2d((c1 + c2) // 2, 256, kernel_size=3, stride=1, padding=1, bias=False),
18 | # BatchNorm(256),
19 | nn.ReLU(),
20 | # nn.Dropout(0.5),
21 | nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1, bias=False),
22 | # BatchNorm(128),
23 | nn.ReLU(),
24 | # nn.Dropout(0.1),
25 | nn.Conv2d(128, 64, kernel_size=1, stride=1))
26 | self._init_weight()
27 |
28 | def forward(self, x, low_level_feat, factor):
29 | low_level_feat = self.conv1(low_level_feat)
30 | low_level_feat = self.relu(low_level_feat)
31 |
32 | x = self.conv2(x)
33 | x = self.relu(x)
34 | x = F.interpolate(x, size=[i * (factor // 2) for i in low_level_feat.size()[2:]], mode='bilinear',
35 | align_corners=True)
36 | if factor > 1:
37 | low_level_feat = F.interpolate(low_level_feat, size=[i * (factor // 2) for i in low_level_feat.size()[2:]],
38 | mode='bilinear', align_corners=True)
39 | # x = self.pixel_shuffle(x)
40 | x = torch.cat((x, low_level_feat), dim=1)
41 | x = self.last_conv(x)
42 |
43 | return x
44 |
45 | def _init_weight(self):
46 | for m in self.modules():
47 | if isinstance(m, nn.Conv2d):
48 | torch.nn.init.kaiming_normal_(m.weight)
49 | elif isinstance(m, SynchronizedBatchNorm2d):
50 | m.weight.data.fill_(1)
51 | m.bias.data.zero_()
52 | elif isinstance(m, nn.BatchNorm2d):
53 | m.weight.data.fill_(1)
54 | m.bias.data.zero_()
55 |
56 |
--------------------------------------------------------------------------------
/yolov7/nets/SR_Encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 |
4 |
5 | def make_model(args, parent=False):
6 | return EDSR(args)
7 |
8 |
9 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
10 | return nn.Conv2d(
11 | in_channels, out_channels, kernel_size,
12 | padding=(kernel_size // 2), bias=bias)
13 |
14 |
15 | class Upsampler(nn.Sequential):
16 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
17 |
18 | m = []
19 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
20 | for _ in range(int(math.log(scale, 2))):
21 | m.append(conv(n_feat, 4 * n_feat, 3, bias))
22 | m.append(nn.PixelShuffle(2))
23 | if bn: m.append(nn.BatchNorm2d(n_feat))
24 | if act: m.append(act())
25 | elif scale == 3:
26 | m.append(conv(n_feat, 9 * n_feat, 3, bias))
27 | m.append(nn.PixelShuffle(3))
28 | if bn: m.append(nn.BatchNorm2d(n_feat))
29 | if act: m.append(act())
30 | else:
31 | raise NotImplementedError
32 |
33 | super(Upsampler, self).__init__(*m)
34 |
35 |
36 | class ResBlock(nn.Module):
37 | def __init__(self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
38 | super(ResBlock, self).__init__()
39 | m = []
40 | for i in range(2):
41 | m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
42 | if bn: m.append(nn.BatchNorm2d(n_feat))
43 | if i == 0: m.append(act)
44 |
45 | self.body = nn.Sequential(*m)
46 | self.res_scale = res_scale
47 |
48 | def forward(self, x):
49 | res = self.body(x).mul(self.res_scale)
50 | res += x
51 |
52 | return res
53 |
54 |
55 | class EDSR(nn.Module):
56 | def __init__(self, num_channels=3, input_channel=64, factor=4, width=64, depth=16, kernel_size=3,
57 | conv=default_conv):
58 | super(EDSR, self).__init__()
59 |
60 | n_resblock = depth
61 | n_feats = width
62 | kernel_size = kernel_size
63 | scale = factor
64 | act = nn.ReLU()
65 |
66 | # rgb_mean = (0.4488, 0.4371, 0.4040)
67 | # rgb_std = (1.0, 1.0, 1.0)
68 | # self.sub_mean = common.MeanShift(1.0, rgb_mean, rgb_std)
69 |
70 | # define head module
71 | m_head = [conv(input_channel, n_feats, kernel_size)]
72 |
73 | # define body module
74 | m_body = [
75 | ResBlock(
76 | conv, n_feats, kernel_size, act=act, res_scale=1.
77 | ) for _ in range(n_resblock)
78 | ]
79 | m_body.append(conv(n_feats, n_feats, kernel_size))
80 |
81 | # define tail module
82 | m_tail = [
83 | Upsampler(conv, scale, n_feats, act=False),
84 | conv(n_feats, num_channels, kernel_size)
85 | ]
86 |
87 | # self.add_mean = common.MeanShift(1.0, rgb_mean, rgb_std, 1)
88 |
89 | self.head = nn.Sequential(*m_head)
90 | self.body = nn.Sequential(*m_body)
91 | self.tail = nn.Sequential(*m_tail)
92 |
93 | def forward(self, x):
94 | # x = self.sub_mean(x)
95 | x = self.head(x)
96 |
97 | res = self.body(x)
98 | res += x
99 |
100 | x = self.tail(res)
101 | # x = self.add_mean(x)
102 |
103 | return x
104 |
105 | def load_state_dict(self, state_dict, strict=True):
106 | own_state = self.state_dict()
107 | for name, param in state_dict.items():
108 | if name in own_state:
109 | if isinstance(param, nn.Parameter):
110 | param = param.data
111 | try:
112 | own_state[name].copy_(param)
113 | except Exception:
114 | if name.find('tail') == -1:
115 | raise RuntimeError('While copying the parameter named {}, '
116 | 'whose dimensions in the model are {} and '
117 | 'whose dimensions in the checkpoint are {}.'
118 | .format(name, own_state[name].size(), param.size()))
119 | elif strict:
120 | if name.find('tail') == -1:
121 | raise KeyError('unexpected key "{}" in state_dict'
122 | .format(name))
123 |
--------------------------------------------------------------------------------
/yolov7/nets/__init__.py:
--------------------------------------------------------------------------------
1 | #
--------------------------------------------------------------------------------
/yolov7/nets/backbone.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def autopad(k, p=None):
6 | if p is None:
7 | p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
8 | return p
9 |
10 | class SiLU(nn.Module):
11 | @staticmethod
12 | def forward(x):
13 | return x * torch.sigmoid(x)
14 |
15 | class Conv(nn.Module):
16 | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=SiLU()): # ch_in, ch_out, kernel, stride, padding, groups
17 | super(Conv, self).__init__()
18 | self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
19 | self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
20 | self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
21 |
22 | def forward(self, x):
23 | return self.act(self.bn(self.conv(x)))
24 |
25 | def fuseforward(self, x):
26 | return self.act(self.conv(x))
27 |
28 | class Block(nn.Module):
29 | def __init__(self, c1, c2, c3, n=4, e=1, ids=[0]):
30 | super(Block, self).__init__()
31 | c_ = int(c2 * e)
32 |
33 | self.ids = ids
34 | self.cv1 = Conv(c1, c_, 1, 1)
35 | self.cv2 = Conv(c1, c_, 1, 1)
36 | self.cv3 = nn.ModuleList(
37 | [Conv(c_ if i ==0 else c2, c2, 3, 1) for i in range(n)]
38 | )
39 | self.cv4 = Conv(c_ * 2 + c2 * (len(ids) - 2), c3, 1, 1)
40 |
41 | def forward(self, x):
42 | x_1 = self.cv1(x)
43 | x_2 = self.cv2(x)
44 |
45 | x_all = [x_1, x_2]
46 | for i in range(len(self.cv3)):
47 | x_2 = self.cv3[i](x_2)
48 | x_all.append(x_2)
49 |
50 | out = self.cv4(torch.cat([x_all[id] for id in self.ids], 1))
51 | return out
52 |
53 | class MP(nn.Module):
54 | def __init__(self, k=2):
55 | super(MP, self).__init__()
56 | self.m = nn.MaxPool2d(kernel_size=k, stride=k)
57 |
58 | def forward(self, x):
59 | return self.m(x)
60 |
61 | class Transition(nn.Module):
62 | def __init__(self, c1, c2):
63 | super(Transition, self).__init__()
64 | self.cv1 = Conv(c1, c2, 1, 1)
65 | self.cv2 = Conv(c1, c2, 1, 1)
66 | self.cv3 = Conv(c2, c2, 3, 2)
67 |
68 | self.mp = MP()
69 |
70 | def forward(self, x):
71 | x_1 = self.mp(x)
72 | x_1 = self.cv1(x_1)
73 |
74 | x_2 = self.cv2(x)
75 | x_2 = self.cv3(x_2)
76 |
77 | return torch.cat([x_2, x_1], 1)
78 |
79 | class Backbone(nn.Module):
80 | def __init__(self, transition_channels, block_channels, n, phi, pretrained=False):
81 | super().__init__()
82 | #-----------------------------------------------#
83 | # 输入图片是640, 640, 3
84 | #-----------------------------------------------#
85 | ids = {
86 | 'l' : [-1, -3, -5, -6],
87 | 'x' : [-1, -3, -5, -7, -8],
88 | }[phi]
89 | self.stem = nn.Sequential(
90 | Conv(3, transition_channels, 3, 1),
91 | Conv(transition_channels, transition_channels * 2, 3, 2),
92 | Conv(transition_channels * 2, transition_channels * 2, 3, 1),
93 | )
94 | self.dark2 = nn.Sequential(
95 | Conv(transition_channels * 2, transition_channels * 4, 3, 2),
96 | Block(transition_channels * 4, block_channels * 2, transition_channels * 8, n=n, ids=ids),
97 | )
98 | self.dark3 = nn.Sequential(
99 | Transition(transition_channels * 8, transition_channels * 4),
100 | Block(transition_channels * 8, block_channels * 4, transition_channels * 16, n=n, ids=ids),
101 | )
102 | self.dark4 = nn.Sequential(
103 | Transition(transition_channels * 16, transition_channels * 8),
104 | Block(transition_channels * 16, block_channels * 8, transition_channels * 32, n=n, ids=ids),
105 | )
106 | self.dark5 = nn.Sequential(
107 | Transition(transition_channels * 32, transition_channels * 16),
108 | Block(transition_channels * 32, block_channels * 8, transition_channels * 32, n=n, ids=ids),
109 | )
110 |
111 | if pretrained:
112 | url = {
113 | "l" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_backbone_weights.pth',
114 | "x" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_x_backbone_weights.pth',
115 | }[phi]
116 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", model_dir="./model_data")
117 | self.load_state_dict(checkpoint, strict=False)
118 | print("Load weights from " + url.split('/')[-1])
119 |
120 | def forward(self, x):
121 | x = self.stem(x)
122 | x = self.dark2(x)
123 | #-----------------------------------------------#
124 | # dark3的输出为80, 80, 256,是一个有效特征层
125 | #-----------------------------------------------#
126 | x = self.dark3(x)
127 | feat1 = x
128 | #-----------------------------------------------#
129 | # dark4的输出为40, 40, 512,是一个有效特征层
130 | #-----------------------------------------------#
131 | x = self.dark4(x)
132 | feat2 = x
133 | #-----------------------------------------------#
134 | # dark5的输出为20, 20, 1024,是一个有效特征层
135 | #-----------------------------------------------#
136 | x = self.dark5(x)
137 | feat3 = x
138 | return feat1, feat2, feat3
139 |
--------------------------------------------------------------------------------
/yolov7/nets/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
--------------------------------------------------------------------------------
/yolov7/nets/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
61 | and passed to a registered callback.
62 | - After receiving the messages, the master device should gather the information and determine to message passed
63 | back to each slave devices.
64 | """
65 |
66 | def __init__(self, master_callback):
67 | """
68 | Args:
69 | master_callback: a callback to be invoked after having collected messages from slave devices.
70 | """
71 | self._master_callback = master_callback
72 | self._queue = queue.Queue()
73 | self._registry = collections.OrderedDict()
74 | self._activated = False
75 |
76 | def __getstate__(self):
77 | return {'master_callback': self._master_callback}
78 |
79 | def __setstate__(self, state):
80 | self.__init__(state['master_callback'])
81 |
82 | def register_slave(self, identifier):
83 | """
84 | Register an slave device.
85 | Args:
86 | identifier: an identifier, usually is the device id.
87 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
88 | """
89 | if self._activated:
90 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
91 | self._activated = False
92 | self._registry.clear()
93 | future = FutureResult()
94 | self._registry[identifier] = _MasterRegistry(future)
95 | return SlavePipe(identifier, self._queue, future)
96 |
97 | def run_master(self, master_msg):
98 | """
99 | Main entry for the master device in each forward pass.
100 | The messages were first collected from each devices (including the master device), and then
101 | an callback will be invoked to compute the message to be sent back to each devices
102 | (including the master device).
103 | Args:
104 | master_msg: the message that the master want to send to itself. This will be placed as the first
105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106 | Returns: the message to be sent back to the master device.
107 | """
108 | self._activated = True
109 |
110 | intermediates = [(0, master_msg)]
111 | for i in range(self.nr_slaves):
112 | intermediates.append(self._queue.get())
113 |
114 | results = self._master_callback(intermediates)
115 | assert results[0][0] == 0, 'The first result should belongs to the master.'
116 |
117 | for i, res in results:
118 | if i == 0:
119 | continue
120 | self._registry[i].result.put(res)
121 |
122 | for i in range(self.nr_slaves):
123 | assert self._queue.get() is True
124 |
125 | return results[0][1]
126 |
127 | @property
128 | def nr_slaves(self):
129 | return len(self._registry)
130 |
--------------------------------------------------------------------------------
/yolov7/nets/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
31 | Note that, as all modules are isomorphism, we assign each sub-module with a context
32 | (shared among multiple copies of this module on different devices).
33 | Through this context, different copies can share some information.
34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
35 | of any slave copies.
36 | """
37 | master_copy = modules[0]
38 | nr_modules = len(list(master_copy.modules()))
39 | ctxs = [CallbackContext() for _ in range(nr_modules)]
40 |
41 | for i, module in enumerate(modules):
42 | for j, m in enumerate(module.modules()):
43 | if hasattr(m, '__data_parallel_replicate__'):
44 | m.__data_parallel_replicate__(ctxs[j], i)
45 |
46 |
47 | class DataParallelWithCallback(DataParallel):
48 | """
49 | Data Parallel with a replication callback.
50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
51 | original `replicate` function.
52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
53 | Examples:
54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
56 | # sync_bn.__data_parallel_replicate__ will be invoked.
57 | """
58 |
59 | def replicate(self, module, device_ids):
60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
61 | execute_replication_callbacks(modules)
62 | return modules
63 |
64 |
65 | def patch_replication_callback(data_parallel):
66 | """
67 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
68 | Useful when you have customized `DataParallel` implementation.
69 | Examples:
70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
72 | > patch_replication_callback(sync_bn)
73 | # this is equivalent to
74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
76 | """
77 |
78 | assert isinstance(data_parallel, DataParallel)
79 |
80 | old_replicate = data_parallel.replicate
81 |
82 | @functools.wraps(old_replicate)
83 | def new_replicate(module, device_ids):
84 | modules = old_replicate(module, device_ids)
85 | execute_replication_callbacks(modules)
86 | return modules
87 |
88 | data_parallel.replicate = new_replicate
--------------------------------------------------------------------------------
/yolov7/nets/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/yolov7/predict.py:
--------------------------------------------------------------------------------
1 | #-----------------------------------------------------------------------#
2 | # predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能
3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。
4 | #-----------------------------------------------------------------------#
5 | import time
6 |
7 | import cv2
8 | import numpy as np
9 | from PIL import Image
10 |
11 | from yolo import YOLO
12 |
13 | if __name__ == "__main__":
14 | yolo = YOLO()
15 | #----------------------------------------------------------------------------------------------------------#
16 | # mode用于指定测试的模式:
17 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
18 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
19 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
20 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
21 | # 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。
22 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。
23 | #----------------------------------------------------------------------------------------------------------#
24 | mode = "predict"
25 | #-------------------------------------------------------------------------#
26 | # crop 指定了是否在单张图片预测后对目标进行截取
27 | # count 指定了是否进行目标的计数
28 | # crop、count仅在mode='predict'时有效
29 | #-------------------------------------------------------------------------#
30 | crop = False
31 | count = False
32 | #----------------------------------------------------------------------------------------------------------#
33 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
34 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
35 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存
36 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
37 | # video_fps 用于保存的视频的fps
38 | #
39 | # video_path、video_save_path和video_fps仅在mode='video'时有效
40 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
41 | #----------------------------------------------------------------------------------------------------------#
42 | video_path = "img/1.mp4"
43 | video_save_path = "2.mp4"
44 | video_fps = 25.0
45 | #----------------------------------------------------------------------------------------------------------#
46 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
47 | # fps_image_path 用于指定测试的fps图片
48 | #
49 | # test_interval和fps_image_path仅在mode='fps'有效
50 | #----------------------------------------------------------------------------------------------------------#
51 | test_interval = 100
52 | fps_image_path = "img/street.jpg"
53 | #-------------------------------------------------------------------------#
54 | # dir_origin_path 指定了用于检测的图片的文件夹路径
55 | # dir_save_path 指定了检测完图片的保存路径
56 | #
57 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
58 | #-------------------------------------------------------------------------#
59 | dir_origin_path = "img/"
60 | dir_save_path = "img_out/"
61 | #-------------------------------------------------------------------------#
62 | # heatmap_save_path 热力图的保存路径,默认保存在model_data下
63 | #
64 | # heatmap_save_path仅在mode='heatmap'有效
65 | #-------------------------------------------------------------------------#
66 | heatmap_save_path = "model_data/heatmap_vision.png"
67 | #-------------------------------------------------------------------------#
68 | # simplify 使用Simplify onnx
69 | # onnx_save_path 指定了onnx的保存路径
70 | #-------------------------------------------------------------------------#
71 | simplify = True
72 | onnx_save_path = "model_data/models.onnx"
73 |
74 | if mode == "predict":
75 | '''
76 | 1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。
77 | 2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
78 | 3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
79 | 在原图上利用矩阵的方式进行截取。
80 | 4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断,
81 | 比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
82 | '''
83 | while True:
84 | img = input('Input image filename:')
85 | try:
86 | image = Image.open(img)
87 | except:
88 | print('Open Error! Try again!')
89 | continue
90 | else:
91 | r_image = yolo.detect_image(image, crop = crop, count=count)
92 | r_image.show()
93 |
94 | elif mode == "video":
95 | capture = cv2.VideoCapture(video_path)
96 | if video_save_path!="":
97 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
98 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
99 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
100 |
101 | ref, frame = capture.read()
102 | if not ref:
103 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
104 |
105 | fps = 0.0
106 | while(True):
107 | t1 = time.time()
108 | # 读取某一帧
109 | ref, frame = capture.read()
110 | if not ref:
111 | break
112 | # 格式转变,BGRtoRGB
113 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
114 | # 转变成Image
115 | frame = Image.fromarray(np.uint8(frame))
116 | # 进行检测
117 | frame = np.array(yolo.detect_image(frame))
118 | # RGBtoBGR满足opencv显示格式
119 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
120 |
121 | fps = ( fps + (1./(time.time()-t1)) ) / 2
122 | print("fps= %.2f"%(fps))
123 | frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
124 |
125 | cv2.imshow("video",frame)
126 | c= cv2.waitKey(1) & 0xff
127 | if video_save_path!="":
128 | out.write(frame)
129 |
130 | if c==27:
131 | capture.release()
132 | break
133 |
134 | print("Video Detection Done!")
135 | capture.release()
136 | if video_save_path!="":
137 | print("Save processed video to the path :" + video_save_path)
138 | out.release()
139 | cv2.destroyAllWindows()
140 |
141 | elif mode == "fps":
142 | img = Image.open(fps_image_path)
143 | tact_time = yolo.get_FPS(img, test_interval)
144 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')
145 |
146 | elif mode == "dir_predict":
147 | import os
148 |
149 | from tqdm import tqdm
150 |
151 | img_names = os.listdir(dir_origin_path)
152 | for img_name in tqdm(img_names):
153 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
154 | image_path = os.path.join(dir_origin_path, img_name)
155 | image = Image.open(image_path)
156 | r_image = yolo.detect_image(image)
157 | if not os.path.exists(dir_save_path):
158 | os.makedirs(dir_save_path)
159 | r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
160 |
161 | elif mode == "heatmap":
162 | while True:
163 | img = input('Input image filename:')
164 | try:
165 | image = Image.open(img)
166 | except:
167 | print('Open Error! Try again!')
168 | continue
169 | else:
170 | yolo.detect_heatmap(image, heatmap_save_path)
171 |
172 | elif mode == "export_onnx":
173 | yolo.convert_to_onnx(simplify, onnx_save_path)
174 |
175 | else:
176 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.")
177 |
--------------------------------------------------------------------------------
/yolov7/predict_RGB.py:
--------------------------------------------------------------------------------
1 | #-----------------------------------------------------------------------#
2 | # predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能
3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。
4 | #-----------------------------------------------------------------------#
5 | import time
6 |
7 | import cv2
8 | import numpy as np
9 | from PIL import Image
10 |
11 | from yolo_RGB import YOLO_RGB
12 |
13 | if __name__ == "__main__":
14 | yolo = YOLO_RGB()
15 | #----------------------------------------------------------------------------------------------------------#
16 | # mode用于指定测试的模式:
17 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
18 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
19 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
20 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
21 | # 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。
22 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。
23 | #----------------------------------------------------------------------------------------------------------#
24 | mode = "predict"
25 | #-------------------------------------------------------------------------#
26 | # crop 指定了是否在单张图片预测后对目标进行截取
27 | # count 指定了是否进行目标的计数
28 | # crop、count仅在mode='predict'时有效
29 | #-------------------------------------------------------------------------#
30 | crop = False
31 | count = False
32 | #----------------------------------------------------------------------------------------------------------#
33 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
34 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
35 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存
36 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
37 | # video_fps 用于保存的视频的fps
38 | #
39 | # video_path、video_save_path和video_fps仅在mode='video'时有效
40 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
41 | #----------------------------------------------------------------------------------------------------------#
42 | video_path = "img/1.mp4"
43 | video_save_path = "2.mp4"
44 | video_fps = 25.0
45 | #----------------------------------------------------------------------------------------------------------#
46 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
47 | # fps_image_path 用于指定测试的fps图片
48 | #
49 | # test_interval和fps_image_path仅在mode='fps'有效
50 | #----------------------------------------------------------------------------------------------------------#
51 | test_interval = 100
52 | fps_image_path = "img/street.jpg"
53 | #-------------------------------------------------------------------------#
54 | # dir_origin_path 指定了用于检测的图片的文件夹路径
55 | # dir_save_path 指定了检测完图片的保存路径
56 | #
57 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
58 | #-------------------------------------------------------------------------#
59 | # dir_origin_path = r"D:\KAIST数据集\重新标注的kaist\kaist_wash_picture_test\lwir"
60 | # dir_origin_path = r"D:\KAIST数据集\重新标注的kaist\kaist_wash_picture_test\visible"
61 | dir_origin_path ="../CenterNet/img/"
62 | dir_save_path = "img_out/voc4_new/"
63 | #-------------------------------------------------------------------------#
64 | # heatmap_save_path 热力图的保存路径,默认保存在model_data下
65 | #
66 | # heatmap_save_path仅在mode='heatmap'有效
67 | #-------------------------------------------------------------------------#
68 | heatmap_save_path = "model_data/heatmap_vision.png"
69 | #-------------------------------------------------------------------------#
70 | # simplify 使用Simplify onnx
71 | # onnx_save_path 指定了onnx的保存路径
72 | #-------------------------------------------------------------------------#
73 | simplify = True
74 | onnx_save_path = "model_data/models.onnx"
75 |
76 | if mode == "predict":
77 | '''
78 | 1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。
79 | 2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
80 | 3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
81 | 在原图上利用矩阵的方式进行截取。
82 | 4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断,
83 | 比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
84 | '''
85 | while True:
86 | img = input('Input image filename:')
87 | try:
88 | image = Image.open(img)
89 | except:
90 | print('Open Error! Try again!')
91 | continue
92 | else:
93 | r_image = yolo.detect_image(image, crop = crop, count=count)
94 | r_image.show()
95 |
96 | elif mode == "video":
97 | capture = cv2.VideoCapture(video_path)
98 | if video_save_path!="":
99 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
100 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
101 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
102 |
103 | ref, frame = capture.read()
104 | if not ref:
105 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
106 |
107 | fps = 0.0
108 | while(True):
109 | t1 = time.time()
110 | # 读取某一帧
111 | ref, frame = capture.read()
112 | if not ref:
113 | break
114 | # 格式转变,BGRtoRGB
115 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
116 | # 转变成Image
117 | frame = Image.fromarray(np.uint8(frame))
118 | # 进行检测
119 | frame = np.array(yolo.detect_image(frame))
120 | # RGBtoBGR满足opencv显示格式
121 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
122 |
123 | fps = ( fps + (1./(time.time()-t1)) ) / 2
124 | print("fps= %.2f"%(fps))
125 | frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
126 |
127 | cv2.imshow("video",frame)
128 | c= cv2.waitKey(1) & 0xff
129 | if video_save_path!="":
130 | out.write(frame)
131 |
132 | if c==27:
133 | capture.release()
134 | break
135 |
136 | print("Video Detection Done!")
137 | capture.release()
138 | if video_save_path!="":
139 | print("Save processed video to the path :" + video_save_path)
140 | out.release()
141 | cv2.destroyAllWindows()
142 |
143 | elif mode == "fps":
144 | img = Image.open(fps_image_path)
145 | tact_time = yolo.get_FPS(img, test_interval)
146 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')
147 |
148 | elif mode == "dir_predict":
149 | import os
150 |
151 | from tqdm import tqdm
152 |
153 | img_names = os.listdir(dir_origin_path)
154 | for img_name in tqdm(img_names):
155 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
156 | image_path = os.path.join(dir_origin_path, img_name)
157 | image = Image.open(image_path)
158 | r_image = yolo.detect_image(image)
159 | if not os.path.exists(dir_save_path):
160 | os.makedirs(dir_save_path)
161 | r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
162 |
163 | elif mode == "heatmap":
164 | while True:
165 | img = input('Input image filename:')
166 | try:
167 | image = Image.open(img)
168 | except:
169 | print('Open Error! Try again!')
170 | continue
171 | else:
172 | yolo.detect_heatmap(image, heatmap_save_path)
173 |
174 | elif mode == "export_onnx":
175 | yolo.convert_to_onnx(simplify, onnx_save_path)
176 |
177 | else:
178 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.")
179 |
--------------------------------------------------------------------------------
/yolov7/predict_T.py:
--------------------------------------------------------------------------------
1 | #-----------------------------------------------------------------------#
2 | # predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能
3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。
4 | #-----------------------------------------------------------------------#
5 | import time
6 |
7 | import cv2
8 | import numpy as np
9 | from PIL import Image
10 |
11 | from yolo_T import YOLO_T
12 |
13 | if __name__ == "__main__":
14 | yolo = YOLO_T()
15 | #----------------------------------------------------------------------------------------------------------#
16 | # mode用于指定测试的模式:
17 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
18 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
19 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
20 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
21 | # 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。
22 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。
23 | #----------------------------------------------------------------------------------------------------------#
24 | mode = "predict"
25 | #-------------------------------------------------------------------------#
26 | # crop 指定了是否在单张图片预测后对目标进行截取
27 | # count 指定了是否进行目标的计数
28 | # crop、count仅在mode='predict'时有效
29 | #-------------------------------------------------------------------------#
30 | crop = False
31 | count = False
32 | #----------------------------------------------------------------------------------------------------------#
33 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
34 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
35 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存
36 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
37 | # video_fps 用于保存的视频的fps
38 | #
39 | # video_path、video_save_path和video_fps仅在mode='video'时有效
40 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
41 | #----------------------------------------------------------------------------------------------------------#
42 | video_path = "img/1.mp4"
43 | video_save_path = "2.mp4"
44 | video_fps = 25.0
45 | #----------------------------------------------------------------------------------------------------------#
46 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
47 | # fps_image_path 用于指定测试的fps图片
48 | #
49 | # test_interval和fps_image_path仅在mode='fps'有效
50 | #----------------------------------------------------------------------------------------------------------#
51 | test_interval = 100
52 | fps_image_path = "img/street.jpg"
53 | #-------------------------------------------------------------------------#
54 | # dir_origin_path 指定了用于检测的图片的文件夹路径
55 | # dir_save_path 指定了检测完图片的保存路径
56 | #
57 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
58 | #-------------------------------------------------------------------------#
59 | # dir_origin_path = r"D:\KAIST数据集\重新标注的kaist\kaist_wash_picture_test\lwir"
60 | # dir_origin_path = r"D:\KAIST数据集\重新标注的kaist\kaist_wash_picture_test\visible"
61 | dir_origin_path ="../CenterNet/img/"
62 | dir_save_path = "img_out/voc4_new/"
63 | #-------------------------------------------------------------------------#
64 | # heatmap_save_path 热力图的保存路径,默认保存在model_data下
65 | #
66 | # heatmap_save_path仅在mode='heatmap'有效
67 | #-------------------------------------------------------------------------#
68 | heatmap_save_path = "model_data/heatmap_vision.png"
69 | #-------------------------------------------------------------------------#
70 | # simplify 使用Simplify onnx
71 | # onnx_save_path 指定了onnx的保存路径
72 | #-------------------------------------------------------------------------#
73 | simplify = True
74 | onnx_save_path = "model_data/models.onnx"
75 |
76 | if mode == "predict":
77 | '''
78 | 1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。
79 | 2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
80 | 3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
81 | 在原图上利用矩阵的方式进行截取。
82 | 4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断,
83 | 比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
84 | '''
85 | while True:
86 | img = input('Input image filename:')
87 | try:
88 | image = Image.open(img)
89 | except:
90 | print('Open Error! Try again!')
91 | continue
92 | else:
93 | r_image = yolo.detect_image(image, crop = crop, count=count)
94 | r_image.show()
95 |
96 | elif mode == "video":
97 | capture = cv2.VideoCapture(video_path)
98 | if video_save_path!="":
99 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
100 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
101 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
102 |
103 | ref, frame = capture.read()
104 | if not ref:
105 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
106 |
107 | fps = 0.0
108 | while(True):
109 | t1 = time.time()
110 | # 读取某一帧
111 | ref, frame = capture.read()
112 | if not ref:
113 | break
114 | # 格式转变,BGRtoRGB
115 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
116 | # 转变成Image
117 | frame = Image.fromarray(np.uint8(frame))
118 | # 进行检测
119 | frame = np.array(yolo.detect_image(frame))
120 | # RGBtoBGR满足opencv显示格式
121 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
122 |
123 | fps = ( fps + (1./(time.time()-t1)) ) / 2
124 | print("fps= %.2f"%(fps))
125 | frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
126 |
127 | cv2.imshow("video",frame)
128 | c= cv2.waitKey(1) & 0xff
129 | if video_save_path!="":
130 | out.write(frame)
131 |
132 | if c==27:
133 | capture.release()
134 | break
135 |
136 | print("Video Detection Done!")
137 | capture.release()
138 | if video_save_path!="":
139 | print("Save processed video to the path :" + video_save_path)
140 | out.release()
141 | cv2.destroyAllWindows()
142 |
143 | elif mode == "fps":
144 | img = Image.open(fps_image_path)
145 | tact_time = yolo.get_FPS(img, test_interval)
146 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')
147 |
148 | elif mode == "dir_predict":
149 | import os
150 |
151 | from tqdm import tqdm
152 |
153 | img_names = os.listdir(dir_origin_path)
154 | for img_name in tqdm(img_names):
155 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
156 | image_path = os.path.join(dir_origin_path, img_name)
157 | image = Image.open(image_path)
158 | r_image = yolo.detect_image(image)
159 | if not os.path.exists(dir_save_path):
160 | os.makedirs(dir_save_path)
161 | r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
162 |
163 | elif mode == "heatmap":
164 | while True:
165 | img = input('Input image filename:')
166 | try:
167 | image = Image.open(img)
168 | except:
169 | print('Open Error! Try again!')
170 | continue
171 | else:
172 | yolo.detect_heatmap(image, heatmap_save_path)
173 |
174 | elif mode == "export_onnx":
175 | yolo.convert_to_onnx(simplify, onnx_save_path)
176 |
177 | else:
178 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.")
179 |
--------------------------------------------------------------------------------
/yolov7/summary.py:
--------------------------------------------------------------------------------
1 | #--------------------------------------------#
2 | # 该部分代码用于看网络结构
3 | #--------------------------------------------#
4 | import torch
5 | from thop import clever_format, profile
6 |
7 | from nets.yolo import YoloBody
8 |
9 | if __name__ == "__main__":
10 | input_shape = [640, 640]
11 | anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
12 | num_classes = 80
13 | phi = 'l'
14 |
15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16 | m = YoloBody(anchors_mask, num_classes, phi, False, phi_attention=1).to(device)
17 | for i in m.children():
18 | print(i)
19 | print('==============================')
20 |
21 | dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)
22 | flops, params = profile(m.to(device), (dummy_input, ), verbose=False)
23 | #--------------------------------------------------------#
24 | # flops * 2是因为profile没有将卷积作为两个operations
25 | # 有些论文将卷积算乘法、加法两个operations。此时乘2
26 | # 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2
27 | # 本代码选择乘2,参考YOLOX。
28 | #--------------------------------------------------------#
29 | flops = flops * 2
30 | flops, params = clever_format([flops, params], "%.3f")
31 | print('Total GFLOPS: %s' % (flops))
32 | print('Total params: %s' % (params))
33 |
--------------------------------------------------------------------------------
/yolov7/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #
--------------------------------------------------------------------------------
/yolov7/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 |
5 | #---------------------------------------------------------#
6 | # 将图像转换成RGB图像,防止灰度图在预测时报错。
7 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
8 | #---------------------------------------------------------#
9 | def cvtColor(image):
10 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
11 | return image
12 | else:
13 | image = image.convert('RGB')
14 | return image
15 |
16 | #---------------------------------------------------#
17 | # 对输入图像进行resize
18 | #---------------------------------------------------#
19 | def resize_image(image, size, letterbox_image):
20 | iw, ih = image.size
21 | w, h = size
22 | if letterbox_image:
23 | scale = min(w/iw, h/ih)
24 | nw = int(iw*scale)
25 | nh = int(ih*scale)
26 |
27 | image = image.resize((nw,nh), Image.BICUBIC)
28 | new_image = Image.new('RGB', size, (128,128,128))
29 | new_image.paste(image, ((w-nw)//2, (h-nh)//2))
30 | else:
31 | new_image = image.resize((w, h), Image.BICUBIC)
32 | return new_image
33 |
34 | #---------------------------------------------------#
35 | # 获得类
36 | #---------------------------------------------------#
37 | def get_classes(classes_path):
38 | with open(classes_path, encoding='utf-8') as f:
39 | class_names = f.readlines()
40 | class_names = [c.strip() for c in class_names]
41 | return class_names, len(class_names)
42 |
43 | #---------------------------------------------------#
44 | # 获得先验框
45 | #---------------------------------------------------#
46 | def get_anchors(anchors_path):
47 | '''loads the anchors from a file'''
48 | with open(anchors_path, encoding='utf-8') as f:
49 | anchors = f.readline()
50 | anchors = [float(x) for x in anchors.split(',')]
51 | anchors = np.array(anchors).reshape(-1, 2)
52 | return anchors, len(anchors)
53 |
54 | #---------------------------------------------------#
55 | # 获得学习率
56 | #---------------------------------------------------#
57 | def get_lr(optimizer):
58 | for param_group in optimizer.param_groups:
59 | return param_group['lr']
60 |
61 | def preprocess_input(image):
62 | image /= 255.0
63 | return image
64 |
65 | def show_config(**kwargs):
66 | print('Configurations:')
67 | print('-' * 70)
68 | print('|%25s | %40s|' % ('keys', 'values'))
69 | print('-' * 70)
70 | for key, value in kwargs.items():
71 | print('|%25s | %40s|' % (str(key), str(value)))
72 | print('-' * 70)
73 |
74 | def download_weights(phi, model_dir="./model_data"):
75 | import os
76 | from torch.hub import load_state_dict_from_url
77 |
78 | download_urls = {
79 | "l" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_backbone_weights.pth',
80 | "x" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_x_backbone_weights.pth',
81 | }
82 | url = download_urls[phi]
83 |
84 | if not os.path.exists(model_dir):
85 | os.makedirs(model_dir)
86 | load_state_dict_from_url(url, model_dir)
--------------------------------------------------------------------------------
/yolov7/utils/utils_fit.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from tqdm import tqdm
5 |
6 | from yolov7.utils.utils import get_lr
7 |
8 | def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
9 | loss = 0
10 | val_loss = 0
11 |
12 | if local_rank == 0:
13 | print('Start Train')
14 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
15 | model_train.train()
16 | for iteration, batch in enumerate(gen):
17 | if iteration >= epoch_step:
18 | break
19 |
20 | images, targets = batch[0], batch[1]
21 | # print(batch)
22 | # print(targets)
23 | with torch.no_grad():
24 | if cuda:
25 | images = images.cuda(local_rank)
26 | targets = targets.cuda(local_rank)
27 | #----------------------#
28 | # 清零梯度
29 | #----------------------#
30 | optimizer.zero_grad()
31 | if not fp16:
32 | #----------------------#
33 | # 前向传播
34 | #----------------------#
35 | outputs = model_train(images)
36 | loss_value = yolo_loss(outputs, targets, images)
37 |
38 | #----------------------#
39 | # 反向传播
40 | #----------------------#
41 | loss_value.backward()
42 | optimizer.step()
43 | else:
44 | from torch.cuda.amp import autocast
45 | with autocast():
46 | #----------------------#
47 | # 前向传播
48 | #----------------------#
49 | outputs = model_train(images)
50 |
51 | loss_value = yolo_loss(outputs, targets, images)
52 |
53 | #----------------------#
54 | # 反向传播
55 | #----------------------#
56 | scaler.scale(loss_value).backward()
57 | scaler.step(optimizer)
58 | scaler.update()
59 | if ema:
60 | ema.update(model_train)
61 |
62 | loss += loss_value.item()
63 |
64 | if local_rank == 0:
65 | pbar.set_postfix(**{'loss' : loss / (iteration + 1),
66 | 'lr' : get_lr(optimizer)})
67 | pbar.update(1)
68 |
69 | if local_rank == 0:
70 | pbar.close()
71 | print('Finish Train')
72 | print('Start Validation')
73 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
74 |
75 | if ema:
76 | model_train_eval = ema.ema
77 | else:
78 | model_train_eval = model_train.eval()
79 |
80 | for iteration, batch in enumerate(gen_val):
81 | if iteration >= epoch_step_val:
82 | break
83 | images, targets = batch[0], batch[1]
84 | with torch.no_grad():
85 | if cuda:
86 | images = images.cuda(local_rank)
87 | targets = targets.cuda(local_rank)
88 | #----------------------#
89 | # 清零梯度
90 | #----------------------#
91 | optimizer.zero_grad()
92 | #----------------------#
93 | # 前向传播
94 | #----------------------#
95 | outputs = model_train_eval(images)
96 | loss_value = yolo_loss(outputs, targets, images)
97 |
98 | val_loss += loss_value.item()
99 | if local_rank == 0:
100 | pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)})
101 | pbar.update(1)
102 |
103 | if local_rank == 0:
104 | pbar.close()
105 | print('Finish Validation')
106 | loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val)
107 | eval_callback.on_epoch_end(epoch + 1, model_train_eval)
108 | print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
109 | print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val))
110 |
111 | #-----------------------------------------------#
112 | # 保存权值
113 | #-----------------------------------------------#
114 | if ema:
115 | save_state_dict = ema.ema.state_dict()
116 | else:
117 | save_state_dict = model.state_dict()
118 |
119 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
120 | torch.save(save_state_dict, os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val)))
121 |
122 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
123 | print('Save best model to best_epoch_weights.pth')
124 | torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth"))
125 |
126 | torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth"))
--------------------------------------------------------------------------------
/yolov7/utils_coco/coco_annotation.py:
--------------------------------------------------------------------------------
1 | #-------------------------------------------------------#
2 | # 用于处理COCO数据集,根据json文件生成txt文件用于训练
3 | #-------------------------------------------------------#
4 | import json
5 | import os
6 | from collections import defaultdict
7 |
8 | #-------------------------------------------------------#
9 | # 指向了COCO训练集与验证集图片的路径
10 | #-------------------------------------------------------#
11 | train_datasets_path = "coco_dataset/train2017"
12 | val_datasets_path = "coco_dataset/val2017"
13 |
14 | #-------------------------------------------------------#
15 | # 指向了COCO训练集与验证集标签的路径
16 | #-------------------------------------------------------#
17 | train_annotation_path = "coco_dataset/annotations/instances_train2017.json"
18 | val_annotation_path = "coco_dataset/annotations/instances_val2017.json"
19 |
20 | #-------------------------------------------------------#
21 | # 生成的txt文件路径
22 | #-------------------------------------------------------#
23 | train_output_path = "coco_train.txt"
24 | val_output_path = "coco_val.txt"
25 |
26 | if __name__ == "__main__":
27 | name_box_id = defaultdict(list)
28 | id_name = dict()
29 | f = open(train_annotation_path, encoding='utf-8')
30 | data = json.load(f)
31 |
32 | annotations = data['annotations']
33 | for ant in annotations:
34 | id = ant['image_id']
35 | name = os.path.join(train_datasets_path, '%012d.jpg' % id)
36 | cat = ant['category_id']
37 | if cat >= 1 and cat <= 11:
38 | cat = cat - 1
39 | elif cat >= 13 and cat <= 25:
40 | cat = cat - 2
41 | elif cat >= 27 and cat <= 28:
42 | cat = cat - 3
43 | elif cat >= 31 and cat <= 44:
44 | cat = cat - 5
45 | elif cat >= 46 and cat <= 65:
46 | cat = cat - 6
47 | elif cat == 67:
48 | cat = cat - 7
49 | elif cat == 70:
50 | cat = cat - 9
51 | elif cat >= 72 and cat <= 82:
52 | cat = cat - 10
53 | elif cat >= 84 and cat <= 90:
54 | cat = cat - 11
55 | name_box_id[name].append([ant['bbox'], cat])
56 |
57 | f = open(train_output_path, 'w')
58 | for key in name_box_id.keys():
59 | f.write(key)
60 | box_infos = name_box_id[key]
61 | for info in box_infos:
62 | x_min = int(info[0][0])
63 | y_min = int(info[0][1])
64 | x_max = x_min + int(info[0][2])
65 | y_max = y_min + int(info[0][3])
66 |
67 | box_info = " %d,%d,%d,%d,%d" % (
68 | x_min, y_min, x_max, y_max, int(info[1]))
69 | f.write(box_info)
70 | f.write('\n')
71 | f.close()
72 |
73 | name_box_id = defaultdict(list)
74 | id_name = dict()
75 | f = open(val_annotation_path, encoding='utf-8')
76 | data = json.load(f)
77 |
78 | annotations = data['annotations']
79 | for ant in annotations:
80 | id = ant['image_id']
81 | name = os.path.join(val_datasets_path, '%012d.jpg' % id)
82 | cat = ant['category_id']
83 | if cat >= 1 and cat <= 11:
84 | cat = cat - 1
85 | elif cat >= 13 and cat <= 25:
86 | cat = cat - 2
87 | elif cat >= 27 and cat <= 28:
88 | cat = cat - 3
89 | elif cat >= 31 and cat <= 44:
90 | cat = cat - 5
91 | elif cat >= 46 and cat <= 65:
92 | cat = cat - 6
93 | elif cat == 67:
94 | cat = cat - 7
95 | elif cat == 70:
96 | cat = cat - 9
97 | elif cat >= 72 and cat <= 82:
98 | cat = cat - 10
99 | elif cat >= 84 and cat <= 90:
100 | cat = cat - 11
101 | name_box_id[name].append([ant['bbox'], cat])
102 |
103 | f = open(val_output_path, 'w')
104 | for key in name_box_id.keys():
105 | f.write(key)
106 | box_infos = name_box_id[key]
107 | for info in box_infos:
108 | x_min = int(info[0][0])
109 | y_min = int(info[0][1])
110 | x_max = x_min + int(info[0][2])
111 | y_max = y_min + int(info[0][3])
112 |
113 | box_info = " %d,%d,%d,%d,%d" % (
114 | x_min, y_min, x_max, y_max, int(info[1]))
115 | f.write(box_info)
116 | f.write('\n')
117 | f.close()
118 |
--------------------------------------------------------------------------------
/yolov7/utils_coco/get_map_coco.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 | from pycocotools.coco import COCO
8 | from pycocotools.cocoeval import COCOeval
9 | from tqdm import tqdm
10 |
11 | from utils.utils import cvtColor, preprocess_input, resize_image
12 | from yolo import YOLO
13 |
14 | #---------------------------------------------------------------------------#
15 | # map_mode用于指定该文件运行时计算的内容
16 | # map_mode为0代表整个map计算流程,包括获得预测结果、计算map。
17 | # map_mode为1代表仅仅获得预测结果。
18 | # map_mode为2代表仅仅获得计算map。
19 | #---------------------------------------------------------------------------#
20 | map_mode = 0
21 | #-------------------------------------------------------#
22 | # 指向了验证集标签与图片路径
23 | #-------------------------------------------------------#
24 | cocoGt_path = 'coco_dataset/annotations/instances_val2017.json'
25 | dataset_img_path = 'coco_dataset/val2017'
26 | #-------------------------------------------------------#
27 | # 结果输出的文件夹,默认为map_out
28 | #-------------------------------------------------------#
29 | temp_save_path = 'map_out/coco_eval'
30 |
31 | class mAP_YOLO(YOLO):
32 | #---------------------------------------------------#
33 | # 检测图片
34 | #---------------------------------------------------#
35 | def detect_image(self, image_id, image, results, clsid2catid):
36 | #---------------------------------------------------#
37 | # 计算输入图片的高和宽
38 | #---------------------------------------------------#
39 | image_shape = np.array(np.shape(image)[0:2])
40 | #---------------------------------------------------------#
41 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
42 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
43 | #---------------------------------------------------------#
44 | image = cvtColor(image)
45 | #---------------------------------------------------------#
46 | # 给图像增加灰条,实现不失真的resize
47 | # 也可以直接resize进行识别
48 | #---------------------------------------------------------#
49 | image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
50 | #---------------------------------------------------------#
51 | # 添加上batch_size维度
52 | #---------------------------------------------------------#
53 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
54 |
55 | with torch.no_grad():
56 | images = torch.from_numpy(image_data)
57 | if self.cuda:
58 | images = images.cuda()
59 | #---------------------------------------------------------#
60 | # 将图像输入网络当中进行预测!
61 | #---------------------------------------------------------#
62 | outputs = self.net(images)
63 | outputs = self.bbox_util.decode_box(outputs)
64 | #---------------------------------------------------------#
65 | # 将预测框进行堆叠,然后进行非极大抑制
66 | #---------------------------------------------------------#
67 | outputs = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
68 | image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
69 |
70 | if outputs[0] is None:
71 | return results
72 |
73 | top_label = np.array(outputs[0][:, 6], dtype = 'int32')
74 | top_conf = outputs[0][:, 4] * outputs[0][:, 5]
75 | top_boxes = outputs[0][:, :4]
76 |
77 | for i, c in enumerate(top_label):
78 | result = {}
79 | top, left, bottom, right = top_boxes[i]
80 |
81 | result["image_id"] = int(image_id)
82 | result["category_id"] = clsid2catid[c]
83 | result["bbox"] = [float(left),float(top),float(right-left),float(bottom-top)]
84 | result["score"] = float(top_conf[i])
85 | results.append(result)
86 | return results
87 |
88 | if __name__ == "__main__":
89 | if not os.path.exists(temp_save_path):
90 | os.makedirs(temp_save_path)
91 |
92 | cocoGt = COCO(cocoGt_path)
93 | ids = list(cocoGt.imgToAnns.keys())
94 | clsid2catid = cocoGt.getCatIds()
95 |
96 | if map_mode == 0 or map_mode == 1:
97 | yolo = mAP_YOLO(confidence = 0.001, nms_iou = 0.65)
98 |
99 | with open(os.path.join(temp_save_path, 'eval_results.json'),"w") as f:
100 | results = []
101 | for image_id in tqdm(ids):
102 | image_path = os.path.join(dataset_img_path, cocoGt.loadImgs(image_id)[0]['file_name'])
103 | image = Image.open(image_path)
104 | results = yolo.detect_image(image_id, image, results, clsid2catid)
105 | json.dump(results, f)
106 |
107 | if map_mode == 0 or map_mode == 2:
108 | cocoDt = cocoGt.loadRes(os.path.join(temp_save_path, 'eval_results.json'))
109 | cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
110 | cocoEval.evaluate()
111 | cocoEval.accumulate()
112 | cocoEval.summarize()
113 | print("Get map done.")
114 |
--------------------------------------------------------------------------------
/yolov7/voc_annotation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import xml.etree.ElementTree as ET
4 |
5 | import numpy as np
6 |
7 | from utils.utils import get_classes
8 |
9 | #--------------------------------------------------------------------------------------------------------------------------------#
10 | # annotation_mode用于指定该文件运行时计算的内容
11 | # annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
12 | # annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
13 | # annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
14 | #--------------------------------------------------------------------------------------------------------------------------------#
15 | annotation_mode = 0
16 | #-------------------------------------------------------------------#
17 | # 必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息
18 | # 与训练和预测所用的classes_path一致即可
19 | # 如果生成的2007_train.txt里面没有目标信息
20 | # 那么就是因为classes没有设定正确
21 | # 仅在annotation_mode为0和2的时候有效
22 | #-------------------------------------------------------------------#
23 | classes_path = r'D:\Deep_Learning_folds\ProbEn\yolov7\model_data\voc_classes.txt'
24 | #--------------------------------------------------------------------------------------------------------------------------------#
25 | # trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1
26 | # train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1
27 | # 仅在annotation_mode为0和1的时候有效
28 | #--------------------------------------------------------------------------------------------------------------------------------#
29 | trainval_percent = 0.9
30 | train_percent = 0.9
31 | #-------------------------------------------------------#
32 | # 指向VOC数据集所在的文件夹
33 | # 默认指向根目录下的VOC数据集
34 | #-------------------------------------------------------#
35 | VOCdevkit_path = 'D:\Deep_Learning_folds\ProbEn\yolov7\VOCdevkit'
36 |
37 | VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')]
38 | classes, _ = get_classes(classes_path)
39 |
40 | #-------------------------------------------------------#
41 | # 统计目标数量
42 | #-------------------------------------------------------#
43 | photo_nums = np.zeros(len(VOCdevkit_sets))
44 | nums = np.zeros(len(classes))
45 | def convert_annotation(year, image_id, list_file):
46 | in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8')
47 | tree=ET.parse(in_file)
48 | root = tree.getroot()
49 |
50 | for obj in root.iter('object'):
51 | difficult = 0
52 | if obj.find('difficult')!=None:
53 | difficult = obj.find('difficult').text
54 | cls = obj.find('name').text
55 | if cls not in classes or int(difficult)==1:
56 | continue
57 | cls_id = classes.index(cls)
58 | xmlbox = obj.find('bndbox')
59 | b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
60 | list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
61 |
62 | nums[classes.index(cls)] = nums[classes.index(cls)] + 1
63 |
64 | if __name__ == "__main__":
65 | random.seed(0)
66 | if " " in os.path.abspath(VOCdevkit_path):
67 | raise ValueError("数据集存放的文件夹路径与图片名称中不可以存在空格,否则会影响正常的模型训练,请注意修改。")
68 |
69 | if annotation_mode == 0 or annotation_mode == 1:
70 | print("Generate txt in ImageSets.")
71 | xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
72 | saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
73 | temp_xml = os.listdir(xmlfilepath)
74 | total_xml = []
75 | for xml in temp_xml:
76 | if xml.endswith(".xml"):
77 | total_xml.append(xml)
78 |
79 | num = len(total_xml)
80 | list = range(num)
81 | tv = int(num*trainval_percent)
82 | tr = int(tv*train_percent)
83 | trainval= random.sample(list,tv)
84 | train = random.sample(trainval,tr)
85 |
86 | print("train and val size",tv)
87 | print("train size",tr)
88 | ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')
89 | ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')
90 | ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')
91 | fval = open(os.path.join(saveBasePath,'val.txt'), 'w')
92 |
93 | for i in list:
94 | name=total_xml[i][:-4]+'\n'
95 | if i in trainval:
96 | ftrainval.write(name)
97 | if i in train:
98 | ftrain.write(name)
99 | else:
100 | fval.write(name)
101 | else:
102 | ftest.write(name)
103 |
104 | ftrainval.close()
105 | ftrain.close()
106 | fval.close()
107 | ftest.close()
108 | print("Generate txt in ImageSets done.")
109 |
110 | if annotation_mode == 0 or annotation_mode == 2:
111 | print("Generate 2007_train.txt and 2007_val.txt for train.")
112 | type_index = 0
113 | for year, image_set in VOCdevkit_sets:
114 | image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split()
115 | list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8')
116 | for image_id in image_ids:
117 | list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id))
118 |
119 | convert_annotation(year, image_id, list_file)
120 | list_file.write('\n')
121 | photo_nums[type_index] = len(image_ids)
122 | type_index += 1
123 | list_file.close()
124 | print("Generate 2007_train.txt and 2007_val.txt for train done.")
125 |
126 | def printTable(List1, List2):
127 | for i in range(len(List1[0])):
128 | print("|", end=' ')
129 | for j in range(len(List1)):
130 | print(List1[j][i].rjust(int(List2[j])), end=' ')
131 | print("|", end=' ')
132 | print()
133 |
134 | str_nums = [str(int(x)) for x in nums]
135 | tableData = [
136 | classes, str_nums
137 | ]
138 | colWidths = [0]*len(tableData)
139 | len1 = 0
140 | for i in range(len(tableData)):
141 | for j in range(len(tableData[i])):
142 | if len(tableData[i][j]) > colWidths[i]:
143 | colWidths[i] = len(tableData[i][j])
144 | printTable(tableData, colWidths)
145 |
146 | if photo_nums[0] <= 500:
147 | print("训练集数量小于500,属于较小的数据量,请注意设置较大的训练世代(Epoch)以满足足够的梯度下降次数(Step)。")
148 |
149 | if np.sum(nums) == 0:
150 | print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
151 | print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
152 | print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
153 | print("(重要的事情说三遍)。")
154 |
--------------------------------------------------------------------------------