12 |
13 | EfficientDet 算法中文介绍:[EfficientDet_CN.md](./EfficientDet_CN.md)
14 |
15 | > 本项目以一个真实比赛的数据集,Step by Step演示如何训练最近开源的相对SOTA的Pytorch版的EfficientDet的训练,评估,推断的过程。像paper中提到的一样,我们并没有使用任何数据增强或模型融合等后处理的trick来提高模型的精度,如果你想增加数据增强的策略可以在`efficientdet/dataset.py`中实现;
16 | >
17 | > 此外我们并没有采用类似于[UWGAN_UIE](https://github.com/DataXujing/UWGAN_UIE),水质迁移(WQT),DG-YOLO或一些水下去雾算法的办法,预处理水下的图像;
18 | >
19 | > 相信这些trick同样会提高模型识别的精度!!!
20 |
21 | ### 1.数据来源
22 |
23 | 数据来源于[科赛网中的水下目标检测的比赛](https://www.kesci.com/home/competition/5e535a612537a0002ca864ac/content/2):
24 |
25 | 
26 |
27 | **大赛简介**
28 |
29 | 「背景」 随着海洋观测的快速发展,水下物体检测在海军沿海防御任务以及渔业、水产养殖等海洋经济中发挥着越来越重要的作用,而水下图像是海洋信息的重要载体,本次比赛希望参赛者在真实海底图片数据中通过算法检测出不同海产品(海参、海胆、扇贝、海星)的位置。
30 |
31 | 
32 |
33 | 「数据」 训练集是5543张 jpg 格式的水下光学图像与对应标注结果,A榜测试集800张,B榜测试集1200张。
34 |
35 | 「评估指标」 mAP(mean Average Precision)
36 |
37 | > 注:数据由鹏城实验室提供。
38 |
39 | ### 2.据转换
40 |
41 | 我们将数据存放在项目的dataset下:
42 |
43 | ```
44 | ..
45 | └─underwater
46 | ├─Annotations #xml标注
47 | └─JPEGImages #jpg原图
48 | # 首先划分训练集和验证集:我们采用9:1的随机换分,划分好的数据等待转化为COCO数据
49 | ```
50 |
51 | 划分训练集和验证集后的数据结构:
52 |
53 | ```
54 | ..
55 | ├─train
56 | │ ├─Annotations
57 | │ └─JPEGImages
58 | └─val
59 | ├─Annotations
60 | └─JPEGImages
61 | ```
62 |
63 | 将VOC转COCO:
64 |
65 | ```
66 | python voc2coco.py train.txt ./train/Annotations instances_train.json ./train/JPEGImages
67 | python voc2coco.py val.txt ./val/Annotations instances_val.json ./val/JPEGImages
68 | # 生成的json文件存放在了dataset/underwater/annotations/*.jpg
69 | ```
70 |
71 |
72 | ### 3.修改EfficientDet的项目文件
73 |
74 | 1.新建dataset文件夹用以存放训练和验证数据
75 |
76 | ```
77 | dataset
78 | └─underwater # 项目数据集名称
79 | ├─annotations # instances_train.json,instances_val.json
80 | ├─train # train jpgs
81 | └─val # val jpgs
82 | ```
83 |
84 | 2.新建logs文件夹
85 |
86 | logs存放了训练过程中的tensprboardX保存的日志及模型
87 |
88 | 3.修改train.py[训练使用]
89 |
90 | ```
91 | def get_args():
92 | parser.add_argument('-p', '--project', type=str, default='underwater', help='project file that contains parameters')
93 | parser.add_argument('--batch_size', type=int, default=16, help='The number of images per batch among all devices')
94 |
95 | ```
96 |
97 | 4.修改efficientdet_test.py[测试新图像使用]
98 |
99 | ```
100 | # obj_list = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
101 | # 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
102 | # 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie',
103 | # 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
104 | # 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
105 | # 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
106 | # 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv',
107 | # 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
108 | # 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
109 | # 'toothbrush']
110 |
111 | obj_list = ["holothurian","echinus","scallop","starfish"]# 换成自己的
112 | compound_coef = 2 # D0-D6
113 | model.load_state_dict(torch.load("./logs/underwater/efficientdet-d2_122_38106.pth")) # 模型地址
114 | ```
115 |
116 | 5.修改coco_eval.py[评估模型使用]
117 |
118 | ```
119 | ap.add_argument('-p', '--project', type=str, default='underwater', help='project file that contains parameters')
120 | ```
121 |
122 | 6.修改efficientdet/config.py
123 |
124 | ```
125 | # COCO_CLASSES = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
126 | # "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog",
127 | # "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
128 | # "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite",
129 | # "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
130 | # "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
131 | # "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant",
132 | # "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
133 | # "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
134 | # "teddy bear", "hair drier", "toothbrush"]
135 | COCO_CLASSES = ["holothurian","echinus","scallop","starfish"]
136 | ```
137 |
138 | 7.新建yml配置文件(./projects/underwater.yml)[训练的配置文件]
139 |
140 | ```
141 | project_name: underwater # also the folder name of the dataset that under data_path folder
142 | train_set: train
143 | val_set: val
144 | num_gpus: 1
145 |
146 | # mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco.
147 | mean: [0.485, 0.456, 0.406]
148 | std: [0.229, 0.224, 0.225]
149 |
150 | # this is coco anchors, change it if necessary
151 | anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]'
152 | anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]'
153 |
154 | # must match your dataset's category_id.
155 | # category_id is one_indexed,
156 | # for example, index of 'car' here is 2, while category_id of is 3
157 | obj_list: ["holothurian","echinus","scallop","starfish"]
158 |
159 | ```
160 |
161 |
162 |
163 | ### 4.训练EfficientDet
164 |
165 | ```
166 | # 从头训练自己的数据集 EfficientDet-D2
167 | python train.py -c 2 --batch_size 16 --lr 1e4
168 |
169 | # train efficientdet-d2 在自己的数据集上使用预训练的模型(推荐)
170 | python train.py -c 2 --batch_size 8 --lr 1e-5 --num_epochs 10 \
171 | --load_weights /path/to/your/weights/efficientdet-d2.pth
172 |
173 | # with a coco-pretrained, you can even freeze the backbone and train heads only
174 | # to speed up training and help convergence.
175 | python train.py -c 2 --batch_size 8 --lr 1e-5 --num_epochs 10 \
176 | --load_weights /path/to/your/weights/efficientdet-d2.pth \
177 | --head_only True
178 |
179 | # Early stopping
180 | #Ctrl+c,
181 | # the program will catch KeyboardInterrupt
182 | # and stop training, save current checkpoint.
183 |
184 | # 断点训练
185 | python train.py -c 2 --batch_size 8 --lr 1e-5 \
186 | --load_weights last \
187 | --head_only True
188 | ```
189 |
190 | ### 5.测试EfficientDet
191 |
192 | 1.评估模型使用coco的map
193 |
194 | ```
195 | python coco_eval.py -p underwater -c 2 -w ./logs/underwater/efficientdet-d2_122_38106.pth
196 | ```
197 |
198 | ```
199 | # 评价结果
200 | Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.381
201 | Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.714
202 | Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.368
203 | Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.170
204 | Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.351
205 | Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.426
206 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.149
207 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.433
208 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.464
209 | Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.267
210 | Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.429
211 | Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.507
212 |
213 | ```
214 |
215 | 2.训练过程中的Debug
216 |
217 | ```
218 | # when you get bad result, you need to debug the training result.
219 | python train.py -c 2 --batch_size 8 --lr 1e-5 --debug True
220 |
221 | # then checkout test/ folder, there you can visualize the predicted boxes during training
222 | # don't panic if you see countless of error boxes, it happens when the training is at early stage.
223 | # But if you still can't see a normal box after several epoches, not even one in all image,
224 | # then it's possible that either the anchors config is inappropriate or the ground truth is corrupted.
225 | ```
226 |
227 | 3.推断新的图像
228 |
229 | ```
230 | python efficientdet_test.py
231 | ```
232 |
233 | 推断速度基本能达到实时:
234 |
235 | 
236 |
237 | 
238 |
239 | 
240 |
241 | 4.Tensorboard展示结果:
242 |
243 | ```
244 | tensorboard --logdir logs/underwater/tensorboard
245 | ```
246 |
247 | 
248 |
249 | 
250 |
251 | 
--------------------------------------------------------------------------------
/backbone.py:
--------------------------------------------------------------------------------
1 | # Author: Zylo117
2 |
3 | import math
4 |
5 | import torch
6 | from torch import nn
7 |
8 | from efficientdet.model import BiFPN, Regressor, Classifier, EfficientNet
9 | from efficientdet.utils import Anchors
10 |
11 |
12 | class EfficientDetBackbone(nn.Module):
13 | def __init__(self, num_classes=80, compound_coef=0, load_weights=False, **kwargs):
14 | super(EfficientDetBackbone, self).__init__()
15 | self.compound_coef = compound_coef
16 |
17 | self.backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6]
18 | self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384]
19 | self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8]
20 | self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
21 | self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5]
22 | self.anchor_scale = [4., 4., 4., 4., 4., 4., 4., 5.]
23 | self.aspect_ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)])
24 | self.num_scales = len(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]))
25 | conv_channel_coef = {
26 | # the channels of P3/P4/P5.
27 | 0: [40, 112, 320],
28 | 1: [40, 112, 320],
29 | 2: [48, 120, 352],
30 | 3: [48, 136, 384],
31 | 4: [56, 160, 448],
32 | 5: [64, 176, 512],
33 | 6: [72, 200, 576],
34 | 7: [72, 200, 576],
35 | }
36 |
37 | num_anchors = len(self.aspect_ratios) * self.num_scales
38 |
39 | self.bifpn = nn.Sequential(
40 | *[BiFPN(self.fpn_num_filters[self.compound_coef],
41 | conv_channel_coef[compound_coef],
42 | True if _ == 0 else False,
43 | attention=True if compound_coef < 6 else False)
44 | for _ in range(self.fpn_cell_repeats[compound_coef])])
45 |
46 | self.num_classes = num_classes
47 | self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
48 | num_layers=self.box_class_repeats[self.compound_coef])
49 | self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
50 | num_classes=num_classes,
51 | num_layers=self.box_class_repeats[self.compound_coef])
52 |
53 | self.anchors = Anchors(anchor_scale=self.anchor_scale[compound_coef], **kwargs)
54 |
55 | self.backbone_net = EfficientNet(self.backbone_compound_coef[compound_coef], load_weights)
56 |
57 | def freeze_bn(self):
58 | for m in self.modules():
59 | if isinstance(m, nn.BatchNorm2d):
60 | m.eval()
61 |
62 | def forward(self, inputs):
63 | max_size = inputs.shape[-1]
64 |
65 | _, p3, p4, p5 = self.backbone_net(inputs)
66 |
67 | features = (p3, p4, p5)
68 | features = self.bifpn(features)
69 |
70 | regression = self.regressor(features)
71 | classification = self.classifier(features)
72 | anchors = self.anchors(inputs, inputs.dtype)
73 |
74 | return features, regression, classification, anchors
75 |
76 | def init_backbone(self, path):
77 | state_dict = torch.load(path)
78 | try:
79 | ret = self.load_state_dict(state_dict, strict=False)
80 | print(ret)
81 | except RuntimeError as e:
82 | print('Ignoring ' + str(e) + '"')
83 |
--------------------------------------------------------------------------------
/benchmark/coco_eval_result:
--------------------------------------------------------------------------------
1 | efficientdet-d0
2 | Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.326
3 | Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.502
4 | Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.342
5 | Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.118
6 | Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.376
7 | Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.509
8 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.268
9 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.402
10 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.430
11 | Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.172
12 | Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.502
13 | Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.624
14 |
15 | efficientdet-d1
16 | Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.382
17 | Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.568
18 | Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.407
19 | Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.181
20 | Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.437
21 | Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.555
22 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.304
23 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.465
24 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.496
25 | Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.265
26 | Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.562
27 | Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.673
28 |
29 | efficientdet-d2
30 | Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.415
31 | Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.603
32 | Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.440
33 | Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.226
34 | Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.471
35 | Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.567
36 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.321
37 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.497
38 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.529
39 | Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.315
40 | Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.595
41 | Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.672
42 |
43 | efficientdet-d3
44 | Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.449
45 | Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.637
46 | Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.480
47 | Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.272
48 | Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.491
49 | Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.602
50 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.342
51 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.533
52 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.567
53 | Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.383
54 | Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.615
55 | Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.710
56 |
57 | efficientdet-d4
58 | Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.481
59 | Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.672
60 | Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.520
61 | Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.320
62 | Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.528
63 | Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.625
64 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.357
65 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.565
66 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.600
67 | Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.436
68 | Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.649
69 | Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.728
70 |
71 | efficientdet-d5
72 | Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.495
73 | Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.687
74 | Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.532
75 | Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.333
76 | Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.540
77 | Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.632
78 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.367
79 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.584
80 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.621
81 | Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.467
82 | Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.662
83 | Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.739
84 |
85 | efficientdet-d6
86 | Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.501
87 | Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.692
88 | Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.540
89 | Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.338
90 | Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.544
91 | Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.637
92 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.368
93 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.588
94 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.626
95 | Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.469
96 | Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.667
97 | Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.738
98 |
99 | efficientdet-d7
100 | Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.507
101 | Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.696
102 | Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.545
103 | Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.352
104 | Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.551
105 | Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.638
106 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.370
107 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.588
108 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.624
109 | Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.466
110 | Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.663
111 | Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.743
112 |
--------------------------------------------------------------------------------
/coco_eval.py:
--------------------------------------------------------------------------------
1 | # Author: Zylo117
2 |
3 | """
4 | COCO-Style Evaluations
5 |
6 | put images here datasets/your_project_name/annotations/val_set_name/*.jpg
7 | put annotations here datasets/your_project_name/annotations/instances_{val_set_name}.json
8 | put weights here /path/to/your/weights/*.pth
9 | change compound_coef
10 |
11 | """
12 |
13 | import json
14 | import os
15 |
16 | import argparse
17 | import torch
18 | import yaml
19 | from tqdm import tqdm
20 | from pycocotools.coco import COCO
21 | from pycocotools.cocoeval import COCOeval
22 |
23 | from backbone import EfficientDetBackbone
24 | from efficientdet.utils import BBoxTransform, ClipBoxes
25 | from utils.utils import preprocess, invert_affine, postprocess
26 |
27 | ap = argparse.ArgumentParser()
28 | ap.add_argument('-p', '--project', type=str, default='underwater', help='project file that contains parameters')
29 | ap.add_argument('-c', '--compound_coef', type=int, default=0, help='coefficients of efficientdet')
30 | ap.add_argument('-w', '--weights', type=str, default=None, help='/path/to/weights')
31 | ap.add_argument('--nms_threshold', type=float, default=0.5, help='nms threshold, don\'t change it if not for testing purposes')
32 | ap.add_argument('--cuda', type=bool, default=True)
33 | ap.add_argument('--device', type=int, default=0)
34 | ap.add_argument('--float16', type=bool, default=False)
35 | args = ap.parse_args()
36 |
37 | compound_coef = args.compound_coef
38 | nms_threshold = args.nms_threshold
39 | use_cuda = args.cuda
40 | gpu = args.device
41 | use_float16 = args.float16
42 | project_name = args.project
43 | weights_path = 'weights/efficientdet-d{}.pth'.format(compound_coef) if args.weights is None else args.weights
44 |
45 | print('running coco-style evaluation on project {}, weights {}...'.format(project_name,weights_path))
46 |
47 | params = yaml.safe_load(open('projects/{}.yml'.format(project_name)))
48 | obj_list = params['obj_list']
49 |
50 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
51 |
52 |
53 | def evaluate_coco(img_path, set_name, image_ids, coco, model, threshold=0.05):
54 | results = []
55 | processed_image_ids = []
56 |
57 | regressBoxes = BBoxTransform()
58 | clipBoxes = ClipBoxes()
59 |
60 | for image_id in tqdm(image_ids):
61 | image_info = coco.loadImgs(image_id)[0]
62 | image_path = img_path + image_info['file_name']
63 |
64 | ori_imgs, framed_imgs, framed_metas = preprocess(image_path, max_size=input_sizes[compound_coef])
65 | x = torch.from_numpy(framed_imgs[0])
66 |
67 | if use_cuda:
68 | x = x.cuda(gpu)
69 | if use_float16:
70 | x = x.half()
71 | else:
72 | x = x.float()
73 | else:
74 | x = x.float()
75 |
76 | x = x.unsqueeze(0).permute(0, 3, 1, 2)
77 | features, regression, classification, anchors = model(x)
78 |
79 | preds = postprocess(x,
80 | anchors, regression, classification,
81 | regressBoxes, clipBoxes,
82 | threshold, nms_threshold)
83 |
84 | processed_image_ids.append(image_id)
85 |
86 | if not preds:
87 | continue
88 |
89 | preds = invert_affine(framed_metas, preds)[0]
90 |
91 | scores = preds['scores']
92 | class_ids = preds['class_ids']
93 | rois = preds['rois']
94 |
95 | if rois.shape[0] > 0:
96 | # x1,y1,x2,y2 -> x1,y1,w,h
97 | rois[:, 2] -= rois[:, 0]
98 | rois[:, 3] -= rois[:, 1]
99 |
100 | bbox_score = scores
101 |
102 | for roi_id in range(rois.shape[0]):
103 | score = float(bbox_score[roi_id])
104 | label = int(class_ids[roi_id])
105 | box = rois[roi_id, :]
106 |
107 | if score < threshold:
108 | break
109 | image_result = {
110 | 'image_id': image_id,
111 | 'category_id': label + 1,
112 | 'score': float(score),
113 | 'bbox': box.tolist(),
114 | }
115 |
116 | results.append(image_result)
117 |
118 | if not len(results):
119 | raise Exception('the model does not provide any valid output, check model architecture and the data input')
120 |
121 | # write output
122 | json.dump(results, open('{}_bbox_results.json'.format(set_name), 'w'), indent=4)
123 |
124 | return processed_image_ids
125 |
126 |
127 | def _eval(coco_gt, image_ids, pred_json_path):
128 | # load results in COCO evaluation tool
129 | coco_pred = coco_gt.loadRes(pred_json_path)
130 |
131 | # run COCO evaluation
132 | print('BBox')
133 | coco_eval = COCOeval(coco_gt, coco_pred, 'bbox')
134 | coco_eval.params.imgIds = image_ids
135 | coco_eval.evaluate()
136 | coco_eval.accumulate()
137 | coco_eval.summarize()
138 |
139 |
140 | if __name__ == '__main__':
141 | SET_NAME = params['val_set']
142 | VAL_GT = 'dataset/{}/annotations/instances_{}.json'.format(params["project_name"],SET_NAME)
143 | VAL_IMGS = 'dataset/{}/{}/'.format(params["project_name"],SET_NAME)
144 | MAX_IMAGES = 10000
145 | coco_gt = COCO(VAL_GT)
146 | image_ids = coco_gt.getImgIds()[:MAX_IMAGES]
147 |
148 | if not os.path.exists('{}_bbox_results.json'.format(SET_NAME)):
149 | model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list),
150 | ratios=eval(params['anchors_ratios']), scales=eval(params['anchors_scales']))
151 | model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
152 | model.requires_grad_(False)
153 | model.eval()
154 |
155 | if use_cuda:
156 | model.cuda(gpu)
157 |
158 | if use_float16:
159 | model.half()
160 |
161 | image_ids = evaluate_coco(VAL_IMGS, SET_NAME, image_ids, coco_gt, model)
162 |
163 | _eval(coco_gt, image_ids, '{}_bbox_results.json'.format(SET_NAME))
164 | else:
165 | _eval(coco_gt, image_ids, '{}_bbox_results.json'.format(SET_NAME))
166 |
--------------------------------------------------------------------------------
/efficientdet-d2.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/efficientdet-d2.pth
--------------------------------------------------------------------------------
/efficientdet/config.py:
--------------------------------------------------------------------------------
1 | # COCO_CLASSES = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
2 | # "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog",
3 | # "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
4 | # "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite",
5 | # "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
6 | # "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
7 | # "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant",
8 | # "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
9 | # "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
10 | # "teddy bear", "hair drier", "toothbrush"]
11 | COCO_CLASSES = ["holothurian","echinus","scallop","starfish"]
12 |
13 | colors = [(39, 129, 113), (164, 80, 133), (83, 122, 114), (99, 81, 172), (95, 56, 104), (37, 84, 86), (14, 89, 122),
14 | (80, 7, 65), (10, 102, 25), (90, 185, 109), (106, 110, 132), (169, 158, 85), (188, 185, 26), (103, 1, 17),
15 | (82, 144, 81), (92, 7, 184), (49, 81, 155), (179, 177, 69), (93, 187, 158), (13, 39, 73), (12, 50, 60),
16 | (16, 179, 33), (112, 69, 165), (15, 139, 63), (33, 191, 159), (182, 173, 32), (34, 113, 133), (90, 135, 34),
17 | (53, 34, 86), (141, 35, 190), (6, 171, 8), (118, 76, 112), (89, 60, 55), (15, 54, 88), (112, 75, 181),
18 | (42, 147, 38), (138, 52, 63), (128, 65, 149), (106, 103, 24), (168, 33, 45), (28, 136, 135), (86, 91, 108),
19 | (52, 11, 76), (142, 6, 189), (57, 81, 168), (55, 19, 148), (182, 101, 89), (44, 65, 179), (1, 33, 26),
20 | (122, 164, 26), (70, 63, 134), (137, 106, 82), (120, 118, 52), (129, 74, 42), (182, 147, 112), (22, 157, 50),
21 | (56, 50, 20), (2, 22, 177), (156, 100, 106), (21, 35, 42), (13, 8, 121), (142, 92, 28), (45, 118, 33),
22 | (105, 118, 30), (7, 185, 124), (46, 34, 146), (105, 184, 169), (22, 18, 5), (147, 71, 73), (181, 64, 91),
23 | (31, 39, 184), (164, 179, 33), (96, 50, 18), (95, 15, 106), (113, 68, 54), (136, 116, 112), (119, 139, 130),
24 | (31, 139, 34), (66, 6, 127), (62, 39, 2), (49, 99, 180), (49, 119, 155), (153, 50, 183), (125, 38, 3),
25 | (129, 87, 143), (49, 87, 40), (128, 62, 120), (73, 85, 148), (28, 144, 118), (29, 9, 24), (175, 45, 108),
26 | (81, 175, 64), (178, 19, 157), (74, 188, 190), (18, 114, 2), (62, 128, 96), (21, 3, 150), (0, 6, 95),
27 | (2, 20, 184), (122, 37, 185)]
28 |
--------------------------------------------------------------------------------
/efficientdet/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 |
5 | from torch.utils.data import Dataset, DataLoader
6 | from pycocotools.coco import COCO
7 | import cv2
8 |
9 |
10 | class CocoDataset(Dataset):
11 | def __init__(self, root_dir, set='train', transform=None):
12 |
13 | self.root_dir = root_dir
14 | self.set_name = set
15 | self.transform = transform
16 |
17 | self.coco = COCO(os.path.join(self.root_dir, 'annotations', 'instances_' + self.set_name + '.json'))
18 | self.image_ids = self.coco.getImgIds()
19 |
20 | self.load_classes()
21 |
22 | def load_classes(self):
23 |
24 | # load class names (name -> label)
25 | categories = self.coco.loadCats(self.coco.getCatIds())
26 | categories.sort(key=lambda x: x['id'])
27 |
28 | self.classes = {}
29 | self.coco_labels = {}
30 | self.coco_labels_inverse = {}
31 | for c in categories:
32 | self.coco_labels[len(self.classes)] = c['id']
33 | self.coco_labels_inverse[c['id']] = len(self.classes)
34 | self.classes[c['name']] = len(self.classes)
35 |
36 | # also load the reverse (label -> name)
37 | self.labels = {}
38 | for key, value in self.classes.items():
39 | self.labels[value] = key
40 |
41 | def __len__(self):
42 | return len(self.image_ids)
43 |
44 | def __getitem__(self, idx):
45 |
46 | img = self.load_image(idx)
47 | annot = self.load_annotations(idx)
48 | sample = {'img': img, 'annot': annot}
49 | if self.transform:
50 | sample = self.transform(sample)
51 | return sample
52 |
53 | def load_image(self, image_index):
54 | image_info = self.coco.loadImgs(self.image_ids[image_index])[0]
55 | path = os.path.join(self.root_dir, self.set_name, image_info['file_name'])
56 | img = cv2.imread(path)
57 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
58 |
59 | return img.astype(np.float32) / 255.
60 |
61 | def load_annotations(self, image_index):
62 | # get ground truth annotations
63 | annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False)
64 | annotations = np.zeros((0, 5))
65 |
66 | # some images appear to miss annotations
67 | if len(annotations_ids) == 0:
68 | return annotations
69 |
70 | # parse annotations
71 | coco_annotations = self.coco.loadAnns(annotations_ids)
72 | for idx, a in enumerate(coco_annotations):
73 |
74 | # some annotations have basically no width / height, skip them
75 | if a['bbox'][2] < 1 or a['bbox'][3] < 1:
76 | continue
77 |
78 | annotation = np.zeros((1, 5))
79 | annotation[0, :4] = a['bbox']
80 | annotation[0, 4] = self.coco_label_to_label(a['category_id'])
81 | annotations = np.append(annotations, annotation, axis=0)
82 |
83 | # transform from [x, y, w, h] to [x1, y1, x2, y2]
84 | annotations[:, 2] = annotations[:, 0] + annotations[:, 2]
85 | annotations[:, 3] = annotations[:, 1] + annotations[:, 3]
86 |
87 | return annotations
88 |
89 | def coco_label_to_label(self, coco_label):
90 | return self.coco_labels_inverse[coco_label]
91 |
92 | def label_to_coco_label(self, label):
93 | return self.coco_labels[label]
94 |
95 |
96 | def collater(data):
97 | imgs = [s['img'] for s in data]
98 | annots = [s['annot'] for s in data]
99 | scales = [s['scale'] for s in data]
100 |
101 | imgs = torch.from_numpy(np.stack(imgs, axis=0))
102 |
103 | max_num_annots = max(annot.shape[0] for annot in annots)
104 |
105 | if max_num_annots > 0:
106 |
107 | annot_padded = torch.ones((len(annots), max_num_annots, 5)) * -1
108 |
109 | if max_num_annots > 0:
110 | for idx, annot in enumerate(annots):
111 | if annot.shape[0] > 0:
112 | annot_padded[idx, :annot.shape[0], :] = annot
113 | else:
114 | annot_padded = torch.ones((len(annots), 1, 5)) * -1
115 |
116 | imgs = imgs.permute(0, 3, 1, 2)
117 |
118 | return {'img': imgs, 'annot': annot_padded, 'scale': scales}
119 |
120 |
121 | class Resizer(object):
122 | """Convert ndarrays in sample to Tensors."""
123 |
124 | def __init__(self, img_size=512):
125 | self.img_size = img_size
126 |
127 | def __call__(self, sample):
128 | image, annots = sample['img'], sample['annot']
129 | height, width, _ = image.shape
130 | if height > width:
131 | scale = self.img_size / height
132 | resized_height = self.img_size
133 | resized_width = int(width * scale)
134 | else:
135 | scale = self.img_size / width
136 | resized_height = int(height * scale)
137 | resized_width = self.img_size
138 |
139 | image = cv2.resize(image, (resized_width, resized_height), interpolation=cv2.INTER_LINEAR)
140 |
141 | new_image = np.zeros((self.img_size, self.img_size, 3))
142 | new_image[0:resized_height, 0:resized_width] = image
143 |
144 | annots[:, :4] *= scale
145 |
146 | return {'img': torch.from_numpy(new_image).to(torch.float32), 'annot': torch.from_numpy(annots), 'scale': scale}
147 |
148 |
149 | class Augmenter(object):
150 | """Convert ndarrays in sample to Tensors."""
151 |
152 | def __call__(self, sample, flip_x=0.5):
153 | if np.random.rand() < flip_x:
154 | image, annots = sample['img'], sample['annot']
155 | image = image[:, ::-1, :]
156 |
157 | rows, cols, channels = image.shape
158 |
159 | x1 = annots[:, 0].copy()
160 | x2 = annots[:, 2].copy()
161 |
162 | x_tmp = x1.copy()
163 |
164 | annots[:, 0] = cols - x2
165 | annots[:, 2] = cols - x_tmp
166 |
167 | sample = {'img': image, 'annot': annots}
168 |
169 | return sample
170 |
171 |
172 | class Normalizer(object):
173 |
174 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
175 | self.mean = np.array([[mean]])
176 | self.std = np.array([[std]])
177 |
178 | def __call__(self, sample):
179 | image, annots = sample['img'], sample['annot']
180 |
181 | return {'img': ((image.astype(np.float32) - self.mean) / self.std), 'annot': annots}
182 |
--------------------------------------------------------------------------------
/efficientdet/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import cv2
4 | import numpy as np
5 |
6 | from efficientdet.utils import BBoxTransform, ClipBoxes
7 | from utils.utils import postprocess, invert_affine, display
8 |
9 |
10 | def calc_iou(a, b):
11 | # a(anchor) [boxes, (y1, x1, y2, x2)]
12 | # b(gt, coco-style) [boxes, (x1, y1, x2, y2)]
13 |
14 | area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])
15 | iw = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 0])
16 | ih = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 1])
17 | iw = torch.clamp(iw, min=0)
18 | ih = torch.clamp(ih, min=0)
19 | ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih
20 | ua = torch.clamp(ua, min=1e-8)
21 | intersection = iw * ih
22 | IoU = intersection / ua
23 |
24 | return IoU
25 |
26 |
27 | class FocalLoss(nn.Module):
28 | def __init__(self):
29 | super(FocalLoss, self).__init__()
30 |
31 | def forward(self, classifications, regressions, anchors, annotations, **kwargs):
32 | alpha = 0.25
33 | gamma = 2.0
34 | batch_size = classifications.shape[0]
35 | classification_losses = []
36 | regression_losses = []
37 |
38 | anchor = anchors[0, :, :] # assuming all image sizes are the same, which it is
39 | dtype = anchors.dtype
40 |
41 | anchor_widths = anchor[:, 3] - anchor[:, 1]
42 | anchor_heights = anchor[:, 2] - anchor[:, 0]
43 | anchor_ctr_x = anchor[:, 1] + 0.5 * anchor_widths
44 | anchor_ctr_y = anchor[:, 0] + 0.5 * anchor_heights
45 |
46 | for j in range(batch_size):
47 |
48 | classification = classifications[j, :, :]
49 | regression = regressions[j, :, :]
50 |
51 | bbox_annotation = annotations[j]
52 | bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
53 |
54 | if bbox_annotation.shape[0] == 0:
55 | if torch.cuda.is_available():
56 | regression_losses.append(torch.tensor(0).to(dtype).cuda())
57 | classification_losses.append(torch.tensor(0).to(dtype).cuda())
58 | else:
59 | regression_losses.append(torch.tensor(0).to(dtype))
60 | classification_losses.append(torch.tensor(0).to(dtype))
61 |
62 | continue
63 |
64 | classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
65 |
66 | IoU = calc_iou(anchor[:, :], bbox_annotation[:, :4])
67 |
68 | IoU_max, IoU_argmax = torch.max(IoU, dim=1)
69 |
70 | # compute the loss for classification
71 | targets = torch.ones_like(classification) * -1
72 | if torch.cuda.is_available():
73 | targets = targets.cuda()
74 |
75 | targets[torch.lt(IoU_max, 0.4), :] = 0
76 |
77 | positive_indices = torch.ge(IoU_max, 0.5)
78 |
79 | num_positive_anchors = positive_indices.sum()
80 |
81 | assigned_annotations = bbox_annotation[IoU_argmax, :]
82 |
83 | targets[positive_indices, :] = 0
84 | targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1
85 |
86 | alpha_factor = torch.ones_like(targets) * alpha
87 | if torch.cuda.is_available():
88 | alpha_factor = alpha_factor.cuda()
89 |
90 | alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
91 | focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
92 | focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
93 |
94 | bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))
95 |
96 | cls_loss = focal_weight * bce
97 |
98 | zeros = torch.zeros_like(cls_loss)
99 | if torch.cuda.is_available():
100 | zeros = zeros.cuda()
101 | cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros)
102 |
103 | classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
104 |
105 | if positive_indices.sum() > 0:
106 | assigned_annotations = assigned_annotations[positive_indices, :]
107 |
108 | anchor_widths_pi = anchor_widths[positive_indices]
109 | anchor_heights_pi = anchor_heights[positive_indices]
110 | anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
111 | anchor_ctr_y_pi = anchor_ctr_y[positive_indices]
112 |
113 | gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0]
114 | gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
115 | gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths
116 | gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights
117 |
118 | # efficientdet style
119 | gt_widths = torch.clamp(gt_widths, min=1)
120 | gt_heights = torch.clamp(gt_heights, min=1)
121 |
122 | targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
123 | targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
124 | targets_dw = torch.log(gt_widths / anchor_widths_pi)
125 | targets_dh = torch.log(gt_heights / anchor_heights_pi)
126 |
127 | targets = torch.stack((targets_dy, targets_dx, targets_dh, targets_dw))
128 | targets = targets.t()
129 |
130 | regression_diff = torch.abs(targets - regression[positive_indices, :])
131 |
132 | regression_loss = torch.where(
133 | torch.le(regression_diff, 1.0 / 9.0),
134 | 0.5 * 9.0 * torch.pow(regression_diff, 2),
135 | regression_diff - 0.5 / 9.0
136 | )
137 | regression_losses.append(regression_loss.mean())
138 | else:
139 | if torch.cuda.is_available():
140 | regression_losses.append(torch.tensor(0).to(dtype).cuda())
141 | else:
142 | regression_losses.append(torch.tensor(0).to(dtype))
143 |
144 | # debug
145 | imgs = kwargs.get('imgs', None)
146 | if imgs is not None:
147 | regressBoxes = BBoxTransform()
148 | clipBoxes = ClipBoxes()
149 | obj_list = kwargs.get('obj_list', None)
150 | out = postprocess(imgs.detach(),
151 | torch.stack([anchors[0]] * imgs.shape[0], 0).detach(), regressions.detach(), classifications.detach(),
152 | regressBoxes, clipBoxes,
153 | 0.5, 0.3)
154 | imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
155 | imgs = ((imgs * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255).astype(np.uint8)
156 | imgs = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in imgs]
157 | display(out, imgs, obj_list, imshow=False, imwrite=True)
158 |
159 | return torch.stack(classification_losses).mean(dim=0, keepdim=True), \
160 | torch.stack(regression_losses).mean(dim=0, keepdim=True)
161 |
--------------------------------------------------------------------------------
/efficientdet/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from torchvision.ops.boxes import nms as nms_torch
4 |
5 | from efficientnet import EfficientNet as EffNet
6 | from efficientnet.utils import MemoryEfficientSwish, Swish
7 | from efficientnet.utils_extra import Conv2dStaticSamePadding, MaxPool2dStaticSamePadding
8 |
9 |
10 | def nms(dets, thresh):
11 | return nms_torch(dets[:, :4], dets[:, 4], thresh)
12 |
13 |
14 | class SeparableConvBlock(nn.Module):
15 | """
16 | created by Zylo117
17 | """
18 |
19 | def __init__(self, in_channels, out_channels=None, norm=True, activation=False, onnx_export=False):
20 | super(SeparableConvBlock, self).__init__()
21 | if out_channels is None:
22 | out_channels = in_channels
23 |
24 | # Q: whether separate conv
25 | # share bias between depthwise_conv and pointwise_conv
26 | # or just pointwise_conv apply bias.
27 | # A: Confirmed, just pointwise_conv applies bias, depthwise_conv has no bias.
28 |
29 | self.depthwise_conv = Conv2dStaticSamePadding(in_channels, in_channels,
30 | kernel_size=3, stride=1, groups=in_channels, bias=False)
31 | self.pointwise_conv = Conv2dStaticSamePadding(in_channels, out_channels, kernel_size=1, stride=1)
32 |
33 | self.norm = norm
34 | if self.norm:
35 | # Warning: pytorch momentum is different from tensorflow's, momentum_pytorch = 1 - momentum_tensorflow
36 | self.bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.01, eps=1e-3)
37 |
38 | self.activation = activation
39 | if self.activation:
40 | self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
41 |
42 | def forward(self, x):
43 | x = self.depthwise_conv(x)
44 | x = self.pointwise_conv(x)
45 |
46 | if self.norm:
47 | x = self.bn(x)
48 |
49 | if self.activation:
50 | x = self.swish(x)
51 |
52 | return x
53 |
54 |
55 | class BiFPN(nn.Module):
56 | """
57 | modified by Zylo117
58 | """
59 |
60 | def __init__(self, num_channels, conv_channels, first_time=False, epsilon=1e-4, onnx_export=False, attention=True):
61 | """
62 |
63 | Args:
64 | num_channels:
65 | conv_channels:
66 | first_time: whether the input comes directly from the efficientnet,
67 | if True, downchannel it first, and downsample P5 to generate P6 then P7
68 | epsilon: epsilon of fast weighted attention sum of BiFPN, not the BN's epsilon
69 | onnx_export: if True, use Swish instead of MemoryEfficientSwish
70 | """
71 | super(BiFPN, self).__init__()
72 | self.epsilon = epsilon
73 | # Conv layers
74 | self.conv6_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
75 | self.conv5_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
76 | self.conv4_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
77 | self.conv3_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
78 | self.conv4_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
79 | self.conv5_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
80 | self.conv6_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
81 | self.conv7_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
82 |
83 | # Feature scaling layers
84 | self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest')
85 | self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest')
86 | self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest')
87 | self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest')
88 |
89 | self.p4_downsample = MaxPool2dStaticSamePadding(3, 2)
90 | self.p5_downsample = MaxPool2dStaticSamePadding(3, 2)
91 | self.p6_downsample = MaxPool2dStaticSamePadding(3, 2)
92 | self.p7_downsample = MaxPool2dStaticSamePadding(3, 2)
93 |
94 | self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
95 |
96 | self.first_time = first_time
97 | if self.first_time:
98 | self.p5_down_channel = nn.Sequential(
99 | Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
100 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
101 | )
102 | self.p4_down_channel = nn.Sequential(
103 | Conv2dStaticSamePadding(conv_channels[1], num_channels, 1),
104 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
105 | )
106 | self.p3_down_channel = nn.Sequential(
107 | Conv2dStaticSamePadding(conv_channels[0], num_channels, 1),
108 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
109 | )
110 |
111 | self.p5_to_p6 = nn.Sequential(
112 | Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
113 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
114 | MaxPool2dStaticSamePadding(3, 2)
115 | )
116 | self.p6_to_p7 = nn.Sequential(
117 | MaxPool2dStaticSamePadding(3, 2)
118 | )
119 |
120 | self.p4_down_channel_2 = nn.Sequential(
121 | Conv2dStaticSamePadding(conv_channels[1], num_channels, 1),
122 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
123 | )
124 | self.p5_down_channel_2 = nn.Sequential(
125 | Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
126 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
127 | )
128 |
129 | # Weight
130 | self.p6_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
131 | self.p6_w1_relu = nn.ReLU()
132 | self.p5_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
133 | self.p5_w1_relu = nn.ReLU()
134 | self.p4_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
135 | self.p4_w1_relu = nn.ReLU()
136 | self.p3_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
137 | self.p3_w1_relu = nn.ReLU()
138 |
139 | self.p4_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
140 | self.p4_w2_relu = nn.ReLU()
141 | self.p5_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
142 | self.p5_w2_relu = nn.ReLU()
143 | self.p6_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
144 | self.p6_w2_relu = nn.ReLU()
145 | self.p7_w2 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
146 | self.p7_w2_relu = nn.ReLU()
147 |
148 | self.attention = attention
149 |
150 | def forward(self, inputs):
151 | """
152 | illustration of a minimal bifpn unit
153 | P7_0 -------------------------> P7_2 -------->
154 | |-------------| ↑
155 | ↓ |
156 | P6_0 ---------> P6_1 ---------> P6_2 -------->
157 | |-------------|--------------↑ ↑
158 | ↓ |
159 | P5_0 ---------> P5_1 ---------> P5_2 -------->
160 | |-------------|--------------↑ ↑
161 | ↓ |
162 | P4_0 ---------> P4_1 ---------> P4_2 -------->
163 | |-------------|--------------↑ ↑
164 | |--------------↓ |
165 | P3_0 -------------------------> P3_2 -------->
166 | """
167 |
168 | # downsample channels using same-padding conv2d to target phase's if not the same
169 | # judge: same phase as target,
170 | # if same, pass;
171 | # elif earlier phase, downsample to target phase's by pooling
172 | # elif later phase, upsample to target phase's by nearest interpolation
173 |
174 | if self.attention:
175 | p3_out, p4_out, p5_out, p6_out, p7_out = self._forward_fast_attention(inputs)
176 | else:
177 | p3_out, p4_out, p5_out, p6_out, p7_out = self._forward(inputs)
178 |
179 | return p3_out, p4_out, p5_out, p6_out, p7_out
180 |
181 | def _forward_fast_attention(self, inputs):
182 | if self.first_time:
183 | p3, p4, p5 = inputs
184 |
185 | p6_in = self.p5_to_p6(p5)
186 | p7_in = self.p6_to_p7(p6_in)
187 |
188 | p3_in = self.p3_down_channel(p3)
189 | p4_in = self.p4_down_channel(p4)
190 | p5_in = self.p5_down_channel(p5)
191 |
192 | else:
193 | # P3_0, P4_0, P5_0, P6_0 and P7_0
194 | p3_in, p4_in, p5_in, p6_in, p7_in = inputs
195 |
196 | # P7_0 to P7_2
197 |
198 | # Weights for P6_0 and P7_0 to P6_1
199 | p6_w1 = self.p6_w1_relu(self.p6_w1)
200 | weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon)
201 | # Connections for P6_0 and P7_0 to P6_1 respectively
202 | p6_up = self.conv6_up(self.swish(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in)))
203 |
204 | # Weights for P5_0 and P6_0 to P5_1
205 | p5_w1 = self.p5_w1_relu(self.p5_w1)
206 | weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon)
207 | # Connections for P5_0 and P6_0 to P5_1 respectively
208 | p5_up = self.conv5_up(self.swish(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_up)))
209 |
210 | # Weights for P4_0 and P5_0 to P4_1
211 | p4_w1 = self.p4_w1_relu(self.p4_w1)
212 | weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon)
213 | # Connections for P4_0 and P5_0 to P4_1 respectively
214 | p4_up = self.conv4_up(self.swish(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_up)))
215 |
216 | # Weights for P3_0 and P4_1 to P3_2
217 | p3_w1 = self.p3_w1_relu(self.p3_w1)
218 | weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon)
219 | # Connections for P3_0 and P4_1 to P3_2 respectively
220 | p3_out = self.conv3_up(self.swish(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_up)))
221 |
222 | if self.first_time:
223 | p4_in = self.p4_down_channel_2(p4)
224 | p5_in = self.p5_down_channel_2(p5)
225 |
226 | # Weights for P4_0, P4_1 and P3_2 to P4_2
227 | p4_w2 = self.p4_w2_relu(self.p4_w2)
228 | weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon)
229 | # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
230 | p4_out = self.conv4_down(
231 | self.swish(weight[0] * p4_in + weight[1] * p4_up + weight[2] * self.p4_downsample(p3_out)))
232 |
233 | # Weights for P5_0, P5_1 and P4_2 to P5_2
234 | p5_w2 = self.p5_w2_relu(self.p5_w2)
235 | weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon)
236 | # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
237 | p5_out = self.conv5_down(
238 | self.swish(weight[0] * p5_in + weight[1] * p5_up + weight[2] * self.p5_downsample(p4_out)))
239 |
240 | # Weights for P6_0, P6_1 and P5_2 to P6_2
241 | p6_w2 = self.p6_w2_relu(self.p6_w2)
242 | weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon)
243 | # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
244 | p6_out = self.conv6_down(
245 | self.swish(weight[0] * p6_in + weight[1] * p6_up + weight[2] * self.p6_downsample(p5_out)))
246 |
247 | # Weights for P7_0 and P6_2 to P7_2
248 | p7_w2 = self.p7_w2_relu(self.p7_w2)
249 | weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon)
250 | # Connections for P7_0 and P6_2 to P7_2
251 | p7_out = self.conv7_down(self.swish(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out)))
252 |
253 | return p3_out, p4_out, p5_out, p6_out, p7_out
254 |
255 | def _forward(self, inputs):
256 | if self.first_time:
257 | p3, p4, p5 = inputs
258 |
259 | p6_in = self.p5_to_p6(p5)
260 | p7_in = self.p6_to_p7(p6_in)
261 |
262 | p3_in = self.p3_down_channel(p3)
263 | p4_in = self.p4_down_channel(p4)
264 | p5_in = self.p5_down_channel(p5)
265 |
266 | else:
267 | # P3_0, P4_0, P5_0, P6_0 and P7_0
268 | p3_in, p4_in, p5_in, p6_in, p7_in = inputs
269 |
270 | # P7_0 to P7_2
271 |
272 | # Connections for P6_0 and P7_0 to P6_1 respectively
273 | p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_in)))
274 |
275 | # Connections for P5_0 and P6_0 to P5_1 respectively
276 | p5_up = self.conv5_up(self.swish(p5_in + self.p5_upsample(p6_up)))
277 |
278 | # Connections for P4_0 and P5_0 to P4_1 respectively
279 | p4_up = self.conv4_up(self.swish(p4_in + self.p4_upsample(p5_up)))
280 |
281 | # Connections for P3_0 and P4_1 to P3_2 respectively
282 | p3_out = self.conv3_up(self.swish(p3_in + self.p3_upsample(p4_up)))
283 |
284 | if self.first_time:
285 | p4_in = self.p4_down_channel_2(p4)
286 | p5_in = self.p5_down_channel_2(p5)
287 |
288 | # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
289 | p4_out = self.conv4_down(
290 | self.swish(p4_in + p4_up + self.p4_downsample(p3_out)))
291 |
292 | # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
293 | p5_out = self.conv5_down(
294 | self.swish(p5_in + p5_up + self.p5_downsample(p4_out)))
295 |
296 | # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
297 | p6_out = self.conv6_down(
298 | self.swish(p6_in + p6_up + self.p6_downsample(p5_out)))
299 |
300 | # Connections for P7_0 and P6_2 to P7_2
301 | p7_out = self.conv7_down(self.swish(p7_in + self.p7_downsample(p6_out)))
302 |
303 | return p3_out, p4_out, p5_out, p6_out, p7_out
304 |
305 |
306 | class Regressor(nn.Module):
307 | """
308 | modified by Zylo117
309 | """
310 |
311 | def __init__(self, in_channels, num_anchors, num_layers, onnx_export=False):
312 | super(Regressor, self).__init__()
313 | self.num_layers = num_layers
314 | self.num_layers = num_layers
315 |
316 | self.conv_list = nn.ModuleList(
317 | [SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)])
318 | self.bn_list = nn.ModuleList(
319 | [nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in
320 | range(5)])
321 | self.header = SeparableConvBlock(in_channels, num_anchors * 4, norm=False, activation=False)
322 | self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
323 |
324 | def forward(self, inputs):
325 | feats = []
326 | for feat, bn_list in zip(inputs, self.bn_list):
327 | for i, bn, conv in zip(range(self.num_layers), bn_list, self.conv_list):
328 | feat = conv(feat)
329 | feat = bn(feat)
330 | feat = self.swish(feat)
331 | feat = self.header(feat)
332 |
333 | feat = feat.permute(0, 2, 3, 1)
334 | feat = feat.contiguous().view(feat.shape[0], -1, 4)
335 |
336 | feats.append(feat)
337 |
338 | feats = torch.cat(feats, dim=1)
339 |
340 | return feats
341 |
342 |
343 | class Classifier(nn.Module):
344 | """
345 | modified by Zylo117
346 | """
347 |
348 | def __init__(self, in_channels, num_anchors, num_classes, num_layers, onnx_export=False):
349 | super(Classifier, self).__init__()
350 | self.num_anchors = num_anchors
351 | self.num_classes = num_classes
352 | self.num_layers = num_layers
353 | self.conv_list = nn.ModuleList(
354 | [SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)])
355 | self.bn_list = nn.ModuleList(
356 | [nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in
357 | range(5)])
358 | self.header = SeparableConvBlock(in_channels, num_anchors * num_classes, norm=False, activation=False)
359 | self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
360 |
361 | def forward(self, inputs):
362 | feats = []
363 | for feat, bn_list in zip(inputs, self.bn_list):
364 | for i, bn, conv in zip(range(self.num_layers), bn_list, self.conv_list):
365 | feat = conv(feat)
366 | feat = bn(feat)
367 | feat = self.swish(feat)
368 | feat = self.header(feat)
369 |
370 | feat = feat.permute(0, 2, 3, 1)
371 | feat = feat.contiguous().view(feat.shape[0], feat.shape[1], feat.shape[2], self.num_anchors,
372 | self.num_classes)
373 | feat = feat.contiguous().view(feat.shape[0], -1, self.num_classes)
374 |
375 | feats.append(feat)
376 |
377 | feats = torch.cat(feats, dim=1)
378 | feats = feats.sigmoid()
379 |
380 | return feats
381 |
382 |
383 | class EfficientNet(nn.Module):
384 | """
385 | modified by Zylo117
386 | """
387 |
388 | def __init__(self, compound_coef, load_weights=False):
389 | super(EfficientNet, self).__init__()
390 | model = EffNet.from_pretrained('efficientnet-b{}'.format(compound_coef), load_weights)
391 | del model._conv_head
392 | del model._bn1
393 | del model._avg_pooling
394 | del model._dropout
395 | del model._fc
396 | self.model = model
397 |
398 | def forward(self, x):
399 | x = self.model._conv_stem(x)
400 | x = self.model._bn0(x)
401 | x = self.model._swish(x)
402 | feature_maps = []
403 |
404 | # TODO: temporarily storing extra tensor last_x and del it later might not be a good idea,
405 | # try recording stride changing when creating efficientnet,
406 | # and then apply it here.
407 | last_x = None
408 | for idx, block in enumerate(self.model._blocks):
409 | drop_connect_rate = self.model._global_params.drop_connect_rate
410 | if drop_connect_rate:
411 | drop_connect_rate *= float(idx) / len(self.model._blocks)
412 | x = block(x, drop_connect_rate=drop_connect_rate)
413 |
414 | if block._depthwise_conv.stride == [2, 2]:
415 | feature_maps.append(last_x)
416 | elif idx == len(self.model._blocks) - 1:
417 | feature_maps.append(x)
418 | last_x = x
419 | del last_x
420 | return feature_maps[1:]
421 |
422 |
423 | if __name__ == '__main__':
424 | from tensorboardX import SummaryWriter
425 |
426 |
427 | def count_parameters(model):
428 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
429 |
--------------------------------------------------------------------------------
/efficientdet/utils.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 |
7 | class BBoxTransform(nn.Module):
8 | def forward(self, anchors, regression):
9 | """
10 | decode_box_outputs adapted from https://github.com/google/automl/blob/master/efficientdet/anchors.py
11 |
12 | Args:
13 | anchors: [batchsize, boxes, (y1, x1, y2, x2)]
14 | regression: [batchsize, boxes, (dy, dx, dh, dw)]
15 |
16 | Returns:
17 |
18 | """
19 | y_centers_a = (anchors[..., 0] + anchors[..., 2]) / 2
20 | x_centers_a = (anchors[..., 1] + anchors[..., 3]) / 2
21 | ha = anchors[..., 2] - anchors[..., 0]
22 | wa = anchors[..., 3] - anchors[..., 1]
23 |
24 | w = regression[..., 3].exp() * wa
25 | h = regression[..., 2].exp() * ha
26 |
27 | y_centers = regression[..., 0] * ha + y_centers_a
28 | x_centers = regression[..., 1] * wa + x_centers_a
29 |
30 | ymin = y_centers - h / 2.
31 | xmin = x_centers - w / 2.
32 | ymax = y_centers + h / 2.
33 | xmax = x_centers + w / 2.
34 |
35 | return torch.stack([xmin, ymin, xmax, ymax], dim=2)
36 |
37 |
38 | class ClipBoxes(nn.Module):
39 |
40 | def __init__(self):
41 | super(ClipBoxes, self).__init__()
42 |
43 | def forward(self, boxes, img):
44 | batch_size, num_channels, height, width = img.shape
45 |
46 | boxes[:, :, 0] = torch.clamp(boxes[:, :, 0], min=0)
47 | boxes[:, :, 1] = torch.clamp(boxes[:, :, 1], min=0)
48 |
49 | boxes[:, :, 2] = torch.clamp(boxes[:, :, 2], max=width - 1)
50 | boxes[:, :, 3] = torch.clamp(boxes[:, :, 3], max=height - 1)
51 |
52 | return boxes
53 |
54 |
55 | class Anchors(nn.Module):
56 | """
57 | adapted and modified from https://github.com/google/automl/blob/master/efficientdet/anchors.py by Zylo117
58 | """
59 |
60 | def __init__(self, anchor_scale=4., pyramid_levels=None, **kwargs):
61 | super().__init__()
62 | self.anchor_scale = anchor_scale
63 |
64 | if pyramid_levels is None:
65 | self.pyramid_levels = [3, 4, 5, 6, 7]
66 |
67 | self.strides = kwargs.get('strides', [2 ** x for x in self.pyramid_levels])
68 | self.scales = np.array(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]))
69 | self.ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)])
70 |
71 | self.last_anchors = {}
72 | self.last_shape = None
73 |
74 | def forward(self, image, dtype=torch.float32):
75 | """Generates multiscale anchor boxes.
76 |
77 | Args:
78 | image_size: integer number of input image size. The input image has the
79 | same dimension for width and height. The image_size should be divided by
80 | the largest feature stride 2^max_level.
81 | anchor_scale: float number representing the scale of size of the base
82 | anchor to the feature stride 2^level.
83 | anchor_configs: a dictionary with keys as the levels of anchors and
84 | values as a list of anchor configuration.
85 |
86 | Returns:
87 | anchor_boxes: a numpy array with shape [N, 4], which stacks anchors on all
88 | feature levels.
89 | Raises:
90 | ValueError: input size must be the multiple of largest feature stride.
91 | """
92 | image_shape = image.shape[2:]
93 |
94 | if image_shape == self.last_shape and image.device in self.last_anchors:
95 | return self.last_anchors[image.device]
96 |
97 | if self.last_shape is None or self.last_shape != image_shape:
98 | self.last_shape = image_shape
99 |
100 | if dtype == torch.float16:
101 | dtype = np.float16
102 | else:
103 | dtype = np.float32
104 |
105 | boxes_all = []
106 | for stride in self.strides:
107 | boxes_level = []
108 | for scale, ratio in itertools.product(self.scales, self.ratios):
109 | if image_shape[1] % stride != 0:
110 | raise ValueError('input size must be divided by the stride.')
111 | base_anchor_size = self.anchor_scale * stride * scale
112 | anchor_size_x_2 = base_anchor_size * ratio[0] / 2.0
113 | anchor_size_y_2 = base_anchor_size * ratio[1] / 2.0
114 |
115 | x = np.arange(stride / 2, image_shape[1], stride)
116 | y = np.arange(stride / 2, image_shape[0], stride)
117 | xv, yv = np.meshgrid(x, y)
118 | xv = xv.reshape(-1)
119 | yv = yv.reshape(-1)
120 |
121 | # y1,x1,y2,x2
122 | boxes = np.vstack((yv - anchor_size_y_2, xv - anchor_size_x_2,
123 | yv + anchor_size_y_2, xv + anchor_size_x_2))
124 | boxes = np.swapaxes(boxes, 0, 1)
125 | boxes_level.append(np.expand_dims(boxes, axis=1))
126 | # concat anchors on the same level to the reshape NxAx4
127 | boxes_level = np.concatenate(boxes_level, axis=1)
128 | boxes_all.append(boxes_level.reshape([-1, 4]))
129 |
130 | anchor_boxes = np.vstack(boxes_all)
131 |
132 | anchor_boxes = torch.from_numpy(anchor_boxes.astype(dtype)).to(image.device)
133 | anchor_boxes = anchor_boxes.unsqueeze(0)
134 |
135 | # save it for later use to reduce overhead
136 | self.last_anchors[image.device] = anchor_boxes
137 | return anchor_boxes
138 |
--------------------------------------------------------------------------------
/efficientdet_test.py:
--------------------------------------------------------------------------------
1 | # Author: Zylo117
2 |
3 | """
4 | Simple Inference Script of EfficientDet-Pytorch
5 | """
6 | import time
7 |
8 | import torch
9 | from torch.backends import cudnn
10 |
11 | from backbone import EfficientDetBackbone
12 | import cv2
13 | import numpy as np
14 |
15 | from efficientdet.utils import BBoxTransform, ClipBoxes
16 | from utils.utils import preprocess, invert_affine, postprocess
17 |
18 | compound_coef = 2
19 | force_input_size = None # set None to use default size
20 | img_path = "dataset/underwater/val/000008.jpg"
21 |
22 | # replace this part with your project's anchor config
23 | anchor_ratios = [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]
24 | anchor_scales = [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]
25 |
26 | threshold = 0.2
27 | iou_threshold = 0.2
28 |
29 | use_cuda = True
30 | use_float16 = False
31 | cudnn.fastest = True
32 | cudnn.benchmark = True
33 |
34 | # obj_list = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
35 | # 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
36 | # 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie',
37 | # 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
38 | # 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
39 | # 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
40 | # 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv',
41 | # 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
42 | # 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
43 | # 'toothbrush']
44 |
45 | obj_list = ["holothurian","echinus","scallop","starfish"]
46 |
47 | # tf bilinear interpolation is different from any other's, just make do
48 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
49 | input_size = input_sizes[compound_coef] if force_input_size is None else force_input_size
50 | ori_imgs, framed_imgs, framed_metas = preprocess(img_path, max_size=input_size)
51 |
52 | if use_cuda:
53 | x = torch.stack([torch.from_numpy(fi).cuda() for fi in framed_imgs], 0)
54 | else:
55 | x = torch.stack([torch.from_numpy(fi) for fi in framed_imgs], 0)
56 |
57 | x = x.to(torch.float32 if not use_float16 else torch.float16).permute(0, 3, 1, 2)
58 |
59 | model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list),
60 | ratios=anchor_ratios, scales=anchor_scales)
61 | model.load_state_dict(torch.load("./logs/underwater/efficientdet-d2_122_38106.pth")) # 模型地址
62 | model.requires_grad_(False)
63 | model.eval()
64 |
65 | if use_cuda:
66 | model = model.cuda()
67 | if use_float16:
68 | model = model.half()
69 |
70 | with torch.no_grad():
71 | features, regression, classification, anchors = model(x)
72 |
73 | regressBoxes = BBoxTransform()
74 | clipBoxes = ClipBoxes()
75 |
76 | out = postprocess(x,
77 | anchors, regression, classification,
78 | regressBoxes, clipBoxes,
79 | threshold, iou_threshold)
80 |
81 |
82 | def display(preds, imgs, imshow=True, imwrite=False):
83 | for i in range(len(imgs)):
84 | if len(preds[i]['rois']) == 0:
85 | continue
86 |
87 | for j in range(len(preds[i]['rois'])):
88 | (x1, y1, x2, y2) = preds[i]['rois'][j].astype(np.int)
89 | cv2.rectangle(imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2)
90 | obj = obj_list[preds[i]['class_ids'][j]]
91 | score = float(preds[i]['scores'][j])
92 |
93 | cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score),
94 | (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
95 | (255, 255, 0), 1)
96 |
97 | if imshow:
98 | cv2.imshow('img', imgs[i])
99 | cv2.waitKey(0)
100 |
101 | if imwrite:
102 | cv2.imwrite('test/img_inferred_d{}_this_repo_{}.jpg'.format(compound_coef,i), imgs[i])
103 |
104 |
105 | out = invert_affine(framed_metas, out)
106 | display(out, ori_imgs, imshow=False, imwrite=True)
107 |
108 | print('running speed test...')
109 | with torch.no_grad():
110 | print('test1: model inferring and postprocessing')
111 | print('inferring image for 10 times...')
112 | t1 = time.time()
113 | for _ in range(10):
114 | _, regression, classification, anchors = model(x)
115 |
116 | out = postprocess(x,
117 | anchors, regression, classification,
118 | regressBoxes, clipBoxes,
119 | threshold, iou_threshold)
120 | out = invert_affine(framed_metas, out)
121 |
122 | t2 = time.time()
123 | tact_time = (t2 - t1) / 10
124 | print('{} seconds, {} FPS, @batch_size 1'.format(tact_time,1 / tact_time))
125 |
126 | # uncomment this if you want a extreme fps test
127 | # print('test2: model inferring only')
128 | # print('inferring images for batch_size 32 for 10 times...')
129 | # t1 = time.time()
130 | # x = torch.cat([x] * 32, 0)
131 | # for _ in range(10):
132 | # _, regression, classification, anchors = model(x)
133 | #
134 | # t2 = time.time()
135 | # tact_time = (t2 - t1) / 10
136 | # print(f'{tact_time} seconds, {32 / tact_time} FPS, @batch_size 32')
137 |
--------------------------------------------------------------------------------
/efficientnet/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.6.1"
2 | from .model import EfficientNet
3 | from .utils import (
4 | GlobalParams,
5 | BlockArgs,
6 | BlockDecoder,
7 | efficientnet,
8 | get_model_params,
9 | )
10 |
11 |
--------------------------------------------------------------------------------
/efficientnet/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from .utils import (
6 | round_filters,
7 | round_repeats,
8 | drop_connect,
9 | get_same_padding_conv2d,
10 | get_model_params,
11 | efficientnet_params,
12 | load_pretrained_weights,
13 | Swish,
14 | MemoryEfficientSwish,
15 | )
16 |
17 | class MBConvBlock(nn.Module):
18 | """
19 | Mobile Inverted Residual Bottleneck Block
20 |
21 | Args:
22 | block_args (namedtuple): BlockArgs, see above
23 | global_params (namedtuple): GlobalParam, see above
24 |
25 | Attributes:
26 | has_se (bool): Whether the block contains a Squeeze and Excitation layer.
27 | """
28 |
29 | def __init__(self, block_args, global_params):
30 | super().__init__()
31 | self._block_args = block_args
32 | self._bn_mom = 1 - global_params.batch_norm_momentum
33 | self._bn_eps = global_params.batch_norm_epsilon
34 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
35 | self.id_skip = block_args.id_skip # skip connection and drop connect
36 |
37 | # Get static or dynamic convolution depending on image size
38 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
39 |
40 | # Expansion phase
41 | inp = self._block_args.input_filters # number of input channels
42 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
43 | if self._block_args.expand_ratio != 1:
44 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
45 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
46 |
47 | # Depthwise convolution phase
48 | k = self._block_args.kernel_size
49 | s = self._block_args.stride
50 | self._depthwise_conv = Conv2d(
51 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
52 | kernel_size=k, stride=s, bias=False)
53 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
54 |
55 | # Squeeze and Excitation layer, if desired
56 | if self.has_se:
57 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
58 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
59 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
60 |
61 | # Output phase
62 | final_oup = self._block_args.output_filters
63 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
64 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
65 | self._swish = MemoryEfficientSwish()
66 |
67 | def forward(self, inputs, drop_connect_rate=None):
68 | """
69 | :param inputs: input tensor
70 | :param drop_connect_rate: drop connect rate (float, between 0 and 1)
71 | :return: output of block
72 | """
73 |
74 | # Expansion and Depthwise Convolution
75 | x = inputs
76 | if self._block_args.expand_ratio != 1:
77 | x = self._expand_conv(inputs)
78 | x = self._bn0(x)
79 | x = self._swish(x)
80 |
81 | x = self._depthwise_conv(x)
82 | x = self._bn1(x)
83 | x = self._swish(x)
84 |
85 | # Squeeze and Excitation
86 | if self.has_se:
87 | x_squeezed = F.adaptive_avg_pool2d(x, 1)
88 | x_squeezed = self._se_reduce(x_squeezed)
89 | x_squeezed = self._swish(x_squeezed)
90 | x_squeezed = self._se_expand(x_squeezed)
91 | x = torch.sigmoid(x_squeezed) * x
92 |
93 | x = self._project_conv(x)
94 | x = self._bn2(x)
95 |
96 | # Skip connection and drop connect
97 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
98 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
99 | if drop_connect_rate:
100 | x = drop_connect(x, p=drop_connect_rate, training=self.training)
101 | x = x + inputs # skip connection
102 | return x
103 |
104 | def set_swish(self, memory_efficient=True):
105 | """Sets swish function as memory efficient (for training) or standard (for export)"""
106 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
107 |
108 |
109 | class EfficientNet(nn.Module):
110 | """
111 | An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
112 |
113 | Args:
114 | blocks_args (list): A list of BlockArgs to construct blocks
115 | global_params (namedtuple): A set of GlobalParams shared between blocks
116 |
117 | Example:
118 | model = EfficientNet.from_pretrained('efficientnet-b0')
119 |
120 | """
121 |
122 | def __init__(self, blocks_args=None, global_params=None):
123 | super().__init__()
124 | assert isinstance(blocks_args, list), 'blocks_args should be a list'
125 | assert len(blocks_args) > 0, 'block args must be greater than 0'
126 | self._global_params = global_params
127 | self._blocks_args = blocks_args
128 |
129 | # Get static or dynamic convolution depending on image size
130 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
131 |
132 | # Batch norm parameters
133 | bn_mom = 1 - self._global_params.batch_norm_momentum
134 | bn_eps = self._global_params.batch_norm_epsilon
135 |
136 | # Stem
137 | in_channels = 3 # rgb
138 | out_channels = round_filters(32, self._global_params) # number of output channels
139 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
140 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
141 |
142 | # Build blocks
143 | self._blocks = nn.ModuleList([])
144 | for block_args in self._blocks_args:
145 |
146 | # Update block input and output filters based on depth multiplier.
147 | block_args = block_args._replace(
148 | input_filters=round_filters(block_args.input_filters, self._global_params),
149 | output_filters=round_filters(block_args.output_filters, self._global_params),
150 | num_repeat=round_repeats(block_args.num_repeat, self._global_params)
151 | )
152 |
153 | # The first block needs to take care of stride and filter size increase.
154 | self._blocks.append(MBConvBlock(block_args, self._global_params))
155 | if block_args.num_repeat > 1:
156 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
157 | for _ in range(block_args.num_repeat - 1):
158 | self._blocks.append(MBConvBlock(block_args, self._global_params))
159 |
160 | # Head
161 | in_channels = block_args.output_filters # output of final block
162 | out_channels = round_filters(1280, self._global_params)
163 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
164 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
165 |
166 | # Final linear layer
167 | self._avg_pooling = nn.AdaptiveAvgPool2d(1)
168 | self._dropout = nn.Dropout(self._global_params.dropout_rate)
169 | self._fc = nn.Linear(out_channels, self._global_params.num_classes)
170 | self._swish = MemoryEfficientSwish()
171 |
172 | def set_swish(self, memory_efficient=True):
173 | """Sets swish function as memory efficient (for training) or standard (for export)"""
174 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
175 | for block in self._blocks:
176 | block.set_swish(memory_efficient)
177 |
178 |
179 | def extract_features(self, inputs):
180 | """ Returns output of the final convolution layer """
181 |
182 | # Stem
183 | x = self._swish(self._bn0(self._conv_stem(inputs)))
184 |
185 | # Blocks
186 | for idx, block in enumerate(self._blocks):
187 | drop_connect_rate = self._global_params.drop_connect_rate
188 | if drop_connect_rate:
189 | drop_connect_rate *= float(idx) / len(self._blocks)
190 | x = block(x, drop_connect_rate=drop_connect_rate)
191 | # Head
192 | x = self._swish(self._bn1(self._conv_head(x)))
193 |
194 | return x
195 |
196 | def forward(self, inputs):
197 | """ Calls extract_features to extract features, applies final linear layer, and returns logits. """
198 | bs = inputs.size(0)
199 | # Convolution layers
200 | x = self.extract_features(inputs)
201 |
202 | # Pooling and final linear layer
203 | x = self._avg_pooling(x)
204 | x = x.view(bs, -1)
205 | x = self._dropout(x)
206 | x = self._fc(x)
207 | return x
208 |
209 | @classmethod
210 | def from_name(cls, model_name, override_params=None):
211 | cls._check_model_name_is_valid(model_name)
212 | blocks_args, global_params = get_model_params(model_name, override_params)
213 | return cls(blocks_args, global_params)
214 |
215 | @classmethod
216 | def from_pretrained(cls, model_name, load_weights=True, advprop=True, num_classes=1000, in_channels=3):
217 | model = cls.from_name(model_name, override_params={'num_classes': num_classes})
218 | if load_weights:
219 | load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop)
220 | if in_channels != 3:
221 | Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size)
222 | out_channels = round_filters(32, model._global_params)
223 | model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
224 | return model
225 |
226 | @classmethod
227 | def get_image_size(cls, model_name):
228 | cls._check_model_name_is_valid(model_name)
229 | _, _, res, _ = efficientnet_params(model_name)
230 | return res
231 |
232 | @classmethod
233 | def _check_model_name_is_valid(cls, model_name):
234 | """ Validates model name. """
235 | valid_models = ['efficientnet-b'+str(i) for i in range(9)]
236 | if model_name not in valid_models:
237 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
238 |
--------------------------------------------------------------------------------
/efficientnet/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains helper functions for building the model and for loading model parameters.
3 | These helper functions are built to mirror those in the official TensorFlow implementation.
4 | """
5 |
6 | import re
7 | import math
8 | import collections
9 | from functools import partial
10 | import torch
11 | from torch import nn
12 | from torch.nn import functional as F
13 | from torch.utils import model_zoo
14 | from .utils_extra import Conv2dStaticSamePadding
15 |
16 | ########################################################################
17 | ############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
18 | ########################################################################
19 |
20 |
21 | # Parameters for the entire model (stem, all blocks, and head)
22 |
23 | GlobalParams = collections.namedtuple('GlobalParams', [
24 | 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate',
25 | 'num_classes', 'width_coefficient', 'depth_coefficient',
26 | 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])
27 |
28 | # Parameters for an individual model block
29 | BlockArgs = collections.namedtuple('BlockArgs', [
30 | 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
31 | 'expand_ratio', 'id_skip', 'stride', 'se_ratio'])
32 |
33 | # Change namedtuple defaults
34 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
35 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
36 |
37 |
38 | class SwishImplementation(torch.autograd.Function):
39 | @staticmethod
40 | def forward(ctx, i):
41 | result = i * torch.sigmoid(i)
42 | ctx.save_for_backward(i)
43 | return result
44 |
45 | @staticmethod
46 | def backward(ctx, grad_output):
47 | i = ctx.saved_variables[0]
48 | sigmoid_i = torch.sigmoid(i)
49 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
50 |
51 |
52 | class MemoryEfficientSwish(nn.Module):
53 | def forward(self, x):
54 | return SwishImplementation.apply(x)
55 |
56 |
57 | class Swish(nn.Module):
58 | def forward(self, x):
59 | return x * torch.sigmoid(x)
60 |
61 |
62 | def round_filters(filters, global_params):
63 | """ Calculate and round number of filters based on depth multiplier. """
64 | multiplier = global_params.width_coefficient
65 | if not multiplier:
66 | return filters
67 | divisor = global_params.depth_divisor
68 | min_depth = global_params.min_depth
69 | filters *= multiplier
70 | min_depth = min_depth or divisor
71 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
72 | if new_filters < 0.9 * filters: # prevent rounding by more than 10%
73 | new_filters += divisor
74 | return int(new_filters)
75 |
76 |
77 | def round_repeats(repeats, global_params):
78 | """ Round number of filters based on depth multiplier. """
79 | multiplier = global_params.depth_coefficient
80 | if not multiplier:
81 | return repeats
82 | return int(math.ceil(multiplier * repeats))
83 |
84 |
85 | def drop_connect(inputs, p, training):
86 | """ Drop connect. """
87 | if not training: return inputs
88 | batch_size = inputs.shape[0]
89 | keep_prob = 1 - p
90 | random_tensor = keep_prob
91 | random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
92 | binary_tensor = torch.floor(random_tensor)
93 | output = inputs / keep_prob * binary_tensor
94 | return output
95 |
96 |
97 | def get_same_padding_conv2d(image_size=None):
98 | """ Chooses static padding if you have specified an image size, and dynamic padding otherwise.
99 | Static padding is necessary for ONNX exporting of models. """
100 | if image_size is None:
101 | return Conv2dDynamicSamePadding
102 | else:
103 | return partial(Conv2dStaticSamePadding, image_size=image_size)
104 |
105 |
106 | class Conv2dDynamicSamePadding(nn.Conv2d):
107 | """ 2D Convolutions like TensorFlow, for a dynamic image size """
108 |
109 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
110 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
111 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
112 |
113 | def forward(self, x):
114 | ih, iw = x.size()[-2:]
115 | kh, kw = self.weight.size()[-2:]
116 | sh, sw = self.stride
117 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
118 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
119 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
120 | if pad_h > 0 or pad_w > 0:
121 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
122 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
123 |
124 |
125 | class Identity(nn.Module):
126 | def __init__(self, ):
127 | super(Identity, self).__init__()
128 |
129 | def forward(self, input):
130 | return input
131 |
132 |
133 | ########################################################################
134 | ############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
135 | ########################################################################
136 |
137 |
138 | def efficientnet_params(model_name):
139 | """ Map EfficientNet model name to parameter coefficients. """
140 | params_dict = {
141 | # Coefficients: width,depth,res,dropout
142 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
143 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
144 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
145 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
146 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
147 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
148 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
149 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
150 | 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
151 | 'efficientnet-l2': (4.3, 5.3, 800, 0.5),
152 | }
153 | return params_dict[model_name]
154 |
155 |
156 | class BlockDecoder(object):
157 | """ Block Decoder for readability, straight from the official TensorFlow repository """
158 |
159 | @staticmethod
160 | def _decode_block_string(block_string):
161 | """ Gets a block through a string notation of arguments. """
162 | assert isinstance(block_string, str)
163 |
164 | ops = block_string.split('_')
165 | options = {}
166 | for op in ops:
167 | splits = re.split(r'(\d.*)', op)
168 | if len(splits) >= 2:
169 | key, value = splits[:2]
170 | options[key] = value
171 |
172 | # Check stride
173 | assert (('s' in options and len(options['s']) == 1) or
174 | (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
175 |
176 | return BlockArgs(
177 | kernel_size=int(options['k']),
178 | num_repeat=int(options['r']),
179 | input_filters=int(options['i']),
180 | output_filters=int(options['o']),
181 | expand_ratio=int(options['e']),
182 | id_skip=('noskip' not in block_string),
183 | se_ratio=float(options['se']) if 'se' in options else None,
184 | stride=[int(options['s'][0])])
185 |
186 | @staticmethod
187 | def _encode_block_string(block):
188 | """Encodes a block to a string."""
189 | args = [
190 | 'r%d' % block.num_repeat,
191 | 'k%d' % block.kernel_size,
192 | 's%d%d' % (block.strides[0], block.strides[1]),
193 | 'e%s' % block.expand_ratio,
194 | 'i%d' % block.input_filters,
195 | 'o%d' % block.output_filters
196 | ]
197 | if 0 < block.se_ratio <= 1:
198 | args.append('se%s' % block.se_ratio)
199 | if block.id_skip is False:
200 | args.append('noskip')
201 | return '_'.join(args)
202 |
203 | @staticmethod
204 | def decode(string_list):
205 | """
206 | Decodes a list of string notations to specify blocks inside the network.
207 |
208 | :param string_list: a list of strings, each string is a notation of block
209 | :return: a list of BlockArgs namedtuples of block args
210 | """
211 | assert isinstance(string_list, list)
212 | blocks_args = []
213 | for block_string in string_list:
214 | blocks_args.append(BlockDecoder._decode_block_string(block_string))
215 | return blocks_args
216 |
217 | @staticmethod
218 | def encode(blocks_args):
219 | """
220 | Encodes a list of BlockArgs to a list of strings.
221 |
222 | :param blocks_args: a list of BlockArgs namedtuples of block args
223 | :return: a list of strings, each string is a notation of block
224 | """
225 | block_strings = []
226 | for block in blocks_args:
227 | block_strings.append(BlockDecoder._encode_block_string(block))
228 | return block_strings
229 |
230 |
231 | def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2,
232 | drop_connect_rate=0.2, image_size=None, num_classes=1000):
233 | """ Creates a efficientnet model. """
234 |
235 | blocks_args = [
236 | 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
237 | 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
238 | 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
239 | 'r1_k3_s11_e6_i192_o320_se0.25',
240 | ]
241 | blocks_args = BlockDecoder.decode(blocks_args)
242 |
243 | global_params = GlobalParams(
244 | batch_norm_momentum=0.99,
245 | batch_norm_epsilon=1e-3,
246 | dropout_rate=dropout_rate,
247 | drop_connect_rate=drop_connect_rate,
248 | # data_format='channels_last', # removed, this is always true in PyTorch
249 | num_classes=num_classes,
250 | width_coefficient=width_coefficient,
251 | depth_coefficient=depth_coefficient,
252 | depth_divisor=8,
253 | min_depth=None,
254 | image_size=image_size,
255 | )
256 |
257 | return blocks_args, global_params
258 |
259 |
260 | def get_model_params(model_name, override_params):
261 | """ Get the block args and global params for a given model """
262 | if model_name.startswith('efficientnet'):
263 | w, d, s, p = efficientnet_params(model_name)
264 | # note: all models have drop connect rate = 0.2
265 | blocks_args, global_params = efficientnet(
266 | width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
267 | else:
268 | raise NotImplementedError('model name is not pre-defined: %s' % model_name)
269 | if override_params:
270 | # ValueError will be raised here if override_params has fields not included in global_params.
271 | global_params = global_params._replace(**override_params)
272 | return blocks_args, global_params
273 |
274 |
275 | url_map = {
276 | 'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b0-355c32eb.pth',
277 | 'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b1-f1951068.pth',
278 | 'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b2-8bb594d6.pth',
279 | 'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b3-5fb5a3c3.pth',
280 | 'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b4-6ed6700e.pth',
281 | 'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b5-b6417697.pth',
282 | 'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b6-c76e70fd.pth',
283 | 'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b7-dcc49843.pth',
284 | }
285 |
286 | url_map_advprop = {
287 | 'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b0-b64d5a18.pth',
288 | 'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b1-0f3ce85a.pth',
289 | 'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b2-6e9d97e5.pth',
290 | 'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b3-cdd7c0f4.pth',
291 | 'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b4-44fb3a87.pth',
292 | 'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b5-86493f6b.pth',
293 | 'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b6-ac80338e.pth',
294 | 'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b7-4652b6dd.pth',
295 | 'efficientnet-b8': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b8-22a8fe65.pth',
296 | }
297 |
298 |
299 | def load_pretrained_weights(model, model_name, load_fc=True, advprop=False):
300 | """ Loads pretrained weights, and downloads if loading for the first time. """
301 | # AutoAugment or Advprop (different preprocessing)
302 | url_map_ = url_map_advprop if advprop else url_map
303 | state_dict = model_zoo.load_url(url_map_[model_name], map_location=torch.device('cpu'))
304 | # state_dict = torch.load('../../weights/backbone_efficientnetb0.pth')
305 | if load_fc:
306 | ret = model.load_state_dict(state_dict, strict=False)
307 | print(ret)
308 | else:
309 | state_dict.pop('_fc.weight')
310 | state_dict.pop('_fc.bias')
311 | res = model.load_state_dict(state_dict, strict=False)
312 | assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
313 | print('Loaded pretrained weights for {}'.format(model_name))
314 |
--------------------------------------------------------------------------------
/efficientnet/utils_extra.py:
--------------------------------------------------------------------------------
1 | # Author: Zylo117
2 |
3 | import math
4 |
5 | from torch import nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class Conv2dStaticSamePadding(nn.Module):
10 | """
11 | created by Zylo117
12 | The real keras/tensorflow conv2d with same padding
13 | """
14 |
15 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1, dilation=1, **kwargs):
16 | super().__init__()
17 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
18 | bias=bias, groups=groups)
19 | self.stride = self.conv.stride
20 | self.kernel_size = self.conv.kernel_size
21 | self.dilation = self.conv.dilation
22 |
23 | if isinstance(self.stride, int):
24 | self.stride = [self.stride] * 2
25 | elif len(self.stride) == 1:
26 | self.stride = [self.stride[0]] * 2
27 |
28 | if isinstance(self.kernel_size, int):
29 | self.kernel_size = [self.kernel_size] * 2
30 | elif len(self.kernel_size) == 1:
31 | self.kernel_size = [self.kernel_size[0]] * 2
32 |
33 | def forward(self, x):
34 | h, w = x.shape[-2:]
35 |
36 | h_step = math.ceil(w / self.stride[1])
37 | v_step = math.ceil(h / self.stride[0])
38 | h_cover_len = self.stride[1] * (h_step - 1) + 1 + (self.kernel_size[1] - 1)
39 | v_cover_len = self.stride[0] * (v_step - 1) + 1 + (self.kernel_size[0] - 1)
40 |
41 | extra_h = h_cover_len - w
42 | extra_v = v_cover_len - h
43 |
44 | left = extra_h // 2
45 | right = extra_h - left
46 | top = extra_v // 2
47 | bottom = extra_v - top
48 |
49 | x = F.pad(x, [left, right, top, bottom])
50 |
51 | x = self.conv(x)
52 | return x
53 |
54 |
55 | class MaxPool2dStaticSamePadding(nn.Module):
56 | """
57 | created by Zylo117
58 | The real keras/tensorflow MaxPool2d with same padding
59 | """
60 |
61 | def __init__(self, *args, **kwargs):
62 | super().__init__()
63 | self.pool = nn.MaxPool2d(*args, **kwargs)
64 | self.stride = self.pool.stride
65 | self.kernel_size = self.pool.kernel_size
66 |
67 | if isinstance(self.stride, int):
68 | self.stride = [self.stride] * 2
69 | elif len(self.stride) == 1:
70 | self.stride = [self.stride[0]] * 2
71 |
72 | if isinstance(self.kernel_size, int):
73 | self.kernel_size = [self.kernel_size] * 2
74 | elif len(self.kernel_size) == 1:
75 | self.kernel_size = [self.kernel_size[0]] * 2
76 |
77 | def forward(self, x):
78 | h, w = x.shape[-2:]
79 |
80 | h_step = math.ceil(w / self.stride[1])
81 | v_step = math.ceil(h / self.stride[0])
82 | h_cover_len = self.stride[1] * (h_step - 1) + 1 + (self.kernel_size[1] - 1)
83 | v_cover_len = self.stride[0] * (v_step - 1) + 1 + (self.kernel_size[0] - 1)
84 |
85 | extra_h = h_cover_len - w
86 | extra_v = v_cover_len - h
87 |
88 | left = extra_h // 2
89 | right = extra_h - left
90 | top = extra_v // 2
91 | bottom = extra_v - top
92 |
93 | x = F.pad(x, [left, right, top, bottom])
94 |
95 | x = self.pool(x)
96 | return x
97 |
--------------------------------------------------------------------------------
/pic/data/img_test1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/data/img_test1.jpg
--------------------------------------------------------------------------------
/pic/data/img_test2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/data/img_test2.jpg
--------------------------------------------------------------------------------
/pic/data/p0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/data/p0.png
--------------------------------------------------------------------------------
/pic/data/p1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/data/p1.png
--------------------------------------------------------------------------------
/pic/data/p2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/data/p2.png
--------------------------------------------------------------------------------
/pic/data/p3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/data/p3.png
--------------------------------------------------------------------------------
/pic/data/p4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/data/p4.png
--------------------------------------------------------------------------------
/pic/data/p5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/data/p5.png
--------------------------------------------------------------------------------
/pic/p0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p0.png
--------------------------------------------------------------------------------
/pic/p1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p1.png
--------------------------------------------------------------------------------
/pic/p10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p10.png
--------------------------------------------------------------------------------
/pic/p11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p11.png
--------------------------------------------------------------------------------
/pic/p12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p12.png
--------------------------------------------------------------------------------
/pic/p13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p13.png
--------------------------------------------------------------------------------
/pic/p14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p14.png
--------------------------------------------------------------------------------
/pic/p15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p15.png
--------------------------------------------------------------------------------
/pic/p16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p16.png
--------------------------------------------------------------------------------
/pic/p17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p17.png
--------------------------------------------------------------------------------
/pic/p18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p18.png
--------------------------------------------------------------------------------
/pic/p19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p19.png
--------------------------------------------------------------------------------
/pic/p2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p2.png
--------------------------------------------------------------------------------
/pic/p20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p20.png
--------------------------------------------------------------------------------
/pic/p21.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p21.png
--------------------------------------------------------------------------------
/pic/p3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p3.png
--------------------------------------------------------------------------------
/pic/p4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p4.png
--------------------------------------------------------------------------------
/pic/p5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p5.png
--------------------------------------------------------------------------------
/pic/p6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p6.png
--------------------------------------------------------------------------------
/pic/p7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p7.png
--------------------------------------------------------------------------------
/pic/p8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p8.png
--------------------------------------------------------------------------------
/pic/p9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/pic/p9.png
--------------------------------------------------------------------------------
/projects/coco.yml:
--------------------------------------------------------------------------------
1 | project_name: coco # also the folder name of the dataset that under data_path folder
2 | train_set: train2017
3 | val_set: val2017
4 | num_gpus: 4
5 |
6 | # mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco.
7 | mean: [0.485, 0.456, 0.406]
8 | std: [0.229, 0.224, 0.225]
9 |
10 | # this is coco anchors, change it if necessary
11 | anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]'
12 | anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]'
13 |
14 | # must match your dataset's category_id.
15 | # category_id is one_indexed,
16 | # for example, index of 'car' here is 2, while category_id of is 3
17 | obj_list: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
18 | 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
19 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie',
20 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
21 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
22 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
23 | 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv',
24 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
25 | 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
26 | 'toothbrush']
--------------------------------------------------------------------------------
/projects/shape.yml:
--------------------------------------------------------------------------------
1 | project_name: shape # also the folder name of the dataset that under data_path folder
2 | train_set: train
3 | val_set: val
4 | num_gpus: 1
5 |
6 | # mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco.
7 | mean: [0.485, 0.456, 0.406]
8 | std: [0.229, 0.224, 0.225]
9 |
10 | # this anchor is adapted to the dataset
11 | anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]'
12 | anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]'
13 |
14 | obj_list: ['rectangle', 'circle']
--------------------------------------------------------------------------------
/projects/underwater.yml:
--------------------------------------------------------------------------------
1 | project_name: underwater # also the folder name of the dataset that under data_path folder
2 | train_set: train
3 | val_set: val
4 | num_gpus: 1
5 |
6 | # mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco.
7 | mean: [0.485, 0.456, 0.406]
8 | std: [0.229, 0.224, 0.225]
9 |
10 | # this is coco anchors, change it if necessary
11 | anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]'
12 | anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]'
13 |
14 | # must match your dataset's category_id.
15 | # category_id is one_indexed,
16 | # for example, index of 'car' here is 2, while category_id of is 3
17 | obj_list: ["holothurian","echinus","scallop","starfish"]
18 |
--------------------------------------------------------------------------------
/readme_efficientdet.md:
--------------------------------------------------------------------------------
1 | # Yet Another EfficientDet Pytorch
2 |
3 | The pytorch re-implement of the official [EfficientDet](https://github.com/google/automl/efficientdet) with SOTA performance in real time, original paper link: https://arxiv.org/abs/1911.09070
4 |
5 |
6 | # Performance
7 |
8 | ## Pretrained weights and benchmark
9 |
10 | The performance is very close to the paper's, it is still SOTA.
11 |
12 | The speed/FPS test includes the time of post-processing with no jit/data precision trick.
13 |
14 | | coefficient | pth_download | GPU Mem(MB) | FPS | Extreme FPS (Batchsize 32) | mAP 0.5:0.95(this repo) | mAP 0.5:0.95(paper) |
15 | | :-----: | :-----: | :------: | :------: | :------: | :-----: | :-----: |
16 | | D0 | [efficientdet-d0.pth](https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d0.pth) | 1049 | 36.20 | 163.14 | 32.6 | 33.8
17 | | D1 | [efficientdet-d1.pth](https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d1.pth) | 1159 | 29.69 | 63.08 | 38.2 | 39.6
18 | | D2 | [efficientdet-d2.pth](https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d2.pth) | 1321 | 26.50 | 40.99 | 41.5 | 43.0
19 | | D3 | [efficientdet-d3.pth](https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d3.pth) | 1647 | 22.73 | - | 44.9 | 45.8
20 | | D4 | [efficientdet-d4.pth](https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d4.pth) | 1903 | 14.75 | - | 48.1 | 49.4
21 | | D5 | [efficientdet-d5.pth](https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d5.pth) | 2255 | 7.11 | - | 49.5 | 50.7
22 | | D6 | [efficientdet-d6.pth](https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d6.pth) | 2985 | 5.30 | - | 50.1 | 51.7
23 | | D7 | [efficientdet-d7.pth](https://github.com/zylo117/Yet-Another-Efficient-Pytorch/releases/download/1.0/efficientdet-d7.pth) | 3819 | 3.73 | - | 50.7 | 52.2
24 |
25 | ## Speed Test
26 |
27 | This pure-pytorch implement is 26 times faster than the official Tensorflow version without any trick.
28 |
29 | | coefficient | Time | FPS | Ratio |
30 | | :------: | :------: | :------: | :-----: |
31 | | Official D0 (tf postprocess) | 0.713s | 1.40 | 1X |
32 | | Official D0 (numpy postprocess) | 0.477s | 2.09 | 1.49X |
33 | | **_Yet-Another-EfficientDet-D0_** | **_0.028s_** | **_36.20_** | **_25.86X_** |
34 |
35 |
36 | Test method:
37 |
38 | Run this test on 2080Ti, Ubuntu 19.10 x64.
39 | 1. Prepare two image tensor with the same content, size (1,3,512,512)-pytorch, (1,512,512,3)-tensorflow.
40 | 2. Initiate everything by inferring once.
41 | 3. Run 10 times with batchsize 1 and calculate the average time, including post-processing and visualization, to make the test more practical.
42 |
43 | ___
44 | # Update log
45 |
46 | [2020-04-14] fixed loss function bug. please pull the latest code.
47 |
48 | [2020-04-14] for those who needs help or can't get a good result after several epochs, check out this [tutorial](tutorial/train_shape.ipynb). You can run it on colab with GPU support.
49 |
50 | [2020-04-10] warp the loss function within the training model, so that the memory usage will be balanced when training with multiple gpus, enabling training with bigger batchsize.
51 |
52 | [2020-04-10] add D7 (D6 with larger input size and larger anchor scale) support and test its mAP
53 |
54 | [2020-04-09] allow custom anchor scales and ratios
55 |
56 | [2020-04-08] add D6 support and test its mAP
57 |
58 | [2020-04-08] add training script and its doc; update eval script and simple inference script.
59 |
60 | [2020-04-07] tested D0-D5 mAP, result seems nice, details can be found [here](benchmark/coco_eval_result)
61 |
62 | [2020-04-07] fix anchors strategies.
63 |
64 | [2020-04-06] adapt anchor strategies.
65 |
66 | [2020-04-05] create this repository.
67 |
68 | # Demo
69 |
70 | # install requirements
71 | pip install pycocotools numpy opencv-python tqdm tensorboard tensorboardX pyyaml
72 | pip install torch==1.4.0
73 | pip install torchvision==0.5.0
74 |
75 | # run the simple inference script
76 | python efficientdet_test.py
77 |
78 | # Training
79 |
80 | Training EfficientDet is a painful and time-consuming task. You shouldn't expect to get a good result within a day or two. Please be patient.
81 |
82 | Check out this [tutorial](tutorial/train_shape.ipynb) if you are new to this. You can run it on colab with GPU support.
83 |
84 | ## 1. Prepare your dataset
85 |
86 | # your dataset structure should be like this
87 | datasets/
88 | -your_project_name/
89 | -train_set_name/
90 | -*.jpg
91 | -val_set_name/
92 | -*.jpg
93 | -annotations
94 | -instances_{train_set_name}.json
95 | -instances_{val_set_name}.json
96 |
97 | # for example, coco2017
98 | datasets/
99 | -coco2017/
100 | -train2017/
101 | -000000000001.jpg
102 | -000000000002.jpg
103 | -000000000003.jpg
104 | -val2017/
105 | -000000000004.jpg
106 | -000000000005.jpg
107 | -000000000006.jpg
108 | -annotations
109 | -instances_train2017.json
110 | -instances_val2017.json
111 |
112 |
113 | ## 2. Manual set project's specific parameters
114 |
115 | # create a yml file {your_project_name}.yml under 'projects'folder
116 | # modify it following 'coco.yml'
117 |
118 | # for example
119 | project_name: coco
120 | train_set: train2017
121 | val_set: val2017
122 | num_gpus: 4 # 0 means using cpu, 1-N means using gpus
123 |
124 | # mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco.
125 | mean: [0.485, 0.456, 0.406]
126 | std: [0.229, 0.224, 0.225]
127 |
128 | # this is coco anchors, change it if necessary
129 | anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]'
130 | anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]'
131 |
132 | # objects from all labels from your dataset with the order from your annotations.
133 | # its index must match your dataset's category_id.
134 | # category_id is one_indexed,
135 | # for example, index of 'car' here is 2, while category_id of is 3
136 | obj_list: ['person', 'bicycle', 'car', ...]
137 |
138 |
139 | ## 3.a. Train on coco from scratch
140 |
141 | # train efficientdet-d0 on coco from scratch
142 | # with batchsize 12
143 | # This takes time and requires change
144 | # of hyperparameters every few hours.
145 | # If you have months to kill, do it.
146 | # It's not like someone going to achieve
147 | # better score than the one in the paper.
148 | # The first few epoches will be rather unstable,
149 | # it's quite normal when you train from scratch.
150 |
151 | python train.py -c 0 --batch_size 12
152 |
153 | ## 3.b. Train a custom dataset from scratch
154 |
155 | # train efficientdet-d1 on a custom dataset
156 | # with batchsize 8 and learning rate 1e-5
157 |
158 | python train.py -c 1 --batch_size 8 --lr 1e-5
159 |
160 | ## 3.c. Train a custom dataset with pretrained weights (Highly Recommended)
161 |
162 | # train efficientdet-d2 on a custom dataset with pretrained weights
163 | # with batchsize 8 and learning rate 1e-5 for 10 epoches
164 |
165 | python train.py -c 2 --batch_size 8 --lr 1e-5 --num_epochs 10 \
166 | --load_weights /path/to/your/weights/efficientdet-d2.pth
167 |
168 | # with a coco-pretrained, you can even freeze the backbone and train heads only
169 | # to speed up training and help convergence.
170 |
171 | python train.py -c 2 --batch_size 8 --lr 1e-5 --num_epochs 10 \
172 | --load_weights /path/to/your/weights/efficientdet-d2.pth \
173 | --head_only True
174 |
175 | ## 4. Early stopping a training session
176 |
177 | # while training, press Ctrl+c, the program will catch KeyboardInterrupt
178 | # and stop training, save current checkpoint.
179 |
180 | ## 5. Resume training
181 |
182 | # let says you started a training session like this.
183 |
184 | python train.py -c 2 --batch_size 8 --lr 1e-5 \
185 | --load_weights /path/to/your/weights/efficientdet-d2.pth \
186 | --head_only True
187 |
188 | # then you stopped it with a Ctrl+c, it exited with a checkpoint
189 |
190 | # now you want to resume training from the last checkpoint
191 | # simply set load_weights to 'last'
192 |
193 | python train.py -c 2 --batch_size 8 --lr 1e-5 \
194 | --load_weights last \
195 | --head_only True
196 |
197 | ## 6. Evaluate model performance
198 |
199 | # eval on your_project, efficientdet-d5
200 |
201 | python coco_eval.py -p your_project_name -c 5 \
202 | -w /path/to/your/weights
203 |
204 | ## 7. Debug training (optional)
205 |
206 | # when you get bad result, you need to debug the training result.
207 | python train.py -c 2 --batch_size 8 --lr 1e-5 --debug True
208 |
209 | # then checkout test/ folder, there you can visualize the predicted boxes during training
210 | # don't panic if you see countless of error boxes, it happens when the training is at early stage.
211 | # But if you still can't see a normal box after several epoches, not even one in all image,
212 | # then it's possible that either the anchors config is inappropriate or the ground truth is corrupted.
213 |
214 | # TODO
215 |
216 | - [X] re-implement efficientdet
217 | - [X] adapt anchor strategies
218 | - [X] mAP tests
219 | - [X] training-scripts
220 | - [X] efficientdet D6 supports
221 | - [X] efficientdet D7 supports
222 |
223 | # FAQ:
224 |
225 | **Q1. Why implement this while there are several efficientdet pytorch projects already.**
226 |
227 | A1: Because AFAIK none of them fully recovers the true algorithm of the official efficientdet, that's why their communities could not achieve or having a hard time to achieve the same score as the official efficientdet by training from scratch.
228 |
229 | **Q2: What exactly is the difference among this repository and the others?**
230 |
231 | A2: For example, these two are the most popular efficientdet-pytorch,
232 |
233 | https://github.com/toandaominh1997/EfficientDet.Pytorch
234 |
235 | https://github.com/signatrix/efficientdet
236 |
237 | Here is the issues and why these are difficult to achieve the same score as the official one:
238 |
239 | The first one:
240 |
241 | 1. Altered EfficientNet the wrong way, strides have been changed to adapt the BiFPN, but we should be aware that efficientnet's great performance comes from it's specific parameters combinations. Any slight alteration could lead to worse performance.
242 |
243 | The second one:
244 |
245 | 1. Pytorch's BatchNormalization is slightly different from TensorFlow, momentum_pytorch = 1 - momentum_tensorflow. Well I didn't realize this trap if I paid less attentions. signatrix/efficientdet succeeded the parameter from TensorFlow, so the BN will perform badly because running mean and the running variance is being dominated by the new input.
246 |
247 | 2. Mis-implement of Depthwise-Separable Conv2D. Depthwise-Separable Conv2D is Depthwise-Conv2D and Pointwise-Conv2D and BiasAdd ,there is only a BiasAdd after two Conv2D, while signatrix/efficientdet has a extra BiasAdd on Depthwise-Conv2D.
248 |
249 | 3. Misunderstand the first parameter of MaxPooling2D, the first parameter is kernel_size, instead of stride.
250 |
251 | 4. Missing BN after downchannel of the feature of the efficientnet output.
252 |
253 | 5. Using the wrong output feature of the efficientnet. This is big one. It takes whatever output that has the conv.stride of 2, but it's wrong. It should be the one whose next conv.stride is 2 or the final output of efficientnet.
254 |
255 | 6. Does not apply same padding on Conv2D and Pooling.
256 |
257 | 7. Missing swish activation after several operations.
258 |
259 | 8. Missing Conv/BN operations in BiFPN, Regressor and Classifier. This one is very tricky, if you don't dig deeper into the official implement, there are some same operations with different weights.
260 |
261 |
262 | illustration of a minimal bifpn unit
263 | P7_0 -------------------------> P7_2 -------->
264 | |-------------| ↑
265 | ↓ |
266 | P6_0 ---------> P6_1 ---------> P6_2 -------->
267 | |-------------|--------------↑ ↑
268 | ↓ |
269 | P5_0 ---------> P5_1 ---------> P5_2 -------->
270 | |-------------|--------------↑ ↑
271 | ↓ |
272 | P4_0 ---------> P4_1 ---------> P4_2 -------->
273 | |-------------|--------------↑ ↑
274 | |--------------↓ |
275 | P3_0 -------------------------> P3_2 -------->
276 |
277 | For example, P4 will downchannel to P4_0, then it goes P4_1,
278 | anyone may takes it for granted that P4_0 goes to P4_2 directly, right?
279 |
280 | That's why they are wrong,
281 | P4 should downchannel again with a different weights to P4_0_another,
282 | then it goes to P4_2.
283 |
284 | And finally some common issues, their anchor decoder and encoder are different from the original one, but it's not the main reason that it performs badly.
285 |
286 | Also, Conv2dStaticSamePadding from [EfficientNet-PyTorch](https://github.com/lukemelas/EfficientNet-PyTorch) does not perform like TensorFlow, the padding strategy is different. So I implement a real tensorflow-style [Conv2dStaticSamePadding](efficientnet/utils_extra.py#L9) and [MaxPool2dStaticSamePadding](efficientnet/utils_extra.py#L55) myself.
287 |
288 | Despite of the above issues, they are great repositories that enlighten me, hence there is this repository.
289 |
290 | This repository is mainly based on [efficientdet](https://github.com/signatrix/efficientdet), with the changing that makes sure that it performs as closer as possible as the paper.
291 |
292 | Btw, debugging static-graph TensorFlow v1 is really painful. Don't try to export it with automation tools like tf-onnx or mmdnn, they will only cause more problems because of its custom/complex operations.
293 |
294 | And even if you succeeded, like I did, you will have to deal with the crazy messed up machine-generated code under the same class that takes more time to refactor than translating it from scratch.
295 |
296 | **Q3: What should I do when I find a bug?**
297 |
298 | A3: Check out the update log if it's been fixed, then pull the latest code to try again. If it doesn't help, create a new issue and describe it in detail.
299 |
300 | # Known issues
301 |
302 | 1. Official EfficientDet use TensorFlow bilinear interpolation to resize image inputs, while it is different from many other methods (opencv/pytorch), so the output is definitely slightly different from the official one.
303 |
304 | # Visual Comparison
305 |
306 | Conclusion: They are providing almost the same precision.
307 |
308 | ## This Repo
309 |
310 |
311 | ## Official EfficientDet
312 |
313 |
314 | ## References
315 |
316 | Appreciate the great work from the following repositories:
317 | - [google/automl](https://github.com/google/automl)
318 | - [lukemelas/EfficientNet-PyTorch](https://github.com/lukemelas/EfficientNet-PyTorch)
319 | - [signatrix/efficientdet](https://github.com/signatrix/efficientdet)
320 | - [vacancy/Synchronized-BatchNorm-PyTorch](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch)
321 |
322 | ## Donation
323 |
324 | If you like this repository, or if you'd like to support the author for any reason, you can donate to the author. Feel free to send me your name or introducing pages, I will make sure your name(s) on the sponsors list.
325 |
326 |
327 |
328 | ## Sponsors
329 |
330 | Sincerely thank you for your generosity.
331 |
332 | [cndylan](https://github.com/cndylan)
333 | [claire-s11](https://github.com/claire-s11)
334 |
--------------------------------------------------------------------------------
/test/img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/test/img.png
--------------------------------------------------------------------------------
/test/img_inferred_d0_official.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/test/img_inferred_d0_official.jpg
--------------------------------------------------------------------------------
/test/img_inferred_d0_this_repo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/test/img_inferred_d0_this_repo.jpg
--------------------------------------------------------------------------------
/test/img_inferred_d0_this_repo_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/EfficientDet_pytorch/b915a3e6a2c4ba6cacbaf9e0d84536cc34c27d59/test/img_inferred_d0_this_repo_0.jpg
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # original author: signatrix
2 | # adapted from https://github.com/signatrix/efficientdet/blob/master/train.py
3 | # modified by Zylo117
4 |
5 | import datetime
6 | import os
7 | import argparse
8 | import traceback
9 |
10 | import torch
11 | import yaml
12 | from torch import nn
13 | from torch.utils.data import DataLoader
14 | from torchvision import transforms
15 | from efficientdet.dataset import CocoDataset, Resizer, Normalizer, Augmenter, collater
16 | from backbone import EfficientDetBackbone
17 | from tensorboardX import SummaryWriter
18 | import numpy as np
19 | from tqdm.autonotebook import tqdm
20 |
21 | from efficientdet.loss import FocalLoss
22 | from utils.utils import replace_w_sync_bn, CustomDataParallel, get_last_weights, init_weights
23 |
24 |
25 | class Params:
26 | def __init__(self, project_file):
27 | self.params = yaml.safe_load(open(project_file).read())
28 |
29 | def __getattr__(self, item):
30 | return self.params.get(item, None)
31 |
32 |
33 | def get_args():
34 | parser = argparse.ArgumentParser('Yet Another EfficientDet Pytorch: SOTA object detection network - Zylo117')
35 | parser.add_argument('-p', '--project', type=str, default='underwater', help='project file that contains parameters')
36 | parser.add_argument('-c', '--compound_coef', type=int, default=0, help='coefficients of efficientdet')
37 | parser.add_argument('-n', '--num_workers', type=int, default=12, help='num_workers of dataloader')
38 | parser.add_argument('--batch_size', type=int, default=16, help='The number of images per batch among all devices')
39 | parser.add_argument('--head_only', type=bool, default=False,
40 | help='whether finetunes only the regressor and the classifier, '
41 | 'useful in early stage convergence or small/easy dataset')
42 | parser.add_argument('--lr', type=float, default=1e-4)
43 | parser.add_argument('--optim', type=str, default='adamw', help='select optimizer for training, '
44 | 'suggest using \'admaw\' until the'
45 | ' very final stage then switch to \'sgd\'')
46 | parser.add_argument('--alpha', type=float, default=0.25)
47 | parser.add_argument('--gamma', type=float, default=1.5)
48 | parser.add_argument('--num_epochs', type=int, default=500)
49 | parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')
50 | parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')
51 | parser.add_argument('--es_min_delta', type=float, default=0.0,
52 | help='Early stopping\'s parameter: minimum change loss to qualify as an improvement')
53 | parser.add_argument('--es_patience', type=int, default=0,
54 | help='Early stopping\'s parameter: number of epochs with no improvement after which training will be stopped. Set to 0 to disable this technique.')
55 | parser.add_argument('--data_path', type=str, default='dataset/', help='the root folder of dataset')
56 | parser.add_argument('--log_path', type=str, default='logs/')
57 | parser.add_argument('--load_weights', type=str, default=None,
58 | help='whether to load weights from a checkpoint, set None to initialize, set \'last\' to load last checkpoint')
59 | parser.add_argument('--saved_path', type=str, default='logs/')
60 | parser.add_argument('--debug', type=bool, default=False, help='whether visualize the predicted boxes of trainging, '
61 | 'the output images will be in test/')
62 |
63 | args = parser.parse_args()
64 | return args
65 |
66 |
67 | class ModelWithLoss(nn.Module):
68 | def __init__(self, model, debug=False):
69 | super().__init__()
70 | self.criterion = FocalLoss()
71 | self.model = model
72 | self.debug = debug
73 |
74 | def forward(self, imgs, annotations, obj_list=None):
75 | _, regression, classification, anchors = self.model(imgs)
76 | if self.debug:
77 | cls_loss, reg_loss = self.criterion(classification, regression, anchors, annotations,
78 | imgs=imgs, obj_list=obj_list)
79 | else:
80 | cls_loss, reg_loss = self.criterion(classification, regression, anchors, annotations)
81 | return cls_loss, reg_loss
82 |
83 |
84 | def train(opt):
85 | params = Params('projects/{}.yml'.format(opt.project))
86 |
87 | if params.num_gpus == 0:
88 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
89 |
90 | if torch.cuda.is_available():
91 | torch.cuda.manual_seed(42)
92 | else:
93 | torch.manual_seed(42)
94 |
95 | opt.saved_path = opt.saved_path + '/{}/'.format(params.project_name)
96 | opt.log_path = opt.log_path + '/{}/tensorboard/'.format(params.project_name)
97 | os.makedirs(opt.log_path, exist_ok=True)
98 | os.makedirs(opt.saved_path, exist_ok=True)
99 |
100 | training_params = {'batch_size': opt.batch_size,
101 | 'shuffle': True,
102 | 'drop_last': True,
103 | 'collate_fn': collater,
104 | 'num_workers': opt.num_workers}
105 |
106 | val_params = {'batch_size': opt.batch_size,
107 | 'shuffle': False,
108 | 'drop_last': True,
109 | 'collate_fn': collater,
110 | 'num_workers': opt.num_workers}
111 |
112 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
113 | training_set = CocoDataset(root_dir=opt.data_path + params.project_name, set=params.train_set,
114 | transform=transforms.Compose([Normalizer(mean=params.mean, std=params.std),
115 | Augmenter(),
116 | Resizer(input_sizes[opt.compound_coef])]))
117 | training_generator = DataLoader(training_set, **training_params)
118 |
119 | val_set = CocoDataset(root_dir=opt.data_path + params.project_name, set=params.val_set,
120 | transform=transforms.Compose([Normalizer(mean=params.mean, std=params.std),
121 | Resizer(input_sizes[opt.compound_coef])]))
122 | val_generator = DataLoader(val_set, **val_params)
123 |
124 | model = EfficientDetBackbone(num_classes=len(params.obj_list), compound_coef=opt.compound_coef,
125 | ratios=eval(params.anchors_ratios), scales=eval(params.anchors_scales))
126 |
127 | # load last weights
128 | if opt.load_weights is not None:
129 | if opt.load_weights.endswith('.pth'):
130 | weights_path = opt.load_weights
131 | else:
132 | weights_path = get_last_weights(opt.saved_path)
133 | try:
134 | last_step = int(os.path.basename(weights_path).split('_')[-1].split('.')[0])
135 | except:
136 | last_step = 0
137 |
138 | try:
139 | ret = model.load_state_dict(torch.load(weights_path), strict=False)
140 | except RuntimeError as e:
141 | print('[Warning] Ignoring {}'.format(e))
142 | print(
143 | '[Warning] Don\'t panic if you see this, this might be because you load a pretrained weights with different number of classes. The rest of the weights should be loaded already.')
144 |
145 | print('[Info] loaded weights: {}, resuming checkpoint from step: {}'.format(os.path.basename(weights_path),last_step))
146 | else:
147 | last_step = 0
148 | print('[Info] initializing weights...')
149 | init_weights(model)
150 |
151 | # freeze backbone if train head_only
152 | if opt.head_only:
153 | def freeze_backbone(m):
154 | classname = m.__class__.__name__
155 | for ntl in ['EfficientNet', 'BiFPN']:
156 | if ntl in classname:
157 | for param in m.parameters():
158 | param.requires_grad = False
159 |
160 | model.apply(freeze_backbone)
161 | print('[Info] freezed backbone')
162 |
163 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
164 | # apply sync_bn when using multiple gpu and batch_size per gpu is lower than 4
165 | # useful when gpu memory is limited.
166 | # because when bn is disable, the training will be very unstable or slow to converge,
167 | # apply sync_bn can solve it,
168 | # by packing all mini-batch across all gpus as one batch and normalize, then send it back to all gpus.
169 | # but it would also slow down the training by a little bit.
170 | if params.num_gpus > 1 and opt.batch_size // params.num_gpus < 4:
171 | model.apply(replace_w_sync_bn)
172 |
173 | writer = SummaryWriter(opt.log_path + '/{}/'.format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S")))
174 |
175 | # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
176 | model = ModelWithLoss(model, debug=opt.debug)
177 |
178 | if params.num_gpus > 0:
179 | model = model.cuda()
180 | if params.num_gpus > 1:
181 | model = CustomDataParallel(model, params.num_gpus)
182 |
183 | if opt.optim == 'adamw':
184 | optimizer = torch.optim.AdamW(model.parameters(), opt.lr)
185 | else:
186 | optimizer = torch.optim.SGD(model.parameters(), opt.lr, momentum=0.9, nesterov=True)
187 |
188 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
189 |
190 | epoch = 0
191 | best_loss = 1e5
192 | best_epoch = 0
193 | step = max(0, last_step)
194 | model.train()
195 |
196 | num_iter_per_epoch = len(training_generator)
197 |
198 | try:
199 | for epoch in range(opt.num_epochs):
200 | last_epoch = step // num_iter_per_epoch
201 | if epoch < last_epoch:
202 | continue
203 |
204 | epoch_loss = []
205 | progress_bar = tqdm(training_generator)
206 | for iter, data in enumerate(progress_bar):
207 | if iter < step - last_epoch * num_iter_per_epoch:
208 | progress_bar.update()
209 | continue
210 | try:
211 | imgs = data['img']
212 | annot = data['annot']
213 |
214 | if params.num_gpus == 1:
215 | # if only one gpu, just send it to cuda:0
216 | # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
217 | imgs = imgs.cuda()
218 | annot = annot.cuda()
219 |
220 | optimizer.zero_grad()
221 | cls_loss, reg_loss = model(imgs, annot, obj_list=params.obj_list)
222 | cls_loss = cls_loss.mean()
223 | reg_loss = reg_loss.mean()
224 |
225 | loss = cls_loss + reg_loss
226 | if loss == 0 or not torch.isfinite(loss):
227 | continue
228 |
229 | loss.backward()
230 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
231 | optimizer.step()
232 |
233 | epoch_loss.append(float(loss))
234 |
235 | progress_bar.set_description(
236 | 'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'.format(
237 | step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch, cls_loss.item(),
238 | reg_loss.item(), loss.item()))
239 | writer.add_scalars('Loss', {'train': loss}, step)
240 | writer.add_scalars('Regression_loss', {'train': reg_loss}, step)
241 | writer.add_scalars('Classfication_loss', {'train': cls_loss}, step)
242 |
243 | # log learning_rate
244 | current_lr = optimizer.param_groups[0]['lr']
245 | writer.add_scalar('learning_rate', current_lr, step)
246 |
247 | step += 1
248 |
249 | if step % opt.save_interval == 0 and step > 0:
250 | save_checkpoint(model, 'efficientdet-d{}_{}_{}.pth'.format(opt.compound_coef,epoch,step))
251 | print('checkpoint...')
252 |
253 | except Exception as e:
254 | print('[Error]', traceback.format_exc())
255 | print(e)
256 | continue
257 | scheduler.step(np.mean(epoch_loss))
258 |
259 | if epoch % opt.val_interval == 0:
260 | model.eval()
261 | loss_regression_ls = []
262 | loss_classification_ls = []
263 | for iter, data in enumerate(val_generator):
264 | with torch.no_grad():
265 | imgs = data['img']
266 | annot = data['annot']
267 |
268 | if params.num_gpus == 1:
269 | imgs = imgs.cuda()
270 | annot = annot.cuda()
271 |
272 | cls_loss, reg_loss = model(imgs, annot, obj_list=params.obj_list)
273 | cls_loss = cls_loss.mean()
274 | reg_loss = reg_loss.mean()
275 |
276 | loss = cls_loss + reg_loss
277 | if loss == 0 or not torch.isfinite(loss):
278 | continue
279 |
280 | loss_classification_ls.append(cls_loss.item())
281 | loss_regression_ls.append(reg_loss.item())
282 |
283 | cls_loss = np.mean(loss_classification_ls)
284 | reg_loss = np.mean(loss_regression_ls)
285 | loss = cls_loss + reg_loss
286 |
287 | print(
288 | 'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'.format(
289 | epoch, opt.num_epochs, cls_loss, reg_loss, loss))
290 | writer.add_scalars('Total_loss', {'val': loss}, step)
291 | writer.add_scalars('Regression_loss', {'val': reg_loss}, step)
292 | writer.add_scalars('Classfication_loss', {'val': cls_loss}, step)
293 |
294 | if loss + opt.es_min_delta < best_loss:
295 | best_loss = loss
296 | best_epoch = epoch
297 |
298 | save_checkpoint(model, 'efficientdet-d{}_{}_{}.pth'.format(opt.compound_coef,epoch,step))
299 |
300 | # onnx export is not tested.
301 | # dummy_input = torch.rand(opt.batch_size, 3, 512, 512)
302 | # if torch.cuda.is_available():
303 | # dummy_input = dummy_input.cuda()
304 | # if isinstance(model, nn.DataParallel):
305 | # model.module.backbone_net.model.set_swish(memory_efficient=False)
306 | #
307 | # torch.onnx.export(model.module, dummy_input,
308 | # os.path.join(opt.saved_path, 'signatrix_efficientdet_coco.onnx'),
309 | # verbose=False)
310 | # model.module.backbone_net.model.set_swish(memory_efficient=True)
311 | # else:
312 | # model.backbone_net.model.set_swish(memory_efficient=False)
313 | #
314 | # torch.onnx.export(model, dummy_input,
315 | # os.path.join(opt.saved_path, 'signatrix_efficientdet_coco.onnx'),
316 | # verbose=False)
317 | # model.backbone_net.model.set_swish(memory_efficient=True)
318 |
319 | # Early stopping
320 | if epoch - best_epoch > opt.es_patience > 0:
321 | print('[Info] Stop training at epoch {}. The lowest loss achieved is {}'.format(epoch, loss))
322 | break
323 | except KeyboardInterrupt:
324 | save_checkpoint(model, 'efficientdet-d{}_{}_{}.pth'.format(opt.compound_coef,epoch,step))
325 | writer.close()
326 | writer.close()
327 |
328 |
329 | def save_checkpoint(model, name):
330 | if isinstance(model, CustomDataParallel):
331 | torch.save(model.module.model.state_dict(), os.path.join(opt.saved_path, name))
332 | else:
333 | torch.save(model.model.state_dict(), os.path.join(opt.saved_path, name))
334 |
335 |
336 | if __name__ == '__main__':
337 | opt = get_args()
338 | train(opt)
339 |
--------------------------------------------------------------------------------
/tutorial/train_shape.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "collapsed": true,
7 | "pycharm": {
8 | "name": "#%% md\n"
9 | }
10 | },
11 | "source": [
12 | "# EfficientDet Training On A Custom Dataset\n",
13 | "\n",
14 | "\n",
15 | "\n",
16 | ""
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "source": [
29 | "## This tutorial will show you how to train a custom dataset.\n",
30 | "\n",
31 | "## For the sake of simplicity, I generated a dataset of different shapes, like rectangles, triangles, circles.\n",
32 | "\n",
33 | "## Please enable GPU support to accelerate on notebook setting if you are using colab.\n",
34 | "\n",
35 | "### 0. Install Requirements"
36 | ],
37 | "metadata": {
38 | "collapsed": false,
39 | "pycharm": {
40 | "name": "#%% md\n"
41 | }
42 | }
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "outputs": [],
48 | "source": [
49 | "!pip install pycocotools numpy==1.16.0 opencv-python tqdm tensorboard tensorboardX pyyaml matplotlib\n",
50 | "!pip install torch==1.4.0\n",
51 | "!pip install torchvision==0.5.0"
52 | ],
53 | "metadata": {
54 | "collapsed": false,
55 | "pycharm": {
56 | "name": "#%%\n"
57 | }
58 | }
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "source": [
63 | "### 1. Prepare Custom Dataset/Pretrained Weights (Skip this part if you already have datasets and weights of your own)"
64 | ],
65 | "metadata": {
66 | "collapsed": false,
67 | "pycharm": {
68 | "name": "#%% md\n"
69 | }
70 | }
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "outputs": [],
76 | "source": [
77 | "import os\n",
78 | "import sys\n",
79 | "if \"projects\" not in os.getcwd():\n",
80 | " !git clone --depth 1 https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch\n",
81 | " os.chdir('Yet-Another-EfficientDet-Pytorch')\n",
82 | " sys.path.append('.')\n",
83 | "else:\n",
84 | " !git pull\n",
85 | "\n",
86 | "# download and unzip dataset\n",
87 | "! mkdir datasets\n",
88 | "! wget https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch/releases/download/1.1/dataset_shape.tar.gz\n",
89 | "! tar xzf dataset_shape.tar.gz\n",
90 | "\n",
91 | "# download pretrained weights\n",
92 | "! mkdir weights\n",
93 | "! wget https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch/releases/download/1.0/efficientdet-d0.pth -O weights/efficientdet-d0.pth\n",
94 | "\n",
95 | "# prepare project file projects/shape.yml\n",
96 | "# showing its contents here\n",
97 | "! cat projects/shape.yml"
98 | ],
99 | "metadata": {
100 | "collapsed": false,
101 | "pycharm": {
102 | "name": "#%%\n"
103 | }
104 | }
105 | },
106 | {
107 | "cell_type": "markdown",
108 | "source": [
109 | "### 2. Training"
110 | ],
111 | "metadata": {
112 | "collapsed": false
113 | }
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "outputs": [],
119 | "source": [
120 | "# consider this is a simple dataset, train head will be enough.\n",
121 | "! python train.py -c 0 -p shape --head_only True --lr 1e-3 --batch_size 32 --load_weights weights/efficientdet-d0.pth --num_epochs 50\n",
122 | "\n",
123 | "# the loss will be high at first\n",
124 | "# don't panic, be patient,\n",
125 | "# just wait for a little bit longer"
126 | ],
127 | "metadata": {
128 | "collapsed": false,
129 | "pycharm": {
130 | "name": "#%%\n"
131 | }
132 | }
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "source": [
137 | "### 3. Evaluation"
138 | ],
139 | "metadata": {
140 | "collapsed": false
141 | }
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": null,
146 | "outputs": [],
147 | "source": [
148 | "! python coco_eval.py -c 0 -p shape -w logs/shape/efficientdet-d0_49_1400.pth"
149 | ],
150 | "metadata": {
151 | "collapsed": false,
152 | "pycharm": {
153 | "name": "#%%\n"
154 | }
155 | }
156 | },
157 | {
158 | "cell_type": "markdown",
159 | "source": [
160 | "### 4. Visualize"
161 | ],
162 | "metadata": {
163 | "collapsed": false,
164 | "pycharm": {
165 | "name": "#%% md\n"
166 | }
167 | }
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": 4,
172 | "outputs": [
173 | {
174 | "data": {
175 | "text/plain": "",
176 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAD8CAYAAACVSwr3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3de7RdVX3o8e9vrrX3PnkgCSRQTECCQAW9ijRCLARQSgX0Ch3FB/WWaNEoYC/Utha9HX1wvR311lFbexWIAsY7rIq2NrnUq+UhJvYK8VB5KYJBwyMDSUCeyck5e635u3/Mufde56xzcp775fl9xtjZa8219l7zZO/123PNNR+iqhhjTJHrdgaMMb3HAoMxpsQCgzGmxAKDMabEAoMxpsQCgzGmpC2BQUTOFpEHRWS7iFzZjmMYY9pH5rodg4gkwEPAWcDjwPeBC1X1R3N6IGNM27SjxHASsF1Vf6qqI8CXgfPacBxjTJukbXjPFcBjhfXHgZP394Jly5bpkUce2YasGGMa7rrrrqdUdflU9m1HYJgSEVkPrAc44ogjGBwc7FZWjJkXROSRqe7bjkuJncDhhfWVMW0UVd2gqqtVdfXy5VMKYsaYDmlHYPg+cIyIrBKRKvBOYHMbjmOMaZM5v5RQ1UxEPgh8C0iA61X1h3N9HGNM+7SljkFVvwF8ox3vbYxpP2v5aIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpmTQwiMj1IrJLRO4vpB0kIjeLyE/i89KYLiLyKRHZLiL3isiJ7cy8MaY9plJi+Dxw9pi0K4FbVfUY4Na4DnAOcEx8rAeunptsGmM6adLAoKpbgF+MST4P2BiXNwLnF9K/oMEdwBIROWyuMmuM6YyZ1jEcqqpPxOWfA4fG5RXAY4X9Ho9pJSKyXkQGRWRw9+7dM8yGMaYdZl35qKoK6Axet0FVV6vq6uXLl882G8aYOTTTwPBk4xIhPu+K6TuBwwv7rYxpxpg+MtPAsBlYF5fXAZsK6RfFuxNrgOcKlxzGmD6RTraDiHwJOANYJiKPA38O/DVwo4hcDDwCvD3u/g3gXGA7sBd4TxvybIxps0kDg6peOMGmM8fZV4HLZpspY0x3WctHY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMyaSBQUQOF5Fvi8iPROSHInJ5TD9IRG4WkZ/E56UxXUTkUyKyXUTuFZET2/1HGGPm1lRKDBnwh6p6PLAGuExEjgeuBG5V1WOAW+M6wDnAMfGxHrh6znNtjGmrSQODqj6hqv8Rl18AHgBWAOcBG+NuG4Hz4/J5wBc0uANYIiKHzXnOjTFtM606BhE5EngtcCdwqKo+ETf9HDg0Lq8AHiu87PGYZozpE1MODCKyGPgn4ApVfb64TVUV0OkcWETWi8igiAzu3r17Oi81xrTZlAKDiFQIQeGLqvrPMfnJxiVCfN4V03cChxdevjKmjaKqG1R1taquXr58+Uzzb4xpg6nclRDgOuABVf3bwqbNwLq4vA7YVEi/KN6dWAM8V7jkMMb0gXQK+5wC/C5wn4jcHdM+Cvw1cKOIXAw8Arw9bvsGcC6wHdgLvGdOc2yMabtJA4OqfheQCTafOc7+Clw2y3wZY7rIWj4aY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMPQtBQ1P/cx3OwNmXGm3M2BmbtgDCk5oBYhioJDC8nTS5+I9iukyerMWVjIgVajYN7Gn2MfRtxy1pNt5mL1KY0EVZLzIZLph0sAgIgPAFqAW9/+aqv65iKwCvgwcDNwF/K6qjohIDfgC8GvA08A7VHVHm/I/7510+rtBW1eExWvDYjF925brWXPa75XSx9t/oveYaXozTTxOwQvgE7ZtvS78DaddwLbvwOgihgWHbppKHcMw8EZVfQ1wAnC2iKwBPg58UlWPBp4BLo77Xww8E9M/Gfczc8gXTppt3/k8ojlChpChhYcUHkApfduW61Ey7ojP27ZcP+o97ojrY997ovSxxyylqUdVw3MhhGzb8rXRlyCm6yYNDBq8GFcr8aHAG4GvxfSNwPlx+by4Ttx+pojYxz5HxlbWvX7tOpwod275Ak6l+Ri7DoxaL6a9fu26cfcZ7zXF/Wf2cM33bTjtjHe0+X/NTNeU7kqISCIidwO7gJuBh4FnVTWLuzwOrIjLK4DHAOL25wiXG2Pfc72IDIrI4O7du2f3V8wjjjEfmjq8CiefdhEe13ycvPY9o9aBUevjpU223nh8b+vGcdOL28bb53tbw++FF4+XVqlny+1fGXPlYJcR3TalwKCquaqeAKwETgJeMdsDq+oGVV2tqquXL18+27ebZ1q/tr5ZXyeE8sT4j5PXrps0rbwPE77XnVtv2O9xJjyeeNAEJG/+Daec8c4+uZQIAcuTt/6HtPGPtnZRgJzmBVNcR2n9d2hv36qdVjsGVX0W+DbwemCJiDQqL1cCO+PyTuBwgLj9QEIlpGkDp4DM/S/snVs3xpN7fPvbtl/q4uVEX0SCMQR8q9TmIJ7sEgNz1twNHC6XZmWr1yR8Ti4+pLcbEU2aNxFZLiJL4vIC4CzgAUKAuCDutg7YFJc3x3Xi9ttU1cqGfWbGJ/4vM6V50jd//YUYncEjeNFm5bBPwEsG5CAaSxqxyNDjjdOm0o7hMGCjiCSEQHKjqt4kIj8CviwiHwN+AFwX978O+N8ish34BfDONuTbmM5r3EVVCcsS7rkoQiLgNBnVSCzEizScNR6QtHDrttOZn55JA4Oq3gu8dpz0nxLqG8am7wPeNie5M6ZjpnimjtktnWibjFl3o55moLPFC2v5aEzTZCdfuH7YV89wlZwqi6jzGIKSI3j2ABk5i3GMIKSkHAQ4KrwktO6ERlUFAG7SxlxSeO5ccLDAYExJ8SQsLjuGeIwFlcP5qwd+nY8eB//jkXMgGSHxipOFDGcjJBVInEdHBD+0kGq+kIF9v82lq3dSYRlOazhpNPEa29FkKh1S2s8CQ78ZcxfCS6j0QrSna7mBcG2uAi4Lty0b2nBXZVZUS3l6ke1cPXg+I4t+Bx14DX7xLwD4i5fdz1WPvQIVyH1GmgAevAo4QRYNUWeIfNFzfOLH72Dg+Zfy/pP+hsW8DIeSZ54k1XgL15MxTlsVoNNBoue/S2YyHofixIeWhT5pPYotDqeTPhfvUUwvtHp0MTi4Qv8O8b3UGyx25lJBs3AC7mEH19z3Xv549b+QL3qKP111T7w9CVc9elwr4BWDXaFlJz7hT1fdR33Bs/zhSV/lmgcugthQ3aXgfRI7kAkpMkEbBy082s9KDP1GwDfbJYLzCcTbZJnLS/tO9B5TTp/L9xCPczle/KiOXyo91NTHCzl1nPPsSx9lw+C72HvYW8iWplz1yCtBfAgG0HyejJOcjz12LLic//74cejiXfzV9tdRee63uPTXPsMi91JyRvBkVMhwAp5hCvdHO85KDH3IFT4273Iy5/EuB3RUI+Tir8x00ufiPYrpo37tGiWJwi+q9NgN/YSUjD38w0MX8sKye8jqyoiOhIA2A43/BUf4X8E5RgaeY++yB/n0g2/H8xwJlXDcWHpyWpu7P2gGrMTQl1on1Z1bN+5nv/6g5JPv1CniyeRF/mHb77Hn0GcRUirJCBVfCaWcGQQH7/IQzPNK8xg4hzp4YcHTfPqOy7l0zf8i8Ysb0aPrLDD0nVi89Mopb3gH6itIXkEFsl4qko9HYiUbAJ7EeYa9MJBq8xe12/bI83z2nnfz/LLtJCSAUkeoeIfKDPs3NEpHLms0ksTFTi6qGU+/dBufGLyQK1bfQE2XoeLx6uhmzYv0Qmvl1atX6+DgYLez0Rda979zhjUBZfQXqMeHdmsMVRnr2hjJoSqQJt27ng7CrcK/ueN3GPmVe0lTZVgzyKuhhaPLSXWG+RMN7aMbdUAqzToWyYW620PFH8DAo6v5g1M2kFYcXj1OEsb/UGaYDZG7VHX1VPa1EkOfaf2qJtTGtq7rQ2nzG9jJPyQ0VMpwpIDmgiQwzCMMHfojcJ7cA7RO5hkHBWjdsSi+RyzdaeJIXQVNRhg66nvUKzug/lK8LqBanexvgOb/2/6C8wyy3gulN2M6TECT+KsozfPmM1v+ApKRDmfFg6+AT/Dq+NxdH0Uqbv+D4zb6a6i06nVFQbLYnT2nuWGG8cwCg5mXwriTgkfJJXSX3nfIQ53NhGg4kVVweHwOzw08yjA7EbefS4dSn4w8LGga6nA0aQWOGbLAYOadMIBKY0yEHIlnwd4Fz3Y2IyqhtJCMgMtJq1Bf+CLXbvsQ+X7qFFSV3OeF6ockvl1GsxLHE4aHGAb2xcc0WB2DmXdG/xomrQZjkpe2tp34Zq9sXI5XYWThU4ywnQUcO+5LVDVUTCooikiYoEOylC233YvPhcSlgAOfkCQJIyPTu0SywGDmIWneNiQX8iTUN6TNInnnKkK9KM5X8Aguh0olI0scQwyzYILXOHF4r4gTJBP8MPy/7zyErztgIXhP3StprNmt5xnVyvQaTFlgMPOXZJBAxgPUgEq+kMzV23a4PzvigSnu+WPg1WPSRk/n5Rr3qNPQ3+LUc2efvyILDGZ+EoA0jLIWrx7aGRT6jQUGM78JEJtkq8br9Ta66tHjmiWHxnKjM9bY5dwPk7jGJUCrQcJ3//UhsnpoF5EkrVN47VtfztbND5eWG+vTYXclzPzUGNJdQlMnCIGh3Ron/3g9M4tB4+P3vxHvssLWYjfulDSpUqlU25ZnKzGYeSz0hWj0SnCuh34nfUJSOD23bvpZaAwlvjlEXDtLOBYYzDyWx5GTOtdfaOylxITbj3iAwlzgrD1vVXN566afsfa8VaMuFQC2bn64eckw0SXFVFlgMPOTQKv7WWdPg2JAmCg41J55Gac88UNOPTukbf3XH4fbqHnoQDHRyT5esJiJHio7GdMFCkKs4NMunw6NCX99gqsPoL6Qnzz2p3B57IA1zlBvc5h/CwxmXiqOq5DEOZdTSToaHEbNECoKeWiUlIws5PiD3wDZwJhXSGHkKyk8GklzNx6HBQYzL7lGnwKXkbIMgGEdiUPkdSoP8VkdDo/Wcuo+w9WXUNv5qx3Lx3isjsHMUxL7RqTNE1TxiHZ23CSnLs6InaAjVRKpkTz9K1RePKKj+Sjlq6tHN6bbmnNMwgD7HRmlDcd2o8aR1HQvfgiOX3oqyfBEPSU6wwKDmZ8UINyq1FiBt/DR1aSdHIlVPLgs1C+4jNQv4ICnjmbg4dPZ/0gt7WeBwcw7njhQCwrkzTGcLln7caRU4dd+zif4vAo+4TXpf6HqhJF8uOP5GJWnrh7dmC5oTgGnAqTNiv0ay6ntOoaqr5AgaBIqIp3kZJrgtXDXQsNYBx7wk9zJ8C4Pg8ECCWH6uwxHhiNJElwCiSqLd5zOwJPH4nwVke7OzmWBwcxfpdbEwgfXfJbKU6sYqUOqA7i8iseRuizcyYDQUtLVIRnBST7pbUKX1XCx30NdE7x4quRUNTRrrmf7GNixhtfIhZAoKkp1f0O7dYAFBmMKqhzIpas/waJsKblkca4OT5Kn4Oqjb2fGOxiTnUROfLx0AZDQWMnlgCfzntreg/lPA79N7cVD0DxBcWi9u8N/W2AwZowFejR/8Ks3seD5pXjnqZOQuzyO5pyGywdfDfUC0Lq7MAEvHi+NYeihgifLa4yIY+D5Q/jj479JbfcqMs1RVdLUo5UJ364jphwYRCQRkR+IyE1xfZWI3Cki20XkKyJSjem1uL49bj+yPVk3Zu4571GUhCVcetwNLHryKJKsSpZDPRtGPOR1QZxvDTXf6PU4gUql0pyXwgN5kpFmVV7y7NH811fdgBs+kGpSwavgU08997i8u7OKTafEcDlQHJvq48AnVfVo4Bng4ph+MfBMTP9k3M+YvqAeRD2Cp6Iv54rXbeSQX7yCBdUBBgaESq1OxaWtXplTaBA1tHeY5rVEqiS6iCW/eCWXv/oGFrKKRByJpqS+QhZnLpcu99uY0tFFZCXwZuBzcV2ANwJfi7tsBM6Py+fFdeL2M6Xdw+IYM0ckdXiXoChVHLV8KZec8EUuX7GZdPup6ItLkGQvmsWZpRrTz+3nRE4HErI0BJCFPzuBPzr8Jj74uo1UdGkYIkbjxLduhAU+9IfochXDlEsMfwd8mFafj4OBZ1W1McTM48CKuLwCeAwgbn8u7j+KiKwXkUERGdy9e/cMs2/M3Au3M124a+FAkoSFrOSK0z/NFcdu4oAdJ7No6AAWiIM8xacjSJKhIuRewkkiaQgWmbB4eDEHPHwSAB9a+xlqrAAEL1kcth58Hk4ujWNDuC5PUDxp8yoReQuwS1XvEpEz5urAqroB2ABhUtu5el9j5oaMegKoMUCqjkvWXkPOEBlPc82Wq0gH4te3sheqe5Chg/G5oFrnvWs+huNAakcdDHwexwG0pvb28Rao4CQhx+Mbc1wWmkp3w1TaXZ4CvFVEzgUGgJcAfw8sEZE0lgpWAjvj/juBw4HHRSQFDgSenvOcGzPnJi+/JzJ6dvEPnzbZK74+Zt01y+nN4noVfv3NE7/DTAdbmY1JA4OqfgT4CEAsMfyRqr5LRL4KXAB8GVgHbIov2RzXvxe336adGGXTmFmZ/Cvqm/sJmkEYoLkxe5WQZRlpHLXZq8c5CZWOLieEEyHMG9cICYL3Hpc5/v2W7fis0Kqyy5cSs6n6/BPgQyKynVCHcF1Mvw44OKZ/CLhydlk0pjc4bZwwOdL8SQ0TyHrApSleGtPNOTw5YRCmpDUgC+CbJZMc58LFRK/9dk6rC5eq3g7cHpd/Cpw0zj77gLfNQd6M6S0KoT7AhZXGaEqiuEZpQNM4uGwdtBoqFltTQoQRqeP7hJ4TYaq5UQPSNodu696tCRuoxZipiiOp+eI6ECoREyAtnM+NmgjFyaidwyti705RqNdBNQwa0yt6JyfG9Lrmr/7YX/JkzD5SSBNK9RdSuIYXIW2MDyMefBpmwFYp9K/oPOsrYcxckvGWJzrDQ7o4WHtmHEY+3qrsZlAACwzGdFejMJEQ+1z0RiWkBQZjuqh5MyIFXDZph6xOscBgTBdJaAIBoqz9jeNRqZMkSdtn3Z6MBQZjeoJABZKKouotMBgzn/lGJYOCSkbuM+r1etcbPFlg6JbCtIMebX5BitOWTec9ptKk1/Qep9Js+yCknHbWq9DKSOiUAYCS1X3o2t3B+gcLDF2SSY4XxWv4cjiVGCCAGCgyQsv6ECxi8FCaffgR8JKThTmUuva3mFloTj8ZorwKvOGsV4HLEIE8z6nVaqFi0qfh0QEWGLrE4WJACBOaeslwKqQaGtCER3H/kOZFSZV4W0vJcaTxfUw/anxwAnhyMqjAqWceTZIkJFKjng3T6gbemRKDtXxsu/F/yRsDdIxuBVfeZ2zkdmP2q0x8iMgiRm+TQmxISFNtJo+wB5ccQOKrqDYuIzrTh8JKDL/0ilOlj30eu5/pOB27HJtQi7L2zOPxMkSmPk7A27mOVVZi6IhicTFWMubKi074/c99lZ35Um75wFkAnHzNVvb5Ybz31Go1sizj7kvO5LVX3w7ikaSGjOQMa8ZRvMDmD57HPp8x4JJR7986XvH445Uexu5jOqp4njcqkyV8jq4Ga3/zWL777ftheEHYuUNjNVhg6Ihw/Qi+WXn4FMLvf/pf2J6+lB984FROvPpW8sQhPkNclQTP4PtOH/UeXhMWZDlDicfh2MGBADxMwiHA8v0ev/TTNGZ5nM4+pkMalw+Fy4r4ebgqrD3zVXz3mw+17kx0IDhYYOgYT6ahP/6ODN73uZt5NlkKcTxdJQkD/QhxDHNCKaFJcUAdxWmYijWJt7QuvvY2lrt9/J/3hz3reU4lcaNeOz4LBr1Bxl0MPJImrD37WL5zy33oyACpJGSZkqTSagylCY4wuc2c5KjbDSkgDAY7ODjY7Wy0ieDxuFwYTuBRhbd95hbEzXaqofANciguAXLPYj/M1z5wNksbwwo2Ws+pWhVCT5hJIA77q5cwUn0G37vlp6Au3s4sliCKpb/R77H2rUcjwl2qunoqR7XKxzZqxG6HJ3NhjP311/4bqQww0zO10R3XKZB7PIK4ChkJz7OID3x2E3vHvrUFhf7UrIwUcq2TkyMVOOVNR1FPn8G7HJ/JmLsVc1OxbIGhjUQbrRkT9gKXX/stnskHyLXOXBThnQsf33A9Ixclr8AjHMgjwNCoSq36rI9luqDwGaZJJXSucoqmntPfdAKnnPVyfHUPKgI+RaRxaRHGkBQJg83OhAWGDhAPF13zTX42lCJpBm7m/+3NhkyieBnVJprcj5CgvPuz/86TozJgH3Nf08IDQdQhCJLAGW9+Jae9ZRVrz3kZVOuIA685xT5YOoNRX6zysY3Eh0kIhgQedgup1sBl9Thy8MyEIb90VH8KFyskRSqMkDCU7ePyDTexaX3YnpHYB93PineUFbzXMDQ9NEsFruY45axjkQzI4fbb7wV15OpJ0+nXZ9n3pQN+LJD4jAVZjaFa7Bfh8xm/n1Mhi0WH1pDmkGceJ0JShR060Np/Npk3vaMRIJJQiek13LZ2zuE1D+1cqg6vcNrZr8RJGJF6JvcXLDC0k4bf9T+55pvU3CKG0ix8gLPoCOOlFRCaaYSTP02UDMVlFZKkFXgsMPSS2dcENz5PV3iriZYRmMnQDhYY2mjYKTVgDwn7CHMPOG1P24EwkSq4XMApvvBtaE6VaLqs200Dph4h7PvSRpX4OezF4fN6OHnb1GrNC2HqdOepKmje+hJYz0szXRYY2shJnFvAJ1TTVuGsHY1ZnYKoCw2vVdARu0VpZs4CQxsNx+eKCPVcC0ODz/1/uwOqGkoOmVNcpXAM7f6ow6a/WGBoo8bpOJwNU1GHxkuIVOf+P74xJJzD4xSKNzR9j8xVYPqHBYY2alzbp2lK3dVJfBLbIfi2XE548aAJPslIkgWtDZJM/CJjxmGBoY1q8dkTSvPOC67Nv94+zsAs2czbSRhjgaEDqkkVcmIzZt++/3R1IfCoIHkrMPRAB1rTZywwtFE91ikkXllYTaknOcNCOIHbcDwvoT7B+YSs0roLklhgMNNkgaGNGuM8JzpMlmUkCok6MqStw707IHeFSwnrdm2myQJDG6V5+KleVXkR0irOJ6Rt/fX2pD50sqr4QjsGCwxmmqYUGERkh4jcJyJ3i8hgTDtIRG4WkZ/E56UxXUTkUyKyXUTuFZET2/kH9LYQBf7yveeT1+vURag7z4C2pzWig+ZdjwO6PPeh6W/TKTG8QVVPKAwNdSVwq6oeA9wa1wHOAY6Jj/XA1XOV2b4Tx104DFgqdbxmqBeGJYsVkbHI7/I4Vl+jL8XEGl3rU8I8A7nzOAmPTIWRxJOrcmx9qPWaNv155pfXbC4lzgM2xuWNwPmF9C9ocAewREQOm8Vx+lZGuM6vAZ+77E0cJDm1TKkmVTJ1LMgctTxMO5b6BC9KPdn/bcY0nuaZpnhNEJ/gNcGrQ0ZykhHlKF7kY5f85+ZrrK+Ema6pBgYF/k1E7hKROPwHh6rqE3H558ChcXkF8FjhtY/HtFFEZL2IDIrI4O7du2eQ9d7XOIkdyvIcjuRZlIysnuNcxlAaGjo5JY6v4EnQ/Xay8upIgVy02QzaSeiLn9YqLK4Oce36t3JwsU2TFRnMNE01MJyqqicSLhMuE5HTihs1DDU9rd8lVd2gqqtVdfXy5RPPiNDP8kKLwwWJ8plLLuAo2cOw1MPMQq4eBvR0ORrbNzifhB6YE/AIHqiRg+SMNO4+iKJ+hOvffw5LkjEfrFU3mGmaUmBQ1Z3xeRfwdeAk4MnGJUJ83hV33wkcXnj5ypg27ySN/14VKngWq+dLl53PcbxAmkOaVck1RUmp+RAUMnXs70xOFTLxDDtAKkie4LTKUq/ceMlvcLRCtT76g53NUHJmfpr0KyMii0TkgMYy8JvA/cBmYF3cbR2wKS5vBi6KdyfWAM8VLjnmmVYhypOQiZCiXP/+3+J49vAStxfv64iHDBd6X8r+C1+ZU5w6Kt4xkCtp6lmZvMAN7zuTVXEUcamA7/qgIKafTWUEp0OBr0u4/ZUC/6iq3xSR7wM3isjFwCPA2+P+3wDOBbYDe4H3zHmu+4VKHOpfWjNXC7wkgY2XtvfQbtRy5yZDNb8cemImKhF5AXiw2/mYomXAU93OxBT0Sz6hf/LaL/mE8fP6MlWdUoVer4z5+OBUp87qNhEZ7Ie89ks+oX/y2i/5hNnn1aqljDElFhiMMSW9Ehg2dDsD09Avee2XfEL/5LVf8gmzzGtPVD4aY3pLr5QYjDE9pOuBQUTOFpEHYzftKyd/RVvzcr2I7BKR+wtpPdm9XEQOF5Fvi8iPROSHInJ5L+ZXRAZEZJuI3BPz+ZcxfZWI3Bnz8xURqcb0WlzfHrcf2Yl8FvKbiMgPROSmHs9ne4dCUNWuPYAEeBg4CqgC9wDHdzE/pwEnAvcX0v4ncGVcvhL4eFw+F/i/hJZDa4A7O5zXw4AT4/IBwEPA8b2W33i8xXG5AtwZj38j8M6Yfg1wSVy+FLgmLr8T+EqH/18/BPwjcFNc79V87qgn+38AAAIxSURBVACWjUmbs8++Y3/IBH/c64FvFdY/Anyky3k6ckxgeBA4LC4fRmhzAXAtcOF4+3Up35uAs3o5v8BC4D+AkwmNb9Kx3wPgW8Dr43Ia95MO5W8lYWyRNwI3xROp5/IZjzleYJizz77blxJT6qLdZbPqXt4JsRj7WsKvcc/lNxbP7yZ0tLuZUEp8VlWzcfLSzGfc/hxwcCfyCfwd8GFaHdUP7tF8QhuGQijqlZaPfUFVVaS3pnUSkcXAPwFXqOrzUhjSrVfyq6o5cIKILCH0zn1Fl7NUIiJvAXap6l0icka38zMFp6rqThE5BLhZRH5c3Djbz77bJYZ+6KLds93LRaRCCApfVNV/jsk9m19VfRb4NqFIvkREGj9Mxbw08xm3Hwg83YHsnQK8VUR2AF8mXE78fQ/mE2j/UAjdDgzfB46JNb9VQiXO5i7naaye7F4uoWhwHfCAqv5tr+ZXRJbHkgIisoBQD/IAIUBcMEE+G/m/ALhN44VxO6nqR1R1paoeSfge3qaq7+q1fEKHhkLoVGXJfipRziXUqD8M/Lcu5+VLwBNAnXAddjHhuvFW4CfALcBBcV8BPh3zfR+wusN5PZVwnXkvcHd8nNtr+QVeDfwg5vN+4M9i+lHANkL3/K8CtZg+ENe3x+1HdeF7cAatuxI9l8+Yp3vi44eN82YuP3tr+WiMKen2pYQxpgdZYDDGlFhgMMaUWGAwxpRYYDDGlFhgMMaUWGAwxpRYYDDGlPx/iALj2+UtxycAAAAASUVORK5CYII=\n"
177 | },
178 | "metadata": {
179 | "needs_background": "light"
180 | },
181 | "output_type": "display_data"
182 | }
183 | ],
184 | "source": [
185 | "import torch\n",
186 | "from torch.backends import cudnn\n",
187 | "\n",
188 | "from backbone import EfficientDetBackbone\n",
189 | "import cv2\n",
190 | "import matplotlib.pyplot as plt\n",
191 | "import numpy as np\n",
192 | "\n",
193 | "from efficientdet.utils import BBoxTransform, ClipBoxes\n",
194 | "from utils.utils import preprocess, invert_affine, postprocess\n",
195 | "\n",
196 | "compound_coef = 0\n",
197 | "force_input_size = None # set None to use default size\n",
198 | "img_path = 'datasets/shape/val/999.jpg'\n",
199 | "\n",
200 | "threshold = 0.2\n",
201 | "iou_threshold = 0.2\n",
202 | "\n",
203 | "use_cuda = True\n",
204 | "use_float16 = False\n",
205 | "cudnn.fastest = True\n",
206 | "cudnn.benchmark = True\n",
207 | "\n",
208 | "obj_list = ['rectangle', 'circle']\n",
209 | "\n",
210 | "# tf bilinear interpolation is different from any other's, just make do\n",
211 | "input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]\n",
212 | "input_size = input_sizes[compound_coef] if force_input_size is None else force_input_size\n",
213 | "ori_imgs, framed_imgs, framed_metas = preprocess(img_path, max_size=input_size)\n",
214 | "\n",
215 | "if use_cuda:\n",
216 | " x = torch.stack([torch.from_numpy(fi).cuda() for fi in framed_imgs], 0)\n",
217 | "else:\n",
218 | " x = torch.stack([torch.from_numpy(fi) for fi in framed_imgs], 0)\n",
219 | "\n",
220 | "x = x.to(torch.float32 if not use_float16 else torch.float16).permute(0, 3, 1, 2)\n",
221 | "\n",
222 | "model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list),\n",
223 | "\n",
224 | " # replace this part with your project's anchor config\n",
225 | " ratios=[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)],\n",
226 | " scales=[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])\n",
227 | "\n",
228 | "model.load_state_dict(torch.load('logs/shape/efficientdet-d0_49_1400.pth'))\n",
229 | "model.requires_grad_(False)\n",
230 | "model.eval()\n",
231 | "\n",
232 | "if use_cuda:\n",
233 | " model = model.cuda()\n",
234 | "if use_float16:\n",
235 | " model = model.half()\n",
236 | "\n",
237 | "with torch.no_grad():\n",
238 | " features, regression, classification, anchors = model(x)\n",
239 | "\n",
240 | " regressBoxes = BBoxTransform()\n",
241 | " clipBoxes = ClipBoxes()\n",
242 | "\n",
243 | " out = postprocess(x,\n",
244 | " anchors, regression, classification,\n",
245 | " regressBoxes, clipBoxes,\n",
246 | " threshold, iou_threshold)\n",
247 | "\n",
248 | "out = invert_affine(framed_metas, out)\n",
249 | "\n",
250 | "for i in range(len(ori_imgs)):\n",
251 | " if len(out[i]['rois']) == 0:\n",
252 | " continue\n",
253 | "\n",
254 | " for j in range(len(out[i]['rois'])):\n",
255 | " (x1, y1, x2, y2) = out[i]['rois'][j].astype(np.int)\n",
256 | " cv2.rectangle(ori_imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2)\n",
257 | " obj = obj_list[out[i]['class_ids'][j]]\n",
258 | " score = float(out[i]['scores'][j])\n",
259 | "\n",
260 | " cv2.putText(ori_imgs[i], '{}, {:.3f}'.format(obj, score),\n",
261 | " (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,\n",
262 | " (255, 255, 0), 1)\n",
263 | "\n",
264 | " plt.imshow(ori_imgs[i])\n",
265 | "\n"
266 | ],
267 | "metadata": {
268 | "collapsed": false,
269 | "pycharm": {
270 | "name": "#%%\n"
271 | }
272 | }
273 | }
274 | ],
275 | "metadata": {
276 | "kernelspec": {
277 | "display_name": "Python 3",
278 | "language": "python",
279 | "name": "python3"
280 | },
281 | "language_info": {
282 | "codemirror_mode": {
283 | "name": "ipython",
284 | "version": 2
285 | },
286 | "file_extension": ".py",
287 | "mimetype": "text/x-python",
288 | "name": "python",
289 | "nbconvert_exporter": "python",
290 | "pygments_lexer": "ipython2",
291 | "version": "2.7.6"
292 | }
293 | },
294 | "nbformat": 4,
295 | "nbformat_minor": 0
296 | }
--------------------------------------------------------------------------------
/utils/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .batchnorm import patch_sync_batchnorm, convert_model
13 | from .replicate import DataParallelWithCallback, patch_replication_callback
14 |
--------------------------------------------------------------------------------
/utils/sync_batchnorm/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 | import contextlib
13 |
14 | import torch
15 | import torch.nn.functional as F
16 |
17 | from torch.nn.modules.batchnorm import _BatchNorm
18 |
19 | try:
20 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
21 | except ImportError:
22 | ReduceAddCoalesced = Broadcast = None
23 |
24 | try:
25 | from jactorch.parallel.comm import SyncMaster
26 | from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
27 | except ImportError:
28 | from .comm import SyncMaster
29 | from .replicate import DataParallelWithCallback
30 |
31 | __all__ = [
32 | 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
33 | 'patch_sync_batchnorm', 'convert_model'
34 | ]
35 |
36 |
37 | def _sum_ft(tensor):
38 | """sum over the first and last dimention"""
39 | return tensor.sum(dim=0).sum(dim=-1)
40 |
41 |
42 | def _unsqueeze_ft(tensor):
43 | """add new dimensions at the front and the tail"""
44 | return tensor.unsqueeze(0).unsqueeze(-1)
45 |
46 |
47 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
48 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
49 |
50 |
51 | class _SynchronizedBatchNorm(_BatchNorm):
52 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
53 | assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
54 |
55 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
56 |
57 | self._sync_master = SyncMaster(self._data_parallel_master)
58 |
59 | self._is_parallel = False
60 | self._parallel_id = None
61 | self._slave_pipe = None
62 |
63 | def forward(self, input):
64 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
65 | if not (self._is_parallel and self.training):
66 | return F.batch_norm(
67 | input, self.running_mean, self.running_var, self.weight, self.bias,
68 | self.training, self.momentum, self.eps)
69 |
70 | # Resize the input to (B, C, -1).
71 | input_shape = input.size()
72 | input = input.view(input.size(0), self.num_features, -1)
73 |
74 | # Compute the sum and square-sum.
75 | sum_size = input.size(0) * input.size(2)
76 | input_sum = _sum_ft(input)
77 | input_ssum = _sum_ft(input ** 2)
78 |
79 | # Reduce-and-broadcast the statistics.
80 | if self._parallel_id == 0:
81 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
82 | else:
83 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
84 |
85 | # Compute the output.
86 | if self.affine:
87 | # MJY:: Fuse the multiplication for speed.
88 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
89 | else:
90 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
91 |
92 | # Reshape it.
93 | return output.view(input_shape)
94 |
95 | def __data_parallel_replicate__(self, ctx, copy_id):
96 | self._is_parallel = True
97 | self._parallel_id = copy_id
98 |
99 | # parallel_id == 0 means master device.
100 | if self._parallel_id == 0:
101 | ctx.sync_master = self._sync_master
102 | else:
103 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
104 |
105 | def _data_parallel_master(self, intermediates):
106 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
107 |
108 | # Always using same "device order" makes the ReduceAdd operation faster.
109 | # Thanks to:: Tete Xiao (http://tetexiao.com/)
110 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
111 |
112 | to_reduce = [i[1][:2] for i in intermediates]
113 | to_reduce = [j for i in to_reduce for j in i] # flatten
114 | target_gpus = [i[1].sum.get_device() for i in intermediates]
115 |
116 | sum_size = sum([i[1].sum_size for i in intermediates])
117 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
118 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
119 |
120 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
121 |
122 | outputs = []
123 | for i, rec in enumerate(intermediates):
124 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
125 |
126 | return outputs
127 |
128 | def _compute_mean_std(self, sum_, ssum, size):
129 | """Compute the mean and standard-deviation with sum and square-sum. This method
130 | also maintains the moving average on the master device."""
131 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
132 | mean = sum_ / size
133 | sumvar = ssum - sum_ * mean
134 | unbias_var = sumvar / (size - 1)
135 | bias_var = sumvar / size
136 |
137 | if hasattr(torch, 'no_grad'):
138 | with torch.no_grad():
139 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
140 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
141 | else:
142 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
143 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
144 |
145 | return mean, bias_var.clamp(self.eps) ** -0.5
146 |
147 |
148 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
149 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
150 | mini-batch.
151 |
152 | .. math::
153 |
154 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
155 |
156 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
157 | standard-deviation are reduced across all devices during training.
158 |
159 | For example, when one uses `nn.DataParallel` to wrap the network during
160 | training, PyTorch's implementation normalize the tensor on each device using
161 | the statistics only on that device, which accelerated the computation and
162 | is also easy to implement, but the statistics might be inaccurate.
163 | Instead, in this synchronized version, the statistics will be computed
164 | over all training samples distributed on multiple devices.
165 |
166 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
167 | as the built-in PyTorch implementation.
168 |
169 | The mean and standard-deviation are calculated per-dimension over
170 | the mini-batches and gamma and beta are learnable parameter vectors
171 | of size C (where C is the input size).
172 |
173 | During training, this layer keeps a running estimate of its computed mean
174 | and variance. The running sum is kept with a default momentum of 0.1.
175 |
176 | During evaluation, this running mean/variance is used for normalization.
177 |
178 | Because the BatchNorm is done over the `C` dimension, computing statistics
179 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
180 |
181 | Args:
182 | num_features: num_features from an expected input of size
183 | `batch_size x num_features [x width]`
184 | eps: a value added to the denominator for numerical stability.
185 | Default: 1e-5
186 | momentum: the value used for the running_mean and running_var
187 | computation. Default: 0.1
188 | affine: a boolean value that when set to ``True``, gives the layer learnable
189 | affine parameters. Default: ``True``
190 |
191 | Shape::
192 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
193 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
194 |
195 | Examples:
196 | >>> # With Learnable Parameters
197 | >>> m = SynchronizedBatchNorm1d(100)
198 | >>> # Without Learnable Parameters
199 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
200 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
201 | >>> output = m(input)
202 | """
203 |
204 | def _check_input_dim(self, input):
205 | if input.dim() != 2 and input.dim() != 3:
206 | raise ValueError('expected 2D or 3D input (got {}D input)'
207 | .format(input.dim()))
208 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
209 |
210 |
211 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
212 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
213 | of 3d inputs
214 |
215 | .. math::
216 |
217 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
218 |
219 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
220 | standard-deviation are reduced across all devices during training.
221 |
222 | For example, when one uses `nn.DataParallel` to wrap the network during
223 | training, PyTorch's implementation normalize the tensor on each device using
224 | the statistics only on that device, which accelerated the computation and
225 | is also easy to implement, but the statistics might be inaccurate.
226 | Instead, in this synchronized version, the statistics will be computed
227 | over all training samples distributed on multiple devices.
228 |
229 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
230 | as the built-in PyTorch implementation.
231 |
232 | The mean and standard-deviation are calculated per-dimension over
233 | the mini-batches and gamma and beta are learnable parameter vectors
234 | of size C (where C is the input size).
235 |
236 | During training, this layer keeps a running estimate of its computed mean
237 | and variance. The running sum is kept with a default momentum of 0.1.
238 |
239 | During evaluation, this running mean/variance is used for normalization.
240 |
241 | Because the BatchNorm is done over the `C` dimension, computing statistics
242 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
243 |
244 | Args:
245 | num_features: num_features from an expected input of
246 | size batch_size x num_features x height x width
247 | eps: a value added to the denominator for numerical stability.
248 | Default: 1e-5
249 | momentum: the value used for the running_mean and running_var
250 | computation. Default: 0.1
251 | affine: a boolean value that when set to ``True``, gives the layer learnable
252 | affine parameters. Default: ``True``
253 |
254 | Shape::
255 | - Input: :math:`(N, C, H, W)`
256 | - Output: :math:`(N, C, H, W)` (same shape as input)
257 |
258 | Examples:
259 | >>> # With Learnable Parameters
260 | >>> m = SynchronizedBatchNorm2d(100)
261 | >>> # Without Learnable Parameters
262 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
263 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
264 | >>> output = m(input)
265 | """
266 |
267 | def _check_input_dim(self, input):
268 | if input.dim() != 4:
269 | raise ValueError('expected 4D input (got {}D input)'
270 | .format(input.dim()))
271 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
272 |
273 |
274 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
275 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
276 | of 4d inputs
277 |
278 | .. math::
279 |
280 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
281 |
282 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
283 | standard-deviation are reduced across all devices during training.
284 |
285 | For example, when one uses `nn.DataParallel` to wrap the network during
286 | training, PyTorch's implementation normalize the tensor on each device using
287 | the statistics only on that device, which accelerated the computation and
288 | is also easy to implement, but the statistics might be inaccurate.
289 | Instead, in this synchronized version, the statistics will be computed
290 | over all training samples distributed on multiple devices.
291 |
292 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
293 | as the built-in PyTorch implementation.
294 |
295 | The mean and standard-deviation are calculated per-dimension over
296 | the mini-batches and gamma and beta are learnable parameter vectors
297 | of size C (where C is the input size).
298 |
299 | During training, this layer keeps a running estimate of its computed mean
300 | and variance. The running sum is kept with a default momentum of 0.1.
301 |
302 | During evaluation, this running mean/variance is used for normalization.
303 |
304 | Because the BatchNorm is done over the `C` dimension, computing statistics
305 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
306 | or Spatio-temporal BatchNorm
307 |
308 | Args:
309 | num_features: num_features from an expected input of
310 | size batch_size x num_features x depth x height x width
311 | eps: a value added to the denominator for numerical stability.
312 | Default: 1e-5
313 | momentum: the value used for the running_mean and running_var
314 | computation. Default: 0.1
315 | affine: a boolean value that when set to ``True``, gives the layer learnable
316 | affine parameters. Default: ``True``
317 |
318 | Shape::
319 | - Input: :math:`(N, C, D, H, W)`
320 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
321 |
322 | Examples:
323 | >>> # With Learnable Parameters
324 | >>> m = SynchronizedBatchNorm3d(100)
325 | >>> # Without Learnable Parameters
326 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
327 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
328 | >>> output = m(input)
329 | """
330 |
331 | def _check_input_dim(self, input):
332 | if input.dim() != 5:
333 | raise ValueError('expected 5D input (got {}D input)'
334 | .format(input.dim()))
335 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
336 |
337 |
338 | @contextlib.contextmanager
339 | def patch_sync_batchnorm():
340 | import torch.nn as nn
341 |
342 | backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
343 |
344 | nn.BatchNorm1d = SynchronizedBatchNorm1d
345 | nn.BatchNorm2d = SynchronizedBatchNorm2d
346 | nn.BatchNorm3d = SynchronizedBatchNorm3d
347 |
348 | yield
349 |
350 | nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
351 |
352 |
353 | def convert_model(module):
354 | """Traverse the input module and its child recursively
355 | and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
356 | to SynchronizedBatchNorm*N*d
357 |
358 | Args:
359 | module: the input module needs to be convert to SyncBN model
360 |
361 | Examples:
362 | >>> import torch.nn as nn
363 | >>> import torchvision
364 | >>> # m is a standard pytorch model
365 | >>> m = torchvision.models.resnet18(True)
366 | >>> m = nn.DataParallel(m)
367 | >>> # after convert, m is using SyncBN
368 | >>> m = convert_model(m)
369 | """
370 | if isinstance(module, torch.nn.DataParallel):
371 | mod = module.module
372 | mod = convert_model(mod)
373 | mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
374 | return mod
375 |
376 | mod = module
377 | for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
378 | torch.nn.modules.batchnorm.BatchNorm2d,
379 | torch.nn.modules.batchnorm.BatchNorm3d],
380 | [SynchronizedBatchNorm1d,
381 | SynchronizedBatchNorm2d,
382 | SynchronizedBatchNorm3d]):
383 | if isinstance(module, pth_module):
384 | mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
385 | mod.running_mean = module.running_mean
386 | mod.running_var = module.running_var
387 | if module.affine:
388 | mod.weight.data = module.weight.data.clone().detach()
389 | mod.bias.data = module.bias.data.clone().detach()
390 |
391 | for name, child in module.named_children():
392 | mod.add_module(name, convert_model(child))
393 |
394 | return mod
395 |
--------------------------------------------------------------------------------
/utils/sync_batchnorm/batchnorm_reimpl.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : batchnorm_reimpl.py
4 | # Author : acgtyrant
5 | # Date : 11/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 |
15 | __all__ = ['BatchNorm2dReimpl']
16 |
17 |
18 | class BatchNorm2dReimpl(nn.Module):
19 | """
20 | A re-implementation of batch normalization, used for testing the numerical
21 | stability.
22 |
23 | Author: acgtyrant
24 | See also:
25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26 | """
27 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
28 | super().__init__()
29 |
30 | self.num_features = num_features
31 | self.eps = eps
32 | self.momentum = momentum
33 | self.weight = nn.Parameter(torch.empty(num_features))
34 | self.bias = nn.Parameter(torch.empty(num_features))
35 | self.register_buffer('running_mean', torch.zeros(num_features))
36 | self.register_buffer('running_var', torch.ones(num_features))
37 | self.reset_parameters()
38 |
39 | def reset_running_stats(self):
40 | self.running_mean.zero_()
41 | self.running_var.fill_(1)
42 |
43 | def reset_parameters(self):
44 | self.reset_running_stats()
45 | init.uniform_(self.weight)
46 | init.zeros_(self.bias)
47 |
48 | def forward(self, input_):
49 | batchsize, channels, height, width = input_.size()
50 | numel = batchsize * height * width
51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52 | sum_ = input_.sum(1)
53 | sum_of_square = input_.pow(2).sum(1)
54 | mean = sum_ / numel
55 | sumvar = sum_of_square - sum_ * mean
56 |
57 | self.running_mean = (
58 | (1 - self.momentum) * self.running_mean
59 | + self.momentum * mean.detach()
60 | )
61 | unbias_var = sumvar / (numel - 1)
62 | self.running_var = (
63 | (1 - self.momentum) * self.running_var
64 | + self.momentum * unbias_var.detach()
65 | )
66 |
67 | bias_var = sumvar / numel
68 | inv_std = 1 / (bias_var + self.eps).pow(0.5)
69 | output = (
70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72 |
73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74 |
75 |
--------------------------------------------------------------------------------
/utils/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def __getstate__(self):
79 | return {'master_callback': self._master_callback}
80 |
81 | def __setstate__(self, state):
82 | self.__init__(state['master_callback'])
83 |
84 | def register_slave(self, identifier):
85 | """
86 | Register an slave device.
87 |
88 | Args:
89 | identifier: an identifier, usually is the device id.
90 |
91 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
92 |
93 | """
94 | if self._activated:
95 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
96 | self._activated = False
97 | self._registry.clear()
98 | future = FutureResult()
99 | self._registry[identifier] = _MasterRegistry(future)
100 | return SlavePipe(identifier, self._queue, future)
101 |
102 | def run_master(self, master_msg):
103 | """
104 | Main entry for the master device in each forward pass.
105 | The messages were first collected from each devices (including the master device), and then
106 | an callback will be invoked to compute the message to be sent back to each devices
107 | (including the master device).
108 |
109 | Args:
110 | master_msg: the message that the master want to send to itself. This will be placed as the first
111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112 |
113 | Returns: the message to be sent back to the master device.
114 |
115 | """
116 | self._activated = True
117 |
118 | intermediates = [(0, master_msg)]
119 | for i in range(self.nr_slaves):
120 | intermediates.append(self._queue.get())
121 |
122 | results = self._master_callback(intermediates)
123 | assert results[0][0] == 0, 'The first result should belongs to the master.'
124 |
125 | for i, res in results:
126 | if i == 0:
127 | continue
128 | self._registry[i].result.put(res)
129 |
130 | for i in range(self.nr_slaves):
131 | assert self._queue.get() is True
132 |
133 | return results[0][1]
134 |
135 | @property
136 | def nr_slaves(self):
137 | return len(self._registry)
138 |
--------------------------------------------------------------------------------
/utils/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/utils/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 | import torch
13 |
14 |
15 | class TorchTestCase(unittest.TestCase):
16 | def assertTensorClose(self, x, y):
17 | adiff = float((x - y).abs().max())
18 | if (y == 0).all():
19 | rdiff = 'NaN'
20 | else:
21 | rdiff = float((adiff / y).abs().max())
22 |
23 | message = (
24 | 'Tensor close check failed\n'
25 | 'adiff={}\n'
26 | 'rdiff={}\n'
27 | ).format(adiff, rdiff)
28 | self.assertTrue(torch.allclose(x, y), message)
29 |
30 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Author: Zylo117
2 |
3 | import os
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | from glob import glob
9 | from torch import nn
10 | from torchvision.ops import nms
11 | from typing import Union
12 | import uuid
13 |
14 | from utils.sync_batchnorm import SynchronizedBatchNorm2d
15 |
16 |
17 | def invert_affine(metas: Union[float, list, tuple], preds):
18 | for i in range(len(preds)):
19 | if len(preds[i]['rois']) == 0:
20 | continue
21 | else:
22 | if metas is float:
23 | preds[i]['rois'][:, [0, 2]] = preds[i]['rois'][:, [0, 2]] / metas
24 | preds[i]['rois'][:, [1, 3]] = preds[i]['rois'][:, [1, 3]] / metas
25 | else:
26 | new_w, new_h, old_w, old_h, padding_w, padding_h = metas[i]
27 | preds[i]['rois'][:, [0, 2]] = preds[i]['rois'][:, [0, 2]] / (new_w / old_w)
28 | preds[i]['rois'][:, [1, 3]] = preds[i]['rois'][:, [1, 3]] / (new_h / old_h)
29 | return preds
30 |
31 |
32 | def aspectaware_resize_padding(image, width, height, interpolation=None, means=None):
33 | old_h, old_w, c = image.shape
34 | if old_w > old_h:
35 | new_w = width
36 | new_h = int(width / old_w * old_h)
37 | else:
38 | new_w = int(height / old_h * old_w)
39 | new_h = height
40 |
41 | canvas = np.zeros((height, height, c), np.float32)
42 | if means is not None:
43 | canvas[...] = means
44 |
45 | if new_w != old_w or new_h != old_h:
46 | if interpolation is None:
47 | image = cv2.resize(image, (new_w, new_h))
48 | else:
49 | image = cv2.resize(image, (new_w, new_h), interpolation=interpolation)
50 |
51 | padding_h = height - new_h
52 | padding_w = width - new_w
53 |
54 | if c > 1:
55 | canvas[:new_h, :new_w] = image
56 | else:
57 | if len(image.shape) == 2:
58 | canvas[:new_h, :new_w, 0] = image
59 | else:
60 | canvas[:new_h, :new_w] = image
61 |
62 | return canvas, new_w, new_h, old_w, old_h, padding_w, padding_h,
63 |
64 |
65 | def preprocess(*image_path, max_size=512, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)):
66 | ori_imgs = [cv2.imread(img_path) for img_path in image_path]
67 | normalized_imgs = [(img / 255 - mean) / std for img in ori_imgs]
68 | imgs_meta = [aspectaware_resize_padding(img[..., ::-1], max_size, max_size,
69 | means=None) for img in normalized_imgs]
70 | framed_imgs = [img_meta[0] for img_meta in imgs_meta]
71 | framed_metas = [img_meta[1:] for img_meta in imgs_meta]
72 |
73 | return ori_imgs, framed_imgs, framed_metas
74 |
75 |
76 | def postprocess(x, anchors, regression, classification, regressBoxes, clipBoxes, threshold, iou_threshold):
77 | transformed_anchors = regressBoxes(anchors, regression)
78 | transformed_anchors = clipBoxes(transformed_anchors, x)
79 | scores = torch.max(classification, dim=2, keepdim=True)[0]
80 | scores_over_thresh = (scores > threshold)[:, :, 0]
81 | out = []
82 | for i in range(x.shape[0]):
83 | if scores_over_thresh.sum() == 0:
84 | out.append({
85 | 'rois': np.array(()),
86 | 'class_ids': np.array(()),
87 | 'scores': np.array(()),
88 | })
89 |
90 | classification_per = classification[i, scores_over_thresh[i, :], ...].permute(1, 0)
91 | transformed_anchors_per = transformed_anchors[i, scores_over_thresh[i, :], ...]
92 | scores_per = scores[i, scores_over_thresh[i, :], ...]
93 | anchors_nms_idx = nms(transformed_anchors_per, scores_per[:, 0], iou_threshold=iou_threshold)
94 |
95 | if anchors_nms_idx.shape[0] != 0:
96 | scores_, classes_ = classification_per[:, anchors_nms_idx].max(dim=0)
97 | boxes_ = transformed_anchors_per[anchors_nms_idx, :]
98 |
99 | out.append({
100 | 'rois': boxes_.cpu().numpy(),
101 | 'class_ids': classes_.cpu().numpy(),
102 | 'scores': scores_.cpu().numpy(),
103 | })
104 | else:
105 | out.append({
106 | 'rois': np.array(()),
107 | 'class_ids': np.array(()),
108 | 'scores': np.array(()),
109 | })
110 |
111 | return out
112 |
113 |
114 | def display(preds, imgs, obj_list, imshow=True, imwrite=False):
115 | for i in range(len(imgs)):
116 | if len(preds[i]['rois']) == 0:
117 | continue
118 |
119 | for j in range(len(preds[i]['rois'])):
120 | (x1, y1, x2, y2) = preds[i]['rois'][j].astype(np.int)
121 | cv2.rectangle(imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2)
122 | obj = obj_list[preds[i]['class_ids'][j]]
123 | score = float(preds[i]['scores'][j])
124 |
125 | cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score),
126 | (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
127 | (255, 255, 0), 1)
128 | if imshow:
129 | cv2.imshow('img', imgs[i])
130 | cv2.waitKey(0)
131 |
132 | if imwrite:
133 | os.makedirs('test/', exist_ok=True)
134 | cv2.imwrite('test/{}.jpg'.format(uuid.uuid4().hex), imgs[i])
135 |
136 |
137 | def replace_w_sync_bn(m):
138 | for var_name in dir(m):
139 | target_attr = getattr(m, var_name)
140 | if type(target_attr) == torch.nn.BatchNorm2d:
141 | num_features = target_attr.num_features
142 | eps = target_attr.eps
143 | momentum = target_attr.momentum
144 | affine = target_attr.affine
145 |
146 | # get parameters
147 | running_mean = target_attr.running_mean
148 | running_var = target_attr.running_var
149 | if affine:
150 | weight = target_attr.weight
151 | bias = target_attr.bias
152 |
153 | setattr(m, var_name,
154 | SynchronizedBatchNorm2d(num_features, eps, momentum, affine))
155 |
156 | target_attr = getattr(m, var_name)
157 | # set parameters
158 | target_attr.running_mean = running_mean
159 | target_attr.running_var = running_var
160 | if affine:
161 | target_attr.weight = weight
162 | target_attr.bias = bias
163 |
164 | for var_name, children in m.named_children():
165 | replace_w_sync_bn(children)
166 |
167 |
168 | class CustomDataParallel(nn.DataParallel):
169 | """
170 | force splitting data to all gpus instead of sending all data to cuda:0 and then moving around.
171 | """
172 |
173 | def __init__(self, module, num_gpus):
174 | super().__init__(module)
175 | self.num_gpus = num_gpus
176 |
177 | def scatter(self, inputs, kwargs, device_ids):
178 | # More like scatter and data prep at the same time. The point is we prep the data in such a way
179 | # that no scatter is necessary, and there's no need to shuffle stuff around different GPUs.
180 | devices = ['cuda:' + str(x) for x in range(self.num_gpus)]
181 | splits = inputs[0].shape[0] // self.num_gpus
182 |
183 | return [(inputs[0][splits * device_idx: splits * (device_idx + 1)].to('cuda:{}'.format(device_idx), non_blocking=True),
184 | inputs[1][splits * device_idx: splits * (device_idx + 1)].to('cuda:{}'.format(device_idx), non_blocking=True))
185 | for device_idx in range(len(devices))], \
186 | [kwargs] * len(devices)
187 |
188 |
189 | def get_last_weights(weights_path):
190 | weights_path = glob(weights_path + '/*.pth')
191 | weights_path = sorted(weights_path,
192 | key=lambda x: int(x.rsplit('_')[-1].rsplit('.')[0]),
193 | reverse=True)[0]
194 | print('using weights {}'.format(weights_path))
195 | return weights_path
196 |
197 |
198 | def init_weights(model):
199 | for name, module in model.named_modules():
200 | is_conv_layer = isinstance(module, nn.Conv2d)
201 |
202 | if is_conv_layer:
203 | nn.init.kaiming_uniform_(module.weight.data)
204 |
205 | if module.bias is not None:
206 | module.bias.data.zero_()
207 |
--------------------------------------------------------------------------------
/voc2coco.py:
--------------------------------------------------------------------------------
1 | # pip install lxml
2 |
3 | import sys
4 | import os
5 | import json
6 | import xml.etree.ElementTree as ET
7 | import cv2
8 |
9 | START_BOUNDING_BOX_ID = 1
10 | PRE_DEFINE_CATEGORIES = {"holothurian":1,"echinus":2,"scallop":3,"starfish":4}
11 | # ["holothurian","echinus","scallop","starfish"]
12 | # If necessary, pre-define category and its id
13 | # PRE_DEFINE_CATEGORIES = {"aeroplane": 1, "bicycle": 2, "bird": 3, "boat": 4,
14 | # "bottle":5, "bus": 6, "car": 7, "cat": 8, "chair": 9,
15 | # "cow": 10, "diningtable": 11, "dog": 12, "horse": 13,
16 | # "motorbike": 14, "person": 15, "pottedplant": 16,
17 | # "sheep": 17, "sofa": 18, "train": 19, "tvmonitor": 20}
18 |
19 |
20 | def get(root, name):
21 | vars = root.findall(name)
22 | return vars
23 |
24 |
25 | def get_and_check(root, name, length):
26 | vars = root.findall(name)
27 | if len(vars) == 0:
28 | raise NotImplementedError('Can not find %s in %s.'%(name, root.tag))
29 | if length > 0 and len(vars) != length:
30 | raise NotImplementedError('The size of %s is supposed to be %d, but is %d.'%(name, length, len(vars)))
31 | if length == 1:
32 | vars = vars[0]
33 | return vars
34 |
35 |
36 | def get_filename_as_int(filename):
37 | try:
38 | filename = os.path.splitext(filename)[0]
39 | return int(filename)
40 | except:
41 | raise NotImplementedError('Filename %s is supposed to be an integer.'%(filename))
42 |
43 |
44 | def convert(xml_list, xml_dir, json_file,img_dir):
45 | list_fp = open(xml_list, 'r')
46 | json_dict = {"images":[], "type": "instances", "annotations": [],
47 | "categories": []}
48 | categories = PRE_DEFINE_CATEGORIES
49 | bnd_id = START_BOUNDING_BOX_ID
50 | for line in list_fp:
51 | line_name = line.strip()
52 | line = line_name + ".xml"
53 | print("Processing %s"%(line))
54 | xml_f = os.path.join(xml_dir, line)
55 | tree = ET.parse(xml_f)
56 | root = tree.getroot()
57 | path = get(root, 'path')
58 | try:
59 | if len(path) == 1:
60 | filename = os.path.basename(path[0].text)
61 | elif len(path) == 0:
62 | filename = get_and_check(root, 'filename', 1).text
63 | except:
64 | filename = line_name + ".jpg"
65 | # raise NotImplementedError('%d paths found in %s'%(len(path), line))
66 | ## The filename must be a number
67 | image_id = get_filename_as_int(filename)
68 | try:
69 | size = get_and_check(root, 'size', 1)
70 | width = int(get_and_check(size, 'width', 1).text)
71 | height = int(get_and_check(size, 'height', 1).text)
72 | except:
73 | img = cv2.imread(img_dir+line_name+".jpg")
74 | height = img.shape[0]
75 | width = img.shape[1]
76 |
77 | image = {'file_name': filename, 'height': height, 'width': width,
78 | 'id':image_id}
79 | json_dict['images'].append(image)
80 | ## Cruuently we do not support segmentation
81 | # segmented = get_and_check(root, 'segmented', 1).text
82 | # assert segmented == '0'
83 | for obj in get(root, 'object'):
84 | category = get_and_check(obj, 'name', 1).text
85 | if category not in categories:
86 | # new_id = len(categories)
87 | # categories[category] = new_id
88 | continue
89 | category_id = categories[category]
90 | bndbox = get_and_check(obj, 'bndbox', 1)
91 | xmin = int(get_and_check(bndbox, 'xmin', 1).text) - 1
92 | ymin = int(get_and_check(bndbox, 'ymin', 1).text) - 1
93 | xmax = int(get_and_check(bndbox, 'xmax', 1).text)
94 | ymax = int(get_and_check(bndbox, 'ymax', 1).text)
95 | assert(xmax > xmin)
96 | assert(ymax > ymin)
97 | o_width = abs(xmax - xmin)
98 | o_height = abs(ymax - ymin)
99 | ann = {'area': o_width*o_height, 'iscrowd': 0, 'image_id':
100 | image_id, 'bbox':[xmin, ymin, o_width, o_height],
101 | 'category_id': category_id, 'id': bnd_id, 'ignore': 0,
102 | 'segmentation': []}
103 | json_dict['annotations'].append(ann)
104 | bnd_id = bnd_id + 1
105 |
106 | for cate, cid in categories.items():
107 | cat = {'supercategory': 'none', 'id': cid, 'name': cate}
108 | json_dict['categories'].append(cat)
109 | json_fp = open(json_file, 'w')
110 | json_str = json.dumps(json_dict)
111 | json_fp.write(json_str)
112 | json_fp.close()
113 | list_fp.close()
114 |
115 |
116 | if __name__ == '__main__':
117 | if len(sys.argv) <= 1:
118 | print('3 auguments are need.')
119 | print('Usage: %s XML_LIST.txt XML_DIR OUTPU_JSON.json'%(sys.argv[0]))
120 | exit(1)
121 |
122 | convert(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4])
123 |
124 | # python voc2coco.py xmllist.txt ../Annotations output.json ../JPEGImages
--------------------------------------------------------------------------------