├── .gitignore ├── LICENSE ├── README.md ├── notebooks ├── basic_prediction.ipynb ├── eval_coco_map.ipynb ├── train_voc.ipynb └── voc_utils.py ├── tests ├── __init__.py ├── mocks │ ├── dog.jpg │ ├── person.jpg │ ├── yololayer_tiny_0_get_loss_0.p │ └── yololayer_tiny_0_get_region_boxes_0.p ├── utils_test.py ├── yolo_layer_test.py ├── yolov3_test.py └── yolov3_tiny_test.py └── yolov3_pytorch ├── __init__.py ├── fastai_utils.py ├── utils.py ├── yolo_layer.py ├── yolov3.py ├── yolov3_base.py └── yolov3_tiny.py /.gitignore: -------------------------------------------------------------------------------- 1 | fastai 2 | weights 3 | data 4 | .vscode 5 | *~ 6 | 7 | 8 | 9 | # Github defaults 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Olli Huotari 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Yolov3_pytorch 2 | 3 | Yolov3 (+tiny) object detection - object oriented pythonic pytorch implementation. 4 | 5 | Tested with pytorch 0.4.0 and python>3.5. Some basic tests are included in [tests](https://github.com/holli/yolov3_pytorch/tree/master/test) folder. 6 | 7 | This repo has a goal to have simple pythonic object oriented implementation that can be easily used as it is and also easy to train or modify the model. 8 | 9 | See https://pjreddie.com/darknet/yolo/ for better explanation of how yolov3 object detection system differs from others. 10 | 11 | # Pretrained Weights 12 | 13 | Pretrained weights are available at: **http://www.ollihuotari.com/data/yolov3_pytorch/** . They are converted from https://pjreddie.com/darknet/yolo/. Check out the notebooks for examples how to use them. 14 | 15 | # Notebook Examples 16 | 17 | - **https://github.com/holli/yolov3_pytorch/blob/master/notebooks/basic_prediction.ipynb** 18 | - show's basic loading of model and prediction 19 | - **https://github.com/holli/yolov3_pytorch/blob/master/notebooks/eval_coco_map.ipynb** 20 | - map metric on coco evaluation data set. Just to make sure that this implementation is close enough to original implementation 21 | - **https://github.com/holli/yolov3_pytorch/blob/master/notebooks/train_voc.ipynb** 22 | - training on a new dataset. Using voc dataset for an example. 23 | 24 | # Support / Commits 25 | 26 | Submit suggestions or feature requests as a GitHub Issue or Pull Request. Preferably create a test to show whats happening and what should happen. 27 | 28 | # Other Implementations 29 | 30 | There are some good pytorch implementations previously but many of them were using original cfg files to create the model. This works well but it's harder to modify and test other approaches. Some of them didn't include yolov3-tiny model or didn't work with using images of different sizes (e.g. 608 pixel sizes instead of default 416). Some nicer ones include: 31 | 32 | - https://github.com/marvis/pytorch-yolo3 33 | - https://github.com/andy-yun/pytorch-0.4-yolov3 34 | - https://github.com/jiasenlu/YOLOv3.pytorch 35 | 36 | # Licence 37 | 38 | Released under the MIT license (http://www.opensource.org/licenses/mit-license.php) 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /notebooks/voc_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from fastai.imports import * 3 | from fastai.dataset import * 4 | from matplotlib import patches, patheffects 5 | 6 | def get_voc_md(data_filenames, sz=416, data_max_lines=False, tfms_trn=None, tfms_val=None): 7 | if not tfms_trn: 8 | #tfms_trn = [RandomRotate(10, tfm_y=TfmType.COORD), RandomLighting(0.20, 0.20), RandomBlur()] 9 | #tfms_trn = [RandomLighting(0.20, 0.20), RandomBlur()] 10 | tfms_trn = [RandomBlur()] 11 | # tfms_trn = image_gen(normalizer=None, denorm=None, sz=sz, crop_type=CropType.RANDOM, 12 | tfms_trn = image_gen(normalizer=None, denorm=None, sz=sz, crop_type=CropType.NO, 13 | max_zoom=1.2, tfm_y=TfmType.COORD, tfms=tfms_trn) 14 | if not tfms_val: 15 | tfms_val = image_gen(normalizer=None, denorm=None, sz=sz, crop_type=CropType.NO, 16 | # tfms_val = image_gen(normalizer=None, denorm=None, sz=sz, crop_type=CropType.CENTER, 17 | max_zoom=1, tfm_y=TfmType.COORD, tfms=[]) 18 | 19 | data_lines = [] 20 | for f in data_filenames: 21 | with open(f, 'r') as file: 22 | arr = file.readlines() 23 | arr = [s.rstrip('\n') for s in arr] 24 | data_lines.append(arr) 25 | 26 | if data_max_lines: 27 | if type(data_max_lines) == int: 28 | data_max_lines = [data_max_lines, data_max_lines] 29 | for i in range(len(data_lines)): 30 | data_lines[i] = data_lines[i][:data_max_lines[i]] 31 | len(data_lines[i]) 32 | 33 | datasets = [ 34 | VocDataset(data_lines[0], transform=tfms_trn, path='', sz=sz), # train 35 | VocDataset(data_lines[1], transform=tfms_val, path='', sz=sz), # valid 36 | VocDataset(data_lines[0], transform=tfms_val, path='', sz=sz), # fix 37 | VocDataset(data_lines[1], transform=tfms_trn, path='', sz=sz), # aug 38 | None, None # test datasets 39 | ] 40 | 41 | md = ImageData(path = "/tmp", datasets=datasets, bs=32, num_workers=2, classes=VocDataset.CLASS_NAMES) 42 | md.trn_dl.pre_pad = md.val_dl.pre_pad = md.fix_dl.pre_pad = md.aug_dl.pre_pad = False 43 | 44 | return md 45 | 46 | 47 | # To be used for example with https://github.com/rafaelpadilla/Object-Detection-Metrics 48 | # python pascalvoc.py --gtfolder /tmp/ai_mAP_1/ground --detfolder /tmp/ai_mAP_1/pred -gtcoords rel -detcoords rel -imgsize 416,416 --noplot 49 | def create_detection_files(validation_ds, tmp_dir='/tmp/ai_mAP_1', remove_old=True): 50 | for p in ['pred', 'ground']: 51 | p = os.path.join(tmp_dir, p) 52 | if os.path.exists(p): 53 | if remove_old: 54 | for f in glob.glob(os.path.join(p, "*.txt")): 55 | os.remove(f) 56 | else: 57 | os.makedirs(p) 58 | 59 | for i in range(len(md.val_ds.fnames)): 60 | imgfile = md.val_ds.fnames[i] 61 | img_org = Image.open(imgfile).convert('RGB') 62 | img_resized = img_org.resize((sz, sz)) 63 | img_torch = image2torch(img_resized).cuda() 64 | all_boxes = model.predict_img(img_torch)[0] 65 | boxes = nms(all_boxes, 0.4) 66 | 67 | fname = os.path.split(imgfile)[-1] 68 | fname = fname.replace('.png','.txt').replace('.jpg','.txt') 69 | det_fname = os.path.join(tmp_dir, 'pred', fname) 70 | with open(det_fname, 'w') as f: 71 | for box in boxes: 72 | box = np.array([b.item() for b in box]) 73 | box[:2] -= box[2:4]/2 74 | arr = [int(box[-1]), box[-2]] + list(box[0:4]) 75 | s = ' '.join([str(a) for a in arr]) + '\n' 76 | _ = f.write(s) 77 | 78 | g_fname = os.path.join(tmp_dir, 'ground', fname) 79 | with open(g_fname, 'w') as f: 80 | for box in md.val_ds.get_y(i): 81 | box = np.array(box) 82 | box[1:3] -= box[-2:]/2 83 | arr = [int(box[0])] + list(box[1:5]) 84 | s = ' '.join([str(a) for a in arr]) + '\n' 85 | _ = f.write(s) 86 | 87 | 88 | 89 | # class VocDataset(Dataset): 90 | # /home/ohu/koodi/data/voc/VOCdevkit/VOC2007/JPEGImages/000012.jpg 91 | # /home/ohu/koodi/data/voc/VOCdevkit/VOC2007/labels/000012.txt 92 | # Parsing from https://pjreddie.com/media/files/voc_label.py 93 | class VocDataset(FilesDataset): 94 | CLASS_NAMES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 95 | 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') 96 | 97 | def __init__(self, fnames, transform, path, sz): 98 | super().__init__(fnames, transform, path) 99 | self.sz = sz 100 | 101 | # Data is in center_x, center_y, width, height 102 | # VocDataset.read_labels('/home/ohu/koodi/data/voc/VOCdevkit/VOC2007/labels/000009.txt', 0.03) 103 | @staticmethod 104 | def read_labels(lab_path, min_box_scale=0.03): 105 | if os.path.exists(lab_path) and os.path.getsize(lab_path): 106 | all_truths = np.loadtxt(lab_path) 107 | all_truths = all_truths.reshape(all_truths.size//5, 5) # to avoid single truth problem 108 | else: 109 | all_truths = np.array([]) 110 | 111 | truths = [] 112 | for t in all_truths: 113 | if t[3] < min_box_scale or t[4] < min_box_scale: 114 | continue 115 | #truths.append([all_truths[i][0], all_truths[i][1], truths[i][2], truths[i][3], truths[i][4]]) 116 | truths.append(t) 117 | return np.array(truths) 118 | 119 | 120 | def get_y(self, i): 121 | path = os.path.join(self.path, self.fnames[i]) 122 | path = path.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt') 123 | # print(path) 124 | arr = self.read_labels(path, 0.03) 125 | return arr 126 | 127 | def get_c(self): 128 | return 20 # class numbers gmm? 129 | 130 | def get(self, tfm, x, y): # override so that tfm only handels part of the thingie 131 | # return (x,y) if tfm is None else tfm(x,y) 132 | w,h = x.shape[0], x.shape[1] 133 | #return x, y 134 | 135 | y1 = y[:, 0:1] 136 | y2 = y[:, 1:] 137 | y2[:, :2] -= y2[:, 2:]/2 # x1, y1, w, h 138 | y2[:, 2:] += y2[:, :2] # x1, y1, x2, y2 139 | y2[:, :] *= [h, w, h, w] # pixels 140 | #y2 *= model.width 141 | 142 | # swap y,x to x,y 143 | y2[:, 0], y2[:, 1] = y2[:, 1].copy(), y2[:, 0].copy() 144 | y2[:, 2], y2[:, 3] = y2[:, 3].copy(), y2[:, 2].copy() 145 | 146 | y2 = y2.reshape(-1) 147 | 148 | x, y2 = tfm(x,y2) 149 | 150 | y2 = y2.reshape(-1, 4) 151 | 152 | y2[:, 2:] -= y2[:, :2] 153 | y2[:, :2] += y2[:, 2:]/2 154 | 155 | # y2 /= model.width 156 | y2 /= self.sz 157 | # swap y,x to x,y 158 | y2[:, 1], y2[:, 0] = y2[:, 0].copy(), y2[:, 1].copy() 159 | y2[:, 3], y2[:, 2] = y2[:, 2].copy(), y2[:, 3].copy() 160 | 161 | y = np.concatenate((y1, y2), axis=1)[:50] # max 50 items 162 | y = y[(y[:, 3] > 0.001) & (y[:, 4] > 0.001)] 163 | 164 | if y.shape[0] < 50: 165 | y_pad = np.zeros((50-y.shape[0], 5)) 166 | y = np.concatenate((y, y_pad), 0) 167 | 168 | y = y.reshape(-1) 169 | return x, y 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/holli/yolov3_pytorch/a01ce7a4c56634a7cd2acbb00c95c5eeaf142e33/tests/__init__.py -------------------------------------------------------------------------------- /tests/mocks/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/holli/yolov3_pytorch/a01ce7a4c56634a7cd2acbb00c95c5eeaf142e33/tests/mocks/dog.jpg -------------------------------------------------------------------------------- /tests/mocks/person.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/holli/yolov3_pytorch/a01ce7a4c56634a7cd2acbb00c95c5eeaf142e33/tests/mocks/person.jpg -------------------------------------------------------------------------------- /tests/mocks/yololayer_tiny_0_get_loss_0.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/holli/yolov3_pytorch/a01ce7a4c56634a7cd2acbb00c95c5eeaf142e33/tests/mocks/yololayer_tiny_0_get_loss_0.p -------------------------------------------------------------------------------- /tests/mocks/yololayer_tiny_0_get_region_boxes_0.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/holli/yolov3_pytorch/a01ce7a4c56634a7cd2acbb00c95c5eeaf142e33/tests/mocks/yololayer_tiny_0_get_region_boxes_0.p -------------------------------------------------------------------------------- /tests/utils_test.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest, torch, unittest, bcolz 3 | import numpy as np 4 | # from unittest import mock 5 | # from unittest.mock import Mock 6 | # from testfixtures import tempdir 7 | 8 | from yolov3_pytorch import utils 9 | 10 | 11 | class NmsTest(unittest.TestCase): 12 | def test_empty(self): 13 | result = utils.nms([], .4) 14 | self.assertEqual(len(result), 0) 15 | 16 | def test_basic_big_array(self): 17 | boxes = [[.2, .2, .2, .2, .8], [.22, .22, .2, .2, 0.7], [.6, .6, .2, .2, 0.5], [.22, .22, .22, .22, 0.2], [.61, .61, .2, .2, 0.2], 18 | [.2, .2, .01, .01, .2]] 19 | result = utils.nms(boxes, .4) 20 | self.assertEqual(result, [[.2, .2, .2, .2, .8], [.6, .6, .2, .2, 0.5], [.2, .2, .01, .01, .2]]) 21 | 22 | def test_basic_array(self): 23 | boxes = [[.2, .2, .2, .2, .8], [.3, .2, .2, .2, 0.7]] 24 | result = utils.nms(boxes, .2) 25 | self.assertEqual(boxes, [[.2, .2, .2, .2, .8], [.3, .2, .2, .2, 0.7]], "nms should not change input") 26 | self.assertEqual(result, [[.2, .2, .2, .2, .8]]) 27 | 28 | boxes = [[.2, .2, .2, .2, .8], [.3, .2, .2, .2, 0.7]] 29 | result = utils.nms(boxes, .7) 30 | self.assertEqual(boxes, [[.2, .2, .2, .2, .8], [.3, .2, .2, .2, 0.7]], "nms should not change input") 31 | self.assertEqual(result, boxes) 32 | 33 | 34 | class BboxIouTest(unittest.TestCase): 35 | def test_bbox_iou_xywh_torch(self): 36 | box1 = torch.FloatTensor([.2, .2, .2, .2]) # from 0.1 to 0.3 37 | box2 = torch.FloatTensor([.3, .3, .2, .2]) # from 0.2 to 0.4 38 | iou = utils.bbox_iou(box1, box2, x1y1x2y2=False) 39 | 40 | intersect = (.1**2) 41 | iou_expected = intersect/(.2**2+.2**2-intersect) 42 | self.assertAlmostEqual(iou, iou_expected, places=5) 43 | 44 | def test_bbox_iou_xywh_numpy(self): 45 | box1 = np.array([.2, .2, .2, .2]) # from 0.1 to 0.3 46 | box2 = np.array([.3, .3, .2, .2]) # from 0.2 to 0.4 47 | iou = utils.bbox_iou(box1, box2, x1y1x2y2=False) 48 | 49 | intersect = (.1**2) 50 | iou_expected = intersect/(.2**2+.2**2-intersect) 51 | self.assertAlmostEqual(iou, iou_expected, places=5) 52 | 53 | def test_bbox_iou_xyxy(self): 54 | box1 = torch.FloatTensor([.1, .1, .3, .3]) 55 | box2 = torch.FloatTensor([.2, .2, .4, .4]) 56 | iou = utils.bbox_iou(box1, box2, x1y1x2y2=True) 57 | 58 | intersect = (.1**2) 59 | iou_expected = intersect/(.2**2+.2**2-intersect) 60 | self.assertAlmostEqual(iou, iou_expected, places=5) 61 | 62 | class MultiBboxIouTest(unittest.TestCase): 63 | def test_multi_bbox_iou_xywh(self): 64 | box1 = torch.FloatTensor([[.2, .2, .2, .2], [.8,.8,.2,.2]]) 65 | box2 = torch.FloatTensor([[.2, .2, .2, .2], [.7,.7,.2,.2]]) 66 | 67 | iou_from_singles_0 = utils.bbox_iou(box1[0], box2[0], x1y1x2y2=False) 68 | iou_from_singles_1 = utils.bbox_iou(box1[1], box2[1], x1y1x2y2=False) 69 | 70 | box1 = box1.t().view(4,2) 71 | box2 = box2.t().view(4,2) 72 | iou = utils.multi_bbox_ious(box1, box2, x1y1x2y2=False) 73 | 74 | self.assertAlmostEqual(iou[0], iou_from_singles_0, places=5) 75 | self.assertAlmostEqual(iou[1], iou_from_singles_1, places=5) 76 | 77 | def test_multi_bbox_iou_xyxy(self): 78 | box1 = torch.FloatTensor([[.1, .1, .2, .2], [.7,.7,.9,.9]]) 79 | box2 = torch.FloatTensor([[.1, .1, .2, .2], [.6,.6,.8,.8]]) 80 | 81 | iou_from_singles_0 = utils.bbox_iou(box1[0], box2[0], x1y1x2y2=True) 82 | iou_from_singles_1 = utils.bbox_iou(box1[1], box2[1], x1y1x2y2=True) 83 | 84 | box1 = box1.t().view(4,2) 85 | box2 = box2.t().view(4,2) 86 | iou = utils.multi_bbox_ious(box1, box2, x1y1x2y2=True) 87 | 88 | self.assertAlmostEqual(iou[0], iou_from_singles_0, places=5) 89 | self.assertAlmostEqual(iou[1], iou_from_singles_1, places=5) 90 | 91 | -------------------------------------------------------------------------------- /tests/yolo_layer_test.py: -------------------------------------------------------------------------------- 1 | import pytest, torch, unittest, bcolz, pickle 2 | import numpy as np 3 | # from unittest import mock 4 | # from unittest.mock import Mock 5 | # from testfixtures import tempdir 6 | 7 | from yolov3_pytorch import utils 8 | from yolov3_pytorch.yolo_layer import * 9 | 10 | 11 | class YoloGetRegionBoxesTest(unittest.TestCase): 12 | 13 | def test_get_region_boxes_large(self): 14 | # output = model(img_torch)[0] 15 | # region_boxes = model.yolo_0.get_region_boxes(output, conf_thresh=.25) 16 | # region_boxes = [[[i.item() for i in box] for box in region_boxes[0]]] 17 | # attrs = {'conf_thresh': .25, 'num_classes': model.yolo_0.num_classes, 'stride': model.yolo_0.stride, 18 | # 'anchors': model.yolo_0.anchors} 19 | # key_val = {'output': output, 'region_boxes': region_boxes, 'attrs': attrs} 20 | # pickle.dump(key_val, open("tests/data/yololayer_tiny_0_get_region_boxes_0.p", "wb")) 21 | 22 | key_val = pickle.load(open("tests/mocks/yololayer_tiny_0_get_region_boxes_0.p", "rb")) 23 | attrs = key_val['attrs']; output = key_val['output']; target = key_val['region_boxes'] 24 | 25 | yolo = YoloLayer(anchors=attrs['anchors'], stride=attrs['stride'], num_classes=attrs['num_classes']) 26 | 27 | region_boxes = yolo.get_region_boxes(output, conf_thresh=attrs['conf_thresh']) 28 | region_boxes = [[[float(i) for i in box] for box in region_boxes[0]]] 29 | 30 | region_boxes = sorted(region_boxes[0], key=lambda x: x[0]) 31 | target = sorted(target[0], key=lambda x: x[0]) 32 | 33 | self.assertEqual(region_boxes, target) 34 | 35 | 36 | class YoloGetLossTest(unittest.TestCase): 37 | 38 | # Pickle file was created when it was working. This is mainly to check that refactorings wont break the code. 39 | def test_get_loss_large_data(self): 40 | key_val = pickle.load(open("tests/mocks/yololayer_tiny_0_get_loss_0.p", "rb")) 41 | attrs = key_val['yolo']; output = key_val['output']; target = key_val['target']; losses = key_val['losses'] 42 | 43 | yolo = YoloLayer(anchors=attrs['anchors'], stride=attrs['stride'], num_classes=attrs['num_classes']) 44 | 45 | loss_total, loss_coord, loss_conf, loss_cls = yolo.get_loss(output, target, return_single_value=False) 46 | 47 | self.assertAlmostEqual(losses[0], loss_total.item()) 48 | self.assertAlmostEqual(losses[1], loss_coord.item()) 49 | self.assertAlmostEqual(losses[2], loss_conf.item()) 50 | self.assertAlmostEqual(losses[3], loss_cls.item()) 51 | 52 | def test_get_loss_large_data_cuda(self): 53 | key_val = pickle.load(open("tests/mocks/yololayer_tiny_0_get_loss_0.p", "rb")) 54 | attrs = key_val['yolo']; output = key_val['output']; target = key_val['target']; losses = key_val['losses'] 55 | 56 | output.to('cuda') 57 | 58 | yolo = YoloLayer(anchors=attrs['anchors'], stride=attrs['stride'], num_classes=attrs['num_classes']) 59 | 60 | loss_total, loss_coord, loss_conf, loss_cls = yolo.get_loss(output, target, return_single_value=False) 61 | 62 | self.assertAlmostEqual(losses[0], loss_total.item()) 63 | self.assertAlmostEqual(losses[1], loss_coord.item()) 64 | self.assertAlmostEqual(losses[2], loss_conf.item()) 65 | self.assertAlmostEqual(losses[3], loss_cls.item()) 66 | 67 | 68 | # def test_get_loss_calculated(self): 69 | # yolo = YoloLayer(anchors=attrs['anchors'], stride=attrs['stride'], num_classes=attrs['num_classes']) 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /tests/yolov3_test.py: -------------------------------------------------------------------------------- 1 | import pytest, torch, unittest, bcolz, pickle 2 | import numpy as np 3 | # from unittest import mock 4 | # from unittest.mock import Mock 5 | # from testfixtures import tempdir 6 | from PIL import Image 7 | 8 | from yolov3_pytorch import utils 9 | # from yolov3_pytorch.yolo_layer import * 10 | from yolov3_pytorch.yolov3 import * 11 | 12 | # import pdb; pdb.set_trace() 13 | 14 | class IntegrationYolov3Test(unittest.TestCase): 15 | 16 | def test_basic_process(self): 17 | model = Yolov3(num_classes=80) 18 | model.load_state_dict(torch.load('data/models/yolov3_coco_01.h5')) 19 | _ = model.eval() # .cuda() 20 | 21 | sz = 416 22 | imgfile = "tests/mocks/person.jpg" 23 | img_org = Image.open(imgfile).convert('RGB') 24 | img_resized = img_org.resize((sz, sz)) 25 | img_torch = utils.image2torch(img_resized) 26 | 27 | output_all = model(img_torch) 28 | self.assertEqual(len(output_all), 3, "Basic output should be for 3 yolo layers") 29 | self.assertEqual(list(output_all[0].shape), [1, 255, 13, 13], "Basic output shape should be correct") 30 | 31 | all_boxes = model.predict_img(img_torch)[0] 32 | self.assertTrue(len(all_boxes) > 2, "Should detect something in img") 33 | self.assertTrue(len(all_boxes) < 20, "Should not detect too much in img") 34 | 35 | nms_boxes = utils.nms(all_boxes, .4) 36 | persons = [a for a in nms_boxes if a[-1] == 0] 37 | 38 | self.assertEqual(1, len(persons), "Should detect one person in img") 39 | 40 | # If there is a little difference it might not be a bug, but its here to alert. 41 | # You can comment away tests that are failing if things are otherwise correct 42 | self.assertEqual(13, len(all_boxes), "Something has changed in predictions") 43 | previous_person = [0.36519092321395874, 0.5594505071640015, 0.13189448416233063, 0.6678988933563232, 0.9999792575836182, 1.0, 0] 44 | self.assertEqual(previous_person, persons[0], "Something has changed in predictions") 45 | 46 | 47 | -------------------------------------------------------------------------------- /tests/yolov3_tiny_test.py: -------------------------------------------------------------------------------- 1 | import pytest, torch, unittest, bcolz, pickle 2 | import numpy as np 3 | # from unittest import mock 4 | # from unittest.mock import Mock 5 | # from testfixtures import tempdir 6 | from PIL import Image 7 | 8 | from yolov3_pytorch import utils 9 | from yolov3_pytorch.yolo_layer import * 10 | from yolov3_pytorch.yolov3_tiny import * 11 | 12 | # import pdb; pdb.set_trace() 13 | 14 | class IntegrationYolov3TinyTest(unittest.TestCase): 15 | 16 | def test_basic_process(self): 17 | model = Yolov3Tiny(num_classes=80, use_wrong_previous_anchors=True) 18 | model.load_state_dict(torch.load('data/models/yolov3_tiny_coco_01.h5')) 19 | _ = model.eval() # .cuda() 20 | 21 | sz = 416 22 | imgfile = "tests/mocks/person.jpg" 23 | img_org = Image.open(imgfile).convert('RGB') 24 | img_resized = img_org.resize((sz, sz)) 25 | img_torch = utils.image2torch(img_resized) 26 | 27 | output_all = model(img_torch) 28 | self.assertEqual(len(output_all), 2, "Basic output should be for 2 yolo layers") 29 | self.assertEqual(list(output_all[0].shape), [1, 255, 13, 13], "Basic output shape should be correct") 30 | 31 | all_boxes = model.predict_img(img_torch)[0] 32 | self.assertTrue(len(all_boxes) > 2, "Should detect something in img") 33 | self.assertTrue(len(all_boxes) < 20, "Should not detect too much in img") 34 | 35 | nms_boxes = utils.nms(all_boxes, .4) 36 | persons = [a for a in nms_boxes if a[-1] == 0] 37 | 38 | self.assertEqual(1, len(persons), "Should detect one person in img") 39 | 40 | # If there is a little difference it might not be a bug, but its here to alert. 41 | # You can comment away tests that are failing if things are otherwise correct 42 | self.assertEqual(9, len(all_boxes), "Something has changed in predictions") 43 | previous_person = [0.35972878336906433, 0.5600799322128296, 0.15276280045509338, 0.6586271524429321, 0.9670860767364502, 1.0, 0] 44 | self.assertEqual(previous_person, persons[0], "Something has changed in predictions") 45 | 46 | 47 | -------------------------------------------------------------------------------- /yolov3_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/holli/yolov3_pytorch/a01ce7a4c56634a7cd2acbb00c95c5eeaf142e33/yolov3_pytorch/__init__.py -------------------------------------------------------------------------------- /yolov3_pytorch/fastai_utils.py: -------------------------------------------------------------------------------- 1 | import fastai 2 | from fastai.imports import * 3 | from fastai.conv_learner import * 4 | from fastai.model import * 5 | from fastai.dataset import * 6 | import collections 7 | 8 | 9 | def learn_sched_plot(learn): 10 | type(learn.sched) 11 | fig, axs = plt.subplots(1,2,figsize=(12,4)) 12 | plt.sca(axs[0]) 13 | learn.sched.plot_loss(0, 1) 14 | plt.sca(axs[1]) 15 | learn.sched.plot_lr() 16 | 17 | 18 | class MultiArraysDataset(BaseDataset): 19 | def __init__(self, y, xs, sz=None): 20 | self.xs,self.y = xs,y 21 | self.sz = sz 22 | assert(all([len(y)==len(x) for x in xs])) 23 | 24 | #super().__init__(None) 25 | self.n = self.get_n() 26 | self.sz = self.get_sz() 27 | self.transform = None 28 | 29 | def get1item(self, idx): 30 | x,y = self.get_x(idx), self.get_y(idx) 31 | return (*x, y) 32 | 33 | def __getitem__(self, idx): 34 | return self.get1item(idx) 35 | 36 | def get_x(self, i): return [x[i] for x in self.xs] 37 | def get_y(self, i): return self.y[i] 38 | def get_n(self): return len(self.y) 39 | def get_sz(self): return self.sz 40 | 41 | 42 | class YoloLearner(Learner): 43 | def __init__(self, data, model, **kwargs): 44 | self.precompute = False 45 | super().__init__(data=data, models=None, tmp_name='/tmp/yolo_tmp/', **kwargs) 46 | self._model = model 47 | self.name = 'yolo_test_1' 48 | 49 | @property 50 | def model(self): 51 | if self.precompute: 52 | self._model.skip_backbone = True 53 | else: 54 | self._model.skip_backbone = False 55 | return self._model 56 | 57 | @property 58 | def data(self): return self.backbone_data if self.precompute else self.data_ 59 | 60 | def get_layer_groups(self): 61 | modules = list(self.model.children()) 62 | groups = list(split_by_idxs(modules, [1])) 63 | return groups 64 | 65 | def set_precomputed(self, force_predict=False): 66 | self.save_backbone_data(force_predict=force_predict) 67 | self.precompute = True 68 | 69 | @staticmethod 70 | def create_empty_bcolz(n, rootdir): 71 | if not os.path.exists(rootdir): os.makedirs(rootdir) 72 | return bcolz.carray(np.zeros(n, np.float32), chunklen=1, mode='w', rootdir=rootdir) 73 | 74 | @staticmethod 75 | def predict_backbone_to_bcolz(m, gen, arrs, workers=4): 76 | device = next(m.parameters()).device 77 | #arrs.trim(len(arr)) 78 | lock=threading.Lock() 79 | m.eval() 80 | for x_org, target in tqdm(gen): 81 | x_acts = m(x_org.to(device)) 82 | with lock: 83 | arrs[0].append(target) 84 | for i, x in enumerate(x_acts): 85 | arrs[i+1].append(to_np(x)) 86 | [a.flush() for a in arrs] 87 | 88 | def save_backbone_data(self, force_predict=False): 89 | tmpl = f'_{self.name}_{self.data_.sz}.bc' 90 | activations = [] 91 | 92 | # Checking what shape of activations backbone outputs 93 | sample_batch = next(iter(self.data_.trn_dl)) 94 | data_shapes = [(sample_batch[1][0].size(0),)] 95 | for out in self._model.backbone(sample_batch[0][0:1].to(next(self._model.backbone.parameters()).device)): 96 | data_shapes.append(out[0].detach().cpu().numpy().shape) 97 | 98 | # Creating or loading bcolz files 99 | dls = (self.data_.fix_dl, self.data_.val_dl) 100 | paths = ('x_act_trn', 'x_act_val') 101 | for i in range(len(dls)): 102 | name = os.path.join(self.tmp_path, paths[i]+tmpl) 103 | shapes_r = range(len(data_shapes)) 104 | 105 | if os.path.exists(os.path.join(name,str(1))) and not force_predict: 106 | acts = [bcolz.open(os.path.join(name,str(i))) for i in shapes_r] 107 | else: 108 | acts = [self.create_empty_bcolz((0,*data_shapes[i]),os.path.join(name,str(i))) for i in shapes_r] 109 | 110 | activations.append(acts) 111 | 112 | # Validate or predict activations 113 | m = self._model.backbone 114 | for acts, dl in zip(activations, dls): 115 | if any([len(a)>0 for a in acts]): 116 | # compare loaded activations to expected 117 | for a, s in zip(acts, data_shapes): 118 | exp_shape = (len(dl.dataset), *s) 119 | if a.shape != exp_shape: 120 | raise TypeError(f"Previous backbone activations won't match. You might want to call learn.save_backbone_data(force_predict=True) to erase previous or change the name of the model. {a.shape} != {exp_shape}: {a.rootdir}") 121 | else: 122 | self.predict_backbone_to_bcolz(m, dl, acts) 123 | 124 | ds_trn = MultiArraysDataset(activations[0][0], activations[0][1:], sz=self.data_.fix_dl.dataset.sz) 125 | ds_val = MultiArraysDataset(activations[1][0], activations[1][1:], sz=self.data_.fix_dl.dataset.sz) 126 | dl_trn = DataLoader(ds_trn, batch_size=self.data_.bs, shuffle=True) 127 | dl_val = DataLoader(ds_val, batch_size=self.data_.bs, shuffle=False) 128 | self.backbone_data = ModelData(self.data_.path, dl_trn, dl_val) 129 | 130 | 131 | class YoloLoss(): 132 | def __init__(self, model, max_history=10000, model_reset_overwrite=True): 133 | self.model, self.max_history = model, max_history 134 | self.reset() 135 | if model_reset_overwrite: 136 | model.reset = self.reset # fastai call's this before validation run 137 | 138 | def __call__(self, output, target): 139 | losses_all = [] 140 | total_losses = [] 141 | 142 | for i, layer in enumerate(self.model.get_loss_layers()): 143 | losses = layer.get_loss(output[i], target, return_single_value=False) 144 | total_losses.append(losses[0]) 145 | losses_all.append([l.item() for l in losses]) 146 | 147 | if self.max_history: 148 | self.history.append(losses_all) 149 | 150 | return sum(total_losses) 151 | 152 | def reset(self): 153 | if self.max_history: 154 | self.history = collections.deque(maxlen = self.max_history) 155 | else: 156 | self.history = None 157 | 158 | 159 | class YoloLossMetrics(): 160 | def __init__(self, yolo_loss): 161 | self.yolo_loss = yolo_loss 162 | self.set_n_layers() 163 | 164 | def set_n_layers(self): 165 | self.n_layers = len(self.yolo_loss.model.get_loss_layers()) 166 | 167 | def layer_losses(self): 168 | arr = [] 169 | for i in range(self.n_layers): 170 | l = lambda a=0,b=0,i=i: sum([h[i][0] for h in self.yolo_loss.history])/len(self.yolo_loss.history) 171 | l.__name__ = f"yolo_l_{i}" 172 | arr.append(l) 173 | return arr 174 | 175 | def individual_losses(self): 176 | arr = [] 177 | for i in range(1,4): 178 | l = lambda a=0,b=0,i=i: sum([sum([h[i] for h in his]) for his in self.yolo_loss.history])/(len(self.yolo_loss.history)) 179 | l.__name__ = ['total_loss', 'loss_coord', 'loss_conf', 'loss_cls'][i] 180 | arr.append(l) 181 | return arr 182 | -------------------------------------------------------------------------------- /yolov3_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | from PIL import Image, ImageDraw, ImageFont 3 | import sys 4 | import os 5 | import time 6 | import math 7 | import torch 8 | import numpy as np 9 | from matplotlib import pyplot as plt, rcParams, animation, patches, patheffects 10 | 11 | 12 | 13 | def nms(boxes, nms_thresh): 14 | if len(boxes) == 0: 15 | return boxes 16 | 17 | confs = [(1-b[4]) for b in boxes] 18 | sorted_idx = np.argsort(confs) 19 | out_boxes = [] 20 | 21 | for i in range(len(boxes)): 22 | box_i = boxes[sorted_idx[i]] 23 | if confs[i] > -1: 24 | out_boxes.append(box_i) 25 | for j in range(i+1, len(boxes)): 26 | if confs[j] > -1: 27 | box_j = boxes[sorted_idx[j]] 28 | if bbox_iou(box_i, box_j, x1y1x2y2=False) > nms_thresh: 29 | confs[j] = -1 30 | return out_boxes 31 | 32 | 33 | def do_detect(model, img, conf_thresh, nms_thresh, use_cuda=True): 34 | model.eval() 35 | img = image2torch(img) 36 | img = img.to(torch.device("cuda" if use_cuda else "cpu")) 37 | all_boxes = model.predict_img(img)[0] 38 | boxes = nms(all_boxes, nms_thresh) 39 | return boxes 40 | 41 | 42 | def image2torch(img): 43 | if isinstance(img, Image.Image): 44 | width = img.width 45 | height = img.height 46 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())) 47 | img = img.view(height, width, 3).transpose(0,1).transpose(0,2).contiguous() 48 | img = img.view(1, 3, height, width) 49 | img = img.float().div(255.0) 50 | elif type(img) == np.ndarray: # cv2 image 51 | img = torch.from_numpy(img.transpose(2,0,1)).float().div(255.0).unsqueeze(0) 52 | else: 53 | print("unknown image type") 54 | exit(-1) 55 | return img 56 | 57 | 58 | def bbox_iou(box1, box2, x1y1x2y2=True): 59 | if x1y1x2y2: 60 | x1_min = min(box1[0], box2[0]) 61 | x2_max = max(box1[2], box2[2]) 62 | y1_min = min(box1[1], box2[1]) 63 | y2_max = max(box1[3], box2[3]) 64 | w1, h1 = box1[2] - box1[0], box1[3] - box1[1] 65 | w2, h2 = box2[2] - box2[0], box2[3] - box2[1] 66 | else: 67 | w1, h1 = box1[2], box1[3] 68 | w2, h2 = box2[2], box2[3] 69 | x1_min = min(box1[0]-w1/2.0, box2[0]-w2/2.0) 70 | x2_max = max(box1[0]+w1/2.0, box2[0]+w2/2.0) 71 | y1_min = min(box1[1]-h1/2.0, box2[1]-h2/2.0) 72 | y2_max = max(box1[1]+h1/2.0, box2[1]+h2/2.0) 73 | 74 | w_union = x2_max - x1_min 75 | h_union = y2_max - y1_min 76 | w_cross = w1 + w2 - w_union 77 | h_cross = h1 + h2 - h_union 78 | carea = 0 79 | if w_cross <= 0 or h_cross <= 0: 80 | return 0.0 81 | 82 | area1 = w1 * h1 83 | area2 = w2 * h2 84 | carea = w_cross * h_cross 85 | uarea = area1 + area2 - carea 86 | return float(carea/uarea) 87 | 88 | 89 | def multi_bbox_ious(boxes1, boxes2, x1y1x2y2=True): 90 | if x1y1x2y2: 91 | x1_min = torch.min(boxes1[0], boxes2[0]) 92 | x2_max = torch.max(boxes1[2], boxes2[2]) 93 | y1_min = torch.min(boxes1[1], boxes2[1]) 94 | y2_max = torch.max(boxes1[3], boxes2[3]) 95 | w1, h1 = boxes1[2] - boxes1[0], boxes1[3] - boxes1[1] 96 | w2, h2 = boxes2[2] - boxes2[0], boxes2[3] - boxes2[1] 97 | else: 98 | w1, h1 = boxes1[2], boxes1[3] 99 | w2, h2 = boxes2[2], boxes2[3] 100 | x1_min = torch.min(boxes1[0]-w1/2.0, boxes2[0]-w2/2.0) 101 | x2_max = torch.max(boxes1[0]+w1/2.0, boxes2[0]+w2/2.0) 102 | y1_min = torch.min(boxes1[1]-h1/2.0, boxes2[1]-h2/2.0) 103 | y2_max = torch.max(boxes1[1]+h1/2.0, boxes2[1]+h2/2.0) 104 | 105 | w_union = x2_max - x1_min 106 | h_union = y2_max - y1_min 107 | w_cross = w1 + w2 - w_union 108 | h_cross = h1 + h2 - h_union 109 | mask = (((w_cross <= 0) + (h_cross <= 0)) > 0) 110 | area1 = w1 * h1 111 | area2 = w2 * h2 112 | carea = w_cross * h_cross 113 | carea[mask] = 0 114 | uarea = area1 + area2 - carea 115 | return carea/uarea 116 | 117 | 118 | ################################################################### 119 | ## Plotting helpers 120 | 121 | # e.g. plot_multi_detections(img_tensor, model.predict_img(img_tensor)) 122 | def plot_multi_detections(imgs, results, figsize=None, **kwargs): 123 | if not figsize: 124 | figsize = (12, min(math.ceil(len(imgs)/3)*4, 30)) 125 | _, axes = plt.subplots(math.ceil(len(imgs)/3), 3, figsize=figsize) 126 | 127 | if type(imgs) == np.ndarray and len(imgs.shape) == 4: 128 | imgs = [imgs] 129 | 130 | classes = [] 131 | boxes = [] 132 | extras = [] 133 | for r in results: 134 | res = np.array([[float(b) for b in arr] for arr in r]) 135 | if len(res) > 0: 136 | cla = res[:, -1].astype(int) 137 | b = res[:, 0:4] 138 | e = ["{:.2f} ({:.2f})".format(float(y[4]), float(y[5])) for y in res] 139 | else: 140 | cla, b, e = [], [], [] 141 | classes.append(cla) 142 | boxes.append(b) 143 | extras.append(e) 144 | 145 | for j, ax in enumerate(axes.flat): 146 | if j >= len(imgs): 147 | #break 148 | plt.delaxes(ax) 149 | else: 150 | plot_img_boxes(imgs[j], boxes[j], classes[j], extras[j], plt_ax=ax, **kwargs) 151 | 152 | plt.tight_layout() 153 | 154 | 155 | def plot_img_detections(img, result_boxes, **kwargs): 156 | b = np.array(result_boxes) 157 | if len(b) > 0: 158 | classes = b[:, -1].astype(int) 159 | boxes = b[:, 0:4] 160 | else: 161 | classes, boxes = [], [] 162 | extras = ["{:.2f} ({:.2f})".format(b[4], b[5]) for b in result_boxes] 163 | return plot_img_boxes(img, boxes, classes, extras=extras, **kwargs) 164 | 165 | 166 | def plot_img_data(x, y, rows=2, figsize=(12, 8), **kwargs): 167 | _, axes = plt.subplots(rows, 3, figsize=figsize) 168 | 169 | for j, ax in enumerate(axes.flat): 170 | if j >= len(y): 171 | break 172 | targets = y[j] 173 | if isinstance(targets, torch.Tensor): 174 | targets = targets.clone().reshape(-1,5) 175 | classes = targets[:, 0].cpu().numpy().astype(int) 176 | else: 177 | classes = targets[:, 0].astype(int) 178 | plot_img_boxes(x[j], targets[:, 1:], classes, plt_ax=ax, **kwargs) 179 | 180 | plt.tight_layout() 181 | 182 | 183 | def plot_img_boxes(img, boxes, classes, extras=None, plt_ax=None, figsize=None, class_names=None, real_pixels=False, box_centered=True): 184 | if not plt_ax: 185 | _, plt_ax = plt.subplots(figsize=figsize) 186 | colors = np.array([[1,0,1],[0,0,1],[0,1,1],[0,1,0],[1,1,0],[1,0,0]]) 187 | 188 | if type(img) == PIL.Image.Image: 189 | width = img.width 190 | height = img.height 191 | elif type(img) in [torch.Tensor, np.ndarray]: 192 | # if len(img.shape)>3: img = img[0] 193 | if type(img) == torch.Tensor: 194 | img = img.clone().cpu().numpy() 195 | width = img.shape[2] 196 | height = img.shape[1] 197 | img = img.transpose(1,2,0) 198 | if (img < 1.01).all() and (img >= 0).all(): 199 | img = img.clip(0, 1) # avoid "Clipping input data to the valid range" warning after tensor roundings 200 | else: 201 | raise(f"Unkown type for image: {type(img)}") 202 | 203 | if len(boxes) > 0 and not real_pixels: 204 | boxes[:, 0] *= width; boxes[:, 2] *= width 205 | boxes[:, 1] *= height; boxes[:, 3] *= height 206 | 207 | for i in range(len(boxes)): 208 | b, class_id = boxes[i], classes[i] 209 | if b[0] == 0: 210 | break 211 | 212 | color = colors[class_id%len(colors)] 213 | 214 | if box_centered: 215 | x, y = (b[0]-b[2]/2, b[1]-b[3]/2) 216 | w, h = (b[2], b[3]) 217 | else: 218 | x, y = b[0], b[1] 219 | w, h = b[2], b[3] 220 | 221 | patch = plt_ax.add_patch(patches.Rectangle([x, y], w, h, fill=False, edgecolor=color, lw=2)) 222 | patch.set_path_effects([patheffects.Stroke(linewidth=3, foreground='black', alpha=0.5), patheffects.Normal()]) 223 | 224 | s = class_names[class_id] if class_names else str(class_id) 225 | if extras: 226 | s += "\n"+str(extras[i]) 227 | patch = plt_ax.text(x+2, y, s, verticalalignment='top', color=color, fontsize=16, weight='bold') 228 | patch.set_path_effects([patheffects.Stroke(linewidth=1, foreground='black', alpha=0.5), patheffects.Normal()]) 229 | 230 | _ = plt_ax.imshow(img) 231 | 232 | -------------------------------------------------------------------------------- /yolov3_pytorch/yolo_layer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import sys 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from .utils import bbox_iou, multi_bbox_ious 10 | 11 | 12 | class YoloLayer(nn.Module): 13 | def __init__(self, anchors, stride, num_classes): 14 | super().__init__() 15 | self.anchors, self.stride = np.array(anchors), stride 16 | self.num_classes = num_classes 17 | 18 | def get_masked_anchors(self): 19 | return self.anchors/self.stride 20 | 21 | def get_region_boxes(self, output, conf_thresh): 22 | if output.dim() == 3: output = output.unsqueeze(0) 23 | device = output.device # torch.device(torch_device) 24 | anchors = torch.from_numpy(self.get_masked_anchors().astype(np.float32)).to(device) 25 | 26 | nB = output.size(0) 27 | nA = len(anchors) 28 | nC = self.num_classes 29 | nH = output.size(2) 30 | nW = output.size(3) 31 | cls_anchor_dim = nB*nA*nH*nW 32 | 33 | assert output.size(1) == (5+nC)*nA 34 | 35 | # if you want to debug this is how you get the indexes where objectness is high 36 | #output = output.view(nB, nA, 5+nC, nH, nW) 37 | #inds = torch.nonzero((torch.sigmoid(output.view(nB, nA, 5+nC, nH, nW)[:,:,4,:,:]) > conf_thresh)) 38 | 39 | output = output.view(nB*nA, 5+nC, nH*nW).transpose(0,1).contiguous().view(5+nC, cls_anchor_dim) 40 | 41 | grid_x = torch.linspace(0, nW-1, nW).repeat(nB*nA, nH, 1).view(cls_anchor_dim).to(device) 42 | grid_y = torch.linspace(0, nH-1, nH).repeat(nW,1).t().repeat(nB*nA, 1, 1).view(cls_anchor_dim).to(device) 43 | ix = torch.LongTensor(range(0,2)).to(device) 44 | anchor_w = anchors.index_select(1, ix[0]).repeat(1, nB, nH*nW).view(cls_anchor_dim) 45 | anchor_h = anchors.index_select(1, ix[1]).repeat(1, nB, nH*nW).view(cls_anchor_dim) 46 | 47 | xs, ys = torch.sigmoid(output[0]) + grid_x, torch.sigmoid(output[1]) + grid_y 48 | ws, hs = torch.exp(output[2]) * anchor_w.detach(), torch.exp(output[3]) * anchor_h.detach() 49 | det_confs = torch.sigmoid(output[4]) 50 | 51 | cls_confs = torch.nn.Softmax(dim=1)(output[5:5+nC].transpose(0,1)).detach() 52 | cls_max_confs, cls_max_ids = torch.max(cls_confs, 1) 53 | cls_max_confs = cls_max_confs.view(-1) 54 | cls_max_ids = cls_max_ids.view(-1) 55 | 56 | 57 | det_confs = det_confs.to('cpu') #, non_blocking=True for torch 4.1? 58 | cls_max_confs = cls_max_confs.to('cpu') 59 | cls_max_ids = cls_max_ids.to('cpu') 60 | xs, ys = xs.to('cpu'), ys.to('cpu') 61 | ws, hs = ws.to('cpu'), hs.to('cpu') 62 | 63 | 64 | all_boxes = [[] for i in range(nB)] 65 | 66 | inds = torch.LongTensor(range(0,len(det_confs))) 67 | for ind in inds[det_confs > conf_thresh]: 68 | bcx = xs[ind] 69 | bcy = ys[ind] 70 | bw = ws[ind] 71 | bh = hs[ind] 72 | # box = [bcx/nW, bcy/nH, bw/nW, bh/nH, det_confs[ind], cls_max_confs[ind], cls_max_ids[ind]] 73 | box = [bcx/nW, bcy/nH, bw/nW, bh/nH, det_confs[ind], cls_max_confs[ind], cls_max_ids[ind]] 74 | box = [i.item() for i in box] 75 | 76 | batch = math.ceil(ind/(nA*nH*nW)) 77 | all_boxes[batch].append(box) 78 | 79 | return all_boxes 80 | 81 | 82 | def build_targets(self, pred_boxes, target, anchors, nH, nW): 83 | self.ignore_thresh = 0.5 84 | self.truth_thresh = 1. 85 | 86 | # Works faster on CPU than on GPU. 87 | devi = torch.device('cpu') 88 | pred_boxes = pred_boxes.to(devi) 89 | target = target.to(devi) 90 | anchors = anchors.to(devi) 91 | 92 | #max_targets = target[0].view(-1,5).size(0) # 50 93 | nB = target.size(0) 94 | nA = len(anchors) 95 | 96 | anchor_step = anchors.size(1) # anchors[nA][anchor_step] 97 | conf_mask = torch.ones (nB, nA, nH, nW) 98 | coord_mask = torch.zeros(nB, nA, nH, nW) 99 | cls_mask = torch.zeros(nB, nA, nH, nW) 100 | tcoord = torch.zeros( 4, nB, nA, nH, nW) 101 | tconf = torch.zeros(nB, nA, nH, nW) 102 | tcls = torch.zeros(nB, nA, nH, nW) 103 | #twidth, theight = self.net_width/self.stride, self.net_height/self.stride 104 | twidth, theight = nW, nH 105 | nAnchors = nA*nH*nW 106 | 107 | for b in range(nB): 108 | cur_pred_boxes = pred_boxes[b*nAnchors:(b+1)*nAnchors].t() 109 | cur_ious = torch.zeros(nAnchors) 110 | tbox = target[b].view(-1,5) 111 | 112 | # If the bounding box prior is not the best but does overlap a ground truth object by 113 | # more than some threshold we ignore the prediction (conf_mask) 114 | for t in range(tbox.size(0)): 115 | if tbox[t][1] == 0: 116 | break 117 | gx, gy = tbox[t][1] * nW, tbox[t][2] * nH 118 | gw, gh = tbox[t][3] * twidth, tbox[t][4] * theight 119 | cur_gt_boxes = torch.FloatTensor([gx, gy, gw, gh]).repeat(nAnchors,1).t() 120 | cur_ious = torch.max(cur_ious, multi_bbox_ious(cur_pred_boxes, cur_gt_boxes, x1y1x2y2=False)) 121 | ignore_ix = cur_ious>self.ignore_thresh 122 | conf_mask[b][ignore_ix.view(nA,nH,nW)] = 0 123 | 124 | for t in range(tbox.size(0)): 125 | if tbox[t][1] == 0: 126 | break 127 | # nGT += 1 128 | gx, gy = tbox[t][1] * nW, tbox[t][2] * nH 129 | gw, gh = tbox[t][3] * twidth, tbox[t][4] * theight 130 | gw, gh = gw.float(), gh.float() 131 | gi, gj = int(gx), int(gy) 132 | 133 | tmp_gt_boxes = torch.FloatTensor([0, 0, gw, gh]).repeat(nA,1).t() 134 | anchor_boxes = torch.cat((torch.zeros(nA, anchor_step), anchors),1).t() 135 | _, best_n = torch.max(multi_bbox_ious(tmp_gt_boxes, anchor_boxes, x1y1x2y2=False), 0) 136 | 137 | coord_mask[b][best_n][gj][gi] = 1 138 | cls_mask [b][best_n][gj][gi] = 1 139 | conf_mask [b][best_n][gj][gi] = 1 140 | tcoord [0][b][best_n][gj][gi] = gx - gi 141 | tcoord [1][b][best_n][gj][gi] = gy - gj 142 | tcoord [2][b][best_n][gj][gi] = math.log(gw/anchors[best_n][0]) 143 | tcoord [3][b][best_n][gj][gi] = math.log(gh/anchors[best_n][1]) 144 | tcls [b][best_n][gj][gi] = tbox[t][0] 145 | tconf [b][best_n][gj][gi] = 1 # yolov1 would have used iou-value here 146 | 147 | return coord_mask, conf_mask, cls_mask, tcoord, tconf, tcls 148 | 149 | 150 | def get_loss(self, output, target, return_single_value=True): 151 | device = output.device 152 | 153 | anchors = torch.from_numpy(self.get_masked_anchors().astype(np.float32)).to(device) 154 | 155 | nB = output.data.size(0) # batch size 156 | nA = len(anchors) 157 | nC = self.num_classes 158 | nH = output.data.size(2) 159 | nW = output.data.size(3) 160 | cls_anchor_dim = nB*nA*nH*nW 161 | 162 | output = output.view(nB, nA, (5+nC), nH, nW) 163 | 164 | ix = torch.LongTensor(range(0,5)).to(device) 165 | coord = output.index_select(2, ix[0:4]).view(nB*nA, -1, nH*nW).transpose(0,1).contiguous().view(4,cls_anchor_dim) # x, y, w, h 166 | coord[0:2] = coord[0:2].sigmoid() # x, y: bx = σ(tx) (+ cx) 167 | conf = output.index_select(2, ix[4]).view(nB, nA, nH, nW).sigmoid() 168 | 169 | grid_x = torch.linspace(0, nW-1, nW).repeat(nB*nA, nH, 1).view(cls_anchor_dim).to(device) 170 | grid_y = torch.linspace(0, nH-1, nH).repeat(nW,1).t().repeat(nB*nA, 1, 1).view(cls_anchor_dim).to(device) 171 | anchor_w = anchors.index_select(1, ix[0]).repeat(1, nB*nH*nW).view(cls_anchor_dim) 172 | anchor_h = anchors.index_select(1, ix[1]).repeat(1, nB*nH*nW).view(cls_anchor_dim) 173 | 174 | pred_boxes = torch.FloatTensor(4, cls_anchor_dim).to(device) 175 | pred_boxes[0] = coord[0] + grid_x # bx = σ(tx) + cx 176 | pred_boxes[1] = coord[1] + grid_y 177 | pred_boxes[2] = coord[2].exp() * anchor_w # pw*e(tw) 178 | pred_boxes[3] = coord[3].exp() * anchor_h 179 | pred_boxes = pred_boxes.transpose(0,1).contiguous().view(-1,4) 180 | 181 | coord_mask, conf_mask, cls_mask, tcoord, tconf, tcls = \ 182 | self.build_targets(pred_boxes.detach(), target.detach(), anchors.detach(), nH, nW) 183 | 184 | cls_grid = torch.linspace(5,5+nC-1,nC).long().to(device) 185 | cls = output.index_select(2, cls_grid) 186 | cls = cls.view(nB*nA, nC, nH*nW).transpose(1,2).contiguous().view(cls_anchor_dim, nC) 187 | cls_mask = (cls_mask == 1) 188 | tcls = tcls[cls_mask].long().view(-1) 189 | cls_mask = cls_mask.view(-1, 1).repeat(1,nC).to(device) 190 | cls = cls[cls_mask].view(-1, nC) 191 | 192 | tcoord = tcoord.view(4, cls_anchor_dim).to(device) 193 | tconf, tcls = tconf.to(device), tcls.to(device) 194 | coord_mask, conf_mask = coord_mask.view(cls_anchor_dim).to(device), conf_mask.to(device) 195 | 196 | loss_coord = nn.MSELoss(size_average=False)(coord*coord_mask, tcoord*coord_mask)/2 197 | loss_conf = nn.MSELoss(size_average=False)(conf*conf_mask, tconf*conf_mask) 198 | loss_cls = nn.CrossEntropyLoss(size_average=False)(cls, tcls) if cls.size(0) > 0 else 0 199 | loss = loss_coord + loss_conf + loss_cls 200 | 201 | if math.isnan(loss.item()): 202 | print(conf, tconf) 203 | raise ValueError('YoloLayer has isnan in loss') 204 | #sys.exit(0) 205 | 206 | if return_single_value: return loss 207 | else: return [loss, loss_coord, loss_conf, loss_cls] 208 | -------------------------------------------------------------------------------- /yolov3_pytorch/yolov3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .yolo_layer import * 5 | from .yolov3_base import * 6 | 7 | class Yolov3(Yolov3Base): 8 | def __init__(self, num_classes=80): 9 | super().__init__() 10 | self.backbone = Darknet([1,2,8,8,4]) 11 | 12 | anchors_per_region = 3 13 | self.yolo_0_pre = Yolov3UpsamplePrep([512, 1024], 1024, anchors_per_region*(5+num_classes)) 14 | self.yolo_0 = YoloLayer(anchors=[(116., 90.), (156., 198.), (373., 326.)], stride=32, num_classes=num_classes) 15 | 16 | self.yolo_1_c = ConvBN(512, 256, 1) 17 | self.yolo_1_prep = Yolov3UpsamplePrep([256, 512], 512+256, anchors_per_region*(5+num_classes)) 18 | self.yolo_1 = YoloLayer(anchors=[(30., 61.), (62., 45.), (59., 119.)], stride=16, num_classes=num_classes) 19 | 20 | self.yolo_2_c = ConvBN(256, 128, 1) 21 | self.yolo_2_prep = Yolov3UpsamplePrep([128, 256], 256+128, anchors_per_region*(5+num_classes)) 22 | self.yolo_2 = YoloLayer(anchors=[(10., 13.), (16., 30.), (33., 23.)], stride=8, num_classes=num_classes) 23 | 24 | def get_loss_layers(self): 25 | return [self.yolo_0, self.yolo_1, self.yolo_2] 26 | 27 | def forward_yolo(self, xb): 28 | x, y0 = self.yolo_0_pre(xb[-1]) 29 | 30 | x = self.yolo_1_c(x) 31 | x = nn.Upsample(scale_factor=2, mode='nearest')(x) 32 | x = torch.cat([x, xb[-2]], 1) 33 | x, y1 = self.yolo_1_prep(x) 34 | 35 | x = self.yolo_2_c(x) 36 | x = nn.Upsample(scale_factor=2, mode='nearest')(x) 37 | x = torch.cat([x, xb[-3]], 1) 38 | x, y2 = self.yolo_2_prep(x) 39 | 40 | return [y0, y1, y2] 41 | 42 | 43 | ################################################################### 44 | ## Backbone and helper modules 45 | 46 | class DarknetBlock(nn.Module): 47 | def __init__(self, ch_in): 48 | super().__init__() 49 | ch_hid = ch_in//2 50 | self.conv1 = ConvBN(ch_in, ch_hid, kernel_size=1, stride=1, padding=0) 51 | self.conv2 = ConvBN(ch_hid, ch_in, kernel_size=3, stride=1, padding=1) 52 | 53 | def forward(self, x): return self.conv2(self.conv1(x)) + x 54 | 55 | 56 | class Darknet(nn.Module): 57 | def __init__(self, num_blocks, start_nf=32): 58 | super().__init__() 59 | nf = start_nf 60 | self.base = ConvBN(3, nf, kernel_size=3, stride=1) #, padding=1) 61 | self.layers = [] 62 | for i, nb in enumerate(num_blocks): 63 | # dn_layer = make_group_layer(nf, nb, stride=(1 if i==-1 else 2)) 64 | dn_layer = self.make_group_layer(nf, nb, stride=2) 65 | self.add_module(f"darknet_{i}", dn_layer) 66 | self.layers.append(dn_layer) 67 | nf *= 2 68 | 69 | def make_group_layer(self, ch_in, num_blocks, stride=2): 70 | layers = [ConvBN(ch_in, ch_in*2, stride=stride)] 71 | for i in range(num_blocks): layers.append(DarknetBlock(ch_in*2)) 72 | return nn.Sequential(*layers) 73 | 74 | def forward(self, x): 75 | y = [self.base(x)] 76 | for l in self.layers: 77 | y.append(l(y[-1])) 78 | return y 79 | 80 | 81 | class Yolov3UpsamplePrep(nn.Module): 82 | def __init__(self, filters_list, in_filters, out_filters): 83 | super().__init__() 84 | self.branch = nn.ModuleList([ 85 | ConvBN(in_filters, filters_list[0], 1), 86 | ConvBN(filters_list[0], filters_list[1], kernel_size=3), 87 | ConvBN(filters_list[1], filters_list[0], kernel_size=1), 88 | ConvBN(filters_list[0], filters_list[1], kernel_size=3), 89 | ConvBN(filters_list[1], filters_list[0], kernel_size=1),]) 90 | self.for_yolo = nn.ModuleList([ 91 | ConvBN(filters_list[0], filters_list[1], kernel_size=3), 92 | nn.Conv2d(filters_list[1], out_filters, kernel_size=1, stride=1, 93 | padding=0, bias=True)]) 94 | 95 | def forward(self, x): 96 | for m in self.branch: x = m(x) 97 | branch_out = x 98 | for m in self.for_yolo: x = m(x) 99 | return branch_out, x 100 | 101 | -------------------------------------------------------------------------------- /yolov3_pytorch/yolov3_base.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, Iterable, defaultdict 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from abc import ABCMeta, abstractmethod 6 | import importlib 7 | 8 | from .yolo_layer import * 9 | 10 | 11 | class Yolov3Base(nn.Module, metaclass=ABCMeta): 12 | 13 | def __init__(self): 14 | super().__init__() 15 | 16 | @abstractmethod 17 | def get_loss_layers(self): 18 | return [self.yolo_0, self.yolo_1] 19 | 20 | def forward_backbone(self, x): 21 | return self.backbone(x) 22 | 23 | def forward(self, x): 24 | shape = x.shape 25 | assert shape[1] == 3 and shape[2] % 32 == 0 and shape[3] % 32 == 0, f"Tensor shape should be [bs, 3, x*32, y*32], was {shape}" 26 | xb = self.forward_backbone(x) 27 | return self.forward_yolo(xb) 28 | 29 | def boxes_from_output(self, outputs, conf_thresh=0.25): 30 | all_boxes = [[] for j in range(outputs[0].size(0))] 31 | for i, layer in enumerate(self.get_loss_layers()): 32 | layer_boxes = layer.get_region_boxes(outputs[i], conf_thresh=conf_thresh) 33 | for j, layer_box in enumerate(layer_boxes): 34 | all_boxes[j] += layer_box 35 | 36 | return all_boxes 37 | 38 | def predict_img(self, imgs, conf_thresh=0.25): 39 | self.eval() 40 | if len(imgs.shape) == 3: imgs = imgs.unsqueeze(-1) 41 | 42 | outputs = self.forward(imgs) 43 | return self.boxes_from_output(outputs, conf_thresh) 44 | 45 | def freeze_backbone(self, requires_grad=False): 46 | for _, p in self.backbone.named_parameters(): 47 | p.requires_grad = requires_grad 48 | def unfreeze(self): 49 | for _, p in self.named_parameters(): 50 | p.requires_grad = True 51 | def freeze_info(self, print_all=False): 52 | d = defaultdict(set) 53 | print("Layer: param.requires_grad") 54 | for name, param in self.named_parameters(): 55 | if print_all: 56 | print(f"{name}: {param.requires_grad}") 57 | else: 58 | d[name.split('.')[0]].add(param.requires_grad) 59 | if not print_all: 60 | for k,v in d.items(): 61 | print(k, ': ', v) 62 | 63 | def load_backbone(self, h5_path): 64 | state_old = self.state_dict() 65 | state_new = torch.load(h5_path) 66 | 67 | skipped_layers = [] 68 | for k in list(state_new.keys()): 69 | if state_old[k].shape != state_new[k].shape: 70 | skipped_layers.append(k) 71 | del state_new[k] 72 | 73 | # for k in list(state_dict.keys()): 74 | # if k.startswith(('yolo_0_pre.15', 'yolo_1_pre.20')): 75 | # del state_dict[k] 76 | 77 | # Renaming some keys if needed for compatibility 78 | # state_dict = type(state_dict_org)() 79 | # for k_old in list(state_dict.keys()): 80 | # k_new = k_old.replace('backend', 'backbone') 81 | # state_dict[k_new] = state_dict_org[k_old] 82 | 83 | return self.load_state_dict(state_new, strict=False), skipped_layers 84 | 85 | 86 | ################################################################### 87 | ## Common helper modules 88 | 89 | # from fastai.models.darknet import ConvBN 90 | class ConvBN(nn.Module): 91 | "convolutional layer then batchnorm" 92 | 93 | def __init__(self, ch_in, ch_out, kernel_size = 3, stride=1, padding=None): 94 | super().__init__() 95 | if padding is None: padding = (kernel_size - 1) // 2 # we should never need to set padding 96 | self.conv = nn.Conv2d(ch_in, ch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) 97 | self.bn = nn.BatchNorm2d(ch_out, momentum=0.01) 98 | self.relu = nn.LeakyReLU(0.1, inplace=True) 99 | 100 | def forward(self, x): return self.relu(self.bn(self.conv(x))) 101 | 102 | 103 | class Upsample(nn.Module): 104 | def __init__(self, stride=2): 105 | super().__init__() 106 | self.stride = stride 107 | def forward(self, x): 108 | assert(x.data.dim() == 4) 109 | return nn.Upsample(scale_factor=self.stride, mode='nearest')(x) 110 | -------------------------------------------------------------------------------- /yolov3_pytorch/yolov3_tiny.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .yolo_layer import * 5 | from .yolov3_base import * 6 | 7 | class Yolov3Tiny(Yolov3Base): 8 | 9 | def __init__(self, num_classes, use_wrong_previous_anchors=False): 10 | super().__init__() 11 | 12 | self.num_classes = num_classes 13 | self.return_out_boxes = False 14 | self.skip_backbone = False 15 | 16 | self.backbone = Yolov3TinyBackbone() 17 | 18 | anchors_per_region = 3 19 | self.yolo_0_pre = nn.Sequential(OrderedDict([ 20 | ('14_convbatch', ConvBN(256, 512, 3, 1, 1)), 21 | ('15_conv', nn.Conv2d(512, anchors_per_region*(5+self.num_classes), 1, 1, 0)), 22 | # ('16_yolo', YoloLayer()), 23 | ])) 24 | self.yolo_0 = YoloLayer(anchors=[(81.,82.), (135.,169.), (344.,319.)], stride=32, num_classes=num_classes) 25 | 26 | self.up_1 = nn.Sequential(OrderedDict([ 27 | ('17_convbatch', ConvBN(256, 128, 1, 1, 0)), 28 | ('18_upsample', Upsample(2)), 29 | ])) 30 | 31 | self.yolo_1_pre = nn.Sequential(OrderedDict([ 32 | ('19_convbatch', ConvBN(128+256, 256, 3, 1, 1)), 33 | ('20_conv', nn.Conv2d(256, anchors_per_region*(5+self.num_classes), 1, 1, 0)), 34 | # ('21_yolo', YoloLayer()), 35 | ])) 36 | 37 | # Tiny yolo weights were originally trained using wrong anchor mask 38 | # https://github.com/pjreddie/darknet/commit/f86901f6177dfc6116360a13cc06ab680e0c86b0#diff-2b0e16f442a744897f1606ff1a0f99d3L175 39 | if use_wrong_previous_anchors: 40 | yolo_1_anchors = [(23.,27.), (37.,58.), (81.,82.)] 41 | else: 42 | yolo_1_anchors = [(10.,14.), (23.,27.), (37.,58.)] 43 | 44 | self.yolo_1 = YoloLayer(anchors=yolo_1_anchors, stride=16.0, num_classes=num_classes) 45 | 46 | def get_loss_layers(self): 47 | return [self.yolo_0, self.yolo_1] 48 | 49 | def forward_yolo(self, xb): 50 | x_b_0, x_b_full = xb[0], xb[1] 51 | y0 = self.yolo_0_pre(x_b_full) 52 | 53 | x_up = self.up_1(x_b_full) 54 | x_up = torch.cat((x_up, x_b_0), 1) 55 | y1 = self.yolo_1_pre(x_up) 56 | 57 | return [y0, y1] 58 | 59 | 60 | ################################################################### 61 | ## Backbone and helper modules 62 | 63 | class MaxPoolStride1(nn.Module): 64 | def __init__(self): 65 | super().__init__() 66 | 67 | def forward(self, x): 68 | x = F.max_pool2d(F.pad(x, (0,1,0,1), mode='replicate'), 2, stride=1) 69 | return x 70 | 71 | 72 | class Yolov3TinyBackbone(nn.Module): 73 | def __init__(self, input_channels=3): 74 | super().__init__() 75 | self.layers_list = OrderedDict([ 76 | ('0_convbatch', ConvBN(input_channels, 16, 3, 1, 1)), 77 | ('1_max', nn.MaxPool2d(2, 2)), 78 | ('2_convbatch', ConvBN(16, 32, 3, 1, 1)), 79 | ('3_max', nn.MaxPool2d(2, 2)), 80 | ('4_convbatch', ConvBN(32, 64, 3, 1, 1)), 81 | ('5_max', nn.MaxPool2d(2, 2)), 82 | ('6_convbatch', ConvBN(64, 128, 3, 1, 1)), 83 | ('7_max', nn.MaxPool2d(2, 2)), 84 | ('8_convbatch', ConvBN(128, 256, 3, 1, 1)), 85 | ('9_max', nn.MaxPool2d(2, 2)), 86 | # ('9_max', nn.MaxPool2d(2, 2, ceil_mode=True)), 87 | ('10_convbatch', ConvBN(256, 512, 3, 1, 1)), 88 | ('11_max', MaxPoolStride1()), 89 | ('12_convbatch', ConvBN(512, 1024, 3, 1, 1)), 90 | ('13_convbatch', ConvBN(1024, 256, 1, 1, 0)), # padding = kernel_size-1//2 91 | ]) 92 | self.layers = nn.Sequential(self.layers_list) 93 | self.idx = 9 94 | 95 | def forward(self, x): 96 | x_b_0 = self.layers[:self.idx](x) 97 | x_b_full = self.layers[self.idx:](x_b_0) 98 | return x_b_0, x_b_full 99 | 100 | --------------------------------------------------------------------------------