├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── backbone.py ├── coco_eval.py ├── detect_process.py ├── efficientService.py ├── efficientdet ├── config.py ├── dataset.py ├── loss.py ├── model.py └── utils.py ├── efficientdet_test.py ├── efficientdet_test_videos.py ├── efficientnet ├── __init__.py ├── model.py ├── utils.py └── utils_extra.py ├── projects ├── coco.yml └── shape.yml ├── test ├── img.png ├── img_inferred_d0_official.jpg ├── img_inferred_d0_this_repo.jpg └── img_inferred_d0_this_repo_0.jpg ├── train.py ├── tutorial └── train_shape.ipynb └── utils ├── sync_batchnorm ├── __init__.py ├── batchnorm.py ├── batchnorm_reimpl.py ├── comm.py ├── replicate.py └── unittest.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | weights/*.pt 2 | *.pt 3 | weights/ 4 | */__pycache__ 5 | */runs 6 | */.idea 7 | */weights 8 | idea/ 9 | idea/*.* 10 | __pycache__/ 11 | __pycache__/*.* 12 | 13 | /.idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | A system of object detection based on Yet-Another-EfficientDet-Pytorch 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | from io import BytesIO 4 | 5 | import requests as req 6 | from PIL import Image 7 | from flask import Flask, request, render_template 8 | 9 | import efficientService as service 10 | 11 | # flask web service 12 | app = Flask(__name__, template_folder="web") 13 | 14 | 15 | @app.route('/detect') 16 | def detect(): 17 | return render_template('detect.html') 18 | 19 | 20 | @app.route('/index') 21 | def index(): 22 | return render_template('index.html') 23 | 24 | 25 | @app.route('/detect/imageDetect', methods=['post']) 26 | def upload(): 27 | # step 1. receive image 28 | file = request.form.get('imageBase64Code') 29 | image_link = request.form.get("imageLink") 30 | 31 | if image_link: 32 | response = req.get(image_link) 33 | image = Image.open(BytesIO(response.content)) 34 | else: 35 | image = Image.open(BytesIO(base64.b64decode(file))) 36 | 37 | # step 2. detect image 38 | image_array = service.detect(image) 39 | 40 | # step 3. convert image_array to byte_array 41 | img = Image.fromarray(image_array, 'RGB') 42 | img_byte_array = io.BytesIO() 43 | img.save(img_byte_array, format='JPEG') 44 | 45 | # step 4. return image_info to page 46 | image_info = base64.b64encode(img_byte_array.getvalue()).decode('ascii') 47 | return image_info 48 | 49 | 50 | if __name__ == '__main__': 51 | app.jinja_env.auto_reload = True 52 | app.config['TEMPLATES_AUTO_RELOAD'] = True 53 | app.run(debug=False, port=8081) 54 | -------------------------------------------------------------------------------- /backbone.py: -------------------------------------------------------------------------------- 1 | # Author: Zylo117 2 | 3 | import math 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from efficientdet.model import BiFPN, Regressor, Classifier, EfficientNet 9 | from efficientdet.utils import Anchors 10 | 11 | 12 | class EfficientDetBackbone(nn.Module): 13 | def __init__(self, num_classes=80, compound_coef=0, load_weights=False, **kwargs): 14 | super(EfficientDetBackbone, self).__init__() 15 | self.compound_coef = compound_coef 16 | 17 | self.backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6] 18 | self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384] 19 | self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8] 20 | self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536] 21 | self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5] 22 | self.anchor_scale = [4., 4., 4., 4., 4., 4., 4., 5.] 23 | self.aspect_ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]) 24 | self.num_scales = len(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])) 25 | conv_channel_coef = { 26 | # the channels of P3/P4/P5. 27 | 0: [40, 112, 320], 28 | 1: [40, 112, 320], 29 | 2: [48, 120, 352], 30 | 3: [48, 136, 384], 31 | 4: [56, 160, 448], 32 | 5: [64, 176, 512], 33 | 6: [72, 200, 576], 34 | 7: [72, 200, 576], 35 | } 36 | 37 | num_anchors = len(self.aspect_ratios) * self.num_scales 38 | 39 | self.bifpn = nn.Sequential( 40 | *[BiFPN(self.fpn_num_filters[self.compound_coef], 41 | conv_channel_coef[compound_coef], 42 | True if _ == 0 else False, 43 | attention=True if compound_coef < 6 else False) 44 | for _ in range(self.fpn_cell_repeats[compound_coef])]) 45 | 46 | self.num_classes = num_classes 47 | self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors, 48 | num_layers=self.box_class_repeats[self.compound_coef]) 49 | self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors, 50 | num_classes=num_classes, 51 | num_layers=self.box_class_repeats[self.compound_coef]) 52 | 53 | self.anchors = Anchors(anchor_scale=self.anchor_scale[compound_coef], **kwargs) 54 | 55 | self.backbone_net = EfficientNet(self.backbone_compound_coef[compound_coef], load_weights) 56 | 57 | def freeze_bn(self): 58 | for m in self.modules(): 59 | if isinstance(m, nn.BatchNorm2d): 60 | m.eval() 61 | 62 | def forward(self, inputs): 63 | max_size = inputs.shape[-1] 64 | 65 | _, p3, p4, p5 = self.backbone_net(inputs) 66 | 67 | features = (p3, p4, p5) 68 | features = self.bifpn(features) 69 | 70 | regression = self.regressor(features) 71 | classification = self.classifier(features) 72 | anchors = self.anchors(inputs, inputs.dtype) 73 | 74 | return features, regression, classification, anchors 75 | 76 | def init_backbone(self, path): 77 | state_dict = torch.load(path) 78 | try: 79 | ret = self.load_state_dict(state_dict, strict=False) 80 | print(ret) 81 | except RuntimeError as e: 82 | print('Ignoring ' + str(e) + '"') 83 | -------------------------------------------------------------------------------- /coco_eval.py: -------------------------------------------------------------------------------- 1 | # Author: Zylo117 2 | 3 | """ 4 | COCO-Style Evaluations 5 | 6 | put images here datasets/your_project_name/annotations/val_set_name/*.jpg 7 | put annotations here datasets/your_project_name/annotations/instances_{val_set_name}.json 8 | put weights here /path/to/your/weights/*.pth 9 | change compound_coef 10 | 11 | """ 12 | 13 | import json 14 | import os 15 | 16 | import argparse 17 | import torch 18 | import yaml 19 | from tqdm import tqdm 20 | from pycocotools.coco import COCO 21 | from pycocotools.cocoeval import COCOeval 22 | 23 | from backbone import EfficientDetBackbone 24 | from efficientdet.utils import BBoxTransform, ClipBoxes 25 | from utils.utils import preprocess, invert_affine, postprocess 26 | 27 | ap = argparse.ArgumentParser() 28 | ap.add_argument('-p', '--project', type=str, default='coco', help='project file that contains parameters') 29 | ap.add_argument('-c', '--compound_coef', type=int, default=0, help='coefficients of efficientdet') 30 | ap.add_argument('-w', '--weights', type=str, default=None, help='/path/to/weights') 31 | ap.add_argument('--nms_threshold', type=float, default=0.5, help='nms threshold, don\'t change it if not for testing purposes') 32 | ap.add_argument('--cuda', type=bool, default=True) 33 | ap.add_argument('--device', type=int, default=0) 34 | ap.add_argument('--float16', type=bool, default=False) 35 | ap.add_argument('--override', type=bool, default=True, help='override previous bbox results file if exists') 36 | args = ap.parse_args() 37 | 38 | compound_coef = args.compound_coef 39 | nms_threshold = args.nms_threshold 40 | use_cuda = args.cuda 41 | gpu = args.device 42 | use_float16 = args.float16 43 | override_prev_results = args.override 44 | project_name = args.project 45 | weights_path = f'weights/efficientdet-d{compound_coef}.pth' if args.weights is None else args.weights 46 | 47 | print(f'running coco-style evaluation on project {project_name}, weights {weights_path}...') 48 | 49 | params = yaml.safe_load(open(f'projects/{project_name}.yml')) 50 | obj_list = params['obj_list'] 51 | 52 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536] 53 | 54 | 55 | def evaluate_coco(img_path, set_name, image_ids, coco, model, threshold=0.05): 56 | results = [] 57 | processed_image_ids = [] 58 | 59 | regressBoxes = BBoxTransform() 60 | clipBoxes = ClipBoxes() 61 | 62 | for image_id in tqdm(image_ids): 63 | image_info = coco.loadImgs(image_id)[0] 64 | image_path = img_path + image_info['file_name'] 65 | 66 | ori_imgs, framed_imgs, framed_metas = preprocess(image_path, max_size=input_sizes[compound_coef]) 67 | x = torch.from_numpy(framed_imgs[0]) 68 | 69 | if use_cuda: 70 | x = x.cuda(gpu) 71 | if use_float16: 72 | x = x.half() 73 | else: 74 | x = x.float() 75 | else: 76 | x = x.float() 77 | 78 | x = x.unsqueeze(0).permute(0, 3, 1, 2) 79 | features, regression, classification, anchors = model(x) 80 | 81 | preds = postprocess(x, 82 | anchors, regression, classification, 83 | regressBoxes, clipBoxes, 84 | threshold, nms_threshold) 85 | 86 | processed_image_ids.append(image_id) 87 | 88 | if not preds: 89 | continue 90 | 91 | preds = invert_affine(framed_metas, preds)[0] 92 | 93 | scores = preds['scores'] 94 | class_ids = preds['class_ids'] 95 | rois = preds['rois'] 96 | 97 | if rois.shape[0] > 0: 98 | # x1,y1,x2,y2 -> x1,y1,w,h 99 | rois[:, 2] -= rois[:, 0] 100 | rois[:, 3] -= rois[:, 1] 101 | 102 | bbox_score = scores 103 | 104 | for roi_id in range(rois.shape[0]): 105 | score = float(bbox_score[roi_id]) 106 | label = int(class_ids[roi_id]) 107 | box = rois[roi_id, :] 108 | 109 | if score < threshold: 110 | break 111 | image_result = { 112 | 'image_id': image_id, 113 | 'category_id': label + 1, 114 | 'score': float(score), 115 | 'bbox': box.tolist(), 116 | } 117 | 118 | results.append(image_result) 119 | 120 | if not len(results): 121 | raise Exception('the model does not provide any valid output, check model architecture and the data input') 122 | 123 | # write output 124 | filepath = f'{set_name}_bbox_results.json' 125 | if os.path.exists(filepath): 126 | os.remove(filepath) 127 | json.dump(results, open(filepath, 'w'), indent=4) 128 | 129 | return processed_image_ids 130 | 131 | 132 | def _eval(coco_gt, image_ids, pred_json_path): 133 | # load results in COCO evaluation tool 134 | coco_pred = coco_gt.loadRes(pred_json_path) 135 | 136 | # run COCO evaluation 137 | print('BBox') 138 | coco_eval = COCOeval(coco_gt, coco_pred, 'bbox') 139 | coco_eval.params.imgIds = image_ids 140 | coco_eval.evaluate() 141 | coco_eval.accumulate() 142 | coco_eval.summarize() 143 | 144 | 145 | if __name__ == '__main__': 146 | SET_NAME = params['val_set'] 147 | VAL_GT = f'datasets/{params["project_name"]}/annotations/instances_{SET_NAME}.json' 148 | VAL_IMGS = f'datasets/{params["project_name"]}/{SET_NAME}/' 149 | MAX_IMAGES = 10000 150 | coco_gt = COCO(VAL_GT) 151 | image_ids = coco_gt.getImgIds()[:MAX_IMAGES] 152 | 153 | if override_prev_results or not os.path.exists(f'{SET_NAME}_bbox_results.json'): 154 | model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list), 155 | ratios=eval(params['anchors_ratios']), scales=eval(params['anchors_scales'])) 156 | model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) 157 | model.requires_grad_(False) 158 | model.eval() 159 | 160 | if use_cuda: 161 | model.cuda(gpu) 162 | 163 | if use_float16: 164 | model.half() 165 | 166 | image_ids = evaluate_coco(VAL_IMGS, SET_NAME, image_ids, coco_gt, model) 167 | 168 | _eval(coco_gt, image_ids, f'{SET_NAME}_bbox_results.json') 169 | -------------------------------------------------------------------------------- /detect_process.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import json 4 | import time 5 | from io import BytesIO 6 | 7 | import redis as redis 8 | import requests as req 9 | from PIL import Image 10 | 11 | import efficientService as service 12 | 13 | # redis cache client 14 | RedisCache = redis.StrictRedis(host="localhost", port=6379, db=0) 15 | 16 | # the queue of expect to detect 17 | IMAGE_QUEUE = "imageQueue" 18 | 19 | # slice size every foreach 20 | 21 | BATCH_SIZE = 32 22 | # server sleep when queue>0 23 | 24 | SERVER_SLEEP = 0.1 25 | # server sleep when queue=0 26 | SERVER_SLEEP_IDLE = 0.5 27 | 28 | 29 | def detect_process(): 30 | while True: 31 | # 从redis中获取预测图像队列 32 | queue = RedisCache.lrange(IMAGE_QUEUE, 0, BATCH_SIZE - 1) 33 | if len(queue) < 1: 34 | time.sleep(SERVER_SLEEP) 35 | continue 36 | 37 | print("classify_process is running") 38 | 39 | # 遍历队列 40 | for item in queue: 41 | # step 1. 获取队列中的图像信息 42 | item = json.loads(item); 43 | image_key = item.get("imageKey") 44 | image_link = item.get("imageUrl") 45 | response = req.get(image_link) 46 | image = Image.open(BytesIO(response.content)) 47 | 48 | # step 2. detect image 识别图片 49 | image_array = service.detect(image) 50 | 51 | # step 3. convert image_array to byte_array 52 | img = Image.fromarray(image_array, 'RGB') 53 | img_byte_array = io.BytesIO() 54 | img.save(img_byte_array, format='JPEG') 55 | 56 | # step 4. set result_info in redis 57 | image_info = base64.b64encode(img_byte_array.getvalue()).decode('ascii') 58 | 59 | RedisCache.hset(name=image_key, key="consultOut", value=image_info) 60 | 61 | # 删除队列中已识别的图片信息 62 | RedisCache.ltrim(IMAGE_QUEUE, BATCH_SIZE, -1) 63 | 64 | time.sleep(SERVER_SLEEP) 65 | 66 | 67 | if __name__ == '__main__': 68 | print("start classify_process") 69 | detect_process() 70 | -------------------------------------------------------------------------------- /efficientService.py: -------------------------------------------------------------------------------- 1 | # Author: Zylo117 2 | 3 | """ 4 | Simple Inference Script of EfficientDet-Pytorch 5 | """ 6 | import random 7 | 8 | import cv2 9 | import numpy as np 10 | import torch 11 | from torch.backends import cudnn 12 | 13 | from backbone import EfficientDetBackbone 14 | from efficientdet.utils import BBoxTransform, ClipBoxes 15 | from utils.utils import invert_affine, postprocess, aspectaware_resize_padding 16 | 17 | compound_coef = 0 18 | force_input_size = None # set None to use default size 19 | 20 | # replace this part with your project's anchor config 21 | anchor_ratios = [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)] 22 | anchor_scales = [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)] 23 | 24 | threshold = 0.2 25 | iou_threshold = 0.2 26 | 27 | use_cuda = True 28 | use_float16 = False 29 | cudnn.fastest = True 30 | cudnn.benchmark = True 31 | 32 | obj_list = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 33 | 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 34 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie', 35 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 36 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 37 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 38 | 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv', 39 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 40 | 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 41 | 'toothbrush'] 42 | 43 | # tf bilinear interpolation is different from any other's, just make do 44 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536] 45 | input_size = input_sizes[compound_coef] if force_input_size is None else force_input_size 46 | 47 | model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list), 48 | ratios=anchor_ratios, scales=anchor_scales) 49 | model.load_state_dict(torch.load(f'weights/efficientdet-d{compound_coef}.pth')) 50 | model.requires_grad_(False) 51 | model.eval() 52 | 53 | if use_cuda: 54 | model = model.cuda() 55 | if use_float16: 56 | model = model.half() 57 | 58 | 59 | def detect(image): 60 | # convert image to array 61 | frame = np.array(image) 62 | 63 | # convert to cv format 64 | frames = frame[:, :, ::-1] 65 | 66 | ori_imgs, framed_imgs, framed_metas = image_preprocess(frames, max_size=input_size) 67 | 68 | if use_cuda: 69 | x = torch.stack([torch.from_numpy(fi).cuda() for fi in framed_imgs], 0) 70 | else: 71 | x = torch.stack([torch.from_numpy(fi) for fi in framed_imgs], 0) 72 | 73 | x = x.to(torch.float32 if not use_float16 else torch.float16).permute(0, 3, 1, 2) 74 | 75 | with torch.no_grad(): 76 | features, regression, classification, anchors = model(x) 77 | 78 | regressBoxes = BBoxTransform() 79 | clipBoxes = ClipBoxes() 80 | 81 | out = postprocess(x, 82 | anchors, regression, classification, 83 | regressBoxes, clipBoxes, 84 | threshold, iou_threshold) 85 | 86 | out = invert_affine(framed_metas, out) 87 | render_frame = display(out, frame, imshow=True, imwrite=False) 88 | return render_frame 89 | 90 | 91 | def image_preprocess(image_path, max_size=512, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)): 92 | ori_imgs = [image_path] 93 | normalized_imgs = [(img / 255 - mean) / std for img in ori_imgs] 94 | imgs_meta = [aspectaware_resize_padding(img[..., ::-1], max_size, max_size, 95 | means=None) for img in normalized_imgs] 96 | framed_imgs = [img_meta[0] for img_meta in imgs_meta] 97 | framed_metas = [img_meta[1:] for img_meta in imgs_meta] 98 | 99 | return ori_imgs, framed_imgs, framed_metas 100 | 101 | 102 | def display(preds, imgs, imshow=True, imwrite=False): 103 | imgs = [imgs] 104 | for i in range(len(imgs)): 105 | if len(preds[i]['rois']) == 0: 106 | continue 107 | 108 | for j in range(len(preds[i]['rois'])): 109 | (x1, y1, x2, y2) = preds[i]['rois'][j].astype(np.int) 110 | color = [random.randint(0, 255) for _ in range(3)] 111 | cv2.rectangle(imgs[i], (x1, y1), (x2, y2), color, 2) 112 | obj = obj_list[preds[i]['class_ids'][j]] 113 | score = float(preds[i]['scores'][j]) 114 | 115 | label = obj 116 | label_size = cv2.getTextSize(label, 0, fontScale=2 / 3, thickness=1)[0] 117 | cv2.rectangle(imgs[i], (x1, y1), (x1 + label_size[0], y1 - label_size[1] - 3), color, -1) 118 | # cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score), 119 | # (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 120 | # [225, 255, 255], 1, lineType=cv2.LINE_AA) 121 | cv2.putText(imgs[i], label, (x1, y1 - 8), 0, 2 / 3, [225, 255, 255], thickness=1, 122 | lineType=cv2.LINE_AA) 123 | 124 | return imgs[i] 125 | -------------------------------------------------------------------------------- /efficientdet/config.py: -------------------------------------------------------------------------------- 1 | COCO_CLASSES = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", 2 | "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", 3 | "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", 4 | "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", 5 | "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", 6 | "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", 7 | "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", 8 | "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", 9 | "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", 10 | "teddy bear", "hair drier", "toothbrush"] 11 | 12 | colors = [(39, 129, 113), (164, 80, 133), (83, 122, 114), (99, 81, 172), (95, 56, 104), (37, 84, 86), (14, 89, 122), 13 | (80, 7, 65), (10, 102, 25), (90, 185, 109), (106, 110, 132), (169, 158, 85), (188, 185, 26), (103, 1, 17), 14 | (82, 144, 81), (92, 7, 184), (49, 81, 155), (179, 177, 69), (93, 187, 158), (13, 39, 73), (12, 50, 60), 15 | (16, 179, 33), (112, 69, 165), (15, 139, 63), (33, 191, 159), (182, 173, 32), (34, 113, 133), (90, 135, 34), 16 | (53, 34, 86), (141, 35, 190), (6, 171, 8), (118, 76, 112), (89, 60, 55), (15, 54, 88), (112, 75, 181), 17 | (42, 147, 38), (138, 52, 63), (128, 65, 149), (106, 103, 24), (168, 33, 45), (28, 136, 135), (86, 91, 108), 18 | (52, 11, 76), (142, 6, 189), (57, 81, 168), (55, 19, 148), (182, 101, 89), (44, 65, 179), (1, 33, 26), 19 | (122, 164, 26), (70, 63, 134), (137, 106, 82), (120, 118, 52), (129, 74, 42), (182, 147, 112), (22, 157, 50), 20 | (56, 50, 20), (2, 22, 177), (156, 100, 106), (21, 35, 42), (13, 8, 121), (142, 92, 28), (45, 118, 33), 21 | (105, 118, 30), (7, 185, 124), (46, 34, 146), (105, 184, 169), (22, 18, 5), (147, 71, 73), (181, 64, 91), 22 | (31, 39, 184), (164, 179, 33), (96, 50, 18), (95, 15, 106), (113, 68, 54), (136, 116, 112), (119, 139, 130), 23 | (31, 139, 34), (66, 6, 127), (62, 39, 2), (49, 99, 180), (49, 119, 155), (153, 50, 183), (125, 38, 3), 24 | (129, 87, 143), (49, 87, 40), (128, 62, 120), (73, 85, 148), (28, 144, 118), (29, 9, 24), (175, 45, 108), 25 | (81, 175, 64), (178, 19, 157), (74, 188, 190), (18, 114, 2), (62, 128, 96), (21, 3, 150), (0, 6, 95), 26 | (2, 20, 184), (122, 37, 185)] 27 | -------------------------------------------------------------------------------- /efficientdet/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from torch.utils.data import Dataset, DataLoader 6 | from pycocotools.coco import COCO 7 | import cv2 8 | 9 | 10 | class CocoDataset(Dataset): 11 | def __init__(self, root_dir, set='train2017', transform=None): 12 | 13 | self.root_dir = root_dir 14 | self.set_name = set 15 | self.transform = transform 16 | 17 | self.coco = COCO(os.path.join(self.root_dir, 'annotations', 'instances_' + self.set_name + '.json')) 18 | self.image_ids = self.coco.getImgIds() 19 | 20 | self.load_classes() 21 | 22 | def load_classes(self): 23 | 24 | # load class names (name -> label) 25 | categories = self.coco.loadCats(self.coco.getCatIds()) 26 | categories.sort(key=lambda x: x['id']) 27 | 28 | self.classes = {} 29 | self.coco_labels = {} 30 | self.coco_labels_inverse = {} 31 | for c in categories: 32 | self.coco_labels[len(self.classes)] = c['id'] 33 | self.coco_labels_inverse[c['id']] = len(self.classes) 34 | self.classes[c['name']] = len(self.classes) 35 | 36 | # also load the reverse (label -> name) 37 | self.labels = {} 38 | for key, value in self.classes.items(): 39 | self.labels[value] = key 40 | 41 | def __len__(self): 42 | return len(self.image_ids) 43 | 44 | def __getitem__(self, idx): 45 | 46 | img = self.load_image(idx) 47 | annot = self.load_annotations(idx) 48 | sample = {'img': img, 'annot': annot} 49 | if self.transform: 50 | sample = self.transform(sample) 51 | return sample 52 | 53 | def load_image(self, image_index): 54 | image_info = self.coco.loadImgs(self.image_ids[image_index])[0] 55 | path = os.path.join(self.root_dir, self.set_name, image_info['file_name']) 56 | img = cv2.imread(path) 57 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 58 | 59 | return img.astype(np.float32) / 255. 60 | 61 | def load_annotations(self, image_index): 62 | # get ground truth annotations 63 | annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False) 64 | annotations = np.zeros((0, 5)) 65 | 66 | # some images appear to miss annotations 67 | if len(annotations_ids) == 0: 68 | return annotations 69 | 70 | # parse annotations 71 | coco_annotations = self.coco.loadAnns(annotations_ids) 72 | for idx, a in enumerate(coco_annotations): 73 | 74 | # some annotations have basically no width / height, skip them 75 | if a['bbox'][2] < 1 or a['bbox'][3] < 1: 76 | continue 77 | 78 | annotation = np.zeros((1, 5)) 79 | annotation[0, :4] = a['bbox'] 80 | annotation[0, 4] = self.coco_label_to_label(a['category_id']) 81 | annotations = np.append(annotations, annotation, axis=0) 82 | 83 | # transform from [x, y, w, h] to [x1, y1, x2, y2] 84 | annotations[:, 2] = annotations[:, 0] + annotations[:, 2] 85 | annotations[:, 3] = annotations[:, 1] + annotations[:, 3] 86 | 87 | return annotations 88 | 89 | def coco_label_to_label(self, coco_label): 90 | return self.coco_labels_inverse[coco_label] 91 | 92 | def label_to_coco_label(self, label): 93 | return self.coco_labels[label] 94 | 95 | 96 | def collater(data): 97 | imgs = [s['img'] for s in data] 98 | annots = [s['annot'] for s in data] 99 | scales = [s['scale'] for s in data] 100 | 101 | imgs = torch.from_numpy(np.stack(imgs, axis=0)) 102 | 103 | max_num_annots = max(annot.shape[0] for annot in annots) 104 | 105 | if max_num_annots > 0: 106 | 107 | annot_padded = torch.ones((len(annots), max_num_annots, 5)) * -1 108 | 109 | for idx, annot in enumerate(annots): 110 | if annot.shape[0] > 0: 111 | annot_padded[idx, :annot.shape[0], :] = annot 112 | else: 113 | annot_padded = torch.ones((len(annots), 1, 5)) * -1 114 | 115 | imgs = imgs.permute(0, 3, 1, 2) 116 | 117 | return {'img': imgs, 'annot': annot_padded, 'scale': scales} 118 | 119 | 120 | class Resizer(object): 121 | """Convert ndarrays in sample to Tensors.""" 122 | 123 | def __init__(self, img_size=512): 124 | self.img_size = img_size 125 | 126 | def __call__(self, sample): 127 | image, annots = sample['img'], sample['annot'] 128 | height, width, _ = image.shape 129 | if height > width: 130 | scale = self.img_size / height 131 | resized_height = self.img_size 132 | resized_width = int(width * scale) 133 | else: 134 | scale = self.img_size / width 135 | resized_height = int(height * scale) 136 | resized_width = self.img_size 137 | 138 | image = cv2.resize(image, (resized_width, resized_height), interpolation=cv2.INTER_LINEAR) 139 | 140 | new_image = np.zeros((self.img_size, self.img_size, 3)) 141 | new_image[0:resized_height, 0:resized_width] = image 142 | 143 | annots[:, :4] *= scale 144 | 145 | return {'img': torch.from_numpy(new_image).to(torch.float32), 'annot': torch.from_numpy(annots), 'scale': scale} 146 | 147 | 148 | class Augmenter(object): 149 | """Convert ndarrays in sample to Tensors.""" 150 | 151 | def __call__(self, sample, flip_x=0.5): 152 | if np.random.rand() < flip_x: 153 | image, annots = sample['img'], sample['annot'] 154 | image = image[:, ::-1, :] 155 | 156 | rows, cols, channels = image.shape 157 | 158 | x1 = annots[:, 0].copy() 159 | x2 = annots[:, 2].copy() 160 | 161 | x_tmp = x1.copy() 162 | 163 | annots[:, 0] = cols - x2 164 | annots[:, 2] = cols - x_tmp 165 | 166 | sample = {'img': image, 'annot': annots} 167 | 168 | return sample 169 | 170 | 171 | class Normalizer(object): 172 | 173 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 174 | self.mean = np.array([[mean]]) 175 | self.std = np.array([[std]]) 176 | 177 | def __call__(self, sample): 178 | image, annots = sample['img'], sample['annot'] 179 | 180 | return {'img': ((image.astype(np.float32) - self.mean) / self.std), 'annot': annots} 181 | -------------------------------------------------------------------------------- /efficientdet/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import cv2 4 | import numpy as np 5 | 6 | from efficientdet.utils import BBoxTransform, ClipBoxes 7 | from utils.utils import postprocess, invert_affine, display 8 | 9 | 10 | def calc_iou(a, b): 11 | # a(anchor) [boxes, (y1, x1, y2, x2)] 12 | # b(gt, coco-style) [boxes, (x1, y1, x2, y2)] 13 | 14 | area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]) 15 | iw = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 0]) 16 | ih = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 1]) 17 | iw = torch.clamp(iw, min=0) 18 | ih = torch.clamp(ih, min=0) 19 | ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih 20 | ua = torch.clamp(ua, min=1e-8) 21 | intersection = iw * ih 22 | IoU = intersection / ua 23 | 24 | return IoU 25 | 26 | 27 | class FocalLoss(nn.Module): 28 | def __init__(self): 29 | super(FocalLoss, self).__init__() 30 | 31 | def forward(self, classifications, regressions, anchors, annotations, **kwargs): 32 | alpha = 0.25 33 | gamma = 2.0 34 | batch_size = classifications.shape[0] 35 | classification_losses = [] 36 | regression_losses = [] 37 | 38 | anchor = anchors[0, :, :] # assuming all image sizes are the same, which it is 39 | dtype = anchors.dtype 40 | 41 | anchor_widths = anchor[:, 3] - anchor[:, 1] 42 | anchor_heights = anchor[:, 2] - anchor[:, 0] 43 | anchor_ctr_x = anchor[:, 1] + 0.5 * anchor_widths 44 | anchor_ctr_y = anchor[:, 0] + 0.5 * anchor_heights 45 | 46 | for j in range(batch_size): 47 | 48 | classification = classifications[j, :, :] 49 | regression = regressions[j, :, :] 50 | 51 | bbox_annotation = annotations[j] 52 | bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1] 53 | 54 | if bbox_annotation.shape[0] == 0: 55 | if torch.cuda.is_available(): 56 | regression_losses.append(torch.tensor(0).to(dtype).cuda()) 57 | classification_losses.append(torch.tensor(0).to(dtype).cuda()) 58 | else: 59 | regression_losses.append(torch.tensor(0).to(dtype)) 60 | classification_losses.append(torch.tensor(0).to(dtype)) 61 | 62 | continue 63 | 64 | classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4) 65 | 66 | IoU = calc_iou(anchor[:, :], bbox_annotation[:, :4]) 67 | 68 | IoU_max, IoU_argmax = torch.max(IoU, dim=1) 69 | 70 | # compute the loss for classification 71 | targets = torch.ones_like(classification) * -1 72 | if torch.cuda.is_available(): 73 | targets = targets.cuda() 74 | 75 | targets[torch.lt(IoU_max, 0.4), :] = 0 76 | 77 | positive_indices = torch.ge(IoU_max, 0.5) 78 | 79 | num_positive_anchors = positive_indices.sum() 80 | 81 | assigned_annotations = bbox_annotation[IoU_argmax, :] 82 | 83 | targets[positive_indices, :] = 0 84 | targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1 85 | 86 | alpha_factor = torch.ones_like(targets) * alpha 87 | if torch.cuda.is_available(): 88 | alpha_factor = alpha_factor.cuda() 89 | 90 | alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor) 91 | focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification) 92 | focal_weight = alpha_factor * torch.pow(focal_weight, gamma) 93 | 94 | bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification)) 95 | 96 | cls_loss = focal_weight * bce 97 | 98 | zeros = torch.zeros_like(cls_loss) 99 | if torch.cuda.is_available(): 100 | zeros = zeros.cuda() 101 | cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros) 102 | 103 | classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0)) 104 | 105 | if positive_indices.sum() > 0: 106 | assigned_annotations = assigned_annotations[positive_indices, :] 107 | 108 | anchor_widths_pi = anchor_widths[positive_indices] 109 | anchor_heights_pi = anchor_heights[positive_indices] 110 | anchor_ctr_x_pi = anchor_ctr_x[positive_indices] 111 | anchor_ctr_y_pi = anchor_ctr_y[positive_indices] 112 | 113 | gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0] 114 | gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1] 115 | gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths 116 | gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights 117 | 118 | # efficientdet style 119 | gt_widths = torch.clamp(gt_widths, min=1) 120 | gt_heights = torch.clamp(gt_heights, min=1) 121 | 122 | targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi 123 | targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi 124 | targets_dw = torch.log(gt_widths / anchor_widths_pi) 125 | targets_dh = torch.log(gt_heights / anchor_heights_pi) 126 | 127 | targets = torch.stack((targets_dy, targets_dx, targets_dh, targets_dw)) 128 | targets = targets.t() 129 | 130 | regression_diff = torch.abs(targets - regression[positive_indices, :]) 131 | 132 | regression_loss = torch.where( 133 | torch.le(regression_diff, 1.0 / 9.0), 134 | 0.5 * 9.0 * torch.pow(regression_diff, 2), 135 | regression_diff - 0.5 / 9.0 136 | ) 137 | regression_losses.append(regression_loss.mean()) 138 | else: 139 | if torch.cuda.is_available(): 140 | regression_losses.append(torch.tensor(0).to(dtype).cuda()) 141 | else: 142 | regression_losses.append(torch.tensor(0).to(dtype)) 143 | 144 | # debug 145 | imgs = kwargs.get('imgs', None) 146 | if imgs is not None: 147 | regressBoxes = BBoxTransform() 148 | clipBoxes = ClipBoxes() 149 | obj_list = kwargs.get('obj_list', None) 150 | out = postprocess(imgs.detach(), 151 | torch.stack([anchors[0]] * imgs.shape[0], 0).detach(), regressions.detach(), classifications.detach(), 152 | regressBoxes, clipBoxes, 153 | 0.5, 0.3) 154 | imgs = imgs.permute(0, 2, 3, 1).cpu().numpy() 155 | imgs = ((imgs * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255).astype(np.uint8) 156 | imgs = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in imgs] 157 | display(out, imgs, obj_list, imshow=False, imwrite=True) 158 | 159 | return torch.stack(classification_losses).mean(dim=0, keepdim=True), \ 160 | torch.stack(regression_losses).mean(dim=0, keepdim=True) 161 | -------------------------------------------------------------------------------- /efficientdet/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torchvision.ops.boxes import nms as nms_torch 4 | 5 | from efficientnet import EfficientNet as EffNet 6 | from efficientnet.utils import MemoryEfficientSwish, Swish 7 | from efficientnet.utils_extra import Conv2dStaticSamePadding, MaxPool2dStaticSamePadding 8 | 9 | 10 | def nms(dets, thresh): 11 | return nms_torch(dets[:, :4], dets[:, 4], thresh) 12 | 13 | 14 | class SeparableConvBlock(nn.Module): 15 | """ 16 | created by Zylo117 17 | """ 18 | 19 | def __init__(self, in_channels, out_channels=None, norm=True, activation=False, onnx_export=False): 20 | super(SeparableConvBlock, self).__init__() 21 | if out_channels is None: 22 | out_channels = in_channels 23 | 24 | # Q: whether separate conv 25 | # share bias between depthwise_conv and pointwise_conv 26 | # or just pointwise_conv apply bias. 27 | # A: Confirmed, just pointwise_conv applies bias, depthwise_conv has no bias. 28 | 29 | self.depthwise_conv = Conv2dStaticSamePadding(in_channels, in_channels, 30 | kernel_size=3, stride=1, groups=in_channels, bias=False) 31 | self.pointwise_conv = Conv2dStaticSamePadding(in_channels, out_channels, kernel_size=1, stride=1) 32 | 33 | self.norm = norm 34 | if self.norm: 35 | # Warning: pytorch momentum is different from tensorflow's, momentum_pytorch = 1 - momentum_tensorflow 36 | self.bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.01, eps=1e-3) 37 | 38 | self.activation = activation 39 | if self.activation: 40 | self.swish = MemoryEfficientSwish() if not onnx_export else Swish() 41 | 42 | def forward(self, x): 43 | x = self.depthwise_conv(x) 44 | x = self.pointwise_conv(x) 45 | 46 | if self.norm: 47 | x = self.bn(x) 48 | 49 | if self.activation: 50 | x = self.swish(x) 51 | 52 | return x 53 | 54 | 55 | class BiFPN(nn.Module): 56 | """ 57 | modified by Zylo117 58 | """ 59 | 60 | def __init__(self, num_channels, conv_channels, first_time=False, epsilon=1e-4, onnx_export=False, attention=True): 61 | """ 62 | 63 | Args: 64 | num_channels: 65 | conv_channels: 66 | first_time: whether the input comes directly from the efficientnet, 67 | if True, downchannel it first, and downsample P5 to generate P6 then P7 68 | epsilon: epsilon of fast weighted attention sum of BiFPN, not the BN's epsilon 69 | onnx_export: if True, use Swish instead of MemoryEfficientSwish 70 | """ 71 | super(BiFPN, self).__init__() 72 | self.epsilon = epsilon 73 | # Conv layers 74 | self.conv6_up = SeparableConvBlock(num_channels, onnx_export=onnx_export) 75 | self.conv5_up = SeparableConvBlock(num_channels, onnx_export=onnx_export) 76 | self.conv4_up = SeparableConvBlock(num_channels, onnx_export=onnx_export) 77 | self.conv3_up = SeparableConvBlock(num_channels, onnx_export=onnx_export) 78 | self.conv4_down = SeparableConvBlock(num_channels, onnx_export=onnx_export) 79 | self.conv5_down = SeparableConvBlock(num_channels, onnx_export=onnx_export) 80 | self.conv6_down = SeparableConvBlock(num_channels, onnx_export=onnx_export) 81 | self.conv7_down = SeparableConvBlock(num_channels, onnx_export=onnx_export) 82 | 83 | # Feature scaling layers 84 | self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest') 85 | self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest') 86 | self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest') 87 | self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest') 88 | 89 | self.p4_downsample = MaxPool2dStaticSamePadding(3, 2) 90 | self.p5_downsample = MaxPool2dStaticSamePadding(3, 2) 91 | self.p6_downsample = MaxPool2dStaticSamePadding(3, 2) 92 | self.p7_downsample = MaxPool2dStaticSamePadding(3, 2) 93 | 94 | self.swish = MemoryEfficientSwish() if not onnx_export else Swish() 95 | 96 | self.first_time = first_time 97 | if self.first_time: 98 | self.p5_down_channel = nn.Sequential( 99 | Conv2dStaticSamePadding(conv_channels[2], num_channels, 1), 100 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), 101 | ) 102 | self.p4_down_channel = nn.Sequential( 103 | Conv2dStaticSamePadding(conv_channels[1], num_channels, 1), 104 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), 105 | ) 106 | self.p3_down_channel = nn.Sequential( 107 | Conv2dStaticSamePadding(conv_channels[0], num_channels, 1), 108 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), 109 | ) 110 | 111 | self.p5_to_p6 = nn.Sequential( 112 | Conv2dStaticSamePadding(conv_channels[2], num_channels, 1), 113 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), 114 | MaxPool2dStaticSamePadding(3, 2) 115 | ) 116 | self.p6_to_p7 = nn.Sequential( 117 | MaxPool2dStaticSamePadding(3, 2) 118 | ) 119 | 120 | self.p4_down_channel_2 = nn.Sequential( 121 | Conv2dStaticSamePadding(conv_channels[1], num_channels, 1), 122 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), 123 | ) 124 | self.p5_down_channel_2 = nn.Sequential( 125 | Conv2dStaticSamePadding(conv_channels[2], num_channels, 1), 126 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), 127 | ) 128 | 129 | # Weight 130 | self.p6_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) 131 | self.p6_w1_relu = nn.ReLU() 132 | self.p5_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) 133 | self.p5_w1_relu = nn.ReLU() 134 | self.p4_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) 135 | self.p4_w1_relu = nn.ReLU() 136 | self.p3_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) 137 | self.p3_w1_relu = nn.ReLU() 138 | 139 | self.p4_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) 140 | self.p4_w2_relu = nn.ReLU() 141 | self.p5_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) 142 | self.p5_w2_relu = nn.ReLU() 143 | self.p6_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) 144 | self.p6_w2_relu = nn.ReLU() 145 | self.p7_w2 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) 146 | self.p7_w2_relu = nn.ReLU() 147 | 148 | self.attention = attention 149 | 150 | def forward(self, inputs): 151 | """ 152 | illustration of a minimal bifpn unit 153 | P7_0 -------------------------> P7_2 --------> 154 | |-------------| ↑ 155 | ↓ | 156 | P6_0 ---------> P6_1 ---------> P6_2 --------> 157 | |-------------|--------------↑ ↑ 158 | ↓ | 159 | P5_0 ---------> P5_1 ---------> P5_2 --------> 160 | |-------------|--------------↑ ↑ 161 | ↓ | 162 | P4_0 ---------> P4_1 ---------> P4_2 --------> 163 | |-------------|--------------↑ ↑ 164 | |--------------↓ | 165 | P3_0 -------------------------> P3_2 --------> 166 | """ 167 | 168 | # downsample channels using same-padding conv2d to target phase's if not the same 169 | # judge: same phase as target, 170 | # if same, pass; 171 | # elif earlier phase, downsample to target phase's by pooling 172 | # elif later phase, upsample to target phase's by nearest interpolation 173 | 174 | if self.attention: 175 | p3_out, p4_out, p5_out, p6_out, p7_out = self._forward_fast_attention(inputs) 176 | else: 177 | p3_out, p4_out, p5_out, p6_out, p7_out = self._forward(inputs) 178 | 179 | return p3_out, p4_out, p5_out, p6_out, p7_out 180 | 181 | def _forward_fast_attention(self, inputs): 182 | if self.first_time: 183 | p3, p4, p5 = inputs 184 | 185 | p6_in = self.p5_to_p6(p5) 186 | p7_in = self.p6_to_p7(p6_in) 187 | 188 | p3_in = self.p3_down_channel(p3) 189 | p4_in = self.p4_down_channel(p4) 190 | p5_in = self.p5_down_channel(p5) 191 | 192 | else: 193 | # P3_0, P4_0, P5_0, P6_0 and P7_0 194 | p3_in, p4_in, p5_in, p6_in, p7_in = inputs 195 | 196 | # P7_0 to P7_2 197 | 198 | # Weights for P6_0 and P7_0 to P6_1 199 | p6_w1 = self.p6_w1_relu(self.p6_w1) 200 | weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon) 201 | # Connections for P6_0 and P7_0 to P6_1 respectively 202 | p6_up = self.conv6_up(self.swish(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in))) 203 | 204 | # Weights for P5_0 and P6_0 to P5_1 205 | p5_w1 = self.p5_w1_relu(self.p5_w1) 206 | weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon) 207 | # Connections for P5_0 and P6_0 to P5_1 respectively 208 | p5_up = self.conv5_up(self.swish(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_up))) 209 | 210 | # Weights for P4_0 and P5_0 to P4_1 211 | p4_w1 = self.p4_w1_relu(self.p4_w1) 212 | weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon) 213 | # Connections for P4_0 and P5_0 to P4_1 respectively 214 | p4_up = self.conv4_up(self.swish(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_up))) 215 | 216 | # Weights for P3_0 and P4_1 to P3_2 217 | p3_w1 = self.p3_w1_relu(self.p3_w1) 218 | weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon) 219 | # Connections for P3_0 and P4_1 to P3_2 respectively 220 | p3_out = self.conv3_up(self.swish(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_up))) 221 | 222 | if self.first_time: 223 | p4_in = self.p4_down_channel_2(p4) 224 | p5_in = self.p5_down_channel_2(p5) 225 | 226 | # Weights for P4_0, P4_1 and P3_2 to P4_2 227 | p4_w2 = self.p4_w2_relu(self.p4_w2) 228 | weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon) 229 | # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively 230 | p4_out = self.conv4_down( 231 | self.swish(weight[0] * p4_in + weight[1] * p4_up + weight[2] * self.p4_downsample(p3_out))) 232 | 233 | # Weights for P5_0, P5_1 and P4_2 to P5_2 234 | p5_w2 = self.p5_w2_relu(self.p5_w2) 235 | weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon) 236 | # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively 237 | p5_out = self.conv5_down( 238 | self.swish(weight[0] * p5_in + weight[1] * p5_up + weight[2] * self.p5_downsample(p4_out))) 239 | 240 | # Weights for P6_0, P6_1 and P5_2 to P6_2 241 | p6_w2 = self.p6_w2_relu(self.p6_w2) 242 | weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon) 243 | # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively 244 | p6_out = self.conv6_down( 245 | self.swish(weight[0] * p6_in + weight[1] * p6_up + weight[2] * self.p6_downsample(p5_out))) 246 | 247 | # Weights for P7_0 and P6_2 to P7_2 248 | p7_w2 = self.p7_w2_relu(self.p7_w2) 249 | weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon) 250 | # Connections for P7_0 and P6_2 to P7_2 251 | p7_out = self.conv7_down(self.swish(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out))) 252 | 253 | return p3_out, p4_out, p5_out, p6_out, p7_out 254 | 255 | def _forward(self, inputs): 256 | if self.first_time: 257 | p3, p4, p5 = inputs 258 | 259 | p6_in = self.p5_to_p6(p5) 260 | p7_in = self.p6_to_p7(p6_in) 261 | 262 | p3_in = self.p3_down_channel(p3) 263 | p4_in = self.p4_down_channel(p4) 264 | p5_in = self.p5_down_channel(p5) 265 | 266 | else: 267 | # P3_0, P4_0, P5_0, P6_0 and P7_0 268 | p3_in, p4_in, p5_in, p6_in, p7_in = inputs 269 | 270 | # P7_0 to P7_2 271 | 272 | # Connections for P6_0 and P7_0 to P6_1 respectively 273 | p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_in))) 274 | 275 | # Connections for P5_0 and P6_0 to P5_1 respectively 276 | p5_up = self.conv5_up(self.swish(p5_in + self.p5_upsample(p6_up))) 277 | 278 | # Connections for P4_0 and P5_0 to P4_1 respectively 279 | p4_up = self.conv4_up(self.swish(p4_in + self.p4_upsample(p5_up))) 280 | 281 | # Connections for P3_0 and P4_1 to P3_2 respectively 282 | p3_out = self.conv3_up(self.swish(p3_in + self.p3_upsample(p4_up))) 283 | 284 | if self.first_time: 285 | p4_in = self.p4_down_channel_2(p4) 286 | p5_in = self.p5_down_channel_2(p5) 287 | 288 | # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively 289 | p4_out = self.conv4_down( 290 | self.swish(p4_in + p4_up + self.p4_downsample(p3_out))) 291 | 292 | # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively 293 | p5_out = self.conv5_down( 294 | self.swish(p5_in + p5_up + self.p5_downsample(p4_out))) 295 | 296 | # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively 297 | p6_out = self.conv6_down( 298 | self.swish(p6_in + p6_up + self.p6_downsample(p5_out))) 299 | 300 | # Connections for P7_0 and P6_2 to P7_2 301 | p7_out = self.conv7_down(self.swish(p7_in + self.p7_downsample(p6_out))) 302 | 303 | return p3_out, p4_out, p5_out, p6_out, p7_out 304 | 305 | 306 | class Regressor(nn.Module): 307 | """ 308 | modified by Zylo117 309 | """ 310 | 311 | def __init__(self, in_channels, num_anchors, num_layers, onnx_export=False): 312 | super(Regressor, self).__init__() 313 | self.num_layers = num_layers 314 | self.num_layers = num_layers 315 | 316 | self.conv_list = nn.ModuleList( 317 | [SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)]) 318 | self.bn_list = nn.ModuleList( 319 | [nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in 320 | range(5)]) 321 | self.header = SeparableConvBlock(in_channels, num_anchors * 4, norm=False, activation=False) 322 | self.swish = MemoryEfficientSwish() if not onnx_export else Swish() 323 | 324 | def forward(self, inputs): 325 | feats = [] 326 | for feat, bn_list in zip(inputs, self.bn_list): 327 | for i, bn, conv in zip(range(self.num_layers), bn_list, self.conv_list): 328 | feat = conv(feat) 329 | feat = bn(feat) 330 | feat = self.swish(feat) 331 | feat = self.header(feat) 332 | 333 | feat = feat.permute(0, 2, 3, 1) 334 | feat = feat.contiguous().view(feat.shape[0], -1, 4) 335 | 336 | feats.append(feat) 337 | 338 | feats = torch.cat(feats, dim=1) 339 | 340 | return feats 341 | 342 | 343 | class Classifier(nn.Module): 344 | """ 345 | modified by Zylo117 346 | """ 347 | 348 | def __init__(self, in_channels, num_anchors, num_classes, num_layers, onnx_export=False): 349 | super(Classifier, self).__init__() 350 | self.num_anchors = num_anchors 351 | self.num_classes = num_classes 352 | self.num_layers = num_layers 353 | self.conv_list = nn.ModuleList( 354 | [SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)]) 355 | self.bn_list = nn.ModuleList( 356 | [nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in 357 | range(5)]) 358 | self.header = SeparableConvBlock(in_channels, num_anchors * num_classes, norm=False, activation=False) 359 | self.swish = MemoryEfficientSwish() if not onnx_export else Swish() 360 | 361 | def forward(self, inputs): 362 | feats = [] 363 | for feat, bn_list in zip(inputs, self.bn_list): 364 | for i, bn, conv in zip(range(self.num_layers), bn_list, self.conv_list): 365 | feat = conv(feat) 366 | feat = bn(feat) 367 | feat = self.swish(feat) 368 | feat = self.header(feat) 369 | 370 | feat = feat.permute(0, 2, 3, 1) 371 | feat = feat.contiguous().view(feat.shape[0], feat.shape[1], feat.shape[2], self.num_anchors, 372 | self.num_classes) 373 | feat = feat.contiguous().view(feat.shape[0], -1, self.num_classes) 374 | 375 | feats.append(feat) 376 | 377 | feats = torch.cat(feats, dim=1) 378 | feats = feats.sigmoid() 379 | 380 | return feats 381 | 382 | 383 | class EfficientNet(nn.Module): 384 | """ 385 | modified by Zylo117 386 | """ 387 | 388 | def __init__(self, compound_coef, load_weights=False): 389 | super(EfficientNet, self).__init__() 390 | model = EffNet.from_pretrained(f'efficientnet-b{compound_coef}', load_weights) 391 | del model._conv_head 392 | del model._bn1 393 | del model._avg_pooling 394 | del model._dropout 395 | del model._fc 396 | self.model = model 397 | 398 | def forward(self, x): 399 | x = self.model._conv_stem(x) 400 | x = self.model._bn0(x) 401 | x = self.model._swish(x) 402 | feature_maps = [] 403 | 404 | # TODO: temporarily storing extra tensor last_x and del it later might not be a good idea, 405 | # try recording stride changing when creating efficientnet, 406 | # and then apply it here. 407 | last_x = None 408 | for idx, block in enumerate(self.model._blocks): 409 | drop_connect_rate = self.model._global_params.drop_connect_rate 410 | if drop_connect_rate: 411 | drop_connect_rate *= float(idx) / len(self.model._blocks) 412 | x = block(x, drop_connect_rate=drop_connect_rate) 413 | 414 | if block._depthwise_conv.stride == [2, 2]: 415 | feature_maps.append(last_x) 416 | elif idx == len(self.model._blocks) - 1: 417 | feature_maps.append(x) 418 | last_x = x 419 | del last_x 420 | return feature_maps[1:] 421 | 422 | 423 | if __name__ == '__main__': 424 | from tensorboardX import SummaryWriter 425 | 426 | 427 | def count_parameters(model): 428 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 429 | -------------------------------------------------------------------------------- /efficientdet/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | class BBoxTransform(nn.Module): 8 | def forward(self, anchors, regression): 9 | """ 10 | decode_box_outputs adapted from https://github.com/google/automl/blob/master/efficientdet/anchors.py 11 | 12 | Args: 13 | anchors: [batchsize, boxes, (y1, x1, y2, x2)] 14 | regression: [batchsize, boxes, (dy, dx, dh, dw)] 15 | 16 | Returns: 17 | 18 | """ 19 | y_centers_a = (anchors[..., 0] + anchors[..., 2]) / 2 20 | x_centers_a = (anchors[..., 1] + anchors[..., 3]) / 2 21 | ha = anchors[..., 2] - anchors[..., 0] 22 | wa = anchors[..., 3] - anchors[..., 1] 23 | 24 | w = regression[..., 3].exp() * wa 25 | h = regression[..., 2].exp() * ha 26 | 27 | y_centers = regression[..., 0] * ha + y_centers_a 28 | x_centers = regression[..., 1] * wa + x_centers_a 29 | 30 | ymin = y_centers - h / 2. 31 | xmin = x_centers - w / 2. 32 | ymax = y_centers + h / 2. 33 | xmax = x_centers + w / 2. 34 | 35 | return torch.stack([xmin, ymin, xmax, ymax], dim=2) 36 | 37 | 38 | class ClipBoxes(nn.Module): 39 | 40 | def __init__(self): 41 | super(ClipBoxes, self).__init__() 42 | 43 | def forward(self, boxes, img): 44 | batch_size, num_channels, height, width = img.shape 45 | 46 | boxes[:, :, 0] = torch.clamp(boxes[:, :, 0], min=0) 47 | boxes[:, :, 1] = torch.clamp(boxes[:, :, 1], min=0) 48 | 49 | boxes[:, :, 2] = torch.clamp(boxes[:, :, 2], max=width - 1) 50 | boxes[:, :, 3] = torch.clamp(boxes[:, :, 3], max=height - 1) 51 | 52 | return boxes 53 | 54 | 55 | class Anchors(nn.Module): 56 | """ 57 | adapted and modified from https://github.com/google/automl/blob/master/efficientdet/anchors.py by Zylo117 58 | """ 59 | 60 | def __init__(self, anchor_scale=4., pyramid_levels=None, **kwargs): 61 | super().__init__() 62 | self.anchor_scale = anchor_scale 63 | 64 | if pyramid_levels is None: 65 | self.pyramid_levels = [3, 4, 5, 6, 7] 66 | 67 | self.strides = kwargs.get('strides', [2 ** x for x in self.pyramid_levels]) 68 | self.scales = np.array(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])) 69 | self.ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]) 70 | 71 | self.last_anchors = {} 72 | self.last_shape = None 73 | 74 | def forward(self, image, dtype=torch.float32): 75 | """Generates multiscale anchor boxes. 76 | 77 | Args: 78 | image_size: integer number of input image size. The input image has the 79 | same dimension for width and height. The image_size should be divided by 80 | the largest feature stride 2^max_level. 81 | anchor_scale: float number representing the scale of size of the base 82 | anchor to the feature stride 2^level. 83 | anchor_configs: a dictionary with keys as the levels of anchors and 84 | values as a list of anchor configuration. 85 | 86 | Returns: 87 | anchor_boxes: a numpy array with shape [N, 4], which stacks anchors on all 88 | feature levels. 89 | Raises: 90 | ValueError: input size must be the multiple of largest feature stride. 91 | """ 92 | image_shape = image.shape[2:] 93 | 94 | if image_shape == self.last_shape and image.device in self.last_anchors: 95 | return self.last_anchors[image.device] 96 | 97 | if self.last_shape is None or self.last_shape != image_shape: 98 | self.last_shape = image_shape 99 | 100 | if dtype == torch.float16: 101 | dtype = np.float16 102 | else: 103 | dtype = np.float32 104 | 105 | boxes_all = [] 106 | for stride in self.strides: 107 | boxes_level = [] 108 | for scale, ratio in itertools.product(self.scales, self.ratios): 109 | if image_shape[1] % stride != 0: 110 | raise ValueError('input size must be divided by the stride.') 111 | base_anchor_size = self.anchor_scale * stride * scale 112 | anchor_size_x_2 = base_anchor_size * ratio[0] / 2.0 113 | anchor_size_y_2 = base_anchor_size * ratio[1] / 2.0 114 | 115 | x = np.arange(stride / 2, image_shape[1], stride) 116 | y = np.arange(stride / 2, image_shape[0], stride) 117 | xv, yv = np.meshgrid(x, y) 118 | xv = xv.reshape(-1) 119 | yv = yv.reshape(-1) 120 | 121 | # y1,x1,y2,x2 122 | boxes = np.vstack((yv - anchor_size_y_2, xv - anchor_size_x_2, 123 | yv + anchor_size_y_2, xv + anchor_size_x_2)) 124 | boxes = np.swapaxes(boxes, 0, 1) 125 | boxes_level.append(np.expand_dims(boxes, axis=1)) 126 | # concat anchors on the same level to the reshape NxAx4 127 | boxes_level = np.concatenate(boxes_level, axis=1) 128 | boxes_all.append(boxes_level.reshape([-1, 4])) 129 | 130 | anchor_boxes = np.vstack(boxes_all) 131 | 132 | anchor_boxes = torch.from_numpy(anchor_boxes.astype(dtype)).to(image.device) 133 | anchor_boxes = anchor_boxes.unsqueeze(0) 134 | 135 | # save it for later use to reduce overhead 136 | self.last_anchors[image.device] = anchor_boxes 137 | return anchor_boxes 138 | -------------------------------------------------------------------------------- /efficientdet_test.py: -------------------------------------------------------------------------------- 1 | # Author: Zylo117 2 | 3 | """ 4 | Simple Inference Script of EfficientDet-Pytorch 5 | """ 6 | import random 7 | import time 8 | 9 | import cv2 10 | import numpy as np 11 | import torch 12 | from torch.backends import cudnn 13 | 14 | from backbone import EfficientDetBackbone 15 | from efficientdet.utils import BBoxTransform, ClipBoxes 16 | from utils.utils import preprocess, invert_affine, postprocess 17 | import efficientService as service 18 | from PIL import Image 19 | 20 | 21 | compound_coef = 0 22 | force_input_size = None # set None to use default size 23 | img_path = 'test/20130124152801556.jpg' 24 | 25 | # replace this part with your project's anchor config 26 | anchor_ratios = [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)] 27 | anchor_scales = [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)] 28 | 29 | threshold = 0.2 30 | iou_threshold = 0.2 31 | 32 | use_cuda = True 33 | use_float16 = False 34 | cudnn.fastest = True 35 | cudnn.benchmark = True 36 | 37 | obj_list = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 38 | 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 39 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie', 40 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 41 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 42 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 43 | 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv', 44 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 45 | 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 46 | 'toothbrush'] 47 | 48 | # tf bilinear interpolation is different from any other's, just make do 49 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536] 50 | input_size = input_sizes[compound_coef] if force_input_size is None else force_input_size 51 | ori_imgs, framed_imgs, framed_metas = preprocess(img_path, max_size=input_size) 52 | 53 | if use_cuda: 54 | x = torch.stack([torch.from_numpy(fi).cuda() for fi in framed_imgs], 0) 55 | else: 56 | x = torch.stack([torch.from_numpy(fi) for fi in framed_imgs], 0) 57 | 58 | x = x.to(torch.float32 if not use_float16 else torch.float16).permute(0, 3, 1, 2) 59 | 60 | model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list), 61 | ratios=anchor_ratios, scales=anchor_scales) 62 | model.load_state_dict(torch.load(f'weights/efficientdet-d{compound_coef}.pth')) 63 | model.requires_grad_(False) 64 | model.eval() 65 | 66 | if use_cuda: 67 | model = model.cuda() 68 | if use_float16: 69 | model = model.half() 70 | 71 | with torch.no_grad(): 72 | image = Image.open(img_path) 73 | frame = np.array(image) 74 | 75 | frame=service.detect(frame) 76 | 77 | features, regression, classification, anchors = model(x) 78 | 79 | regressBoxes = BBoxTransform() 80 | clipBoxes = ClipBoxes() 81 | 82 | out = postprocess(x, 83 | anchors, regression, classification, 84 | regressBoxes, clipBoxes, 85 | threshold, iou_threshold) 86 | 87 | 88 | def display(preds, imgs, imshow=True, imwrite=False): 89 | for i in range(len(imgs)): 90 | if len(preds[i]['rois']) == 0: 91 | continue 92 | 93 | for j in range(len(preds[i]['rois'])): 94 | (x1, y1, x2, y2) = preds[i]['rois'][j].astype(np.int) 95 | color = [random.randint(0, 255) for _ in range(3)] 96 | cv2.rectangle(imgs[i], (x1, y1), (x2, y2), color, 2) 97 | obj = obj_list[preds[i]['class_ids'][j]] 98 | score = float(preds[i]['scores'][j]) 99 | 100 | label = obj 101 | label_size = cv2.getTextSize(label, 0, fontScale=2 / 3, thickness=1)[0] 102 | cv2.rectangle(imgs[i], (x1, y1), (x1 + label_size[0], y1 - label_size[1] - 3), color, -1) 103 | # cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score), 104 | # (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 105 | # [225, 255, 255], 1, lineType=cv2.LINE_AA) 106 | cv2.putText(imgs[i], label, (x1, y1 - 8), 0, 2 / 3, [225, 255, 255], thickness=1, 107 | lineType=cv2.LINE_AA) 108 | 109 | if imshow: 110 | cv2.imshow('img', imgs[i]) 111 | cv2.waitKey(0) 112 | 113 | if imwrite: 114 | cv2.imwrite(f'test/img_inferred_d{compound_coef}_this_repo_{i}.jpg', imgs[i]) 115 | 116 | 117 | out = invert_affine(framed_metas, out) 118 | display(out, ori_imgs, imshow=True, imwrite=False) 119 | 120 | print('running speed test...') 121 | with torch.no_grad(): 122 | print('test1: model inferring and postprocessing') 123 | print('inferring image for 10 times...') 124 | t1 = time.time() 125 | for _ in range(10): 126 | _, regression, classification, anchors = model(x) 127 | 128 | out = postprocess(x, 129 | anchors, regression, classification, 130 | regressBoxes, clipBoxes, 131 | threshold, iou_threshold) 132 | out = invert_affine(framed_metas, out) 133 | 134 | t2 = time.time() 135 | tact_time = (t2 - t1) / 10 136 | print(f'{tact_time} seconds, {1 / tact_time} FPS, @batch_size 1') 137 | 138 | # uncomment this if you want a extreme fps test 139 | # print('test2: model inferring only') 140 | # print('inferring images for batch_size 32 for 10 times...') 141 | # t1 = time.time() 142 | # x = torch.cat([x] * 32, 0) 143 | # for _ in range(10): 144 | # _, regression, classification, anchors = model(x) 145 | # 146 | # t2 = time.time() 147 | # tact_time = (t2 - t1) / 10 148 | # print(f'{tact_time} seconds, {32 / tact_time} FPS, @batch_size 32') 149 | -------------------------------------------------------------------------------- /efficientdet_test_videos.py: -------------------------------------------------------------------------------- 1 | # Core Author: Zylo117 2 | # Script's Author: winter2897 3 | 4 | """ 5 | Simple Inference Script of EfficientDet-Pytorch for detecting objects on webcam 6 | """ 7 | import time 8 | import torch 9 | import cv2 10 | import numpy as np 11 | from torch.backends import cudnn 12 | from backbone import EfficientDetBackbone 13 | from efficientdet.utils import BBoxTransform, ClipBoxes 14 | from utils.utils import preprocess, invert_affine, postprocess, preprocess_video 15 | 16 | # Video's path 17 | video_src = 'videotest.mp4' # set int to use webcam, set str to read from a video file 18 | 19 | compound_coef = 0 20 | force_input_size = None # set None to use default size 21 | 22 | threshold = 0.2 23 | iou_threshold = 0.2 24 | 25 | use_cuda = True 26 | use_float16 = False 27 | cudnn.fastest = True 28 | cudnn.benchmark = True 29 | 30 | obj_list = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 31 | 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 32 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie', 33 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 34 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 35 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 36 | 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv', 37 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 38 | 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 39 | 'toothbrush'] 40 | 41 | # tf bilinear interpolation is different from any other's, just make do 42 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536] 43 | input_size = input_sizes[compound_coef] if force_input_size is None else force_input_size 44 | 45 | # load model 46 | model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list)) 47 | model.load_state_dict(torch.load(f'weights/efficientdet-d{compound_coef}.pth')) 48 | model.requires_grad_(False) 49 | model.eval() 50 | 51 | if use_cuda: 52 | model = model.cuda() 53 | if use_float16: 54 | model = model.half() 55 | 56 | # function for display 57 | def display(preds, imgs): 58 | for i in range(len(imgs)): 59 | if len(preds[i]['rois']) == 0: 60 | continue 61 | 62 | for j in range(len(preds[i]['rois'])): 63 | (x1, y1, x2, y2) = preds[i]['rois'][j].astype(np.int) 64 | cv2.rectangle(imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2) 65 | obj = obj_list[preds[i]['class_ids'][j]] 66 | score = float(preds[i]['scores'][j]) 67 | 68 | cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score), 69 | (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 70 | (255, 255, 0), 1) 71 | 72 | return imgs[i] 73 | # Box 74 | regressBoxes = BBoxTransform() 75 | clipBoxes = ClipBoxes() 76 | 77 | # Video capture 78 | cap = cv2.VideoCapture(video_src) 79 | 80 | while True: 81 | ret, frame = cap.read() 82 | if not ret: 83 | break 84 | 85 | # frame preprocessing 86 | ori_imgs, framed_imgs, framed_metas = preprocess_video(frame, max_size=input_size) 87 | 88 | if use_cuda: 89 | x = torch.stack([torch.from_numpy(fi).cuda() for fi in framed_imgs], 0) 90 | else: 91 | x = torch.stack([torch.from_numpy(fi) for fi in framed_imgs], 0) 92 | 93 | x = x.to(torch.float32 if not use_float16 else torch.float16).permute(0, 3, 1, 2) 94 | 95 | # model predict 96 | with torch.no_grad(): 97 | features, regression, classification, anchors = model(x) 98 | 99 | out = postprocess(x, 100 | anchors, regression, classification, 101 | regressBoxes, clipBoxes, 102 | threshold, iou_threshold) 103 | 104 | # result 105 | out = invert_affine(framed_metas, out) 106 | img_show = display(out, ori_imgs) 107 | 108 | # show frame by frame 109 | cv2.imshow('frame',img_show) 110 | if cv2.waitKey(1) & 0xFF == ord('q'): 111 | break 112 | 113 | cap.release() 114 | cv2.destroyAllWindows() 115 | 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /efficientnet/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.1" 2 | from .model import EfficientNet 3 | from .utils import ( 4 | GlobalParams, 5 | BlockArgs, 6 | BlockDecoder, 7 | efficientnet, 8 | get_model_params, 9 | ) 10 | 11 | -------------------------------------------------------------------------------- /efficientnet/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from .utils import ( 6 | round_filters, 7 | round_repeats, 8 | drop_connect, 9 | get_same_padding_conv2d, 10 | get_model_params, 11 | efficientnet_params, 12 | load_pretrained_weights, 13 | Swish, 14 | MemoryEfficientSwish, 15 | ) 16 | 17 | class MBConvBlock(nn.Module): 18 | """ 19 | Mobile Inverted Residual Bottleneck Block 20 | 21 | Args: 22 | block_args (namedtuple): BlockArgs, see above 23 | global_params (namedtuple): GlobalParam, see above 24 | 25 | Attributes: 26 | has_se (bool): Whether the block contains a Squeeze and Excitation layer. 27 | """ 28 | 29 | def __init__(self, block_args, global_params): 30 | super().__init__() 31 | self._block_args = block_args 32 | self._bn_mom = 1 - global_params.batch_norm_momentum 33 | self._bn_eps = global_params.batch_norm_epsilon 34 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) 35 | self.id_skip = block_args.id_skip # skip connection and drop connect 36 | 37 | # Get static or dynamic convolution depending on image size 38 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) 39 | 40 | # Expansion phase 41 | inp = self._block_args.input_filters # number of input channels 42 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels 43 | if self._block_args.expand_ratio != 1: 44 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) 45 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 46 | 47 | # Depthwise convolution phase 48 | k = self._block_args.kernel_size 49 | s = self._block_args.stride 50 | self._depthwise_conv = Conv2d( 51 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise 52 | kernel_size=k, stride=s, bias=False) 53 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 54 | 55 | # Squeeze and Excitation layer, if desired 56 | if self.has_se: 57 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) 58 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) 59 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) 60 | 61 | # Output phase 62 | final_oup = self._block_args.output_filters 63 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) 64 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) 65 | self._swish = MemoryEfficientSwish() 66 | 67 | def forward(self, inputs, drop_connect_rate=None): 68 | """ 69 | :param inputs: input tensor 70 | :param drop_connect_rate: drop connect rate (float, between 0 and 1) 71 | :return: output of block 72 | """ 73 | 74 | # Expansion and Depthwise Convolution 75 | x = inputs 76 | if self._block_args.expand_ratio != 1: 77 | x = self._expand_conv(inputs) 78 | x = self._bn0(x) 79 | x = self._swish(x) 80 | 81 | x = self._depthwise_conv(x) 82 | x = self._bn1(x) 83 | x = self._swish(x) 84 | 85 | # Squeeze and Excitation 86 | if self.has_se: 87 | x_squeezed = F.adaptive_avg_pool2d(x, 1) 88 | x_squeezed = self._se_reduce(x_squeezed) 89 | x_squeezed = self._swish(x_squeezed) 90 | x_squeezed = self._se_expand(x_squeezed) 91 | x = torch.sigmoid(x_squeezed) * x 92 | 93 | x = self._project_conv(x) 94 | x = self._bn2(x) 95 | 96 | # Skip connection and drop connect 97 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters 98 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: 99 | if drop_connect_rate: 100 | x = drop_connect(x, p=drop_connect_rate, training=self.training) 101 | x = x + inputs # skip connection 102 | return x 103 | 104 | def set_swish(self, memory_efficient=True): 105 | """Sets swish function as memory efficient (for training) or standard (for export)""" 106 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 107 | 108 | 109 | class EfficientNet(nn.Module): 110 | """ 111 | An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods 112 | 113 | Args: 114 | blocks_args (list): A list of BlockArgs to construct blocks 115 | global_params (namedtuple): A set of GlobalParams shared between blocks 116 | 117 | Example: 118 | model = EfficientNet.from_pretrained('efficientnet-b0') 119 | 120 | """ 121 | 122 | def __init__(self, blocks_args=None, global_params=None): 123 | super().__init__() 124 | assert isinstance(blocks_args, list), 'blocks_args should be a list' 125 | assert len(blocks_args) > 0, 'block args must be greater than 0' 126 | self._global_params = global_params 127 | self._blocks_args = blocks_args 128 | 129 | # Get static or dynamic convolution depending on image size 130 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) 131 | 132 | # Batch norm parameters 133 | bn_mom = 1 - self._global_params.batch_norm_momentum 134 | bn_eps = self._global_params.batch_norm_epsilon 135 | 136 | # Stem 137 | in_channels = 3 # rgb 138 | out_channels = round_filters(32, self._global_params) # number of output channels 139 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 140 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 141 | 142 | # Build blocks 143 | self._blocks = nn.ModuleList([]) 144 | for block_args in self._blocks_args: 145 | 146 | # Update block input and output filters based on depth multiplier. 147 | block_args = block_args._replace( 148 | input_filters=round_filters(block_args.input_filters, self._global_params), 149 | output_filters=round_filters(block_args.output_filters, self._global_params), 150 | num_repeat=round_repeats(block_args.num_repeat, self._global_params) 151 | ) 152 | 153 | # The first block needs to take care of stride and filter size increase. 154 | self._blocks.append(MBConvBlock(block_args, self._global_params)) 155 | if block_args.num_repeat > 1: 156 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) 157 | for _ in range(block_args.num_repeat - 1): 158 | self._blocks.append(MBConvBlock(block_args, self._global_params)) 159 | 160 | # Head 161 | in_channels = block_args.output_filters # output of final block 162 | out_channels = round_filters(1280, self._global_params) 163 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 164 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 165 | 166 | # Final linear layer 167 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 168 | self._dropout = nn.Dropout(self._global_params.dropout_rate) 169 | self._fc = nn.Linear(out_channels, self._global_params.num_classes) 170 | self._swish = MemoryEfficientSwish() 171 | 172 | def set_swish(self, memory_efficient=True): 173 | """Sets swish function as memory efficient (for training) or standard (for export)""" 174 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 175 | for block in self._blocks: 176 | block.set_swish(memory_efficient) 177 | 178 | 179 | def extract_features(self, inputs): 180 | """ Returns output of the final convolution layer """ 181 | 182 | # Stem 183 | x = self._swish(self._bn0(self._conv_stem(inputs))) 184 | 185 | # Blocks 186 | for idx, block in enumerate(self._blocks): 187 | drop_connect_rate = self._global_params.drop_connect_rate 188 | if drop_connect_rate: 189 | drop_connect_rate *= float(idx) / len(self._blocks) 190 | x = block(x, drop_connect_rate=drop_connect_rate) 191 | # Head 192 | x = self._swish(self._bn1(self._conv_head(x))) 193 | 194 | return x 195 | 196 | def forward(self, inputs): 197 | """ Calls extract_features to extract features, applies final linear layer, and returns logits. """ 198 | bs = inputs.size(0) 199 | # Convolution layers 200 | x = self.extract_features(inputs) 201 | 202 | # Pooling and final linear layer 203 | x = self._avg_pooling(x) 204 | x = x.view(bs, -1) 205 | x = self._dropout(x) 206 | x = self._fc(x) 207 | return x 208 | 209 | @classmethod 210 | def from_name(cls, model_name, override_params=None): 211 | cls._check_model_name_is_valid(model_name) 212 | blocks_args, global_params = get_model_params(model_name, override_params) 213 | return cls(blocks_args, global_params) 214 | 215 | @classmethod 216 | def from_pretrained(cls, model_name, load_weights=True, advprop=True, num_classes=1000, in_channels=3): 217 | model = cls.from_name(model_name, override_params={'num_classes': num_classes}) 218 | if load_weights: 219 | load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop) 220 | if in_channels != 3: 221 | Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size) 222 | out_channels = round_filters(32, model._global_params) 223 | model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 224 | return model 225 | 226 | @classmethod 227 | def get_image_size(cls, model_name): 228 | cls._check_model_name_is_valid(model_name) 229 | _, _, res, _ = efficientnet_params(model_name) 230 | return res 231 | 232 | @classmethod 233 | def _check_model_name_is_valid(cls, model_name): 234 | """ Validates model name. """ 235 | valid_models = ['efficientnet-b'+str(i) for i in range(9)] 236 | if model_name not in valid_models: 237 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) 238 | -------------------------------------------------------------------------------- /efficientnet/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains helper functions for building the model and for loading model parameters. 3 | These helper functions are built to mirror those in the official TensorFlow implementation. 4 | """ 5 | 6 | import re 7 | import math 8 | import collections 9 | from functools import partial 10 | import torch 11 | from torch import nn 12 | from torch.nn import functional as F 13 | from torch.utils import model_zoo 14 | from .utils_extra import Conv2dStaticSamePadding 15 | 16 | ######################################################################## 17 | ############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ############### 18 | ######################################################################## 19 | 20 | 21 | # Parameters for the entire model (stem, all blocks, and head) 22 | 23 | GlobalParams = collections.namedtuple('GlobalParams', [ 24 | 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 25 | 'num_classes', 'width_coefficient', 'depth_coefficient', 26 | 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size']) 27 | 28 | # Parameters for an individual model block 29 | BlockArgs = collections.namedtuple('BlockArgs', [ 30 | 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', 31 | 'expand_ratio', 'id_skip', 'stride', 'se_ratio']) 32 | 33 | # Change namedtuple defaults 34 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) 35 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 36 | 37 | 38 | class SwishImplementation(torch.autograd.Function): 39 | @staticmethod 40 | def forward(ctx, i): 41 | result = i * torch.sigmoid(i) 42 | ctx.save_for_backward(i) 43 | return result 44 | 45 | @staticmethod 46 | def backward(ctx, grad_output): 47 | i = ctx.saved_variables[0] 48 | sigmoid_i = torch.sigmoid(i) 49 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) 50 | 51 | 52 | class MemoryEfficientSwish(nn.Module): 53 | def forward(self, x): 54 | return SwishImplementation.apply(x) 55 | 56 | 57 | class Swish(nn.Module): 58 | def forward(self, x): 59 | return x * torch.sigmoid(x) 60 | 61 | 62 | def round_filters(filters, global_params): 63 | """ Calculate and round number of filters based on depth multiplier. """ 64 | multiplier = global_params.width_coefficient 65 | if not multiplier: 66 | return filters 67 | divisor = global_params.depth_divisor 68 | min_depth = global_params.min_depth 69 | filters *= multiplier 70 | min_depth = min_depth or divisor 71 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) 72 | if new_filters < 0.9 * filters: # prevent rounding by more than 10% 73 | new_filters += divisor 74 | return int(new_filters) 75 | 76 | 77 | def round_repeats(repeats, global_params): 78 | """ Round number of filters based on depth multiplier. """ 79 | multiplier = global_params.depth_coefficient 80 | if not multiplier: 81 | return repeats 82 | return int(math.ceil(multiplier * repeats)) 83 | 84 | 85 | def drop_connect(inputs, p, training): 86 | """ Drop connect. """ 87 | if not training: return inputs 88 | batch_size = inputs.shape[0] 89 | keep_prob = 1 - p 90 | random_tensor = keep_prob 91 | random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) 92 | binary_tensor = torch.floor(random_tensor) 93 | output = inputs / keep_prob * binary_tensor 94 | return output 95 | 96 | 97 | def get_same_padding_conv2d(image_size=None): 98 | """ Chooses static padding if you have specified an image size, and dynamic padding otherwise. 99 | Static padding is necessary for ONNX exporting of models. """ 100 | if image_size is None: 101 | return Conv2dDynamicSamePadding 102 | else: 103 | return partial(Conv2dStaticSamePadding, image_size=image_size) 104 | 105 | 106 | class Conv2dDynamicSamePadding(nn.Conv2d): 107 | """ 2D Convolutions like TensorFlow, for a dynamic image size """ 108 | 109 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): 110 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 111 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 112 | 113 | def forward(self, x): 114 | ih, iw = x.size()[-2:] 115 | kh, kw = self.weight.size()[-2:] 116 | sh, sw = self.stride 117 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 118 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 119 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 120 | if pad_h > 0 or pad_w > 0: 121 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 122 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 123 | 124 | 125 | class Identity(nn.Module): 126 | def __init__(self, ): 127 | super(Identity, self).__init__() 128 | 129 | def forward(self, input): 130 | return input 131 | 132 | 133 | ######################################################################## 134 | ############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ############## 135 | ######################################################################## 136 | 137 | 138 | def efficientnet_params(model_name): 139 | """ Map EfficientNet model name to parameter coefficients. """ 140 | params_dict = { 141 | # Coefficients: width,depth,res,dropout 142 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 143 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 144 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 145 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 146 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 147 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 148 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 149 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 150 | 'efficientnet-b8': (2.2, 3.6, 672, 0.5), 151 | 'efficientnet-l2': (4.3, 5.3, 800, 0.5), 152 | } 153 | return params_dict[model_name] 154 | 155 | 156 | class BlockDecoder(object): 157 | """ Block Decoder for readability, straight from the official TensorFlow repository """ 158 | 159 | @staticmethod 160 | def _decode_block_string(block_string): 161 | """ Gets a block through a string notation of arguments. """ 162 | assert isinstance(block_string, str) 163 | 164 | ops = block_string.split('_') 165 | options = {} 166 | for op in ops: 167 | splits = re.split(r'(\d.*)', op) 168 | if len(splits) >= 2: 169 | key, value = splits[:2] 170 | options[key] = value 171 | 172 | # Check stride 173 | assert (('s' in options and len(options['s']) == 1) or 174 | (len(options['s']) == 2 and options['s'][0] == options['s'][1])) 175 | 176 | return BlockArgs( 177 | kernel_size=int(options['k']), 178 | num_repeat=int(options['r']), 179 | input_filters=int(options['i']), 180 | output_filters=int(options['o']), 181 | expand_ratio=int(options['e']), 182 | id_skip=('noskip' not in block_string), 183 | se_ratio=float(options['se']) if 'se' in options else None, 184 | stride=[int(options['s'][0])]) 185 | 186 | @staticmethod 187 | def _encode_block_string(block): 188 | """Encodes a block to a string.""" 189 | args = [ 190 | 'r%d' % block.num_repeat, 191 | 'k%d' % block.kernel_size, 192 | 's%d%d' % (block.strides[0], block.strides[1]), 193 | 'e%s' % block.expand_ratio, 194 | 'i%d' % block.input_filters, 195 | 'o%d' % block.output_filters 196 | ] 197 | if 0 < block.se_ratio <= 1: 198 | args.append('se%s' % block.se_ratio) 199 | if block.id_skip is False: 200 | args.append('noskip') 201 | return '_'.join(args) 202 | 203 | @staticmethod 204 | def decode(string_list): 205 | """ 206 | Decodes a list of string notations to specify blocks inside the network. 207 | 208 | :param string_list: a list of strings, each string is a notation of block 209 | :return: a list of BlockArgs namedtuples of block args 210 | """ 211 | assert isinstance(string_list, list) 212 | blocks_args = [] 213 | for block_string in string_list: 214 | blocks_args.append(BlockDecoder._decode_block_string(block_string)) 215 | return blocks_args 216 | 217 | @staticmethod 218 | def encode(blocks_args): 219 | """ 220 | Encodes a list of BlockArgs to a list of strings. 221 | 222 | :param blocks_args: a list of BlockArgs namedtuples of block args 223 | :return: a list of strings, each string is a notation of block 224 | """ 225 | block_strings = [] 226 | for block in blocks_args: 227 | block_strings.append(BlockDecoder._encode_block_string(block)) 228 | return block_strings 229 | 230 | 231 | def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2, 232 | drop_connect_rate=0.2, image_size=None, num_classes=1000): 233 | """ Creates a efficientnet model. """ 234 | 235 | blocks_args = [ 236 | 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', 237 | 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', 238 | 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', 239 | 'r1_k3_s11_e6_i192_o320_se0.25', 240 | ] 241 | blocks_args = BlockDecoder.decode(blocks_args) 242 | 243 | global_params = GlobalParams( 244 | batch_norm_momentum=0.99, 245 | batch_norm_epsilon=1e-3, 246 | dropout_rate=dropout_rate, 247 | drop_connect_rate=drop_connect_rate, 248 | # data_format='channels_last', # removed, this is always true in PyTorch 249 | num_classes=num_classes, 250 | width_coefficient=width_coefficient, 251 | depth_coefficient=depth_coefficient, 252 | depth_divisor=8, 253 | min_depth=None, 254 | image_size=image_size, 255 | ) 256 | 257 | return blocks_args, global_params 258 | 259 | 260 | def get_model_params(model_name, override_params): 261 | """ Get the block args and global params for a given model """ 262 | if model_name.startswith('efficientnet'): 263 | w, d, s, p = efficientnet_params(model_name) 264 | # note: all models have drop connect rate = 0.2 265 | blocks_args, global_params = efficientnet( 266 | width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) 267 | else: 268 | raise NotImplementedError('model name is not pre-defined: %s' % model_name) 269 | if override_params: 270 | # ValueError will be raised here if override_params has fields not included in global_params. 271 | global_params = global_params._replace(**override_params) 272 | return blocks_args, global_params 273 | 274 | 275 | url_map = { 276 | 'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b0-355c32eb.pth', 277 | 'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b1-f1951068.pth', 278 | 'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b2-8bb594d6.pth', 279 | 'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b3-5fb5a3c3.pth', 280 | 'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b4-6ed6700e.pth', 281 | 'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b5-b6417697.pth', 282 | 'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b6-c76e70fd.pth', 283 | 'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b7-dcc49843.pth', 284 | } 285 | 286 | url_map_advprop = { 287 | 'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b0-b64d5a18.pth', 288 | 'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b1-0f3ce85a.pth', 289 | 'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b2-6e9d97e5.pth', 290 | 'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b3-cdd7c0f4.pth', 291 | 'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b4-44fb3a87.pth', 292 | 'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b5-86493f6b.pth', 293 | 'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b6-ac80338e.pth', 294 | 'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b7-4652b6dd.pth', 295 | 'efficientnet-b8': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b8-22a8fe65.pth', 296 | } 297 | 298 | 299 | def load_pretrained_weights(model, model_name, load_fc=True, advprop=False): 300 | """ Loads pretrained weights, and downloads if loading for the first time. """ 301 | # AutoAugment or Advprop (different preprocessing) 302 | url_map_ = url_map_advprop if advprop else url_map 303 | state_dict = model_zoo.load_url(url_map_[model_name], map_location=torch.device('cpu')) 304 | # state_dict = torch.load('../../weights/backbone_efficientnetb0.pth') 305 | if load_fc: 306 | ret = model.load_state_dict(state_dict, strict=False) 307 | print(ret) 308 | else: 309 | state_dict.pop('_fc.weight') 310 | state_dict.pop('_fc.bias') 311 | res = model.load_state_dict(state_dict, strict=False) 312 | assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights' 313 | print('Loaded pretrained weights for {}'.format(model_name)) 314 | -------------------------------------------------------------------------------- /efficientnet/utils_extra.py: -------------------------------------------------------------------------------- 1 | # Author: Zylo117 2 | 3 | import math 4 | 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Conv2dStaticSamePadding(nn.Module): 10 | """ 11 | created by Zylo117 12 | The real keras/tensorflow conv2d with same padding 13 | """ 14 | 15 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1, dilation=1, **kwargs): 16 | super().__init__() 17 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, 18 | bias=bias, groups=groups) 19 | self.stride = self.conv.stride 20 | self.kernel_size = self.conv.kernel_size 21 | self.dilation = self.conv.dilation 22 | 23 | if isinstance(self.stride, int): 24 | self.stride = [self.stride] * 2 25 | elif len(self.stride) == 1: 26 | self.stride = [self.stride[0]] * 2 27 | 28 | if isinstance(self.kernel_size, int): 29 | self.kernel_size = [self.kernel_size] * 2 30 | elif len(self.kernel_size) == 1: 31 | self.kernel_size = [self.kernel_size[0]] * 2 32 | 33 | def forward(self, x): 34 | h, w = x.shape[-2:] 35 | 36 | h_step = math.ceil(w / self.stride[1]) 37 | v_step = math.ceil(h / self.stride[0]) 38 | h_cover_len = self.stride[1] * (h_step - 1) + 1 + (self.kernel_size[1] - 1) 39 | v_cover_len = self.stride[0] * (v_step - 1) + 1 + (self.kernel_size[0] - 1) 40 | 41 | extra_h = h_cover_len - w 42 | extra_v = v_cover_len - h 43 | 44 | left = extra_h // 2 45 | right = extra_h - left 46 | top = extra_v // 2 47 | bottom = extra_v - top 48 | 49 | x = F.pad(x, [left, right, top, bottom]) 50 | 51 | x = self.conv(x) 52 | return x 53 | 54 | 55 | class MaxPool2dStaticSamePadding(nn.Module): 56 | """ 57 | created by Zylo117 58 | The real keras/tensorflow MaxPool2d with same padding 59 | """ 60 | 61 | def __init__(self, *args, **kwargs): 62 | super().__init__() 63 | self.pool = nn.MaxPool2d(*args, **kwargs) 64 | self.stride = self.pool.stride 65 | self.kernel_size = self.pool.kernel_size 66 | 67 | if isinstance(self.stride, int): 68 | self.stride = [self.stride] * 2 69 | elif len(self.stride) == 1: 70 | self.stride = [self.stride[0]] * 2 71 | 72 | if isinstance(self.kernel_size, int): 73 | self.kernel_size = [self.kernel_size] * 2 74 | elif len(self.kernel_size) == 1: 75 | self.kernel_size = [self.kernel_size[0]] * 2 76 | 77 | def forward(self, x): 78 | h, w = x.shape[-2:] 79 | 80 | h_step = math.ceil(w / self.stride[1]) 81 | v_step = math.ceil(h / self.stride[0]) 82 | h_cover_len = self.stride[1] * (h_step - 1) + 1 + (self.kernel_size[1] - 1) 83 | v_cover_len = self.stride[0] * (v_step - 1) + 1 + (self.kernel_size[0] - 1) 84 | 85 | extra_h = h_cover_len - w 86 | extra_v = v_cover_len - h 87 | 88 | left = extra_h // 2 89 | right = extra_h - left 90 | top = extra_v // 2 91 | bottom = extra_v - top 92 | 93 | x = F.pad(x, [left, right, top, bottom]) 94 | 95 | x = self.pool(x) 96 | return x 97 | -------------------------------------------------------------------------------- /projects/coco.yml: -------------------------------------------------------------------------------- 1 | project_name: coco # also the folder name of the dataset that under data_path folder 2 | train_set: train2017 3 | val_set: val2017 4 | num_gpus: 4 5 | 6 | # mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco. 7 | mean: [0.485, 0.456, 0.406] 8 | std: [0.229, 0.224, 0.225] 9 | 10 | # this is coco anchors, change it if necessary 11 | anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]' 12 | anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]' 13 | 14 | # must match your dataset's category_id. 15 | # category_id is one_indexed, 16 | # for example, index of 'car' here is 2, while category_id of is 3 17 | obj_list: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 18 | 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 19 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie', 20 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 21 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 22 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 23 | 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv', 24 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 25 | 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 26 | 'toothbrush'] -------------------------------------------------------------------------------- /projects/shape.yml: -------------------------------------------------------------------------------- 1 | project_name: shape # also the folder name of the dataset that under data_path folder 2 | train_set: train 3 | val_set: val 4 | num_gpus: 1 5 | 6 | # mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco. 7 | mean: [0.485, 0.456, 0.406] 8 | std: [0.229, 0.224, 0.225] 9 | 10 | # this anchor is adapted to the dataset 11 | anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]' 12 | anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]' 13 | 14 | obj_list: ['rectangle', 'circle'] -------------------------------------------------------------------------------- /test/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anylots/DetectNet/2be42fc1c8439efc48ab337b1c73fe4ee9b9f78e/test/img.png -------------------------------------------------------------------------------- /test/img_inferred_d0_official.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anylots/DetectNet/2be42fc1c8439efc48ab337b1c73fe4ee9b9f78e/test/img_inferred_d0_official.jpg -------------------------------------------------------------------------------- /test/img_inferred_d0_this_repo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anylots/DetectNet/2be42fc1c8439efc48ab337b1c73fe4ee9b9f78e/test/img_inferred_d0_this_repo.jpg -------------------------------------------------------------------------------- /test/img_inferred_d0_this_repo_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anylots/DetectNet/2be42fc1c8439efc48ab337b1c73fe4ee9b9f78e/test/img_inferred_d0_this_repo_0.jpg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # original author: signatrix 2 | # adapted from https://github.com/signatrix/efficientdet/blob/master/train.py 3 | # modified by Zylo117 4 | 5 | import datetime 6 | import os 7 | import argparse 8 | import traceback 9 | 10 | import torch 11 | import yaml 12 | from torch import nn 13 | from torch.utils.data import DataLoader 14 | from torchvision import transforms 15 | from efficientdet.dataset import CocoDataset, Resizer, Normalizer, Augmenter, collater 16 | from backbone import EfficientDetBackbone 17 | from tensorboardX import SummaryWriter 18 | import numpy as np 19 | from tqdm.autonotebook import tqdm 20 | 21 | from efficientdet.loss import FocalLoss 22 | from utils.sync_batchnorm import patch_replication_callback 23 | from utils.utils import replace_w_sync_bn, CustomDataParallel, get_last_weights, init_weights 24 | 25 | 26 | class Params: 27 | def __init__(self, project_file): 28 | self.params = yaml.safe_load(open(project_file).read()) 29 | 30 | def __getattr__(self, item): 31 | return self.params.get(item, None) 32 | 33 | 34 | def get_args(): 35 | parser = argparse.ArgumentParser('Yet Another EfficientDet Pytorch: SOTA object detection network - Zylo117') 36 | parser.add_argument('-p', '--project', type=str, default='coco', help='project file that contains parameters') 37 | parser.add_argument('-c', '--compound_coef', type=int, default=0, help='coefficients of efficientdet') 38 | parser.add_argument('-n', '--num_workers', type=int, default=12, help='num_workers of dataloader') 39 | parser.add_argument('--batch_size', type=int, default=12, help='The number of images per batch among all devices') 40 | parser.add_argument('--head_only', type=bool, default=False, 41 | help='whether finetunes only the regressor and the classifier, ' 42 | 'useful in early stage convergence or small/easy dataset') 43 | parser.add_argument('--lr', type=float, default=1e-4) 44 | parser.add_argument('--optim', type=str, default='adamw', help='select optimizer for training, ' 45 | 'suggest using \'admaw\' until the' 46 | ' very final stage then switch to \'sgd\'') 47 | parser.add_argument('--alpha', type=float, default=0.25) 48 | parser.add_argument('--gamma', type=float, default=1.5) 49 | parser.add_argument('--num_epochs', type=int, default=500) 50 | parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases') 51 | parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving') 52 | parser.add_argument('--es_min_delta', type=float, default=0.0, 53 | help='Early stopping\'s parameter: minimum change loss to qualify as an improvement') 54 | parser.add_argument('--es_patience', type=int, default=0, 55 | help='Early stopping\'s parameter: number of epochs with no improvement after which training will be stopped. Set to 0 to disable this technique.') 56 | parser.add_argument('--data_path', type=str, default='datasets/', help='the root folder of dataset') 57 | parser.add_argument('--log_path', type=str, default='logs/') 58 | parser.add_argument('-w', '--load_weights', type=str, default=None, 59 | help='whether to load weights from a checkpoint, set None to initialize, set \'last\' to load last checkpoint') 60 | parser.add_argument('--saved_path', type=str, default='logs/') 61 | parser.add_argument('--debug', type=bool, default=False, help='whether visualize the predicted boxes of trainging, ' 62 | 'the output images will be in test/') 63 | 64 | args = parser.parse_args() 65 | return args 66 | 67 | 68 | class ModelWithLoss(nn.Module): 69 | def __init__(self, model, debug=False): 70 | super().__init__() 71 | self.criterion = FocalLoss() 72 | self.model = model 73 | self.debug = debug 74 | 75 | def forward(self, imgs, annotations, obj_list=None): 76 | _, regression, classification, anchors = self.model(imgs) 77 | if self.debug: 78 | cls_loss, reg_loss = self.criterion(classification, regression, anchors, annotations, 79 | imgs=imgs, obj_list=obj_list) 80 | else: 81 | cls_loss, reg_loss = self.criterion(classification, regression, anchors, annotations) 82 | return cls_loss, reg_loss 83 | 84 | 85 | def train(opt): 86 | params = Params(f'projects/{opt.project}.yml') 87 | 88 | if params.num_gpus == 0: 89 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 90 | 91 | if torch.cuda.is_available(): 92 | torch.cuda.manual_seed(42) 93 | else: 94 | torch.manual_seed(42) 95 | 96 | opt.saved_path = opt.saved_path + f'/{params.project_name}/' 97 | opt.log_path = opt.log_path + f'/{params.project_name}/tensorboard/' 98 | os.makedirs(opt.log_path, exist_ok=True) 99 | os.makedirs(opt.saved_path, exist_ok=True) 100 | 101 | training_params = {'batch_size': opt.batch_size, 102 | 'shuffle': True, 103 | 'drop_last': True, 104 | 'collate_fn': collater, 105 | 'num_workers': opt.num_workers} 106 | 107 | val_params = {'batch_size': opt.batch_size, 108 | 'shuffle': False, 109 | 'drop_last': True, 110 | 'collate_fn': collater, 111 | 'num_workers': opt.num_workers} 112 | 113 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536] 114 | training_set = CocoDataset(root_dir=os.path.join(opt.data_path, params.project_name), set=params.train_set, 115 | transform=transforms.Compose([Normalizer(mean=params.mean, std=params.std), 116 | Augmenter(), 117 | Resizer(input_sizes[opt.compound_coef])])) 118 | training_generator = DataLoader(training_set, **training_params) 119 | 120 | val_set = CocoDataset(root_dir=os.path.join(opt.data_path, params.project_name), set=params.val_set, 121 | transform=transforms.Compose([Normalizer(mean=params.mean, std=params.std), 122 | Resizer(input_sizes[opt.compound_coef])])) 123 | val_generator = DataLoader(val_set, **val_params) 124 | 125 | model = EfficientDetBackbone(num_classes=len(params.obj_list), compound_coef=opt.compound_coef, 126 | ratios=eval(params.anchors_ratios), scales=eval(params.anchors_scales)) 127 | 128 | # load last weights 129 | if opt.load_weights is not None: 130 | if opt.load_weights.endswith('.pth'): 131 | weights_path = opt.load_weights 132 | else: 133 | weights_path = get_last_weights(opt.saved_path) 134 | try: 135 | last_step = int(os.path.basename(weights_path).split('_')[-1].split('.')[0]) 136 | except: 137 | last_step = 0 138 | 139 | try: 140 | ret = model.load_state_dict(torch.load(weights_path), strict=False) 141 | except RuntimeError as e: 142 | print(f'[Warning] Ignoring {e}') 143 | print( 144 | '[Warning] Don\'t panic if you see this, this might be because you load a pretrained weights with different number of classes. The rest of the weights should be loaded already.') 145 | 146 | print(f'[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}') 147 | else: 148 | last_step = 0 149 | print('[Info] initializing weights...') 150 | init_weights(model) 151 | 152 | # freeze backbone if train head_only 153 | if opt.head_only: 154 | def freeze_backbone(m): 155 | classname = m.__class__.__name__ 156 | for ntl in ['EfficientNet', 'BiFPN']: 157 | if ntl in classname: 158 | for param in m.parameters(): 159 | param.requires_grad = False 160 | 161 | model.apply(freeze_backbone) 162 | print('[Info] freezed backbone') 163 | 164 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 165 | # apply sync_bn when using multiple gpu and batch_size per gpu is lower than 4 166 | # useful when gpu memory is limited. 167 | # because when bn is disable, the training will be very unstable or slow to converge, 168 | # apply sync_bn can solve it, 169 | # by packing all mini-batch across all gpus as one batch and normalize, then send it back to all gpus. 170 | # but it would also slow down the training by a little bit. 171 | if params.num_gpus > 1 and opt.batch_size // params.num_gpus < 4: 172 | model.apply(replace_w_sync_bn) 173 | use_sync_bn = True 174 | else: 175 | use_sync_bn = False 176 | 177 | writer = SummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/') 178 | 179 | # warp the model with loss function, to reduce the memory usage on gpu0 and speedup 180 | model = ModelWithLoss(model, debug=opt.debug) 181 | 182 | if params.num_gpus > 0: 183 | model = model.cuda() 184 | if params.num_gpus > 1: 185 | model = CustomDataParallel(model, params.num_gpus) 186 | if use_sync_bn: 187 | patch_replication_callback(model) 188 | 189 | if opt.optim == 'adamw': 190 | optimizer = torch.optim.AdamW(model.parameters(), opt.lr) 191 | else: 192 | optimizer = torch.optim.SGD(model.parameters(), opt.lr, momentum=0.9, nesterov=True) 193 | 194 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True) 195 | 196 | epoch = 0 197 | best_loss = 1e5 198 | best_epoch = 0 199 | step = max(0, last_step) 200 | model.train() 201 | 202 | num_iter_per_epoch = len(training_generator) 203 | 204 | try: 205 | for epoch in range(opt.num_epochs): 206 | last_epoch = step // num_iter_per_epoch 207 | if epoch < last_epoch: 208 | continue 209 | 210 | epoch_loss = [] 211 | progress_bar = tqdm(training_generator) 212 | for iter, data in enumerate(progress_bar): 213 | if iter < step - last_epoch * num_iter_per_epoch: 214 | progress_bar.update() 215 | continue 216 | try: 217 | imgs = data['img'] 218 | annot = data['annot'] 219 | 220 | if params.num_gpus == 1: 221 | # if only one gpu, just send it to cuda:0 222 | # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here 223 | imgs = imgs.cuda() 224 | annot = annot.cuda() 225 | 226 | optimizer.zero_grad() 227 | cls_loss, reg_loss = model(imgs, annot, obj_list=params.obj_list) 228 | cls_loss = cls_loss.mean() 229 | reg_loss = reg_loss.mean() 230 | 231 | loss = cls_loss + reg_loss 232 | if loss == 0 or not torch.isfinite(loss): 233 | continue 234 | 235 | loss.backward() 236 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) 237 | optimizer.step() 238 | 239 | epoch_loss.append(float(loss)) 240 | 241 | progress_bar.set_description( 242 | 'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'.format( 243 | step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch, cls_loss.item(), 244 | reg_loss.item(), loss.item())) 245 | writer.add_scalars('Loss', {'train': loss}, step) 246 | writer.add_scalars('Regression_loss', {'train': reg_loss}, step) 247 | writer.add_scalars('Classfication_loss', {'train': cls_loss}, step) 248 | 249 | # log learning_rate 250 | current_lr = optimizer.param_groups[0]['lr'] 251 | writer.add_scalar('learning_rate', current_lr, step) 252 | 253 | step += 1 254 | 255 | if step % opt.save_interval == 0 and step > 0: 256 | save_checkpoint(model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth') 257 | print('checkpoint...') 258 | 259 | except Exception as e: 260 | print('[Error]', traceback.format_exc()) 261 | print(e) 262 | continue 263 | scheduler.step(np.mean(epoch_loss)) 264 | 265 | if epoch % opt.val_interval == 0: 266 | model.eval() 267 | loss_regression_ls = [] 268 | loss_classification_ls = [] 269 | for iter, data in enumerate(val_generator): 270 | with torch.no_grad(): 271 | imgs = data['img'] 272 | annot = data['annot'] 273 | 274 | if params.num_gpus == 1: 275 | imgs = imgs.cuda() 276 | annot = annot.cuda() 277 | 278 | cls_loss, reg_loss = model(imgs, annot, obj_list=params.obj_list) 279 | cls_loss = cls_loss.mean() 280 | reg_loss = reg_loss.mean() 281 | 282 | loss = cls_loss + reg_loss 283 | if loss == 0 or not torch.isfinite(loss): 284 | continue 285 | 286 | loss_classification_ls.append(cls_loss.item()) 287 | loss_regression_ls.append(reg_loss.item()) 288 | 289 | cls_loss = np.mean(loss_classification_ls) 290 | reg_loss = np.mean(loss_regression_ls) 291 | loss = cls_loss + reg_loss 292 | 293 | print( 294 | 'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'.format( 295 | epoch, opt.num_epochs, cls_loss, reg_loss, loss)) 296 | writer.add_scalars('Loss', {'val': loss}, step) 297 | writer.add_scalars('Regression_loss', {'val': reg_loss}, step) 298 | writer.add_scalars('Classfication_loss', {'val': cls_loss}, step) 299 | 300 | if loss + opt.es_min_delta < best_loss: 301 | best_loss = loss 302 | best_epoch = epoch 303 | 304 | save_checkpoint(model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth') 305 | 306 | model.train() 307 | 308 | # Early stopping 309 | if epoch - best_epoch > opt.es_patience > 0: 310 | print('[Info] Stop training at epoch {}. The lowest loss achieved is {}'.format(epoch, best_loss)) 311 | break 312 | except KeyboardInterrupt: 313 | save_checkpoint(model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth') 314 | writer.close() 315 | writer.close() 316 | 317 | 318 | def save_checkpoint(model, name): 319 | if isinstance(model, CustomDataParallel): 320 | torch.save(model.module.model.state_dict(), os.path.join(opt.saved_path, name)) 321 | else: 322 | torch.save(model.model.state_dict(), os.path.join(opt.saved_path, name)) 323 | 324 | 325 | if __name__ == '__main__': 326 | opt = get_args() 327 | train(opt) 328 | -------------------------------------------------------------------------------- /tutorial/train_shape.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true, 7 | "pycharm": { 8 | "name": "#%% md\n" 9 | } 10 | }, 11 | "source": [ 12 | "# EfficientDet Training On A Custom Dataset\n", 13 | "\n", 14 | "\n", 15 | "\n", 16 | "
\n", 17 | " \n", 18 | " View source on github\n", 19 | " \n", 20 | "\n", 21 | " \n", 22 | " Run in Google Colab\n", 23 | "
" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "source": [ 29 | "## This tutorial will show you how to train a custom dataset.\n", 30 | "\n", 31 | "## For the sake of simplicity, I generated a dataset of different shapes, like rectangles, triangles, circles.\n", 32 | "\n", 33 | "## Please enable GPU support to accelerate on notebook setting if you are using colab.\n", 34 | "\n", 35 | "### 0. Install Requirements" 36 | ], 37 | "metadata": { 38 | "collapsed": false, 39 | "pycharm": { 40 | "name": "#%% md\n" 41 | } 42 | } 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "outputs": [], 48 | "source": [ 49 | "!pip install pycocotools numpy==1.16.0 opencv-python tqdm tensorboard tensorboardX pyyaml matplotlib\n", 50 | "!pip install torch==1.4.0\n", 51 | "!pip install torchvision==0.5.0" 52 | ], 53 | "metadata": { 54 | "collapsed": false, 55 | "pycharm": { 56 | "name": "#%%\n" 57 | } 58 | } 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "source": [ 63 | "### 1. Prepare Custom Dataset/Pretrained Weights (Skip this part if you already have datasets and weights of your own)" 64 | ], 65 | "metadata": { 66 | "collapsed": false, 67 | "pycharm": { 68 | "name": "#%% md\n" 69 | } 70 | } 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "outputs": [], 76 | "source": [ 77 | "import os\n", 78 | "import sys\n", 79 | "if \"projects\" not in os.getcwd():\n", 80 | " !git clone --depth 1 https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch\n", 81 | " os.chdir('Yet-Another-EfficientDet-Pytorch')\n", 82 | " sys.path.append('.')\n", 83 | "else:\n", 84 | " !git pull\n", 85 | "\n", 86 | "# download and unzip dataset\n", 87 | "! mkdir datasets\n", 88 | "! wget https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch/releases/download/1.1/dataset_shape.tar.gz\n", 89 | "! tar xzf dataset_shape.tar.gz\n", 90 | "\n", 91 | "# download pretrained weights\n", 92 | "! mkdir weights\n", 93 | "! wget https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch/releases/download/1.0/efficientdet-d0.pth -O weights/efficientdet-d0.pth\n", 94 | "\n", 95 | "# prepare project file projects/shape.yml\n", 96 | "# showing its contents here\n", 97 | "! cat projects/shape.yml" 98 | ], 99 | "metadata": { 100 | "collapsed": false, 101 | "pycharm": { 102 | "name": "#%%\n" 103 | } 104 | } 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "source": [ 109 | "### 2. Training" 110 | ], 111 | "metadata": { 112 | "collapsed": false 113 | } 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "outputs": [], 119 | "source": [ 120 | "# consider this is a simple dataset, train head will be enough.\n", 121 | "! python train.py -c 0 -p shape --head_only True --lr 1e-3 --batch_size 32 --load_weights weights/efficientdet-d0.pth --num_epochs 50\n", 122 | "\n", 123 | "# the loss will be high at first\n", 124 | "# don't panic, be patient,\n", 125 | "# just wait for a little bit longer" 126 | ], 127 | "metadata": { 128 | "collapsed": false, 129 | "pycharm": { 130 | "name": "#%%\n" 131 | } 132 | } 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "source": [ 137 | "### 3. Evaluation" 138 | ], 139 | "metadata": { 140 | "collapsed": false 141 | } 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "outputs": [], 147 | "source": [ 148 | "! python coco_eval.py -c 0 -p shape -w logs/shape/efficientdet-d0_49_1400.pth" 149 | ], 150 | "metadata": { 151 | "collapsed": false, 152 | "pycharm": { 153 | "name": "#%%\n" 154 | } 155 | } 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "source": [ 160 | "### 4. Visualize" 161 | ], 162 | "metadata": { 163 | "collapsed": false, 164 | "pycharm": { 165 | "name": "#%% md\n" 166 | } 167 | } 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 4, 172 | "outputs": [ 173 | { 174 | "data": { 175 | "text/plain": "
", 176 | "image/png": "\n" 177 | }, 178 | "metadata": { 179 | "needs_background": "light" 180 | }, 181 | "output_type": "display_data" 182 | } 183 | ], 184 | "source": [ 185 | "import torch\n", 186 | "from torch.backends import cudnn\n", 187 | "\n", 188 | "from backbone import EfficientDetBackbone\n", 189 | "import cv2\n", 190 | "import matplotlib.pyplot as plt\n", 191 | "import numpy as np\n", 192 | "\n", 193 | "from efficientdet.utils import BBoxTransform, ClipBoxes\n", 194 | "from utils.utils import preprocess, invert_affine, postprocess\n", 195 | "\n", 196 | "compound_coef = 0\n", 197 | "force_input_size = None # set None to use default size\n", 198 | "img_path = 'datasets/shape/val/999.jpg'\n", 199 | "\n", 200 | "threshold = 0.2\n", 201 | "iou_threshold = 0.2\n", 202 | "\n", 203 | "use_cuda = True\n", 204 | "use_float16 = False\n", 205 | "cudnn.fastest = True\n", 206 | "cudnn.benchmark = True\n", 207 | "\n", 208 | "obj_list = ['rectangle', 'circle']\n", 209 | "\n", 210 | "# tf bilinear interpolation is different from any other's, just make do\n", 211 | "input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]\n", 212 | "input_size = input_sizes[compound_coef] if force_input_size is None else force_input_size\n", 213 | "ori_imgs, framed_imgs, framed_metas = preprocess(img_path, max_size=input_size)\n", 214 | "\n", 215 | "if use_cuda:\n", 216 | " x = torch.stack([torch.from_numpy(fi).cuda() for fi in framed_imgs], 0)\n", 217 | "else:\n", 218 | " x = torch.stack([torch.from_numpy(fi) for fi in framed_imgs], 0)\n", 219 | "\n", 220 | "x = x.to(torch.float32 if not use_float16 else torch.float16).permute(0, 3, 1, 2)\n", 221 | "\n", 222 | "model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list),\n", 223 | "\n", 224 | " # replace this part with your project's anchor config\n", 225 | " ratios=[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)],\n", 226 | " scales=[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])\n", 227 | "\n", 228 | "model.load_state_dict(torch.load('logs/shape/efficientdet-d0_49_1400.pth'))\n", 229 | "model.requires_grad_(False)\n", 230 | "model.eval()\n", 231 | "\n", 232 | "if use_cuda:\n", 233 | " model = model.cuda()\n", 234 | "if use_float16:\n", 235 | " model = model.half()\n", 236 | "\n", 237 | "with torch.no_grad():\n", 238 | " features, regression, classification, anchors = model(x)\n", 239 | "\n", 240 | " regressBoxes = BBoxTransform()\n", 241 | " clipBoxes = ClipBoxes()\n", 242 | "\n", 243 | " out = postprocess(x,\n", 244 | " anchors, regression, classification,\n", 245 | " regressBoxes, clipBoxes,\n", 246 | " threshold, iou_threshold)\n", 247 | "\n", 248 | "out = invert_affine(framed_metas, out)\n", 249 | "\n", 250 | "for i in range(len(ori_imgs)):\n", 251 | " if len(out[i]['rois']) == 0:\n", 252 | " continue\n", 253 | "\n", 254 | " for j in range(len(out[i]['rois'])):\n", 255 | " (x1, y1, x2, y2) = out[i]['rois'][j].astype(np.int)\n", 256 | " cv2.rectangle(ori_imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2)\n", 257 | " obj = obj_list[out[i]['class_ids'][j]]\n", 258 | " score = float(out[i]['scores'][j])\n", 259 | "\n", 260 | " cv2.putText(ori_imgs[i], '{}, {:.3f}'.format(obj, score),\n", 261 | " (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,\n", 262 | " (255, 255, 0), 1)\n", 263 | "\n", 264 | " plt.imshow(ori_imgs[i])\n", 265 | "\n" 266 | ], 267 | "metadata": { 268 | "collapsed": false, 269 | "pycharm": { 270 | "name": "#%%\n" 271 | } 272 | } 273 | } 274 | ], 275 | "metadata": { 276 | "kernelspec": { 277 | "display_name": "Python 3", 278 | "language": "python", 279 | "name": "python3" 280 | }, 281 | "language_info": { 282 | "codemirror_mode": { 283 | "name": "ipython", 284 | "version": 2 285 | }, 286 | "file_extension": ".py", 287 | "mimetype": "text/x-python", 288 | "name": "python", 289 | "nbconvert_exporter": "python", 290 | "pygments_lexer": "ipython2", 291 | "version": "2.7.6" 292 | } 293 | }, 294 | "nbformat": 4, 295 | "nbformat_minor": 0 296 | } -------------------------------------------------------------------------------- /utils/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .batchnorm import patch_sync_batchnorm, convert_model 13 | from .replicate import DataParallelWithCallback, patch_replication_callback 14 | -------------------------------------------------------------------------------- /utils/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | import contextlib 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | from torch.nn.modules.batchnorm import _BatchNorm 18 | 19 | try: 20 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 21 | except ImportError: 22 | ReduceAddCoalesced = Broadcast = None 23 | 24 | try: 25 | from jactorch.parallel.comm import SyncMaster 26 | from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback 27 | except ImportError: 28 | from .comm import SyncMaster 29 | from .replicate import DataParallelWithCallback 30 | 31 | __all__ = [ 32 | 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', 33 | 'patch_sync_batchnorm', 'convert_model' 34 | ] 35 | 36 | 37 | def _sum_ft(tensor): 38 | """sum over the first and last dimention""" 39 | return tensor.sum(dim=0).sum(dim=-1) 40 | 41 | 42 | def _unsqueeze_ft(tensor): 43 | """add new dimensions at the front and the tail""" 44 | return tensor.unsqueeze(0).unsqueeze(-1) 45 | 46 | 47 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 48 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 49 | 50 | 51 | class _SynchronizedBatchNorm(_BatchNorm): 52 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 53 | assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' 54 | 55 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 56 | 57 | self._sync_master = SyncMaster(self._data_parallel_master) 58 | 59 | self._is_parallel = False 60 | self._parallel_id = None 61 | self._slave_pipe = None 62 | 63 | def forward(self, input): 64 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 65 | if not (self._is_parallel and self.training): 66 | return F.batch_norm( 67 | input, self.running_mean, self.running_var, self.weight, self.bias, 68 | self.training, self.momentum, self.eps) 69 | 70 | # Resize the input to (B, C, -1). 71 | input_shape = input.size() 72 | input = input.view(input.size(0), self.num_features, -1) 73 | 74 | # Compute the sum and square-sum. 75 | sum_size = input.size(0) * input.size(2) 76 | input_sum = _sum_ft(input) 77 | input_ssum = _sum_ft(input ** 2) 78 | 79 | # Reduce-and-broadcast the statistics. 80 | if self._parallel_id == 0: 81 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 82 | else: 83 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 84 | 85 | # Compute the output. 86 | if self.affine: 87 | # MJY:: Fuse the multiplication for speed. 88 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 89 | else: 90 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 91 | 92 | # Reshape it. 93 | return output.view(input_shape) 94 | 95 | def __data_parallel_replicate__(self, ctx, copy_id): 96 | self._is_parallel = True 97 | self._parallel_id = copy_id 98 | 99 | # parallel_id == 0 means master device. 100 | if self._parallel_id == 0: 101 | ctx.sync_master = self._sync_master 102 | else: 103 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 104 | 105 | def _data_parallel_master(self, intermediates): 106 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 107 | 108 | # Always using same "device order" makes the ReduceAdd operation faster. 109 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 110 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 111 | 112 | to_reduce = [i[1][:2] for i in intermediates] 113 | to_reduce = [j for i in to_reduce for j in i] # flatten 114 | target_gpus = [i[1].sum.get_device() for i in intermediates] 115 | 116 | sum_size = sum([i[1].sum_size for i in intermediates]) 117 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 118 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 119 | 120 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 121 | 122 | outputs = [] 123 | for i, rec in enumerate(intermediates): 124 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 125 | 126 | return outputs 127 | 128 | def _compute_mean_std(self, sum_, ssum, size): 129 | """Compute the mean and standard-deviation with sum and square-sum. This method 130 | also maintains the moving average on the master device.""" 131 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 132 | mean = sum_ / size 133 | sumvar = ssum - sum_ * mean 134 | unbias_var = sumvar / (size - 1) 135 | bias_var = sumvar / size 136 | 137 | if hasattr(torch, 'no_grad'): 138 | with torch.no_grad(): 139 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 140 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 141 | else: 142 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 143 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 144 | 145 | return mean, bias_var.clamp(self.eps) ** -0.5 146 | 147 | 148 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 149 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 150 | mini-batch. 151 | 152 | .. math:: 153 | 154 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 155 | 156 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 157 | standard-deviation are reduced across all devices during training. 158 | 159 | For example, when one uses `nn.DataParallel` to wrap the network during 160 | training, PyTorch's implementation normalize the tensor on each device using 161 | the statistics only on that device, which accelerated the computation and 162 | is also easy to implement, but the statistics might be inaccurate. 163 | Instead, in this synchronized version, the statistics will be computed 164 | over all training samples distributed on multiple devices. 165 | 166 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 167 | as the built-in PyTorch implementation. 168 | 169 | The mean and standard-deviation are calculated per-dimension over 170 | the mini-batches and gamma and beta are learnable parameter vectors 171 | of size C (where C is the input size). 172 | 173 | During training, this layer keeps a running estimate of its computed mean 174 | and variance. The running sum is kept with a default momentum of 0.1. 175 | 176 | During evaluation, this running mean/variance is used for normalization. 177 | 178 | Because the BatchNorm is done over the `C` dimension, computing statistics 179 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 180 | 181 | Args: 182 | num_features: num_features from an expected input of size 183 | `batch_size x num_features [x width]` 184 | eps: a value added to the denominator for numerical stability. 185 | Default: 1e-5 186 | momentum: the value used for the running_mean and running_var 187 | computation. Default: 0.1 188 | affine: a boolean value that when set to ``True``, gives the layer learnable 189 | affine parameters. Default: ``True`` 190 | 191 | Shape:: 192 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 193 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 194 | 195 | Examples: 196 | >>> # With Learnable Parameters 197 | >>> m = SynchronizedBatchNorm1d(100) 198 | >>> # Without Learnable Parameters 199 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 200 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 201 | >>> output = m(input) 202 | """ 203 | 204 | def _check_input_dim(self, input): 205 | if input.dim() != 2 and input.dim() != 3: 206 | raise ValueError('expected 2D or 3D input (got {}D input)' 207 | .format(input.dim())) 208 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 209 | 210 | 211 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 212 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 213 | of 3d inputs 214 | 215 | .. math:: 216 | 217 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 218 | 219 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 220 | standard-deviation are reduced across all devices during training. 221 | 222 | For example, when one uses `nn.DataParallel` to wrap the network during 223 | training, PyTorch's implementation normalize the tensor on each device using 224 | the statistics only on that device, which accelerated the computation and 225 | is also easy to implement, but the statistics might be inaccurate. 226 | Instead, in this synchronized version, the statistics will be computed 227 | over all training samples distributed on multiple devices. 228 | 229 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 230 | as the built-in PyTorch implementation. 231 | 232 | The mean and standard-deviation are calculated per-dimension over 233 | the mini-batches and gamma and beta are learnable parameter vectors 234 | of size C (where C is the input size). 235 | 236 | During training, this layer keeps a running estimate of its computed mean 237 | and variance. The running sum is kept with a default momentum of 0.1. 238 | 239 | During evaluation, this running mean/variance is used for normalization. 240 | 241 | Because the BatchNorm is done over the `C` dimension, computing statistics 242 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 243 | 244 | Args: 245 | num_features: num_features from an expected input of 246 | size batch_size x num_features x height x width 247 | eps: a value added to the denominator for numerical stability. 248 | Default: 1e-5 249 | momentum: the value used for the running_mean and running_var 250 | computation. Default: 0.1 251 | affine: a boolean value that when set to ``True``, gives the layer learnable 252 | affine parameters. Default: ``True`` 253 | 254 | Shape:: 255 | - Input: :math:`(N, C, H, W)` 256 | - Output: :math:`(N, C, H, W)` (same shape as input) 257 | 258 | Examples: 259 | >>> # With Learnable Parameters 260 | >>> m = SynchronizedBatchNorm2d(100) 261 | >>> # Without Learnable Parameters 262 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 263 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 264 | >>> output = m(input) 265 | """ 266 | 267 | def _check_input_dim(self, input): 268 | if input.dim() != 4: 269 | raise ValueError('expected 4D input (got {}D input)' 270 | .format(input.dim())) 271 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 272 | 273 | 274 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 275 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 276 | of 4d inputs 277 | 278 | .. math:: 279 | 280 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 281 | 282 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 283 | standard-deviation are reduced across all devices during training. 284 | 285 | For example, when one uses `nn.DataParallel` to wrap the network during 286 | training, PyTorch's implementation normalize the tensor on each device using 287 | the statistics only on that device, which accelerated the computation and 288 | is also easy to implement, but the statistics might be inaccurate. 289 | Instead, in this synchronized version, the statistics will be computed 290 | over all training samples distributed on multiple devices. 291 | 292 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 293 | as the built-in PyTorch implementation. 294 | 295 | The mean and standard-deviation are calculated per-dimension over 296 | the mini-batches and gamma and beta are learnable parameter vectors 297 | of size C (where C is the input size). 298 | 299 | During training, this layer keeps a running estimate of its computed mean 300 | and variance. The running sum is kept with a default momentum of 0.1. 301 | 302 | During evaluation, this running mean/variance is used for normalization. 303 | 304 | Because the BatchNorm is done over the `C` dimension, computing statistics 305 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 306 | or Spatio-temporal BatchNorm 307 | 308 | Args: 309 | num_features: num_features from an expected input of 310 | size batch_size x num_features x depth x height x width 311 | eps: a value added to the denominator for numerical stability. 312 | Default: 1e-5 313 | momentum: the value used for the running_mean and running_var 314 | computation. Default: 0.1 315 | affine: a boolean value that when set to ``True``, gives the layer learnable 316 | affine parameters. Default: ``True`` 317 | 318 | Shape:: 319 | - Input: :math:`(N, C, D, H, W)` 320 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 321 | 322 | Examples: 323 | >>> # With Learnable Parameters 324 | >>> m = SynchronizedBatchNorm3d(100) 325 | >>> # Without Learnable Parameters 326 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 327 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 328 | >>> output = m(input) 329 | """ 330 | 331 | def _check_input_dim(self, input): 332 | if input.dim() != 5: 333 | raise ValueError('expected 5D input (got {}D input)' 334 | .format(input.dim())) 335 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 336 | 337 | 338 | @contextlib.contextmanager 339 | def patch_sync_batchnorm(): 340 | import torch.nn as nn 341 | 342 | backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d 343 | 344 | nn.BatchNorm1d = SynchronizedBatchNorm1d 345 | nn.BatchNorm2d = SynchronizedBatchNorm2d 346 | nn.BatchNorm3d = SynchronizedBatchNorm3d 347 | 348 | yield 349 | 350 | nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup 351 | 352 | 353 | def convert_model(module): 354 | """Traverse the input module and its child recursively 355 | and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d 356 | to SynchronizedBatchNorm*N*d 357 | 358 | Args: 359 | module: the input module needs to be convert to SyncBN model 360 | 361 | Examples: 362 | >>> import torch.nn as nn 363 | >>> import torchvision 364 | >>> # m is a standard pytorch model 365 | >>> m = torchvision.models.resnet18(True) 366 | >>> m = nn.DataParallel(m) 367 | >>> # after convert, m is using SyncBN 368 | >>> m = convert_model(m) 369 | """ 370 | if isinstance(module, torch.nn.DataParallel): 371 | mod = module.module 372 | mod = convert_model(mod) 373 | mod = DataParallelWithCallback(mod, device_ids=module.device_ids) 374 | return mod 375 | 376 | mod = module 377 | for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, 378 | torch.nn.modules.batchnorm.BatchNorm2d, 379 | torch.nn.modules.batchnorm.BatchNorm3d], 380 | [SynchronizedBatchNorm1d, 381 | SynchronizedBatchNorm2d, 382 | SynchronizedBatchNorm3d]): 383 | if isinstance(module, pth_module): 384 | mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) 385 | mod.running_mean = module.running_mean 386 | mod.running_var = module.running_var 387 | if module.affine: 388 | mod.weight.data = module.weight.data.clone().detach() 389 | mod.bias.data = module.bias.data.clone().detach() 390 | 391 | for name, child in module.named_children(): 392 | mod.add_module(name, convert_model(child)) 393 | 394 | return mod 395 | -------------------------------------------------------------------------------- /utils/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNorm2dReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /utils/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /utils/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /utils/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Author: Zylo117 2 | 3 | import os 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from glob import glob 9 | from torch import nn 10 | from torchvision.ops import nms 11 | from typing import Union 12 | import uuid 13 | 14 | from utils.sync_batchnorm import SynchronizedBatchNorm2d 15 | 16 | from torch.nn.init import _calculate_fan_in_and_fan_out, _no_grad_normal_ 17 | import math 18 | 19 | 20 | def invert_affine(metas: Union[float, list, tuple], preds): 21 | for i in range(len(preds)): 22 | if len(preds[i]['rois']) == 0: 23 | continue 24 | else: 25 | if metas is float: 26 | preds[i]['rois'][:, [0, 2]] = preds[i]['rois'][:, [0, 2]] / metas 27 | preds[i]['rois'][:, [1, 3]] = preds[i]['rois'][:, [1, 3]] / metas 28 | else: 29 | new_w, new_h, old_w, old_h, padding_w, padding_h = metas[i] 30 | preds[i]['rois'][:, [0, 2]] = preds[i]['rois'][:, [0, 2]] / (new_w / old_w) 31 | preds[i]['rois'][:, [1, 3]] = preds[i]['rois'][:, [1, 3]] / (new_h / old_h) 32 | return preds 33 | 34 | 35 | def aspectaware_resize_padding(image, width, height, interpolation=None, means=None): 36 | old_h, old_w, c = image.shape 37 | if old_w > old_h: 38 | new_w = width 39 | new_h = int(width / old_w * old_h) 40 | else: 41 | new_w = int(height / old_h * old_w) 42 | new_h = height 43 | 44 | canvas = np.zeros((height, height, c), np.float32) 45 | if means is not None: 46 | canvas[...] = means 47 | 48 | if new_w != old_w or new_h != old_h: 49 | if interpolation is None: 50 | image = cv2.resize(image, (new_w, new_h)) 51 | else: 52 | image = cv2.resize(image, (new_w, new_h), interpolation=interpolation) 53 | 54 | padding_h = height - new_h 55 | padding_w = width - new_w 56 | 57 | if c > 1: 58 | canvas[:new_h, :new_w] = image 59 | else: 60 | if len(image.shape) == 2: 61 | canvas[:new_h, :new_w, 0] = image 62 | else: 63 | canvas[:new_h, :new_w] = image 64 | 65 | return canvas, new_w, new_h, old_w, old_h, padding_w, padding_h, 66 | 67 | 68 | def preprocess(*image_path, max_size=512, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)): 69 | ori_imgs = [cv2.imread(img_path) for img_path in image_path] 70 | normalized_imgs = [(img / 255 - mean) / std for img in ori_imgs] 71 | imgs_meta = [aspectaware_resize_padding(img[..., ::-1], max_size, max_size, 72 | means=None) for img in normalized_imgs] 73 | framed_imgs = [img_meta[0] for img_meta in imgs_meta] 74 | framed_metas = [img_meta[1:] for img_meta in imgs_meta] 75 | 76 | return ori_imgs, framed_imgs, framed_metas 77 | 78 | 79 | def preprocess_video(*frame_from_video, max_size=512, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)): 80 | ori_imgs = frame_from_video 81 | normalized_imgs = [(img / 255 - mean) / std for img in ori_imgs] 82 | imgs_meta = [aspectaware_resize_padding(img[..., ::-1], max_size, max_size, 83 | means=None) for img in normalized_imgs] 84 | framed_imgs = [img_meta[0] for img_meta in imgs_meta] 85 | framed_metas = [img_meta[1:] for img_meta in imgs_meta] 86 | 87 | return ori_imgs, framed_imgs, framed_metas 88 | 89 | 90 | def postprocess(x, anchors, regression, classification, regressBoxes, clipBoxes, threshold, iou_threshold): 91 | transformed_anchors = regressBoxes(anchors, regression) 92 | transformed_anchors = clipBoxes(transformed_anchors, x) 93 | scores = torch.max(classification, dim=2, keepdim=True)[0] 94 | scores_over_thresh = (scores > threshold)[:, :, 0] 95 | out = [] 96 | for i in range(x.shape[0]): 97 | if scores_over_thresh.sum() == 0: 98 | out.append({ 99 | 'rois': np.array(()), 100 | 'class_ids': np.array(()), 101 | 'scores': np.array(()), 102 | }) 103 | continue 104 | 105 | classification_per = classification[i, scores_over_thresh[i, :], ...].permute(1, 0) 106 | transformed_anchors_per = transformed_anchors[i, scores_over_thresh[i, :], ...] 107 | scores_per = scores[i, scores_over_thresh[i, :], ...] 108 | anchors_nms_idx = nms(transformed_anchors_per, scores_per[:, 0], iou_threshold=iou_threshold) 109 | 110 | if anchors_nms_idx.shape[0] != 0: 111 | scores_, classes_ = classification_per[:, anchors_nms_idx].max(dim=0) 112 | boxes_ = transformed_anchors_per[anchors_nms_idx, :] 113 | 114 | out.append({ 115 | 'rois': boxes_.cpu().numpy(), 116 | 'class_ids': classes_.cpu().numpy(), 117 | 'scores': scores_.cpu().numpy(), 118 | }) 119 | else: 120 | out.append({ 121 | 'rois': np.array(()), 122 | 'class_ids': np.array(()), 123 | 'scores': np.array(()), 124 | }) 125 | 126 | return out 127 | 128 | 129 | def display(preds, imgs, obj_list, imshow=True, imwrite=False): 130 | for i in range(len(imgs)): 131 | if len(preds[i]['rois']) == 0: 132 | continue 133 | 134 | for j in range(len(preds[i]['rois'])): 135 | (x1, y1, x2, y2) = preds[i]['rois'][j].astype(np.int) 136 | cv2.rectangle(imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2) 137 | obj = obj_list[preds[i]['class_ids'][j]] 138 | score = float(preds[i]['scores'][j]) 139 | 140 | cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score), 141 | (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 142 | (255, 255, 0), 1) 143 | if imshow: 144 | cv2.imshow('img', imgs[i]) 145 | cv2.waitKey(0) 146 | 147 | if imwrite: 148 | os.makedirs('test/', exist_ok=True) 149 | cv2.imwrite(f'test/{uuid.uuid4().hex}.jpg', imgs[i]) 150 | 151 | 152 | def replace_w_sync_bn(m): 153 | for var_name in dir(m): 154 | target_attr = getattr(m, var_name) 155 | if type(target_attr) == torch.nn.BatchNorm2d: 156 | num_features = target_attr.num_features 157 | eps = target_attr.eps 158 | momentum = target_attr.momentum 159 | affine = target_attr.affine 160 | 161 | # get parameters 162 | running_mean = target_attr.running_mean 163 | running_var = target_attr.running_var 164 | if affine: 165 | weight = target_attr.weight 166 | bias = target_attr.bias 167 | 168 | setattr(m, var_name, 169 | SynchronizedBatchNorm2d(num_features, eps, momentum, affine)) 170 | 171 | target_attr = getattr(m, var_name) 172 | # set parameters 173 | target_attr.running_mean = running_mean 174 | target_attr.running_var = running_var 175 | if affine: 176 | target_attr.weight = weight 177 | target_attr.bias = bias 178 | 179 | for var_name, children in m.named_children(): 180 | replace_w_sync_bn(children) 181 | 182 | 183 | class CustomDataParallel(nn.DataParallel): 184 | """ 185 | force splitting data to all gpus instead of sending all data to cuda:0 and then moving around. 186 | """ 187 | 188 | def __init__(self, module, num_gpus): 189 | super().__init__(module) 190 | self.num_gpus = num_gpus 191 | 192 | def scatter(self, inputs, kwargs, device_ids): 193 | # More like scatter and data prep at the same time. The point is we prep the data in such a way 194 | # that no scatter is necessary, and there's no need to shuffle stuff around different GPUs. 195 | devices = ['cuda:' + str(x) for x in range(self.num_gpus)] 196 | splits = inputs[0].shape[0] // self.num_gpus 197 | 198 | return [(inputs[0][splits * device_idx: splits * (device_idx + 1)].to(f'cuda:{device_idx}', non_blocking=True), 199 | inputs[1][splits * device_idx: splits * (device_idx + 1)].to(f'cuda:{device_idx}', non_blocking=True)) 200 | for device_idx in range(len(devices))], \ 201 | [kwargs] * len(devices) 202 | 203 | 204 | def get_last_weights(weights_path): 205 | weights_path = glob(weights_path + f'/*.pth') 206 | weights_path = sorted(weights_path, 207 | key=lambda x: int(x.rsplit('_')[-1].rsplit('.')[0]), 208 | reverse=True)[0] 209 | print(f'using weights {weights_path}') 210 | return weights_path 211 | 212 | 213 | def init_weights(model): 214 | for name, module in model.named_modules(): 215 | is_conv_layer = isinstance(module, nn.Conv2d) 216 | 217 | if is_conv_layer: 218 | if "conv_list" or "header" in name: 219 | variance_scaling_(module.weight.data) 220 | else: 221 | nn.init.kaiming_uniform_(module.weight.data) 222 | 223 | if module.bias is not None: 224 | if "classifier.header" in name: 225 | bias_value = -np.log((1 - 0.01) / 0.01) 226 | torch.nn.init.constant_(module.bias, bias_value) 227 | else: 228 | module.bias.data.zero_() 229 | 230 | 231 | def variance_scaling_(tensor, gain=1.): 232 | # type: (Tensor, float) -> Tensor 233 | r""" 234 | initializer for SeparableConv in Regressor/Classifier 235 | reference: https://keras.io/zh/initializers/ VarianceScaling 236 | """ 237 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 238 | std = math.sqrt(gain / float(fan_in)) 239 | 240 | return _no_grad_normal_(tensor, 0., std) 241 | --------------------------------------------------------------------------------