├── data ├── __init__.py ├── scripts │ ├── COCO2017.sh │ ├── VOC2012.sh │ └── VOC2007.sh ├── coco.py ├── voc0712.py └── transform.py ├── utils ├── __init__.py ├── com_paras_flops.py ├── misc.py └── kmeans_anchor.py ├── .gitignore ├── config ├── __init__.py └── yolov2_config.py ├── requirements.txt ├── train.sh ├── backbone ├── __init__.py └── darknet19.py ├── models ├── build.py ├── basic.py ├── loss.py ├── matcher.py └── yolov2.py ├── README.md ├── eval.py ├── evaluator ├── cocoapi_evaluator.py └── vocapi_evaluator.py ├── test.py ├── tools.py └── train.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .vscode 3 | *pt 4 | *.pth 5 | *.pkl 6 | *.pyc 7 | det_results 8 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .yolov2_config import yolov2_config 2 | 3 | 4 | def build_model_config(args): 5 | return yolov2_config[args.version] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | 3 | torchvision 4 | 5 | opencv-python 6 | 7 | thop 8 | 9 | scipy 10 | 11 | matplotlib 12 | 13 | numpy 14 | 15 | pycocotools 16 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --cuda \ 3 | -d coco \ 4 | -ms \ 5 | -bs 16 \ 6 | -accu 4 \ 7 | --lr 0.001 \ 8 | --max_epoch 200 \ 9 | --lr_epoch 100 150 \ 10 | -------------------------------------------------------------------------------- /backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .darknet19 import build_darknet19 2 | 3 | 4 | def build_backbone(model_name='darknet19', pretrained=False): 5 | if model_name == 'darknet19': 6 | backbone, feat_dims = build_darknet19(pretrained) 7 | 8 | return backbone, feat_dims 9 | -------------------------------------------------------------------------------- /utils/com_paras_flops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from thop import profile 3 | 4 | 5 | def FLOPs_and_Params(model, img_size, device): 6 | x = torch.randn(1, 3, img_size, img_size).to(device) 7 | print('==============================') 8 | flops, params = profile(model, inputs=(x, )) 9 | print('==============================') 10 | print('FLOPs : {:.2f} B'.format(flops / 1e9)) 11 | print('Params : {:.2f} M'.format(params / 1e6)) 12 | 13 | 14 | if __name__ == "__main__": 15 | pass 16 | -------------------------------------------------------------------------------- /models/build.py: -------------------------------------------------------------------------------- 1 | from .yolov2 import YOLOv2 2 | 3 | 4 | def build_yolov2(args, cfg, device, input_size, num_classes=20, trainable=False): 5 | anchor_size = cfg['anchor_size'][args.dataset] 6 | 7 | model = YOLOv2( 8 | cfg=cfg, 9 | device=device, 10 | input_size=input_size, 11 | num_classes=num_classes, 12 | trainable=trainable, 13 | conf_thresh=args.conf_thresh, 14 | nms_thresh=args.nms_thresh, 15 | topk=args.topk, 16 | anchor_size=anchor_size 17 | ) 18 | 19 | return model 20 | -------------------------------------------------------------------------------- /config/yolov2_config.py: -------------------------------------------------------------------------------- 1 | # yolov2 config 2 | 3 | 4 | yolov2_config = { 5 | 'yolov2': { 6 | # model 7 | 'backbone': 'darknet19', 8 | 'pretrained': True, 9 | 'stride': 32, # P5 10 | 'reorg_dim': 64, 11 | 'head_dim': 1024, 12 | # anchor size 13 | 'anchor_size': { 14 | 'voc': [[1.19, 1.98], [2.79, 4.59], [4.53, 8.92], [8.06, 5.29], [10.32, 10.65]], 15 | 'coco': [[0.53, 0.79], [1.71, 2.36], [2.89, 6.44], [6.33, 3.79], [9.03, 9.74]] 16 | }, 17 | # matcher 18 | 'ignore_thresh': 0.5, 19 | }, 20 | } -------------------------------------------------------------------------------- /data/scripts/COCO2017.sh: -------------------------------------------------------------------------------- 1 | mkdir COCO 2 | cd COCO 3 | 4 | wget http://images.cocodataset.org/zips/train2017.zip 5 | wget http://images.cocodataset.org/zips/val2017.zip 6 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip 7 | wget http://images.cocodataset.org/zips/test2017.zip 8 | wget http://images.cocodataset.org/annotations/image_info_test2017.zip  9 | 10 | unzip train2017.zip 11 | unzip val2017.zip 12 | unzip annotations_trainval2017.zip 13 | unzip test2017.zip 14 | unzip image_info_test2017.zip 15 | 16 | # rm -f train2017.zip 17 | # rm -f val2017.zip 18 | # rm -f annotations_trainval2017.zip 19 | # rm -f test2017.zip 20 | # rm -f image_info_test2017.zip 21 | -------------------------------------------------------------------------------- /data/scripts/VOC2012.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Ellis Brown 3 | 4 | start=`date +%s` 5 | 6 | # handle optional download dir 7 | if [ -z "$1" ] 8 | then 9 | # navigate to ~/data 10 | echo "navigating to ~/data/ ..." 11 | mkdir -p ~/data 12 | cd ~/data/ 13 | else 14 | # check if is valid directory 15 | if [ ! -d $1 ]; then 16 | echo $1 "is not a valid directory" 17 | exit 0 18 | fi 19 | echo "navigating to" $1 "..." 20 | cd $1 21 | fi 22 | 23 | echo "Downloading VOC2012 trainval ..." 24 | # Download the data. 25 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 26 | echo "Done downloading." 27 | 28 | 29 | # Extract data 30 | echo "Extracting trainval ..." 31 | tar -xvf VOCtrainval_11-May-2012.tar 32 | echo "removing tar ..." 33 | rm VOCtrainval_11-May-2012.tar 34 | 35 | end=`date +%s` 36 | runtime=$((end-start)) 37 | 38 | echo "Completed in" $runtime "seconds" -------------------------------------------------------------------------------- /data/scripts/VOC2007.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Ellis Brown 3 | 4 | start=`date +%s` 5 | 6 | # handle optional download dir 7 | if [ -z "$1" ] 8 | then 9 | # navigate to ~/data 10 | echo "navigating to ~/data/ ..." 11 | mkdir -p ~/data 12 | cd ~/data/ 13 | else 14 | # check if is valid directory 15 | if [ ! -d $1 ]; then 16 | echo $1 "is not a valid directory" 17 | exit 0 18 | fi 19 | echo "navigating to" $1 "..." 20 | cd $1 21 | fi 22 | 23 | echo "Downloading VOC2007 trainval ..." 24 | # Download the data. 25 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar 26 | echo "Downloading VOC2007 test data ..." 27 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar 28 | echo "Done downloading." 29 | 30 | # Extract data 31 | echo "Extracting trainval ..." 32 | tar -xvf VOCtrainval_06-Nov-2007.tar 33 | echo "Extracting test ..." 34 | tar -xvf VOCtest_06-Nov-2007.tar 35 | echo "removing tars ..." 36 | rm VOCtrainval_06-Nov-2007.tar 37 | rm VOCtest_06-Nov-2007.tar 38 | 39 | end=`date +%s` 40 | runtime=$((end-start)) 41 | 42 | echo "Completed in" $runtime "seconds" -------------------------------------------------------------------------------- /models/basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Conv(nn.Module): 6 | def __init__(self, in_ch, out_ch, k=1, p=0, s=1, d=1, act=True): 7 | super(Conv, self).__init__() 8 | self.convs = nn.Sequential( 9 | nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, dilation=d, bias=False), 10 | nn.BatchNorm2d(out_ch), 11 | nn.LeakyReLU(0.1, inplace=True) if act else nn.Identity() 12 | ) 13 | 14 | def forward(self, x): 15 | return self.convs(x) 16 | 17 | 18 | class reorg_layer(nn.Module): 19 | def __init__(self, stride): 20 | super(reorg_layer, self).__init__() 21 | self.stride = stride 22 | 23 | def forward(self, x): 24 | batch_size, channels, height, width = x.size() 25 | _height, _width = height // self.stride, width // self.stride 26 | 27 | x = x.view(batch_size, channels, _height, self.stride, _width, self.stride).transpose(3, 4).contiguous() 28 | x = x.view(batch_size, channels, _height * _width, self.stride * self.stride).transpose(2, 3).contiguous() 29 | x = x.view(batch_size, channels, self.stride * self.stride, _height, _width).transpose(1, 2).contiguous() 30 | x = x.view(batch_size, -1, _height, _width) 31 | 32 | return x 33 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def detection_collate(batch): 6 | """Custom collate fn for dealing with batches of images that have a different 7 | number of associated object annotations (bounding boxes). 8 | 9 | Arguments: 10 | batch: (tuple) A tuple of tensor images and lists of annotations 11 | 12 | Return: 13 | A tuple containing: 14 | 1) (tensor) batch of images stacked on their 0 dim 15 | 2) (list of tensors) annotations for a given image are stacked on 16 | 0 dim 17 | """ 18 | targets = [] 19 | imgs = [] 20 | for sample in batch: 21 | imgs.append(sample[0]) 22 | targets.append(torch.FloatTensor(sample[1])) 23 | return torch.stack(imgs, 0), targets 24 | 25 | 26 | def load_weight(model, path_to_ckpt=None): 27 | # check 28 | if path_to_ckpt is None: 29 | print('no weight file ...') 30 | return model 31 | 32 | checkpoint_state_dict = torch.load(path_to_ckpt, map_location='cpu') 33 | # model state dict 34 | model_state_dict = model.state_dict() 35 | # check 36 | for k in list(checkpoint_state_dict.keys()): 37 | if k in model_state_dict: 38 | shape_model = tuple(model_state_dict[k].shape) 39 | shape_checkpoint = tuple(checkpoint_state_dict[k].shape) 40 | if shape_model != shape_checkpoint: 41 | checkpoint_state_dict.pop(k) 42 | else: 43 | checkpoint_state_dict.pop(k) 44 | print(k) 45 | 46 | model.load_state_dict(checkpoint_state_dict) 47 | print('Finished loading model!') 48 | 49 | return model 50 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class MSEWithLogitsLoss(nn.Module): 7 | def __init__(self): 8 | super(MSEWithLogitsLoss, self).__init__() 9 | 10 | def forward(self, logits, targets, mask): 11 | inputs = torch.clamp(torch.sigmoid(logits), min=1e-4, max=1.0 - 1e-4) 12 | 13 | # 被忽略的先验框的mask都是-1,不参与loss计算 14 | pos_id = (mask==1.0).float() 15 | neg_id = (mask==0.0).float() 16 | pos_loss = pos_id * (inputs - targets)**2 17 | neg_loss = neg_id * (inputs)**2 18 | loss = 5.0*pos_loss + 1.0*neg_loss 19 | 20 | return loss 21 | 22 | 23 | def iou_score(bboxes_a, bboxes_b): 24 | """ 25 | bbox_1 : [B*N, 4] = [x1, y1, x2, y2] 26 | bbox_2 : [B*N, 4] = [x1, y1, x2, y2] 27 | """ 28 | tl = torch.max(bboxes_a[:, :2], bboxes_b[:, :2]) 29 | br = torch.min(bboxes_a[:, 2:], bboxes_b[:, 2:]) 30 | area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) 31 | area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) 32 | 33 | en = (tl < br).type(tl.type()).prod(dim=1) 34 | area_i = torch.prod(br - tl, 1) * en # * ((tl < br).all()) 35 | return area_i / (area_a + area_b - area_i) 36 | 37 | 38 | def compute_loss(pred_conf, pred_cls, pred_txtytwth, targets): 39 | batch_size = pred_conf.size(0) 40 | # 损失函数 41 | conf_loss_function = MSEWithLogitsLoss() 42 | cls_loss_function = nn.CrossEntropyLoss(reduction='none') 43 | txty_loss_function = nn.BCEWithLogitsLoss(reduction='none') 44 | twth_loss_function = nn.MSELoss(reduction='none') 45 | 46 | # 预测 47 | pred_conf = pred_conf[..., 0] # [B, HW,] 48 | pred_cls = pred_cls.permute(0, 2, 1) # [B, C, HW] 49 | pred_txty = pred_txtytwth[..., :2] # [B, HW, 2] 50 | pred_twth = pred_txtytwth[..., 2:] # [B, HW, 2] 51 | 52 | # 标签 53 | gt_conf = targets[..., 0].float() # [B, HW,] 54 | gt_obj = targets[..., 1].float() # [B, HW,] 55 | gt_cls = targets[..., 2].long() # [B, HW,] 56 | gt_txty = targets[..., 3:5].float() # [B, HW, 2] 57 | gt_twth = targets[..., 5:7].float() # [B, HW, 2] 58 | gt_box_scale_weight = targets[..., 7] # [B, HW,] 59 | gt_mask = (gt_box_scale_weight > 0.).float() # [B, HW,] 60 | 61 | # 置信度损失 62 | conf_loss = conf_loss_function(pred_conf, gt_conf, gt_obj) 63 | conf_loss = conf_loss.sum() / batch_size 64 | 65 | # 类别损失 66 | cls_loss = cls_loss_function(pred_cls, gt_cls) * gt_mask 67 | cls_loss = cls_loss.sum() / batch_size 68 | 69 | # 边界框txty的损失 70 | txty_loss = txty_loss_function(pred_txty, gt_txty).sum(-1) * gt_mask * gt_box_scale_weight 71 | txty_loss = txty_loss.sum() / batch_size 72 | 73 | # 边界框twth的损失 74 | twth_loss = twth_loss_function(pred_twth, gt_twth).sum(-1) * gt_mask * gt_box_scale_weight 75 | twth_loss = twth_loss.sum() / batch_size 76 | bbox_loss = txty_loss + twth_loss 77 | 78 | #总的损失 79 | total_loss = conf_loss + cls_loss + bbox_loss 80 | 81 | return conf_loss, cls_loss, bbox_loss, total_loss 82 | 83 | 84 | if __name__ == "__main__": 85 | pass 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # PyTorch_YOLOv2 3 | 这个YOLOv2项目是配合我在知乎专栏上连载的《YOLO入门教程》而创建的: 4 | 5 | https://zhuanlan.zhihu.com/c_1364967262269693952 6 | 7 | 感兴趣的小伙伴可以配合着上面的专栏来一起学习,入门目标检测。 8 | 9 | 这里也诚挚推荐我的另一个YOLO项目,训练更加稳定,性能更好呦 10 | 11 | https://github.com/yjh0410/PyTorch_YOLO-Family 12 | 13 | # 配置环境 14 | - 我们建议使用anaconda来创建虚拟环境: 15 | ```Shell 16 | conda create -n yolo python=3.6 17 | ``` 18 | 19 | - 然后,激活虚拟环境: 20 | ```Shell 21 | conda activate yolo 22 | ``` 23 | 24 | - 配置环境: 25 | 运行下方的命令即可一键配置相关的深度学习环境: 26 | ```Shell 27 | pip install -r requirements.txt 28 | ``` 29 | 如果您已经学习了笔者之前的YOLOv1项目,那么就不需要再次创建该虚拟环境了,二者的环境是可以共用的。 30 | 31 | ## 训练所使用的tricks 32 | 33 | - [x] batch norm 34 | - [x] hi-res classifier 35 | - [x] convolutional 36 | - [x] anchor boxes 37 | - [x] better backbone: resnet50 38 | - [x] dimension priors 39 | - [x] location prediction 40 | - [x] passthrough 41 | - [x] multi-scale 42 | - [x] hi-red detector 43 | 44 | ## 数据集 45 | 46 | ### VOC2007与VOC2012数据集 47 | 48 | 读者可以从下面的百度网盘链接来下载VOC2007和VOC2012数据集 49 | 50 | 链接:https://pan.baidu.com/s/1qClcQXSXjP8FEnsP_RrZjg 51 | 52 | 提取码:zrcj 53 | 54 | 读者会获得 ```VOCdevkit.zip```压缩包, 分别包含 ```VOCdevkit/VOC2007``` 和 ```VOCdevkit/VOC2012```两个文件夹,分别是VOC2007数据集和VOC2012数据集. 55 | 56 | ### COCO 2017 数据集 57 | 58 | * 自己下载 59 | 60 | 运行 ```sh data/scripts/COCO2017.sh```,将会获得 COCO train2017, val2017, test2017三个数据集. 61 | 62 | * 百度网盘下载: 63 | 64 | 这里,笔者也提供了由笔者下好的COCO数据集的百度网盘链接: 65 | 66 | 链接:https://pan.baidu.com/s/1XQqeHgNMp8U-ohbEWuT2CA 67 | 68 | 提取码:l1e5 69 | 70 | ## 实验结果 71 | VOC2007 test 测试集 72 | 73 | | Model | Input size | mAP | Weight | 74 | |-------------------|--------------|---------|--------| 75 | | YOLOv2 | 320×320 | 73.4 | - | 76 | | YOLOv2 | 416×416 | 77.1 | - | 77 | | YOLOv2 | 512×512 | 78.0 | - | 78 | | YOLOv2 | 608×608 | 78.3 | [github](https://github.com/yjh0410/PyTorch_YOLOv2/releases/download/yolov2_weight/yolov2_voc.pth) | 79 | 80 | 81 | COCO val 验证集 82 | 83 | | Model | Input size | AP | AP50 | Weight| 84 | |-------------------|----------------|---------|-----------|-------| 85 | | YOLOv2 | 320×320 | 24.1 | 42.8 | - | 86 | | YOLOv2 | 416×416 | 27.2 | 47.3 | - | 87 | | YOLOv2 | 512×512 | 28.8 | 50.0 | - | 88 | | YOLOv2 | 608×608 | 29.7 | 51.7 | [github](https://github.com/yjh0410/PyTorch_YOLOv2/releases/download/yolov2_weight/yolov2_coco.pth) | 89 | 90 | 大家可以点击表格中的[github]()来下载模型权重文件。 91 | 92 | # 训练模型 93 | 运行下方的命令可开始在```VOC```数据集上进行训练: 94 | ```Shell 95 | python train.py \ 96 | --cuda \ 97 | -d voc \ 98 | -ms \ 99 | -bs 16 \ 100 | -accu 4 \ 101 | --lr 0.001 \ 102 | --max_epoch 200 \ 103 | --lr_epoch 100 150 \ 104 | ``` 105 | 其中,`-bs 16`表示我们设置batch size为16,`-accu 4`表示我们累加梯度4次,以此来近似使用64 batch size的训练效果。 106 | 倘若使用者将`-bs`设置更小,如8,请务必将`-accu`也做相应的调整,如8,以确保`-bs x -accu = 64`,否则,可能会出现训练不稳定的问题。 107 | 108 | # 测试模型 109 | 运行下方的命令可开始在```VOC```数据集上进行训练: 110 | ```Shell 111 | python test.py \ 112 | --cuda \ 113 | -d voc \ 114 | -size 416 \ 115 | --weight path/to/weight \ 116 | ``` 117 | 118 | 119 | # 验证模型 120 | 运行下方的命令可开始在```VOC```数据集上进行训练: 121 | ```Shell 122 | python eval.py \ 123 | --cuda \ 124 | -d voc \ 125 | -size 416 \ 126 | --weight path/to/weight \ 127 | ``` -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | 5 | from data.transform import BaseTransform 6 | from evaluator.cocoapi_evaluator import COCOAPIEvaluator 7 | from evaluator.vocapi_evaluator import VOCAPIEvaluator 8 | from utils.misc import load_weight 9 | 10 | from config import build_model_config 11 | from models.build import build_yolov2 12 | 13 | 14 | parser = argparse.ArgumentParser(description='YOLOv2 Detector Evaluation') 15 | parser.add_argument('-d', '--dataset', default='voc', 16 | help='voc, coco-val, coco-test.') 17 | parser.add_argument('--root', default='/mnt/share/ssd2/dataset', 18 | help='data root') 19 | 20 | parser.add_argument('-v', '--version', default='yolov2', 21 | help='yolo.') 22 | parser.add_argument('--coco_test', action='store_true', default=False, 23 | help='evaluate model on coco-test') 24 | parser.add_argument('--conf_thresh', default=0.001, type=float, 25 | help='得分阈值') 26 | parser.add_argument('--nms_thresh', default=0.50, type=float, 27 | help='NMS 阈值') 28 | parser.add_argument('--topk', default=1000, type=int, 29 | help='topk predicted candidates') 30 | parser.add_argument('--weight', type=str, default=None, 31 | help='Trained state_dict file path to open') 32 | 33 | parser.add_argument('-size', '--input_size', default=416, type=int, 34 | help='input_size') 35 | parser.add_argument('--cuda', action='store_true', default=False, 36 | help='Use cuda') 37 | 38 | args = parser.parse_args() 39 | 40 | 41 | 42 | def voc_test(model, device, input_size, val_transform): 43 | data_root = os.path.join(args.root, 'VOCdevkit') 44 | evaluator = VOCAPIEvaluator( 45 | data_root=data_root, 46 | img_size=input_size, 47 | device=device, 48 | transform=val_transform, 49 | display=True 50 | ) 51 | 52 | # VOC evaluation 53 | evaluator.evaluate(model) 54 | 55 | 56 | def coco_test(model, device, input_size, val_transform, test=False): 57 | data_root = os.path.join(args.root, 'COCO') 58 | if test: 59 | # test-dev 60 | print('test on test-dev 2017') 61 | evaluator = COCOAPIEvaluator( 62 | data_dir=data_root, 63 | img_size=input_size, 64 | device=device, 65 | testset=True, 66 | transform=val_transform 67 | ) 68 | 69 | else: 70 | # eval 71 | evaluator = COCOAPIEvaluator( 72 | data_dir=data_root, 73 | img_size=input_size, 74 | device=device, 75 | testset=False, 76 | transform=val_transform 77 | ) 78 | 79 | # COCO evaluation 80 | evaluator.evaluate(model) 81 | 82 | 83 | if __name__ == '__main__': 84 | # dataset 85 | if args.dataset == 'voc': 86 | print('eval on voc ...') 87 | num_classes = 20 88 | elif args.dataset == 'coco': 89 | print('eval on coco-val ...') 90 | num_classes = 80 91 | else: 92 | print('unknow dataset !! we only support voc, coco !!!') 93 | exit(0) 94 | 95 | # cuda 96 | if args.cuda: 97 | print('use cuda') 98 | device = torch.device("cuda") 99 | else: 100 | device = torch.device("cpu") 101 | 102 | # 构建模型配置文件 103 | cfg = build_model_config(args) 104 | 105 | # 构建模型 106 | model = build_yolov2(args, cfg, device, args.input_size, num_classes, trainable=False) 107 | 108 | # 加载已训练好的模型权重 109 | model = load_weight(model, args.weight) 110 | model.to(device).eval() 111 | 112 | val_transform = BaseTransform(args.input_size) 113 | 114 | # evaluation 115 | with torch.no_grad(): 116 | if args.dataset == 'voc': 117 | voc_test(model, device, args.input_size, val_transform) 118 | elif args.dataset == 'coco': 119 | if args.coco_test: 120 | coco_test(model, device, args.input_size, val_transform, test=True) 121 | else: 122 | coco_test(model, device, args.input_size, val_transform, test=False) 123 | -------------------------------------------------------------------------------- /evaluator/cocoapi_evaluator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tempfile 3 | import torch 4 | import numpy as np 5 | from pycocotools.cocoeval import COCOeval 6 | 7 | from data.coco import COCODataset 8 | 9 | 10 | class COCOAPIEvaluator(): 11 | """ 12 | COCO AP Evaluation class. 13 | All the data in the val2017 dataset are processed \ 14 | and evaluated by COCO API. 15 | """ 16 | def __init__(self, data_dir, img_size, device, testset=False, transform=None): 17 | """ 18 | Args: 19 | data_dir (str): dataset root directory 20 | img_size (int): image size after preprocess. images are resized \ 21 | to squares whose shape is (img_size, img_size). 22 | confthre (float): 23 | confidence threshold ranging from 0 to 1, \ 24 | which is defined in the config file. 25 | nmsthre (float): 26 | IoU threshold of non-max supression ranging from 0 to 1. 27 | """ 28 | self.img_size = img_size 29 | self.transform = transform 30 | self.device = device 31 | self.map = -1. 32 | 33 | self.testset = testset 34 | if self.testset: 35 | json_file='image_info_test-dev2017.json' 36 | image_set = 'test2017' 37 | else: 38 | json_file='instances_val2017.json' 39 | image_set='val2017' 40 | 41 | self.dataset = COCODataset( 42 | data_dir=data_dir, 43 | img_size=img_size, 44 | json_file=json_file, 45 | transform=None, 46 | image_set=image_set) 47 | 48 | 49 | def evaluate(self, model): 50 | """ 51 | COCO average precision (AP) Evaluation. Iterate inference on the test dataset 52 | and the results are evaluated by COCO API. 53 | Args: 54 | model : model object 55 | Returns: 56 | ap50_95 (float) : calculated COCO AP for IoU=50:95 57 | ap50 (float) : calculated COCO AP for IoU=50 58 | """ 59 | model.eval() 60 | ids = [] 61 | data_dict = [] 62 | num_images = len(self.dataset) 63 | print('total number of images: %d' % (num_images)) 64 | 65 | # start testing 66 | for index in range(num_images): # all the data in val2017 67 | if index % 500 == 0: 68 | print('[Eval: %d / %d]'%(index, num_images)) 69 | 70 | img, id_ = self.dataset.pull_image(index) # load a batch 71 | if self.transform is not None: 72 | x = torch.from_numpy(self.transform(img)[0][:, :, (2, 1, 0)]).permute(2, 0, 1) 73 | x = x.unsqueeze(0).to(self.device) 74 | scale = np.array([[img.shape[1], img.shape[0], 75 | img.shape[1], img.shape[0]]]) 76 | 77 | id_ = int(id_) 78 | ids.append(id_) 79 | with torch.no_grad(): 80 | outputs = model(x) 81 | bboxes, scores, labels = outputs 82 | bboxes *= scale 83 | for i, box in enumerate(bboxes): 84 | x1 = float(box[0]) 85 | y1 = float(box[1]) 86 | x2 = float(box[2]) 87 | y2 = float(box[3]) 88 | label = self.dataset.class_ids[int(labels[i])] 89 | 90 | bbox = [x1, y1, x2 - x1, y2 - y1] 91 | score = float(scores[i]) # object score * class score 92 | A = {"image_id": id_, "category_id": label, "bbox": bbox, 93 | "score": score} # COCO json format 94 | data_dict.append(A) 95 | 96 | annType = ['segm', 'bbox', 'keypoints'] 97 | 98 | # Evaluate the Dt (detection) json comparing with the ground truth 99 | if len(data_dict) > 0: 100 | print('evaluating ......') 101 | cocoGt = self.dataset.coco 102 | # workaround: temporarily write data to json file because pycocotools can't process dict in py36. 103 | if self.testset: 104 | json.dump(data_dict, open('yolo_2017.json', 'w')) 105 | cocoDt = cocoGt.loadRes('yolo_2017.json') 106 | else: 107 | _, tmp = tempfile.mkstemp() 108 | json.dump(data_dict, open(tmp, 'w')) 109 | cocoDt = cocoGt.loadRes(tmp) 110 | cocoEval = COCOeval(self.dataset.coco, cocoDt, annType[1]) 111 | cocoEval.params.imgIds = ids 112 | cocoEval.evaluate() 113 | cocoEval.accumulate() 114 | cocoEval.summarize() 115 | 116 | ap50_95, ap50 = cocoEval.stats[0], cocoEval.stats[1] 117 | print('ap50_95 : ', ap50_95) 118 | print('ap50 : ', ap50) 119 | self.map = ap50_95 120 | self.ap50_95 = ap50_95 121 | self.ap50 = ap50 122 | 123 | return ap50_95, ap50 124 | else: 125 | return 0, 0 126 | 127 | -------------------------------------------------------------------------------- /backbone/darknet19.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | 6 | model_urls = { 7 | "darknet19": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/darknet19.pth", 8 | } 9 | 10 | 11 | __all__ = ['darknet19'] 12 | 13 | 14 | class Conv_BN_LeakyReLU(nn.Module): 15 | def __init__(self, in_channels, out_channels, ksize, padding=0, stride=1, dilation=1): 16 | super(Conv_BN_LeakyReLU, self).__init__() 17 | self.convs = nn.Sequential( 18 | nn.Conv2d(in_channels, out_channels, ksize, padding=padding, stride=stride, dilation=dilation), 19 | nn.BatchNorm2d(out_channels), 20 | nn.LeakyReLU(0.1, inplace=True) 21 | ) 22 | 23 | def forward(self, x): 24 | return self.convs(x) 25 | 26 | 27 | class DarkNet_19(nn.Module): 28 | def __init__(self): 29 | super(DarkNet_19, self).__init__() 30 | # backbone network : DarkNet-19 31 | # output : stride = 2, c = 32 32 | self.conv_1 = nn.Sequential( 33 | Conv_BN_LeakyReLU(3, 32, 3, 1), 34 | nn.MaxPool2d((2,2), 2), 35 | ) 36 | 37 | # output : stride = 4, c = 64 38 | self.conv_2 = nn.Sequential( 39 | Conv_BN_LeakyReLU(32, 64, 3, 1), 40 | nn.MaxPool2d((2,2), 2) 41 | ) 42 | 43 | # output : stride = 8, c = 128 44 | self.conv_3 = nn.Sequential( 45 | Conv_BN_LeakyReLU(64, 128, 3, 1), 46 | Conv_BN_LeakyReLU(128, 64, 1), 47 | Conv_BN_LeakyReLU(64, 128, 3, 1), 48 | nn.MaxPool2d((2,2), 2) 49 | ) 50 | 51 | # output : stride = 8, c = 256 52 | self.conv_4 = nn.Sequential( 53 | Conv_BN_LeakyReLU(128, 256, 3, 1), 54 | Conv_BN_LeakyReLU(256, 128, 1), 55 | Conv_BN_LeakyReLU(128, 256, 3, 1), 56 | ) 57 | 58 | # output : stride = 16, c = 512 59 | self.maxpool_4 = nn.MaxPool2d((2, 2), 2) 60 | self.conv_5 = nn.Sequential( 61 | Conv_BN_LeakyReLU(256, 512, 3, 1), 62 | Conv_BN_LeakyReLU(512, 256, 1), 63 | Conv_BN_LeakyReLU(256, 512, 3, 1), 64 | Conv_BN_LeakyReLU(512, 256, 1), 65 | Conv_BN_LeakyReLU(256, 512, 3, 1), 66 | ) 67 | 68 | # output : stride = 32, c = 1024 69 | self.maxpool_5 = nn.MaxPool2d((2, 2), 2) 70 | self.conv_6 = nn.Sequential( 71 | Conv_BN_LeakyReLU(512, 1024, 3, 1), 72 | Conv_BN_LeakyReLU(1024, 512, 1), 73 | Conv_BN_LeakyReLU(512, 1024, 3, 1), 74 | Conv_BN_LeakyReLU(1024, 512, 1), 75 | Conv_BN_LeakyReLU(512, 1024, 3, 1) 76 | ) 77 | 78 | def forward(self, x): 79 | """ 80 | Input: 81 | x: (Tensor) -> [B, 3, H, W] 82 | Output: 83 | output: (Dict) { 84 | 'c3': c3 -> Tensor[B, C3, H/8, W/8], 85 | 'c4': c4 -> Tensor[B, C4, H/16, W/16], 86 | 'c5': c5 -> Tensor[B, C5, H/32, W/32], 87 | } 88 | """ 89 | c1 = self.conv_1(x) # [B, C1, H/2, W/2] 90 | c2 = self.conv_2(c1) # [B, C2, H/4, W/4] 91 | c3 = self.conv_3(c2) # [B, C3, H/8, W/8] 92 | c3 = self.conv_4(c3) # [B, C3, H/8, W/8] 93 | c4 = self.conv_5(self.maxpool_4(c3)) # [B, C4, H/16, W/16] 94 | c5 = self.conv_6(self.maxpool_5(c4)) # [B, C5, H/32, W/32] 95 | 96 | output = { 97 | 'c3': c3, 98 | 'c4': c4, 99 | 'c5': c5 100 | } 101 | return output 102 | 103 | 104 | def build_darknet19(pretrained=False): 105 | # model 106 | model = DarkNet_19() 107 | feat_dims = [256, 512, 1024] 108 | 109 | # load weight 110 | if pretrained: 111 | print('Loading pretrained weight ...') 112 | url = model_urls['darknet19'] 113 | # checkpoint state dict 114 | checkpoint_state_dict = torch.hub.load_state_dict_from_url( 115 | url=url, map_location="cpu", check_hash=True) 116 | # model state dict 117 | model_state_dict = model.state_dict() 118 | # check 119 | for k in list(checkpoint_state_dict.keys()): 120 | if k in model_state_dict: 121 | shape_model = tuple(model_state_dict[k].shape) 122 | shape_checkpoint = tuple(checkpoint_state_dict[k].shape) 123 | if shape_model != shape_checkpoint: 124 | checkpoint_state_dict.pop(k) 125 | else: 126 | checkpoint_state_dict.pop(k) 127 | print(k) 128 | 129 | model.load_state_dict(checkpoint_state_dict) 130 | 131 | return model, feat_dims 132 | 133 | 134 | if __name__ == '__main__': 135 | import time 136 | model, feats = build_darknet19(pretrained=True) 137 | x = torch.randn(1, 3, 224, 224) 138 | t0 = time.time() 139 | outputs = model(x) 140 | t1 = time.time() 141 | print('Time: ', t1 - t0) 142 | for k in outputs.keys(): 143 | print(outputs[k].shape) 144 | -------------------------------------------------------------------------------- /models/matcher.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def compute_iou(anchor_boxes, gt_box): 6 | """计算先验框和真实框之间的IoU 7 | Input: \n 8 | anchor_boxes: [K, 4] \n 9 | gt_box: [1, 4] \n 10 | Output: \n 11 | iou : [K,] \n 12 | """ 13 | 14 | # anchor box : 15 | ab_x1y1_x2y2 = np.zeros([len(anchor_boxes), 4]) 16 | # 计算先验框的左上角点坐标和右下角点坐标 17 | ab_x1y1_x2y2[:, 0] = anchor_boxes[:, 0] - anchor_boxes[:, 2] / 2 # xmin 18 | ab_x1y1_x2y2[:, 1] = anchor_boxes[:, 1] - anchor_boxes[:, 3] / 2 # ymin 19 | ab_x1y1_x2y2[:, 2] = anchor_boxes[:, 0] + anchor_boxes[:, 2] / 2 # xmax 20 | ab_x1y1_x2y2[:, 3] = anchor_boxes[:, 1] + anchor_boxes[:, 3] / 2 # ymax 21 | w_ab, h_ab = anchor_boxes[:, 2], anchor_boxes[:, 3] 22 | 23 | # gt_box : 24 | # 我们将真实框扩展成[K, 4], 便于计算IoU. 25 | gt_box_expand = np.repeat(gt_box, len(anchor_boxes), axis=0) 26 | 27 | gb_x1y1_x2y2 = np.zeros([len(anchor_boxes), 4]) 28 | # 计算真实框的左上角点坐标和右下角点坐标 29 | gb_x1y1_x2y2[:, 0] = gt_box_expand[:, 0] - gt_box_expand[:, 2] / 2 # xmin 30 | gb_x1y1_x2y2[:, 1] = gt_box_expand[:, 1] - gt_box_expand[:, 3] / 2 # ymin 31 | gb_x1y1_x2y2[:, 2] = gt_box_expand[:, 0] + gt_box_expand[:, 2] / 2 # xmax 32 | gb_x1y1_x2y2[:, 3] = gt_box_expand[:, 1] + gt_box_expand[:, 3] / 2 # ymin 33 | w_gt, h_gt = gt_box_expand[:, 2], gt_box_expand[:, 3] 34 | 35 | # 计算IoU 36 | S_gt = w_gt * h_gt 37 | S_ab = w_ab * h_ab 38 | I_w = np.minimum(gb_x1y1_x2y2[:, 2], ab_x1y1_x2y2[:, 2]) - np.maximum(gb_x1y1_x2y2[:, 0], ab_x1y1_x2y2[:, 0]) 39 | I_h = np.minimum(gb_x1y1_x2y2[:, 3], ab_x1y1_x2y2[:, 3]) - np.maximum(gb_x1y1_x2y2[:, 1], ab_x1y1_x2y2[:, 1]) 40 | S_I = I_h * I_w 41 | U = S_gt + S_ab - S_I + 1e-20 42 | IoU = S_I / U 43 | 44 | return IoU 45 | 46 | 47 | def set_anchors(anchor_size): 48 | """将输入进来的只包含wh的先验框尺寸转换成[N, 4]的ndarray类型, 49 | 包含先验框的中心点坐标和宽高wh,中心点坐标设为0. \n 50 | Input: \n 51 | anchor_size: list -> [[h_1, w_1], \n 52 | [h_2, w_2], \n 53 | ..., \n 54 | [h_n, w_n]]. \n 55 | Output: \n 56 | anchor_boxes: ndarray -> [[0, 0, anchor_w, anchor_h], \n 57 | [0, 0, anchor_w, anchor_h], \n 58 | ... \n 59 | [0, 0, anchor_w, anchor_h]]. \n 60 | """ 61 | anchor_number = len(anchor_size) 62 | anchor_boxes = np.zeros([anchor_number, 4]) 63 | for index, size in enumerate(anchor_size): 64 | anchor_w, anchor_h = size 65 | anchor_boxes[index] = np.array([0, 0, anchor_w, anchor_h]) 66 | 67 | return anchor_boxes 68 | 69 | 70 | def generate_txtytwth(gt_label, w, h, s, anchor_size, ignore_thresh): 71 | xmin, ymin, xmax, ymax = gt_label[:-1] 72 | # 计算真实边界框的中心点和宽高 73 | c_x = (xmax + xmin) / 2 * w 74 | c_y = (ymax + ymin) / 2 * h 75 | box_w = (xmax - xmin) * w 76 | box_h = (ymax - ymin) * h 77 | 78 | if box_w < 1e-4 or box_h < 1e-4: 79 | # print('not a valid data !!!') 80 | return False 81 | 82 | # 将真是边界框的尺寸映射到网格的尺度上去 83 | c_x_s = c_x / s 84 | c_y_s = c_y / s 85 | box_ws = box_w / s 86 | box_hs = box_h / s 87 | 88 | # 计算中心点所落在的网格的坐标 89 | grid_x = int(c_x_s) 90 | grid_y = int(c_y_s) 91 | 92 | # 获得先验框的中心点坐标和宽高, 93 | # 这里,我们设置所有的先验框的中心点坐标为0 94 | anchor_boxes = set_anchors(anchor_size) 95 | gt_box = np.array([[0, 0, box_ws, box_hs]]) 96 | 97 | # 计算先验框和真实框之间的IoU 98 | iou = compute_iou(anchor_boxes, gt_box) 99 | 100 | # 只保留大于ignore_thresh的先验框去做正样本匹配, 101 | iou_mask = (iou > ignore_thresh) 102 | 103 | result = [] 104 | if iou_mask.sum() == 0: 105 | # 如果所有的先验框算出的IoU都小于阈值,那么就将IoU最大的那个先验框分配给正样本. 106 | # 其他的先验框统统视为负样本 107 | index = np.argmax(iou) 108 | p_w, p_h = anchor_size[index] 109 | tx = c_x_s - grid_x 110 | ty = c_y_s - grid_y 111 | tw = np.log(box_ws / p_w) 112 | th = np.log(box_hs / p_h) 113 | weight = 2.0 - (box_w / w) * (box_h / h) 114 | 115 | result.append([index, grid_x, grid_y, tx, ty, tw, th, weight, xmin, ymin, xmax, ymax]) 116 | 117 | return result 118 | 119 | else: 120 | # 有至少一个先验框的IoU超过了阈值. 121 | # 但我们只保留超过阈值的那些先验框中IoU最大的,其他的先验框忽略掉,不参与loss计算。 122 | # 而小于阈值的先验框统统视为负样本。 123 | best_index = np.argmax(iou) 124 | for index, iou_m in enumerate(iou_mask): 125 | if iou_m: 126 | if index == best_index: 127 | p_w, p_h = anchor_size[index] 128 | tx = c_x_s - grid_x 129 | ty = c_y_s - grid_y 130 | tw = np.log(box_ws / p_w) 131 | th = np.log(box_hs / p_h) 132 | weight = 2.0 - (box_w / w) * (box_h / h) 133 | 134 | result.append([index, grid_x, grid_y, tx, ty, tw, th, weight, xmin, ymin, xmax, ymax]) 135 | else: 136 | # 对于被忽略的先验框,我们将其权重weight设置为-1 137 | result.append([index, grid_x, grid_y, 0., 0., 0., 0., -1.0, 0., 0., 0., 0.]) 138 | 139 | return result 140 | 141 | 142 | def gt_creator(input_size, stride, label_lists, anchor_size, ignore_thresh): 143 | # 必要的参数 144 | batch_size = len(label_lists) 145 | s = stride 146 | w = input_size 147 | h = input_size 148 | ws = w // s 149 | hs = h // s 150 | anchor_number = len(anchor_size) 151 | gt_tensor = np.zeros([batch_size, hs, ws, anchor_number, 1+1+4+1+4]) 152 | 153 | # 制作正样本 154 | for batch_index in range(batch_size): 155 | for gt_label in label_lists[batch_index]: 156 | # get a bbox coords 157 | gt_class = int(gt_label[-1]) 158 | results = generate_txtytwth(gt_label, w, h, s, anchor_size, ignore_thresh) 159 | if results: 160 | for result in results: 161 | index, grid_x, grid_y, tx, ty, tw, th, weight, xmin, ymin, xmax, ymax = result 162 | if weight > 0.: 163 | if grid_y < gt_tensor.shape[1] and grid_x < gt_tensor.shape[2]: 164 | gt_tensor[batch_index, grid_y, grid_x, index, 0] = 1.0 165 | gt_tensor[batch_index, grid_y, grid_x, index, 1] = gt_class 166 | gt_tensor[batch_index, grid_y, grid_x, index, 2:6] = np.array([tx, ty, tw, th]) 167 | gt_tensor[batch_index, grid_y, grid_x, index, 6] = weight 168 | gt_tensor[batch_index, grid_y, grid_x, index, 7:] = np.array([xmin, ymin, xmax, ymax]) 169 | else: 170 | # 对于那些被忽略的先验框,其gt_obj参数为-1,weight权重也是-1 171 | gt_tensor[batch_index, grid_y, grid_x, index, 0] = -1.0 172 | gt_tensor[batch_index, grid_y, grid_x, index, 6] = -1.0 173 | 174 | gt_tensor = gt_tensor.reshape(batch_size, hs * ws * anchor_number, 1+1+4+1+4) 175 | 176 | return gt_tensor 177 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import numpy as np 4 | import cv2 5 | import os 6 | import time 7 | 8 | from utils.misc import load_weight 9 | from data.voc0712 import VOCDetection, VOC_CLASSES 10 | from data.coco import COCODataset, coco_class_index, coco_class_labels 11 | from data.transform import BaseTransform 12 | 13 | from config import build_model_config 14 | from models.build import build_yolov2 15 | 16 | 17 | parser = argparse.ArgumentParser(description='YOLOv2 Detection') 18 | parser.add_argument('-d', '--dataset', default='voc', 19 | help='voc, coco-val.') 20 | parser.add_argument('--root', default='/mnt/share/ssd2/dataset', 21 | help='data root') 22 | parser.add_argument('-size', '--input_size', default=416, type=int, 23 | help='输入图像尺寸') 24 | 25 | parser.add_argument('-v', '--version', default='yolov2', 26 | help='yolo') 27 | parser.add_argument('--weight', default=None, 28 | type=str, help='模型权重的路径') 29 | parser.add_argument('--conf_thresh', default=0.1, type=float, 30 | help='得分阈值') 31 | parser.add_argument('--nms_thresh', default=0.50, type=float, 32 | help='NMS 阈值') 33 | parser.add_argument('--topk', default=100, type=int, 34 | help='topk predicted candidates') 35 | 36 | parser.add_argument('-vs', '--visual_threshold', default=0.33, type=float, 37 | help='用于可视化的阈值参数') 38 | parser.add_argument('--cuda', action='store_true', default=False, 39 | help='use cuda.') 40 | parser.add_argument('--save', action='store_true', default=False, 41 | help='save vis results.') 42 | 43 | args = parser.parse_args() 44 | 45 | 46 | def plot_bbox_labels(img, bbox, label=None, cls_color=None, text_scale=0.4): 47 | x1, y1, x2, y2 = bbox 48 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) 49 | t_size = cv2.getTextSize(label, 0, fontScale=1, thickness=2)[0] 50 | # plot bbox 51 | cv2.rectangle(img, (x1, y1), (x2, y2), cls_color, 2) 52 | 53 | if label is not None: 54 | # plot title bbox 55 | cv2.rectangle(img, (x1, y1-t_size[1]), (int(x1 + t_size[0] * text_scale), y1), cls_color, -1) 56 | # put the test on the title bbox 57 | cv2.putText(img, label, (int(x1), int(y1 - 5)), 0, text_scale, (0, 0, 0), 1, lineType=cv2.LINE_AA) 58 | 59 | return img 60 | 61 | 62 | def visualize(img, 63 | bboxes, 64 | scores, 65 | labels, 66 | vis_thresh, 67 | class_colors, 68 | class_names, 69 | class_indexs=None, 70 | dataset_name='voc'): 71 | ts = 0.4 72 | for i, bbox in enumerate(bboxes): 73 | if scores[i] > vis_thresh: 74 | cls_id = int(labels[i]) 75 | if dataset_name == 'coco': 76 | cls_color = class_colors[cls_id] 77 | cls_id = class_indexs[cls_id] 78 | else: 79 | cls_color = class_colors[cls_id] 80 | 81 | if len(class_names) > 1: 82 | mess = '%s: %.2f' % (class_names[cls_id], scores[i]) 83 | else: 84 | cls_color = [255, 0, 0] 85 | mess = None 86 | img = plot_bbox_labels(img, bbox, mess, cls_color, text_scale=ts) 87 | 88 | return img 89 | 90 | 91 | def test(args, model, device, testset, transform, class_colors=None, class_names=None, class_indexs=None): 92 | save_path = os.path.join('det_results/', args.dataset, args.version) 93 | os.makedirs(save_path, exist_ok=True) 94 | 95 | num_images = len(testset) 96 | for index in range(num_images): 97 | print('Testing image {:d}/{:d}....'.format(index+1, num_images)) 98 | img, _ = testset.pull_image(index) 99 | h, w, _ = img.shape 100 | 101 | # 预处理图像,并将其转换为tensor类型 102 | x = torch.from_numpy(transform(img)[0][:, :, (2, 1, 0)]).permute(2, 0, 1) 103 | x = x.unsqueeze(0).to(device) 104 | 105 | t0 = time.time() 106 | # 前向推理 107 | bboxes, scores, labels = model(x) 108 | print("detection time used ", time.time() - t0, "s") 109 | 110 | # 将预测的输出映射到原图的尺寸上去 111 | scale = np.array([[w, h, w, h]]) 112 | bboxes *= scale 113 | 114 | # 可视化检测结果 115 | img_processed = visualize( 116 | img=img, 117 | bboxes=bboxes, 118 | scores=scores, 119 | labels=labels, 120 | vis_thresh=args.visual_threshold, 121 | class_colors=class_colors, 122 | class_names=class_names, 123 | class_indexs=class_indexs, 124 | dataset_name=args.dataset 125 | ) 126 | cv2.imshow('detection', img_processed) 127 | cv2.waitKey(0) 128 | 129 | # 保存可视化结果 130 | if args.save: 131 | cv2.imwrite(os.path.join(save_path, str(index).zfill(6) +'.jpg'), img_processed) 132 | 133 | 134 | if __name__ == '__main__': 135 | # 是否使用cuda 136 | if args.cuda: 137 | print('use cuda') 138 | device = torch.device("cuda") 139 | else: 140 | device = torch.device("cpu") 141 | 142 | # 输入图像的尺寸 143 | input_size = args.input_size 144 | 145 | # 构建数据集 146 | if args.dataset == 'voc': 147 | data_root = os.path.join(args.root, 'VOCdevkit') 148 | # 加载VOC2007 test数据集 149 | print('test on voc ...') 150 | class_names = VOC_CLASSES 151 | class_indexs = None 152 | num_classes = 20 153 | dataset = VOCDetection( 154 | root=data_root, 155 | img_size=input_size, 156 | image_sets=[('2007', 'test')], 157 | transform=None 158 | ) 159 | 160 | elif args.dataset == 'coco': 161 | data_root = os.path.join(args.root, 'COCO') 162 | # 加载COCO val数据集 163 | print('test on coco-val ...') 164 | class_names = coco_class_labels 165 | class_indexs = coco_class_index 166 | num_classes = 80 167 | dataset = COCODataset( 168 | data_dir=data_root, 169 | json_file='instances_val2017.json', 170 | image_set='val2017', 171 | img_size=input_size) 172 | 173 | # 用于可视化,给不同类别的边界框赋予不同的颜色,为了便于区分。 174 | np.random.seed(0) 175 | class_colors = [(np.random.randint(255), 176 | np.random.randint(255), 177 | np.random.randint(255)) for _ in range(num_classes)] 178 | 179 | # 构建模型配置文件 180 | cfg = build_model_config(args) 181 | 182 | # 构建模型 183 | model = build_yolov2(args, cfg, device, input_size, num_classes, trainable=False) 184 | 185 | # 加载已训练好的模型权重 186 | model = load_weight(model, args.weight) 187 | model.to(device).eval() 188 | print('Finished loading model!') 189 | 190 | val_transform = BaseTransform(input_size) 191 | 192 | # 开始测试 193 | test(args=args, 194 | model=model, 195 | device=device, 196 | testset=dataset, 197 | transform=val_transform, 198 | class_colors=class_colors, 199 | class_names=class_names, 200 | class_indexs=class_indexs, 201 | ) 202 | -------------------------------------------------------------------------------- /data/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | import cv2 8 | from pycocotools.coco import COCO 9 | 10 | 11 | coco_class_labels = ('background', 12 | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 13 | 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign', 14 | 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 15 | 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella', 16 | 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 17 | 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 18 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass', 19 | 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 20 | 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 21 | 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk', 22 | 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 23 | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book', 24 | 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush') 25 | 26 | coco_class_index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 27 | 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 28 | 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 29 | 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] 30 | 31 | 32 | class COCODataset(Dataset): 33 | """ 34 | COCO dataset class. 35 | """ 36 | def __init__(self, data_dir='COCO', 37 | json_file='instances_train2017.json', 38 | image_set='train2017', 39 | img_size=None, 40 | transform=None, 41 | ): 42 | """ 43 | COCO dataset initialization. Annotation data are read into memory by COCO API. 44 | Args: 45 | data_dir (str): dataset root directory 46 | json_file (str): COCO json file image_set 47 | image_set (str): COCO data image_set (e.g. 'train2017' or 'val2017') 48 | img_size (int): target image size after pre-processing 49 | min_size (int): bounding boxes smaller than this are ignored 50 | debug (bool): if True, only one data id is selected from the dataset 51 | """ 52 | self.data_dir = data_dir 53 | self.json_file = json_file 54 | self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file)) 55 | self.ids = self.coco.getImgIds() 56 | self.class_ids = sorted(self.coco.getCatIds()) 57 | self.image_set = image_set 58 | self.max_labels = 50 59 | self.img_size = img_size 60 | self.transform = transform 61 | 62 | 63 | def __len__(self): 64 | return len(self.ids) 65 | 66 | 67 | def pull_image(self, index): 68 | id_ = self.ids[index] 69 | img_file = os.path.join(self.data_dir, self.image_set, 70 | '{:012}'.format(id_) + '.jpg') 71 | img = cv2.imread(img_file) 72 | 73 | if self.json_file == 'instances_val5k.json' and img is None: 74 | img_file = os.path.join(self.data_dir, 'train2017', 75 | '{:012}'.format(id_) + '.jpg') 76 | img = cv2.imread(img_file) 77 | 78 | elif self.json_file == 'image_info_test-dev2017.json' and img is None: 79 | img_file = os.path.join(self.data_dir, 'test2017', 80 | '{:012}'.format(id_) + '.jpg') 81 | img = cv2.imread(img_file) 82 | 83 | elif self.json_file == 'image_info_test2017.json' and img is None: 84 | img_file = os.path.join(self.data_dir, 'test2017', 85 | '{:012}'.format(id_) + '.jpg') 86 | img = cv2.imread(img_file) 87 | 88 | return img, id_ 89 | 90 | 91 | def __getitem__(self, index): 92 | im, gt, h, w = self.pull_item(index) 93 | 94 | return im, gt 95 | 96 | 97 | def pull_item(self, index): 98 | id_ = self.ids[index] 99 | 100 | anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=None) 101 | annotations = self.coco.loadAnns(anno_ids) 102 | 103 | # load image and preprocess 104 | img_file = os.path.join(self.data_dir, self.image_set, 105 | '{:012}'.format(id_) + '.jpg') 106 | img = cv2.imread(img_file) 107 | 108 | if self.json_file == 'instances_val5k.json' and img is None: 109 | img_file = os.path.join(self.data_dir, 'train2017', 110 | '{:012}'.format(id_) + '.jpg') 111 | img = cv2.imread(img_file) 112 | 113 | assert img is not None 114 | 115 | height, width, channels = img.shape 116 | 117 | # COCOAnnotation Transform 118 | # start here : 119 | target = [] 120 | for anno in annotations: 121 | x1 = np.max((0, anno['bbox'][0])) 122 | y1 = np.max((0, anno['bbox'][1])) 123 | x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1)))) 124 | y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1)))) 125 | if anno['area'] > 0 and x2 >= x1 and y2 >= y1: 126 | label_ind = anno['category_id'] 127 | cls_id = self.class_ids.index(label_ind) 128 | x1 /= width 129 | y1 /= height 130 | x2 /= width 131 | y2 /= height 132 | 133 | target.append([x1, y1, x2, y2, cls_id]) # [xmin, ymin, xmax, ymax, label_ind] 134 | # end here . 135 | 136 | # data augmentation 137 | if self.transform is not None: 138 | if len(target) == 0: 139 | target = np.zeros([1, 5]) 140 | else: 141 | target = np.array(target) 142 | 143 | img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) 144 | # to rgb 145 | img = img[:, :, (2, 1, 0)] 146 | # img = img.transpose(2, 0, 1) 147 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 148 | 149 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width 150 | 151 | 152 | if __name__ == "__main__": 153 | from transform import Augmentation, BaseTransform 154 | 155 | img_size = 640 156 | pixel_mean = (0.406, 0.456, 0.485) # BGR 157 | pixel_std = (0.225, 0.224, 0.229) # BGR 158 | data_root = '/mnt/share/ssd2/dataset/COCO' 159 | transform = Augmentation(img_size, pixel_mean, pixel_std) 160 | transform = BaseTransform(img_size, pixel_mean, pixel_std) 161 | 162 | img_size = 640 163 | dataset = COCODataset( 164 | data_dir=data_root, 165 | img_size=img_size, 166 | transform=transform 167 | ) 168 | 169 | for i in range(1000): 170 | im, gt, h, w = dataset.pull_item(i) 171 | 172 | # to numpy 173 | image = im.permute(1, 2, 0).numpy() 174 | # to BGR 175 | image = image[..., (2, 1, 0)] 176 | # denormalize 177 | image = (image * pixel_std + pixel_mean) * 255 178 | # to 179 | image = image.astype(np.uint8).copy() 180 | 181 | # draw bbox 182 | for box in gt: 183 | xmin, ymin, xmax, ymax, _ = box 184 | xmin *= img_size 185 | ymin *= img_size 186 | xmax *= img_size 187 | ymax *= img_size 188 | image = cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0,0,255), 2) 189 | cv2.imshow('gt', image) 190 | cv2.waitKey(0) 191 | -------------------------------------------------------------------------------- /data/voc0712.py: -------------------------------------------------------------------------------- 1 | """VOC Dataset Classes 2 | 3 | Original author: Francisco Massa 4 | https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py 5 | 6 | Updated by: Ellis Brown, Max deGroot 7 | """ 8 | import os.path as osp 9 | import torch 10 | import torch.utils.data as data 11 | import cv2 12 | import numpy as np 13 | import xml.etree.ElementTree as ET 14 | 15 | 16 | VOC_CLASSES = ( # always index 0 17 | 'aeroplane', 'bicycle', 'bird', 'boat', 18 | 'bottle', 'bus', 'car', 'cat', 'chair', 19 | 'cow', 'diningtable', 'dog', 'horse', 20 | 'motorbike', 'person', 'pottedplant', 21 | 'sheep', 'sofa', 'train', 'tvmonitor') 22 | 23 | 24 | class VOCAnnotationTransform(object): 25 | """Transforms a VOC annotation into a Tensor of bbox coords and label index 26 | Initilized with a dictionary lookup of classnames to indexes 27 | 28 | Arguments: 29 | class_to_ind (dict, optional): dictionary lookup of classnames -> indexes 30 | (default: alphabetic indexing of VOC's 20 classes) 31 | keep_difficult (bool, optional): keep difficult instances or not 32 | (default: False) 33 | height (int): height 34 | width (int): width 35 | """ 36 | 37 | def __init__(self, class_to_ind=None, keep_difficult=False): 38 | self.class_to_ind = class_to_ind or dict( 39 | zip(VOC_CLASSES, range(len(VOC_CLASSES)))) 40 | self.keep_difficult = keep_difficult 41 | 42 | def __call__(self, target, width, height): 43 | """ 44 | Arguments: 45 | target (annotation) : the target annotation to be made usable 46 | will be an ET.Element 47 | Returns: 48 | a list containing lists of bounding boxes [bbox coords, class name] 49 | """ 50 | res = [] 51 | for obj in target.iter('object'): 52 | difficult = int(obj.find('difficult').text) == 1 53 | if not self.keep_difficult and difficult: 54 | continue 55 | name = obj.find('name').text.lower().strip() 56 | bbox = obj.find('bndbox') 57 | 58 | pts = ['xmin', 'ymin', 'xmax', 'ymax'] 59 | bndbox = [] 60 | for i, pt in enumerate(pts): 61 | cur_pt = int(bbox.find(pt).text) - 1 62 | # scale height or width 63 | cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height 64 | bndbox.append(cur_pt) 65 | label_idx = self.class_to_ind[name] 66 | bndbox.append(label_idx) 67 | res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind] 68 | # img_id = target.find('filename').text[:-4] 69 | 70 | return res # [[xmin, ymin, xmax, ymax, label_ind], ... ] 71 | 72 | 73 | class VOCDetection(data.Dataset): 74 | """VOC Detection Dataset Object 75 | 76 | input is image, target is annotation 77 | 78 | Arguments: 79 | root (string): filepath to VOCdevkit folder. 80 | image_set (string): imageset to use (eg. 'train', 'val', 'test') 81 | transform (callable, optional): transformation to perform on the 82 | input image 83 | target_transform (callable, optional): transformation to perform on the 84 | target `annotation` 85 | (eg: take in caption string, return tensor of word indices) 86 | dataset_name (string, optional): which dataset to load 87 | (default: 'VOC2007') 88 | """ 89 | 90 | def __init__(self, 91 | root, 92 | img_size=None, 93 | image_sets=[('2007', 'trainval'), ('2012', 'trainval')], 94 | transform=None, 95 | target_transform=VOCAnnotationTransform(), 96 | dataset_name='VOC0712' 97 | ): 98 | self.root = root 99 | self.img_size = img_size 100 | self.image_set = image_sets 101 | self.transform = transform 102 | self.target_transform = target_transform 103 | self.name = dataset_name 104 | self._annopath = osp.join('%s', 'Annotations', '%s.xml') 105 | self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg') 106 | self.ids = list() 107 | for (year, name) in image_sets: 108 | rootpath = osp.join(self.root, 'VOC' + year) 109 | for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')): 110 | self.ids.append((rootpath, line.strip())) 111 | 112 | 113 | def __getitem__(self, index): 114 | im, gt, h, w = self.pull_item(index) 115 | 116 | return im, gt 117 | 118 | 119 | def __len__(self): 120 | return len(self.ids) 121 | 122 | 123 | def pull_item(self, index): 124 | img_id = self.ids[index] 125 | 126 | target = ET.parse(self._annopath % img_id).getroot() 127 | img = cv2.imread(self._imgpath % img_id) 128 | height, width, channels = img.shape 129 | 130 | if self.target_transform is not None: 131 | target = self.target_transform(target, width, height) 132 | 133 | # basic augmentation(SSDAugmentation or BaseTransform) 134 | if self.transform is not None: 135 | # check labels 136 | if len(target) == 0: 137 | target = np.zeros([1, 5]) 138 | else: 139 | target = np.array(target) 140 | 141 | img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) 142 | # to rgb 143 | img = img[:, :, (2, 1, 0)] 144 | # img = img.transpose(2, 0, 1) 145 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 146 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width 147 | # return torch.from_numpy(img), target, height, width 148 | 149 | 150 | def pull_image(self, index): 151 | '''Returns the original image object at index in PIL form 152 | 153 | Note: not using self.__getitem__(), as any transformations passed in 154 | could mess up this functionality. 155 | 156 | Argument: 157 | index (int): index of img to show 158 | Return: 159 | PIL img 160 | ''' 161 | img_id = self.ids[index] 162 | return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR), img_id 163 | 164 | 165 | def pull_anno(self, index): 166 | '''Returns the original annotation of image at index 167 | 168 | Note: not using self.__getitem__(), as any transformations passed in 169 | could mess up this functionality. 170 | 171 | Argument: 172 | index (int): index of img to get annotation of 173 | Return: 174 | list: [img_id, [(label, bbox coords),...]] 175 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 176 | ''' 177 | img_id = self.ids[index] 178 | anno = ET.parse(self._annopath % img_id).getroot() 179 | gt = self.target_transform(anno, 1, 1) 180 | return img_id[1], gt 181 | 182 | 183 | if __name__ == "__main__": 184 | from transform import Augmentation, BaseTransform 185 | 186 | img_size = 640 187 | pixel_mean = (0.406, 0.456, 0.485) # BGR 188 | pixel_std = (0.225, 0.224, 0.229) # BGR 189 | data_root = 'D:\\python_work\\object-detection\\dataset\\VOCdevkit' 190 | transform = Augmentation(img_size, pixel_mean, pixel_std) 191 | transform = BaseTransform(img_size, pixel_mean, pixel_std) 192 | 193 | # dataset 194 | dataset = VOCDetection( 195 | root=data_root, 196 | img_size=img_size, 197 | image_sets=[('2007', 'trainval')], 198 | transform=transform 199 | ) 200 | 201 | for i in range(1000): 202 | im, gt, h, w = dataset.pull_item(i) 203 | 204 | # to numpy 205 | image = im.permute(1, 2, 0).numpy() 206 | # to BGR 207 | image = image[..., (2, 1, 0)] 208 | # denormalize 209 | image = (image * pixel_std + pixel_mean) * 255 210 | # to 211 | image = image.astype(np.uint8).copy() 212 | 213 | # draw bbox 214 | for box in gt: 215 | xmin, ymin, xmax, ymax, _ = box 216 | xmin *= img_size 217 | ymin *= img_size 218 | xmax *= img_size 219 | ymax *= img_size 220 | image = cv2.rectangle(image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0,0,255), 2) 221 | cv2.imshow('gt', image) 222 | cv2.waitKey(0) 223 | -------------------------------------------------------------------------------- /utils/kmeans_anchor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import argparse 4 | import os 5 | import sys 6 | sys.path.append('..') 7 | 8 | from data.voc0712 import VOCDetection 9 | from data.coco import COCODataset 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser(description='kmeans for anchor box') 14 | parser.add_argument('--root', default='/mnt/share/ssd2/dataset', 15 | help='data root') 16 | parser.add_argument('-d', '--dataset', default='coco', 17 | help='coco, widerface, crowdhuman') 18 | parser.add_argument('-na', '--num_anchorbox', default=5, type=int, 19 | help='number of anchor box.') 20 | parser.add_argument('-size', '--img_size', default=416, type=int, 21 | help='input size.') 22 | parser.add_argument('-ab', '--absolute', action='store_true', default=False, 23 | help='absolute coords.') 24 | return parser.parse_args() 25 | 26 | args = parse_args() 27 | 28 | 29 | class Box(): 30 | def __init__(self, x, y, w, h): 31 | self.x = x 32 | self.y = y 33 | self.w = w 34 | self.h = h 35 | 36 | 37 | def iou(box1, box2): 38 | x1, y1, w1, h1 = box1.x, box1.y, box1.w, box1.h 39 | x2, y2, w2, h2 = box2.x, box2.y, box2.w, box2.h 40 | 41 | S_1 = w1 * h1 42 | S_2 = w2 * h2 43 | 44 | xmin_1, ymin_1 = x1 - w1 / 2, y1 - h1 / 2 45 | xmax_1, ymax_1 = x1 + w1 / 2, y1 + h1 / 2 46 | xmin_2, ymin_2 = x2 - w2 / 2, y2 - h2 / 2 47 | xmax_2, ymax_2 = x2 + w2 / 2, y2 + h2 / 2 48 | 49 | I_w = min(xmax_1, xmax_2) - max(xmin_1, xmin_2) 50 | I_h = min(ymax_1, ymax_2) - max(ymin_1, ymin_2) 51 | if I_w < 0 or I_h < 0: 52 | return 0 53 | I = I_w * I_h 54 | 55 | IoU = I / (S_1 + S_2 - I) 56 | 57 | return IoU 58 | 59 | 60 | def init_centroids(boxes, n_anchors): 61 | """ 62 | We use kmeans++ to initialize centroids. 63 | """ 64 | centroids = [] 65 | boxes_num = len(boxes) 66 | 67 | centroid_index = int(np.random.choice(boxes_num, 1)[0]) 68 | centroids.append(boxes[centroid_index]) 69 | print(centroids[0].w,centroids[0].h) 70 | 71 | for centroid_index in range(0, n_anchors-1): 72 | sum_distance = 0 73 | distance_thresh = 0 74 | distance_list = [] 75 | cur_sum = 0 76 | 77 | for box in boxes: 78 | min_distance = 1 79 | for centroid_i, centroid in enumerate(centroids): 80 | distance = (1 - iou(box, centroid)) 81 | if distance < min_distance: 82 | min_distance = distance 83 | sum_distance += min_distance 84 | distance_list.append(min_distance) 85 | 86 | distance_thresh = sum_distance * np.random.random() 87 | 88 | for i in range(0, boxes_num): 89 | cur_sum += distance_list[i] 90 | if cur_sum > distance_thresh: 91 | centroids.append(boxes[i]) 92 | print(boxes[i].w, boxes[i].h) 93 | break 94 | return centroids 95 | 96 | 97 | def do_kmeans(n_anchors, boxes, centroids): 98 | loss = 0 99 | groups = [] 100 | new_centroids = [] 101 | # for box in centroids: 102 | # print('box: ', box.x, box.y, box.w, box.h) 103 | # exit() 104 | for i in range(n_anchors): 105 | groups.append([]) 106 | new_centroids.append(Box(0, 0, 0, 0)) 107 | 108 | for box in boxes: 109 | min_distance = 1 110 | group_index = 0 111 | for centroid_index, centroid in enumerate(centroids): 112 | distance = (1 - iou(box, centroid)) 113 | if distance < min_distance: 114 | min_distance = distance 115 | group_index = centroid_index 116 | groups[group_index].append(box) 117 | loss += min_distance 118 | new_centroids[group_index].w += box.w 119 | new_centroids[group_index].h += box.h 120 | 121 | for i in range(n_anchors): 122 | new_centroids[i].w /= max(len(groups[i]), 1) 123 | new_centroids[i].h /= max(len(groups[i]), 1) 124 | 125 | return new_centroids, groups, loss# / len(boxes) 126 | 127 | 128 | def anchor_box_kmeans(total_gt_boxes, n_anchors, loss_convergence, iters, plus=True): 129 | """ 130 | This function will use k-means to get appropriate anchor boxes for train dataset. 131 | Input: 132 | total_gt_boxes: 133 | n_anchor : int -> the number of anchor boxes. 134 | loss_convergence : float -> threshold of iterating convergence. 135 | iters: int -> the number of iterations for training kmeans. 136 | Output: anchor_boxes : list -> [[w1, h1], [w2, h2], ..., [wn, hn]]. 137 | """ 138 | boxes = total_gt_boxes 139 | centroids = [] 140 | if plus: 141 | centroids = init_centroids(boxes, n_anchors) 142 | else: 143 | total_indexs = range(len(boxes)) 144 | sample_indexs = random.sample(total_indexs, n_anchors) 145 | for i in sample_indexs: 146 | centroids.append(boxes[i]) 147 | 148 | # iterate k-means 149 | centroids, groups, old_loss = do_kmeans(n_anchors, boxes, centroids) 150 | iterations = 1 151 | while(True): 152 | centroids, groups, loss = do_kmeans(n_anchors, boxes, centroids) 153 | iterations += 1 154 | print("Loss = %f" % loss) 155 | if abs(old_loss - loss) < loss_convergence or iterations > iters: 156 | break 157 | old_loss = loss 158 | 159 | for centroid in centroids: 160 | print(centroid.w, centroid.h) 161 | 162 | print("k-means result : ") 163 | for centroid in centroids: 164 | if args.absolute: 165 | print("w, h: ", round(centroid.w, 2), round(centroid.h, 2), 166 | "area: ", round(centroid.w, 2) * round(centroid.h, 2)) 167 | else: 168 | print("w, h: ", round(centroid.w / 32, 2), round(centroid.h / 32, 2), 169 | "area: ", round(centroid.w / 32, 2) * round(centroid.h / 32, 2)) 170 | 171 | return centroids 172 | 173 | 174 | if __name__ == "__main__": 175 | 176 | n_anchors = args.num_anchorbox 177 | img_size = args.img_size 178 | 179 | loss_convergence = 1e-6 180 | iters_n = 1000 181 | 182 | boxes = [] 183 | if args.dataset == 'voc': 184 | data_root = os.path.join(args.root, 'VOCdevkit') 185 | dataset = VOCDetection(root=data_root,img_size=img_size) 186 | 187 | # VOC 188 | for i in range(len(dataset)): 189 | if i % 5000 == 0: 190 | print('Loading voc data [%d / %d]' % (i+1, len(dataset))) 191 | 192 | # For VOC 193 | img, _ = dataset.pull_image(i) 194 | w, h = img.shape[1], img.shape[0] 195 | _, annotation = dataset.pull_anno(i) 196 | 197 | # prepare bbox datas 198 | for box_and_label in annotation: 199 | box = box_and_label[:-1] 200 | xmin, ymin, xmax, ymax = box 201 | bw = (xmax - xmin) / max(w, h) * img_size 202 | bh = (ymax - ymin) / max(w, h) * img_size 203 | # check bbox 204 | if bw < 1.0 or bh < 1.0: 205 | continue 206 | boxes.append(Box(0, 0, bw, bh)) 207 | break 208 | 209 | elif args.dataset == 'coco': 210 | data_root = os.path.join(args.root, 'COCO') 211 | dataset = COCODataset(data_dir=data_root, img_size=img_size) 212 | 213 | for i in range(len(dataset)): 214 | if i % 5000 == 0: 215 | print('Loading coco datat [%d / %d]' % (i+1, len(dataset))) 216 | 217 | # For COCO 218 | img, _ = dataset.pull_image(i) 219 | w, h = img.shape[1], img.shape[0] 220 | annotation = dataset.pull_anno(i) 221 | 222 | # prepare bbox datas 223 | for box_and_label in annotation: 224 | box = box_and_label[:-1] 225 | xmin, ymin, xmax, ymax = box 226 | bw = (xmax - xmin) / max(w, h) * img_size 227 | bh = (ymax - ymin) / max(w, h) * img_size 228 | # check bbox 229 | if bw < 1.0 or bh < 1.0: 230 | continue 231 | boxes.append(Box(0, 0, bw, bh)) 232 | 233 | print("Number of all bboxes: ", len(boxes)) 234 | print("Start k-means !") 235 | centroids = anchor_box_kmeans(boxes, n_anchors, loss_convergence, iters_n, plus=True) 236 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from data import * 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | # We use ignore thresh to decide which anchor box can be kept. 7 | ignore_thresh = IGNORE_THRESH 8 | 9 | 10 | class MSEWithLogitsLoss(nn.Module): 11 | def __init__(self, reduction='mean'): 12 | super(MSEWithLogitsLoss, self).__init__() 13 | self.reduction = reduction 14 | 15 | def forward(self, logits, targets, mask): 16 | inputs = torch.clamp(torch.sigmoid(logits), min=1e-4, max=1.0 - 1e-4) 17 | 18 | # 被忽略的先验框的mask都是-1,不参与loss计算 19 | pos_id = (mask==1.0).float() 20 | neg_id = (mask==0.0).float() 21 | pos_loss = pos_id * (inputs - targets)**2 22 | neg_loss = neg_id * (inputs)**2 23 | loss = 5.0*pos_loss + 1.0*neg_loss 24 | 25 | if self.reduction == 'mean': 26 | batch_size = logits.size(0) 27 | loss = torch.sum(loss) / batch_size 28 | 29 | return loss 30 | 31 | else: 32 | return loss 33 | 34 | 35 | def compute_iou(anchor_boxes, gt_box): 36 | """计算先验框和真实框之间的IoU 37 | Input: \n 38 | anchor_boxes: [K, 4] \n 39 | gt_box: [1, 4] \n 40 | Output: \n 41 | iou : [K,] \n 42 | """ 43 | 44 | # anchor box : 45 | ab_x1y1_x2y2 = np.zeros([len(anchor_boxes), 4]) 46 | # 计算先验框的左上角点坐标和右下角点坐标 47 | ab_x1y1_x2y2[:, 0] = anchor_boxes[:, 0] - anchor_boxes[:, 2] / 2 # xmin 48 | ab_x1y1_x2y2[:, 1] = anchor_boxes[:, 1] - anchor_boxes[:, 3] / 2 # ymin 49 | ab_x1y1_x2y2[:, 2] = anchor_boxes[:, 0] + anchor_boxes[:, 2] / 2 # xmax 50 | ab_x1y1_x2y2[:, 3] = anchor_boxes[:, 1] + anchor_boxes[:, 3] / 2 # ymax 51 | w_ab, h_ab = anchor_boxes[:, 2], anchor_boxes[:, 3] 52 | 53 | # gt_box : 54 | # 我们将真实框扩展成[K, 4], 便于计算IoU. 55 | gt_box_expand = np.repeat(gt_box, len(anchor_boxes), axis=0) 56 | 57 | gb_x1y1_x2y2 = np.zeros([len(anchor_boxes), 4]) 58 | # 计算真实框的左上角点坐标和右下角点坐标 59 | gb_x1y1_x2y2[:, 0] = gt_box_expand[:, 0] - gt_box_expand[:, 2] / 2 # xmin 60 | gb_x1y1_x2y2[:, 1] = gt_box_expand[:, 1] - gt_box_expand[:, 3] / 2 # ymin 61 | gb_x1y1_x2y2[:, 2] = gt_box_expand[:, 0] + gt_box_expand[:, 2] / 2 # xmax 62 | gb_x1y1_x2y2[:, 3] = gt_box_expand[:, 1] + gt_box_expand[:, 3] / 2 # ymin 63 | w_gt, h_gt = gt_box_expand[:, 2], gt_box_expand[:, 3] 64 | 65 | # 计算IoU 66 | S_gt = w_gt * h_gt 67 | S_ab = w_ab * h_ab 68 | I_w = np.minimum(gb_x1y1_x2y2[:, 2], ab_x1y1_x2y2[:, 2]) - np.maximum(gb_x1y1_x2y2[:, 0], ab_x1y1_x2y2[:, 0]) 69 | I_h = np.minimum(gb_x1y1_x2y2[:, 3], ab_x1y1_x2y2[:, 3]) - np.maximum(gb_x1y1_x2y2[:, 1], ab_x1y1_x2y2[:, 1]) 70 | S_I = I_h * I_w 71 | U = S_gt + S_ab - S_I + 1e-20 72 | IoU = S_I / U 73 | 74 | return IoU 75 | 76 | 77 | def set_anchors(anchor_size): 78 | """将输入进来的只包含wh的先验框尺寸转换成[N, 4]的ndarray类型, 79 | 包含先验框的中心点坐标和宽高wh,中心点坐标设为0. \n 80 | Input: \n 81 | anchor_size: list -> [[h_1, w_1], \n 82 | [h_2, w_2], \n 83 | ..., \n 84 | [h_n, w_n]]. \n 85 | Output: \n 86 | anchor_boxes: ndarray -> [[0, 0, anchor_w, anchor_h], \n 87 | [0, 0, anchor_w, anchor_h], \n 88 | ... \n 89 | [0, 0, anchor_w, anchor_h]]. \n 90 | """ 91 | anchor_number = len(anchor_size) 92 | anchor_boxes = np.zeros([anchor_number, 4]) 93 | for index, size in enumerate(anchor_size): 94 | anchor_w, anchor_h = size 95 | anchor_boxes[index] = np.array([0, 0, anchor_w, anchor_h]) 96 | 97 | return anchor_boxes 98 | 99 | 100 | def generate_txtytwth(gt_label, w, h, s, anchor_size): 101 | xmin, ymin, xmax, ymax = gt_label[:-1] 102 | # 计算真实边界框的中心点和宽高 103 | c_x = (xmax + xmin) / 2 * w 104 | c_y = (ymax + ymin) / 2 * h 105 | box_w = (xmax - xmin) * w 106 | box_h = (ymax - ymin) * h 107 | 108 | if box_w < 1e-4 or box_h < 1e-4: 109 | # print('not a valid data !!!') 110 | return False 111 | 112 | # 将真是边界框的尺寸映射到网格的尺度上去 113 | c_x_s = c_x / s 114 | c_y_s = c_y / s 115 | box_ws = box_w / s 116 | box_hs = box_h / s 117 | 118 | # 计算中心点所落在的网格的坐标 119 | grid_x = int(c_x_s) 120 | grid_y = int(c_y_s) 121 | 122 | # 获得先验框的中心点坐标和宽高, 123 | # 这里,我们设置所有的先验框的中心点坐标为0 124 | anchor_boxes = set_anchors(anchor_size) 125 | gt_box = np.array([[0, 0, box_ws, box_hs]]) 126 | 127 | # 计算先验框和真实框之间的IoU 128 | iou = compute_iou(anchor_boxes, gt_box) 129 | 130 | # 只保留大于ignore_thresh的先验框去做正样本匹配, 131 | iou_mask = (iou > ignore_thresh) 132 | 133 | result = [] 134 | if iou_mask.sum() == 0: 135 | # 如果所有的先验框算出的IoU都小于阈值,那么就将IoU最大的那个先验框分配给正样本. 136 | # 其他的先验框统统视为负样本 137 | index = np.argmax(iou) 138 | p_w, p_h = anchor_size[index] 139 | tx = c_x_s - grid_x 140 | ty = c_y_s - grid_y 141 | tw = np.log(box_ws / p_w) 142 | th = np.log(box_hs / p_h) 143 | weight = 2.0 - (box_w / w) * (box_h / h) 144 | 145 | result.append([index, grid_x, grid_y, tx, ty, tw, th, weight, xmin, ymin, xmax, ymax]) 146 | 147 | return result 148 | 149 | else: 150 | # 有至少一个先验框的IoU超过了阈值. 151 | # 但我们只保留超过阈值的那些先验框中IoU最大的,其他的先验框忽略掉,不参与loss计算。 152 | # 而小于阈值的先验框统统视为负样本。 153 | best_index = np.argmax(iou) 154 | for index, iou_m in enumerate(iou_mask): 155 | if iou_m: 156 | if index == best_index: 157 | p_w, p_h = anchor_size[index] 158 | tx = c_x_s - grid_x 159 | ty = c_y_s - grid_y 160 | tw = np.log(box_ws / p_w) 161 | th = np.log(box_hs / p_h) 162 | weight = 2.0 - (box_w / w) * (box_h / h) 163 | 164 | result.append([index, grid_x, grid_y, tx, ty, tw, th, weight, xmin, ymin, xmax, ymax]) 165 | else: 166 | # 对于被忽略的先验框,我们将其权重weight设置为-1 167 | result.append([index, grid_x, grid_y, 0., 0., 0., 0., -1.0, 0., 0., 0., 0.]) 168 | 169 | return result 170 | 171 | 172 | def gt_creator(input_size, stride, label_lists, anchor_size): 173 | # 必要的参数 174 | batch_size = len(label_lists) 175 | s = stride 176 | w = input_size 177 | h = input_size 178 | ws = w // s 179 | hs = h // s 180 | anchor_number = len(anchor_size) 181 | gt_tensor = np.zeros([batch_size, hs, ws, anchor_number, 1+1+4+1+4]) 182 | 183 | # 制作正样本 184 | for batch_index in range(batch_size): 185 | for gt_label in label_lists[batch_index]: 186 | # get a bbox coords 187 | gt_class = int(gt_label[-1]) 188 | results = generate_txtytwth(gt_label, w, h, s, anchor_size) 189 | if results: 190 | for result in results: 191 | index, grid_x, grid_y, tx, ty, tw, th, weight, xmin, ymin, xmax, ymax = result 192 | if weight > 0.: 193 | if grid_y < gt_tensor.shape[1] and grid_x < gt_tensor.shape[2]: 194 | gt_tensor[batch_index, grid_y, grid_x, index, 0] = 1.0 195 | gt_tensor[batch_index, grid_y, grid_x, index, 1] = gt_class 196 | gt_tensor[batch_index, grid_y, grid_x, index, 2:6] = np.array([tx, ty, tw, th]) 197 | gt_tensor[batch_index, grid_y, grid_x, index, 6] = weight 198 | gt_tensor[batch_index, grid_y, grid_x, index, 7:] = np.array([xmin, ymin, xmax, ymax]) 199 | else: 200 | # 对于那些被忽略的先验框,其gt_obj参数为-1,weight权重也是-1 201 | gt_tensor[batch_index, grid_y, grid_x, index, 0] = -1.0 202 | gt_tensor[batch_index, grid_y, grid_x, index, 6] = -1.0 203 | 204 | gt_tensor = gt_tensor.reshape(batch_size, hs * ws * anchor_number, 1+1+4+1+4) 205 | 206 | return gt_tensor 207 | 208 | 209 | def iou_score(bboxes_a, bboxes_b): 210 | """ 211 | bbox_1 : [B*N, 4] = [x1, y1, x2, y2] 212 | bbox_2 : [B*N, 4] = [x1, y1, x2, y2] 213 | """ 214 | tl = torch.max(bboxes_a[:, :2], bboxes_b[:, :2]) 215 | br = torch.min(bboxes_a[:, 2:], bboxes_b[:, 2:]) 216 | area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) 217 | area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) 218 | 219 | en = (tl < br).type(tl.type()).prod(dim=1) 220 | area_i = torch.prod(br - tl, 1) * en # * ((tl < br).all()) 221 | return area_i / (area_a + area_b - area_i) 222 | 223 | 224 | def loss(pred_conf, pred_cls, pred_txtytwth, pred_iou, label, num_classes): 225 | # 损失函数 226 | conf_loss_function = MSEWithLogitsLoss(reduction='mean') 227 | cls_loss_function = nn.CrossEntropyLoss(reduction='none') 228 | txty_loss_function = nn.BCEWithLogitsLoss(reduction='none') 229 | twth_loss_function = nn.MSELoss(reduction='none') 230 | iou_loss_function = nn.SmoothL1Loss(reduction='none') 231 | 232 | # 预测 233 | pred_conf = pred_conf[:, :, 0] 234 | pred_cls = pred_cls.permute(0, 2, 1) 235 | pred_txty = pred_txtytwth[:, :, :2] 236 | pred_twth = pred_txtytwth[:, :, 2:] 237 | pred_iou = pred_iou[:, :, 0] 238 | 239 | # 标签 240 | gt_conf = label[:, :, 0].float() 241 | gt_obj = label[:, :, 1].float() 242 | gt_cls = label[:, :, 2].long() 243 | gt_txty = label[:, :, 3:5].float() 244 | gt_twth = label[:, :, 5:7].float() 245 | gt_box_scale_weight = label[:, :, 7] 246 | gt_iou = (gt_box_scale_weight > 0.).float() 247 | gt_mask = (gt_box_scale_weight > 0.).float() 248 | 249 | batch_size = pred_conf.size(0) 250 | # 置信度损失 251 | conf_loss = conf_loss_function(pred_conf, gt_conf, gt_obj) 252 | 253 | # 类别损失 254 | cls_loss = torch.sum(cls_loss_function(pred_cls, gt_cls) * gt_mask) / batch_size 255 | 256 | # 边界框的位置损失 257 | txty_loss = torch.sum(torch.sum(txty_loss_function(pred_txty, gt_txty), dim=-1) * gt_box_scale_weight * gt_mask) / batch_size 258 | twth_loss = torch.sum(torch.sum(twth_loss_function(pred_twth, gt_twth), dim=-1) * gt_box_scale_weight * gt_mask) / batch_size 259 | bbox_loss = txty_loss + twth_loss 260 | 261 | # iou 损失 262 | iou_loss = torch.sum(iou_loss_function(pred_iou, gt_iou) * gt_mask) / batch_size 263 | 264 | return conf_loss, cls_loss, bbox_loss, iou_loss 265 | 266 | 267 | if __name__ == "__main__": 268 | gt_box = np.array([[0.0, 0.0, 10, 10]]) 269 | anchor_boxes = np.array([[0.0, 0.0, 10, 10], 270 | [0.0, 0.0, 4, 4], 271 | [0.0, 0.0, 8, 8], 272 | [0.0, 0.0, 16, 16] 273 | ]) 274 | iou = compute_iou(anchor_boxes, gt_box) 275 | print(iou) -------------------------------------------------------------------------------- /models/yolov2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .basic import Conv, reorg_layer 6 | from backbone import build_backbone 7 | 8 | import numpy as np 9 | from .loss import iou_score, compute_loss 10 | 11 | 12 | class YOLOv2(nn.Module): 13 | def __init__(self, 14 | cfg, 15 | device, 16 | input_size=416, 17 | num_classes=20, 18 | trainable=False, 19 | conf_thresh=0.001, 20 | nms_thresh=0.6, 21 | topk=100, 22 | anchor_size=None): 23 | super(YOLOv2, self).__init__() 24 | self.device = device 25 | self.input_size = input_size 26 | self.num_classes = num_classes 27 | self.trainable = trainable 28 | self.conf_thresh = conf_thresh 29 | self.nms_thresh = nms_thresh 30 | self.stride = cfg['stride'] 31 | self.topk = topk 32 | 33 | # Anchor box config 34 | self.anchor_size = torch.tensor(anchor_size) # [KA, 2] 35 | self.num_anchors = len(anchor_size) 36 | self.anchor_boxes = self.create_grid(input_size) 37 | 38 | # 主干网络 39 | self.backbone, feat_dims = build_backbone(cfg['backbone'], cfg['pretrained']) 40 | 41 | # 检测头 42 | self.convsets_1 = nn.Sequential( 43 | Conv(feat_dims[-1], cfg['head_dim'], k=3, p=1), 44 | Conv(cfg['head_dim'], cfg['head_dim'], k=3, p=1) 45 | ) 46 | 47 | # 融合高分辨率的特征信息 48 | self.route_layer = Conv(feat_dims[-2], cfg['reorg_dim'], k=1) 49 | self.reorg = reorg_layer(stride=2) 50 | 51 | # 检测头 52 | self.convsets_2 = Conv(cfg['head_dim']+cfg['reorg_dim']*4, cfg['head_dim'], k=3, p=1) 53 | 54 | # 预测层 55 | self.pred = nn.Conv2d(cfg['head_dim'], self.num_anchors*(1 + 4 + self.num_classes), 1) 56 | 57 | 58 | if self.trainable: 59 | self.init_bias() 60 | 61 | 62 | def init_bias(self): 63 | # init bias 64 | init_prob = 0.01 65 | bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob)) 66 | nn.init.constant_(self.pred.bias[..., :self.num_anchors], bias_value) 67 | nn.init.constant_(self.pred.bias[..., 1*self.num_anchors:(1+self.num_classes)*self.num_anchors], bias_value) 68 | 69 | 70 | def create_grid(self, input_size): 71 | w, h = input_size, input_size 72 | # 生成G矩阵 73 | fmp_w, fmp_h = w // self.stride, h // self.stride 74 | grid_y, grid_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)]) 75 | # [H, W, 2] -> [HW, 2] 76 | grid_xy = torch.stack([grid_x, grid_y], dim=-1).float().view(-1, 2) 77 | # [HW, 2] -> [HW, 1, 2] -> [HW, KA, 2] 78 | grid_xy = grid_xy[:, None, :].repeat(1, self.num_anchors, 1) 79 | 80 | # [KA, 2] -> [1, KA, 2] -> [HW, KA, 2] 81 | anchor_wh = self.anchor_size[None, :, :].repeat(fmp_h*fmp_w, 1, 1) 82 | 83 | # [HW, KA, 4] -> [M, 4] 84 | anchor_boxes = torch.cat([grid_xy, anchor_wh], dim=-1) 85 | anchor_boxes = anchor_boxes.view(-1, 4).to(self.device) 86 | 87 | return anchor_boxes 88 | 89 | 90 | def set_grid(self, input_size): 91 | self.input_size = input_size 92 | self.anchor_boxes = self.create_grid(input_size) 93 | 94 | 95 | def decode_boxes(self, anchors, txtytwth_pred): 96 | """将txtytwth预测换算成边界框的左上角点坐标和右下角点坐标 \n 97 | Input: \n 98 | txtytwth_pred : [B, H*W*KA, 4] \n 99 | Output: \n 100 | x1y1x2y2_pred : [B, H*W*KA, 4] \n 101 | """ 102 | # 获得边界框的中心点坐标和宽高 103 | # b_x = sigmoid(tx) + gride_x 104 | # b_y = sigmoid(ty) + gride_y 105 | xy_pred = torch.sigmoid(txtytwth_pred[..., :2]) + anchors[..., :2] 106 | # b_w = anchor_w * exp(tw) 107 | # b_h = anchor_h * exp(th) 108 | wh_pred = torch.exp(txtytwth_pred[..., 2:]) * anchors[..., 2:] 109 | 110 | # [B, H*W*KA, 4] 111 | xywh_pred = torch.cat([xy_pred, wh_pred], -1) * self.stride 112 | 113 | # 将中心点坐标和宽高换算成边界框的左上角点坐标和右下角点坐标 114 | x1y1x2y2_pred = torch.zeros_like(xywh_pred) 115 | x1y1x2y2_pred[..., :2] = xywh_pred[..., :2] - xywh_pred[..., 2:] * 0.5 116 | x1y1x2y2_pred[..., 2:] = xywh_pred[..., :2] + xywh_pred[..., 2:] * 0.5 117 | 118 | return x1y1x2y2_pred 119 | 120 | 121 | def nms(self, bboxes, scores): 122 | """"Pure Python NMS baseline.""" 123 | x1 = bboxes[:, 0] #xmin 124 | y1 = bboxes[:, 1] #ymin 125 | x2 = bboxes[:, 2] #xmax 126 | y2 = bboxes[:, 3] #ymax 127 | 128 | areas = (x2 - x1) * (y2 - y1) 129 | order = scores.argsort()[::-1] 130 | 131 | keep = [] 132 | while order.size > 0: 133 | i = order[0] 134 | keep.append(i) 135 | # 计算交集的左上角点和右下角点的坐标 136 | xx1 = np.maximum(x1[i], x1[order[1:]]) 137 | yy1 = np.maximum(y1[i], y1[order[1:]]) 138 | xx2 = np.minimum(x2[i], x2[order[1:]]) 139 | yy2 = np.minimum(y2[i], y2[order[1:]]) 140 | # 计算交集的宽高 141 | w = np.maximum(1e-10, xx2 - xx1) 142 | h = np.maximum(1e-10, yy2 - yy1) 143 | # 计算交集的面积 144 | inter = w * h 145 | 146 | # 计算交并比 147 | iou = inter / (areas[i] + areas[order[1:]] - inter) 148 | # 滤除超过nms阈值的检测框 149 | inds = np.where(iou <= self.nms_thresh)[0] 150 | order = order[inds + 1] 151 | 152 | return keep 153 | 154 | 155 | def postprocess(self, conf_pred, cls_pred, reg_pred): 156 | """ 157 | Input: 158 | conf_pred: (Tensor) [H*W*KA, 1] 159 | cls_pred: (Tensor) [H*W*KA, C] 160 | reg_pred: (Tensor) [H*W*KA, 4] 161 | """ 162 | anchors = self.anchor_boxes 163 | 164 | # (H x W x KA x C,) 165 | scores = (torch.sigmoid(conf_pred) * torch.softmax(cls_pred, dim=-1)).flatten() 166 | 167 | # Keep top k top scoring indices only. 168 | num_topk = min(self.topk, reg_pred.size(0)) 169 | 170 | # torch.sort is actually faster than .topk (at least on GPUs) 171 | predicted_prob, topk_idxs = scores.sort(descending=True) 172 | topk_scores = predicted_prob[:num_topk] 173 | topk_idxs = topk_idxs[:num_topk] 174 | 175 | # filter out the proposals with low confidence score 176 | keep_idxs = topk_scores > self.conf_thresh 177 | scores = topk_scores[keep_idxs] 178 | topk_idxs = topk_idxs[keep_idxs] 179 | 180 | anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor') 181 | labels = topk_idxs % self.num_classes 182 | 183 | reg_pred = reg_pred[anchor_idxs] 184 | anchors = anchors[anchor_idxs] 185 | 186 | # 解算边界框, 并归一化边界框: [H*W*KA, 4] 187 | bboxes = self.decode_boxes(anchors, reg_pred) 188 | 189 | # to cpu 190 | scores = scores.cpu().numpy() 191 | labels = labels.cpu().numpy() 192 | bboxes = bboxes.cpu().numpy() 193 | 194 | # NMS 195 | keep = np.zeros(len(bboxes), dtype=np.int) 196 | for i in range(self.num_classes): 197 | inds = np.where(labels == i)[0] 198 | if len(inds) == 0: 199 | continue 200 | c_bboxes = bboxes[inds] 201 | c_scores = scores[inds] 202 | c_keep = self.nms(c_bboxes, c_scores) 203 | keep[inds[c_keep]] = 1 204 | 205 | keep = np.where(keep > 0) 206 | bboxes = bboxes[keep] 207 | scores = scores[keep] 208 | labels = labels[keep] 209 | 210 | # 归一化边界框 211 | bboxes = bboxes / self.input_size 212 | bboxes = np.clip(bboxes, 0., 1.) 213 | 214 | return bboxes, scores, labels 215 | 216 | 217 | @torch.no_grad() 218 | def inference(self, x): 219 | # backbone主干网络 220 | feats = self.backbone(x) 221 | c4, c5 = feats['c4'], feats['c5'] 222 | 223 | # 处理c5特征 224 | p5 = self.convsets_1(c5) 225 | 226 | # 融合c4特征 227 | p4 = self.reorg(self.route_layer(c4)) 228 | p5 = torch.cat([p4, p5], dim=1) 229 | 230 | # 处理p5特征 231 | p5 = self.convsets_2(p5) 232 | 233 | # 预测 234 | prediction = self.pred(p5) 235 | 236 | B, abC, H, W = prediction.size() 237 | KA = self.num_anchors 238 | NC = self.num_classes 239 | 240 | # [B, KA * C, H, W] -> [B, H, W, KA * C] -> [B, H*W, KA*C] 241 | prediction = prediction.permute(0, 2, 3, 1).contiguous().view(B, -1, abC) 242 | 243 | # 从pred中分离出objectness预测、类别class预测、bbox的txtytwth预测 244 | # [B, H*W, KA*C] -> [B, H*W, KA] -> [B, H*W*KA, 1] 245 | conf_pred = prediction[..., :KA].contiguous().view(B, -1, 1) 246 | # [B, H*W, KA*C] -> [B, H*W, KA*NC] -> [B, H*W*KA, NC] 247 | cls_pred = prediction[..., 1*KA : (1+NC)*KA].contiguous().view(B, -1, NC) 248 | # [B, H*W, KA*C] -> [B, H*W, KA*4] -> [B, H*W, KA, 4] 249 | txtytwth_pred = prediction[..., (1+NC)*KA:].contiguous().view(B, -1, 4) 250 | 251 | # 测试时,笔者默认batch是1, 252 | # 因此,我们不需要用batch这个维度,用[0]将其取走。 253 | conf_pred = conf_pred[0] #[H*W*KA, 1] 254 | cls_pred = cls_pred[0] #[H*W*KA, NC] 255 | txtytwth_pred = txtytwth_pred[0] #[H*W*KA, 4] 256 | 257 | # 后处理 258 | bboxes, scores, labels = self.postprocess(conf_pred, cls_pred, txtytwth_pred) 259 | 260 | return bboxes, scores, labels 261 | 262 | 263 | def forward(self, x, targets=None): 264 | if not self.trainable: 265 | return self.inference(x) 266 | else: 267 | # backbone主干网络 268 | feats = self.backbone(x) 269 | c4, c5 = feats['c4'], feats['c5'] 270 | 271 | # 处理c5特征 272 | p5 = self.convsets_1(c5) 273 | 274 | # 融合c4特征 275 | p4 = self.reorg(self.route_layer(c4)) 276 | p5 = torch.cat([p4, p5], dim=1) 277 | 278 | # 处理p5特征 279 | p5 = self.convsets_2(p5) 280 | 281 | # 预测 282 | prediction = self.pred(p5) 283 | 284 | B, abC, H, W = prediction.size() 285 | KA = self.num_anchors 286 | NC = self.num_classes 287 | 288 | # [B, KA * C, H, W] -> [B, H, W, KA * C] -> [B, H*W, KA*C] 289 | prediction = prediction.permute(0, 2, 3, 1).contiguous().view(B, H*W, abC) 290 | 291 | # 从pred中分离出objectness预测、类别class预测、bbox的txtytwth预测 292 | # [B, H*W, KA*C] -> [B, H*W, KA] -> [B, H*W*KA, 1] 293 | conf_pred = prediction[..., :KA].contiguous().view(B, -1, 1) 294 | # [B, H*W, KA*C] -> [B, H*W, KA*NC] -> [B, H*W*KA, NC] 295 | cls_pred = prediction[..., 1*KA : (1+NC)*KA].contiguous().view(B, -1, NC) 296 | # [B, H*W, KA*C] -> [B, H*W, KA*4] -> [B, H*W*KA, 4] 297 | txtytwth_pred = prediction[..., (1+NC)*KA:].contiguous().view(B, -1, 4) 298 | 299 | # 解算边界框 300 | x1y1x2y2_pred = (self.decode_boxes(self.anchor_boxes, txtytwth_pred) / self.input_size).view(-1, 4) 301 | x1y1x2y2_gt = targets[:, :, 7:].view(-1, 4) 302 | 303 | # 计算预测框和真实框之间的IoU 304 | iou_pred = iou_score(x1y1x2y2_pred, x1y1x2y2_gt).view(B, -1, 1) 305 | 306 | # 将IoU作为置信度的学习目标 307 | with torch.no_grad(): 308 | gt_conf = iou_pred.clone() 309 | 310 | # 将IoU作为置信度的学习目标 311 | # [obj, cls, txtytwth, x1y1x2y2] -> [conf, obj, cls, txtytwth] 312 | targets = torch.cat([gt_conf, targets[:, :, :7]], dim=2) 313 | 314 | # 计算损失 315 | ( 316 | conf_loss, 317 | cls_loss, 318 | bbox_loss, 319 | total_loss 320 | ) = compute_loss( 321 | pred_conf=conf_pred, 322 | pred_cls=cls_pred, 323 | pred_txtytwth=txtytwth_pred, 324 | targets=targets, 325 | ) 326 | 327 | return conf_loss, cls_loss, bbox_loss, total_loss 328 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import os 4 | import random 5 | import argparse 6 | import time 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from copy import deepcopy 12 | 13 | from data.coco import COCODataset 14 | from data.voc0712 import VOCDetection 15 | from data.transform import Augmentation, BaseTransform 16 | 17 | from utils.misc import detection_collate 18 | from utils.com_paras_flops import FLOPs_and_Params 19 | from evaluator.cocoapi_evaluator import COCOAPIEvaluator 20 | from evaluator.vocapi_evaluator import VOCAPIEvaluator 21 | 22 | from config import build_model_config 23 | from models.build import build_yolov2 24 | from models.matcher import gt_creator 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser(description='YOLOv2 Detection') 28 | # 基本参数 29 | parser.add_argument('--cuda', action='store_true', default=False, 30 | help='use cuda.') 31 | parser.add_argument('--tfboard', action='store_true', default=False, 32 | help='use tensorboard') 33 | parser.add_argument('--eval_epoch', type=int, 34 | default=10, help='interval between evaluations') 35 | parser.add_argument('--save_folder', default='weights/', type=str, 36 | help='Gamma update for SGD') 37 | parser.add_argument('--num_workers', default=8, type=int, 38 | help='Number of workers used in dataloading') 39 | 40 | # 模型参数 41 | parser.add_argument('-v', '--version', default='yolov2', 42 | help='build yolo') 43 | parser.add_argument('--conf_thresh', default=0.001, type=float, 44 | help='Confidence threshold') 45 | parser.add_argument('--nms_thresh', default=0.50, type=float, 46 | help='NMS threshold') 47 | parser.add_argument('--topk', default=1000, type=int, 48 | help='topk predicted candidates') 49 | 50 | # 训练配置 51 | parser.add_argument('-bs', '--batch_size', default=8, type=int, 52 | help='Batch size for training') 53 | parser.add_argument('-accu', '--accumulate', default=8, type=int, 54 | help='gradient accumulate.') 55 | parser.add_argument('-no_wp', '--no_warm_up', action='store_true', default=False, 56 | help='yes or no to choose using warmup strategy to train') 57 | parser.add_argument('--wp_epoch', type=int, default=1, 58 | help='The upper bound of warm-up') 59 | parser.add_argument('--start_epoch', type=int, default=0, 60 | help='start epoch to train') 61 | parser.add_argument('-r', '--resume', default=None, type=str, 62 | help='keep training') 63 | parser.add_argument('-ms', '--multi_scale', action='store_true', default=False, 64 | help='use multi-scale trick') 65 | parser.add_argument('--max_epoch', type=int, default=200, 66 | help='The upper bound of warm-up') 67 | parser.add_argument('--lr_epoch', nargs='+', default=[100, 150], type=int, 68 | help='lr epoch to decay') 69 | 70 | # 优化器参数 71 | parser.add_argument('--lr', default=1e-3, type=float, 72 | help='initial learning rate') 73 | parser.add_argument('--momentum', default=0.9, type=float, 74 | help='Momentum value for optim') 75 | parser.add_argument('--weight_decay', default=5e-4, type=float, 76 | help='Weight decay for SGD') 77 | parser.add_argument('--gamma', default=0.1, type=float, 78 | help='Gamma update for SGD') 79 | 80 | # 数据集参数 81 | parser.add_argument('-d', '--dataset', default='voc', 82 | help='voc or coco') 83 | parser.add_argument('--root', default='/mnt/share/ssd2/dataset', 84 | help='data root') 85 | 86 | return parser.parse_args() 87 | 88 | 89 | def train(): 90 | args = parse_args() 91 | print("Setting Arguments.. : ", args) 92 | print("----------------------------------------------------------") 93 | 94 | path_to_save = os.path.join(args.save_folder, args.dataset, args.version) 95 | os.makedirs(path_to_save, exist_ok=True) 96 | 97 | # 是否使用cuda 98 | if args.cuda: 99 | print('use cuda') 100 | device = torch.device("cuda") 101 | else: 102 | device = torch.device("cpu") 103 | 104 | # 是否使用多尺度训练 105 | if args.multi_scale: 106 | print('use the multi-scale trick ...') 107 | train_size = 640 108 | val_size = 416 109 | else: 110 | train_size = 416 111 | val_size = 416 112 | 113 | # 构建yolov2的配置文件 114 | cfg = build_model_config(args) 115 | 116 | # 构建dataset类和dataloader类 117 | dataset, num_classes, evaluator = build_dataset(args, device, train_size, val_size) 118 | 119 | # 构建dataloader类 120 | dataloader = build_dataloader(args, dataset) 121 | 122 | # 构建我们的模型 123 | model = build_yolov2(args, cfg, device, train_size, num_classes, trainable=True) 124 | model.to(device).train() 125 | 126 | # compute FLOPs and Params 127 | model_copy = deepcopy(model) 128 | model_copy.trainable = False 129 | model_copy.eval() 130 | model_copy.set_grid(val_size) 131 | FLOPs_and_Params(model=model_copy, 132 | img_size=val_size, 133 | device=device) 134 | del model_copy 135 | 136 | # 使用 tensorboard 可视化训练过程 137 | if args.tfboard: 138 | print('use tensorboard') 139 | from torch.utils.tensorboard import SummaryWriter 140 | c_time = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())) 141 | log_path = os.path.join('log/coco/', args.version, c_time) 142 | os.makedirs(log_path, exist_ok=True) 143 | 144 | writer = SummaryWriter(log_path) 145 | 146 | # keep training 147 | if args.resume is not None: 148 | print('keep training model: %s' % (args.resume)) 149 | model.load_state_dict(torch.load(args.resume, map_location=device)) 150 | 151 | # 构建训练优化器 152 | base_lr = args.lr 153 | tmp_lr = base_lr 154 | optimizer = optim.SGD(model.parameters(), 155 | lr=args.lr, 156 | momentum=args.momentum, 157 | weight_decay=args.weight_decay 158 | ) 159 | 160 | max_epoch = args.max_epoch # 最大训练轮次 161 | lr_epoch = args.lr_epoch 162 | epoch_size = len(dataloader) # 每一训练轮次的迭代次数 163 | 164 | # 开始训练 165 | best_map = -1. 166 | t0 = time.time() 167 | for epoch in range(args.start_epoch, max_epoch): 168 | 169 | # 使用阶梯学习率衰减策略 170 | if epoch in lr_epoch: 171 | tmp_lr = tmp_lr * 0.1 172 | set_lr(optimizer, tmp_lr) 173 | 174 | 175 | for iter_i, (images, targets) in enumerate(dataloader): 176 | ni = iter_i+epoch*epoch_size 177 | # 使用warm-up策略来调整早期的学习率 178 | if not args.no_warm_up: 179 | if epoch < args.wp_epoch: 180 | nw = args.wp_epoch*epoch_size 181 | tmp_lr = base_lr * pow((ni)*1. / (nw), 4) 182 | set_lr(optimizer, tmp_lr) 183 | 184 | elif epoch == args.wp_epoch and iter_i == 0: 185 | tmp_lr = base_lr 186 | set_lr(optimizer, tmp_lr) 187 | 188 | # 多尺度训练 189 | if iter_i % 10 == 0 and iter_i > 0 and args.multi_scale: 190 | # 随机选择一个新的尺寸 191 | train_size = random.randint(10, 19) * 32 192 | model.set_grid(train_size) 193 | if args.multi_scale: 194 | # 插值 195 | images = torch.nn.functional.interpolate(images, size=train_size, mode='bilinear', align_corners=False) 196 | 197 | # 制作训练标签 198 | targets = [label.tolist() for label in targets] 199 | targets = gt_creator( 200 | input_size=train_size, 201 | stride=cfg['stride'], 202 | label_lists=targets, 203 | anchor_size=cfg['anchor_size'][args.dataset], 204 | ignore_thresh=cfg['ignore_thresh'] 205 | ) 206 | 207 | # to device 208 | images = images.to(device) 209 | targets = torch.tensor(targets).float().to(device) 210 | 211 | # 前向推理和计算损失 212 | conf_loss, cls_loss, bbox_loss, total_loss = model(images, targets=targets) 213 | 214 | # 梯度累加 & 反向传播 215 | total_loss /= args.accumulate 216 | total_loss.backward() 217 | 218 | # 更新 219 | if ni % args.accumulate == 0: 220 | optimizer.step() 221 | optimizer.zero_grad() 222 | 223 | if iter_i % 10 == 0: 224 | if args.tfboard: 225 | # viz loss 226 | writer.add_scalar('obj loss', conf_loss.item(), iter_i + epoch * epoch_size) 227 | writer.add_scalar('cls loss', cls_loss.item(), iter_i + epoch * epoch_size) 228 | writer.add_scalar('box loss', bbox_loss.item(), iter_i + epoch * epoch_size) 229 | 230 | t1 = time.time() 231 | print('[Epoch %d/%d][Iter %d/%d][lr %.6f]' 232 | '[Loss: obj %.2f || cls %.2f || bbox %.2f || total %.2f || size %d || time: %.2f]' 233 | % (epoch+1, max_epoch, iter_i, epoch_size, tmp_lr, 234 | conf_loss.item(), 235 | cls_loss.item(), 236 | bbox_loss.item(), 237 | total_loss.item(), 238 | train_size, t1-t0), 239 | flush=True) 240 | 241 | t0 = time.time() 242 | 243 | # evaluation 244 | if epoch % args.eval_epoch == 0 or (epoch + 1) == max_epoch: 245 | model.trainable = False 246 | model.set_grid(val_size) 247 | model.eval() 248 | 249 | # evaluate 250 | evaluator.evaluate(model) 251 | 252 | # convert to training mode. 253 | model.trainable = True 254 | model.set_grid(train_size) 255 | model.train() 256 | 257 | cur_map = evaluator.map 258 | if cur_map > best_map: 259 | # update best-map 260 | best_map = cur_map 261 | # save model 262 | print('Saving state, epoch:', epoch + 1) 263 | weight_name = '{}_epoch_{}_{:.1f}.pth'.format(args.version, epoch + 1, best_map*100) 264 | checkpoint_path = os.path.join(path_to_save, weight_name) 265 | torch.save(model.state_dict(), checkpoint_path) 266 | 267 | 268 | def set_lr(optimizer, lr): 269 | for param_group in optimizer.param_groups: 270 | param_group['lr'] = lr 271 | 272 | 273 | def build_dataset(args, device, train_size, val_size): 274 | pixel_mean = (0.406, 0.456, 0.485) # BGR 275 | pixel_std = (0.225, 0.224, 0.229) # BGR 276 | train_transform = Augmentation(train_size, pixel_mean, pixel_std) 277 | val_transform = BaseTransform(val_size, pixel_mean, pixel_std) 278 | 279 | # 构建dataset类和dataloader类 280 | if args.dataset == 'voc': 281 | data_root = os.path.join(args.root, 'VOCdevkit') 282 | # 加载voc数据集 283 | num_classes = 20 284 | dataset = VOCDetection( 285 | root=data_root, 286 | transform=train_transform 287 | ) 288 | 289 | evaluator = VOCAPIEvaluator( 290 | data_root=data_root, 291 | img_size=val_size, 292 | device=device, 293 | transform=val_transform 294 | ) 295 | 296 | elif args.dataset == 'coco': 297 | # 加载COCO数据集 298 | data_root = os.path.join(args.root, 'COCO') 299 | num_classes = 80 300 | dataset = COCODataset( 301 | data_dir=data_root, 302 | img_size=train_size, 303 | transform=train_transform 304 | ) 305 | 306 | evaluator = COCOAPIEvaluator( 307 | data_dir=data_root, 308 | img_size=val_size, 309 | device=device, 310 | transform=val_transform 311 | ) 312 | 313 | else: 314 | print('unknow dataset !! Only support voc and coco !!') 315 | exit(0) 316 | 317 | print('Training model on:', args.dataset) 318 | print('The dataset size:', len(dataset)) 319 | print("----------------------------------------------------------") 320 | 321 | 322 | return dataset, num_classes, evaluator 323 | 324 | 325 | def build_dataloader(args, dataset): 326 | dataloader = torch.utils.data.DataLoader( 327 | dataset, 328 | batch_size=args.batch_size, 329 | shuffle=True, 330 | collate_fn=detection_collate, 331 | num_workers=args.num_workers, 332 | pin_memory=True 333 | ) 334 | 335 | return dataloader 336 | 337 | 338 | if __name__ == '__main__': 339 | train() 340 | -------------------------------------------------------------------------------- /evaluator/vocapi_evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from data.voc0712 import VOCDetection, VOC_CLASSES 4 | import os 5 | import time 6 | import numpy as np 7 | import pickle 8 | 9 | import xml.etree.ElementTree as ET 10 | 11 | 12 | class VOCAPIEvaluator(): 13 | """ VOC AP Evaluation class """ 14 | def __init__(self, data_root, img_size, device, transform, set_type='test', year='2007', display=False): 15 | self.data_root = data_root 16 | self.img_size = img_size 17 | self.device = device 18 | self.transform = transform 19 | self.labelmap = VOC_CLASSES 20 | self.set_type = set_type 21 | self.year = year 22 | self.display = display 23 | 24 | # path 25 | self.devkit_path = data_root + 'VOC' + year 26 | self.annopath = os.path.join(data_root, 'VOC2007', 'Annotations', '%s.xml') 27 | self.imgpath = os.path.join(data_root, 'VOC2007', 'JPEGImages', '%s.jpg') 28 | self.imgsetpath = os.path.join(data_root, 'VOC2007', 'ImageSets', 'Main', set_type+'.txt') 29 | self.output_dir = self.get_output_dir('voc_eval/', self.set_type) 30 | 31 | # dataset 32 | self.dataset = VOCDetection(root=data_root, 33 | image_sets=[('2007', set_type)], 34 | transform=transform 35 | ) 36 | 37 | def evaluate(self, net): 38 | net.eval() 39 | num_images = len(self.dataset) 40 | # all detections are collected into: 41 | # all_boxes[cls][image] = N x 5 array of detections in 42 | # (x1, y1, x2, y2, score) 43 | self.all_boxes = [[[] for _ in range(num_images)] 44 | for _ in range(len(self.labelmap))] 45 | 46 | # timers 47 | det_file = os.path.join(self.output_dir, 'detections.pkl') 48 | 49 | for i in range(num_images): 50 | im, gt, h, w = self.dataset.pull_item(i) 51 | 52 | x = Variable(im.unsqueeze(0)).to(self.device) 53 | t0 = time.time() 54 | # forward 55 | bboxes, scores, labels = net(x) 56 | detect_time = time.time() - t0 57 | scale = np.array([[w, h, w, h]]) 58 | bboxes *= scale 59 | 60 | for j in range(len(self.labelmap)): 61 | inds = np.where(labels == j)[0] 62 | if len(inds) == 0: 63 | self.all_boxes[j][i] = np.empty([0, 5], dtype=np.float32) 64 | continue 65 | c_bboxes = bboxes[inds] 66 | c_scores = scores[inds] 67 | c_dets = np.hstack((c_bboxes, 68 | c_scores[:, np.newaxis])).astype(np.float32, 69 | copy=False) 70 | self.all_boxes[j][i] = c_dets 71 | 72 | if i % 500 == 0: 73 | print('im_detect: {:d}/{:d} {:.3f}s'.format(i + 1, num_images, detect_time)) 74 | 75 | with open(det_file, 'wb') as f: 76 | pickle.dump(self.all_boxes, f, pickle.HIGHEST_PROTOCOL) 77 | 78 | print('Evaluating detections') 79 | self.evaluate_detections(self.all_boxes) 80 | 81 | print('Mean AP: ', self.map) 82 | 83 | 84 | def parse_rec(self, filename): 85 | """ Parse a PASCAL VOC xml file """ 86 | tree = ET.parse(filename) 87 | objects = [] 88 | for obj in tree.findall('object'): 89 | obj_struct = {} 90 | obj_struct['name'] = obj.find('name').text 91 | obj_struct['pose'] = obj.find('pose').text 92 | obj_struct['truncated'] = int(obj.find('truncated').text) 93 | obj_struct['difficult'] = int(obj.find('difficult').text) 94 | bbox = obj.find('bndbox') 95 | obj_struct['bbox'] = [int(bbox.find('xmin').text), 96 | int(bbox.find('ymin').text), 97 | int(bbox.find('xmax').text), 98 | int(bbox.find('ymax').text)] 99 | objects.append(obj_struct) 100 | 101 | return objects 102 | 103 | 104 | def get_output_dir(self, name, phase): 105 | """Return the directory where experimental artifacts are placed. 106 | If the directory does not exist, it is created. 107 | A canonical path is built using the name from an imdb and a network 108 | (if not None). 109 | """ 110 | filedir = os.path.join(name, phase) 111 | if not os.path.exists(filedir): 112 | os.makedirs(filedir) 113 | return filedir 114 | 115 | 116 | def get_voc_results_file_template(self, cls): 117 | # VOCdevkit/VOC2007/results/det_test_aeroplane.txt 118 | filename = 'det_' + self.set_type + '_%s.txt' % (cls) 119 | filedir = os.path.join(self.devkit_path, 'results') 120 | if not os.path.exists(filedir): 121 | os.makedirs(filedir) 122 | path = os.path.join(filedir, filename) 123 | return path 124 | 125 | 126 | def write_voc_results_file(self, all_boxes): 127 | for cls_ind, cls in enumerate(self.labelmap): 128 | if self.display: 129 | print('Writing {:s} VOC results file'.format(cls)) 130 | filename = self.get_voc_results_file_template(cls) 131 | with open(filename, 'wt') as f: 132 | for im_ind, index in enumerate(self.dataset.ids): 133 | dets = all_boxes[cls_ind][im_ind] 134 | if dets == []: 135 | continue 136 | # the VOCdevkit expects 1-based indices 137 | for k in range(dets.shape[0]): 138 | f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'. 139 | format(index[1], dets[k, -1], 140 | dets[k, 0] + 1, dets[k, 1] + 1, 141 | dets[k, 2] + 1, dets[k, 3] + 1)) 142 | 143 | 144 | def do_python_eval(self, use_07=True): 145 | cachedir = os.path.join(self.devkit_path, 'annotations_cache') 146 | aps = [] 147 | # The PASCAL VOC metric changed in 2010 148 | use_07_metric = use_07 149 | print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No')) 150 | if not os.path.isdir(self.output_dir): 151 | os.mkdir(self.output_dir) 152 | for i, cls in enumerate(self.labelmap): 153 | filename = self.get_voc_results_file_template(cls) 154 | rec, prec, ap = self.voc_eval(detpath=filename, 155 | classname=cls, 156 | cachedir=cachedir, 157 | ovthresh=0.5, 158 | use_07_metric=use_07_metric 159 | ) 160 | aps += [ap] 161 | print('AP for {} = {:.4f}'.format(cls, ap)) 162 | with open(os.path.join(self.output_dir, cls + '_pr.pkl'), 'wb') as f: 163 | pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f) 164 | if self.display: 165 | self.map = np.mean(aps) 166 | print('Mean AP = {:.4f}'.format(np.mean(aps))) 167 | print('~~~~~~~~') 168 | print('Results:') 169 | for ap in aps: 170 | print('{:.3f}'.format(ap)) 171 | print('{:.3f}'.format(np.mean(aps))) 172 | print('~~~~~~~~') 173 | print('') 174 | print('--------------------------------------------------------------') 175 | print('Results computed with the **unofficial** Python eval code.') 176 | print('Results should be very close to the official MATLAB eval code.') 177 | print('--------------------------------------------------------------') 178 | else: 179 | self.map = np.mean(aps) 180 | print('Mean AP = {:.4f}'.format(np.mean(aps))) 181 | 182 | 183 | def voc_ap(self, rec, prec, use_07_metric=True): 184 | """ ap = voc_ap(rec, prec, [use_07_metric]) 185 | Compute VOC AP given precision and recall. 186 | If use_07_metric is true, uses the 187 | VOC 07 11 point method (default:True). 188 | """ 189 | if use_07_metric: 190 | # 11 point metric 191 | ap = 0. 192 | for t in np.arange(0., 1.1, 0.1): 193 | if np.sum(rec >= t) == 0: 194 | p = 0 195 | else: 196 | p = np.max(prec[rec >= t]) 197 | ap = ap + p / 11. 198 | else: 199 | # correct AP calculation 200 | # first append sentinel values at the end 201 | mrec = np.concatenate(([0.], rec, [1.])) 202 | mpre = np.concatenate(([0.], prec, [0.])) 203 | 204 | # compute the precision envelope 205 | for i in range(mpre.size - 1, 0, -1): 206 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 207 | 208 | # to calculate area under PR curve, look for points 209 | # where X axis (recall) changes value 210 | i = np.where(mrec[1:] != mrec[:-1])[0] 211 | 212 | # and sum (\Delta recall) * prec 213 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 214 | return ap 215 | 216 | 217 | def voc_eval(self, detpath, classname, cachedir, ovthresh=0.5, use_07_metric=True): 218 | if not os.path.isdir(cachedir): 219 | os.mkdir(cachedir) 220 | cachefile = os.path.join(cachedir, 'annots.pkl') 221 | # read list of images 222 | with open(self.imgsetpath, 'r') as f: 223 | lines = f.readlines() 224 | imagenames = [x.strip() for x in lines] 225 | if not os.path.isfile(cachefile): 226 | # load annots 227 | recs = {} 228 | for i, imagename in enumerate(imagenames): 229 | recs[imagename] = self.parse_rec(self.annopath % (imagename)) 230 | if i % 100 == 0 and self.display: 231 | print('Reading annotation for {:d}/{:d}'.format( 232 | i + 1, len(imagenames))) 233 | # save 234 | if self.display: 235 | print('Saving cached annotations to {:s}'.format(cachefile)) 236 | with open(cachefile, 'wb') as f: 237 | pickle.dump(recs, f) 238 | else: 239 | # load 240 | with open(cachefile, 'rb') as f: 241 | recs = pickle.load(f) 242 | 243 | # extract gt objects for this class 244 | class_recs = {} 245 | npos = 0 246 | for imagename in imagenames: 247 | R = [obj for obj in recs[imagename] if obj['name'] == classname] 248 | bbox = np.array([x['bbox'] for x in R]) 249 | difficult = np.array([x['difficult'] for x in R]).astype(np.bool) 250 | det = [False] * len(R) 251 | npos = npos + sum(~difficult) 252 | class_recs[imagename] = {'bbox': bbox, 253 | 'difficult': difficult, 254 | 'det': det} 255 | 256 | # read dets 257 | detfile = detpath.format(classname) 258 | with open(detfile, 'r') as f: 259 | lines = f.readlines() 260 | if any(lines) == 1: 261 | 262 | splitlines = [x.strip().split(' ') for x in lines] 263 | image_ids = [x[0] for x in splitlines] 264 | confidence = np.array([float(x[1]) for x in splitlines]) 265 | BB = np.array([[float(z) for z in x[2:]] for x in splitlines]) 266 | 267 | # sort by confidence 268 | sorted_ind = np.argsort(-confidence) 269 | sorted_scores = np.sort(-confidence) 270 | BB = BB[sorted_ind, :] 271 | image_ids = [image_ids[x] for x in sorted_ind] 272 | 273 | # go down dets and mark TPs and FPs 274 | nd = len(image_ids) 275 | tp = np.zeros(nd) 276 | fp = np.zeros(nd) 277 | for d in range(nd): 278 | R = class_recs[image_ids[d]] 279 | bb = BB[d, :].astype(float) 280 | ovmax = -np.inf 281 | BBGT = R['bbox'].astype(float) 282 | if BBGT.size > 0: 283 | # compute overlaps 284 | # intersection 285 | ixmin = np.maximum(BBGT[:, 0], bb[0]) 286 | iymin = np.maximum(BBGT[:, 1], bb[1]) 287 | ixmax = np.minimum(BBGT[:, 2], bb[2]) 288 | iymax = np.minimum(BBGT[:, 3], bb[3]) 289 | iw = np.maximum(ixmax - ixmin, 0.) 290 | ih = np.maximum(iymax - iymin, 0.) 291 | inters = iw * ih 292 | uni = ((bb[2] - bb[0]) * (bb[3] - bb[1]) + 293 | (BBGT[:, 2] - BBGT[:, 0]) * 294 | (BBGT[:, 3] - BBGT[:, 1]) - inters) 295 | overlaps = inters / uni 296 | ovmax = np.max(overlaps) 297 | jmax = np.argmax(overlaps) 298 | 299 | if ovmax > ovthresh: 300 | if not R['difficult'][jmax]: 301 | if not R['det'][jmax]: 302 | tp[d] = 1. 303 | R['det'][jmax] = 1 304 | else: 305 | fp[d] = 1. 306 | else: 307 | fp[d] = 1. 308 | 309 | # compute precision recall 310 | fp = np.cumsum(fp) 311 | tp = np.cumsum(tp) 312 | rec = tp / float(npos) 313 | # avoid divide by zero in case the first detection matches a difficult 314 | # ground truth 315 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 316 | ap = self.voc_ap(rec, prec, use_07_metric) 317 | else: 318 | rec = -1. 319 | prec = -1. 320 | ap = -1. 321 | 322 | return rec, prec, ap 323 | 324 | 325 | def evaluate_detections(self, box_list): 326 | self.write_voc_results_file(box_list) 327 | self.do_python_eval() 328 | 329 | 330 | if __name__ == '__main__': 331 | pass -------------------------------------------------------------------------------- /data/transform.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import types 4 | from numpy import random 5 | 6 | 7 | def intersect(box_a, box_b): 8 | max_xy = np.minimum(box_a[:, 2:], box_b[2:]) 9 | min_xy = np.maximum(box_a[:, :2], box_b[:2]) 10 | inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) 11 | return inter[:, 0] * inter[:, 1] 12 | 13 | 14 | def jaccard_numpy(box_a, box_b): 15 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 16 | is simply the intersection over union of two boxes. 17 | E.g.: 18 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 19 | Args: 20 | box_a: Multiple bounding boxes, Shape: [num_boxes,4] 21 | box_b: Single bounding box, Shape: [4] 22 | Return: 23 | jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]] 24 | """ 25 | inter = intersect(box_a, box_b) 26 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 27 | (box_a[:, 3]-box_a[:, 1])) # [A,B] 28 | area_b = ((box_b[2]-box_b[0]) * 29 | (box_b[3]-box_b[1])) # [A,B] 30 | union = area_a + area_b - inter 31 | return inter / union # [A,B] 32 | 33 | 34 | class Compose(object): 35 | """Composes several augmentations together. 36 | Args: 37 | transforms (List[Transform]): list of transforms to compose. 38 | Example: 39 | >>> augmentations.Compose([ 40 | >>> transforms.CenterCrop(10), 41 | >>> transforms.ToTensor(), 42 | >>> ]) 43 | """ 44 | 45 | def __init__(self, transforms): 46 | self.transforms = transforms 47 | 48 | def __call__(self, img, boxes=None, labels=None): 49 | for t in self.transforms: 50 | img, boxes, labels = t(img, boxes, labels) 51 | return img, boxes, labels 52 | 53 | 54 | class ConvertFromInts(object): 55 | def __call__(self, image, boxes=None, labels=None): 56 | return image.astype(np.float32), boxes, labels 57 | 58 | 59 | class Normalize(object): 60 | def __init__(self, mean=None, std=None): 61 | self.mean = np.array(mean, dtype=np.float32) 62 | self.std = np.array(std, dtype=np.float32) 63 | 64 | def __call__(self, image, boxes=None, labels=None): 65 | image = image.astype(np.float32) 66 | image /= 255. 67 | image -= self.mean 68 | image /= self.std 69 | 70 | return image, boxes, labels 71 | 72 | 73 | class ToAbsoluteCoords(object): 74 | def __call__(self, image, boxes=None, labels=None): 75 | height, width, channels = image.shape 76 | boxes[:, 0] *= width 77 | boxes[:, 2] *= width 78 | boxes[:, 1] *= height 79 | boxes[:, 3] *= height 80 | 81 | return image, boxes, labels 82 | 83 | 84 | class ToPercentCoords(object): 85 | def __call__(self, image, boxes=None, labels=None): 86 | height, width, channels = image.shape 87 | boxes[:, 0] /= width 88 | boxes[:, 2] /= width 89 | boxes[:, 1] /= height 90 | boxes[:, 3] /= height 91 | 92 | return image, boxes, labels 93 | 94 | 95 | class ConvertColor(object): 96 | def __init__(self, current='BGR', transform='HSV'): 97 | self.transform = transform 98 | self.current = current 99 | 100 | def __call__(self, image, boxes=None, labels=None): 101 | if self.current == 'BGR' and self.transform == 'HSV': 102 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 103 | elif self.current == 'HSV' and self.transform == 'BGR': 104 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 105 | else: 106 | raise NotImplementedError 107 | return image, boxes, labels 108 | 109 | 110 | class Resize(object): 111 | def __init__(self, size=640): 112 | self.size = size 113 | 114 | def __call__(self, image, boxes=None, labels=None): 115 | image = cv2.resize(image, (self.size, self.size)) 116 | return image, boxes, labels 117 | 118 | 119 | class RandomSaturation(object): 120 | def __init__(self, lower=0.5, upper=1.5): 121 | self.lower = lower 122 | self.upper = upper 123 | assert self.upper >= self.lower, "contrast upper must be >= lower." 124 | assert self.lower >= 0, "contrast lower must be non-negative." 125 | 126 | def __call__(self, image, boxes=None, labels=None): 127 | if random.randint(2): 128 | image[:, :, 1] *= random.uniform(self.lower, self.upper) 129 | 130 | return image, boxes, labels 131 | 132 | 133 | class RandomHue(object): 134 | def __init__(self, delta=18.0): 135 | assert delta >= 0.0 and delta <= 360.0 136 | self.delta = delta 137 | 138 | def __call__(self, image, boxes=None, labels=None): 139 | if random.randint(2): 140 | image[:, :, 0] += random.uniform(-self.delta, self.delta) 141 | image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0 142 | image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 143 | return image, boxes, labels 144 | 145 | 146 | class RandomLightingNoise(object): 147 | def __init__(self): 148 | self.perms = ((0, 1, 2), (0, 2, 1), 149 | (1, 0, 2), (1, 2, 0), 150 | (2, 0, 1), (2, 1, 0)) 151 | 152 | def __call__(self, image, boxes=None, labels=None): 153 | if random.randint(2): 154 | swap = self.perms[random.randint(len(self.perms))] 155 | shuffle = SwapChannels(swap) # shuffle channels 156 | image = shuffle(image) 157 | return image, boxes, labels 158 | 159 | 160 | class RandomContrast(object): 161 | def __init__(self, lower=0.5, upper=1.5): 162 | self.lower = lower 163 | self.upper = upper 164 | assert self.upper >= self.lower, "contrast upper must be >= lower." 165 | assert self.lower >= 0, "contrast lower must be non-negative." 166 | 167 | # expects float image 168 | def __call__(self, image, boxes=None, labels=None): 169 | if random.randint(2): 170 | alpha = random.uniform(self.lower, self.upper) 171 | image *= alpha 172 | return image, boxes, labels 173 | 174 | 175 | class RandomBrightness(object): 176 | def __init__(self, delta=32): 177 | assert delta >= 0.0 178 | assert delta <= 255.0 179 | self.delta = delta 180 | 181 | def __call__(self, image, boxes=None, labels=None): 182 | if random.randint(2): 183 | delta = random.uniform(-self.delta, self.delta) 184 | image += delta 185 | return image, boxes, labels 186 | 187 | 188 | class RandomSampleCrop(object): 189 | """Crop 190 | Arguments: 191 | img (Image): the image being input during training 192 | boxes (Tensor): the original bounding boxes in pt form 193 | labels (Tensor): the class labels for each bbox 194 | mode (float tuple): the min and max jaccard overlaps 195 | Return: 196 | (img, boxes, classes) 197 | img (Image): the cropped image 198 | boxes (Tensor): the adjusted bounding boxes in pt form 199 | labels (Tensor): the class labels for each bbox 200 | """ 201 | def __init__(self): 202 | self.sample_options = ( 203 | # using entire original input image 204 | None, 205 | # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9 206 | (0.1, None), 207 | (0.3, None), 208 | (0.7, None), 209 | (0.9, None), 210 | # randomly sample a patch 211 | (None, None), 212 | ) 213 | 214 | def __call__(self, image, boxes=None, labels=None): 215 | height, width, _ = image.shape 216 | while True: 217 | # randomly choose a mode 218 | sample_id = np.random.randint(len(self.sample_options)) 219 | mode = self.sample_options[sample_id] 220 | if mode is None: 221 | return image, boxes, labels 222 | 223 | min_iou, max_iou = mode 224 | if min_iou is None: 225 | min_iou = float('-inf') 226 | if max_iou is None: 227 | max_iou = float('inf') 228 | 229 | # max trails (50) 230 | for _ in range(50): 231 | current_image = image 232 | 233 | w = random.uniform(0.3 * width, width) 234 | h = random.uniform(0.3 * height, height) 235 | 236 | # aspect ratio constraint b/t .5 & 2 237 | if h / w < 0.5 or h / w > 2: 238 | continue 239 | 240 | left = random.uniform(width - w) 241 | top = random.uniform(height - h) 242 | 243 | # convert to integer rect x1,y1,x2,y2 244 | rect = np.array([int(left), int(top), int(left+w), int(top+h)]) 245 | 246 | # calculate IoU (jaccard overlap) b/t the cropped and gt boxes 247 | overlap = jaccard_numpy(boxes, rect) 248 | 249 | # is min and max overlap constraint satisfied? if not try again 250 | if overlap.min() < min_iou and max_iou < overlap.max(): 251 | continue 252 | 253 | # cut the crop from the image 254 | current_image = current_image[rect[1]:rect[3], rect[0]:rect[2], 255 | :] 256 | 257 | # keep overlap with gt box IF center in sampled patch 258 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0 259 | 260 | # mask in all gt boxes that above and to the left of centers 261 | m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) 262 | 263 | # mask in all gt boxes that under and to the right of centers 264 | m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) 265 | 266 | # mask in that both m1 and m2 are true 267 | mask = m1 * m2 268 | 269 | # have any valid boxes? try again if not 270 | if not mask.any(): 271 | continue 272 | 273 | # take only matching gt boxes 274 | current_boxes = boxes[mask, :].copy() 275 | 276 | # take only matching gt labels 277 | current_labels = labels[mask] 278 | 279 | # should we use the box left and top corner or the crop's 280 | current_boxes[:, :2] = np.maximum(current_boxes[:, :2], 281 | rect[:2]) 282 | # adjust to crop (by substracting crop's left,top) 283 | current_boxes[:, :2] -= rect[:2] 284 | 285 | current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], 286 | rect[2:]) 287 | # adjust to crop (by substracting crop's left,top) 288 | current_boxes[:, 2:] -= rect[:2] 289 | 290 | return current_image, current_boxes, current_labels 291 | 292 | 293 | class Expand(object): 294 | def __init__(self, mean): 295 | self.mean = mean 296 | 297 | def __call__(self, image, boxes, labels): 298 | if random.randint(2): 299 | return image, boxes, labels 300 | 301 | height, width, depth = image.shape 302 | ratio = random.uniform(1, 4) 303 | left = random.uniform(0, width*ratio - width) 304 | top = random.uniform(0, height*ratio - height) 305 | 306 | expand_image = np.zeros( 307 | (int(height*ratio), int(width*ratio), depth), 308 | dtype=image.dtype) 309 | expand_image[:, :, :] = self.mean 310 | expand_image[int(top):int(top + height), 311 | int(left):int(left + width)] = image 312 | image = expand_image 313 | 314 | boxes = boxes.copy() 315 | boxes[:, :2] += (int(left), int(top)) 316 | boxes[:, 2:] += (int(left), int(top)) 317 | 318 | return image, boxes, labels 319 | 320 | 321 | class RandomMirror(object): 322 | def __call__(self, image, boxes, classes): 323 | _, width, _ = image.shape 324 | if random.randint(2): 325 | image = image[:, ::-1] 326 | boxes = boxes.copy() 327 | boxes[:, 0::2] = width - boxes[:, 2::-2] 328 | return image, boxes, classes 329 | 330 | 331 | class SwapChannels(object): 332 | """Transforms a tensorized image by swapping the channels in the order 333 | specified in the swap tuple. 334 | Args: 335 | swaps (int triple): final order of channels 336 | eg: (2, 1, 0) 337 | """ 338 | 339 | def __init__(self, swaps): 340 | self.swaps = swaps 341 | 342 | def __call__(self, image): 343 | """ 344 | Args: 345 | image (Tensor): image tensor to be transformed 346 | Return: 347 | a tensor with channels swapped according to swap 348 | """ 349 | # if torch.is_tensor(image): 350 | # image = image.data.cpu().numpy() 351 | # else: 352 | # image = np.array(image) 353 | image = image[:, :, self.swaps] 354 | return image 355 | 356 | 357 | class PhotometricDistort(object): 358 | def __init__(self): 359 | self.pd = [ 360 | RandomContrast(), 361 | ConvertColor(transform='HSV'), 362 | RandomSaturation(), 363 | RandomHue(), 364 | ConvertColor(current='HSV', transform='BGR'), 365 | RandomContrast() 366 | ] 367 | self.rand_brightness = RandomBrightness() 368 | 369 | def __call__(self, image, boxes, labels): 370 | im = image.copy() 371 | im, boxes, labels = self.rand_brightness(im, boxes, labels) 372 | if random.randint(2): 373 | distort = Compose(self.pd[:-1]) 374 | else: 375 | distort = Compose(self.pd[1:]) 376 | im, boxes, labels = distort(im, boxes, labels) 377 | return im, boxes, labels 378 | 379 | 380 | class Augmentation(object): 381 | def __init__(self, size=640, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)): 382 | self.mean = mean 383 | self.size = size 384 | self.std = std 385 | self.augment = Compose([ 386 | ConvertFromInts(), # 将int类型转换为float32类型 387 | ToAbsoluteCoords(), # 将归一化的相对坐标转换为绝对坐标 388 | PhotometricDistort(), # 图像颜色增强 389 | Expand(self.mean), # 扩充增强 390 | RandomSampleCrop(), # 随机剪裁 391 | RandomMirror(), # 随机水平镜像 392 | ToPercentCoords(), # 将绝对坐标转换为归一化的相对坐标 393 | Resize(self.size), # resize操作 394 | Normalize(self.mean, self.std) # 图像颜色归一化 395 | ]) 396 | 397 | def __call__(self, img, boxes, labels): 398 | return self.augment(img, boxes, labels) 399 | 400 | 401 | class BaseTransform: 402 | def __init__(self, size, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)): 403 | self.size = size 404 | self.mean = np.array(mean, dtype=np.float32) 405 | self.std = np.array(std, dtype=np.float32) 406 | 407 | def __call__(self, image, boxes=None, labels=None): 408 | # resize 409 | image = cv2.resize(image, (self.size, self.size)).astype(np.float32) 410 | # normalize 411 | image /= 255. 412 | image -= self.mean 413 | image /= self.std 414 | 415 | return image, boxes, labels 416 | --------------------------------------------------------------------------------