├── .gitignore
├── LICENSE
├── README.md
├── app.py
├── backbone.py
├── coco_eval.py
├── detect_process.py
├── efficientService.py
├── efficientdet
├── config.py
├── dataset.py
├── loss.py
├── model.py
└── utils.py
├── efficientdet_test.py
├── efficientdet_test_videos.py
├── efficientnet
├── __init__.py
├── model.py
├── utils.py
└── utils_extra.py
├── projects
├── coco.yml
└── shape.yml
├── test
├── img.png
├── img_inferred_d0_official.jpg
├── img_inferred_d0_this_repo.jpg
└── img_inferred_d0_this_repo_0.jpg
├── train.py
├── tutorial
└── train_shape.ipynb
└── utils
├── sync_batchnorm
├── __init__.py
├── batchnorm.py
├── batchnorm_reimpl.py
├── comm.py
├── replicate.py
└── unittest.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | weights/*.pt
2 | *.pt
3 | weights/
4 | */__pycache__
5 | */runs
6 | */.idea
7 | */weights
8 | idea/
9 | idea/*.*
10 | __pycache__/
11 | __pycache__/*.*
12 |
13 | /.idea/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU LESSER GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 |
9 | This version of the GNU Lesser General Public License incorporates
10 | the terms and conditions of version 3 of the GNU General Public
11 | License, supplemented by the additional permissions listed below.
12 |
13 | 0. Additional Definitions.
14 |
15 | As used herein, "this License" refers to version 3 of the GNU Lesser
16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU
17 | General Public License.
18 |
19 | "The Library" refers to a covered work governed by this License,
20 | other than an Application or a Combined Work as defined below.
21 |
22 | An "Application" is any work that makes use of an interface provided
23 | by the Library, but which is not otherwise based on the Library.
24 | Defining a subclass of a class defined by the Library is deemed a mode
25 | of using an interface provided by the Library.
26 |
27 | A "Combined Work" is a work produced by combining or linking an
28 | Application with the Library. The particular version of the Library
29 | with which the Combined Work was made is also called the "Linked
30 | Version".
31 |
32 | The "Minimal Corresponding Source" for a Combined Work means the
33 | Corresponding Source for the Combined Work, excluding any source code
34 | for portions of the Combined Work that, considered in isolation, are
35 | based on the Application, and not on the Linked Version.
36 |
37 | The "Corresponding Application Code" for a Combined Work means the
38 | object code and/or source code for the Application, including any data
39 | and utility programs needed for reproducing the Combined Work from the
40 | Application, but excluding the System Libraries of the Combined Work.
41 |
42 | 1. Exception to Section 3 of the GNU GPL.
43 |
44 | You may convey a covered work under sections 3 and 4 of this License
45 | without being bound by section 3 of the GNU GPL.
46 |
47 | 2. Conveying Modified Versions.
48 |
49 | If you modify a copy of the Library, and, in your modifications, a
50 | facility refers to a function or data to be supplied by an Application
51 | that uses the facility (other than as an argument passed when the
52 | facility is invoked), then you may convey a copy of the modified
53 | version:
54 |
55 | a) under this License, provided that you make a good faith effort to
56 | ensure that, in the event an Application does not supply the
57 | function or data, the facility still operates, and performs
58 | whatever part of its purpose remains meaningful, or
59 |
60 | b) under the GNU GPL, with none of the additional permissions of
61 | this License applicable to that copy.
62 |
63 | 3. Object Code Incorporating Material from Library Header Files.
64 |
65 | The object code form of an Application may incorporate material from
66 | a header file that is part of the Library. You may convey such object
67 | code under terms of your choice, provided that, if the incorporated
68 | material is not limited to numerical parameters, data structure
69 | layouts and accessors, or small macros, inline functions and templates
70 | (ten or fewer lines in length), you do both of the following:
71 |
72 | a) Give prominent notice with each copy of the object code that the
73 | Library is used in it and that the Library and its use are
74 | covered by this License.
75 |
76 | b) Accompany the object code with a copy of the GNU GPL and this license
77 | document.
78 |
79 | 4. Combined Works.
80 |
81 | You may convey a Combined Work under terms of your choice that,
82 | taken together, effectively do not restrict modification of the
83 | portions of the Library contained in the Combined Work and reverse
84 | engineering for debugging such modifications, if you also do each of
85 | the following:
86 |
87 | a) Give prominent notice with each copy of the Combined Work that
88 | the Library is used in it and that the Library and its use are
89 | covered by this License.
90 |
91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license
92 | document.
93 |
94 | c) For a Combined Work that displays copyright notices during
95 | execution, include the copyright notice for the Library among
96 | these notices, as well as a reference directing the user to the
97 | copies of the GNU GPL and this license document.
98 |
99 | d) Do one of the following:
100 |
101 | 0) Convey the Minimal Corresponding Source under the terms of this
102 | License, and the Corresponding Application Code in a form
103 | suitable for, and under terms that permit, the user to
104 | recombine or relink the Application with a modified version of
105 | the Linked Version to produce a modified Combined Work, in the
106 | manner specified by section 6 of the GNU GPL for conveying
107 | Corresponding Source.
108 |
109 | 1) Use a suitable shared library mechanism for linking with the
110 | Library. A suitable mechanism is one that (a) uses at run time
111 | a copy of the Library already present on the user's computer
112 | system, and (b) will operate properly with a modified version
113 | of the Library that is interface-compatible with the Linked
114 | Version.
115 |
116 | e) Provide Installation Information, but only if you would otherwise
117 | be required to provide such information under section 6 of the
118 | GNU GPL, and only to the extent that such information is
119 | necessary to install and execute a modified version of the
120 | Combined Work produced by recombining or relinking the
121 | Application with a modified version of the Linked Version. (If
122 | you use option 4d0, the Installation Information must accompany
123 | the Minimal Corresponding Source and Corresponding Application
124 | Code. If you use option 4d1, you must provide the Installation
125 | Information in the manner specified by section 6 of the GNU GPL
126 | for conveying Corresponding Source.)
127 |
128 | 5. Combined Libraries.
129 |
130 | You may place library facilities that are a work based on the
131 | Library side by side in a single library together with other library
132 | facilities that are not Applications and are not covered by this
133 | License, and convey such a combined library under terms of your
134 | choice, if you do both of the following:
135 |
136 | a) Accompany the combined library with a copy of the same work based
137 | on the Library, uncombined with any other library facilities,
138 | conveyed under the terms of this License.
139 |
140 | b) Give prominent notice with the combined library that part of it
141 | is a work based on the Library, and explaining where to find the
142 | accompanying uncombined form of the same work.
143 |
144 | 6. Revised Versions of the GNU Lesser General Public License.
145 |
146 | The Free Software Foundation may publish revised and/or new versions
147 | of the GNU Lesser General Public License from time to time. Such new
148 | versions will be similar in spirit to the present version, but may
149 | differ in detail to address new problems or concerns.
150 |
151 | Each version is given a distinguishing version number. If the
152 | Library as you received it specifies that a certain numbered version
153 | of the GNU Lesser General Public License "or any later version"
154 | applies to it, you have the option of following the terms and
155 | conditions either of that published version or of any later version
156 | published by the Free Software Foundation. If the Library as you
157 | received it does not specify a version number of the GNU Lesser
158 | General Public License, you may choose any version of the GNU Lesser
159 | General Public License ever published by the Free Software Foundation.
160 |
161 | If the Library as you received it specifies that a proxy can decide
162 | whether future versions of the GNU Lesser General Public License shall
163 | apply, that proxy's public statement of acceptance of any version is
164 | permanent authorization for you to choose that version for the
165 | Library.
166 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Introduction
2 | A system of object detection based on Yet-Another-EfficientDet-Pytorch
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import io
3 | from io import BytesIO
4 |
5 | import requests as req
6 | from PIL import Image
7 | from flask import Flask, request, render_template
8 |
9 | import efficientService as service
10 |
11 | # flask web service
12 | app = Flask(__name__, template_folder="web")
13 |
14 |
15 | @app.route('/detect')
16 | def detect():
17 | return render_template('detect.html')
18 |
19 |
20 | @app.route('/index')
21 | def index():
22 | return render_template('index.html')
23 |
24 |
25 | @app.route('/detect/imageDetect', methods=['post'])
26 | def upload():
27 | # step 1. receive image
28 | file = request.form.get('imageBase64Code')
29 | image_link = request.form.get("imageLink")
30 |
31 | if image_link:
32 | response = req.get(image_link)
33 | image = Image.open(BytesIO(response.content))
34 | else:
35 | image = Image.open(BytesIO(base64.b64decode(file)))
36 |
37 | # step 2. detect image
38 | image_array = service.detect(image)
39 |
40 | # step 3. convert image_array to byte_array
41 | img = Image.fromarray(image_array, 'RGB')
42 | img_byte_array = io.BytesIO()
43 | img.save(img_byte_array, format='JPEG')
44 |
45 | # step 4. return image_info to page
46 | image_info = base64.b64encode(img_byte_array.getvalue()).decode('ascii')
47 | return image_info
48 |
49 |
50 | if __name__ == '__main__':
51 | app.jinja_env.auto_reload = True
52 | app.config['TEMPLATES_AUTO_RELOAD'] = True
53 | app.run(debug=False, port=8081)
54 |
--------------------------------------------------------------------------------
/backbone.py:
--------------------------------------------------------------------------------
1 | # Author: Zylo117
2 |
3 | import math
4 |
5 | import torch
6 | from torch import nn
7 |
8 | from efficientdet.model import BiFPN, Regressor, Classifier, EfficientNet
9 | from efficientdet.utils import Anchors
10 |
11 |
12 | class EfficientDetBackbone(nn.Module):
13 | def __init__(self, num_classes=80, compound_coef=0, load_weights=False, **kwargs):
14 | super(EfficientDetBackbone, self).__init__()
15 | self.compound_coef = compound_coef
16 |
17 | self.backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6]
18 | self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384]
19 | self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8]
20 | self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
21 | self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5]
22 | self.anchor_scale = [4., 4., 4., 4., 4., 4., 4., 5.]
23 | self.aspect_ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)])
24 | self.num_scales = len(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]))
25 | conv_channel_coef = {
26 | # the channels of P3/P4/P5.
27 | 0: [40, 112, 320],
28 | 1: [40, 112, 320],
29 | 2: [48, 120, 352],
30 | 3: [48, 136, 384],
31 | 4: [56, 160, 448],
32 | 5: [64, 176, 512],
33 | 6: [72, 200, 576],
34 | 7: [72, 200, 576],
35 | }
36 |
37 | num_anchors = len(self.aspect_ratios) * self.num_scales
38 |
39 | self.bifpn = nn.Sequential(
40 | *[BiFPN(self.fpn_num_filters[self.compound_coef],
41 | conv_channel_coef[compound_coef],
42 | True if _ == 0 else False,
43 | attention=True if compound_coef < 6 else False)
44 | for _ in range(self.fpn_cell_repeats[compound_coef])])
45 |
46 | self.num_classes = num_classes
47 | self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
48 | num_layers=self.box_class_repeats[self.compound_coef])
49 | self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
50 | num_classes=num_classes,
51 | num_layers=self.box_class_repeats[self.compound_coef])
52 |
53 | self.anchors = Anchors(anchor_scale=self.anchor_scale[compound_coef], **kwargs)
54 |
55 | self.backbone_net = EfficientNet(self.backbone_compound_coef[compound_coef], load_weights)
56 |
57 | def freeze_bn(self):
58 | for m in self.modules():
59 | if isinstance(m, nn.BatchNorm2d):
60 | m.eval()
61 |
62 | def forward(self, inputs):
63 | max_size = inputs.shape[-1]
64 |
65 | _, p3, p4, p5 = self.backbone_net(inputs)
66 |
67 | features = (p3, p4, p5)
68 | features = self.bifpn(features)
69 |
70 | regression = self.regressor(features)
71 | classification = self.classifier(features)
72 | anchors = self.anchors(inputs, inputs.dtype)
73 |
74 | return features, regression, classification, anchors
75 |
76 | def init_backbone(self, path):
77 | state_dict = torch.load(path)
78 | try:
79 | ret = self.load_state_dict(state_dict, strict=False)
80 | print(ret)
81 | except RuntimeError as e:
82 | print('Ignoring ' + str(e) + '"')
83 |
--------------------------------------------------------------------------------
/coco_eval.py:
--------------------------------------------------------------------------------
1 | # Author: Zylo117
2 |
3 | """
4 | COCO-Style Evaluations
5 |
6 | put images here datasets/your_project_name/annotations/val_set_name/*.jpg
7 | put annotations here datasets/your_project_name/annotations/instances_{val_set_name}.json
8 | put weights here /path/to/your/weights/*.pth
9 | change compound_coef
10 |
11 | """
12 |
13 | import json
14 | import os
15 |
16 | import argparse
17 | import torch
18 | import yaml
19 | from tqdm import tqdm
20 | from pycocotools.coco import COCO
21 | from pycocotools.cocoeval import COCOeval
22 |
23 | from backbone import EfficientDetBackbone
24 | from efficientdet.utils import BBoxTransform, ClipBoxes
25 | from utils.utils import preprocess, invert_affine, postprocess
26 |
27 | ap = argparse.ArgumentParser()
28 | ap.add_argument('-p', '--project', type=str, default='coco', help='project file that contains parameters')
29 | ap.add_argument('-c', '--compound_coef', type=int, default=0, help='coefficients of efficientdet')
30 | ap.add_argument('-w', '--weights', type=str, default=None, help='/path/to/weights')
31 | ap.add_argument('--nms_threshold', type=float, default=0.5, help='nms threshold, don\'t change it if not for testing purposes')
32 | ap.add_argument('--cuda', type=bool, default=True)
33 | ap.add_argument('--device', type=int, default=0)
34 | ap.add_argument('--float16', type=bool, default=False)
35 | ap.add_argument('--override', type=bool, default=True, help='override previous bbox results file if exists')
36 | args = ap.parse_args()
37 |
38 | compound_coef = args.compound_coef
39 | nms_threshold = args.nms_threshold
40 | use_cuda = args.cuda
41 | gpu = args.device
42 | use_float16 = args.float16
43 | override_prev_results = args.override
44 | project_name = args.project
45 | weights_path = f'weights/efficientdet-d{compound_coef}.pth' if args.weights is None else args.weights
46 |
47 | print(f'running coco-style evaluation on project {project_name}, weights {weights_path}...')
48 |
49 | params = yaml.safe_load(open(f'projects/{project_name}.yml'))
50 | obj_list = params['obj_list']
51 |
52 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
53 |
54 |
55 | def evaluate_coco(img_path, set_name, image_ids, coco, model, threshold=0.05):
56 | results = []
57 | processed_image_ids = []
58 |
59 | regressBoxes = BBoxTransform()
60 | clipBoxes = ClipBoxes()
61 |
62 | for image_id in tqdm(image_ids):
63 | image_info = coco.loadImgs(image_id)[0]
64 | image_path = img_path + image_info['file_name']
65 |
66 | ori_imgs, framed_imgs, framed_metas = preprocess(image_path, max_size=input_sizes[compound_coef])
67 | x = torch.from_numpy(framed_imgs[0])
68 |
69 | if use_cuda:
70 | x = x.cuda(gpu)
71 | if use_float16:
72 | x = x.half()
73 | else:
74 | x = x.float()
75 | else:
76 | x = x.float()
77 |
78 | x = x.unsqueeze(0).permute(0, 3, 1, 2)
79 | features, regression, classification, anchors = model(x)
80 |
81 | preds = postprocess(x,
82 | anchors, regression, classification,
83 | regressBoxes, clipBoxes,
84 | threshold, nms_threshold)
85 |
86 | processed_image_ids.append(image_id)
87 |
88 | if not preds:
89 | continue
90 |
91 | preds = invert_affine(framed_metas, preds)[0]
92 |
93 | scores = preds['scores']
94 | class_ids = preds['class_ids']
95 | rois = preds['rois']
96 |
97 | if rois.shape[0] > 0:
98 | # x1,y1,x2,y2 -> x1,y1,w,h
99 | rois[:, 2] -= rois[:, 0]
100 | rois[:, 3] -= rois[:, 1]
101 |
102 | bbox_score = scores
103 |
104 | for roi_id in range(rois.shape[0]):
105 | score = float(bbox_score[roi_id])
106 | label = int(class_ids[roi_id])
107 | box = rois[roi_id, :]
108 |
109 | if score < threshold:
110 | break
111 | image_result = {
112 | 'image_id': image_id,
113 | 'category_id': label + 1,
114 | 'score': float(score),
115 | 'bbox': box.tolist(),
116 | }
117 |
118 | results.append(image_result)
119 |
120 | if not len(results):
121 | raise Exception('the model does not provide any valid output, check model architecture and the data input')
122 |
123 | # write output
124 | filepath = f'{set_name}_bbox_results.json'
125 | if os.path.exists(filepath):
126 | os.remove(filepath)
127 | json.dump(results, open(filepath, 'w'), indent=4)
128 |
129 | return processed_image_ids
130 |
131 |
132 | def _eval(coco_gt, image_ids, pred_json_path):
133 | # load results in COCO evaluation tool
134 | coco_pred = coco_gt.loadRes(pred_json_path)
135 |
136 | # run COCO evaluation
137 | print('BBox')
138 | coco_eval = COCOeval(coco_gt, coco_pred, 'bbox')
139 | coco_eval.params.imgIds = image_ids
140 | coco_eval.evaluate()
141 | coco_eval.accumulate()
142 | coco_eval.summarize()
143 |
144 |
145 | if __name__ == '__main__':
146 | SET_NAME = params['val_set']
147 | VAL_GT = f'datasets/{params["project_name"]}/annotations/instances_{SET_NAME}.json'
148 | VAL_IMGS = f'datasets/{params["project_name"]}/{SET_NAME}/'
149 | MAX_IMAGES = 10000
150 | coco_gt = COCO(VAL_GT)
151 | image_ids = coco_gt.getImgIds()[:MAX_IMAGES]
152 |
153 | if override_prev_results or not os.path.exists(f'{SET_NAME}_bbox_results.json'):
154 | model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list),
155 | ratios=eval(params['anchors_ratios']), scales=eval(params['anchors_scales']))
156 | model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
157 | model.requires_grad_(False)
158 | model.eval()
159 |
160 | if use_cuda:
161 | model.cuda(gpu)
162 |
163 | if use_float16:
164 | model.half()
165 |
166 | image_ids = evaluate_coco(VAL_IMGS, SET_NAME, image_ids, coco_gt, model)
167 |
168 | _eval(coco_gt, image_ids, f'{SET_NAME}_bbox_results.json')
169 |
--------------------------------------------------------------------------------
/detect_process.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import io
3 | import json
4 | import time
5 | from io import BytesIO
6 |
7 | import redis as redis
8 | import requests as req
9 | from PIL import Image
10 |
11 | import efficientService as service
12 |
13 | # redis cache client
14 | RedisCache = redis.StrictRedis(host="localhost", port=6379, db=0)
15 |
16 | # the queue of expect to detect
17 | IMAGE_QUEUE = "imageQueue"
18 |
19 | # slice size every foreach
20 |
21 | BATCH_SIZE = 32
22 | # server sleep when queue>0
23 |
24 | SERVER_SLEEP = 0.1
25 | # server sleep when queue=0
26 | SERVER_SLEEP_IDLE = 0.5
27 |
28 |
29 | def detect_process():
30 | while True:
31 | # 从redis中获取预测图像队列
32 | queue = RedisCache.lrange(IMAGE_QUEUE, 0, BATCH_SIZE - 1)
33 | if len(queue) < 1:
34 | time.sleep(SERVER_SLEEP)
35 | continue
36 |
37 | print("classify_process is running")
38 |
39 | # 遍历队列
40 | for item in queue:
41 | # step 1. 获取队列中的图像信息
42 | item = json.loads(item);
43 | image_key = item.get("imageKey")
44 | image_link = item.get("imageUrl")
45 | response = req.get(image_link)
46 | image = Image.open(BytesIO(response.content))
47 |
48 | # step 2. detect image 识别图片
49 | image_array = service.detect(image)
50 |
51 | # step 3. convert image_array to byte_array
52 | img = Image.fromarray(image_array, 'RGB')
53 | img_byte_array = io.BytesIO()
54 | img.save(img_byte_array, format='JPEG')
55 |
56 | # step 4. set result_info in redis
57 | image_info = base64.b64encode(img_byte_array.getvalue()).decode('ascii')
58 |
59 | RedisCache.hset(name=image_key, key="consultOut", value=image_info)
60 |
61 | # 删除队列中已识别的图片信息
62 | RedisCache.ltrim(IMAGE_QUEUE, BATCH_SIZE, -1)
63 |
64 | time.sleep(SERVER_SLEEP)
65 |
66 |
67 | if __name__ == '__main__':
68 | print("start classify_process")
69 | detect_process()
70 |
--------------------------------------------------------------------------------
/efficientService.py:
--------------------------------------------------------------------------------
1 | # Author: Zylo117
2 |
3 | """
4 | Simple Inference Script of EfficientDet-Pytorch
5 | """
6 | import random
7 |
8 | import cv2
9 | import numpy as np
10 | import torch
11 | from torch.backends import cudnn
12 |
13 | from backbone import EfficientDetBackbone
14 | from efficientdet.utils import BBoxTransform, ClipBoxes
15 | from utils.utils import invert_affine, postprocess, aspectaware_resize_padding
16 |
17 | compound_coef = 0
18 | force_input_size = None # set None to use default size
19 |
20 | # replace this part with your project's anchor config
21 | anchor_ratios = [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]
22 | anchor_scales = [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]
23 |
24 | threshold = 0.2
25 | iou_threshold = 0.2
26 |
27 | use_cuda = True
28 | use_float16 = False
29 | cudnn.fastest = True
30 | cudnn.benchmark = True
31 |
32 | obj_list = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
33 | 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
34 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie',
35 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
36 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
37 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
38 | 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv',
39 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
40 | 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
41 | 'toothbrush']
42 |
43 | # tf bilinear interpolation is different from any other's, just make do
44 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
45 | input_size = input_sizes[compound_coef] if force_input_size is None else force_input_size
46 |
47 | model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list),
48 | ratios=anchor_ratios, scales=anchor_scales)
49 | model.load_state_dict(torch.load(f'weights/efficientdet-d{compound_coef}.pth'))
50 | model.requires_grad_(False)
51 | model.eval()
52 |
53 | if use_cuda:
54 | model = model.cuda()
55 | if use_float16:
56 | model = model.half()
57 |
58 |
59 | def detect(image):
60 | # convert image to array
61 | frame = np.array(image)
62 |
63 | # convert to cv format
64 | frames = frame[:, :, ::-1]
65 |
66 | ori_imgs, framed_imgs, framed_metas = image_preprocess(frames, max_size=input_size)
67 |
68 | if use_cuda:
69 | x = torch.stack([torch.from_numpy(fi).cuda() for fi in framed_imgs], 0)
70 | else:
71 | x = torch.stack([torch.from_numpy(fi) for fi in framed_imgs], 0)
72 |
73 | x = x.to(torch.float32 if not use_float16 else torch.float16).permute(0, 3, 1, 2)
74 |
75 | with torch.no_grad():
76 | features, regression, classification, anchors = model(x)
77 |
78 | regressBoxes = BBoxTransform()
79 | clipBoxes = ClipBoxes()
80 |
81 | out = postprocess(x,
82 | anchors, regression, classification,
83 | regressBoxes, clipBoxes,
84 | threshold, iou_threshold)
85 |
86 | out = invert_affine(framed_metas, out)
87 | render_frame = display(out, frame, imshow=True, imwrite=False)
88 | return render_frame
89 |
90 |
91 | def image_preprocess(image_path, max_size=512, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)):
92 | ori_imgs = [image_path]
93 | normalized_imgs = [(img / 255 - mean) / std for img in ori_imgs]
94 | imgs_meta = [aspectaware_resize_padding(img[..., ::-1], max_size, max_size,
95 | means=None) for img in normalized_imgs]
96 | framed_imgs = [img_meta[0] for img_meta in imgs_meta]
97 | framed_metas = [img_meta[1:] for img_meta in imgs_meta]
98 |
99 | return ori_imgs, framed_imgs, framed_metas
100 |
101 |
102 | def display(preds, imgs, imshow=True, imwrite=False):
103 | imgs = [imgs]
104 | for i in range(len(imgs)):
105 | if len(preds[i]['rois']) == 0:
106 | continue
107 |
108 | for j in range(len(preds[i]['rois'])):
109 | (x1, y1, x2, y2) = preds[i]['rois'][j].astype(np.int)
110 | color = [random.randint(0, 255) for _ in range(3)]
111 | cv2.rectangle(imgs[i], (x1, y1), (x2, y2), color, 2)
112 | obj = obj_list[preds[i]['class_ids'][j]]
113 | score = float(preds[i]['scores'][j])
114 |
115 | label = obj
116 | label_size = cv2.getTextSize(label, 0, fontScale=2 / 3, thickness=1)[0]
117 | cv2.rectangle(imgs[i], (x1, y1), (x1 + label_size[0], y1 - label_size[1] - 3), color, -1)
118 | # cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score),
119 | # (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
120 | # [225, 255, 255], 1, lineType=cv2.LINE_AA)
121 | cv2.putText(imgs[i], label, (x1, y1 - 8), 0, 2 / 3, [225, 255, 255], thickness=1,
122 | lineType=cv2.LINE_AA)
123 |
124 | return imgs[i]
125 |
--------------------------------------------------------------------------------
/efficientdet/config.py:
--------------------------------------------------------------------------------
1 | COCO_CLASSES = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
2 | "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog",
3 | "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
4 | "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite",
5 | "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
6 | "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
7 | "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant",
8 | "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
9 | "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
10 | "teddy bear", "hair drier", "toothbrush"]
11 |
12 | colors = [(39, 129, 113), (164, 80, 133), (83, 122, 114), (99, 81, 172), (95, 56, 104), (37, 84, 86), (14, 89, 122),
13 | (80, 7, 65), (10, 102, 25), (90, 185, 109), (106, 110, 132), (169, 158, 85), (188, 185, 26), (103, 1, 17),
14 | (82, 144, 81), (92, 7, 184), (49, 81, 155), (179, 177, 69), (93, 187, 158), (13, 39, 73), (12, 50, 60),
15 | (16, 179, 33), (112, 69, 165), (15, 139, 63), (33, 191, 159), (182, 173, 32), (34, 113, 133), (90, 135, 34),
16 | (53, 34, 86), (141, 35, 190), (6, 171, 8), (118, 76, 112), (89, 60, 55), (15, 54, 88), (112, 75, 181),
17 | (42, 147, 38), (138, 52, 63), (128, 65, 149), (106, 103, 24), (168, 33, 45), (28, 136, 135), (86, 91, 108),
18 | (52, 11, 76), (142, 6, 189), (57, 81, 168), (55, 19, 148), (182, 101, 89), (44, 65, 179), (1, 33, 26),
19 | (122, 164, 26), (70, 63, 134), (137, 106, 82), (120, 118, 52), (129, 74, 42), (182, 147, 112), (22, 157, 50),
20 | (56, 50, 20), (2, 22, 177), (156, 100, 106), (21, 35, 42), (13, 8, 121), (142, 92, 28), (45, 118, 33),
21 | (105, 118, 30), (7, 185, 124), (46, 34, 146), (105, 184, 169), (22, 18, 5), (147, 71, 73), (181, 64, 91),
22 | (31, 39, 184), (164, 179, 33), (96, 50, 18), (95, 15, 106), (113, 68, 54), (136, 116, 112), (119, 139, 130),
23 | (31, 139, 34), (66, 6, 127), (62, 39, 2), (49, 99, 180), (49, 119, 155), (153, 50, 183), (125, 38, 3),
24 | (129, 87, 143), (49, 87, 40), (128, 62, 120), (73, 85, 148), (28, 144, 118), (29, 9, 24), (175, 45, 108),
25 | (81, 175, 64), (178, 19, 157), (74, 188, 190), (18, 114, 2), (62, 128, 96), (21, 3, 150), (0, 6, 95),
26 | (2, 20, 184), (122, 37, 185)]
27 |
--------------------------------------------------------------------------------
/efficientdet/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 |
5 | from torch.utils.data import Dataset, DataLoader
6 | from pycocotools.coco import COCO
7 | import cv2
8 |
9 |
10 | class CocoDataset(Dataset):
11 | def __init__(self, root_dir, set='train2017', transform=None):
12 |
13 | self.root_dir = root_dir
14 | self.set_name = set
15 | self.transform = transform
16 |
17 | self.coco = COCO(os.path.join(self.root_dir, 'annotations', 'instances_' + self.set_name + '.json'))
18 | self.image_ids = self.coco.getImgIds()
19 |
20 | self.load_classes()
21 |
22 | def load_classes(self):
23 |
24 | # load class names (name -> label)
25 | categories = self.coco.loadCats(self.coco.getCatIds())
26 | categories.sort(key=lambda x: x['id'])
27 |
28 | self.classes = {}
29 | self.coco_labels = {}
30 | self.coco_labels_inverse = {}
31 | for c in categories:
32 | self.coco_labels[len(self.classes)] = c['id']
33 | self.coco_labels_inverse[c['id']] = len(self.classes)
34 | self.classes[c['name']] = len(self.classes)
35 |
36 | # also load the reverse (label -> name)
37 | self.labels = {}
38 | for key, value in self.classes.items():
39 | self.labels[value] = key
40 |
41 | def __len__(self):
42 | return len(self.image_ids)
43 |
44 | def __getitem__(self, idx):
45 |
46 | img = self.load_image(idx)
47 | annot = self.load_annotations(idx)
48 | sample = {'img': img, 'annot': annot}
49 | if self.transform:
50 | sample = self.transform(sample)
51 | return sample
52 |
53 | def load_image(self, image_index):
54 | image_info = self.coco.loadImgs(self.image_ids[image_index])[0]
55 | path = os.path.join(self.root_dir, self.set_name, image_info['file_name'])
56 | img = cv2.imread(path)
57 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
58 |
59 | return img.astype(np.float32) / 255.
60 |
61 | def load_annotations(self, image_index):
62 | # get ground truth annotations
63 | annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False)
64 | annotations = np.zeros((0, 5))
65 |
66 | # some images appear to miss annotations
67 | if len(annotations_ids) == 0:
68 | return annotations
69 |
70 | # parse annotations
71 | coco_annotations = self.coco.loadAnns(annotations_ids)
72 | for idx, a in enumerate(coco_annotations):
73 |
74 | # some annotations have basically no width / height, skip them
75 | if a['bbox'][2] < 1 or a['bbox'][3] < 1:
76 | continue
77 |
78 | annotation = np.zeros((1, 5))
79 | annotation[0, :4] = a['bbox']
80 | annotation[0, 4] = self.coco_label_to_label(a['category_id'])
81 | annotations = np.append(annotations, annotation, axis=0)
82 |
83 | # transform from [x, y, w, h] to [x1, y1, x2, y2]
84 | annotations[:, 2] = annotations[:, 0] + annotations[:, 2]
85 | annotations[:, 3] = annotations[:, 1] + annotations[:, 3]
86 |
87 | return annotations
88 |
89 | def coco_label_to_label(self, coco_label):
90 | return self.coco_labels_inverse[coco_label]
91 |
92 | def label_to_coco_label(self, label):
93 | return self.coco_labels[label]
94 |
95 |
96 | def collater(data):
97 | imgs = [s['img'] for s in data]
98 | annots = [s['annot'] for s in data]
99 | scales = [s['scale'] for s in data]
100 |
101 | imgs = torch.from_numpy(np.stack(imgs, axis=0))
102 |
103 | max_num_annots = max(annot.shape[0] for annot in annots)
104 |
105 | if max_num_annots > 0:
106 |
107 | annot_padded = torch.ones((len(annots), max_num_annots, 5)) * -1
108 |
109 | for idx, annot in enumerate(annots):
110 | if annot.shape[0] > 0:
111 | annot_padded[idx, :annot.shape[0], :] = annot
112 | else:
113 | annot_padded = torch.ones((len(annots), 1, 5)) * -1
114 |
115 | imgs = imgs.permute(0, 3, 1, 2)
116 |
117 | return {'img': imgs, 'annot': annot_padded, 'scale': scales}
118 |
119 |
120 | class Resizer(object):
121 | """Convert ndarrays in sample to Tensors."""
122 |
123 | def __init__(self, img_size=512):
124 | self.img_size = img_size
125 |
126 | def __call__(self, sample):
127 | image, annots = sample['img'], sample['annot']
128 | height, width, _ = image.shape
129 | if height > width:
130 | scale = self.img_size / height
131 | resized_height = self.img_size
132 | resized_width = int(width * scale)
133 | else:
134 | scale = self.img_size / width
135 | resized_height = int(height * scale)
136 | resized_width = self.img_size
137 |
138 | image = cv2.resize(image, (resized_width, resized_height), interpolation=cv2.INTER_LINEAR)
139 |
140 | new_image = np.zeros((self.img_size, self.img_size, 3))
141 | new_image[0:resized_height, 0:resized_width] = image
142 |
143 | annots[:, :4] *= scale
144 |
145 | return {'img': torch.from_numpy(new_image).to(torch.float32), 'annot': torch.from_numpy(annots), 'scale': scale}
146 |
147 |
148 | class Augmenter(object):
149 | """Convert ndarrays in sample to Tensors."""
150 |
151 | def __call__(self, sample, flip_x=0.5):
152 | if np.random.rand() < flip_x:
153 | image, annots = sample['img'], sample['annot']
154 | image = image[:, ::-1, :]
155 |
156 | rows, cols, channels = image.shape
157 |
158 | x1 = annots[:, 0].copy()
159 | x2 = annots[:, 2].copy()
160 |
161 | x_tmp = x1.copy()
162 |
163 | annots[:, 0] = cols - x2
164 | annots[:, 2] = cols - x_tmp
165 |
166 | sample = {'img': image, 'annot': annots}
167 |
168 | return sample
169 |
170 |
171 | class Normalizer(object):
172 |
173 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
174 | self.mean = np.array([[mean]])
175 | self.std = np.array([[std]])
176 |
177 | def __call__(self, sample):
178 | image, annots = sample['img'], sample['annot']
179 |
180 | return {'img': ((image.astype(np.float32) - self.mean) / self.std), 'annot': annots}
181 |
--------------------------------------------------------------------------------
/efficientdet/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import cv2
4 | import numpy as np
5 |
6 | from efficientdet.utils import BBoxTransform, ClipBoxes
7 | from utils.utils import postprocess, invert_affine, display
8 |
9 |
10 | def calc_iou(a, b):
11 | # a(anchor) [boxes, (y1, x1, y2, x2)]
12 | # b(gt, coco-style) [boxes, (x1, y1, x2, y2)]
13 |
14 | area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])
15 | iw = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 0])
16 | ih = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 1])
17 | iw = torch.clamp(iw, min=0)
18 | ih = torch.clamp(ih, min=0)
19 | ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih
20 | ua = torch.clamp(ua, min=1e-8)
21 | intersection = iw * ih
22 | IoU = intersection / ua
23 |
24 | return IoU
25 |
26 |
27 | class FocalLoss(nn.Module):
28 | def __init__(self):
29 | super(FocalLoss, self).__init__()
30 |
31 | def forward(self, classifications, regressions, anchors, annotations, **kwargs):
32 | alpha = 0.25
33 | gamma = 2.0
34 | batch_size = classifications.shape[0]
35 | classification_losses = []
36 | regression_losses = []
37 |
38 | anchor = anchors[0, :, :] # assuming all image sizes are the same, which it is
39 | dtype = anchors.dtype
40 |
41 | anchor_widths = anchor[:, 3] - anchor[:, 1]
42 | anchor_heights = anchor[:, 2] - anchor[:, 0]
43 | anchor_ctr_x = anchor[:, 1] + 0.5 * anchor_widths
44 | anchor_ctr_y = anchor[:, 0] + 0.5 * anchor_heights
45 |
46 | for j in range(batch_size):
47 |
48 | classification = classifications[j, :, :]
49 | regression = regressions[j, :, :]
50 |
51 | bbox_annotation = annotations[j]
52 | bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
53 |
54 | if bbox_annotation.shape[0] == 0:
55 | if torch.cuda.is_available():
56 | regression_losses.append(torch.tensor(0).to(dtype).cuda())
57 | classification_losses.append(torch.tensor(0).to(dtype).cuda())
58 | else:
59 | regression_losses.append(torch.tensor(0).to(dtype))
60 | classification_losses.append(torch.tensor(0).to(dtype))
61 |
62 | continue
63 |
64 | classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
65 |
66 | IoU = calc_iou(anchor[:, :], bbox_annotation[:, :4])
67 |
68 | IoU_max, IoU_argmax = torch.max(IoU, dim=1)
69 |
70 | # compute the loss for classification
71 | targets = torch.ones_like(classification) * -1
72 | if torch.cuda.is_available():
73 | targets = targets.cuda()
74 |
75 | targets[torch.lt(IoU_max, 0.4), :] = 0
76 |
77 | positive_indices = torch.ge(IoU_max, 0.5)
78 |
79 | num_positive_anchors = positive_indices.sum()
80 |
81 | assigned_annotations = bbox_annotation[IoU_argmax, :]
82 |
83 | targets[positive_indices, :] = 0
84 | targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1
85 |
86 | alpha_factor = torch.ones_like(targets) * alpha
87 | if torch.cuda.is_available():
88 | alpha_factor = alpha_factor.cuda()
89 |
90 | alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
91 | focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
92 | focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
93 |
94 | bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))
95 |
96 | cls_loss = focal_weight * bce
97 |
98 | zeros = torch.zeros_like(cls_loss)
99 | if torch.cuda.is_available():
100 | zeros = zeros.cuda()
101 | cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros)
102 |
103 | classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0))
104 |
105 | if positive_indices.sum() > 0:
106 | assigned_annotations = assigned_annotations[positive_indices, :]
107 |
108 | anchor_widths_pi = anchor_widths[positive_indices]
109 | anchor_heights_pi = anchor_heights[positive_indices]
110 | anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
111 | anchor_ctr_y_pi = anchor_ctr_y[positive_indices]
112 |
113 | gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0]
114 | gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
115 | gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths
116 | gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights
117 |
118 | # efficientdet style
119 | gt_widths = torch.clamp(gt_widths, min=1)
120 | gt_heights = torch.clamp(gt_heights, min=1)
121 |
122 | targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
123 | targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
124 | targets_dw = torch.log(gt_widths / anchor_widths_pi)
125 | targets_dh = torch.log(gt_heights / anchor_heights_pi)
126 |
127 | targets = torch.stack((targets_dy, targets_dx, targets_dh, targets_dw))
128 | targets = targets.t()
129 |
130 | regression_diff = torch.abs(targets - regression[positive_indices, :])
131 |
132 | regression_loss = torch.where(
133 | torch.le(regression_diff, 1.0 / 9.0),
134 | 0.5 * 9.0 * torch.pow(regression_diff, 2),
135 | regression_diff - 0.5 / 9.0
136 | )
137 | regression_losses.append(regression_loss.mean())
138 | else:
139 | if torch.cuda.is_available():
140 | regression_losses.append(torch.tensor(0).to(dtype).cuda())
141 | else:
142 | regression_losses.append(torch.tensor(0).to(dtype))
143 |
144 | # debug
145 | imgs = kwargs.get('imgs', None)
146 | if imgs is not None:
147 | regressBoxes = BBoxTransform()
148 | clipBoxes = ClipBoxes()
149 | obj_list = kwargs.get('obj_list', None)
150 | out = postprocess(imgs.detach(),
151 | torch.stack([anchors[0]] * imgs.shape[0], 0).detach(), regressions.detach(), classifications.detach(),
152 | regressBoxes, clipBoxes,
153 | 0.5, 0.3)
154 | imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
155 | imgs = ((imgs * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255).astype(np.uint8)
156 | imgs = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in imgs]
157 | display(out, imgs, obj_list, imshow=False, imwrite=True)
158 |
159 | return torch.stack(classification_losses).mean(dim=0, keepdim=True), \
160 | torch.stack(regression_losses).mean(dim=0, keepdim=True)
161 |
--------------------------------------------------------------------------------
/efficientdet/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from torchvision.ops.boxes import nms as nms_torch
4 |
5 | from efficientnet import EfficientNet as EffNet
6 | from efficientnet.utils import MemoryEfficientSwish, Swish
7 | from efficientnet.utils_extra import Conv2dStaticSamePadding, MaxPool2dStaticSamePadding
8 |
9 |
10 | def nms(dets, thresh):
11 | return nms_torch(dets[:, :4], dets[:, 4], thresh)
12 |
13 |
14 | class SeparableConvBlock(nn.Module):
15 | """
16 | created by Zylo117
17 | """
18 |
19 | def __init__(self, in_channels, out_channels=None, norm=True, activation=False, onnx_export=False):
20 | super(SeparableConvBlock, self).__init__()
21 | if out_channels is None:
22 | out_channels = in_channels
23 |
24 | # Q: whether separate conv
25 | # share bias between depthwise_conv and pointwise_conv
26 | # or just pointwise_conv apply bias.
27 | # A: Confirmed, just pointwise_conv applies bias, depthwise_conv has no bias.
28 |
29 | self.depthwise_conv = Conv2dStaticSamePadding(in_channels, in_channels,
30 | kernel_size=3, stride=1, groups=in_channels, bias=False)
31 | self.pointwise_conv = Conv2dStaticSamePadding(in_channels, out_channels, kernel_size=1, stride=1)
32 |
33 | self.norm = norm
34 | if self.norm:
35 | # Warning: pytorch momentum is different from tensorflow's, momentum_pytorch = 1 - momentum_tensorflow
36 | self.bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.01, eps=1e-3)
37 |
38 | self.activation = activation
39 | if self.activation:
40 | self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
41 |
42 | def forward(self, x):
43 | x = self.depthwise_conv(x)
44 | x = self.pointwise_conv(x)
45 |
46 | if self.norm:
47 | x = self.bn(x)
48 |
49 | if self.activation:
50 | x = self.swish(x)
51 |
52 | return x
53 |
54 |
55 | class BiFPN(nn.Module):
56 | """
57 | modified by Zylo117
58 | """
59 |
60 | def __init__(self, num_channels, conv_channels, first_time=False, epsilon=1e-4, onnx_export=False, attention=True):
61 | """
62 |
63 | Args:
64 | num_channels:
65 | conv_channels:
66 | first_time: whether the input comes directly from the efficientnet,
67 | if True, downchannel it first, and downsample P5 to generate P6 then P7
68 | epsilon: epsilon of fast weighted attention sum of BiFPN, not the BN's epsilon
69 | onnx_export: if True, use Swish instead of MemoryEfficientSwish
70 | """
71 | super(BiFPN, self).__init__()
72 | self.epsilon = epsilon
73 | # Conv layers
74 | self.conv6_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
75 | self.conv5_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
76 | self.conv4_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
77 | self.conv3_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
78 | self.conv4_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
79 | self.conv5_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
80 | self.conv6_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
81 | self.conv7_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
82 |
83 | # Feature scaling layers
84 | self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest')
85 | self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest')
86 | self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest')
87 | self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest')
88 |
89 | self.p4_downsample = MaxPool2dStaticSamePadding(3, 2)
90 | self.p5_downsample = MaxPool2dStaticSamePadding(3, 2)
91 | self.p6_downsample = MaxPool2dStaticSamePadding(3, 2)
92 | self.p7_downsample = MaxPool2dStaticSamePadding(3, 2)
93 |
94 | self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
95 |
96 | self.first_time = first_time
97 | if self.first_time:
98 | self.p5_down_channel = nn.Sequential(
99 | Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
100 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
101 | )
102 | self.p4_down_channel = nn.Sequential(
103 | Conv2dStaticSamePadding(conv_channels[1], num_channels, 1),
104 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
105 | )
106 | self.p3_down_channel = nn.Sequential(
107 | Conv2dStaticSamePadding(conv_channels[0], num_channels, 1),
108 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
109 | )
110 |
111 | self.p5_to_p6 = nn.Sequential(
112 | Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
113 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
114 | MaxPool2dStaticSamePadding(3, 2)
115 | )
116 | self.p6_to_p7 = nn.Sequential(
117 | MaxPool2dStaticSamePadding(3, 2)
118 | )
119 |
120 | self.p4_down_channel_2 = nn.Sequential(
121 | Conv2dStaticSamePadding(conv_channels[1], num_channels, 1),
122 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
123 | )
124 | self.p5_down_channel_2 = nn.Sequential(
125 | Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
126 | nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
127 | )
128 |
129 | # Weight
130 | self.p6_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
131 | self.p6_w1_relu = nn.ReLU()
132 | self.p5_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
133 | self.p5_w1_relu = nn.ReLU()
134 | self.p4_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
135 | self.p4_w1_relu = nn.ReLU()
136 | self.p3_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
137 | self.p3_w1_relu = nn.ReLU()
138 |
139 | self.p4_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
140 | self.p4_w2_relu = nn.ReLU()
141 | self.p5_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
142 | self.p5_w2_relu = nn.ReLU()
143 | self.p6_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
144 | self.p6_w2_relu = nn.ReLU()
145 | self.p7_w2 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
146 | self.p7_w2_relu = nn.ReLU()
147 |
148 | self.attention = attention
149 |
150 | def forward(self, inputs):
151 | """
152 | illustration of a minimal bifpn unit
153 | P7_0 -------------------------> P7_2 -------->
154 | |-------------| ↑
155 | ↓ |
156 | P6_0 ---------> P6_1 ---------> P6_2 -------->
157 | |-------------|--------------↑ ↑
158 | ↓ |
159 | P5_0 ---------> P5_1 ---------> P5_2 -------->
160 | |-------------|--------------↑ ↑
161 | ↓ |
162 | P4_0 ---------> P4_1 ---------> P4_2 -------->
163 | |-------------|--------------↑ ↑
164 | |--------------↓ |
165 | P3_0 -------------------------> P3_2 -------->
166 | """
167 |
168 | # downsample channels using same-padding conv2d to target phase's if not the same
169 | # judge: same phase as target,
170 | # if same, pass;
171 | # elif earlier phase, downsample to target phase's by pooling
172 | # elif later phase, upsample to target phase's by nearest interpolation
173 |
174 | if self.attention:
175 | p3_out, p4_out, p5_out, p6_out, p7_out = self._forward_fast_attention(inputs)
176 | else:
177 | p3_out, p4_out, p5_out, p6_out, p7_out = self._forward(inputs)
178 |
179 | return p3_out, p4_out, p5_out, p6_out, p7_out
180 |
181 | def _forward_fast_attention(self, inputs):
182 | if self.first_time:
183 | p3, p4, p5 = inputs
184 |
185 | p6_in = self.p5_to_p6(p5)
186 | p7_in = self.p6_to_p7(p6_in)
187 |
188 | p3_in = self.p3_down_channel(p3)
189 | p4_in = self.p4_down_channel(p4)
190 | p5_in = self.p5_down_channel(p5)
191 |
192 | else:
193 | # P3_0, P4_0, P5_0, P6_0 and P7_0
194 | p3_in, p4_in, p5_in, p6_in, p7_in = inputs
195 |
196 | # P7_0 to P7_2
197 |
198 | # Weights for P6_0 and P7_0 to P6_1
199 | p6_w1 = self.p6_w1_relu(self.p6_w1)
200 | weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon)
201 | # Connections for P6_0 and P7_0 to P6_1 respectively
202 | p6_up = self.conv6_up(self.swish(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in)))
203 |
204 | # Weights for P5_0 and P6_0 to P5_1
205 | p5_w1 = self.p5_w1_relu(self.p5_w1)
206 | weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon)
207 | # Connections for P5_0 and P6_0 to P5_1 respectively
208 | p5_up = self.conv5_up(self.swish(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_up)))
209 |
210 | # Weights for P4_0 and P5_0 to P4_1
211 | p4_w1 = self.p4_w1_relu(self.p4_w1)
212 | weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon)
213 | # Connections for P4_0 and P5_0 to P4_1 respectively
214 | p4_up = self.conv4_up(self.swish(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_up)))
215 |
216 | # Weights for P3_0 and P4_1 to P3_2
217 | p3_w1 = self.p3_w1_relu(self.p3_w1)
218 | weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon)
219 | # Connections for P3_0 and P4_1 to P3_2 respectively
220 | p3_out = self.conv3_up(self.swish(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_up)))
221 |
222 | if self.first_time:
223 | p4_in = self.p4_down_channel_2(p4)
224 | p5_in = self.p5_down_channel_2(p5)
225 |
226 | # Weights for P4_0, P4_1 and P3_2 to P4_2
227 | p4_w2 = self.p4_w2_relu(self.p4_w2)
228 | weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon)
229 | # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
230 | p4_out = self.conv4_down(
231 | self.swish(weight[0] * p4_in + weight[1] * p4_up + weight[2] * self.p4_downsample(p3_out)))
232 |
233 | # Weights for P5_0, P5_1 and P4_2 to P5_2
234 | p5_w2 = self.p5_w2_relu(self.p5_w2)
235 | weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon)
236 | # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
237 | p5_out = self.conv5_down(
238 | self.swish(weight[0] * p5_in + weight[1] * p5_up + weight[2] * self.p5_downsample(p4_out)))
239 |
240 | # Weights for P6_0, P6_1 and P5_2 to P6_2
241 | p6_w2 = self.p6_w2_relu(self.p6_w2)
242 | weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon)
243 | # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
244 | p6_out = self.conv6_down(
245 | self.swish(weight[0] * p6_in + weight[1] * p6_up + weight[2] * self.p6_downsample(p5_out)))
246 |
247 | # Weights for P7_0 and P6_2 to P7_2
248 | p7_w2 = self.p7_w2_relu(self.p7_w2)
249 | weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon)
250 | # Connections for P7_0 and P6_2 to P7_2
251 | p7_out = self.conv7_down(self.swish(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out)))
252 |
253 | return p3_out, p4_out, p5_out, p6_out, p7_out
254 |
255 | def _forward(self, inputs):
256 | if self.first_time:
257 | p3, p4, p5 = inputs
258 |
259 | p6_in = self.p5_to_p6(p5)
260 | p7_in = self.p6_to_p7(p6_in)
261 |
262 | p3_in = self.p3_down_channel(p3)
263 | p4_in = self.p4_down_channel(p4)
264 | p5_in = self.p5_down_channel(p5)
265 |
266 | else:
267 | # P3_0, P4_0, P5_0, P6_0 and P7_0
268 | p3_in, p4_in, p5_in, p6_in, p7_in = inputs
269 |
270 | # P7_0 to P7_2
271 |
272 | # Connections for P6_0 and P7_0 to P6_1 respectively
273 | p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_in)))
274 |
275 | # Connections for P5_0 and P6_0 to P5_1 respectively
276 | p5_up = self.conv5_up(self.swish(p5_in + self.p5_upsample(p6_up)))
277 |
278 | # Connections for P4_0 and P5_0 to P4_1 respectively
279 | p4_up = self.conv4_up(self.swish(p4_in + self.p4_upsample(p5_up)))
280 |
281 | # Connections for P3_0 and P4_1 to P3_2 respectively
282 | p3_out = self.conv3_up(self.swish(p3_in + self.p3_upsample(p4_up)))
283 |
284 | if self.first_time:
285 | p4_in = self.p4_down_channel_2(p4)
286 | p5_in = self.p5_down_channel_2(p5)
287 |
288 | # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
289 | p4_out = self.conv4_down(
290 | self.swish(p4_in + p4_up + self.p4_downsample(p3_out)))
291 |
292 | # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
293 | p5_out = self.conv5_down(
294 | self.swish(p5_in + p5_up + self.p5_downsample(p4_out)))
295 |
296 | # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
297 | p6_out = self.conv6_down(
298 | self.swish(p6_in + p6_up + self.p6_downsample(p5_out)))
299 |
300 | # Connections for P7_0 and P6_2 to P7_2
301 | p7_out = self.conv7_down(self.swish(p7_in + self.p7_downsample(p6_out)))
302 |
303 | return p3_out, p4_out, p5_out, p6_out, p7_out
304 |
305 |
306 | class Regressor(nn.Module):
307 | """
308 | modified by Zylo117
309 | """
310 |
311 | def __init__(self, in_channels, num_anchors, num_layers, onnx_export=False):
312 | super(Regressor, self).__init__()
313 | self.num_layers = num_layers
314 | self.num_layers = num_layers
315 |
316 | self.conv_list = nn.ModuleList(
317 | [SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)])
318 | self.bn_list = nn.ModuleList(
319 | [nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in
320 | range(5)])
321 | self.header = SeparableConvBlock(in_channels, num_anchors * 4, norm=False, activation=False)
322 | self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
323 |
324 | def forward(self, inputs):
325 | feats = []
326 | for feat, bn_list in zip(inputs, self.bn_list):
327 | for i, bn, conv in zip(range(self.num_layers), bn_list, self.conv_list):
328 | feat = conv(feat)
329 | feat = bn(feat)
330 | feat = self.swish(feat)
331 | feat = self.header(feat)
332 |
333 | feat = feat.permute(0, 2, 3, 1)
334 | feat = feat.contiguous().view(feat.shape[0], -1, 4)
335 |
336 | feats.append(feat)
337 |
338 | feats = torch.cat(feats, dim=1)
339 |
340 | return feats
341 |
342 |
343 | class Classifier(nn.Module):
344 | """
345 | modified by Zylo117
346 | """
347 |
348 | def __init__(self, in_channels, num_anchors, num_classes, num_layers, onnx_export=False):
349 | super(Classifier, self).__init__()
350 | self.num_anchors = num_anchors
351 | self.num_classes = num_classes
352 | self.num_layers = num_layers
353 | self.conv_list = nn.ModuleList(
354 | [SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)])
355 | self.bn_list = nn.ModuleList(
356 | [nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in
357 | range(5)])
358 | self.header = SeparableConvBlock(in_channels, num_anchors * num_classes, norm=False, activation=False)
359 | self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
360 |
361 | def forward(self, inputs):
362 | feats = []
363 | for feat, bn_list in zip(inputs, self.bn_list):
364 | for i, bn, conv in zip(range(self.num_layers), bn_list, self.conv_list):
365 | feat = conv(feat)
366 | feat = bn(feat)
367 | feat = self.swish(feat)
368 | feat = self.header(feat)
369 |
370 | feat = feat.permute(0, 2, 3, 1)
371 | feat = feat.contiguous().view(feat.shape[0], feat.shape[1], feat.shape[2], self.num_anchors,
372 | self.num_classes)
373 | feat = feat.contiguous().view(feat.shape[0], -1, self.num_classes)
374 |
375 | feats.append(feat)
376 |
377 | feats = torch.cat(feats, dim=1)
378 | feats = feats.sigmoid()
379 |
380 | return feats
381 |
382 |
383 | class EfficientNet(nn.Module):
384 | """
385 | modified by Zylo117
386 | """
387 |
388 | def __init__(self, compound_coef, load_weights=False):
389 | super(EfficientNet, self).__init__()
390 | model = EffNet.from_pretrained(f'efficientnet-b{compound_coef}', load_weights)
391 | del model._conv_head
392 | del model._bn1
393 | del model._avg_pooling
394 | del model._dropout
395 | del model._fc
396 | self.model = model
397 |
398 | def forward(self, x):
399 | x = self.model._conv_stem(x)
400 | x = self.model._bn0(x)
401 | x = self.model._swish(x)
402 | feature_maps = []
403 |
404 | # TODO: temporarily storing extra tensor last_x and del it later might not be a good idea,
405 | # try recording stride changing when creating efficientnet,
406 | # and then apply it here.
407 | last_x = None
408 | for idx, block in enumerate(self.model._blocks):
409 | drop_connect_rate = self.model._global_params.drop_connect_rate
410 | if drop_connect_rate:
411 | drop_connect_rate *= float(idx) / len(self.model._blocks)
412 | x = block(x, drop_connect_rate=drop_connect_rate)
413 |
414 | if block._depthwise_conv.stride == [2, 2]:
415 | feature_maps.append(last_x)
416 | elif idx == len(self.model._blocks) - 1:
417 | feature_maps.append(x)
418 | last_x = x
419 | del last_x
420 | return feature_maps[1:]
421 |
422 |
423 | if __name__ == '__main__':
424 | from tensorboardX import SummaryWriter
425 |
426 |
427 | def count_parameters(model):
428 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
429 |
--------------------------------------------------------------------------------
/efficientdet/utils.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 |
7 | class BBoxTransform(nn.Module):
8 | def forward(self, anchors, regression):
9 | """
10 | decode_box_outputs adapted from https://github.com/google/automl/blob/master/efficientdet/anchors.py
11 |
12 | Args:
13 | anchors: [batchsize, boxes, (y1, x1, y2, x2)]
14 | regression: [batchsize, boxes, (dy, dx, dh, dw)]
15 |
16 | Returns:
17 |
18 | """
19 | y_centers_a = (anchors[..., 0] + anchors[..., 2]) / 2
20 | x_centers_a = (anchors[..., 1] + anchors[..., 3]) / 2
21 | ha = anchors[..., 2] - anchors[..., 0]
22 | wa = anchors[..., 3] - anchors[..., 1]
23 |
24 | w = regression[..., 3].exp() * wa
25 | h = regression[..., 2].exp() * ha
26 |
27 | y_centers = regression[..., 0] * ha + y_centers_a
28 | x_centers = regression[..., 1] * wa + x_centers_a
29 |
30 | ymin = y_centers - h / 2.
31 | xmin = x_centers - w / 2.
32 | ymax = y_centers + h / 2.
33 | xmax = x_centers + w / 2.
34 |
35 | return torch.stack([xmin, ymin, xmax, ymax], dim=2)
36 |
37 |
38 | class ClipBoxes(nn.Module):
39 |
40 | def __init__(self):
41 | super(ClipBoxes, self).__init__()
42 |
43 | def forward(self, boxes, img):
44 | batch_size, num_channels, height, width = img.shape
45 |
46 | boxes[:, :, 0] = torch.clamp(boxes[:, :, 0], min=0)
47 | boxes[:, :, 1] = torch.clamp(boxes[:, :, 1], min=0)
48 |
49 | boxes[:, :, 2] = torch.clamp(boxes[:, :, 2], max=width - 1)
50 | boxes[:, :, 3] = torch.clamp(boxes[:, :, 3], max=height - 1)
51 |
52 | return boxes
53 |
54 |
55 | class Anchors(nn.Module):
56 | """
57 | adapted and modified from https://github.com/google/automl/blob/master/efficientdet/anchors.py by Zylo117
58 | """
59 |
60 | def __init__(self, anchor_scale=4., pyramid_levels=None, **kwargs):
61 | super().__init__()
62 | self.anchor_scale = anchor_scale
63 |
64 | if pyramid_levels is None:
65 | self.pyramid_levels = [3, 4, 5, 6, 7]
66 |
67 | self.strides = kwargs.get('strides', [2 ** x for x in self.pyramid_levels])
68 | self.scales = np.array(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]))
69 | self.ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)])
70 |
71 | self.last_anchors = {}
72 | self.last_shape = None
73 |
74 | def forward(self, image, dtype=torch.float32):
75 | """Generates multiscale anchor boxes.
76 |
77 | Args:
78 | image_size: integer number of input image size. The input image has the
79 | same dimension for width and height. The image_size should be divided by
80 | the largest feature stride 2^max_level.
81 | anchor_scale: float number representing the scale of size of the base
82 | anchor to the feature stride 2^level.
83 | anchor_configs: a dictionary with keys as the levels of anchors and
84 | values as a list of anchor configuration.
85 |
86 | Returns:
87 | anchor_boxes: a numpy array with shape [N, 4], which stacks anchors on all
88 | feature levels.
89 | Raises:
90 | ValueError: input size must be the multiple of largest feature stride.
91 | """
92 | image_shape = image.shape[2:]
93 |
94 | if image_shape == self.last_shape and image.device in self.last_anchors:
95 | return self.last_anchors[image.device]
96 |
97 | if self.last_shape is None or self.last_shape != image_shape:
98 | self.last_shape = image_shape
99 |
100 | if dtype == torch.float16:
101 | dtype = np.float16
102 | else:
103 | dtype = np.float32
104 |
105 | boxes_all = []
106 | for stride in self.strides:
107 | boxes_level = []
108 | for scale, ratio in itertools.product(self.scales, self.ratios):
109 | if image_shape[1] % stride != 0:
110 | raise ValueError('input size must be divided by the stride.')
111 | base_anchor_size = self.anchor_scale * stride * scale
112 | anchor_size_x_2 = base_anchor_size * ratio[0] / 2.0
113 | anchor_size_y_2 = base_anchor_size * ratio[1] / 2.0
114 |
115 | x = np.arange(stride / 2, image_shape[1], stride)
116 | y = np.arange(stride / 2, image_shape[0], stride)
117 | xv, yv = np.meshgrid(x, y)
118 | xv = xv.reshape(-1)
119 | yv = yv.reshape(-1)
120 |
121 | # y1,x1,y2,x2
122 | boxes = np.vstack((yv - anchor_size_y_2, xv - anchor_size_x_2,
123 | yv + anchor_size_y_2, xv + anchor_size_x_2))
124 | boxes = np.swapaxes(boxes, 0, 1)
125 | boxes_level.append(np.expand_dims(boxes, axis=1))
126 | # concat anchors on the same level to the reshape NxAx4
127 | boxes_level = np.concatenate(boxes_level, axis=1)
128 | boxes_all.append(boxes_level.reshape([-1, 4]))
129 |
130 | anchor_boxes = np.vstack(boxes_all)
131 |
132 | anchor_boxes = torch.from_numpy(anchor_boxes.astype(dtype)).to(image.device)
133 | anchor_boxes = anchor_boxes.unsqueeze(0)
134 |
135 | # save it for later use to reduce overhead
136 | self.last_anchors[image.device] = anchor_boxes
137 | return anchor_boxes
138 |
--------------------------------------------------------------------------------
/efficientdet_test.py:
--------------------------------------------------------------------------------
1 | # Author: Zylo117
2 |
3 | """
4 | Simple Inference Script of EfficientDet-Pytorch
5 | """
6 | import random
7 | import time
8 |
9 | import cv2
10 | import numpy as np
11 | import torch
12 | from torch.backends import cudnn
13 |
14 | from backbone import EfficientDetBackbone
15 | from efficientdet.utils import BBoxTransform, ClipBoxes
16 | from utils.utils import preprocess, invert_affine, postprocess
17 | import efficientService as service
18 | from PIL import Image
19 |
20 |
21 | compound_coef = 0
22 | force_input_size = None # set None to use default size
23 | img_path = 'test/20130124152801556.jpg'
24 |
25 | # replace this part with your project's anchor config
26 | anchor_ratios = [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]
27 | anchor_scales = [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]
28 |
29 | threshold = 0.2
30 | iou_threshold = 0.2
31 |
32 | use_cuda = True
33 | use_float16 = False
34 | cudnn.fastest = True
35 | cudnn.benchmark = True
36 |
37 | obj_list = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
38 | 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
39 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie',
40 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
41 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
42 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
43 | 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv',
44 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
45 | 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
46 | 'toothbrush']
47 |
48 | # tf bilinear interpolation is different from any other's, just make do
49 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
50 | input_size = input_sizes[compound_coef] if force_input_size is None else force_input_size
51 | ori_imgs, framed_imgs, framed_metas = preprocess(img_path, max_size=input_size)
52 |
53 | if use_cuda:
54 | x = torch.stack([torch.from_numpy(fi).cuda() for fi in framed_imgs], 0)
55 | else:
56 | x = torch.stack([torch.from_numpy(fi) for fi in framed_imgs], 0)
57 |
58 | x = x.to(torch.float32 if not use_float16 else torch.float16).permute(0, 3, 1, 2)
59 |
60 | model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list),
61 | ratios=anchor_ratios, scales=anchor_scales)
62 | model.load_state_dict(torch.load(f'weights/efficientdet-d{compound_coef}.pth'))
63 | model.requires_grad_(False)
64 | model.eval()
65 |
66 | if use_cuda:
67 | model = model.cuda()
68 | if use_float16:
69 | model = model.half()
70 |
71 | with torch.no_grad():
72 | image = Image.open(img_path)
73 | frame = np.array(image)
74 |
75 | frame=service.detect(frame)
76 |
77 | features, regression, classification, anchors = model(x)
78 |
79 | regressBoxes = BBoxTransform()
80 | clipBoxes = ClipBoxes()
81 |
82 | out = postprocess(x,
83 | anchors, regression, classification,
84 | regressBoxes, clipBoxes,
85 | threshold, iou_threshold)
86 |
87 |
88 | def display(preds, imgs, imshow=True, imwrite=False):
89 | for i in range(len(imgs)):
90 | if len(preds[i]['rois']) == 0:
91 | continue
92 |
93 | for j in range(len(preds[i]['rois'])):
94 | (x1, y1, x2, y2) = preds[i]['rois'][j].astype(np.int)
95 | color = [random.randint(0, 255) for _ in range(3)]
96 | cv2.rectangle(imgs[i], (x1, y1), (x2, y2), color, 2)
97 | obj = obj_list[preds[i]['class_ids'][j]]
98 | score = float(preds[i]['scores'][j])
99 |
100 | label = obj
101 | label_size = cv2.getTextSize(label, 0, fontScale=2 / 3, thickness=1)[0]
102 | cv2.rectangle(imgs[i], (x1, y1), (x1 + label_size[0], y1 - label_size[1] - 3), color, -1)
103 | # cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score),
104 | # (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
105 | # [225, 255, 255], 1, lineType=cv2.LINE_AA)
106 | cv2.putText(imgs[i], label, (x1, y1 - 8), 0, 2 / 3, [225, 255, 255], thickness=1,
107 | lineType=cv2.LINE_AA)
108 |
109 | if imshow:
110 | cv2.imshow('img', imgs[i])
111 | cv2.waitKey(0)
112 |
113 | if imwrite:
114 | cv2.imwrite(f'test/img_inferred_d{compound_coef}_this_repo_{i}.jpg', imgs[i])
115 |
116 |
117 | out = invert_affine(framed_metas, out)
118 | display(out, ori_imgs, imshow=True, imwrite=False)
119 |
120 | print('running speed test...')
121 | with torch.no_grad():
122 | print('test1: model inferring and postprocessing')
123 | print('inferring image for 10 times...')
124 | t1 = time.time()
125 | for _ in range(10):
126 | _, regression, classification, anchors = model(x)
127 |
128 | out = postprocess(x,
129 | anchors, regression, classification,
130 | regressBoxes, clipBoxes,
131 | threshold, iou_threshold)
132 | out = invert_affine(framed_metas, out)
133 |
134 | t2 = time.time()
135 | tact_time = (t2 - t1) / 10
136 | print(f'{tact_time} seconds, {1 / tact_time} FPS, @batch_size 1')
137 |
138 | # uncomment this if you want a extreme fps test
139 | # print('test2: model inferring only')
140 | # print('inferring images for batch_size 32 for 10 times...')
141 | # t1 = time.time()
142 | # x = torch.cat([x] * 32, 0)
143 | # for _ in range(10):
144 | # _, regression, classification, anchors = model(x)
145 | #
146 | # t2 = time.time()
147 | # tact_time = (t2 - t1) / 10
148 | # print(f'{tact_time} seconds, {32 / tact_time} FPS, @batch_size 32')
149 |
--------------------------------------------------------------------------------
/efficientdet_test_videos.py:
--------------------------------------------------------------------------------
1 | # Core Author: Zylo117
2 | # Script's Author: winter2897
3 |
4 | """
5 | Simple Inference Script of EfficientDet-Pytorch for detecting objects on webcam
6 | """
7 | import time
8 | import torch
9 | import cv2
10 | import numpy as np
11 | from torch.backends import cudnn
12 | from backbone import EfficientDetBackbone
13 | from efficientdet.utils import BBoxTransform, ClipBoxes
14 | from utils.utils import preprocess, invert_affine, postprocess, preprocess_video
15 |
16 | # Video's path
17 | video_src = 'videotest.mp4' # set int to use webcam, set str to read from a video file
18 |
19 | compound_coef = 0
20 | force_input_size = None # set None to use default size
21 |
22 | threshold = 0.2
23 | iou_threshold = 0.2
24 |
25 | use_cuda = True
26 | use_float16 = False
27 | cudnn.fastest = True
28 | cudnn.benchmark = True
29 |
30 | obj_list = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
31 | 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
32 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie',
33 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
34 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
35 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
36 | 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv',
37 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
38 | 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
39 | 'toothbrush']
40 |
41 | # tf bilinear interpolation is different from any other's, just make do
42 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
43 | input_size = input_sizes[compound_coef] if force_input_size is None else force_input_size
44 |
45 | # load model
46 | model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list))
47 | model.load_state_dict(torch.load(f'weights/efficientdet-d{compound_coef}.pth'))
48 | model.requires_grad_(False)
49 | model.eval()
50 |
51 | if use_cuda:
52 | model = model.cuda()
53 | if use_float16:
54 | model = model.half()
55 |
56 | # function for display
57 | def display(preds, imgs):
58 | for i in range(len(imgs)):
59 | if len(preds[i]['rois']) == 0:
60 | continue
61 |
62 | for j in range(len(preds[i]['rois'])):
63 | (x1, y1, x2, y2) = preds[i]['rois'][j].astype(np.int)
64 | cv2.rectangle(imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2)
65 | obj = obj_list[preds[i]['class_ids'][j]]
66 | score = float(preds[i]['scores'][j])
67 |
68 | cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score),
69 | (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
70 | (255, 255, 0), 1)
71 |
72 | return imgs[i]
73 | # Box
74 | regressBoxes = BBoxTransform()
75 | clipBoxes = ClipBoxes()
76 |
77 | # Video capture
78 | cap = cv2.VideoCapture(video_src)
79 |
80 | while True:
81 | ret, frame = cap.read()
82 | if not ret:
83 | break
84 |
85 | # frame preprocessing
86 | ori_imgs, framed_imgs, framed_metas = preprocess_video(frame, max_size=input_size)
87 |
88 | if use_cuda:
89 | x = torch.stack([torch.from_numpy(fi).cuda() for fi in framed_imgs], 0)
90 | else:
91 | x = torch.stack([torch.from_numpy(fi) for fi in framed_imgs], 0)
92 |
93 | x = x.to(torch.float32 if not use_float16 else torch.float16).permute(0, 3, 1, 2)
94 |
95 | # model predict
96 | with torch.no_grad():
97 | features, regression, classification, anchors = model(x)
98 |
99 | out = postprocess(x,
100 | anchors, regression, classification,
101 | regressBoxes, clipBoxes,
102 | threshold, iou_threshold)
103 |
104 | # result
105 | out = invert_affine(framed_metas, out)
106 | img_show = display(out, ori_imgs)
107 |
108 | # show frame by frame
109 | cv2.imshow('frame',img_show)
110 | if cv2.waitKey(1) & 0xFF == ord('q'):
111 | break
112 |
113 | cap.release()
114 | cv2.destroyAllWindows()
115 |
116 |
117 |
118 |
119 |
120 |
--------------------------------------------------------------------------------
/efficientnet/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.6.1"
2 | from .model import EfficientNet
3 | from .utils import (
4 | GlobalParams,
5 | BlockArgs,
6 | BlockDecoder,
7 | efficientnet,
8 | get_model_params,
9 | )
10 |
11 |
--------------------------------------------------------------------------------
/efficientnet/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from .utils import (
6 | round_filters,
7 | round_repeats,
8 | drop_connect,
9 | get_same_padding_conv2d,
10 | get_model_params,
11 | efficientnet_params,
12 | load_pretrained_weights,
13 | Swish,
14 | MemoryEfficientSwish,
15 | )
16 |
17 | class MBConvBlock(nn.Module):
18 | """
19 | Mobile Inverted Residual Bottleneck Block
20 |
21 | Args:
22 | block_args (namedtuple): BlockArgs, see above
23 | global_params (namedtuple): GlobalParam, see above
24 |
25 | Attributes:
26 | has_se (bool): Whether the block contains a Squeeze and Excitation layer.
27 | """
28 |
29 | def __init__(self, block_args, global_params):
30 | super().__init__()
31 | self._block_args = block_args
32 | self._bn_mom = 1 - global_params.batch_norm_momentum
33 | self._bn_eps = global_params.batch_norm_epsilon
34 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
35 | self.id_skip = block_args.id_skip # skip connection and drop connect
36 |
37 | # Get static or dynamic convolution depending on image size
38 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
39 |
40 | # Expansion phase
41 | inp = self._block_args.input_filters # number of input channels
42 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
43 | if self._block_args.expand_ratio != 1:
44 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
45 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
46 |
47 | # Depthwise convolution phase
48 | k = self._block_args.kernel_size
49 | s = self._block_args.stride
50 | self._depthwise_conv = Conv2d(
51 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
52 | kernel_size=k, stride=s, bias=False)
53 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
54 |
55 | # Squeeze and Excitation layer, if desired
56 | if self.has_se:
57 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
58 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
59 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
60 |
61 | # Output phase
62 | final_oup = self._block_args.output_filters
63 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
64 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
65 | self._swish = MemoryEfficientSwish()
66 |
67 | def forward(self, inputs, drop_connect_rate=None):
68 | """
69 | :param inputs: input tensor
70 | :param drop_connect_rate: drop connect rate (float, between 0 and 1)
71 | :return: output of block
72 | """
73 |
74 | # Expansion and Depthwise Convolution
75 | x = inputs
76 | if self._block_args.expand_ratio != 1:
77 | x = self._expand_conv(inputs)
78 | x = self._bn0(x)
79 | x = self._swish(x)
80 |
81 | x = self._depthwise_conv(x)
82 | x = self._bn1(x)
83 | x = self._swish(x)
84 |
85 | # Squeeze and Excitation
86 | if self.has_se:
87 | x_squeezed = F.adaptive_avg_pool2d(x, 1)
88 | x_squeezed = self._se_reduce(x_squeezed)
89 | x_squeezed = self._swish(x_squeezed)
90 | x_squeezed = self._se_expand(x_squeezed)
91 | x = torch.sigmoid(x_squeezed) * x
92 |
93 | x = self._project_conv(x)
94 | x = self._bn2(x)
95 |
96 | # Skip connection and drop connect
97 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
98 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
99 | if drop_connect_rate:
100 | x = drop_connect(x, p=drop_connect_rate, training=self.training)
101 | x = x + inputs # skip connection
102 | return x
103 |
104 | def set_swish(self, memory_efficient=True):
105 | """Sets swish function as memory efficient (for training) or standard (for export)"""
106 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
107 |
108 |
109 | class EfficientNet(nn.Module):
110 | """
111 | An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
112 |
113 | Args:
114 | blocks_args (list): A list of BlockArgs to construct blocks
115 | global_params (namedtuple): A set of GlobalParams shared between blocks
116 |
117 | Example:
118 | model = EfficientNet.from_pretrained('efficientnet-b0')
119 |
120 | """
121 |
122 | def __init__(self, blocks_args=None, global_params=None):
123 | super().__init__()
124 | assert isinstance(blocks_args, list), 'blocks_args should be a list'
125 | assert len(blocks_args) > 0, 'block args must be greater than 0'
126 | self._global_params = global_params
127 | self._blocks_args = blocks_args
128 |
129 | # Get static or dynamic convolution depending on image size
130 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
131 |
132 | # Batch norm parameters
133 | bn_mom = 1 - self._global_params.batch_norm_momentum
134 | bn_eps = self._global_params.batch_norm_epsilon
135 |
136 | # Stem
137 | in_channels = 3 # rgb
138 | out_channels = round_filters(32, self._global_params) # number of output channels
139 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
140 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
141 |
142 | # Build blocks
143 | self._blocks = nn.ModuleList([])
144 | for block_args in self._blocks_args:
145 |
146 | # Update block input and output filters based on depth multiplier.
147 | block_args = block_args._replace(
148 | input_filters=round_filters(block_args.input_filters, self._global_params),
149 | output_filters=round_filters(block_args.output_filters, self._global_params),
150 | num_repeat=round_repeats(block_args.num_repeat, self._global_params)
151 | )
152 |
153 | # The first block needs to take care of stride and filter size increase.
154 | self._blocks.append(MBConvBlock(block_args, self._global_params))
155 | if block_args.num_repeat > 1:
156 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
157 | for _ in range(block_args.num_repeat - 1):
158 | self._blocks.append(MBConvBlock(block_args, self._global_params))
159 |
160 | # Head
161 | in_channels = block_args.output_filters # output of final block
162 | out_channels = round_filters(1280, self._global_params)
163 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
164 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
165 |
166 | # Final linear layer
167 | self._avg_pooling = nn.AdaptiveAvgPool2d(1)
168 | self._dropout = nn.Dropout(self._global_params.dropout_rate)
169 | self._fc = nn.Linear(out_channels, self._global_params.num_classes)
170 | self._swish = MemoryEfficientSwish()
171 |
172 | def set_swish(self, memory_efficient=True):
173 | """Sets swish function as memory efficient (for training) or standard (for export)"""
174 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
175 | for block in self._blocks:
176 | block.set_swish(memory_efficient)
177 |
178 |
179 | def extract_features(self, inputs):
180 | """ Returns output of the final convolution layer """
181 |
182 | # Stem
183 | x = self._swish(self._bn0(self._conv_stem(inputs)))
184 |
185 | # Blocks
186 | for idx, block in enumerate(self._blocks):
187 | drop_connect_rate = self._global_params.drop_connect_rate
188 | if drop_connect_rate:
189 | drop_connect_rate *= float(idx) / len(self._blocks)
190 | x = block(x, drop_connect_rate=drop_connect_rate)
191 | # Head
192 | x = self._swish(self._bn1(self._conv_head(x)))
193 |
194 | return x
195 |
196 | def forward(self, inputs):
197 | """ Calls extract_features to extract features, applies final linear layer, and returns logits. """
198 | bs = inputs.size(0)
199 | # Convolution layers
200 | x = self.extract_features(inputs)
201 |
202 | # Pooling and final linear layer
203 | x = self._avg_pooling(x)
204 | x = x.view(bs, -1)
205 | x = self._dropout(x)
206 | x = self._fc(x)
207 | return x
208 |
209 | @classmethod
210 | def from_name(cls, model_name, override_params=None):
211 | cls._check_model_name_is_valid(model_name)
212 | blocks_args, global_params = get_model_params(model_name, override_params)
213 | return cls(blocks_args, global_params)
214 |
215 | @classmethod
216 | def from_pretrained(cls, model_name, load_weights=True, advprop=True, num_classes=1000, in_channels=3):
217 | model = cls.from_name(model_name, override_params={'num_classes': num_classes})
218 | if load_weights:
219 | load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop)
220 | if in_channels != 3:
221 | Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size)
222 | out_channels = round_filters(32, model._global_params)
223 | model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
224 | return model
225 |
226 | @classmethod
227 | def get_image_size(cls, model_name):
228 | cls._check_model_name_is_valid(model_name)
229 | _, _, res, _ = efficientnet_params(model_name)
230 | return res
231 |
232 | @classmethod
233 | def _check_model_name_is_valid(cls, model_name):
234 | """ Validates model name. """
235 | valid_models = ['efficientnet-b'+str(i) for i in range(9)]
236 | if model_name not in valid_models:
237 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
238 |
--------------------------------------------------------------------------------
/efficientnet/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains helper functions for building the model and for loading model parameters.
3 | These helper functions are built to mirror those in the official TensorFlow implementation.
4 | """
5 |
6 | import re
7 | import math
8 | import collections
9 | from functools import partial
10 | import torch
11 | from torch import nn
12 | from torch.nn import functional as F
13 | from torch.utils import model_zoo
14 | from .utils_extra import Conv2dStaticSamePadding
15 |
16 | ########################################################################
17 | ############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
18 | ########################################################################
19 |
20 |
21 | # Parameters for the entire model (stem, all blocks, and head)
22 |
23 | GlobalParams = collections.namedtuple('GlobalParams', [
24 | 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate',
25 | 'num_classes', 'width_coefficient', 'depth_coefficient',
26 | 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])
27 |
28 | # Parameters for an individual model block
29 | BlockArgs = collections.namedtuple('BlockArgs', [
30 | 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
31 | 'expand_ratio', 'id_skip', 'stride', 'se_ratio'])
32 |
33 | # Change namedtuple defaults
34 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
35 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
36 |
37 |
38 | class SwishImplementation(torch.autograd.Function):
39 | @staticmethod
40 | def forward(ctx, i):
41 | result = i * torch.sigmoid(i)
42 | ctx.save_for_backward(i)
43 | return result
44 |
45 | @staticmethod
46 | def backward(ctx, grad_output):
47 | i = ctx.saved_variables[0]
48 | sigmoid_i = torch.sigmoid(i)
49 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
50 |
51 |
52 | class MemoryEfficientSwish(nn.Module):
53 | def forward(self, x):
54 | return SwishImplementation.apply(x)
55 |
56 |
57 | class Swish(nn.Module):
58 | def forward(self, x):
59 | return x * torch.sigmoid(x)
60 |
61 |
62 | def round_filters(filters, global_params):
63 | """ Calculate and round number of filters based on depth multiplier. """
64 | multiplier = global_params.width_coefficient
65 | if not multiplier:
66 | return filters
67 | divisor = global_params.depth_divisor
68 | min_depth = global_params.min_depth
69 | filters *= multiplier
70 | min_depth = min_depth or divisor
71 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
72 | if new_filters < 0.9 * filters: # prevent rounding by more than 10%
73 | new_filters += divisor
74 | return int(new_filters)
75 |
76 |
77 | def round_repeats(repeats, global_params):
78 | """ Round number of filters based on depth multiplier. """
79 | multiplier = global_params.depth_coefficient
80 | if not multiplier:
81 | return repeats
82 | return int(math.ceil(multiplier * repeats))
83 |
84 |
85 | def drop_connect(inputs, p, training):
86 | """ Drop connect. """
87 | if not training: return inputs
88 | batch_size = inputs.shape[0]
89 | keep_prob = 1 - p
90 | random_tensor = keep_prob
91 | random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
92 | binary_tensor = torch.floor(random_tensor)
93 | output = inputs / keep_prob * binary_tensor
94 | return output
95 |
96 |
97 | def get_same_padding_conv2d(image_size=None):
98 | """ Chooses static padding if you have specified an image size, and dynamic padding otherwise.
99 | Static padding is necessary for ONNX exporting of models. """
100 | if image_size is None:
101 | return Conv2dDynamicSamePadding
102 | else:
103 | return partial(Conv2dStaticSamePadding, image_size=image_size)
104 |
105 |
106 | class Conv2dDynamicSamePadding(nn.Conv2d):
107 | """ 2D Convolutions like TensorFlow, for a dynamic image size """
108 |
109 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
110 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
111 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
112 |
113 | def forward(self, x):
114 | ih, iw = x.size()[-2:]
115 | kh, kw = self.weight.size()[-2:]
116 | sh, sw = self.stride
117 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
118 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
119 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
120 | if pad_h > 0 or pad_w > 0:
121 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
122 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
123 |
124 |
125 | class Identity(nn.Module):
126 | def __init__(self, ):
127 | super(Identity, self).__init__()
128 |
129 | def forward(self, input):
130 | return input
131 |
132 |
133 | ########################################################################
134 | ############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
135 | ########################################################################
136 |
137 |
138 | def efficientnet_params(model_name):
139 | """ Map EfficientNet model name to parameter coefficients. """
140 | params_dict = {
141 | # Coefficients: width,depth,res,dropout
142 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
143 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
144 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
145 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
146 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
147 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
148 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
149 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
150 | 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
151 | 'efficientnet-l2': (4.3, 5.3, 800, 0.5),
152 | }
153 | return params_dict[model_name]
154 |
155 |
156 | class BlockDecoder(object):
157 | """ Block Decoder for readability, straight from the official TensorFlow repository """
158 |
159 | @staticmethod
160 | def _decode_block_string(block_string):
161 | """ Gets a block through a string notation of arguments. """
162 | assert isinstance(block_string, str)
163 |
164 | ops = block_string.split('_')
165 | options = {}
166 | for op in ops:
167 | splits = re.split(r'(\d.*)', op)
168 | if len(splits) >= 2:
169 | key, value = splits[:2]
170 | options[key] = value
171 |
172 | # Check stride
173 | assert (('s' in options and len(options['s']) == 1) or
174 | (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
175 |
176 | return BlockArgs(
177 | kernel_size=int(options['k']),
178 | num_repeat=int(options['r']),
179 | input_filters=int(options['i']),
180 | output_filters=int(options['o']),
181 | expand_ratio=int(options['e']),
182 | id_skip=('noskip' not in block_string),
183 | se_ratio=float(options['se']) if 'se' in options else None,
184 | stride=[int(options['s'][0])])
185 |
186 | @staticmethod
187 | def _encode_block_string(block):
188 | """Encodes a block to a string."""
189 | args = [
190 | 'r%d' % block.num_repeat,
191 | 'k%d' % block.kernel_size,
192 | 's%d%d' % (block.strides[0], block.strides[1]),
193 | 'e%s' % block.expand_ratio,
194 | 'i%d' % block.input_filters,
195 | 'o%d' % block.output_filters
196 | ]
197 | if 0 < block.se_ratio <= 1:
198 | args.append('se%s' % block.se_ratio)
199 | if block.id_skip is False:
200 | args.append('noskip')
201 | return '_'.join(args)
202 |
203 | @staticmethod
204 | def decode(string_list):
205 | """
206 | Decodes a list of string notations to specify blocks inside the network.
207 |
208 | :param string_list: a list of strings, each string is a notation of block
209 | :return: a list of BlockArgs namedtuples of block args
210 | """
211 | assert isinstance(string_list, list)
212 | blocks_args = []
213 | for block_string in string_list:
214 | blocks_args.append(BlockDecoder._decode_block_string(block_string))
215 | return blocks_args
216 |
217 | @staticmethod
218 | def encode(blocks_args):
219 | """
220 | Encodes a list of BlockArgs to a list of strings.
221 |
222 | :param blocks_args: a list of BlockArgs namedtuples of block args
223 | :return: a list of strings, each string is a notation of block
224 | """
225 | block_strings = []
226 | for block in blocks_args:
227 | block_strings.append(BlockDecoder._encode_block_string(block))
228 | return block_strings
229 |
230 |
231 | def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2,
232 | drop_connect_rate=0.2, image_size=None, num_classes=1000):
233 | """ Creates a efficientnet model. """
234 |
235 | blocks_args = [
236 | 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
237 | 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
238 | 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
239 | 'r1_k3_s11_e6_i192_o320_se0.25',
240 | ]
241 | blocks_args = BlockDecoder.decode(blocks_args)
242 |
243 | global_params = GlobalParams(
244 | batch_norm_momentum=0.99,
245 | batch_norm_epsilon=1e-3,
246 | dropout_rate=dropout_rate,
247 | drop_connect_rate=drop_connect_rate,
248 | # data_format='channels_last', # removed, this is always true in PyTorch
249 | num_classes=num_classes,
250 | width_coefficient=width_coefficient,
251 | depth_coefficient=depth_coefficient,
252 | depth_divisor=8,
253 | min_depth=None,
254 | image_size=image_size,
255 | )
256 |
257 | return blocks_args, global_params
258 |
259 |
260 | def get_model_params(model_name, override_params):
261 | """ Get the block args and global params for a given model """
262 | if model_name.startswith('efficientnet'):
263 | w, d, s, p = efficientnet_params(model_name)
264 | # note: all models have drop connect rate = 0.2
265 | blocks_args, global_params = efficientnet(
266 | width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
267 | else:
268 | raise NotImplementedError('model name is not pre-defined: %s' % model_name)
269 | if override_params:
270 | # ValueError will be raised here if override_params has fields not included in global_params.
271 | global_params = global_params._replace(**override_params)
272 | return blocks_args, global_params
273 |
274 |
275 | url_map = {
276 | 'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b0-355c32eb.pth',
277 | 'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b1-f1951068.pth',
278 | 'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b2-8bb594d6.pth',
279 | 'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b3-5fb5a3c3.pth',
280 | 'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b4-6ed6700e.pth',
281 | 'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b5-b6417697.pth',
282 | 'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b6-c76e70fd.pth',
283 | 'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b7-dcc49843.pth',
284 | }
285 |
286 | url_map_advprop = {
287 | 'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b0-b64d5a18.pth',
288 | 'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b1-0f3ce85a.pth',
289 | 'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b2-6e9d97e5.pth',
290 | 'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b3-cdd7c0f4.pth',
291 | 'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b4-44fb3a87.pth',
292 | 'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b5-86493f6b.pth',
293 | 'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b6-ac80338e.pth',
294 | 'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b7-4652b6dd.pth',
295 | 'efficientnet-b8': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b8-22a8fe65.pth',
296 | }
297 |
298 |
299 | def load_pretrained_weights(model, model_name, load_fc=True, advprop=False):
300 | """ Loads pretrained weights, and downloads if loading for the first time. """
301 | # AutoAugment or Advprop (different preprocessing)
302 | url_map_ = url_map_advprop if advprop else url_map
303 | state_dict = model_zoo.load_url(url_map_[model_name], map_location=torch.device('cpu'))
304 | # state_dict = torch.load('../../weights/backbone_efficientnetb0.pth')
305 | if load_fc:
306 | ret = model.load_state_dict(state_dict, strict=False)
307 | print(ret)
308 | else:
309 | state_dict.pop('_fc.weight')
310 | state_dict.pop('_fc.bias')
311 | res = model.load_state_dict(state_dict, strict=False)
312 | assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
313 | print('Loaded pretrained weights for {}'.format(model_name))
314 |
--------------------------------------------------------------------------------
/efficientnet/utils_extra.py:
--------------------------------------------------------------------------------
1 | # Author: Zylo117
2 |
3 | import math
4 |
5 | from torch import nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class Conv2dStaticSamePadding(nn.Module):
10 | """
11 | created by Zylo117
12 | The real keras/tensorflow conv2d with same padding
13 | """
14 |
15 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1, dilation=1, **kwargs):
16 | super().__init__()
17 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
18 | bias=bias, groups=groups)
19 | self.stride = self.conv.stride
20 | self.kernel_size = self.conv.kernel_size
21 | self.dilation = self.conv.dilation
22 |
23 | if isinstance(self.stride, int):
24 | self.stride = [self.stride] * 2
25 | elif len(self.stride) == 1:
26 | self.stride = [self.stride[0]] * 2
27 |
28 | if isinstance(self.kernel_size, int):
29 | self.kernel_size = [self.kernel_size] * 2
30 | elif len(self.kernel_size) == 1:
31 | self.kernel_size = [self.kernel_size[0]] * 2
32 |
33 | def forward(self, x):
34 | h, w = x.shape[-2:]
35 |
36 | h_step = math.ceil(w / self.stride[1])
37 | v_step = math.ceil(h / self.stride[0])
38 | h_cover_len = self.stride[1] * (h_step - 1) + 1 + (self.kernel_size[1] - 1)
39 | v_cover_len = self.stride[0] * (v_step - 1) + 1 + (self.kernel_size[0] - 1)
40 |
41 | extra_h = h_cover_len - w
42 | extra_v = v_cover_len - h
43 |
44 | left = extra_h // 2
45 | right = extra_h - left
46 | top = extra_v // 2
47 | bottom = extra_v - top
48 |
49 | x = F.pad(x, [left, right, top, bottom])
50 |
51 | x = self.conv(x)
52 | return x
53 |
54 |
55 | class MaxPool2dStaticSamePadding(nn.Module):
56 | """
57 | created by Zylo117
58 | The real keras/tensorflow MaxPool2d with same padding
59 | """
60 |
61 | def __init__(self, *args, **kwargs):
62 | super().__init__()
63 | self.pool = nn.MaxPool2d(*args, **kwargs)
64 | self.stride = self.pool.stride
65 | self.kernel_size = self.pool.kernel_size
66 |
67 | if isinstance(self.stride, int):
68 | self.stride = [self.stride] * 2
69 | elif len(self.stride) == 1:
70 | self.stride = [self.stride[0]] * 2
71 |
72 | if isinstance(self.kernel_size, int):
73 | self.kernel_size = [self.kernel_size] * 2
74 | elif len(self.kernel_size) == 1:
75 | self.kernel_size = [self.kernel_size[0]] * 2
76 |
77 | def forward(self, x):
78 | h, w = x.shape[-2:]
79 |
80 | h_step = math.ceil(w / self.stride[1])
81 | v_step = math.ceil(h / self.stride[0])
82 | h_cover_len = self.stride[1] * (h_step - 1) + 1 + (self.kernel_size[1] - 1)
83 | v_cover_len = self.stride[0] * (v_step - 1) + 1 + (self.kernel_size[0] - 1)
84 |
85 | extra_h = h_cover_len - w
86 | extra_v = v_cover_len - h
87 |
88 | left = extra_h // 2
89 | right = extra_h - left
90 | top = extra_v // 2
91 | bottom = extra_v - top
92 |
93 | x = F.pad(x, [left, right, top, bottom])
94 |
95 | x = self.pool(x)
96 | return x
97 |
--------------------------------------------------------------------------------
/projects/coco.yml:
--------------------------------------------------------------------------------
1 | project_name: coco # also the folder name of the dataset that under data_path folder
2 | train_set: train2017
3 | val_set: val2017
4 | num_gpus: 4
5 |
6 | # mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco.
7 | mean: [0.485, 0.456, 0.406]
8 | std: [0.229, 0.224, 0.225]
9 |
10 | # this is coco anchors, change it if necessary
11 | anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]'
12 | anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]'
13 |
14 | # must match your dataset's category_id.
15 | # category_id is one_indexed,
16 | # for example, index of 'car' here is 2, while category_id of is 3
17 | obj_list: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
18 | 'fire hydrant', '', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
19 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '', 'backpack', 'umbrella', '', '', 'handbag', 'tie',
20 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
21 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', '', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
22 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
23 | 'cake', 'chair', 'couch', 'potted plant', 'bed', '', 'dining table', '', '', 'toilet', '', 'tv',
24 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
25 | 'refrigerator', '', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
26 | 'toothbrush']
--------------------------------------------------------------------------------
/projects/shape.yml:
--------------------------------------------------------------------------------
1 | project_name: shape # also the folder name of the dataset that under data_path folder
2 | train_set: train
3 | val_set: val
4 | num_gpus: 1
5 |
6 | # mean and std in RGB order, actually this part should remain unchanged as long as your dataset is similar to coco.
7 | mean: [0.485, 0.456, 0.406]
8 | std: [0.229, 0.224, 0.225]
9 |
10 | # this anchor is adapted to the dataset
11 | anchors_scales: '[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]'
12 | anchors_ratios: '[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]'
13 |
14 | obj_list: ['rectangle', 'circle']
--------------------------------------------------------------------------------
/test/img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anylots/DetectNet/2be42fc1c8439efc48ab337b1c73fe4ee9b9f78e/test/img.png
--------------------------------------------------------------------------------
/test/img_inferred_d0_official.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anylots/DetectNet/2be42fc1c8439efc48ab337b1c73fe4ee9b9f78e/test/img_inferred_d0_official.jpg
--------------------------------------------------------------------------------
/test/img_inferred_d0_this_repo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anylots/DetectNet/2be42fc1c8439efc48ab337b1c73fe4ee9b9f78e/test/img_inferred_d0_this_repo.jpg
--------------------------------------------------------------------------------
/test/img_inferred_d0_this_repo_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/anylots/DetectNet/2be42fc1c8439efc48ab337b1c73fe4ee9b9f78e/test/img_inferred_d0_this_repo_0.jpg
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # original author: signatrix
2 | # adapted from https://github.com/signatrix/efficientdet/blob/master/train.py
3 | # modified by Zylo117
4 |
5 | import datetime
6 | import os
7 | import argparse
8 | import traceback
9 |
10 | import torch
11 | import yaml
12 | from torch import nn
13 | from torch.utils.data import DataLoader
14 | from torchvision import transforms
15 | from efficientdet.dataset import CocoDataset, Resizer, Normalizer, Augmenter, collater
16 | from backbone import EfficientDetBackbone
17 | from tensorboardX import SummaryWriter
18 | import numpy as np
19 | from tqdm.autonotebook import tqdm
20 |
21 | from efficientdet.loss import FocalLoss
22 | from utils.sync_batchnorm import patch_replication_callback
23 | from utils.utils import replace_w_sync_bn, CustomDataParallel, get_last_weights, init_weights
24 |
25 |
26 | class Params:
27 | def __init__(self, project_file):
28 | self.params = yaml.safe_load(open(project_file).read())
29 |
30 | def __getattr__(self, item):
31 | return self.params.get(item, None)
32 |
33 |
34 | def get_args():
35 | parser = argparse.ArgumentParser('Yet Another EfficientDet Pytorch: SOTA object detection network - Zylo117')
36 | parser.add_argument('-p', '--project', type=str, default='coco', help='project file that contains parameters')
37 | parser.add_argument('-c', '--compound_coef', type=int, default=0, help='coefficients of efficientdet')
38 | parser.add_argument('-n', '--num_workers', type=int, default=12, help='num_workers of dataloader')
39 | parser.add_argument('--batch_size', type=int, default=12, help='The number of images per batch among all devices')
40 | parser.add_argument('--head_only', type=bool, default=False,
41 | help='whether finetunes only the regressor and the classifier, '
42 | 'useful in early stage convergence or small/easy dataset')
43 | parser.add_argument('--lr', type=float, default=1e-4)
44 | parser.add_argument('--optim', type=str, default='adamw', help='select optimizer for training, '
45 | 'suggest using \'admaw\' until the'
46 | ' very final stage then switch to \'sgd\'')
47 | parser.add_argument('--alpha', type=float, default=0.25)
48 | parser.add_argument('--gamma', type=float, default=1.5)
49 | parser.add_argument('--num_epochs', type=int, default=500)
50 | parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')
51 | parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')
52 | parser.add_argument('--es_min_delta', type=float, default=0.0,
53 | help='Early stopping\'s parameter: minimum change loss to qualify as an improvement')
54 | parser.add_argument('--es_patience', type=int, default=0,
55 | help='Early stopping\'s parameter: number of epochs with no improvement after which training will be stopped. Set to 0 to disable this technique.')
56 | parser.add_argument('--data_path', type=str, default='datasets/', help='the root folder of dataset')
57 | parser.add_argument('--log_path', type=str, default='logs/')
58 | parser.add_argument('-w', '--load_weights', type=str, default=None,
59 | help='whether to load weights from a checkpoint, set None to initialize, set \'last\' to load last checkpoint')
60 | parser.add_argument('--saved_path', type=str, default='logs/')
61 | parser.add_argument('--debug', type=bool, default=False, help='whether visualize the predicted boxes of trainging, '
62 | 'the output images will be in test/')
63 |
64 | args = parser.parse_args()
65 | return args
66 |
67 |
68 | class ModelWithLoss(nn.Module):
69 | def __init__(self, model, debug=False):
70 | super().__init__()
71 | self.criterion = FocalLoss()
72 | self.model = model
73 | self.debug = debug
74 |
75 | def forward(self, imgs, annotations, obj_list=None):
76 | _, regression, classification, anchors = self.model(imgs)
77 | if self.debug:
78 | cls_loss, reg_loss = self.criterion(classification, regression, anchors, annotations,
79 | imgs=imgs, obj_list=obj_list)
80 | else:
81 | cls_loss, reg_loss = self.criterion(classification, regression, anchors, annotations)
82 | return cls_loss, reg_loss
83 |
84 |
85 | def train(opt):
86 | params = Params(f'projects/{opt.project}.yml')
87 |
88 | if params.num_gpus == 0:
89 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
90 |
91 | if torch.cuda.is_available():
92 | torch.cuda.manual_seed(42)
93 | else:
94 | torch.manual_seed(42)
95 |
96 | opt.saved_path = opt.saved_path + f'/{params.project_name}/'
97 | opt.log_path = opt.log_path + f'/{params.project_name}/tensorboard/'
98 | os.makedirs(opt.log_path, exist_ok=True)
99 | os.makedirs(opt.saved_path, exist_ok=True)
100 |
101 | training_params = {'batch_size': opt.batch_size,
102 | 'shuffle': True,
103 | 'drop_last': True,
104 | 'collate_fn': collater,
105 | 'num_workers': opt.num_workers}
106 |
107 | val_params = {'batch_size': opt.batch_size,
108 | 'shuffle': False,
109 | 'drop_last': True,
110 | 'collate_fn': collater,
111 | 'num_workers': opt.num_workers}
112 |
113 | input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
114 | training_set = CocoDataset(root_dir=os.path.join(opt.data_path, params.project_name), set=params.train_set,
115 | transform=transforms.Compose([Normalizer(mean=params.mean, std=params.std),
116 | Augmenter(),
117 | Resizer(input_sizes[opt.compound_coef])]))
118 | training_generator = DataLoader(training_set, **training_params)
119 |
120 | val_set = CocoDataset(root_dir=os.path.join(opt.data_path, params.project_name), set=params.val_set,
121 | transform=transforms.Compose([Normalizer(mean=params.mean, std=params.std),
122 | Resizer(input_sizes[opt.compound_coef])]))
123 | val_generator = DataLoader(val_set, **val_params)
124 |
125 | model = EfficientDetBackbone(num_classes=len(params.obj_list), compound_coef=opt.compound_coef,
126 | ratios=eval(params.anchors_ratios), scales=eval(params.anchors_scales))
127 |
128 | # load last weights
129 | if opt.load_weights is not None:
130 | if opt.load_weights.endswith('.pth'):
131 | weights_path = opt.load_weights
132 | else:
133 | weights_path = get_last_weights(opt.saved_path)
134 | try:
135 | last_step = int(os.path.basename(weights_path).split('_')[-1].split('.')[0])
136 | except:
137 | last_step = 0
138 |
139 | try:
140 | ret = model.load_state_dict(torch.load(weights_path), strict=False)
141 | except RuntimeError as e:
142 | print(f'[Warning] Ignoring {e}')
143 | print(
144 | '[Warning] Don\'t panic if you see this, this might be because you load a pretrained weights with different number of classes. The rest of the weights should be loaded already.')
145 |
146 | print(f'[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}')
147 | else:
148 | last_step = 0
149 | print('[Info] initializing weights...')
150 | init_weights(model)
151 |
152 | # freeze backbone if train head_only
153 | if opt.head_only:
154 | def freeze_backbone(m):
155 | classname = m.__class__.__name__
156 | for ntl in ['EfficientNet', 'BiFPN']:
157 | if ntl in classname:
158 | for param in m.parameters():
159 | param.requires_grad = False
160 |
161 | model.apply(freeze_backbone)
162 | print('[Info] freezed backbone')
163 |
164 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
165 | # apply sync_bn when using multiple gpu and batch_size per gpu is lower than 4
166 | # useful when gpu memory is limited.
167 | # because when bn is disable, the training will be very unstable or slow to converge,
168 | # apply sync_bn can solve it,
169 | # by packing all mini-batch across all gpus as one batch and normalize, then send it back to all gpus.
170 | # but it would also slow down the training by a little bit.
171 | if params.num_gpus > 1 and opt.batch_size // params.num_gpus < 4:
172 | model.apply(replace_w_sync_bn)
173 | use_sync_bn = True
174 | else:
175 | use_sync_bn = False
176 |
177 | writer = SummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')
178 |
179 | # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
180 | model = ModelWithLoss(model, debug=opt.debug)
181 |
182 | if params.num_gpus > 0:
183 | model = model.cuda()
184 | if params.num_gpus > 1:
185 | model = CustomDataParallel(model, params.num_gpus)
186 | if use_sync_bn:
187 | patch_replication_callback(model)
188 |
189 | if opt.optim == 'adamw':
190 | optimizer = torch.optim.AdamW(model.parameters(), opt.lr)
191 | else:
192 | optimizer = torch.optim.SGD(model.parameters(), opt.lr, momentum=0.9, nesterov=True)
193 |
194 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
195 |
196 | epoch = 0
197 | best_loss = 1e5
198 | best_epoch = 0
199 | step = max(0, last_step)
200 | model.train()
201 |
202 | num_iter_per_epoch = len(training_generator)
203 |
204 | try:
205 | for epoch in range(opt.num_epochs):
206 | last_epoch = step // num_iter_per_epoch
207 | if epoch < last_epoch:
208 | continue
209 |
210 | epoch_loss = []
211 | progress_bar = tqdm(training_generator)
212 | for iter, data in enumerate(progress_bar):
213 | if iter < step - last_epoch * num_iter_per_epoch:
214 | progress_bar.update()
215 | continue
216 | try:
217 | imgs = data['img']
218 | annot = data['annot']
219 |
220 | if params.num_gpus == 1:
221 | # if only one gpu, just send it to cuda:0
222 | # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
223 | imgs = imgs.cuda()
224 | annot = annot.cuda()
225 |
226 | optimizer.zero_grad()
227 | cls_loss, reg_loss = model(imgs, annot, obj_list=params.obj_list)
228 | cls_loss = cls_loss.mean()
229 | reg_loss = reg_loss.mean()
230 |
231 | loss = cls_loss + reg_loss
232 | if loss == 0 or not torch.isfinite(loss):
233 | continue
234 |
235 | loss.backward()
236 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
237 | optimizer.step()
238 |
239 | epoch_loss.append(float(loss))
240 |
241 | progress_bar.set_description(
242 | 'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'.format(
243 | step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch, cls_loss.item(),
244 | reg_loss.item(), loss.item()))
245 | writer.add_scalars('Loss', {'train': loss}, step)
246 | writer.add_scalars('Regression_loss', {'train': reg_loss}, step)
247 | writer.add_scalars('Classfication_loss', {'train': cls_loss}, step)
248 |
249 | # log learning_rate
250 | current_lr = optimizer.param_groups[0]['lr']
251 | writer.add_scalar('learning_rate', current_lr, step)
252 |
253 | step += 1
254 |
255 | if step % opt.save_interval == 0 and step > 0:
256 | save_checkpoint(model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')
257 | print('checkpoint...')
258 |
259 | except Exception as e:
260 | print('[Error]', traceback.format_exc())
261 | print(e)
262 | continue
263 | scheduler.step(np.mean(epoch_loss))
264 |
265 | if epoch % opt.val_interval == 0:
266 | model.eval()
267 | loss_regression_ls = []
268 | loss_classification_ls = []
269 | for iter, data in enumerate(val_generator):
270 | with torch.no_grad():
271 | imgs = data['img']
272 | annot = data['annot']
273 |
274 | if params.num_gpus == 1:
275 | imgs = imgs.cuda()
276 | annot = annot.cuda()
277 |
278 | cls_loss, reg_loss = model(imgs, annot, obj_list=params.obj_list)
279 | cls_loss = cls_loss.mean()
280 | reg_loss = reg_loss.mean()
281 |
282 | loss = cls_loss + reg_loss
283 | if loss == 0 or not torch.isfinite(loss):
284 | continue
285 |
286 | loss_classification_ls.append(cls_loss.item())
287 | loss_regression_ls.append(reg_loss.item())
288 |
289 | cls_loss = np.mean(loss_classification_ls)
290 | reg_loss = np.mean(loss_regression_ls)
291 | loss = cls_loss + reg_loss
292 |
293 | print(
294 | 'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'.format(
295 | epoch, opt.num_epochs, cls_loss, reg_loss, loss))
296 | writer.add_scalars('Loss', {'val': loss}, step)
297 | writer.add_scalars('Regression_loss', {'val': reg_loss}, step)
298 | writer.add_scalars('Classfication_loss', {'val': cls_loss}, step)
299 |
300 | if loss + opt.es_min_delta < best_loss:
301 | best_loss = loss
302 | best_epoch = epoch
303 |
304 | save_checkpoint(model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')
305 |
306 | model.train()
307 |
308 | # Early stopping
309 | if epoch - best_epoch > opt.es_patience > 0:
310 | print('[Info] Stop training at epoch {}. The lowest loss achieved is {}'.format(epoch, best_loss))
311 | break
312 | except KeyboardInterrupt:
313 | save_checkpoint(model, f'efficientdet-d{opt.compound_coef}_{epoch}_{step}.pth')
314 | writer.close()
315 | writer.close()
316 |
317 |
318 | def save_checkpoint(model, name):
319 | if isinstance(model, CustomDataParallel):
320 | torch.save(model.module.model.state_dict(), os.path.join(opt.saved_path, name))
321 | else:
322 | torch.save(model.model.state_dict(), os.path.join(opt.saved_path, name))
323 |
324 |
325 | if __name__ == '__main__':
326 | opt = get_args()
327 | train(opt)
328 |
--------------------------------------------------------------------------------
/tutorial/train_shape.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "collapsed": true,
7 | "pycharm": {
8 | "name": "#%% md\n"
9 | }
10 | },
11 | "source": [
12 | "# EfficientDet Training On A Custom Dataset\n",
13 | "\n",
14 | "\n",
15 | "\n",
16 | "
"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "source": [
29 | "## This tutorial will show you how to train a custom dataset.\n",
30 | "\n",
31 | "## For the sake of simplicity, I generated a dataset of different shapes, like rectangles, triangles, circles.\n",
32 | "\n",
33 | "## Please enable GPU support to accelerate on notebook setting if you are using colab.\n",
34 | "\n",
35 | "### 0. Install Requirements"
36 | ],
37 | "metadata": {
38 | "collapsed": false,
39 | "pycharm": {
40 | "name": "#%% md\n"
41 | }
42 | }
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "outputs": [],
48 | "source": [
49 | "!pip install pycocotools numpy==1.16.0 opencv-python tqdm tensorboard tensorboardX pyyaml matplotlib\n",
50 | "!pip install torch==1.4.0\n",
51 | "!pip install torchvision==0.5.0"
52 | ],
53 | "metadata": {
54 | "collapsed": false,
55 | "pycharm": {
56 | "name": "#%%\n"
57 | }
58 | }
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "source": [
63 | "### 1. Prepare Custom Dataset/Pretrained Weights (Skip this part if you already have datasets and weights of your own)"
64 | ],
65 | "metadata": {
66 | "collapsed": false,
67 | "pycharm": {
68 | "name": "#%% md\n"
69 | }
70 | }
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "outputs": [],
76 | "source": [
77 | "import os\n",
78 | "import sys\n",
79 | "if \"projects\" not in os.getcwd():\n",
80 | " !git clone --depth 1 https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch\n",
81 | " os.chdir('Yet-Another-EfficientDet-Pytorch')\n",
82 | " sys.path.append('.')\n",
83 | "else:\n",
84 | " !git pull\n",
85 | "\n",
86 | "# download and unzip dataset\n",
87 | "! mkdir datasets\n",
88 | "! wget https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch/releases/download/1.1/dataset_shape.tar.gz\n",
89 | "! tar xzf dataset_shape.tar.gz\n",
90 | "\n",
91 | "# download pretrained weights\n",
92 | "! mkdir weights\n",
93 | "! wget https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch/releases/download/1.0/efficientdet-d0.pth -O weights/efficientdet-d0.pth\n",
94 | "\n",
95 | "# prepare project file projects/shape.yml\n",
96 | "# showing its contents here\n",
97 | "! cat projects/shape.yml"
98 | ],
99 | "metadata": {
100 | "collapsed": false,
101 | "pycharm": {
102 | "name": "#%%\n"
103 | }
104 | }
105 | },
106 | {
107 | "cell_type": "markdown",
108 | "source": [
109 | "### 2. Training"
110 | ],
111 | "metadata": {
112 | "collapsed": false
113 | }
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "outputs": [],
119 | "source": [
120 | "# consider this is a simple dataset, train head will be enough.\n",
121 | "! python train.py -c 0 -p shape --head_only True --lr 1e-3 --batch_size 32 --load_weights weights/efficientdet-d0.pth --num_epochs 50\n",
122 | "\n",
123 | "# the loss will be high at first\n",
124 | "# don't panic, be patient,\n",
125 | "# just wait for a little bit longer"
126 | ],
127 | "metadata": {
128 | "collapsed": false,
129 | "pycharm": {
130 | "name": "#%%\n"
131 | }
132 | }
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "source": [
137 | "### 3. Evaluation"
138 | ],
139 | "metadata": {
140 | "collapsed": false
141 | }
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": null,
146 | "outputs": [],
147 | "source": [
148 | "! python coco_eval.py -c 0 -p shape -w logs/shape/efficientdet-d0_49_1400.pth"
149 | ],
150 | "metadata": {
151 | "collapsed": false,
152 | "pycharm": {
153 | "name": "#%%\n"
154 | }
155 | }
156 | },
157 | {
158 | "cell_type": "markdown",
159 | "source": [
160 | "### 4. Visualize"
161 | ],
162 | "metadata": {
163 | "collapsed": false,
164 | "pycharm": {
165 | "name": "#%% md\n"
166 | }
167 | }
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": 4,
172 | "outputs": [
173 | {
174 | "data": {
175 | "text/plain": "",
176 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAD8CAYAAACVSwr3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3de7RdVX3o8e9vrrX3PnkgCSRQTECCQAW9ijRCLARQSgX0Ch3FB/WWaNEoYC/Utha9HX1wvR311lFbexWIAsY7rIq2NrnUq+UhJvYK8VB5KYJBwyMDSUCeyck5e635u3/Mufde56xzcp775fl9xtjZa8219l7zZO/123PNNR+iqhhjTJHrdgaMMb3HAoMxpsQCgzGmxAKDMabEAoMxpsQCgzGmpC2BQUTOFpEHRWS7iFzZjmMYY9pH5rodg4gkwEPAWcDjwPeBC1X1R3N6IGNM27SjxHASsF1Vf6qqI8CXgfPacBxjTJukbXjPFcBjhfXHgZP394Jly5bpkUce2YasGGMa7rrrrqdUdflU9m1HYJgSEVkPrAc44ogjGBwc7FZWjJkXROSRqe7bjkuJncDhhfWVMW0UVd2gqqtVdfXy5VMKYsaYDmlHYPg+cIyIrBKRKvBOYHMbjmOMaZM5v5RQ1UxEPgh8C0iA61X1h3N9HGNM+7SljkFVvwF8ox3vbYxpP2v5aIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpscBgjCmxwGCMKbHAYIwpmTQwiMj1IrJLRO4vpB0kIjeLyE/i89KYLiLyKRHZLiL3isiJ7cy8MaY9plJi+Dxw9pi0K4FbVfUY4Na4DnAOcEx8rAeunptsGmM6adLAoKpbgF+MST4P2BiXNwLnF9K/oMEdwBIROWyuMmuM6YyZ1jEcqqpPxOWfA4fG5RXAY4X9Ho9pJSKyXkQGRWRw9+7dM8yGMaYdZl35qKoK6Axet0FVV6vq6uXLl882G8aYOTTTwPBk4xIhPu+K6TuBwwv7rYxpxpg+MtPAsBlYF5fXAZsK6RfFuxNrgOcKlxzGmD6RTraDiHwJOANYJiKPA38O/DVwo4hcDDwCvD3u/g3gXGA7sBd4TxvybIxps0kDg6peOMGmM8fZV4HLZpspY0x3WctHY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMiQUGY0yJBQZjTIkFBmNMyaSBQUQOF5Fvi8iPROSHInJ5TD9IRG4WkZ/E56UxXUTkUyKyXUTuFZET2/1HGGPm1lRKDBnwh6p6PLAGuExEjgeuBG5V1WOAW+M6wDnAMfGxHrh6znNtjGmrSQODqj6hqv8Rl18AHgBWAOcBG+NuG4Hz4/J5wBc0uANYIiKHzXnOjTFtM606BhE5EngtcCdwqKo+ETf9HDg0Lq8AHiu87PGYZozpE1MODCKyGPgn4ApVfb64TVUV0OkcWETWi8igiAzu3r17Oi81xrTZlAKDiFQIQeGLqvrPMfnJxiVCfN4V03cChxdevjKmjaKqG1R1taquXr58+Uzzb4xpg6nclRDgOuABVf3bwqbNwLq4vA7YVEi/KN6dWAM8V7jkMMb0gXQK+5wC/C5wn4jcHdM+Cvw1cKOIXAw8Arw9bvsGcC6wHdgLvGdOc2yMabtJA4OqfheQCTafOc7+Clw2y3wZY7rIWj4aY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMBhjSiwwGGNKLDAYY0osMPQtBQ1P/cx3OwNmXGm3M2BmbtgDCk5oBYhioJDC8nTS5+I9iukyerMWVjIgVajYN7Gn2MfRtxy1pNt5mL1KY0EVZLzIZLph0sAgIgPAFqAW9/+aqv65iKwCvgwcDNwF/K6qjohIDfgC8GvA08A7VHVHm/I/7510+rtBW1eExWvDYjF925brWXPa75XSx9t/oveYaXozTTxOwQvgE7ZtvS78DaddwLbvwOgihgWHbppKHcMw8EZVfQ1wAnC2iKwBPg58UlWPBp4BLo77Xww8E9M/Gfczc8gXTppt3/k8ojlChpChhYcUHkApfduW61Ey7ojP27ZcP+o97ojrY997ovSxxyylqUdVw3MhhGzb8rXRlyCm6yYNDBq8GFcr8aHAG4GvxfSNwPlx+by4Ttx+pojYxz5HxlbWvX7tOpwod275Ak6l+Ri7DoxaL6a9fu26cfcZ7zXF/Wf2cM33bTjtjHe0+X/NTNeU7kqISCIidwO7gJuBh4FnVTWLuzwOrIjLK4DHAOL25wiXG2Pfc72IDIrI4O7du2f3V8wjjjEfmjq8CiefdhEe13ycvPY9o9aBUevjpU223nh8b+vGcdOL28bb53tbw++FF4+XVqlny+1fGXPlYJcR3TalwKCquaqeAKwETgJeMdsDq+oGVV2tqquXL18+27ebZ1q/tr5ZXyeE8sT4j5PXrps0rbwPE77XnVtv2O9xJjyeeNAEJG/+Daec8c4+uZQIAcuTt/6HtPGPtnZRgJzmBVNcR2n9d2hv36qdVjsGVX0W+DbwemCJiDQqL1cCO+PyTuBwgLj9QEIlpGkDp4DM/S/snVs3xpN7fPvbtl/q4uVEX0SCMQR8q9TmIJ7sEgNz1twNHC6XZmWr1yR8Ti4+pLcbEU2aNxFZLiJL4vIC4CzgAUKAuCDutg7YFJc3x3Xi9ttU1cqGfWbGJ/4vM6V50jd//YUYncEjeNFm5bBPwEsG5CAaSxqxyNDjjdOm0o7hMGCjiCSEQHKjqt4kIj8CviwiHwN+AFwX978O+N8ish34BfDONuTbmM5r3EVVCcsS7rkoQiLgNBnVSCzEizScNR6QtHDrttOZn55JA4Oq3gu8dpz0nxLqG8am7wPeNie5M6ZjpnimjtktnWibjFl3o55moLPFC2v5aEzTZCdfuH7YV89wlZwqi6jzGIKSI3j2ABk5i3GMIKSkHAQ4KrwktO6ERlUFAG7SxlxSeO5ccLDAYExJ8SQsLjuGeIwFlcP5qwd+nY8eB//jkXMgGSHxipOFDGcjJBVInEdHBD+0kGq+kIF9v82lq3dSYRlOazhpNPEa29FkKh1S2s8CQ78ZcxfCS6j0QrSna7mBcG2uAi4Lty0b2nBXZVZUS3l6ke1cPXg+I4t+Bx14DX7xLwD4i5fdz1WPvQIVyH1GmgAevAo4QRYNUWeIfNFzfOLH72Dg+Zfy/pP+hsW8DIeSZ54k1XgL15MxTlsVoNNBoue/S2YyHofixIeWhT5pPYotDqeTPhfvUUwvtHp0MTi4Qv8O8b3UGyx25lJBs3AC7mEH19z3Xv549b+QL3qKP111T7w9CVc9elwr4BWDXaFlJz7hT1fdR33Bs/zhSV/lmgcugthQ3aXgfRI7kAkpMkEbBy082s9KDP1GwDfbJYLzCcTbZJnLS/tO9B5TTp/L9xCPczle/KiOXyo91NTHCzl1nPPsSx9lw+C72HvYW8iWplz1yCtBfAgG0HyejJOcjz12LLic//74cejiXfzV9tdRee63uPTXPsMi91JyRvBkVMhwAp5hCvdHO85KDH3IFT4273Iy5/EuB3RUI+Tir8x00ufiPYrpo37tGiWJwi+q9NgN/YSUjD38w0MX8sKye8jqyoiOhIA2A43/BUf4X8E5RgaeY++yB/n0g2/H8xwJlXDcWHpyWpu7P2gGrMTQl1on1Z1bN+5nv/6g5JPv1CniyeRF/mHb77Hn0GcRUirJCBVfCaWcGQQH7/IQzPNK8xg4hzp4YcHTfPqOy7l0zf8i8Ysb0aPrLDD0nVi89Mopb3gH6itIXkEFsl4qko9HYiUbAJ7EeYa9MJBq8xe12/bI83z2nnfz/LLtJCSAUkeoeIfKDPs3NEpHLms0ksTFTi6qGU+/dBufGLyQK1bfQE2XoeLx6uhmzYv0Qmvl1atX6+DgYLez0Rda979zhjUBZfQXqMeHdmsMVRnr2hjJoSqQJt27ng7CrcK/ueN3GPmVe0lTZVgzyKuhhaPLSXWG+RMN7aMbdUAqzToWyYW620PFH8DAo6v5g1M2kFYcXj1OEsb/UGaYDZG7VHX1VPa1EkOfaf2qJtTGtq7rQ2nzG9jJPyQ0VMpwpIDmgiQwzCMMHfojcJ7cA7RO5hkHBWjdsSi+RyzdaeJIXQVNRhg66nvUKzug/lK8LqBanexvgOb/2/6C8wyy3gulN2M6TECT+KsozfPmM1v+ApKRDmfFg6+AT/Dq+NxdH0Uqbv+D4zb6a6i06nVFQbLYnT2nuWGG8cwCg5mXwriTgkfJJXSX3nfIQ53NhGg4kVVweHwOzw08yjA7EbefS4dSn4w8LGga6nA0aQWOGbLAYOadMIBKY0yEHIlnwd4Fz3Y2IyqhtJCMgMtJq1Bf+CLXbvsQ+X7qFFSV3OeF6ockvl1GsxLHE4aHGAb2xcc0WB2DmXdG/xomrQZjkpe2tp34Zq9sXI5XYWThU4ywnQUcO+5LVDVUTCooikiYoEOylC233YvPhcSlgAOfkCQJIyPTu0SywGDmIWneNiQX8iTUN6TNInnnKkK9KM5X8Aguh0olI0scQwyzYILXOHF4r4gTJBP8MPy/7zyErztgIXhP3StprNmt5xnVyvQaTFlgMPOXZJBAxgPUgEq+kMzV23a4PzvigSnu+WPg1WPSRk/n5Rr3qNPQ3+LUc2efvyILDGZ+EoA0jLIWrx7aGRT6jQUGM78JEJtkq8br9Ta66tHjmiWHxnKjM9bY5dwPk7jGJUCrQcJ3//UhsnpoF5EkrVN47VtfztbND5eWG+vTYXclzPzUGNJdQlMnCIGh3Ron/3g9M4tB4+P3vxHvssLWYjfulDSpUqlU25ZnKzGYeSz0hWj0SnCuh34nfUJSOD23bvpZaAwlvjlEXDtLOBYYzDyWx5GTOtdfaOylxITbj3iAwlzgrD1vVXN566afsfa8VaMuFQC2bn64eckw0SXFVFlgMPOTQKv7WWdPg2JAmCg41J55Gac88UNOPTukbf3XH4fbqHnoQDHRyT5esJiJHio7GdMFCkKs4NMunw6NCX99gqsPoL6Qnzz2p3B57IA1zlBvc5h/CwxmXiqOq5DEOZdTSToaHEbNECoKeWiUlIws5PiD3wDZwJhXSGHkKyk8GklzNx6HBQYzL7lGnwKXkbIMgGEdiUPkdSoP8VkdDo/Wcuo+w9WXUNv5qx3Lx3isjsHMUxL7RqTNE1TxiHZ23CSnLs6InaAjVRKpkTz9K1RePKKj+Sjlq6tHN6bbmnNMwgD7HRmlDcd2o8aR1HQvfgiOX3oqyfBEPSU6wwKDmZ8UINyq1FiBt/DR1aSdHIlVPLgs1C+4jNQv4ICnjmbg4dPZ/0gt7WeBwcw7njhQCwrkzTGcLln7caRU4dd+zif4vAo+4TXpf6HqhJF8uOP5GJWnrh7dmC5oTgGnAqTNiv0ay6ntOoaqr5AgaBIqIp3kZJrgtXDXQsNYBx7wk9zJ8C4Pg8ECCWH6uwxHhiNJElwCiSqLd5zOwJPH4nwVke7OzmWBwcxfpdbEwgfXfJbKU6sYqUOqA7i8iseRuizcyYDQUtLVIRnBST7pbUKX1XCx30NdE7x4quRUNTRrrmf7GNixhtfIhZAoKkp1f0O7dYAFBmMKqhzIpas/waJsKblkca4OT5Kn4Oqjb2fGOxiTnUROfLx0AZDQWMnlgCfzntreg/lPA79N7cVD0DxBcWi9u8N/W2AwZowFejR/8Ks3seD5pXjnqZOQuzyO5pyGywdfDfUC0Lq7MAEvHi+NYeihgifLa4yIY+D5Q/jj479JbfcqMs1RVdLUo5UJ364jphwYRCQRkR+IyE1xfZWI3Cki20XkKyJSjem1uL49bj+yPVk3Zu4571GUhCVcetwNLHryKJKsSpZDPRtGPOR1QZxvDTXf6PU4gUql0pyXwgN5kpFmVV7y7NH811fdgBs+kGpSwavgU08997i8u7OKTafEcDlQHJvq48AnVfVo4Bng4ph+MfBMTP9k3M+YvqAeRD2Cp6Iv54rXbeSQX7yCBdUBBgaESq1OxaWtXplTaBA1tHeY5rVEqiS6iCW/eCWXv/oGFrKKRByJpqS+QhZnLpcu99uY0tFFZCXwZuBzcV2ANwJfi7tsBM6Py+fFdeL2M6Xdw+IYM0ckdXiXoChVHLV8KZec8EUuX7GZdPup6ItLkGQvmsWZpRrTz+3nRE4HErI0BJCFPzuBPzr8Jj74uo1UdGkYIkbjxLduhAU+9IfochXDlEsMfwd8mFafj4OBZ1W1McTM48CKuLwCeAwgbn8u7j+KiKwXkUERGdy9e/cMs2/M3Au3M124a+FAkoSFrOSK0z/NFcdu4oAdJ7No6AAWiIM8xacjSJKhIuRewkkiaQgWmbB4eDEHPHwSAB9a+xlqrAAEL1kcth58Hk4ujWNDuC5PUDxp8yoReQuwS1XvEpEz5urAqroB2ABhUtu5el9j5oaMegKoMUCqjkvWXkPOEBlPc82Wq0gH4te3sheqe5Chg/G5oFrnvWs+huNAakcdDHwexwG0pvb28Rao4CQhx+Mbc1wWmkp3w1TaXZ4CvFVEzgUGgJcAfw8sEZE0lgpWAjvj/juBw4HHRSQFDgSenvOcGzPnJi+/JzJ6dvEPnzbZK74+Zt01y+nN4noVfv3NE7/DTAdbmY1JA4OqfgT4CEAsMfyRqr5LRL4KXAB8GVgHbIov2RzXvxe336adGGXTmFmZ/Cvqm/sJmkEYoLkxe5WQZRlpHLXZq8c5CZWOLieEEyHMG9cICYL3Hpc5/v2W7fis0Kqyy5cSs6n6/BPgQyKynVCHcF1Mvw44OKZ/CLhydlk0pjc4bZwwOdL8SQ0TyHrApSleGtPNOTw5YRCmpDUgC+CbJZMc58LFRK/9dk6rC5eq3g7cHpd/Cpw0zj77gLfNQd6M6S0KoT7AhZXGaEqiuEZpQNM4uGwdtBoqFltTQoQRqeP7hJ4TYaq5UQPSNodu696tCRuoxZipiiOp+eI6ECoREyAtnM+NmgjFyaidwyti705RqNdBNQwa0yt6JyfG9Lrmr/7YX/JkzD5SSBNK9RdSuIYXIW2MDyMefBpmwFYp9K/oPOsrYcxckvGWJzrDQ7o4WHtmHEY+3qrsZlAACwzGdFejMJEQ+1z0RiWkBQZjuqh5MyIFXDZph6xOscBgTBdJaAIBoqz9jeNRqZMkSdtn3Z6MBQZjeoJABZKKouotMBgzn/lGJYOCSkbuM+r1etcbPFlg6JbCtIMebX5BitOWTec9ptKk1/Qep9Js+yCknHbWq9DKSOiUAYCS1X3o2t3B+gcLDF2SSY4XxWv4cjiVGCCAGCgyQsv6ECxi8FCaffgR8JKThTmUuva3mFloTj8ZorwKvOGsV4HLEIE8z6nVaqFi0qfh0QEWGLrE4WJACBOaeslwKqQaGtCER3H/kOZFSZV4W0vJcaTxfUw/anxwAnhyMqjAqWceTZIkJFKjng3T6gbemRKDtXxsu/F/yRsDdIxuBVfeZ2zkdmP2q0x8iMgiRm+TQmxISFNtJo+wB5ccQOKrqDYuIzrTh8JKDL/0ilOlj30eu5/pOB27HJtQi7L2zOPxMkSmPk7A27mOVVZi6IhicTFWMubKi074/c99lZ35Um75wFkAnHzNVvb5Ybz31Go1sizj7kvO5LVX3w7ikaSGjOQMa8ZRvMDmD57HPp8x4JJR7986XvH445Uexu5jOqp4njcqkyV8jq4Ga3/zWL777ftheEHYuUNjNVhg6Ihw/Qi+WXn4FMLvf/pf2J6+lB984FROvPpW8sQhPkNclQTP4PtOH/UeXhMWZDlDicfh2MGBADxMwiHA8v0ev/TTNGZ5nM4+pkMalw+Fy4r4ebgqrD3zVXz3mw+17kx0IDhYYOgYT6ahP/6ODN73uZt5NlkKcTxdJQkD/QhxDHNCKaFJcUAdxWmYijWJt7QuvvY2lrt9/J/3hz3reU4lcaNeOz4LBr1Bxl0MPJImrD37WL5zy33oyACpJGSZkqTSagylCY4wuc2c5KjbDSkgDAY7ODjY7Wy0ieDxuFwYTuBRhbd95hbEzXaqofANciguAXLPYj/M1z5wNksbwwo2Ws+pWhVCT5hJIA77q5cwUn0G37vlp6Au3s4sliCKpb/R77H2rUcjwl2qunoqR7XKxzZqxG6HJ3NhjP311/4bqQww0zO10R3XKZB7PIK4ChkJz7OID3x2E3vHvrUFhf7UrIwUcq2TkyMVOOVNR1FPn8G7HJ/JmLsVc1OxbIGhjUQbrRkT9gKXX/stnskHyLXOXBThnQsf33A9Ixclr8AjHMgjwNCoSq36rI9luqDwGaZJJXSucoqmntPfdAKnnPVyfHUPKgI+RaRxaRHGkBQJg83OhAWGDhAPF13zTX42lCJpBm7m/+3NhkyieBnVJprcj5CgvPuz/86TozJgH3Nf08IDQdQhCJLAGW9+Jae9ZRVrz3kZVOuIA685xT5YOoNRX6zysY3Eh0kIhgQedgup1sBl9Thy8MyEIb90VH8KFyskRSqMkDCU7ePyDTexaX3YnpHYB93PineUFbzXMDQ9NEsFruY45axjkQzI4fbb7wV15OpJ0+nXZ9n3pQN+LJD4jAVZjaFa7Bfh8xm/n1Mhi0WH1pDmkGceJ0JShR060Np/Npk3vaMRIJJQiek13LZ2zuE1D+1cqg6vcNrZr8RJGJF6JvcXLDC0k4bf9T+55pvU3CKG0ix8gLPoCOOlFRCaaYSTP02UDMVlFZKkFXgsMPSS2dcENz5PV3iriZYRmMnQDhYY2mjYKTVgDwn7CHMPOG1P24EwkSq4XMApvvBtaE6VaLqs200Dph4h7PvSRpX4OezF4fN6OHnb1GrNC2HqdOepKmje+hJYz0szXRYY2shJnFvAJ1TTVuGsHY1ZnYKoCw2vVdARu0VpZs4CQxsNx+eKCPVcC0ODz/1/uwOqGkoOmVNcpXAM7f6ow6a/WGBoo8bpOJwNU1GHxkuIVOf+P74xJJzD4xSKNzR9j8xVYPqHBYY2alzbp2lK3dVJfBLbIfi2XE548aAJPslIkgWtDZJM/CJjxmGBoY1q8dkTSvPOC67Nv94+zsAs2czbSRhjgaEDqkkVcmIzZt++/3R1IfCoIHkrMPRAB1rTZywwtFE91ikkXllYTaknOcNCOIHbcDwvoT7B+YSs0roLklhgMNNkgaGNGuM8JzpMlmUkCok6MqStw707IHeFSwnrdm2myQJDG6V5+KleVXkR0irOJ6Rt/fX2pD50sqr4QjsGCwxmmqYUGERkh4jcJyJ3i8hgTDtIRG4WkZ/E56UxXUTkUyKyXUTuFZET2/kH9LYQBf7yveeT1+vURag7z4C2pzWig+ZdjwO6PPeh6W/TKTG8QVVPKAwNdSVwq6oeA9wa1wHOAY6Jj/XA1XOV2b4Tx104DFgqdbxmqBeGJYsVkbHI7/I4Vl+jL8XEGl3rU8I8A7nzOAmPTIWRxJOrcmx9qPWaNv155pfXbC4lzgM2xuWNwPmF9C9ocAewREQOm8Vx+lZGuM6vAZ+77E0cJDm1TKkmVTJ1LMgctTxMO5b6BC9KPdn/bcY0nuaZpnhNEJ/gNcGrQ0ZykhHlKF7kY5f85+ZrrK+Ema6pBgYF/k1E7hKROPwHh6rqE3H558ChcXkF8FjhtY/HtFFEZL2IDIrI4O7du2eQ9d7XOIkdyvIcjuRZlIysnuNcxlAaGjo5JY6v4EnQ/Xay8upIgVy02QzaSeiLn9YqLK4Oce36t3JwsU2TFRnMNE01MJyqqicSLhMuE5HTihs1DDU9rd8lVd2gqqtVdfXy5RPPiNDP8kKLwwWJ8plLLuAo2cOw1MPMQq4eBvR0ORrbNzifhB6YE/AIHqiRg+SMNO4+iKJ+hOvffw5LkjEfrFU3mGmaUmBQ1Z3xeRfwdeAk4MnGJUJ83hV33wkcXnj5ypg27ySN/14VKngWq+dLl53PcbxAmkOaVck1RUmp+RAUMnXs70xOFTLxDDtAKkie4LTKUq/ceMlvcLRCtT76g53NUHJmfpr0KyMii0TkgMYy8JvA/cBmYF3cbR2wKS5vBi6KdyfWAM8VLjnmmVYhypOQiZCiXP/+3+J49vAStxfv64iHDBd6X8r+C1+ZU5w6Kt4xkCtp6lmZvMAN7zuTVXEUcamA7/qgIKafTWUEp0OBr0u4/ZUC/6iq3xSR7wM3isjFwCPA2+P+3wDOBbYDe4H3zHmu+4VKHOpfWjNXC7wkgY2XtvfQbtRy5yZDNb8cemImKhF5AXiw2/mYomXAU93OxBT0Sz6hf/LaL/mE8fP6MlWdUoVer4z5+OBUp87qNhEZ7Ie89ks+oX/y2i/5hNnn1aqljDElFhiMMSW9Ehg2dDsD09Avee2XfEL/5LVf8gmzzGtPVD4aY3pLr5QYjDE9pOuBQUTOFpEHYzftKyd/RVvzcr2I7BKR+wtpPdm9XEQOF5Fvi8iPROSHInJ5L+ZXRAZEZJuI3BPz+ZcxfZWI3Bnz8xURqcb0WlzfHrcf2Yl8FvKbiMgPROSmHs9ne4dCUNWuPYAEeBg4CqgC9wDHdzE/pwEnAvcX0v4ncGVcvhL4eFw+F/i/hJZDa4A7O5zXw4AT4/IBwEPA8b2W33i8xXG5AtwZj38j8M6Yfg1wSVy+FLgmLr8T+EqH/18/BPwjcFNc79V87qgn+38AAAIxSURBVACWjUmbs8++Y3/IBH/c64FvFdY/Anyky3k6ckxgeBA4LC4fRmhzAXAtcOF4+3Up35uAs3o5v8BC4D+AkwmNb9Kx3wPgW8Dr43Ia95MO5W8lYWyRNwI3xROp5/IZjzleYJizz77blxJT6qLdZbPqXt4JsRj7WsKvcc/lNxbP7yZ0tLuZUEp8VlWzcfLSzGfc/hxwcCfyCfwd8GFaHdUP7tF8QhuGQijqlZaPfUFVVaS3pnUSkcXAPwFXqOrzUhjSrVfyq6o5cIKILCH0zn1Fl7NUIiJvAXap6l0icka38zMFp6rqThE5BLhZRH5c3Djbz77bJYZ+6KLds93LRaRCCApfVNV/jsk9m19VfRb4NqFIvkREGj9Mxbw08xm3Hwg83YHsnQK8VUR2AF8mXE78fQ/mE2j/UAjdDgzfB46JNb9VQiXO5i7naaye7F4uoWhwHfCAqv5tr+ZXRJbHkgIisoBQD/IAIUBcMEE+G/m/ALhN44VxO6nqR1R1paoeSfge3qaq7+q1fEKHhkLoVGXJfipRziXUqD8M/Lcu5+VLwBNAnXAddjHhuvFW4CfALcBBcV8BPh3zfR+wusN5PZVwnXkvcHd8nNtr+QVeDfwg5vN+4M9i+lHANkL3/K8CtZg+ENe3x+1HdeF7cAatuxI9l8+Yp3vi44eN82YuP3tr+WiMKen2pYQxpgdZYDDGlFhgMMaUWGAwxpRYYDDGlFhgMMaUWGAwxpRYYDDGlPx/iALj2+UtxycAAAAASUVORK5CYII=\n"
177 | },
178 | "metadata": {
179 | "needs_background": "light"
180 | },
181 | "output_type": "display_data"
182 | }
183 | ],
184 | "source": [
185 | "import torch\n",
186 | "from torch.backends import cudnn\n",
187 | "\n",
188 | "from backbone import EfficientDetBackbone\n",
189 | "import cv2\n",
190 | "import matplotlib.pyplot as plt\n",
191 | "import numpy as np\n",
192 | "\n",
193 | "from efficientdet.utils import BBoxTransform, ClipBoxes\n",
194 | "from utils.utils import preprocess, invert_affine, postprocess\n",
195 | "\n",
196 | "compound_coef = 0\n",
197 | "force_input_size = None # set None to use default size\n",
198 | "img_path = 'datasets/shape/val/999.jpg'\n",
199 | "\n",
200 | "threshold = 0.2\n",
201 | "iou_threshold = 0.2\n",
202 | "\n",
203 | "use_cuda = True\n",
204 | "use_float16 = False\n",
205 | "cudnn.fastest = True\n",
206 | "cudnn.benchmark = True\n",
207 | "\n",
208 | "obj_list = ['rectangle', 'circle']\n",
209 | "\n",
210 | "# tf bilinear interpolation is different from any other's, just make do\n",
211 | "input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]\n",
212 | "input_size = input_sizes[compound_coef] if force_input_size is None else force_input_size\n",
213 | "ori_imgs, framed_imgs, framed_metas = preprocess(img_path, max_size=input_size)\n",
214 | "\n",
215 | "if use_cuda:\n",
216 | " x = torch.stack([torch.from_numpy(fi).cuda() for fi in framed_imgs], 0)\n",
217 | "else:\n",
218 | " x = torch.stack([torch.from_numpy(fi) for fi in framed_imgs], 0)\n",
219 | "\n",
220 | "x = x.to(torch.float32 if not use_float16 else torch.float16).permute(0, 3, 1, 2)\n",
221 | "\n",
222 | "model = EfficientDetBackbone(compound_coef=compound_coef, num_classes=len(obj_list),\n",
223 | "\n",
224 | " # replace this part with your project's anchor config\n",
225 | " ratios=[(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)],\n",
226 | " scales=[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])\n",
227 | "\n",
228 | "model.load_state_dict(torch.load('logs/shape/efficientdet-d0_49_1400.pth'))\n",
229 | "model.requires_grad_(False)\n",
230 | "model.eval()\n",
231 | "\n",
232 | "if use_cuda:\n",
233 | " model = model.cuda()\n",
234 | "if use_float16:\n",
235 | " model = model.half()\n",
236 | "\n",
237 | "with torch.no_grad():\n",
238 | " features, regression, classification, anchors = model(x)\n",
239 | "\n",
240 | " regressBoxes = BBoxTransform()\n",
241 | " clipBoxes = ClipBoxes()\n",
242 | "\n",
243 | " out = postprocess(x,\n",
244 | " anchors, regression, classification,\n",
245 | " regressBoxes, clipBoxes,\n",
246 | " threshold, iou_threshold)\n",
247 | "\n",
248 | "out = invert_affine(framed_metas, out)\n",
249 | "\n",
250 | "for i in range(len(ori_imgs)):\n",
251 | " if len(out[i]['rois']) == 0:\n",
252 | " continue\n",
253 | "\n",
254 | " for j in range(len(out[i]['rois'])):\n",
255 | " (x1, y1, x2, y2) = out[i]['rois'][j].astype(np.int)\n",
256 | " cv2.rectangle(ori_imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2)\n",
257 | " obj = obj_list[out[i]['class_ids'][j]]\n",
258 | " score = float(out[i]['scores'][j])\n",
259 | "\n",
260 | " cv2.putText(ori_imgs[i], '{}, {:.3f}'.format(obj, score),\n",
261 | " (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,\n",
262 | " (255, 255, 0), 1)\n",
263 | "\n",
264 | " plt.imshow(ori_imgs[i])\n",
265 | "\n"
266 | ],
267 | "metadata": {
268 | "collapsed": false,
269 | "pycharm": {
270 | "name": "#%%\n"
271 | }
272 | }
273 | }
274 | ],
275 | "metadata": {
276 | "kernelspec": {
277 | "display_name": "Python 3",
278 | "language": "python",
279 | "name": "python3"
280 | },
281 | "language_info": {
282 | "codemirror_mode": {
283 | "name": "ipython",
284 | "version": 2
285 | },
286 | "file_extension": ".py",
287 | "mimetype": "text/x-python",
288 | "name": "python",
289 | "nbconvert_exporter": "python",
290 | "pygments_lexer": "ipython2",
291 | "version": "2.7.6"
292 | }
293 | },
294 | "nbformat": 4,
295 | "nbformat_minor": 0
296 | }
--------------------------------------------------------------------------------
/utils/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .batchnorm import patch_sync_batchnorm, convert_model
13 | from .replicate import DataParallelWithCallback, patch_replication_callback
14 |
--------------------------------------------------------------------------------
/utils/sync_batchnorm/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 | import contextlib
13 |
14 | import torch
15 | import torch.nn.functional as F
16 |
17 | from torch.nn.modules.batchnorm import _BatchNorm
18 |
19 | try:
20 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
21 | except ImportError:
22 | ReduceAddCoalesced = Broadcast = None
23 |
24 | try:
25 | from jactorch.parallel.comm import SyncMaster
26 | from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
27 | except ImportError:
28 | from .comm import SyncMaster
29 | from .replicate import DataParallelWithCallback
30 |
31 | __all__ = [
32 | 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
33 | 'patch_sync_batchnorm', 'convert_model'
34 | ]
35 |
36 |
37 | def _sum_ft(tensor):
38 | """sum over the first and last dimention"""
39 | return tensor.sum(dim=0).sum(dim=-1)
40 |
41 |
42 | def _unsqueeze_ft(tensor):
43 | """add new dimensions at the front and the tail"""
44 | return tensor.unsqueeze(0).unsqueeze(-1)
45 |
46 |
47 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
48 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
49 |
50 |
51 | class _SynchronizedBatchNorm(_BatchNorm):
52 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
53 | assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
54 |
55 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
56 |
57 | self._sync_master = SyncMaster(self._data_parallel_master)
58 |
59 | self._is_parallel = False
60 | self._parallel_id = None
61 | self._slave_pipe = None
62 |
63 | def forward(self, input):
64 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
65 | if not (self._is_parallel and self.training):
66 | return F.batch_norm(
67 | input, self.running_mean, self.running_var, self.weight, self.bias,
68 | self.training, self.momentum, self.eps)
69 |
70 | # Resize the input to (B, C, -1).
71 | input_shape = input.size()
72 | input = input.view(input.size(0), self.num_features, -1)
73 |
74 | # Compute the sum and square-sum.
75 | sum_size = input.size(0) * input.size(2)
76 | input_sum = _sum_ft(input)
77 | input_ssum = _sum_ft(input ** 2)
78 |
79 | # Reduce-and-broadcast the statistics.
80 | if self._parallel_id == 0:
81 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
82 | else:
83 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
84 |
85 | # Compute the output.
86 | if self.affine:
87 | # MJY:: Fuse the multiplication for speed.
88 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
89 | else:
90 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
91 |
92 | # Reshape it.
93 | return output.view(input_shape)
94 |
95 | def __data_parallel_replicate__(self, ctx, copy_id):
96 | self._is_parallel = True
97 | self._parallel_id = copy_id
98 |
99 | # parallel_id == 0 means master device.
100 | if self._parallel_id == 0:
101 | ctx.sync_master = self._sync_master
102 | else:
103 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
104 |
105 | def _data_parallel_master(self, intermediates):
106 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
107 |
108 | # Always using same "device order" makes the ReduceAdd operation faster.
109 | # Thanks to:: Tete Xiao (http://tetexiao.com/)
110 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
111 |
112 | to_reduce = [i[1][:2] for i in intermediates]
113 | to_reduce = [j for i in to_reduce for j in i] # flatten
114 | target_gpus = [i[1].sum.get_device() for i in intermediates]
115 |
116 | sum_size = sum([i[1].sum_size for i in intermediates])
117 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
118 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
119 |
120 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
121 |
122 | outputs = []
123 | for i, rec in enumerate(intermediates):
124 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
125 |
126 | return outputs
127 |
128 | def _compute_mean_std(self, sum_, ssum, size):
129 | """Compute the mean and standard-deviation with sum and square-sum. This method
130 | also maintains the moving average on the master device."""
131 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
132 | mean = sum_ / size
133 | sumvar = ssum - sum_ * mean
134 | unbias_var = sumvar / (size - 1)
135 | bias_var = sumvar / size
136 |
137 | if hasattr(torch, 'no_grad'):
138 | with torch.no_grad():
139 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
140 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
141 | else:
142 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
143 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
144 |
145 | return mean, bias_var.clamp(self.eps) ** -0.5
146 |
147 |
148 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
149 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
150 | mini-batch.
151 |
152 | .. math::
153 |
154 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
155 |
156 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
157 | standard-deviation are reduced across all devices during training.
158 |
159 | For example, when one uses `nn.DataParallel` to wrap the network during
160 | training, PyTorch's implementation normalize the tensor on each device using
161 | the statistics only on that device, which accelerated the computation and
162 | is also easy to implement, but the statistics might be inaccurate.
163 | Instead, in this synchronized version, the statistics will be computed
164 | over all training samples distributed on multiple devices.
165 |
166 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
167 | as the built-in PyTorch implementation.
168 |
169 | The mean and standard-deviation are calculated per-dimension over
170 | the mini-batches and gamma and beta are learnable parameter vectors
171 | of size C (where C is the input size).
172 |
173 | During training, this layer keeps a running estimate of its computed mean
174 | and variance. The running sum is kept with a default momentum of 0.1.
175 |
176 | During evaluation, this running mean/variance is used for normalization.
177 |
178 | Because the BatchNorm is done over the `C` dimension, computing statistics
179 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
180 |
181 | Args:
182 | num_features: num_features from an expected input of size
183 | `batch_size x num_features [x width]`
184 | eps: a value added to the denominator for numerical stability.
185 | Default: 1e-5
186 | momentum: the value used for the running_mean and running_var
187 | computation. Default: 0.1
188 | affine: a boolean value that when set to ``True``, gives the layer learnable
189 | affine parameters. Default: ``True``
190 |
191 | Shape::
192 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
193 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
194 |
195 | Examples:
196 | >>> # With Learnable Parameters
197 | >>> m = SynchronizedBatchNorm1d(100)
198 | >>> # Without Learnable Parameters
199 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
200 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
201 | >>> output = m(input)
202 | """
203 |
204 | def _check_input_dim(self, input):
205 | if input.dim() != 2 and input.dim() != 3:
206 | raise ValueError('expected 2D or 3D input (got {}D input)'
207 | .format(input.dim()))
208 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
209 |
210 |
211 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
212 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
213 | of 3d inputs
214 |
215 | .. math::
216 |
217 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
218 |
219 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
220 | standard-deviation are reduced across all devices during training.
221 |
222 | For example, when one uses `nn.DataParallel` to wrap the network during
223 | training, PyTorch's implementation normalize the tensor on each device using
224 | the statistics only on that device, which accelerated the computation and
225 | is also easy to implement, but the statistics might be inaccurate.
226 | Instead, in this synchronized version, the statistics will be computed
227 | over all training samples distributed on multiple devices.
228 |
229 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
230 | as the built-in PyTorch implementation.
231 |
232 | The mean and standard-deviation are calculated per-dimension over
233 | the mini-batches and gamma and beta are learnable parameter vectors
234 | of size C (where C is the input size).
235 |
236 | During training, this layer keeps a running estimate of its computed mean
237 | and variance. The running sum is kept with a default momentum of 0.1.
238 |
239 | During evaluation, this running mean/variance is used for normalization.
240 |
241 | Because the BatchNorm is done over the `C` dimension, computing statistics
242 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
243 |
244 | Args:
245 | num_features: num_features from an expected input of
246 | size batch_size x num_features x height x width
247 | eps: a value added to the denominator for numerical stability.
248 | Default: 1e-5
249 | momentum: the value used for the running_mean and running_var
250 | computation. Default: 0.1
251 | affine: a boolean value that when set to ``True``, gives the layer learnable
252 | affine parameters. Default: ``True``
253 |
254 | Shape::
255 | - Input: :math:`(N, C, H, W)`
256 | - Output: :math:`(N, C, H, W)` (same shape as input)
257 |
258 | Examples:
259 | >>> # With Learnable Parameters
260 | >>> m = SynchronizedBatchNorm2d(100)
261 | >>> # Without Learnable Parameters
262 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
263 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
264 | >>> output = m(input)
265 | """
266 |
267 | def _check_input_dim(self, input):
268 | if input.dim() != 4:
269 | raise ValueError('expected 4D input (got {}D input)'
270 | .format(input.dim()))
271 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
272 |
273 |
274 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
275 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
276 | of 4d inputs
277 |
278 | .. math::
279 |
280 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
281 |
282 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
283 | standard-deviation are reduced across all devices during training.
284 |
285 | For example, when one uses `nn.DataParallel` to wrap the network during
286 | training, PyTorch's implementation normalize the tensor on each device using
287 | the statistics only on that device, which accelerated the computation and
288 | is also easy to implement, but the statistics might be inaccurate.
289 | Instead, in this synchronized version, the statistics will be computed
290 | over all training samples distributed on multiple devices.
291 |
292 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
293 | as the built-in PyTorch implementation.
294 |
295 | The mean and standard-deviation are calculated per-dimension over
296 | the mini-batches and gamma and beta are learnable parameter vectors
297 | of size C (where C is the input size).
298 |
299 | During training, this layer keeps a running estimate of its computed mean
300 | and variance. The running sum is kept with a default momentum of 0.1.
301 |
302 | During evaluation, this running mean/variance is used for normalization.
303 |
304 | Because the BatchNorm is done over the `C` dimension, computing statistics
305 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
306 | or Spatio-temporal BatchNorm
307 |
308 | Args:
309 | num_features: num_features from an expected input of
310 | size batch_size x num_features x depth x height x width
311 | eps: a value added to the denominator for numerical stability.
312 | Default: 1e-5
313 | momentum: the value used for the running_mean and running_var
314 | computation. Default: 0.1
315 | affine: a boolean value that when set to ``True``, gives the layer learnable
316 | affine parameters. Default: ``True``
317 |
318 | Shape::
319 | - Input: :math:`(N, C, D, H, W)`
320 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
321 |
322 | Examples:
323 | >>> # With Learnable Parameters
324 | >>> m = SynchronizedBatchNorm3d(100)
325 | >>> # Without Learnable Parameters
326 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
327 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
328 | >>> output = m(input)
329 | """
330 |
331 | def _check_input_dim(self, input):
332 | if input.dim() != 5:
333 | raise ValueError('expected 5D input (got {}D input)'
334 | .format(input.dim()))
335 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
336 |
337 |
338 | @contextlib.contextmanager
339 | def patch_sync_batchnorm():
340 | import torch.nn as nn
341 |
342 | backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
343 |
344 | nn.BatchNorm1d = SynchronizedBatchNorm1d
345 | nn.BatchNorm2d = SynchronizedBatchNorm2d
346 | nn.BatchNorm3d = SynchronizedBatchNorm3d
347 |
348 | yield
349 |
350 | nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
351 |
352 |
353 | def convert_model(module):
354 | """Traverse the input module and its child recursively
355 | and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
356 | to SynchronizedBatchNorm*N*d
357 |
358 | Args:
359 | module: the input module needs to be convert to SyncBN model
360 |
361 | Examples:
362 | >>> import torch.nn as nn
363 | >>> import torchvision
364 | >>> # m is a standard pytorch model
365 | >>> m = torchvision.models.resnet18(True)
366 | >>> m = nn.DataParallel(m)
367 | >>> # after convert, m is using SyncBN
368 | >>> m = convert_model(m)
369 | """
370 | if isinstance(module, torch.nn.DataParallel):
371 | mod = module.module
372 | mod = convert_model(mod)
373 | mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
374 | return mod
375 |
376 | mod = module
377 | for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
378 | torch.nn.modules.batchnorm.BatchNorm2d,
379 | torch.nn.modules.batchnorm.BatchNorm3d],
380 | [SynchronizedBatchNorm1d,
381 | SynchronizedBatchNorm2d,
382 | SynchronizedBatchNorm3d]):
383 | if isinstance(module, pth_module):
384 | mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
385 | mod.running_mean = module.running_mean
386 | mod.running_var = module.running_var
387 | if module.affine:
388 | mod.weight.data = module.weight.data.clone().detach()
389 | mod.bias.data = module.bias.data.clone().detach()
390 |
391 | for name, child in module.named_children():
392 | mod.add_module(name, convert_model(child))
393 |
394 | return mod
395 |
--------------------------------------------------------------------------------
/utils/sync_batchnorm/batchnorm_reimpl.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : batchnorm_reimpl.py
4 | # Author : acgtyrant
5 | # Date : 11/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 |
15 | __all__ = ['BatchNorm2dReimpl']
16 |
17 |
18 | class BatchNorm2dReimpl(nn.Module):
19 | """
20 | A re-implementation of batch normalization, used for testing the numerical
21 | stability.
22 |
23 | Author: acgtyrant
24 | See also:
25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26 | """
27 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
28 | super().__init__()
29 |
30 | self.num_features = num_features
31 | self.eps = eps
32 | self.momentum = momentum
33 | self.weight = nn.Parameter(torch.empty(num_features))
34 | self.bias = nn.Parameter(torch.empty(num_features))
35 | self.register_buffer('running_mean', torch.zeros(num_features))
36 | self.register_buffer('running_var', torch.ones(num_features))
37 | self.reset_parameters()
38 |
39 | def reset_running_stats(self):
40 | self.running_mean.zero_()
41 | self.running_var.fill_(1)
42 |
43 | def reset_parameters(self):
44 | self.reset_running_stats()
45 | init.uniform_(self.weight)
46 | init.zeros_(self.bias)
47 |
48 | def forward(self, input_):
49 | batchsize, channels, height, width = input_.size()
50 | numel = batchsize * height * width
51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52 | sum_ = input_.sum(1)
53 | sum_of_square = input_.pow(2).sum(1)
54 | mean = sum_ / numel
55 | sumvar = sum_of_square - sum_ * mean
56 |
57 | self.running_mean = (
58 | (1 - self.momentum) * self.running_mean
59 | + self.momentum * mean.detach()
60 | )
61 | unbias_var = sumvar / (numel - 1)
62 | self.running_var = (
63 | (1 - self.momentum) * self.running_var
64 | + self.momentum * unbias_var.detach()
65 | )
66 |
67 | bias_var = sumvar / numel
68 | inv_std = 1 / (bias_var + self.eps).pow(0.5)
69 | output = (
70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72 |
73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74 |
75 |
--------------------------------------------------------------------------------
/utils/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def __getstate__(self):
79 | return {'master_callback': self._master_callback}
80 |
81 | def __setstate__(self, state):
82 | self.__init__(state['master_callback'])
83 |
84 | def register_slave(self, identifier):
85 | """
86 | Register an slave device.
87 |
88 | Args:
89 | identifier: an identifier, usually is the device id.
90 |
91 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
92 |
93 | """
94 | if self._activated:
95 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
96 | self._activated = False
97 | self._registry.clear()
98 | future = FutureResult()
99 | self._registry[identifier] = _MasterRegistry(future)
100 | return SlavePipe(identifier, self._queue, future)
101 |
102 | def run_master(self, master_msg):
103 | """
104 | Main entry for the master device in each forward pass.
105 | The messages were first collected from each devices (including the master device), and then
106 | an callback will be invoked to compute the message to be sent back to each devices
107 | (including the master device).
108 |
109 | Args:
110 | master_msg: the message that the master want to send to itself. This will be placed as the first
111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112 |
113 | Returns: the message to be sent back to the master device.
114 |
115 | """
116 | self._activated = True
117 |
118 | intermediates = [(0, master_msg)]
119 | for i in range(self.nr_slaves):
120 | intermediates.append(self._queue.get())
121 |
122 | results = self._master_callback(intermediates)
123 | assert results[0][0] == 0, 'The first result should belongs to the master.'
124 |
125 | for i, res in results:
126 | if i == 0:
127 | continue
128 | self._registry[i].result.put(res)
129 |
130 | for i in range(self.nr_slaves):
131 | assert self._queue.get() is True
132 |
133 | return results[0][1]
134 |
135 | @property
136 | def nr_slaves(self):
137 | return len(self._registry)
138 |
--------------------------------------------------------------------------------
/utils/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/utils/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 | import torch
13 |
14 |
15 | class TorchTestCase(unittest.TestCase):
16 | def assertTensorClose(self, x, y):
17 | adiff = float((x - y).abs().max())
18 | if (y == 0).all():
19 | rdiff = 'NaN'
20 | else:
21 | rdiff = float((adiff / y).abs().max())
22 |
23 | message = (
24 | 'Tensor close check failed\n'
25 | 'adiff={}\n'
26 | 'rdiff={}\n'
27 | ).format(adiff, rdiff)
28 | self.assertTrue(torch.allclose(x, y), message)
29 |
30 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Author: Zylo117
2 |
3 | import os
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | from glob import glob
9 | from torch import nn
10 | from torchvision.ops import nms
11 | from typing import Union
12 | import uuid
13 |
14 | from utils.sync_batchnorm import SynchronizedBatchNorm2d
15 |
16 | from torch.nn.init import _calculate_fan_in_and_fan_out, _no_grad_normal_
17 | import math
18 |
19 |
20 | def invert_affine(metas: Union[float, list, tuple], preds):
21 | for i in range(len(preds)):
22 | if len(preds[i]['rois']) == 0:
23 | continue
24 | else:
25 | if metas is float:
26 | preds[i]['rois'][:, [0, 2]] = preds[i]['rois'][:, [0, 2]] / metas
27 | preds[i]['rois'][:, [1, 3]] = preds[i]['rois'][:, [1, 3]] / metas
28 | else:
29 | new_w, new_h, old_w, old_h, padding_w, padding_h = metas[i]
30 | preds[i]['rois'][:, [0, 2]] = preds[i]['rois'][:, [0, 2]] / (new_w / old_w)
31 | preds[i]['rois'][:, [1, 3]] = preds[i]['rois'][:, [1, 3]] / (new_h / old_h)
32 | return preds
33 |
34 |
35 | def aspectaware_resize_padding(image, width, height, interpolation=None, means=None):
36 | old_h, old_w, c = image.shape
37 | if old_w > old_h:
38 | new_w = width
39 | new_h = int(width / old_w * old_h)
40 | else:
41 | new_w = int(height / old_h * old_w)
42 | new_h = height
43 |
44 | canvas = np.zeros((height, height, c), np.float32)
45 | if means is not None:
46 | canvas[...] = means
47 |
48 | if new_w != old_w or new_h != old_h:
49 | if interpolation is None:
50 | image = cv2.resize(image, (new_w, new_h))
51 | else:
52 | image = cv2.resize(image, (new_w, new_h), interpolation=interpolation)
53 |
54 | padding_h = height - new_h
55 | padding_w = width - new_w
56 |
57 | if c > 1:
58 | canvas[:new_h, :new_w] = image
59 | else:
60 | if len(image.shape) == 2:
61 | canvas[:new_h, :new_w, 0] = image
62 | else:
63 | canvas[:new_h, :new_w] = image
64 |
65 | return canvas, new_w, new_h, old_w, old_h, padding_w, padding_h,
66 |
67 |
68 | def preprocess(*image_path, max_size=512, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)):
69 | ori_imgs = [cv2.imread(img_path) for img_path in image_path]
70 | normalized_imgs = [(img / 255 - mean) / std for img in ori_imgs]
71 | imgs_meta = [aspectaware_resize_padding(img[..., ::-1], max_size, max_size,
72 | means=None) for img in normalized_imgs]
73 | framed_imgs = [img_meta[0] for img_meta in imgs_meta]
74 | framed_metas = [img_meta[1:] for img_meta in imgs_meta]
75 |
76 | return ori_imgs, framed_imgs, framed_metas
77 |
78 |
79 | def preprocess_video(*frame_from_video, max_size=512, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)):
80 | ori_imgs = frame_from_video
81 | normalized_imgs = [(img / 255 - mean) / std for img in ori_imgs]
82 | imgs_meta = [aspectaware_resize_padding(img[..., ::-1], max_size, max_size,
83 | means=None) for img in normalized_imgs]
84 | framed_imgs = [img_meta[0] for img_meta in imgs_meta]
85 | framed_metas = [img_meta[1:] for img_meta in imgs_meta]
86 |
87 | return ori_imgs, framed_imgs, framed_metas
88 |
89 |
90 | def postprocess(x, anchors, regression, classification, regressBoxes, clipBoxes, threshold, iou_threshold):
91 | transformed_anchors = regressBoxes(anchors, regression)
92 | transformed_anchors = clipBoxes(transformed_anchors, x)
93 | scores = torch.max(classification, dim=2, keepdim=True)[0]
94 | scores_over_thresh = (scores > threshold)[:, :, 0]
95 | out = []
96 | for i in range(x.shape[0]):
97 | if scores_over_thresh.sum() == 0:
98 | out.append({
99 | 'rois': np.array(()),
100 | 'class_ids': np.array(()),
101 | 'scores': np.array(()),
102 | })
103 | continue
104 |
105 | classification_per = classification[i, scores_over_thresh[i, :], ...].permute(1, 0)
106 | transformed_anchors_per = transformed_anchors[i, scores_over_thresh[i, :], ...]
107 | scores_per = scores[i, scores_over_thresh[i, :], ...]
108 | anchors_nms_idx = nms(transformed_anchors_per, scores_per[:, 0], iou_threshold=iou_threshold)
109 |
110 | if anchors_nms_idx.shape[0] != 0:
111 | scores_, classes_ = classification_per[:, anchors_nms_idx].max(dim=0)
112 | boxes_ = transformed_anchors_per[anchors_nms_idx, :]
113 |
114 | out.append({
115 | 'rois': boxes_.cpu().numpy(),
116 | 'class_ids': classes_.cpu().numpy(),
117 | 'scores': scores_.cpu().numpy(),
118 | })
119 | else:
120 | out.append({
121 | 'rois': np.array(()),
122 | 'class_ids': np.array(()),
123 | 'scores': np.array(()),
124 | })
125 |
126 | return out
127 |
128 |
129 | def display(preds, imgs, obj_list, imshow=True, imwrite=False):
130 | for i in range(len(imgs)):
131 | if len(preds[i]['rois']) == 0:
132 | continue
133 |
134 | for j in range(len(preds[i]['rois'])):
135 | (x1, y1, x2, y2) = preds[i]['rois'][j].astype(np.int)
136 | cv2.rectangle(imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2)
137 | obj = obj_list[preds[i]['class_ids'][j]]
138 | score = float(preds[i]['scores'][j])
139 |
140 | cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score),
141 | (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
142 | (255, 255, 0), 1)
143 | if imshow:
144 | cv2.imshow('img', imgs[i])
145 | cv2.waitKey(0)
146 |
147 | if imwrite:
148 | os.makedirs('test/', exist_ok=True)
149 | cv2.imwrite(f'test/{uuid.uuid4().hex}.jpg', imgs[i])
150 |
151 |
152 | def replace_w_sync_bn(m):
153 | for var_name in dir(m):
154 | target_attr = getattr(m, var_name)
155 | if type(target_attr) == torch.nn.BatchNorm2d:
156 | num_features = target_attr.num_features
157 | eps = target_attr.eps
158 | momentum = target_attr.momentum
159 | affine = target_attr.affine
160 |
161 | # get parameters
162 | running_mean = target_attr.running_mean
163 | running_var = target_attr.running_var
164 | if affine:
165 | weight = target_attr.weight
166 | bias = target_attr.bias
167 |
168 | setattr(m, var_name,
169 | SynchronizedBatchNorm2d(num_features, eps, momentum, affine))
170 |
171 | target_attr = getattr(m, var_name)
172 | # set parameters
173 | target_attr.running_mean = running_mean
174 | target_attr.running_var = running_var
175 | if affine:
176 | target_attr.weight = weight
177 | target_attr.bias = bias
178 |
179 | for var_name, children in m.named_children():
180 | replace_w_sync_bn(children)
181 |
182 |
183 | class CustomDataParallel(nn.DataParallel):
184 | """
185 | force splitting data to all gpus instead of sending all data to cuda:0 and then moving around.
186 | """
187 |
188 | def __init__(self, module, num_gpus):
189 | super().__init__(module)
190 | self.num_gpus = num_gpus
191 |
192 | def scatter(self, inputs, kwargs, device_ids):
193 | # More like scatter and data prep at the same time. The point is we prep the data in such a way
194 | # that no scatter is necessary, and there's no need to shuffle stuff around different GPUs.
195 | devices = ['cuda:' + str(x) for x in range(self.num_gpus)]
196 | splits = inputs[0].shape[0] // self.num_gpus
197 |
198 | return [(inputs[0][splits * device_idx: splits * (device_idx + 1)].to(f'cuda:{device_idx}', non_blocking=True),
199 | inputs[1][splits * device_idx: splits * (device_idx + 1)].to(f'cuda:{device_idx}', non_blocking=True))
200 | for device_idx in range(len(devices))], \
201 | [kwargs] * len(devices)
202 |
203 |
204 | def get_last_weights(weights_path):
205 | weights_path = glob(weights_path + f'/*.pth')
206 | weights_path = sorted(weights_path,
207 | key=lambda x: int(x.rsplit('_')[-1].rsplit('.')[0]),
208 | reverse=True)[0]
209 | print(f'using weights {weights_path}')
210 | return weights_path
211 |
212 |
213 | def init_weights(model):
214 | for name, module in model.named_modules():
215 | is_conv_layer = isinstance(module, nn.Conv2d)
216 |
217 | if is_conv_layer:
218 | if "conv_list" or "header" in name:
219 | variance_scaling_(module.weight.data)
220 | else:
221 | nn.init.kaiming_uniform_(module.weight.data)
222 |
223 | if module.bias is not None:
224 | if "classifier.header" in name:
225 | bias_value = -np.log((1 - 0.01) / 0.01)
226 | torch.nn.init.constant_(module.bias, bias_value)
227 | else:
228 | module.bias.data.zero_()
229 |
230 |
231 | def variance_scaling_(tensor, gain=1.):
232 | # type: (Tensor, float) -> Tensor
233 | r"""
234 | initializer for SeparableConv in Regressor/Classifier
235 | reference: https://keras.io/zh/initializers/ VarianceScaling
236 | """
237 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
238 | std = math.sqrt(gain / float(fan_in))
239 |
240 | return _no_grad_normal_(tensor, 0., std)
241 |
--------------------------------------------------------------------------------