├── tool
├── util.py
├── visualize.py
├── evaluate.py
├── postprocess.py
├── distributed.py
└── boxlist.py
├── model
├── __init__.py
├── head.py
├── model.py
├── bi.py
├── loss.py
├── efficientnet.py
└── utils.py
├── .idea
├── misc.xml
├── modules.xml
├── efficientdet-anchor-free.iml
├── deployment.xml
└── workspace.xml
├── README.md
├── config.py
├── data
├── transform.py
└── dataset.py
└── train.py
/tool/util.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/efficientdet-anchor-free.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EfficientDet_anchor_free
2 | EfficientDet_anchor_free
3 |
4 |
5 |
6 | # Introduction
7 | This is EfficientDet anchor-free in Pytorch,i also completed EfficientDet_anchor-based.
8 |
9 |
10 | ## Results
11 | | |This report(anchor-free)| This report(anchor-based)|Paper |
12 | | :----- | :----- | :------ |:------ |
13 | |network|Efficientnet-b0|Efficientnet-b0|Efficientnet-b0|
14 | |datasets|COCO2017|VOC0712|COCO2017|
15 | |notes|Multi-scales|mixup-up ,label smooth, giou loss, cosine lr|Multi-scales|
16 | |MAPS|32.9|68.5|32.4|
17 | |b1-b7|TODO|TODO||--|
18 |
19 |
20 | There are some problems in EfficientDet_anchor-based,which has low Maps and slow speed. I will fix it and share the code.
21 |
22 |
23 | ## Reference
24 | * [FCOS](https://github.com/tianzhi0549/FCOS)
25 | * [EfficientDet](https://arxiv.org/pdf/1911.09070.pdf)
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/tool/visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 |
3 | matplotlib.use('Agg')
4 | from matplotlib import pyplot as plt
5 | from matplotlib import patheffects, patches
6 |
7 |
8 | def show_img(img, figsize=None, fig=None, ax=None):
9 | if not ax:
10 | fig, ax = plt.subplots(figsize=figsize)
11 |
12 | ax.imshow(img)
13 | ax.get_xaxis().set_visible(False)
14 | ax.get_yaxis().set_visible(False)
15 |
16 | return fig, ax
17 |
18 |
19 | def draw_outline(obj, line_width):
20 | obj.set_path_effects(
21 | [
22 | patheffects.Stroke(linewidth=line_width, foreground='black'),
23 | patheffects.Normal(),
24 | ]
25 | )
26 |
27 |
28 | def draw_rect(ax, box):
29 | patch = ax.add_patch(
30 | patches.Rectangle(box[:2], *box[-2:], fill=False, edgecolor='white', lw=2)
31 | )
32 | draw_outline(patch, 4)
33 |
34 |
35 | def draw_text(ax, xy, txt, sz=14):
36 | text = ax.text(
37 | *xy, txt, verticalalignment='top', color='white', fontsize=sz, weight='bold'
38 | )
39 | draw_outline(text, 1)
40 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | class Config:
2 |
3 | def __init__(self):
4 | self.network = 'efficientnet-b0'
5 | self.local_rank = 0
6 | self.lr = 0.01
7 | self.l2 = 0.0001
8 | self.batch = 16
9 | self.epoch = 50
10 | self.n_save_sample = 5
11 | self.out_channel = 256
12 | self.n_class = 81
13 | self.prior = 0.01
14 | self.threshold = 0.05
15 | self.top_n = 1000
16 | self.nms_threshold = 0.6
17 | self.post_top_n = 100
18 | self.fpn_strides = [8, 16, 32, 64, 128]
19 | self.gamma = 2.0
20 | self.alpha = 0.25
21 | self.iou_loss_type = 'giou'
22 | self.pos_radius = 1.5
23 | self.center_sample = True
24 | self.sizes = [[-1, 64], [64, 128], [128, 256], [256, 512], [512, 100000000]]
25 | self.train_min_size_range = (512,640)
26 | self.train_max_size = 800
27 | self.test_min_size = 512
28 | self.test_max_size = 800
29 | self.pixel_mean = [0.40789654, 0.44719302, 0.47026115]
30 | self.pixel_std = [0.28863828, 0.27408164, 0.27809835]
31 | self.size_divisible = 32
32 | self.min_size = 0
33 |
34 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/model/head.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import math
4 |
5 | class Scale(nn.Module):
6 | def __init__(self, init=1.0):
7 | super().__init__()
8 |
9 | self.scale = nn.Parameter(torch.tensor([init], dtype=torch.float32))
10 |
11 | def forward(self, input):
12 | return input * self.scale
13 |
14 | class FCOSHead(nn.Module):
15 | def __init__(self, in_channel, n_class, n_conv, prior):
16 | super().__init__()
17 |
18 | n_class = n_class - 1
19 |
20 | cls_tower = []
21 | bbox_tower = []
22 |
23 | for i in range(n_conv):
24 | cls_tower.append(
25 | nn.Conv2d(in_channel, in_channel, 3, padding=1, bias=False)
26 | )
27 | cls_tower.append(nn.BatchNorm2d(in_channel))
28 | # cls_tower.append(nn.GroupNorm(32, in_channel))
29 | cls_tower.append(nn.ReLU())
30 |
31 | bbox_tower.append(
32 | nn.Conv2d(in_channel, in_channel, 3, padding=1, bias=False)
33 | )
34 | cls_tower.append(nn.BatchNorm2d(in_channel))
35 | # bbox_tower.append(nn.GroupNorm(32, in_channel))
36 | bbox_tower.append(nn.ReLU())
37 |
38 | self.cls_tower = nn.Sequential(*cls_tower)
39 | self.bbox_tower = nn.Sequential(*bbox_tower)
40 |
41 | self.cls_pred = nn.Conv2d(in_channel, n_class, 3, padding=1)
42 | self.bbox_pred = nn.Conv2d(in_channel, 4, 3, padding=1)
43 | self.center_pred = nn.Conv2d(in_channel, 1, 3, padding=1)
44 |
45 | self.apply(self.init_conv_std)
46 |
47 | prior_bias = -math.log((1 - prior) / prior)
48 | nn.init.constant_(self.cls_pred.bias, prior_bias)
49 |
50 | self.scales = nn.ModuleList([Scale(1.0) for _ in range(5)])
51 |
52 | def init_conv_std(self,module, std=0.01):
53 | if isinstance(module, nn.Conv2d):
54 | nn.init.normal_(module.weight, std=std)
55 |
56 | if module.bias is not None:
57 | nn.init.constant_(module.bias, 0)
58 |
59 | def forward(self, input):
60 | logits = []
61 | bboxes = []
62 | centers = []
63 |
64 | for feat, scale in zip(input, self.scales):
65 | cls_out = self.cls_tower(feat)
66 |
67 | logits.append(self.cls_pred(cls_out))
68 | centers.append(self.center_pred(cls_out))
69 |
70 | bbox_out = self.bbox_tower(feat)
71 | bbox_out = torch.exp(scale(self.bbox_pred(bbox_out)))
72 |
73 | bboxes.append(bbox_out)
74 |
75 | return logits, bboxes, centers
76 |
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from .efficientnet import EfficientNet
4 | from .bi import BiFpn
5 | from .head import FCOSHead
6 | from tool import postprocess
7 | from .loss import FCOSLoss
8 | class EfficientDet_Free(nn.Module):
9 |
10 | def __init__(self,config):
11 | super(EfficientDet_Free,self).__init__()
12 | self.config = config
13 | self.backbone = EfficientNet.from_pretrained(config.network)
14 | self.fpn = BiFpn(in_channels=self.backbone.get_list_features()[-5:],out_channels=config.out_channels,len_input=5,bi=3)
15 | self.head = FCOSHead(config.out_channels, config.n_class, config.n_conv, config.prior)
16 | self.postprocessor = postprocess.FCOSPostprocessor(
17 | config.threshold,
18 | config.top_n,
19 | config.nms_threshold,
20 | config.post_top_n,
21 | config.min_size,
22 | config.n_class,
23 | )
24 | self.loss = FCOSLoss(
25 | config.sizes,
26 | config.gamma,
27 | config.alpha,
28 | config.iou_loss_type,
29 | config.center_sample,
30 | config.fpn_strides,
31 | config.pos_radius,
32 | )
33 | self.fpn_strides = config.fpn_strides
34 |
35 | def compute_location(self, features):
36 | locations = []
37 |
38 | for i, feat in enumerate(features):
39 | _, _, height, width = feat.shape
40 | location_per_level = self.compute_location_per_level(
41 | height, width, self.fpn_strides[i], feat.device
42 | )
43 | locations.append(location_per_level)
44 |
45 | return locations
46 |
47 | def compute_location_per_level(self, height, width, stride, device):
48 | shift_x = torch.arange(
49 | 0, width * stride, step=stride, dtype=torch.float32, device=device
50 | )
51 | shift_y = torch.arange(
52 | 0, height * stride, step=stride, dtype=torch.float32, device=device
53 | )
54 | shift_y, shift_x = torch.meshgrid(shift_y, shift_x)
55 | shift_x = shift_x.reshape(-1)
56 | shift_y = shift_y.reshape(-1)
57 | location = torch.stack((shift_x, shift_y), 1) + stride // 2
58 |
59 | return location
60 |
61 | def forward(self, input, image_sizes=None, targets=None):
62 | features = self.backbone(input)[-5:]
63 | features = self.fpn(features)
64 | cls_pred, box_pred, center_pred = self.head(features)
65 | location = self.compute_location(features)
66 |
67 | if self.training:
68 | loss_cls, loss_box, loss_center = self.loss(
69 | location, cls_pred, box_pred, center_pred, targets
70 | )
71 | losses = {
72 | 'loss_cls': loss_cls,
73 | 'loss_box': loss_box,
74 | 'loss_center': loss_center,
75 | }
76 |
77 | return None, losses
78 |
79 | else:
80 | boxes = self.postprocessor(
81 | location, cls_pred, box_pred, center_pred, image_sizes
82 | )
83 |
84 | return boxes, None
85 |
86 |
87 |
88 |
--------------------------------------------------------------------------------
/data/transform.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 | import torchvision
5 | from torchvision.transforms import functional as F
6 |
7 |
8 | class Compose:
9 | def __init__(self, transforms):
10 | self.transforms = transforms
11 |
12 | def __call__(self, img, target):
13 | for t in self.transforms:
14 | img, target = t(img, target)
15 |
16 | return img, target
17 |
18 | def __repr__(self):
19 | format_str = self.__class__.__name__ + '('
20 | for t in self.transforms:
21 | format_str += '\n'
22 | format_str += ' {t}'
23 | format_str += '\n)'
24 |
25 | return format_str
26 |
27 |
28 | class Resize:
29 | def __init__(self, min_size, max_size):
30 | if not isinstance(min_size, (list, tuple)):
31 | min_size = (min_size,)
32 |
33 | self.min_size = min_size
34 | self.max_size = max_size
35 |
36 | def get_size(self, img_size):
37 | w, h = img_size
38 | size = random.choice(self.min_size)
39 | max_size = self.max_size
40 |
41 | if max_size is not None:
42 | min_orig = float(min((w, h)))
43 | max_orig = float(max((w, h)))
44 |
45 | if max_orig / min_orig * size > max_size:
46 | size = int(round(max_size * min_orig / max_orig))
47 |
48 | if (w <= h and w == size) or (h <= w and h == size):
49 | return h, w
50 |
51 | if w < h:
52 | ow = size
53 | oh = int(size * h / w)
54 |
55 | else:
56 | oh = size
57 | ow = int(size * w / h)
58 |
59 | return oh, ow
60 |
61 | def __call__(self, img, target):
62 | size = self.get_size(img.size)
63 | img = F.resize(img, size)
64 | target = target.resize(img.size)
65 |
66 | return img, target
67 |
68 |
69 | class RandomHorizontalFlip:
70 | def __init__(self, p=0.5):
71 | self.p = p
72 |
73 | def __call__(self, img, target):
74 | if random.random() < self.p:
75 | img = F.hflip(img)
76 | target = target.transpose(0)
77 |
78 | return img, target
79 |
80 |
81 | class ToTensor:
82 | def __call__(self, img, target):
83 | return F.to_tensor(img), target
84 |
85 |
86 | class Normalize:
87 | def __init__(self, mean, std):
88 | self.mean = mean
89 | self.std = std
90 |
91 | def __call__(self, img, target):
92 | img = F.normalize(img, mean=self.mean, std=self.std)
93 |
94 | return img, target
95 |
96 |
97 | def preset_transform(config, train=True):
98 | if train:
99 | if config.train_min_size_range[0] == -1:
100 | min_size = config.train_min_size
101 |
102 | else:
103 | min_size = list(
104 | range(
105 | config.train_min_size_range[0], config.train_min_size_range[1] + 1
106 | )
107 | )
108 |
109 | max_size = config.train_max_size
110 | flip = 0.5
111 |
112 | else:
113 | min_size = config.test_min_size
114 | max_size = config.test_max_size
115 | flip = 0
116 |
117 | normalize = Normalize(mean=config.pixel_mean, std=config.pixel_std)
118 |
119 | transform = Compose(
120 | [Resize(min_size, max_size), RandomHorizontalFlip(flip), ToTensor(), normalize]
121 | )
122 |
123 | return transform
124 |
--------------------------------------------------------------------------------
/tool/evaluate.py:
--------------------------------------------------------------------------------
1 | import json
2 | import tempfile
3 | from collections import OrderedDict
4 |
5 | import numpy as np
6 | from pycocotools.coco import COCO
7 | from pycocotools.cocoeval import COCOeval
8 |
9 |
10 | def evaluate(dataset, predictions):
11 | coco_results = {}
12 | coco_results['bbox'] = make_coco_detection(predictions, dataset)
13 |
14 | results = COCOResult('bbox')
15 |
16 | with tempfile.NamedTemporaryFile() as f:
17 | path = f.name
18 | res = evaluate_predictions_on_coco(
19 | dataset.coco, coco_results['bbox'], path, 'bbox'
20 | )
21 | results.update(res)
22 |
23 | print(results)
24 |
25 | return res
26 |
27 |
28 | def evaluate_predictions_on_coco(coco_gt, results, result_file, iou_type):
29 | with open(result_file, 'w') as f:
30 | json.dump(results, f)
31 |
32 | coco_dt = coco_gt.loadRes(str(result_file)) if results else COCO()
33 |
34 | coco_eval = COCOeval(coco_gt, coco_dt, iou_type)
35 | coco_eval.evaluate()
36 | coco_eval.accumulate()
37 | coco_eval.summarize()
38 |
39 | # compute_thresholds_for_classes(coco_eval)
40 |
41 | return coco_eval
42 |
43 |
44 | def compute_thresholds_for_classes(coco_eval):
45 | precision = coco_eval.eval['precision']
46 | precision = precision[0, :, :, 0, -1]
47 | scores = coco_eval.eval['scores']
48 | scores = scores[0, :, :, 0, -1]
49 |
50 | recall = np.linspace(0, 1, num=precision.shape[0])
51 | recall = recall[:, None]
52 |
53 | f1 = (2 * precision * recall) / (np.maximum(precision + recall, 1e-6))
54 | max_f1 = f1.max(0)
55 | max_f1_id = f1.argmax(0)
56 | scores = scores[max_f1_id, range(len(max_f1_id))]
57 |
58 | print('Maximum f1 for classes:')
59 | print(list(max_f1))
60 | print('Score thresholds for classes')
61 | print(list(scores))
62 |
63 |
64 | def make_coco_detection(predictions, dataset):
65 | coco_results = []
66 |
67 | for id, pred in enumerate(predictions):
68 | orig_id = dataset.id2img[id]
69 |
70 | if len(pred) == 0:
71 | continue
72 |
73 | img_meta = dataset.get_image_meta(id)
74 | width = img_meta['width']
75 | height = img_meta['height']
76 | pred = pred.resize((width, height))
77 | pred = pred.convert('xywh')
78 |
79 | boxes = pred.box.tolist()
80 | scores = pred.fields['scores'].tolist()
81 | labels = pred.fields['labels'].tolist()
82 |
83 | labels = [dataset.id2category[i] for i in labels]
84 |
85 | coco_results.extend(
86 | [
87 | {
88 | 'image_id': orig_id,
89 | 'category_id': labels[k],
90 | 'bbox': box,
91 | 'score': scores[k],
92 | }
93 | for k, box in enumerate(boxes)
94 | ]
95 | )
96 |
97 | return coco_results
98 |
99 |
100 | class COCOResult:
101 | METRICS = {
102 | 'bbox': ['AP', 'AP50', 'AP75', 'APs', 'APm', 'APl'],
103 | 'segm': ['AP', 'AP50', 'AP75', 'APs', 'APm', 'APl'],
104 | 'box_proposal': [
105 | 'AR@100',
106 | 'ARs@100',
107 | 'ARm@100',
108 | 'ARl@100',
109 | 'AR@1000',
110 | 'ARs@1000',
111 | 'ARm@1000',
112 | 'ARl@1000',
113 | ],
114 | 'keypoints': ['AP', 'AP50', 'AP75', 'APm', 'APl'],
115 | }
116 |
117 | def __init__(self, *iou_types):
118 | allowed_types = ("box_proposal", "bbox", "segm", "keypoints")
119 | assert all(iou_type in allowed_types for iou_type in iou_types)
120 | results = OrderedDict()
121 | for iou_type in iou_types:
122 | results[iou_type] = OrderedDict(
123 | [(metric, -1) for metric in COCOResult.METRICS[iou_type]]
124 | )
125 | self.results = results
126 |
127 | def update(self, coco_eval):
128 | if coco_eval is None:
129 | return
130 |
131 | assert isinstance(coco_eval, COCOeval)
132 | s = coco_eval.stats
133 | iou_type = coco_eval.params.iouType
134 | res = self.results[iou_type]
135 | metrics = COCOResult.METRICS[iou_type]
136 | for idx, metric in enumerate(metrics):
137 | res[metric] = s[idx]
138 |
139 | def __repr__(self):
140 | return repr(self.results)
141 |
--------------------------------------------------------------------------------
/model/bi.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ConvBlock(nn.Module):
7 | def __init__(self,in_channels,out_channels,kernel = 3,strides = 1,padding=1,
8 | bias = True,act = True,bn = True):
9 | super(ConvBlock,self).__init__()
10 | self.conv = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,
11 | kernel_size=kernel,stride=strides,padding=padding,
12 | bias=bias)
13 | self.act = nn.ReLU(True) if act else None
14 | self.bn = nn.BatchNorm2d(in_channels) if bn else None
15 |
16 | def forward(self, x):
17 | x = self.conv(x)
18 | if self.bn:
19 | x = self.bn(x)
20 | if self.act:
21 | x = self.act(x)
22 | return x
23 |
24 |
25 | class BiBlock(nn.Module):
26 |
27 | def __init__(self,feature_size):
28 | super(BiBlock,self).__init__()
29 |
30 |
31 | self.maxpooling = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
32 | self.p7_to_p6 = ConvBlock(feature_size,feature_size)
33 | self.p6_to_p5 = ConvBlock(feature_size,feature_size)
34 | self.p5_to_p4 = ConvBlock(feature_size,feature_size)
35 |
36 | self.p3 = ConvBlock(feature_size,feature_size)
37 | self.p4 = ConvBlock(feature_size,feature_size)
38 | self.p5 = ConvBlock(feature_size,feature_size)
39 | self.p6 = ConvBlock(feature_size,feature_size)
40 | self.p7 = ConvBlock(feature_size,feature_size)
41 |
42 |
43 |
44 | def forward(self, pyramids):
45 | p3,p4,p5,p6,p7 = pyramids
46 |
47 | p7_to_p6 = F.upsample(p7,size=p6.shape[-2:])
48 | # p7_to_p6 = F.upsample(p7,scale_factor=2)
49 | p7_to_p6 = self.p7_to_p6(p7_to_p6 + p6)
50 |
51 | p6_to_p5 = F.upsample(p7_to_p6,p5.shape[-2:])
52 | # p6_to_p5 = F.upsample(p7_to_p6,scale_factor=2)
53 | p6_to_p5 = self.p6_to_p5(p6_to_p5 + p5)
54 |
55 | p5_to_p4 = F.upsample(p6_to_p5,size=p4.shape[-2:])
56 | # p5_to_p4 = F.upsample(p6_to_p5,scale_factor=2)
57 | p5_to_p4 = self.p5_to_p4(p5_to_p4 + p4)
58 |
59 | p4_to_p3 = F.upsample(p5_to_p4,size=p3.shape[-2:])
60 | # p4_to_p3 = F.upsample(p5_to_p4,scale_factor=2)
61 | p3 = self.p3(p4_to_p3 + p3)
62 |
63 | p3_to_p4 = self.maxpooling(p3)
64 | p4 = self.p4(p3_to_p4 + p5_to_p4 + p4)
65 |
66 | p4_to_p5 = self.maxpooling(p4)
67 | p5 = self.p5(p4_to_p5 + p6_to_p5 + p5)
68 |
69 | p5_to_p6 = self.maxpooling(p5)
70 | p5_to_p6 = F.upsample(p5_to_p6,size=p6.shape[-2:])
71 | p6 = self.p6(p5_to_p6 + p7_to_p6 + p6)
72 |
73 | p6_to_p7 = self.maxpooling(p6)
74 | p6_to_p7 = F.upsample(p6_to_p7,size=p7.shape[-2:])
75 | p7 = self.p7(p6_to_p7 + p7)
76 |
77 | return p3,p4,p5,p6,p7
78 |
79 |
80 | class BiFpn(nn.Module):
81 |
82 | def __init__(self,in_channels,out_channels,len_input,bi = 3):
83 | super(BiFpn,self).__init__()
84 | assert len_input <= 5
85 | self.len_input = len_input
86 | self.bi = bi
87 | self.default = 5 - len_input
88 | for i in range(len_input):
89 | setattr(self, 'p{}'.format(str(i)), ConvBlock(in_channels=in_channels[i], out_channels=out_channels,
90 | kernel=1, strides=1, padding=0, act=False, bn=False))
91 | if self.default > 0:
92 | for i in range(self.default):
93 | setattr(self,'make_pyramid{}'.format(str(i)),ConvBlock(in_channels=in_channels[-1] if i == 0 else out_channels,out_channels=out_channels,kernel=3,strides=2,
94 | padding=1,act=False,bn=False))
95 | for i in range(bi):
96 | setattr(self, 'biblock{}'.format(str(i)), BiBlock(out_channels))
97 |
98 | def forward(self, inputs):
99 | pyramids = []
100 | for i in range(self.len_input):
101 | pyramids.append(getattr(self,'p{}'.format(str(i)))(inputs[i]))
102 |
103 | if self.default > 0:
104 | x = inputs[-1]
105 | for i in range(self.default):
106 | x = getattr(self,'make_pyramid{}'.format(str(i)))(x)
107 | pyramids.append(x)
108 |
109 | for i in range(self.bi):
110 | pyramids = getattr(self,'biblock{}'.format(str(i)))(pyramids)
111 |
112 | return pyramids
113 |
114 |
115 | if __name__ == '__main__':
116 | p3 = torch.randn(2,512,64,64)
117 | p4 = torch.randn(2,1024,32,32)
118 | p5 = torch.randn(2,2048,16,16)
119 |
120 | b = BiFpn([512,1024,2048],256,3,3)
121 | py = b([p3,p4,p5])
122 | for i in py:
123 | print(i.shape)
124 |
125 |
--------------------------------------------------------------------------------
/tool/postprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from .boxlist import BoxList, boxlist_nms, remove_small_box, cat_boxlist
5 |
6 |
7 | class FCOSPostprocessor(nn.Module):
8 | def __init__(self, threshold, top_n, nms_threshold, post_top_n, min_size, n_class):
9 | super().__init__()
10 |
11 | self.threshold = threshold
12 | self.top_n = top_n
13 | self.nms_threshold = nms_threshold
14 | self.post_top_n = post_top_n
15 | self.min_size = min_size
16 | self.n_class = n_class
17 |
18 | def forward_single_feature_map(
19 | self, location, cls_pred, box_pred, center_pred, image_sizes
20 | ):
21 | batch, channel, height, width = cls_pred.shape
22 |
23 | cls_pred = cls_pred.view(batch, channel, height, width).permute(0, 2, 3, 1)
24 | cls_pred = cls_pred.reshape(batch, -1, channel).sigmoid()
25 |
26 | box_pred = box_pred.view(batch, 4, height, width).permute(0, 2, 3, 1)
27 | box_pred = box_pred.reshape(batch, -1, 4)
28 |
29 | center_pred = center_pred.view(batch, 1, height, width).permute(0, 2, 3, 1)
30 | center_pred = center_pred.reshape(batch, -1).sigmoid()
31 |
32 | candid_ids = cls_pred > self.threshold
33 | top_ns = candid_ids.view(batch, -1).sum(1)
34 | top_ns = top_ns.clamp(max=self.top_n)
35 |
36 | cls_pred = cls_pred * center_pred[:, :, None]
37 |
38 | results = []
39 |
40 | for i in range(batch):
41 | cls_p = cls_pred[i]
42 | candid_id = candid_ids[i]
43 | cls_p = cls_p[candid_id]
44 | candid_nonzero = candid_id.nonzero()
45 | box_loc = candid_nonzero[:, 0]
46 | class_id = candid_nonzero[:, 1] + 1
47 |
48 | box_p = box_pred[i]
49 | box_p = box_p[box_loc]
50 | loc = location[box_loc]
51 |
52 | top_n = top_ns[i]
53 |
54 | if candid_id.sum().item() > top_n.item():
55 | cls_p, top_k_id = cls_p.topk(top_n, sorted=False)
56 | class_id = class_id[top_k_id]
57 | box_p = box_p[top_k_id]
58 | loc = loc[top_k_id]
59 |
60 | detections = torch.stack(
61 | [
62 | loc[:, 0] - box_p[:, 0],
63 | loc[:, 1] - box_p[:, 1],
64 | loc[:, 0] + box_p[:, 2],
65 | loc[:, 1] + box_p[:, 3],
66 | ],
67 | 1,
68 | )
69 |
70 | height, width = image_sizes[i]
71 |
72 | boxlist = BoxList(detections, (int(width), int(height)), mode='xyxy')
73 | boxlist.fields['labels'] = class_id
74 | boxlist.fields['scores'] = torch.sqrt(cls_p)
75 | boxlist = boxlist.clip(remove_empty=False)
76 | boxlist = remove_small_box(boxlist, self.min_size)
77 |
78 | results.append(boxlist)
79 |
80 | return results
81 |
82 | def forward(self, location, cls_pred, box_pred, center_pred, image_sizes):
83 | boxes = []
84 |
85 | for loc, cls_p, box_p, center_p in zip(
86 | location, cls_pred, box_pred, center_pred
87 | ):
88 | boxes.append(
89 | self.forward_single_feature_map(
90 | loc, cls_p, box_p, center_p, image_sizes
91 | )
92 | )
93 |
94 | boxlists = list(zip(*boxes))
95 | boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
96 | boxlists = self.select_over_scales(boxlists)
97 |
98 | return boxlists
99 |
100 | def select_over_scales(self, boxlists):
101 | results = []
102 |
103 | for boxlist in boxlists:
104 | scores = boxlist.fields['scores']
105 | labels = boxlist.fields['labels']
106 | box = boxlist.box
107 |
108 | result = []
109 |
110 | for j in range(1, self.n_class):
111 | id = (labels == j).nonzero().view(-1)
112 | score_j = scores[id]
113 | box_j = box[id, :].view(-1, 4)
114 | box_by_class = BoxList(box_j, boxlist.size, mode='xyxy')
115 | box_by_class.fields['scores'] = score_j
116 | box_by_class = boxlist_nms(box_by_class, score_j, self.nms_threshold)
117 | n_label = len(box_by_class)
118 | box_by_class.fields['labels'] = torch.full(
119 | (n_label,), j, dtype=torch.int64, device=scores.device
120 | )
121 | result.append(box_by_class)
122 |
123 | result = cat_boxlist(result)
124 | n_detection = len(result)
125 |
126 | if n_detection > self.post_top_n > 0:
127 | scores = result.fields['scores']
128 | img_threshold, _ = torch.kthvalue(
129 | scores.cpu(), n_detection - self.post_top_n + 1
130 | )
131 | keep = scores >= img_threshold.item()
132 | keep = torch.nonzero(keep).squeeze(1)
133 | result = result[keep]
134 |
135 | results.append(result)
136 |
137 | return results
138 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torchvision import datasets
5 |
6 | from tool.boxlist import BoxList
7 |
8 |
9 | CLASS_NAME = [
10 | '__background__',
11 | 'person',
12 | 'bicycle',
13 | 'car',
14 | 'motorcycle',
15 | 'airplane',
16 | 'bus',
17 | 'train',
18 | 'truck',
19 | 'boat',
20 | 'traffic light',
21 | 'fire hydrant',
22 | 'stop sign',
23 | 'parking meter',
24 | 'bench',
25 | 'bird',
26 | 'cat',
27 | 'dog',
28 | 'horse',
29 | 'sheep',
30 | 'cow',
31 | 'elephant',
32 | 'bear',
33 | 'zebra',
34 | 'giraffe',
35 | 'backpack',
36 | 'umbrella',
37 | 'handbag',
38 | 'tie',
39 | 'suitcase',
40 | 'frisbee',
41 | 'skis',
42 | 'snowboard',
43 | 'sports ball',
44 | 'kite',
45 | 'baseball bat',
46 | 'baseball glove',
47 | 'skateboard',
48 | 'surfboard',
49 | 'tennis racket',
50 | 'bottle',
51 | 'wine glass',
52 | 'cup',
53 | 'fork',
54 | 'knife',
55 | 'spoon',
56 | 'bowl',
57 | 'banana',
58 | 'apple',
59 | 'sandwich',
60 | 'orange',
61 | 'broccoli',
62 | 'carrot',
63 | 'hot dog',
64 | 'pizza',
65 | 'donut',
66 | 'cake',
67 | 'chair',
68 | 'couch',
69 | 'potted plant',
70 | 'bed',
71 | 'dining table',
72 | 'toilet',
73 | 'tv',
74 | 'laptop',
75 | 'mouse',
76 | 'remote',
77 | 'keyboard',
78 | 'cell phone',
79 | 'microwave',
80 | 'oven',
81 | 'toaster',
82 | 'sink',
83 | 'refrigerator',
84 | 'book',
85 | 'clock',
86 | 'vase',
87 | 'scissors',
88 | 'teddy bear',
89 | 'hair drier',
90 | 'toothbrush',
91 | ]
92 |
93 |
94 | def has_only_empty_bbox(annot):
95 | return all(any(o <= 1 for o in obj['bbox'][2:]) for obj in annot)
96 |
97 |
98 | def has_valid_annotation(annot):
99 | if len(annot) == 0:
100 | return False
101 |
102 | if has_only_empty_bbox(annot):
103 | return False
104 |
105 | return True
106 |
107 |
108 | class COCODataset(datasets.CocoDetection):
109 | def __init__(self, path, split, transform=None):
110 | root = os.path.join(path, '{}2017'.format(split))
111 | annot = os.path.join(path, 'annotations', 'instances_{}2017.json'.format(split))
112 |
113 | super().__init__(root, annot)
114 |
115 | self.ids = sorted(self.ids)
116 |
117 | if split == 'train':
118 | ids = []
119 |
120 | for id in self.ids:
121 | ann_ids = self.coco.getAnnIds(imgIds=id, iscrowd=None)
122 | annot = self.coco.loadAnns(ann_ids)
123 |
124 | if has_valid_annotation(annot):
125 | ids.append(id)
126 |
127 | self.ids = ids
128 |
129 | self.category2id = {v: i + 1 for i, v in enumerate(self.coco.getCatIds())}
130 | self.id2category = {v: k for k, v in self.category2id.items()}
131 | self.id2img = {k: v for k, v in enumerate(self.ids)}
132 |
133 | self.transform = transform
134 |
135 | def __getitem__(self, index):
136 | img, annot = super().__getitem__(index)
137 |
138 | annot = [o for o in annot if o['iscrowd'] == 0]
139 |
140 | boxes = [o['bbox'] for o in annot]
141 | boxes = torch.as_tensor(boxes).reshape(-1, 4)
142 | target = BoxList(boxes, img.size, mode='xywh').convert('xyxy')
143 |
144 | classes = [o['category_id'] for o in annot]
145 | classes = [self.category2id[c] for c in classes]
146 | classes = torch.tensor(classes)
147 | target.fields['labels'] = classes
148 |
149 | target.clip(remove_empty=True)
150 |
151 | if self.transform is not None:
152 | img, target = self.transform(img, target)
153 |
154 | return img, target, index
155 |
156 | def get_image_meta(self, index):
157 | id = self.id2img[index]
158 | img_data = self.coco.imgs[id]
159 |
160 | return img_data
161 |
162 |
163 | class ImageList:
164 | def __init__(self, tensors, sizes):
165 | self.tensors = tensors
166 | self.sizes = sizes
167 |
168 | def to(self, *args, **kwargs):
169 | tensor = self.tensors.to(*args, **kwargs)
170 |
171 | return ImageList(tensor, self.sizes)
172 |
173 |
174 | def image_list(tensors, size_divisible=0):
175 | max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
176 |
177 | if size_divisible > 0:
178 | stride = size_divisible
179 | max_size = list(max_size)
180 | max_size[1] = (max_size[1] | (stride - 1)) + 1
181 | max_size[2] = (max_size[2] | (stride - 1)) + 1
182 | max_size = tuple(max_size)
183 |
184 | shape = (len(tensors),) + max_size
185 | batch = tensors[0].new(*shape).zero_()
186 |
187 | for img, pad_img in zip(tensors, batch):
188 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
189 |
190 | sizes = [img.shape[-2:] for img in tensors]
191 |
192 | return ImageList(batch, sizes)
193 |
194 |
195 | def collate_fn(config):
196 | def collate_data(batch):
197 | batch = list(zip(*batch))
198 | imgs = image_list(batch[0], config.size_divisible)
199 | targets = batch[1]
200 | ids = batch[2]
201 |
202 | return imgs, targets, ids
203 |
204 | return collate_data
205 |
--------------------------------------------------------------------------------
/tool/distributed.py:
--------------------------------------------------------------------------------
1 | import math
2 | import pickle
3 |
4 | import torch
5 | from torch import distributed as dist
6 | from torch.utils.data.sampler import Sampler
7 |
8 |
9 | def get_rank():
10 | if not dist.is_available():
11 | return 0
12 |
13 | if not dist.is_initialized():
14 | return 0
15 |
16 | return dist.get_rank()
17 |
18 |
19 | def synchronize():
20 | if not dist.is_available():
21 | return
22 |
23 | if not dist.is_initialized():
24 | return
25 |
26 | world_size = dist.get_world_size()
27 |
28 | if world_size == 1:
29 | return
30 |
31 | dist.barrier()
32 |
33 |
34 | def get_world_size():
35 | if not dist.is_available():
36 | return 1
37 |
38 | if not dist.is_initialized():
39 | return 1
40 |
41 | return dist.get_world_size()
42 |
43 |
44 | def all_gather(data):
45 | world_size = get_world_size()
46 |
47 | if world_size == 1:
48 | return [data]
49 |
50 | buffer = pickle.dumps(data)
51 | storage = torch.ByteStorage.from_buffer(buffer)
52 | tensor = torch.ByteTensor(storage).to('cuda')
53 |
54 | local_size = torch.IntTensor([tensor.numel()]).to('cuda')
55 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
56 | dist.all_gather(size_list, local_size)
57 | size_list = [int(size.item()) for size in size_list]
58 | max_size = max(size_list)
59 |
60 | tensor_list = []
61 | for _ in size_list:
62 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
63 |
64 | if local_size != max_size:
65 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
66 | tensor = torch.cat((tensor, padding), 0)
67 |
68 | dist.all_gather(tensor_list, tensor)
69 |
70 | data_list = []
71 |
72 | for size, tensor in zip(size_list, tensor_list):
73 | buffer = tensor.cpu().numpy().tobytes()[:size]
74 | data_list.append(pickle.loads(buffer))
75 |
76 | return data_list
77 |
78 |
79 | def reduce_loss_dict(loss_dict):
80 | world_size = get_world_size()
81 |
82 | if world_size < 2:
83 | return loss_dict
84 |
85 | with torch.no_grad():
86 | keys = []
87 | losses = []
88 |
89 | for k in sorted(loss_dict.keys()):
90 | keys.append(k)
91 | losses.append(loss_dict[k])
92 |
93 | losses = torch.stack(losses, 0)
94 | dist.reduce(losses, dst=0)
95 |
96 | if dist.get_rank() == 0:
97 | losses /= world_size
98 |
99 | reduced_losses = {k: v for k, v in zip(keys, losses)}
100 |
101 | return reduced_losses
102 |
103 |
104 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
105 | # Code is copy-pasted exactly as in torch.utils.data.distributed.
106 | # FIXME remove this once c10d fixes the bug it has
107 |
108 |
109 | class DistributedSampler(Sampler):
110 | """Sampler that restricts data loading to a subset of the dataset.
111 | It is especially useful in conjunction with
112 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
113 | process can pass a DistributedSampler instance as a DataLoader sampler,
114 | and load a subset of the original dataset that is exclusive to it.
115 | .. note::
116 | Dataset is assumed to be of constant size.
117 | Arguments:
118 | dataset: Dataset used for sampling.
119 | num_replicas (optional): Number of processes participating in
120 | distributed training.
121 | rank (optional): Rank of the current process within num_replicas.
122 | """
123 |
124 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
125 | if num_replicas is None:
126 | if not dist.is_available():
127 | raise RuntimeError("Requires distributed package to be available")
128 | num_replicas = dist.get_world_size()
129 | if rank is None:
130 | if not dist.is_available():
131 | raise RuntimeError("Requires distributed package to be available")
132 | rank = dist.get_rank()
133 | self.dataset = dataset
134 | self.num_replicas = num_replicas
135 | self.rank = rank
136 | self.epoch = 0
137 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
138 | self.total_size = self.num_samples * self.num_replicas
139 | self.shuffle = shuffle
140 |
141 | def __iter__(self):
142 | if self.shuffle:
143 | # deterministically shuffle based on epoch
144 | g = torch.Generator()
145 | g.manual_seed(self.epoch)
146 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
147 | else:
148 | indices = torch.arange(len(self.dataset)).tolist()
149 |
150 | # add extra samples to make it evenly divisible
151 | indices += indices[: (self.total_size - len(indices))]
152 | assert len(indices) == self.total_size
153 |
154 | # subsample
155 | offset = self.num_samples * self.rank
156 | indices = indices[offset : offset + self.num_samples]
157 | assert len(indices) == self.num_samples
158 |
159 | return iter(indices)
160 |
161 | def __len__(self):
162 | return self.num_samples
163 |
164 | def set_epoch(self, epoch):
165 | self.epoch = epoch
166 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch import nn, optim
4 | from torch.utils.data import DataLoader, sampler
5 | from tqdm import tqdm
6 | from data.dataset import COCODataset, collate_fn
7 |
8 | from model.model import EfficientDet_Free
9 | from data.transform import preset_transform
10 | from tool.distributed import (
11 | get_rank,
12 | synchronize,
13 | reduce_loss_dict,
14 | DistributedSampler,
15 | all_gather
16 | )
17 | from tool.evaluate import evaluate
18 | from config import Config
19 |
20 | def accumulate_predictions(predictions):
21 | all_predictions = all_gather(predictions)
22 |
23 | if get_rank() != 0:
24 | return
25 |
26 | predictions = {}
27 |
28 | for p in all_predictions:
29 | predictions.update(p)
30 |
31 | ids = list(sorted(predictions.keys()))
32 |
33 | if len(ids) != ids[-1] + 1:
34 | print('Evaluation results is not contiguous')
35 |
36 | predictions = [predictions[i] for i in ids]
37 |
38 | return predictions
39 |
40 | @torch.no_grad()
41 | def valid(args, epoch, loader, dataset, model, device):
42 | if args.distributed:
43 | model = model.module
44 |
45 | torch.cuda.empty_cache()
46 |
47 | model.eval()
48 |
49 | pbar = tqdm(loader, dynamic_ncols=True)
50 |
51 | preds = {}
52 |
53 | for images, targets, ids in pbar:
54 | model.zero_grad()
55 |
56 | images = images.to(device)
57 | targets = [target.to(device) for target in targets]
58 |
59 | pred, _ = model(images.tensors, images.sizes)
60 |
61 | pred = [p.to('cpu') for p in pred]
62 |
63 | preds.update({id: p for id, p in zip(ids, pred)})
64 |
65 | preds = accumulate_predictions(preds)
66 |
67 | if get_rank() != 0:
68 | return
69 |
70 | evaluate(dataset, preds)
71 |
72 | def train(args, epoch, loader, model, optimizer, device):
73 | model.train()
74 | if get_rank() == 0:
75 | pbar = tqdm(loader, dynamic_ncols=True)
76 |
77 | else:
78 | pbar = loader
79 | for images, targets, _ in pbar:
80 | model.zero_grad()
81 |
82 | images = images.to(device)
83 | targets = [target.to(device) for target in targets]
84 |
85 | _, loss_dict = model(images.tensors, targets=targets)
86 | loss_cls = loss_dict['loss_cls'].mean()
87 | loss_box = loss_dict['loss_box'].mean()
88 | loss_center = loss_dict['loss_center'].mean()
89 |
90 | loss = loss_cls + loss_box + loss_center
91 | loss.backward()
92 | nn.utils.clip_grad_norm_(model.parameters(), 10)
93 | optimizer.step()
94 |
95 | loss_reduced = reduce_loss_dict(loss_dict)
96 | loss_cls = loss_reduced['loss_cls'].mean().item()
97 | loss_box = loss_reduced['loss_box'].mean().item()
98 | loss_center = loss_reduced['loss_center'].mean().item()
99 |
100 | if get_rank() == 0:
101 | pbar.set_description(
102 | (
103 | 'epoch: {}; cls: {}; box: {}; center: {}'.format(epoch + 1,loss_cls,loss_box,loss_center)
104 | )
105 | )
106 |
107 | def data_sampler(dataset, shuffle, distributed):
108 | if distributed:
109 | return DistributedSampler(dataset, shuffle=shuffle)
110 |
111 | if shuffle:
112 | return sampler.RandomSampler(dataset)
113 |
114 | else:
115 | return sampler.SequentialSampler(dataset)
116 |
117 | if __name__ == '__main__':
118 | args = Config()
119 | n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
120 | args.distributed = n_gpu > 1
121 |
122 | if args.distributed:
123 | torch.cuda.set_device(args.local_rank)
124 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
125 | synchronize()
126 |
127 | device = 'cuda:0'
128 | train_set = COCODataset(args.path, 'train', preset_transform(args, train=True))
129 | valid_set = COCODataset(args.path, 'val', preset_transform(args, train=False))
130 |
131 |
132 | model = EfficientDet_Free(args)
133 | model.load_state_dict()
134 | model = model.to(device)
135 | optimizer = optim.SGD(
136 | model.parameters(),
137 | lr=args.lr,
138 | momentum=0.9,
139 | weight_decay=args.l2,
140 | nesterov=True,
141 | )
142 | scheduler = optim.lr_scheduler.MultiStepLR(
143 | optimizer, milestones=[16, 22], gamma=0.1
144 | )
145 | if args.distributed:
146 | model = nn.parallel.DistributedDataParallel(
147 | model,
148 | device_ids=[args.local_rank],
149 | output_device=args.local_rank,
150 | broadcast_buffers=False,
151 | )
152 |
153 | train_loader = DataLoader(
154 | train_set,
155 | batch_size=args.batch,
156 | sampler=data_sampler(train_set, shuffle=True, distributed=args.distributed),
157 | num_workers=2,
158 | collate_fn=collate_fn(args),
159 | )
160 | valid_loader = DataLoader(
161 | valid_set,
162 | batch_size=args.batch,
163 | sampler=data_sampler(valid_set, shuffle=False, distributed=args.distributed),
164 | num_workers=2,
165 | collate_fn=collate_fn(args),
166 | )
167 | for epoch in range(args.epoch):
168 | train(args, epoch, train_loader, model, optimizer, device)
169 | valid(args, epoch, valid_loader, valid_set, model, device)
170 |
171 | scheduler.step()
172 |
173 | if get_rank() == 0:
174 | torch.save(
175 | {'model': model.module.state_dict(), 'optim': optimizer.state_dict()},
176 | 'checkpoint/epoch-{epoch + 1}.pt',
177 | )
178 |
179 |
--------------------------------------------------------------------------------
/tool/boxlist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import ops
3 |
4 |
5 | FLIP_LEFT_RIGHT = 0
6 | FLIP_TOP_BOTTOM = 1
7 |
8 |
9 | class BoxList:
10 | def __init__(self, box, image_size, mode='xyxy'):
11 | device = box.device if hasattr(box, 'device') else 'cpu'
12 | box = torch.as_tensor(box, dtype=torch.float32, device=device)
13 |
14 | self.box = box
15 | self.size = image_size
16 | self.mode = mode
17 |
18 | self.fields = {}
19 |
20 | def convert(self, mode):
21 | if mode == self.mode:
22 | return self
23 |
24 | x_min, y_min, x_max, y_max = self.split_to_xyxy()
25 |
26 | if mode == 'xyxy':
27 | box = torch.cat([x_min, y_min, x_max, y_max], -1)
28 | box = BoxList(box, self.size, mode=mode)
29 |
30 | elif mode == 'xywh':
31 | remove = 1
32 | box = torch.cat(
33 | [x_min, y_min, x_max - x_min + remove, y_max - y_min + remove], -1
34 | )
35 | box = BoxList(box, self.size, mode=mode)
36 |
37 | box.copy_field(self)
38 |
39 | return box
40 |
41 | def copy_field(self, box):
42 | for k, v in box.fields.items():
43 | self.fields[k] = v
44 |
45 | def area(self):
46 | box = self.box
47 |
48 | if self.mode == 'xyxy':
49 | remove = 1
50 |
51 | area = (box[:, 2] - box[:, 0] + remove) * (box[:, 3] - box[:, 1] + remove)
52 |
53 | elif self.mode == 'xywh':
54 | area = box[:, 2] * box[:, 3]
55 |
56 | return area
57 |
58 | def split_to_xyxy(self):
59 | if self.mode == 'xyxy':
60 | x_min, y_min, x_max, y_max = self.box.split(1, dim=-1)
61 |
62 | return x_min, y_min, x_max, y_max
63 |
64 | elif self.mode == 'xywh':
65 | remove = 1
66 | x_min, y_min, w, h = self.box.split(1, dim=-1)
67 |
68 | return (
69 | x_min,
70 | y_min,
71 | x_min + (w - remove).clamp(min=0),
72 | y_min + (h - remove).clamp(min=0),
73 | )
74 |
75 | def __len__(self):
76 | return self.box.shape[0]
77 |
78 | def __getitem__(self, index):
79 | box = BoxList(self.box[index], self.size, self.mode)
80 |
81 | for k, v in self.fields.items():
82 | box.fields[k] = v[index]
83 |
84 | return box
85 |
86 | def resize(self, size, *args, **kwargs):
87 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
88 |
89 | if ratios[0] == ratios[1]:
90 | ratio = ratios[0]
91 | scaled = self.box * ratio
92 | box = BoxList(scaled, size, mode=self.mode)
93 |
94 | for k, v in self.fields.items():
95 | if not isinstance(v, torch.Tensor):
96 | v = v.resize(size, *args, **kwargs)
97 |
98 | box.fields[k] = v
99 |
100 | return box
101 |
102 | ratio_w, ratio_h = ratios
103 | x_min, y_min, x_max, y_max = self.split_to_xyxy()
104 | scaled_x_min = x_min * ratio_w
105 | scaled_x_max = x_max * ratio_w
106 | scaled_y_min = y_min * ratio_h
107 | scaled_y_max = y_max * ratio_h
108 | scaled = torch.cat([scaled_x_min, scaled_y_min, scaled_x_max, scaled_y_max], -1)
109 | box = BoxList(scaled, size, mode='xyxy')
110 |
111 | for k, v in self.fields.items():
112 | if not isinstance(v, torch.Tensor):
113 | v = v.resize(size, *args, **kwargs)
114 |
115 | box.fields[k] = v
116 |
117 | return box.convert(self.mode)
118 |
119 | def transpose(self, method):
120 | width, height = self.size
121 | x_min, y_min, x_max, y_max = self.split_to_xyxy()
122 |
123 | if method == FLIP_LEFT_RIGHT:
124 | remove = 1
125 |
126 | transpose_x_min = width - x_max - remove
127 | transpose_x_max = width - x_min - remove
128 | transpose_y_min = y_min
129 | transpose_y_max = y_max
130 |
131 | elif method == FLIP_TOP_BOTTOM:
132 | transpose_x_min = x_min
133 | transpose_x_max = x_max
134 | transpose_y_min = height - y_max
135 | transpose_y_max = height - y_min
136 |
137 | transpose_box = torch.cat(
138 | [transpose_x_min, transpose_y_min, transpose_x_max, transpose_y_max], -1
139 | )
140 | box = BoxList(transpose_box, self.size, mode='xyxy')
141 |
142 | for k, v in self.fields.items():
143 | if not isinstance(v, torch.Tensor):
144 | v = v.transpose(method)
145 |
146 | box.fields[k] = v
147 |
148 | return box.convert(self.mode)
149 |
150 | def clip(self, remove_empty=True):
151 | remove = 1
152 |
153 | max_width = self.size[0] - remove
154 | max_height = self.size[1] - remove
155 |
156 | self.box[:, 0].clamp_(min=0, max=max_width)
157 | self.box[:, 1].clamp_(min=0, max=max_height)
158 | self.box[:, 2].clamp_(min=0, max=max_width)
159 | self.box[:, 3].clamp_(min=0, max=max_height)
160 |
161 | if remove_empty:
162 | box = self.box
163 | keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
164 |
165 | return self[keep]
166 |
167 | else:
168 | return self
169 |
170 | def to(self, device):
171 | box = BoxList(self.box.to(device), self.size, self.mode)
172 |
173 | for k, v in self.fields.items():
174 | if hasattr(v, 'to'):
175 | v = v.to(device)
176 |
177 | box.fields[k] = v
178 |
179 | return box
180 |
181 |
182 | def remove_small_box(boxlist, min_size):
183 | box = boxlist.convert('xywh').box
184 | _, _, w, h = box.unbind(dim=1)
185 | keep = (w >= min_size) & (h >= min_size)
186 | keep = keep.nonzero().squeeze(1)
187 |
188 | return boxlist[keep]
189 |
190 |
191 | def cat_boxlist(boxlists):
192 | size = boxlists[0].size
193 | mode = boxlists[0].mode
194 | field_keys = boxlists[0].fields.keys()
195 |
196 | box_cat = torch.cat([boxlist.box for boxlist in boxlists], 0)
197 | new_boxlist = BoxList(box_cat, size, mode)
198 |
199 | for field in field_keys:
200 | data = torch.cat([boxlist.fields[field] for boxlist in boxlists], 0)
201 | new_boxlist.fields[field] = data
202 |
203 | return new_boxlist
204 |
205 |
206 | def boxlist_nms(boxlist, scores, threshold, max_proposal=-1):
207 | if threshold <= 0:
208 | return boxlist
209 |
210 | mode = boxlist.mode
211 | boxlist = boxlist.convert('xyxy')
212 | box = boxlist.box
213 | keep = ops.nms(box, scores, threshold)
214 |
215 | if max_proposal > 0:
216 | keep = keep[:max_proposal]
217 |
218 | boxlist = boxlist[keep]
219 |
220 | return boxlist.convert(mode)
221 |
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | INF = 100000000
6 |
7 |
8 | class IOULoss(nn.Module):
9 | def __init__(self, loc_loss_type):
10 | super().__init__()
11 |
12 | self.loc_loss_type = loc_loss_type
13 |
14 | def forward(self, out, target, weight=None):
15 | pred_left, pred_top, pred_right, pred_bottom = out.unbind(1)
16 | target_left, target_top, target_right, target_bottom = target.unbind(1)
17 |
18 | target_area = (target_left + target_right) * (target_top + target_bottom)
19 | pred_area = (pred_left + pred_right) * (pred_top + pred_bottom)
20 |
21 | w_intersect = torch.min(pred_left, target_left) + torch.min(
22 | pred_right, target_right
23 | )
24 | h_intersect = torch.min(pred_bottom, target_bottom) + torch.min(
25 | pred_top, target_top
26 | )
27 |
28 | area_intersect = w_intersect * h_intersect
29 | area_union = target_area + pred_area - area_intersect
30 |
31 | ious = (area_intersect + 1) / (area_union + 1)
32 |
33 | if self.loc_loss_type == 'iou':
34 | loss = -torch.log(ious)
35 |
36 | elif self.loc_loss_type == 'giou':
37 | g_w_intersect = torch.max(pred_left, target_left) + torch.max(
38 | pred_right, target_right
39 | )
40 | g_h_intersect = torch.max(pred_bottom, target_bottom) + torch.max(
41 | pred_top, target_top
42 | )
43 | g_intersect = g_w_intersect * g_h_intersect + 1e-7
44 | gious = ious - (g_intersect - area_union) / g_intersect
45 |
46 | loss = 1 - gious
47 |
48 | if weight is not None and weight.sum() > 0:
49 | return (loss * weight).sum() / weight.sum()
50 |
51 | else:
52 | return loss.mean()
53 |
54 |
55 | def clip_sigmoid(input):
56 | out = torch.clamp(torch.sigmoid(input), min=1e-4, max=1 - 1e-4)
57 |
58 | return out
59 |
60 |
61 | class SigmoidFocalLoss(nn.Module):
62 | def __init__(self, gamma, alpha):
63 | super().__init__()
64 |
65 | self.gamma = gamma
66 | self.alpha = alpha
67 |
68 | def forward(self, out, target):
69 | n_class = out.shape[1]
70 | class_ids = torch.arange(
71 | 1, n_class + 1, dtype=target.dtype, device=target.device
72 | ).unsqueeze(0)
73 |
74 | t = target.unsqueeze(1)
75 | p = torch.sigmoid(out)
76 |
77 | gamma = self.gamma
78 | alpha = self.alpha
79 |
80 | term1 = (1 - p) ** gamma * torch.log(p)
81 | term2 = p ** gamma * torch.log(1 - p)
82 |
83 | # print(term1.sum(), term2.sum())
84 |
85 | loss = (
86 | -(t == class_ids).float() * alpha * term1
87 | - ((t != class_ids) * (t >= 0)).float() * (1 - alpha) * term2
88 | )
89 |
90 | return loss.sum()
91 |
92 |
93 | class FCOSLoss(nn.Module):
94 | def __init__(
95 | self, sizes, gamma, alpha, iou_loss_type, center_sample, fpn_strides, pos_radius
96 | ):
97 | super().__init__()
98 |
99 | self.sizes = sizes
100 |
101 | self.cls_loss = SigmoidFocalLoss(gamma, alpha)
102 | self.box_loss = IOULoss(iou_loss_type)
103 | self.center_loss = nn.BCEWithLogitsLoss()
104 |
105 | self.center_sample = center_sample
106 | self.strides = fpn_strides
107 | self.radius = pos_radius
108 |
109 | def prepare_target(self, points, targets):
110 | ex_size_of_interest = []
111 |
112 | for i, point_per_level in enumerate(points):
113 | size_of_interest_per_level = point_per_level.new_tensor(self.sizes[i])
114 | ex_size_of_interest.append(
115 | size_of_interest_per_level[None].expand(len(point_per_level), -1)
116 | )
117 |
118 | ex_size_of_interest = torch.cat(ex_size_of_interest, 0)
119 | n_point_per_level = [len(point_per_level) for point_per_level in points]
120 | point_all = torch.cat(points, dim=0)
121 | label, box_target = self.compute_target_for_location(
122 | point_all, targets, ex_size_of_interest, n_point_per_level
123 | )
124 |
125 | for i in range(len(label)):
126 | label[i] = torch.split(label[i], n_point_per_level, 0)
127 | box_target[i] = torch.split(box_target[i], n_point_per_level, 0)
128 |
129 | label_level_first = []
130 | box_target_level_first = []
131 |
132 | for level in range(len(points)):
133 | label_level_first.append(
134 | torch.cat([label_per_img[level] for label_per_img in label], 0)
135 | )
136 | box_target_level_first.append(
137 | torch.cat(
138 | [box_target_per_img[level] for box_target_per_img in box_target], 0
139 | )
140 | )
141 |
142 | return label_level_first, box_target_level_first
143 |
144 | def get_sample_region(self, gt, strides, n_point_per_level, xs, ys, radius=1):
145 | n_gt = gt.shape[0]
146 | n_loc = len(xs)
147 | gt = gt[None].expand(n_loc, n_gt, 4)
148 | center_x = (gt[..., 0] + gt[..., 2]) / 2
149 | center_y = (gt[..., 1] + gt[..., 3]) / 2
150 |
151 | if center_x[..., 0].sum() == 0:
152 | return xs.new_zeros(xs.shape, dtype=torch.uint8)
153 |
154 | begin = 0
155 |
156 | center_gt = gt.new_zeros(gt.shape)
157 |
158 | for level, n_p in enumerate(n_point_per_level):
159 | end = begin + n_p
160 | stride = strides[level] * radius
161 |
162 | x_min = center_x[begin:end] - stride
163 | y_min = center_y[begin:end] - stride
164 | x_max = center_x[begin:end] + stride
165 | y_max = center_y[begin:end] + stride
166 |
167 | center_gt[begin:end, :, 0] = torch.where(
168 | x_min > gt[begin:end, :, 0], x_min, gt[begin:end, :, 0]
169 | )
170 | center_gt[begin:end, :, 1] = torch.where(
171 | y_min > gt[begin:end, :, 1], y_min, gt[begin:end, :, 1]
172 | )
173 | center_gt[begin:end, :, 2] = torch.where(
174 | x_max > gt[begin:end, :, 2], gt[begin:end, :, 2], x_max
175 | )
176 | center_gt[begin:end, :, 3] = torch.where(
177 | y_max > gt[begin:end, :, 3], gt[begin:end, :, 3], y_max
178 | )
179 |
180 | begin = end
181 |
182 | left = xs[:, None] - center_gt[..., 0]
183 | right = center_gt[..., 2] - xs[:, None]
184 | top = ys[:, None] - center_gt[..., 1]
185 | bottom = center_gt[..., 3] - ys[:, None]
186 |
187 | center_bbox = torch.stack((left, top, right, bottom), -1)
188 | is_in_boxes = center_bbox.min(-1)[0] > 0
189 |
190 | return is_in_boxes
191 |
192 | def compute_target_for_location(
193 | self, locations, targets, sizes_of_interest, n_point_per_level
194 | ):
195 | labels = []
196 | box_targets = []
197 | xs, ys = locations[:, 0], locations[:, 1]
198 |
199 | for i in range(len(targets)):
200 | targets_per_img = targets[i]
201 | assert targets_per_img.mode == 'xyxy'
202 | bboxes = targets_per_img.box
203 | labels_per_img = targets_per_img.fields['labels']
204 | area = targets_per_img.area()
205 |
206 | l = xs[:, None] - bboxes[:, 0][None]
207 | t = ys[:, None] - bboxes[:, 1][None]
208 | r = bboxes[:, 2][None] - xs[:, None]
209 | b = bboxes[:, 3][None] - ys[:, None]
210 |
211 | box_targets_per_img = torch.stack([l, t, r, b], 2)
212 |
213 | if self.center_sample:
214 | is_in_boxes = self.get_sample_region(
215 | bboxes, self.strides, n_point_per_level, xs, ys, radius=self.radius
216 | )
217 |
218 | else:
219 | is_in_boxes = box_targets_per_img.min(2)[0] > 0
220 |
221 | max_box_targets_per_img = box_targets_per_img.max(2)[0]
222 |
223 | is_cared_in_level = (
224 | max_box_targets_per_img >= sizes_of_interest[:, [0]]
225 | ) & (max_box_targets_per_img <= sizes_of_interest[:, [1]])
226 |
227 | locations_to_gt_area = area[None].repeat(len(locations), 1)
228 | locations_to_gt_area[is_in_boxes == 0] = INF
229 | locations_to_gt_area[is_cared_in_level == 0] = INF
230 |
231 | locations_to_min_area, locations_to_gt_id = locations_to_gt_area.min(1)
232 |
233 | box_targets_per_img = box_targets_per_img[
234 | range(len(locations)), locations_to_gt_id
235 | ]
236 | labels_per_img = labels_per_img[locations_to_gt_id]
237 | labels_per_img[locations_to_min_area == INF] = 0
238 |
239 | labels.append(labels_per_img)
240 | box_targets.append(box_targets_per_img)
241 |
242 | return labels, box_targets
243 |
244 | def compute_centerness_targets(self, box_targets):
245 | left_right = box_targets[:, [0, 2]]
246 | top_bottom = box_targets[:, [1, 3]]
247 | centerness = (left_right.min(-1)[0] / left_right.max(-1)[0]) * (
248 | top_bottom.min(-1)[0] / top_bottom.max(-1)[0]
249 | )
250 |
251 | return torch.sqrt(centerness)
252 |
253 | def forward(self, locations, cls_pred, box_pred, center_pred, targets):
254 | batch = cls_pred[0].shape[0]
255 | n_class = cls_pred[0].shape[1]
256 |
257 | labels, box_targets = self.prepare_target(locations, targets)
258 |
259 | cls_flat = []
260 | box_flat = []
261 | center_flat = []
262 |
263 | labels_flat = []
264 | box_targets_flat = []
265 |
266 | for i in range(len(labels)):
267 | cls_flat.append(cls_pred[i].permute(0, 2, 3, 1).reshape(-1, n_class))
268 | box_flat.append(box_pred[i].permute(0, 2, 3, 1).reshape(-1, 4))
269 | center_flat.append(center_pred[i].permute(0, 2, 3, 1).reshape(-1))
270 |
271 | labels_flat.append(labels[i].reshape(-1))
272 | box_targets_flat.append(box_targets[i].reshape(-1, 4))
273 |
274 | cls_flat = torch.cat(cls_flat, 0)
275 | box_flat = torch.cat(box_flat, 0)
276 | center_flat = torch.cat(center_flat, 0)
277 |
278 | labels_flat = torch.cat(labels_flat, 0)
279 | box_targets_flat = torch.cat(box_targets_flat, 0)
280 |
281 | pos_id = torch.nonzero(labels_flat > 0).squeeze(1)
282 |
283 | cls_loss = self.cls_loss(cls_flat, labels_flat.int()) / (pos_id.numel() + batch)
284 |
285 | box_flat = box_flat[pos_id]
286 | center_flat = center_flat[pos_id]
287 |
288 | box_targets_flat = box_targets_flat[pos_id]
289 |
290 | if pos_id.numel() > 0:
291 | center_targets = self.compute_centerness_targets(box_targets_flat)
292 |
293 | box_loss = self.box_loss(box_flat, box_targets_flat, center_targets)
294 | center_loss = self.center_loss(center_flat, center_targets)
295 |
296 | else:
297 | box_loss = box_flat.sum()
298 | center_loss = center_flat.sum()
299 |
300 | return cls_loss, box_loss, center_loss
301 |
--------------------------------------------------------------------------------
/model/efficientnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from .utils import (
6 | round_filters,
7 | round_repeats,
8 | drop_connect,
9 | get_same_padding_conv2d,
10 | get_model_params,
11 | efficientnet_params,
12 | load_pretrained_weights,
13 | Swish,
14 | MemoryEfficientSwish,
15 | )
16 |
17 | class MBConvBlock(nn.Module):
18 | """
19 | Mobile Inverted Residual Bottleneck Block
20 | Args:
21 | block_args (namedtuple): BlockArgs, see above
22 | global_params (namedtuple): GlobalParam, see above
23 | Attributes:
24 | has_se (bool): Whether the block contains a Squeeze and Excitation layer.
25 | """
26 |
27 | def __init__(self, block_args, global_params):
28 | super().__init__()
29 | self._block_args = block_args
30 | self._bn_mom = 1 - global_params.batch_norm_momentum
31 | self._bn_eps = global_params.batch_norm_epsilon
32 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
33 | self.id_skip = block_args.id_skip # skip connection and drop connect
34 |
35 | # Get static or dynamic convolution depending on image size
36 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
37 |
38 | # Expansion phase
39 | inp = self._block_args.input_filters # number of input channels
40 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
41 | if self._block_args.expand_ratio != 1:
42 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
43 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
44 | # Depthwise convolution phase
45 | k = self._block_args.kernel_size
46 | s = self._block_args.stride
47 | self._depthwise_conv = Conv2d(
48 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
49 | kernel_size=k, stride=s, bias=False)
50 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
51 |
52 | # Squeeze and Excitation layer, if desired
53 | if self.has_se:
54 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
55 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
56 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
57 |
58 | # Output phase
59 | final_oup = self._block_args.output_filters
60 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
61 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
62 | self._swish = MemoryEfficientSwish()
63 |
64 | def forward(self, inputs, drop_connect_rate=None):
65 | """
66 | :param inputs: input tensor
67 | :param drop_connect_rate: drop connect rate (float, between 0 and 1)
68 | :return: output of block
69 | """
70 |
71 | # Expansion and Depthwise Convolution
72 | x = inputs
73 | if self._block_args.expand_ratio != 1:
74 | x = self._swish(self._bn0(self._expand_conv(inputs)))
75 |
76 | x = self._swish(self._bn1(self._depthwise_conv(x)))
77 |
78 | # Squeeze and Excitation
79 | if self.has_se:
80 | x_squeezed = F.adaptive_avg_pool2d(x, 1)
81 | x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
82 | x = torch.sigmoid(x_squeezed) * x
83 |
84 | x = self._bn2(self._project_conv(x))
85 |
86 | # Skip connection and drop connect
87 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
88 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
89 | if drop_connect_rate:
90 | x = drop_connect(x, p=drop_connect_rate, training=self.training)
91 | x = x + inputs # skip connection
92 | return x
93 |
94 | def set_swish(self, memory_efficient=True):
95 | """Sets swish function as memory efficient (for training) or standard (for export)"""
96 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
97 |
98 |
99 | class EfficientNet(nn.Module):
100 | """
101 | An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
102 | Args:
103 | blocks_args (list): A list of BlockArgs to construct blocks
104 | global_params (namedtuple): A set of GlobalParams shared between blocks
105 | Example:
106 | model = EfficientNet.from_pretrained('efficientnet-b0')
107 | """
108 |
109 | def __init__(self, blocks_args=None, global_params=None):
110 | super().__init__()
111 | assert isinstance(blocks_args, list), 'blocks_args should be a list'
112 | assert len(blocks_args) > 0, 'block args must be greater than 0'
113 | self._global_params = global_params
114 | self._blocks_args = blocks_args
115 |
116 | # Get static or dynamic convolution depending on image size
117 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
118 |
119 | # Batch norm parameters
120 | bn_mom = 1 - self._global_params.batch_norm_momentum
121 | bn_eps = self._global_params.batch_norm_epsilon
122 |
123 | # Stem
124 | in_channels = 3 # rgb
125 | out_channels = round_filters(32, self._global_params) # number of output channels
126 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
127 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
128 |
129 | # Build blocks
130 | self._blocks = nn.ModuleList([])
131 | for i in range(len(self._blocks_args)):
132 | # Update block input and output filters based on depth multiplier.
133 | self._blocks_args[i] = self._blocks_args[i]._replace(
134 | input_filters=round_filters(self._blocks_args[i].input_filters, self._global_params),
135 | output_filters=round_filters(self._blocks_args[i].output_filters, self._global_params),
136 | num_repeat=round_repeats(self._blocks_args[i].num_repeat, self._global_params)
137 | )
138 |
139 | # The first block needs to take care of stride and filter size increase.
140 | self._blocks.append(MBConvBlock(self._blocks_args[i], self._global_params))
141 | if self._blocks_args[i].num_repeat > 1:
142 | self._blocks_args[i] = self._blocks_args[i]._replace(input_filters=self._blocks_args[i].output_filters, stride=1)
143 | for _ in range(self._blocks_args[i].num_repeat - 1):
144 | self._blocks.append(MBConvBlock(self._blocks_args[i], self._global_params))
145 |
146 | # Head'efficientdet-d0': 'efficientnet-b0',
147 | in_channels = self._blocks_args[len(self._blocks_args)-1].output_filters # output of final block
148 | out_channels = round_filters(1280, self._global_params)
149 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
150 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
151 |
152 | # Final linear layer
153 | self._avg_pooling = nn.AdaptiveAvgPool2d(1)
154 | self._dropout = nn.Dropout(self._global_params.dropout_rate)
155 | self._fc = nn.Linear(out_channels, self._global_params.num_classes)
156 | self._swish = MemoryEfficientSwish()
157 |
158 | def set_swish(self, memory_efficient=True):
159 | """Sets swish function as memory efficient (for training) or standard (for export)"""
160 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
161 | for block in self._blocks:
162 | block.set_swish(memory_efficient)
163 |
164 |
165 | def extract_features(self, inputs):
166 | """ Returns output of the final convolution layer """
167 | # Stem
168 | x = self._swish(self._bn0(self._conv_stem(inputs)))
169 |
170 | P = []
171 | index = 0
172 | num_repeat = 0
173 | # Blocks
174 | for idx, block in enumerate(self._blocks):
175 | drop_connect_rate = self._global_params.drop_connect_rate
176 | if drop_connect_rate:
177 | drop_connect_rate *= float(idx) / len(self._blocks)
178 | x = block(x, drop_connect_rate=drop_connect_rate)
179 | num_repeat = num_repeat + 1
180 | if(num_repeat == self._blocks_args[index].num_repeat):
181 | num_repeat = 0
182 | index = index + 1
183 | P.append(x)
184 | return P
185 |
186 | def forward(self, inputs):
187 | """ Calls extract_features to extract features, applies final linear layer, and returns logits. """
188 | # Convolution layers
189 | P = self.extract_features(inputs)
190 | return P
191 |
192 | @classmethod
193 | def from_name(cls, model_name, override_params=None):
194 | cls._check_model_name_is_valid(model_name)
195 | blocks_args, global_params = get_model_params(model_name, override_params)
196 | return cls(blocks_args, global_params)
197 |
198 | @classmethod
199 | def from_pretrained(cls, model_name, num_classes=1000, in_channels = 3):
200 | model = cls.from_name(model_name, override_params={'num_classes': num_classes})
201 | load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
202 | if in_channels != 3:
203 | Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size)
204 | out_channels = round_filters(32, model._global_params)
205 | model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
206 | return model
207 |
208 | @classmethod
209 | def from_pretrained(cls, model_name, num_classes=1000):
210 | model = cls.from_name(model_name, override_params={'num_classes': num_classes})
211 | load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
212 |
213 | return model
214 |
215 | @classmethod
216 | def get_image_size(cls, model_name):
217 | cls._check_model_name_is_valid(model_name)
218 | _, _, res, _ = efficientnet_params(model_name)
219 | return res
220 |
221 | @classmethod
222 | def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False):
223 | """ Validates model name. None that pretrained weights are only available for
224 | the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """
225 | num_models = 4 if also_need_pretrained_weights else 8
226 | valid_models = ['efficientnet-b'+str(i) for i in range(num_models)]
227 | if model_name not in valid_models:
228 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
229 |
230 | def get_list_features(self):
231 | list_feature = []
232 | for idx in range(len(self._blocks_args)):
233 | list_feature.append(self._blocks_args[idx].output_filters)
234 |
235 | return list_feature
236 |
237 |
238 |
239 |
240 |
241 |
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import math
3 | import collections
4 | from functools import partial
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 | from torch.utils import model_zoo
9 |
10 | ########################################################################
11 | ############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
12 | ########################################################################
13 |
14 |
15 | # Parameters for the entire model (stem, all blocks, and head)
16 | GlobalParams = collections.namedtuple('GlobalParams', [
17 | 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate',
18 | 'num_classes', 'width_coefficient', 'depth_coefficient',
19 | 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])
20 |
21 | # Parameters for an individual model block
22 | BlockArgs = collections.namedtuple('BlockArgs', [
23 | 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
24 | 'expand_ratio', 'id_skip', 'stride', 'se_ratio'])
25 |
26 | # Change namedtuple defaults
27 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
28 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
29 |
30 |
31 | class SwishImplementation(torch.autograd.Function):
32 | @staticmethod
33 | def forward(ctx, i):
34 | result = i * torch.sigmoid(i)
35 | ctx.save_for_backward(i)
36 | return result
37 |
38 | @staticmethod
39 | def backward(ctx, grad_output):
40 | i = ctx.saved_variables[0]
41 | sigmoid_i = torch.sigmoid(i)
42 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
43 |
44 |
45 | class MemoryEfficientSwish(nn.Module):
46 | def forward(self, x):
47 | return SwishImplementation.apply(x)
48 |
49 | class Swish(nn.Module):
50 | def forward(self, x):
51 | return x * torch.sigmoid(x)
52 |
53 |
54 | def round_filters(filters, global_params):
55 | """ Calculate and round number of filters based on depth multiplier. """
56 | multiplier = global_params.width_coefficient
57 | if not multiplier:
58 | return filters
59 | divisor = global_params.depth_divisor
60 | min_depth = global_params.min_depth
61 | filters *= multiplier
62 | min_depth = min_depth or divisor
63 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
64 | if new_filters < 0.9 * filters: # prevent rounding by more than 10%
65 | new_filters += divisor
66 | return int(new_filters)
67 |
68 |
69 | def round_repeats(repeats, global_params):
70 | """ Round number of filters based on depth multiplier. """
71 | multiplier = global_params.depth_coefficient
72 | if not multiplier:
73 | return repeats
74 | return int(math.ceil(multiplier * repeats))
75 |
76 |
77 | def drop_connect(inputs, p, training):
78 | """ Drop connect. """
79 | if not training: return inputs
80 | batch_size = inputs.shape[0]
81 | keep_prob = 1 - p
82 | random_tensor = keep_prob
83 | random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
84 | binary_tensor = torch.floor(random_tensor)
85 | output = inputs / keep_prob * binary_tensor
86 | return output
87 |
88 |
89 | def get_same_padding_conv2d(image_size=None):
90 | """ Chooses static padding if you have specified an image size, and dynamic padding otherwise.
91 | Static padding is necessary for ONNX exporting of models. """
92 | if image_size is None:
93 | return Conv2dDynamicSamePadding
94 | else:
95 | return partial(Conv2dStaticSamePadding, image_size=image_size)
96 |
97 |
98 | class Conv2dDynamicSamePadding(nn.Conv2d):
99 | """ 2D Convolutions like TensorFlow, for a dynamic image size """
100 |
101 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
102 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
103 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
104 |
105 | def forward(self, x):
106 | ih, iw = x.size()[-2:]
107 | kh, kw = self.weight.size()[-2:]
108 | sh, sw = self.stride
109 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
110 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
111 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
112 | if pad_h > 0 or pad_w > 0:
113 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
114 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
115 |
116 |
117 | class Conv2dStaticSamePadding(nn.Conv2d):
118 | """ 2D Convolutions like TensorFlow, for a fixed image size"""
119 |
120 | def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
121 | super().__init__(in_channels, out_channels, kernel_size, **kwargs)
122 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
123 |
124 | # Calculate padding based on image size and save it
125 | assert image_size is not None
126 | ih, iw = image_size if type(image_size) == list else [image_size, image_size]
127 | kh, kw = self.weight.size()[-2:]
128 | sh, sw = self.stride
129 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
130 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
131 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
132 | if pad_h > 0 or pad_w > 0:
133 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
134 | else:
135 | self.static_padding = Identity()
136 |
137 | def forward(self, x):
138 | x = self.static_padding(x)
139 | x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
140 | return x
141 |
142 |
143 | class Identity(nn.Module):
144 | def __init__(self, ):
145 | super(Identity, self).__init__()
146 |
147 | def forward(self, input):
148 | return input
149 |
150 |
151 | ########################################################################
152 | ############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
153 | ########################################################################
154 |
155 |
156 | def efficientnet_params(model_name):
157 | """ Map EfficientNet model name to parameter coefficients. """
158 | params_dict = {
159 | # Coefficients: width,depth,res,dropout
160 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
161 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
162 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
163 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
164 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
165 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
166 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
167 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
168 | }
169 | return params_dict[model_name]
170 |
171 |
172 | class BlockDecoder(object):
173 | """ Block Decoder for readability, straight from the official TensorFlow repository """
174 |
175 | @staticmethod
176 | def _decode_block_string(block_string):
177 | """ Gets a block through a string notation of arguments. """
178 | assert isinstance(block_string, str)
179 |
180 | ops = block_string.split('_')
181 | options = {}
182 | for op in ops:
183 | splits = re.split(r'(\d.*)', op)
184 | if len(splits) >= 2:
185 | key, value = splits[:2]
186 | options[key] = value
187 |
188 | # Check stride
189 | assert (('s' in options and len(options['s']) == 1) or
190 | (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
191 |
192 | return BlockArgs(
193 | kernel_size=int(options['k']),
194 | num_repeat=int(options['r']),
195 | input_filters=int(options['i']),
196 | output_filters=int(options['o']),
197 | expand_ratio=int(options['e']),
198 | id_skip=('noskip' not in block_string),
199 | se_ratio=float(options['se']) if 'se' in options else None,
200 | stride=[int(options['s'][0])])
201 |
202 | @staticmethod
203 | def _encode_block_string(block):
204 | """Encodes a block to a string."""
205 | args = [
206 | 'r%d' % block.num_repeat,
207 | 'k%d' % block.kernel_size,
208 | 's%d%d' % (block.strides[0], block.strides[1]),
209 | 'e%s' % block.expand_ratio,
210 | 'i%d' % block.input_filters,
211 | 'o%d' % block.output_filters
212 | ]
213 | if 0 < block.se_ratio <= 1:
214 | args.append('se%s' % block.se_ratio)
215 | if block.id_skip is False:
216 | args.append('noskip')
217 | return '_'.join(args)
218 |
219 | @staticmethod
220 | def decode(string_list):
221 | """
222 | Decodes a list of string notations to specify blocks inside the network.
223 | :param string_list: a list of strings, each string is a notation of block
224 | :return: a list of BlockArgs namedtuples of block args
225 | """
226 | assert isinstance(string_list, list)
227 | blocks_args = []
228 | for block_string in string_list:
229 | blocks_args.append(BlockDecoder._decode_block_string(block_string))
230 | return blocks_args
231 |
232 | @staticmethod
233 | def encode(blocks_args):
234 | """
235 | Encodes a list of BlockArgs to a list of strings.
236 | :param blocks_args: a list of BlockArgs namedtuples of block args
237 | :return: a list of strings, each string is a notation of block
238 | """
239 | block_strings = []
240 | for block in blocks_args:
241 | block_strings.append(BlockDecoder._encode_block_string(block))
242 | return block_strings
243 |
244 |
245 | def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2,
246 | drop_connect_rate=0.2, image_size=None, num_classes=1000):
247 | """ Creates a efficientnet model. """
248 |
249 | blocks_args = [
250 | 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
251 | 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
252 | 'r3_k5_s22_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
253 | 'r1_k3_s22_e6_i192_o320_se0.25',
254 | ]
255 | blocks_args = BlockDecoder.decode(blocks_args)
256 |
257 | global_params = GlobalParams(
258 | batch_norm_momentum=0.99,
259 | batch_norm_epsilon=1e-3,
260 | dropout_rate=dropout_rate,
261 | drop_connect_rate=drop_connect_rate,
262 | # data_format='channels_last', # removed, this is always true in PyTorch
263 | num_classes=num_classes,
264 | width_coefficient=width_coefficient,
265 | depth_coefficient=depth_coefficient,
266 | depth_divisor=8,
267 | min_depth=None,
268 | image_size=image_size,
269 | )
270 |
271 | return blocks_args, global_params
272 |
273 |
274 | def get_model_params(model_name, override_params):
275 | """ Get the block args and global params for a given model """
276 | if model_name.startswith('efficientnet'):
277 | w, d, s, p = efficientnet_params(model_name)
278 | # note: all models have drop connect rate = 0.2
279 | blocks_args, global_params = efficientnet(
280 | width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
281 | else:
282 | raise NotImplementedError('model name is not pre-defined: %s' % model_name)
283 | if override_params:
284 | # ValueError will be raised here if override_params has fields not included in global_params.
285 | global_params = global_params._replace(**override_params)
286 | return blocks_args, global_params
287 |
288 |
289 | url_map = {
290 | 'efficientnet-b0': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b0-355c32eb.pth',
291 | 'efficientnet-b1': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b1-f1951068.pth',
292 | 'efficientnet-b2': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b2-8bb594d6.pth',
293 | 'efficientnet-b3': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b3-5fb5a3c3.pth',
294 | 'efficientnet-b4': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b4-6ed6700e.pth',
295 | 'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b5-b6417697.pth',
296 | 'efficientnet-b6': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b6-c76e70fd.pth',
297 | 'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth',
298 | }
299 |
300 |
301 | def load_pretrained_weights(model, model_name, load_fc=True):
302 | """ Loads pretrained weights, and downloads if loading for the first time. """
303 | state_dict = model_zoo.load_url(url_map[model_name])
304 | if load_fc:
305 | model.load_state_dict(state_dict)
306 | else:
307 | state_dict.pop('_fc.weight')
308 | state_dict.pop('_fc.bias')
309 | res = model.load_state_dict(state_dict, strict=False)
310 | assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
311 | print('Loaded pretrained weights for {}'.format(model_name))
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 | true
117 | DEFINITION_ORDER
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 | 1581650485803
202 |
203 |
204 | 1581650485803
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
--------------------------------------------------------------------------------