├── tool ├── util.py ├── visualize.py ├── evaluate.py ├── postprocess.py ├── distributed.py └── boxlist.py ├── model ├── __init__.py ├── head.py ├── model.py ├── bi.py ├── loss.py ├── efficientnet.py └── utils.py ├── .idea ├── misc.xml ├── modules.xml ├── efficientdet-anchor-free.iml ├── deployment.xml └── workspace.xml ├── README.md ├── config.py ├── data ├── transform.py └── dataset.py └── train.py /tool/util.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/efficientdet-anchor-free.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EfficientDet_anchor_free 2 | EfficientDet_anchor_free 3 | 4 | 5 | 6 | # Introduction 7 | This is EfficientDet anchor-free in Pytorch,i also completed EfficientDet_anchor-based. 8 | 9 | 10 | ## Results 11 | | |This report(anchor-free)| This report(anchor-based)|Paper | 12 | | :----- | :----- | :------ |:------ | 13 | |network|Efficientnet-b0|Efficientnet-b0|Efficientnet-b0| 14 | |datasets|COCO2017|VOC0712|COCO2017| 15 | |notes|Multi-scales|mixup-up ,label smooth, giou loss, cosine lr|Multi-scales| 16 | |MAPS|32.9|68.5|32.4| 17 | |b1-b7|TODO|TODO||--| 18 | 19 | 20 | There are some problems in EfficientDet_anchor-based,which has low Maps and slow speed. I will fix it and share the code. 21 | 22 | 23 | ## Reference 24 | * [FCOS](https://github.com/tianzhi0549/FCOS) 25 | * [EfficientDet](https://arxiv.org/pdf/1911.09070.pdf) 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /tool/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | 3 | matplotlib.use('Agg') 4 | from matplotlib import pyplot as plt 5 | from matplotlib import patheffects, patches 6 | 7 | 8 | def show_img(img, figsize=None, fig=None, ax=None): 9 | if not ax: 10 | fig, ax = plt.subplots(figsize=figsize) 11 | 12 | ax.imshow(img) 13 | ax.get_xaxis().set_visible(False) 14 | ax.get_yaxis().set_visible(False) 15 | 16 | return fig, ax 17 | 18 | 19 | def draw_outline(obj, line_width): 20 | obj.set_path_effects( 21 | [ 22 | patheffects.Stroke(linewidth=line_width, foreground='black'), 23 | patheffects.Normal(), 24 | ] 25 | ) 26 | 27 | 28 | def draw_rect(ax, box): 29 | patch = ax.add_patch( 30 | patches.Rectangle(box[:2], *box[-2:], fill=False, edgecolor='white', lw=2) 31 | ) 32 | draw_outline(patch, 4) 33 | 34 | 35 | def draw_text(ax, xy, txt, sz=14): 36 | text = ax.text( 37 | *xy, txt, verticalalignment='top', color='white', fontsize=sz, weight='bold' 38 | ) 39 | draw_outline(text, 1) 40 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | 3 | def __init__(self): 4 | self.network = 'efficientnet-b0' 5 | self.local_rank = 0 6 | self.lr = 0.01 7 | self.l2 = 0.0001 8 | self.batch = 16 9 | self.epoch = 50 10 | self.n_save_sample = 5 11 | self.out_channel = 256 12 | self.n_class = 81 13 | self.prior = 0.01 14 | self.threshold = 0.05 15 | self.top_n = 1000 16 | self.nms_threshold = 0.6 17 | self.post_top_n = 100 18 | self.fpn_strides = [8, 16, 32, 64, 128] 19 | self.gamma = 2.0 20 | self.alpha = 0.25 21 | self.iou_loss_type = 'giou' 22 | self.pos_radius = 1.5 23 | self.center_sample = True 24 | self.sizes = [[-1, 64], [64, 128], [128, 256], [256, 512], [512, 100000000]] 25 | self.train_min_size_range = (512,640) 26 | self.train_max_size = 800 27 | self.test_min_size = 512 28 | self.test_max_size = 800 29 | self.pixel_mean = [0.40789654, 0.44719302, 0.47026115] 30 | self.pixel_std = [0.28863828, 0.27408164, 0.27809835] 31 | self.size_divisible = 32 32 | self.min_size = 0 33 | 34 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /model/head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | 5 | class Scale(nn.Module): 6 | def __init__(self, init=1.0): 7 | super().__init__() 8 | 9 | self.scale = nn.Parameter(torch.tensor([init], dtype=torch.float32)) 10 | 11 | def forward(self, input): 12 | return input * self.scale 13 | 14 | class FCOSHead(nn.Module): 15 | def __init__(self, in_channel, n_class, n_conv, prior): 16 | super().__init__() 17 | 18 | n_class = n_class - 1 19 | 20 | cls_tower = [] 21 | bbox_tower = [] 22 | 23 | for i in range(n_conv): 24 | cls_tower.append( 25 | nn.Conv2d(in_channel, in_channel, 3, padding=1, bias=False) 26 | ) 27 | cls_tower.append(nn.BatchNorm2d(in_channel)) 28 | # cls_tower.append(nn.GroupNorm(32, in_channel)) 29 | cls_tower.append(nn.ReLU()) 30 | 31 | bbox_tower.append( 32 | nn.Conv2d(in_channel, in_channel, 3, padding=1, bias=False) 33 | ) 34 | cls_tower.append(nn.BatchNorm2d(in_channel)) 35 | # bbox_tower.append(nn.GroupNorm(32, in_channel)) 36 | bbox_tower.append(nn.ReLU()) 37 | 38 | self.cls_tower = nn.Sequential(*cls_tower) 39 | self.bbox_tower = nn.Sequential(*bbox_tower) 40 | 41 | self.cls_pred = nn.Conv2d(in_channel, n_class, 3, padding=1) 42 | self.bbox_pred = nn.Conv2d(in_channel, 4, 3, padding=1) 43 | self.center_pred = nn.Conv2d(in_channel, 1, 3, padding=1) 44 | 45 | self.apply(self.init_conv_std) 46 | 47 | prior_bias = -math.log((1 - prior) / prior) 48 | nn.init.constant_(self.cls_pred.bias, prior_bias) 49 | 50 | self.scales = nn.ModuleList([Scale(1.0) for _ in range(5)]) 51 | 52 | def init_conv_std(self,module, std=0.01): 53 | if isinstance(module, nn.Conv2d): 54 | nn.init.normal_(module.weight, std=std) 55 | 56 | if module.bias is not None: 57 | nn.init.constant_(module.bias, 0) 58 | 59 | def forward(self, input): 60 | logits = [] 61 | bboxes = [] 62 | centers = [] 63 | 64 | for feat, scale in zip(input, self.scales): 65 | cls_out = self.cls_tower(feat) 66 | 67 | logits.append(self.cls_pred(cls_out)) 68 | centers.append(self.center_pred(cls_out)) 69 | 70 | bbox_out = self.bbox_tower(feat) 71 | bbox_out = torch.exp(scale(self.bbox_pred(bbox_out))) 72 | 73 | bboxes.append(bbox_out) 74 | 75 | return logits, bboxes, centers 76 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .efficientnet import EfficientNet 4 | from .bi import BiFpn 5 | from .head import FCOSHead 6 | from tool import postprocess 7 | from .loss import FCOSLoss 8 | class EfficientDet_Free(nn.Module): 9 | 10 | def __init__(self,config): 11 | super(EfficientDet_Free,self).__init__() 12 | self.config = config 13 | self.backbone = EfficientNet.from_pretrained(config.network) 14 | self.fpn = BiFpn(in_channels=self.backbone.get_list_features()[-5:],out_channels=config.out_channels,len_input=5,bi=3) 15 | self.head = FCOSHead(config.out_channels, config.n_class, config.n_conv, config.prior) 16 | self.postprocessor = postprocess.FCOSPostprocessor( 17 | config.threshold, 18 | config.top_n, 19 | config.nms_threshold, 20 | config.post_top_n, 21 | config.min_size, 22 | config.n_class, 23 | ) 24 | self.loss = FCOSLoss( 25 | config.sizes, 26 | config.gamma, 27 | config.alpha, 28 | config.iou_loss_type, 29 | config.center_sample, 30 | config.fpn_strides, 31 | config.pos_radius, 32 | ) 33 | self.fpn_strides = config.fpn_strides 34 | 35 | def compute_location(self, features): 36 | locations = [] 37 | 38 | for i, feat in enumerate(features): 39 | _, _, height, width = feat.shape 40 | location_per_level = self.compute_location_per_level( 41 | height, width, self.fpn_strides[i], feat.device 42 | ) 43 | locations.append(location_per_level) 44 | 45 | return locations 46 | 47 | def compute_location_per_level(self, height, width, stride, device): 48 | shift_x = torch.arange( 49 | 0, width * stride, step=stride, dtype=torch.float32, device=device 50 | ) 51 | shift_y = torch.arange( 52 | 0, height * stride, step=stride, dtype=torch.float32, device=device 53 | ) 54 | shift_y, shift_x = torch.meshgrid(shift_y, shift_x) 55 | shift_x = shift_x.reshape(-1) 56 | shift_y = shift_y.reshape(-1) 57 | location = torch.stack((shift_x, shift_y), 1) + stride // 2 58 | 59 | return location 60 | 61 | def forward(self, input, image_sizes=None, targets=None): 62 | features = self.backbone(input)[-5:] 63 | features = self.fpn(features) 64 | cls_pred, box_pred, center_pred = self.head(features) 65 | location = self.compute_location(features) 66 | 67 | if self.training: 68 | loss_cls, loss_box, loss_center = self.loss( 69 | location, cls_pred, box_pred, center_pred, targets 70 | ) 71 | losses = { 72 | 'loss_cls': loss_cls, 73 | 'loss_box': loss_box, 74 | 'loss_center': loss_center, 75 | } 76 | 77 | return None, losses 78 | 79 | else: 80 | boxes = self.postprocessor( 81 | location, cls_pred, box_pred, center_pred, image_sizes 82 | ) 83 | 84 | return boxes, None 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /data/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import torchvision 5 | from torchvision.transforms import functional as F 6 | 7 | 8 | class Compose: 9 | def __init__(self, transforms): 10 | self.transforms = transforms 11 | 12 | def __call__(self, img, target): 13 | for t in self.transforms: 14 | img, target = t(img, target) 15 | 16 | return img, target 17 | 18 | def __repr__(self): 19 | format_str = self.__class__.__name__ + '(' 20 | for t in self.transforms: 21 | format_str += '\n' 22 | format_str += ' {t}' 23 | format_str += '\n)' 24 | 25 | return format_str 26 | 27 | 28 | class Resize: 29 | def __init__(self, min_size, max_size): 30 | if not isinstance(min_size, (list, tuple)): 31 | min_size = (min_size,) 32 | 33 | self.min_size = min_size 34 | self.max_size = max_size 35 | 36 | def get_size(self, img_size): 37 | w, h = img_size 38 | size = random.choice(self.min_size) 39 | max_size = self.max_size 40 | 41 | if max_size is not None: 42 | min_orig = float(min((w, h))) 43 | max_orig = float(max((w, h))) 44 | 45 | if max_orig / min_orig * size > max_size: 46 | size = int(round(max_size * min_orig / max_orig)) 47 | 48 | if (w <= h and w == size) or (h <= w and h == size): 49 | return h, w 50 | 51 | if w < h: 52 | ow = size 53 | oh = int(size * h / w) 54 | 55 | else: 56 | oh = size 57 | ow = int(size * w / h) 58 | 59 | return oh, ow 60 | 61 | def __call__(self, img, target): 62 | size = self.get_size(img.size) 63 | img = F.resize(img, size) 64 | target = target.resize(img.size) 65 | 66 | return img, target 67 | 68 | 69 | class RandomHorizontalFlip: 70 | def __init__(self, p=0.5): 71 | self.p = p 72 | 73 | def __call__(self, img, target): 74 | if random.random() < self.p: 75 | img = F.hflip(img) 76 | target = target.transpose(0) 77 | 78 | return img, target 79 | 80 | 81 | class ToTensor: 82 | def __call__(self, img, target): 83 | return F.to_tensor(img), target 84 | 85 | 86 | class Normalize: 87 | def __init__(self, mean, std): 88 | self.mean = mean 89 | self.std = std 90 | 91 | def __call__(self, img, target): 92 | img = F.normalize(img, mean=self.mean, std=self.std) 93 | 94 | return img, target 95 | 96 | 97 | def preset_transform(config, train=True): 98 | if train: 99 | if config.train_min_size_range[0] == -1: 100 | min_size = config.train_min_size 101 | 102 | else: 103 | min_size = list( 104 | range( 105 | config.train_min_size_range[0], config.train_min_size_range[1] + 1 106 | ) 107 | ) 108 | 109 | max_size = config.train_max_size 110 | flip = 0.5 111 | 112 | else: 113 | min_size = config.test_min_size 114 | max_size = config.test_max_size 115 | flip = 0 116 | 117 | normalize = Normalize(mean=config.pixel_mean, std=config.pixel_std) 118 | 119 | transform = Compose( 120 | [Resize(min_size, max_size), RandomHorizontalFlip(flip), ToTensor(), normalize] 121 | ) 122 | 123 | return transform 124 | -------------------------------------------------------------------------------- /tool/evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tempfile 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | from pycocotools.coco import COCO 7 | from pycocotools.cocoeval import COCOeval 8 | 9 | 10 | def evaluate(dataset, predictions): 11 | coco_results = {} 12 | coco_results['bbox'] = make_coco_detection(predictions, dataset) 13 | 14 | results = COCOResult('bbox') 15 | 16 | with tempfile.NamedTemporaryFile() as f: 17 | path = f.name 18 | res = evaluate_predictions_on_coco( 19 | dataset.coco, coco_results['bbox'], path, 'bbox' 20 | ) 21 | results.update(res) 22 | 23 | print(results) 24 | 25 | return res 26 | 27 | 28 | def evaluate_predictions_on_coco(coco_gt, results, result_file, iou_type): 29 | with open(result_file, 'w') as f: 30 | json.dump(results, f) 31 | 32 | coco_dt = coco_gt.loadRes(str(result_file)) if results else COCO() 33 | 34 | coco_eval = COCOeval(coco_gt, coco_dt, iou_type) 35 | coco_eval.evaluate() 36 | coco_eval.accumulate() 37 | coco_eval.summarize() 38 | 39 | # compute_thresholds_for_classes(coco_eval) 40 | 41 | return coco_eval 42 | 43 | 44 | def compute_thresholds_for_classes(coco_eval): 45 | precision = coco_eval.eval['precision'] 46 | precision = precision[0, :, :, 0, -1] 47 | scores = coco_eval.eval['scores'] 48 | scores = scores[0, :, :, 0, -1] 49 | 50 | recall = np.linspace(0, 1, num=precision.shape[0]) 51 | recall = recall[:, None] 52 | 53 | f1 = (2 * precision * recall) / (np.maximum(precision + recall, 1e-6)) 54 | max_f1 = f1.max(0) 55 | max_f1_id = f1.argmax(0) 56 | scores = scores[max_f1_id, range(len(max_f1_id))] 57 | 58 | print('Maximum f1 for classes:') 59 | print(list(max_f1)) 60 | print('Score thresholds for classes') 61 | print(list(scores)) 62 | 63 | 64 | def make_coco_detection(predictions, dataset): 65 | coco_results = [] 66 | 67 | for id, pred in enumerate(predictions): 68 | orig_id = dataset.id2img[id] 69 | 70 | if len(pred) == 0: 71 | continue 72 | 73 | img_meta = dataset.get_image_meta(id) 74 | width = img_meta['width'] 75 | height = img_meta['height'] 76 | pred = pred.resize((width, height)) 77 | pred = pred.convert('xywh') 78 | 79 | boxes = pred.box.tolist() 80 | scores = pred.fields['scores'].tolist() 81 | labels = pred.fields['labels'].tolist() 82 | 83 | labels = [dataset.id2category[i] for i in labels] 84 | 85 | coco_results.extend( 86 | [ 87 | { 88 | 'image_id': orig_id, 89 | 'category_id': labels[k], 90 | 'bbox': box, 91 | 'score': scores[k], 92 | } 93 | for k, box in enumerate(boxes) 94 | ] 95 | ) 96 | 97 | return coco_results 98 | 99 | 100 | class COCOResult: 101 | METRICS = { 102 | 'bbox': ['AP', 'AP50', 'AP75', 'APs', 'APm', 'APl'], 103 | 'segm': ['AP', 'AP50', 'AP75', 'APs', 'APm', 'APl'], 104 | 'box_proposal': [ 105 | 'AR@100', 106 | 'ARs@100', 107 | 'ARm@100', 108 | 'ARl@100', 109 | 'AR@1000', 110 | 'ARs@1000', 111 | 'ARm@1000', 112 | 'ARl@1000', 113 | ], 114 | 'keypoints': ['AP', 'AP50', 'AP75', 'APm', 'APl'], 115 | } 116 | 117 | def __init__(self, *iou_types): 118 | allowed_types = ("box_proposal", "bbox", "segm", "keypoints") 119 | assert all(iou_type in allowed_types for iou_type in iou_types) 120 | results = OrderedDict() 121 | for iou_type in iou_types: 122 | results[iou_type] = OrderedDict( 123 | [(metric, -1) for metric in COCOResult.METRICS[iou_type]] 124 | ) 125 | self.results = results 126 | 127 | def update(self, coco_eval): 128 | if coco_eval is None: 129 | return 130 | 131 | assert isinstance(coco_eval, COCOeval) 132 | s = coco_eval.stats 133 | iou_type = coco_eval.params.iouType 134 | res = self.results[iou_type] 135 | metrics = COCOResult.METRICS[iou_type] 136 | for idx, metric in enumerate(metrics): 137 | res[metric] = s[idx] 138 | 139 | def __repr__(self): 140 | return repr(self.results) 141 | -------------------------------------------------------------------------------- /model/bi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvBlock(nn.Module): 7 | def __init__(self,in_channels,out_channels,kernel = 3,strides = 1,padding=1, 8 | bias = True,act = True,bn = True): 9 | super(ConvBlock,self).__init__() 10 | self.conv = nn.Conv2d(in_channels=in_channels,out_channels=out_channels, 11 | kernel_size=kernel,stride=strides,padding=padding, 12 | bias=bias) 13 | self.act = nn.ReLU(True) if act else None 14 | self.bn = nn.BatchNorm2d(in_channels) if bn else None 15 | 16 | def forward(self, x): 17 | x = self.conv(x) 18 | if self.bn: 19 | x = self.bn(x) 20 | if self.act: 21 | x = self.act(x) 22 | return x 23 | 24 | 25 | class BiBlock(nn.Module): 26 | 27 | def __init__(self,feature_size): 28 | super(BiBlock,self).__init__() 29 | 30 | 31 | self.maxpooling = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) 32 | self.p7_to_p6 = ConvBlock(feature_size,feature_size) 33 | self.p6_to_p5 = ConvBlock(feature_size,feature_size) 34 | self.p5_to_p4 = ConvBlock(feature_size,feature_size) 35 | 36 | self.p3 = ConvBlock(feature_size,feature_size) 37 | self.p4 = ConvBlock(feature_size,feature_size) 38 | self.p5 = ConvBlock(feature_size,feature_size) 39 | self.p6 = ConvBlock(feature_size,feature_size) 40 | self.p7 = ConvBlock(feature_size,feature_size) 41 | 42 | 43 | 44 | def forward(self, pyramids): 45 | p3,p4,p5,p6,p7 = pyramids 46 | 47 | p7_to_p6 = F.upsample(p7,size=p6.shape[-2:]) 48 | # p7_to_p6 = F.upsample(p7,scale_factor=2) 49 | p7_to_p6 = self.p7_to_p6(p7_to_p6 + p6) 50 | 51 | p6_to_p5 = F.upsample(p7_to_p6,p5.shape[-2:]) 52 | # p6_to_p5 = F.upsample(p7_to_p6,scale_factor=2) 53 | p6_to_p5 = self.p6_to_p5(p6_to_p5 + p5) 54 | 55 | p5_to_p4 = F.upsample(p6_to_p5,size=p4.shape[-2:]) 56 | # p5_to_p4 = F.upsample(p6_to_p5,scale_factor=2) 57 | p5_to_p4 = self.p5_to_p4(p5_to_p4 + p4) 58 | 59 | p4_to_p3 = F.upsample(p5_to_p4,size=p3.shape[-2:]) 60 | # p4_to_p3 = F.upsample(p5_to_p4,scale_factor=2) 61 | p3 = self.p3(p4_to_p3 + p3) 62 | 63 | p3_to_p4 = self.maxpooling(p3) 64 | p4 = self.p4(p3_to_p4 + p5_to_p4 + p4) 65 | 66 | p4_to_p5 = self.maxpooling(p4) 67 | p5 = self.p5(p4_to_p5 + p6_to_p5 + p5) 68 | 69 | p5_to_p6 = self.maxpooling(p5) 70 | p5_to_p6 = F.upsample(p5_to_p6,size=p6.shape[-2:]) 71 | p6 = self.p6(p5_to_p6 + p7_to_p6 + p6) 72 | 73 | p6_to_p7 = self.maxpooling(p6) 74 | p6_to_p7 = F.upsample(p6_to_p7,size=p7.shape[-2:]) 75 | p7 = self.p7(p6_to_p7 + p7) 76 | 77 | return p3,p4,p5,p6,p7 78 | 79 | 80 | class BiFpn(nn.Module): 81 | 82 | def __init__(self,in_channels,out_channels,len_input,bi = 3): 83 | super(BiFpn,self).__init__() 84 | assert len_input <= 5 85 | self.len_input = len_input 86 | self.bi = bi 87 | self.default = 5 - len_input 88 | for i in range(len_input): 89 | setattr(self, 'p{}'.format(str(i)), ConvBlock(in_channels=in_channels[i], out_channels=out_channels, 90 | kernel=1, strides=1, padding=0, act=False, bn=False)) 91 | if self.default > 0: 92 | for i in range(self.default): 93 | setattr(self,'make_pyramid{}'.format(str(i)),ConvBlock(in_channels=in_channels[-1] if i == 0 else out_channels,out_channels=out_channels,kernel=3,strides=2, 94 | padding=1,act=False,bn=False)) 95 | for i in range(bi): 96 | setattr(self, 'biblock{}'.format(str(i)), BiBlock(out_channels)) 97 | 98 | def forward(self, inputs): 99 | pyramids = [] 100 | for i in range(self.len_input): 101 | pyramids.append(getattr(self,'p{}'.format(str(i)))(inputs[i])) 102 | 103 | if self.default > 0: 104 | x = inputs[-1] 105 | for i in range(self.default): 106 | x = getattr(self,'make_pyramid{}'.format(str(i)))(x) 107 | pyramids.append(x) 108 | 109 | for i in range(self.bi): 110 | pyramids = getattr(self,'biblock{}'.format(str(i)))(pyramids) 111 | 112 | return pyramids 113 | 114 | 115 | if __name__ == '__main__': 116 | p3 = torch.randn(2,512,64,64) 117 | p4 = torch.randn(2,1024,32,32) 118 | p5 = torch.randn(2,2048,16,16) 119 | 120 | b = BiFpn([512,1024,2048],256,3,3) 121 | py = b([p3,p4,p5]) 122 | for i in py: 123 | print(i.shape) 124 | 125 | -------------------------------------------------------------------------------- /tool/postprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .boxlist import BoxList, boxlist_nms, remove_small_box, cat_boxlist 5 | 6 | 7 | class FCOSPostprocessor(nn.Module): 8 | def __init__(self, threshold, top_n, nms_threshold, post_top_n, min_size, n_class): 9 | super().__init__() 10 | 11 | self.threshold = threshold 12 | self.top_n = top_n 13 | self.nms_threshold = nms_threshold 14 | self.post_top_n = post_top_n 15 | self.min_size = min_size 16 | self.n_class = n_class 17 | 18 | def forward_single_feature_map( 19 | self, location, cls_pred, box_pred, center_pred, image_sizes 20 | ): 21 | batch, channel, height, width = cls_pred.shape 22 | 23 | cls_pred = cls_pred.view(batch, channel, height, width).permute(0, 2, 3, 1) 24 | cls_pred = cls_pred.reshape(batch, -1, channel).sigmoid() 25 | 26 | box_pred = box_pred.view(batch, 4, height, width).permute(0, 2, 3, 1) 27 | box_pred = box_pred.reshape(batch, -1, 4) 28 | 29 | center_pred = center_pred.view(batch, 1, height, width).permute(0, 2, 3, 1) 30 | center_pred = center_pred.reshape(batch, -1).sigmoid() 31 | 32 | candid_ids = cls_pred > self.threshold 33 | top_ns = candid_ids.view(batch, -1).sum(1) 34 | top_ns = top_ns.clamp(max=self.top_n) 35 | 36 | cls_pred = cls_pred * center_pred[:, :, None] 37 | 38 | results = [] 39 | 40 | for i in range(batch): 41 | cls_p = cls_pred[i] 42 | candid_id = candid_ids[i] 43 | cls_p = cls_p[candid_id] 44 | candid_nonzero = candid_id.nonzero() 45 | box_loc = candid_nonzero[:, 0] 46 | class_id = candid_nonzero[:, 1] + 1 47 | 48 | box_p = box_pred[i] 49 | box_p = box_p[box_loc] 50 | loc = location[box_loc] 51 | 52 | top_n = top_ns[i] 53 | 54 | if candid_id.sum().item() > top_n.item(): 55 | cls_p, top_k_id = cls_p.topk(top_n, sorted=False) 56 | class_id = class_id[top_k_id] 57 | box_p = box_p[top_k_id] 58 | loc = loc[top_k_id] 59 | 60 | detections = torch.stack( 61 | [ 62 | loc[:, 0] - box_p[:, 0], 63 | loc[:, 1] - box_p[:, 1], 64 | loc[:, 0] + box_p[:, 2], 65 | loc[:, 1] + box_p[:, 3], 66 | ], 67 | 1, 68 | ) 69 | 70 | height, width = image_sizes[i] 71 | 72 | boxlist = BoxList(detections, (int(width), int(height)), mode='xyxy') 73 | boxlist.fields['labels'] = class_id 74 | boxlist.fields['scores'] = torch.sqrt(cls_p) 75 | boxlist = boxlist.clip(remove_empty=False) 76 | boxlist = remove_small_box(boxlist, self.min_size) 77 | 78 | results.append(boxlist) 79 | 80 | return results 81 | 82 | def forward(self, location, cls_pred, box_pred, center_pred, image_sizes): 83 | boxes = [] 84 | 85 | for loc, cls_p, box_p, center_p in zip( 86 | location, cls_pred, box_pred, center_pred 87 | ): 88 | boxes.append( 89 | self.forward_single_feature_map( 90 | loc, cls_p, box_p, center_p, image_sizes 91 | ) 92 | ) 93 | 94 | boxlists = list(zip(*boxes)) 95 | boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] 96 | boxlists = self.select_over_scales(boxlists) 97 | 98 | return boxlists 99 | 100 | def select_over_scales(self, boxlists): 101 | results = [] 102 | 103 | for boxlist in boxlists: 104 | scores = boxlist.fields['scores'] 105 | labels = boxlist.fields['labels'] 106 | box = boxlist.box 107 | 108 | result = [] 109 | 110 | for j in range(1, self.n_class): 111 | id = (labels == j).nonzero().view(-1) 112 | score_j = scores[id] 113 | box_j = box[id, :].view(-1, 4) 114 | box_by_class = BoxList(box_j, boxlist.size, mode='xyxy') 115 | box_by_class.fields['scores'] = score_j 116 | box_by_class = boxlist_nms(box_by_class, score_j, self.nms_threshold) 117 | n_label = len(box_by_class) 118 | box_by_class.fields['labels'] = torch.full( 119 | (n_label,), j, dtype=torch.int64, device=scores.device 120 | ) 121 | result.append(box_by_class) 122 | 123 | result = cat_boxlist(result) 124 | n_detection = len(result) 125 | 126 | if n_detection > self.post_top_n > 0: 127 | scores = result.fields['scores'] 128 | img_threshold, _ = torch.kthvalue( 129 | scores.cpu(), n_detection - self.post_top_n + 1 130 | ) 131 | keep = scores >= img_threshold.item() 132 | keep = torch.nonzero(keep).squeeze(1) 133 | result = result[keep] 134 | 135 | results.append(result) 136 | 137 | return results 138 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torchvision import datasets 5 | 6 | from tool.boxlist import BoxList 7 | 8 | 9 | CLASS_NAME = [ 10 | '__background__', 11 | 'person', 12 | 'bicycle', 13 | 'car', 14 | 'motorcycle', 15 | 'airplane', 16 | 'bus', 17 | 'train', 18 | 'truck', 19 | 'boat', 20 | 'traffic light', 21 | 'fire hydrant', 22 | 'stop sign', 23 | 'parking meter', 24 | 'bench', 25 | 'bird', 26 | 'cat', 27 | 'dog', 28 | 'horse', 29 | 'sheep', 30 | 'cow', 31 | 'elephant', 32 | 'bear', 33 | 'zebra', 34 | 'giraffe', 35 | 'backpack', 36 | 'umbrella', 37 | 'handbag', 38 | 'tie', 39 | 'suitcase', 40 | 'frisbee', 41 | 'skis', 42 | 'snowboard', 43 | 'sports ball', 44 | 'kite', 45 | 'baseball bat', 46 | 'baseball glove', 47 | 'skateboard', 48 | 'surfboard', 49 | 'tennis racket', 50 | 'bottle', 51 | 'wine glass', 52 | 'cup', 53 | 'fork', 54 | 'knife', 55 | 'spoon', 56 | 'bowl', 57 | 'banana', 58 | 'apple', 59 | 'sandwich', 60 | 'orange', 61 | 'broccoli', 62 | 'carrot', 63 | 'hot dog', 64 | 'pizza', 65 | 'donut', 66 | 'cake', 67 | 'chair', 68 | 'couch', 69 | 'potted plant', 70 | 'bed', 71 | 'dining table', 72 | 'toilet', 73 | 'tv', 74 | 'laptop', 75 | 'mouse', 76 | 'remote', 77 | 'keyboard', 78 | 'cell phone', 79 | 'microwave', 80 | 'oven', 81 | 'toaster', 82 | 'sink', 83 | 'refrigerator', 84 | 'book', 85 | 'clock', 86 | 'vase', 87 | 'scissors', 88 | 'teddy bear', 89 | 'hair drier', 90 | 'toothbrush', 91 | ] 92 | 93 | 94 | def has_only_empty_bbox(annot): 95 | return all(any(o <= 1 for o in obj['bbox'][2:]) for obj in annot) 96 | 97 | 98 | def has_valid_annotation(annot): 99 | if len(annot) == 0: 100 | return False 101 | 102 | if has_only_empty_bbox(annot): 103 | return False 104 | 105 | return True 106 | 107 | 108 | class COCODataset(datasets.CocoDetection): 109 | def __init__(self, path, split, transform=None): 110 | root = os.path.join(path, '{}2017'.format(split)) 111 | annot = os.path.join(path, 'annotations', 'instances_{}2017.json'.format(split)) 112 | 113 | super().__init__(root, annot) 114 | 115 | self.ids = sorted(self.ids) 116 | 117 | if split == 'train': 118 | ids = [] 119 | 120 | for id in self.ids: 121 | ann_ids = self.coco.getAnnIds(imgIds=id, iscrowd=None) 122 | annot = self.coco.loadAnns(ann_ids) 123 | 124 | if has_valid_annotation(annot): 125 | ids.append(id) 126 | 127 | self.ids = ids 128 | 129 | self.category2id = {v: i + 1 for i, v in enumerate(self.coco.getCatIds())} 130 | self.id2category = {v: k for k, v in self.category2id.items()} 131 | self.id2img = {k: v for k, v in enumerate(self.ids)} 132 | 133 | self.transform = transform 134 | 135 | def __getitem__(self, index): 136 | img, annot = super().__getitem__(index) 137 | 138 | annot = [o for o in annot if o['iscrowd'] == 0] 139 | 140 | boxes = [o['bbox'] for o in annot] 141 | boxes = torch.as_tensor(boxes).reshape(-1, 4) 142 | target = BoxList(boxes, img.size, mode='xywh').convert('xyxy') 143 | 144 | classes = [o['category_id'] for o in annot] 145 | classes = [self.category2id[c] for c in classes] 146 | classes = torch.tensor(classes) 147 | target.fields['labels'] = classes 148 | 149 | target.clip(remove_empty=True) 150 | 151 | if self.transform is not None: 152 | img, target = self.transform(img, target) 153 | 154 | return img, target, index 155 | 156 | def get_image_meta(self, index): 157 | id = self.id2img[index] 158 | img_data = self.coco.imgs[id] 159 | 160 | return img_data 161 | 162 | 163 | class ImageList: 164 | def __init__(self, tensors, sizes): 165 | self.tensors = tensors 166 | self.sizes = sizes 167 | 168 | def to(self, *args, **kwargs): 169 | tensor = self.tensors.to(*args, **kwargs) 170 | 171 | return ImageList(tensor, self.sizes) 172 | 173 | 174 | def image_list(tensors, size_divisible=0): 175 | max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors])) 176 | 177 | if size_divisible > 0: 178 | stride = size_divisible 179 | max_size = list(max_size) 180 | max_size[1] = (max_size[1] | (stride - 1)) + 1 181 | max_size[2] = (max_size[2] | (stride - 1)) + 1 182 | max_size = tuple(max_size) 183 | 184 | shape = (len(tensors),) + max_size 185 | batch = tensors[0].new(*shape).zero_() 186 | 187 | for img, pad_img in zip(tensors, batch): 188 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 189 | 190 | sizes = [img.shape[-2:] for img in tensors] 191 | 192 | return ImageList(batch, sizes) 193 | 194 | 195 | def collate_fn(config): 196 | def collate_data(batch): 197 | batch = list(zip(*batch)) 198 | imgs = image_list(batch[0], config.size_divisible) 199 | targets = batch[1] 200 | ids = batch[2] 201 | 202 | return imgs, targets, ids 203 | 204 | return collate_data 205 | -------------------------------------------------------------------------------- /tool/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | def get_rank(): 10 | if not dist.is_available(): 11 | return 0 12 | 13 | if not dist.is_initialized(): 14 | return 0 15 | 16 | return dist.get_rank() 17 | 18 | 19 | def synchronize(): 20 | if not dist.is_available(): 21 | return 22 | 23 | if not dist.is_initialized(): 24 | return 25 | 26 | world_size = dist.get_world_size() 27 | 28 | if world_size == 1: 29 | return 30 | 31 | dist.barrier() 32 | 33 | 34 | def get_world_size(): 35 | if not dist.is_available(): 36 | return 1 37 | 38 | if not dist.is_initialized(): 39 | return 1 40 | 41 | return dist.get_world_size() 42 | 43 | 44 | def all_gather(data): 45 | world_size = get_world_size() 46 | 47 | if world_size == 1: 48 | return [data] 49 | 50 | buffer = pickle.dumps(data) 51 | storage = torch.ByteStorage.from_buffer(buffer) 52 | tensor = torch.ByteTensor(storage).to('cuda') 53 | 54 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 55 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 56 | dist.all_gather(size_list, local_size) 57 | size_list = [int(size.item()) for size in size_list] 58 | max_size = max(size_list) 59 | 60 | tensor_list = [] 61 | for _ in size_list: 62 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 63 | 64 | if local_size != max_size: 65 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 66 | tensor = torch.cat((tensor, padding), 0) 67 | 68 | dist.all_gather(tensor_list, tensor) 69 | 70 | data_list = [] 71 | 72 | for size, tensor in zip(size_list, tensor_list): 73 | buffer = tensor.cpu().numpy().tobytes()[:size] 74 | data_list.append(pickle.loads(buffer)) 75 | 76 | return data_list 77 | 78 | 79 | def reduce_loss_dict(loss_dict): 80 | world_size = get_world_size() 81 | 82 | if world_size < 2: 83 | return loss_dict 84 | 85 | with torch.no_grad(): 86 | keys = [] 87 | losses = [] 88 | 89 | for k in sorted(loss_dict.keys()): 90 | keys.append(k) 91 | losses.append(loss_dict[k]) 92 | 93 | losses = torch.stack(losses, 0) 94 | dist.reduce(losses, dst=0) 95 | 96 | if dist.get_rank() == 0: 97 | losses /= world_size 98 | 99 | reduced_losses = {k: v for k, v in zip(keys, losses)} 100 | 101 | return reduced_losses 102 | 103 | 104 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 105 | # Code is copy-pasted exactly as in torch.utils.data.distributed. 106 | # FIXME remove this once c10d fixes the bug it has 107 | 108 | 109 | class DistributedSampler(Sampler): 110 | """Sampler that restricts data loading to a subset of the dataset. 111 | It is especially useful in conjunction with 112 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 113 | process can pass a DistributedSampler instance as a DataLoader sampler, 114 | and load a subset of the original dataset that is exclusive to it. 115 | .. note:: 116 | Dataset is assumed to be of constant size. 117 | Arguments: 118 | dataset: Dataset used for sampling. 119 | num_replicas (optional): Number of processes participating in 120 | distributed training. 121 | rank (optional): Rank of the current process within num_replicas. 122 | """ 123 | 124 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 125 | if num_replicas is None: 126 | if not dist.is_available(): 127 | raise RuntimeError("Requires distributed package to be available") 128 | num_replicas = dist.get_world_size() 129 | if rank is None: 130 | if not dist.is_available(): 131 | raise RuntimeError("Requires distributed package to be available") 132 | rank = dist.get_rank() 133 | self.dataset = dataset 134 | self.num_replicas = num_replicas 135 | self.rank = rank 136 | self.epoch = 0 137 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 138 | self.total_size = self.num_samples * self.num_replicas 139 | self.shuffle = shuffle 140 | 141 | def __iter__(self): 142 | if self.shuffle: 143 | # deterministically shuffle based on epoch 144 | g = torch.Generator() 145 | g.manual_seed(self.epoch) 146 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 147 | else: 148 | indices = torch.arange(len(self.dataset)).tolist() 149 | 150 | # add extra samples to make it evenly divisible 151 | indices += indices[: (self.total_size - len(indices))] 152 | assert len(indices) == self.total_size 153 | 154 | # subsample 155 | offset = self.num_samples * self.rank 156 | indices = indices[offset : offset + self.num_samples] 157 | assert len(indices) == self.num_samples 158 | 159 | return iter(indices) 160 | 161 | def __len__(self): 162 | return self.num_samples 163 | 164 | def set_epoch(self, epoch): 165 | self.epoch = epoch 166 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn, optim 4 | from torch.utils.data import DataLoader, sampler 5 | from tqdm import tqdm 6 | from data.dataset import COCODataset, collate_fn 7 | 8 | from model.model import EfficientDet_Free 9 | from data.transform import preset_transform 10 | from tool.distributed import ( 11 | get_rank, 12 | synchronize, 13 | reduce_loss_dict, 14 | DistributedSampler, 15 | all_gather 16 | ) 17 | from tool.evaluate import evaluate 18 | from config import Config 19 | 20 | def accumulate_predictions(predictions): 21 | all_predictions = all_gather(predictions) 22 | 23 | if get_rank() != 0: 24 | return 25 | 26 | predictions = {} 27 | 28 | for p in all_predictions: 29 | predictions.update(p) 30 | 31 | ids = list(sorted(predictions.keys())) 32 | 33 | if len(ids) != ids[-1] + 1: 34 | print('Evaluation results is not contiguous') 35 | 36 | predictions = [predictions[i] for i in ids] 37 | 38 | return predictions 39 | 40 | @torch.no_grad() 41 | def valid(args, epoch, loader, dataset, model, device): 42 | if args.distributed: 43 | model = model.module 44 | 45 | torch.cuda.empty_cache() 46 | 47 | model.eval() 48 | 49 | pbar = tqdm(loader, dynamic_ncols=True) 50 | 51 | preds = {} 52 | 53 | for images, targets, ids in pbar: 54 | model.zero_grad() 55 | 56 | images = images.to(device) 57 | targets = [target.to(device) for target in targets] 58 | 59 | pred, _ = model(images.tensors, images.sizes) 60 | 61 | pred = [p.to('cpu') for p in pred] 62 | 63 | preds.update({id: p for id, p in zip(ids, pred)}) 64 | 65 | preds = accumulate_predictions(preds) 66 | 67 | if get_rank() != 0: 68 | return 69 | 70 | evaluate(dataset, preds) 71 | 72 | def train(args, epoch, loader, model, optimizer, device): 73 | model.train() 74 | if get_rank() == 0: 75 | pbar = tqdm(loader, dynamic_ncols=True) 76 | 77 | else: 78 | pbar = loader 79 | for images, targets, _ in pbar: 80 | model.zero_grad() 81 | 82 | images = images.to(device) 83 | targets = [target.to(device) for target in targets] 84 | 85 | _, loss_dict = model(images.tensors, targets=targets) 86 | loss_cls = loss_dict['loss_cls'].mean() 87 | loss_box = loss_dict['loss_box'].mean() 88 | loss_center = loss_dict['loss_center'].mean() 89 | 90 | loss = loss_cls + loss_box + loss_center 91 | loss.backward() 92 | nn.utils.clip_grad_norm_(model.parameters(), 10) 93 | optimizer.step() 94 | 95 | loss_reduced = reduce_loss_dict(loss_dict) 96 | loss_cls = loss_reduced['loss_cls'].mean().item() 97 | loss_box = loss_reduced['loss_box'].mean().item() 98 | loss_center = loss_reduced['loss_center'].mean().item() 99 | 100 | if get_rank() == 0: 101 | pbar.set_description( 102 | ( 103 | 'epoch: {}; cls: {}; box: {}; center: {}'.format(epoch + 1,loss_cls,loss_box,loss_center) 104 | ) 105 | ) 106 | 107 | def data_sampler(dataset, shuffle, distributed): 108 | if distributed: 109 | return DistributedSampler(dataset, shuffle=shuffle) 110 | 111 | if shuffle: 112 | return sampler.RandomSampler(dataset) 113 | 114 | else: 115 | return sampler.SequentialSampler(dataset) 116 | 117 | if __name__ == '__main__': 118 | args = Config() 119 | n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 120 | args.distributed = n_gpu > 1 121 | 122 | if args.distributed: 123 | torch.cuda.set_device(args.local_rank) 124 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 125 | synchronize() 126 | 127 | device = 'cuda:0' 128 | train_set = COCODataset(args.path, 'train', preset_transform(args, train=True)) 129 | valid_set = COCODataset(args.path, 'val', preset_transform(args, train=False)) 130 | 131 | 132 | model = EfficientDet_Free(args) 133 | model.load_state_dict() 134 | model = model.to(device) 135 | optimizer = optim.SGD( 136 | model.parameters(), 137 | lr=args.lr, 138 | momentum=0.9, 139 | weight_decay=args.l2, 140 | nesterov=True, 141 | ) 142 | scheduler = optim.lr_scheduler.MultiStepLR( 143 | optimizer, milestones=[16, 22], gamma=0.1 144 | ) 145 | if args.distributed: 146 | model = nn.parallel.DistributedDataParallel( 147 | model, 148 | device_ids=[args.local_rank], 149 | output_device=args.local_rank, 150 | broadcast_buffers=False, 151 | ) 152 | 153 | train_loader = DataLoader( 154 | train_set, 155 | batch_size=args.batch, 156 | sampler=data_sampler(train_set, shuffle=True, distributed=args.distributed), 157 | num_workers=2, 158 | collate_fn=collate_fn(args), 159 | ) 160 | valid_loader = DataLoader( 161 | valid_set, 162 | batch_size=args.batch, 163 | sampler=data_sampler(valid_set, shuffle=False, distributed=args.distributed), 164 | num_workers=2, 165 | collate_fn=collate_fn(args), 166 | ) 167 | for epoch in range(args.epoch): 168 | train(args, epoch, train_loader, model, optimizer, device) 169 | valid(args, epoch, valid_loader, valid_set, model, device) 170 | 171 | scheduler.step() 172 | 173 | if get_rank() == 0: 174 | torch.save( 175 | {'model': model.module.state_dict(), 'optim': optimizer.state_dict()}, 176 | 'checkpoint/epoch-{epoch + 1}.pt', 177 | ) 178 | 179 | -------------------------------------------------------------------------------- /tool/boxlist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import ops 3 | 4 | 5 | FLIP_LEFT_RIGHT = 0 6 | FLIP_TOP_BOTTOM = 1 7 | 8 | 9 | class BoxList: 10 | def __init__(self, box, image_size, mode='xyxy'): 11 | device = box.device if hasattr(box, 'device') else 'cpu' 12 | box = torch.as_tensor(box, dtype=torch.float32, device=device) 13 | 14 | self.box = box 15 | self.size = image_size 16 | self.mode = mode 17 | 18 | self.fields = {} 19 | 20 | def convert(self, mode): 21 | if mode == self.mode: 22 | return self 23 | 24 | x_min, y_min, x_max, y_max = self.split_to_xyxy() 25 | 26 | if mode == 'xyxy': 27 | box = torch.cat([x_min, y_min, x_max, y_max], -1) 28 | box = BoxList(box, self.size, mode=mode) 29 | 30 | elif mode == 'xywh': 31 | remove = 1 32 | box = torch.cat( 33 | [x_min, y_min, x_max - x_min + remove, y_max - y_min + remove], -1 34 | ) 35 | box = BoxList(box, self.size, mode=mode) 36 | 37 | box.copy_field(self) 38 | 39 | return box 40 | 41 | def copy_field(self, box): 42 | for k, v in box.fields.items(): 43 | self.fields[k] = v 44 | 45 | def area(self): 46 | box = self.box 47 | 48 | if self.mode == 'xyxy': 49 | remove = 1 50 | 51 | area = (box[:, 2] - box[:, 0] + remove) * (box[:, 3] - box[:, 1] + remove) 52 | 53 | elif self.mode == 'xywh': 54 | area = box[:, 2] * box[:, 3] 55 | 56 | return area 57 | 58 | def split_to_xyxy(self): 59 | if self.mode == 'xyxy': 60 | x_min, y_min, x_max, y_max = self.box.split(1, dim=-1) 61 | 62 | return x_min, y_min, x_max, y_max 63 | 64 | elif self.mode == 'xywh': 65 | remove = 1 66 | x_min, y_min, w, h = self.box.split(1, dim=-1) 67 | 68 | return ( 69 | x_min, 70 | y_min, 71 | x_min + (w - remove).clamp(min=0), 72 | y_min + (h - remove).clamp(min=0), 73 | ) 74 | 75 | def __len__(self): 76 | return self.box.shape[0] 77 | 78 | def __getitem__(self, index): 79 | box = BoxList(self.box[index], self.size, self.mode) 80 | 81 | for k, v in self.fields.items(): 82 | box.fields[k] = v[index] 83 | 84 | return box 85 | 86 | def resize(self, size, *args, **kwargs): 87 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) 88 | 89 | if ratios[0] == ratios[1]: 90 | ratio = ratios[0] 91 | scaled = self.box * ratio 92 | box = BoxList(scaled, size, mode=self.mode) 93 | 94 | for k, v in self.fields.items(): 95 | if not isinstance(v, torch.Tensor): 96 | v = v.resize(size, *args, **kwargs) 97 | 98 | box.fields[k] = v 99 | 100 | return box 101 | 102 | ratio_w, ratio_h = ratios 103 | x_min, y_min, x_max, y_max = self.split_to_xyxy() 104 | scaled_x_min = x_min * ratio_w 105 | scaled_x_max = x_max * ratio_w 106 | scaled_y_min = y_min * ratio_h 107 | scaled_y_max = y_max * ratio_h 108 | scaled = torch.cat([scaled_x_min, scaled_y_min, scaled_x_max, scaled_y_max], -1) 109 | box = BoxList(scaled, size, mode='xyxy') 110 | 111 | for k, v in self.fields.items(): 112 | if not isinstance(v, torch.Tensor): 113 | v = v.resize(size, *args, **kwargs) 114 | 115 | box.fields[k] = v 116 | 117 | return box.convert(self.mode) 118 | 119 | def transpose(self, method): 120 | width, height = self.size 121 | x_min, y_min, x_max, y_max = self.split_to_xyxy() 122 | 123 | if method == FLIP_LEFT_RIGHT: 124 | remove = 1 125 | 126 | transpose_x_min = width - x_max - remove 127 | transpose_x_max = width - x_min - remove 128 | transpose_y_min = y_min 129 | transpose_y_max = y_max 130 | 131 | elif method == FLIP_TOP_BOTTOM: 132 | transpose_x_min = x_min 133 | transpose_x_max = x_max 134 | transpose_y_min = height - y_max 135 | transpose_y_max = height - y_min 136 | 137 | transpose_box = torch.cat( 138 | [transpose_x_min, transpose_y_min, transpose_x_max, transpose_y_max], -1 139 | ) 140 | box = BoxList(transpose_box, self.size, mode='xyxy') 141 | 142 | for k, v in self.fields.items(): 143 | if not isinstance(v, torch.Tensor): 144 | v = v.transpose(method) 145 | 146 | box.fields[k] = v 147 | 148 | return box.convert(self.mode) 149 | 150 | def clip(self, remove_empty=True): 151 | remove = 1 152 | 153 | max_width = self.size[0] - remove 154 | max_height = self.size[1] - remove 155 | 156 | self.box[:, 0].clamp_(min=0, max=max_width) 157 | self.box[:, 1].clamp_(min=0, max=max_height) 158 | self.box[:, 2].clamp_(min=0, max=max_width) 159 | self.box[:, 3].clamp_(min=0, max=max_height) 160 | 161 | if remove_empty: 162 | box = self.box 163 | keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0]) 164 | 165 | return self[keep] 166 | 167 | else: 168 | return self 169 | 170 | def to(self, device): 171 | box = BoxList(self.box.to(device), self.size, self.mode) 172 | 173 | for k, v in self.fields.items(): 174 | if hasattr(v, 'to'): 175 | v = v.to(device) 176 | 177 | box.fields[k] = v 178 | 179 | return box 180 | 181 | 182 | def remove_small_box(boxlist, min_size): 183 | box = boxlist.convert('xywh').box 184 | _, _, w, h = box.unbind(dim=1) 185 | keep = (w >= min_size) & (h >= min_size) 186 | keep = keep.nonzero().squeeze(1) 187 | 188 | return boxlist[keep] 189 | 190 | 191 | def cat_boxlist(boxlists): 192 | size = boxlists[0].size 193 | mode = boxlists[0].mode 194 | field_keys = boxlists[0].fields.keys() 195 | 196 | box_cat = torch.cat([boxlist.box for boxlist in boxlists], 0) 197 | new_boxlist = BoxList(box_cat, size, mode) 198 | 199 | for field in field_keys: 200 | data = torch.cat([boxlist.fields[field] for boxlist in boxlists], 0) 201 | new_boxlist.fields[field] = data 202 | 203 | return new_boxlist 204 | 205 | 206 | def boxlist_nms(boxlist, scores, threshold, max_proposal=-1): 207 | if threshold <= 0: 208 | return boxlist 209 | 210 | mode = boxlist.mode 211 | boxlist = boxlist.convert('xyxy') 212 | box = boxlist.box 213 | keep = ops.nms(box, scores, threshold) 214 | 215 | if max_proposal > 0: 216 | keep = keep[:max_proposal] 217 | 218 | boxlist = boxlist[keep] 219 | 220 | return boxlist.convert(mode) 221 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | INF = 100000000 6 | 7 | 8 | class IOULoss(nn.Module): 9 | def __init__(self, loc_loss_type): 10 | super().__init__() 11 | 12 | self.loc_loss_type = loc_loss_type 13 | 14 | def forward(self, out, target, weight=None): 15 | pred_left, pred_top, pred_right, pred_bottom = out.unbind(1) 16 | target_left, target_top, target_right, target_bottom = target.unbind(1) 17 | 18 | target_area = (target_left + target_right) * (target_top + target_bottom) 19 | pred_area = (pred_left + pred_right) * (pred_top + pred_bottom) 20 | 21 | w_intersect = torch.min(pred_left, target_left) + torch.min( 22 | pred_right, target_right 23 | ) 24 | h_intersect = torch.min(pred_bottom, target_bottom) + torch.min( 25 | pred_top, target_top 26 | ) 27 | 28 | area_intersect = w_intersect * h_intersect 29 | area_union = target_area + pred_area - area_intersect 30 | 31 | ious = (area_intersect + 1) / (area_union + 1) 32 | 33 | if self.loc_loss_type == 'iou': 34 | loss = -torch.log(ious) 35 | 36 | elif self.loc_loss_type == 'giou': 37 | g_w_intersect = torch.max(pred_left, target_left) + torch.max( 38 | pred_right, target_right 39 | ) 40 | g_h_intersect = torch.max(pred_bottom, target_bottom) + torch.max( 41 | pred_top, target_top 42 | ) 43 | g_intersect = g_w_intersect * g_h_intersect + 1e-7 44 | gious = ious - (g_intersect - area_union) / g_intersect 45 | 46 | loss = 1 - gious 47 | 48 | if weight is not None and weight.sum() > 0: 49 | return (loss * weight).sum() / weight.sum() 50 | 51 | else: 52 | return loss.mean() 53 | 54 | 55 | def clip_sigmoid(input): 56 | out = torch.clamp(torch.sigmoid(input), min=1e-4, max=1 - 1e-4) 57 | 58 | return out 59 | 60 | 61 | class SigmoidFocalLoss(nn.Module): 62 | def __init__(self, gamma, alpha): 63 | super().__init__() 64 | 65 | self.gamma = gamma 66 | self.alpha = alpha 67 | 68 | def forward(self, out, target): 69 | n_class = out.shape[1] 70 | class_ids = torch.arange( 71 | 1, n_class + 1, dtype=target.dtype, device=target.device 72 | ).unsqueeze(0) 73 | 74 | t = target.unsqueeze(1) 75 | p = torch.sigmoid(out) 76 | 77 | gamma = self.gamma 78 | alpha = self.alpha 79 | 80 | term1 = (1 - p) ** gamma * torch.log(p) 81 | term2 = p ** gamma * torch.log(1 - p) 82 | 83 | # print(term1.sum(), term2.sum()) 84 | 85 | loss = ( 86 | -(t == class_ids).float() * alpha * term1 87 | - ((t != class_ids) * (t >= 0)).float() * (1 - alpha) * term2 88 | ) 89 | 90 | return loss.sum() 91 | 92 | 93 | class FCOSLoss(nn.Module): 94 | def __init__( 95 | self, sizes, gamma, alpha, iou_loss_type, center_sample, fpn_strides, pos_radius 96 | ): 97 | super().__init__() 98 | 99 | self.sizes = sizes 100 | 101 | self.cls_loss = SigmoidFocalLoss(gamma, alpha) 102 | self.box_loss = IOULoss(iou_loss_type) 103 | self.center_loss = nn.BCEWithLogitsLoss() 104 | 105 | self.center_sample = center_sample 106 | self.strides = fpn_strides 107 | self.radius = pos_radius 108 | 109 | def prepare_target(self, points, targets): 110 | ex_size_of_interest = [] 111 | 112 | for i, point_per_level in enumerate(points): 113 | size_of_interest_per_level = point_per_level.new_tensor(self.sizes[i]) 114 | ex_size_of_interest.append( 115 | size_of_interest_per_level[None].expand(len(point_per_level), -1) 116 | ) 117 | 118 | ex_size_of_interest = torch.cat(ex_size_of_interest, 0) 119 | n_point_per_level = [len(point_per_level) for point_per_level in points] 120 | point_all = torch.cat(points, dim=0) 121 | label, box_target = self.compute_target_for_location( 122 | point_all, targets, ex_size_of_interest, n_point_per_level 123 | ) 124 | 125 | for i in range(len(label)): 126 | label[i] = torch.split(label[i], n_point_per_level, 0) 127 | box_target[i] = torch.split(box_target[i], n_point_per_level, 0) 128 | 129 | label_level_first = [] 130 | box_target_level_first = [] 131 | 132 | for level in range(len(points)): 133 | label_level_first.append( 134 | torch.cat([label_per_img[level] for label_per_img in label], 0) 135 | ) 136 | box_target_level_first.append( 137 | torch.cat( 138 | [box_target_per_img[level] for box_target_per_img in box_target], 0 139 | ) 140 | ) 141 | 142 | return label_level_first, box_target_level_first 143 | 144 | def get_sample_region(self, gt, strides, n_point_per_level, xs, ys, radius=1): 145 | n_gt = gt.shape[0] 146 | n_loc = len(xs) 147 | gt = gt[None].expand(n_loc, n_gt, 4) 148 | center_x = (gt[..., 0] + gt[..., 2]) / 2 149 | center_y = (gt[..., 1] + gt[..., 3]) / 2 150 | 151 | if center_x[..., 0].sum() == 0: 152 | return xs.new_zeros(xs.shape, dtype=torch.uint8) 153 | 154 | begin = 0 155 | 156 | center_gt = gt.new_zeros(gt.shape) 157 | 158 | for level, n_p in enumerate(n_point_per_level): 159 | end = begin + n_p 160 | stride = strides[level] * radius 161 | 162 | x_min = center_x[begin:end] - stride 163 | y_min = center_y[begin:end] - stride 164 | x_max = center_x[begin:end] + stride 165 | y_max = center_y[begin:end] + stride 166 | 167 | center_gt[begin:end, :, 0] = torch.where( 168 | x_min > gt[begin:end, :, 0], x_min, gt[begin:end, :, 0] 169 | ) 170 | center_gt[begin:end, :, 1] = torch.where( 171 | y_min > gt[begin:end, :, 1], y_min, gt[begin:end, :, 1] 172 | ) 173 | center_gt[begin:end, :, 2] = torch.where( 174 | x_max > gt[begin:end, :, 2], gt[begin:end, :, 2], x_max 175 | ) 176 | center_gt[begin:end, :, 3] = torch.where( 177 | y_max > gt[begin:end, :, 3], gt[begin:end, :, 3], y_max 178 | ) 179 | 180 | begin = end 181 | 182 | left = xs[:, None] - center_gt[..., 0] 183 | right = center_gt[..., 2] - xs[:, None] 184 | top = ys[:, None] - center_gt[..., 1] 185 | bottom = center_gt[..., 3] - ys[:, None] 186 | 187 | center_bbox = torch.stack((left, top, right, bottom), -1) 188 | is_in_boxes = center_bbox.min(-1)[0] > 0 189 | 190 | return is_in_boxes 191 | 192 | def compute_target_for_location( 193 | self, locations, targets, sizes_of_interest, n_point_per_level 194 | ): 195 | labels = [] 196 | box_targets = [] 197 | xs, ys = locations[:, 0], locations[:, 1] 198 | 199 | for i in range(len(targets)): 200 | targets_per_img = targets[i] 201 | assert targets_per_img.mode == 'xyxy' 202 | bboxes = targets_per_img.box 203 | labels_per_img = targets_per_img.fields['labels'] 204 | area = targets_per_img.area() 205 | 206 | l = xs[:, None] - bboxes[:, 0][None] 207 | t = ys[:, None] - bboxes[:, 1][None] 208 | r = bboxes[:, 2][None] - xs[:, None] 209 | b = bboxes[:, 3][None] - ys[:, None] 210 | 211 | box_targets_per_img = torch.stack([l, t, r, b], 2) 212 | 213 | if self.center_sample: 214 | is_in_boxes = self.get_sample_region( 215 | bboxes, self.strides, n_point_per_level, xs, ys, radius=self.radius 216 | ) 217 | 218 | else: 219 | is_in_boxes = box_targets_per_img.min(2)[0] > 0 220 | 221 | max_box_targets_per_img = box_targets_per_img.max(2)[0] 222 | 223 | is_cared_in_level = ( 224 | max_box_targets_per_img >= sizes_of_interest[:, [0]] 225 | ) & (max_box_targets_per_img <= sizes_of_interest[:, [1]]) 226 | 227 | locations_to_gt_area = area[None].repeat(len(locations), 1) 228 | locations_to_gt_area[is_in_boxes == 0] = INF 229 | locations_to_gt_area[is_cared_in_level == 0] = INF 230 | 231 | locations_to_min_area, locations_to_gt_id = locations_to_gt_area.min(1) 232 | 233 | box_targets_per_img = box_targets_per_img[ 234 | range(len(locations)), locations_to_gt_id 235 | ] 236 | labels_per_img = labels_per_img[locations_to_gt_id] 237 | labels_per_img[locations_to_min_area == INF] = 0 238 | 239 | labels.append(labels_per_img) 240 | box_targets.append(box_targets_per_img) 241 | 242 | return labels, box_targets 243 | 244 | def compute_centerness_targets(self, box_targets): 245 | left_right = box_targets[:, [0, 2]] 246 | top_bottom = box_targets[:, [1, 3]] 247 | centerness = (left_right.min(-1)[0] / left_right.max(-1)[0]) * ( 248 | top_bottom.min(-1)[0] / top_bottom.max(-1)[0] 249 | ) 250 | 251 | return torch.sqrt(centerness) 252 | 253 | def forward(self, locations, cls_pred, box_pred, center_pred, targets): 254 | batch = cls_pred[0].shape[0] 255 | n_class = cls_pred[0].shape[1] 256 | 257 | labels, box_targets = self.prepare_target(locations, targets) 258 | 259 | cls_flat = [] 260 | box_flat = [] 261 | center_flat = [] 262 | 263 | labels_flat = [] 264 | box_targets_flat = [] 265 | 266 | for i in range(len(labels)): 267 | cls_flat.append(cls_pred[i].permute(0, 2, 3, 1).reshape(-1, n_class)) 268 | box_flat.append(box_pred[i].permute(0, 2, 3, 1).reshape(-1, 4)) 269 | center_flat.append(center_pred[i].permute(0, 2, 3, 1).reshape(-1)) 270 | 271 | labels_flat.append(labels[i].reshape(-1)) 272 | box_targets_flat.append(box_targets[i].reshape(-1, 4)) 273 | 274 | cls_flat = torch.cat(cls_flat, 0) 275 | box_flat = torch.cat(box_flat, 0) 276 | center_flat = torch.cat(center_flat, 0) 277 | 278 | labels_flat = torch.cat(labels_flat, 0) 279 | box_targets_flat = torch.cat(box_targets_flat, 0) 280 | 281 | pos_id = torch.nonzero(labels_flat > 0).squeeze(1) 282 | 283 | cls_loss = self.cls_loss(cls_flat, labels_flat.int()) / (pos_id.numel() + batch) 284 | 285 | box_flat = box_flat[pos_id] 286 | center_flat = center_flat[pos_id] 287 | 288 | box_targets_flat = box_targets_flat[pos_id] 289 | 290 | if pos_id.numel() > 0: 291 | center_targets = self.compute_centerness_targets(box_targets_flat) 292 | 293 | box_loss = self.box_loss(box_flat, box_targets_flat, center_targets) 294 | center_loss = self.center_loss(center_flat, center_targets) 295 | 296 | else: 297 | box_loss = box_flat.sum() 298 | center_loss = center_flat.sum() 299 | 300 | return cls_loss, box_loss, center_loss 301 | -------------------------------------------------------------------------------- /model/efficientnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from .utils import ( 6 | round_filters, 7 | round_repeats, 8 | drop_connect, 9 | get_same_padding_conv2d, 10 | get_model_params, 11 | efficientnet_params, 12 | load_pretrained_weights, 13 | Swish, 14 | MemoryEfficientSwish, 15 | ) 16 | 17 | class MBConvBlock(nn.Module): 18 | """ 19 | Mobile Inverted Residual Bottleneck Block 20 | Args: 21 | block_args (namedtuple): BlockArgs, see above 22 | global_params (namedtuple): GlobalParam, see above 23 | Attributes: 24 | has_se (bool): Whether the block contains a Squeeze and Excitation layer. 25 | """ 26 | 27 | def __init__(self, block_args, global_params): 28 | super().__init__() 29 | self._block_args = block_args 30 | self._bn_mom = 1 - global_params.batch_norm_momentum 31 | self._bn_eps = global_params.batch_norm_epsilon 32 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) 33 | self.id_skip = block_args.id_skip # skip connection and drop connect 34 | 35 | # Get static or dynamic convolution depending on image size 36 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) 37 | 38 | # Expansion phase 39 | inp = self._block_args.input_filters # number of input channels 40 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels 41 | if self._block_args.expand_ratio != 1: 42 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) 43 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 44 | # Depthwise convolution phase 45 | k = self._block_args.kernel_size 46 | s = self._block_args.stride 47 | self._depthwise_conv = Conv2d( 48 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise 49 | kernel_size=k, stride=s, bias=False) 50 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 51 | 52 | # Squeeze and Excitation layer, if desired 53 | if self.has_se: 54 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) 55 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) 56 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) 57 | 58 | # Output phase 59 | final_oup = self._block_args.output_filters 60 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) 61 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) 62 | self._swish = MemoryEfficientSwish() 63 | 64 | def forward(self, inputs, drop_connect_rate=None): 65 | """ 66 | :param inputs: input tensor 67 | :param drop_connect_rate: drop connect rate (float, between 0 and 1) 68 | :return: output of block 69 | """ 70 | 71 | # Expansion and Depthwise Convolution 72 | x = inputs 73 | if self._block_args.expand_ratio != 1: 74 | x = self._swish(self._bn0(self._expand_conv(inputs))) 75 | 76 | x = self._swish(self._bn1(self._depthwise_conv(x))) 77 | 78 | # Squeeze and Excitation 79 | if self.has_se: 80 | x_squeezed = F.adaptive_avg_pool2d(x, 1) 81 | x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed))) 82 | x = torch.sigmoid(x_squeezed) * x 83 | 84 | x = self._bn2(self._project_conv(x)) 85 | 86 | # Skip connection and drop connect 87 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters 88 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: 89 | if drop_connect_rate: 90 | x = drop_connect(x, p=drop_connect_rate, training=self.training) 91 | x = x + inputs # skip connection 92 | return x 93 | 94 | def set_swish(self, memory_efficient=True): 95 | """Sets swish function as memory efficient (for training) or standard (for export)""" 96 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 97 | 98 | 99 | class EfficientNet(nn.Module): 100 | """ 101 | An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods 102 | Args: 103 | blocks_args (list): A list of BlockArgs to construct blocks 104 | global_params (namedtuple): A set of GlobalParams shared between blocks 105 | Example: 106 | model = EfficientNet.from_pretrained('efficientnet-b0') 107 | """ 108 | 109 | def __init__(self, blocks_args=None, global_params=None): 110 | super().__init__() 111 | assert isinstance(blocks_args, list), 'blocks_args should be a list' 112 | assert len(blocks_args) > 0, 'block args must be greater than 0' 113 | self._global_params = global_params 114 | self._blocks_args = blocks_args 115 | 116 | # Get static or dynamic convolution depending on image size 117 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) 118 | 119 | # Batch norm parameters 120 | bn_mom = 1 - self._global_params.batch_norm_momentum 121 | bn_eps = self._global_params.batch_norm_epsilon 122 | 123 | # Stem 124 | in_channels = 3 # rgb 125 | out_channels = round_filters(32, self._global_params) # number of output channels 126 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 127 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 128 | 129 | # Build blocks 130 | self._blocks = nn.ModuleList([]) 131 | for i in range(len(self._blocks_args)): 132 | # Update block input and output filters based on depth multiplier. 133 | self._blocks_args[i] = self._blocks_args[i]._replace( 134 | input_filters=round_filters(self._blocks_args[i].input_filters, self._global_params), 135 | output_filters=round_filters(self._blocks_args[i].output_filters, self._global_params), 136 | num_repeat=round_repeats(self._blocks_args[i].num_repeat, self._global_params) 137 | ) 138 | 139 | # The first block needs to take care of stride and filter size increase. 140 | self._blocks.append(MBConvBlock(self._blocks_args[i], self._global_params)) 141 | if self._blocks_args[i].num_repeat > 1: 142 | self._blocks_args[i] = self._blocks_args[i]._replace(input_filters=self._blocks_args[i].output_filters, stride=1) 143 | for _ in range(self._blocks_args[i].num_repeat - 1): 144 | self._blocks.append(MBConvBlock(self._blocks_args[i], self._global_params)) 145 | 146 | # Head'efficientdet-d0': 'efficientnet-b0', 147 | in_channels = self._blocks_args[len(self._blocks_args)-1].output_filters # output of final block 148 | out_channels = round_filters(1280, self._global_params) 149 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 150 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 151 | 152 | # Final linear layer 153 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 154 | self._dropout = nn.Dropout(self._global_params.dropout_rate) 155 | self._fc = nn.Linear(out_channels, self._global_params.num_classes) 156 | self._swish = MemoryEfficientSwish() 157 | 158 | def set_swish(self, memory_efficient=True): 159 | """Sets swish function as memory efficient (for training) or standard (for export)""" 160 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 161 | for block in self._blocks: 162 | block.set_swish(memory_efficient) 163 | 164 | 165 | def extract_features(self, inputs): 166 | """ Returns output of the final convolution layer """ 167 | # Stem 168 | x = self._swish(self._bn0(self._conv_stem(inputs))) 169 | 170 | P = [] 171 | index = 0 172 | num_repeat = 0 173 | # Blocks 174 | for idx, block in enumerate(self._blocks): 175 | drop_connect_rate = self._global_params.drop_connect_rate 176 | if drop_connect_rate: 177 | drop_connect_rate *= float(idx) / len(self._blocks) 178 | x = block(x, drop_connect_rate=drop_connect_rate) 179 | num_repeat = num_repeat + 1 180 | if(num_repeat == self._blocks_args[index].num_repeat): 181 | num_repeat = 0 182 | index = index + 1 183 | P.append(x) 184 | return P 185 | 186 | def forward(self, inputs): 187 | """ Calls extract_features to extract features, applies final linear layer, and returns logits. """ 188 | # Convolution layers 189 | P = self.extract_features(inputs) 190 | return P 191 | 192 | @classmethod 193 | def from_name(cls, model_name, override_params=None): 194 | cls._check_model_name_is_valid(model_name) 195 | blocks_args, global_params = get_model_params(model_name, override_params) 196 | return cls(blocks_args, global_params) 197 | 198 | @classmethod 199 | def from_pretrained(cls, model_name, num_classes=1000, in_channels = 3): 200 | model = cls.from_name(model_name, override_params={'num_classes': num_classes}) 201 | load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000)) 202 | if in_channels != 3: 203 | Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size) 204 | out_channels = round_filters(32, model._global_params) 205 | model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 206 | return model 207 | 208 | @classmethod 209 | def from_pretrained(cls, model_name, num_classes=1000): 210 | model = cls.from_name(model_name, override_params={'num_classes': num_classes}) 211 | load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000)) 212 | 213 | return model 214 | 215 | @classmethod 216 | def get_image_size(cls, model_name): 217 | cls._check_model_name_is_valid(model_name) 218 | _, _, res, _ = efficientnet_params(model_name) 219 | return res 220 | 221 | @classmethod 222 | def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False): 223 | """ Validates model name. None that pretrained weights are only available for 224 | the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """ 225 | num_models = 4 if also_need_pretrained_weights else 8 226 | valid_models = ['efficientnet-b'+str(i) for i in range(num_models)] 227 | if model_name not in valid_models: 228 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) 229 | 230 | def get_list_features(self): 231 | list_feature = [] 232 | for idx in range(len(self._blocks_args)): 233 | list_feature.append(self._blocks_args[idx].output_filters) 234 | 235 | return list_feature 236 | 237 | 238 | 239 | 240 | 241 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import math 3 | import collections 4 | from functools import partial 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.utils import model_zoo 9 | 10 | ######################################################################## 11 | ############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ############### 12 | ######################################################################## 13 | 14 | 15 | # Parameters for the entire model (stem, all blocks, and head) 16 | GlobalParams = collections.namedtuple('GlobalParams', [ 17 | 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 18 | 'num_classes', 'width_coefficient', 'depth_coefficient', 19 | 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size']) 20 | 21 | # Parameters for an individual model block 22 | BlockArgs = collections.namedtuple('BlockArgs', [ 23 | 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', 24 | 'expand_ratio', 'id_skip', 'stride', 'se_ratio']) 25 | 26 | # Change namedtuple defaults 27 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) 28 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 29 | 30 | 31 | class SwishImplementation(torch.autograd.Function): 32 | @staticmethod 33 | def forward(ctx, i): 34 | result = i * torch.sigmoid(i) 35 | ctx.save_for_backward(i) 36 | return result 37 | 38 | @staticmethod 39 | def backward(ctx, grad_output): 40 | i = ctx.saved_variables[0] 41 | sigmoid_i = torch.sigmoid(i) 42 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) 43 | 44 | 45 | class MemoryEfficientSwish(nn.Module): 46 | def forward(self, x): 47 | return SwishImplementation.apply(x) 48 | 49 | class Swish(nn.Module): 50 | def forward(self, x): 51 | return x * torch.sigmoid(x) 52 | 53 | 54 | def round_filters(filters, global_params): 55 | """ Calculate and round number of filters based on depth multiplier. """ 56 | multiplier = global_params.width_coefficient 57 | if not multiplier: 58 | return filters 59 | divisor = global_params.depth_divisor 60 | min_depth = global_params.min_depth 61 | filters *= multiplier 62 | min_depth = min_depth or divisor 63 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) 64 | if new_filters < 0.9 * filters: # prevent rounding by more than 10% 65 | new_filters += divisor 66 | return int(new_filters) 67 | 68 | 69 | def round_repeats(repeats, global_params): 70 | """ Round number of filters based on depth multiplier. """ 71 | multiplier = global_params.depth_coefficient 72 | if not multiplier: 73 | return repeats 74 | return int(math.ceil(multiplier * repeats)) 75 | 76 | 77 | def drop_connect(inputs, p, training): 78 | """ Drop connect. """ 79 | if not training: return inputs 80 | batch_size = inputs.shape[0] 81 | keep_prob = 1 - p 82 | random_tensor = keep_prob 83 | random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) 84 | binary_tensor = torch.floor(random_tensor) 85 | output = inputs / keep_prob * binary_tensor 86 | return output 87 | 88 | 89 | def get_same_padding_conv2d(image_size=None): 90 | """ Chooses static padding if you have specified an image size, and dynamic padding otherwise. 91 | Static padding is necessary for ONNX exporting of models. """ 92 | if image_size is None: 93 | return Conv2dDynamicSamePadding 94 | else: 95 | return partial(Conv2dStaticSamePadding, image_size=image_size) 96 | 97 | 98 | class Conv2dDynamicSamePadding(nn.Conv2d): 99 | """ 2D Convolutions like TensorFlow, for a dynamic image size """ 100 | 101 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): 102 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 103 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 104 | 105 | def forward(self, x): 106 | ih, iw = x.size()[-2:] 107 | kh, kw = self.weight.size()[-2:] 108 | sh, sw = self.stride 109 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 110 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 111 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 112 | if pad_h > 0 or pad_w > 0: 113 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 114 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 115 | 116 | 117 | class Conv2dStaticSamePadding(nn.Conv2d): 118 | """ 2D Convolutions like TensorFlow, for a fixed image size""" 119 | 120 | def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs): 121 | super().__init__(in_channels, out_channels, kernel_size, **kwargs) 122 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 123 | 124 | # Calculate padding based on image size and save it 125 | assert image_size is not None 126 | ih, iw = image_size if type(image_size) == list else [image_size, image_size] 127 | kh, kw = self.weight.size()[-2:] 128 | sh, sw = self.stride 129 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 130 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 131 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 132 | if pad_h > 0 or pad_w > 0: 133 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) 134 | else: 135 | self.static_padding = Identity() 136 | 137 | def forward(self, x): 138 | x = self.static_padding(x) 139 | x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 140 | return x 141 | 142 | 143 | class Identity(nn.Module): 144 | def __init__(self, ): 145 | super(Identity, self).__init__() 146 | 147 | def forward(self, input): 148 | return input 149 | 150 | 151 | ######################################################################## 152 | ############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ############## 153 | ######################################################################## 154 | 155 | 156 | def efficientnet_params(model_name): 157 | """ Map EfficientNet model name to parameter coefficients. """ 158 | params_dict = { 159 | # Coefficients: width,depth,res,dropout 160 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 161 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 162 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 163 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 164 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 165 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 166 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 167 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 168 | } 169 | return params_dict[model_name] 170 | 171 | 172 | class BlockDecoder(object): 173 | """ Block Decoder for readability, straight from the official TensorFlow repository """ 174 | 175 | @staticmethod 176 | def _decode_block_string(block_string): 177 | """ Gets a block through a string notation of arguments. """ 178 | assert isinstance(block_string, str) 179 | 180 | ops = block_string.split('_') 181 | options = {} 182 | for op in ops: 183 | splits = re.split(r'(\d.*)', op) 184 | if len(splits) >= 2: 185 | key, value = splits[:2] 186 | options[key] = value 187 | 188 | # Check stride 189 | assert (('s' in options and len(options['s']) == 1) or 190 | (len(options['s']) == 2 and options['s'][0] == options['s'][1])) 191 | 192 | return BlockArgs( 193 | kernel_size=int(options['k']), 194 | num_repeat=int(options['r']), 195 | input_filters=int(options['i']), 196 | output_filters=int(options['o']), 197 | expand_ratio=int(options['e']), 198 | id_skip=('noskip' not in block_string), 199 | se_ratio=float(options['se']) if 'se' in options else None, 200 | stride=[int(options['s'][0])]) 201 | 202 | @staticmethod 203 | def _encode_block_string(block): 204 | """Encodes a block to a string.""" 205 | args = [ 206 | 'r%d' % block.num_repeat, 207 | 'k%d' % block.kernel_size, 208 | 's%d%d' % (block.strides[0], block.strides[1]), 209 | 'e%s' % block.expand_ratio, 210 | 'i%d' % block.input_filters, 211 | 'o%d' % block.output_filters 212 | ] 213 | if 0 < block.se_ratio <= 1: 214 | args.append('se%s' % block.se_ratio) 215 | if block.id_skip is False: 216 | args.append('noskip') 217 | return '_'.join(args) 218 | 219 | @staticmethod 220 | def decode(string_list): 221 | """ 222 | Decodes a list of string notations to specify blocks inside the network. 223 | :param string_list: a list of strings, each string is a notation of block 224 | :return: a list of BlockArgs namedtuples of block args 225 | """ 226 | assert isinstance(string_list, list) 227 | blocks_args = [] 228 | for block_string in string_list: 229 | blocks_args.append(BlockDecoder._decode_block_string(block_string)) 230 | return blocks_args 231 | 232 | @staticmethod 233 | def encode(blocks_args): 234 | """ 235 | Encodes a list of BlockArgs to a list of strings. 236 | :param blocks_args: a list of BlockArgs namedtuples of block args 237 | :return: a list of strings, each string is a notation of block 238 | """ 239 | block_strings = [] 240 | for block in blocks_args: 241 | block_strings.append(BlockDecoder._encode_block_string(block)) 242 | return block_strings 243 | 244 | 245 | def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2, 246 | drop_connect_rate=0.2, image_size=None, num_classes=1000): 247 | """ Creates a efficientnet model. """ 248 | 249 | blocks_args = [ 250 | 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', 251 | 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', 252 | 'r3_k5_s22_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', 253 | 'r1_k3_s22_e6_i192_o320_se0.25', 254 | ] 255 | blocks_args = BlockDecoder.decode(blocks_args) 256 | 257 | global_params = GlobalParams( 258 | batch_norm_momentum=0.99, 259 | batch_norm_epsilon=1e-3, 260 | dropout_rate=dropout_rate, 261 | drop_connect_rate=drop_connect_rate, 262 | # data_format='channels_last', # removed, this is always true in PyTorch 263 | num_classes=num_classes, 264 | width_coefficient=width_coefficient, 265 | depth_coefficient=depth_coefficient, 266 | depth_divisor=8, 267 | min_depth=None, 268 | image_size=image_size, 269 | ) 270 | 271 | return blocks_args, global_params 272 | 273 | 274 | def get_model_params(model_name, override_params): 275 | """ Get the block args and global params for a given model """ 276 | if model_name.startswith('efficientnet'): 277 | w, d, s, p = efficientnet_params(model_name) 278 | # note: all models have drop connect rate = 0.2 279 | blocks_args, global_params = efficientnet( 280 | width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) 281 | else: 282 | raise NotImplementedError('model name is not pre-defined: %s' % model_name) 283 | if override_params: 284 | # ValueError will be raised here if override_params has fields not included in global_params. 285 | global_params = global_params._replace(**override_params) 286 | return blocks_args, global_params 287 | 288 | 289 | url_map = { 290 | 'efficientnet-b0': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b0-355c32eb.pth', 291 | 'efficientnet-b1': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b1-f1951068.pth', 292 | 'efficientnet-b2': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b2-8bb594d6.pth', 293 | 'efficientnet-b3': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b3-5fb5a3c3.pth', 294 | 'efficientnet-b4': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b4-6ed6700e.pth', 295 | 'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b5-b6417697.pth', 296 | 'efficientnet-b6': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b6-c76e70fd.pth', 297 | 'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth', 298 | } 299 | 300 | 301 | def load_pretrained_weights(model, model_name, load_fc=True): 302 | """ Loads pretrained weights, and downloads if loading for the first time. """ 303 | state_dict = model_zoo.load_url(url_map[model_name]) 304 | if load_fc: 305 | model.load_state_dict(state_dict) 306 | else: 307 | state_dict.pop('_fc.weight') 308 | state_dict.pop('_fc.bias') 309 | res = model.load_state_dict(state_dict, strict=False) 310 | assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights' 311 | print('Loaded pretrained weights for {}'.format(model_name)) -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 94 | 95 | 96 | 112 | 113 | 114 | 115 | 116 | true 117 | DEFINITION_ORDER 118 | 119 | 120 | 121 | 122 | 123 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 |