├── 05_OHEM
└── .gitkeep
├── 07_SSD
└── .gitkeep
├── 08_R_FCN
└── .gitkeep
├── 10_FPN
└── .gitkeep
├── 15_M2Det
└── .gitkeep
├── 03_Fast_RCNN
└── .gitkeep
├── 09_YOLO_v2
└── .gitkeep
├── 11_RetinaNet
└── .gitkeep
├── 12_Mask_RCNN
└── .gitkeep
├── 13_YOLO_v3
└── .gitkeep
├── 14_RefineDet
└── .gitkeep
├── .gitignore
├── 01_RCNN
└── model.py
├── Readme.md
├── 04_Faster_RCNN
└── model.py
└── 06_YOLO_v1
├── dataset.py
└── model.py
/05_OHEM/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/07_SSD/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/08_R_FCN/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/10_FPN/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/15_M2Det/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/03_Fast_RCNN/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/09_YOLO_v2/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/11_RetinaNet/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/12_Mask_RCNN/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/13_YOLO_v3/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/14_RefineDet/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | *.gitkeep
--------------------------------------------------------------------------------
/01_RCNN/model.py:
--------------------------------------------------------------------------------
1 | """
2 | pytorch内置的GeneralizeRCNN的实现已经不是原始的RCNN了
3 | 原始的RCNN非常朴素地先提取出2000个region,然后对每个region提取特征并进去后面的分类和回归模型,计算量非常庞大
4 | GeneralizeRCNN的实现中采用了FastRCNN的思路,先提取出所有feature map,然后在原图上提取region并映射会feature map上,然后再进入最后的模型
5 |
6 | 由于原始的RCNN的效率上确实比较差,现在也几乎没有单独使用它的必要了,暂时没有什么动力用pytorch实现它,这里先占个坑吧。
7 | """
8 |
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | 本项目使用Pytorch实现(我认为的)目标检测中的里程碑模型。
2 |
在自身学习之余,我希望能帮助读者们通过阅读风格一致的代码,来学习、使用这些模型,减少学习成本。
3 | 因为能力有限或者其他原因,我不能保证模型都完美地复现论文的细节,只希望能抓(一些)大放小。同时,我本人也不期望这里的模型能达到生产级别。
4 | 同时我想提醒读者,本项目仅仅提供了一些模型上的信息,事实上本项目涉及的这些里程碑模型在别的地方也很有可取之处,甚至起到了至关重要的作用。比如yolo v2在ImageNet的finetune过程,这些
5 | 都是本项目力所不逮的地方。
6 |
7 | This Repository uses Pytorch to implement milestone models(in my opinion) of objection detection.
8 |
I hope it can help you study and use these great models with less learning cost by same style codes.
9 | Due to limited capacity, I won't implement all the details of origin papers but only focus on (some)key points.
10 | Besides, models here are not necessarily good enough to be used directly for production environment. And I really want to remind you that this repository only provide some infomation of models, while these related milestone actually did something great beyond model architecture, which may be even more important than models. For example, yolo v2 finetune model on ImageNet to make model adaptive to 448 * 448 image.
11 |
12 |
13 | # progress
14 | + [ ] 01_RCNN
15 | + [ ] 02_OverFeat
16 | + [ ] 03_Fast_RCNN
17 | + [x] 04_Faster_RCNN
18 | + [ ] 05_OHEM
19 | + [x] 06_YOLO_v1
20 | + [ ] 07_SSD
21 | + [ ] 08_R_FCN
22 | + [ ] 09_YOLO_v2
23 | + [ ] 10_FPN
24 | + [ ] 11_RetinaNet
25 | + [ ] 12_Mask_RCNN
26 | + [ ] 13_YOLO_v3
27 | + [ ] 14_RefineDet
28 | + [ ] 15_M2Det
29 |
30 | # requirements
31 | python 3.7
32 |
opencv 3.4.2
33 | pytorch 1.4.0
34 |
35 | # running
36 | Basically you can run train_test.py in every folder.
37 |
38 | # note
39 | I have some unnecessary preference on format and order and something like that. I hope the folders' name wouldn't
40 | bother you.
41 |
42 | # reference
43 |
--------------------------------------------------------------------------------
/04_Faster_RCNN/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torchvision.models.detection import FasterRCNN
4 | from torchvision.models.detection.rpn import AnchorGenerator
5 |
6 | """
7 | 笔记
8 | 1. Faster RCNN的步骤
9 | a. Transform
10 | 1) normalize
11 | 2) resize
12 | b. 利用backbone网络来提取特征,下面的demo使用的是mobileNetV2的特征提取网络,即
13 | 1) 1个ConvBNRelu
14 | 2) 17个倒置残差:ConvBNRelu + Conv2d + BatchNorm2d
15 | 3) 1个ConvBNRelu
16 | c. 用RPN网络提取ROI,并计算它的指标(iou)作为loss之一
17 | d. 坐标位置回归,计算loss之二
18 | """
19 |
20 |
21 | def demo():
22 | # load a pre-trained model for classification and return
23 | # only the features
24 | backbone = torchvision.models.mobilenet_v2(pretrained=True).features
25 | # FasterRCNN needs to know the number of
26 | # output channels in a backbone. For mobilenet_v2, it's 1280
27 | # so we need to add it here
28 | backbone.out_channels = 1280
29 | # let's make the RPN generate 5 x 3 anchors per spatial
30 | # location, with 5 different sizes and 3 different aspect
31 | # ratios. We have a Tuple[Tuple[int]] because each feature
32 | # map could potentially have different sizes and
33 | # aspect ratios
34 | anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
35 | aspect_ratios=((0.5, 1.0, 2.0),))
36 | # let's define what are the feature maps that we will
37 | # use to perform the region of interest cropping, as well as
38 | # the size of the crop after rescaling.
39 | # if your backbone returns a Tensor, featmap_names is expected to
40 | # be [0]. More generally, the backbone should return an
41 | # OrderedDict[Tensor], and in featmap_names you can choose which
42 | # feature maps to use.
43 | roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
44 | output_size=7,
45 | sampling_ratio=2)
46 | # put the pieces together inside a FasterRCNN model
47 | model = FasterRCNN(backbone,
48 | num_classes=2,
49 | rpn_anchor_generator=anchor_generator,
50 | box_roi_pool=roi_pooler)
51 | model.eval()
52 | x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
53 | predictions = model(x)
54 | print(predictions)
55 |
56 |
57 | if __name__ == '__main__':
58 | demo()
59 |
--------------------------------------------------------------------------------
/06_YOLO_v1/dataset.py:
--------------------------------------------------------------------------------
1 | import xml.etree.ElementTree as et
2 | from torch.utils.data import Dataset
3 | from torchvision.transforms import transforms
4 | import numpy as np
5 | import cv2
6 |
7 | CLASSES = ['person', 'bird', 'cat', 'cow', 'dog', 'horse', 'sheep',
8 | 'aeroplane', 'bicycle', 'boat', 'bus', 'car', 'motorbike', 'train',
9 | 'bottle', 'chair', 'dining table', 'potted plant', 'sofa', 'tvmonitor']
10 | DATASET_PATH = '00_Data'
11 | NUM_BBOX = 2
12 |
13 |
14 | def convert(size, box):
15 | """
16 | 将bbox的左上角点、右下角点坐标的格式,转换为bbox中心点+bbox的w,h的格式
17 | 并进行归一化
18 | """
19 | dw = 1. / size[0]
20 | dh = 1. / size[1]
21 | x = (box[0] + box[1]) / 2.0
22 | y = (box[2] + box[3]) / 2.0
23 | w = box[1] - box[0]
24 | h = box[3] - box[2]
25 | x = x * dw
26 | w = w * dw
27 | y = y * dh
28 | h = h * dh
29 | return x, y, w, h
30 |
31 |
32 | def convert_annotation(image_id):
33 | """
34 | 把图像image_id的xml文件转换为目标检测的label文件(txt)
35 | 其中包含物体的类别,bbox的左上角点坐标以及bbox的宽、高
36 | 并将四个物理量归一化
37 | """
38 | in_file = open(f"{DATASET_PATH}Annotations{image_id}")
39 | tree = et.parse(in_file)
40 | root = tree.getroot()
41 | size = root.find('size')
42 | w = int(size.find('width').text)
43 | h = int(size.find('height').text)
44 |
45 | image_id = image_id.split('.')[0]
46 | out_file = open(f'./labels/{image_id}.txt', 'w')
47 | for obj in root.iter('object'):
48 | difficult = obj.find('difficult').text
49 | cls = obj.find('name').text
50 | if cls not in CLASSES or int(difficult) == 1:
51 | continue
52 | cls_id = CLASSES.index(cls)
53 | xmlbox = obj.find('bndbox')
54 | points = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
55 | float(xmlbox.find('ymax').text))
56 | bb = convert((w, h), points)
57 | out_file.write(f"{cls_id} {''.join([str(a) for a in bb])}\n")
58 |
59 |
60 | class VOC2012(Dataset):
61 | def __init__(self, is_train=True, is_aug=True):
62 | """
63 | :param is_train: 调用的是训练集(True),还是验证集(False)
64 | :param is_aug: 是否进行数据增广
65 | """
66 | self.filenames = [] # 储存数据集的文件名称
67 | if is_train:
68 | with open(DATASET_PATH + "ImageSets/Main/train.txt", 'r') as f: # 调用包含训练集图像名称的txt文件
69 | self.filenames = [x.strip() for x in f]
70 | else:
71 | with open(DATASET_PATH + "ImageSets/Main/val.txt", 'r') as f:
72 | self.filenames = [x.strip() for x in f]
73 | self.imgpath = DATASET_PATH + "JPEGImages/" # 原始图像所在的路径
74 | self.labelpath = "./labels/" # 图像对应的label文件(.txt文件)的路径
75 | self.is_aug = is_aug
76 |
77 | def __len__(self):
78 | return len(self.filenames)
79 |
80 | def __getitem__(self, item):
81 | img = cv2.imread(self.imgpath + self.filenames[item] + ".jpg") # 读取原始图像
82 | h, w = img.shape[0:2]
83 | input_size = 448 # 输入YOLOv1网络的图像尺寸为448x448
84 | # 因为数据集内原始图像的尺寸是不定的,所以需要进行适当的padding,将原始图像padding成宽高一致的正方形
85 | # 然后再将Padding后的正方形图像缩放成448x448
86 | padw, padh = 0, 0 # 要记录宽高方向的padding具体数值,因为padding之后需要调整bbox的位置信息
87 | if h > w:
88 | padw = (h - w) // 2
89 | img = np.pad(img, ((0, 0), (padw, padw), (0, 0)), 'constant', constant_values=0)
90 | elif w > h:
91 | padh = (w - h) // 2
92 | img = np.pad(img, ((padh, padh), (0, 0), (0, 0)), 'constant', constant_values=0)
93 | img = cv2.resize(img, (input_size, input_size))
94 | # 图像增广部分,这里不做过多处理,因为改变bbox信息还蛮麻烦的
95 | if self.is_aug:
96 | aug = transforms.Compose([
97 | transforms.ToTensor()
98 | ])
99 | img = aug(img)
100 |
101 | # 读取图像对应的bbox信息,按1维的方式储存,每5个元素表示一个bbox的(cls,xc,yc,w,h)
102 | with open(self.labelpath + self.filenames[item] + ".txt") as f:
103 | bbox = f.read().split('\n')
104 | bbox = [x.split() for x in bbox]
105 | bbox = [float(x) for y in bbox for x in y]
106 | if len(bbox) % 5 != 0:
107 | raise ValueError("File:" + self.labelpath + self.filenames[item] + ".txt" + "——bbox Extraction Error!")
108 |
109 | # 根据padding、图像增广等操作,将原始的bbox数据转换为修改后图像的bbox数据
110 | for i in range(len(bbox) // 5):
111 | if padw != 0:
112 | bbox[i * 5 + 1] = (bbox[i * 5 + 1] * w + padw) / h
113 | bbox[i * 5 + 3] = (bbox[i * 5 + 3] * w) / h
114 | elif padh != 0:
115 | bbox[i * 5 + 2] = (bbox[i * 5 + 2] * h + padh) / w
116 | bbox[i * 5 + 4] = (bbox[i * 5 + 4] * h) / w
117 | # 此处可以写代码验证一下,查看padding后修改的bbox数值是否正确,在原图中画出bbox检验
118 |
119 | labels = convert_bbox2labels(bbox) # 将所有bbox的(cls,x,y,w,h)数据转换为训练时方便计算Loss的数据形式(7,7,5*B+cls_num)
120 | # 此处可以写代码验证一下,经过convert_bbox2labels函数后得到的labels变量中储存的数据是否正确
121 | labels = transforms.ToTensor()(labels)
122 | return img, labels
123 |
124 |
125 | def convert_bbox2labels(bbox):
126 | """
127 | 将bbox的(cls,x,y,w,h)数据转换为训练时方便计算Loss的数据形式(7,7,5*B+cls_num)
128 | 注意,输入的bbox的信息是(xc,yc,w,h)格式的,转换为labels后,bbox的信息转换为了(px,py,w,h)格式
129 | """
130 | grid_size = 1.0 / 7
131 | labels = np.zeros((7, 7, 5 * NUM_BBOX + len(CLASSES))) # 注意,此处需要根据不同数据集的类别个数进行修改
132 | for i in range(len(bbox) // 5):
133 | grid_x = int(bbox[i * 5 + 1] // grid_size) # 当前bbox中心落在第gridx个网格,列
134 | grid_y = int(bbox[i * 5 + 2] // grid_size) # 当前bbox中心落在第gridy个网格,行
135 | # (bbox中心坐标 - 网格左上角点的坐标)/网格大小 ==> bbox中心点的相对位置
136 | grid_px = bbox[i * 5 + 1] / grid_size - grid_x
137 | grid_py = bbox[i * 5 + 2] / grid_size - grid_y
138 | # 将第gridy行,gridx列的网格设置为负责当前ground truth的预测,置信度和对应类别概率均置为1
139 | labels[grid_y, grid_x, 0:5] = np.array([grid_px, grid_py, bbox[i * 5 + 3], bbox[i * 5 + 4], 1])
140 | labels[grid_y, grid_x, 5:10] = np.array([grid_px, grid_py, bbox[i * 5 + 3], bbox[i * 5 + 4], 1])
141 | labels[grid_y, grid_x, 10 + int(bbox[i * 5])] = 1
142 | return labels
143 |
--------------------------------------------------------------------------------
/06_YOLO_v1/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from torchvision.models.resnet import resnet18
5 | from torch.utils.data import DataLoader
6 |
7 | from .dataset import NUM_BBOX, CLASSES, VOC2012
8 |
9 |
10 | class Loss_yolov1(nn.Module):
11 | def __init__(self):
12 | super(Loss_yolov1, self).__init__()
13 |
14 | def forward(self, pred, labels):
15 | """
16 | :param pred: (batchsize,30,7,7)的网络输出数据
17 | :param labels: (batchsize,30,7,7)的样本标签数据
18 | :return: 当前批次样本的平均损失
19 | """
20 | num_gridx, num_gridy = labels.size()[-2:] # 划分网格数量
21 | num_b = 2 # 每个网格的bbox数量
22 | num_cls = 20 # 类别数量
23 | noobj_confi_loss = 0. # 不含目标的网格损失(只有置信度损失)
24 | coor_loss = 0. # 含有目标的bbox的坐标损失
25 | obj_confi_loss = 0. # 含有目标的bbox的置信度损失
26 | class_loss = 0. # 含有目标的网格的类别损失
27 | n_batch = labels.size()[0] # batchsize的大小
28 |
29 | # 可以考虑用矩阵运算进行优化,提高速度,为了准确起见,这里还是用循环
30 | for i in range(n_batch): # batchsize循环
31 | for n in range(7): # x方向网格循环
32 | for m in range(7): # y方向网格循环
33 | if labels[i, 4, m, n] == 1: # 如果包含物体
34 | # 将数据(px,py,w,h)转换为(x1,y1,x2,y2)
35 | # 先将px,py转换为cx,cy,即相对网格的位置转换为标准化后实际的bbox中心位置cx,xy
36 | # 然后再利用(cx-w/2,cy-h/2,cx+w/2,cy+h/2)转换为xyxy形式,用于计算iou
37 | bbox1_pred_xyxy = ((pred[i, 0, m, n] + m) / num_gridx - pred[i, 2, m, n] / 2,
38 | (pred[i, 1, m, n] + n) / num_gridy - pred[i, 3, m, n] / 2,
39 | (pred[i, 0, m, n] + m) / num_gridx + pred[i, 2, m, n] / 2,
40 | (pred[i, 1, m, n] + n) / num_gridy + pred[i, 3, m, n] / 2)
41 | bbox2_pred_xyxy = ((pred[i, 5, m, n] + m) / num_gridx - pred[i, 7, m, n] / 2,
42 | (pred[i, 6, m, n] + n) / num_gridy - pred[i, 8, m, n] / 2,
43 | (pred[i, 5, m, n] + m) / num_gridx + pred[i, 7, m, n] / 2,
44 | (pred[i, 6, m, n] + n) / num_gridy + pred[i, 8, m, n] / 2)
45 | bbox_gt_xyxy = ((labels[i, 0, m, n] + m) / num_gridx - labels[i, 2, m, n] / 2,
46 | (labels[i, 1, m, n] + n) / num_gridy - labels[i, 3, m, n] / 2,
47 | (labels[i, 0, m, n] + m) / num_gridx + labels[i, 2, m, n] / 2,
48 | (labels[i, 1, m, n] + n) / num_gridy + labels[i, 3, m, n] / 2)
49 | iou1 = calculate_iou(bbox1_pred_xyxy, bbox_gt_xyxy)
50 | iou2 = calculate_iou(bbox2_pred_xyxy, bbox_gt_xyxy)
51 | # 选择iou大的bbox作为负责物体
52 | if iou1 >= iou2:
53 | coor_loss = coor_loss + 5 * (torch.sum((pred[i, 0:2, m, n] - labels[i, 0:2, m, n]) ** 2) \
54 | + torch.sum(
55 | (pred[i, 2:4, m, n].sqrt() - labels[i, 2:4, m, n].sqrt()) ** 2))
56 | obj_confi_loss = obj_confi_loss + (pred[i, 4, m, n] - iou1) ** 2
57 | # iou比较小的bbox不负责预测物体,因此confidence loss算在noobj中,注意,对于标签的置信度应该是iou2
58 | noobj_confi_loss = noobj_confi_loss + 0.5 * ((pred[i, 9, m, n] - iou2) ** 2)
59 | else:
60 | coor_loss = coor_loss + 5 * (torch.sum((pred[i, 5:7, m, n] - labels[i, 5:7, m, n]) ** 2) \
61 | + torch.sum(
62 | (pred[i, 7:9, m, n].sqrt() - labels[i, 7:9, m, n].sqrt()) ** 2))
63 | obj_confi_loss = obj_confi_loss + (pred[i, 9, m, n] - iou2) ** 2
64 | # iou比较小的bbox不负责预测物体,因此confidence loss算在noobj中,注意,对于标签的置信度应该是iou1
65 | noobj_confi_loss = noobj_confi_loss + 0.5 * ((pred[i, 4, m, n] - iou1) ** 2)
66 | class_loss = class_loss + torch.sum((pred[i, 10:, m, n] - labels[i, 10:, m, n]) ** 2)
67 | else: # 如果不包含物体
68 | noobj_confi_loss = noobj_confi_loss + 0.5 * torch.sum(pred[i, [4, 9], m, n] ** 2)
69 |
70 | loss = coor_loss + obj_confi_loss + noobj_confi_loss + class_loss
71 | # 此处可以写代码验证一下loss的大致计算是否正确,这个要验证起来比较麻烦,比较简洁的办法是,将输入的pred置为全1矩阵,再进行误差检查,会直观很多。
72 | return loss / n_batch
73 |
74 |
75 | def calculate_iou(bbox1, bbox2):
76 | """计算bbox1=(x1,y1,x2,y2)和bbox2=(x3,y3,x4,y4)两个bbox的iou"""
77 | intersect_bbox = [0., 0., 0., 0.] # bbox1和bbox2的交集
78 | if bbox1[2] < bbox2[0] or bbox1[0] > bbox2[2] or bbox1[3] < bbox2[1] or bbox1[1] > bbox2[3]:
79 | pass
80 | else:
81 | intersect_bbox[0] = max(bbox1[0], bbox2[0])
82 | intersect_bbox[1] = max(bbox1[1], bbox2[1])
83 | intersect_bbox[2] = min(bbox1[2], bbox2[2])
84 | intersect_bbox[3] = min(bbox1[3], bbox2[3])
85 |
86 | area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) # bbox1面积
87 | area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) # bbox2面积
88 | area_intersect = (intersect_bbox[2] - intersect_bbox[0]) * (intersect_bbox[3] - intersect_bbox[1]) # 交集面积
89 | # print(bbox1,bbox2)
90 | # print(intersect_bbox)
91 | # input()
92 |
93 | if area_intersect > 0:
94 | return area_intersect / (area1 + area2 - area_intersect) # 计算iou
95 | else:
96 | return 0
97 |
98 |
99 | class YOLO_v1(nn.module):
100 | def __init__(self):
101 | super(YOLO_v1, self).__init__()
102 | resnet = resnet18(pretrained=True) # 调用torchvision里的resnet34预训练模型
103 | resnet_out_channel = resnet.fc.in_features # 记录resnet全连接层之前的网络输出通道数,方便连入后续卷积网络中
104 | self.resnet = nn.Sequential(*list(resnet.children())[:-2]) # 去除resnet的最后两层
105 | # 以下是YOLOv1的最后四个卷积层
106 | self.Conv_layers = nn.Sequential(
107 | nn.Conv2d(resnet_out_channel, 1024, 3, padding=1),
108 | nn.BatchNorm2d(1024), # 为了加快训练,这里增加了BN层,原论文里YOLOv1是没有的
109 | nn.LeakyReLU(),
110 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1),
111 | nn.BatchNorm2d(1024),
112 | nn.LeakyReLU(),
113 | nn.Conv2d(1024, 1024, 3, padding=1),
114 | nn.BatchNorm2d(1024),
115 | nn.LeakyReLU(),
116 | nn.Conv2d(1024, 1024, 3, padding=1),
117 | nn.BatchNorm2d(1024),
118 | nn.LeakyReLU(),
119 | )
120 | # 以下是YOLOv1的最后2个全连接层
121 | self.Conn_layers = nn.Sequential(
122 | nn.Linear(7 * 7 * 1024, 4096),
123 | nn.LeakyReLU(),
124 | nn.Linear(4096, 7 * 7 * 30),
125 | nn.Sigmoid() # 增加sigmoid函数是为了将输出全部映射到(0,1)之间,因为如果出现负数或太大的数,后续计算loss会很麻烦
126 | )
127 |
128 | def forward(self, input_):
129 | input_ = self.resnet(input_)
130 | input_ = self.Conv_layers(input_)
131 | input_ = input_.view(input_.size()[0], -1)
132 | input_ = self.Conn_layers(input_)
133 | return input_.reshape(-1, (5 * NUM_BBOX + len(CLASSES)), 7, 7) # 记住最后要reshape一下输出数据
134 |
135 |
136 | def demo():
137 | epoch = 50
138 | batchsize = 5
139 | lr = 0.01
140 |
141 | trainData = VOC2012()
142 | trainDataLoader = DataLoader(VOC2012(is_train=True), batch_size=batchsize, shuffle=True)
143 |
144 | model = YOLO_v1().cuda()
145 | # model.children()里是按模块(Sequential)提取的子模块,而不是具体到每个层,具体可以参见pytorch帮助文档
146 | # 冻结resnet34特征提取层,特征提取层不参与参数更新
147 | for layer in model.children():
148 | layer.requires_grad = False
149 | break
150 | criterion = Loss_yolov1()
151 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
152 |
153 | for e in range(epoch):
154 | model.train()
155 | yl = torch.Tensor([0]).cuda()
156 | for i, (inputs, labels) in enumerate(trainDataLoader):
157 | inputs = inputs.cuda()
158 | labels = labels.float().cuda()
159 | pred = model(inputs)
160 | loss = criterion(pred, labels)
161 | optimizer.zero_grad()
162 | loss.backward()
163 | optimizer.step()
164 |
165 | print("Epoch %d/%d| Step %d/%d| Loss: %.2f" % (e, epoch, i, len(trainData) // batchsize, loss))
166 | yl = yl + loss
167 | if (e + 1) % 10 == 0:
168 | torch.save(model, "./models_pkl/YOLOv1_epoch" + str(e + 1) + ".pkl")
169 | # compute_val_map(model)
170 |
--------------------------------------------------------------------------------