├── .gitignore
├── LICENSE
├── README.md
├── demo
├── 1.jpg
├── 2.jpg
├── 3.jpg
├── 4.jpg
├── 5.jpg
├── 6.jpg
├── 7.jpg
├── 8.jpg
├── 9.jpg
└── video.gif
├── mAP_evaluation.py
├── requirements.txt
├── src
├── config.py
├── dataset.py
├── loss.py
├── model.py
└── utils.py
├── test_dataset.py
├── test_video.py
├── train.py
└── trained_models
├── signatrix_efficientdet_coco.onnx
└── signatrix_efficientdet_coco.pth
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | .DS_Store
107 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Signatrix
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EfficientDet: Scalable and Efficient Object Detection
2 |
3 | ## Introduction
4 |
5 | Here is our pytorch implementation of the model described in the paper **EfficientDet: Scalable and Efficient Object Detection** [paper](https://arxiv.org/abs/1911.09070) (*Note*: We also provide pre-trained weights, which you could see at ./trained_models)
6 |
7 | 
8 | An example of our model's output.
9 |
10 |
11 |
12 | ## Datasets
13 |
14 |
15 | | Dataset | Classes | #Train images | #Validation images |
16 | |------------------------|:---------:|:-----------------------:|:----------------------------:|
17 | | COCO2017 | 80 | 118k | 5k |
18 |
19 | Create a data folder under the repository,
20 |
21 | ```
22 | cd {repo_root}
23 | mkdir data
24 | ```
25 |
26 | - **COCO**:
27 | Download the coco images and annotations from [coco website](http://cocodataset.org/#download). Make sure to put the files as the following structure:
28 | ```
29 | COCO
30 | ├── annotations
31 | │ ├── instances_train2017.json
32 | │ └── instances_val2017.json
33 | │── images
34 | ├── train2017
35 | └── val2017
36 | ```
37 |
38 | ## How to use our code
39 |
40 | With our code, you can:
41 |
42 | * **Train your model** by running **python train.py**
43 | * **Evaluate mAP for COCO dataset** by running **python mAP_evaluation.py**
44 | * **Test your model for COCO dataset** by running **python test_dataset.py --pretrained_model path/to/trained_model**
45 | * **Test your model for video** by running **python test_video.py --pretrained_model path/to/trained_model --input path/to/input/file --output path/to/output/file**
46 |
47 | ## Experiments
48 |
49 | We trained our model by using 3 NVIDIA GTX 1080Ti. Below is mAP (mean average precision) for COCO val2017 dataset
50 |
51 | | Average Precision | IoU=0.50:0.95 | area= all | maxDets=100 | 0.314 |
52 | |-----------------------|:-------------------:|:-----------------:|:-----------------:|:-------------:|
53 | | Average Precision | IoU=0.50 | area= all | maxDets=100 | 0.461 |
54 | | Average Precision | IoU=0.75 | area= all | maxDets=100 | 0.343 |
55 | | Average Precision | IoU=0.50:0.95 | area= small | maxDets=100 | 0.093 |
56 | | Average Precision | IoU=0.50:0.95 | area= medium | maxDets=100 | 0.358 |
57 | | Average Precision | IoU=0.50:0.95 | area= large | maxDets=100 | 0.517 |
58 | | Average Recall | IoU=0.50:0.95 | area= all | maxDets=1 | 0.268 |
59 | | Average Recall | IoU=0.50:0.95 | area= all | maxDets=10 | 0.382 |
60 | | Average Recall | IoU=0.50:0.95 | area= all | maxDets=100 | 0.403 |
61 | | Average Recall | IoU=0.50:0.95 | area= small | maxDets=100 | 0.117 |
62 | | Average Recall | IoU=0.50:0.95 | area= medium | maxDets=100 | 0.486 |
63 | | Average Recall | IoU=0.50:0.95 | area= large | maxDets=100 | 0.625 |
64 |
65 |
66 | ## Results
67 |
68 | Some predictions are shown below:
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 | ## Requirements
78 |
79 | * **python 3.6**
80 | * **pytorch 1.2**
81 | * **opencv (cv2)**
82 | * **tensorboard**
83 | * **tensorboardX** (This library could be skipped if you do not use SummaryWriter)
84 | * **pycocotools**
85 | * **efficientnet_pytorch**
86 |
87 | ## References
88 | - Mingxing Tan, Ruoming Pang, Quoc V. Le. "EfficientDet: Scalable and Efficient Object Detection." [EfficientDet](https://arxiv.org/abs/1911.09070).
89 | - Our implementation borrows some parts from [RetinaNet.Pytorch](https://github.com/yhenon/pytorch-retinanet)
90 |
91 |
92 | ## Citation
93 |
94 | @article{EfficientDetSignatrix,
95 | Author = {Signatrix GmbH},
96 | Title = {A Pytorch Implementation of EfficientDet Object Detection},
97 | Journal = {https://github.com/signatrix/efficientdet},
98 | Year = {2020}
99 | }
--------------------------------------------------------------------------------
/demo/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/1.jpg
--------------------------------------------------------------------------------
/demo/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/2.jpg
--------------------------------------------------------------------------------
/demo/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/3.jpg
--------------------------------------------------------------------------------
/demo/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/4.jpg
--------------------------------------------------------------------------------
/demo/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/5.jpg
--------------------------------------------------------------------------------
/demo/6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/6.jpg
--------------------------------------------------------------------------------
/demo/7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/7.jpg
--------------------------------------------------------------------------------
/demo/8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/8.jpg
--------------------------------------------------------------------------------
/demo/9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/9.jpg
--------------------------------------------------------------------------------
/demo/video.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/video.gif
--------------------------------------------------------------------------------
/mAP_evaluation.py:
--------------------------------------------------------------------------------
1 | from src.dataset import CocoDataset, Resizer, Normalizer
2 | from torchvision import transforms
3 | from pycocotools.cocoeval import COCOeval
4 | import json
5 | import torch
6 |
7 |
8 | def evaluate_coco(dataset, model, threshold=0.05):
9 | model.eval()
10 | with torch.no_grad():
11 | results = []
12 | image_ids = []
13 |
14 | for index in range(len(dataset)):
15 | data = dataset[index]
16 | scale = data['scale']
17 | scores, labels, boxes = model(data['img'].cuda().permute(2, 0, 1).float().unsqueeze(dim=0))
18 | boxes /= scale
19 |
20 | if boxes.shape[0] > 0:
21 |
22 | boxes[:, 2] -= boxes[:, 0]
23 | boxes[:, 3] -= boxes[:, 1]
24 |
25 | for box_id in range(boxes.shape[0]):
26 | score = float(scores[box_id])
27 | label = int(labels[box_id])
28 | box = boxes[box_id, :]
29 |
30 | if score < threshold:
31 | break
32 |
33 | image_result = {
34 | 'image_id': dataset.image_ids[index],
35 | 'category_id': dataset.label_to_coco_label(label),
36 | 'score': float(score),
37 | 'bbox': box.tolist(),
38 | }
39 |
40 | results.append(image_result)
41 |
42 | # append image to list of processed images
43 | image_ids.append(dataset.image_ids[index])
44 |
45 | # print progress
46 | print('{}/{}'.format(index, len(dataset)), end='\r')
47 |
48 | if not len(results):
49 | return
50 |
51 | # write output
52 | json.dump(results, open('{}_bbox_results.json'.format(dataset.set_name), 'w'), indent=4)
53 |
54 | # load results in COCO evaluation tool
55 | coco_true = dataset.coco
56 | coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(dataset.set_name))
57 |
58 | # run COCO evaluation
59 | coco_eval = COCOeval(coco_true, coco_pred, 'bbox')
60 | coco_eval.params.imgIds = image_ids
61 | coco_eval.evaluate()
62 | coco_eval.accumulate()
63 | coco_eval.summarize()
64 |
65 |
66 | if __name__ == '__main__':
67 | efficientdet = torch.load("trained_models/signatrix_efficientdet_coco.pth").module
68 | efficientdet.cuda()
69 | dataset_val = CocoDataset("data/COCO", set='val2017',
70 | transform=transforms.Compose([Normalizer(), Resizer()]))
71 | evaluate_coco(dataset_val, efficientdet)
72 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | efficientnet_pytorch
2 | tensorboardX
3 | pycocotools
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | COCO_CLASSES = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
2 | "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog",
3 | "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
4 | "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite",
5 | "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
6 | "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
7 | "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant",
8 | "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
9 | "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
10 | "teddy bear", "hair drier", "toothbrush"]
11 |
12 | colors = [(39, 129, 113), (164, 80, 133), (83, 122, 114), (99, 81, 172), (95, 56, 104), (37, 84, 86), (14, 89, 122),
13 | (80, 7, 65), (10, 102, 25), (90, 185, 109), (106, 110, 132), (169, 158, 85), (188, 185, 26), (103, 1, 17),
14 | (82, 144, 81), (92, 7, 184), (49, 81, 155), (179, 177, 69), (93, 187, 158), (13, 39, 73), (12, 50, 60),
15 | (16, 179, 33), (112, 69, 165), (15, 139, 63), (33, 191, 159), (182, 173, 32), (34, 113, 133), (90, 135, 34),
16 | (53, 34, 86), (141, 35, 190), (6, 171, 8), (118, 76, 112), (89, 60, 55), (15, 54, 88), (112, 75, 181),
17 | (42, 147, 38), (138, 52, 63), (128, 65, 149), (106, 103, 24), (168, 33, 45), (28, 136, 135), (86, 91, 108),
18 | (52, 11, 76), (142, 6, 189), (57, 81, 168), (55, 19, 148), (182, 101, 89), (44, 65, 179), (1, 33, 26),
19 | (122, 164, 26), (70, 63, 134), (137, 106, 82), (120, 118, 52), (129, 74, 42), (182, 147, 112), (22, 157, 50),
20 | (56, 50, 20), (2, 22, 177), (156, 100, 106), (21, 35, 42), (13, 8, 121), (142, 92, 28), (45, 118, 33),
21 | (105, 118, 30), (7, 185, 124), (46, 34, 146), (105, 184, 169), (22, 18, 5), (147, 71, 73), (181, 64, 91),
22 | (31, 39, 184), (164, 179, 33), (96, 50, 18), (95, 15, 106), (113, 68, 54), (136, 116, 112), (119, 139, 130),
23 | (31, 139, 34), (66, 6, 127), (62, 39, 2), (49, 99, 180), (49, 119, 155), (153, 50, 183), (125, 38, 3),
24 | (129, 87, 143), (49, 87, 40), (128, 62, 120), (73, 85, 148), (28, 144, 118), (29, 9, 24), (175, 45, 108),
25 | (81, 175, 64), (178, 19, 157), (74, 188, 190), (18, 114, 2), (62, 128, 96), (21, 3, 150), (0, 6, 95),
26 | (2, 20, 184), (122, 37, 185)]
27 |
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 |
5 | from torch.utils.data import Dataset, DataLoader
6 | from pycocotools.coco import COCO
7 | import cv2
8 |
9 |
10 | class CocoDataset(Dataset):
11 | def __init__(self, root_dir, set='train2017', transform=None):
12 |
13 | self.root_dir = root_dir
14 | self.set_name = set
15 | self.transform = transform
16 |
17 | self.coco = COCO(os.path.join(self.root_dir, 'annotations', 'instances_' + self.set_name + '.json'))
18 | self.image_ids = self.coco.getImgIds()
19 |
20 | self.load_classes()
21 |
22 | def load_classes(self):
23 |
24 | # load class names (name -> label)
25 | categories = self.coco.loadCats(self.coco.getCatIds())
26 | categories.sort(key=lambda x: x['id'])
27 |
28 | self.classes = {}
29 | self.coco_labels = {}
30 | self.coco_labels_inverse = {}
31 | for c in categories:
32 | self.coco_labels[len(self.classes)] = c['id']
33 | self.coco_labels_inverse[c['id']] = len(self.classes)
34 | self.classes[c['name']] = len(self.classes)
35 |
36 | # also load the reverse (label -> name)
37 | self.labels = {}
38 | for key, value in self.classes.items():
39 | self.labels[value] = key
40 |
41 | def __len__(self):
42 | return len(self.image_ids)
43 |
44 | def __getitem__(self, idx):
45 |
46 | img = self.load_image(idx)
47 | annot = self.load_annotations(idx)
48 | sample = {'img': img, 'annot': annot}
49 | if self.transform:
50 | sample = self.transform(sample)
51 | return sample
52 |
53 | def load_image(self, image_index):
54 | image_info = self.coco.loadImgs(self.image_ids[image_index])[0]
55 | path = os.path.join(self.root_dir, 'images', self.set_name, image_info['file_name'])
56 | img = cv2.imread(path)
57 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
58 |
59 | # if len(img.shape) == 2:
60 | # img = skimage.color.gray2rgb(img)
61 |
62 | return img.astype(np.float32) / 255.
63 |
64 | def load_annotations(self, image_index):
65 | # get ground truth annotations
66 | annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False)
67 | annotations = np.zeros((0, 5))
68 |
69 | # some images appear to miss annotations
70 | if len(annotations_ids) == 0:
71 | return annotations
72 |
73 | # parse annotations
74 | coco_annotations = self.coco.loadAnns(annotations_ids)
75 | for idx, a in enumerate(coco_annotations):
76 |
77 | # some annotations have basically no width / height, skip them
78 | if a['bbox'][2] < 1 or a['bbox'][3] < 1:
79 | continue
80 |
81 | annotation = np.zeros((1, 5))
82 | annotation[0, :4] = a['bbox']
83 | annotation[0, 4] = self.coco_label_to_label(a['category_id'])
84 | annotations = np.append(annotations, annotation, axis=0)
85 |
86 | # transform from [x, y, w, h] to [x1, y1, x2, y2]
87 | annotations[:, 2] = annotations[:, 0] + annotations[:, 2]
88 | annotations[:, 3] = annotations[:, 1] + annotations[:, 3]
89 |
90 | return annotations
91 |
92 | def coco_label_to_label(self, coco_label):
93 | return self.coco_labels_inverse[coco_label]
94 |
95 | def label_to_coco_label(self, label):
96 | return self.coco_labels[label]
97 |
98 | def num_classes(self):
99 | return 80
100 |
101 |
102 | def collater(data):
103 | imgs = [s['img'] for s in data]
104 | annots = [s['annot'] for s in data]
105 | scales = [s['scale'] for s in data]
106 |
107 | imgs = torch.from_numpy(np.stack(imgs, axis=0))
108 |
109 | max_num_annots = max(annot.shape[0] for annot in annots)
110 |
111 | if max_num_annots > 0:
112 |
113 | annot_padded = torch.ones((len(annots), max_num_annots, 5)) * -1
114 |
115 | if max_num_annots > 0:
116 | for idx, annot in enumerate(annots):
117 | if annot.shape[0] > 0:
118 | annot_padded[idx, :annot.shape[0], :] = annot
119 | else:
120 | annot_padded = torch.ones((len(annots), 1, 5)) * -1
121 |
122 | imgs = imgs.permute(0, 3, 1, 2)
123 |
124 | return {'img': imgs, 'annot': annot_padded, 'scale': scales}
125 |
126 |
127 | class Resizer(object):
128 | """Convert ndarrays in sample to Tensors."""
129 |
130 | def __call__(self, sample, common_size=512):
131 | image, annots = sample['img'], sample['annot']
132 | height, width, _ = image.shape
133 | if height > width:
134 | scale = common_size / height
135 | resized_height = common_size
136 | resized_width = int(width * scale)
137 | else:
138 | scale = common_size / width
139 | resized_height = int(height * scale)
140 | resized_width = common_size
141 |
142 | image = cv2.resize(image, (resized_width, resized_height))
143 |
144 | new_image = np.zeros((common_size, common_size, 3))
145 | new_image[0:resized_height, 0:resized_width] = image
146 |
147 | annots[:, :4] *= scale
148 |
149 | return {'img': torch.from_numpy(new_image), 'annot': torch.from_numpy(annots), 'scale': scale}
150 |
151 |
152 | class Augmenter(object):
153 | """Convert ndarrays in sample to Tensors."""
154 |
155 | def __call__(self, sample, flip_x=0.5):
156 | if np.random.rand() < flip_x:
157 | image, annots = sample['img'], sample['annot']
158 | image = image[:, ::-1, :]
159 |
160 | rows, cols, channels = image.shape
161 |
162 | x1 = annots[:, 0].copy()
163 | x2 = annots[:, 2].copy()
164 |
165 | x_tmp = x1.copy()
166 |
167 | annots[:, 0] = cols - x2
168 | annots[:, 2] = cols - x_tmp
169 |
170 | sample = {'img': image, 'annot': annots}
171 |
172 | return sample
173 |
174 |
175 | class Normalizer(object):
176 |
177 | def __init__(self):
178 | self.mean = np.array([[[0.485, 0.456, 0.406]]])
179 | self.std = np.array([[[0.229, 0.224, 0.225]]])
180 |
181 | def __call__(self, sample):
182 | image, annots = sample['img'], sample['annot']
183 |
184 | return {'img': ((image.astype(np.float32) - self.mean) / self.std), 'annot': annots}
185 |
--------------------------------------------------------------------------------
/src/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def calc_iou(a, b):
6 |
7 | area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])
8 | iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0])
9 | ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1])
10 | iw = torch.clamp(iw, min=0)
11 | ih = torch.clamp(ih, min=0)
12 | ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih
13 | ua = torch.clamp(ua, min=1e-8)
14 | intersection = iw * ih
15 | IoU = intersection / ua
16 |
17 | return IoU
18 |
19 |
20 | class FocalLoss(nn.Module):
21 | def __init__(self):
22 | super(FocalLoss, self).__init__()
23 |
24 | def forward(self, classifications, regressions, anchors, annotations):
25 | alpha = 0.25
26 | gamma = 2.0
27 | batch_size = classifications.shape[0]
28 | classification_losses = []
29 | regression_losses = []
30 |
31 | anchor = anchors[0, :, :]
32 |
33 | anchor_widths = anchor[:, 2] - anchor[:, 0]
34 | anchor_heights = anchor[:, 3] - anchor[:, 1]
35 | anchor_ctr_x = anchor[:, 0] + 0.5 * anchor_widths
36 | anchor_ctr_y = anchor[:, 1] + 0.5 * anchor_heights
37 |
38 | for j in range(batch_size):
39 |
40 | classification = classifications[j, :, :]
41 | regression = regressions[j, :, :]
42 |
43 | bbox_annotation = annotations[j, :, :]
44 | bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
45 |
46 | if bbox_annotation.shape[0] == 0:
47 | if torch.cuda.is_available():
48 | regression_losses.append(torch.tensor(0).float().cuda())
49 | classification_losses.append(torch.tensor(0).float().cuda())
50 | else:
51 | regression_losses.append(torch.tensor(0).float())
52 | classification_losses.append(torch.tensor(0).float())
53 |
54 | continue
55 |
56 | classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
57 |
58 | IoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4])
59 |
60 | IoU_max, IoU_argmax = torch.max(IoU, dim=1)
61 |
62 | # compute the loss for classification
63 | targets = torch.ones(classification.shape) * -1
64 | if torch.cuda.is_available():
65 | targets = targets.cuda()
66 |
67 | targets[torch.lt(IoU_max, 0.4), :] = 0
68 |
69 | positive_indices = torch.ge(IoU_max, 0.5)
70 |
71 | num_positive_anchors = positive_indices.sum()
72 |
73 | assigned_annotations = bbox_annotation[IoU_argmax, :]
74 |
75 | targets[positive_indices, :] = 0
76 | targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1
77 |
78 | alpha_factor = torch.ones(targets.shape) * alpha
79 | if torch.cuda.is_available():
80 | alpha_factor = alpha_factor.cuda()
81 |
82 | alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
83 | focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
84 | focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
85 |
86 | bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))
87 |
88 | cls_loss = focal_weight * bce
89 |
90 | zeros = torch.zeros(cls_loss.shape)
91 | if torch.cuda.is_available():
92 | zeros = zeros.cuda()
93 | cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros)
94 |
95 | classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.float(), min=1.0))
96 |
97 |
98 | if positive_indices.sum() > 0:
99 | assigned_annotations = assigned_annotations[positive_indices, :]
100 |
101 | anchor_widths_pi = anchor_widths[positive_indices]
102 | anchor_heights_pi = anchor_heights[positive_indices]
103 | anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
104 | anchor_ctr_y_pi = anchor_ctr_y[positive_indices]
105 |
106 | gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0]
107 | gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
108 | gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths
109 | gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights
110 |
111 | gt_widths = torch.clamp(gt_widths, min=1)
112 | gt_heights = torch.clamp(gt_heights, min=1)
113 |
114 | targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
115 | targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
116 | targets_dw = torch.log(gt_widths / anchor_widths_pi)
117 | targets_dh = torch.log(gt_heights / anchor_heights_pi)
118 |
119 | targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh))
120 | targets = targets.t()
121 |
122 | norm = torch.Tensor([[0.1, 0.1, 0.2, 0.2]])
123 | if torch.cuda.is_available():
124 | norm = norm.cuda()
125 | targets = targets / norm
126 |
127 | regression_diff = torch.abs(targets - regression[positive_indices, :])
128 |
129 | regression_loss = torch.where(
130 | torch.le(regression_diff, 1.0 / 9.0),
131 | 0.5 * 9.0 * torch.pow(regression_diff, 2),
132 | regression_diff - 0.5 / 9.0
133 | )
134 | regression_losses.append(regression_loss.mean())
135 | else:
136 | if torch.cuda.is_available():
137 | regression_losses.append(torch.tensor(0).float().cuda())
138 | else:
139 | regression_losses.append(torch.tensor(0).float())
140 |
141 | return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses).mean(dim=0,
142 | keepdim=True)
143 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import math
4 | from efficientnet_pytorch import EfficientNet as EffNet
5 | from src.utils import BBoxTransform, ClipBoxes, Anchors
6 | from src.loss import FocalLoss
7 | from torchvision.ops.boxes import nms as nms_torch
8 |
9 |
10 | def nms(dets, thresh):
11 | return nms_torch(dets[:, :4], dets[:, 4], thresh)
12 |
13 |
14 | class ConvBlock(nn.Module):
15 | def __init__(self, num_channels):
16 | super(ConvBlock, self).__init__()
17 | self.conv = nn.Sequential(
18 | nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1, groups=num_channels),
19 | nn.Conv2d(num_channels, num_channels, kernel_size=1, stride=1, padding=0),
20 | nn.BatchNorm2d(num_features=num_channels, momentum=0.9997, eps=4e-5), nn.ReLU())
21 |
22 | def forward(self, input):
23 | return self.conv(input)
24 |
25 |
26 | class BiFPN(nn.Module):
27 | def __init__(self, num_channels, epsilon=1e-4):
28 | super(BiFPN, self).__init__()
29 | self.epsilon = epsilon
30 | # Conv layers
31 | self.conv6_up = ConvBlock(num_channels)
32 | self.conv5_up = ConvBlock(num_channels)
33 | self.conv4_up = ConvBlock(num_channels)
34 | self.conv3_up = ConvBlock(num_channels)
35 | self.conv4_down = ConvBlock(num_channels)
36 | self.conv5_down = ConvBlock(num_channels)
37 | self.conv6_down = ConvBlock(num_channels)
38 | self.conv7_down = ConvBlock(num_channels)
39 |
40 | # Feature scaling layers
41 | self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest')
42 | self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest')
43 | self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest')
44 | self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest')
45 |
46 | self.p4_downsample = nn.MaxPool2d(kernel_size=2)
47 | self.p5_downsample = nn.MaxPool2d(kernel_size=2)
48 | self.p6_downsample = nn.MaxPool2d(kernel_size=2)
49 | self.p7_downsample = nn.MaxPool2d(kernel_size=2)
50 |
51 | # Weight
52 | self.p6_w1 = nn.Parameter(torch.ones(2))
53 | self.p6_w1_relu = nn.ReLU()
54 | self.p5_w1 = nn.Parameter(torch.ones(2))
55 | self.p5_w1_relu = nn.ReLU()
56 | self.p4_w1 = nn.Parameter(torch.ones(2))
57 | self.p4_w1_relu = nn.ReLU()
58 | self.p3_w1 = nn.Parameter(torch.ones(2))
59 | self.p3_w1_relu = nn.ReLU()
60 |
61 | self.p4_w2 = nn.Parameter(torch.ones(3))
62 | self.p4_w2_relu = nn.ReLU()
63 | self.p5_w2 = nn.Parameter(torch.ones(3))
64 | self.p5_w2_relu = nn.ReLU()
65 | self.p6_w2 = nn.Parameter(torch.ones(3))
66 | self.p6_w2_relu = nn.ReLU()
67 | self.p7_w2 = nn.Parameter(torch.ones(2))
68 | self.p7_w2_relu = nn.ReLU()
69 |
70 | def forward(self, inputs):
71 | """
72 | P7_0 -------------------------- P7_2 -------->
73 |
74 | P6_0 ---------- P6_1 ---------- P6_2 -------->
75 |
76 | P5_0 ---------- P5_1 ---------- P5_2 -------->
77 |
78 | P4_0 ---------- P4_1 ---------- P4_2 -------->
79 |
80 | P3_0 -------------------------- P3_2 -------->
81 | """
82 |
83 | # P3_0, P4_0, P5_0, P6_0 and P7_0
84 | p3_in, p4_in, p5_in, p6_in, p7_in = inputs
85 | # P7_0 to P7_2
86 | # Weights for P6_0 and P7_0 to P6_1
87 | p6_w1 = self.p6_w1_relu(self.p6_w1)
88 | weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon)
89 | # Connections for P6_0 and P7_0 to P6_1 respectively
90 | p6_up = self.conv6_up(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in))
91 | # Weights for P5_0 and P6_0 to P5_1
92 | p5_w1 = self.p5_w1_relu(self.p5_w1)
93 | weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon)
94 | # Connections for P5_0 and P6_0 to P5_1 respectively
95 | p5_up = self.conv5_up(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_up))
96 | # Weights for P4_0 and P5_0 to P4_1
97 | p4_w1 = self.p4_w1_relu(self.p4_w1)
98 | weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon)
99 | # Connections for P4_0 and P5_0 to P4_1 respectively
100 | p4_up = self.conv4_up(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_up))
101 |
102 | # Weights for P3_0 and P4_1 to P3_2
103 | p3_w1 = self.p3_w1_relu(self.p3_w1)
104 | weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon)
105 | # Connections for P3_0 and P4_1 to P3_2 respectively
106 | p3_out = self.conv3_up(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_up))
107 |
108 | # Weights for P4_0, P4_1 and P3_2 to P4_2
109 | p4_w2 = self.p4_w2_relu(self.p4_w2)
110 | weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon)
111 | # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
112 | p4_out = self.conv4_down(
113 | weight[0] * p4_in + weight[1] * p4_up + weight[2] * self.p4_downsample(p3_out))
114 | # Weights for P5_0, P5_1 and P4_2 to P5_2
115 | p5_w2 = self.p5_w2_relu(self.p5_w2)
116 | weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon)
117 | # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
118 | p5_out = self.conv5_down(
119 | weight[0] * p5_in + weight[1] * p5_up + weight[2] * self.p5_downsample(p4_out))
120 | # Weights for P6_0, P6_1 and P5_2 to P6_2
121 | p6_w2 = self.p6_w2_relu(self.p6_w2)
122 | weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon)
123 | # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
124 | p6_out = self.conv6_down(
125 | weight[0] * p6_in + weight[1] * p6_up + weight[2] * self.p6_downsample(p5_out))
126 | # Weights for P7_0 and P6_2 to P7_2
127 | p7_w2 = self.p7_w2_relu(self.p7_w2)
128 | weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon)
129 | # Connections for P7_0 and P6_2 to P7_2
130 | p7_out = self.conv7_down(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out))
131 |
132 | return p3_out, p4_out, p5_out, p6_out, p7_out
133 |
134 |
135 | class Regressor(nn.Module):
136 | def __init__(self, in_channels, num_anchors, num_layers):
137 | super(Regressor, self).__init__()
138 | layers = []
139 | for _ in range(num_layers):
140 | layers.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
141 | layers.append(nn.ReLU(True))
142 | self.layers = nn.Sequential(*layers)
143 | self.header = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
144 |
145 | def forward(self, inputs):
146 | inputs = self.layers(inputs)
147 | inputs = self.header(inputs)
148 | output = inputs.permute(0, 2, 3, 1)
149 | return output.contiguous().view(output.shape[0], -1, 4)
150 |
151 |
152 | class Classifier(nn.Module):
153 | def __init__(self, in_channels, num_anchors, num_classes, num_layers):
154 | super(Classifier, self).__init__()
155 | self.num_anchors = num_anchors
156 | self.num_classes = num_classes
157 | layers = []
158 | for _ in range(num_layers):
159 | layers.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
160 | layers.append(nn.ReLU(True))
161 | self.layers = nn.Sequential(*layers)
162 | self.header = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
163 | self.act = nn.Sigmoid()
164 |
165 | def forward(self, inputs):
166 | inputs = self.layers(inputs)
167 | inputs = self.header(inputs)
168 | inputs = self.act(inputs)
169 | inputs = inputs.permute(0, 2, 3, 1)
170 | output = inputs.contiguous().view(inputs.shape[0], inputs.shape[1], inputs.shape[2], self.num_anchors,
171 | self.num_classes)
172 | return output.contiguous().view(output.shape[0], -1, self.num_classes)
173 |
174 |
175 | class EfficientNet(nn.Module):
176 | def __init__(self, ):
177 | super(EfficientNet, self).__init__()
178 | model = EffNet.from_pretrained('efficientnet-b0')
179 | del model._conv_head
180 | del model._bn1
181 | del model._avg_pooling
182 | del model._dropout
183 | del model._fc
184 | self.model = model
185 |
186 | def forward(self, x):
187 | x = self.model._swish(self.model._bn0(self.model._conv_stem(x)))
188 | feature_maps = []
189 | for idx, block in enumerate(self.model._blocks):
190 | drop_connect_rate = self.model._global_params.drop_connect_rate
191 | if drop_connect_rate:
192 | drop_connect_rate *= float(idx) / len(self.model._blocks)
193 | x = block(x, drop_connect_rate=drop_connect_rate)
194 | if block._depthwise_conv.stride == [2, 2]:
195 | feature_maps.append(x)
196 |
197 | return feature_maps[1:]
198 |
199 |
200 | class EfficientDet(nn.Module):
201 | def __init__(self, num_anchors=9, num_classes=20, compound_coef=0):
202 | super(EfficientDet, self).__init__()
203 | self.compound_coef = compound_coef
204 |
205 | self.num_channels = [64, 88, 112, 160, 224, 288, 384, 384][self.compound_coef]
206 |
207 | self.conv3 = nn.Conv2d(40, self.num_channels, kernel_size=1, stride=1, padding=0)
208 | self.conv4 = nn.Conv2d(80, self.num_channels, kernel_size=1, stride=1, padding=0)
209 | self.conv5 = nn.Conv2d(192, self.num_channels, kernel_size=1, stride=1, padding=0)
210 | self.conv6 = nn.Conv2d(192, self.num_channels, kernel_size=3, stride=2, padding=1)
211 | self.conv7 = nn.Sequential(nn.ReLU(),
212 | nn.Conv2d(self.num_channels, self.num_channels, kernel_size=3, stride=2, padding=1))
213 |
214 | self.bifpn = nn.Sequential(*[BiFPN(self.num_channels) for _ in range(min(2 + self.compound_coef, 8))])
215 |
216 | self.num_classes = num_classes
217 | self.regressor = Regressor(in_channels=self.num_channels, num_anchors=num_anchors,
218 | num_layers=3 + self.compound_coef // 3)
219 | self.classifier = Classifier(in_channels=self.num_channels, num_anchors=num_anchors, num_classes=num_classes,
220 | num_layers=3 + self.compound_coef // 3)
221 |
222 | self.anchors = Anchors()
223 | self.regressBoxes = BBoxTransform()
224 | self.clipBoxes = ClipBoxes()
225 | self.focalLoss = FocalLoss()
226 |
227 | for m in self.modules():
228 | if isinstance(m, nn.Conv2d):
229 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
230 | m.weight.data.normal_(0, math.sqrt(2. / n))
231 | elif isinstance(m, nn.BatchNorm2d):
232 | m.weight.data.fill_(1)
233 | m.bias.data.zero_()
234 |
235 | prior = 0.01
236 |
237 | self.classifier.header.weight.data.fill_(0)
238 | self.classifier.header.bias.data.fill_(-math.log((1.0 - prior) / prior))
239 |
240 | self.regressor.header.weight.data.fill_(0)
241 | self.regressor.header.bias.data.fill_(0)
242 |
243 | self.backbone_net = EfficientNet()
244 |
245 | def freeze_bn(self):
246 | for m in self.modules():
247 | if isinstance(m, nn.BatchNorm2d):
248 | m.eval()
249 |
250 | def forward(self, inputs):
251 | if len(inputs) == 2:
252 | is_training = True
253 | img_batch, annotations = inputs
254 | else:
255 | is_training = False
256 | img_batch = inputs
257 |
258 | c3, c4, c5 = self.backbone_net(img_batch)
259 | p3 = self.conv3(c3)
260 | p4 = self.conv4(c4)
261 | p5 = self.conv5(c5)
262 | p6 = self.conv6(c5)
263 | p7 = self.conv7(p6)
264 |
265 | features = [p3, p4, p5, p6, p7]
266 | features = self.bifpn(features)
267 |
268 | regression = torch.cat([self.regressor(feature) for feature in features], dim=1)
269 | classification = torch.cat([self.classifier(feature) for feature in features], dim=1)
270 | anchors = self.anchors(img_batch)
271 |
272 | if is_training:
273 | return self.focalLoss(classification, regression, anchors, annotations)
274 | else:
275 | transformed_anchors = self.regressBoxes(anchors, regression)
276 | transformed_anchors = self.clipBoxes(transformed_anchors, img_batch)
277 |
278 | scores = torch.max(classification, dim=2, keepdim=True)[0]
279 |
280 | scores_over_thresh = (scores > 0.05)[0, :, 0]
281 |
282 | if scores_over_thresh.sum() == 0:
283 | return [torch.zeros(0), torch.zeros(0), torch.zeros(0, 4)]
284 |
285 | classification = classification[:, scores_over_thresh, :]
286 | transformed_anchors = transformed_anchors[:, scores_over_thresh, :]
287 | scores = scores[:, scores_over_thresh, :]
288 |
289 | anchors_nms_idx = nms(torch.cat([transformed_anchors, scores], dim=2)[0, :, :], 0.5)
290 |
291 | nms_scores, nms_class = classification[0, anchors_nms_idx, :].max(dim=1)
292 |
293 | return [nms_scores, nms_class, transformed_anchors[0, anchors_nms_idx, :]]
294 |
295 |
296 | if __name__ == '__main__':
297 | from tensorboardX import SummaryWriter
298 | def count_parameters(model):
299 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
300 |
301 | model = EfficientDet(num_classes=80)
302 | print (count_parameters(model))
303 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class BBoxTransform(nn.Module):
7 |
8 | def __init__(self, mean=None, std=None):
9 | super(BBoxTransform, self).__init__()
10 | if mean is None:
11 | self.mean = torch.from_numpy(np.array([0, 0, 0, 0]).astype(np.float32))
12 | else:
13 | self.mean = mean
14 | if std is None:
15 | self.std = torch.from_numpy(np.array([0.1, 0.1, 0.2, 0.2]).astype(np.float32))
16 | else:
17 | self.std = std
18 | if torch.cuda.is_available():
19 | self.mean = self.mean.cuda()
20 | self.std = self.std.cuda()
21 |
22 | def forward(self, boxes, deltas):
23 |
24 | widths = boxes[:, :, 2] - boxes[:, :, 0]
25 | heights = boxes[:, :, 3] - boxes[:, :, 1]
26 | ctr_x = boxes[:, :, 0] + 0.5 * widths
27 | ctr_y = boxes[:, :, 1] + 0.5 * heights
28 |
29 | dx = deltas[:, :, 0] * self.std[0] + self.mean[0]
30 | dy = deltas[:, :, 1] * self.std[1] + self.mean[1]
31 | dw = deltas[:, :, 2] * self.std[2] + self.mean[2]
32 | dh = deltas[:, :, 3] * self.std[3] + self.mean[3]
33 |
34 | pred_ctr_x = ctr_x + dx * widths
35 | pred_ctr_y = ctr_y + dy * heights
36 | pred_w = torch.exp(dw) * widths
37 | pred_h = torch.exp(dh) * heights
38 |
39 | pred_boxes_x1 = pred_ctr_x - 0.5 * pred_w
40 | pred_boxes_y1 = pred_ctr_y - 0.5 * pred_h
41 | pred_boxes_x2 = pred_ctr_x + 0.5 * pred_w
42 | pred_boxes_y2 = pred_ctr_y + 0.5 * pred_h
43 |
44 | pred_boxes = torch.stack([pred_boxes_x1, pred_boxes_y1, pred_boxes_x2, pred_boxes_y2], dim=2)
45 |
46 | return pred_boxes
47 |
48 |
49 | class ClipBoxes(nn.Module):
50 |
51 | def __init__(self):
52 | super(ClipBoxes, self).__init__()
53 |
54 | def forward(self, boxes, img):
55 | batch_size, num_channels, height, width = img.shape
56 |
57 | boxes[:, :, 0] = torch.clamp(boxes[:, :, 0], min=0)
58 | boxes[:, :, 1] = torch.clamp(boxes[:, :, 1], min=0)
59 |
60 | boxes[:, :, 2] = torch.clamp(boxes[:, :, 2], max=width)
61 | boxes[:, :, 3] = torch.clamp(boxes[:, :, 3], max=height)
62 |
63 | return boxes
64 |
65 |
66 | class Anchors(nn.Module):
67 | def __init__(self, pyramid_levels=None, strides=None, sizes=None, ratios=None, scales=None):
68 | super(Anchors, self).__init__()
69 |
70 | if pyramid_levels is None:
71 | self.pyramid_levels = [3, 4, 5, 6, 7]
72 | if strides is None:
73 | self.strides = [2 ** x for x in self.pyramid_levels]
74 | if sizes is None:
75 | self.sizes = [2 ** (x + 2) for x in self.pyramid_levels]
76 | if ratios is None:
77 | self.ratios = np.array([0.5, 1, 2])
78 | if scales is None:
79 | self.scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])
80 |
81 | def forward(self, image):
82 |
83 | image_shape = image.shape[2:]
84 | image_shape = np.array(image_shape)
85 | image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in self.pyramid_levels]
86 |
87 | all_anchors = np.zeros((0, 4)).astype(np.float32)
88 |
89 | for idx, p in enumerate(self.pyramid_levels):
90 | anchors = generate_anchors(base_size=self.sizes[idx], ratios=self.ratios, scales=self.scales)
91 | shifted_anchors = shift(image_shapes[idx], self.strides[idx], anchors)
92 | all_anchors = np.append(all_anchors, shifted_anchors, axis=0)
93 |
94 | all_anchors = np.expand_dims(all_anchors, axis=0)
95 |
96 | anchors = torch.from_numpy(all_anchors.astype(np.float32))
97 | if torch.cuda.is_available():
98 | anchors = anchors.cuda()
99 | return anchors
100 |
101 |
102 | def generate_anchors(base_size=16, ratios=None, scales=None):
103 | if ratios is None:
104 | ratios = np.array([0.5, 1, 2])
105 |
106 | if scales is None:
107 | scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])
108 |
109 | num_anchors = len(ratios) * len(scales)
110 | anchors = np.zeros((num_anchors, 4))
111 | anchors[:, 2:] = base_size * np.tile(scales, (2, len(ratios))).T
112 | areas = anchors[:, 2] * anchors[:, 3]
113 | anchors[:, 2] = np.sqrt(areas / np.repeat(ratios, len(scales)))
114 | anchors[:, 3] = anchors[:, 2] * np.repeat(ratios, len(scales))
115 | anchors[:, 0::2] -= np.tile(anchors[:, 2] * 0.5, (2, 1)).T
116 | anchors[:, 1::2] -= np.tile(anchors[:, 3] * 0.5, (2, 1)).T
117 |
118 | return anchors
119 |
120 |
121 | def compute_shape(image_shape, pyramid_levels):
122 | image_shape = np.array(image_shape[:2])
123 | image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in pyramid_levels]
124 | return image_shapes
125 |
126 |
127 | def shift(shape, stride, anchors):
128 | shift_x = (np.arange(0, shape[1]) + 0.5) * stride
129 | shift_y = (np.arange(0, shape[0]) + 0.5) * stride
130 | shift_x, shift_y = np.meshgrid(shift_x, shift_y)
131 | shifts = np.vstack((
132 | shift_x.ravel(), shift_y.ravel(),
133 | shift_x.ravel(), shift_y.ravel()
134 | )).transpose()
135 |
136 | A = anchors.shape[0]
137 | K = shifts.shape[0]
138 | all_anchors = (anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2)))
139 | all_anchors = all_anchors.reshape((K * A, 4))
140 |
141 | return all_anchors
142 |
--------------------------------------------------------------------------------
/test_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | from torchvision import transforms
5 | from src.dataset import CocoDataset, Resizer, Normalizer
6 | from src.config import COCO_CLASSES, colors
7 | import cv2
8 | import shutil
9 |
10 |
11 | def get_args():
12 | parser = argparse.ArgumentParser(
13 | "EfficientDet: Scalable and Efficient Object Detection implementation by Signatrix GmbH")
14 | parser.add_argument("--image_size", type=int, default=512, help="The common width and height for all images")
15 | parser.add_argument("--data_path", type=str, default="data/COCO", help="the root folder of dataset")
16 | parser.add_argument("--cls_threshold", type=float, default=0.5)
17 | parser.add_argument("--nms_threshold", type=float, default=0.5)
18 | parser.add_argument("--pretrained_model", type=str, default="trained_models/signatrix_efficientdet_coco.pth")
19 | parser.add_argument("--output", type=str, default="predictions")
20 | args = parser.parse_args()
21 | return args
22 |
23 |
24 |
25 | def test(opt):
26 | model = torch.load(opt.pretrained_model)
27 | model.cuda()
28 | dataset = CocoDataset(opt.data_path, set='val2017', transform=transforms.Compose([Normalizer(), Resizer()]))
29 |
30 | if os.path.isdir(opt.output):
31 | shutil.rmtree(opt.output)
32 | os.makedirs(opt.output)
33 |
34 | for index in range(len(dataset)):
35 | data = dataset[index]
36 | scale = data['scale']
37 | with torch.no_grad():
38 | scores, labels, boxes = model(data['img'].cuda().permute(2, 0, 1).float().unsqueeze(dim=0))
39 | boxes /= scale
40 |
41 | if boxes.shape[0] > 0:
42 | image_info = dataset.coco.loadImgs(dataset.image_ids[index])[0]
43 | path = os.path.join(dataset.root_dir, 'images', dataset.set_name, image_info['file_name'])
44 | output_image = cv2.imread(path)
45 |
46 | for box_id in range(boxes.shape[0]):
47 | pred_prob = float(scores[box_id])
48 | if pred_prob < opt.cls_threshold:
49 | break
50 | pred_label = int(labels[box_id])
51 | xmin, ymin, xmax, ymax = boxes[box_id, :]
52 | color = colors[pred_label]
53 | xmin = int(round(float(xmin), 0))
54 | ymin = int(round(float(ymin), 0))
55 | xmax = int(round(float(xmax), 0))
56 | ymax = int(round(float(ymax), 0))
57 | cv2.rectangle(output_image, (xmin, ymin), (xmax, ymax), color, 2)
58 | text_size = cv2.getTextSize(COCO_CLASSES[pred_label] + ' : %.2f' % pred_prob, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
59 |
60 | cv2.rectangle(output_image, (xmin, ymin), (xmin + text_size[0] + 3, ymin + text_size[1] + 4), color, -1)
61 | cv2.putText(
62 | output_image, COCO_CLASSES[pred_label] + ' : %.2f' % pred_prob,
63 | (xmin, ymin + text_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1,
64 | (255, 255, 255), 1)
65 |
66 | cv2.imwrite("{}/{}_prediction.jpg".format(opt.output, image_info["file_name"][:-4]), output_image)
67 |
68 |
69 | if __name__ == "__main__":
70 | opt = get_args()
71 | test(opt)
72 |
--------------------------------------------------------------------------------
/test_video.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from src.config import COCO_CLASSES, colors
4 | import cv2
5 | import numpy as np
6 |
7 |
8 | def get_args():
9 | parser = argparse.ArgumentParser(
10 | "EfficientDet: Scalable and Efficient Object Detection implementation by Signatrix GmbH")
11 | parser.add_argument("--image_size", type=int, default=512, help="The common width and height for all images")
12 | parser.add_argument("--cls_threshold", type=float, default=0.5)
13 | parser.add_argument("--nms_threshold", type=float, default=0.5)
14 | parser.add_argument("--pretrained_model", type=str, default="trained_models/signatrix_efficientdet_coco.pth")
15 | parser.add_argument("--input", type=str, default="test_videos/input.mp4")
16 | parser.add_argument("--output", type=str, default="test_videos/output.mp4")
17 |
18 | args = parser.parse_args()
19 | return args
20 |
21 |
22 | def test(opt):
23 | model = torch.load(opt.pretrained_model).module
24 | if torch.cuda.is_available():
25 | model.cuda()
26 |
27 | cap = cv2.VideoCapture(opt.input)
28 | out = cv2.VideoWriter(opt.output, cv2.VideoWriter_fourcc(*"MJPG"), int(cap.get(cv2.CAP_PROP_FPS)),
29 | (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))
30 | while cap.isOpened():
31 | flag, image = cap.read()
32 | output_image = np.copy(image)
33 | if flag:
34 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
35 | else:
36 | break
37 | height, width = image.shape[:2]
38 | image = image.astype(np.float32) / 255
39 | image[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
40 | image[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
41 | image[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
42 | if height > width:
43 | scale = opt.image_size / height
44 | resized_height = opt.image_size
45 | resized_width = int(width * scale)
46 | else:
47 | scale = opt.image_size / width
48 | resized_height = int(height * scale)
49 | resized_width = opt.image_size
50 |
51 | image = cv2.resize(image, (resized_width, resized_height))
52 |
53 | new_image = np.zeros((opt.image_size, opt.image_size, 3))
54 | new_image[0:resized_height, 0:resized_width] = image
55 | new_image = np.transpose(new_image, (2, 0, 1))
56 | new_image = new_image[None, :, :, :]
57 | new_image = torch.Tensor(new_image)
58 | if torch.cuda.is_available():
59 | new_image = new_image.cuda()
60 | with torch.no_grad():
61 | scores, labels, boxes = model(new_image)
62 | boxes /= scale
63 | if boxes.shape[0] == 0:
64 | continue
65 |
66 | for box_id in range(boxes.shape[0]):
67 | pred_prob = float(scores[box_id])
68 | if pred_prob < opt.cls_threshold:
69 | break
70 | pred_label = int(labels[box_id])
71 | xmin, ymin, xmax, ymax = boxes[box_id, :]
72 | color = colors[pred_label]
73 | cv2.rectangle(output_image, (xmin, ymin), (xmax, ymax), color, 2)
74 | text_size = cv2.getTextSize(COCO_CLASSES[pred_label] + ' : %.2f' % pred_prob, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
75 | cv2.rectangle(output_image, (xmin, ymin), (xmin + text_size[0] + 3, ymin + text_size[1] + 4), color, -1)
76 | cv2.putText(
77 | output_image, COCO_CLASSES[pred_label] + ' : %.2f' % pred_prob,
78 | (xmin, ymin + text_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1,
79 | (255, 255, 255), 1)
80 | out.write(output_image)
81 |
82 | cap.release()
83 | out.release()
84 |
85 |
86 | if __name__ == "__main__":
87 | opt = get_args()
88 | test(opt)
89 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | import torch.nn as nn
5 | from torch.utils.data import DataLoader
6 | from torchvision import transforms
7 | from src.dataset import CocoDataset, Resizer, Normalizer, Augmenter, collater
8 | from src.model import EfficientDet
9 | from tensorboardX import SummaryWriter
10 | import shutil
11 | import numpy as np
12 | from tqdm.autonotebook import tqdm
13 |
14 |
15 | def get_args():
16 | parser = argparse.ArgumentParser(
17 | "EfficientDet: Scalable and Efficient Object Detection implementation by Signatrix GmbH")
18 | parser.add_argument("--image_size", type=int, default=512, help="The common width and height for all images")
19 | parser.add_argument("--batch_size", type=int, default=8, help="The number of images per batch")
20 | parser.add_argument("--lr", type=float, default=1e-4)
21 | parser.add_argument('--alpha', type=float, default=0.25)
22 | parser.add_argument('--gamma', type=float, default=1.5)
23 | parser.add_argument("--num_epochs", type=int, default=500)
24 | parser.add_argument("--test_interval", type=int, default=1, help="Number of epoches between testing phases")
25 | parser.add_argument("--es_min_delta", type=float, default=0.0,
26 | help="Early stopping's parameter: minimum change loss to qualify as an improvement")
27 | parser.add_argument("--es_patience", type=int, default=0,
28 | help="Early stopping's parameter: number of epochs with no improvement after which training will be stopped. Set to 0 to disable this technique.")
29 | parser.add_argument("--data_path", type=str, default="data/COCO", help="the root folder of dataset")
30 | parser.add_argument("--log_path", type=str, default="tensorboard/signatrix_efficientdet_coco")
31 | parser.add_argument("--saved_path", type=str, default="trained_models")
32 |
33 | args = parser.parse_args()
34 | return args
35 |
36 |
37 | def train(opt):
38 | num_gpus = 1
39 | if torch.cuda.is_available():
40 | num_gpus = torch.cuda.device_count()
41 | torch.cuda.manual_seed(123)
42 | else:
43 | torch.manual_seed(123)
44 |
45 | training_params = {"batch_size": opt.batch_size * num_gpus,
46 | "shuffle": True,
47 | "drop_last": True,
48 | "collate_fn": collater,
49 | "num_workers": 12}
50 |
51 | test_params = {"batch_size": opt.batch_size,
52 | "shuffle": False,
53 | "drop_last": False,
54 | "collate_fn": collater,
55 | "num_workers": 12}
56 |
57 | training_set = CocoDataset(root_dir=opt.data_path, set="train2017",
58 | transform=transforms.Compose([Normalizer(), Augmenter(), Resizer()]))
59 | training_generator = DataLoader(training_set, **training_params)
60 |
61 | test_set = CocoDataset(root_dir=opt.data_path, set="val2017",
62 | transform=transforms.Compose([Normalizer(), Resizer()]))
63 | test_generator = DataLoader(test_set, **test_params)
64 |
65 | model = EfficientDet(num_classes=training_set.num_classes())
66 |
67 |
68 | if os.path.isdir(opt.log_path):
69 | shutil.rmtree(opt.log_path)
70 | os.makedirs(opt.log_path)
71 |
72 | if not os.path.isdir(opt.saved_path):
73 | os.makedirs(opt.saved_path)
74 |
75 | writer = SummaryWriter(opt.log_path)
76 | if torch.cuda.is_available():
77 | model = model.cuda()
78 | model = nn.DataParallel(model)
79 |
80 | optimizer = torch.optim.Adam(model.parameters(), opt.lr)
81 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
82 |
83 | best_loss = 1e5
84 | best_epoch = 0
85 | model.train()
86 |
87 | num_iter_per_epoch = len(training_generator)
88 | for epoch in range(opt.num_epochs):
89 | model.train()
90 | # if torch.cuda.is_available():
91 | # model.module.freeze_bn()
92 | # else:
93 | # model.freeze_bn()
94 | epoch_loss = []
95 | progress_bar = tqdm(training_generator)
96 | for iter, data in enumerate(progress_bar):
97 | try:
98 | optimizer.zero_grad()
99 | if torch.cuda.is_available():
100 | cls_loss, reg_loss = model([data['img'].cuda().float(), data['annot'].cuda()])
101 | else:
102 | cls_loss, reg_loss = model([data['img'].float(), data['annot']])
103 |
104 | cls_loss = cls_loss.mean()
105 | reg_loss = reg_loss.mean()
106 | loss = cls_loss + reg_loss
107 | if loss == 0:
108 | continue
109 | loss.backward()
110 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
111 | optimizer.step()
112 | epoch_loss.append(float(loss))
113 | total_loss = np.mean(epoch_loss)
114 |
115 | progress_bar.set_description(
116 | 'Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Batch loss: {:.5f} Total loss: {:.5f}'.format(
117 | epoch + 1, opt.num_epochs, iter + 1, num_iter_per_epoch, cls_loss, reg_loss, loss,
118 | total_loss))
119 | writer.add_scalar('Train/Total_loss', total_loss, epoch * num_iter_per_epoch + iter)
120 | writer.add_scalar('Train/Regression_loss', reg_loss, epoch * num_iter_per_epoch + iter)
121 | writer.add_scalar('Train/Classfication_loss (focal loss)', cls_loss, epoch * num_iter_per_epoch + iter)
122 |
123 | except Exception as e:
124 | print(e)
125 | continue
126 | scheduler.step(np.mean(epoch_loss))
127 |
128 | if epoch % opt.test_interval == 0:
129 | model.eval()
130 | loss_regression_ls = []
131 | loss_classification_ls = []
132 | for iter, data in enumerate(test_generator):
133 | with torch.no_grad():
134 | if torch.cuda.is_available():
135 | cls_loss, reg_loss = model([data['img'].cuda().float(), data['annot'].cuda()])
136 | else:
137 | cls_loss, reg_loss = model([data['img'].float(), data['annot']])
138 |
139 | cls_loss = cls_loss.mean()
140 | reg_loss = reg_loss.mean()
141 |
142 | loss_classification_ls.append(float(cls_loss))
143 | loss_regression_ls.append(float(reg_loss))
144 |
145 | cls_loss = np.mean(loss_classification_ls)
146 | reg_loss = np.mean(loss_regression_ls)
147 | loss = cls_loss + reg_loss
148 |
149 | print(
150 | 'Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'.format(
151 | epoch + 1, opt.num_epochs, cls_loss, reg_loss,
152 | np.mean(loss)))
153 | writer.add_scalar('Test/Total_loss', loss, epoch)
154 | writer.add_scalar('Test/Regression_loss', reg_loss, epoch)
155 | writer.add_scalar('Test/Classfication_loss (focal loss)', cls_loss, epoch)
156 |
157 | if loss + opt.es_min_delta < best_loss:
158 | best_loss = loss
159 | best_epoch = epoch
160 | torch.save(model, os.path.join(opt.saved_path, "signatrix_efficientdet_coco.pth"))
161 |
162 | dummy_input = torch.rand(opt.batch_size, 3, 512, 512)
163 | if torch.cuda.is_available():
164 | dummy_input = dummy_input.cuda()
165 | if isinstance(model, nn.DataParallel):
166 | model.module.backbone_net.model.set_swish(memory_efficient=False)
167 |
168 | torch.onnx.export(model.module, dummy_input,
169 | os.path.join(opt.saved_path, "signatrix_efficientdet_coco.onnx"),
170 | verbose=False)
171 | model.module.backbone_net.model.set_swish(memory_efficient=True)
172 | else:
173 | model.backbone_net.model.set_swish(memory_efficient=False)
174 |
175 | torch.onnx.export(model, dummy_input,
176 | os.path.join(opt.saved_path, "signatrix_efficientdet_coco.onnx"),
177 | verbose=False)
178 | model.backbone_net.model.set_swish(memory_efficient=True)
179 |
180 | # Early stopping
181 | if epoch - best_epoch > opt.es_patience > 0:
182 | print("Stop training at epoch {}. The lowest loss achieved is {}".format(epoch, loss))
183 | break
184 | writer.close()
185 |
186 |
187 | if __name__ == "__main__":
188 | opt = get_args()
189 | train(opt)
190 |
--------------------------------------------------------------------------------
/trained_models/signatrix_efficientdet_coco.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/trained_models/signatrix_efficientdet_coco.onnx
--------------------------------------------------------------------------------
/trained_models/signatrix_efficientdet_coco.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/trained_models/signatrix_efficientdet_coco.pth
--------------------------------------------------------------------------------