├── .gitignore ├── .idea ├── ProbEn.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── CenterNet ├── LICENSE ├── centernet.py ├── change_weight.py ├── get_FNPI.py ├── get_map.py ├── model_data │ ├── people_classes.txt │ └── simhei.ttf ├── nets │ ├── CenterNet_yolov7.py │ ├── centernet.py │ ├── centernet_training.py │ ├── hourglass.py │ └── resnet50.py ├── predict.py ├── summary.py ├── train.py ├── utils │ ├── __init__.py │ ├── callbacks.py │ ├── dataloader.py │ ├── utils.py │ ├── utils_bbox.py │ ├── utils_fit.py │ └── utils_map.py ├── vision_for_centernet.py └── voc_annotation.py ├── ProbEn.py ├── ProbEn_time_test.py ├── README.md ├── detector_test.py ├── get_fusion_FNPI.py ├── get_fusion_FNPI_onlyYolo.py ├── get_fusion_map.py ├── get_fusion_map_onlyYolo.py ├── img ├── 1.jpg ├── 2.jpg └── street.jpg ├── img_out ├── 1.png ├── 2.png └── street.png ├── predict_with_probEn.py ├── predict_with_probEn_onlyYolo.py └── yolov7 ├── get_FNPI.py ├── get_map.py ├── kmeans_for_anchors.py ├── model_data ├── coco_classes.txt ├── simhei.ttf ├── voc_classes.txt └── yolo_anchors.txt ├── nets ├── SRModule.py ├── SR_Decoder.py ├── SR_Encoder.py ├── __init__.py ├── backbone.py ├── sync_batchnorm │ ├── __init__.py │ ├── batchnorm.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── yolo().py ├── yolo.py └── yolo_training.py ├── predict.py ├── predict_RGB.py ├── predict_T.py ├── summary.py ├── train.py ├── utils ├── __init__.py ├── attentions.py ├── callbacks.py ├── dataloader.py ├── utils.py ├── utils_FNPI.py ├── utils_bbox.py ├── utils_fit.py └── utils_map.py ├── utils_coco ├── coco_annotation.py └── get_map_coco.py ├── voc_annotation.py ├── yolo.py ├── yolo_RGB.py └── yolo_T.py /.gitignore: -------------------------------------------------------------------------------- 1 | /yolov7/VOCdevkit/ 2 | /CenterNet/2007_val.txt 3 | /CenterNet/2007_train.txt 4 | -------------------------------------------------------------------------------- /.idea/ProbEn.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 16 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 37 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /CenterNet/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Bubbliiiing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /CenterNet/change_weight.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def Combine_weights(ir_weights,rgb_weights,alpha,save_path): 6 | # lambda_ = np.random.beta(alpha, alpha) 7 | lambda_=alpha 8 | weights_1 = torch.load(ir_weights) 9 | # print(weights_1['head.reg_head.3.bias']) 10 | weights_2 = torch.load(rgb_weights) 11 | for key,value in weights_2.items(): 12 | if key in weights_1: 13 | weights_1[key]=torch.tensor(((1-lambda_)*weights_1[key].cpu().numpy()+lambda_ * weights_2[key].cpu().numpy()),device ='cuda:0') 14 | # print(weights_1['head.reg_head.3.bias']) 15 | torch.save(weights_1,r"D:\xiangmushiyan\centernet-pytorch-main\new_weights.pth") 16 | return save_path 17 | save_path=r'D:\xiangmushiyan\centernet-pytorch-main' 18 | ir_weights=r"D:\xiangmushiyan\centernet-pytorch-main\logs\loss_2022_10_17_16_36_01-yolov7-ir\best_epoch_weights.pth" 19 | rgb_weights=r"D:\xiangmushiyan\centernet-pytorch-main\logs\loss_2022_10_18_21_51_40-yolov7-rgb\best_epoch_weights.pth" 20 | alpha=0.99 21 | 22 | Combine_weights(ir_weights,rgb_weights,alpha,save_path) -------------------------------------------------------------------------------- /CenterNet/get_FNPI.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xml.etree.ElementTree as ET 3 | 4 | from PIL import Image 5 | from tqdm import tqdm 6 | 7 | from utils.utils import get_classes 8 | from yolov7.utils.utils_FNPI import get_FNPI 9 | from centernet import CenterNet 10 | 11 | if __name__ == "__main__": 12 | #------------------------------------------------------------------------------------------------------------------# 13 | # map_mode用于指定该文件运行时计算的内容 14 | # map_mode为0代表整个FNPI计算流程,包括获得预测结果、获得真实框、计算FNPI。 15 | # map_mode为1代表仅仅获得预测结果。 16 | # map_mode为2代表仅仅获得真实框。 17 | # map_mode为3代表仅仅计算FNPI。 18 | #-------------------------------------------------------------------------------------------------------------------# 19 | map_mode = 0 20 | #--------------------------------------------------------------------------------------# 21 | # 此处的classes_path用于指定需要测量FNPI的类别 22 | # 一般情况下与训练和预测所用的classes_path一致即可 23 | #--------------------------------------------------------------------------------------# 24 | # classes_path = 'model_data/people_classes_KAIST.txt' 25 | classes_path = r'E:\pythonProject\object-detection\ProbEn-master\CenterNet\model_data\people_classes_voc4.txt' 26 | #--------------------------------------------------------------------------------------# 27 | # FNPI_IOU作为判定预测框与真实框相匹配(即真实框所对应的目标被检测成功)的条件 28 | # 只有大于FNPI_IOU值才算检测成功 29 | #--------------------------------------------------------------------------------------# 30 | FNPI_IOU = 0.5 31 | #--------------------------------------------------------------------------------------# 32 | # confidence的设置与计算map时的设置情况不一样。 33 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,因此,map计算时的confidence的值应当设置的尽量小进而获得全部可能的预测框。 34 | # 而计算FNPI设置的置信度confidence应该与预测时的置信度一致,只有得分大于置信度的预测框会被保留下来 35 | #--------------------------------------------------------------------------------------# 36 | confidence = 0.5 37 | #--------------------------------------------------------------------------------------# 38 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。 39 | # 该值也应该与预测时设置的nms_iou一致。 40 | #--------------------------------------------------------------------------------------# 41 | nms_iou = 0.3 42 | #-------------------------------------------------------# 43 | # 指向VOC数据集所在的文件夹 44 | # 默认指向根目录下的VOC数据集 45 | #-------------------------------------------------------# 46 | VOCdevkit_path = r'E:\pythonProject\object-detection\yolov7-pytorch-master\VOCdevkit' 47 | # VOCdevkit_path = r'D:\KAIST数据集\重新标注的kaist' 48 | #-------------------------------------------------------# 49 | # 结果输出的文件夹,默认为map_out 50 | #-------------------------------------------------------# 51 | FNPI_out_path = 'FNPI_out/FNPI_out_voc4' 52 | 53 | image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split() 54 | # image_ids = open(os.path.join(VOCdevkit_path, "kaist_wash_picture_test/test.txt")).read().strip().split() 55 | 56 | if not os.path.exists(FNPI_out_path): 57 | os.makedirs(FNPI_out_path) 58 | if not os.path.exists(os.path.join(FNPI_out_path, 'ground-truth')): 59 | os.makedirs(os.path.join(FNPI_out_path, 'ground-truth')) 60 | if not os.path.exists(os.path.join(FNPI_out_path, 'detection-results')): 61 | os.makedirs(os.path.join(FNPI_out_path, 'detection-results')) 62 | 63 | class_names, _ = get_classes(classes_path) 64 | 65 | if map_mode == 0 or map_mode == 1: 66 | print("Load model.") 67 | centernet = CenterNet(confidence = confidence, nms_iou = nms_iou) 68 | print("Load model done.") 69 | 70 | print("Get predict result.") 71 | for image_id in tqdm(image_ids): 72 | image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg") 73 | # image_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/lwir/"+image_id+".jpg") 74 | # image_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/visible/"+image_id+".jpg") 75 | image = Image.open(image_path) 76 | centernet.get_map_txt(image_id, image, class_names, FNPI_out_path) 77 | print("Get predict result done.") 78 | 79 | if map_mode == 0 or map_mode == 2: 80 | print("Get ground truth result.") 81 | for image_id in tqdm(image_ids): 82 | with open(os.path.join(FNPI_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: 83 | root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/"+image_id+".xml")).getroot() 84 | # root = ET.parse(os.path.join(VOCdevkit_path, "kaist_wash_annotation_test/"+image_id+".xml")).getroot() 85 | for obj in root.findall('object'): 86 | difficult_flag = False 87 | if obj.find('difficult')!=None: 88 | difficult = obj.find('difficult').text 89 | if int(difficult)==1: 90 | difficult_flag = True 91 | obj_name = obj.find('name').text 92 | if obj_name not in class_names: 93 | continue 94 | bndbox = obj.find('bndbox') 95 | left = bndbox.find('xmin').text 96 | top = bndbox.find('ymin').text 97 | right = bndbox.find('xmax').text 98 | bottom = bndbox.find('ymax').text 99 | 100 | if difficult_flag: 101 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom)) 102 | else: 103 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 104 | print("Get ground truth result done.") 105 | 106 | if map_mode == 0 or map_mode == 3: 107 | print("Get map.") 108 | get_FNPI(FNPI_IOU, True, path = FNPI_out_path) 109 | print("Get map done.") -------------------------------------------------------------------------------- /CenterNet/get_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xml.etree.ElementTree as ET 3 | 4 | from PIL import Image 5 | from tqdm import tqdm 6 | 7 | from centernet import CenterNet 8 | from utils.utils import get_classes 9 | from utils.utils_map import get_coco_map, get_map 10 | 11 | if __name__ == "__main__": 12 | ''' 13 | Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。 14 | 默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。 15 | 16 | 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值 17 | 因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框, 18 | ''' 19 | #------------------------------------------------------------------------------------------------------------------# 20 | # map_mode用于指定该文件运行时计算的内容 21 | # map_mode为0代表整个map计算流程,包括获得预测结果、获得真实框、计算VOC_map。 22 | # map_mode为1代表仅仅获得预测结果。 23 | # map_mode为2代表仅仅获得真实框。 24 | # map_mode为3代表仅仅计算VOC_map。 25 | # map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行 26 | #-------------------------------------------------------------------------------------------------------------------# 27 | map_mode = 0 28 | #--------------------------------------------------------------------------------------# 29 | # 此处的classes_path用于指定需要测量VOC_map的类别 30 | # 一般情况下与训练和预测所用的classes_path一致即可 31 | #--------------------------------------------------------------------------------------# 32 | classes_path = 'model_data/people_classes.txt' 33 | #--------------------------------------------------------------------------------------# 34 | # MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。 35 | # 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。 36 | # 37 | # 当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。 38 | # 因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低, 39 | #--------------------------------------------------------------------------------------# 40 | MINOVERLAP = 0.5 41 | #--------------------------------------------------------------------------------------# 42 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP 43 | # 因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。 44 | # 45 | # 该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。 46 | # 想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。 47 | #--------------------------------------------------------------------------------------# 48 | confidence = 0.02 49 | #--------------------------------------------------------------------------------------# 50 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。 51 | # 52 | # 该值一般不调整。 53 | #--------------------------------------------------------------------------------------# 54 | nms_iou = 0.5 55 | #---------------------------------------------------------------------------------------------------------------# 56 | # Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。 57 | # 58 | # 默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。 59 | # 因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。 60 | # 这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。 61 | #---------------------------------------------------------------------------------------------------------------# 62 | score_threhold = 0.5 63 | #-------------------------------------------------------# 64 | # map_vis用于指定是否开启VOC_map计算的可视化 65 | #-------------------------------------------------------# 66 | map_vis = False 67 | #-------------------------------------------------------# 68 | # 指向VOC数据集所在的文件夹 69 | # 默认指向根目录下的VOC数据集 70 | #-------------------------------------------------------# 71 | VOCdevkit_path = 'VOCdevkit' 72 | #-------------------------------------------------------# 73 | # 结果输出的文件夹,默认为map_out 74 | #-------------------------------------------------------# 75 | map_out_path = 'map_out' 76 | 77 | image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split() 78 | 79 | if not os.path.exists(map_out_path): 80 | os.makedirs(map_out_path) 81 | if not os.path.exists(os.path.join(map_out_path, 'ground-truth')): 82 | os.makedirs(os.path.join(map_out_path, 'ground-truth')) 83 | if not os.path.exists(os.path.join(map_out_path, 'detection-results')): 84 | os.makedirs(os.path.join(map_out_path, 'detection-results')) 85 | if not os.path.exists(os.path.join(map_out_path, 'images-optional')): 86 | os.makedirs(os.path.join(map_out_path, 'images-optional')) 87 | 88 | class_names, _ = get_classes(classes_path) 89 | 90 | if map_mode == 0 or map_mode == 1: 91 | print("Load model.") 92 | centernet = CenterNet(confidence = confidence, nms_iou = nms_iou) 93 | print("Load model done.") 94 | 95 | print("Get predict result.") 96 | for image_id in tqdm(image_ids): 97 | image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg") 98 | image = Image.open(image_path) 99 | if map_vis: 100 | image.save(os.path.join(map_out_path, "images-optional/" + image_id + ".jpg")) 101 | centernet.get_map_txt(image_id, image, class_names, map_out_path) 102 | print("Get predict result done.") 103 | 104 | if map_mode == 0 or map_mode == 2: 105 | print("Get ground truth result.") 106 | for image_id in tqdm(image_ids): 107 | with open(os.path.join(map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: 108 | root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/"+image_id+".xml")).getroot() 109 | for obj in root.findall('object'): 110 | difficult_flag = False 111 | if obj.find('difficult')!=None: 112 | difficult = obj.find('difficult').text 113 | if int(difficult)==1: 114 | difficult_flag = True 115 | obj_name = obj.find('name').text 116 | if obj_name not in class_names: 117 | continue 118 | bndbox = obj.find('bndbox') 119 | left = bndbox.find('xmin').text 120 | top = bndbox.find('ymin').text 121 | right = bndbox.find('xmax').text 122 | bottom = bndbox.find('ymax').text 123 | 124 | if difficult_flag: 125 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom)) 126 | else: 127 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 128 | print("Get ground truth result done.") 129 | 130 | if map_mode == 0 or map_mode == 3: 131 | print("Get map.") 132 | get_map(MINOVERLAP, True, score_threhold = score_threhold, path = map_out_path) 133 | print("Get map done.") 134 | 135 | if map_mode == 4: 136 | print("Get map.") 137 | get_coco_map(class_names = class_names, path = map_out_path) 138 | print("Get map done.") 139 | -------------------------------------------------------------------------------- /CenterNet/model_data/people_classes.txt: -------------------------------------------------------------------------------- 1 | dog 2 | person 3 | cat 4 | car -------------------------------------------------------------------------------- /CenterNet/model_data/simhei.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/ProbDet/0e203e0dc827ad34f8c5eb87c953f16703b9a5d1/CenterNet/model_data/simhei.ttf -------------------------------------------------------------------------------- /CenterNet/nets/CenterNet_yolov7.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def autopad(k, p=None): 6 | if p is None: 7 | p = k // 2 if isinstance(k, int) else [x // 2 for x in k] 8 | return p 9 | 10 | class SiLU(nn.Module): 11 | @staticmethod 12 | def forward(x): 13 | return x * torch.sigmoid(x) 14 | 15 | class Conv(nn.Module): 16 | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=SiLU()): # ch_in, ch_out, kernel, stride, padding, groups 17 | super(Conv, self).__init__() 18 | self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) 19 | self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03) 20 | self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) 21 | 22 | def forward(self, x): 23 | return self.act(self.bn(self.conv(x))) 24 | 25 | def fuseforward(self, x): 26 | return self.act(self.conv(x)) 27 | 28 | class Block(nn.Module): 29 | def __init__(self, c1, c2, c3, n=4, e=1, ids=[0]): 30 | super(Block, self).__init__() 31 | c_ = int(c2 * e) 32 | 33 | self.ids = ids 34 | self.cv1 = Conv(c1, c_, 1, 1) 35 | self.cv2 = Conv(c1, c_, 1, 1) 36 | self.cv3 = nn.ModuleList( 37 | [Conv(c_ if i ==0 else c2, c2, 3, 1) for i in range(n)] 38 | ) 39 | self.cv4 = Conv(c_ * 2 + c2 * (len(ids) - 2), c3, 1, 1) 40 | 41 | def forward(self, x): 42 | x_1 = self.cv1(x) 43 | x_2 = self.cv2(x) 44 | 45 | x_all = [x_1, x_2] 46 | for i in range(len(self.cv3)): 47 | x_2 = self.cv3[i](x_2) 48 | x_all.append(x_2) 49 | 50 | out = self.cv4(torch.cat([x_all[id] for id in self.ids], 1)) 51 | return out 52 | 53 | class MP(nn.Module): 54 | def __init__(self, k=2): 55 | super(MP, self).__init__() 56 | self.m = nn.MaxPool2d(kernel_size=k, stride=k) 57 | 58 | def forward(self, x): 59 | return self.m(x) 60 | 61 | class Transition(nn.Module): 62 | def __init__(self, c1, c2): 63 | super(Transition, self).__init__() 64 | self.cv1 = Conv(c1, c2, 1, 1) 65 | self.cv2 = Conv(c1, c2, 1, 1) 66 | self.cv3 = Conv(c2, c2, 3, 2) 67 | 68 | self.mp = MP() 69 | 70 | def forward(self, x): 71 | x_1 = self.mp(x) 72 | x_1 = self.cv1(x_1) 73 | 74 | x_2 = self.cv2(x) 75 | x_2 = self.cv3(x_2) 76 | 77 | return torch.cat([x_2, x_1], 1) 78 | 79 | class yolo_Backbone(nn.Module): 80 | def __init__(self, transition_channels, block_channels, n, phi, pretrained=False): 81 | super().__init__() 82 | #-----------------------------------------------# 83 | # 输入图片是640, 640, 3 84 | #-----------------------------------------------# 85 | ids = { 86 | 'l' : [-1, -3, -5, -6], 87 | 'x' : [-1, -3, -5, -7, -8], 88 | }[phi] 89 | self.stem = nn.Sequential( 90 | Conv(3, transition_channels, 3, 1), 91 | Conv(transition_channels, transition_channels * 2, 3, 2), 92 | Conv(transition_channels * 2, transition_channels * 2, 3, 1), 93 | ) 94 | self.dark2 = nn.Sequential( 95 | Conv(transition_channels * 2, transition_channels * 4, 3, 2), 96 | Block(transition_channels * 4, block_channels * 2, transition_channels * 8, n=n, ids=ids), 97 | ) 98 | self.dark3 = nn.Sequential( 99 | Transition(transition_channels * 8, transition_channels * 4), 100 | Block(transition_channels * 8, block_channels * 4, transition_channels * 16, n=n, ids=ids), 101 | ) 102 | self.dark4 = nn.Sequential( 103 | Transition(transition_channels * 16, transition_channels * 8), 104 | Block(transition_channels * 16, block_channels * 8, transition_channels * 32, n=n, ids=ids), 105 | ) 106 | self.dark5 = nn.Sequential( 107 | Transition(transition_channels * 32, transition_channels * 16), 108 | Block(transition_channels * 32, block_channels * 8, transition_channels * 64, n=n, ids=ids), 109 | ) 110 | 111 | if pretrained: 112 | url = { 113 | "l" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_backbone_weights.pth', 114 | "x" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_x_backbone_weights.pth', 115 | }[phi] 116 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", model_dir="./model_data") 117 | self.load_state_dict(checkpoint, strict=False) 118 | print("Load weights from " + url.split('/')[-1]) 119 | 120 | def forward(self, x): 121 | x = self.stem(x) 122 | x = self.dark2(x) 123 | #-----------------------------------------------# 124 | # dark3的输出为80, 80, 256,是一个有效特征层 125 | #-----------------------------------------------# 126 | x = self.dark3(x) 127 | feat1 = x 128 | #-----------------------------------------------# 129 | # dark4的输出为40, 40, 512,是一个有效特征层 130 | #-----------------------------------------------# 131 | x = self.dark4(x) 132 | feat2 = x 133 | #-----------------------------------------------# 134 | # dark5的输出为20, 20, 1024,是一个有效特征层 135 | #-----------------------------------------------# 136 | x = self.dark5(x) 137 | feat3 = x 138 | return feat3 139 | -------------------------------------------------------------------------------- /CenterNet/nets/centernet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | from torch import nn 5 | 6 | from CenterNet.nets.hourglass import * 7 | from CenterNet.nets.resnet50 import resnet50, resnet50_Decoder, resnet50_Head,yolov7 8 | 9 | 10 | class CenterNet_Resnet50(nn.Module): 11 | def __init__(self, num_classes = 20, pretrained = False): 12 | super(CenterNet_Resnet50, self).__init__() 13 | self.pretrained = pretrained 14 | # 512,512,3 -> 16,16,2048 15 | self.backbone = resnet50(pretrained = pretrained) 16 | # 16,16,2048 -> 128,128,64 17 | self.decoder = resnet50_Decoder(2048) 18 | #-----------------------------------------------------------------# 19 | # 对获取到的特征进行上采样,进行分类预测和回归预测 20 | # 128, 128, 64 -> 128, 128, 64 -> 128, 128, num_classes 21 | # -> 128, 128, 64 -> 128, 128, 2 22 | # -> 128, 128, 64 -> 128, 128, 2 23 | #-----------------------------------------------------------------# 24 | self.head = resnet50_Head(channel=64, num_classes=num_classes) 25 | 26 | self._init_weights() 27 | 28 | def freeze_backbone(self): 29 | for param in self.backbone.parameters(): 30 | param.requires_grad = False 31 | 32 | def unfreeze_backbone(self): 33 | for param in self.backbone.parameters(): 34 | param.requires_grad = True 35 | 36 | def _init_weights(self): 37 | if not self.pretrained: 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 41 | m.weight.data.normal_(0, math.sqrt(2. / n)) 42 | elif isinstance(m, nn.BatchNorm2d): 43 | m.weight.data.fill_(1) 44 | m.bias.data.zero_() 45 | 46 | self.head.cls_head[-1].weight.data.fill_(0) 47 | self.head.cls_head[-1].bias.data.fill_(-2.19) 48 | 49 | def forward(self, x): 50 | feat = self.backbone(x) 51 | return self.head(self.decoder(feat)) 52 | 53 | class CenterNet_HourglassNet(nn.Module): 54 | def __init__(self, heads, pretrained=False, num_stacks=2, n=5, cnv_dim=256, dims=[256, 256, 384, 384, 384, 512], modules = [2, 2, 2, 2, 2, 4]): 55 | super(CenterNet_HourglassNet, self).__init__() 56 | if pretrained: 57 | raise ValueError("HourglassNet has no pretrained model") 58 | 59 | self.nstack = num_stacks 60 | self.heads = heads 61 | 62 | curr_dim = dims[0] 63 | 64 | self.pre = nn.Sequential( 65 | conv2d(7, 3, 128, stride=2), 66 | residual(3, 128, 256, stride=2) 67 | ) 68 | 69 | self.kps = nn.ModuleList([ 70 | kp_module( 71 | n, dims, modules 72 | ) for _ in range(num_stacks) 73 | ]) 74 | 75 | self.cnvs = nn.ModuleList([ 76 | conv2d(3, curr_dim, cnv_dim) for _ in range(num_stacks) 77 | ]) 78 | 79 | self.inters = nn.ModuleList([ 80 | residual(3, curr_dim, curr_dim) for _ in range(num_stacks - 1) 81 | ]) 82 | 83 | self.inters_ = nn.ModuleList([ 84 | nn.Sequential( 85 | nn.Conv2d(curr_dim, curr_dim, (1, 1), bias=False), 86 | nn.BatchNorm2d(curr_dim) 87 | ) for _ in range(num_stacks - 1) 88 | ]) 89 | 90 | self.cnvs_ = nn.ModuleList([ 91 | nn.Sequential( 92 | nn.Conv2d(cnv_dim, curr_dim, (1, 1), bias=False), 93 | nn.BatchNorm2d(curr_dim) 94 | ) for _ in range(num_stacks - 1) 95 | ]) 96 | 97 | for head in heads.keys(): 98 | if 'hm' in head: 99 | module = nn.ModuleList([ 100 | nn.Sequential( 101 | conv2d(3, cnv_dim, curr_dim, with_bn=False), 102 | nn.Conv2d(curr_dim, heads[head], (1, 1)) 103 | ) for _ in range(num_stacks) 104 | ]) 105 | self.__setattr__(head, module) 106 | for heat in self.__getattr__(head): 107 | heat[-1].weight.data.fill_(0) 108 | heat[-1].bias.data.fill_(-2.19) 109 | else: 110 | module = nn.ModuleList([ 111 | nn.Sequential( 112 | conv2d(3, cnv_dim, curr_dim, with_bn=False), 113 | nn.Conv2d(curr_dim, heads[head], (1, 1)) 114 | ) for _ in range(num_stacks) 115 | ]) 116 | self.__setattr__(head, module) 117 | 118 | 119 | self.relu = nn.ReLU(inplace=True) 120 | 121 | def freeze_backbone(self): 122 | freeze_list = [self.pre, self.kps] 123 | for module in freeze_list: 124 | for param in module.parameters(): 125 | param.requires_grad = False 126 | 127 | def unfreeze_backbone(self): 128 | freeze_list = [self.pre, self.kps] 129 | for module in freeze_list: 130 | for param in module.parameters(): 131 | param.requires_grad = True 132 | 133 | def forward(self, image): 134 | # print('image shape', image.shape) 135 | inter = self.pre(image) 136 | outs = [] 137 | 138 | for ind in range(self.nstack): 139 | kp = self.kps[ind](inter) 140 | cnv = self.cnvs[ind](kp) 141 | 142 | if ind < self.nstack - 1: 143 | inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv) 144 | inter = self.relu(inter) 145 | inter = self.inters[ind](inter) 146 | 147 | out = {} 148 | for head in self.heads: 149 | out[head] = self.__getattr__(head)[ind](cnv) 150 | outs.append(out) 151 | return outs 152 | def autopad(k, p=None): 153 | if p is None: 154 | p = k // 2 if isinstance(k, int) else [x // 2 for x in k] 155 | return p 156 | 157 | class SiLU(nn.Module): 158 | @staticmethod 159 | def forward(x): 160 | return x * torch.sigmoid(x) 161 | class Conv(nn.Module): 162 | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=SiLU()): # ch_in, ch_out, kernel, stride, padding, groups 163 | super(Conv, self).__init__() 164 | self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) 165 | self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03) 166 | self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) 167 | 168 | def forward(self, x): 169 | return self.act(self.bn(self.conv(x))) 170 | 171 | def fuseforward(self, x): 172 | return self.act(self.conv(x)) 173 | 174 | 175 | class Block(nn.Module): 176 | def __init__(self, c1, c2, c3, n=4, e=1, ids=[0]): 177 | super(Block, self).__init__() 178 | c_ = int(c2 * e) 179 | 180 | self.ids = ids 181 | self.cv1 = Conv(c1, c_, 1, 1) 182 | self.cv2 = Conv(c1, c_, 1, 1) 183 | self.cv3 = nn.ModuleList( 184 | [Conv(c_ if i == 0 else c2, c2, 3, 1) for i in range(n)] 185 | ) 186 | self.cv4 = Conv(c_ * 2 + c2 * (len(ids) - 2), c3, 1, 1) 187 | 188 | def forward(self, x): 189 | x_1 = self.cv1(x) 190 | x_2 = self.cv2(x) 191 | 192 | x_all = [x_1, x_2] 193 | for i in range(len(self.cv3)): 194 | x_2 = self.cv3[i](x_2) 195 | x_all.append(x_2) 196 | 197 | out = self.cv4(torch.cat([x_all[id] for id in self.ids], 1)) 198 | return out 199 | 200 | 201 | class MP(nn.Module): 202 | def __init__(self, k=2): 203 | super(MP, self).__init__() 204 | self.m = nn.MaxPool2d(kernel_size=k, stride=k) 205 | 206 | def forward(self, x): 207 | return self.m(x) 208 | 209 | 210 | class Transition(nn.Module): 211 | def __init__(self, c1, c2): 212 | super(Transition, self).__init__() 213 | self.cv1 = Conv(c1, c2, 1, 1) 214 | self.cv2 = Conv(c1, c2, 1, 1) 215 | self.cv3 = Conv(c2, c2, 3, 2) 216 | 217 | self.mp = MP() 218 | 219 | def forward(self, x): 220 | x_1 = self.mp(x) 221 | x_1 = self.cv1(x_1) 222 | 223 | x_2 = self.cv2(x) 224 | x_2 = self.cv3(x_2) 225 | 226 | return torch.cat([x_2, x_1], 1) 227 | 228 | 229 | class CenterNet_yolov7(nn.Module): 230 | def __init__(self, num_classes=20, pretrained=False): 231 | super(CenterNet_yolov7, self).__init__() 232 | self.pretrained = pretrained 233 | # 512,512,3 -> 16,16,2048 234 | self.backbone = yolov7(pretrained=pretrained) 235 | # 16,16,2048 -> 128,128,64 236 | self.decoder = resnet50_Decoder(2048) 237 | # -----------------------------------------------------------------# 238 | # 对获取到的特征进行上采样,进行分类预测和回归预测 239 | # 128, 128, 64 -> 128, 128, 64 -> 128, 128, num_classes 240 | # -> 128, 128, 64 -> 128, 128, 2 241 | # -> 128, 128, 64 -> 128, 128, 2 242 | # -----------------------------------------------------------------# 243 | self.head = resnet50_Head(channel=64, num_classes=num_classes) 244 | 245 | self._init_weights() 246 | 247 | def freeze_backbone(self): 248 | for param in self.backbone.parameters(): 249 | param.requires_grad = False 250 | 251 | def unfreeze_backbone(self): 252 | for param in self.backbone.parameters(): 253 | param.requires_grad = True 254 | 255 | def _init_weights(self): 256 | if not self.pretrained: 257 | for m in self.modules(): 258 | if isinstance(m, nn.Conv2d): 259 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 260 | m.weight.data.normal_(0, math.sqrt(2. / n)) 261 | elif isinstance(m, nn.BatchNorm2d): 262 | m.weight.data.fill_(1) 263 | m.bias.data.zero_() 264 | 265 | self.head.cls_head[-1].weight.data.fill_(0) 266 | self.head.cls_head[-1].bias.data.fill_(-2.19) 267 | 268 | def forward(self, x): 269 | feat = self.backbone(x) 270 | return self.head(self.decoder(feat)) 271 | -------------------------------------------------------------------------------- /CenterNet/nets/centernet_training.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def focal_loss(pred, target): 9 | pred = pred.permute(0, 2, 3, 1) 10 | 11 | #-------------------------------------------------------------------------# 12 | # 找到每张图片的正样本和负样本 13 | # 一个真实框对应一个正样本 14 | # 除去正样本的特征点,其余为负样本 15 | #-------------------------------------------------------------------------# 16 | pos_inds = target.eq(1).float() 17 | neg_inds = target.lt(1).float() 18 | #-------------------------------------------------------------------------# 19 | # 正样本特征点附近的负样本的权值更小一些 20 | #-------------------------------------------------------------------------# 21 | neg_weights = torch.pow(1 - target, 4) 22 | 23 | pred = torch.clamp(pred, 1e-6, 1 - 1e-6) 24 | #-------------------------------------------------------------------------# 25 | # 计算focal loss。难分类样本权重大,易分类样本权重小。 26 | #-------------------------------------------------------------------------# 27 | pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds 28 | neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds 29 | 30 | #-------------------------------------------------------------------------# 31 | # 进行损失的归一化 32 | #-------------------------------------------------------------------------# 33 | num_pos = pos_inds.float().sum() 34 | pos_loss = pos_loss.sum() 35 | neg_loss = neg_loss.sum() 36 | 37 | if num_pos == 0: 38 | loss = -neg_loss 39 | else: 40 | loss = -(pos_loss + neg_loss) / num_pos 41 | return loss 42 | 43 | def reg_l1_loss(pred, target, mask): 44 | #--------------------------------# 45 | # 计算l1_loss 46 | #--------------------------------# 47 | pred = pred.permute(0,2,3,1) 48 | expand_mask = torch.unsqueeze(mask,-1).repeat(1,1,1,2) 49 | 50 | loss = F.l1_loss(pred * expand_mask, target * expand_mask, reduction='sum') 51 | loss = loss / (mask.sum() + 1e-4) 52 | return loss 53 | 54 | def weights_init(net, init_type='normal', init_gain=0.02): 55 | def init_func(m): 56 | classname = m.__class__.__name__ 57 | if hasattr(m, 'weight') and classname.find('Conv') != -1: 58 | if init_type == 'normal': 59 | torch.nn.init.normal_(m.weight.data, 0.0, init_gain) 60 | elif init_type == 'xavier': 61 | torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain) 62 | elif init_type == 'kaiming': 63 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 64 | elif init_type == 'orthogonal': 65 | torch.nn.init.orthogonal_(m.weight.data, gain=init_gain) 66 | else: 67 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 68 | elif classname.find('BatchNorm2d') != -1: 69 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 70 | torch.nn.init.constant_(m.bias.data, 0.0) 71 | print('initialize network with %s type' % init_type) 72 | net.apply(init_func) 73 | 74 | def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10): 75 | def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters): 76 | if iters <= warmup_total_iters: 77 | # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start 78 | lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start 79 | elif iters >= total_iters - no_aug_iter: 80 | lr = min_lr 81 | else: 82 | lr = min_lr + 0.5 * (lr - min_lr) * ( 83 | 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter)) 84 | ) 85 | return lr 86 | 87 | def step_lr(lr, decay_rate, step_size, iters): 88 | if step_size < 1: 89 | raise ValueError("step_size must above 1.") 90 | n = iters // step_size 91 | out_lr = lr * decay_rate ** n 92 | return out_lr 93 | 94 | if lr_decay_type == "cos": 95 | warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3) 96 | warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6) 97 | no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15) 98 | func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter) 99 | else: 100 | decay_rate = (min_lr / lr) ** (1 / (step_num - 1)) 101 | step_size = total_iters / step_num 102 | func = partial(step_lr, lr, decay_rate, step_size) 103 | 104 | return func 105 | 106 | def set_optimizer_lr(optimizer, lr_scheduler_func, epoch): 107 | lr = lr_scheduler_func(epoch) 108 | for param_group in optimizer.param_groups: 109 | param_group['lr'] = lr 110 | -------------------------------------------------------------------------------- /CenterNet/nets/hourglass.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | #-------------------------# 7 | # 卷积+标准化+激活函数 8 | #-------------------------# 9 | class conv2d(nn.Module): 10 | def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True): 11 | super(conv2d, self).__init__() 12 | 13 | pad = (k - 1) // 2 14 | self.conv = nn.Conv2d(inp_dim, out_dim, (k, k), padding=(pad, pad), stride=(stride, stride), bias=not with_bn) 15 | self.bn = nn.BatchNorm2d(out_dim) if with_bn else nn.Sequential() 16 | self.relu = nn.ReLU(inplace=True) 17 | 18 | def forward(self, x): 19 | conv = self.conv(x) 20 | bn = self.bn(conv) 21 | relu = self.relu(bn) 22 | return relu 23 | 24 | #-------------------------# 25 | # 残差结构 26 | #-------------------------# 27 | class residual(nn.Module): 28 | def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True): 29 | super(residual, self).__init__() 30 | 31 | self.conv1 = nn.Conv2d(inp_dim, out_dim, (3, 3), padding=(1, 1), stride=(stride, stride), bias=False) 32 | self.bn1 = nn.BatchNorm2d(out_dim) 33 | self.relu1 = nn.ReLU(inplace=True) 34 | 35 | self.conv2 = nn.Conv2d(out_dim, out_dim, (3, 3), padding=(1, 1), bias=False) 36 | self.bn2 = nn.BatchNorm2d(out_dim) 37 | 38 | self.skip = nn.Sequential( 39 | nn.Conv2d(inp_dim, out_dim, (1, 1), stride=(stride, stride), bias=False), 40 | nn.BatchNorm2d(out_dim) 41 | ) if stride != 1 or inp_dim != out_dim else nn.Sequential() 42 | self.relu = nn.ReLU(inplace=True) 43 | 44 | def forward(self, x): 45 | conv1 = self.conv1(x) 46 | bn1 = self.bn1(conv1) 47 | relu1 = self.relu1(bn1) 48 | 49 | conv2 = self.conv2(relu1) 50 | bn2 = self.bn2(conv2) 51 | 52 | skip = self.skip(x) 53 | return self.relu(bn2 + skip) 54 | 55 | def make_layer(k, inp_dim, out_dim, modules, **kwargs): 56 | layers = [residual(k, inp_dim, out_dim, **kwargs)] 57 | for _ in range(modules - 1): 58 | layers.append(residual(k, out_dim, out_dim, **kwargs)) 59 | return nn.Sequential(*layers) 60 | 61 | def make_hg_layer(k, inp_dim, out_dim, modules, **kwargs): 62 | layers = [residual(k, inp_dim, out_dim, stride=2)] 63 | for _ in range(modules - 1): 64 | layers += [residual(k, out_dim, out_dim)] 65 | return nn.Sequential(*layers) 66 | 67 | def make_layer_revr(k, inp_dim, out_dim, modules, **kwargs): 68 | layers = [] 69 | for _ in range(modules - 1): 70 | layers.append(residual(k, inp_dim, inp_dim, **kwargs)) 71 | layers.append(residual(k, inp_dim, out_dim, **kwargs)) 72 | return nn.Sequential(*layers) 73 | 74 | 75 | class kp_module(nn.Module): 76 | def __init__(self, n, dims, modules, **kwargs): 77 | super(kp_module, self).__init__() 78 | self.n = n 79 | 80 | curr_mod = modules[0] 81 | next_mod = modules[1] 82 | 83 | curr_dim = dims[0] 84 | next_dim = dims[1] 85 | 86 | # 将输入进来的特征层进行两次残差卷积,便于和后面的层进行融合 87 | self.up1 = make_layer( 88 | 3, curr_dim, curr_dim, curr_mod, **kwargs 89 | ) 90 | 91 | # 进行下采样 92 | self.low1 = make_hg_layer( 93 | 3, curr_dim, next_dim, curr_mod, **kwargs 94 | ) 95 | 96 | # 构建U形结构的下一层 97 | if self.n > 1 : 98 | self.low2 = kp_module( 99 | n - 1, dims[1:], modules[1:], **kwargs 100 | ) 101 | else: 102 | self.low2 = make_layer( 103 | 3, next_dim, next_dim, next_mod, **kwargs 104 | ) 105 | 106 | # 将U形结构下一层反馈上来的层进行残差卷积 107 | self.low3 = make_layer_revr( 108 | 3, next_dim, curr_dim, curr_mod, **kwargs 109 | ) 110 | # 将U形结构下一层反馈上来的层进行上采样 111 | self.up2 = nn.Upsample(scale_factor=2) 112 | 113 | def forward(self, x): 114 | up1 = self.up1(x) 115 | low1 = self.low1(x) 116 | low2 = self.low2(low1) 117 | low3 = self.low3(low2) 118 | up2 = self.up2(low3) 119 | outputs = up1 + up2 120 | return outputs 121 | 122 | -------------------------------------------------------------------------------- /CenterNet/nets/resnet50.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import math 4 | import torch.nn as nn 5 | from torch.hub import load_state_dict_from_url 6 | from CenterNet.nets.CenterNet_yolov7 import yolo_Backbone 7 | model_urls = { 8 | 'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', 9 | 'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth', 10 | 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', 11 | 'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth', 12 | 'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth', 13 | } 14 | 15 | class Bottleneck(nn.Module): 16 | expansion = 4 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None): 19 | super(Bottleneck, self).__init__() 20 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 23 | padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * 4) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | residual = x 33 | 34 | out = self.conv1(x) 35 | out = self.bn1(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv2(out) 39 | out = self.bn2(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv3(out) 43 | out = self.bn3(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | #-----------------------------------------------------------------# 54 | # 使用Renset50作为主干特征提取网络,最终会获得一个 55 | # 16x16x2048的有效特征层 56 | #-----------------------------------------------------------------# 57 | class ResNet(nn.Module): 58 | def __init__(self, block, layers, num_classes=1000): 59 | self.inplanes = 64 60 | super(ResNet, self).__init__() 61 | # 512,512,3 -> 256,256,64 62 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.relu = nn.ReLU(inplace=True) 65 | # 256x256x64 -> 128x128x64 66 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change 67 | 68 | # 128x128x64 -> 128x128x256 69 | self.layer1 = self._make_layer(block, 64, layers[0]) 70 | 71 | # 128x128x256 -> 64x64x512 72 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 73 | 74 | # 64x64x512 -> 32x32x1024 75 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 76 | 77 | # 32x32x1024 -> 16x16x2048 78 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 79 | 80 | self.avgpool = nn.AvgPool2d(7) 81 | self.fc = nn.Linear(512 * block.expansion, num_classes) 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 86 | m.weight.data.normal_(0, math.sqrt(2. / n)) 87 | elif isinstance(m, nn.BatchNorm2d): 88 | m.weight.data.fill_(1) 89 | m.bias.data.zero_() 90 | 91 | def _make_layer(self, block, planes, blocks, stride=1): 92 | downsample = None 93 | if stride != 1 or self.inplanes != planes * block.expansion: 94 | downsample = nn.Sequential( 95 | nn.Conv2d(self.inplanes, planes * block.expansion, 96 | kernel_size=1, stride=stride, bias=False), 97 | nn.BatchNorm2d(planes * block.expansion), 98 | ) 99 | 100 | layers = [] 101 | layers.append(block(self.inplanes, planes, stride, downsample)) 102 | self.inplanes = planes * block.expansion 103 | for i in range(1, blocks): 104 | layers.append(block(self.inplanes, planes)) 105 | 106 | return nn.Sequential(*layers) 107 | 108 | def forward(self, x): 109 | x = self.conv1(x) 110 | x = self.bn1(x) 111 | x = self.relu(x) 112 | x = self.maxpool(x) 113 | 114 | x = self.layer1(x) 115 | x = self.layer2(x) 116 | x = self.layer3(x) 117 | x = self.layer4(x) 118 | 119 | x = self.avgpool(x) 120 | x = x.view(x.size(0), -1) 121 | x = self.fc(x) 122 | 123 | return x 124 | 125 | def resnet50(pretrained = True): 126 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 127 | if pretrained: 128 | state_dict = load_state_dict_from_url(model_urls['resnet50'], model_dir = 'model_data/') 129 | model.load_state_dict(state_dict) 130 | #----------------------------------------------------------# 131 | # 获取特征提取部分 132 | #----------------------------------------------------------# 133 | features = list([model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2, model.layer3, model.layer4]) 134 | features = nn.Sequential(*features) 135 | return features 136 | def yolov7(pretrained = True): 137 | phi = 'l' 138 | transition_channels = {'l': 32, 'x': 40}[phi] 139 | block_channels = 32 140 | n = {'l': 4, 'x': 6}[phi] 141 | model = yolo_Backbone(transition_channels, block_channels, n, phi, pretrained=False) 142 | if pretrained: 143 | state_dict = load_state_dict_from_url(model_urls['yolov7'], model_dir = 'model_data/') 144 | model.load_state_dict(state_dict) 145 | #----------------------------------------------------------# 146 | # 获取特征提取部分 147 | #----------------------------------------------------------# 148 | features = list([model.stem, model.dark2, model.dark3, model.dark4, model.dark5]) 149 | features = nn.Sequential(*features) 150 | return features 151 | class resnet50_Decoder(nn.Module): 152 | def __init__(self, inplanes, bn_momentum=0.1): 153 | super(resnet50_Decoder, self).__init__() 154 | self.bn_momentum = bn_momentum 155 | self.inplanes = inplanes 156 | self.deconv_with_bias = False 157 | 158 | #----------------------------------------------------------# 159 | # 16,16,2048 -> 32,32,256 -> 64,64,128 -> 128,128,64 160 | # 利用ConvTranspose2d进行上采样。 161 | # 每次特征层的宽高变为原来的两倍。 162 | #----------------------------------------------------------# 163 | self.deconv_layers = self._make_deconv_layer( 164 | num_layers=3, 165 | num_filters=[256, 128, 64], 166 | num_kernels=[4, 4, 4], 167 | ) 168 | 169 | def _make_deconv_layer(self, num_layers, num_filters, num_kernels): 170 | layers = [] 171 | for i in range(num_layers): 172 | kernel = num_kernels[i] 173 | planes = num_filters[i] 174 | 175 | layers.append( 176 | nn.ConvTranspose2d( 177 | in_channels=self.inplanes, 178 | out_channels=planes, 179 | kernel_size=kernel, 180 | stride=2, 181 | padding=1, 182 | output_padding=0, 183 | bias=self.deconv_with_bias)) 184 | layers.append(nn.BatchNorm2d(planes, momentum=self.bn_momentum)) 185 | layers.append(nn.ReLU(inplace=True)) 186 | self.inplanes = planes 187 | return nn.Sequential(*layers) 188 | 189 | def forward(self, x): 190 | return self.deconv_layers(x) 191 | 192 | 193 | class resnet50_Head(nn.Module): 194 | def __init__(self, num_classes=80, channel=64, bn_momentum=0.1): 195 | super(resnet50_Head, self).__init__() 196 | #-----------------------------------------------------------------# 197 | # 对获取到的特征进行上采样,进行分类预测和回归预测 198 | # 128, 128, 64 -> 128, 128, 64 -> 128, 128, num_classes 199 | # -> 128, 128, 64 -> 128, 128, 2 200 | # -> 128, 128, 64 -> 128, 128, 2 201 | #-----------------------------------------------------------------# 202 | # 热力图预测部分 203 | self.cls_head = nn.Sequential( 204 | nn.Conv2d(64, channel, 205 | kernel_size=3, padding=1, bias=False), 206 | nn.BatchNorm2d(64, momentum=bn_momentum), 207 | nn.ReLU(inplace=True), 208 | nn.Conv2d(channel, num_classes, 209 | kernel_size=1, stride=1, padding=0)) 210 | # 宽高预测的部分 211 | self.wh_head = nn.Sequential( 212 | nn.Conv2d(64, channel, 213 | kernel_size=3, padding=1, bias=False), 214 | nn.BatchNorm2d(64, momentum=bn_momentum), 215 | nn.ReLU(inplace=True), 216 | nn.Conv2d(channel, 2, 217 | kernel_size=1, stride=1, padding=0)) 218 | 219 | # 中心点预测的部分 220 | self.reg_head = nn.Sequential( 221 | nn.Conv2d(64, channel, 222 | kernel_size=3, padding=1, bias=False), 223 | nn.BatchNorm2d(64, momentum=bn_momentum), 224 | nn.ReLU(inplace=True), 225 | nn.Conv2d(channel, 2, 226 | kernel_size=1, stride=1, padding=0)) 227 | 228 | def forward(self, x): 229 | hm = self.cls_head(x).sigmoid_() 230 | wh = self.wh_head(x) 231 | offset = self.reg_head(x) 232 | return hm, wh, offset 233 | 234 | -------------------------------------------------------------------------------- /CenterNet/predict.py: -------------------------------------------------------------------------------- 1 | #-----------------------------------------------------------------------# 2 | # predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能 3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。 4 | #-----------------------------------------------------------------------# 5 | import time 6 | 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from centernet import CenterNet 12 | 13 | if __name__ == "__main__": 14 | centernet = CenterNet() 15 | #----------------------------------------------------------------------------------------------------------# 16 | # mode用于指定测试的模式: 17 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释 18 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。 19 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。 20 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。 21 | # 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。 22 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。 23 | #----------------------------------------------------------------------------------------------------------# 24 | mode = "predict" 25 | #-------------------------------------------------------------------------# 26 | # crop 指定了是否在单张图片预测后对目标进行截取 27 | # count 指定了是否进行目标的计数 28 | # crop、count仅在mode='predict'时有效 29 | #-------------------------------------------------------------------------# 30 | crop = False 31 | count = False 32 | #----------------------------------------------------------------------------------------------------------# 33 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头 34 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。 35 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存 36 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。 37 | # video_fps 用于保存的视频的fps 38 | # 39 | # video_path、video_save_path和video_fps仅在mode='video'时有效 40 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。 41 | #----------------------------------------------------------------------------------------------------------# 42 | video_path = 0 43 | video_save_path = "" 44 | video_fps = 25.0 45 | #----------------------------------------------------------------------------------------------------------# 46 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。 47 | # fps_image_path 用于指定测试的fps图片 48 | # 49 | # test_interval和fps_image_path仅在mode='fps'有效 50 | #----------------------------------------------------------------------------------------------------------# 51 | test_interval = 100 52 | fps_image_path = r"D:\yolov7-pytorch-master\img\000001.jpg" 53 | #-------------------------------------------------------------------------# 54 | # dir_origin_path 指定了用于检测的图片的文件夹路径 55 | # dir_save_path 指定了检测完图片的保存路径 56 | # 57 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效 58 | #-------------------------------------------------------------------------# 59 | dir_origin_path = r"D:\yolov7-pytorch-master\img\000001.jpg" 60 | dir_save_path = "img_out/" 61 | #-------------------------------------------------------------------------# 62 | # heatmap_save_path 热力图的保存路径,默认保存在model_data下 63 | # 64 | # heatmap_save_path仅在mode='heatmap'有效 65 | #-------------------------------------------------------------------------# 66 | heatmap_save_path = "model_data/heatmap_vision.png" 67 | #-------------------------------------------------------------------------# 68 | # simplify 使用Simplify onnx 69 | # onnx_save_path 指定了onnx的保存路径 70 | #-------------------------------------------------------------------------# 71 | simplify = True 72 | onnx_save_path = "model_data/models.onnx" 73 | 74 | if mode == "predict": 75 | ''' 76 | 1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。 77 | 2、如果想要获得预测框的坐标,可以进入centernet.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。 78 | 3、如果想要利用预测框截取下目标,可以进入centernet.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值 79 | 在原图上利用矩阵的方式进行截取。 80 | 4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入centernet.detect_image函数,在绘图部分对predicted_class进行判断, 81 | 比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。 82 | ''' 83 | while True: 84 | img = input('Input image filename:') 85 | try: 86 | image = Image.open(img) 87 | except: 88 | print('Open Error! Try again!') 89 | continue 90 | else: 91 | r_image = centernet.detect_image(image, crop = crop, count=count) 92 | r_image.show() 93 | 94 | elif mode == "video": 95 | capture = cv2.VideoCapture(video_path) 96 | if video_save_path!="": 97 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 98 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))) 99 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size) 100 | 101 | ref, frame = capture.read() 102 | if not ref: 103 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。") 104 | 105 | fps = 0.0 106 | while(True): 107 | t1 = time.time() 108 | # 读取某一帧 109 | ref, frame = capture.read() 110 | if not ref: 111 | break 112 | # 格式转变,BGRtoRGB 113 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) 114 | # 转变成Image 115 | frame = Image.fromarray(np.uint8(frame)) 116 | # 进行检测 117 | frame = np.array(centernet.detect_image(frame)) 118 | # RGBtoBGR满足opencv显示格式 119 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR) 120 | 121 | fps = ( fps + (1./(time.time()-t1)) ) / 2 122 | print("fps= %.2f"%(fps)) 123 | frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) 124 | 125 | cv2.imshow("video",frame) 126 | c= cv2.waitKey(1) & 0xff 127 | if video_save_path!="": 128 | out.write(frame) 129 | 130 | if c==27: 131 | capture.release() 132 | break 133 | 134 | print("Video Detection Done!") 135 | capture.release() 136 | if video_save_path!="": 137 | print("Save processed video to the path :" + video_save_path) 138 | out.release() 139 | cv2.destroyAllWindows() 140 | 141 | elif mode == "fps": 142 | img = Image.open(fps_image_path) 143 | tact_time = centernet.get_FPS(img, test_interval) 144 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1') 145 | 146 | elif mode == "dir_predict": 147 | import os 148 | 149 | from tqdm import tqdm 150 | 151 | img_names = os.listdir(dir_origin_path) 152 | for img_name in tqdm(img_names): 153 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): 154 | image_path = os.path.join(dir_origin_path, img_name) 155 | image = Image.open(image_path) 156 | r_image = centernet.detect_image(image) 157 | if not os.path.exists(dir_save_path): 158 | os.makedirs(dir_save_path) 159 | r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0) 160 | 161 | elif mode == "heatmap": 162 | while True: 163 | img = input('Input image filename:') 164 | try: 165 | image = Image.open(img) 166 | except: 167 | print('Open Error! Try again!') 168 | continue 169 | else: 170 | centernet.detect_heatmap(image, heatmap_save_path) 171 | 172 | elif mode == "export_onnx": 173 | centernet.convert_to_onnx(simplify, onnx_save_path) 174 | 175 | else: 176 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.") 177 | -------------------------------------------------------------------------------- /CenterNet/summary.py: -------------------------------------------------------------------------------- 1 | #--------------------------------------------# 2 | # 该部分代码用于看网络参数 3 | #--------------------------------------------# 4 | import torch 5 | from thop import clever_format, profile 6 | from torchsummary import summary 7 | 8 | from nets.centernet import CenterNet_HourglassNet, CenterNet_Resnet50 9 | 10 | if __name__ == "__main__": 11 | input_shape = [512, 512] 12 | num_classes = 20 13 | 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | model = CenterNet_Resnet50().to(device) 16 | summary(model, (3, input_shape[0], input_shape[1])) 17 | 18 | dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device) 19 | flops, params = profile(model.to(device), (dummy_input, ), verbose=False) 20 | #--------------------------------------------------------# 21 | # flops * 2是因为profile没有将卷积作为两个operations 22 | # 有些论文将卷积算乘法、加法两个operations。此时乘2 23 | # 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2 24 | # 本代码选择乘2,参考YOLOX。 25 | #--------------------------------------------------------# 26 | flops = flops * 2 27 | flops, params = clever_format([flops, params], "%.3f") 28 | print('Total GFLOPS: %s' % (flops)) 29 | print('Total params: %s' % (params)) 30 | -------------------------------------------------------------------------------- /CenterNet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /CenterNet/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | #---------------------------------------------------------# 5 | # 将图像转换成RGB图像,防止灰度图在预测时报错。 6 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 7 | #---------------------------------------------------------# 8 | def cvtColor(image): 9 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: 10 | return image 11 | else: 12 | image = image.convert('RGB') 13 | return image 14 | 15 | #---------------------------------------------------# 16 | # 对输入图像进行resize 17 | #---------------------------------------------------# 18 | def resize_image(image, size, letterbox_image): 19 | iw, ih = image.size 20 | w, h = size 21 | if letterbox_image: 22 | scale = min(w/iw, h/ih) 23 | nw = int(iw*scale) 24 | nh = int(ih*scale) 25 | 26 | image = image.resize((nw,nh), Image.BICUBIC) 27 | new_image = Image.new('RGB', size, (128,128,128)) 28 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 29 | else: 30 | new_image = image.resize((w, h), Image.BICUBIC) 31 | return new_image 32 | 33 | #---------------------------------------------------# 34 | # 获得类 35 | #---------------------------------------------------# 36 | def get_classes(classes_path): 37 | with open(classes_path, encoding='utf-8') as f: 38 | class_names = f.readlines() 39 | class_names = [c.strip() for c in class_names] 40 | return class_names, len(class_names) 41 | 42 | def get_lr(optimizer): 43 | for param_group in optimizer.param_groups: 44 | return param_group['lr'] 45 | 46 | def preprocess_input(image): 47 | image = np.array(image,dtype = np.float32)[:, :,::-1] 48 | mean = [0.40789655, 0.44719303, 0.47026116] 49 | std = [0.2886383, 0.27408165, 0.27809834] 50 | return (image / 255. - mean) / std 51 | 52 | def show_config(**kwargs): 53 | print('Configurations:') 54 | print('-' * 70) 55 | print('|%25s | %40s|' % ('keys', 'values')) 56 | print('-' * 70) 57 | for key, value in kwargs.items(): 58 | print('|%25s | %40s|' % (str(key), str(value))) 59 | print('-' * 70) 60 | 61 | def download_weights(backbone, model_dir="./model_data"): 62 | import os 63 | from torch.hub import load_state_dict_from_url 64 | 65 | if backbone == "hourglass": 66 | raise ValueError("HourglassNet has no pretrained model") 67 | 68 | download_urls = { 69 | 'resnet50' : 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', 70 | } 71 | url = download_urls[backbone] 72 | 73 | if not os.path.exists(model_dir): 74 | os.makedirs(model_dir) 75 | load_state_dict_from_url(url, model_dir) -------------------------------------------------------------------------------- /CenterNet/utils/utils_fit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from nets.centernet_training import focal_loss, reg_l1_loss 5 | from tqdm import tqdm 6 | 7 | from utils.utils import get_lr 8 | 9 | 10 | def fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, backbone, save_period, save_dir, local_rank=0): 11 | total_r_loss = 0 12 | total_c_loss = 0 13 | total_loss = 0 14 | val_loss = 0 15 | 16 | if local_rank == 0: 17 | print('Start Train') 18 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 19 | model_train.train() 20 | for iteration, batch in enumerate(gen): 21 | if iteration >= epoch_step: 22 | break 23 | with torch.no_grad(): 24 | if cuda: 25 | batch = [ann.cuda(local_rank) for ann in batch] 26 | batch_images, batch_hms, batch_whs, batch_regs, batch_reg_masks = batch 27 | 28 | #----------------------# 29 | # 清零梯度 30 | #----------------------# 31 | optimizer.zero_grad() 32 | if not fp16: 33 | if backbone=="resnet50": 34 | hm, wh, offset = model_train(batch_images) 35 | c_loss = focal_loss(hm, batch_hms) 36 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks) 37 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks) 38 | 39 | loss = c_loss + wh_loss + off_loss 40 | 41 | total_loss += loss.item() 42 | total_c_loss += c_loss.item() 43 | total_r_loss += wh_loss.item() + off_loss.item() 44 | elif backbone == "yolov7_brackbone": 45 | hm, wh, offset = model_train(batch_images) 46 | c_loss = focal_loss(hm, batch_hms) 47 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks) 48 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks) 49 | 50 | loss = c_loss + wh_loss + off_loss 51 | 52 | total_loss += loss.item() 53 | total_c_loss += c_loss.item() 54 | total_r_loss += wh_loss.item() + off_loss.item() 55 | 56 | elif backbone=="hourglass": 57 | outputs = model_train(batch_images) 58 | loss = 0 59 | c_loss_all = 0 60 | r_loss_all = 0 61 | index = 0 62 | for output in outputs: 63 | hm, wh, offset = output["hm"].sigmoid(), output["wh"], output["reg"] 64 | c_loss = focal_loss(hm, batch_hms) 65 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks) 66 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks) 67 | 68 | loss += c_loss + wh_loss + off_loss 69 | 70 | c_loss_all += c_loss 71 | r_loss_all += wh_loss + off_loss 72 | index += 1 73 | total_loss += loss.item() / index 74 | total_c_loss += c_loss_all.item() / index 75 | total_r_loss += r_loss_all.item() / index 76 | loss.backward() 77 | optimizer.step() 78 | else: 79 | from torch.cuda.amp import autocast 80 | with autocast(): 81 | if backbone=="resnet50": 82 | hm, wh, offset = model_train(batch_images) 83 | c_loss = focal_loss(hm, batch_hms) 84 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks) 85 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks) 86 | 87 | loss = c_loss + wh_loss + off_loss 88 | 89 | total_loss += loss.item() 90 | total_c_loss += c_loss.item() 91 | total_r_loss += wh_loss.item() + off_loss.item() 92 | if backbone == "yolov7_brackbone": 93 | hm, wh, offset = model_train(batch_images) 94 | c_loss = focal_loss(hm, batch_hms) 95 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks) 96 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks) 97 | 98 | loss = c_loss + wh_loss + off_loss 99 | 100 | total_loss += loss.item() 101 | total_c_loss += c_loss.item() 102 | total_r_loss += wh_loss.item() + off_loss.item() 103 | 104 | elif backbone=="hourglass": 105 | outputs = model_train(batch_images) 106 | loss = 0 107 | c_loss_all = 0 108 | r_loss_all = 0 109 | index = 0 110 | for output in outputs: 111 | hm, wh, offset = output["hm"].sigmoid(), output["wh"], output["reg"] 112 | c_loss = focal_loss(hm, batch_hms) 113 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks) 114 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks) 115 | 116 | loss += c_loss + wh_loss + off_loss 117 | 118 | c_loss_all += c_loss 119 | r_loss_all += wh_loss + off_loss 120 | index += 1 121 | total_loss += loss.item() / index 122 | total_c_loss += c_loss_all.item() / index 123 | total_r_loss += r_loss_all.item() / index 124 | 125 | #----------------------# 126 | # 反向传播 127 | #----------------------# 128 | scaler.scale(loss).backward() 129 | scaler.step(optimizer) 130 | scaler.update() 131 | 132 | if local_rank == 0: 133 | pbar.set_postfix(**{'total_r_loss' : total_r_loss / (iteration + 1), 134 | 'total_c_loss' : total_c_loss / (iteration + 1), 135 | 'lr' : get_lr(optimizer)}) 136 | pbar.update(1) 137 | 138 | if local_rank == 0: 139 | pbar.close() 140 | print('Finish Train') 141 | print('Start Validation') 142 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 143 | 144 | model_train.eval() 145 | for iteration, batch in enumerate(gen_val): 146 | if iteration >= epoch_step_val: 147 | break 148 | 149 | with torch.no_grad(): 150 | if cuda: 151 | batch = [ann.cuda(local_rank) for ann in batch] 152 | batch_images, batch_hms, batch_whs, batch_regs, batch_reg_masks = batch 153 | 154 | if backbone=="resnet50": 155 | hm, wh, offset = model_train(batch_images) 156 | c_loss = focal_loss(hm, batch_hms) 157 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks) 158 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks) 159 | 160 | loss = c_loss + wh_loss + off_loss 161 | 162 | val_loss += loss.item() 163 | elif backbone=="yolov7_brackbone": 164 | hm, wh, offset = model_train(batch_images) 165 | c_loss = focal_loss(hm, batch_hms) 166 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks) 167 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks) 168 | 169 | loss = c_loss + wh_loss + off_loss 170 | 171 | val_loss += loss.item() 172 | elif backbone=="hourglass": 173 | outputs = model_train(batch_images) 174 | index = 0 175 | loss = 0 176 | for output in outputs: 177 | hm, wh, offset = output["hm"].sigmoid(), output["wh"], output["reg"] 178 | c_loss = focal_loss(hm, batch_hms) 179 | wh_loss = 0.1 * reg_l1_loss(wh, batch_whs, batch_reg_masks) 180 | off_loss = reg_l1_loss(offset, batch_regs, batch_reg_masks) 181 | 182 | loss += c_loss + wh_loss + off_loss 183 | index += 1 184 | val_loss += loss.item() / index 185 | 186 | if local_rank == 0: 187 | pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)}) 188 | pbar.update(1) 189 | 190 | if local_rank == 0: 191 | pbar.close() 192 | print('Finish Validation') 193 | loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val) 194 | eval_callback.on_epoch_end(epoch + 1, model_train) 195 | print('Epoch:'+ str(epoch+1) + '/' + str(Epoch)) 196 | print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val)) 197 | 198 | #-----------------------------------------------# 199 | # 保存权值 200 | #-----------------------------------------------# 201 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: 202 | torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f-val_loss%.3f.pth' % (epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val))) 203 | 204 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss): 205 | print('Save best model to best_epoch_weights.pth') 206 | torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth")) 207 | 208 | torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth")) -------------------------------------------------------------------------------- /CenterNet/vision_for_centernet.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | if __name__ == "__main__": 5 | height, width, feat_stride = 128,128,1 6 | 7 | fig = plt.figure() 8 | ax = fig.add_subplot(121) 9 | plt.ylim(-10,17) 10 | plt.xlim(-10,17) 11 | 12 | shift_x = np.arange(0, width * feat_stride, feat_stride) 13 | shift_y = np.arange(0, height * feat_stride, feat_stride) 14 | shift_x, shift_y = np.meshgrid(shift_x, shift_y) 15 | boxes = np.stack([shift_x,shift_y,shift_x,shift_y],axis=-1).reshape([-1,4]).astype(np.float32) 16 | plt.scatter(boxes[3:,0],boxes[3:,1]) 17 | plt.scatter(boxes[0:3,0],boxes[0:3,1],c="r") 18 | ax.invert_yaxis() 19 | 20 | ax = fig.add_subplot(122) 21 | plt.ylim(-10,17) 22 | plt.xlim(-10,17) 23 | 24 | shift_x = np.arange(0, width * feat_stride, feat_stride) 25 | shift_y = np.arange(0, height * feat_stride, feat_stride) 26 | shift_x, shift_y = np.meshgrid(shift_x, shift_y) 27 | boxes = np.stack([shift_x,shift_y,shift_x,shift_y],axis=-1).reshape([-1,4]).astype(np.float32) 28 | plt.scatter(shift_x,shift_y) 29 | 30 | heatmap = np.random.uniform(0,1,[128,128,80]).reshape([-1,80]) 31 | reg = np.random.uniform(0,1,[128,128,2]).reshape([-1,2]) 32 | wh = np.random.uniform(5,20,[128,128,2]).reshape([-1,2]) 33 | 34 | boxes[:,:2] = boxes[:,:2] + reg 35 | boxes[:,2:] = boxes[:,2:] + reg 36 | plt.scatter(boxes[0:3,0],boxes[0:3,1]) 37 | boxes[:,:2] = boxes[:,:2] - wh/2 38 | boxes[:,2:] = boxes[:,2:] + wh/2 39 | for i in [0,1,2]: 40 | rect = plt.Rectangle([boxes[i, 0],boxes[i, 1]], wh[i,0], wh[i,1], color="r",fill=False) 41 | ax.add_patch(rect) 42 | 43 | ax.invert_yaxis() 44 | 45 | plt.show() 46 | -------------------------------------------------------------------------------- /CenterNet/voc_annotation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import xml.etree.ElementTree as ET 4 | 5 | import numpy as np 6 | 7 | from utils.utils import get_classes 8 | 9 | #--------------------------------------------------------------------------------------------------------------------------------# 10 | # annotation_mode用于指定该文件运行时计算的内容 11 | # annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt 12 | # annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt 13 | # annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt 14 | #--------------------------------------------------------------------------------------------------------------------------------# 15 | annotation_mode = 0 16 | #-------------------------------------------------------------------# 17 | # 必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息 18 | # 与训练和预测所用的classes_path一致即可 19 | # 如果生成的2007_train.txt里面没有目标信息 20 | # 那么就是因为classes没有设定正确 21 | # 仅在annotation_mode为0和2的时候有效 22 | #-------------------------------------------------------------------# 23 | classes_path =r'model_data\people_classes.txt' 24 | #--------------------------------------------------------------------------------------------------------------------------------# 25 | # trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1 26 | # train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1 27 | # 仅在annotation_mode为0和1的时候有效 28 | #--------------------------------------------------------------------------------------------------------------------------------# 29 | trainval_percent = 0.8 30 | train_percent = 0.8 31 | #-------------------------------------------------------# 32 | # 指向VOC数据集所在的文件夹 33 | # 默认指向根目录下的VOC数据集 34 | #-------------------------------------------------------# 35 | VOCdevkit_path = 'VOCdevkit' 36 | 37 | VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')] 38 | classes, _ = get_classes(classes_path) 39 | 40 | #-------------------------------------------------------# 41 | # 统计目标数量 42 | #-------------------------------------------------------# 43 | photo_nums = np.zeros(len(VOCdevkit_sets)) 44 | nums = np.zeros(len(classes)) 45 | def convert_annotation(year, image_id, list_file): 46 | in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8') 47 | tree=ET.parse(in_file) 48 | root = tree.getroot() 49 | 50 | for obj in root.iter('object'): 51 | difficult = 0 52 | if obj.find('difficult')!=None: 53 | difficult = obj.find('difficult').text 54 | cls = obj.find('name').text 55 | if cls not in classes or int(difficult)==1: 56 | continue 57 | cls_id = classes.index(cls) 58 | xmlbox = obj.find('bndbox') 59 | b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text))) 60 | list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id)) 61 | 62 | nums[classes.index(cls)] = nums[classes.index(cls)] + 1 63 | 64 | if __name__ == "__main__": 65 | random.seed(0) 66 | if " " in os.path.abspath(VOCdevkit_path): 67 | raise ValueError("数据集存放的文件夹路径与图片名称中不可以存在空格,否则会影响正常的模型训练,请注意修改。") 68 | 69 | if annotation_mode == 0 or annotation_mode == 1: 70 | print("Generate txt in ImageSets.") 71 | xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007/Annotations') 72 | saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main') 73 | temp_xml = os.listdir(xmlfilepath) 74 | total_xml = [] 75 | for xml in temp_xml: 76 | if xml.endswith(".xml"): 77 | total_xml.append(xml) 78 | 79 | num = len(total_xml) 80 | list = range(num) 81 | tv = int(num*trainval_percent) 82 | tr = int(tv*train_percent) 83 | trainval= random.sample(list,tv) 84 | train = random.sample(trainval,tr) 85 | 86 | print("train and val size",tv) 87 | print("train size",tr) 88 | ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w') 89 | ftest = open(os.path.join(saveBasePath,'test.txt'), 'w') 90 | ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w') 91 | fval = open(os.path.join(saveBasePath,'val.txt'), 'w') 92 | 93 | for i in list: 94 | name=total_xml[i][:-4]+'\n' 95 | if i in trainval: 96 | ftrainval.write(name) 97 | if i in train: 98 | ftrain.write(name) 99 | else: 100 | fval.write(name) 101 | else: 102 | ftest.write(name) 103 | 104 | ftrainval.close() 105 | ftrain.close() 106 | fval.close() 107 | ftest.close() 108 | print("Generate txt in ImageSets done.") 109 | 110 | if annotation_mode == 0 or annotation_mode == 2: 111 | print("Generate 2007_train.txt and 2007_val.txt for train.") 112 | type_index = 0 113 | for year, image_set in VOCdevkit_sets: 114 | image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split() 115 | list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8') 116 | for image_id in image_ids: 117 | list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id)) 118 | 119 | convert_annotation(year, image_id, list_file) 120 | list_file.write('\n') 121 | photo_nums[type_index] = len(image_ids) 122 | type_index += 1 123 | list_file.close() 124 | print("Generate 2007_train.txt and 2007_val.txt for train done.") 125 | 126 | def printTable(List1, List2): 127 | for i in range(len(List1[0])): 128 | print("|", end=' ') 129 | for j in range(len(List1)): 130 | print(List1[j][i].rjust(int(List2[j])), end=' ') 131 | print("|", end=' ') 132 | print() 133 | 134 | str_nums = [str(int(x)) for x in nums] 135 | tableData = [ 136 | classes, str_nums 137 | ] 138 | colWidths = [0]*len(tableData) 139 | len1 = 0 140 | for i in range(len(tableData)): 141 | for j in range(len(tableData[i])): 142 | if len(tableData[i][j]) > colWidths[i]: 143 | colWidths[i] = len(tableData[i][j]) 144 | printTable(tableData, colWidths) 145 | 146 | if photo_nums[0] <= 500: 147 | print("训练集数量小于500,属于较小的数据量,请注意设置较大的训练世代(Epoch)以满足足够的梯度下降次数(Step)。") 148 | 149 | if np.sum(nums) == 0: 150 | print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!") 151 | print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!") 152 | print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!") 153 | print("(重要的事情说三遍)。") 154 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProbDet:基于概率决策融合的多模态目标检测 2 | 3 | --- 4 | 5 | ## 目录 6 | 1. [仓库更新 Top News](#仓库更新) 7 | 2. [训练步骤 How2train](#训练步骤) 8 | 3. [预测步骤 How2predict](#预测步骤) 9 | 4. [评估步骤 How2eval](#评估步骤) 10 | 5. [融合步骤 How2fusion](#融合步骤) 11 | 6. [参考资料 Reference](#Reference) 12 | 13 | ## Top News 14 | **`2023-01-10`**:**仓库创建,支持yolov7、CenterNet两种目标检测器** 15 | **`2023-01-14`**:**添加基于尺度不变特征和注意力机制的改进yolov7** 16 | 17 | ## 所需环境 18 | torch>=1.2.0 19 | 为了使用amp混合精度,推荐使用torch1.7.1以上的版本。 20 | 21 | 22 | ## 训练步骤(检测器) 23 | ### a、训练VOC07+12数据集 24 | 1. 数据集的准备 25 | **本文使用VOC格式进行训练,训练前需要下载好VOC07+12的数据集,解压后放在根目录** 26 | 27 | 2. 数据集的处理 28 | 修改voc_annotation.py里面的annotation_mode=2,运行voc_annotation.py生成根目录下的2007_train.txt和2007_val.txt。 29 | 30 | 3. 开始网络训练 31 | train.py的默认参数用于训练VOC数据集,直接运行train.py即可开始训练。 32 | 33 | 4. 训练结果预测 34 | 训练结果预测需要用到两个文件,分别是yolo.py和predict.py。我们首先需要去yolo.py里面修改model_path以及classes_path,这两个参数必须要修改。 35 | **model_path指向训练好的权值文件,在logs文件夹里。 36 | classes_path指向检测类别所对应的txt。** 37 | 完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。 38 | 39 | ### b、训练自己的数据集 40 | 1. 数据集的准备 41 | **本文使用VOC格式进行训练,训练前需要自己制作好数据集,** 42 | 训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中。 43 | 训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。 44 | 45 | 2. 数据集的处理 46 | 在完成数据集的摆放之后,我们需要利用voc_annotation.py获得训练用的2007_train.txt和2007_val.txt。 47 | 修改voc_annotation.py里面的参数。第一次训练可以仅修改classes_path,classes_path用于指向检测类别所对应的txt。 48 | 训练自己的数据集时,可以自己建立一个cls_classes.txt,里面写自己所需要区分的类别。 49 | model_data/cls_classes.txt文件内容为: 50 | ```python 51 | cat 52 | dog 53 | ... 54 | ``` 55 | 修改voc_annotation.py中的classes_path,使其对应cls_classes.txt,并运行voc_annotation.py。 56 | 57 | 3. 开始网络训练 58 | **训练的参数较多,均在train.py中,大家可以在下载库后仔细看注释,其中最重要的部分依然是train.py里的classes_path。** 59 | **classes_path用于指向检测类别所对应的txt,这个txt和voc_annotation.py里面的txt一样!训练自己的数据集必须要修改!** 60 | 修改完classes_path后就可以运行train.py开始训练了,在训练多个epoch后,权值会生成在logs文件夹中。 61 | 62 | 4. 训练结果预测 63 | 训练结果预测需要用到两个文件,分别是yolo.py和predict.py。在yolo.py里面修改model_path以及classes_path。 64 | **model_path指向训练好的权值文件,在logs文件夹里。 65 | classes_path指向检测类别所对应的txt。** 66 | 完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。 67 | 68 | ## 预测步骤(检测器) 69 | ### a、使用预训练权重 70 | 1. 下载完库后解压,在百度网盘下载权值,放入model_data,运行predict.py,输入 71 | ```python 72 | img/street.jpg 73 | ``` 74 | 2. 在predict.py里面进行设置可以进行fps测试和video视频检测。 75 | ### b、使用自己训练的权重 76 | 1. 按照训练步骤训练。 77 | 2. 在yolo.py文件里面,在如下部分修改model_path和classes_path使其对应训练好的文件;**model_path对应logs文件夹下面的权值文件,classes_path是model_path对应分的类**。 78 | ```python 79 | _defaults = { 80 | #--------------------------------------------------------------------------# 81 | # 使用自己训练好的模型进行预测一定要修改model_path和classes_path! 82 | # model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt 83 | # 84 | # 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。 85 | # 验证集损失较低不代表mAP较高,仅代表该权值在验证集上泛化性能较好。 86 | # 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改 87 | #--------------------------------------------------------------------------# 88 | "model_path" : 'model_data/yolov7_weights.pth', 89 | "classes_path" : 'model_data/coco_classes.txt', 90 | #---------------------------------------------------------------------# 91 | # anchors_path代表先验框对应的txt文件,一般不修改。 92 | # anchors_mask用于帮助代码找到对应的先验框,一般不修改。 93 | #---------------------------------------------------------------------# 94 | "anchors_path" : 'model_data/yolo_anchors.txt', 95 | "anchors_mask" : [[6, 7, 8], [3, 4, 5], [0, 1, 2]], 96 | #---------------------------------------------------------------------# 97 | # 输入图片的大小,必须为32的倍数。 98 | #---------------------------------------------------------------------# 99 | "input_shape" : [640, 640], 100 | #------------------------------------------------------# 101 | # 所使用到的yolov7的版本,本仓库一共提供两个: 102 | # l : 对应yolov7 103 | # x : 对应yolov7_x 104 | #------------------------------------------------------# 105 | "phi" : 'l', 106 | #---------------------------------------------------------------------# 107 | # 只有得分大于置信度的预测框会被保留下来 108 | #---------------------------------------------------------------------# 109 | "confidence" : 0.5, 110 | #---------------------------------------------------------------------# 111 | # 非极大抑制所用到的nms_iou大小 112 | #---------------------------------------------------------------------# 113 | "nms_iou" : 0.3, 114 | #---------------------------------------------------------------------# 115 | # 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize, 116 | # 在多次测试后,发现关闭letterbox_image直接resize的效果更好 117 | #---------------------------------------------------------------------# 118 | "letterbox_image" : True, 119 | #-------------------------------# 120 | # 是否使用Cuda 121 | # 没有GPU可以设置成False 122 | #-------------------------------# 123 | "cuda" : True, 124 | } 125 | ``` 126 | 3. 运行predict.py,输入 127 | ```python 128 | img/street.jpg 129 | ``` 130 | 4. 在predict.py里面进行设置可以进行fps测试和video视频检测。 131 | 132 | ## 评估步骤 (检测器) 133 | ### a、评估VOC07+12的测试集 134 | 1. 本文使用VOC格式进行评估。VOC07+12已经划分好了测试集,无需利用voc_annotation.py生成ImageSets文件夹下的txt。 135 | 2. 在yolo.py里面修改model_path以及classes_path。**model_path指向训练好的权值文件,在logs文件夹里。classes_path指向检测类别所对应的txt。** 136 | 3. 运行get_map.py即可获得评估结果,评估结果会保存在map_out文件夹中。 137 | 138 | ### b、评估自己的数据集 139 | 1. 本文使用VOC格式进行评估。 140 | 2. 如果在训练前已经运行过voc_annotation.py文件,代码会自动将数据集划分成训练集、验证集和测试集。如果想要修改测试集的比例,可以修改voc_annotation.py文件下的trainval_percent。trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1。train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1。 141 | 3. 利用voc_annotation.py划分测试集后,前往get_map.py文件修改classes_path,classes_path用于指向检测类别所对应的txt,这个txt和训练时的txt一样。评估自己的数据集必须要修改。 142 | 4. 在yolo.py里面修改model_path以及classes_path。**model_path指向训练好的权值文件,在logs文件夹里。classes_path指向检测类别所对应的txt。** 143 | 5. 运行get_map.py即可获得评估结果,评估结果会保存在map_out文件夹中。 144 | 145 | ## 融合步骤(ProbEn) 146 | 1. 参考第4节检测器预测步骤,修改两种检测器配置 147 | 2. 选择predict_with_probEn.py中的模式:图片、视频、目录 148 | 3. 更改对应的路径,详情见注释 149 | 4. 运行predict_with_probEn.py 150 | 151 | ## Reference 152 | https://github.com/WongKinYiu/yolov7 153 | 154 | https://github.com/bubbliiiing/yolov7-pytorch 155 | 156 | https://github.com/bubbliiiing/centernet-pytorch 157 | 158 | https://github.com/Jamie725/Multimodal-Object-Detection-via-Probabilistic-Ensembling 159 | -------------------------------------------------------------------------------- /detector_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from yolov7.yolo import YOLO 8 | from CenterNet.centernet import CenterNet 9 | 10 | if __name__ == '__main__': 11 | centernet = CenterNet() 12 | yolo = YOLO() 13 | crop = False 14 | count = False 15 | 16 | # img = input('Input image filename:') 17 | img = 'img/1.jpg' 18 | try: 19 | image1 = Image.open(img) 20 | image2 = Image.open(img) 21 | except: 22 | print('Open Error! Try again!') 23 | else: 24 | print("-----------------") 25 | print("yolov7:") 26 | #r_image = yolo.detect_image(image1, crop=crop, count=count) 27 | #r_image.show() 28 | # tact_time = yolo.get_FPS(image1, test_interval=100) 29 | # print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1') 30 | dets_yolo, scores_yolo = yolo.detect_image_dets(image1) 31 | t1 = time.time() 32 | for _ in range(100): 33 | dets_yolo, scores_yolo = yolo.detect_image_dets(image1) 34 | t2 = time.time() 35 | tact_time = (t2 - t1) / 100 36 | print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1') 37 | # print(dets_yolo) 38 | # print(scores_yolo) 39 | 40 | print("-----------------") 41 | print("centernet:") 42 | # r_image2 = centernet.detect_image(image2, crop = crop, count=count) 43 | # r_image2.show() 44 | # tact_time = centernet.get_FPS(image2, test_interval=100) 45 | # print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1') 46 | dets_centernet, scores_centernet = centernet.detect_image_dets(image2) 47 | t1 = time.time() 48 | for _ in range(100): 49 | dets_centernet, scores_centernet = centernet.detect_image_dets(image2) 50 | t2 = time.time() 51 | tact_time = (t2 - t1) / 100 52 | print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1') 53 | # print(dets_centernet) 54 | #print(scores_centernet) -------------------------------------------------------------------------------- /get_fusion_FNPI.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xml.etree.ElementTree as ET 3 | 4 | from PIL import Image 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | from yolov7.utils.utils import get_classes 9 | from yolov7.utils.utils_FNPI import get_FNPI 10 | from CenterNet.centernet import CenterNet 11 | from ProbEn import ProbEn 12 | from yolov7.yolo import YOLO 13 | 14 | if __name__ == "__main__": 15 | #------------------------------------------------------------------------------------------------------------------# 16 | # map_mode用于指定该文件运行时计算的内容 17 | # map_mode为0代表整个FNPI计算流程,包括获得预测结果、获得真实框、计算FNPI。 18 | # map_mode为1代表仅仅获得预测结果。 19 | # map_mode为2代表仅仅获得真实框。 20 | # map_mode为3代表仅仅计算FNPI。 21 | #-------------------------------------------------------------------------------------------------------------------# 22 | map_mode = 0 23 | #--------------------------------------------------------------------------------------# 24 | # 此处的classes_path用于指定需要测量FNPI的类别 25 | # 一般情况下与训练和预测所用的classes_path一致即可 26 | #--------------------------------------------------------------------------------------# 27 | classes_path = r'E:\pythonProject\object-detection\ProbEn-master\yolov7\model_data\voc_classes.txt' 28 | #--------------------------------------------------------------------------------------# 29 | # FNPI_IOU作为判定预测框与真实框相匹配(即真实框所对应的目标被检测成功)的条件 30 | # 只有大于FNPI_IOU值才算检测成功 31 | #--------------------------------------------------------------------------------------# 32 | FNPI_IOU = 0.5 33 | #--------------------------------------------------------------------------------------# 34 | # confidence的设置与计算map时的设置情况不一样。 35 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,因此,map计算时的confidence的值应当设置的尽量小进而获得全部可能的预测框。 36 | # 而计算FNPI设置的置信度confidence应该与预测时的置信度一致,只有得分大于置信度的预测框会被保留下来 37 | #--------------------------------------------------------------------------------------# 38 | confidence_yolo = 0.5 39 | confidence_centernet = 0.5 40 | #--------------------------------------------------------------------------------------# 41 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。 42 | # 该值也应该与预测时设置的nms_iou一致。 43 | #--------------------------------------------------------------------------------------# 44 | nms_iou = 0.3 45 | #-------------------------------------------------------# 46 | # 指向VOC数据集所在的文件夹 47 | # 默认指向根目录下的VOC数据集 48 | #-------------------------------------------------------# 49 | VOCdevkit_path = r'E:\pythonProject\object-detection\yolov7-pytorch-master\VOCdevkit' 50 | #-------------------------------------------------------# 51 | # 结果输出的文件夹,默认为map_out 52 | #-------------------------------------------------------# 53 | FNPI_out_path = 'FNPI_out/FNPI_out_ProbEn_VOC4' 54 | 55 | image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split() 56 | 57 | if not os.path.exists(FNPI_out_path): 58 | os.makedirs(FNPI_out_path) 59 | if not os.path.exists(os.path.join(FNPI_out_path, 'ground-truth')): 60 | os.makedirs(os.path.join(FNPI_out_path, 'ground-truth')) 61 | if not os.path.exists(os.path.join(FNPI_out_path, 'detection-results')): 62 | os.makedirs(os.path.join(FNPI_out_path, 'detection-results')) 63 | 64 | class_names, _ = get_classes(classes_path) 65 | 66 | if map_mode == 0 or map_mode == 1: 67 | print("Load model.") 68 | yolo = YOLO(confidence=confidence_yolo, nms_iou=nms_iou) 69 | centernet = CenterNet(confidence=confidence_centernet, nms_iou=nms_iou) 70 | proben = ProbEn() 71 | print("Load model done.") 72 | 73 | print("Get predict result.") 74 | for image_id in tqdm(image_ids): 75 | # image_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/visible/" + image_id + ".jpg") 76 | image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg") 77 | image = Image.open(image_path) 78 | 79 | dets_yolo, scores_yolo = yolo.detect_image_dets(image) 80 | dets_yolo = np.asarray(dets_yolo) 81 | scores_yolo = np.asarray(scores_yolo) 82 | 83 | dets_centernet, scores_centernet = centernet.detect_image_dets(image) 84 | dets_centernet = np.asarray(dets_centernet) 85 | scores_centernet = np.asarray(scores_centernet) 86 | 87 | proben.get_map_txt(image_id, class_names, FNPI_out_path, dets_yolo, scores_yolo, dets_centernet, scores_centernet) 88 | print("Get predict result done.") 89 | 90 | if map_mode == 0 or map_mode == 2: 91 | print("Get ground truth result.") 92 | for image_id in tqdm(image_ids): 93 | with open(os.path.join(FNPI_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: 94 | # root = ET.parse(os.path.join(VOCdevkit_path, "kaist_wash_annotation_test/"+image_id+".xml")).getroot() 95 | root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/"+image_id+".xml")).getroot() 96 | for obj in root.findall('object'): 97 | difficult_flag = False 98 | if obj.find('difficult')!=None: 99 | difficult = obj.find('difficult').text 100 | if int(difficult)==1: 101 | difficult_flag = True 102 | obj_name = obj.find('name').text 103 | if obj_name not in class_names: 104 | continue 105 | bndbox = obj.find('bndbox') 106 | left = bndbox.find('xmin').text 107 | top = bndbox.find('ymin').text 108 | right = bndbox.find('xmax').text 109 | bottom = bndbox.find('ymax').text 110 | 111 | if difficult_flag: 112 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom)) 113 | else: 114 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 115 | print("Get ground truth result done.") 116 | 117 | if map_mode == 0 or map_mode == 3: 118 | print("Get map.") 119 | get_FNPI(FNPI_IOU, True, path = FNPI_out_path) 120 | print("Get map done.") -------------------------------------------------------------------------------- /get_fusion_FNPI_onlyYolo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xml.etree.ElementTree as ET 3 | 4 | from PIL import Image 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | from yolov7.utils.utils import get_classes 9 | from yolov7.utils.utils_FNPI import get_FNPI 10 | from ProbEn import ProbEn 11 | from yolov7.yolo_RGB import YOLO_RGB 12 | from yolov7.yolo_T import YOLO_T 13 | 14 | if __name__ == "__main__": 15 | #------------------------------------------------------------------------------------------------------------------# 16 | # map_mode用于指定该文件运行时计算的内容 17 | # map_mode为0代表整个FNPI计算流程,包括获得预测结果、获得真实框、计算FNPI。 18 | # map_mode为1代表仅仅获得预测结果。 19 | # map_mode为2代表仅仅获得真实框。 20 | # map_mode为3代表仅仅计算FNPI。 21 | #-------------------------------------------------------------------------------------------------------------------# 22 | map_mode = 3 23 | #--------------------------------------------------------------------------------------# 24 | # 此处的classes_path用于指定需要测量FNPI的类别 25 | # 一般情况下与训练和预测所用的classes_path一致即可 26 | #--------------------------------------------------------------------------------------# 27 | classes_path = r'E:\pythonProject\object-detection\ProbEn-master\yolov7\model_data\people_classes_KAIST.txt' 28 | #--------------------------------------------------------------------------------------# 29 | # FNPI_IOU作为判定预测框与真实框相匹配(即真实框所对应的目标被检测成功)的条件 30 | # 只有大于FNPI_IOU值才算检测成功 31 | #--------------------------------------------------------------------------------------# 32 | FNPI_IOU = 0.5 33 | #--------------------------------------------------------------------------------------# 34 | # confidence的设置与计算map时的设置情况不一样。 35 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,因此,map计算时的confidence的值应当设置的尽量小进而获得全部可能的预测框。 36 | # 而计算FNPI设置的置信度confidence应该与预测时的置信度一致,只有得分大于置信度的预测框会被保留下来 37 | #--------------------------------------------------------------------------------------# 38 | confidence_RGB = 0.5 39 | confidence_T = 0.5 40 | #--------------------------------------------------------------------------------------# 41 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。 42 | # 该值也应该与预测时设置的nms_iou一致。 43 | #--------------------------------------------------------------------------------------# 44 | nms_iou = 0.3 45 | #-------------------------------------------------------# 46 | # 指向VOC数据集所在的文件夹 47 | # 默认指向根目录下的VOC数据集 48 | #-------------------------------------------------------# 49 | VOCdevkit_path = r'D:\KAIST数据集\重新标注的kaist' 50 | #-------------------------------------------------------# 51 | # 结果输出的文件夹,默认为map_out 52 | #-------------------------------------------------------# 53 | FNPI_out_path = 'FNPI_out/FNPI_out_ProbEn_YOLO' 54 | 55 | image_ids = open(os.path.join(VOCdevkit_path, "kaist_wash_picture_test/test.txt")).read().strip().split() 56 | 57 | if not os.path.exists(FNPI_out_path): 58 | os.makedirs(FNPI_out_path) 59 | if not os.path.exists(os.path.join(FNPI_out_path, 'ground-truth')): 60 | os.makedirs(os.path.join(FNPI_out_path, 'ground-truth')) 61 | if not os.path.exists(os.path.join(FNPI_out_path, 'detection-results')): 62 | os.makedirs(os.path.join(FNPI_out_path, 'detection-results')) 63 | 64 | class_names, _ = get_classes(classes_path) 65 | 66 | if map_mode == 0 or map_mode == 1: 67 | print("Load model.") 68 | yolo_rgb = YOLO_RGB(confidence=confidence_RGB, nms_iou=nms_iou) 69 | yolo_T = YOLO_T(confidence=confidence_T, nms_iou=nms_iou) 70 | proben = ProbEn() 71 | print("Load model done.") 72 | 73 | print("Get predict result.") 74 | for image_id in tqdm(image_ids): 75 | image_RGB_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/visible/" + image_id + ".jpg") 76 | image_T_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/lwir/" + image_id + ".jpg") 77 | 78 | image_rgb = Image.open(image_RGB_path) 79 | image_T = Image.open(image_T_path) 80 | 81 | dets_rgb, scores_rgb = yolo_rgb.detect_image_dets(image_rgb) 82 | dets_rgb = np.asarray(dets_rgb) 83 | scores_rgb = np.asarray(scores_rgb) 84 | 85 | dets_T, scores_T = yolo_T.detect_image_dets(image_T) 86 | dets_T = np.asarray(dets_T) 87 | scores_T = np.asarray(scores_T) 88 | 89 | proben.get_map_txt(image_id, class_names, FNPI_out_path, dets_rgb, scores_rgb, dets_T, scores_T) 90 | print("Get predict result done.") 91 | 92 | if map_mode == 0 or map_mode == 2: 93 | print("Get ground truth result.") 94 | for image_id in tqdm(image_ids): 95 | with open(os.path.join(FNPI_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: 96 | root = ET.parse(os.path.join(VOCdevkit_path, "kaist_wash_annotation_test/"+image_id+".xml")).getroot() 97 | for obj in root.findall('object'): 98 | difficult_flag = False 99 | if obj.find('difficult')!=None: 100 | difficult = obj.find('difficult').text 101 | if int(difficult)==1: 102 | difficult_flag = True 103 | obj_name = obj.find('name').text 104 | if obj_name not in class_names: 105 | continue 106 | bndbox = obj.find('bndbox') 107 | left = bndbox.find('xmin').text 108 | top = bndbox.find('ymin').text 109 | right = bndbox.find('xmax').text 110 | bottom = bndbox.find('ymax').text 111 | 112 | if difficult_flag: 113 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom)) 114 | else: 115 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 116 | print("Get ground truth result done.") 117 | 118 | if map_mode == 0 or map_mode == 3: 119 | print("Get map.") 120 | get_FNPI(FNPI_IOU, True, path = FNPI_out_path) 121 | print("Get map done.") -------------------------------------------------------------------------------- /get_fusion_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xml.etree.ElementTree as ET 3 | 4 | from PIL import Image 5 | from tqdm import tqdm 6 | import numpy as np 7 | from yolov7.utils.utils import get_classes 8 | from yolov7.utils.utils_map import get_coco_map, get_map 9 | from CenterNet.centernet import CenterNet 10 | from ProbEn import ProbEn 11 | from yolov7.yolo import YOLO 12 | 13 | if __name__ == "__main__": 14 | ''' 15 | Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。 16 | 默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。 17 | 18 | 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值 19 | 因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框, 20 | ''' 21 | # ------------------------------------------------------------------------------------------------------------------# 22 | # map_mode用于指定该文件运行时计算的内容 23 | # map_mode为0代表整个map计算流程,包括获得预测结果、获得真实框、计算VOC_map。 24 | # map_mode为1代表仅仅获得预测结果。 25 | # map_mode为2代表仅仅获得真实框。 26 | # map_mode为3代表仅仅计算VOC_map。 27 | # map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行 28 | # -------------------------------------------------------------------------------------------------------------------# 29 | map_mode = 0 30 | # --------------------------------------------------------------------------------------# 31 | # 此处的classes_path用于指定需要测量VOC_map的类别 32 | # 一般情况下与训练和预测所用的classes_path一致即可 33 | # --------------------------------------------------------------------------------------# 34 | classes_path = r'D:\Deep_Learning_folds\ProbEn\yolov7\model_data\voc_classes.txt' 35 | # --------------------------------------------------------------------------------------# 36 | # MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。 37 | # 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。 38 | # 39 | # 当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。 40 | # 因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低, 41 | # --------------------------------------------------------------------------------------# 42 | MINOVERLAP = 0.5 43 | # --------------------------------------------------------------------------------------# 44 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP 45 | # 因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。 46 | # 47 | # 该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。 48 | # 想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。 49 | # --------------------------------------------------------------------------------------# 50 | confidence_yolo = 0.001 51 | confidence_centernet = 0.04 52 | # --------------------------------------------------------------------------------------# 53 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。 54 | # 55 | # 该值一般不调整。 56 | # --------------------------------------------------------------------------------------# 57 | nms_iou = 0.5 58 | # ---------------------------------------------------------------------------------------------------------------# 59 | # Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。 60 | # 61 | # 默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。 62 | # 因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。 63 | # 这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。 64 | # ---------------------------------------------------------------------------------------------------------------# 65 | score_threhold = 0.5 66 | # -------------------------------------------------------# 67 | # map_vis用于指定是否开启VOC_map计算的可视化 68 | # -------------------------------------------------------# 69 | map_vis = False 70 | # -------------------------------------------------------# 71 | # 指向VOC数据集所在的文件夹 72 | # 默认指向根目录下的VOC数据集 73 | # -------------------------------------------------------# 74 | VOCdevkit_path = r'D:\Deep_Learning_folds\ProbEn\yolov7\VOCdevkit' 75 | # -------------------------------------------------------# 76 | # 结果输出的文件夹,默认为map_out 77 | # -------------------------------------------------------# 78 | map_out_path = 'map_out' 79 | 80 | image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split() 81 | 82 | if not os.path.exists(map_out_path): 83 | os.makedirs(map_out_path) 84 | if not os.path.exists(os.path.join(map_out_path, 'ground-truth')): 85 | os.makedirs(os.path.join(map_out_path, 'ground-truth')) 86 | if not os.path.exists(os.path.join(map_out_path, 'detection-results')): 87 | os.makedirs(os.path.join(map_out_path, 'detection-results')) 88 | if not os.path.exists(os.path.join(map_out_path, 'images-optional')): 89 | os.makedirs(os.path.join(map_out_path, 'images-optional')) 90 | 91 | class_names, _ = get_classes(classes_path) 92 | 93 | if map_mode == 0 or map_mode == 1: 94 | print("Load model.") 95 | yolo = YOLO(confidence=confidence_yolo, nms_iou=nms_iou) 96 | centernet = CenterNet(confidence=confidence_centernet, nms_iou=nms_iou) 97 | proben = ProbEn() 98 | print("Load model done.") 99 | 100 | print("Get predict result.") 101 | for image_id in tqdm(image_ids): 102 | image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/" + image_id + ".jpg") 103 | image = Image.open(image_path) 104 | if map_vis: 105 | image.save(os.path.join(map_out_path, "images-optional/" + image_id + ".jpg")) 106 | 107 | dets_yolo, scores_yolo = yolo.detect_image_dets(image) 108 | dets_yolo = np.asarray(dets_yolo) 109 | scores_yolo = np.asarray(scores_yolo) 110 | 111 | dets_centernet, scores_centernet = centernet.detect_image_dets(image) 112 | dets_centernet = np.asarray(dets_centernet) 113 | scores_centernet = np.asarray(scores_centernet) 114 | 115 | proben.get_map_txt(image_id, class_names, map_out_path, dets_yolo, scores_yolo, dets_centernet, scores_centernet) 116 | print("Get predict result done.") 117 | 118 | if map_mode == 0 or map_mode == 2: 119 | print("Get ground truth result.") 120 | for image_id in tqdm(image_ids): 121 | with open(os.path.join(map_out_path, "ground-truth/" + image_id + ".txt"), "w") as new_f: 122 | root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/" + image_id + ".xml")).getroot() 123 | for obj in root.findall('object'): 124 | difficult_flag = False 125 | if obj.find('difficult') != None: 126 | difficult = obj.find('difficult').text 127 | if int(difficult) == 1: 128 | difficult_flag = True 129 | obj_name = obj.find('name').text 130 | if obj_name not in class_names: 131 | continue 132 | bndbox = obj.find('bndbox') 133 | left = bndbox.find('xmin').text 134 | top = bndbox.find('ymin').text 135 | right = bndbox.find('xmax').text 136 | bottom = bndbox.find('ymax').text 137 | 138 | if difficult_flag: 139 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom)) 140 | else: 141 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 142 | print("Get ground truth result done.") 143 | 144 | if map_mode == 0 or map_mode == 3: 145 | print("Get map.") 146 | get_map(MINOVERLAP, True, score_threhold=score_threhold, path=map_out_path) 147 | print("Get map done.") 148 | 149 | if map_mode == 4: 150 | print("Get map.") 151 | get_coco_map(class_names=class_names, path=map_out_path) 152 | print("Get map done.") 153 | -------------------------------------------------------------------------------- /get_fusion_map_onlyYolo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xml.etree.ElementTree as ET 3 | 4 | from PIL import Image 5 | from tqdm import tqdm 6 | import numpy as np 7 | from yolov7.utils.utils import get_classes 8 | from yolov7.utils.utils_map import get_coco_map, get_map 9 | 10 | from ProbEn import ProbEn 11 | from yolov7.yolo_RGB import YOLO_RGB 12 | from yolov7.yolo_T import YOLO_T 13 | 14 | if __name__ == "__main__": 15 | ''' 16 | Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。 17 | 默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。 18 | 19 | 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值 20 | 因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框, 21 | ''' 22 | # ------------------------------------------------------------------------------------------------------------------# 23 | # map_mode用于指定该文件运行时计算的内容 24 | # map_mode为0代表整个map计算流程,包括获得预测结果、获得真实框、计算VOC_map。 25 | # map_mode为1代表仅仅获得预测结果。 26 | # map_mode为2代表仅仅获得真实框。 27 | # map_mode为3代表仅仅计算VOC_map。 28 | # map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行 29 | # -------------------------------------------------------------------------------------------------------------------# 30 | map_mode = 0 31 | # --------------------------------------------------------------------------------------# 32 | # 此处的classes_path用于指定需要测量VOC_map的类别 33 | # 一般情况下与训练和预测所用的classes_path一致即可 34 | # --------------------------------------------------------------------------------------# 35 | classes_path = r'E:\pythonProject\object-detection\ProbEn-master\yolov7\model_data\people_classes_KAIST.txt' 36 | # --------------------------------------------------------------------------------------# 37 | # MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。 38 | # 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。 39 | # 40 | # 当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。 41 | # 因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低, 42 | # --------------------------------------------------------------------------------------# 43 | MINOVERLAP = 0.5 44 | # --------------------------------------------------------------------------------------# 45 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP 46 | # 因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。 47 | # 48 | # 该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。 49 | # 想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。 50 | # --------------------------------------------------------------------------------------# 51 | confidence_RGB = 0.001 52 | confidence_T = 0.001 53 | # --------------------------------------------------------------------------------------# 54 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。 55 | # 56 | # 该值一般不调整。 57 | # --------------------------------------------------------------------------------------# 58 | nms_iou = 0.5 59 | # ---------------------------------------------------------------------------------------------------------------# 60 | # Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。 61 | # 62 | # 默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。 63 | # 因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。 64 | # 这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。 65 | # ---------------------------------------------------------------------------------------------------------------# 66 | score_threhold = 0.5 67 | # -------------------------------------------------------# 68 | # map_vis用于指定是否开启VOC_map计算的可视化 69 | # -------------------------------------------------------# 70 | map_vis = False 71 | # -------------------------------------------------------# 72 | # 指向VOC数据集所在的文件夹 73 | # 默认指向根目录下的VOC数据集 74 | # -------------------------------------------------------# 75 | VOCdevkit_path = r'D:\KAIST数据集\重新标注的kaist' 76 | # -------------------------------------------------------# 77 | # 结果输出的文件夹,默认为map_out 78 | # -------------------------------------------------------# 79 | map_out_path = 'map_out_ProbEn_YOLO1' 80 | 81 | image_ids = open(os.path.join(VOCdevkit_path, "kaist_wash_picture_test/test.txt")).read().strip().split() 82 | 83 | if not os.path.exists(map_out_path): 84 | os.makedirs(map_out_path) 85 | if not os.path.exists(os.path.join(map_out_path, 'ground-truth')): 86 | os.makedirs(os.path.join(map_out_path, 'ground-truth')) 87 | if not os.path.exists(os.path.join(map_out_path, 'detection-results')): 88 | os.makedirs(os.path.join(map_out_path, 'detection-results')) 89 | if not os.path.exists(os.path.join(map_out_path, 'images-optional')): 90 | os.makedirs(os.path.join(map_out_path, 'images-optional')) 91 | 92 | class_names, _ = get_classes(classes_path) 93 | 94 | if map_mode == 0 or map_mode == 1: 95 | print("Load model.") 96 | yolo_rgb = YOLO_RGB(confidence=confidence_RGB, nms_iou=nms_iou) 97 | yolo_T = YOLO_T(confidence=confidence_T, nms_iou=nms_iou) 98 | proben = ProbEn() 99 | print("Load model done.") 100 | 101 | print("Get predict result.") 102 | for image_id in tqdm(image_ids): 103 | image_RGB_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/visible/" + image_id + ".jpg") 104 | image_T_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/lwir/" + image_id + ".jpg") 105 | 106 | image_rgb = Image.open(image_RGB_path) 107 | image_T = Image.open(image_T_path) 108 | 109 | if map_vis: 110 | image_rgb.save(os.path.join(map_out_path, "images-optional/rgb/" + image_id + ".jpg")) 111 | image_T.save(os.path.join(map_out_path, "images-optional/T/" + image_id + ".jpg")) 112 | 113 | dets_rgb, scores_rgb = yolo_rgb.detect_image_dets(image_rgb) 114 | dets_rgb = np.asarray(dets_rgb) 115 | scores_rgb = np.asarray(scores_rgb) 116 | 117 | dets_T, scores_T = yolo_T.detect_image_dets(image_T) 118 | dets_T = np.asarray(dets_T) 119 | scores_T = np.asarray(scores_T) 120 | 121 | proben.get_map_txt(image_id, class_names, map_out_path, dets_rgb, scores_rgb, dets_T, scores_T) 122 | print("Get predict result done.") 123 | 124 | if map_mode == 0 or map_mode == 2: 125 | print("Get ground truth result.") 126 | for image_id in tqdm(image_ids): 127 | with open(os.path.join(map_out_path, "ground-truth/" + image_id + ".txt"), "w") as new_f: 128 | root = ET.parse(os.path.join(VOCdevkit_path, "kaist_wash_annotation_test/" + image_id + ".xml")).getroot() 129 | for obj in root.findall('object'): 130 | difficult_flag = False 131 | if obj.find('difficult') != None: 132 | difficult = obj.find('difficult').text 133 | if int(difficult) == 1: 134 | difficult_flag = True 135 | obj_name = obj.find('name').text 136 | if obj_name not in class_names: 137 | continue 138 | bndbox = obj.find('bndbox') 139 | left = bndbox.find('xmin').text 140 | top = bndbox.find('ymin').text 141 | right = bndbox.find('xmax').text 142 | bottom = bndbox.find('ymax').text 143 | 144 | if difficult_flag: 145 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom)) 146 | else: 147 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 148 | print("Get ground truth result done.") 149 | 150 | if map_mode == 0 or map_mode == 3: 151 | print("Get map.") 152 | get_map(MINOVERLAP, True, score_threhold=score_threhold, path=map_out_path) 153 | print("Get map done.") 154 | 155 | if map_mode == 4: 156 | print("Get map.") 157 | get_coco_map(class_names=class_names, path=map_out_path) 158 | print("Get map done.") 159 | -------------------------------------------------------------------------------- /img/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/ProbDet/0e203e0dc827ad34f8c5eb87c953f16703b9a5d1/img/1.jpg -------------------------------------------------------------------------------- /img/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/ProbDet/0e203e0dc827ad34f8c5eb87c953f16703b9a5d1/img/2.jpg -------------------------------------------------------------------------------- /img/street.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/ProbDet/0e203e0dc827ad34f8c5eb87c953f16703b9a5d1/img/street.jpg -------------------------------------------------------------------------------- /img_out/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/ProbDet/0e203e0dc827ad34f8c5eb87c953f16703b9a5d1/img_out/1.png -------------------------------------------------------------------------------- /img_out/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/ProbDet/0e203e0dc827ad34f8c5eb87c953f16703b9a5d1/img_out/2.png -------------------------------------------------------------------------------- /img_out/street.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/ProbDet/0e203e0dc827ad34f8c5eb87c953f16703b9a5d1/img_out/street.png -------------------------------------------------------------------------------- /yolov7/get_FNPI.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xml.etree.ElementTree as ET 3 | 4 | from PIL import Image 5 | from tqdm import tqdm 6 | 7 | from utils.utils import get_classes 8 | from utils.utils_FNPI import get_FNPI 9 | from yolo import YOLO 10 | 11 | if __name__ == "__main__": 12 | #------------------------------------------------------------------------------------------------------------------# 13 | # map_mode用于指定该文件运行时计算的内容 14 | # map_mode为0代表整个FNPI计算流程,包括获得预测结果、获得真实框、计算FNPI。 15 | # map_mode为1代表仅仅获得预测结果。 16 | # map_mode为2代表仅仅获得真实框。 17 | # map_mode为3代表仅仅计算FNPI。 18 | #-------------------------------------------------------------------------------------------------------------------# 19 | map_mode = 0 20 | #--------------------------------------------------------------------------------------# 21 | # 此处的classes_path用于指定需要测量FNPI的类别 22 | # 一般情况下与训练和预测所用的classes_path一致即可 23 | #--------------------------------------------------------------------------------------# 24 | classes_path = 'model_data/people_classes_KAIST.txt' 25 | # classes_path = 'model_data/voc_classes.txt' 26 | #--------------------------------------------------------------------------------------# 27 | # FNPI_IOU作为判定预测框与真实框相匹配(即真实框所对应的目标被检测成功)的条件 28 | # 只有大于FNPI_IOU值才算检测成功 29 | #--------------------------------------------------------------------------------------# 30 | FNPI_IOU = 0.5 31 | #--------------------------------------------------------------------------------------# 32 | # confidence的设置与计算map时的设置情况不一样。 33 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,因此,map计算时的confidence的值应当设置的尽量小进而获得全部可能的预测框。 34 | # 而计算FNPI设置的置信度confidence应该与预测时的置信度一致,只有得分大于置信度的预测框会被保留下来 35 | #--------------------------------------------------------------------------------------# 36 | confidence = 0.5 37 | #--------------------------------------------------------------------------------------# 38 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。 39 | # 该值也应该与预测时设置的nms_iou一致。 40 | #--------------------------------------------------------------------------------------# 41 | nms_iou = 0.3 42 | #-------------------------------------------------------# 43 | # 指向VOC数据集所在的文件夹 44 | # 默认指向根目录下的VOC数据集 45 | #-------------------------------------------------------# 46 | # VOCdevkit_path = r'E:\pythonProject\object-detection\yolov7-pytorch-master\VOCdevkit' 47 | VOCdevkit_path = r'D:\KAIST数据集\重新标注的kaist' 48 | #-------------------------------------------------------# 49 | # 结果输出的文件夹,默认为map_out 50 | #-------------------------------------------------------# 51 | FNPI_out_path = 'FNPI_out/FNPI_out_KAIST_RGB' 52 | 53 | # image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split() 54 | image_ids = open(os.path.join(VOCdevkit_path, "kaist_wash_picture_test/test.txt")).read().strip().split() 55 | 56 | if not os.path.exists(FNPI_out_path): 57 | os.makedirs(FNPI_out_path) 58 | if not os.path.exists(os.path.join(FNPI_out_path, 'ground-truth')): 59 | os.makedirs(os.path.join(FNPI_out_path, 'ground-truth')) 60 | if not os.path.exists(os.path.join(FNPI_out_path, 'detection-results')): 61 | os.makedirs(os.path.join(FNPI_out_path, 'detection-results')) 62 | 63 | class_names, _ = get_classes(classes_path) 64 | 65 | if map_mode == 0 or map_mode == 1: 66 | print("Load model.") 67 | yolo = YOLO(confidence = confidence, nms_iou = nms_iou) 68 | print("Load model done.") 69 | 70 | print("Get predict result.") 71 | for image_id in tqdm(image_ids): 72 | # image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg") 73 | # image_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/lwir/"+image_id+".jpg") 74 | image_path = os.path.join(VOCdevkit_path, "kaist_wash_picture_test/visible/"+image_id+".jpg") 75 | image = Image.open(image_path) 76 | yolo.get_map_txt(image_id, image, class_names, FNPI_out_path) 77 | print("Get predict result done.") 78 | 79 | if map_mode == 0 or map_mode == 2: 80 | print("Get ground truth result.") 81 | for image_id in tqdm(image_ids): 82 | with open(os.path.join(FNPI_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: 83 | # root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/"+image_id+".xml")).getroot() 84 | root = ET.parse(os.path.join(VOCdevkit_path, "kaist_wash_annotation_test/"+image_id+".xml")).getroot() 85 | for obj in root.findall('object'): 86 | difficult_flag = False 87 | if obj.find('difficult')!=None: 88 | difficult = obj.find('difficult').text 89 | if int(difficult)==1: 90 | difficult_flag = True 91 | obj_name = obj.find('name').text 92 | if obj_name not in class_names: 93 | continue 94 | bndbox = obj.find('bndbox') 95 | left = bndbox.find('xmin').text 96 | top = bndbox.find('ymin').text 97 | right = bndbox.find('xmax').text 98 | bottom = bndbox.find('ymax').text 99 | 100 | if difficult_flag: 101 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom)) 102 | else: 103 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 104 | print("Get ground truth result done.") 105 | 106 | if map_mode == 0 or map_mode == 3: 107 | print("Get map.") 108 | get_FNPI(FNPI_IOU, True, path = FNPI_out_path) 109 | print("Get map done.") -------------------------------------------------------------------------------- /yolov7/get_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xml.etree.ElementTree as ET 3 | 4 | from PIL import Image 5 | from tqdm import tqdm 6 | 7 | from utils.utils import get_classes 8 | from utils.utils_map import get_coco_map, get_map 9 | from yolo import YOLO 10 | 11 | if __name__ == "__main__": 12 | ''' 13 | Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。 14 | 默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。 15 | 16 | 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值 17 | 因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框, 18 | ''' 19 | #------------------------------------------------------------------------------------------------------------------# 20 | # map_mode用于指定该文件运行时计算的内容 21 | # map_mode为0代表整个map计算流程,包括获得预测结果、获得真实框、计算VOC_map。 22 | # map_mode为1代表仅仅获得预测结果。 23 | # map_mode为2代表仅仅获得真实框。 24 | # map_mode为3代表仅仅计算VOC_map。 25 | # map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行 26 | #-------------------------------------------------------------------------------------------------------------------# 27 | map_mode = 0 28 | #--------------------------------------------------------------------------------------# 29 | # 此处的classes_path用于指定需要测量VOC_map的类别 30 | # 一般情况下与训练和预测所用的classes_path一致即可 31 | #--------------------------------------------------------------------------------------# 32 | classes_path = 'model_data/voc_classes.txt' 33 | #--------------------------------------------------------------------------------------# 34 | # MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。 35 | # 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。 36 | # 37 | # 当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。 38 | # 因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低, 39 | #--------------------------------------------------------------------------------------# 40 | MINOVERLAP = 0.5 41 | #--------------------------------------------------------------------------------------# 42 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP 43 | # 因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。 44 | # 45 | # 该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。 46 | # 想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。 47 | #--------------------------------------------------------------------------------------# 48 | confidence = 0.001 49 | #--------------------------------------------------------------------------------------# 50 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。 51 | # 52 | # 该值一般不调整。 53 | #--------------------------------------------------------------------------------------# 54 | nms_iou = 0.5 55 | #---------------------------------------------------------------------------------------------------------------# 56 | # Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。 57 | # 58 | # 默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。 59 | # 因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。 60 | # 这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。 61 | #---------------------------------------------------------------------------------------------------------------# 62 | score_threhold = 0.5 63 | #-------------------------------------------------------# 64 | # map_vis用于指定是否开启VOC_map计算的可视化 65 | #-------------------------------------------------------# 66 | map_vis = False 67 | #-------------------------------------------------------# 68 | # 指向VOC数据集所在的文件夹 69 | # 默认指向根目录下的VOC数据集 70 | #-------------------------------------------------------# 71 | VOCdevkit_path = 'VOCdevkit' 72 | #-------------------------------------------------------# 73 | # 结果输出的文件夹,默认为map_out 74 | #-------------------------------------------------------# 75 | map_out_path = 'map_out' 76 | 77 | image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split() 78 | 79 | if not os.path.exists(map_out_path): 80 | os.makedirs(map_out_path) 81 | if not os.path.exists(os.path.join(map_out_path, 'ground-truth')): 82 | os.makedirs(os.path.join(map_out_path, 'ground-truth')) 83 | if not os.path.exists(os.path.join(map_out_path, 'detection-results')): 84 | os.makedirs(os.path.join(map_out_path, 'detection-results')) 85 | if not os.path.exists(os.path.join(map_out_path, 'images-optional')): 86 | os.makedirs(os.path.join(map_out_path, 'images-optional')) 87 | 88 | class_names, _ = get_classes(classes_path) 89 | 90 | if map_mode == 0 or map_mode == 1: 91 | print("Load model.") 92 | yolo = YOLO(confidence = confidence, nms_iou = nms_iou) 93 | print("Load model done.") 94 | 95 | print("Get predict result.") 96 | for image_id in tqdm(image_ids): 97 | image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg") 98 | image = Image.open(image_path) 99 | if map_vis: 100 | image.save(os.path.join(map_out_path, "images-optional/" + image_id + ".jpg")) 101 | yolo.get_map_txt(image_id, image, class_names, map_out_path) 102 | print("Get predict result done.") 103 | 104 | if map_mode == 0 or map_mode == 2: 105 | print("Get ground truth result.") 106 | for image_id in tqdm(image_ids): 107 | with open(os.path.join(map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: 108 | root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/"+image_id+".xml")).getroot() 109 | for obj in root.findall('object'): 110 | difficult_flag = False 111 | if obj.find('difficult')!=None: 112 | difficult = obj.find('difficult').text 113 | if int(difficult)==1: 114 | difficult_flag = True 115 | obj_name = obj.find('name').text 116 | if obj_name not in class_names: 117 | continue 118 | bndbox = obj.find('bndbox') 119 | left = bndbox.find('xmin').text 120 | top = bndbox.find('ymin').text 121 | right = bndbox.find('xmax').text 122 | bottom = bndbox.find('ymax').text 123 | 124 | if difficult_flag: 125 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom)) 126 | else: 127 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 128 | print("Get ground truth result done.") 129 | 130 | if map_mode == 0 or map_mode == 3: 131 | print("Get map.") 132 | get_map(MINOVERLAP, True, score_threhold = score_threhold, path = map_out_path) 133 | print("Get map done.") 134 | 135 | if map_mode == 4: 136 | print("Get map.") 137 | get_coco_map(class_names = class_names, path = map_out_path) 138 | print("Get map done.") 139 | -------------------------------------------------------------------------------- /yolov7/kmeans_for_anchors.py: -------------------------------------------------------------------------------- 1 | #-------------------------------------------------------------------------------------------------------# 2 | # kmeans虽然会对数据集中的框进行聚类,但是很多数据集由于框的大小相近,聚类出来的9个框相差不大, 3 | # 这样的框反而不利于模型的训练。因为不同的特征层适合不同大小的先验框,shape越小的特征层适合越大的先验框 4 | # 原始网络的先验框已经按大中小比例分配好了,不进行聚类也会有非常好的效果。 5 | #-------------------------------------------------------------------------------------------------------# 6 | import glob 7 | import xml.etree.ElementTree as ET 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | 14 | def cas_ratio(box,cluster): 15 | ratios_of_box_cluster = box / cluster 16 | ratios_of_cluster_box = cluster / box 17 | ratios = np.concatenate([ratios_of_box_cluster, ratios_of_cluster_box], axis = -1) 18 | 19 | return np.max(ratios, -1) 20 | 21 | def avg_ratio(box,cluster): 22 | return np.mean([np.min(cas_ratio(box[i],cluster)) for i in range(box.shape[0])]) 23 | 24 | def kmeans(box,k): 25 | #-------------------------------------------------------------# 26 | # 取出一共有多少框 27 | #-------------------------------------------------------------# 28 | row = box.shape[0] 29 | 30 | #-------------------------------------------------------------# 31 | # 每个框各个点的位置 32 | #-------------------------------------------------------------# 33 | distance = np.empty((row,k)) 34 | 35 | #-------------------------------------------------------------# 36 | # 最后的聚类位置 37 | #-------------------------------------------------------------# 38 | last_clu = np.zeros((row,)) 39 | 40 | np.random.seed() 41 | 42 | #-------------------------------------------------------------# 43 | # 随机选5个当聚类中心 44 | #-------------------------------------------------------------# 45 | cluster = box[np.random.choice(row,k,replace = False)] 46 | 47 | iter = 0 48 | while True: 49 | #-------------------------------------------------------------# 50 | # 计算当前框和先验框的宽高比例 51 | #-------------------------------------------------------------# 52 | for i in range(row): 53 | distance[i] = cas_ratio(box[i],cluster) 54 | 55 | #-------------------------------------------------------------# 56 | # 取出最小点 57 | #-------------------------------------------------------------# 58 | near = np.argmin(distance,axis=1) 59 | 60 | if (last_clu == near).all(): 61 | break 62 | 63 | #-------------------------------------------------------------# 64 | # 求每一个类的中位点 65 | #-------------------------------------------------------------# 66 | for j in range(k): 67 | cluster[j] = np.median( 68 | box[near == j],axis=0) 69 | 70 | last_clu = near 71 | if iter % 5 == 0: 72 | print('iter: {:d}. avg_ratio:{:.2f}'.format(iter, avg_ratio(box,cluster))) 73 | iter += 1 74 | 75 | return cluster, near 76 | 77 | def load_data(path): 78 | data = [] 79 | #-------------------------------------------------------------# 80 | # 对于每一个xml都寻找box 81 | #-------------------------------------------------------------# 82 | for xml_file in tqdm(glob.glob('{}/*xml'.format(path))): 83 | tree = ET.parse(xml_file) 84 | height = int(tree.findtext('./size/height')) 85 | width = int(tree.findtext('./size/width')) 86 | if height<=0 or width<=0: 87 | continue 88 | 89 | #-------------------------------------------------------------# 90 | # 对于每一个目标都获得它的宽高 91 | #-------------------------------------------------------------# 92 | for obj in tree.iter('object'): 93 | xmin = int(float(obj.findtext('bndbox/xmin'))) / width 94 | ymin = int(float(obj.findtext('bndbox/ymin'))) / height 95 | xmax = int(float(obj.findtext('bndbox/xmax'))) / width 96 | ymax = int(float(obj.findtext('bndbox/ymax'))) / height 97 | 98 | xmin = np.float64(xmin) 99 | ymin = np.float64(ymin) 100 | xmax = np.float64(xmax) 101 | ymax = np.float64(ymax) 102 | # 得到宽高 103 | data.append([xmax-xmin, ymax-ymin]) 104 | return np.array(data) 105 | 106 | if __name__ == '__main__': 107 | np.random.seed(0) 108 | #-------------------------------------------------------------# 109 | # 运行该程序会计算'./VOCdevkit/VOC2007/Annotations'的xml 110 | # 会生成yolo_anchors.txt 111 | #-------------------------------------------------------------# 112 | input_shape = [640, 640] 113 | anchors_num = 9 114 | #-------------------------------------------------------------# 115 | # 载入数据集,可以使用VOC的xml 116 | #-------------------------------------------------------------# 117 | path = 'VOCdevkit/VOC2007/Annotations' 118 | 119 | #-------------------------------------------------------------# 120 | # 载入所有的xml 121 | # 存储格式为转化为比例后的width,height 122 | #-------------------------------------------------------------# 123 | print('Load xmls.') 124 | data = load_data(path) 125 | print('Load xmls done.') 126 | 127 | #-------------------------------------------------------------# 128 | # 使用k聚类算法 129 | #-------------------------------------------------------------# 130 | print('K-means boxes.') 131 | cluster, near = kmeans(data, anchors_num) 132 | print('K-means boxes done.') 133 | data = data * np.array([input_shape[1], input_shape[0]]) 134 | cluster = cluster * np.array([input_shape[1], input_shape[0]]) 135 | 136 | #-------------------------------------------------------------# 137 | # 绘图 138 | #-------------------------------------------------------------# 139 | for j in range(anchors_num): 140 | plt.scatter(data[near == j][:,0], data[near == j][:,1]) 141 | plt.scatter(cluster[j][0], cluster[j][1], marker='x', c='black') 142 | plt.savefig("kmeans_for_anchors.jpg") 143 | plt.show() 144 | print('Save kmeans_for_anchors.jpg in root dir.') 145 | 146 | cluster = cluster[np.argsort(cluster[:, 0] * cluster[:, 1])] 147 | print('avg_ratio:{:.2f}'.format(avg_ratio(data, cluster))) 148 | print(cluster) 149 | 150 | f = open("yolo_anchors.txt", 'w') 151 | row = np.shape(cluster)[0] 152 | for i in range(row): 153 | if i == 0: 154 | x_y = "%d,%d" % (cluster[i][0], cluster[i][1]) 155 | else: 156 | x_y = ", %d,%d" % (cluster[i][0], cluster[i][1]) 157 | f.write(x_y) 158 | f.close() 159 | -------------------------------------------------------------------------------- /yolov7/model_data/coco_classes.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /yolov7/model_data/simhei.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1e12Leon/ProbDet/0e203e0dc827ad34f8c5eb87c953f16703b9a5d1/yolov7/model_data/simhei.ttf -------------------------------------------------------------------------------- /yolov7/model_data/voc_classes.txt: -------------------------------------------------------------------------------- 1 | dog 2 | person 3 | cat 4 | car -------------------------------------------------------------------------------- /yolov7/model_data/yolo_anchors.txt: -------------------------------------------------------------------------------- 1 | 23,44, 64,64, 37,117, 75,228, 140,146, 154,380, 307,273, 291,498, 536,541 -------------------------------------------------------------------------------- /yolov7/nets/SRModule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from nets.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | from nets.SR_Decoder import Decoder 6 | from nets.SR_Encoder import EDSR 7 | 8 | 9 | class EDSRConv(torch.nn.Module): 10 | def __init__(self, in_ch, out_ch): 11 | super(EDSRConv, self).__init__() 12 | self.conv = torch.nn.Sequential( 13 | torch.nn.Conv2d(in_ch, out_ch, 3, padding=1), 14 | torch.nn.ReLU(inplace=True), 15 | torch.nn.Conv2d(out_ch, out_ch, 3, padding=1), 16 | ) 17 | 18 | self.residual_upsampler = torch.nn.Sequential( 19 | torch.nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False), 20 | ) 21 | 22 | def forward(self, input): 23 | return self.conv(input)+self.residual_upsampler(input) 24 | 25 | 26 | class DeepLab(nn.Module): 27 | def __init__(self, ch, c1=128, c2=512, factor=2): 28 | super(DeepLab, self).__init__() 29 | self.sr_decoder = Decoder(c1, c2) 30 | self.edsr = EDSR(num_channels=ch, input_channel=64, factor=8) 31 | self.factor = factor 32 | 33 | def forward(self, low_level_feat,x): 34 | x_sr = self.sr_decoder(x, low_level_feat, self.factor) 35 | x_sr_up = self.edsr(x_sr) 36 | return x_sr_up 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /yolov7/nets/SR_Decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from nets.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, c1, c2): 9 | super(Decoder, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(c1, c1 // 2, 1, bias=False) 12 | self.conv2 = nn.Conv2d(c2, c2 // 2, 1, bias=False) 13 | # self.bn1 = BatchNorm(48) 14 | self.relu = nn.ReLU() 15 | # self.pixel_shuffle = nn.PixelShuffle(4) 16 | # self.attention = AttentionModel(48+c2) 17 | self.last_conv = nn.Sequential(nn.Conv2d((c1 + c2) // 2, 256, kernel_size=3, stride=1, padding=1, bias=False), 18 | # BatchNorm(256), 19 | nn.ReLU(), 20 | # nn.Dropout(0.5), 21 | nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1, bias=False), 22 | # BatchNorm(128), 23 | nn.ReLU(), 24 | # nn.Dropout(0.1), 25 | nn.Conv2d(128, 64, kernel_size=1, stride=1)) 26 | self._init_weight() 27 | 28 | def forward(self, x, low_level_feat, factor): 29 | low_level_feat = self.conv1(low_level_feat) 30 | low_level_feat = self.relu(low_level_feat) 31 | 32 | x = self.conv2(x) 33 | x = self.relu(x) 34 | x = F.interpolate(x, size=[i * (factor // 2) for i in low_level_feat.size()[2:]], mode='bilinear', 35 | align_corners=True) 36 | if factor > 1: 37 | low_level_feat = F.interpolate(low_level_feat, size=[i * (factor // 2) for i in low_level_feat.size()[2:]], 38 | mode='bilinear', align_corners=True) 39 | # x = self.pixel_shuffle(x) 40 | x = torch.cat((x, low_level_feat), dim=1) 41 | x = self.last_conv(x) 42 | 43 | return x 44 | 45 | def _init_weight(self): 46 | for m in self.modules(): 47 | if isinstance(m, nn.Conv2d): 48 | torch.nn.init.kaiming_normal_(m.weight) 49 | elif isinstance(m, SynchronizedBatchNorm2d): 50 | m.weight.data.fill_(1) 51 | m.bias.data.zero_() 52 | elif isinstance(m, nn.BatchNorm2d): 53 | m.weight.data.fill_(1) 54 | m.bias.data.zero_() 55 | 56 | -------------------------------------------------------------------------------- /yolov7/nets/SR_Encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def make_model(args, parent=False): 6 | return EDSR(args) 7 | 8 | 9 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 10 | return nn.Conv2d( 11 | in_channels, out_channels, kernel_size, 12 | padding=(kernel_size // 2), bias=bias) 13 | 14 | 15 | class Upsampler(nn.Sequential): 16 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): 17 | 18 | m = [] 19 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 20 | for _ in range(int(math.log(scale, 2))): 21 | m.append(conv(n_feat, 4 * n_feat, 3, bias)) 22 | m.append(nn.PixelShuffle(2)) 23 | if bn: m.append(nn.BatchNorm2d(n_feat)) 24 | if act: m.append(act()) 25 | elif scale == 3: 26 | m.append(conv(n_feat, 9 * n_feat, 3, bias)) 27 | m.append(nn.PixelShuffle(3)) 28 | if bn: m.append(nn.BatchNorm2d(n_feat)) 29 | if act: m.append(act()) 30 | else: 31 | raise NotImplementedError 32 | 33 | super(Upsampler, self).__init__(*m) 34 | 35 | 36 | class ResBlock(nn.Module): 37 | def __init__(self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 38 | super(ResBlock, self).__init__() 39 | m = [] 40 | for i in range(2): 41 | m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 42 | if bn: m.append(nn.BatchNorm2d(n_feat)) 43 | if i == 0: m.append(act) 44 | 45 | self.body = nn.Sequential(*m) 46 | self.res_scale = res_scale 47 | 48 | def forward(self, x): 49 | res = self.body(x).mul(self.res_scale) 50 | res += x 51 | 52 | return res 53 | 54 | 55 | class EDSR(nn.Module): 56 | def __init__(self, num_channels=3, input_channel=64, factor=4, width=64, depth=16, kernel_size=3, 57 | conv=default_conv): 58 | super(EDSR, self).__init__() 59 | 60 | n_resblock = depth 61 | n_feats = width 62 | kernel_size = kernel_size 63 | scale = factor 64 | act = nn.ReLU() 65 | 66 | # rgb_mean = (0.4488, 0.4371, 0.4040) 67 | # rgb_std = (1.0, 1.0, 1.0) 68 | # self.sub_mean = common.MeanShift(1.0, rgb_mean, rgb_std) 69 | 70 | # define head module 71 | m_head = [conv(input_channel, n_feats, kernel_size)] 72 | 73 | # define body module 74 | m_body = [ 75 | ResBlock( 76 | conv, n_feats, kernel_size, act=act, res_scale=1. 77 | ) for _ in range(n_resblock) 78 | ] 79 | m_body.append(conv(n_feats, n_feats, kernel_size)) 80 | 81 | # define tail module 82 | m_tail = [ 83 | Upsampler(conv, scale, n_feats, act=False), 84 | conv(n_feats, num_channels, kernel_size) 85 | ] 86 | 87 | # self.add_mean = common.MeanShift(1.0, rgb_mean, rgb_std, 1) 88 | 89 | self.head = nn.Sequential(*m_head) 90 | self.body = nn.Sequential(*m_body) 91 | self.tail = nn.Sequential(*m_tail) 92 | 93 | def forward(self, x): 94 | # x = self.sub_mean(x) 95 | x = self.head(x) 96 | 97 | res = self.body(x) 98 | res += x 99 | 100 | x = self.tail(res) 101 | # x = self.add_mean(x) 102 | 103 | return x 104 | 105 | def load_state_dict(self, state_dict, strict=True): 106 | own_state = self.state_dict() 107 | for name, param in state_dict.items(): 108 | if name in own_state: 109 | if isinstance(param, nn.Parameter): 110 | param = param.data 111 | try: 112 | own_state[name].copy_(param) 113 | except Exception: 114 | if name.find('tail') == -1: 115 | raise RuntimeError('While copying the parameter named {}, ' 116 | 'whose dimensions in the model are {} and ' 117 | 'whose dimensions in the checkpoint are {}.' 118 | .format(name, own_state[name].size(), param.size())) 119 | elif strict: 120 | if name.find('tail') == -1: 121 | raise KeyError('unexpected key "{}" in state_dict' 122 | .format(name)) 123 | -------------------------------------------------------------------------------- /yolov7/nets/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /yolov7/nets/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def autopad(k, p=None): 6 | if p is None: 7 | p = k // 2 if isinstance(k, int) else [x // 2 for x in k] 8 | return p 9 | 10 | class SiLU(nn.Module): 11 | @staticmethod 12 | def forward(x): 13 | return x * torch.sigmoid(x) 14 | 15 | class Conv(nn.Module): 16 | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=SiLU()): # ch_in, ch_out, kernel, stride, padding, groups 17 | super(Conv, self).__init__() 18 | self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) 19 | self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03) 20 | self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) 21 | 22 | def forward(self, x): 23 | return self.act(self.bn(self.conv(x))) 24 | 25 | def fuseforward(self, x): 26 | return self.act(self.conv(x)) 27 | 28 | class Block(nn.Module): 29 | def __init__(self, c1, c2, c3, n=4, e=1, ids=[0]): 30 | super(Block, self).__init__() 31 | c_ = int(c2 * e) 32 | 33 | self.ids = ids 34 | self.cv1 = Conv(c1, c_, 1, 1) 35 | self.cv2 = Conv(c1, c_, 1, 1) 36 | self.cv3 = nn.ModuleList( 37 | [Conv(c_ if i ==0 else c2, c2, 3, 1) for i in range(n)] 38 | ) 39 | self.cv4 = Conv(c_ * 2 + c2 * (len(ids) - 2), c3, 1, 1) 40 | 41 | def forward(self, x): 42 | x_1 = self.cv1(x) 43 | x_2 = self.cv2(x) 44 | 45 | x_all = [x_1, x_2] 46 | for i in range(len(self.cv3)): 47 | x_2 = self.cv3[i](x_2) 48 | x_all.append(x_2) 49 | 50 | out = self.cv4(torch.cat([x_all[id] for id in self.ids], 1)) 51 | return out 52 | 53 | class MP(nn.Module): 54 | def __init__(self, k=2): 55 | super(MP, self).__init__() 56 | self.m = nn.MaxPool2d(kernel_size=k, stride=k) 57 | 58 | def forward(self, x): 59 | return self.m(x) 60 | 61 | class Transition(nn.Module): 62 | def __init__(self, c1, c2): 63 | super(Transition, self).__init__() 64 | self.cv1 = Conv(c1, c2, 1, 1) 65 | self.cv2 = Conv(c1, c2, 1, 1) 66 | self.cv3 = Conv(c2, c2, 3, 2) 67 | 68 | self.mp = MP() 69 | 70 | def forward(self, x): 71 | x_1 = self.mp(x) 72 | x_1 = self.cv1(x_1) 73 | 74 | x_2 = self.cv2(x) 75 | x_2 = self.cv3(x_2) 76 | 77 | return torch.cat([x_2, x_1], 1) 78 | 79 | class Backbone(nn.Module): 80 | def __init__(self, transition_channels, block_channels, n, phi, pretrained=False): 81 | super().__init__() 82 | #-----------------------------------------------# 83 | # 输入图片是640, 640, 3 84 | #-----------------------------------------------# 85 | ids = { 86 | 'l' : [-1, -3, -5, -6], 87 | 'x' : [-1, -3, -5, -7, -8], 88 | }[phi] 89 | self.stem = nn.Sequential( 90 | Conv(3, transition_channels, 3, 1), 91 | Conv(transition_channels, transition_channels * 2, 3, 2), 92 | Conv(transition_channels * 2, transition_channels * 2, 3, 1), 93 | ) 94 | self.dark2 = nn.Sequential( 95 | Conv(transition_channels * 2, transition_channels * 4, 3, 2), 96 | Block(transition_channels * 4, block_channels * 2, transition_channels * 8, n=n, ids=ids), 97 | ) 98 | self.dark3 = nn.Sequential( 99 | Transition(transition_channels * 8, transition_channels * 4), 100 | Block(transition_channels * 8, block_channels * 4, transition_channels * 16, n=n, ids=ids), 101 | ) 102 | self.dark4 = nn.Sequential( 103 | Transition(transition_channels * 16, transition_channels * 8), 104 | Block(transition_channels * 16, block_channels * 8, transition_channels * 32, n=n, ids=ids), 105 | ) 106 | self.dark5 = nn.Sequential( 107 | Transition(transition_channels * 32, transition_channels * 16), 108 | Block(transition_channels * 32, block_channels * 8, transition_channels * 32, n=n, ids=ids), 109 | ) 110 | 111 | if pretrained: 112 | url = { 113 | "l" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_backbone_weights.pth', 114 | "x" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_x_backbone_weights.pth', 115 | }[phi] 116 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", model_dir="./model_data") 117 | self.load_state_dict(checkpoint, strict=False) 118 | print("Load weights from " + url.split('/')[-1]) 119 | 120 | def forward(self, x): 121 | x = self.stem(x) 122 | x = self.dark2(x) 123 | #-----------------------------------------------# 124 | # dark3的输出为80, 80, 256,是一个有效特征层 125 | #-----------------------------------------------# 126 | x = self.dark3(x) 127 | feat1 = x 128 | #-----------------------------------------------# 129 | # dark4的输出为40, 40, 512,是一个有效特征层 130 | #-----------------------------------------------# 131 | x = self.dark4(x) 132 | feat2 = x 133 | #-----------------------------------------------# 134 | # dark5的输出为20, 20, 1024,是一个有效特征层 135 | #-----------------------------------------------# 136 | x = self.dark5(x) 137 | feat3 = x 138 | return feat1, feat2, feat3 139 | -------------------------------------------------------------------------------- /yolov7/nets/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /yolov7/nets/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /yolov7/nets/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /yolov7/nets/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /yolov7/predict.py: -------------------------------------------------------------------------------- 1 | #-----------------------------------------------------------------------# 2 | # predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能 3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。 4 | #-----------------------------------------------------------------------# 5 | import time 6 | 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from yolo import YOLO 12 | 13 | if __name__ == "__main__": 14 | yolo = YOLO() 15 | #----------------------------------------------------------------------------------------------------------# 16 | # mode用于指定测试的模式: 17 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释 18 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。 19 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。 20 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。 21 | # 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。 22 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。 23 | #----------------------------------------------------------------------------------------------------------# 24 | mode = "predict" 25 | #-------------------------------------------------------------------------# 26 | # crop 指定了是否在单张图片预测后对目标进行截取 27 | # count 指定了是否进行目标的计数 28 | # crop、count仅在mode='predict'时有效 29 | #-------------------------------------------------------------------------# 30 | crop = False 31 | count = False 32 | #----------------------------------------------------------------------------------------------------------# 33 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头 34 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。 35 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存 36 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。 37 | # video_fps 用于保存的视频的fps 38 | # 39 | # video_path、video_save_path和video_fps仅在mode='video'时有效 40 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。 41 | #----------------------------------------------------------------------------------------------------------# 42 | video_path = "img/1.mp4" 43 | video_save_path = "2.mp4" 44 | video_fps = 25.0 45 | #----------------------------------------------------------------------------------------------------------# 46 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。 47 | # fps_image_path 用于指定测试的fps图片 48 | # 49 | # test_interval和fps_image_path仅在mode='fps'有效 50 | #----------------------------------------------------------------------------------------------------------# 51 | test_interval = 100 52 | fps_image_path = "img/street.jpg" 53 | #-------------------------------------------------------------------------# 54 | # dir_origin_path 指定了用于检测的图片的文件夹路径 55 | # dir_save_path 指定了检测完图片的保存路径 56 | # 57 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效 58 | #-------------------------------------------------------------------------# 59 | dir_origin_path = "img/" 60 | dir_save_path = "img_out/" 61 | #-------------------------------------------------------------------------# 62 | # heatmap_save_path 热力图的保存路径,默认保存在model_data下 63 | # 64 | # heatmap_save_path仅在mode='heatmap'有效 65 | #-------------------------------------------------------------------------# 66 | heatmap_save_path = "model_data/heatmap_vision.png" 67 | #-------------------------------------------------------------------------# 68 | # simplify 使用Simplify onnx 69 | # onnx_save_path 指定了onnx的保存路径 70 | #-------------------------------------------------------------------------# 71 | simplify = True 72 | onnx_save_path = "model_data/models.onnx" 73 | 74 | if mode == "predict": 75 | ''' 76 | 1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。 77 | 2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。 78 | 3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值 79 | 在原图上利用矩阵的方式进行截取。 80 | 4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断, 81 | 比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。 82 | ''' 83 | while True: 84 | img = input('Input image filename:') 85 | try: 86 | image = Image.open(img) 87 | except: 88 | print('Open Error! Try again!') 89 | continue 90 | else: 91 | r_image = yolo.detect_image(image, crop = crop, count=count) 92 | r_image.show() 93 | 94 | elif mode == "video": 95 | capture = cv2.VideoCapture(video_path) 96 | if video_save_path!="": 97 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 98 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))) 99 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size) 100 | 101 | ref, frame = capture.read() 102 | if not ref: 103 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。") 104 | 105 | fps = 0.0 106 | while(True): 107 | t1 = time.time() 108 | # 读取某一帧 109 | ref, frame = capture.read() 110 | if not ref: 111 | break 112 | # 格式转变,BGRtoRGB 113 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) 114 | # 转变成Image 115 | frame = Image.fromarray(np.uint8(frame)) 116 | # 进行检测 117 | frame = np.array(yolo.detect_image(frame)) 118 | # RGBtoBGR满足opencv显示格式 119 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR) 120 | 121 | fps = ( fps + (1./(time.time()-t1)) ) / 2 122 | print("fps= %.2f"%(fps)) 123 | frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) 124 | 125 | cv2.imshow("video",frame) 126 | c= cv2.waitKey(1) & 0xff 127 | if video_save_path!="": 128 | out.write(frame) 129 | 130 | if c==27: 131 | capture.release() 132 | break 133 | 134 | print("Video Detection Done!") 135 | capture.release() 136 | if video_save_path!="": 137 | print("Save processed video to the path :" + video_save_path) 138 | out.release() 139 | cv2.destroyAllWindows() 140 | 141 | elif mode == "fps": 142 | img = Image.open(fps_image_path) 143 | tact_time = yolo.get_FPS(img, test_interval) 144 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1') 145 | 146 | elif mode == "dir_predict": 147 | import os 148 | 149 | from tqdm import tqdm 150 | 151 | img_names = os.listdir(dir_origin_path) 152 | for img_name in tqdm(img_names): 153 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): 154 | image_path = os.path.join(dir_origin_path, img_name) 155 | image = Image.open(image_path) 156 | r_image = yolo.detect_image(image) 157 | if not os.path.exists(dir_save_path): 158 | os.makedirs(dir_save_path) 159 | r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0) 160 | 161 | elif mode == "heatmap": 162 | while True: 163 | img = input('Input image filename:') 164 | try: 165 | image = Image.open(img) 166 | except: 167 | print('Open Error! Try again!') 168 | continue 169 | else: 170 | yolo.detect_heatmap(image, heatmap_save_path) 171 | 172 | elif mode == "export_onnx": 173 | yolo.convert_to_onnx(simplify, onnx_save_path) 174 | 175 | else: 176 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.") 177 | -------------------------------------------------------------------------------- /yolov7/predict_RGB.py: -------------------------------------------------------------------------------- 1 | #-----------------------------------------------------------------------# 2 | # predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能 3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。 4 | #-----------------------------------------------------------------------# 5 | import time 6 | 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from yolo_RGB import YOLO_RGB 12 | 13 | if __name__ == "__main__": 14 | yolo = YOLO_RGB() 15 | #----------------------------------------------------------------------------------------------------------# 16 | # mode用于指定测试的模式: 17 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释 18 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。 19 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。 20 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。 21 | # 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。 22 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。 23 | #----------------------------------------------------------------------------------------------------------# 24 | mode = "predict" 25 | #-------------------------------------------------------------------------# 26 | # crop 指定了是否在单张图片预测后对目标进行截取 27 | # count 指定了是否进行目标的计数 28 | # crop、count仅在mode='predict'时有效 29 | #-------------------------------------------------------------------------# 30 | crop = False 31 | count = False 32 | #----------------------------------------------------------------------------------------------------------# 33 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头 34 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。 35 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存 36 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。 37 | # video_fps 用于保存的视频的fps 38 | # 39 | # video_path、video_save_path和video_fps仅在mode='video'时有效 40 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。 41 | #----------------------------------------------------------------------------------------------------------# 42 | video_path = "img/1.mp4" 43 | video_save_path = "2.mp4" 44 | video_fps = 25.0 45 | #----------------------------------------------------------------------------------------------------------# 46 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。 47 | # fps_image_path 用于指定测试的fps图片 48 | # 49 | # test_interval和fps_image_path仅在mode='fps'有效 50 | #----------------------------------------------------------------------------------------------------------# 51 | test_interval = 100 52 | fps_image_path = "img/street.jpg" 53 | #-------------------------------------------------------------------------# 54 | # dir_origin_path 指定了用于检测的图片的文件夹路径 55 | # dir_save_path 指定了检测完图片的保存路径 56 | # 57 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效 58 | #-------------------------------------------------------------------------# 59 | # dir_origin_path = r"D:\KAIST数据集\重新标注的kaist\kaist_wash_picture_test\lwir" 60 | # dir_origin_path = r"D:\KAIST数据集\重新标注的kaist\kaist_wash_picture_test\visible" 61 | dir_origin_path ="../CenterNet/img/" 62 | dir_save_path = "img_out/voc4_new/" 63 | #-------------------------------------------------------------------------# 64 | # heatmap_save_path 热力图的保存路径,默认保存在model_data下 65 | # 66 | # heatmap_save_path仅在mode='heatmap'有效 67 | #-------------------------------------------------------------------------# 68 | heatmap_save_path = "model_data/heatmap_vision.png" 69 | #-------------------------------------------------------------------------# 70 | # simplify 使用Simplify onnx 71 | # onnx_save_path 指定了onnx的保存路径 72 | #-------------------------------------------------------------------------# 73 | simplify = True 74 | onnx_save_path = "model_data/models.onnx" 75 | 76 | if mode == "predict": 77 | ''' 78 | 1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。 79 | 2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。 80 | 3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值 81 | 在原图上利用矩阵的方式进行截取。 82 | 4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断, 83 | 比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。 84 | ''' 85 | while True: 86 | img = input('Input image filename:') 87 | try: 88 | image = Image.open(img) 89 | except: 90 | print('Open Error! Try again!') 91 | continue 92 | else: 93 | r_image = yolo.detect_image(image, crop = crop, count=count) 94 | r_image.show() 95 | 96 | elif mode == "video": 97 | capture = cv2.VideoCapture(video_path) 98 | if video_save_path!="": 99 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 100 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))) 101 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size) 102 | 103 | ref, frame = capture.read() 104 | if not ref: 105 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。") 106 | 107 | fps = 0.0 108 | while(True): 109 | t1 = time.time() 110 | # 读取某一帧 111 | ref, frame = capture.read() 112 | if not ref: 113 | break 114 | # 格式转变,BGRtoRGB 115 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) 116 | # 转变成Image 117 | frame = Image.fromarray(np.uint8(frame)) 118 | # 进行检测 119 | frame = np.array(yolo.detect_image(frame)) 120 | # RGBtoBGR满足opencv显示格式 121 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR) 122 | 123 | fps = ( fps + (1./(time.time()-t1)) ) / 2 124 | print("fps= %.2f"%(fps)) 125 | frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) 126 | 127 | cv2.imshow("video",frame) 128 | c= cv2.waitKey(1) & 0xff 129 | if video_save_path!="": 130 | out.write(frame) 131 | 132 | if c==27: 133 | capture.release() 134 | break 135 | 136 | print("Video Detection Done!") 137 | capture.release() 138 | if video_save_path!="": 139 | print("Save processed video to the path :" + video_save_path) 140 | out.release() 141 | cv2.destroyAllWindows() 142 | 143 | elif mode == "fps": 144 | img = Image.open(fps_image_path) 145 | tact_time = yolo.get_FPS(img, test_interval) 146 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1') 147 | 148 | elif mode == "dir_predict": 149 | import os 150 | 151 | from tqdm import tqdm 152 | 153 | img_names = os.listdir(dir_origin_path) 154 | for img_name in tqdm(img_names): 155 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): 156 | image_path = os.path.join(dir_origin_path, img_name) 157 | image = Image.open(image_path) 158 | r_image = yolo.detect_image(image) 159 | if not os.path.exists(dir_save_path): 160 | os.makedirs(dir_save_path) 161 | r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0) 162 | 163 | elif mode == "heatmap": 164 | while True: 165 | img = input('Input image filename:') 166 | try: 167 | image = Image.open(img) 168 | except: 169 | print('Open Error! Try again!') 170 | continue 171 | else: 172 | yolo.detect_heatmap(image, heatmap_save_path) 173 | 174 | elif mode == "export_onnx": 175 | yolo.convert_to_onnx(simplify, onnx_save_path) 176 | 177 | else: 178 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.") 179 | -------------------------------------------------------------------------------- /yolov7/predict_T.py: -------------------------------------------------------------------------------- 1 | #-----------------------------------------------------------------------# 2 | # predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能 3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。 4 | #-----------------------------------------------------------------------# 5 | import time 6 | 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from yolo_T import YOLO_T 12 | 13 | if __name__ == "__main__": 14 | yolo = YOLO_T() 15 | #----------------------------------------------------------------------------------------------------------# 16 | # mode用于指定测试的模式: 17 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释 18 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。 19 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。 20 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。 21 | # 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。 22 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。 23 | #----------------------------------------------------------------------------------------------------------# 24 | mode = "predict" 25 | #-------------------------------------------------------------------------# 26 | # crop 指定了是否在单张图片预测后对目标进行截取 27 | # count 指定了是否进行目标的计数 28 | # crop、count仅在mode='predict'时有效 29 | #-------------------------------------------------------------------------# 30 | crop = False 31 | count = False 32 | #----------------------------------------------------------------------------------------------------------# 33 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头 34 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。 35 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存 36 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。 37 | # video_fps 用于保存的视频的fps 38 | # 39 | # video_path、video_save_path和video_fps仅在mode='video'时有效 40 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。 41 | #----------------------------------------------------------------------------------------------------------# 42 | video_path = "img/1.mp4" 43 | video_save_path = "2.mp4" 44 | video_fps = 25.0 45 | #----------------------------------------------------------------------------------------------------------# 46 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。 47 | # fps_image_path 用于指定测试的fps图片 48 | # 49 | # test_interval和fps_image_path仅在mode='fps'有效 50 | #----------------------------------------------------------------------------------------------------------# 51 | test_interval = 100 52 | fps_image_path = "img/street.jpg" 53 | #-------------------------------------------------------------------------# 54 | # dir_origin_path 指定了用于检测的图片的文件夹路径 55 | # dir_save_path 指定了检测完图片的保存路径 56 | # 57 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效 58 | #-------------------------------------------------------------------------# 59 | # dir_origin_path = r"D:\KAIST数据集\重新标注的kaist\kaist_wash_picture_test\lwir" 60 | # dir_origin_path = r"D:\KAIST数据集\重新标注的kaist\kaist_wash_picture_test\visible" 61 | dir_origin_path ="../CenterNet/img/" 62 | dir_save_path = "img_out/voc4_new/" 63 | #-------------------------------------------------------------------------# 64 | # heatmap_save_path 热力图的保存路径,默认保存在model_data下 65 | # 66 | # heatmap_save_path仅在mode='heatmap'有效 67 | #-------------------------------------------------------------------------# 68 | heatmap_save_path = "model_data/heatmap_vision.png" 69 | #-------------------------------------------------------------------------# 70 | # simplify 使用Simplify onnx 71 | # onnx_save_path 指定了onnx的保存路径 72 | #-------------------------------------------------------------------------# 73 | simplify = True 74 | onnx_save_path = "model_data/models.onnx" 75 | 76 | if mode == "predict": 77 | ''' 78 | 1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。 79 | 2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。 80 | 3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值 81 | 在原图上利用矩阵的方式进行截取。 82 | 4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断, 83 | 比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。 84 | ''' 85 | while True: 86 | img = input('Input image filename:') 87 | try: 88 | image = Image.open(img) 89 | except: 90 | print('Open Error! Try again!') 91 | continue 92 | else: 93 | r_image = yolo.detect_image(image, crop = crop, count=count) 94 | r_image.show() 95 | 96 | elif mode == "video": 97 | capture = cv2.VideoCapture(video_path) 98 | if video_save_path!="": 99 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 100 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))) 101 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size) 102 | 103 | ref, frame = capture.read() 104 | if not ref: 105 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。") 106 | 107 | fps = 0.0 108 | while(True): 109 | t1 = time.time() 110 | # 读取某一帧 111 | ref, frame = capture.read() 112 | if not ref: 113 | break 114 | # 格式转变,BGRtoRGB 115 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) 116 | # 转变成Image 117 | frame = Image.fromarray(np.uint8(frame)) 118 | # 进行检测 119 | frame = np.array(yolo.detect_image(frame)) 120 | # RGBtoBGR满足opencv显示格式 121 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR) 122 | 123 | fps = ( fps + (1./(time.time()-t1)) ) / 2 124 | print("fps= %.2f"%(fps)) 125 | frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) 126 | 127 | cv2.imshow("video",frame) 128 | c= cv2.waitKey(1) & 0xff 129 | if video_save_path!="": 130 | out.write(frame) 131 | 132 | if c==27: 133 | capture.release() 134 | break 135 | 136 | print("Video Detection Done!") 137 | capture.release() 138 | if video_save_path!="": 139 | print("Save processed video to the path :" + video_save_path) 140 | out.release() 141 | cv2.destroyAllWindows() 142 | 143 | elif mode == "fps": 144 | img = Image.open(fps_image_path) 145 | tact_time = yolo.get_FPS(img, test_interval) 146 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1') 147 | 148 | elif mode == "dir_predict": 149 | import os 150 | 151 | from tqdm import tqdm 152 | 153 | img_names = os.listdir(dir_origin_path) 154 | for img_name in tqdm(img_names): 155 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): 156 | image_path = os.path.join(dir_origin_path, img_name) 157 | image = Image.open(image_path) 158 | r_image = yolo.detect_image(image) 159 | if not os.path.exists(dir_save_path): 160 | os.makedirs(dir_save_path) 161 | r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0) 162 | 163 | elif mode == "heatmap": 164 | while True: 165 | img = input('Input image filename:') 166 | try: 167 | image = Image.open(img) 168 | except: 169 | print('Open Error! Try again!') 170 | continue 171 | else: 172 | yolo.detect_heatmap(image, heatmap_save_path) 173 | 174 | elif mode == "export_onnx": 175 | yolo.convert_to_onnx(simplify, onnx_save_path) 176 | 177 | else: 178 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.") 179 | -------------------------------------------------------------------------------- /yolov7/summary.py: -------------------------------------------------------------------------------- 1 | #--------------------------------------------# 2 | # 该部分代码用于看网络结构 3 | #--------------------------------------------# 4 | import torch 5 | from thop import clever_format, profile 6 | 7 | from nets.yolo import YoloBody 8 | 9 | if __name__ == "__main__": 10 | input_shape = [640, 640] 11 | anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] 12 | num_classes = 80 13 | phi = 'l' 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | m = YoloBody(anchors_mask, num_classes, phi, False, phi_attention=1).to(device) 17 | for i in m.children(): 18 | print(i) 19 | print('==============================') 20 | 21 | dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device) 22 | flops, params = profile(m.to(device), (dummy_input, ), verbose=False) 23 | #--------------------------------------------------------# 24 | # flops * 2是因为profile没有将卷积作为两个operations 25 | # 有些论文将卷积算乘法、加法两个operations。此时乘2 26 | # 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2 27 | # 本代码选择乘2,参考YOLOX。 28 | #--------------------------------------------------------# 29 | flops = flops * 2 30 | flops, params = clever_format([flops, params], "%.3f") 31 | print('Total GFLOPS: %s' % (flops)) 32 | print('Total params: %s' % (params)) 33 | -------------------------------------------------------------------------------- /yolov7/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /yolov7/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | #---------------------------------------------------------# 6 | # 将图像转换成RGB图像,防止灰度图在预测时报错。 7 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 8 | #---------------------------------------------------------# 9 | def cvtColor(image): 10 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: 11 | return image 12 | else: 13 | image = image.convert('RGB') 14 | return image 15 | 16 | #---------------------------------------------------# 17 | # 对输入图像进行resize 18 | #---------------------------------------------------# 19 | def resize_image(image, size, letterbox_image): 20 | iw, ih = image.size 21 | w, h = size 22 | if letterbox_image: 23 | scale = min(w/iw, h/ih) 24 | nw = int(iw*scale) 25 | nh = int(ih*scale) 26 | 27 | image = image.resize((nw,nh), Image.BICUBIC) 28 | new_image = Image.new('RGB', size, (128,128,128)) 29 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 30 | else: 31 | new_image = image.resize((w, h), Image.BICUBIC) 32 | return new_image 33 | 34 | #---------------------------------------------------# 35 | # 获得类 36 | #---------------------------------------------------# 37 | def get_classes(classes_path): 38 | with open(classes_path, encoding='utf-8') as f: 39 | class_names = f.readlines() 40 | class_names = [c.strip() for c in class_names] 41 | return class_names, len(class_names) 42 | 43 | #---------------------------------------------------# 44 | # 获得先验框 45 | #---------------------------------------------------# 46 | def get_anchors(anchors_path): 47 | '''loads the anchors from a file''' 48 | with open(anchors_path, encoding='utf-8') as f: 49 | anchors = f.readline() 50 | anchors = [float(x) for x in anchors.split(',')] 51 | anchors = np.array(anchors).reshape(-1, 2) 52 | return anchors, len(anchors) 53 | 54 | #---------------------------------------------------# 55 | # 获得学习率 56 | #---------------------------------------------------# 57 | def get_lr(optimizer): 58 | for param_group in optimizer.param_groups: 59 | return param_group['lr'] 60 | 61 | def preprocess_input(image): 62 | image /= 255.0 63 | return image 64 | 65 | def show_config(**kwargs): 66 | print('Configurations:') 67 | print('-' * 70) 68 | print('|%25s | %40s|' % ('keys', 'values')) 69 | print('-' * 70) 70 | for key, value in kwargs.items(): 71 | print('|%25s | %40s|' % (str(key), str(value))) 72 | print('-' * 70) 73 | 74 | def download_weights(phi, model_dir="./model_data"): 75 | import os 76 | from torch.hub import load_state_dict_from_url 77 | 78 | download_urls = { 79 | "l" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_backbone_weights.pth', 80 | "x" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_x_backbone_weights.pth', 81 | } 82 | url = download_urls[phi] 83 | 84 | if not os.path.exists(model_dir): 85 | os.makedirs(model_dir) 86 | load_state_dict_from_url(url, model_dir) -------------------------------------------------------------------------------- /yolov7/utils/utils_fit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from yolov7.utils.utils import get_lr 7 | 8 | def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0): 9 | loss = 0 10 | val_loss = 0 11 | 12 | if local_rank == 0: 13 | print('Start Train') 14 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 15 | model_train.train() 16 | for iteration, batch in enumerate(gen): 17 | if iteration >= epoch_step: 18 | break 19 | 20 | images, targets = batch[0], batch[1] 21 | # print(batch) 22 | # print(targets) 23 | with torch.no_grad(): 24 | if cuda: 25 | images = images.cuda(local_rank) 26 | targets = targets.cuda(local_rank) 27 | #----------------------# 28 | # 清零梯度 29 | #----------------------# 30 | optimizer.zero_grad() 31 | if not fp16: 32 | #----------------------# 33 | # 前向传播 34 | #----------------------# 35 | outputs = model_train(images) 36 | loss_value = yolo_loss(outputs, targets, images) 37 | 38 | #----------------------# 39 | # 反向传播 40 | #----------------------# 41 | loss_value.backward() 42 | optimizer.step() 43 | else: 44 | from torch.cuda.amp import autocast 45 | with autocast(): 46 | #----------------------# 47 | # 前向传播 48 | #----------------------# 49 | outputs = model_train(images) 50 | 51 | loss_value = yolo_loss(outputs, targets, images) 52 | 53 | #----------------------# 54 | # 反向传播 55 | #----------------------# 56 | scaler.scale(loss_value).backward() 57 | scaler.step(optimizer) 58 | scaler.update() 59 | if ema: 60 | ema.update(model_train) 61 | 62 | loss += loss_value.item() 63 | 64 | if local_rank == 0: 65 | pbar.set_postfix(**{'loss' : loss / (iteration + 1), 66 | 'lr' : get_lr(optimizer)}) 67 | pbar.update(1) 68 | 69 | if local_rank == 0: 70 | pbar.close() 71 | print('Finish Train') 72 | print('Start Validation') 73 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 74 | 75 | if ema: 76 | model_train_eval = ema.ema 77 | else: 78 | model_train_eval = model_train.eval() 79 | 80 | for iteration, batch in enumerate(gen_val): 81 | if iteration >= epoch_step_val: 82 | break 83 | images, targets = batch[0], batch[1] 84 | with torch.no_grad(): 85 | if cuda: 86 | images = images.cuda(local_rank) 87 | targets = targets.cuda(local_rank) 88 | #----------------------# 89 | # 清零梯度 90 | #----------------------# 91 | optimizer.zero_grad() 92 | #----------------------# 93 | # 前向传播 94 | #----------------------# 95 | outputs = model_train_eval(images) 96 | loss_value = yolo_loss(outputs, targets, images) 97 | 98 | val_loss += loss_value.item() 99 | if local_rank == 0: 100 | pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)}) 101 | pbar.update(1) 102 | 103 | if local_rank == 0: 104 | pbar.close() 105 | print('Finish Validation') 106 | loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val) 107 | eval_callback.on_epoch_end(epoch + 1, model_train_eval) 108 | print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch)) 109 | print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val)) 110 | 111 | #-----------------------------------------------# 112 | # 保存权值 113 | #-----------------------------------------------# 114 | if ema: 115 | save_state_dict = ema.ema.state_dict() 116 | else: 117 | save_state_dict = model.state_dict() 118 | 119 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: 120 | torch.save(save_state_dict, os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val))) 121 | 122 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss): 123 | print('Save best model to best_epoch_weights.pth') 124 | torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth")) 125 | 126 | torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth")) -------------------------------------------------------------------------------- /yolov7/utils_coco/coco_annotation.py: -------------------------------------------------------------------------------- 1 | #-------------------------------------------------------# 2 | # 用于处理COCO数据集,根据json文件生成txt文件用于训练 3 | #-------------------------------------------------------# 4 | import json 5 | import os 6 | from collections import defaultdict 7 | 8 | #-------------------------------------------------------# 9 | # 指向了COCO训练集与验证集图片的路径 10 | #-------------------------------------------------------# 11 | train_datasets_path = "coco_dataset/train2017" 12 | val_datasets_path = "coco_dataset/val2017" 13 | 14 | #-------------------------------------------------------# 15 | # 指向了COCO训练集与验证集标签的路径 16 | #-------------------------------------------------------# 17 | train_annotation_path = "coco_dataset/annotations/instances_train2017.json" 18 | val_annotation_path = "coco_dataset/annotations/instances_val2017.json" 19 | 20 | #-------------------------------------------------------# 21 | # 生成的txt文件路径 22 | #-------------------------------------------------------# 23 | train_output_path = "coco_train.txt" 24 | val_output_path = "coco_val.txt" 25 | 26 | if __name__ == "__main__": 27 | name_box_id = defaultdict(list) 28 | id_name = dict() 29 | f = open(train_annotation_path, encoding='utf-8') 30 | data = json.load(f) 31 | 32 | annotations = data['annotations'] 33 | for ant in annotations: 34 | id = ant['image_id'] 35 | name = os.path.join(train_datasets_path, '%012d.jpg' % id) 36 | cat = ant['category_id'] 37 | if cat >= 1 and cat <= 11: 38 | cat = cat - 1 39 | elif cat >= 13 and cat <= 25: 40 | cat = cat - 2 41 | elif cat >= 27 and cat <= 28: 42 | cat = cat - 3 43 | elif cat >= 31 and cat <= 44: 44 | cat = cat - 5 45 | elif cat >= 46 and cat <= 65: 46 | cat = cat - 6 47 | elif cat == 67: 48 | cat = cat - 7 49 | elif cat == 70: 50 | cat = cat - 9 51 | elif cat >= 72 and cat <= 82: 52 | cat = cat - 10 53 | elif cat >= 84 and cat <= 90: 54 | cat = cat - 11 55 | name_box_id[name].append([ant['bbox'], cat]) 56 | 57 | f = open(train_output_path, 'w') 58 | for key in name_box_id.keys(): 59 | f.write(key) 60 | box_infos = name_box_id[key] 61 | for info in box_infos: 62 | x_min = int(info[0][0]) 63 | y_min = int(info[0][1]) 64 | x_max = x_min + int(info[0][2]) 65 | y_max = y_min + int(info[0][3]) 66 | 67 | box_info = " %d,%d,%d,%d,%d" % ( 68 | x_min, y_min, x_max, y_max, int(info[1])) 69 | f.write(box_info) 70 | f.write('\n') 71 | f.close() 72 | 73 | name_box_id = defaultdict(list) 74 | id_name = dict() 75 | f = open(val_annotation_path, encoding='utf-8') 76 | data = json.load(f) 77 | 78 | annotations = data['annotations'] 79 | for ant in annotations: 80 | id = ant['image_id'] 81 | name = os.path.join(val_datasets_path, '%012d.jpg' % id) 82 | cat = ant['category_id'] 83 | if cat >= 1 and cat <= 11: 84 | cat = cat - 1 85 | elif cat >= 13 and cat <= 25: 86 | cat = cat - 2 87 | elif cat >= 27 and cat <= 28: 88 | cat = cat - 3 89 | elif cat >= 31 and cat <= 44: 90 | cat = cat - 5 91 | elif cat >= 46 and cat <= 65: 92 | cat = cat - 6 93 | elif cat == 67: 94 | cat = cat - 7 95 | elif cat == 70: 96 | cat = cat - 9 97 | elif cat >= 72 and cat <= 82: 98 | cat = cat - 10 99 | elif cat >= 84 and cat <= 90: 100 | cat = cat - 11 101 | name_box_id[name].append([ant['bbox'], cat]) 102 | 103 | f = open(val_output_path, 'w') 104 | for key in name_box_id.keys(): 105 | f.write(key) 106 | box_infos = name_box_id[key] 107 | for info in box_infos: 108 | x_min = int(info[0][0]) 109 | y_min = int(info[0][1]) 110 | x_max = x_min + int(info[0][2]) 111 | y_max = y_min + int(info[0][3]) 112 | 113 | box_info = " %d,%d,%d,%d,%d" % ( 114 | x_min, y_min, x_max, y_max, int(info[1])) 115 | f.write(box_info) 116 | f.write('\n') 117 | f.close() 118 | -------------------------------------------------------------------------------- /yolov7/utils_coco/get_map_coco.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from pycocotools.coco import COCO 8 | from pycocotools.cocoeval import COCOeval 9 | from tqdm import tqdm 10 | 11 | from utils.utils import cvtColor, preprocess_input, resize_image 12 | from yolo import YOLO 13 | 14 | #---------------------------------------------------------------------------# 15 | # map_mode用于指定该文件运行时计算的内容 16 | # map_mode为0代表整个map计算流程,包括获得预测结果、计算map。 17 | # map_mode为1代表仅仅获得预测结果。 18 | # map_mode为2代表仅仅获得计算map。 19 | #---------------------------------------------------------------------------# 20 | map_mode = 0 21 | #-------------------------------------------------------# 22 | # 指向了验证集标签与图片路径 23 | #-------------------------------------------------------# 24 | cocoGt_path = 'coco_dataset/annotations/instances_val2017.json' 25 | dataset_img_path = 'coco_dataset/val2017' 26 | #-------------------------------------------------------# 27 | # 结果输出的文件夹,默认为map_out 28 | #-------------------------------------------------------# 29 | temp_save_path = 'map_out/coco_eval' 30 | 31 | class mAP_YOLO(YOLO): 32 | #---------------------------------------------------# 33 | # 检测图片 34 | #---------------------------------------------------# 35 | def detect_image(self, image_id, image, results, clsid2catid): 36 | #---------------------------------------------------# 37 | # 计算输入图片的高和宽 38 | #---------------------------------------------------# 39 | image_shape = np.array(np.shape(image)[0:2]) 40 | #---------------------------------------------------------# 41 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 42 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 43 | #---------------------------------------------------------# 44 | image = cvtColor(image) 45 | #---------------------------------------------------------# 46 | # 给图像增加灰条,实现不失真的resize 47 | # 也可以直接resize进行识别 48 | #---------------------------------------------------------# 49 | image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image) 50 | #---------------------------------------------------------# 51 | # 添加上batch_size维度 52 | #---------------------------------------------------------# 53 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) 54 | 55 | with torch.no_grad(): 56 | images = torch.from_numpy(image_data) 57 | if self.cuda: 58 | images = images.cuda() 59 | #---------------------------------------------------------# 60 | # 将图像输入网络当中进行预测! 61 | #---------------------------------------------------------# 62 | outputs = self.net(images) 63 | outputs = self.bbox_util.decode_box(outputs) 64 | #---------------------------------------------------------# 65 | # 将预测框进行堆叠,然后进行非极大抑制 66 | #---------------------------------------------------------# 67 | outputs = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape, 68 | image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou) 69 | 70 | if outputs[0] is None: 71 | return results 72 | 73 | top_label = np.array(outputs[0][:, 6], dtype = 'int32') 74 | top_conf = outputs[0][:, 4] * outputs[0][:, 5] 75 | top_boxes = outputs[0][:, :4] 76 | 77 | for i, c in enumerate(top_label): 78 | result = {} 79 | top, left, bottom, right = top_boxes[i] 80 | 81 | result["image_id"] = int(image_id) 82 | result["category_id"] = clsid2catid[c] 83 | result["bbox"] = [float(left),float(top),float(right-left),float(bottom-top)] 84 | result["score"] = float(top_conf[i]) 85 | results.append(result) 86 | return results 87 | 88 | if __name__ == "__main__": 89 | if not os.path.exists(temp_save_path): 90 | os.makedirs(temp_save_path) 91 | 92 | cocoGt = COCO(cocoGt_path) 93 | ids = list(cocoGt.imgToAnns.keys()) 94 | clsid2catid = cocoGt.getCatIds() 95 | 96 | if map_mode == 0 or map_mode == 1: 97 | yolo = mAP_YOLO(confidence = 0.001, nms_iou = 0.65) 98 | 99 | with open(os.path.join(temp_save_path, 'eval_results.json'),"w") as f: 100 | results = [] 101 | for image_id in tqdm(ids): 102 | image_path = os.path.join(dataset_img_path, cocoGt.loadImgs(image_id)[0]['file_name']) 103 | image = Image.open(image_path) 104 | results = yolo.detect_image(image_id, image, results, clsid2catid) 105 | json.dump(results, f) 106 | 107 | if map_mode == 0 or map_mode == 2: 108 | cocoDt = cocoGt.loadRes(os.path.join(temp_save_path, 'eval_results.json')) 109 | cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') 110 | cocoEval.evaluate() 111 | cocoEval.accumulate() 112 | cocoEval.summarize() 113 | print("Get map done.") 114 | -------------------------------------------------------------------------------- /yolov7/voc_annotation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import xml.etree.ElementTree as ET 4 | 5 | import numpy as np 6 | 7 | from utils.utils import get_classes 8 | 9 | #--------------------------------------------------------------------------------------------------------------------------------# 10 | # annotation_mode用于指定该文件运行时计算的内容 11 | # annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt 12 | # annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt 13 | # annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt 14 | #--------------------------------------------------------------------------------------------------------------------------------# 15 | annotation_mode = 0 16 | #-------------------------------------------------------------------# 17 | # 必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息 18 | # 与训练和预测所用的classes_path一致即可 19 | # 如果生成的2007_train.txt里面没有目标信息 20 | # 那么就是因为classes没有设定正确 21 | # 仅在annotation_mode为0和2的时候有效 22 | #-------------------------------------------------------------------# 23 | classes_path = r'D:\Deep_Learning_folds\ProbEn\yolov7\model_data\voc_classes.txt' 24 | #--------------------------------------------------------------------------------------------------------------------------------# 25 | # trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1 26 | # train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1 27 | # 仅在annotation_mode为0和1的时候有效 28 | #--------------------------------------------------------------------------------------------------------------------------------# 29 | trainval_percent = 0.9 30 | train_percent = 0.9 31 | #-------------------------------------------------------# 32 | # 指向VOC数据集所在的文件夹 33 | # 默认指向根目录下的VOC数据集 34 | #-------------------------------------------------------# 35 | VOCdevkit_path = 'D:\Deep_Learning_folds\ProbEn\yolov7\VOCdevkit' 36 | 37 | VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')] 38 | classes, _ = get_classes(classes_path) 39 | 40 | #-------------------------------------------------------# 41 | # 统计目标数量 42 | #-------------------------------------------------------# 43 | photo_nums = np.zeros(len(VOCdevkit_sets)) 44 | nums = np.zeros(len(classes)) 45 | def convert_annotation(year, image_id, list_file): 46 | in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8') 47 | tree=ET.parse(in_file) 48 | root = tree.getroot() 49 | 50 | for obj in root.iter('object'): 51 | difficult = 0 52 | if obj.find('difficult')!=None: 53 | difficult = obj.find('difficult').text 54 | cls = obj.find('name').text 55 | if cls not in classes or int(difficult)==1: 56 | continue 57 | cls_id = classes.index(cls) 58 | xmlbox = obj.find('bndbox') 59 | b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text))) 60 | list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id)) 61 | 62 | nums[classes.index(cls)] = nums[classes.index(cls)] + 1 63 | 64 | if __name__ == "__main__": 65 | random.seed(0) 66 | if " " in os.path.abspath(VOCdevkit_path): 67 | raise ValueError("数据集存放的文件夹路径与图片名称中不可以存在空格,否则会影响正常的模型训练,请注意修改。") 68 | 69 | if annotation_mode == 0 or annotation_mode == 1: 70 | print("Generate txt in ImageSets.") 71 | xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007/Annotations') 72 | saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main') 73 | temp_xml = os.listdir(xmlfilepath) 74 | total_xml = [] 75 | for xml in temp_xml: 76 | if xml.endswith(".xml"): 77 | total_xml.append(xml) 78 | 79 | num = len(total_xml) 80 | list = range(num) 81 | tv = int(num*trainval_percent) 82 | tr = int(tv*train_percent) 83 | trainval= random.sample(list,tv) 84 | train = random.sample(trainval,tr) 85 | 86 | print("train and val size",tv) 87 | print("train size",tr) 88 | ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w') 89 | ftest = open(os.path.join(saveBasePath,'test.txt'), 'w') 90 | ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w') 91 | fval = open(os.path.join(saveBasePath,'val.txt'), 'w') 92 | 93 | for i in list: 94 | name=total_xml[i][:-4]+'\n' 95 | if i in trainval: 96 | ftrainval.write(name) 97 | if i in train: 98 | ftrain.write(name) 99 | else: 100 | fval.write(name) 101 | else: 102 | ftest.write(name) 103 | 104 | ftrainval.close() 105 | ftrain.close() 106 | fval.close() 107 | ftest.close() 108 | print("Generate txt in ImageSets done.") 109 | 110 | if annotation_mode == 0 or annotation_mode == 2: 111 | print("Generate 2007_train.txt and 2007_val.txt for train.") 112 | type_index = 0 113 | for year, image_set in VOCdevkit_sets: 114 | image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split() 115 | list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8') 116 | for image_id in image_ids: 117 | list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id)) 118 | 119 | convert_annotation(year, image_id, list_file) 120 | list_file.write('\n') 121 | photo_nums[type_index] = len(image_ids) 122 | type_index += 1 123 | list_file.close() 124 | print("Generate 2007_train.txt and 2007_val.txt for train done.") 125 | 126 | def printTable(List1, List2): 127 | for i in range(len(List1[0])): 128 | print("|", end=' ') 129 | for j in range(len(List1)): 130 | print(List1[j][i].rjust(int(List2[j])), end=' ') 131 | print("|", end=' ') 132 | print() 133 | 134 | str_nums = [str(int(x)) for x in nums] 135 | tableData = [ 136 | classes, str_nums 137 | ] 138 | colWidths = [0]*len(tableData) 139 | len1 = 0 140 | for i in range(len(tableData)): 141 | for j in range(len(tableData[i])): 142 | if len(tableData[i][j]) > colWidths[i]: 143 | colWidths[i] = len(tableData[i][j]) 144 | printTable(tableData, colWidths) 145 | 146 | if photo_nums[0] <= 500: 147 | print("训练集数量小于500,属于较小的数据量,请注意设置较大的训练世代(Epoch)以满足足够的梯度下降次数(Step)。") 148 | 149 | if np.sum(nums) == 0: 150 | print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!") 151 | print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!") 152 | print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!") 153 | print("(重要的事情说三遍)。") 154 | --------------------------------------------------------------------------------