├── README.md
├── backbone
├── __init__.py
├── feature_pyramid_network.py
├── mobilenetv2_model.py
├── resnet101_fpn_model.py
├── resnet50_fpn_model.py
└── vgg_model.py
├── coco91_indices.json
├── coco_to_voc.py
├── draw_box_utils.py
├── loss_and_lr20220612-095042.png
├── mAP.png
├── my_dataset.py
├── network_files
├── __init__.py
├── boxes.py
├── cawb.py
├── det_utils.py
├── faster_rcnn_framework.py
├── image_list.py
├── roi_head.py
├── rpn_function.py
└── transform.py
├── pascal_voc_classes.json
├── plot_curve.py
├── predict.py
├── results20220611-205355.txt
├── split_data.py
├── train_mobilenetv2.py
├── train_multi_GPU.py
├── train_res50_fpn.py
├── train_utils
├── __init__.py
├── coco_eval.py
├── distributed_utils.py
├── group_by_aspect_ratio.py
└── train_eval_utils.py
├── transforms.py
└── validation.py
/README.md:
--------------------------------------------------------------------------------
1 | # AC-FPN论文代码复现(本文并没有实现AM模块,经过测试后,加上AM模块对精度提高效果并不明显,但是增加了计算量和模型权重)
2 | ## 该项目主要使用的训练代码来自b站up主 霹雳吧啦wz:https://b23.tv/HvMiDy ,AC-FPN论文代码纯手撸,如果转载,请标明出处。
3 | # 环境配置:
4 | ①Python3.6/3.7/3.8
5 |
6 | ②Pytorch1.7.1(注意:必须是1.6.0或以上,因为使用官方提供的混合精度训练1.6.0后才支持)
7 |
8 | ③pycocotools(Linux:pip install pycocotools; Windows:pip install pycocotools-windows(不需要额外安装vs))
9 |
10 | ④Ubuntu或Centos(不建议Windows)
11 |
12 | ⑤最好使用GPU训练
13 |
14 | ⑥详细环境配置见requirements.txt
15 |
16 | # 文件结构
17 | ├── backbone: 特征提取网络,包含ACFPN网络
18 |
19 | ├── network_files: Faster R-CNN网络(包括Fast R-CNN以及RPN等模块)
20 |
21 | ├── train_utils: 训练验证相关模块(包括cocotools)
22 |
23 | ├── my_dataset.py: 自定义dataset用于读取COCO数据集
24 |
25 | ├── train_resnet50_fpn.py: 以resnet50+CEFPN做为backbone进行训练
26 |
27 | ├── train_multi_GPU.py: 针对使用多GPU的用户使用
28 |
29 | ├── predict.py: 简易的预测脚本,使用训练好的权重进行预测测试
30 |
31 | ├── validation.py: 利用训练好的权重验证/测试数据的COCO指标,并生成record_mAP.txt文件
32 |
33 | └── pascal_voc_classes.json: pascal_voc标签文件
34 | # 预训练权重下载地址
35 | ResNet50+FPN backbone: https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth ,注意,下载的预训练权重记得要重命名,比如在train_resnet50_fpn.py中读取的是fasterrcnn_resnet50_fpn_coco.pth文件, 不是fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
36 |
37 | # 数据集下载(默认使用的是COCO格式的数据集)
38 | - COCO官网地址:https://cocodataset.org/
39 |
40 | - 对数据集不了解的可以参考b站up主霹雳吧啦wz的博文:https://blog.csdn.net/qq_37541097/article/details/113247318
41 | - 这里以下载coco2017数据集为例,主要下载三个文件:
42 |
43 | - 2017 Train images [118K/18GB]:训练过程中使用到的所有图像文件
44 |
45 | - 2017 Val images [5K/1GB]:验证过程中使用到的所有图像文件
46 |
47 | - 2017 Train/Val annotations [241MB]:对应训练集和验证集的标注json文件
48 |
49 | 都解压到coco2017文件夹下,可得到如下文件结构:
50 |
51 | ├── coco2017: 数据集根目录
52 |
53 | ├── train2017: 所有训练图像文件夹(118287张)
54 |
55 | ├── val2017: 所有验证图像文件夹(5000张)
56 |
57 | └── annotations: 对应标注文件夹
58 |
59 | ├── instances_train2017.json: 对应目标检测、分割任务的训练集标注文件
60 |
61 | ├── instances_val2017.json: 对应目标检测、分割任务的验证集标注文件
62 |
63 | ├── captions_train2017.json: 对应图像描述的训练集标注文件
64 |
65 | ├── captions_val2017.json: 对应图像描述的验证集标注文件
66 |
67 | ├── person_keypoints_train2017.json: 对应人体关键点检测的训练集标注文件
68 |
69 | └── person_keypoints_val2017.json: 对应人体关键点检测的验证集标注文件夹
70 | # 训练方法
71 | - 确保提前准备好数据集
72 | - 确保提前下载好对应预训练模型权重
73 | - 若要使用单GPU训练直接使用train_res50_fpn.py训练脚本
74 | - 若要使用多GPU训练,使用torchrun --nproc_per_node=8 train_multi_GPU.py指令,nproc_per_node参数为使用GPU数量,若使用四块卡,初始学习率调整为0.01即可
75 | - 如果想指定使用哪些GPU设备可在指令前加上CUDA_VISIBLE_DEVICES=0,3(例如只使用设备中的第1块和第4块GPU设备)CUDA_VISIBLE_DEVICES=0,3 torchrun --nproc_per_node=2 train_multi_GPU.py
76 | # 注意事项
77 | - 在使用训练脚本时,注意要将--data-path设置为自己存放coco2017文件夹所在的根目录
78 | - 在使用预测脚本时,要将weights_path设置为你自己生成的权重路径。
79 | - 使用validation文件时,注意确保你的验证集或者测试集中必须包含每个类别的目标,并且使用时只需要修改--num-classes、--data-path和--weights-path即可,其他代码尽量不要改动
80 |
--------------------------------------------------------------------------------
/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet50_fpn_model import resnet50_fpn_backbone
2 | from .resnet101_fpn_model import resnet101_fpn_backbone
--------------------------------------------------------------------------------
/backbone/feature_pyramid_network.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch.nn as nn
4 | import torch
5 | from torch import Tensor
6 | import torch.nn.functional as F
7 |
8 | from torch.jit.annotations import Tuple, List, Dict
9 |
10 | # AC-FPN 实现代码
11 |
12 | # 实现了CxAM和CnAM模块
13 | class CxAM(nn.Module):
14 | def __init__(self, in_channels, out_channels, reduction=8):
15 | super(CxAM, self).__init__()
16 | self.key_conv = nn.Conv2d(in_channels, out_channels//reduction, 1)
17 | self.query_conv = nn.Conv2d(in_channels, out_channels//reduction, 1)
18 | self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
19 | self.sigmoid = nn.Sigmoid()
20 | self.avg = nn.AdaptiveAvgPool2d(1)
21 |
22 | def forward(self, x):
23 | m_batchsize, C, width, height = x.size()
24 |
25 | proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) # B x N x C'
26 |
27 | proj_key = self.key_conv(x).view(m_batchsize, -1, width*height) # B x C' x N
28 |
29 | R = torch.bmm(proj_query, proj_key).view(m_batchsize, width*height, width, height) # B x N x W x H
30 | # 先进行全局平均池化, 此时 R 的shape为 B x N x 1 x 1, 再进行view, R 的shape为 B x 1 x W x H
31 | attention_R = self.sigmoid(self.avg(R).view(m_batchsize, -1, width, height)) # B x 1 x W x H
32 |
33 | proj_value = self.value_conv(x)
34 |
35 | out = proj_value * attention_R # B x W x H
36 |
37 | return out
38 |
39 |
40 | class CnAM(nn.Module):
41 | def __init__(self, in_channels, out_channels, reduction=8):
42 | super(CnAM, self).__init__()
43 | # 原文中对应的P, Z, S
44 | self.Z_conv = nn.Conv2d(in_channels, out_channels // reduction, 1)
45 | self.P_conv = nn.Conv2d(in_channels, out_channels // reduction, 1)
46 | self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
47 | self.sigmoid = nn.Sigmoid()
48 | self.avg = nn.AdaptiveAvgPool2d(1)
49 |
50 | # CnAM使用了FPN中的F5和CEM输出的特征图F
51 | def forward(self, F5, F):
52 | m_batchsize, C, width, height = F5.size()
53 |
54 | proj_query = self.P_conv(F5).view(m_batchsize, -1, width*height).permute(0, 2, 1) # B x N x C''
55 |
56 | proj_key = self.Z_conv(F5).view(m_batchsize, -1, width * height) # B x C'' x N
57 |
58 | S = torch.bmm(proj_query, proj_key).view(m_batchsize, width * height, width, height) # B x N x W x H
59 | attention_S = self.sigmoid(self.avg(S).view(m_batchsize, -1, width, height)) # B x 1 x W x H
60 |
61 | proj_value = self.value_conv(F)
62 |
63 | out = proj_value * attention_S # B x W x H
64 |
65 | return out
66 |
67 | class DenseBlock(nn.Module):
68 | def __init__(self, input_num, num1, num2, rate, drop_out):
69 | super(DenseBlock, self).__init__()
70 |
71 | # C: 2048 --> 512 --> 256
72 | self.conv1x1 = nn.Conv2d(in_channels=input_num, out_channels=num1, kernel_size=1)
73 | self.ConvGN = nn.GroupNorm(num_groups=32, num_channels=num1)
74 | self.relu1 = nn.ReLU(inplace=True)
75 | self.dilaconv = nn.Conv2d(in_channels=num1, out_channels=num2, kernel_size=3, padding=1 * rate, dilation=rate)
76 | self.relu2 = nn.ReLU(inplace=True)
77 | self.drop = nn.Dropout(p=drop_out)
78 |
79 | def forward(self, x):
80 | x = self.ConvGN(self.conv1x1(x))
81 | x = self.relu1(x)
82 | x = self.dilaconv(x)
83 | x = self.relu2(x)
84 | x = self.drop(x)
85 | return x
86 |
87 |
88 | class DenseAPP(nn.Module):
89 | def __init__(self, num_channels=2048):
90 | super(DenseAPP, self).__init__()
91 | self.drop_out = 0.1
92 | self.channels1 = 512
93 | self.channels2 = 256
94 | self.num_channels = num_channels
95 | self.aspp3 = DenseBlock(self.num_channels, num1=self.channels1, num2=self.channels2, rate=3,
96 | drop_out=self.drop_out)
97 | self.aspp6 = DenseBlock(self.num_channels + self.channels2 * 1, num1=self.channels1, num2=self.channels2,
98 | rate=6,
99 | drop_out=self.drop_out)
100 | self.aspp12 = DenseBlock(self.num_channels + self.channels2 * 2, num1=self.channels1, num2=self.channels2,
101 | rate=12,
102 | drop_out=self.drop_out)
103 | self.aspp18 = DenseBlock(self.num_channels + self.channels2 * 3, num1=self.channels1, num2=self.channels2,
104 | rate=18,
105 | drop_out=self.drop_out)
106 | self.aspp24 = DenseBlock(self.num_channels + self.channels2 * 4, num1=self.channels1, num2=self.channels2,
107 | rate=24,
108 | drop_out=self.drop_out)
109 | self.conv1x1 = nn.Conv2d(in_channels=5*self.channels2, out_channels=256, kernel_size=1)
110 | self.ConvGN = nn.GroupNorm(num_groups=32, num_channels=256)
111 |
112 | def forward(self, feature):
113 | aspp3 = self.aspp3(feature)
114 | feature = torch.concat((aspp3, feature), dim=1)
115 | aspp6 = self.aspp6(feature)
116 | feature = torch.concat((aspp6, feature), dim=1)
117 | aspp12 = self.aspp12(feature)
118 | feature = torch.concat((aspp12, feature), dim=1)
119 | aspp18 = self.aspp18(feature)
120 | feature = torch.concat((aspp18, feature), dim=1)
121 | aspp24 = self.aspp24(feature)
122 |
123 | x = torch.concat((aspp3, aspp6, aspp12, aspp18, aspp24), dim=1)
124 | out = self.ConvGN(self.conv1x1(x))
125 | return out
126 |
127 |
128 | class FeaturePyramidNetwork(nn.Module):
129 | def __init__(self, in_channels_list, out_channels, extra_blocks=None):
130 | super().__init__()
131 | self.dense = DenseAPP(num_channels=in_channels_list[-1])
132 |
133 | # --------增加AM模块,若不想使用,可直接注释掉--------#
134 | self.CxAM = CxAM(in_channels=256, out_channels=256)
135 | self.CnAM = CnAM(in_channels=256, out_channels=256)
136 | # -------------------------------------------------#
137 |
138 | # 用来调整resnet特征矩阵(layer1,2,3,4)的channel(kernel_size=1)
139 | self.inner_blocks = nn.ModuleList()
140 | # 对调整后的特征矩阵使用3x3的卷积核来得到对应的预测特征矩阵
141 | self.layer_blocks = nn.ModuleList()
142 | for in_channels in in_channels_list:
143 | if in_channels == 0:
144 | continue
145 | inner_block_module = nn.Conv2d(in_channels, out_channels, 1)
146 | layer_block_module = nn.Conv2d(out_channels, out_channels, 3, padding=1)
147 | self.inner_blocks.append(inner_block_module)
148 | self.layer_blocks.append(layer_block_module)
149 |
150 | # initialize parameters now to avoid modifying the initialization of top_blocks
151 | for m in self.children():
152 | if isinstance(m, nn.Conv2d):
153 | nn.init.kaiming_uniform_(m.weight, a=1)
154 | nn.init.constant_(m.bias, 0)
155 |
156 | self.extra_blocks = extra_blocks
157 |
158 | def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
159 | """
160 | This is equivalent to self.inner_blocks[idx](x),
161 | but torchscript doesn't support this yet
162 | """
163 | num_blocks = len(self.inner_blocks)
164 | if idx < 0:
165 | idx += num_blocks
166 | i = 0
167 | out = x
168 | for module in self.inner_blocks:
169 | if i == idx:
170 | out = module(x)
171 | i += 1
172 | return out
173 |
174 | def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
175 | """
176 | This is equivalent to self.layer_blocks[idx](x),
177 | but torchscript doesn't support this yet
178 | """
179 | num_blocks = len(self.layer_blocks)
180 | if idx < 0:
181 | idx += num_blocks
182 | i = 0
183 | out = x
184 | for module in self.layer_blocks:
185 | if i == idx:
186 | out = module(x)
187 | i += 1
188 | return out
189 |
190 | def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
191 | """
192 | Computes the FPN for a set of feature maps.
193 | Arguments:
194 | x (OrderedDict[Tensor]): feature maps for each feature level.
195 | Returns:
196 | results (OrderedDict[Tensor]): feature maps after FPN layers.
197 | They are ordered from highest resolution first.
198 | """
199 | # unpack OrderedDict into two lists for easier handling
200 | names = list(x.keys())
201 | x = list(x.values())
202 |
203 | # 将C5送入DenseAPP中获得上下文信息
204 | dense = self.dense(x[-1])
205 |
206 | # 将resnet layer4的channel调整到指定的out_channels
207 | # last_inner = self.inner_blocks[-1](x[-1])
208 | last_inner = self.get_result_from_inner_blocks(x[-1], -1)
209 |
210 | # 将dense送入cxam模块和cnam模块,不想使用AM模块注释下面三行即可
211 | cxam = self.CxAM(dense)
212 | cnam = self.CnAM(dense, last_inner)
213 | result = cxam + cnam
214 |
215 | # result中保存着每个预测特征层
216 | results = []
217 | # 将layer4调整channel后的特征矩阵,通过3x3卷积后得到对应的预测特征矩阵
218 | # results.append(self.layer_blocks[-1](last_inner))
219 |
220 | # 不使用AM模块
221 | # P5 = dense + self.get_result_from_layer_blocks(last_inner, -1)
222 |
223 | # 使用AM模块
224 | P5 = result + self.get_result_from_layer_blocks(last_inner, -1)
225 |
226 | results.append(P5)
227 |
228 | for idx in range(len(x) - 2, -1, -1):
229 | inner_lateral = self.get_result_from_inner_blocks(x[idx], idx)
230 | feat_shape = inner_lateral.shape[-2:]
231 | inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest")
232 | last_inner = inner_lateral + inner_top_down
233 | results.insert(0, self.get_result_from_layer_blocks(last_inner, idx))
234 |
235 | # 在layer4对应的预测特征层基础上生成预测特征矩阵5
236 | if self.extra_blocks is not None:
237 | results, names = self.extra_blocks(results, x, names)
238 |
239 | # make it back an OrderedDict
240 | out = OrderedDict([(k, v) for k, v in zip(names, results)])
241 |
242 | return out
243 |
244 |
245 | class LastLevelMaxPool(torch.nn.Module):
246 | """
247 | Applies a max_pool2d on top of the last feature map
248 | """
249 |
250 | def forward(self, x: List[Tensor], y: List[Tensor], names: List[str]) -> Tuple[List[Tensor], List[str]]:
251 | names.append("pool")
252 | x.append(F.max_pool2d(x[-1], 1, 2, 0)) # input, kernel_size, stride, padding
253 | return x, names
254 |
--------------------------------------------------------------------------------
/backbone/mobilenetv2_model.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | from torchvision.ops import misc
4 |
5 |
6 | def _make_divisible(ch, divisor=8, min_ch=None):
7 | """
8 | This function is taken from the original tf repo.
9 | It ensures that all layers have a channel number that is divisible by 8
10 | It can be seen here:
11 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
12 | """
13 | if min_ch is None:
14 | min_ch = divisor
15 | new_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor)
16 | # Make sure that round down does not go down by more than 10%.
17 | if new_ch < 0.9 * ch:
18 | new_ch += divisor
19 | return new_ch
20 |
21 |
22 | class ConvBNReLU(nn.Sequential):
23 | def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, groups=1, norm_layer=None):
24 | padding = (kernel_size - 1) // 2
25 | if norm_layer is None:
26 | norm_layer = nn.BatchNorm2d
27 | super(ConvBNReLU, self).__init__(
28 | nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, groups=groups, bias=False),
29 | norm_layer(out_channel),
30 | nn.ReLU6(inplace=True)
31 | )
32 |
33 |
34 | class InvertedResidual(nn.Module):
35 | def __init__(self, in_channel, out_channel, stride, expand_ratio, norm_layer=None):
36 | super(InvertedResidual, self).__init__()
37 | hidden_channel = in_channel * expand_ratio
38 | self.use_shortcut = stride == 1 and in_channel == out_channel
39 | if norm_layer is None:
40 | norm_layer = nn.BatchNorm2d
41 |
42 | layers = []
43 | if expand_ratio != 1:
44 | # 1x1 pointwise conv
45 | layers.append(ConvBNReLU(in_channel, hidden_channel, kernel_size=1, norm_layer=norm_layer))
46 | layers.extend([
47 | # 3x3 depthwise conv
48 | ConvBNReLU(hidden_channel, hidden_channel, stride=stride, groups=hidden_channel, norm_layer=norm_layer),
49 | # 1x1 pointwise conv(linear)
50 | nn.Conv2d(hidden_channel, out_channel, kernel_size=1, bias=False),
51 | norm_layer(out_channel),
52 | ])
53 |
54 | self.conv = nn.Sequential(*layers)
55 |
56 | def forward(self, x):
57 | if self.use_shortcut:
58 | return x + self.conv(x)
59 | else:
60 | return self.conv(x)
61 |
62 |
63 | class MobileNetV2(nn.Module):
64 | def __init__(self, num_classes=1000, alpha=1.0, round_nearest=8, weights_path=None, norm_layer=None):
65 | super(MobileNetV2, self).__init__()
66 | block = InvertedResidual
67 | input_channel = _make_divisible(32 * alpha, round_nearest)
68 | last_channel = _make_divisible(1280 * alpha, round_nearest)
69 |
70 | if norm_layer is None:
71 | norm_layer = nn.BatchNorm2d
72 |
73 | inverted_residual_setting = [
74 | # t, c, n, s
75 | [1, 16, 1, 1],
76 | [6, 24, 2, 2],
77 | [6, 32, 3, 2],
78 | [6, 64, 4, 2],
79 | [6, 96, 3, 1],
80 | [6, 160, 3, 2],
81 | [6, 320, 1, 1],
82 | ]
83 |
84 | features = []
85 | # conv1 layer
86 | features.append(ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer))
87 | # building inverted residual residual blockes
88 | for t, c, n, s in inverted_residual_setting:
89 | output_channel = _make_divisible(c * alpha, round_nearest)
90 | for i in range(n):
91 | stride = s if i == 0 else 1
92 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
93 | input_channel = output_channel
94 | # building last several layers
95 | features.append(ConvBNReLU(input_channel, last_channel, 1, norm_layer=norm_layer))
96 | # combine feature layers
97 | self.features = nn.Sequential(*features)
98 |
99 | # building classifier
100 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
101 | self.classifier = nn.Sequential(
102 | nn.Dropout(0.2),
103 | nn.Linear(last_channel, num_classes)
104 | )
105 |
106 | if weights_path is None:
107 | # weight initialization
108 | for m in self.modules():
109 | if isinstance(m, nn.Conv2d):
110 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
111 | if m.bias is not None:
112 | nn.init.zeros_(m.bias)
113 | elif isinstance(m, nn.BatchNorm2d):
114 | nn.init.ones_(m.weight)
115 | nn.init.zeros_(m.bias)
116 | elif isinstance(m, nn.Linear):
117 | nn.init.normal_(m.weight, 0, 0.01)
118 | nn.init.zeros_(m.bias)
119 | else:
120 | self.load_state_dict(torch.load(weights_path))
121 |
122 | def forward(self, x):
123 | x = self.features(x)
124 | x = self.avgpool(x)
125 | x = torch.flatten(x, 1)
126 | x = self.classifier(x)
127 | return x
128 |
--------------------------------------------------------------------------------
/backbone/resnet101_fpn_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.jit.annotations import List, Dict
7 | from torchvision.ops.misc import FrozenBatchNorm2d
8 |
9 | from .feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
10 |
11 |
12 | class Bottleneck(nn.Module):
13 | expansion = 4
14 |
15 | def __init__(self, in_channel, out_channel, stride=1, downsample=None, norm_layer=None):
16 | super(Bottleneck, self).__init__()
17 | if norm_layer is None:
18 | norm_layer = nn.BatchNorm2d
19 |
20 | self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
21 | kernel_size=1, stride=1, bias=False) # squeeze channels
22 | self.bn1 = norm_layer(out_channel)
23 | # -----------------------------------------
24 | self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
25 | kernel_size=3, stride=stride, bias=False, padding=1)
26 | self.bn2 = norm_layer(out_channel)
27 | # -----------------------------------------
28 | self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel * self.expansion,
29 | kernel_size=1, stride=1, bias=False) # unsqueeze channels
30 | self.bn3 = norm_layer(out_channel * self.expansion)
31 | self.relu = nn.ReLU(inplace=True)
32 | self.downsample = downsample
33 |
34 | def forward(self, x):
35 | identity = x
36 | if self.downsample is not None:
37 | identity = self.downsample(x)
38 |
39 | out = self.conv1(x)
40 | out = self.bn1(out)
41 | out = self.relu(out)
42 |
43 | out = self.conv2(out)
44 | out = self.bn2(out)
45 | out = self.relu(out)
46 |
47 | out = self.conv3(out)
48 | out = self.bn3(out)
49 |
50 | out += identity
51 | out = self.relu(out)
52 |
53 | return out
54 |
55 |
56 | class ResNet(nn.Module):
57 |
58 | def __init__(self, block, blocks_num, num_classes=1000, include_top=True, norm_layer=None):
59 | super(ResNet, self).__init__()
60 | if norm_layer is None:
61 | norm_layer = nn.BatchNorm2d
62 | self._norm_layer = norm_layer
63 |
64 | self.include_top = include_top
65 | self.in_channel = 64
66 |
67 | self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
68 | padding=3, bias=False)
69 | self.bn1 = norm_layer(self.in_channel)
70 | self.relu = nn.ReLU(inplace=True)
71 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
72 | self.layer1 = self._make_layer(block, 64, blocks_num[0])
73 | self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
74 | self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
75 | self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
76 | if self.include_top:
77 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)
78 | self.fc = nn.Linear(512 * block.expansion, num_classes)
79 |
80 | for m in self.modules():
81 | if isinstance(m, nn.Conv2d):
82 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
83 |
84 | def _make_layer(self, block, channel, block_num, stride=1):
85 | norm_layer = self._norm_layer
86 | downsample = None
87 | if stride != 1 or self.in_channel != channel * block.expansion:
88 | downsample = nn.Sequential(
89 | nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
90 | norm_layer(channel * block.expansion))
91 |
92 | layers = []
93 | layers.append(block(self.in_channel, channel, downsample=downsample,
94 | stride=stride, norm_layer=norm_layer))
95 | self.in_channel = channel * block.expansion
96 |
97 | for _ in range(1, block_num):
98 | layers.append(block(self.in_channel, channel, norm_layer=norm_layer))
99 |
100 | return nn.Sequential(*layers)
101 |
102 | def forward(self, x):
103 | x = self.conv1(x)
104 | x = self.bn1(x)
105 | x = self.relu(x)
106 | x = self.maxpool(x)
107 |
108 | x = self.layer1(x)
109 | x = self.layer2(x)
110 | x = self.layer3(x)
111 | x = self.layer4(x)
112 |
113 | if self.include_top:
114 | x = self.avgpool(x)
115 | x = torch.flatten(x, 1)
116 | x = self.fc(x)
117 |
118 | return x
119 |
120 |
121 | def overwrite_eps(model, eps):
122 | """
123 | This method overwrites the default eps values of all the
124 | FrozenBatchNorm2d layers of the model with the provided value.
125 | This is necessary to address the BC-breaking change introduced
126 | by the bug-fix at pytorch/vision#2933. The overwrite is applied
127 | only when the pretrained weights are loaded to maintain compatibility
128 | with previous versions.
129 |
130 | Args:
131 | model (nn.Module): The model on which we perform the overwrite.
132 | eps (float): The new value of eps.
133 | """
134 | for module in model.modules():
135 | if isinstance(module, FrozenBatchNorm2d):
136 | module.eps = eps
137 |
138 |
139 | class IntermediateLayerGetter(nn.ModuleDict):
140 | """
141 | Module wrapper that returns intermediate layers from a model
142 | It has a strong assumption that the modules have been registered
143 | into the model in the same order as they are used.
144 | This means that one should **not** reuse the same nn.Module
145 | twice in the forward if you want this to work.
146 | Additionally, it is only able to query submodules that are directly
147 | assigned to the model. So if `model` is passed, `model.feature1` can
148 | be returned, but not `model.feature1.layer2`.
149 | Arguments:
150 | model (nn.Module): model on which we will extract the features
151 | return_layers (Dict[name, new_name]): a dict containing the names
152 | of the modules for which the activations will be returned as
153 | the key of the dict, and the value of the dict is the name
154 | of the returned activation (which the user can specify).
155 | """
156 | __annotations__ = {
157 | "return_layers": Dict[str, str],
158 | }
159 |
160 | def __init__(self, model, return_layers):
161 | if not set(return_layers).issubset([name for name, _ in model.named_children()]):
162 | raise ValueError("return_layers are not present in model")
163 |
164 | orig_return_layers = return_layers
165 | return_layers = {str(k): str(v) for k, v in return_layers.items()}
166 | layers = OrderedDict()
167 |
168 | # 遍历模型子模块按顺序存入有序字典
169 | # 只保存layer4及其之前的结构,舍去之后不用的结构
170 | for name, module in model.named_children():
171 | layers[name] = module
172 | if name in return_layers:
173 | del return_layers[name]
174 | if not return_layers:
175 | break
176 |
177 | super(IntermediateLayerGetter, self).__init__(layers)
178 | self.return_layers = orig_return_layers
179 |
180 | def forward(self, x):
181 | out = OrderedDict()
182 | # 依次遍历模型的所有子模块,并进行正向传播,
183 | # 收集layer1, layer2, layer3, layer4的输出
184 | for name, module in self.items():
185 | x = module(x)
186 | if name in self.return_layers:
187 | out_name = self.return_layers[name]
188 | out[out_name] = x
189 | return out
190 |
191 |
192 | class BackboneWithFPN(nn.Module):
193 | """
194 | Adds a FPN on top of a model.
195 | Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
196 | extract a submodel that returns the feature maps specified in return_layers.
197 | The same limitations of IntermediatLayerGetter apply here.
198 | Arguments:
199 | backbone (nn.Module)
200 | return_layers (Dict[name, new_name]): a dict containing the names
201 | of the modules for which the activations will be returned as
202 | the key of the dict, and the value of the dict is the name
203 | of the returned activation (which the user can specify).
204 | in_channels_list (List[int]): number of channels for each feature map
205 | that is returned, in the order they are present in the OrderedDict
206 | out_channels (int): number of channels in the FPN.
207 | extra_blocks: ExtraFPNBlock
208 | Attributes:
209 | out_channels (int): the number of channels in the FPN
210 | """
211 |
212 | def __init__(self, backbone, return_layers, in_channels, out_channels, extra_blocks=None):
213 | super(BackboneWithFPN, self).__init__()
214 |
215 | if extra_blocks is None:
216 | extra_blocks = LastLevelMaxPool()
217 |
218 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
219 | self.fpn = FeaturePyramidNetwork(
220 | channels_list=in_channels,
221 | extra_blocks=extra_blocks,
222 | )
223 |
224 | self.out_channels = out_channels
225 |
226 | def forward(self, x):
227 | x = self.body(x)
228 | x = self.fpn(x)
229 | return x
230 |
231 |
232 | def resnet101_fpn_backbone(pretrain_path="./backbone/resnet101.pth",
233 | norm_layer=FrozenBatchNorm2d, # FrozenBatchNorm2d的功能与BatchNorm2d类似,但参数无法更新
234 | trainable_layers=3,
235 | returned_layers=None,
236 | extra_blocks=None):
237 | """
238 | 搭建resnet50_fpn——backbone
239 | Args:
240 | pretrain_path: resnet50的预训练权重,如果不使用就默认为空
241 | norm_layer: 官方默认的是FrozenBatchNorm2d,即不会更新参数的bn层(因为如果batch_size设置的很小会导致效果更差,还不如不用bn层)
242 | 如果自己的GPU显存很大可以设置很大的batch_size,那么自己可以传入正常的BatchNorm2d层
243 | (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
244 | trainable_layers: 指定训练哪些层结构
245 | returned_layers: 指定哪些层的输出需要返回
246 | extra_blocks: 在输出的特征层基础上额外添加的层结构
247 |
248 | Returns:
249 |
250 | """
251 | resnet_backbone = ResNet(Bottleneck, [3, 4, 23, 3],
252 | include_top=False,
253 | norm_layer=norm_layer)
254 |
255 | if isinstance(norm_layer, FrozenBatchNorm2d):
256 | overwrite_eps(resnet_backbone, 0.0)
257 |
258 | if pretrain_path != "":
259 | assert os.path.exists(pretrain_path), "{} is not exist.".format(pretrain_path)
260 | # 载入预训练权重
261 | print("加载预训练权重:", resnet_backbone.load_state_dict(torch.load(pretrain_path), strict=False))
262 |
263 | # select layers that wont be frozen
264 | assert 0 <= trainable_layers <= 5
265 | layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers]
266 |
267 | # 如果要训练所有层结构的话,不要忘了conv1后还有一个bn1
268 | if trainable_layers == 5:
269 | layers_to_train.append("bn1")
270 |
271 | # freeze layers
272 | for name, parameter in resnet_backbone.named_parameters():
273 | # 只训练不在layers_to_train列表中的层结构
274 | if all([not name.startswith(layer) for layer in layers_to_train]):
275 | parameter.requires_grad_(False)
276 |
277 | if extra_blocks is None:
278 | extra_blocks = LastLevelMaxPool()
279 |
280 | if returned_layers is None:
281 | returned_layers = [1, 2, 3, 4]
282 | # 返回的特征层个数肯定大于0小于5
283 | assert min(returned_layers) > 0 and max(returned_layers) < 5
284 |
285 | # return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}
286 | return_layers = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)}
287 |
288 | # in_channel 为layer4的输出特征矩阵channel = 2048
289 | in_channels_stage2 = resnet_backbone.in_channel // 8 # 256
290 | # 记录resnet50提供给fpn的每个特征层channel
291 | in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
292 | # 通过fpn后得到的每个特征层的channel
293 | out_channels = 256
294 | return BackboneWithFPN(resnet_backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
295 |
--------------------------------------------------------------------------------
/backbone/resnet50_fpn_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.jit.annotations import List, Dict
7 | from torchvision.ops.misc import FrozenBatchNorm2d
8 |
9 | from .feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
10 |
11 |
12 | class Bottleneck(nn.Module):
13 | expansion = 4
14 |
15 | def __init__(self, in_channel, out_channel, stride=1, downsample=None, norm_layer=None):
16 | super(Bottleneck, self).__init__()
17 | if norm_layer is None:
18 | norm_layer = nn.BatchNorm2d
19 |
20 | self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
21 | kernel_size=1, stride=1, bias=False) # squeeze channels
22 | self.bn1 = norm_layer(out_channel)
23 | # -----------------------------------------
24 | self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
25 | kernel_size=3, stride=stride, bias=False, padding=1)
26 | self.bn2 = norm_layer(out_channel)
27 | # -----------------------------------------
28 | self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel * self.expansion,
29 | kernel_size=1, stride=1, bias=False) # unsqueeze channels
30 | self.bn3 = norm_layer(out_channel * self.expansion)
31 | self.relu = nn.ReLU(inplace=True)
32 | self.downsample = downsample
33 |
34 | def forward(self, x):
35 | identity = x
36 | if self.downsample is not None:
37 | identity = self.downsample(x)
38 |
39 | out = self.conv1(x)
40 | out = self.bn1(out)
41 | out = self.relu(out)
42 |
43 | out = self.conv2(out)
44 | out = self.bn2(out)
45 | out = self.relu(out)
46 |
47 | out = self.conv3(out)
48 | out = self.bn3(out)
49 |
50 | out += identity
51 | out = self.relu(out)
52 |
53 | return out
54 |
55 |
56 | class ResNet(nn.Module):
57 |
58 | def __init__(self, block, blocks_num, num_classes=1000, include_top=True, norm_layer=None):
59 | super(ResNet, self).__init__()
60 | if norm_layer is None:
61 | norm_layer = nn.BatchNorm2d
62 | self._norm_layer = norm_layer
63 |
64 | self.include_top = include_top
65 | self.in_channel = 64
66 |
67 | self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
68 | padding=3, bias=False)
69 | self.bn1 = norm_layer(self.in_channel)
70 | self.relu = nn.ReLU(inplace=True)
71 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
72 | self.layer1 = self._make_layer(block, 64, blocks_num[0])
73 | self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
74 | self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
75 | self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
76 | if self.include_top:
77 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)
78 | self.fc = nn.Linear(512 * block.expansion, num_classes)
79 |
80 | for m in self.modules():
81 | if isinstance(m, nn.Conv2d):
82 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
83 |
84 | def _make_layer(self, block, channel, block_num, stride=1):
85 | norm_layer = self._norm_layer
86 | downsample = None
87 | if stride != 1 or self.in_channel != channel * block.expansion:
88 | downsample = nn.Sequential(
89 | nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
90 | norm_layer(channel * block.expansion))
91 |
92 | layers = []
93 | layers.append(block(self.in_channel, channel, downsample=downsample,
94 | stride=stride, norm_layer=norm_layer))
95 | self.in_channel = channel * block.expansion
96 |
97 | for _ in range(1, block_num):
98 | layers.append(block(self.in_channel, channel, norm_layer=norm_layer))
99 |
100 | return nn.Sequential(*layers)
101 |
102 | def forward(self, x):
103 | x = self.conv1(x)
104 | x = self.bn1(x)
105 | x = self.relu(x)
106 | x = self.maxpool(x)
107 |
108 | x = self.layer1(x)
109 | x = self.layer2(x)
110 | x = self.layer3(x)
111 | x = self.layer4(x)
112 |
113 | if self.include_top:
114 | x = self.avgpool(x)
115 | x = torch.flatten(x, 1)
116 | x = self.fc(x)
117 |
118 | return x
119 |
120 |
121 | def overwrite_eps(model, eps):
122 | """
123 | This method overwrites the default eps values of all the
124 | FrozenBatchNorm2d layers of the model with the provided value.
125 | This is necessary to address the BC-breaking change introduced
126 | by the bug-fix at pytorch/vision#2933. The overwrite is applied
127 | only when the pretrained weights are loaded to maintain compatibility
128 | with previous versions.
129 |
130 | Args:
131 | model (nn.Module): The model on which we perform the overwrite.
132 | eps (float): The new value of eps.
133 | """
134 | for module in model.modules():
135 | if isinstance(module, FrozenBatchNorm2d):
136 | module.eps = eps
137 |
138 |
139 | class IntermediateLayerGetter(nn.ModuleDict):
140 | """
141 | Module wrapper that returns intermediate layers from a model
142 | It has a strong assumption that the modules have been registered
143 | into the model in the same order as they are used.
144 | This means that one should **not** reuse the same nn.Module
145 | twice in the forward if you want this to work.
146 | Additionally, it is only able to query submodules that are directly
147 | assigned to the model. So if `model` is passed, `model.feature1` can
148 | be returned, but not `model.feature1.layer2`.
149 | Arguments:
150 | model (nn.Module): model on which we will extract the features
151 | return_layers (Dict[name, new_name]): a dict containing the names
152 | of the modules for which the activations will be returned as
153 | the key of the dict, and the value of the dict is the name
154 | of the returned activation (which the user can specify).
155 | """
156 | __annotations__ = {
157 | "return_layers": Dict[str, str],
158 | }
159 |
160 | def __init__(self, model, return_layers):
161 | if not set(return_layers).issubset([name for name, _ in model.named_children()]):
162 | raise ValueError("return_layers are not present in model")
163 |
164 | orig_return_layers = return_layers
165 | return_layers = {str(k): str(v) for k, v in return_layers.items()}
166 | layers = OrderedDict()
167 |
168 | # 遍历模型子模块按顺序存入有序字典
169 | # 只保存layer4及其之前的结构,舍去之后不用的结构
170 | for name, module in model.named_children():
171 | layers[name] = module
172 | if name in return_layers:
173 | del return_layers[name]
174 | if not return_layers:
175 | break
176 |
177 | super(IntermediateLayerGetter, self).__init__(layers)
178 | self.return_layers = orig_return_layers
179 |
180 | def forward(self, x):
181 | out = OrderedDict()
182 | # 依次遍历模型的所有子模块,并进行正向传播,
183 | # 收集layer1, layer2, layer3, layer4的输出
184 | for name, module in self.items():
185 | x = module(x)
186 | if name in self.return_layers:
187 | out_name = self.return_layers[name]
188 | out[out_name] = x
189 | return out
190 |
191 |
192 | class BackboneWithFPN(nn.Module):
193 | """
194 | Adds a FPN on top of a model.
195 | Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
196 | extract a submodel that returns the feature maps specified in return_layers.
197 | The same limitations of IntermediatLayerGetter apply here.
198 | Arguments:
199 | backbone (nn.Module)
200 | return_layers (Dict[name, new_name]): a dict containing the names
201 | of the modules for which the activations will be returned as
202 | the key of the dict, and the value of the dict is the name
203 | of the returned activation (which the user can specify).
204 | in_channels_list (List[int]): number of channels for each feature map
205 | that is returned, in the order they are present in the OrderedDict
206 | out_channels (int): number of channels in the FPN.
207 | extra_blocks: ExtraFPNBlock
208 | Attributes:
209 | out_channels (int): the number of channels in the FPN
210 | """
211 |
212 | def __init__(self, backbone, return_layers, in_channels, out_channels, extra_blocks=None):
213 | super(BackboneWithFPN, self).__init__()
214 |
215 | if extra_blocks is None:
216 | extra_blocks = LastLevelMaxPool()
217 |
218 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
219 | self.fpn = FeaturePyramidNetwork(
220 | in_channels_list=in_channels,
221 | out_channels=out_channels,
222 | extra_blocks=extra_blocks,
223 | )
224 |
225 | self.out_channels = out_channels
226 |
227 | def forward(self, x):
228 | x = self.body(x)
229 | x = self.fpn(x)
230 | return x
231 |
232 |
233 | def resnet50_fpn_backbone(pretrain_path="",
234 | norm_layer=FrozenBatchNorm2d, # FrozenBatchNorm2d的功能与BatchNorm2d类似,但参数无法更新
235 | trainable_layers=3,
236 | returned_layers=None,
237 | extra_blocks=None):
238 | """
239 | 搭建resnet50_fpn——backbone
240 | Args:
241 | pretrain_path: resnet50的预训练权重,如果不使用就默认为空
242 | norm_layer: 官方默认的是FrozenBatchNorm2d,即不会更新参数的bn层(因为如果batch_size设置的很小会导致效果更差,还不如不用bn层)
243 | 如果自己的GPU显存很大可以设置很大的batch_size,那么自己可以传入正常的BatchNorm2d层
244 | (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
245 | trainable_layers: 指定训练哪些层结构
246 | returned_layers: 指定哪些层的输出需要返回
247 | extra_blocks: 在输出的特征层基础上额外添加的层结构
248 |
249 | Returns:
250 |
251 | """
252 | resnet_backbone = ResNet(Bottleneck, [3, 4, 6, 3],
253 | include_top=False,
254 | norm_layer=norm_layer)
255 |
256 | if isinstance(norm_layer, FrozenBatchNorm2d):
257 | overwrite_eps(resnet_backbone, 0.0)
258 |
259 | if pretrain_path != "":
260 | assert os.path.exists(pretrain_path), "{} is not exist.".format(pretrain_path)
261 | # 载入预训练权重
262 | print(resnet_backbone.load_state_dict(torch.load(pretrain_path), strict=False))
263 |
264 | # select layers that wont be frozen
265 | assert 0 <= trainable_layers <= 5
266 | layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers]
267 |
268 | # 如果要训练所有层结构的话,不要忘了conv1后还有一个bn1
269 | if trainable_layers == 5:
270 | layers_to_train.append("bn1")
271 |
272 | # freeze layers
273 | for name, parameter in resnet_backbone.named_parameters():
274 | # 只训练不在layers_to_train列表中的层结构
275 | if all([not name.startswith(layer) for layer in layers_to_train]):
276 | parameter.requires_grad_(False)
277 |
278 | if extra_blocks is None:
279 | extra_blocks = LastLevelMaxPool()
280 |
281 | if returned_layers is None:
282 | returned_layers = [1, 2, 3, 4]
283 | # 返回的特征层个数肯定大于0小于5
284 | assert min(returned_layers) > 0 and max(returned_layers) < 5
285 |
286 | # return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}
287 | return_layers = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)}
288 |
289 | # in_channel 为layer4的输出特征矩阵channel = 2048
290 | in_channels_stage2 = resnet_backbone.in_channel // 8 # 256
291 | # 记录resnet50提供给fpn的每个特征层channel
292 | in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
293 | # 通过fpn后得到的每个特征层的channel
294 | out_channels = 256
295 | return BackboneWithFPN(resnet_backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
296 |
--------------------------------------------------------------------------------
/backbone/vgg_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 |
5 | class VGG(nn.Module):
6 | def __init__(self, features, class_num=1000, init_weights=False, weights_path=None):
7 | super(VGG, self).__init__()
8 | self.features = features
9 | self.classifier = nn.Sequential(
10 | nn.Linear(512*7*7, 4096),
11 | nn.ReLU(True),
12 | nn.Dropout(p=0.5),
13 | nn.Linear(4096, 4096),
14 | nn.ReLU(True),
15 | nn.Dropout(p=0.5),
16 | nn.Linear(4096, class_num)
17 | )
18 | if init_weights and weights_path is None:
19 | self._initialize_weights()
20 |
21 | if weights_path is not None:
22 | self.load_state_dict(torch.load(weights_path))
23 |
24 | def forward(self, x):
25 | # N x 3 x 224 x 224
26 | x = self.features(x)
27 | # N x 512 x 7 x 7
28 | x = torch.flatten(x, start_dim=1)
29 | # N x 512*7*7
30 | x = self.classifier(x)
31 | return x
32 |
33 | def _initialize_weights(self):
34 | for m in self.modules():
35 | if isinstance(m, nn.Conv2d):
36 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
37 | nn.init.xavier_uniform_(m.weight)
38 | if m.bias is not None:
39 | nn.init.constant_(m.bias, 0)
40 | elif isinstance(m, nn.Linear):
41 | nn.init.xavier_uniform_(m.weight)
42 | # nn.init.normal_(m.weight, 0, 0.01)
43 | nn.init.constant_(m.bias, 0)
44 |
45 |
46 | def make_features(cfg: list):
47 | layers = []
48 | in_channels = 3
49 | for v in cfg:
50 | if v == "M":
51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
52 | else:
53 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
54 | layers += [conv2d, nn.ReLU(True)]
55 | in_channels = v
56 | return nn.Sequential(*layers)
57 |
58 |
59 | cfgs = {
60 | 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
61 | 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
62 | 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
63 | 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
64 | }
65 |
66 |
67 | def vgg(model_name="vgg16", weights_path=None):
68 | assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
69 | cfg = cfgs[model_name]
70 |
71 | model = VGG(make_features(cfg), weights_path=weights_path)
72 | return model
73 |
--------------------------------------------------------------------------------
/coco91_indices.json:
--------------------------------------------------------------------------------
1 | {
2 | "1": "person",
3 | "2": "bicycle",
4 | "3": "car",
5 | "4": "motorcycle",
6 | "5": "airplane",
7 | "6": "bus",
8 | "7": "train",
9 | "8": "truck",
10 | "9": "boat",
11 | "10": "traffic light",
12 | "11": "fire hydrant",
13 | "12": "N/A",
14 | "13": "stop sign",
15 | "14": "parking meter",
16 | "15": "bench",
17 | "16": "bird",
18 | "17": "cat",
19 | "18": "dog",
20 | "19": "horse",
21 | "20": "sheep",
22 | "21": "cow",
23 | "22": "elephant",
24 | "23": "bear",
25 | "24": "zebra",
26 | "25": "giraffe",
27 | "26": "N/A",
28 | "27": "backpack",
29 | "28": "umbrella",
30 | "29": "N/A",
31 | "30": "N/A",
32 | "31": "handbag",
33 | "32": "tie",
34 | "33": "suitcase",
35 | "34": "frisbee",
36 | "35": "skis",
37 | "36": "snowboard",
38 | "37": "sports ball",
39 | "38": "kite",
40 | "39": "baseball bat",
41 | "40": "baseball glove",
42 | "41": "skateboard",
43 | "42": "surfboard",
44 | "43": "tennis racket",
45 | "44": "bottle",
46 | "45": "N/A",
47 | "46": "wine glass",
48 | "47": "cup",
49 | "48": "fork",
50 | "49": "knife",
51 | "50": "spoon",
52 | "51": "bowl",
53 | "52": "banana",
54 | "53": "apple",
55 | "54": "sandwich",
56 | "55": "orange",
57 | "56": "broccoli",
58 | "57": "carrot",
59 | "58": "hot dog",
60 | "59": "pizza",
61 | "60": "donut",
62 | "61": "cake",
63 | "62": "chair",
64 | "63": "couch",
65 | "64": "potted plant",
66 | "65": "bed",
67 | "66": "N/A",
68 | "67": "dining table",
69 | "68": "N/A",
70 | "69": "N/A",
71 | "70": "toilet",
72 | "71": "N/A",
73 | "72": "tv",
74 | "73": "laptop",
75 | "74": "mouse",
76 | "75": "remote",
77 | "76": "keyboard",
78 | "77": "cell phone",
79 | "78": "microwave",
80 | "79": "oven",
81 | "80": "toaster",
82 | "81": "sink",
83 | "82": "refrigerator",
84 | "83": "N/A",
85 | "84": "book",
86 | "85": "clock",
87 | "86": "vase",
88 | "87": "scissors",
89 | "88": "teddy bear",
90 | "89": "hair drier",
91 | "90": "toothbrush"
92 | }
--------------------------------------------------------------------------------
/coco_to_voc.py:
--------------------------------------------------------------------------------
1 | '''
2 | 把coco数据集合的所有标注转换到voc格式,不改变图片命名方式,
3 | 注意,原来有一些图片是黑白照片,检测出不是 RGB 图像,这样的图像不会被放到新的文件夹中
4 | 更新日期:2019-11-19
5 | '''
6 | # 这个包可以从git上下载https://github.com/cocodataset/cocoapi/tree/master/PythonAPI,也可以直接用修改后的coco.py
7 | from pycocotools.coco import COCO
8 | import os, cv2, shutil
9 | from lxml import etree, objectify
10 | from tqdm import tqdm
11 | from PIL import Image
12 |
13 | CKimg_dir = './coco2017_voc/images'
14 | CKanno_dir = './coco2017_voc/annotations'
15 |
16 |
17 | # 若模型保存文件夹不存在,创建模型保存文件夹,若存在,删除重建
18 | def mkr(path):
19 | if os.path.exists(path):
20 | shutil.rmtree(path)
21 | os.mkdir(path)
22 | else:
23 | os.mkdir(path)
24 |
25 |
26 | def save_annotations(filename, objs, filepath):
27 | annopath = CKanno_dir + "/" + filename[:-3] + "xml" # 生成的xml文件保存路径
28 | dst_path = CKimg_dir + "/" + filename
29 | img_path = filepath
30 | img = cv2.imread(img_path)
31 | im = Image.open(img_path)
32 | if im.mode != "RGB":
33 | print(filename + " not a RGB image")
34 | im.close()
35 | return
36 | im.close()
37 | shutil.copy(img_path, dst_path) # 把原始图像复制到目标文件夹
38 | E = objectify.ElementMaker(annotate=False)
39 | anno_tree = E.annotation(
40 | E.folder('1'),
41 | E.filename(filename),
42 | E.source(
43 | E.database('CKdemo'),
44 | E.annotation('VOC'),
45 | E.image('CK')
46 | ),
47 | E.size(
48 | E.width(img.shape[1]),
49 | E.height(img.shape[0]),
50 | E.depth(img.shape[2])
51 | ),
52 | E.segmented(0)
53 | )
54 | for obj in objs:
55 | E2 = objectify.ElementMaker(annotate=False)
56 | anno_tree2 = E2.object(
57 | E.name(obj[0]),
58 | E.pose(),
59 | E.truncated("0"),
60 | E.difficult(0),
61 | E.bndbox(
62 | E.xmin(obj[2]),
63 | E.ymin(obj[3]),
64 | E.xmax(obj[4]),
65 | E.ymax(obj[5])
66 | )
67 | )
68 | anno_tree.append(anno_tree2)
69 | etree.ElementTree(anno_tree).write(annopath, pretty_print=True)
70 |
71 |
72 | def showbycv(coco, dataType, img, classes, origin_image_dir, verbose=False):
73 | filename = img['file_name']
74 | filepath = os.path.join(origin_image_dir, dataType, filename)
75 | I = cv2.imread(filepath)
76 | annIds = coco.getAnnIds(imgIds=img['id'], iscrowd=None)
77 | anns = coco.loadAnns(annIds)
78 | objs = []
79 | for ann in anns:
80 | name = classes[ann['category_id']]
81 | if 'bbox' in ann:
82 | bbox = ann['bbox']
83 | xmin = (int)(bbox[0])
84 | ymin = (int)(bbox[1])
85 | xmax = (int)(bbox[2] + bbox[0])
86 | ymax = (int)(bbox[3] + bbox[1])
87 | obj = [name, 1.0, xmin, ymin, xmax, ymax]
88 | objs.append(obj)
89 | if verbose:
90 | cv2.rectangle(I, (xmin, ymin), (xmax, ymax), (255, 0, 0))
91 | cv2.putText(I, name, (xmin, ymin), 3, 1, (0, 0, 255))
92 | save_annotations(filename, objs, filepath)
93 | if verbose:
94 | cv2.imshow("img", I)
95 | cv2.waitKey(0)
96 |
97 |
98 | def catid2name(coco): # 将名字和id号建立一个字典
99 | classes = dict()
100 | for cat in coco.dataset['categories']:
101 | classes[cat['id']] = cat['name']
102 | # print(str(cat['id'])+":"+cat['name'])
103 | return classes
104 |
105 |
106 | def get_CK5(origin_anno_dir, origin_image_dir, verbose=False):
107 | dataTypes = ['val2017']
108 | for dataType in dataTypes:
109 | annFile = 'instances_{}.json'.format(dataType)
110 | annpath = os.path.join(origin_anno_dir, annFile)
111 | coco = COCO(annpath)
112 | classes = catid2name(coco)
113 | imgIds = coco.getImgIds()
114 | # imgIds=imgIds[0:1000]#测试用,抽取10张图片,看下存储效果
115 | for imgId in tqdm(imgIds):
116 | img = coco.loadImgs(imgId)[0]
117 | showbycv(coco, dataType, img, classes, origin_image_dir, verbose=False)
118 |
119 |
120 | def main():
121 | base_dir = './coco2017_voc' # step1 这里是一个新的文件夹,存放转换后的图片和标注
122 | image_dir = os.path.join(base_dir, 'images') # 在上述文件夹中生成images,annotations两个子文件夹
123 | anno_dir = os.path.join(base_dir, 'annotations')
124 | mkr(image_dir)
125 | mkr(anno_dir)
126 | origin_image_dir = './coco2017' # step 2原始的coco的图像存放位置
127 | origin_anno_dir = './coco2017/annotations' # step 3 原始的coco的标注存放位置
128 | print(origin_anno_dir)
129 | verbose = True # 是否需要看下标记是否正确的开关标记,若是true,就会把标记展示到图片上
130 | get_CK5(origin_anno_dir, origin_image_dir, verbose)
131 |
132 |
133 | if __name__ == "__main__":
134 | main()
135 |
136 | # split_traintest()
--------------------------------------------------------------------------------
/draw_box_utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import PIL.ImageDraw as ImageDraw
3 | import PIL.ImageFont as ImageFont
4 | import numpy as np
5 |
6 | STANDARD_COLORS = [
7 | 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque',
8 | 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite',
9 | 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan',
10 | 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange',
11 | 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet',
12 | 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite',
13 | 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod',
14 | 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki',
15 | 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue',
16 | 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey',
17 | 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue',
18 | 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime',
19 | 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid',
20 | 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen',
21 | 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin',
22 | 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed',
23 | 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed',
24 | 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple',
25 | 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown',
26 | 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue',
27 | 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow',
28 | 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White',
29 | 'WhiteSmoke', 'Yellow', 'YellowGreen'
30 | ]
31 |
32 |
33 | def filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map):
34 | for i in range(boxes.shape[0]):
35 | if scores[i] > thresh:
36 | box = tuple(boxes[i].tolist()) # numpy -> list -> tuple
37 | if classes[i] in category_index.keys():
38 | class_name = category_index[classes[i]]
39 | else:
40 | class_name = 'N/A'
41 | display_str = str(class_name)
42 | display_str = '{}: {}%'.format(display_str, int(100 * scores[i]))
43 | box_to_display_str_map[box].append(display_str)
44 | box_to_color_map[box] = STANDARD_COLORS[
45 | classes[i] % len(STANDARD_COLORS)]
46 | else:
47 | break # 网络输出概率已经排序过,当遇到一个不满足后面的肯定不满足
48 |
49 |
50 | def draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color):
51 | try:
52 | font = ImageFont.truetype('arial.ttf', 24)
53 | except IOError:
54 | font = ImageFont.load_default()
55 |
56 | # If the total height of the display strings added to the top of the bounding
57 | # box exceeds the top of the image, stack the strings below the bounding box
58 | # instead of above.
59 | display_str_heights = [font.getsize(ds)[1] for ds in box_to_display_str_map[box]]
60 | # Each display_str has a top and bottom margin of 0.05x.
61 | total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
62 |
63 | if top > total_display_str_height:
64 | text_bottom = top
65 | else:
66 | text_bottom = bottom + total_display_str_height
67 | # Reverse list and print from bottom to top.
68 | for display_str in box_to_display_str_map[box][::-1]:
69 | text_width, text_height = font.getsize(display_str)
70 | margin = np.ceil(0.05 * text_height)
71 | draw.rectangle([(left, text_bottom - text_height - 2 * margin),
72 | (left + text_width, text_bottom)], fill=color)
73 | draw.text((left + margin, text_bottom - text_height - margin),
74 | display_str,
75 | fill='black',
76 | font=font)
77 | text_bottom -= text_height - 2 * margin
78 |
79 |
80 | def draw_box(image, boxes, classes, scores, category_index, thresh=0.5, line_thickness=8):
81 | box_to_display_str_map = collections.defaultdict(list)
82 | box_to_color_map = collections.defaultdict(str)
83 |
84 | filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map)
85 |
86 | # Draw all boxes onto image.
87 | draw = ImageDraw.Draw(image)
88 | im_width, im_height = image.size
89 | for box, color in box_to_color_map.items():
90 | xmin, ymin, xmax, ymax = box
91 | (left, right, top, bottom) = (xmin * 1, xmax * 1,
92 | ymin * 1, ymax * 1)
93 | draw.line([(left, top), (left, bottom), (right, bottom),
94 | (right, top), (left, top)], width=line_thickness, fill=color)
95 | draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color)
96 |
--------------------------------------------------------------------------------
/loss_and_lr20220612-095042.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RooKichenn/AC-FPN/d61107cc69b1a669738b14d846779c1de564e3a9/loss_and_lr20220612-095042.png
--------------------------------------------------------------------------------
/mAP.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RooKichenn/AC-FPN/d61107cc69b1a669738b14d846779c1de564e3a9/mAP.png
--------------------------------------------------------------------------------
/my_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | import torch
5 | from PIL import Image
6 | import torch.utils.data as data
7 | from pycocotools.coco import COCO
8 |
9 |
10 | def _coco_remove_images_without_annotations(dataset, ids):
11 | """
12 | 删除coco数据集中没有目标,或者目标面积非常小的数据
13 | refer to:
14 | https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py
15 | :param dataset:
16 | :param cat_list:
17 | :return:
18 | """
19 | def _has_only_empty_bbox(anno):
20 | return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
21 |
22 | def _has_valid_annotation(anno):
23 | # if it's empty, there is no annotation
24 | if len(anno) == 0:
25 | return False
26 | # if all boxes have close to zero area, there is no annotation
27 | if _has_only_empty_bbox(anno):
28 | return False
29 |
30 | return True
31 |
32 | valid_ids = []
33 | for ds_idx, img_id in enumerate(ids):
34 | ann_ids = dataset.getAnnIds(imgIds=img_id, iscrowd=None)
35 | anno = dataset.loadAnns(ann_ids)
36 |
37 | if _has_valid_annotation(anno):
38 | valid_ids.append(img_id)
39 |
40 | return valid_ids
41 |
42 |
43 | class CocoDetection(data.Dataset):
44 | """`MS Coco Detection `_ Dataset.
45 | Args:
46 | root (string): Root directory where images are downloaded to.
47 | annFile (string): Path to json annotation file.
48 | transforms (callable, optional): A function/transform that takes input sample and its target as entry
49 | and returns a transformed version.
50 | """
51 |
52 | def __init__(self, root, dataset="train", transforms=None):
53 | super(CocoDetection, self).__init__()
54 | assert dataset in ["train", "val"], 'dataset must be in ["train", "val"]'
55 | anno_file = "instances_{}2017.json".format(dataset)
56 | assert os.path.exists(root), "file '{}' does not exist.".format(root)
57 | self.img_root = os.path.join(root, "{}2017".format(dataset))
58 | assert os.path.exists(self.img_root), "path '{}' does not exist.".format(self.img_root)
59 | self.anno_path = os.path.join(root, "annotations", anno_file)
60 | assert os.path.exists(self.anno_path), "file '{}' does not exist.".format(self.anno_path)
61 |
62 | self.mode = dataset
63 | self.transforms = transforms
64 | self.coco = COCO(self.anno_path)
65 |
66 | # 获取coco数据索引与类别名称的关系
67 | # 注意在object80中的索引并不是连续的,虽然只有80个类别,但索引还是按照stuff91来排序的
68 | data_classes = dict([(v["id"], v["name"]) for k, v in self.coco.cats.items()])
69 | max_index = max(data_classes.keys()) # 90
70 | # 将缺失的类别名称设置成N/A
71 | coco_classes = {}
72 | for k in range(1, max_index + 1):
73 | if k in data_classes:
74 | coco_classes[k] = data_classes[k]
75 | else:
76 | coco_classes[k] = "N/A"
77 |
78 | if dataset == "train":
79 | json_str = json.dumps(coco_classes, indent=4)
80 | with open("coco91_indices.json", "w") as f:
81 | f.write(json_str)
82 |
83 | self.coco_classes = coco_classes
84 |
85 | ids = list(sorted(self.coco.imgs.keys()))
86 | if dataset == "train":
87 | # 移除没有目标,或者目标面积非常小的数据
88 | valid_ids = _coco_remove_images_without_annotations(self.coco, ids)
89 | self.ids = valid_ids
90 | else:
91 | self.ids = ids
92 |
93 | def parse_targets(self,
94 | img_id: int,
95 | coco_targets: list,
96 | w: int = None,
97 | h: int = None):
98 | assert w > 0
99 | assert h > 0
100 |
101 | # 只筛选出单个对象的情况
102 | anno = [obj for obj in coco_targets if obj['iscrowd'] == 0]
103 |
104 | boxes = [obj["bbox"] for obj in anno]
105 |
106 | # guard against no boxes via resizing
107 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
108 | # [xmin, ymin, w, h] -> [xmin, ymin, xmax, ymax]
109 | boxes[:, 2:] += boxes[:, :2]
110 | boxes[:, 0::2].clamp_(min=0, max=w)
111 | boxes[:, 1::2].clamp_(min=0, max=h)
112 |
113 | classes = [obj["category_id"] for obj in anno]
114 | classes = torch.tensor(classes, dtype=torch.int64)
115 |
116 | area = torch.tensor([obj["area"] for obj in anno])
117 | iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
118 |
119 | # 筛选出合法的目标,即x_max>x_min且y_max>y_min
120 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
121 | boxes = boxes[keep]
122 | classes = classes[keep]
123 | area = area[keep]
124 | iscrowd = iscrowd[keep]
125 |
126 | target = {}
127 | target["boxes"] = boxes
128 | target["labels"] = classes
129 | target["image_id"] = torch.tensor([img_id])
130 |
131 | # for conversion to coco api
132 | target["area"] = area
133 | target["iscrowd"] = iscrowd
134 |
135 | return target
136 |
137 | def __getitem__(self, index):
138 | """
139 | Args:
140 | index (int): Index
141 | Returns:
142 | tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
143 | """
144 | coco = self.coco
145 | img_id = self.ids[index]
146 | ann_ids = coco.getAnnIds(imgIds=img_id)
147 | coco_target = coco.loadAnns(ann_ids)
148 |
149 | path = coco.loadImgs(img_id)[0]['file_name']
150 | img = Image.open(os.path.join(self.img_root, path)).convert('RGB')
151 |
152 | w, h = img.size
153 | target = self.parse_targets(img_id, coco_target, w, h)
154 | if self.transforms is not None:
155 | img, target = self.transforms(img, target)
156 |
157 | return img, target
158 |
159 | def __len__(self):
160 | return len(self.ids)
161 |
162 | def get_height_and_width(self, index):
163 | coco = self.coco
164 | img_id = self.ids[index]
165 |
166 | img_info = coco.loadImgs(img_id)[0]
167 | w = img_info["width"]
168 | h = img_info["height"]
169 | return h, w
170 |
171 | @staticmethod
172 | def collate_fn(batch):
173 | return tuple(zip(*batch))
174 |
175 |
176 | # train = CocoDetection("/data/coco_data/", dataset="train")
177 | # print(len(train))
178 | # t = train[0]
179 | # print(t)
--------------------------------------------------------------------------------
/network_files/__init__.py:
--------------------------------------------------------------------------------
1 | from .faster_rcnn_framework import FasterRCNN, FastRCNNPredictor
2 | from .rpn_function import AnchorsGenerator
3 | from .cawb import CosineAnnealingWarmbootingLR
--------------------------------------------------------------------------------
/network_files/boxes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Tuple
3 | from torch import Tensor
4 | import torchvision
5 |
6 |
7 | def nms(boxes, scores, iou_threshold):
8 | # type: (Tensor, Tensor, float) -> Tensor
9 | """
10 | Performs non-maximum suppression (NMS) on the boxes according
11 | to their intersection-over-union (IoU).
12 |
13 | NMS iteratively removes lower scoring boxes which have an
14 | IoU greater than iou_threshold with another (higher scoring)
15 | box.
16 |
17 | Parameters
18 | ----------
19 | boxes : Tensor[N, 4])
20 | boxes to perform NMS on. They
21 | are expected to be in (x1, y1, x2, y2) format
22 | scores : Tensor[N]
23 | scores for each one of the boxes
24 | iou_threshold : float
25 | discards all overlapping
26 | boxes with IoU < iou_threshold
27 |
28 | Returns
29 | -------
30 | keep : Tensor
31 | int64 tensor with the indices
32 | of the elements that have been kept
33 | by NMS, sorted in decreasing order of scores
34 | """
35 | return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
36 |
37 |
38 | def batched_nms(boxes, scores, idxs, iou_threshold):
39 | # type: (Tensor, Tensor, Tensor, float) -> Tensor
40 | """
41 | Performs non-maximum suppression in a batched fashion.
42 |
43 | Each index value correspond to a category, and NMS
44 | will not be applied between elements of different categories.
45 |
46 | Parameters
47 | ----------
48 | boxes : Tensor[N, 4]
49 | boxes where NMS will be performed. They
50 | are expected to be in (x1, y1, x2, y2) format
51 | scores : Tensor[N]
52 | scores for each one of the boxes
53 | idxs : Tensor[N]
54 | indices of the categories for each one of the boxes.
55 | iou_threshold : float
56 | discards all overlapping boxes
57 | with IoU < iou_threshold
58 |
59 | Returns
60 | -------
61 | keep : Tensor
62 | int64 tensor with the indices of
63 | the elements that have been kept by NMS, sorted
64 | in decreasing order of scores
65 | """
66 | if boxes.numel() == 0:
67 | return torch.empty((0,), dtype=torch.int64, device=boxes.device)
68 |
69 | # strategy: in order to perform NMS independently per class.
70 | # we add an offset to all the boxes. The offset is dependent
71 | # only on the class idx, and is large enough so that boxes
72 | # from different classes do not overlap
73 | # 获取所有boxes中最大的坐标值(xmin, ymin, xmax, ymax)
74 | max_coordinate = boxes.max()
75 |
76 | # to(): Performs Tensor dtype and/or device conversion
77 | # 为每一个类别/每一层生成一个很大的偏移量
78 | # 这里的to只是让生成tensor的dytpe和device与boxes保持一致
79 | offsets = idxs.to(boxes) * (max_coordinate + 1)
80 | # boxes加上对应层的偏移量后,保证不同类别/层之间boxes不会有重合的现象
81 | boxes_for_nms = boxes + offsets[:, None]
82 | keep = nms(boxes_for_nms, scores, iou_threshold)
83 | return keep
84 |
85 |
86 | def remove_small_boxes(boxes, min_size):
87 | # type: (Tensor, float) -> Tensor
88 | """
89 | Remove boxes which contains at least one side smaller than min_size.
90 | 移除宽高小于指定阈值的索引
91 | Arguments:
92 | boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
93 | min_size (float): minimum size
94 |
95 | Returns:
96 | keep (Tensor[K]): indices of the boxes that have both sides
97 | larger than min_size
98 | """
99 | ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] # 预测boxes的宽和高
100 | # keep = (ws >= min_size) & (hs >= min_size) # 当满足宽,高都大于给定阈值时为True
101 | keep = torch.logical_and(torch.ge(ws, min_size), torch.ge(hs, min_size))
102 | # nonzero(): Returns a tensor containing the indices of all non-zero elements of input
103 | # keep = keep.nonzero().squeeze(1)
104 | keep = torch.where(keep)[0]
105 | return keep
106 |
107 |
108 | def clip_boxes_to_image(boxes, size):
109 | # type: (Tensor, Tuple[int, int]) -> Tensor
110 | """
111 | Clip boxes so that they lie inside an image of size `size`.
112 | 裁剪预测的boxes信息,将越界的坐标调整到图片边界上
113 |
114 | Arguments:
115 | boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
116 | size (Tuple[height, width]): size of the image
117 |
118 | Returns:
119 | clipped_boxes (Tensor[N, 4])
120 | """
121 | dim = boxes.dim()
122 | boxes_x = boxes[..., 0::2] # x1, x2
123 | boxes_y = boxes[..., 1::2] # y1, y2
124 | height, width = size
125 |
126 | if torchvision._is_tracing():
127 | boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
128 | boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
129 | boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
130 | boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
131 | else:
132 | boxes_x = boxes_x.clamp(min=0, max=width) # 限制x坐标范围在[0,width]之间
133 | boxes_y = boxes_y.clamp(min=0, max=height) # 限制y坐标范围在[0,height]之间
134 |
135 | clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
136 | return clipped_boxes.reshape(boxes.shape)
137 |
138 |
139 | def box_area(boxes):
140 | """
141 | Computes the area of a set of bounding boxes, which are specified by its
142 | (x1, y1, x2, y2) coordinates.
143 |
144 | Arguments:
145 | boxes (Tensor[N, 4]): boxes for which the area will be computed. They
146 | are expected to be in (x1, y1, x2, y2) format
147 |
148 | Returns:
149 | area (Tensor[N]): area for each box
150 | """
151 | return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
152 |
153 |
154 | def box_iou(boxes1, boxes2):
155 | """
156 | Return intersection-over-union (Jaccard index) of boxes.
157 |
158 | Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
159 |
160 | Arguments:
161 | boxes1 (Tensor[N, 4])
162 | boxes2 (Tensor[M, 4])
163 |
164 | Returns:
165 | iou (Tensor[N, M]): the NxM matrix containing the pairwise
166 | IoU values for every element in boxes1 and boxes2
167 | """
168 | area1 = box_area(boxes1)
169 | area2 = box_area(boxes2)
170 |
171 | # When the shapes do not match,
172 | # the shape of the returned output tensor follows the broadcasting rules
173 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # left-top [N,M,2]
174 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # right-bottom [N,M,2]
175 |
176 | wh = (rb - lt).clamp(min=0) # [N,M,2]
177 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
178 |
179 | iou = inter / (area1[:, None] + area2 - inter)
180 | return iou
181 |
182 |
--------------------------------------------------------------------------------
/network_files/cawb.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Mon Sep 6 19:10:49 2021
4 |
5 | @author: hdb
6 | """
7 |
8 | import torch.optim as optim
9 | import torch
10 | import torch.nn as nn
11 | import argparse
12 | import math
13 | from copy import copy
14 | import matplotlib.pyplot as plt
15 |
16 |
17 | class CosineAnnealingWarmbootingLR:
18 | # cawb learning rate scheduler: given the warm booting steps, calculate the learning rate automatically
19 |
20 | def __init__(self, optimizer, epochs=0, eta_min=0.05, steps=[], step_scale=0.8, lf=None, batchs=0, warmup_epoch=0, epoch_scale=1.0):
21 | self.warmup_iters = batchs * warmup_epoch
22 | self.optimizer = optimizer
23 | self.eta_min = eta_min
24 | self.iters = -1
25 | self.iters_batch = -1
26 | self.base_lr = [group['lr'] for group in optimizer.param_groups]
27 | self.step_scale = step_scale
28 | steps.sort()
29 | self.steps = [warmup_epoch] + [i for i in steps if (i < epochs and i > warmup_epoch)] + [epochs]
30 | self.gap = 0
31 | self.last_epoch = 0
32 | self.lf = lf
33 | self.epoch_scale = epoch_scale
34 |
35 | # Initialize epochs and base learning rates
36 | for group in optimizer.param_groups:
37 | group.setdefault('initial_lr', group['lr'])
38 |
39 | def step(self, external_iter = None):
40 | self.iters += 1
41 | if external_iter is not None:
42 | self.iters = external_iter
43 |
44 | # cos warm boot policy
45 | iters = self.iters + self.last_epoch
46 | scale = 1.0
47 | for i in range(len(self.steps)-1):
48 | if (iters <= self.steps[i+1]):
49 | self.gap = self.steps[i+1] - self.steps[i]
50 | iters = iters - self.steps[i]
51 |
52 | if i != len(self.steps)-2:
53 | self.gap += self.epoch_scale
54 | break
55 | scale *= self.step_scale
56 |
57 | if self.lf is None:
58 | for group, lr in zip(self.optimizer.param_groups, self.base_lr):
59 | group['lr'] = scale * lr * ((((1 + math.cos(iters * math.pi / self.gap)) / 2) ** 1.0) * (1.0 - self.eta_min) + self.eta_min)
60 | else:
61 | for group, lr in zip(self.optimizer.param_groups, self.base_lr):
62 | group['lr'] = scale * lr * self.lf(iters, self.gap)
63 |
64 | return self.optimizer.param_groups[0]['lr']
65 |
66 | def step_batch(self):
67 | self.iters_batch += 1
68 |
69 | if self.iters_batch < self.warmup_iters:
70 |
71 | rate = self.iters_batch / self.warmup_iters
72 | for group, lr in zip(self.optimizer.param_groups, self.base_lr):
73 | group['lr'] = lr * rate
74 | return self.optimizer.param_groups[0]['lr']
75 | else:
76 | return None
77 |
78 |
79 | def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir='./LR.png'):
80 | # Plot LR simulating training for full epochs
81 | optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
82 | y = []
83 | for _ in range(scheduler.last_epoch):
84 | y.append(None)
85 | for _ in range(scheduler.last_epoch, epochs):
86 | y.append(scheduler.step())
87 |
88 | plt.plot(y, '.-', label='LR')
89 | plt.xlabel('epoch')
90 | plt.ylabel('LR')
91 | plt.grid()
92 | plt.xlim(0, epochs)
93 | plt.ylim(0)
94 | plt.tight_layout()
95 | plt.savefig(save_dir, dpi=200)
96 |
97 |
98 | class model(nn.Module):
99 | def __init__(self):
100 | super().__init__()
101 |
102 | self.conv = nn.Conv2d(3,3,3)
103 |
104 | def forward(self, x):
105 | return self.conv(x)
106 |
107 |
108 | def train(opt):
109 |
110 | net = model()
111 | data = [1] * 50
112 |
113 | optimizer = optim.Adam(net.parameters(), lr=0.1)
114 |
115 | lf = lambda x, y=opt.epochs: (((1 + math.cos(x * math.pi / y)) / 2) ** 1.0) * 0.8 + 0.2
116 | # lf = lambda x, y=opt.epochs: (1.0 - (x / y)) * 0.9 + 0.1
117 | scheduler = CosineAnnealingWarmbootingLR(optimizer, epochs=opt.epochs, steps=opt.cawb_steps, step_scale=0.7,
118 | lf=lf, batchs=len(data), warmup_epoch=5)
119 | last_epoch = 0
120 | scheduler.last_epoch = last_epoch # if resume from given model
121 | plot_lr_scheduler(optimizer, scheduler, opt.epochs) # 目前不能画出 warmup 的曲线
122 |
123 |
124 | for i in range(opt.epochs):
125 |
126 | for b in range(len(data)):
127 | lr = scheduler.step_batch() # defore the backward
128 | print(lr)
129 | # training
130 | # loss
131 | # backward
132 |
133 |
134 | scheduler.step()
135 |
136 | return 0
137 |
138 |
139 | if __name__ == '__main__':
140 | parser = argparse.ArgumentParser()
141 | # parser.add_argument('--epochs', type=int, default=150)
142 | # parser.add_argument('--scheduler_lr', type=str, default='cawb', help='the learning rate scheduler, cos/cawb')
143 | # parser.add_argument('--cawb_steps', nargs='+', type=int, default=[50, 100, 150], help='the cawb learning rate scheduler steps')
144 | parser.add_argument('--epochs', type=int, default=45)
145 | parser.add_argument('--scheduler_lr', type=str, default='cawb', help='the learning rate scheduler, cos/cawb')
146 | parser.add_argument('--cawb_steps', nargs='+', type=int, default=[15, 30, 45],
147 | help='the cawb learning rate scheduler steps')
148 | opt = parser.parse_args()
149 |
150 | train(opt)
151 |
152 |
--------------------------------------------------------------------------------
/network_files/det_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from typing import List, Tuple
4 | from torch import Tensor
5 |
6 |
7 | class BalancedPositiveNegativeSampler(object):
8 | """
9 | This class samples batches, ensuring that they contain a fixed proportion of positives
10 | """
11 |
12 | def __init__(self, batch_size_per_image, positive_fraction):
13 | # type: (int, float) -> None
14 | """
15 | Arguments:
16 | batch_size_per_image (int): number of elements to be selected per image
17 | positive_fraction (float): percentage of positive elements per batch
18 | """
19 | self.batch_size_per_image = batch_size_per_image
20 | self.positive_fraction = positive_fraction
21 |
22 | def __call__(self, matched_idxs):
23 | # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
24 | """
25 | Arguments:
26 | matched idxs: list of tensors containing -1, 0 or positive values.
27 | Each tensor corresponds to a specific image.
28 | -1 values are ignored, 0 are considered as negatives and > 0 as
29 | positives.
30 |
31 | Returns:
32 | pos_idx (list[tensor])
33 | neg_idx (list[tensor])
34 |
35 | Returns two lists of binary masks for each image.
36 | The first list contains the positive elements that were selected,
37 | and the second list the negative example.
38 | """
39 | pos_idx = []
40 | neg_idx = []
41 | # 遍历每张图像的matched_idxs
42 | for matched_idxs_per_image in matched_idxs:
43 | # >= 1的为正样本, nonzero返回非零元素索引
44 | # positive = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1)
45 | positive = torch.where(torch.ge(matched_idxs_per_image, 1))[0]
46 | # = 0的为负样本
47 | # negative = torch.nonzero(matched_idxs_per_image == 0).squeeze(1)
48 | negative = torch.where(torch.eq(matched_idxs_per_image, 0))[0]
49 |
50 | # 指定正样本的数量
51 | num_pos = int(self.batch_size_per_image * self.positive_fraction)
52 | # protect against not enough positive examples
53 | # 如果正样本数量不够就直接采用所有正样本
54 | num_pos = min(positive.numel(), num_pos)
55 | # 指定负样本数量
56 | num_neg = self.batch_size_per_image - num_pos
57 | # protect against not enough negative examples
58 | # 如果负样本数量不够就直接采用所有负样本
59 | num_neg = min(negative.numel(), num_neg)
60 |
61 | # randomly select positive and negative examples
62 | # Returns a random permutation of integers from 0 to n - 1.
63 |
64 | # -------------------------------------------------------------------------------------------#
65 | # -------------------------------------------------------------------------------------------#
66 | # -------------------------------------------------------------------------------------------#
67 | # 分层采样
68 |
69 | # 首先将positive和negative分为三层
70 | k = 3
71 | # 每层有几个数据
72 | pk = positive.numel() // 3
73 | fk = negative.numel() // 3
74 |
75 | positive01 = positive[0:pk]
76 | positive02 = positive[pk:pk*2]
77 | positive03 = positive[pk*2:]
78 |
79 | negative01 = negative[0:fk]
80 | negative02 = negative[fk:fk*2]
81 | negative03 = negative[fk*2:]
82 |
83 | # 每层采集数据个数
84 | num_pos_k = num_pos // 3
85 | num_neg_k = num_neg // 3
86 | rep01 = positive01[torch.randperm(positive01.numel(), device=positive.device)[:num_pos_k]]
87 | rep02 = positive02[torch.randperm(positive02.numel(), device=positive.device)[:num_pos_k]]
88 | rep03 = positive03[torch.randperm(positive03.numel(), device=positive.device)[:num_pos_k]]
89 |
90 | ref01 = negative01[torch.randperm(negative01.numel(), device=negative.device)[:num_neg_k]]
91 | ref02 = negative02[torch.randperm(negative02.numel(), device=negative.device)[:num_neg_k]]
92 | ref03 = negative03[torch.randperm(negative03.numel(), device=negative.device)[:num_neg_k]]
93 |
94 | pos_idx_per_image = torch.cat((rep01, rep02, rep03))
95 | neg_idx_per_image = torch.cat((ref01, ref02, ref03))
96 | # -------------------------------------------------------------------------------------------#
97 | # -------------------------------------------------------------------------------------------#
98 | # -------------------------------------------------------------------------------------------#
99 |
100 | # create binary mask from indices
101 | pos_idx_per_image_mask = torch.zeros_like(
102 | matched_idxs_per_image, dtype=torch.uint8
103 | )
104 | neg_idx_per_image_mask = torch.zeros_like(
105 | matched_idxs_per_image, dtype=torch.uint8
106 | )
107 |
108 | pos_idx_per_image_mask[pos_idx_per_image] = 1
109 | neg_idx_per_image_mask[neg_idx_per_image] = 1
110 |
111 | pos_idx.append(pos_idx_per_image_mask)
112 | neg_idx.append(neg_idx_per_image_mask)
113 |
114 | return pos_idx, neg_idx
115 |
116 |
117 | @torch.jit._script_if_tracing
118 | def encode_boxes(reference_boxes, proposals, weights):
119 | # type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
120 | """
121 | Encode a set of proposals with respect to some
122 | reference boxes
123 |
124 | Arguments:
125 | reference_boxes (Tensor): reference boxes(gt)
126 | proposals (Tensor): boxes to be encoded(anchors)
127 | weights:
128 | """
129 |
130 | # perform some unpacking to make it JIT-fusion friendly
131 | wx = weights[0]
132 | wy = weights[1]
133 | ww = weights[2]
134 | wh = weights[3]
135 |
136 | # unsqueeze()
137 | # Returns a new tensor with a dimension of size one inserted at the specified position.
138 | proposals_x1 = proposals[:, 0].unsqueeze(1)
139 | proposals_y1 = proposals[:, 1].unsqueeze(1)
140 | proposals_x2 = proposals[:, 2].unsqueeze(1)
141 | proposals_y2 = proposals[:, 3].unsqueeze(1)
142 |
143 | reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
144 | reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
145 | reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
146 | reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
147 |
148 | # implementation starts here
149 | # parse widths and heights
150 | ex_widths = proposals_x2 - proposals_x1
151 | ex_heights = proposals_y2 - proposals_y1
152 | # parse coordinate of center point
153 | ex_ctr_x = proposals_x1 + 0.5 * ex_widths
154 | ex_ctr_y = proposals_y1 + 0.5 * ex_heights
155 |
156 | gt_widths = reference_boxes_x2 - reference_boxes_x1
157 | gt_heights = reference_boxes_y2 - reference_boxes_y1
158 | gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
159 | gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
160 |
161 | targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
162 | targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
163 | targets_dw = ww * torch.log(gt_widths / ex_widths)
164 | targets_dh = wh * torch.log(gt_heights / ex_heights)
165 |
166 | targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
167 | return targets
168 |
169 |
170 | class BoxCoder(object):
171 | """
172 | This class encodes and decodes a set of bounding boxes into
173 | the representation used for training the regressors.
174 | """
175 |
176 | def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
177 | # type: (Tuple[float, float, float, float], float) -> None
178 | """
179 | Arguments:
180 | weights (4-element tuple)
181 | bbox_xform_clip (float)
182 | """
183 | self.weights = weights
184 | self.bbox_xform_clip = bbox_xform_clip
185 |
186 | def encode(self, reference_boxes, proposals):
187 | # type: (List[Tensor], List[Tensor]) -> List[Tensor]
188 | """
189 | 结合anchors和与之对应的gt计算regression参数
190 | Args:
191 | reference_boxes: List[Tensor] 每个proposal/anchor对应的gt_boxes
192 | proposals: List[Tensor] anchors/proposals
193 |
194 | Returns: regression parameters
195 |
196 | """
197 | # 统计每张图像的anchors个数,方便后面拼接在一起处理后在分开
198 | # reference_boxes和proposal数据结构相同
199 | boxes_per_image = [len(b) for b in reference_boxes]
200 | reference_boxes = torch.cat(reference_boxes, dim=0)
201 | proposals = torch.cat(proposals, dim=0)
202 |
203 | # targets_dx, targets_dy, targets_dw, targets_dh
204 | targets = self.encode_single(reference_boxes, proposals)
205 | return targets.split(boxes_per_image, 0)
206 |
207 | def encode_single(self, reference_boxes, proposals):
208 | """
209 | Encode a set of proposals with respect to some
210 | reference boxes
211 |
212 | Arguments:
213 | reference_boxes (Tensor): reference boxes
214 | proposals (Tensor): boxes to be encoded
215 | """
216 | dtype = reference_boxes.dtype
217 | device = reference_boxes.device
218 | weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
219 | targets = encode_boxes(reference_boxes, proposals, weights)
220 |
221 | return targets
222 |
223 | def decode(self, rel_codes, boxes):
224 | # type: (Tensor, List[Tensor]) -> Tensor
225 | """
226 |
227 | Args:
228 | rel_codes: bbox regression parameters
229 | boxes: anchors/proposals
230 |
231 | Returns:
232 |
233 | """
234 | assert isinstance(boxes, (list, tuple))
235 | assert isinstance(rel_codes, torch.Tensor)
236 | boxes_per_image = [b.size(0) for b in boxes]
237 | concat_boxes = torch.cat(boxes, dim=0)
238 |
239 | box_sum = 0
240 | for val in boxes_per_image:
241 | box_sum += val
242 |
243 | # 将预测的bbox回归参数应用到对应anchors上得到预测bbox的坐标
244 | pred_boxes = self.decode_single(
245 | rel_codes, concat_boxes
246 | )
247 |
248 | # 防止pred_boxes为空时导致reshape报错
249 | if box_sum > 0:
250 | pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
251 |
252 | return pred_boxes
253 |
254 | def decode_single(self, rel_codes, boxes):
255 | """
256 | From a set of original boxes and encoded relative box offsets,
257 | get the decoded boxes.
258 |
259 | Arguments:
260 | rel_codes (Tensor): encoded boxes (bbox regression parameters)
261 | boxes (Tensor): reference boxes (anchors/proposals)
262 | """
263 | boxes = boxes.to(rel_codes.dtype)
264 |
265 | # xmin, ymin, xmax, ymax
266 | widths = boxes[:, 2] - boxes[:, 0] # anchor/proposal宽度
267 | heights = boxes[:, 3] - boxes[:, 1] # anchor/proposal高度
268 | ctr_x = boxes[:, 0] + 0.5 * widths # anchor/proposal中心x坐标
269 | ctr_y = boxes[:, 1] + 0.5 * heights # anchor/proposal中心y坐标
270 |
271 | wx, wy, ww, wh = self.weights # RPN中为[1,1,1,1], fastrcnn中为[10,10,5,5]
272 | dx = rel_codes[:, 0::4] / wx # 预测anchors/proposals的中心坐标x回归参数
273 | dy = rel_codes[:, 1::4] / wy # 预测anchors/proposals的中心坐标y回归参数
274 | dw = rel_codes[:, 2::4] / ww # 预测anchors/proposals的宽度回归参数
275 | dh = rel_codes[:, 3::4] / wh # 预测anchors/proposals的高度回归参数
276 |
277 | # limit max value, prevent sending too large values into torch.exp()
278 | # self.bbox_xform_clip=math.log(1000. / 16) 4.135
279 | dw = torch.clamp(dw, max=self.bbox_xform_clip)
280 | dh = torch.clamp(dh, max=self.bbox_xform_clip)
281 |
282 | pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
283 | pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
284 | pred_w = torch.exp(dw) * widths[:, None]
285 | pred_h = torch.exp(dh) * heights[:, None]
286 |
287 | # xmin
288 | pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
289 | # ymin
290 | pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
291 | # xmax
292 | pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
293 | # ymax
294 | pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
295 |
296 | pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
297 | return pred_boxes
298 |
299 |
300 | class Matcher(object):
301 | BELOW_LOW_THRESHOLD = -1
302 | BETWEEN_THRESHOLDS = -2
303 |
304 | __annotations__ = {
305 | 'BELOW_LOW_THRESHOLD': int,
306 | 'BETWEEN_THRESHOLDS': int,
307 | }
308 |
309 | def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
310 | # type: (float, float, bool) -> None
311 | """
312 | Args:
313 | high_threshold (float): quality values greater than or equal to
314 | this value are candidate matches.
315 | low_threshold (float): a lower quality threshold used to stratify
316 | matches into three levels:
317 | 1) matches >= high_threshold
318 | 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
319 | 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
320 | allow_low_quality_matches (bool): if True, produce additional matches
321 | for predictions that have only low-quality match candidates. See
322 | set_low_quality_matches_ for more details.
323 | """
324 | self.BELOW_LOW_THRESHOLD = -1
325 | self.BETWEEN_THRESHOLDS = -2
326 | assert low_threshold <= high_threshold
327 | self.high_threshold = high_threshold # 0.7
328 | self.low_threshold = low_threshold # 0.3
329 | self.allow_low_quality_matches = allow_low_quality_matches
330 |
331 | def __call__(self, match_quality_matrix):
332 | """
333 | 计算anchors与每个gtboxes匹配的iou最大值,并记录索引,
334 | iou= self.low_threshold) & (
371 | matched_vals < self.high_threshold
372 | )
373 | # iou小于low_threshold的matches索引置为-1
374 | matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD # -1
375 |
376 | # iou在[low_threshold, high_threshold]之间的matches索引置为-2
377 | matches[between_thresholds] = self.BETWEEN_THRESHOLDS # -2
378 |
379 | if self.allow_low_quality_matches:
380 | assert all_matches is not None
381 | self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
382 |
383 | return matches
384 |
385 | def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
386 | """
387 | Produce additional matches for predictions that have only low-quality matches.
388 | Specifically, for each ground-truth find the set of predictions that have
389 | maximum overlap with it (including ties); for each prediction in that set, if
390 | it is unmatched, then match it to the ground-truth with which it has the highest
391 | quality value.
392 | """
393 | # For each gt, find the prediction with which it has highest quality
394 | # 对于每个gt boxes寻找与其iou最大的anchor,
395 | # highest_quality_foreach_gt为匹配到的最大iou值
396 | highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) # the dimension to reduce.
397 |
398 | # Find highest quality match available, even if it is low, including ties
399 | # 寻找每个gt boxes与其iou最大的anchor索引,一个gt匹配到的最大iou可能有多个anchor
400 | # gt_pred_pairs_of_highest_quality = torch.nonzero(
401 | # match_quality_matrix == highest_quality_foreach_gt[:, None]
402 | # )
403 | gt_pred_pairs_of_highest_quality = torch.where(
404 | torch.eq(match_quality_matrix, highest_quality_foreach_gt[:, None])
405 | )
406 | # Example gt_pred_pairs_of_highest_quality:
407 | # tensor([[ 0, 39796],
408 | # [ 1, 32055],
409 | # [ 1, 32070],
410 | # [ 2, 39190],
411 | # [ 2, 40255],
412 | # [ 3, 40390],
413 | # [ 3, 41455],
414 | # [ 4, 45470],
415 | # [ 5, 45325],
416 | # [ 5, 46390]])
417 | # Each row is a (gt index, prediction index)
418 | # Note how gt items 1, 2, 3, and 5 each have two ties
419 |
420 | # gt_pred_pairs_of_highest_quality[:, 0]代表是对应的gt index(不需要)
421 | # pre_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1]
422 | pre_inds_to_update = gt_pred_pairs_of_highest_quality[1]
423 | # 保留该anchor匹配gt最大iou的索引,即使iou低于设定的阈值
424 | matches[pre_inds_to_update] = all_matches[pre_inds_to_update]
425 |
426 |
427 | def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = True):
428 | """
429 | very similar to the smooth_l1_loss from pytorch, but with
430 | the extra beta parameter
431 | """
432 | n = torch.abs(input - target)
433 | # cond = n < beta
434 | cond = torch.lt(n, beta)
435 | loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
436 | if size_average:
437 | return loss.mean()
438 | return loss.sum()
439 |
--------------------------------------------------------------------------------
/network_files/faster_rcnn_framework.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from collections import OrderedDict
3 | from typing import Tuple, List, Dict, Optional, Union
4 |
5 | import torch
6 | from torch import nn, Tensor
7 | import torch.nn.functional as F
8 | from torchvision.ops import MultiScaleRoIAlign
9 |
10 | from .roi_head import RoIHeads
11 | from .transform import GeneralizedRCNNTransform
12 | from .rpn_function import AnchorsGenerator, RPNHead, RegionProposalNetwork
13 |
14 |
15 | class FasterRCNNBase(nn.Module):
16 | """
17 | Main class for Generalized R-CNN.
18 |
19 | Arguments:
20 | backbone (nn.Module):
21 | rpn (nn.Module):
22 | roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
23 | detections / masks from it.
24 | transform (nn.Module): performs the data transformation from the inputs to feed into
25 | the model
26 | """
27 |
28 | def __init__(self, backbone, rpn, roi_heads, transform):
29 | super(FasterRCNNBase, self).__init__()
30 | self.transform = transform
31 | self.backbone = backbone
32 | self.rpn = rpn
33 | self.roi_heads = roi_heads
34 | # used only on torchscript mode
35 | self._has_warned = False
36 |
37 | @torch.jit.unused
38 | def eager_outputs(self, losses, detections):
39 | # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
40 | if self.training:
41 | return losses
42 |
43 | return detections
44 |
45 | def forward(self, images, targets=None):
46 | # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
47 | """
48 | Arguments:
49 | images (list[Tensor]): images to be processed
50 | targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
51 |
52 | Returns:
53 | result (list[BoxList] or dict[Tensor]): the output from the model.
54 | During training, it returns a dict[Tensor] which contains the losses.
55 | During testing, it returns list[BoxList] contains additional fields
56 | like `scores`, `labels` and `mask` (for Mask R-CNN models).
57 |
58 | """
59 | if self.training and targets is None:
60 | raise ValueError("In training mode, targets should be passed")
61 |
62 | if self.training:
63 | assert targets is not None
64 | for target in targets: # 进一步判断传入的target的boxes参数是否符合规定
65 | boxes = target["boxes"]
66 | if isinstance(boxes, torch.Tensor):
67 | if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
68 | raise ValueError("Expected target boxes to be a tensor"
69 | "of shape [N, 4], got {:}.".format(
70 | boxes.shape))
71 | else:
72 | raise ValueError("Expected target boxes to be of type "
73 | "Tensor, got {:}.".format(type(boxes)))
74 |
75 | original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
76 | for img in images:
77 | val = img.shape[-2:]
78 | assert len(val) == 2 # 防止输入的是个一维向量
79 | original_image_sizes.append((val[0], val[1]))
80 | # original_image_sizes = [img.shape[-2:] for img in images]
81 |
82 | images, targets = self.transform(images, targets) # 对图像进行预处理
83 |
84 | # print(images.tensors.shape)
85 | features = self.backbone(images.tensors) # 将图像输入backbone得到特征图
86 | if isinstance(features, torch.Tensor): # 若只在一层特征层上预测,将feature放入有序字典中,并编号为‘0’
87 | features = OrderedDict([('0', features)]) # 若在多层特征层上预测,传入的就是一个有序字典
88 |
89 | # 将特征层以及标注target信息传入rpn中
90 | # proposals: List[Tensor], Tensor_shape: [num_proposals, 4],
91 | # 每个proposals是绝对坐标,且为(x1, y1, x2, y2)格式
92 | proposals, proposal_losses = self.rpn(images, features, targets)
93 |
94 | # 将rpn生成的数据以及标注target信息传入fast rcnn后半部分
95 | detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
96 |
97 | # 对网络的预测结果进行后处理(主要将bboxes还原到原图像尺度上)
98 | detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
99 |
100 | losses = {}
101 | losses.update(detector_losses)
102 | losses.update(proposal_losses)
103 |
104 | if torch.jit.is_scripting():
105 | if not self._has_warned:
106 | warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
107 | self._has_warned = True
108 | return losses, detections
109 | else:
110 | return self.eager_outputs(losses, detections)
111 |
112 | # if self.training:
113 | # return losses
114 | #
115 | # return detections
116 |
117 |
118 | class TwoMLPHead(nn.Module):
119 | """
120 | Standard heads for FPN-based models
121 |
122 | Arguments:
123 | in_channels (int): number of input channels
124 | representation_size (int): size of the intermediate representation
125 | """
126 |
127 | def __init__(self, in_channels, representation_size):
128 | super(TwoMLPHead, self).__init__()
129 |
130 | self.fc6 = nn.Linear(in_channels, representation_size)
131 | self.fc7 = nn.Linear(representation_size, representation_size)
132 |
133 | def forward(self, x):
134 | x = x.flatten(start_dim=1)
135 |
136 | x = F.relu(self.fc6(x))
137 | x = F.relu(self.fc7(x))
138 |
139 | return x
140 |
141 |
142 | class FastRCNNPredictor(nn.Module):
143 | """
144 | Standard classification + bounding box regression layers
145 | for Fast R-CNN.
146 |
147 | Arguments:
148 | in_channels (int): number of input channels
149 | num_classes (int): number of output classes (including background)
150 | """
151 |
152 | def __init__(self, in_channels, num_classes):
153 | super(FastRCNNPredictor, self).__init__()
154 | self.cls_score = nn.Linear(in_channels, num_classes)
155 | self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
156 |
157 | def forward(self, x):
158 | if x.dim() == 4:
159 | assert list(x.shape[2:]) == [1, 1]
160 | x = x.flatten(start_dim=1)
161 | scores = self.cls_score(x)
162 | bbox_deltas = self.bbox_pred(x)
163 |
164 | return scores, bbox_deltas
165 |
166 |
167 | class FasterRCNN(FasterRCNNBase):
168 | """
169 | Implements Faster R-CNN.
170 |
171 | The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
172 | image, and should be in 0-1 range. Different images can have different sizes.
173 |
174 | The behavior of the model changes depending if it is in training or evaluation mode.
175 |
176 | During training, the model expects both the input tensors, as well as a targets (list of dictionary),
177 | containing:
178 | - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values
179 | between 0 and H and 0 and W
180 | - labels (Int64Tensor[N]): the class label for each ground-truth box
181 |
182 | The model returns a Dict[Tensor] during training, containing the classification and regression
183 | losses for both the RPN and the R-CNN.
184 |
185 | During inference, the model requires only the input tensors, and returns the post-processed
186 | predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
187 | follows:
188 | - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between
189 | 0 and H and 0 and W
190 | - labels (Int64Tensor[N]): the predicted labels for each image
191 | - scores (Tensor[N]): the scores or each prediction
192 |
193 | Arguments:
194 | backbone (nn.Module): the network used to compute the features for the model.
195 | It should contain a out_channels attribute, which indicates the number of output
196 | channels that each feature map has (and it should be the same for all feature maps).
197 | The backbone should return a single Tensor or and OrderedDict[Tensor].
198 | num_classes (int): number of output classes of the model (including the background).
199 | If box_predictor is specified, num_classes should be None.
200 | min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
201 | max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
202 | image_mean (Tuple[float, float, float]): mean values used for input normalization.
203 | They are generally the mean values of the dataset on which the backbone has been trained
204 | on
205 | image_std (Tuple[float, float, float]): std values used for input normalization.
206 | They are generally the std values of the dataset on which the backbone has been trained on
207 | rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
208 | maps.
209 | rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
210 | rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
211 | rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
212 | rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
213 | rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
214 | rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
215 | rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
216 | considered as positive during training of the RPN.
217 | rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
218 | considered as negative during training of the RPN.
219 | rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
220 | for computing the loss
221 | rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
222 | of the RPN
223 | rpn_score_thresh (float): during inference, only return proposals with a classification score
224 | greater than rpn_score_thresh
225 | box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
226 | the locations indicated by the bounding boxes
227 | box_head (nn.Module): module that takes the cropped feature maps as input
228 | box_predictor (nn.Module): module that takes the output of box_head and returns the
229 | classification logits and box regression deltas.
230 | box_score_thresh (float): during inference, only return proposals with a classification score
231 | greater than box_score_thresh
232 | box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
233 | box_detections_per_img (int): maximum number of detections per image, for all classes.
234 | box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
235 | considered as positive during training of the classification head
236 | box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
237 | considered as negative during training of the classification head
238 | box_batch_size_per_image (int): number of proposals that are sampled during training of the
239 | classification head
240 | box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
241 | of the classification head
242 | bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
243 | bounding boxes
244 |
245 | """
246 |
247 | def __init__(self, backbone, num_classes=None,
248 | # transform parameter
249 | min_size=800, max_size=1333, # 预处理resize时限制的最小尺寸与最大尺寸
250 | image_mean=None, image_std=None, # 预处理normalize时使用的均值和方差
251 | # RPN parameters
252 | rpn_anchor_generator=None, rpn_head=None,
253 | rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, # rpn中在nms处理前保留的proposal数(根据score)
254 | rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, # rpn中在nms处理后保留的proposal数
255 | rpn_nms_thresh=0.7, # rpn中进行nms处理时使用的iou阈值
256 | rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, # rpn计算损失时,采集正负样本设置的阈值
257 | rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, # rpn计算损失时采样的样本数,以及正样本占总样本的比例
258 | rpn_score_thresh=0.0,
259 | # Box parameters
260 | box_roi_pool=None, box_head=None, box_predictor=None,
261 | # 移除低目标概率 fast rcnn中进行nms处理的阈值 对预测结果根据score排序取前100个目标
262 | box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
263 | box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, # fast rcnn计算误差时,采集正负样本设置的阈值
264 | box_batch_size_per_image=512, box_positive_fraction=0.25, # fast rcnn计算误差时采样的样本数,以及正样本占所有样本的比例
265 | bbox_reg_weights=None):
266 | if not hasattr(backbone, "out_channels"):
267 | raise ValueError(
268 | "backbone should contain an attribute out_channels"
269 | "specifying the number of output channels (assumed to be the"
270 | "same for all the levels"
271 | )
272 |
273 | assert isinstance(rpn_anchor_generator, (AnchorsGenerator, type(None)))
274 | assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None)))
275 |
276 | if num_classes is not None:
277 | if box_predictor is not None:
278 | raise ValueError("num_classes should be None when box_predictor "
279 | "is specified")
280 | else:
281 | if box_predictor is None:
282 | raise ValueError("num_classes should not be None when box_predictor "
283 | "is not specified")
284 |
285 | # 预测特征层的channels
286 | out_channels = backbone.out_channels
287 |
288 | # 若anchor生成器为空,则自动生成针对resnet50_fpn的anchor生成器
289 | if rpn_anchor_generator is None:
290 | anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
291 | aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
292 | rpn_anchor_generator = AnchorsGenerator(
293 | anchor_sizes, aspect_ratios
294 | )
295 |
296 | # 生成RPN通过滑动窗口预测网络部分
297 | if rpn_head is None:
298 | rpn_head = RPNHead(
299 | out_channels, rpn_anchor_generator.num_anchors_per_location()[0]
300 | )
301 |
302 | # 默认rpn_pre_nms_top_n_train = 2000, rpn_pre_nms_top_n_test = 1000,
303 | # 默认rpn_post_nms_top_n_train = 2000, rpn_post_nms_top_n_test = 1000,
304 | rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
305 | rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
306 |
307 | # 定义整个RPN框架
308 | rpn = RegionProposalNetwork(
309 | rpn_anchor_generator, rpn_head,
310 | rpn_fg_iou_thresh, rpn_bg_iou_thresh,
311 | rpn_batch_size_per_image, rpn_positive_fraction,
312 | rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh,
313 | score_thresh=rpn_score_thresh)
314 |
315 | # Multi-scale RoIAlign pooling
316 | if box_roi_pool is None:
317 | box_roi_pool = MultiScaleRoIAlign(
318 | featmap_names=['0', '1', '2', '3'], # 在哪些特征层进行roi pooling
319 | output_size=[7, 7],
320 | sampling_ratio=2)
321 |
322 | # fast RCNN中roi pooling后的展平处理两个全连接层部分
323 | if box_head is None:
324 | resolution = box_roi_pool.output_size[0] # 默认等于7
325 | representation_size = 1024
326 | box_head = TwoMLPHead(
327 | out_channels * resolution ** 2,
328 | representation_size
329 | )
330 |
331 | # 在box_head的输出上预测部分
332 | if box_predictor is None:
333 | representation_size = 1024
334 | box_predictor = FastRCNNPredictor(
335 | representation_size,
336 | num_classes)
337 |
338 | # 将roi pooling, box_head以及box_predictor结合在一起
339 | roi_heads = RoIHeads(
340 | # box
341 | box_roi_pool, box_head, box_predictor,
342 | box_fg_iou_thresh, box_bg_iou_thresh, # 0.5 0.5
343 | box_batch_size_per_image, box_positive_fraction, # 512 0.25
344 | bbox_reg_weights,
345 | box_score_thresh, box_nms_thresh, box_detections_per_img) # 0.05 0.5 100
346 |
347 | if image_mean is None:
348 | image_mean = [0.485, 0.456, 0.406]
349 | if image_std is None:
350 | image_std = [0.229, 0.224, 0.225]
351 |
352 | # 对数据进行标准化,缩放,打包成batch等处理部分
353 | transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
354 |
355 | super(FasterRCNN, self).__init__(backbone, rpn, roi_heads, transform)
356 |
--------------------------------------------------------------------------------
/network_files/image_list.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 | from torch import Tensor
3 |
4 |
5 | class ImageList(object):
6 | """
7 | Structure that holds a list of images (of possibly
8 | varying sizes) as a single tensor.
9 | This works by padding the images to the same size,
10 | and storing in a field the original sizes of each image
11 | """
12 |
13 | def __init__(self, tensors, image_sizes):
14 | # type: (Tensor, List[Tuple[int, int]]) -> None
15 | """
16 | Arguments:
17 | tensors (tensor) padding后的图像数据
18 | image_sizes (list[tuple[int, int]]) padding前的图像尺寸
19 | """
20 | self.tensors = tensors
21 | self.image_sizes = image_sizes
22 |
23 | def to(self, device):
24 | # type: (Device) -> ImageList # noqa
25 | cast_tensor = self.tensors.to(device)
26 | return ImageList(cast_tensor, self.image_sizes)
27 |
28 |
--------------------------------------------------------------------------------
/network_files/roi_head.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Dict, Tuple
2 |
3 | import torch
4 | from torch import Tensor
5 | import torch.nn.functional as F
6 |
7 | from . import det_utils
8 | from . import boxes as box_ops
9 |
10 |
11 | def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
12 | # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
13 | """
14 | Computes the loss for Faster R-CNN.
15 |
16 | Arguments:
17 | class_logits : 预测类别概率信息,shape=[num_anchors, num_classes]
18 | box_regression : 预测边目标界框回归信息
19 | labels : 真实类别信息
20 | regression_targets : 真实目标边界框信息
21 |
22 | Returns:
23 | classification_loss (Tensor)
24 | box_loss (Tensor)
25 | """
26 |
27 | labels = torch.cat(labels, dim=0)
28 | regression_targets = torch.cat(regression_targets, dim=0)
29 |
30 | # 计算类别损失信息
31 | classification_loss = F.cross_entropy(class_logits, labels)
32 |
33 | # get indices that correspond to the regression targets for
34 | # the corresponding ground truth labels, to be used with
35 | # advanced indexing
36 | # 返回标签类别大于0的索引
37 | # sampled_pos_inds_subset = torch.nonzero(torch.gt(labels, 0)).squeeze(1)
38 | sampled_pos_inds_subset = torch.where(torch.gt(labels, 0))[0]
39 |
40 | # 返回标签类别大于0位置的类别信息
41 | labels_pos = labels[sampled_pos_inds_subset]
42 |
43 | # shape=[num_proposal, num_classes]
44 | N, num_classes = class_logits.shape
45 | box_regression = box_regression.reshape(N, -1, 4)
46 |
47 | # 计算边界框损失信息
48 | box_loss = det_utils.smooth_l1_loss(
49 | # 获取指定索引proposal的指定类别box信息
50 | box_regression[sampled_pos_inds_subset, labels_pos],
51 | regression_targets[sampled_pos_inds_subset],
52 | beta=1 / 9,
53 | size_average=False,
54 | ) / labels.numel()
55 |
56 | return classification_loss, box_loss
57 |
58 |
59 | class RoIHeads(torch.nn.Module):
60 | __annotations__ = {
61 | 'box_coder': det_utils.BoxCoder,
62 | 'proposal_matcher': det_utils.Matcher,
63 | 'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler,
64 | }
65 |
66 | def __init__(self,
67 | box_roi_pool, # Multi-scale RoIAlign pooling
68 | box_head, # TwoMLPHead
69 | box_predictor, # FastRCNNPredictor
70 | # Faster R-CNN training
71 | fg_iou_thresh, bg_iou_thresh, # default: 0.5, 0.5
72 | batch_size_per_image, positive_fraction, # default: 512, 0.25
73 | bbox_reg_weights, # None
74 | # Faster R-CNN inference
75 | score_thresh, # default: 0.05
76 | nms_thresh, # default: 0.5
77 | detection_per_img): # default: 100
78 | super(RoIHeads, self).__init__()
79 |
80 | self.box_similarity = box_ops.box_iou
81 | # assign ground-truth boxes for each proposal
82 | self.proposal_matcher = det_utils.Matcher(
83 | fg_iou_thresh, # default: 0.5
84 | bg_iou_thresh, # default: 0.5
85 | allow_low_quality_matches=False)
86 |
87 | self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(
88 | batch_size_per_image, # default: 512
89 | positive_fraction) # default: 0.25
90 |
91 | if bbox_reg_weights is None:
92 | bbox_reg_weights = (10., 10., 5., 5.)
93 | self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
94 |
95 | self.box_roi_pool = box_roi_pool # Multi-scale RoIAlign pooling
96 | self.box_head = box_head # TwoMLPHead
97 | self.box_predictor = box_predictor # FastRCNNPredictor
98 |
99 | self.score_thresh = score_thresh # default: 0.05
100 | self.nms_thresh = nms_thresh # default: 0.5
101 | self.detection_per_img = detection_per_img # default: 100
102 |
103 | def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
104 | # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
105 | """
106 | 为每个proposal匹配对应的gt_box,并划分到正负样本中
107 | Args:
108 | proposals:
109 | gt_boxes:
110 | gt_labels:
111 |
112 | Returns:
113 |
114 | """
115 | matched_idxs = []
116 | labels = []
117 | # 遍历每张图像的proposals, gt_boxes, gt_labels信息
118 | for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
119 | if gt_boxes_in_image.numel() == 0: # 该张图像中没有gt框,为背景
120 | # background image
121 | device = proposals_in_image.device
122 | clamped_matched_idxs_in_image = torch.zeros(
123 | (proposals_in_image.shape[0],), dtype=torch.int64, device=device
124 | )
125 | labels_in_image = torch.zeros(
126 | (proposals_in_image.shape[0],), dtype=torch.int64, device=device
127 | )
128 | else:
129 | # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
130 | # 计算proposal与每个gt_box的iou重合度
131 | match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
132 |
133 | # 计算proposal与每个gt_box匹配的iou最大值,并记录索引,
134 | # iou < low_threshold索引值为 -1, low_threshold <= iou < high_threshold索引值为 -2
135 | matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
136 |
137 | # 限制最小值,防止匹配标签时出现越界的情况
138 | # 注意-1, -2对应的gt索引会调整到0,获取的标签类别为第0个gt的类别(实际上并不是),后续会进一步处理
139 | clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
140 | # 获取proposal匹配到的gt对应标签
141 | labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
142 | labels_in_image = labels_in_image.to(dtype=torch.int64)
143 |
144 | # label background (below the low threshold)
145 | # 将gt索引为-1的类别设置为0,即背景,负样本
146 | bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD # -1
147 | labels_in_image[bg_inds] = 0
148 |
149 | # label ignore proposals (between low and high threshold)
150 | # 将gt索引为-2的类别设置为-1, 即废弃样本
151 | ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS # -2
152 | labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
153 |
154 | matched_idxs.append(clamped_matched_idxs_in_image)
155 | labels.append(labels_in_image)
156 | return matched_idxs, labels
157 |
158 | def subsample(self, labels):
159 | # type: (List[Tensor]) -> List[Tensor]
160 | # BalancedPositiveNegativeSampler
161 | sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
162 | sampled_inds = []
163 | # 遍历每张图片的正负样本索引
164 | for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
165 | # 记录所有采集样本索引(包括正样本和负样本)
166 | # img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
167 | img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
168 | sampled_inds.append(img_sampled_inds)
169 | return sampled_inds
170 |
171 | def add_gt_proposals(self, proposals, gt_boxes):
172 | # type: (List[Tensor], List[Tensor]) -> List[Tensor]
173 | """
174 | 将gt_boxes拼接到proposal后面
175 | Args:
176 | proposals: 一个batch中每张图像rpn预测的boxes
177 | gt_boxes: 一个batch中每张图像对应的真实目标边界框
178 |
179 | Returns:
180 |
181 | """
182 | proposals = [
183 | torch.cat((proposal, gt_box))
184 | for proposal, gt_box in zip(proposals, gt_boxes)
185 | ]
186 | return proposals
187 |
188 | def check_targets(self, targets):
189 | # type: (Optional[List[Dict[str, Tensor]]]) -> None
190 | assert targets is not None
191 | assert all(["boxes" in t for t in targets])
192 | assert all(["labels" in t for t in targets])
193 |
194 | def select_training_samples(self,
195 | proposals, # type: List[Tensor]
196 | targets # type: Optional[List[Dict[str, Tensor]]]
197 | ):
198 | # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
199 | """
200 | 划分正负样本,统计对应gt的标签以及边界框回归信息
201 | list元素个数为batch_size
202 | Args:
203 | proposals: rpn预测的boxes
204 | targets:
205 |
206 | Returns:
207 |
208 | """
209 |
210 | # 检查target数据是否为空
211 | self.check_targets(targets)
212 | # 如果不加这句,jit.script会不通过(看不懂)
213 | assert targets is not None
214 |
215 | dtype = proposals[0].dtype
216 | device = proposals[0].device
217 |
218 | # 获取标注好的boxes以及labels信息
219 | gt_boxes = [t["boxes"].to(dtype) for t in targets]
220 | gt_labels = [t["labels"] for t in targets]
221 |
222 | # append ground-truth bboxes to proposal
223 | # 将gt_boxes拼接到proposal后面
224 | proposals = self.add_gt_proposals(proposals, gt_boxes)
225 |
226 | # get matching gt indices for each proposal
227 | # 为每个proposal匹配对应的gt_box,并划分到正负样本中
228 | matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
229 | # sample a fixed proportion of positive-negative proposals
230 | # 按给定数量和比例采样正负样本
231 | sampled_inds = self.subsample(labels)
232 | matched_gt_boxes = []
233 | num_images = len(proposals)
234 |
235 | # 遍历每张图像
236 | for img_id in range(num_images):
237 | # 获取每张图像的正负样本索引
238 | img_sampled_inds = sampled_inds[img_id]
239 | # 获取对应正负样本的proposals信息
240 | proposals[img_id] = proposals[img_id][img_sampled_inds]
241 | # 获取对应正负样本的真实类别信息
242 | labels[img_id] = labels[img_id][img_sampled_inds]
243 | # 获取对应正负样本的gt索引信息
244 | matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
245 |
246 | gt_boxes_in_image = gt_boxes[img_id]
247 | if gt_boxes_in_image.numel() == 0:
248 | gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
249 | # 获取对应正负样本的gt box信息
250 | matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
251 |
252 | # 根据gt和proposal计算边框回归参数(针对gt的)
253 | regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
254 | return proposals, labels, regression_targets
255 |
256 | def postprocess_detections(self,
257 | class_logits, # type: Tensor
258 | box_regression, # type: Tensor
259 | proposals, # type: List[Tensor]
260 | image_shapes # type: List[Tuple[int, int]]
261 | ):
262 | # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
263 | """
264 | 对网络的预测数据进行后处理,包括
265 | (1)根据proposal以及预测的回归参数计算出最终bbox坐标
266 | (2)对预测类别结果进行softmax处理
267 | (3)裁剪预测的boxes信息,将越界的坐标调整到图片边界上
268 | (4)移除所有背景信息
269 | (5)移除低概率目标
270 | (6)移除小尺寸目标
271 | (7)执行nms处理,并按scores进行排序
272 | (8)根据scores排序返回前topk个目标
273 | Args:
274 | class_logits: 网络预测类别概率信息
275 | box_regression: 网络预测的边界框回归参数
276 | proposals: rpn输出的proposal
277 | image_shapes: 打包成batch前每张图像的宽高
278 |
279 | Returns:
280 |
281 | """
282 | device = class_logits.device
283 | # 预测目标类别数
284 | num_classes = class_logits.shape[-1]
285 |
286 | # 获取每张图像的预测bbox数量
287 | boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
288 | # 根据proposal以及预测的回归参数计算出最终bbox坐标
289 | pred_boxes = self.box_coder.decode(box_regression, proposals)
290 |
291 | # 对预测类别结果进行softmax处理
292 | pred_scores = F.softmax(class_logits, -1)
293 |
294 | # split boxes and scores per image
295 | # 根据每张图像的预测bbox数量分割结果
296 | pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
297 | pred_scores_list = pred_scores.split(boxes_per_image, 0)
298 |
299 | all_boxes = []
300 | all_scores = []
301 | all_labels = []
302 | # 遍历每张图像预测信息
303 | for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
304 | # 裁剪预测的boxes信息,将越界的坐标调整到图片边界上
305 | boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
306 |
307 | # create labels for each prediction
308 | labels = torch.arange(num_classes, device=device)
309 | labels = labels.view(1, -1).expand_as(scores)
310 |
311 | # remove prediction with the background label
312 | # 移除索引为0的所有信息(0代表背景)
313 | boxes = boxes[:, 1:]
314 | scores = scores[:, 1:]
315 | labels = labels[:, 1:]
316 |
317 | # batch everything, by making every class prediction be a separate instance
318 | boxes = boxes.reshape(-1, 4)
319 | scores = scores.reshape(-1)
320 | labels = labels.reshape(-1)
321 |
322 | # remove low scoring boxes
323 | # 移除低概率目标,self.scores_thresh=0.05
324 | # gt: Computes input > other element-wise.
325 | # inds = torch.nonzero(torch.gt(scores, self.score_thresh)).squeeze(1)
326 | inds = torch.where(torch.gt(scores, self.score_thresh))[0]
327 | boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
328 |
329 | # remove empty boxes
330 | # 移除小目标
331 | keep = box_ops.remove_small_boxes(boxes, min_size=1.)
332 | boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
333 |
334 | # non-maximun suppression, independently done per class
335 | # 执行nms处理,执行后的结果会按照scores从大到小进行排序返回
336 | keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
337 |
338 | # keep only topk scoring predictions
339 | # 获取scores排在前topk个预测目标
340 | keep = keep[:self.detection_per_img]
341 | boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
342 |
343 | all_boxes.append(boxes)
344 | all_scores.append(scores)
345 | all_labels.append(labels)
346 |
347 | return all_boxes, all_scores, all_labels
348 |
349 | def forward(self,
350 | features, # type: Dict[str, Tensor]
351 | proposals, # type: List[Tensor]
352 | image_shapes, # type: List[Tuple[int, int]]
353 | targets=None # type: Optional[List[Dict[str, Tensor]]]
354 | ):
355 | # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
356 | """
357 | Arguments:
358 | features (List[Tensor])
359 | proposals (List[Tensor[N, 4]])
360 | image_shapes (List[Tuple[H, W]])
361 | targets (List[Dict])
362 | """
363 |
364 | # 检查targets的数据类型是否正确
365 | if targets is not None:
366 | for t in targets:
367 | floating_point_types = (torch.float, torch.double, torch.half)
368 | assert t["boxes"].dtype in floating_point_types, "target boxes must of float type"
369 | assert t["labels"].dtype == torch.int64, "target labels must of int64 type"
370 |
371 | if self.training:
372 | # 划分正负样本,统计对应gt的标签以及边界框回归信息
373 | proposals, labels, regression_targets = self.select_training_samples(proposals, targets)
374 | else:
375 | labels = None
376 | regression_targets = None
377 |
378 | # 将采集样本通过Multi-scale RoIAlign pooling层
379 | # box_features_shape: [num_proposals, channel, height, width]
380 | box_features = self.box_roi_pool(features, proposals, image_shapes)
381 |
382 | # 通过roi_pooling后的两层全连接层
383 | # box_features_shape: [num_proposals, representation_size]
384 | box_features = self.box_head(box_features)
385 |
386 | # 接着分别预测目标类别和边界框回归参数
387 | class_logits, box_regression = self.box_predictor(box_features)
388 |
389 | result = torch.jit.annotate(List[Dict[str, torch.Tensor]], [])
390 | losses = {}
391 | if self.training:
392 | assert labels is not None and regression_targets is not None
393 | loss_classifier, loss_box_reg = fastrcnn_loss(
394 | class_logits, box_regression, labels, regression_targets)
395 | losses = {
396 | "loss_classifier": loss_classifier,
397 | "loss_box_reg": loss_box_reg
398 | }
399 | else:
400 | boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
401 | num_images = len(boxes)
402 | for i in range(num_images):
403 | result.append(
404 | {
405 | "boxes": boxes[i],
406 | "labels": labels[i],
407 | "scores": scores[i],
408 | }
409 | )
410 |
411 | return result, losses
412 |
--------------------------------------------------------------------------------
/network_files/transform.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List, Tuple, Dict, Optional
3 |
4 | import torch
5 | from torch import nn, Tensor
6 | import torchvision
7 |
8 | from .image_list import ImageList
9 |
10 |
11 | @torch.jit.unused
12 | def _resize_image_onnx(image, self_min_size, self_max_size):
13 | # type: (Tensor, float, float) -> Tensor
14 | from torch.onnx import operators
15 | im_shape = operators.shape_as_tensor(image)[-2:]
16 | min_size = torch.min(im_shape).to(dtype=torch.float32)
17 | max_size = torch.max(im_shape).to(dtype=torch.float32)
18 | scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size)
19 |
20 | image = torch.nn.functional.interpolate(
21 | image[None], scale_factor=scale_factor, mode="bilinear", recompute_scale_factor=True,
22 | align_corners=False)[0]
23 |
24 | return image
25 |
26 |
27 | def _resize_image(image, self_min_size, self_max_size):
28 | # type: (Tensor, float, float) -> Tensor
29 | im_shape = torch.tensor(image.shape[-2:])
30 | min_size = float(torch.min(im_shape)) # 获取高宽中的最小值
31 | max_size = float(torch.max(im_shape)) # 获取高宽中的最大值
32 | scale_factor = self_min_size / min_size # 根据指定最小边长和图片最小边长计算缩放比例
33 |
34 | # 如果使用该缩放比例计算的图片最大边长大于指定的最大边长
35 | if max_size * scale_factor > self_max_size:
36 | scale_factor = self_max_size / max_size # 将缩放比例设为指定最大边长和图片最大边长之比
37 |
38 | # interpolate利用插值的方法缩放图片
39 | # image[None]操作是在最前面添加batch维度[C, H, W] -> [1, C, H, W]
40 | # bilinear只支持4D Tensor
41 | image = torch.nn.functional.interpolate(
42 | image[None], scale_factor=scale_factor, mode="bilinear", recompute_scale_factor=True,
43 | align_corners=False)[0]
44 |
45 | return image
46 |
47 |
48 | class GeneralizedRCNNTransform(nn.Module):
49 | """
50 | Performs input / target transformation before feeding the data to a GeneralizedRCNN
51 | model.
52 |
53 | The transformations it perform are:
54 | - input normalization (mean subtraction and std division)
55 | - input / target resizing to match min_size / max_size
56 |
57 | It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
58 | """
59 |
60 | def __init__(self, min_size, max_size, image_mean, image_std):
61 | super(GeneralizedRCNNTransform, self).__init__()
62 | if not isinstance(min_size, (list, tuple)):
63 | min_size = (min_size,)
64 | self.min_size = min_size # 指定图像的最小边长范围
65 | self.max_size = max_size # 指定图像的最大边长范围
66 | self.image_mean = image_mean # 指定图像在标准化处理中的均值
67 | self.image_std = image_std # 指定图像在标准化处理中的方差
68 |
69 | def normalize(self, image):
70 | """标准化处理"""
71 | dtype, device = image.dtype, image.device
72 | mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
73 | std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
74 | # [:, None, None]: shape [3] -> [3, 1, 1]
75 | return (image - mean[:, None, None]) / std[:, None, None]
76 |
77 | def torch_choice(self, k):
78 | # type: (List[int]) -> int
79 | """
80 | Implements `random.choice` via torch ops so it can be compiled with
81 | TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
82 | is fixed.
83 | """
84 | index = int(torch.empty(1).uniform_(0., float(len(k))).item())
85 | return k[index]
86 |
87 | def resize(self, image, target):
88 | # type: (Tensor, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
89 | """
90 | 将图片缩放到指定的大小范围内,并对应缩放bboxes信息
91 | Args:
92 | image: 输入的图片
93 | target: 输入图片的相关信息(包括bboxes信息)
94 |
95 | Returns:
96 | image: 缩放后的图片
97 | target: 缩放bboxes后的图片相关信息
98 | """
99 | # image shape is [channel, height, width]
100 | h, w = image.shape[-2:]
101 |
102 | if self.training:
103 | size = float(self.torch_choice(self.min_size)) # 指定输入图片的最小边长,注意是self.min_size不是min_size
104 | else:
105 | # FIXME assume for now that testing uses the largest scale
106 | size = float(self.min_size[-1]) # 指定输入图片的最小边长,注意是self.min_size不是min_size
107 |
108 | if torchvision._is_tracing():
109 | image = _resize_image_onnx(image, size, float(self.max_size))
110 | else:
111 | image = _resize_image(image, size, float(self.max_size))
112 |
113 | if target is None:
114 | return image, target
115 | bbox = target["boxes"]
116 | # 根据图像的缩放比例来缩放bbox
117 | bbox = resize_boxes(bbox, [h, w], image.shape[-2:])
118 | target["boxes"] = bbox
119 |
120 | return image, target
121 |
122 | # _onnx_batch_images() is an implementation of
123 | # batch_images() that is supported by ONNX tracing.
124 | @torch.jit.unused
125 | def _onnx_batch_images(self, images, size_divisible=32):
126 | # type: (List[Tensor], int) -> Tensor
127 | max_size = []
128 | for i in range(images[0].dim()):
129 | max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
130 | max_size.append(max_size_i)
131 | stride = size_divisible
132 | max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
133 | max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
134 | max_size = tuple(max_size)
135 |
136 | # work around for
137 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
138 | # which is not yet supported in onnx
139 | padded_imgs = []
140 | for img in images:
141 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
142 | padded_img = torch.nn.functional.pad(img, [0, padding[2], 0, padding[1], 0, padding[0]])
143 | padded_imgs.append(padded_img)
144 |
145 | return torch.stack(padded_imgs)
146 |
147 | def max_by_axis(self, the_list):
148 | # type: (List[List[int]]) -> List[int]
149 | maxes = the_list[0]
150 | for sublist in the_list[1:]:
151 | for index, item in enumerate(sublist):
152 | maxes[index] = max(maxes[index], item)
153 | return maxes
154 |
155 | def batch_images(self, images, size_divisible=32):
156 | # type: (List[Tensor], int) -> Tensor
157 | """
158 | 将一批图像打包成一个batch返回(注意batch中每个tensor的shape是相同的)
159 | Args:
160 | images: 输入的一批图片
161 | size_divisible: 将图像高和宽调整到该数的整数倍
162 |
163 | Returns:
164 | batched_imgs: 打包成一个batch后的tensor数据
165 | """
166 |
167 | if torchvision._is_tracing():
168 | # batch_images() does not export well to ONNX
169 | # call _onnx_batch_images() instead
170 | return self._onnx_batch_images(images, size_divisible)
171 |
172 | # 分别计算一个batch中所有图片中的最大channel, height, width
173 | max_size = self.max_by_axis([list(img.shape) for img in images])
174 |
175 | stride = float(size_divisible)
176 | # max_size = list(max_size)
177 | # 将height向上调整到stride的整数倍
178 | max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
179 | # 将width向上调整到stride的整数倍
180 | max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
181 |
182 | # [batch, channel, height, width]
183 | batch_shape = [len(images)] + max_size
184 |
185 | # 创建shape为batch_shape且值全部为0的tensor
186 | batched_imgs = images[0].new_full(batch_shape, 0)
187 | for img, pad_img in zip(images, batched_imgs):
188 | # 将输入images中的每张图片复制到新的batched_imgs的每张图片中,对齐左上角,保证bboxes的坐标不变
189 | # 这样保证输入到网络中一个batch的每张图片的shape相同
190 | # copy_: Copies the elements from src into self tensor and returns self
191 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
192 |
193 | return batched_imgs
194 |
195 | def postprocess(self,
196 | result, # type: List[Dict[str, Tensor]]
197 | image_shapes, # type: List[Tuple[int, int]]
198 | original_image_sizes # type: List[Tuple[int, int]]
199 | ):
200 | # type: (...) -> List[Dict[str, Tensor]]
201 | """
202 | 对网络的预测结果进行后处理(主要将bboxes还原到原图像尺度上)
203 | Args:
204 | result: list(dict), 网络的预测结果, len(result) == batch_size
205 | image_shapes: list(torch.Size), 图像预处理缩放后的尺寸, len(image_shapes) == batch_size
206 | original_image_sizes: list(torch.Size), 图像的原始尺寸, len(original_image_sizes) == batch_size
207 |
208 | Returns:
209 |
210 | """
211 | if self.training:
212 | return result
213 |
214 | # 遍历每张图片的预测信息,将boxes信息还原回原尺度
215 | for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
216 | boxes = pred["boxes"]
217 | boxes = resize_boxes(boxes, im_s, o_im_s) # 将bboxes缩放回原图像尺度上
218 | result[i]["boxes"] = boxes
219 | return result
220 |
221 | def __repr__(self):
222 | """自定义输出实例化对象的信息,可通过print打印实例信息"""
223 | format_string = self.__class__.__name__ + '('
224 | _indent = '\n '
225 | format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std)
226 | format_string += "{0}Resize(min_size={1}, max_size={2}, mode='bilinear')".format(_indent, self.min_size,
227 | self.max_size)
228 | format_string += '\n)'
229 | return format_string
230 |
231 | def forward(self,
232 | images, # type: List[Tensor]
233 | targets=None # type: Optional[List[Dict[str, Tensor]]]
234 | ):
235 | # type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]
236 | images = [img for img in images]
237 | for i in range(len(images)):
238 | image = images[i]
239 | target_index = targets[i] if targets is not None else None
240 |
241 | if image.dim() != 3:
242 | raise ValueError("images is expected to be a list of 3d tensors "
243 | "of shape [C, H, W], got {}".format(image.shape))
244 | image = self.normalize(image) # 对图像进行标准化处理
245 | image, target_index = self.resize(image, target_index) # 对图像和对应的bboxes缩放到指定范围
246 | images[i] = image
247 | if targets is not None and target_index is not None:
248 | targets[i] = target_index
249 |
250 | # 记录resize后的图像尺寸
251 | image_sizes = [img.shape[-2:] for img in images]
252 | images = self.batch_images(images) # 将images打包成一个batch
253 | image_sizes_list = torch.jit.annotate(List[Tuple[int, int]], [])
254 |
255 | for image_size in image_sizes:
256 | assert len(image_size) == 2
257 | image_sizes_list.append((image_size[0], image_size[1]))
258 |
259 | image_list = ImageList(images, image_sizes_list)
260 | return image_list, targets
261 |
262 |
263 | def resize_boxes(boxes, original_size, new_size):
264 | # type: (Tensor, List[int], List[int]) -> Tensor
265 | """
266 | 将boxes参数根据图像的缩放情况进行相应缩放
267 |
268 | Arguments:
269 | original_size: 图像缩放前的尺寸
270 | new_size: 图像缩放后的尺寸
271 | """
272 | ratios = [
273 | torch.tensor(s, dtype=torch.float32, device=boxes.device) /
274 | torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
275 | for s, s_orig in zip(new_size, original_size)
276 | ]
277 | ratios_height, ratios_width = ratios
278 | # Removes a tensor dimension, boxes [minibatch, 4]
279 | # Returns a tuple of all slices along a given dimension, already without it.
280 | xmin, ymin, xmax, ymax = boxes.unbind(1)
281 | xmin = xmin * ratios_width
282 | xmax = xmax * ratios_width
283 | ymin = ymin * ratios_height
284 | ymax = ymax * ratios_height
285 | return torch.stack((xmin, ymin, xmax, ymax), dim=1)
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
--------------------------------------------------------------------------------
/pascal_voc_classes.json:
--------------------------------------------------------------------------------
1 | {
2 | "aeroplane": 1,
3 | "bicycle": 2,
4 | "bird": 3,
5 | "boat": 4,
6 | "bottle": 5,
7 | "bus": 6,
8 | "car": 7,
9 | "cat": 8,
10 | "chair": 9,
11 | "cow": 10,
12 | "diningtable": 11,
13 | "dog": 12,
14 | "horse": 13,
15 | "motorbike": 14,
16 | "person": 15,
17 | "pottedplant": 16,
18 | "sheep": 17,
19 | "sofa": 18,
20 | "train": 19,
21 | "tvmonitor": 20
22 | }
--------------------------------------------------------------------------------
/plot_curve.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import matplotlib.pyplot as plt
3 |
4 |
5 | def plot_loss_and_lr(train_loss, learning_rate):
6 | try:
7 | x = list(range(len(train_loss)))
8 | fig, ax1 = plt.subplots(1, 1)
9 | ax1.plot(x, train_loss, 'r', label='loss')
10 | ax1.set_xlabel("step")
11 | ax1.set_ylabel("loss")
12 | ax1.set_title("Train Loss and lr")
13 | plt.legend(loc='best')
14 |
15 | ax2 = ax1.twinx()
16 | ax2.plot(x, learning_rate, label='lr')
17 | ax2.set_ylabel("learning rate")
18 | ax2.set_xlim(0, len(train_loss)) # 设置横坐标整数间隔
19 | plt.legend(loc='best')
20 |
21 | handles1, labels1 = ax1.get_legend_handles_labels()
22 | handles2, labels2 = ax2.get_legend_handles_labels()
23 | plt.legend(handles1 + handles2, labels1 + labels2, loc='upper right')
24 |
25 | fig.subplots_adjust(right=0.8) # 防止出现保存图片显示不全的情况
26 | fig.savefig('./loss_and_lr{}.png'.format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S")))
27 | plt.close()
28 | print("successful save loss curve! ")
29 | except Exception as e:
30 | print(e)
31 |
32 |
33 | def plot_map(mAP):
34 | try:
35 | x = list(range(len(mAP)))
36 | plt.plot(x, mAP, label='mAp')
37 | plt.xlabel('epoch')
38 | plt.ylabel('mAP')
39 | plt.title('Eval mAP')
40 | plt.xlim(0, len(mAP))
41 | plt.legend(loc='best')
42 | plt.savefig('./mAP.png')
43 | plt.close()
44 | print("successful save mAP curve!")
45 | except Exception as e:
46 | print(e)
47 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import json
4 |
5 | import torch
6 | import torchvision
7 | from PIL import Image
8 | import matplotlib.pyplot as plt
9 |
10 | from torchvision import transforms
11 | from network_files import FasterRCNN, FastRCNNPredictor, AnchorsGenerator
12 | from backbone import resnet50_fpn_backbone, MobileNetV2
13 | from draw_box_utils import draw_box
14 |
15 |
16 | def create_model(num_classes):
17 | # mobileNetv2+faster_RCNN
18 | # backbone = MobileNetV2().features
19 | # backbone.out_channels = 1280
20 | #
21 | # anchor_generator = AnchorsGenerator(sizes=((32, 64, 128, 256, 512),),
22 | # aspect_ratios=((0.5, 1.0, 2.0),))
23 | #
24 | # roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
25 | # output_size=[7, 7],
26 | # sampling_ratio=2)
27 | #
28 | # model = FasterRCNN(backbone=backbone,
29 | # num_classes=num_classes,
30 | # rpn_anchor_generator=anchor_generator,
31 | # box_roi_pool=roi_pooler)
32 |
33 | # resNet50+fpn+faster_RCNN
34 | # 注意,这里的norm_layer要和训练脚本中保持一致
35 | backbone = resnet50_fpn_backbone(norm_layer=torch.nn.BatchNorm2d)
36 | model = FasterRCNN(backbone=backbone, num_classes=num_classes, rpn_score_thresh=0.5)
37 |
38 | return model
39 |
40 |
41 | def time_synchronized():
42 | torch.cuda.synchronize() if torch.cuda.is_available() else None
43 | return time.time()
44 |
45 |
46 | def main():
47 | # get devices
48 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
49 | print("using {} device.".format(device))
50 |
51 | # create model
52 | model = create_model(num_classes=21)
53 |
54 | # load train weights
55 | train_weights = "./save_weights/model.pth"
56 | assert os.path.exists(train_weights), "{} file dose not exist.".format(train_weights)
57 | model.load_state_dict(torch.load(train_weights, map_location=device)["model"])
58 | model.to(device)
59 |
60 | # read class_indict
61 | label_json_path = './pascal_voc_classes.json'
62 | assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
63 | json_file = open(label_json_path, 'r')
64 | class_dict = json.load(json_file)
65 | json_file.close()
66 | category_index = {v: k for k, v in class_dict.items()}
67 |
68 | # load image
69 | original_img = Image.open("./test.jpg")
70 |
71 | # from pil image to tensor, do not normalize image
72 | data_transform = transforms.Compose([transforms.ToTensor()])
73 | img = data_transform(original_img)
74 | # expand batch dimension
75 | img = torch.unsqueeze(img, dim=0)
76 |
77 | model.eval() # 进入验证模式
78 | with torch.no_grad():
79 | # init
80 | img_height, img_width = img.shape[-2:]
81 | init_img = torch.zeros((1, 3, img_height, img_width), device=device)
82 | model(init_img)
83 |
84 | t_start = time_synchronized()
85 | predictions = model(img.to(device))[0]
86 | t_end = time_synchronized()
87 | print("inference+NMS time: {}".format(t_end - t_start))
88 |
89 | predict_boxes = predictions["boxes"].to("cpu").numpy()
90 | predict_classes = predictions["labels"].to("cpu").numpy()
91 | predict_scores = predictions["scores"].to("cpu").numpy()
92 |
93 | if len(predict_boxes) == 0:
94 | print("没有检测到任何目标!")
95 |
96 | draw_box(original_img,
97 | predict_boxes,
98 | predict_classes,
99 | predict_scores,
100 | category_index,
101 | thresh=0.5,
102 | line_thickness=3)
103 | plt.imshow(original_img)
104 | plt.show()
105 | # 保存预测的图片结果
106 | original_img.save("test_result.jpg")
107 |
108 |
109 | if __name__ == '__main__':
110 | main()
111 |
112 |
--------------------------------------------------------------------------------
/results20220611-205355.txt:
--------------------------------------------------------------------------------
1 | epoch:0 0.2680 0.4651 0.2795 0.1493 0.2999 0.3332 0.2498 0.4026 0.4225 0.2445 0.4575 0.5215 0.6781 0.020000
2 | epoch:1 0.2827 0.4795 0.2967 0.1537 0.3183 0.3679 0.2605 0.4194 0.4410 0.2440 0.4783 0.5668 0.5997 0.020000
3 | epoch:2 0.2807 0.4727 0.2956 0.1567 0.3137 0.3573 0.2637 0.4233 0.4467 0.2567 0.4866 0.5714 0.5861 0.020000
4 | epoch:3 0.2844 0.4783 0.2985 0.1588 0.3187 0.3628 0.2658 0.4269 0.4506 0.2563 0.4863 0.5796 0.5784 0.020000
5 | epoch:4 0.2843 0.4749 0.3017 0.1605 0.3169 0.3628 0.2642 0.4198 0.4430 0.2643 0.4759 0.5508 0.5737 0.020000
6 | epoch:5 0.2834 0.4682 0.3012 0.1543 0.3219 0.3605 0.2641 0.4147 0.4340 0.2403 0.4713 0.5627 0.5693 0.020000
7 | epoch:6 0.2862 0.4717 0.3047 0.1582 0.3138 0.3741 0.2676 0.4243 0.4459 0.2526 0.4770 0.5665 0.5659 0.020000
8 | epoch:7 0.2875 0.4756 0.3069 0.1613 0.3164 0.3734 0.2695 0.4281 0.4493 0.2606 0.4798 0.5793 0.5627 0.020000
9 | epoch:8 0.3437 0.5393 0.3680 0.1964 0.3780 0.4459 0.2982 0.4704 0.4942 0.2960 0.5331 0.6295 0.4886 0.002000
10 | epoch:9 0.3517 0.5503 0.3764 0.2005 0.3893 0.4562 0.3033 0.4748 0.4983 0.2977 0.5391 0.6418 0.4670 0.002000
11 | epoch:10 0.3543 0.5540 0.3809 0.2035 0.3891 0.4591 0.3036 0.4799 0.5037 0.3071 0.5426 0.6402 0.4558 0.002000
12 | epoch:11 0.3555 0.5522 0.3816 0.2035 0.3903 0.4634 0.3041 0.4763 0.4992 0.3021 0.5390 0.6388 0.4403 0.000200
13 | epoch:12 0.3549 0.5528 0.3801 0.2039 0.3916 0.4655 0.3048 0.4770 0.5000 0.3033 0.5395 0.6428 0.4373 0.000200
14 | epoch:13 0.3574 0.5554 0.3831 0.2016 0.3943 0.4704 0.3060 0.4758 0.4987 0.2979 0.5388 0.6445 0.4355 0.000200
15 | epoch:14 0.3575 0.5554 0.3835 0.2053 0.3922 0.4657 0.3041 0.4756 0.4985 0.3038 0.5363 0.6351 0.4338 0.000200
16 |
--------------------------------------------------------------------------------
/split_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 |
5 | def main():
6 | random.seed(0) # 设置随机种子,保证随机结果可复现
7 |
8 | files_path = "./VOCdevkit/VOC2012/Annotations"
9 | assert os.path.exists(files_path), "path: '{}' does not exist.".format(files_path)
10 |
11 | val_rate = 0.2
12 |
13 | files_name = sorted([file.split(".")[0] for file in os.listdir(files_path)])
14 | files_num = len(files_name)
15 | val_index = random.sample(range(0, files_num), k=int(files_num*val_rate))
16 | train_files = []
17 | val_files = []
18 | for index, file_name in enumerate(files_name):
19 | if index in val_index:
20 | val_files.append(file_name)
21 | else:
22 | train_files.append(file_name)
23 |
24 | try:
25 | train_f = open("train.txt", "x")
26 | eval_f = open("val.txt", "x")
27 | train_f.write("\n".join(train_files))
28 | eval_f.write("\n".join(val_files))
29 | except FileExistsError as e:
30 | print(e)
31 | exit(1)
32 |
33 |
34 | if __name__ == '__main__':
35 | main()
36 |
--------------------------------------------------------------------------------
/train_mobilenetv2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 |
4 | import torch
5 | import torchvision
6 |
7 | import transforms
8 | from network_files import FasterRCNN, AnchorsGenerator
9 | from backbone import MobileNetV2, vgg
10 | from my_dataset import VOCDataSet
11 | from train_utils import GroupedBatchSampler, create_aspect_ratio_groups
12 | from train_utils import train_eval_utils as utils
13 |
14 |
15 | def create_model(num_classes):
16 | # https://download.pytorch.org/models/vgg16-397923af.pth
17 | # 如果使用vgg16的话就下载对应预训练权重并取消下面注释,接着把mobilenetv2模型对应的两行代码注释掉
18 | # vgg_feature = vgg(model_name="vgg16", weights_path="./backbone/vgg16.pth").features
19 | # backbone = torch.nn.Sequential(*list(vgg_feature._modules.values())[:-1]) # 删除features中最后一个Maxpool层
20 | # backbone.out_channels = 512
21 |
22 | # https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
23 | backbone = MobileNetV2(weights_path="./backbone/mobilenet_v2.pth").features
24 | backbone.out_channels = 1280 # 设置对应backbone输出特征矩阵的channels
25 |
26 | anchor_generator = AnchorsGenerator(sizes=((32, 64, 128, 256, 512),),
27 | aspect_ratios=((0.5, 1.0, 2.0),))
28 |
29 | roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], # 在哪些特征层上进行roi pooling
30 | output_size=[7, 7], # roi_pooling输出特征矩阵尺寸
31 | sampling_ratio=2) # 采样率
32 |
33 | model = FasterRCNN(backbone=backbone,
34 | num_classes=num_classes,
35 | rpn_anchor_generator=anchor_generator,
36 | box_roi_pool=roi_pooler)
37 |
38 | return model
39 |
40 |
41 | def main():
42 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
43 | print("Using {} device training.".format(device.type))
44 |
45 | # 用来保存coco_info的文件
46 | results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
47 |
48 | # 检查保存权重文件夹是否存在,不存在则创建
49 | if not os.path.exists("save_weights"):
50 | os.makedirs("save_weights")
51 |
52 | data_transform = {
53 | "train": transforms.Compose([transforms.ToTensor(),
54 | transforms.RandomHorizontalFlip(0.5)]),
55 | "val": transforms.Compose([transforms.ToTensor()])
56 | }
57 |
58 | VOC_root = "./" # VOCdevkit
59 | aspect_ratio_group_factor = 3
60 | batch_size = 8
61 | amp = False # 是否使用混合精度训练,需要GPU支持
62 |
63 | # check voc root
64 | if os.path.exists(os.path.join(VOC_root, "VOCdevkit")) is False:
65 | raise FileNotFoundError("VOCdevkit dose not in path:'{}'.".format(VOC_root))
66 |
67 | # load train data set
68 | # VOCdevkit -> VOC2012 -> ImageSets -> Main -> train.txt
69 | train_dataset = VOCDataSet(VOC_root, "2012", data_transform["train"], "train.txt")
70 | train_sampler = None
71 |
72 | # 是否按图片相似高宽比采样图片组成batch
73 | # 使用的话能够减小训练时所需GPU显存,默认使用
74 | if aspect_ratio_group_factor >= 0:
75 | train_sampler = torch.utils.data.RandomSampler(train_dataset)
76 | # 统计所有图像高宽比例在bins区间中的位置索引
77 | group_ids = create_aspect_ratio_groups(train_dataset, k=aspect_ratio_group_factor)
78 | # 每个batch图片从同一高宽比例区间中取
79 | train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, batch_size)
80 |
81 | nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
82 | print('Using %g dataloader workers' % nw)
83 |
84 | # 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batch
85 | if train_sampler:
86 | # 如果按照图片高宽比采样图片,dataloader中需要使用batch_sampler
87 | train_data_loader = torch.utils.data.DataLoader(train_dataset,
88 | batch_sampler=train_batch_sampler,
89 | pin_memory=True,
90 | num_workers=nw,
91 | collate_fn=train_dataset.collate_fn)
92 | else:
93 | train_data_loader = torch.utils.data.DataLoader(train_dataset,
94 | batch_size=batch_size,
95 | shuffle=True,
96 | pin_memory=True,
97 | num_workers=nw,
98 | collate_fn=train_dataset.collate_fn)
99 |
100 | # load validation data set
101 | # VOCdevkit -> VOC2012 -> ImageSets -> Main -> val.txt
102 | val_dataset = VOCDataSet(VOC_root, "2012", data_transform["val"], "val.txt")
103 | val_data_loader = torch.utils.data.DataLoader(val_dataset,
104 | batch_size=1,
105 | shuffle=False,
106 | pin_memory=True,
107 | num_workers=nw,
108 | collate_fn=val_dataset.collate_fn)
109 |
110 | # create model num_classes equal background + 20 classes
111 | model = create_model(num_classes=21)
112 | # print(model)
113 |
114 | model.to(device)
115 |
116 | scaler = torch.cuda.amp.GradScaler() if amp else None
117 |
118 | train_loss = []
119 | learning_rate = []
120 | val_map = []
121 |
122 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
123 | # first frozen backbone and train 5 epochs #
124 | # 首先冻结前置特征提取网络权重(backbone),训练rpn以及最终预测网络部分 #
125 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
126 | for param in model.backbone.parameters():
127 | param.requires_grad = False
128 |
129 | # define optimizer
130 | params = [p for p in model.parameters() if p.requires_grad]
131 | optimizer = torch.optim.SGD(params, lr=0.005,
132 | momentum=0.9, weight_decay=0.0005)
133 |
134 | init_epochs = 5
135 | for epoch in range(init_epochs):
136 | # train for one epoch, printing every 10 iterations
137 | mean_loss, lr = utils.train_one_epoch(model, optimizer, train_data_loader,
138 | device, epoch, print_freq=50,
139 | warmup=True, scaler=scaler)
140 | train_loss.append(mean_loss.item())
141 | learning_rate.append(lr)
142 |
143 | # evaluate on the test dataset
144 | coco_info = utils.evaluate(model, val_data_loader, device=device)
145 |
146 | # write into txt
147 | with open(results_file, "a") as f:
148 | # 写入的数据包括coco指标还有loss和learning rate
149 | result_info = [str(round(i, 4)) for i in coco_info + [mean_loss.item()]] + [str(round(lr, 6))]
150 | txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
151 | f.write(txt + "\n")
152 |
153 | val_map.append(coco_info[1]) # pascal mAP
154 |
155 | torch.save(model.state_dict(), "./save_weights/pretrain.pth")
156 |
157 | # # # # # # # # # # # # # # # # # # # # # # # # # # # #
158 | # second unfrozen backbone and train all network #
159 | # 解冻前置特征提取网络权重(backbone),接着训练整个网络权重 #
160 | # # # # # # # # # # # # # # # # # # # # # # # # # # # #
161 |
162 | # 冻结backbone部分底层权重
163 | for name, parameter in model.backbone.named_parameters():
164 | split_name = name.split(".")[0]
165 | if split_name in ["0", "1", "2", "3"]:
166 | parameter.requires_grad = False
167 | else:
168 | parameter.requires_grad = True
169 |
170 | # define optimizer
171 | params = [p for p in model.parameters() if p.requires_grad]
172 | optimizer = torch.optim.SGD(params, lr=0.005,
173 | momentum=0.9, weight_decay=0.0005)
174 | # learning rate scheduler
175 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
176 | step_size=3,
177 | gamma=0.33)
178 | num_epochs = 20
179 | for epoch in range(init_epochs, num_epochs+init_epochs, 1):
180 | # train for one epoch, printing every 50 iterations
181 | mean_loss, lr = utils.train_one_epoch(model, optimizer, train_data_loader,
182 | device, epoch, print_freq=50,
183 | warmup=True, scaler=scaler)
184 | train_loss.append(mean_loss.item())
185 | learning_rate.append(lr)
186 |
187 | # update the learning rate
188 | lr_scheduler.step()
189 |
190 | # evaluate on the test dataset
191 | coco_info = utils.evaluate(model, val_data_loader, device=device)
192 |
193 | # write into txt
194 | with open(results_file, "a") as f:
195 | # 写入的数据包括coco指标还有loss和learning rate
196 | result_info = [str(round(i, 4)) for i in coco_info + [mean_loss.item()]] + [str(round(lr, 6))]
197 | txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
198 | f.write(txt + "\n")
199 |
200 | val_map.append(coco_info[1]) # pascal mAP
201 |
202 | # save weights
203 | # 仅保存最后5个epoch的权重
204 | if epoch in range(num_epochs+init_epochs)[-5:]:
205 | save_files = {
206 | 'model': model.state_dict(),
207 | 'optimizer': optimizer.state_dict(),
208 | 'lr_scheduler': lr_scheduler.state_dict(),
209 | 'epoch': epoch}
210 | torch.save(save_files, "./save_weights/mobile-model-{}.pth".format(epoch))
211 |
212 | # plot loss and lr curve
213 | if len(train_loss) != 0 and len(learning_rate) != 0:
214 | from plot_curve import plot_loss_and_lr
215 | plot_loss_and_lr(train_loss, learning_rate)
216 |
217 | # plot mAP curve
218 | if len(val_map) != 0:
219 | from plot_curve import plot_map
220 | plot_map(val_map)
221 |
222 |
223 | if __name__ == "__main__":
224 | main()
225 |
--------------------------------------------------------------------------------
/train_multi_GPU.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import datetime
4 |
5 | import torch
6 |
7 | import transforms
8 | from my_dataset import CocoDetection
9 | from backbone import resnet50_fpn_backbone, resnet101_fpn_backbone
10 | from network_files import FasterRCNN, FastRCNNPredictor
11 | import train_utils.train_eval_utils as utils
12 | from train_utils import GroupedBatchSampler, create_aspect_ratio_groups, init_distributed_mode, save_on_master, mkdir
13 |
14 |
15 | def create_model(num_classes, load_pretrain_weights=True):
16 | # 注意,这里的backbone默认使用的是FrozenBatchNorm2d,即不会去更新bn参数
17 | # 目的是为了防止batch_size太小导致效果更差(如果显存很小,建议使用默认的FrozenBatchNorm2d)
18 | # 如果GPU显存很大可以设置比较大的batch_size就可以将norm_layer设置为普通的BatchNorm2d
19 | # trainable_layers包括['layer4', 'layer3', 'layer2', 'layer1', 'conv1'], 5代表全部训练
20 | # resnet50 imagenet weights url: https://download.pytorch.org/models/resnet50-0676ba61.pth
21 | backbone = resnet50_fpn_backbone(norm_layer=torch.nn.BatchNorm2d,
22 | trainable_layers=4)
23 | #backbone = resnet101_fpn_backbone(norm_layer=torch.nn.BatchNorm2d,
24 | # trainable_layers=4)
25 | # 训练自己数据集时不要修改这里的91,修改的是传入的num_classes参数
26 | model = FasterRCNN(backbone=backbone, num_classes=91)
27 |
28 | if load_pretrain_weights:
29 | # 载入预训练模型权重
30 | # https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
31 | weights_dict = torch.load("./backbone/fasterrcnn_resnet50_fpn_coco.pth", map_location='cpu')
32 | missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)
33 | if len(missing_keys) != 0 or len(unexpected_keys) != 0:
34 | print("missing_keys: ", missing_keys)
35 | print("unexpected_keys: ", unexpected_keys)
36 |
37 | # get number of input features for the classifier
38 | in_features = model.roi_heads.box_predictor.cls_score.in_features
39 | # replace the pre-trained head with a new one
40 | model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
41 |
42 | return model
43 |
44 |
45 | def main(args):
46 | init_distributed_mode(args)
47 | print(args)
48 |
49 | device = torch.device(args.device)
50 |
51 | # 用来保存coco_info的文件
52 | results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
53 |
54 | # Data loading code
55 | print("Loading data")
56 |
57 | data_transform = {
58 | "train": transforms.Compose([transforms.ToTensor(),
59 | transforms.RandomHorizontalFlip(0.5)]),
60 | "val": transforms.Compose([transforms.ToTensor()])
61 | }
62 |
63 | COCO_root = args.data_path
64 |
65 | # load train data set
66 | # coco2017 -> annotations -> instances_train2017.json
67 | train_dataset = CocoDetection(COCO_root, "train", data_transform["train"])
68 |
69 | # load validation data set
70 | # coco2017 -> annotations -> instances_val2017.json
71 | val_dataset = CocoDetection(COCO_root, "val", data_transform["val"])
72 |
73 | print("Creating data loaders")
74 | if args.distributed:
75 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
76 | test_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
77 | else:
78 | train_sampler = torch.utils.data.RandomSampler(train_dataset)
79 | test_sampler = torch.utils.data.SequentialSampler(val_dataset)
80 |
81 | if args.aspect_ratio_group_factor >= 0:
82 | # 统计所有图像比例在bins区间中的位置索引
83 | group_ids = create_aspect_ratio_groups(train_dataset, k=args.aspect_ratio_group_factor)
84 | train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
85 | else:
86 | train_batch_sampler = torch.utils.data.BatchSampler(
87 | train_sampler, args.batch_size, drop_last=True)
88 |
89 | data_loader = torch.utils.data.DataLoader(
90 | train_dataset, batch_sampler=train_batch_sampler, num_workers=args.workers,
91 | collate_fn=train_dataset.collate_fn)
92 |
93 | data_loader_test = torch.utils.data.DataLoader(
94 | val_dataset, batch_size=1,
95 | sampler=test_sampler, num_workers=args.workers,
96 | collate_fn=train_dataset.collate_fn)
97 |
98 | print("Creating model")
99 | # create model num_classes equal background + classes
100 | model = create_model(num_classes=args.num_classes + 1)
101 | model.to(device)
102 |
103 | model_without_ddp = model
104 | if args.distributed:
105 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
106 | model_without_ddp = model.module
107 |
108 | params = [p for p in model.parameters() if p.requires_grad]
109 | optimizer = torch.optim.SGD(
110 | params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
111 |
112 | scaler = torch.cuda.amp.GradScaler() if args.amp else None
113 |
114 | # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
115 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
116 |
117 | # 如果传入resume参数,即上次训练的权重地址,则接着上次的参数训练
118 | if args.resume:
119 | # If map_location is missing, torch.load will first load the module to CPU
120 | # and then copy each parameter to where it was saved,
121 | # which would result in all processes on the same machine using the same set of devices.
122 | checkpoint = torch.load(args.resume, map_location='cpu') # 读取之前保存的权重文件(包括优化器以及学习率策略)
123 | model_without_ddp.load_state_dict(checkpoint['model'])
124 | optimizer.load_state_dict(checkpoint['optimizer'])
125 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
126 | args.start_epoch = checkpoint['epoch'] + 1
127 | if args.amp and "scaler" in checkpoint:
128 | scaler.load_state_dict(checkpoint["scaler"])
129 |
130 | train_loss = []
131 | learning_rate = []
132 | val_map = []
133 |
134 | print("Start training")
135 | start_time = time.time()
136 | for epoch in range(args.start_epoch, args.epochs):
137 | if args.distributed:
138 | train_sampler.set_epoch(epoch)
139 | mean_loss, lr = utils.train_one_epoch(model, optimizer, data_loader,
140 | device, epoch, args.print_freq,
141 | warmup=True, scaler=scaler)
142 |
143 | # update learning rate
144 | lr_scheduler.step()
145 |
146 | # evaluate after every epoch
147 | coco_info = utils.evaluate(model, data_loader_test, device=device)
148 |
149 | # 只在主进程上进行写操作
150 | if args.rank in [-1, 0]:
151 | train_loss.append(mean_loss.item())
152 | learning_rate.append(lr)
153 | val_map.append(coco_info[1]) # pascal mAP
154 |
155 | # write into txt
156 | with open(results_file, "a") as f:
157 | # 写入的数据包括coco指标还有loss和learning rate
158 | result_info = [f"{i:.4f}" for i in coco_info + [mean_loss.item()]] + [f"{lr:.6f}"]
159 | txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
160 | f.write(txt + "\n")
161 |
162 | if args.output_dir:
163 | # 只在主节点上执行保存权重操作
164 | save_files = {'model': model_without_ddp.state_dict(),
165 | 'optimizer': optimizer.state_dict(),
166 | 'lr_scheduler': lr_scheduler.state_dict(),
167 | 'args': args,
168 | 'epoch': epoch}
169 | if args.amp:
170 | save_files["scaler"] = scaler.state_dict()
171 | save_on_master(save_files,
172 | os.path.join(args.output_dir, f'model_{epoch}.pth'))
173 |
174 | total_time = time.time() - start_time
175 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
176 | print('Training time {}'.format(total_time_str))
177 |
178 | if args.rank in [-1, 0]:
179 | # plot loss and lr curve
180 | if len(train_loss) != 0 and len(learning_rate) != 0:
181 | from plot_curve import plot_loss_and_lr
182 | plot_loss_and_lr(train_loss, learning_rate)
183 |
184 | # plot mAP curve
185 | if len(val_map) != 0:
186 | from plot_curve import plot_map
187 | plot_map(val_map)
188 |
189 |
190 | if __name__ == "__main__":
191 | import argparse
192 |
193 | parser = argparse.ArgumentParser(
194 | description=__doc__)
195 |
196 | # 训练文件的根目录(coco2017)
197 | parser.add_argument('--data-path', default='/root/wu_datasets/ReCurrentPapper/data/coco2017', help='dataset')
198 | # 训练设备类型
199 | parser.add_argument('--device', default='cuda', help='device')
200 | # 检测目标类别数(不包含背景)
201 | parser.add_argument('--num-classes', default=90, type=int, help='num_classes')
202 | # 每块GPU上的batch_size
203 | parser.add_argument('-b', '--batch-size', default=2, type=int,
204 | help='images per gpu, the total batch size is $NGPU x batch_size')
205 | # 指定接着从哪个epoch数开始训练
206 | parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
207 | # 训练的总epoch数
208 | parser.add_argument('--epochs', default=12, type=int, metavar='N',
209 | help='number of total epochs to run')
210 | # 数据加载以及预处理的线程数
211 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
212 | help='number of data loading workers (default: 4)')
213 | # 学习率,这个需要根据gpu的数量以及batch_size进行设置0.02 / 8 * num_GPU
214 | parser.add_argument('--lr', default=0.02, type=float,
215 | help='initial learning rate, 0.02 is the default value for training '
216 | 'on 8 gpus and 2 images_per_gpu')
217 | # SGD的momentum参数
218 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
219 | help='momentum')
220 | # SGD的weight_decay参数
221 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
222 | metavar='W', help='weight decay (default: 1e-4)',
223 | dest='weight_decay')
224 | # 针对torch.optim.lr_scheduler.StepLR的参数
225 | parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
226 | # 针对torch.optim.lr_scheduler.MultiStepLR的参数
227 | parser.add_argument('--lr-steps', default=[8, 11], nargs='+', type=int,
228 | help='decrease lr every step-size epochs')
229 | # 针对torch.optim.lr_scheduler.MultiStepLR的参数
230 | parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
231 | # 训练过程打印信息的频率
232 | parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
233 | # 文件保存地址
234 | parser.add_argument('--output-dir', default='./multi_train', help='path where to save')
235 | # 基于上次的训练结果接着训练
236 | parser.add_argument('--resume', default='', help='resume from checkpoint')
237 | parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
238 |
239 | # 开启的进程数(注意不是线程)
240 | parser.add_argument('--world-size', default=4, type=int,
241 | help='number of distributed processes')
242 | parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
243 | # 是否使用混合精度训练(需要GPU支持混合精度)
244 | parser.add_argument("--amp", default=False, help="Use torch.cuda.amp for mixed precision training")
245 |
246 | args = parser.parse_args()
247 |
248 | # 如果指定了保存文件地址,检查文件夹是否存在,若不存在,则创建
249 | if args.output_dir:
250 | mkdir(args.output_dir)
251 |
252 | main(args)
253 |
--------------------------------------------------------------------------------
/train_res50_fpn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 |
4 | import torch
5 | import math
6 | import transforms
7 | from network_files import FasterRCNN, FastRCNNPredictor
8 | from backbone import resnet50_fpn_backbone
9 | from my_dataset import CocoDetection
10 | from train_utils import GroupedBatchSampler, create_aspect_ratio_groups
11 | from train_utils import train_eval_utils as utils
12 | from network_files import CosineAnnealingWarmbootingLR
13 |
14 |
15 | def create_model(num_classes):
16 | # 注意,这里的backbone默认使用的是FrozenBatchNorm2d,即不会去更新bn参数
17 | # 目的是为了防止batch_size太小导致效果更差(如果显存很小,建议使用默认的FrozenBatchNorm2d)
18 | # 如果GPU显存很大可以设置比较大的batch_size就可以将norm_layer设置为普通的BatchNorm2d
19 | # trainable_layers包括['layer4', 'layer3', 'layer2', 'layer1', 'conv1'], 5代表全部训练
20 | backbone = resnet50_fpn_backbone(norm_layer=torch.nn.BatchNorm2d,
21 | trainable_layers=4)
22 | # 训练自己数据集时不要修改这里的91,修改的是传入的num_classes参数
23 | model = FasterRCNN(backbone=backbone, num_classes=91)
24 | # 载入预训练模型权重
25 | # https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
26 | weights_dict = torch.load("./backbone/fasterrcnn_resnet50_fpn_coco.pth", map_location='cpu')
27 | missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)
28 | if len(missing_keys) != 0 or len(unexpected_keys) != 0:
29 | print("missing_keys: ", missing_keys)
30 | print("unexpected_keys: ", unexpected_keys)
31 |
32 | # get number of input features for the classifier
33 | in_features = model.roi_heads.box_predictor.cls_score.in_features
34 | # replace the pre-trained head with a new one
35 | model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
36 |
37 | return model
38 |
39 |
40 | def main(parser_data):
41 | device = torch.device(parser_data.device if torch.cuda.is_available() else "cpu")
42 | print("Using {} device training.".format(device.type))
43 |
44 | # 用来保存coco_info的文件
45 | results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
46 |
47 | data_transform = {
48 | "train": transforms.Compose([transforms.ToTensor(),
49 | transforms.RandomHorizontalFlip(0.5)]),
50 | "val": transforms.Compose([transforms.ToTensor()])
51 | }
52 |
53 | COCO_root = args.data_path
54 |
55 | # load train data set
56 | # coco2017 -> annotations -> instances_train2017.json
57 | train_dataset = CocoDetection(COCO_root, "train", data_transform["train"])
58 | train_sampler = None
59 |
60 | # 是否按图片相似高宽比采样图片组成batch
61 | # 使用的话能够减小训练时所需GPU显存,默认使用
62 | if args.aspect_ratio_group_factor >= 0:
63 | train_sampler = torch.utils.data.RandomSampler(train_dataset)
64 | # 统计所有图像高宽比例在bins区间中的位置索引
65 | group_ids = create_aspect_ratio_groups(train_dataset, k=args.aspect_ratio_group_factor)
66 | # 每个batch图片从同一高宽比例区间中取
67 | train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
68 |
69 | # 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batch
70 | batch_size = parser_data.batch_size
71 | nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
72 | print('Using %g dataloader workers' % nw)
73 | if train_sampler:
74 | # 如果按照图片高宽比采样图片,dataloader中需要使用batch_sampler
75 | train_data_loader = torch.utils.data.DataLoader(train_dataset,
76 | batch_sampler=train_batch_sampler,
77 | pin_memory=True,
78 | num_workers=nw,
79 | collate_fn=train_dataset.collate_fn)
80 | else:
81 | train_data_loader = torch.utils.data.DataLoader(train_dataset,
82 | batch_size=batch_size,
83 | shuffle=True,
84 | pin_memory=True,
85 | num_workers=nw,
86 | collate_fn=train_dataset.collate_fn)
87 |
88 | # load validation data set
89 | # coco2017 -> annotations -> instances_val2017.json
90 | val_dataset = CocoDetection(COCO_root, "val", data_transform["val"])
91 | val_data_loader = torch.utils.data.DataLoader(val_dataset,
92 | batch_size=1,
93 | shuffle=False,
94 | pin_memory=True,
95 | num_workers=nw,
96 | collate_fn=train_dataset.collate_fn)
97 | # create model num_classes equal background + 20 classes
98 | model = create_model(num_classes=parser_data.num_classes + 1)
99 | # print(model)
100 |
101 | model.to(device)
102 |
103 | # define optimizer
104 | params = [p for p in model.parameters() if p.requires_grad]
105 | optimizer = torch.optim.SGD(params, lr=0.005,
106 | momentum=0.9, weight_decay=0.0005)
107 |
108 | scaler = torch.cuda.amp.GradScaler() if args.amp else None
109 |
110 | # learning rate scheduler
111 | # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
112 | # step_size=3,
113 | # gamma=0.33)
114 | lf = lambda x, y=args.epochs: (((1 + math.cos(x * math.pi / y)) / 2) ** 1.0) * 0.8 + 0.2
115 | # lf = lambda x, y=opt.epochs: (1.0 - (x / y)) * 0.9 + 0.1
116 | lr_scheduler = CosineAnnealingWarmbootingLR(optimizer, epochs=args.epochs, steps=args.cawb_steps, step_scale=0.7,
117 | lf=lf, batchs=len(train_dataset), warmup_epoch=0)
118 |
119 | # 如果指定了上次训练保存的权重文件地址,则接着上次结果接着训练
120 | if parser_data.resume != "":
121 | checkpoint = torch.load(parser_data.resume, map_location='cpu')
122 | model.load_state_dict(checkpoint['model'])
123 | optimizer.load_state_dict(checkpoint['optimizer'])
124 | # lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
125 | parser_data.start_epoch = checkpoint['epoch'] + 1
126 | if args.amp and "scaler" in checkpoint:
127 | scaler.load_state_dict(checkpoint["scaler"])
128 | print("the training process from epoch{}...".format(parser_data.start_epoch))
129 |
130 | train_loss = []
131 | learning_rate = []
132 | val_map = []
133 |
134 | for epoch in range(parser_data.start_epoch, parser_data.epochs):
135 | # train for one epoch, printing every 10 iterations
136 | mean_loss, lr = utils.train_one_epoch(model, optimizer, train_data_loader,
137 | device=device, epoch=epoch,
138 | print_freq=50, warmup=True,
139 | scaler=scaler)
140 | train_loss.append(mean_loss.item())
141 | learning_rate.append(lr)
142 |
143 | # update the learning rate
144 | lr_scheduler.step()
145 |
146 | # evaluate on the test dataset
147 | coco_info = utils.evaluate(model, val_data_loader, device=device)
148 |
149 | # write into txt
150 | with open(results_file, "a") as f:
151 | # 写入的数据包括coco指标还有loss和learning rate
152 | result_info = [str(round(i, 4)) for i in coco_info + [mean_loss.item()]] + [str(round(lr, 6))]
153 | txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
154 | f.write(txt + "\n")
155 |
156 | val_map.append(coco_info[1]) # pascal mAP
157 |
158 | # save weights
159 | save_files = {
160 | 'model': model.state_dict(),
161 | 'optimizer': optimizer.state_dict(),
162 | # 'lr_scheduler': lr_scheduler.state_dict(),
163 | 'epoch': epoch}
164 | if args.amp:
165 | save_files["scaler"] = scaler.state_dict()
166 | torch.save(save_files, "./save_weights/resNetFpn-model-{}.pth".format(epoch))
167 |
168 | # plot loss and lr curve
169 | if len(train_loss) != 0 and len(learning_rate) != 0:
170 | from plot_curve import plot_loss_and_lr
171 | plot_loss_and_lr(train_loss, learning_rate)
172 |
173 | # plot mAP curve
174 | if len(val_map) != 0:
175 | from plot_curve import plot_map
176 | plot_map(val_map)
177 |
178 |
179 | if __name__ == "__main__":
180 | import argparse
181 |
182 | parser = argparse.ArgumentParser(
183 | description=__doc__)
184 |
185 | # 训练设备类型
186 | parser.add_argument('--device', default='cuda:1', help='device')
187 | # 训练数据集的根目录(VOCdevkit)
188 | parser.add_argument('--data-path', default='./coco2017', help='dataset')
189 | # 检测目标类别数(不包含背景)
190 | parser.add_argument('--num-classes', default=90, type=int, help='num_classes')
191 | # 文件保存地址
192 | parser.add_argument('--output-dir', default='./save_weights', help='path where to save')
193 | # 若需要接着上次训练,则指定上次训练保存权重文件地址
194 | parser.add_argument('--resume', default='', type=str, help='resume from checkpoint')
195 | # 指定接着从哪个epoch数开始训练
196 | parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
197 | # 训练的总epoch数
198 | parser.add_argument('--epochs', default=150, type=int, metavar='N',
199 | help='number of total epochs to run')
200 | # 训练的batch size
201 | parser.add_argument('--batch_size', default=16, type=int, metavar='N',
202 | help='batch size when training.')
203 | parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
204 | # 是否使用混合精度训练(需要GPU支持混合精度)
205 | parser.add_argument("--amp", default=False, help="Use torch.cuda.amp for mixed precision training")
206 | # 是否使用cawb(余弦退火)训练
207 | parser.add_argument('--cawb_steps', nargs='+', type=int, default=[50, 100, 150],
208 | help='the cawb learning rate scheduler steps')
209 |
210 | args = parser.parse_args()
211 | print(args)
212 |
213 | # 检查保存权重文件夹是否存在,不存在则创建
214 | if not os.path.exists(args.output_dir):
215 | os.makedirs(args.output_dir)
216 |
217 | main(args)
218 |
--------------------------------------------------------------------------------
/train_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
2 | from .distributed_utils import init_distributed_mode, save_on_master, mkdir
3 | from .coco_eval import EvalCOCOMetric
4 |
--------------------------------------------------------------------------------
/train_utils/coco_eval.py:
--------------------------------------------------------------------------------
1 | import json
2 | import copy
3 |
4 | import numpy as np
5 | from pycocotools.coco import COCO
6 | from pycocotools.cocoeval import COCOeval
7 | import pycocotools.mask as mask_util
8 | from .distributed_utils import all_gather, is_main_process
9 |
10 |
11 | def merge(img_ids, eval_results):
12 | """将多个进程之间的数据汇总在一起"""
13 | all_img_ids = all_gather(img_ids)
14 | all_eval_results = all_gather(eval_results)
15 |
16 | merged_img_ids = []
17 | for p in all_img_ids:
18 | merged_img_ids.extend(p)
19 |
20 | merged_eval_results = []
21 | for p in all_eval_results:
22 | merged_eval_results.extend(p)
23 |
24 | merged_img_ids = np.array(merged_img_ids)
25 |
26 | # keep only unique (and in sorted order) images
27 | # 去除重复的图片索引,多GPU训练时为了保证每个进程的训练图片数量相同,可能将一张图片分配给多个进程
28 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
29 | merged_eval_results = [merged_eval_results[i] for i in idx]
30 |
31 | return list(merged_img_ids), merged_eval_results
32 |
33 |
34 | class EvalCOCOMetric:
35 | def __init__(self,
36 | coco: COCO = None,
37 | iou_type: str = None,
38 | results_file_name: str = "predict_results.json",
39 | classes_mapping: dict = None):
40 | self.coco = copy.deepcopy(coco)
41 | self.img_ids = [] # 记录每个进程处理图片的ids
42 | self.results = []
43 | self.aggregation_results = None
44 | self.classes_mapping = classes_mapping
45 | self.coco_evaluator = None
46 | assert iou_type in ["bbox", "segm", "keypoints"]
47 | self.iou_type = iou_type
48 | self.results_file_name = results_file_name
49 |
50 | def prepare_for_coco_detection(self, targets, outputs):
51 | """将预测的结果转换成COCOeval指定的格式,针对目标检测任务"""
52 | # 遍历每张图像的预测结果
53 | for target, output in zip(targets, outputs):
54 | if len(output) == 0:
55 | continue
56 |
57 | img_id = int(target["image_id"])
58 | if img_id in self.img_ids:
59 | # 防止出现重复的数据
60 | continue
61 | self.img_ids.append(img_id)
62 | per_image_boxes = output["boxes"]
63 | # 对于coco_eval, 需要的每个box的数据格式为[x_min, y_min, w, h]
64 | # 而我们预测的box格式是[x_min, y_min, x_max, y_max],所以需要转下格式
65 | per_image_boxes[:, 2:] -= per_image_boxes[:, :2]
66 | per_image_classes = output["labels"].tolist()
67 | per_image_scores = output["scores"].tolist()
68 |
69 | res_list = []
70 | # 遍历每个目标的信息
71 | for object_score, object_class, object_box in zip(
72 | per_image_scores, per_image_classes, per_image_boxes):
73 | object_score = float(object_score)
74 | class_idx = int(object_class)
75 | if self.classes_mapping is not None:
76 | class_idx = int(self.classes_mapping[str(class_idx)])
77 | # We recommend rounding coordinates to the nearest tenth of a pixel
78 | # to reduce resulting JSON file size.
79 | object_box = [round(b, 2) for b in object_box.tolist()]
80 |
81 | res = {"image_id": img_id,
82 | "category_id": class_idx,
83 | "bbox": object_box,
84 | "score": round(object_score, 3)}
85 | res_list.append(res)
86 | self.results.append(res_list)
87 |
88 | def prepare_for_coco_segmentation(self, targets, outputs):
89 | """将预测的结果转换成COCOeval指定的格式,针对实例分割任务"""
90 | # 遍历每张图像的预测结果
91 | for target, output in zip(targets, outputs):
92 | if len(output) == 0:
93 | continue
94 |
95 | img_id = int(target["image_id"])
96 | if img_id in self.img_ids:
97 | # 防止出现重复的数据
98 | continue
99 |
100 | self.img_ids.append(img_id)
101 | per_image_masks = output["masks"]
102 | per_image_classes = output["labels"].tolist()
103 | per_image_scores = output["scores"].tolist()
104 |
105 | masks = per_image_masks > 0.5
106 |
107 | res_list = []
108 | # 遍历每个目标的信息
109 | for mask, label, score in zip(masks, per_image_classes, per_image_scores):
110 | rle = mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
111 | rle["counts"] = rle["counts"].decode("utf-8")
112 |
113 | class_idx = int(label)
114 | if self.classes_mapping is not None:
115 | class_idx = int(self.classes_mapping[str(class_idx)])
116 |
117 | res = {"image_id": img_id,
118 | "category_id": class_idx,
119 | "segmentation": rle,
120 | "score": round(score, 3)}
121 | res_list.append(res)
122 | self.results.append(res_list)
123 |
124 | def update(self, targets, outputs):
125 | if self.iou_type == "bbox":
126 | self.prepare_for_coco_detection(targets, outputs)
127 | elif self.iou_type == "segm":
128 | self.prepare_for_coco_segmentation(targets, outputs)
129 | else:
130 | raise KeyError(f"not support iou_type: {self.iou_type}")
131 |
132 | def synchronize_results(self):
133 | # 同步所有进程中的数据
134 | eval_ids, eval_results = merge(self.img_ids, self.results)
135 | self.aggregation_results = {"img_ids": eval_ids, "results": eval_results}
136 |
137 | # 主进程上保存即可
138 | if is_main_process():
139 | results = []
140 | [results.extend(i) for i in eval_results]
141 | # write predict results into json file
142 | json_str = json.dumps(results, indent=4)
143 | with open(self.results_file_name, 'w') as json_file:
144 | json_file.write(json_str)
145 |
146 | def evaluate(self):
147 | # 只在主进程上评估即可
148 | if is_main_process():
149 | # accumulate predictions from all images
150 | coco_true = self.coco
151 | coco_pre = coco_true.loadRes(self.results_file_name)
152 |
153 | self.coco_evaluator = COCOeval(cocoGt=coco_true, cocoDt=coco_pre, iouType=self.iou_type)
154 |
155 | self.coco_evaluator.evaluate()
156 | self.coco_evaluator.accumulate()
157 | print(f"IoU metric: {self.iou_type}")
158 | self.coco_evaluator.summarize()
159 |
160 | coco_info = self.coco_evaluator.stats.tolist() # numpy to list
161 | return coco_info
162 | else:
163 | return None
164 |
--------------------------------------------------------------------------------
/train_utils/distributed_utils.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict, deque
2 | import datetime
3 | import pickle
4 | import time
5 | import errno
6 | import os
7 |
8 | import torch
9 | import torch.distributed as dist
10 |
11 |
12 | class SmoothedValue(object):
13 | """Track a series of values and provide access to smoothed values over a
14 | window or the global series average.
15 | """
16 | def __init__(self, window_size=20, fmt=None):
17 | if fmt is None:
18 | fmt = "{value:.4f} ({global_avg:.4f})"
19 | self.deque = deque(maxlen=window_size) # deque简单理解成加强版list
20 | self.total = 0.0
21 | self.count = 0
22 | self.fmt = fmt
23 |
24 | def update(self, value, n=1):
25 | self.deque.append(value)
26 | self.count += n
27 | self.total += value * n
28 |
29 | def synchronize_between_processes(self):
30 | """
31 | Warning: does not synchronize the deque!
32 | """
33 | if not is_dist_avail_and_initialized():
34 | return
35 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
36 | dist.barrier()
37 | dist.all_reduce(t)
38 | t = t.tolist()
39 | self.count = int(t[0])
40 | self.total = t[1]
41 |
42 | @property
43 | def median(self): # @property 是装饰器,这里可简单理解为增加median属性(只读)
44 | d = torch.tensor(list(self.deque))
45 | return d.median().item()
46 |
47 | @property
48 | def avg(self):
49 | d = torch.tensor(list(self.deque), dtype=torch.float32)
50 | return d.mean().item()
51 |
52 | @property
53 | def global_avg(self):
54 | return self.total / self.count
55 |
56 | @property
57 | def max(self):
58 | return max(self.deque)
59 |
60 | @property
61 | def value(self):
62 | return self.deque[-1]
63 |
64 | def __str__(self):
65 | return self.fmt.format(
66 | median=self.median,
67 | avg=self.avg,
68 | global_avg=self.global_avg,
69 | max=self.max,
70 | value=self.value)
71 |
72 |
73 | def all_gather(data):
74 | """
75 | 收集各个进程中的数据
76 | Run all_gather on arbitrary picklable data (not necessarily tensors)
77 | Args:
78 | data: any picklable object
79 | Returns:
80 | list[data]: list of data gathered from each rank
81 | """
82 | world_size = get_world_size() # 进程数
83 | if world_size == 1:
84 | return [data]
85 |
86 | data_list = [None] * world_size
87 | dist.all_gather_object(data_list, data)
88 |
89 | return data_list
90 |
91 |
92 | def reduce_dict(input_dict, average=True):
93 | """
94 | Args:
95 | input_dict (dict): all the values will be reduced
96 | average (bool): whether to do average or sum
97 | Reduce the values in the dictionary from all processes so that all processes
98 | have the averaged results. Returns a dict with the same fields as
99 | input_dict, after reduction.
100 | """
101 | world_size = get_world_size()
102 | if world_size < 2: # 单GPU的情况
103 | return input_dict
104 | with torch.no_grad(): # 多GPU的情况
105 | names = []
106 | values = []
107 | # sort the keys so that they are consistent across processes
108 | for k in sorted(input_dict.keys()):
109 | names.append(k)
110 | values.append(input_dict[k])
111 | values = torch.stack(values, dim=0)
112 | dist.all_reduce(values)
113 | if average:
114 | values /= world_size
115 |
116 | reduced_dict = {k: v for k, v in zip(names, values)}
117 | return reduced_dict
118 |
119 |
120 | class MetricLogger(object):
121 | def __init__(self, delimiter="\t"):
122 | self.meters = defaultdict(SmoothedValue)
123 | self.delimiter = delimiter
124 |
125 | def update(self, **kwargs):
126 | for k, v in kwargs.items():
127 | if isinstance(v, torch.Tensor):
128 | v = v.item()
129 | assert isinstance(v, (float, int))
130 | self.meters[k].update(v)
131 |
132 | def __getattr__(self, attr):
133 | if attr in self.meters:
134 | return self.meters[attr]
135 | if attr in self.__dict__:
136 | return self.__dict__[attr]
137 | raise AttributeError("'{}' object has no attribute '{}'".format(
138 | type(self).__name__, attr))
139 |
140 | def __str__(self):
141 | loss_str = []
142 | for name, meter in self.meters.items():
143 | loss_str.append(
144 | "{}: {}".format(name, str(meter))
145 | )
146 | return self.delimiter.join(loss_str)
147 |
148 | def synchronize_between_processes(self):
149 | for meter in self.meters.values():
150 | meter.synchronize_between_processes()
151 |
152 | def add_meter(self, name, meter):
153 | self.meters[name] = meter
154 |
155 | def log_every(self, iterable, print_freq, header=None):
156 | i = 0
157 | if not header:
158 | header = ""
159 | start_time = time.time()
160 | end = time.time()
161 | iter_time = SmoothedValue(fmt='{avg:.4f}')
162 | data_time = SmoothedValue(fmt='{avg:.4f}')
163 | space_fmt = ":" + str(len(str(len(iterable)))) + "d"
164 | if torch.cuda.is_available():
165 | log_msg = self.delimiter.join([header,
166 | '[{0' + space_fmt + '}/{1}]',
167 | 'eta: {eta}',
168 | '{meters}',
169 | 'time: {time}',
170 | 'data: {data}',
171 | 'max mem: {memory:.0f}'])
172 | else:
173 | log_msg = self.delimiter.join([header,
174 | '[{0' + space_fmt + '}/{1}]',
175 | 'eta: {eta}',
176 | '{meters}',
177 | 'time: {time}',
178 | 'data: {data}'])
179 | MB = 1024.0 * 1024.0
180 | for obj in iterable:
181 | data_time.update(time.time() - end)
182 | yield obj
183 | iter_time.update(time.time() - end)
184 | if i % print_freq == 0 or i == len(iterable) - 1:
185 | eta_second = iter_time.global_avg * (len(iterable) - i)
186 | eta_string = str(datetime.timedelta(seconds=eta_second))
187 | if torch.cuda.is_available():
188 | print(log_msg.format(i, len(iterable),
189 | eta=eta_string,
190 | meters=str(self),
191 | time=str(iter_time),
192 | data=str(data_time),
193 | memory=torch.cuda.max_memory_allocated() / MB))
194 | else:
195 | print(log_msg.format(i, len(iterable),
196 | eta=eta_string,
197 | meters=str(self),
198 | time=str(iter_time),
199 | data=str(data_time)))
200 | i += 1
201 | end = time.time()
202 | total_time = time.time() - start_time
203 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
204 | print('{} Total time: {} ({:.4f} s / it)'.format(header,
205 | total_time_str,
206 |
207 | total_time / len(iterable)))
208 |
209 |
210 | def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
211 |
212 | def f(x):
213 | """根据step数返回一个学习率倍率因子"""
214 | if x >= warmup_iters: # 当迭代数大于给定的warmup_iters时,倍率因子为1
215 | return 1
216 | alpha = float(x) / warmup_iters
217 | # 迭代过程中倍率因子从warmup_factor -> 1
218 | return warmup_factor * (1 - alpha) + alpha
219 |
220 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
221 |
222 |
223 | def mkdir(path):
224 | try:
225 | os.makedirs(path)
226 | except OSError as e:
227 | if e.errno != errno.EEXIST:
228 | raise
229 |
230 |
231 | def setup_for_distributed(is_master):
232 | """
233 | This function disables when not in master process
234 | """
235 | import builtins as __builtin__
236 | builtin_print = __builtin__.print
237 |
238 | def print(*args, **kwargs):
239 | force = kwargs.pop('force', False)
240 | if is_master or force:
241 | builtin_print(*args, **kwargs)
242 |
243 | __builtin__.print = print
244 |
245 |
246 | def is_dist_avail_and_initialized():
247 | """检查是否支持分布式环境"""
248 | if not dist.is_available():
249 | return False
250 | if not dist.is_initialized():
251 | return False
252 | return True
253 |
254 |
255 | def get_world_size():
256 | if not is_dist_avail_and_initialized():
257 | return 1
258 | return dist.get_world_size()
259 |
260 |
261 | def get_rank():
262 | if not is_dist_avail_and_initialized():
263 | return 0
264 | return dist.get_rank()
265 |
266 |
267 | def is_main_process():
268 | return get_rank() == 0
269 |
270 |
271 | def save_on_master(*args, **kwargs):
272 | if is_main_process():
273 | torch.save(*args, **kwargs)
274 |
275 |
276 | def init_distributed_mode(args):
277 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
278 | args.rank = int(os.environ["RANK"])
279 | args.world_size = int(os.environ['WORLD_SIZE'])
280 | args.gpu = int(os.environ['LOCAL_RANK'])
281 | elif 'SLURM_PROCID' in os.environ:
282 | args.rank = int(os.environ['SLURM_PROCID'])
283 | args.gpu = args.rank % torch.cuda.device_count()
284 | else:
285 | print('Not using distributed mode')
286 | args.distributed = False
287 | return
288 |
289 | args.distributed = True
290 |
291 | torch.cuda.set_device(args.gpu)
292 | args.dist_backend = 'nccl'
293 | print('| distributed init (rank {}): {}'.format(
294 | args.rank, args.dist_url), flush=True)
295 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
296 | world_size=args.world_size, rank=args.rank)
297 | torch.distributed.barrier()
298 | setup_for_distributed(args.rank == 0)
299 |
300 |
--------------------------------------------------------------------------------
/train_utils/group_by_aspect_ratio.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | from collections import defaultdict
3 | import copy
4 | from itertools import repeat, chain
5 | import math
6 | import numpy as np
7 |
8 | import torch
9 | import torch.utils.data
10 | from torch.utils.data.sampler import BatchSampler, Sampler
11 | from torch.utils.model_zoo import tqdm
12 | import torchvision
13 |
14 | from PIL import Image
15 |
16 |
17 | def _repeat_to_at_least(iterable, n):
18 | repeat_times = math.ceil(n / len(iterable))
19 | repeated = chain.from_iterable(repeat(iterable, repeat_times))
20 | return list(repeated)
21 |
22 |
23 | class GroupedBatchSampler(BatchSampler):
24 | """
25 | Wraps another sampler to yield a mini-batch of indices.
26 | It enforces that the batch only contain elements from the same group.
27 | It also tries to provide mini-batches which follows an ordering which is
28 | as close as possible to the ordering from the original sampler.
29 | Arguments:
30 | sampler (Sampler): Base sampler.
31 | group_ids (list[int]): If the sampler produces indices in range [0, N),
32 | `group_ids` must be a list of `N` ints which contains the group id of each sample.
33 | The group ids must be a continuous set of integers starting from
34 | 0, i.e. they must be in the range [0, num_groups).
35 | batch_size (int): Size of mini-batch.
36 | """
37 | def __init__(self, sampler, group_ids, batch_size):
38 | if not isinstance(sampler, Sampler):
39 | raise ValueError(
40 | "sampler should be an instance of "
41 | "torch.utils.data.Sampler, but got sampler={}".format(sampler)
42 | )
43 | self.sampler = sampler
44 | self.group_ids = group_ids
45 | self.batch_size = batch_size
46 |
47 | def __iter__(self):
48 | buffer_per_group = defaultdict(list)
49 | samples_per_group = defaultdict(list)
50 |
51 | num_batches = 0
52 | for idx in self.sampler:
53 | group_id = self.group_ids[idx]
54 | buffer_per_group[group_id].append(idx)
55 | samples_per_group[group_id].append(idx)
56 | if len(buffer_per_group[group_id]) == self.batch_size:
57 | yield buffer_per_group[group_id]
58 | num_batches += 1
59 | del buffer_per_group[group_id]
60 | assert len(buffer_per_group[group_id]) < self.batch_size
61 |
62 | # now we have run out of elements that satisfy
63 | # the group criteria, let's return the remaining
64 | # elements so that the size of the sampler is
65 | # deterministic
66 | expected_num_batches = len(self)
67 | num_remaining = expected_num_batches - num_batches
68 | if num_remaining > 0:
69 | # for the remaining batches, take first the buffers with largest number
70 | # of elements
71 | for group_id, _ in sorted(buffer_per_group.items(),
72 | key=lambda x: len(x[1]), reverse=True):
73 | remaining = self.batch_size - len(buffer_per_group[group_id])
74 | samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining)
75 | buffer_per_group[group_id].extend(samples_from_group_id[:remaining])
76 | assert len(buffer_per_group[group_id]) == self.batch_size
77 | yield buffer_per_group[group_id]
78 | num_remaining -= 1
79 | if num_remaining == 0:
80 | break
81 | assert num_remaining == 0
82 |
83 | def __len__(self):
84 | return len(self.sampler) // self.batch_size
85 |
86 |
87 | def _compute_aspect_ratios_slow(dataset, indices=None):
88 | print("Your dataset doesn't support the fast path for "
89 | "computing the aspect ratios, so will iterate over "
90 | "the full dataset and load every image instead. "
91 | "This might take some time...")
92 | if indices is None:
93 | indices = range(len(dataset))
94 |
95 | class SubsetSampler(Sampler):
96 | def __init__(self, indices):
97 | self.indices = indices
98 |
99 | def __iter__(self):
100 | return iter(self.indices)
101 |
102 | def __len__(self):
103 | return len(self.indices)
104 |
105 | sampler = SubsetSampler(indices)
106 | data_loader = torch.utils.data.DataLoader(
107 | dataset, batch_size=1, sampler=sampler,
108 | num_workers=14, # you might want to increase it for faster processing
109 | collate_fn=lambda x: x[0])
110 | aspect_ratios = []
111 | with tqdm(total=len(dataset)) as pbar:
112 | for _i, (img, _) in enumerate(data_loader):
113 | pbar.update(1)
114 | height, width = img.shape[-2:]
115 | aspect_ratio = float(width) / float(height)
116 | aspect_ratios.append(aspect_ratio)
117 | return aspect_ratios
118 |
119 |
120 | def _compute_aspect_ratios_custom_dataset(dataset, indices=None):
121 | if indices is None:
122 | indices = range(len(dataset))
123 | aspect_ratios = []
124 | for i in indices:
125 | height, width = dataset.get_height_and_width(i)
126 | aspect_ratio = float(width) / float(height)
127 | aspect_ratios.append(aspect_ratio)
128 | return aspect_ratios
129 |
130 |
131 | def _compute_aspect_ratios_coco_dataset(dataset, indices=None):
132 | if indices is None:
133 | indices = range(len(dataset))
134 | aspect_ratios = []
135 | for i in indices:
136 | img_info = dataset.coco.imgs[dataset.ids[i]]
137 | aspect_ratio = float(img_info["width"]) / float(img_info["height"])
138 | aspect_ratios.append(aspect_ratio)
139 | return aspect_ratios
140 |
141 |
142 | def _compute_aspect_ratios_voc_dataset(dataset, indices=None):
143 | if indices is None:
144 | indices = range(len(dataset))
145 | aspect_ratios = []
146 | for i in indices:
147 | # this doesn't load the data into memory, because PIL loads it lazily
148 | width, height = Image.open(dataset.images[i]).size
149 | aspect_ratio = float(width) / float(height)
150 | aspect_ratios.append(aspect_ratio)
151 | return aspect_ratios
152 |
153 |
154 | def _compute_aspect_ratios_subset_dataset(dataset, indices=None):
155 | if indices is None:
156 | indices = range(len(dataset))
157 |
158 | ds_indices = [dataset.indices[i] for i in indices]
159 | return compute_aspect_ratios(dataset.dataset, ds_indices)
160 |
161 |
162 | def compute_aspect_ratios(dataset, indices=None):
163 | if hasattr(dataset, "get_height_and_width"):
164 | return _compute_aspect_ratios_custom_dataset(dataset, indices)
165 |
166 | if isinstance(dataset, torchvision.datasets.CocoDetection):
167 | return _compute_aspect_ratios_coco_dataset(dataset, indices)
168 |
169 | if isinstance(dataset, torchvision.datasets.VOCDetection):
170 | return _compute_aspect_ratios_voc_dataset(dataset, indices)
171 |
172 | if isinstance(dataset, torch.utils.data.Subset):
173 | return _compute_aspect_ratios_subset_dataset(dataset, indices)
174 |
175 | # slow path
176 | return _compute_aspect_ratios_slow(dataset, indices)
177 |
178 |
179 | def _quantize(x, bins):
180 | bins = copy.deepcopy(bins)
181 | bins = sorted(bins)
182 | # bisect_right:寻找y元素按顺序应该排在bins中哪个元素的右边,返回的是索引
183 | quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
184 | return quantized
185 |
186 |
187 | def create_aspect_ratio_groups(dataset, k=0):
188 | # 计算所有数据集中的图片width/height比例
189 | aspect_ratios = compute_aspect_ratios(dataset)
190 | # 将[0.5, 2]区间划分成2*k+1等份
191 | bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0]
192 |
193 | # 统计所有图像比例在bins区间中的位置索引
194 | groups = _quantize(aspect_ratios, bins)
195 | # count number of elements per group
196 | # 统计每个区间的频次
197 | counts = np.unique(groups, return_counts=True)[1]
198 | fbins = [0] + bins + [np.inf]
199 | print("Using {} as bins for aspect ratio quantization".format(fbins))
200 | print("Count of instances per bin: {}".format(counts))
201 | return groups
202 |
--------------------------------------------------------------------------------
/train_utils/train_eval_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | import time
4 |
5 | import torch
6 |
7 | import train_utils.distributed_utils as utils
8 | from .coco_eval import EvalCOCOMetric
9 |
10 |
11 | def train_one_epoch(model, optimizer, data_loader, device, epoch,
12 | print_freq=50, warmup=False, scaler=None):
13 | model.train()
14 | metric_logger = utils.MetricLogger(delimiter=" ")
15 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
16 | header = 'Epoch: [{}]'.format(epoch)
17 |
18 | lr_scheduler = None
19 | if epoch == 0 and warmup is True: # 当训练第一轮(epoch=0)时,启用warmup训练方式,可理解为热身训练
20 | warmup_factor = 1.0 / 1000
21 | warmup_iters = min(1000, len(data_loader) - 1)
22 |
23 | lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
24 |
25 | mloss = torch.zeros(1).to(device) # mean losses
26 | for i, [images, targets] in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
27 | images = list(image.to(device) for image in images)
28 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
29 |
30 | # 混合精度训练上下文管理器,如果在CPU环境中不起任何作用
31 | with torch.cuda.amp.autocast(enabled=scaler is not None):
32 | loss_dict = model(images, targets)
33 |
34 | losses = sum(loss for loss in loss_dict.values())
35 |
36 | # reduce losses over all GPUs for logging purpose
37 | loss_dict_reduced = utils.reduce_dict(loss_dict)
38 | losses_reduced = sum(loss for loss in loss_dict_reduced.values())
39 |
40 | loss_value = losses_reduced.item()
41 | # 记录训练损失
42 | mloss = (mloss * i + loss_value) / (i + 1) # update mean losses
43 |
44 | if not math.isfinite(loss_value): # 当计算的损失为无穷大时停止训练
45 | print("Loss is {}, stopping training".format(loss_value))
46 | print(loss_dict_reduced)
47 | sys.exit(1)
48 |
49 | optimizer.zero_grad()
50 | if scaler is not None:
51 | scaler.scale(losses).backward()
52 | scaler.step(optimizer)
53 | scaler.update()
54 | else:
55 | losses.backward()
56 | optimizer.step()
57 |
58 | if lr_scheduler is not None: # 第一轮使用warmup训练方式
59 | lr_scheduler.step()
60 |
61 | metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
62 | now_lr = optimizer.param_groups[0]["lr"]
63 | metric_logger.update(lr=now_lr)
64 |
65 | return mloss, now_lr
66 |
67 |
68 | @torch.no_grad()
69 | def evaluate(model, data_loader, device):
70 | cpu_device = torch.device("cpu")
71 | model.eval()
72 | metric_logger = utils.MetricLogger(delimiter=" ")
73 | header = "Test: "
74 |
75 | det_metric = EvalCOCOMetric(data_loader.dataset.coco, iou_type="bbox", results_file_name="det_results.json")
76 | for image, targets in metric_logger.log_every(data_loader, 100, header):
77 | image = list(img.to(device) for img in image)
78 |
79 | # 当使用CPU时,跳过GPU相关指令
80 | if device != torch.device("cpu"):
81 | torch.cuda.synchronize(device)
82 |
83 | model_time = time.time()
84 | outputs = model(image)
85 |
86 | outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
87 | model_time = time.time() - model_time
88 |
89 | det_metric.update(targets, outputs)
90 | metric_logger.update(model_time=model_time)
91 |
92 | # gather the stats from all processes
93 | metric_logger.synchronize_between_processes()
94 | print("Averaged stats:", metric_logger)
95 |
96 | # 同步所有进程中的数据
97 | det_metric.synchronize_results()
98 |
99 | if utils.is_main_process():
100 | coco_info = det_metric.evaluate()
101 | else:
102 | coco_info = None
103 |
104 | return coco_info
105 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | from torchvision.transforms import functional as F
3 |
4 |
5 | class Compose(object):
6 | """组合多个transform函数"""
7 | def __init__(self, transforms):
8 | self.transforms = transforms
9 |
10 | def __call__(self, image, target):
11 | for t in self.transforms:
12 | image, target = t(image, target)
13 | return image, target
14 |
15 |
16 | class ToTensor(object):
17 | """将PIL图像转为Tensor"""
18 | def __call__(self, image, target):
19 | image = F.to_tensor(image)
20 | return image, target
21 |
22 |
23 | class RandomHorizontalFlip(object):
24 | """随机水平翻转图像以及bboxes"""
25 | def __init__(self, prob=0.5):
26 | self.prob = prob
27 |
28 | def __call__(self, image, target):
29 | if random.random() < self.prob:
30 | height, width = image.shape[-2:]
31 | image = image.flip(-1) # 水平翻转图片
32 | bbox = target["boxes"]
33 | # bbox: xmin, ymin, xmax, ymax
34 | bbox[:, [0, 2]] = width - bbox[:, [2, 0]] # 翻转对应bbox坐标信息
35 | target["boxes"] = bbox
36 | return image, target
37 |
--------------------------------------------------------------------------------
/validation.py:
--------------------------------------------------------------------------------
1 | """
2 | 该脚本用于调用训练好的模型权重去计算验证集/测试集的COCO指标
3 | 以及每个类别的mAP(IoU=0.5)
4 | """
5 |
6 | import os
7 | import json
8 |
9 | import torch
10 | import torchvision
11 | from tqdm import tqdm
12 | import numpy as np
13 | from torchvision.models.feature_extraction import create_feature_extractor
14 |
15 | import transforms
16 | from network_files import FasterRCNN, AnchorsGenerator
17 | from my_dataset import CocoDetection
18 | from backbone import resnet50_fpn_backbone
19 | from train_utils import EvalCOCOMetric
20 |
21 |
22 | def summarize(self, catId=None):
23 | """
24 | Compute and display summary metrics for evaluation results.
25 | Note this functin can *only* be applied on the default parameter setting
26 | """
27 |
28 | def _summarize(ap=1, iouThr=None, areaRng='all', maxDets=100):
29 | p = self.params
30 | iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'
31 | titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
32 | typeStr = '(AP)' if ap == 1 else '(AR)'
33 | iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
34 | if iouThr is None else '{:0.2f}'.format(iouThr)
35 |
36 | aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
37 | mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
38 |
39 | if ap == 1:
40 | # dimension of precision: [TxRxKxAxM]
41 | s = self.eval['precision']
42 | # IoU
43 | if iouThr is not None:
44 | t = np.where(iouThr == p.iouThrs)[0]
45 | s = s[t]
46 |
47 | if isinstance(catId, int):
48 | s = s[:, :, catId, aind, mind]
49 | else:
50 | s = s[:, :, :, aind, mind]
51 |
52 | else:
53 | # dimension of recall: [TxKxAxM]
54 | s = self.eval['recall']
55 | if iouThr is not None:
56 | t = np.where(iouThr == p.iouThrs)[0]
57 | s = s[t]
58 |
59 | if isinstance(catId, int):
60 | s = s[:, catId, aind, mind]
61 | else:
62 | s = s[:, :, aind, mind]
63 |
64 | if len(s[s > -1]) == 0:
65 | mean_s = -1
66 | else:
67 | mean_s = np.mean(s[s > -1])
68 |
69 | print_string = iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s)
70 | return mean_s, print_string
71 |
72 | stats, print_list = [0] * 12, [""] * 12
73 | stats[0], print_list[0] = _summarize(1)
74 | stats[1], print_list[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])
75 | stats[2], print_list[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2])
76 | stats[3], print_list[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])
77 | stats[4], print_list[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2])
78 | stats[5], print_list[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2])
79 | stats[6], print_list[6] = _summarize(0, maxDets=self.params.maxDets[0])
80 | stats[7], print_list[7] = _summarize(0, maxDets=self.params.maxDets[1])
81 | stats[8], print_list[8] = _summarize(0, maxDets=self.params.maxDets[2])
82 | stats[9], print_list[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2])
83 | stats[10], print_list[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])
84 | stats[11], print_list[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])
85 |
86 | print_info = "\n".join(print_list)
87 |
88 | if not self.eval:
89 | raise Exception('Please run accumulate() first')
90 |
91 | return stats, print_info
92 |
93 |
94 | def main(parser_data):
95 | device = torch.device(parser_data.device if torch.cuda.is_available() else "cpu")
96 | print("Using {} device training.".format(device.type))
97 |
98 | data_transform = {
99 | "val": transforms.Compose([transforms.ToTensor()])
100 | }
101 |
102 | # read class_indict
103 | label_json_path = './coco91_indices.json'
104 | assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
105 | with open(label_json_path, 'r') as f:
106 | category_index = json.load(f)
107 |
108 | coco_root = parser_data.data_path
109 |
110 | # 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batch
111 | batch_size = parser_data.batch_size
112 | nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
113 | print('Using %g dataloader workers' % nw)
114 |
115 | # load validation data set
116 | val_dataset = CocoDetection(coco_root, "val", data_transform["val"])
117 | val_dataset_loader = torch.utils.data.DataLoader(val_dataset,
118 | batch_size=batch_size,
119 | shuffle=False,
120 | pin_memory=True,
121 | num_workers=nw,
122 | collate_fn=val_dataset.collate_fn)
123 |
124 | # create model
125 | backbone = resnet50_fpn_backbone(norm_layer=torch.nn.BatchNorm2d)
126 | model = FasterRCNN(backbone=backbone, num_classes=parser_data.num_classes + 1)
127 |
128 | # 载入你自己训练好的模型权重
129 | weights_path = parser_data.weights
130 | assert os.path.exists(weights_path), "not found {} file.".format(weights_path)
131 | model.load_state_dict(torch.load(weights_path, map_location='cpu')['model'])
132 | # print(model)
133 |
134 | model.to(device)
135 |
136 | # evaluate on the val dataset
137 | cpu_device = torch.device("cpu")
138 |
139 | det_metric = EvalCOCOMetric(val_dataset.coco, "bbox", "det_results.json")
140 | model.eval()
141 | with torch.no_grad():
142 | for image, targets in tqdm(val_dataset_loader, desc="validation..."):
143 | # 将图片传入指定设备device
144 | image = list(img.to(device) for img in image)
145 |
146 | # inference
147 | outputs = model(image)
148 | outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
149 | det_metric.update(targets, outputs)
150 |
151 | det_metric.synchronize_results()
152 | det_metric.evaluate()
153 |
154 | # calculate COCO info for all classes
155 | coco_stats, print_coco = summarize(det_metric.coco_evaluator)
156 |
157 | # calculate voc info for every classes(IoU=0.5)
158 | voc_map_info_list = []
159 | classes = [v for v in category_index.values() if v != "N/A"]
160 | for i in range(len(classes)):
161 | stats, _ = summarize(det_metric.coco_evaluator, catId=i)
162 | voc_map_info_list.append(" {:15}: {}".format(classes[i], stats[1]))
163 |
164 | print_voc = "\n".join(voc_map_info_list)
165 | print(print_voc)
166 |
167 | # 将验证结果保存至txt文件中
168 | with open("record_mAP.txt", "w") as f:
169 | record_lines = ["COCO results:",
170 | print_coco,
171 | "",
172 | "mAP(IoU=0.5) for each category:",
173 | print_voc]
174 | f.write("\n".join(record_lines))
175 |
176 |
177 | if __name__ == "__main__":
178 | import argparse
179 |
180 | parser = argparse.ArgumentParser(
181 | description=__doc__)
182 |
183 | # 使用设备类型
184 | parser.add_argument('--device', default='cuda', help='device')
185 |
186 | # 检测目标类别数
187 | parser.add_argument('--num-classes', type=int, default=90, help='number of classes')
188 |
189 | # 数据集的根目录(coco2017根目录)
190 | parser.add_argument('--data-path', default='./coco2017', help='dataset root')
191 |
192 | # 训练好的权重文件
193 | parser.add_argument('--weights', default='./multi_train/model_25.pth', type=str, help='training weights')
194 |
195 | # batch size
196 | parser.add_argument('--batch_size', default=1, type=int, metavar='N',
197 | help='batch size when validation.')
198 |
199 | args = parser.parse_args()
200 |
201 | main(args)
--------------------------------------------------------------------------------