├── model ├── __init__.py ├── backbone │ ├── __init__.py │ ├── darknet.py │ └── mobilenetv2.py └── yolo_layer.py ├── test_img.png ├── font ├── FiraMono-Medium.otf └── SIL Open Font License.txt ├── cfgs ├── yolo_detect.yml └── yolo_train.yml ├── LICENSE ├── .gitignore ├── README.md ├── detect.py ├── utils └── utils.py ├── train.py └── .idea └── workspace.xml /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TommyAnqi/YOLOv3-Pytorch/HEAD/test_img.png -------------------------------------------------------------------------------- /font/FiraMono-Medium.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TommyAnqi/YOLOv3-Pytorch/HEAD/font/FiraMono-Medium.otf -------------------------------------------------------------------------------- /cfgs/yolo_detect.yml: -------------------------------------------------------------------------------- 1 | backbones_network: 'mobilenetv2' 2 | weightfile: '' 3 | 4 | # general setting: 5 | 6 | output_dir: outputs/ 7 | save_interval: 5 8 | save_path: save/ 9 | dir_logs: logs/ 10 | inference_path: save/tl2/mobilenetv2/model-best.pth 11 | height: 416 12 | width: 416 13 | seed: 123 14 | 15 | # loading options. 16 | use_cuda: True 17 | 18 | anchors: [[10, 13], [16, 30], [33, 23], 19 | [30, 61], [62, 45], [59, 119], 20 | [116, 90], [156, 198], [373, 326]] 21 | classes_names: ['g', 'r', 'y', 'n'] 22 | use_all_gt: False -------------------------------------------------------------------------------- /model/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from . import darknet 2 | from . import mobilenetv2 3 | import pdb 4 | 5 | 6 | def backbone_fn(opt): 7 | if opt.backbones_network == "darknet21": 8 | model = darknet.darknet21(opt.weightfile) 9 | elif opt.backbones_network == "darknet53": 10 | model = darknet.darknet53(opt.weightfile) 11 | elif opt.backbones_network == "mobilenetv2": 12 | model = mobilenetv2.mobilenetv2(opt.weightfile) 13 | elif opt.backbones_network == "modified_mobilenetv2": 14 | model = mobilenetv2.mobilenetv2(opt.weightfile) 15 | else: 16 | pdb.set_trace() 17 | return model -------------------------------------------------------------------------------- /cfgs/yolo_train.yml: -------------------------------------------------------------------------------- 1 | # training file 2 | backbones_network: 'mobilenetv2' 3 | weightfile: '' 4 | annotation_path: 'cfgs/train_TL_NI_mixed_crop_07092.txt' 5 | anchors: [[10, 13], [16, 30], [33, 23], 6 | [30, 61], [62, 45], [59, 119], 7 | [116, 90], [156, 198], [373, 326]] 8 | classes_names: ['g', 'r', 'y', 'n'] 9 | 10 | # general setting: 11 | batch_size: 10 12 | max_epochs: 100 13 | val_split: 0.1 14 | 15 | output_dir: outputs/ 16 | save_interval: 5 17 | save_path: save/ 18 | dir_logs: logs/ 19 | height: 416 20 | width: 416 21 | seed: 123 22 | 23 | # optimization: 24 | optimizer: 'adam' 25 | backbone_lr: 0.0000001 26 | lr: 0.0000001 27 | 28 | weight_decay: 0 29 | decay_gamma: 0.1 30 | decay_step: 20 31 | 32 | # loading options. 33 | use_cuda: True 34 | n_iter: 0 35 | start_epoch: 0 36 | start_from: ' ' 37 | load_best_score: 0 38 | display_interval: 5 39 | seen: 0 40 | use_all_gt: False -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 TommyML 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YOLOv3-Pytorch 2 | PyTorch implementation of YOLOv3, including training and inference based on darknet and mobilnetv2 3 | ## Introduction 4 | The YOLO is one of the most popular one stage object detector. In Mar 2018, [YOLOv3: An Incremental Improvement](https://pjreddie.com/media/files/papers/YOLOv3.pdf), which is extremely fast and accurate has been released. The aim of this project is to replicate the [Darknet](https://github.com/pjreddie/darknet) implementation. It also supports training YOLOv3 network with various backends such as MobileNetv2, Darknet(53 or 21). If you have any question or idea about this repo, make comments or email to anqitommy@gmail.com 5 | 6 | 7 | --- 8 | ## Quick Start 9 | 1. Download YOLOv3 mobilenetv2 full weights from [BaiduDisk](https://pan.baidu.com/s/15SS5CtdXcIRzSwdB4w0h3Q), password:j7oz. 10 | 2. Creat a new file to store the weights and modify the inference path in the ./cfg/yolo_detect.yml. 11 | 3. Run detect.py with the test_img.png. 12 | 13 | 14 | --- 15 | ## Training 16 | 1. Generate your own annotation file and class names file. 17 | One row for one image; 18 | Row format: `image_file_path box1 box2 ... boxN`; 19 | Box format: `x_min,y_min,x_max,y_max,class_id` (no space). 20 | Here is an example: 21 | ``` 22 | path/to/img1.jpg 50,100,150,200,0 30,50,200,120,3 23 | path/to/img2.jpg 120,300,250,600,2 24 | ... 25 | ``` 26 | 2. If you want to use original pretrained weights for YOLOv3: 27 | Download YOLOv3 darknet53 backbone weights from [BaiduDisk](https://pan.baidu.com/s/1N3jN6imnsbsquk04J2G_-Q), password:w6fm. 28 | 29 | 3. Modify yolo_train.yml and start training. 30 | `python train.py` 31 | Use your trained weights or checkpoint weights, modify the training parameters, weightfile in yolo_train.yml, 32 | Remember to modify the annotation_path of your own annotation file, class_names, anchors, save_path. If you want to use mobilnetv2 as the backbone net, modify the `backbones_network` 33 | 34 | 35 | --- 36 | 37 | ## Todo list: 38 | - [x] Training 39 | - [x] Multiscale training 40 | - [x] Mobilnetv2 backends 41 | - [ ] Multiscale testing 42 | - [ ] Soft-nms 43 | - [ ] Multiple-GPU training 44 | - [ ] mAP Evaluation 45 | - [ ] Extend to YOLO-FCN 46 | 47 | 48 | --- 49 | ## Requirements 50 | - Python 3.6 51 | - Pytorch 0.4.0 52 | - TensorboardX 53 | - Cuda 9.0 or higher 54 | 55 | 56 | --- 57 | 58 | ## Citation 59 | - [qqwweee/keras-yolo3](https://github.com/qqwweee/keras-yolo3) 60 | - [BobLiu20/YOLOv3_PyTorch](https://github.com/BobLiu20/YOLOv3_PyTorch) 61 | - [xiaochus/MobileNetV2](https://github.com/xiaochus/MobileNetV2) 62 | -------------------------------------------------------------------------------- /model/backbone/darknet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from collections import OrderedDict 5 | 6 | __all__ = ['darknet21', 'darknet53'] 7 | 8 | 9 | class BasicBlock(nn.Module): 10 | def __init__(self, inplanes, planes): 11 | super(BasicBlock, self).__init__() 12 | self.conv1 = nn.Conv2d(inplanes, planes[0], kernel_size=1, 13 | stride=1, padding=0, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes[0]) 15 | self.relu1 = nn.LeakyReLU(0.1) 16 | self.conv2 = nn.Conv2d(planes[0], planes[1], kernel_size=3, 17 | stride=1, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes[1]) 19 | self.relu2 = nn.LeakyReLU(0.1) 20 | 21 | def forward(self, x): 22 | residual = x 23 | 24 | out = self.conv1(x) 25 | out = self.bn1(out) 26 | out = self.relu1(out) 27 | 28 | out = self.conv2(out) 29 | out = self.bn2(out) 30 | out = self.relu2(out) 31 | 32 | out += residual 33 | return out 34 | 35 | 36 | class DarkNet(nn.Module): 37 | def __init__(self, layers): 38 | super(DarkNet, self).__init__() 39 | self.inplanes = 32 40 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 41 | self.bn1 = nn.BatchNorm2d(self.inplanes) 42 | self.relu1 = nn.LeakyReLU(0.1) 43 | 44 | self.layer1 = self._make_layer([32, 64], layers[0]) 45 | self.layer2 = self._make_layer([64, 128], layers[1]) 46 | self.layer3 = self._make_layer([128, 256], layers[2]) 47 | self.layer4 = self._make_layer([256, 512], layers[3]) 48 | self.layer5 = self._make_layer([512, 1024], layers[4]) 49 | 50 | self.layers_out_filters = [64, 128, 256, 512, 1024] 51 | self.init_params() 52 | 53 | # initialize the parameters in convolution and batch normalization layers 54 | def init_params(self): 55 | for m in self.modules(): 56 | if isinstance(m, nn.Conv2d): 57 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 58 | m.weight.data.normal_(0, math.sqrt(2. / n)) 59 | elif isinstance(m, nn.BatchNorm2d): 60 | m.weight.data.fill_(1) 61 | m.bias.data.zero_() 62 | 63 | def _make_layer(self, planes, blocks): 64 | layers = [] 65 | # downsample 66 | layers.append(("ds_conv", nn.Conv2d(self.inplanes, planes[1], kernel_size=3, 67 | stride=2, padding=1, bias=False))) 68 | layers.append(("ds_bn", nn.BatchNorm2d(planes[1]))) 69 | layers.append(("ds_relu", nn.LeakyReLU(0.1))) 70 | # blocks 71 | self.inplanes = planes[1] 72 | for i in range(0, blocks): 73 | layers.append(("residual_{}".format(i), BasicBlock(self.inplanes, planes))) 74 | return nn.Sequential(OrderedDict(layers)) 75 | 76 | def forward(self, x): 77 | x = self.conv1(x) 78 | x = self.bn1(x) 79 | x = self.relu1(x) 80 | 81 | x = self.layer1(x) 82 | x = self.layer2(x) 83 | out3 = self.layer3(x) 84 | out4 = self.layer4(out3) 85 | out5 = self.layer5(out4) 86 | 87 | return out3, out4, out5 88 | 89 | 90 | def darknet21(pretrained, **kwargs): 91 | """Constructs a darknet-21 model. 92 | """ 93 | model = DarkNet([1, 1, 2, 2, 1]) 94 | if pretrained: 95 | if isinstance(pretrained, str): 96 | model.load_state_dict(torch.load(pretrained)) 97 | else: 98 | raise Exception("darknet request a pretrained path. got [{}]".format(pretrained)) 99 | return model 100 | 101 | 102 | def darknet53(pretrained, **kwargs): 103 | """Constructs a darknet-53 model. 104 | """ 105 | model = DarkNet([1, 2, 8, 8, 4]) 106 | if pretrained: 107 | if isinstance(pretrained, str): 108 | model.load_state_dict(torch.load(pretrained)) 109 | else: 110 | raise Exception("darknet request a pretrained path. got [{}]".format(pretrained)) 111 | return model 112 | -------------------------------------------------------------------------------- /font/SIL Open Font License.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014, Mozilla Foundation https://mozilla.org/ with Reserved Font Name Fira Mono. 2 | 3 | Copyright (c) 2014, Telefonica S.A. 4 | 5 | This Font Software is licensed under the SIL Open Font License, Version 1.1. 6 | This license is copied below, and is also available with a FAQ at: http://scripts.sil.org/OFL 7 | 8 | ----------------------------------------------------------- 9 | SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007 10 | ----------------------------------------------------------- 11 | 12 | PREAMBLE 13 | The goals of the Open Font License (OFL) are to stimulate worldwide development of collaborative font projects, to support the font creation efforts of academic and linguistic communities, and to provide a free and open framework in which fonts may be shared and improved in partnership with others. 14 | 15 | The OFL allows the licensed fonts to be used, studied, modified and redistributed freely as long as they are not sold by themselves. The fonts, including any derivative works, can be bundled, embedded, redistributed and/or sold with any software provided that any reserved names are not used by derivative works. The fonts and derivatives, however, cannot be released under any other type of license. The requirement for fonts to remain under this license does not apply to any document created using the fonts or their derivatives. 16 | 17 | DEFINITIONS 18 | "Font Software" refers to the set of files released by the Copyright Holder(s) under this license and clearly marked as such. This may include source files, build scripts and documentation. 19 | 20 | "Reserved Font Name" refers to any names specified as such after the copyright statement(s). 21 | 22 | "Original Version" refers to the collection of Font Software components as distributed by the Copyright Holder(s). 23 | 24 | "Modified Version" refers to any derivative made by adding to, deleting, or substituting -- in part or in whole -- any of the components of the Original Version, by changing formats or by porting the Font Software to a new environment. 25 | 26 | "Author" refers to any designer, engineer, programmer, technical writer or other person who contributed to the Font Software. 27 | 28 | PERMISSION & CONDITIONS 29 | Permission is hereby granted, free of charge, to any person obtaining a copy of the Font Software, to use, study, copy, merge, embed, modify, redistribute, and sell modified and unmodified copies of the Font Software, subject to the following conditions: 30 | 31 | 1) Neither the Font Software nor any of its individual components, in Original or Modified Versions, may be sold by itself. 32 | 33 | 2) Original or Modified Versions of the Font Software may be bundled, redistributed and/or sold with any software, provided that each copy contains the above copyright notice and this license. These can be included either as stand-alone text files, human-readable headers or in the appropriate machine-readable metadata fields within text or binary files as long as those fields can be easily viewed by the user. 34 | 35 | 3) No Modified Version of the Font Software may use the Reserved Font Name(s) unless explicit written permission is granted by the corresponding Copyright Holder. This restriction only applies to the primary font name as presented to the users. 36 | 37 | 4) The name(s) of the Copyright Holder(s) or the Author(s) of the Font Software shall not be used to promote, endorse or advertise any Modified Version, except to acknowledge the contribution(s) of the Copyright Holder(s) and the Author(s) or with their explicit written permission. 38 | 39 | 5) The Font Software, modified or unmodified, in part or in whole, must be distributed entirely under this license, and must not be distributed under any other license. The requirement for fonts to remain under this license does not apply to any document created using the Font Software. 40 | 41 | TERMINATION 42 | This license becomes null and void if any of the above conditions are not met. 43 | 44 | DISCLAIMER 45 | THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM OTHER DEALINGS IN THE FONT SOFTWARE. -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import colorsys 3 | import numpy as np 4 | import argparse 5 | import yaml 6 | from utils.utils import update_values, letterbox_image 7 | from model import yolo_layer 8 | import time 9 | import cv2 10 | 11 | 12 | class yolo_detect(object): 13 | def __init__(self, args): 14 | self.args = args 15 | self.model_image_size = (args.height, args.width) 16 | self.class_names = self.args.classes_names 17 | self.model = self.creat_model() 18 | self.model.eval() 19 | 20 | def creat_model(self): 21 | model_path = self.args.inference_path 22 | model = yolo_layer.yolov3layer(self.args) 23 | model.load_state_dict(torch.load(model_path)) 24 | hsv_tuples = [(x / len(self.class_names), 1., 1.) 25 | for x in range(len(self.class_names))] 26 | self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 27 | self.colors = list( 28 | map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), 29 | self.colors)) 30 | if self.args.use_cuda: 31 | if self.args.mGPUs: 32 | model = torch.nn.DataParallel(model).cuda() 33 | else: 34 | model = model.cuda() 35 | return model 36 | 37 | def detect(self, image): 38 | time_crop_img_s = time.time() 39 | if self.model_image_size != (None, None): 40 | assert self.model_image_size[0]%32 == 0, 'Multiples of 32 required' 41 | assert self.model_image_size[1]%32 == 0, 'Multiples of 32 required' 42 | boxed_image = letterbox_image(image, tuple(reversed(self.model_image_size))) 43 | else: 44 | new_image_size = (image.width - (image.width % 32), 45 | image.height - (image.height % 32)) 46 | time_crop_img_e = time.time() 47 | print('Time consuming of image crop:', time_crop_img_e - time_crop_img_s) 48 | time_inference_s = time.time() 49 | image_data = np.array(boxed_image, dtype='float32') / 255.0 50 | img = torch.from_numpy(image_data).float().cuda() 51 | img = img.unsqueeze(0) 52 | img = img.view(img.shape[0], img.shape[1], img.shape[2], img.shape[3]).permute(0, 3, 1, 2).contiguous() 53 | img_ori_shape = torch.Tensor([image.shape[0], image.shape[1]]) 54 | dets, images, classes = self.model.detect(img, img_ori_shape) 55 | time_inference_e = time.time() 56 | print('Time consuming of inference:', time_inference_e - time_inference_s) 57 | dets_arr = dets.cpu().numpy() 58 | classes_arr = classes.cpu().numpy() 59 | time_draw_s = time.time() 60 | # font = ImageFont.truetype(font='font/FiraMono-Medium.otf', 61 | # size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32')) 62 | for i, c in enumerate(classes_arr): 63 | c = int(c) 64 | predicted_class = self.class_names[c] 65 | box = dets_arr[i] 66 | 67 | top, left, bottom, right, score = box 68 | label = '{} {:.2f}'.format(predicted_class, score) 69 | 70 | top = max(0, np.floor(top + 0.5).astype('int32')) 71 | left = max(0, np.floor(left + 0.5).astype('int32')) 72 | bottom = min(image.shape[0], np.floor(bottom + 0.5).astype('int32')) 73 | right = min(image.shape[1], np.floor(right + 0.5).astype('int32')) 74 | print(label, (left, top), (right, bottom)) 75 | 76 | cv2.rectangle(image, (left, top), (right, bottom), self.colors[c], 2) 77 | time_draw_e = time.time() 78 | print('Time consuming of Drawing image:', time_draw_e - time_draw_s) 79 | return image 80 | 81 | 82 | def detect_img_test(yolo): 83 | while 1: 84 | image = input("input image:") 85 | try: 86 | time_load_start = time.time() 87 | image = cv2.imread(image) 88 | time_load_end = time.time() 89 | print('Time consuming of image loading:', time_load_end - time_load_start) 90 | except: 91 | print('Open Error! Try again!') 92 | continue 93 | else: 94 | r_image = yolo.detect(image) 95 | time_end = time.time() 96 | print('Time consuming of object detection based on yolo3', time_end - time_load_start) 97 | cv2.namedWindow("result", cv2.WINDOW_NORMAL) 98 | cv2.imshow('result', r_image) 99 | cv2.waitKey(0) 100 | 101 | 102 | def main(): 103 | args = make_args() 104 | with open(args.cfg_path, 'r') as handle: 105 | options_yaml = yaml.load(handle) 106 | update_values(options_yaml, vars(args)) 107 | detect_img_test(yolo_detect(args)) 108 | 109 | 110 | def make_args(): 111 | # load the optional parameters and update new arguments 112 | parser = argparse.ArgumentParser() 113 | # # Data input settings 114 | parser.add_argument('--cfg_path', type=str, default='cfgs/Yolo_detect.yml', help='load config') 115 | parser.add_argument('--use_cuda', type=bool, default=True, help='whether use gpu.') 116 | parser.add_argument('--mGPUs', type=bool, default=False, help='whether use mgpu.') 117 | return parser.parse_args() 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | 123 | 124 | -------------------------------------------------------------------------------- /model/backbone/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from collections import OrderedDict 5 | 6 | __all__ = ['mobilenetv2'] 7 | 8 | 9 | def conv_bn(inp, oup, stride): 10 | return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | nn.BatchNorm2d(oup), 12 | nn.ReLU6(inplace=True)) 13 | 14 | 15 | def pointwise_conv_bn(inp, oup): 16 | return nn.Sequential( 17 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 18 | nn.BatchNorm2d(oup), 19 | nn.ReLU6(inplace=True)) 20 | 21 | 22 | class InvertedResidual(nn.Module): 23 | def __init__(self, inp, oup, stride, expand_ratio): 24 | super(InvertedResidual, self).__init__() 25 | self.stride = stride 26 | assert stride in [1, 2] 27 | 28 | hidden_dim = round(inp * expand_ratio) 29 | self.use_res_connect = self.stride == 1 and inp == oup 30 | 31 | if expand_ratio == 1: 32 | self.conv = nn.Sequential( 33 | # dw 34 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 35 | nn.BatchNorm2d(hidden_dim), 36 | nn.ReLU6(inplace=True), 37 | # pw-linear 38 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 39 | nn.BatchNorm2d(oup), 40 | ) 41 | else: 42 | self.conv = nn.Sequential( 43 | # pw 44 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 45 | nn.BatchNorm2d(hidden_dim), 46 | nn.ReLU6(inplace=True), 47 | # dw 48 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 49 | nn.BatchNorm2d(hidden_dim), 50 | nn.ReLU6(inplace=True), 51 | # pw-linear 52 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 53 | nn.BatchNorm2d(oup), 54 | ) 55 | 56 | def forward(self, x): 57 | if self.use_res_connect: 58 | return x + self.conv(x) 59 | else: 60 | return self.conv(x) 61 | 62 | 63 | class Mobilenetv2(nn.Module): 64 | def __init__(self, width_mult=1.): 65 | super(Mobilenetv2, self).__init__() 66 | self.input_channel = 32 67 | last_channel = 1280 68 | self.width_mult = width_mult 69 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 70 | 71 | # building first layer 72 | self.layer0 = nn.Sequential(conv_bn(3, self.input_channel, 2)) 73 | 74 | # building inverted residual blocks 75 | self.layer1 = self._make_layer(1, 16, 1, 1) 76 | self.layer2 = self._make_layer(6, 24, 2, 2) 77 | 78 | # the 52*52 feature map output 79 | self.layer3 = self._make_layer(6, 32, 3, 2) 80 | self.layer3_output = pointwise_conv_bn(self.input_channel, 128) 81 | 82 | # the 26*26 feature map output 83 | self.layer4 = self._make_layer(6, 64, 4, 2) 84 | self.layer5 = self._make_layer(6, 96, 3, 1) 85 | self.layer5_output = pointwise_conv_bn(self.input_channel, 384) 86 | 87 | self.layer6 = self._make_layer(6, 160, 3, 2) 88 | self.layer7 = self._make_layer(6, 320, 1, 1) 89 | 90 | self.last_layer = pointwise_conv_bn(self.input_channel, self.last_channel) 91 | 92 | # the demensions of output feature map channels 93 | self.layers_out_filters = [128, 384, 1280] 94 | self._initialize_weights() 95 | 96 | def _initialize_weights(self): 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 100 | m.weight.data.normal_(0, math.sqrt(2. / n)) 101 | if m.bias is not None: 102 | m.bias.data.zero_() 103 | elif isinstance(m, nn.BatchNorm2d): 104 | m.weight.data.fill_(1) 105 | m.bias.data.zero_() 106 | elif isinstance(m, nn.Linear): 107 | n = m.weight.size(1) 108 | m.weight.data.normal_(0, 0.01) 109 | m.bias.data.zero_() 110 | 111 | def _make_layer(self, t, c, n, s): 112 | layers = [] 113 | output_channel = int(c * self.width_mult) 114 | for i in range(n): 115 | if i == 0: 116 | layers.append(("IRB_{}".format(i), InvertedResidual(self.input_channel, output_channel, s, expand_ratio=t))) 117 | else: 118 | layers.append(("IRB_{}".format(i), InvertedResidual(self.input_channel, output_channel, 1, expand_ratio=t))) 119 | self.input_channel = output_channel 120 | return nn.Sequential(OrderedDict(layers)) 121 | 122 | 123 | def forward(self, x): 124 | x = self.layer0(x) 125 | x = self.layer1(x) 126 | x = self.layer2(x) 127 | x = self.layer3(x) 128 | out52 = self.layer3_output(x) 129 | x = self.layer4(x) 130 | x = self.layer5(x) 131 | out26 = self.layer5_output(x) 132 | x = self.layer6(x) 133 | x = self.layer7(x) 134 | out13 = self.last_layer(x) 135 | return out52, out26, out13 136 | 137 | 138 | def mobilenetv2(pretrained, **kwargs): 139 | """Constructs a mobilenetv2 model. 140 | """ 141 | model = Mobilenetv2() 142 | if pretrained: 143 | if isinstance(pretrained, str): 144 | model.load_state_dict(torch.load(pretrained)) 145 | else: 146 | raise Exception("darknet request a pretrained path. got [{}]".format(pretrained)) 147 | return model 148 | 149 | 150 | if __name__ == "__main__": 151 | config = {"model_params": {"backbone_name": "darknet_53"}} 152 | m = Mobilenetv2() 153 | x = torch.randn(1, 3, 416, 416) 154 | y0, y1, y2 = m(x) 155 | print(y0.size()) 156 | print(y1.size()) 157 | print(y2.size()) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | from matplotlib.colors import rgb_to_hsv, hsv_to_rgb 4 | from tensorboardX import SummaryWriter 5 | import cv2 6 | 7 | 8 | def letterbox_image(image, size): 9 | '''resize image with unchanged aspect ratio using padding''' 10 | iw, ih = image.shape[1], image.shape[0] 11 | w, h = size 12 | scale = min(w/iw, h/ih) 13 | nw = int(iw*scale) 14 | nh = int(ih*scale) 15 | image = cv2.resize(image, (nw, nh), cv2.INTER_CUBIC) 16 | image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 17 | new_image = Image.new('RGB', size, (128,128,128)) 18 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 19 | return new_image 20 | 21 | 22 | def rand(a=0, b=1): 23 | return np.random.rand()*(b-a) + a 24 | 25 | 26 | def get_random_data(annotation_line, input_shape, random=True, max_boxes=30, jitter=.3, hue=.1, sat=1.5, val=1.5, proc_img=True): 27 | '''random preprocessing for real-time data augmentation''' 28 | line = annotation_line.split() 29 | image = Image.open(line[0]) 30 | iw, ih = image.size 31 | h, w = input_shape 32 | box = np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]]) 33 | 34 | if not random: 35 | # resize image 36 | scale = min(w/iw, h/ih) 37 | nw = int(iw*scale) 38 | nh = int(ih*scale) 39 | dx = (w-nw)//2 40 | dy = (h-nh)//2 41 | image_data = 0 42 | 43 | if proc_img: 44 | image = image.resize((nw,nh), Image.BICUBIC) 45 | new_image = Image.new('RGB', (w,h), (128,128,128)) 46 | new_image.paste(image, (dx, dy)) 47 | image_data = np.array(new_image)/255. 48 | 49 | # correct boxes 50 | box_data = np.zeros((max_boxes,5)) 51 | if len(box)>0: 52 | np.random.shuffle(box) 53 | if len(box)>max_boxes: box = box[:max_boxes] 54 | box[:, [0,2]] = box[:, [0,2]]*scale + dx 55 | box[:, [1,3]] = box[:, [1,3]]*scale + dy 56 | box_data[:len(box)] = box 57 | 58 | return image_data, box_data 59 | 60 | 61 | # resize image 62 | new_ar = w/h * rand(1-jitter, 1+jitter)/rand(1-jitter, 1+jitter) 63 | scale = rand(.25, 2) 64 | if new_ar < 1: 65 | nh = int(scale*h) 66 | nw = int(nh*new_ar) 67 | else: 68 | nw = int(scale*w) 69 | nh = int(nw/new_ar) 70 | image = image.resize((nw, nh), Image.BICUBIC) 71 | 72 | # place image 73 | dx = int(rand(0, w-nw)) 74 | dy = int(rand(0, h-nh)) 75 | new_image = Image.new('RGB', (w, h), (128, 128, 128)) 76 | new_image.paste(image, (dx, dy)) 77 | image = new_image 78 | 79 | # flip image or not 80 | flip = rand() < .5 81 | if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT) 82 | 83 | # distort image 84 | hue = rand(-hue, hue) 85 | sat = rand(1, sat) if rand()<.5 else 1/rand(1, sat) 86 | val = rand(1, val) if rand()<.5 else 1/rand(1, val) 87 | x = rgb_to_hsv(np.array(image)/255.) 88 | x[..., 0] += hue 89 | x[..., 0][x[..., 0]>1] -= 1 90 | x[..., 0][x[..., 0]<0] += 1 91 | x[..., 1] *= sat 92 | x[..., 2] *= val 93 | x[x>1] = 1 94 | x[x<0] = 0 95 | image_data = hsv_to_rgb(x) # numpy array, 0 to 1 96 | 97 | # correct boxes 98 | box_data = np.zeros((max_boxes,5)) 99 | if len(box)>0: 100 | np.random.shuffle(box) 101 | box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx 102 | box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy 103 | if flip: box[:, [0,2]] = w - box[:, [2,0]] 104 | box[:, 0:2][box[:, 0:2]<0] = 0 105 | box[:, 2][box[:, 2]>w] = w 106 | box[:, 3][box[:, 3]>h] = h 107 | box_w = box[:, 2] - box[:, 0] 108 | box_h = box[:, 3] - box[:, 1] 109 | box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box 110 | if len(box)>max_boxes: box = box[:max_boxes] 111 | box_data[:len(box)] = box 112 | 113 | return image_data, box_data 114 | 115 | 116 | def preprocess_true_boxes(true_boxes, input_shape, anchors, num_classes): 117 | '''Preprocess true boxes to training input format 118 | 119 | Parameters 120 | ---------- 121 | true_boxes: array, shape=(m, T, 5) 122 | Absolute x_min, y_min, x_max, y_max, class_id relative to input_shape. 123 | input_shape: array-like, hw, multiples of 32 124 | anchors: array, shape=(N, 2), wh 125 | num_classes: integer 126 | 127 | Returns 128 | ------- 129 | y_true: list of array, shape like yolo_outputs, xywh are reletive value 130 | 131 | ''' 132 | assert (true_boxes[..., 4] < num_classes).all(), 'class id must be less than num_classes' 133 | num_layers = len(anchors) // 3 # default setting 134 | anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] if num_layers == 3 else [[3, 4, 5], [0, 1, 2]] 135 | 136 | true_boxes = np.array(true_boxes, dtype='float32') 137 | input_shape = np.array(input_shape, dtype='int32') 138 | boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2 # center x, y 139 | boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2] 140 | true_boxes[..., 0:2] = boxes_xy / input_shape[::-1] 141 | true_boxes[..., 2:4] = boxes_wh / input_shape[::-1] 142 | 143 | m = true_boxes.shape[0] 144 | grid_shapes = [input_shape // {0: 32, 1: 16, 2: 8}[l] for l in range(num_layers)] 145 | y_true = [np.zeros((m, grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]), 5 + num_classes), 146 | dtype='float32') for l in range(num_layers)] 147 | 148 | # Expand dim to apply broadcasting. 149 | anchors = np.expand_dims(anchors, 0) 150 | anchor_maxes = anchors / 2. 151 | anchor_mins = -anchor_maxes 152 | 153 | valid_mask = np.bitwise_and(boxes_wh[..., 0] > 0, boxes_wh[..., 1] > 0) 154 | 155 | for b in range(m): 156 | # Discard zero rows. 157 | wh = boxes_wh[b, valid_mask[b]] 158 | if len(wh) == 0: continue 159 | # Expand dim to apply broadcasting. 160 | wh = np.expand_dims(wh, -2) 161 | box_maxes = wh / 2. 162 | box_mins = -box_maxes 163 | 164 | intersect_mins = np.maximum(box_mins, anchor_mins) 165 | intersect_maxes = np.minimum(box_maxes, anchor_maxes) 166 | intersect_wh = np.maximum(intersect_maxes - intersect_mins, 0.) 167 | intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] 168 | box_area = wh[..., 0] * wh[..., 1] 169 | anchor_area = anchors[..., 0] * anchors[..., 1] 170 | iou = intersect_area / (box_area + anchor_area - intersect_area) 171 | 172 | # Find best anchor for each true box 173 | best_anchor = np.argmax(iou, axis=-1) 174 | 175 | for t, n in enumerate(best_anchor): 176 | for l in range(num_layers): 177 | if n in anchor_mask[l]: 178 | i = np.floor(true_boxes[b, t, 0] * grid_shapes[l][1]).astype('int32') 179 | j = np.floor(true_boxes[b, t, 1] * grid_shapes[l][0]).astype('int32') 180 | k = anchor_mask[l].index(n) 181 | c = true_boxes[b, t, 4].astype('int32') 182 | y_true[l][b, j, i, k, 0:4] = true_boxes[b, t, 0:4] 183 | y_true[l][b, j, i, k, 4] = 1 184 | y_true[l][b, j, i, k, 5 + c] = 1 185 | 186 | return y_true 187 | 188 | 189 | def non_max_suppression(dets, thresh): 190 | """Pure Python NMS baseline.""" 191 | # x1, y1, x2, y2, score 192 | x1 = dets[:, 0] 193 | y1 = dets[:, 1] 194 | x2 = dets[:, 2] 195 | y2 = dets[:, 3] 196 | scores = dets[:, 4] 197 | 198 | # area of each bounding box 199 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 200 | 201 | # order the score decending 202 | order = scores.argsort()[::-1] 203 | 204 | keep = [] 205 | while order.size > 0: 206 | i = order[0] 207 | keep.append(i) 208 | #计算当前概率最大矩形框与其他矩形框的相交框的坐标 209 | xx1 = np.maximum(x1[i], x1[order[1:]]) 210 | yy1 = np.maximum(y1[i], y1[order[1:]]) 211 | xx2 = np.minimum(x2[i], x2[order[1:]]) 212 | yy2 = np.minimum(y2[i], y2[order[1:]]) 213 | 214 | #计算相交框的面积 215 | w = np.maximum(0.0, xx2 - xx1 + 1) 216 | h = np.maximum(0.0, yy2 - yy1 + 1) 217 | inter = w * h 218 | #计算重叠度IOU:重叠面积/(面积1+面积2-重叠面积) 219 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 220 | 221 | #找到重叠度不高于阈值的矩形框索引 222 | inds = np.where(ovr <= thresh)[0] 223 | #将order序列更新,由于前面得到的矩形框索引要比矩形框在原order序列中的索引小1,所以要把这个1加回来 224 | order = order[inds + 1] 225 | return keep 226 | 227 | 228 | def add_logger(losses, logger, step, split): 229 | losses_name = ["loss", "xy", "wh", "conf", "clss"] 230 | for i, name in enumerate(losses_name): 231 | logger.add_scalar('DET_' + split + '/' + name, losses[i].sum().item() / losses[i].numel(), step) 232 | 233 | return losses[0].sum() / losses[0].numel() 234 | 235 | 236 | def set_tb_logger(log_dir, exp_name): 237 | """ Set up tensorboard logger""" 238 | log_dir = log_dir + '/' + exp_name 239 | # remove previous log with the same name, if not resume 240 | # if not resume and os.path.exists(log_dir): 241 | # import shutil 242 | # try: 243 | # shutil.rmtree(log_dir) 244 | # except: 245 | # warnings.warn('Experiment existed in TensorBoard, but failed to remove') 246 | return SummaryWriter(log_dir=log_dir) 247 | 248 | 249 | def update_values(dict_from, dict_to): 250 | for key, value in dict_from.items(): 251 | if isinstance(value, dict): 252 | update_values(dict_from[key], dict_to[key]) 253 | elif value is not None: 254 | dict_to[key] = dict_from[key] -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import math 4 | import os 5 | import time 6 | from six.moves import cPickle # six.moves is used to self-adjust the change of python 2 and python 3 7 | import yaml 8 | import itertools 9 | from multiprocessing.dummy import Pool as ThreadPool 10 | 11 | import torch 12 | import torch.optim as optim 13 | from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR 14 | 15 | from model import yolo_layer 16 | from utils.utils import get_random_data, preprocess_true_boxes, add_logger, update_values, set_tb_logger 17 | 18 | 19 | class Train(object): 20 | def __init__(self, args, threadpool): 21 | 22 | # initial parameters and model 23 | self.args = args 24 | self.annotation_path = args.annotation_path 25 | self.input_shape = (args.height, args.width) 26 | 27 | self.anchors = args.anchors 28 | self.class_names = args.classes_names 29 | self.num_classes = len(self.class_names) 30 | self.model, self.infos = self._create_model() 31 | self.optimizer = self._get_optimizer(self.args, self.model) 32 | self.threadpool = threadpool 33 | 34 | # mkdir log file 35 | log_name = str(time.strftime("%Y%m%d%H%M%S", time.localtime())) + '_' + '_' + 'bs_' + str(self.args.batch_size) 36 | print("logging to %s ..." % (log_name)) 37 | self.logger = set_tb_logger('logs', log_name) 38 | self._train_pipeline(self.annotation_path) 39 | self.init_lr = args.lr 40 | 41 | def _train_pipeline(self, sample_path): 42 | ''' 43 | The training pipeline including training and validation samples generation, loss backward 44 | and checkpoints saving. 45 | ''' 46 | # random split the data to training and validation samples 47 | data_gen, data_gen_val, max_batch_ind, max_val_batch_ind = self._train_data_generation(sample_path) 48 | 49 | self.args.n_iter = self.infos.get('n_iter', self.args.n_iter) 50 | start_epoch = self.infos.get('epoch', self.args.start_epoch) 51 | best_val_loss = self.infos.get('best_val_loss', 100000) 52 | 53 | # Learning rate self-adjusting depends on the validation loss 54 | scheduler = ReduceLROnPlateau(self.optimizer, factor=0.1, patience=2, verbose=True, eps=1e-12) 55 | # scheduler = StepLR(self.optimizer, step_size=3, gamma=0.6) 56 | 57 | # training approach 58 | for epoch in range(start_epoch, self.args.max_epochs): 59 | self._train(data_gen, max_batch_ind, epoch) 60 | val_loss = self._validation(epoch, data_gen_val, max_val_batch_ind) 61 | scheduler.step(val_loss) 62 | best_flag = False 63 | 64 | if best_val_loss is None or val_loss < best_val_loss: 65 | best_val_loss = val_loss 66 | best_flag = True 67 | 68 | # checkpoints saving 69 | self.checkpoint_save(self.infos, epoch, best_val_loss, best_flag) 70 | 71 | def cos_lr(self, epoch_start, epoch_max, base_lr): 72 | curr_lr = base_lr * (1 + math.cos(math.pi*epoch_start / epoch_max)) / 2 73 | return curr_lr 74 | 75 | def _train(self, data_gen, max_batch_ind, epoch): 76 | ''' 77 | :param data_gen: generator for training data 78 | :param max_batch_ind: maximum slice for training data 79 | :param epoch: 80 | ''' 81 | 82 | torch.set_grad_enabled(True) 83 | self.model.train() 84 | tmp_losses = 0 85 | batch_sum = 0 86 | start = time.time() 87 | for batch_ind in range(max_batch_ind): 88 | data = next(data_gen) 89 | img, label0, label1, label2 = data[0], data[1], data[2], data[3] 90 | img = torch.from_numpy(img).float().cuda() 91 | img = img.view(img.shape[0], img.shape[1], img.shape[2], img.shape[3]).permute(0, 3, 1, 2).contiguous() 92 | label0 = torch.from_numpy(label0) 93 | label1 = torch.from_numpy(label1) 94 | label2 = torch.from_numpy(label2) 95 | 96 | if self.args.use_cuda: 97 | img = img.cuda() 98 | label0 = label0.cuda() 99 | label1 = label1.cuda() 100 | label2 = label2.cuda() 101 | 102 | losses = self.model(img, label0, label1, label2) 103 | loss = add_logger(losses, self.logger, batch_ind + max_batch_ind * epoch, 'train') 104 | loss = loss.sum() / loss.numel() 105 | tmp_losses = tmp_losses + loss.item() 106 | 107 | self.optimizer.zero_grad() 108 | loss.backward() 109 | self.optimizer.step() 110 | 111 | if batch_ind % self.args.display_interval == 0 and batch_ind != 0: 112 | end = time.time() 113 | batch_sum = batch_sum + self.args.display_interval 114 | tmp_losses_show = tmp_losses / batch_sum 115 | print("step {}/{} (epoch {}), loss: {:f} , lr:{:f}, time/batch = {:.3f}" 116 | .format(batch_ind, max_batch_ind, epoch, tmp_losses_show, self.optimizer.param_groups[-1]['lr'], 117 | (end - start)/self.args.display_interval)) 118 | start = end 119 | 120 | def _validation(self, epoch, data_gen, max_val_batch_ind): 121 | torch.set_grad_enabled(False) 122 | self.model.eval() 123 | tmp_losses = 0 124 | 125 | for batch_idx in range(max_val_batch_ind): 126 | data = next(data_gen) 127 | img, label0, label1, label2 = data[0], data[1], data[2], data[3] 128 | img = torch.from_numpy(img).float() 129 | img = img.view(img.shape[0], img.shape[1], img.shape[2], img.shape[3]).permute(0, 3, 1, 2).contiguous() 130 | 131 | label0 = torch.from_numpy(label0) 132 | label1 = torch.from_numpy(label1) 133 | label2 = torch.from_numpy(label2) 134 | 135 | if self.args.use_cuda: 136 | img = img.cuda() 137 | label0 = label0.cuda() 138 | label1 = label1.cuda() 139 | label2 = label2.cuda() 140 | 141 | losses = self.model(img, label0, label1, label2) 142 | loss = add_logger(losses, self.logger, self.args.n_iter, 'val') 143 | tmp_losses = tmp_losses + loss.item() 144 | 145 | tmp_losses = tmp_losses / max_val_batch_ind 146 | print("============================================") 147 | print("Evaluation Loss (epoch {}), TOTAL_LOSS: {:.3f}".format(epoch, tmp_losses)) 148 | print("============================================") 149 | return tmp_losses 150 | 151 | @staticmethod 152 | def _get_optimizer(args, net): 153 | params = [] 154 | for key, value in dict(net.named_parameters()).items(): 155 | if value.requires_grad: 156 | if 'backbone' in key: 157 | params += [{'params':[value], 'lr':args.backbone_lr}] 158 | else: 159 | params += [{'params':[value], 'lr':args.lr}] 160 | 161 | # Initialize optimizer class 162 | if args.optimizer == "adam": 163 | optimizer = optim.Adam(params, weight_decay=args.weight_decay) 164 | elif args.optimizer == "rmsprop": 165 | optimizer = optim.RMSprop(params, weight_decay=args.weight_decay) 166 | else: 167 | # Default to sgd 168 | optimizer = optim.SGD(params, momentum=0.9, weight_decay=args.weight_decay, 169 | nesterov=(args.optimizer == "nesterov")) 170 | return optimizer 171 | 172 | def _create_model(self): 173 | model = yolo_layer.yolov3layer(self.args) 174 | infos = {} 175 | if self.args.start_from != '': 176 | if self.args.load_best_score == 1: 177 | model_path = os.path.join(self.args.start_from, 'model-best.pth') 178 | info_path = os.path.join(self.args.start_from, 'infos-best.pkl') 179 | else: 180 | model_path = os.path.join(self.args.start_from, 'model.pth') 181 | info_path = os.path.join(self.args.start_from, 'infos.pkl') 182 | 183 | if os.path.exists(info_path): 184 | with open(info_path, 'rb') as f: 185 | infos = cPickle.load(f) 186 | 187 | print('Loading the model from %s ...' %(model_path)) 188 | model.load_state_dict(torch.load(model_path)) 189 | 190 | if self.args.use_cuda: 191 | if self.args.mGPUs: 192 | model = torch.nn.DataParallel(model).cuda() 193 | else: 194 | model = model.cuda() 195 | return model, infos 196 | 197 | def _data_generator(self, annotation_lines, batch_size, input_shape, anchors, num_classes): 198 | '''data generator for fit_generator''' 199 | n = len(annotation_lines) 200 | i = 0 201 | while True: 202 | if i + batch_size > n: 203 | np.random.shuffle(annotation_lines) 204 | i = 0 205 | 206 | # Separate all image reader into different threads can speed up the program 207 | output = self.threadpool.starmap(get_random_data, zip(annotation_lines[i:i + batch_size], 208 | itertools.repeat(input_shape, batch_size))) 209 | image_data = list(zip(*output))[0] 210 | box_data = list(zip(*output))[1] 211 | i = i + batch_size 212 | image_data = np.array(image_data) 213 | box_data = np.array(box_data) 214 | y_true = preprocess_true_boxes(box_data, input_shape, anchors, num_classes) 215 | 216 | yield [image_data, *y_true] 217 | 218 | def _data_generator_wrapper(self, annotation_lines, batch_size, input_shape, anchors, num_classes): 219 | n = len(annotation_lines) 220 | if n == 0 or batch_size <= 0: return None 221 | return self._data_generator(annotation_lines, batch_size, input_shape, anchors, num_classes) 222 | 223 | def _train_data_generation(self, sample_path): 224 | 225 | with open(sample_path) as f: 226 | lines = f.readlines() 227 | np.random.seed(10101) 228 | np.random.shuffle(lines) 229 | np.random.seed(None) 230 | val_split = self.args.val_split 231 | num_val = int(len(lines) * val_split) 232 | num_train = len(lines) - num_val 233 | 234 | batch_size = self.args.batch_size 235 | data_gen = self._data_generator_wrapper(lines[:num_train], batch_size, self.input_shape, self.anchors, self.num_classes) 236 | data_gen_val = self._data_generator_wrapper(lines[num_train:], batch_size, self.input_shape, self.anchors, self.num_classes) 237 | max_batch_ind = int(num_train / batch_size) 238 | max_val_batch_ind = int(num_val / batch_size) 239 | 240 | return data_gen, data_gen_val, max_batch_ind, max_val_batch_ind 241 | 242 | def checkpoint_save(self, infos, epoch, best_val_loss, best_flag=False): 243 | 244 | checkpoint_path = os.path.join(self.args.save_path, self.args.backbones_network) 245 | 246 | if not os.path.exists(checkpoint_path): 247 | os.makedirs(checkpoint_path) 248 | 249 | if self.args.mGPUs > 1: 250 | torch.save(self.model.module.state_dict(), os.path.join(checkpoint_path, 'model.pth')) 251 | else: 252 | torch.save(self.model.state_dict(), os.path.join(checkpoint_path, 'model.pth')) 253 | 254 | print("model saved to {}".format(checkpoint_path)) 255 | infos['n_iter'] = self.args.n_iter 256 | infos['epoch'] = epoch 257 | infos['best_val_loss'] = best_val_loss 258 | infos['opt'] = self.args 259 | 260 | with open(os.path.join(checkpoint_path, 'infos.pkl'), 'wb') as f: 261 | cPickle.dump(infos, f) 262 | 263 | if best_flag: 264 | if self.args.mGPUs > 1: 265 | torch.save(self.model.module.state_dict(), os.path.join(checkpoint_path, 'model-best.pth')) 266 | else: 267 | torch.save(self.model.state_dict(), os.path.join(checkpoint_path, 'model-best.pth')) 268 | 269 | print("model saved to {} with best total loss {:.3f}".format(os.path.join(checkpoint_path, \ 270 | 'model-best.pth'), best_val_loss)) 271 | 272 | with open(os.path.join(checkpoint_path, 'infos-best.pkl'), 'wb') as f: 273 | cPickle.dump(infos, f) 274 | 275 | 276 | def main(): 277 | args = make_args() 278 | with open(args.cfg_path, 'r') as handle: 279 | options_yaml = yaml.load(handle) 280 | update_values(options_yaml, vars(args)) 281 | 282 | # set random seed to cpu and gpu 283 | if args.seed: 284 | torch.manual_seed(args.seed) 285 | if args.use_cuda: 286 | torch.cuda.manual_seed(args.seed) 287 | 288 | try: 289 | threadpool = ThreadPool(args.batch_size) 290 | except Exception as e: 291 | print(e) 292 | exit(1) 293 | 294 | Train(args, threadpool) 295 | 296 | 297 | def make_args(): 298 | # load the optional parameters and update new arguments 299 | parser = argparse.ArgumentParser() 300 | # # Data input settings 301 | parser.add_argument('--cfg_path', type=str, default='cfgs/Yolo_train.yml', help='load config') 302 | parser.add_argument('--use_cuda', type=bool, default=True, help='whether use gpu.') 303 | parser.add_argument('--mGPUs', type=bool, default=False, help='whether use mgpu.') 304 | return parser.parse_args() 305 | 306 | 307 | if __name__ == '__main__': 308 | 309 | main() 310 | 311 | -------------------------------------------------------------------------------- /model/yolo_layer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from .backbone import backbone_fn 5 | from collections import OrderedDict 6 | from utils.utils import non_max_suppression 7 | 8 | 9 | class yolov3layer(nn.Module): 10 | ''' 11 | Detection Decoder followed yolo v3. 12 | ''' 13 | 14 | def __init__(self, args): 15 | super(yolov3layer, self).__init__() 16 | self.args = args 17 | 18 | self.backbone = backbone_fn(args) 19 | _out_filters = self.backbone.layers_out_filters 20 | self.num_classes = len(args.classes_names) 21 | final_out_filter0 = 3 * (5 + self.num_classes) 22 | 23 | self.embedding0 = self._make_embedding([512, 1024], _out_filters[-1], final_out_filter0) 24 | # embedding1 25 | final_out_filter1 = 3 * (5 + self.num_classes) 26 | self.embedding1_cbl = self._make_cbl(512, 256, 1) 27 | # self.embedding1_upsample = nn.Upsample(scale_factor=2, mode='nearest') 28 | self.embedding1 = self._make_embedding([256, 512], _out_filters[-2] + 256, final_out_filter1) 29 | # embedding2 30 | final_out_filter2 = 3 * (5 + self.num_classes) 31 | self.embedding2_cbl = self._make_cbl(256, 128, 1) 32 | # self.embedding2_upsample = nn.Upsample(scale_factor=2, mode='nearest') 33 | self.embedding2 = self._make_embedding([128, 256], _out_filters[-3] + 128, final_out_filter2) 34 | 35 | self.anchors = np.array(args.anchors) 36 | self.num_layers = len(self.anchors) // 3 37 | 38 | # initlize the loss function here. 39 | self.loss = yolo_loss(args) 40 | 41 | def _make_cbl(self, _in, _out, ks): 42 | ''' cbl = conv + batch_norm + leaky_relu 43 | ''' 44 | pad = (ks - 1) // 2 if ks else 0 45 | return nn.Sequential(OrderedDict([ 46 | ("conv", nn.Conv2d(_in, _out, kernel_size=ks, stride=1, padding=pad, bias=False)), 47 | ("bn", nn.BatchNorm2d(_out)), 48 | ("relu", nn.LeakyReLU(0.1)), 49 | ])) 50 | 51 | def _make_embedding(self, filters_list, in_filters, out_filter): 52 | m = nn.ModuleList([ 53 | self._make_cbl(in_filters, filters_list[0], 1), 54 | self._make_cbl(filters_list[0], filters_list[1], 3), 55 | self._make_cbl(filters_list[1], filters_list[0], 1), 56 | self._make_cbl(filters_list[0], filters_list[1], 3), 57 | self._make_cbl(filters_list[1], filters_list[0], 1), 58 | self._make_cbl(filters_list[0], filters_list[1], 3)]) 59 | m.add_module("conv_out", nn.Conv2d(filters_list[1], out_filter, kernel_size=1, 60 | stride=1, padding=0, bias=True)) 61 | return m 62 | 63 | def _branch(self, _embedding, _in): 64 | for i, e in enumerate(_embedding): 65 | _in = e(_in) 66 | if i == 4: 67 | out_branch = _in 68 | return _in, out_branch 69 | 70 | def forward(self, img, label0, label1, label2): 71 | 72 | if self.args.backbone_lr == 0: 73 | with torch.no_grad(): 74 | x2, x1, x0 = self.backbone(img) 75 | else: 76 | 77 | x2, x1, x0 = self.backbone(img) 78 | 79 | out0, out0_branch = self._branch(self.embedding0, x0) 80 | # yolo branch 1 81 | x1_in = self.embedding1_cbl(out0_branch) 82 | x1_in = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)(x1_in) 83 | x1_in = torch.cat([x1_in, x1], 1) 84 | out1, out1_branch = self._branch(self.embedding1, x1_in) 85 | # yolo branch 2 86 | x2_in = self.embedding2_cbl(out1_branch) 87 | # x2_in = self.embedding2_upsample(x2_in) 88 | x2_in = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)(x2_in) 89 | x2_in = torch.cat([x2_in, x2], 1) 90 | out2, out2_branch = self._branch(self.embedding2, x2_in) 91 | 92 | loss = self.loss((out0, out1, out2), (label0, label1, label2)) 93 | 94 | return loss 95 | 96 | def detect(self, img, ori_shape): 97 | 98 | with torch.no_grad(): 99 | x2, x1, x0 = self.backbone(img) 100 | # forward the decoder block 101 | out0, out0_branch = self._branch(self.embedding0, x0) 102 | # yolo branch 1 103 | x1_in = self.embedding1_cbl(out0_branch) 104 | x1_in = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)(x1_in) 105 | x1_in = torch.cat([x1_in, x1], 1) 106 | out1, out1_branch = self._branch(self.embedding1, x1_in) 107 | # yolo branch 2 108 | x2_in = self.embedding2_cbl(out1_branch) 109 | x2_in = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)(x2_in) 110 | x2_in = torch.cat([x2_in, x2], 1) 111 | out2, out2_branch = self._branch(self.embedding2, x2_in) 112 | 113 | dets_, images_, classes_= yolo_eval((out0, out1, out2), self.anchors, self.num_classes, ori_shape) 114 | 115 | return dets_, images_, classes_ 116 | 117 | 118 | def yolo_boxes_and_scores(feats, anchors, num_classes, input_shape, image_shape): 119 | '''Process Conv layer output''' 120 | 121 | box_xy, box_wh, box_confidence, box_class_probs = yolo_head(feats, 122 | anchors, num_classes, input_shape) 123 | 124 | boxes = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape) 125 | boxes = boxes.view([-1, 4]) 126 | 127 | box_scores = box_confidence * box_class_probs 128 | box_scores = box_scores.view(-1, num_classes) 129 | return boxes.view(feats.size(0), -1, 4), box_scores.view(feats.size(0), -1, num_classes) 130 | 131 | 132 | def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape): 133 | '''Get corrected boxes''' 134 | image_shape = image_shape.cpu() 135 | box_yx = torch.stack((box_xy[..., 1], box_xy[..., 0]), dim=4) 136 | box_hw = torch.stack((box_wh[..., 1], box_wh[..., 0]), dim=4) 137 | 138 | new_shape = torch.round(image_shape * torch.min(input_shape / image_shape)) 139 | offset = (input_shape - new_shape) / 2. / input_shape 140 | scale = input_shape / new_shape 141 | box_yx = (box_yx - offset) * scale 142 | box_hw *= scale 143 | 144 | box_mins = box_yx - (box_hw / 2.) 145 | box_maxes = box_yx + (box_hw / 2.) 146 | 147 | boxes = torch.stack([ 148 | box_mins[..., 0], # y_min 149 | box_mins[..., 1], # x_min 150 | box_maxes[..., 0], # y_max 151 | box_maxes[..., 1] # x_max 152 | ], dim=4) 153 | 154 | # Scale boxes back to original image shape. 155 | boxes *= torch.cat([image_shape, image_shape]).view(1, 1, 1, 1, 4) 156 | return boxes 157 | 158 | 159 | def yolo_eval(yolo_outputs, 160 | anchors, 161 | num_classes, 162 | image_shape, 163 | score_threshold=.2, 164 | nms_threshold=.3): 165 | """Evaluate YOLO model on given input and return filtered boxes.""" 166 | yolo_outputs = yolo_outputs 167 | num_layers = len(yolo_outputs) 168 | max_per_image = 100 169 | anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] if num_layers == 3 else [[3, 4, 5], [1, 2, 3]] # default setting 170 | input_shape = torch.Tensor([yolo_outputs[0].shape[2] * 32, yolo_outputs[0].shape[3] * 32]).type_as(yolo_outputs[0]) 171 | input_shape = input_shape.cpu() 172 | boxes = [] 173 | box_scores = [] 174 | 175 | # output all the boxes and scores in two lists 176 | for l in range(num_layers): 177 | _boxes, _box_scores = yolo_boxes_and_scores(yolo_outputs[l], 178 | anchors[anchor_mask[l]], num_classes, input_shape, image_shape) 179 | boxes.append(_boxes.cpu()) 180 | box_scores.append(_box_scores.cpu()) 181 | # concatenate data based on batch size 182 | boxes = torch.cat(boxes, dim=1) # torch.Size([1, 10647, 4]) 183 | box_scores = torch.cat(box_scores, dim=1) # torch.Size([1, 10647, num_classes]) 184 | dets_ = [] 185 | classes_ = [] 186 | images_ = [] 187 | for i in range(boxes.size(0)): 188 | mask = box_scores[i] >= score_threshold 189 | img_dets = [] 190 | img_classes = [] 191 | img_images = [] 192 | for c in range(num_classes): 193 | # tf.boolean_mask(boxes, mask[:, c]) 194 | class_boxes = boxes[i][mask[:, c]] 195 | if len(class_boxes) == 0: 196 | continue 197 | class_box_scores = box_scores[i][:, c][mask[:, c]] 198 | _, order = torch.sort(class_box_scores, 0, True) 199 | # do nms here. 200 | cls_dets = torch.cat((class_boxes, class_box_scores.view(-1, 1)), 1) 201 | cls_dets = cls_dets[order] 202 | keep = non_max_suppression(cls_dets.cpu().numpy(), nms_threshold) 203 | keep = torch.from_numpy(np.array(keep)) 204 | cls_dets = cls_dets[keep.view(-1).long()] 205 | 206 | img_dets.append(cls_dets) 207 | img_classes.append(torch.ones(cls_dets.size(0)) * c) 208 | img_images.append(torch.ones(cls_dets.size(0)) * i) 209 | # Limit to max_per_image detections *over all classes* 210 | if len(img_dets) > 0: 211 | img_dets = torch.cat(img_dets, dim=0) 212 | img_classes = torch.cat(img_classes, dim=0) 213 | img_images = torch.cat(img_images, dim=0) 214 | 215 | if max_per_image > 0: 216 | if img_dets.size(0) > max_per_image: 217 | _, order = torch.sort(img_dets[:, 4], 0, True) 218 | keep = order[:max_per_image] 219 | img_dets = img_dets[keep] 220 | img_classes = img_classes[keep] 221 | img_images = img_images[keep] 222 | 223 | dets_.append(img_dets) 224 | classes_.append(img_classes) 225 | images_.append(img_images) 226 | 227 | if not dets_: 228 | return torch.Tensor(dets_), torch.Tensor(classes_), torch.Tensor(images_) 229 | dets_ = torch.cat(dets_, dim=0) 230 | images_ = torch.cat(images_, dim=0) 231 | classes_ = torch.cat(classes_, dim=0) 232 | 233 | return dets_, images_, classes_ 234 | 235 | 236 | def box_iou(b1, b2): 237 | '''Return iou tensor 238 | 239 | Parameters 240 | ---------- 241 | b1: tensor, shape=(i1,...,iN, 4), xywh 242 | b2: tensor, shape=(j, 4), xywh 243 | 244 | Returns 245 | ------- 246 | iou: tensor, shape=(i1,...,iN, j) 247 | 248 | ''' 249 | 250 | # Expand dim to apply broadcasting. 251 | b1 = b1.unsqueeze(3) 252 | 253 | b1_xy = b1[..., :2] 254 | b1_wh = b1[..., 2:4] 255 | b1_wh_half = b1_wh / 2. 256 | b1_mins = b1_xy - b1_wh_half 257 | b1_maxes = b1_xy + b1_wh_half 258 | 259 | # if b2 is an empty tensor: then iou is empty 260 | if b2.shape[0] == 0: 261 | iou = torch.zeros(b1.shape[0:4]).type_as(b1) 262 | else: 263 | b2 = b2.view(1, 1, 1, b2.size(0), b2.size(1)) 264 | # Expand dim to apply broadcasting. 265 | b2_xy = b2[..., :2] 266 | b2_wh = b2[..., 2:4] 267 | b2_wh_half = b2_wh / 2. 268 | b2_mins = b2_xy - b2_wh_half 269 | b2_maxes = b2_xy + b2_wh_half 270 | 271 | intersect_mins = torch.max(b1_mins, b2_mins) 272 | intersect_maxes = torch.min(b1_maxes, b2_maxes) 273 | intersect_wh = torch.clamp(intersect_maxes - intersect_mins, min=0) 274 | 275 | intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] 276 | b1_area = b1_wh[..., 0] * b1_wh[..., 1] 277 | b2_area = b2_wh[..., 0] * b2_wh[..., 1] 278 | iou = intersect_area / (b1_area + b2_area - intersect_area) 279 | 280 | return iou 281 | 282 | 283 | def yolo_head(feats, anchors, num_classes, input_shape, calc_loss=False): 284 | if not calc_loss: 285 | feats = feats.cpu() 286 | input_shape = input_shape.cpu() 287 | num_anchors = len(anchors) 288 | anchors_tensor = torch.from_numpy(anchors).view(1, 1, 1, num_anchors, 2).type_as(feats) 289 | grid_shape = (feats.shape[2:4]) 290 | 291 | grid_y = torch.arange(0, grid_shape[0]).view(-1, 1, 1, 1).expand(grid_shape[0], grid_shape[0], 1, 1) 292 | grid_x = torch.arange(0, grid_shape[1]).view(1, -1, 1, 1).expand(grid_shape[1], grid_shape[1], 1, 1) 293 | 294 | grid = torch.cat([grid_x, grid_y], dim=3).unsqueeze(0).type_as(feats) 295 | 296 | feats = feats.view(-1, num_anchors, num_classes + 5, grid_shape[0], \ 297 | grid_shape[1]).permute(0, 3, 4, 1, 2).contiguous() 298 | 299 | # Adjust preditions to each spatial grid point and anchor size. 300 | box_xy = (torch.sigmoid(feats[..., :2]) + grid) / torch.tensor(grid_shape).view(1, 1, 1, 1, 2).type_as(feats) # 301 | box_wh = torch.exp(feats[..., 2:4]) * anchors_tensor / input_shape.view(1, 1, 1, 1, 2) 302 | 303 | box_confidence = torch.sigmoid(feats[..., 4:5]) 304 | box_class_probs = torch.sigmoid(feats[..., 5:]) 305 | 306 | if calc_loss == True: 307 | return grid, feats, box_xy, box_wh 308 | return box_xy, box_wh, box_confidence, box_class_probs 309 | 310 | 311 | class yolo_loss(nn.Module): 312 | def __init__(self, args): 313 | super(yolo_loss, self).__init__() 314 | 315 | self.args = args 316 | self.anchors = np.array(args.anchors) 317 | self.num_layers = len(self.anchors) // 3 318 | self.anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] if self.num_layers == 3 else [[3, 4, 5], [1, 2, 3]] 319 | self.num_classes = len(args.classes_names) 320 | self.ignore_thresh = 0.5 321 | 322 | self.mse_loss = nn.MSELoss(reduce=False) 323 | self.bce_loss = nn.BCEWithLogitsLoss(reduce=False) 324 | 325 | def forward(self, yolo_outputs, y_true): 326 | input_shape = torch.Tensor([yolo_outputs[0].shape[2] * 32, yolo_outputs[0].shape[3] * 32]).type_as(yolo_outputs[0]) 327 | grid_shapes = [torch.Tensor([output.shape[2], output.shape[3]]).type_as(yolo_outputs[0]) for output in yolo_outputs] 328 | 329 | bs = yolo_outputs[0].size(0) 330 | loss_xy = 0 331 | loss_wh = 0 332 | loss_conf = 0 333 | loss_clss = 0 334 | 335 | for l in range(self.num_layers): 336 | object_mask = y_true[l][..., 4:5] 337 | true_class_probs = y_true[l][..., 5:] 338 | grid, raw_pred, pred_xy, pred_wh = yolo_head(yolo_outputs[l], self.anchors[self.anchor_mask[l]], 339 | self.num_classes, input_shape, calc_loss=True) 340 | pred_box = torch.cat([pred_xy, pred_wh], dim=4) 341 | # Darknet raw box to calculate loss. 342 | raw_true_xy = y_true[l][..., :2] * grid_shapes[l].view(1, 1, 1, 1, 2) - grid 343 | raw_true_wh = torch.log(y_true[l][..., 2:4] / torch.Tensor(self.anchors[self.anchor_mask[l]]). 344 | type_as(pred_box).view(1, 1, 1, self.num_layers, 2) * 345 | input_shape.view(1, 1, 1, 1, 2)) 346 | raw_true_wh.masked_fill_(object_mask.expand_as(raw_true_wh) == 0, 0) 347 | box_loss_scale = 2 - y_true[l][..., 2:3] * y_true[l][..., 3:4] 348 | 349 | # Find ignore mask, iterate over each of batch. 350 | # ignore_mask = tf.TensorArray(K.dtype(y_true[0]), size=1, dynamic_size=True) 351 | best_ious = [] 352 | for b in range(bs): 353 | true_box = y_true[l][b, ..., 0:4][object_mask[b, ..., 0] == 1] 354 | iou = box_iou(pred_box[b], true_box) 355 | best_iou, _ = torch.max(iou, dim=3) 356 | best_ious.append(best_iou) 357 | 358 | best_ious = torch.stack(best_ious, dim=0).unsqueeze(4) 359 | ignore_mask = (best_ious < self.ignore_thresh).float() 360 | 361 | # binary_crossentropy is helpful to avoid exp overflow. 362 | xy_loss = torch.sum(object_mask * box_loss_scale * self.bce_loss(raw_pred[..., 0:2], raw_true_xy)) / bs 363 | wh_loss = torch.sum(object_mask * box_loss_scale * self.mse_loss(raw_pred[..., 2:4], raw_true_wh)) / bs 364 | confidence_loss = (torch.sum(self.bce_loss(raw_pred[..., 4:5], object_mask) * object_mask + 365 | (1 - object_mask) * self.bce_loss(raw_pred[..., 4:5], object_mask) * ignore_mask)) / bs 366 | class_loss = torch.sum(object_mask * self.bce_loss(raw_pred[..., 5:], true_class_probs)) / bs 367 | 368 | loss_xy += xy_loss 369 | loss_wh += wh_loss 370 | loss_conf += confidence_loss 371 | loss_clss += class_loss 372 | 373 | loss = loss_xy + loss_wh + loss_conf + loss_clss 374 | # print('loss %.3f, xy %.3f, wh %.3f, conf %.3f, class_loss: %.3f' 375 | # %(loss.item(), xy_loss.item(), wh_loss.item(), confidence_loss.item(), class_loss.item())) 376 | 377 | return loss.unsqueeze(0), loss_xy.unsqueeze(0), loss_wh.unsqueeze(0), loss_conf.unsqueeze(0), \ 378 | loss_clss.unsqueeze(0) -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | logs 99 | save_interval 100 | 101 | 102 | 103 | 105 | 106 | 113 | 114 | 115 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 |