├── .gitignore ├── LICENSE ├── README.md ├── demo ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg ├── 5.jpg ├── 6.jpg ├── 7.jpg ├── 8.jpg ├── 9.jpg └── video.gif ├── mAP_evaluation.py ├── requirements.txt ├── src ├── config.py ├── dataset.py ├── loss.py ├── model.py └── utils.py ├── test_dataset.py ├── test_video.py ├── train.py └── trained_models ├── signatrix_efficientdet_coco.onnx └── signatrix_efficientdet_coco.pth /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | .DS_Store 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Signatrix 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EfficientDet: Scalable and Efficient Object Detection 2 | 3 | ## Introduction 4 | 5 | Here is our pytorch implementation of the model described in the paper **EfficientDet: Scalable and Efficient Object Detection** [paper](https://arxiv.org/abs/1911.09070) (*Note*: We also provide pre-trained weights, which you could see at ./trained_models) 6 |

7 |
8 | An example of our model's output. 9 |

10 | 11 | 12 | ## Datasets 13 | 14 | 15 | | Dataset | Classes | #Train images | #Validation images | 16 | |------------------------|:---------:|:-----------------------:|:----------------------------:| 17 | | COCO2017 | 80 | 118k | 5k | 18 | 19 | Create a data folder under the repository, 20 | 21 | ``` 22 | cd {repo_root} 23 | mkdir data 24 | ``` 25 | 26 | - **COCO**: 27 | Download the coco images and annotations from [coco website](http://cocodataset.org/#download). Make sure to put the files as the following structure: 28 | ``` 29 | COCO 30 | ├── annotations 31 | │ ├── instances_train2017.json 32 | │ └── instances_val2017.json 33 | │── images 34 | ├── train2017 35 | └── val2017 36 | ``` 37 | 38 | ## How to use our code 39 | 40 | With our code, you can: 41 | 42 | * **Train your model** by running **python train.py** 43 | * **Evaluate mAP for COCO dataset** by running **python mAP_evaluation.py** 44 | * **Test your model for COCO dataset** by running **python test_dataset.py --pretrained_model path/to/trained_model** 45 | * **Test your model for video** by running **python test_video.py --pretrained_model path/to/trained_model --input path/to/input/file --output path/to/output/file** 46 | 47 | ## Experiments 48 | 49 | We trained our model by using 3 NVIDIA GTX 1080Ti. Below is mAP (mean average precision) for COCO val2017 dataset 50 | 51 | | Average Precision | IoU=0.50:0.95 | area= all | maxDets=100 | 0.314 | 52 | |-----------------------|:-------------------:|:-----------------:|:-----------------:|:-------------:| 53 | | Average Precision | IoU=0.50 | area= all | maxDets=100 | 0.461 | 54 | | Average Precision | IoU=0.75 | area= all | maxDets=100 | 0.343 | 55 | | Average Precision | IoU=0.50:0.95 | area= small | maxDets=100 | 0.093 | 56 | | Average Precision | IoU=0.50:0.95 | area= medium | maxDets=100 | 0.358 | 57 | | Average Precision | IoU=0.50:0.95 | area= large | maxDets=100 | 0.517 | 58 | | Average Recall | IoU=0.50:0.95 | area= all | maxDets=1 | 0.268 | 59 | | Average Recall | IoU=0.50:0.95 | area= all | maxDets=10 | 0.382 | 60 | | Average Recall | IoU=0.50:0.95 | area= all | maxDets=100 | 0.403 | 61 | | Average Recall | IoU=0.50:0.95 | area= small | maxDets=100 | 0.117 | 62 | | Average Recall | IoU=0.50:0.95 | area= medium | maxDets=100 | 0.486 | 63 | | Average Recall | IoU=0.50:0.95 | area= large | maxDets=100 | 0.625 | 64 | 65 | 66 | ## Results 67 | 68 | Some predictions are shown below: 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | ## Requirements 78 | 79 | * **python 3.6** 80 | * **pytorch 1.2** 81 | * **opencv (cv2)** 82 | * **tensorboard** 83 | * **tensorboardX** (This library could be skipped if you do not use SummaryWriter) 84 | * **pycocotools** 85 | * **efficientnet_pytorch** 86 | 87 | ## References 88 | - Mingxing Tan, Ruoming Pang, Quoc V. Le. "EfficientDet: Scalable and Efficient Object Detection." [EfficientDet](https://arxiv.org/abs/1911.09070). 89 | - Our implementation borrows some parts from [RetinaNet.Pytorch](https://github.com/yhenon/pytorch-retinanet) 90 | 91 | 92 | ## Citation 93 | 94 | @article{EfficientDetSignatrix, 95 | Author = {Signatrix GmbH}, 96 | Title = {A Pytorch Implementation of EfficientDet Object Detection}, 97 | Journal = {https://github.com/signatrix/efficientdet}, 98 | Year = {2020} 99 | } -------------------------------------------------------------------------------- /demo/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/1.jpg -------------------------------------------------------------------------------- /demo/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/2.jpg -------------------------------------------------------------------------------- /demo/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/3.jpg -------------------------------------------------------------------------------- /demo/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/4.jpg -------------------------------------------------------------------------------- /demo/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/5.jpg -------------------------------------------------------------------------------- /demo/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/6.jpg -------------------------------------------------------------------------------- /demo/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/7.jpg -------------------------------------------------------------------------------- /demo/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/8.jpg -------------------------------------------------------------------------------- /demo/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/9.jpg -------------------------------------------------------------------------------- /demo/video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/demo/video.gif -------------------------------------------------------------------------------- /mAP_evaluation.py: -------------------------------------------------------------------------------- 1 | from src.dataset import CocoDataset, Resizer, Normalizer 2 | from torchvision import transforms 3 | from pycocotools.cocoeval import COCOeval 4 | import json 5 | import torch 6 | 7 | 8 | def evaluate_coco(dataset, model, threshold=0.05): 9 | model.eval() 10 | with torch.no_grad(): 11 | results = [] 12 | image_ids = [] 13 | 14 | for index in range(len(dataset)): 15 | data = dataset[index] 16 | scale = data['scale'] 17 | scores, labels, boxes = model(data['img'].cuda().permute(2, 0, 1).float().unsqueeze(dim=0)) 18 | boxes /= scale 19 | 20 | if boxes.shape[0] > 0: 21 | 22 | boxes[:, 2] -= boxes[:, 0] 23 | boxes[:, 3] -= boxes[:, 1] 24 | 25 | for box_id in range(boxes.shape[0]): 26 | score = float(scores[box_id]) 27 | label = int(labels[box_id]) 28 | box = boxes[box_id, :] 29 | 30 | if score < threshold: 31 | break 32 | 33 | image_result = { 34 | 'image_id': dataset.image_ids[index], 35 | 'category_id': dataset.label_to_coco_label(label), 36 | 'score': float(score), 37 | 'bbox': box.tolist(), 38 | } 39 | 40 | results.append(image_result) 41 | 42 | # append image to list of processed images 43 | image_ids.append(dataset.image_ids[index]) 44 | 45 | # print progress 46 | print('{}/{}'.format(index, len(dataset)), end='\r') 47 | 48 | if not len(results): 49 | return 50 | 51 | # write output 52 | json.dump(results, open('{}_bbox_results.json'.format(dataset.set_name), 'w'), indent=4) 53 | 54 | # load results in COCO evaluation tool 55 | coco_true = dataset.coco 56 | coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(dataset.set_name)) 57 | 58 | # run COCO evaluation 59 | coco_eval = COCOeval(coco_true, coco_pred, 'bbox') 60 | coco_eval.params.imgIds = image_ids 61 | coco_eval.evaluate() 62 | coco_eval.accumulate() 63 | coco_eval.summarize() 64 | 65 | 66 | if __name__ == '__main__': 67 | efficientdet = torch.load("trained_models/signatrix_efficientdet_coco.pth").module 68 | efficientdet.cuda() 69 | dataset_val = CocoDataset("data/COCO", set='val2017', 70 | transform=transforms.Compose([Normalizer(), Resizer()])) 71 | evaluate_coco(dataset_val, efficientdet) 72 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | efficientnet_pytorch 2 | tensorboardX 3 | pycocotools -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | COCO_CLASSES = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", 2 | "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", 3 | "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", 4 | "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", 5 | "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", 6 | "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", 7 | "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", 8 | "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", 9 | "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", 10 | "teddy bear", "hair drier", "toothbrush"] 11 | 12 | colors = [(39, 129, 113), (164, 80, 133), (83, 122, 114), (99, 81, 172), (95, 56, 104), (37, 84, 86), (14, 89, 122), 13 | (80, 7, 65), (10, 102, 25), (90, 185, 109), (106, 110, 132), (169, 158, 85), (188, 185, 26), (103, 1, 17), 14 | (82, 144, 81), (92, 7, 184), (49, 81, 155), (179, 177, 69), (93, 187, 158), (13, 39, 73), (12, 50, 60), 15 | (16, 179, 33), (112, 69, 165), (15, 139, 63), (33, 191, 159), (182, 173, 32), (34, 113, 133), (90, 135, 34), 16 | (53, 34, 86), (141, 35, 190), (6, 171, 8), (118, 76, 112), (89, 60, 55), (15, 54, 88), (112, 75, 181), 17 | (42, 147, 38), (138, 52, 63), (128, 65, 149), (106, 103, 24), (168, 33, 45), (28, 136, 135), (86, 91, 108), 18 | (52, 11, 76), (142, 6, 189), (57, 81, 168), (55, 19, 148), (182, 101, 89), (44, 65, 179), (1, 33, 26), 19 | (122, 164, 26), (70, 63, 134), (137, 106, 82), (120, 118, 52), (129, 74, 42), (182, 147, 112), (22, 157, 50), 20 | (56, 50, 20), (2, 22, 177), (156, 100, 106), (21, 35, 42), (13, 8, 121), (142, 92, 28), (45, 118, 33), 21 | (105, 118, 30), (7, 185, 124), (46, 34, 146), (105, 184, 169), (22, 18, 5), (147, 71, 73), (181, 64, 91), 22 | (31, 39, 184), (164, 179, 33), (96, 50, 18), (95, 15, 106), (113, 68, 54), (136, 116, 112), (119, 139, 130), 23 | (31, 139, 34), (66, 6, 127), (62, 39, 2), (49, 99, 180), (49, 119, 155), (153, 50, 183), (125, 38, 3), 24 | (129, 87, 143), (49, 87, 40), (128, 62, 120), (73, 85, 148), (28, 144, 118), (29, 9, 24), (175, 45, 108), 25 | (81, 175, 64), (178, 19, 157), (74, 188, 190), (18, 114, 2), (62, 128, 96), (21, 3, 150), (0, 6, 95), 26 | (2, 20, 184), (122, 37, 185)] 27 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from torch.utils.data import Dataset, DataLoader 6 | from pycocotools.coco import COCO 7 | import cv2 8 | 9 | 10 | class CocoDataset(Dataset): 11 | def __init__(self, root_dir, set='train2017', transform=None): 12 | 13 | self.root_dir = root_dir 14 | self.set_name = set 15 | self.transform = transform 16 | 17 | self.coco = COCO(os.path.join(self.root_dir, 'annotations', 'instances_' + self.set_name + '.json')) 18 | self.image_ids = self.coco.getImgIds() 19 | 20 | self.load_classes() 21 | 22 | def load_classes(self): 23 | 24 | # load class names (name -> label) 25 | categories = self.coco.loadCats(self.coco.getCatIds()) 26 | categories.sort(key=lambda x: x['id']) 27 | 28 | self.classes = {} 29 | self.coco_labels = {} 30 | self.coco_labels_inverse = {} 31 | for c in categories: 32 | self.coco_labels[len(self.classes)] = c['id'] 33 | self.coco_labels_inverse[c['id']] = len(self.classes) 34 | self.classes[c['name']] = len(self.classes) 35 | 36 | # also load the reverse (label -> name) 37 | self.labels = {} 38 | for key, value in self.classes.items(): 39 | self.labels[value] = key 40 | 41 | def __len__(self): 42 | return len(self.image_ids) 43 | 44 | def __getitem__(self, idx): 45 | 46 | img = self.load_image(idx) 47 | annot = self.load_annotations(idx) 48 | sample = {'img': img, 'annot': annot} 49 | if self.transform: 50 | sample = self.transform(sample) 51 | return sample 52 | 53 | def load_image(self, image_index): 54 | image_info = self.coco.loadImgs(self.image_ids[image_index])[0] 55 | path = os.path.join(self.root_dir, 'images', self.set_name, image_info['file_name']) 56 | img = cv2.imread(path) 57 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 58 | 59 | # if len(img.shape) == 2: 60 | # img = skimage.color.gray2rgb(img) 61 | 62 | return img.astype(np.float32) / 255. 63 | 64 | def load_annotations(self, image_index): 65 | # get ground truth annotations 66 | annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False) 67 | annotations = np.zeros((0, 5)) 68 | 69 | # some images appear to miss annotations 70 | if len(annotations_ids) == 0: 71 | return annotations 72 | 73 | # parse annotations 74 | coco_annotations = self.coco.loadAnns(annotations_ids) 75 | for idx, a in enumerate(coco_annotations): 76 | 77 | # some annotations have basically no width / height, skip them 78 | if a['bbox'][2] < 1 or a['bbox'][3] < 1: 79 | continue 80 | 81 | annotation = np.zeros((1, 5)) 82 | annotation[0, :4] = a['bbox'] 83 | annotation[0, 4] = self.coco_label_to_label(a['category_id']) 84 | annotations = np.append(annotations, annotation, axis=0) 85 | 86 | # transform from [x, y, w, h] to [x1, y1, x2, y2] 87 | annotations[:, 2] = annotations[:, 0] + annotations[:, 2] 88 | annotations[:, 3] = annotations[:, 1] + annotations[:, 3] 89 | 90 | return annotations 91 | 92 | def coco_label_to_label(self, coco_label): 93 | return self.coco_labels_inverse[coco_label] 94 | 95 | def label_to_coco_label(self, label): 96 | return self.coco_labels[label] 97 | 98 | def num_classes(self): 99 | return 80 100 | 101 | 102 | def collater(data): 103 | imgs = [s['img'] for s in data] 104 | annots = [s['annot'] for s in data] 105 | scales = [s['scale'] for s in data] 106 | 107 | imgs = torch.from_numpy(np.stack(imgs, axis=0)) 108 | 109 | max_num_annots = max(annot.shape[0] for annot in annots) 110 | 111 | if max_num_annots > 0: 112 | 113 | annot_padded = torch.ones((len(annots), max_num_annots, 5)) * -1 114 | 115 | if max_num_annots > 0: 116 | for idx, annot in enumerate(annots): 117 | if annot.shape[0] > 0: 118 | annot_padded[idx, :annot.shape[0], :] = annot 119 | else: 120 | annot_padded = torch.ones((len(annots), 1, 5)) * -1 121 | 122 | imgs = imgs.permute(0, 3, 1, 2) 123 | 124 | return {'img': imgs, 'annot': annot_padded, 'scale': scales} 125 | 126 | 127 | class Resizer(object): 128 | """Convert ndarrays in sample to Tensors.""" 129 | 130 | def __call__(self, sample, common_size=512): 131 | image, annots = sample['img'], sample['annot'] 132 | height, width, _ = image.shape 133 | if height > width: 134 | scale = common_size / height 135 | resized_height = common_size 136 | resized_width = int(width * scale) 137 | else: 138 | scale = common_size / width 139 | resized_height = int(height * scale) 140 | resized_width = common_size 141 | 142 | image = cv2.resize(image, (resized_width, resized_height)) 143 | 144 | new_image = np.zeros((common_size, common_size, 3)) 145 | new_image[0:resized_height, 0:resized_width] = image 146 | 147 | annots[:, :4] *= scale 148 | 149 | return {'img': torch.from_numpy(new_image), 'annot': torch.from_numpy(annots), 'scale': scale} 150 | 151 | 152 | class Augmenter(object): 153 | """Convert ndarrays in sample to Tensors.""" 154 | 155 | def __call__(self, sample, flip_x=0.5): 156 | if np.random.rand() < flip_x: 157 | image, annots = sample['img'], sample['annot'] 158 | image = image[:, ::-1, :] 159 | 160 | rows, cols, channels = image.shape 161 | 162 | x1 = annots[:, 0].copy() 163 | x2 = annots[:, 2].copy() 164 | 165 | x_tmp = x1.copy() 166 | 167 | annots[:, 0] = cols - x2 168 | annots[:, 2] = cols - x_tmp 169 | 170 | sample = {'img': image, 'annot': annots} 171 | 172 | return sample 173 | 174 | 175 | class Normalizer(object): 176 | 177 | def __init__(self): 178 | self.mean = np.array([[[0.485, 0.456, 0.406]]]) 179 | self.std = np.array([[[0.229, 0.224, 0.225]]]) 180 | 181 | def __call__(self, sample): 182 | image, annots = sample['img'], sample['annot'] 183 | 184 | return {'img': ((image.astype(np.float32) - self.mean) / self.std), 'annot': annots} 185 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def calc_iou(a, b): 6 | 7 | area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]) 8 | iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0]) 9 | ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1]) 10 | iw = torch.clamp(iw, min=0) 11 | ih = torch.clamp(ih, min=0) 12 | ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih 13 | ua = torch.clamp(ua, min=1e-8) 14 | intersection = iw * ih 15 | IoU = intersection / ua 16 | 17 | return IoU 18 | 19 | 20 | class FocalLoss(nn.Module): 21 | def __init__(self): 22 | super(FocalLoss, self).__init__() 23 | 24 | def forward(self, classifications, regressions, anchors, annotations): 25 | alpha = 0.25 26 | gamma = 2.0 27 | batch_size = classifications.shape[0] 28 | classification_losses = [] 29 | regression_losses = [] 30 | 31 | anchor = anchors[0, :, :] 32 | 33 | anchor_widths = anchor[:, 2] - anchor[:, 0] 34 | anchor_heights = anchor[:, 3] - anchor[:, 1] 35 | anchor_ctr_x = anchor[:, 0] + 0.5 * anchor_widths 36 | anchor_ctr_y = anchor[:, 1] + 0.5 * anchor_heights 37 | 38 | for j in range(batch_size): 39 | 40 | classification = classifications[j, :, :] 41 | regression = regressions[j, :, :] 42 | 43 | bbox_annotation = annotations[j, :, :] 44 | bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1] 45 | 46 | if bbox_annotation.shape[0] == 0: 47 | if torch.cuda.is_available(): 48 | regression_losses.append(torch.tensor(0).float().cuda()) 49 | classification_losses.append(torch.tensor(0).float().cuda()) 50 | else: 51 | regression_losses.append(torch.tensor(0).float()) 52 | classification_losses.append(torch.tensor(0).float()) 53 | 54 | continue 55 | 56 | classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4) 57 | 58 | IoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4]) 59 | 60 | IoU_max, IoU_argmax = torch.max(IoU, dim=1) 61 | 62 | # compute the loss for classification 63 | targets = torch.ones(classification.shape) * -1 64 | if torch.cuda.is_available(): 65 | targets = targets.cuda() 66 | 67 | targets[torch.lt(IoU_max, 0.4), :] = 0 68 | 69 | positive_indices = torch.ge(IoU_max, 0.5) 70 | 71 | num_positive_anchors = positive_indices.sum() 72 | 73 | assigned_annotations = bbox_annotation[IoU_argmax, :] 74 | 75 | targets[positive_indices, :] = 0 76 | targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1 77 | 78 | alpha_factor = torch.ones(targets.shape) * alpha 79 | if torch.cuda.is_available(): 80 | alpha_factor = alpha_factor.cuda() 81 | 82 | alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor) 83 | focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification) 84 | focal_weight = alpha_factor * torch.pow(focal_weight, gamma) 85 | 86 | bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification)) 87 | 88 | cls_loss = focal_weight * bce 89 | 90 | zeros = torch.zeros(cls_loss.shape) 91 | if torch.cuda.is_available(): 92 | zeros = zeros.cuda() 93 | cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros) 94 | 95 | classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.float(), min=1.0)) 96 | 97 | 98 | if positive_indices.sum() > 0: 99 | assigned_annotations = assigned_annotations[positive_indices, :] 100 | 101 | anchor_widths_pi = anchor_widths[positive_indices] 102 | anchor_heights_pi = anchor_heights[positive_indices] 103 | anchor_ctr_x_pi = anchor_ctr_x[positive_indices] 104 | anchor_ctr_y_pi = anchor_ctr_y[positive_indices] 105 | 106 | gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0] 107 | gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1] 108 | gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths 109 | gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights 110 | 111 | gt_widths = torch.clamp(gt_widths, min=1) 112 | gt_heights = torch.clamp(gt_heights, min=1) 113 | 114 | targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi 115 | targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi 116 | targets_dw = torch.log(gt_widths / anchor_widths_pi) 117 | targets_dh = torch.log(gt_heights / anchor_heights_pi) 118 | 119 | targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh)) 120 | targets = targets.t() 121 | 122 | norm = torch.Tensor([[0.1, 0.1, 0.2, 0.2]]) 123 | if torch.cuda.is_available(): 124 | norm = norm.cuda() 125 | targets = targets / norm 126 | 127 | regression_diff = torch.abs(targets - regression[positive_indices, :]) 128 | 129 | regression_loss = torch.where( 130 | torch.le(regression_diff, 1.0 / 9.0), 131 | 0.5 * 9.0 * torch.pow(regression_diff, 2), 132 | regression_diff - 0.5 / 9.0 133 | ) 134 | regression_losses.append(regression_loss.mean()) 135 | else: 136 | if torch.cuda.is_available(): 137 | regression_losses.append(torch.tensor(0).float().cuda()) 138 | else: 139 | regression_losses.append(torch.tensor(0).float()) 140 | 141 | return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses).mean(dim=0, 142 | keepdim=True) 143 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | from efficientnet_pytorch import EfficientNet as EffNet 5 | from src.utils import BBoxTransform, ClipBoxes, Anchors 6 | from src.loss import FocalLoss 7 | from torchvision.ops.boxes import nms as nms_torch 8 | 9 | 10 | def nms(dets, thresh): 11 | return nms_torch(dets[:, :4], dets[:, 4], thresh) 12 | 13 | 14 | class ConvBlock(nn.Module): 15 | def __init__(self, num_channels): 16 | super(ConvBlock, self).__init__() 17 | self.conv = nn.Sequential( 18 | nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1, groups=num_channels), 19 | nn.Conv2d(num_channels, num_channels, kernel_size=1, stride=1, padding=0), 20 | nn.BatchNorm2d(num_features=num_channels, momentum=0.9997, eps=4e-5), nn.ReLU()) 21 | 22 | def forward(self, input): 23 | return self.conv(input) 24 | 25 | 26 | class BiFPN(nn.Module): 27 | def __init__(self, num_channels, epsilon=1e-4): 28 | super(BiFPN, self).__init__() 29 | self.epsilon = epsilon 30 | # Conv layers 31 | self.conv6_up = ConvBlock(num_channels) 32 | self.conv5_up = ConvBlock(num_channels) 33 | self.conv4_up = ConvBlock(num_channels) 34 | self.conv3_up = ConvBlock(num_channels) 35 | self.conv4_down = ConvBlock(num_channels) 36 | self.conv5_down = ConvBlock(num_channels) 37 | self.conv6_down = ConvBlock(num_channels) 38 | self.conv7_down = ConvBlock(num_channels) 39 | 40 | # Feature scaling layers 41 | self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest') 42 | self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest') 43 | self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest') 44 | self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest') 45 | 46 | self.p4_downsample = nn.MaxPool2d(kernel_size=2) 47 | self.p5_downsample = nn.MaxPool2d(kernel_size=2) 48 | self.p6_downsample = nn.MaxPool2d(kernel_size=2) 49 | self.p7_downsample = nn.MaxPool2d(kernel_size=2) 50 | 51 | # Weight 52 | self.p6_w1 = nn.Parameter(torch.ones(2)) 53 | self.p6_w1_relu = nn.ReLU() 54 | self.p5_w1 = nn.Parameter(torch.ones(2)) 55 | self.p5_w1_relu = nn.ReLU() 56 | self.p4_w1 = nn.Parameter(torch.ones(2)) 57 | self.p4_w1_relu = nn.ReLU() 58 | self.p3_w1 = nn.Parameter(torch.ones(2)) 59 | self.p3_w1_relu = nn.ReLU() 60 | 61 | self.p4_w2 = nn.Parameter(torch.ones(3)) 62 | self.p4_w2_relu = nn.ReLU() 63 | self.p5_w2 = nn.Parameter(torch.ones(3)) 64 | self.p5_w2_relu = nn.ReLU() 65 | self.p6_w2 = nn.Parameter(torch.ones(3)) 66 | self.p6_w2_relu = nn.ReLU() 67 | self.p7_w2 = nn.Parameter(torch.ones(2)) 68 | self.p7_w2_relu = nn.ReLU() 69 | 70 | def forward(self, inputs): 71 | """ 72 | P7_0 -------------------------- P7_2 --------> 73 | 74 | P6_0 ---------- P6_1 ---------- P6_2 --------> 75 | 76 | P5_0 ---------- P5_1 ---------- P5_2 --------> 77 | 78 | P4_0 ---------- P4_1 ---------- P4_2 --------> 79 | 80 | P3_0 -------------------------- P3_2 --------> 81 | """ 82 | 83 | # P3_0, P4_0, P5_0, P6_0 and P7_0 84 | p3_in, p4_in, p5_in, p6_in, p7_in = inputs 85 | # P7_0 to P7_2 86 | # Weights for P6_0 and P7_0 to P6_1 87 | p6_w1 = self.p6_w1_relu(self.p6_w1) 88 | weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon) 89 | # Connections for P6_0 and P7_0 to P6_1 respectively 90 | p6_up = self.conv6_up(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in)) 91 | # Weights for P5_0 and P6_0 to P5_1 92 | p5_w1 = self.p5_w1_relu(self.p5_w1) 93 | weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon) 94 | # Connections for P5_0 and P6_0 to P5_1 respectively 95 | p5_up = self.conv5_up(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_up)) 96 | # Weights for P4_0 and P5_0 to P4_1 97 | p4_w1 = self.p4_w1_relu(self.p4_w1) 98 | weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon) 99 | # Connections for P4_0 and P5_0 to P4_1 respectively 100 | p4_up = self.conv4_up(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_up)) 101 | 102 | # Weights for P3_0 and P4_1 to P3_2 103 | p3_w1 = self.p3_w1_relu(self.p3_w1) 104 | weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon) 105 | # Connections for P3_0 and P4_1 to P3_2 respectively 106 | p3_out = self.conv3_up(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_up)) 107 | 108 | # Weights for P4_0, P4_1 and P3_2 to P4_2 109 | p4_w2 = self.p4_w2_relu(self.p4_w2) 110 | weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon) 111 | # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively 112 | p4_out = self.conv4_down( 113 | weight[0] * p4_in + weight[1] * p4_up + weight[2] * self.p4_downsample(p3_out)) 114 | # Weights for P5_0, P5_1 and P4_2 to P5_2 115 | p5_w2 = self.p5_w2_relu(self.p5_w2) 116 | weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon) 117 | # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively 118 | p5_out = self.conv5_down( 119 | weight[0] * p5_in + weight[1] * p5_up + weight[2] * self.p5_downsample(p4_out)) 120 | # Weights for P6_0, P6_1 and P5_2 to P6_2 121 | p6_w2 = self.p6_w2_relu(self.p6_w2) 122 | weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon) 123 | # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively 124 | p6_out = self.conv6_down( 125 | weight[0] * p6_in + weight[1] * p6_up + weight[2] * self.p6_downsample(p5_out)) 126 | # Weights for P7_0 and P6_2 to P7_2 127 | p7_w2 = self.p7_w2_relu(self.p7_w2) 128 | weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon) 129 | # Connections for P7_0 and P6_2 to P7_2 130 | p7_out = self.conv7_down(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out)) 131 | 132 | return p3_out, p4_out, p5_out, p6_out, p7_out 133 | 134 | 135 | class Regressor(nn.Module): 136 | def __init__(self, in_channels, num_anchors, num_layers): 137 | super(Regressor, self).__init__() 138 | layers = [] 139 | for _ in range(num_layers): 140 | layers.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) 141 | layers.append(nn.ReLU(True)) 142 | self.layers = nn.Sequential(*layers) 143 | self.header = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1) 144 | 145 | def forward(self, inputs): 146 | inputs = self.layers(inputs) 147 | inputs = self.header(inputs) 148 | output = inputs.permute(0, 2, 3, 1) 149 | return output.contiguous().view(output.shape[0], -1, 4) 150 | 151 | 152 | class Classifier(nn.Module): 153 | def __init__(self, in_channels, num_anchors, num_classes, num_layers): 154 | super(Classifier, self).__init__() 155 | self.num_anchors = num_anchors 156 | self.num_classes = num_classes 157 | layers = [] 158 | for _ in range(num_layers): 159 | layers.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) 160 | layers.append(nn.ReLU(True)) 161 | self.layers = nn.Sequential(*layers) 162 | self.header = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1) 163 | self.act = nn.Sigmoid() 164 | 165 | def forward(self, inputs): 166 | inputs = self.layers(inputs) 167 | inputs = self.header(inputs) 168 | inputs = self.act(inputs) 169 | inputs = inputs.permute(0, 2, 3, 1) 170 | output = inputs.contiguous().view(inputs.shape[0], inputs.shape[1], inputs.shape[2], self.num_anchors, 171 | self.num_classes) 172 | return output.contiguous().view(output.shape[0], -1, self.num_classes) 173 | 174 | 175 | class EfficientNet(nn.Module): 176 | def __init__(self, ): 177 | super(EfficientNet, self).__init__() 178 | model = EffNet.from_pretrained('efficientnet-b0') 179 | del model._conv_head 180 | del model._bn1 181 | del model._avg_pooling 182 | del model._dropout 183 | del model._fc 184 | self.model = model 185 | 186 | def forward(self, x): 187 | x = self.model._swish(self.model._bn0(self.model._conv_stem(x))) 188 | feature_maps = [] 189 | for idx, block in enumerate(self.model._blocks): 190 | drop_connect_rate = self.model._global_params.drop_connect_rate 191 | if drop_connect_rate: 192 | drop_connect_rate *= float(idx) / len(self.model._blocks) 193 | x = block(x, drop_connect_rate=drop_connect_rate) 194 | if block._depthwise_conv.stride == [2, 2]: 195 | feature_maps.append(x) 196 | 197 | return feature_maps[1:] 198 | 199 | 200 | class EfficientDet(nn.Module): 201 | def __init__(self, num_anchors=9, num_classes=20, compound_coef=0): 202 | super(EfficientDet, self).__init__() 203 | self.compound_coef = compound_coef 204 | 205 | self.num_channels = [64, 88, 112, 160, 224, 288, 384, 384][self.compound_coef] 206 | 207 | self.conv3 = nn.Conv2d(40, self.num_channels, kernel_size=1, stride=1, padding=0) 208 | self.conv4 = nn.Conv2d(80, self.num_channels, kernel_size=1, stride=1, padding=0) 209 | self.conv5 = nn.Conv2d(192, self.num_channels, kernel_size=1, stride=1, padding=0) 210 | self.conv6 = nn.Conv2d(192, self.num_channels, kernel_size=3, stride=2, padding=1) 211 | self.conv7 = nn.Sequential(nn.ReLU(), 212 | nn.Conv2d(self.num_channels, self.num_channels, kernel_size=3, stride=2, padding=1)) 213 | 214 | self.bifpn = nn.Sequential(*[BiFPN(self.num_channels) for _ in range(min(2 + self.compound_coef, 8))]) 215 | 216 | self.num_classes = num_classes 217 | self.regressor = Regressor(in_channels=self.num_channels, num_anchors=num_anchors, 218 | num_layers=3 + self.compound_coef // 3) 219 | self.classifier = Classifier(in_channels=self.num_channels, num_anchors=num_anchors, num_classes=num_classes, 220 | num_layers=3 + self.compound_coef // 3) 221 | 222 | self.anchors = Anchors() 223 | self.regressBoxes = BBoxTransform() 224 | self.clipBoxes = ClipBoxes() 225 | self.focalLoss = FocalLoss() 226 | 227 | for m in self.modules(): 228 | if isinstance(m, nn.Conv2d): 229 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 230 | m.weight.data.normal_(0, math.sqrt(2. / n)) 231 | elif isinstance(m, nn.BatchNorm2d): 232 | m.weight.data.fill_(1) 233 | m.bias.data.zero_() 234 | 235 | prior = 0.01 236 | 237 | self.classifier.header.weight.data.fill_(0) 238 | self.classifier.header.bias.data.fill_(-math.log((1.0 - prior) / prior)) 239 | 240 | self.regressor.header.weight.data.fill_(0) 241 | self.regressor.header.bias.data.fill_(0) 242 | 243 | self.backbone_net = EfficientNet() 244 | 245 | def freeze_bn(self): 246 | for m in self.modules(): 247 | if isinstance(m, nn.BatchNorm2d): 248 | m.eval() 249 | 250 | def forward(self, inputs): 251 | if len(inputs) == 2: 252 | is_training = True 253 | img_batch, annotations = inputs 254 | else: 255 | is_training = False 256 | img_batch = inputs 257 | 258 | c3, c4, c5 = self.backbone_net(img_batch) 259 | p3 = self.conv3(c3) 260 | p4 = self.conv4(c4) 261 | p5 = self.conv5(c5) 262 | p6 = self.conv6(c5) 263 | p7 = self.conv7(p6) 264 | 265 | features = [p3, p4, p5, p6, p7] 266 | features = self.bifpn(features) 267 | 268 | regression = torch.cat([self.regressor(feature) for feature in features], dim=1) 269 | classification = torch.cat([self.classifier(feature) for feature in features], dim=1) 270 | anchors = self.anchors(img_batch) 271 | 272 | if is_training: 273 | return self.focalLoss(classification, regression, anchors, annotations) 274 | else: 275 | transformed_anchors = self.regressBoxes(anchors, regression) 276 | transformed_anchors = self.clipBoxes(transformed_anchors, img_batch) 277 | 278 | scores = torch.max(classification, dim=2, keepdim=True)[0] 279 | 280 | scores_over_thresh = (scores > 0.05)[0, :, 0] 281 | 282 | if scores_over_thresh.sum() == 0: 283 | return [torch.zeros(0), torch.zeros(0), torch.zeros(0, 4)] 284 | 285 | classification = classification[:, scores_over_thresh, :] 286 | transformed_anchors = transformed_anchors[:, scores_over_thresh, :] 287 | scores = scores[:, scores_over_thresh, :] 288 | 289 | anchors_nms_idx = nms(torch.cat([transformed_anchors, scores], dim=2)[0, :, :], 0.5) 290 | 291 | nms_scores, nms_class = classification[0, anchors_nms_idx, :].max(dim=1) 292 | 293 | return [nms_scores, nms_class, transformed_anchors[0, anchors_nms_idx, :]] 294 | 295 | 296 | if __name__ == '__main__': 297 | from tensorboardX import SummaryWriter 298 | def count_parameters(model): 299 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 300 | 301 | model = EfficientDet(num_classes=80) 302 | print (count_parameters(model)) 303 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class BBoxTransform(nn.Module): 7 | 8 | def __init__(self, mean=None, std=None): 9 | super(BBoxTransform, self).__init__() 10 | if mean is None: 11 | self.mean = torch.from_numpy(np.array([0, 0, 0, 0]).astype(np.float32)) 12 | else: 13 | self.mean = mean 14 | if std is None: 15 | self.std = torch.from_numpy(np.array([0.1, 0.1, 0.2, 0.2]).astype(np.float32)) 16 | else: 17 | self.std = std 18 | if torch.cuda.is_available(): 19 | self.mean = self.mean.cuda() 20 | self.std = self.std.cuda() 21 | 22 | def forward(self, boxes, deltas): 23 | 24 | widths = boxes[:, :, 2] - boxes[:, :, 0] 25 | heights = boxes[:, :, 3] - boxes[:, :, 1] 26 | ctr_x = boxes[:, :, 0] + 0.5 * widths 27 | ctr_y = boxes[:, :, 1] + 0.5 * heights 28 | 29 | dx = deltas[:, :, 0] * self.std[0] + self.mean[0] 30 | dy = deltas[:, :, 1] * self.std[1] + self.mean[1] 31 | dw = deltas[:, :, 2] * self.std[2] + self.mean[2] 32 | dh = deltas[:, :, 3] * self.std[3] + self.mean[3] 33 | 34 | pred_ctr_x = ctr_x + dx * widths 35 | pred_ctr_y = ctr_y + dy * heights 36 | pred_w = torch.exp(dw) * widths 37 | pred_h = torch.exp(dh) * heights 38 | 39 | pred_boxes_x1 = pred_ctr_x - 0.5 * pred_w 40 | pred_boxes_y1 = pred_ctr_y - 0.5 * pred_h 41 | pred_boxes_x2 = pred_ctr_x + 0.5 * pred_w 42 | pred_boxes_y2 = pred_ctr_y + 0.5 * pred_h 43 | 44 | pred_boxes = torch.stack([pred_boxes_x1, pred_boxes_y1, pred_boxes_x2, pred_boxes_y2], dim=2) 45 | 46 | return pred_boxes 47 | 48 | 49 | class ClipBoxes(nn.Module): 50 | 51 | def __init__(self): 52 | super(ClipBoxes, self).__init__() 53 | 54 | def forward(self, boxes, img): 55 | batch_size, num_channels, height, width = img.shape 56 | 57 | boxes[:, :, 0] = torch.clamp(boxes[:, :, 0], min=0) 58 | boxes[:, :, 1] = torch.clamp(boxes[:, :, 1], min=0) 59 | 60 | boxes[:, :, 2] = torch.clamp(boxes[:, :, 2], max=width) 61 | boxes[:, :, 3] = torch.clamp(boxes[:, :, 3], max=height) 62 | 63 | return boxes 64 | 65 | 66 | class Anchors(nn.Module): 67 | def __init__(self, pyramid_levels=None, strides=None, sizes=None, ratios=None, scales=None): 68 | super(Anchors, self).__init__() 69 | 70 | if pyramid_levels is None: 71 | self.pyramid_levels = [3, 4, 5, 6, 7] 72 | if strides is None: 73 | self.strides = [2 ** x for x in self.pyramid_levels] 74 | if sizes is None: 75 | self.sizes = [2 ** (x + 2) for x in self.pyramid_levels] 76 | if ratios is None: 77 | self.ratios = np.array([0.5, 1, 2]) 78 | if scales is None: 79 | self.scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]) 80 | 81 | def forward(self, image): 82 | 83 | image_shape = image.shape[2:] 84 | image_shape = np.array(image_shape) 85 | image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in self.pyramid_levels] 86 | 87 | all_anchors = np.zeros((0, 4)).astype(np.float32) 88 | 89 | for idx, p in enumerate(self.pyramid_levels): 90 | anchors = generate_anchors(base_size=self.sizes[idx], ratios=self.ratios, scales=self.scales) 91 | shifted_anchors = shift(image_shapes[idx], self.strides[idx], anchors) 92 | all_anchors = np.append(all_anchors, shifted_anchors, axis=0) 93 | 94 | all_anchors = np.expand_dims(all_anchors, axis=0) 95 | 96 | anchors = torch.from_numpy(all_anchors.astype(np.float32)) 97 | if torch.cuda.is_available(): 98 | anchors = anchors.cuda() 99 | return anchors 100 | 101 | 102 | def generate_anchors(base_size=16, ratios=None, scales=None): 103 | if ratios is None: 104 | ratios = np.array([0.5, 1, 2]) 105 | 106 | if scales is None: 107 | scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]) 108 | 109 | num_anchors = len(ratios) * len(scales) 110 | anchors = np.zeros((num_anchors, 4)) 111 | anchors[:, 2:] = base_size * np.tile(scales, (2, len(ratios))).T 112 | areas = anchors[:, 2] * anchors[:, 3] 113 | anchors[:, 2] = np.sqrt(areas / np.repeat(ratios, len(scales))) 114 | anchors[:, 3] = anchors[:, 2] * np.repeat(ratios, len(scales)) 115 | anchors[:, 0::2] -= np.tile(anchors[:, 2] * 0.5, (2, 1)).T 116 | anchors[:, 1::2] -= np.tile(anchors[:, 3] * 0.5, (2, 1)).T 117 | 118 | return anchors 119 | 120 | 121 | def compute_shape(image_shape, pyramid_levels): 122 | image_shape = np.array(image_shape[:2]) 123 | image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in pyramid_levels] 124 | return image_shapes 125 | 126 | 127 | def shift(shape, stride, anchors): 128 | shift_x = (np.arange(0, shape[1]) + 0.5) * stride 129 | shift_y = (np.arange(0, shape[0]) + 0.5) * stride 130 | shift_x, shift_y = np.meshgrid(shift_x, shift_y) 131 | shifts = np.vstack(( 132 | shift_x.ravel(), shift_y.ravel(), 133 | shift_x.ravel(), shift_y.ravel() 134 | )).transpose() 135 | 136 | A = anchors.shape[0] 137 | K = shifts.shape[0] 138 | all_anchors = (anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2))) 139 | all_anchors = all_anchors.reshape((K * A, 4)) 140 | 141 | return all_anchors 142 | -------------------------------------------------------------------------------- /test_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from torchvision import transforms 5 | from src.dataset import CocoDataset, Resizer, Normalizer 6 | from src.config import COCO_CLASSES, colors 7 | import cv2 8 | import shutil 9 | 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser( 13 | "EfficientDet: Scalable and Efficient Object Detection implementation by Signatrix GmbH") 14 | parser.add_argument("--image_size", type=int, default=512, help="The common width and height for all images") 15 | parser.add_argument("--data_path", type=str, default="data/COCO", help="the root folder of dataset") 16 | parser.add_argument("--cls_threshold", type=float, default=0.5) 17 | parser.add_argument("--nms_threshold", type=float, default=0.5) 18 | parser.add_argument("--pretrained_model", type=str, default="trained_models/signatrix_efficientdet_coco.pth") 19 | parser.add_argument("--output", type=str, default="predictions") 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | 25 | def test(opt): 26 | model = torch.load(opt.pretrained_model) 27 | model.cuda() 28 | dataset = CocoDataset(opt.data_path, set='val2017', transform=transforms.Compose([Normalizer(), Resizer()])) 29 | 30 | if os.path.isdir(opt.output): 31 | shutil.rmtree(opt.output) 32 | os.makedirs(opt.output) 33 | 34 | for index in range(len(dataset)): 35 | data = dataset[index] 36 | scale = data['scale'] 37 | with torch.no_grad(): 38 | scores, labels, boxes = model(data['img'].cuda().permute(2, 0, 1).float().unsqueeze(dim=0)) 39 | boxes /= scale 40 | 41 | if boxes.shape[0] > 0: 42 | image_info = dataset.coco.loadImgs(dataset.image_ids[index])[0] 43 | path = os.path.join(dataset.root_dir, 'images', dataset.set_name, image_info['file_name']) 44 | output_image = cv2.imread(path) 45 | 46 | for box_id in range(boxes.shape[0]): 47 | pred_prob = float(scores[box_id]) 48 | if pred_prob < opt.cls_threshold: 49 | break 50 | pred_label = int(labels[box_id]) 51 | xmin, ymin, xmax, ymax = boxes[box_id, :] 52 | color = colors[pred_label] 53 | xmin = int(round(float(xmin), 0)) 54 | ymin = int(round(float(ymin), 0)) 55 | xmax = int(round(float(xmax), 0)) 56 | ymax = int(round(float(ymax), 0)) 57 | cv2.rectangle(output_image, (xmin, ymin), (xmax, ymax), color, 2) 58 | text_size = cv2.getTextSize(COCO_CLASSES[pred_label] + ' : %.2f' % pred_prob, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0] 59 | 60 | cv2.rectangle(output_image, (xmin, ymin), (xmin + text_size[0] + 3, ymin + text_size[1] + 4), color, -1) 61 | cv2.putText( 62 | output_image, COCO_CLASSES[pred_label] + ' : %.2f' % pred_prob, 63 | (xmin, ymin + text_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, 64 | (255, 255, 255), 1) 65 | 66 | cv2.imwrite("{}/{}_prediction.jpg".format(opt.output, image_info["file_name"][:-4]), output_image) 67 | 68 | 69 | if __name__ == "__main__": 70 | opt = get_args() 71 | test(opt) 72 | -------------------------------------------------------------------------------- /test_video.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from src.config import COCO_CLASSES, colors 4 | import cv2 5 | import numpy as np 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser( 10 | "EfficientDet: Scalable and Efficient Object Detection implementation by Signatrix GmbH") 11 | parser.add_argument("--image_size", type=int, default=512, help="The common width and height for all images") 12 | parser.add_argument("--cls_threshold", type=float, default=0.5) 13 | parser.add_argument("--nms_threshold", type=float, default=0.5) 14 | parser.add_argument("--pretrained_model", type=str, default="trained_models/signatrix_efficientdet_coco.pth") 15 | parser.add_argument("--input", type=str, default="test_videos/input.mp4") 16 | parser.add_argument("--output", type=str, default="test_videos/output.mp4") 17 | 18 | args = parser.parse_args() 19 | return args 20 | 21 | 22 | def test(opt): 23 | model = torch.load(opt.pretrained_model).module 24 | if torch.cuda.is_available(): 25 | model.cuda() 26 | 27 | cap = cv2.VideoCapture(opt.input) 28 | out = cv2.VideoWriter(opt.output, cv2.VideoWriter_fourcc(*"MJPG"), int(cap.get(cv2.CAP_PROP_FPS)), 29 | (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))) 30 | while cap.isOpened(): 31 | flag, image = cap.read() 32 | output_image = np.copy(image) 33 | if flag: 34 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 35 | else: 36 | break 37 | height, width = image.shape[:2] 38 | image = image.astype(np.float32) / 255 39 | image[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 40 | image[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224 41 | image[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225 42 | if height > width: 43 | scale = opt.image_size / height 44 | resized_height = opt.image_size 45 | resized_width = int(width * scale) 46 | else: 47 | scale = opt.image_size / width 48 | resized_height = int(height * scale) 49 | resized_width = opt.image_size 50 | 51 | image = cv2.resize(image, (resized_width, resized_height)) 52 | 53 | new_image = np.zeros((opt.image_size, opt.image_size, 3)) 54 | new_image[0:resized_height, 0:resized_width] = image 55 | new_image = np.transpose(new_image, (2, 0, 1)) 56 | new_image = new_image[None, :, :, :] 57 | new_image = torch.Tensor(new_image) 58 | if torch.cuda.is_available(): 59 | new_image = new_image.cuda() 60 | with torch.no_grad(): 61 | scores, labels, boxes = model(new_image) 62 | boxes /= scale 63 | if boxes.shape[0] == 0: 64 | continue 65 | 66 | for box_id in range(boxes.shape[0]): 67 | pred_prob = float(scores[box_id]) 68 | if pred_prob < opt.cls_threshold: 69 | break 70 | pred_label = int(labels[box_id]) 71 | xmin, ymin, xmax, ymax = boxes[box_id, :] 72 | color = colors[pred_label] 73 | cv2.rectangle(output_image, (xmin, ymin), (xmax, ymax), color, 2) 74 | text_size = cv2.getTextSize(COCO_CLASSES[pred_label] + ' : %.2f' % pred_prob, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0] 75 | cv2.rectangle(output_image, (xmin, ymin), (xmin + text_size[0] + 3, ymin + text_size[1] + 4), color, -1) 76 | cv2.putText( 77 | output_image, COCO_CLASSES[pred_label] + ' : %.2f' % pred_prob, 78 | (xmin, ymin + text_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, 79 | (255, 255, 255), 1) 80 | out.write(output_image) 81 | 82 | cap.release() 83 | out.release() 84 | 85 | 86 | if __name__ == "__main__": 87 | opt = get_args() 88 | test(opt) 89 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | from src.dataset import CocoDataset, Resizer, Normalizer, Augmenter, collater 8 | from src.model import EfficientDet 9 | from tensorboardX import SummaryWriter 10 | import shutil 11 | import numpy as np 12 | from tqdm.autonotebook import tqdm 13 | 14 | 15 | def get_args(): 16 | parser = argparse.ArgumentParser( 17 | "EfficientDet: Scalable and Efficient Object Detection implementation by Signatrix GmbH") 18 | parser.add_argument("--image_size", type=int, default=512, help="The common width and height for all images") 19 | parser.add_argument("--batch_size", type=int, default=8, help="The number of images per batch") 20 | parser.add_argument("--lr", type=float, default=1e-4) 21 | parser.add_argument('--alpha', type=float, default=0.25) 22 | parser.add_argument('--gamma', type=float, default=1.5) 23 | parser.add_argument("--num_epochs", type=int, default=500) 24 | parser.add_argument("--test_interval", type=int, default=1, help="Number of epoches between testing phases") 25 | parser.add_argument("--es_min_delta", type=float, default=0.0, 26 | help="Early stopping's parameter: minimum change loss to qualify as an improvement") 27 | parser.add_argument("--es_patience", type=int, default=0, 28 | help="Early stopping's parameter: number of epochs with no improvement after which training will be stopped. Set to 0 to disable this technique.") 29 | parser.add_argument("--data_path", type=str, default="data/COCO", help="the root folder of dataset") 30 | parser.add_argument("--log_path", type=str, default="tensorboard/signatrix_efficientdet_coco") 31 | parser.add_argument("--saved_path", type=str, default="trained_models") 32 | 33 | args = parser.parse_args() 34 | return args 35 | 36 | 37 | def train(opt): 38 | num_gpus = 1 39 | if torch.cuda.is_available(): 40 | num_gpus = torch.cuda.device_count() 41 | torch.cuda.manual_seed(123) 42 | else: 43 | torch.manual_seed(123) 44 | 45 | training_params = {"batch_size": opt.batch_size * num_gpus, 46 | "shuffle": True, 47 | "drop_last": True, 48 | "collate_fn": collater, 49 | "num_workers": 12} 50 | 51 | test_params = {"batch_size": opt.batch_size, 52 | "shuffle": False, 53 | "drop_last": False, 54 | "collate_fn": collater, 55 | "num_workers": 12} 56 | 57 | training_set = CocoDataset(root_dir=opt.data_path, set="train2017", 58 | transform=transforms.Compose([Normalizer(), Augmenter(), Resizer()])) 59 | training_generator = DataLoader(training_set, **training_params) 60 | 61 | test_set = CocoDataset(root_dir=opt.data_path, set="val2017", 62 | transform=transforms.Compose([Normalizer(), Resizer()])) 63 | test_generator = DataLoader(test_set, **test_params) 64 | 65 | model = EfficientDet(num_classes=training_set.num_classes()) 66 | 67 | 68 | if os.path.isdir(opt.log_path): 69 | shutil.rmtree(opt.log_path) 70 | os.makedirs(opt.log_path) 71 | 72 | if not os.path.isdir(opt.saved_path): 73 | os.makedirs(opt.saved_path) 74 | 75 | writer = SummaryWriter(opt.log_path) 76 | if torch.cuda.is_available(): 77 | model = model.cuda() 78 | model = nn.DataParallel(model) 79 | 80 | optimizer = torch.optim.Adam(model.parameters(), opt.lr) 81 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True) 82 | 83 | best_loss = 1e5 84 | best_epoch = 0 85 | model.train() 86 | 87 | num_iter_per_epoch = len(training_generator) 88 | for epoch in range(opt.num_epochs): 89 | model.train() 90 | # if torch.cuda.is_available(): 91 | # model.module.freeze_bn() 92 | # else: 93 | # model.freeze_bn() 94 | epoch_loss = [] 95 | progress_bar = tqdm(training_generator) 96 | for iter, data in enumerate(progress_bar): 97 | try: 98 | optimizer.zero_grad() 99 | if torch.cuda.is_available(): 100 | cls_loss, reg_loss = model([data['img'].cuda().float(), data['annot'].cuda()]) 101 | else: 102 | cls_loss, reg_loss = model([data['img'].float(), data['annot']]) 103 | 104 | cls_loss = cls_loss.mean() 105 | reg_loss = reg_loss.mean() 106 | loss = cls_loss + reg_loss 107 | if loss == 0: 108 | continue 109 | loss.backward() 110 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) 111 | optimizer.step() 112 | epoch_loss.append(float(loss)) 113 | total_loss = np.mean(epoch_loss) 114 | 115 | progress_bar.set_description( 116 | 'Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Batch loss: {:.5f} Total loss: {:.5f}'.format( 117 | epoch + 1, opt.num_epochs, iter + 1, num_iter_per_epoch, cls_loss, reg_loss, loss, 118 | total_loss)) 119 | writer.add_scalar('Train/Total_loss', total_loss, epoch * num_iter_per_epoch + iter) 120 | writer.add_scalar('Train/Regression_loss', reg_loss, epoch * num_iter_per_epoch + iter) 121 | writer.add_scalar('Train/Classfication_loss (focal loss)', cls_loss, epoch * num_iter_per_epoch + iter) 122 | 123 | except Exception as e: 124 | print(e) 125 | continue 126 | scheduler.step(np.mean(epoch_loss)) 127 | 128 | if epoch % opt.test_interval == 0: 129 | model.eval() 130 | loss_regression_ls = [] 131 | loss_classification_ls = [] 132 | for iter, data in enumerate(test_generator): 133 | with torch.no_grad(): 134 | if torch.cuda.is_available(): 135 | cls_loss, reg_loss = model([data['img'].cuda().float(), data['annot'].cuda()]) 136 | else: 137 | cls_loss, reg_loss = model([data['img'].float(), data['annot']]) 138 | 139 | cls_loss = cls_loss.mean() 140 | reg_loss = reg_loss.mean() 141 | 142 | loss_classification_ls.append(float(cls_loss)) 143 | loss_regression_ls.append(float(reg_loss)) 144 | 145 | cls_loss = np.mean(loss_classification_ls) 146 | reg_loss = np.mean(loss_regression_ls) 147 | loss = cls_loss + reg_loss 148 | 149 | print( 150 | 'Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'.format( 151 | epoch + 1, opt.num_epochs, cls_loss, reg_loss, 152 | np.mean(loss))) 153 | writer.add_scalar('Test/Total_loss', loss, epoch) 154 | writer.add_scalar('Test/Regression_loss', reg_loss, epoch) 155 | writer.add_scalar('Test/Classfication_loss (focal loss)', cls_loss, epoch) 156 | 157 | if loss + opt.es_min_delta < best_loss: 158 | best_loss = loss 159 | best_epoch = epoch 160 | torch.save(model, os.path.join(opt.saved_path, "signatrix_efficientdet_coco.pth")) 161 | 162 | dummy_input = torch.rand(opt.batch_size, 3, 512, 512) 163 | if torch.cuda.is_available(): 164 | dummy_input = dummy_input.cuda() 165 | if isinstance(model, nn.DataParallel): 166 | model.module.backbone_net.model.set_swish(memory_efficient=False) 167 | 168 | torch.onnx.export(model.module, dummy_input, 169 | os.path.join(opt.saved_path, "signatrix_efficientdet_coco.onnx"), 170 | verbose=False) 171 | model.module.backbone_net.model.set_swish(memory_efficient=True) 172 | else: 173 | model.backbone_net.model.set_swish(memory_efficient=False) 174 | 175 | torch.onnx.export(model, dummy_input, 176 | os.path.join(opt.saved_path, "signatrix_efficientdet_coco.onnx"), 177 | verbose=False) 178 | model.backbone_net.model.set_swish(memory_efficient=True) 179 | 180 | # Early stopping 181 | if epoch - best_epoch > opt.es_patience > 0: 182 | print("Stop training at epoch {}. The lowest loss achieved is {}".format(epoch, loss)) 183 | break 184 | writer.close() 185 | 186 | 187 | if __name__ == "__main__": 188 | opt = get_args() 189 | train(opt) 190 | -------------------------------------------------------------------------------- /trained_models/signatrix_efficientdet_coco.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/trained_models/signatrix_efficientdet_coco.onnx -------------------------------------------------------------------------------- /trained_models/signatrix_efficientdet_coco.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/signatrix/efficientdet/36faca88a0eec24a1473c3443ae9d7af03f20cae/trained_models/signatrix_efficientdet_coco.pth --------------------------------------------------------------------------------