├── .gitignore ├── README.md ├── config.py ├── data └── README.md ├── data_loader ├── __init__.py ├── bird_dataset.py ├── coco_dataset.py ├── data_transform_base.py ├── dataset_base.py ├── switch_dataset.py ├── udacity_dataset.py ├── voc_dataset.py └── wheat_dataset.py ├── dataset_params.py ├── docs └── images │ ├── labels.gif │ └── test_result.jpg ├── environment.yml ├── estimate_priors_size_dataset.py ├── export_onnx.py ├── inference_onnx.py ├── models ├── __init__.py ├── functions │ ├── __init__.py │ ├── detect.py │ ├── mish.py │ ├── new_types.py │ └── utils.py ├── layers │ ├── __init__.py │ └── backbone.py ├── yolo_layer.py └── yolo_v3.py ├── requirements.txt ├── scripts ├── download_bird_dataset.py ├── download_darknet_weight.py ├── download_darknet_weight.sh ├── download_switch_dataset.py ├── download_udacity_dataset.sh └── download_voc_dataset.sh ├── test.py ├── test ├── get_dataset_size_distribution.py ├── test_model.py └── test_visualization.py ├── train.py ├── trainer ├── __init__.py ├── trainer.py └── trainer_base.py └── utils ├── __init__.py ├── download_utility.py ├── kmeans_bboxes.py ├── priors_bboxes.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled source # 2 | ################### 3 | *.com 4 | *.class 5 | *.dll 6 | *.exe 7 | *.o 8 | *.so 9 | *.pyc 10 | .ipynb_checkpoints 11 | *~ 12 | *# 13 | build* 14 | 15 | # Packages # 16 | ################### 17 | # it's better to unpack these files and commit the raw source 18 | # git has its own built in compression methods 19 | *.7z 20 | *.dmg 21 | *.gz 22 | *.iso 23 | *.jar 24 | *.rar 25 | *.tar 26 | *.zip 27 | 28 | # Logs and databases # 29 | ###################### 30 | *.log 31 | *.sql 32 | *.sqlite 33 | 34 | # OS generated files # 35 | ###################### 36 | .DS_Store 37 | .DS_Store? 38 | ._* 39 | .Spotlight-V100 40 | .Trashes 41 | ehthumbs.db 42 | Thumbs.db 43 | 44 | # Images 45 | ###################### 46 | *.jpg 47 | *.gif 48 | *.png 49 | *.svg 50 | *.ico 51 | 52 | # Video 53 | ###################### 54 | *.wmv 55 | *.mpg 56 | *.mpeg 57 | *.mp4 58 | *.mov 59 | *.flv 60 | *.avi 61 | *.ogv 62 | *.ogg 63 | *.webm 64 | 65 | # Audio 66 | ###################### 67 | *.wav 68 | *.mp3 69 | *.wma 70 | 71 | # Fonts 72 | ###################### 73 | Fonts 74 | *.eot 75 | *.ttf 76 | *.woff 77 | 78 | # Format 79 | ###################### 80 | CPPLINT.cfg 81 | .clang-format 82 | 83 | # Gtags 84 | ###################### 85 | GPATH 86 | GRTAGS 87 | GSYMS 88 | GTAGS 89 | 90 | data/ 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 2 | 3 | # Pytorch Implementation of [Yolov3](https://pjreddie.com/media/files/papers/YOLOv3.pdf) For Bird Detection 4 | *** 5 | 6 | This project provides a dataset for wild birds and yolov3 implementation in pytorch for training the dataset. This bird detection dataset is special in the sense that it also provides the dense labels of birds in flock. 7 | The images of birds are collected from the internet, partly by crawling. Label samples can be seen as followings. 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 |
Label Samples
17 | 18 | ## TODO ## 19 | *** 20 | 21 | - [x] Train on Bird Dataset 22 | - [x] Export onnx weight and test inferencing on onnx weight 23 | - [x] Train on multiple scales 24 | - [x] Mish activation 25 | - [x] Onnx Model 26 | 27 | ## Preparation ## 28 | *** 29 | 30 | ```bash 31 | python3 -m pip install -r requirements.txt 32 | ``` 33 | 34 | - Download darknet53 backbone trained on imagenet dataset 35 | ```bash 36 | python3 scripts/download_darknet_weight.py 37 | ``` 38 | 39 | After running this script, darknet53.conv.74 weights will be saved inside save_models directory. 40 | 41 | - Download bird dataset 42 | ```bash 43 | python3 scripts/download_bird_dataset.py 44 | ``` 45 | 46 | The bird dataset will be saved and extracted in data directory 47 | 48 | ## Scripts ## 49 | *** 50 | 51 | - Training (details for parameters please see train.py script) 52 | ```bash 53 | python3 train.py --dataset bird_dataset --backbone_weight_path ./saved_models/darknet53.conv.74 54 | ``` 55 | 56 | Weights will be saved inside save_models directory. 57 | 58 | - Testing 59 | ```bash 60 | python3 test.py --dataset bird_dataset --snapshot [path/to/snapshot] --image_path [path/to/image] --conf_thresh [confidence/thresh] --nms_thresh [nms/thresh] 61 | ``` 62 | 63 | A sample trained weight can be download from [HERE](https://drive.google.com/file/d/1DkxLsReA-vEjjWG5jTtzL2gb_kQEZe6b/view?usp=sharing) 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 |
Test Result
73 | 74 | - Export to onnx model 75 | ```bash 76 | python3 export_onnx.py --dataset bird_dataset --snapshot [path/to/weight/snapshot] --batch_size [batch/size] --onnx_weight_file [output/onnx/file] 77 | ``` 78 | 79 | - Inferece with onnx 80 | ```bash 81 | python3 inference_onnx.py --dataset bird_dataset --img_h [img/input/height] --img_w [img/input/width] --image_path [image/path] --onnx_weight_file [onnx/weight] --conf_thresh [confidence/threshold] --nms_thresh [nms_threshold] 82 | ``` 83 | 84 | ## References ## 85 | *** 86 | 87 | - [Yolov3](https://pjreddie.com/media/files/papers/YOLOv3.pdf) 88 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import logging 5 | from dataset_params import dataset_params 6 | from data_loader import BirdDataset, UdacityDataset, VOCDataset, COCODataset, SwitchDataset, WheatDataset 7 | 8 | 9 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) 10 | 11 | 12 | class Config(object): 13 | DATASETS = { 14 | "bird_dataset": BirdDataset, 15 | "udacity_dataset": UdacityDataset, 16 | "voc_dataset": VOCDataset, 17 | "coco_dataset": COCODataset, 18 | "switch_dataset": SwitchDataset, 19 | "wheat_dataset": WheatDataset, 20 | } 21 | 22 | DATASET_PARAMS = dataset_params 23 | 24 | def __init__(self): 25 | self.CURRENT_DIR = _CURRENT_DIR 26 | 27 | self.DATA_PATH = os.path.abspath(os.path.join(_CURRENT_DIR, "data")) 28 | 29 | self.SAVED_MODEL_PATH = os.path.join(self.CURRENT_DIR, "saved_models") 30 | if not os.path.isdir(self.SAVED_MODEL_PATH): 31 | os.system("mkdir -p {}".format(self.SAVED_MODEL_PATH)) 32 | 33 | self.LOG_PATH = os.path.join(self.CURRENT_DIR, "logs") 34 | if not os.path.isdir(self.LOG_PATH): 35 | os.system("mkdir -p {}".format(self.LOG_PATH)) 36 | _config_logging(log_file=os.path.join(self.LOG_PATH, "log.txt")) 37 | 38 | def display(self): 39 | """ 40 | Display Configuration values. 41 | """ 42 | print("\nConfigurations:") 43 | for a in dir(self): 44 | if not a.startswith("__") and not callable(getattr(self, a)): 45 | print("{:30} {}".format(a, getattr(self, a))) 46 | print("\n") 47 | 48 | 49 | def _config_logging(log_file, log_level=logging.DEBUG): 50 | import sys 51 | 52 | format_line = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)" 53 | custom_formatter = CustomFormatter(format_line) 54 | stream_handler = logging.StreamHandler() 55 | file_handler = logging.FileHandler(log_file) 56 | 57 | stream_handler.setFormatter(custom_formatter) 58 | 59 | logging.basicConfig(handlers=[file_handler, stream_handler], level=log_level, format=format_line) 60 | 61 | 62 | class CustomFormatter(logging.Formatter): 63 | def format(self, record, *args, **kwargs): 64 | import copy 65 | 66 | LOG_COLORS = { 67 | logging.INFO: "\x1b[33m", 68 | logging.DEBUG: "\x1b[36m", 69 | logging.WARNING: "\x1b[31m", 70 | logging.ERROR: "\x1b[31;1m", 71 | logging.CRITICAL: "\x1b[35m", 72 | } 73 | 74 | new_record = copy.copy(record) 75 | if new_record.levelno in LOG_COLORS: 76 | new_record.levelname = "{color_begin}{level}{color_end}".format( 77 | level=new_record.levelname, color_begin=LOG_COLORS[new_record.levelno], color_end="\x1b[0m", 78 | ) 79 | return super(CustomFormatter, self).format(new_record, *args, **kwargs) 80 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmba15/yolov3_pytorch/993d9bf966965cb2f7800da2cb3b88ce1ea17f51/data/README.md -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from .udacity_dataset import UdacityDataset 4 | from .bird_dataset import BirdDataset 5 | from .voc_dataset import VOCDataset 6 | from .coco_dataset import COCODataset 7 | from .switch_dataset import SwitchDataset 8 | from .wheat_dataset import WheatDataset 9 | from .data_transform_base import DataTransformBase 10 | -------------------------------------------------------------------------------- /data_loader/bird_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import cv2 5 | import numpy as np 6 | import json 7 | from .dataset_base import DatasetBase, DatasetConfigBase 8 | 9 | 10 | class BirdDatasetConfig(DatasetConfigBase): 11 | def __init__(self): 12 | super(BirdDatasetConfig, self).__init__() 13 | 14 | self.CLASSES = [ 15 | "bird", 16 | ] 17 | 18 | self.COLORS = DatasetConfigBase.generate_color_chart(self.num_classes) 19 | 20 | 21 | _bird_config = BirdDatasetConfig() 22 | 23 | 24 | class BirdDataset(DatasetBase): 25 | __name__ = "bird_dataset" 26 | 27 | def __init__( 28 | self, 29 | data_path, 30 | classes=_bird_config.CLASSES, 31 | colors=_bird_config.COLORS, 32 | phase="train", 33 | transform=None, 34 | shuffle=True, 35 | random_seed=2000, 36 | normalize_bbox=False, 37 | bbox_transformer=None, 38 | multiscale=False, 39 | resize_after_batch_num=10, 40 | ): 41 | super(BirdDataset, self).__init__( 42 | data_path, 43 | classes=classes, 44 | colors=colors, 45 | phase=phase, 46 | transform=transform, 47 | shuffle=shuffle, 48 | normalize_bbox=normalize_bbox, 49 | bbox_transformer=bbox_transformer, 50 | multiscale=multiscale, 51 | resize_after_batch_num=resize_after_batch_num, 52 | ) 53 | 54 | assert os.path.isdir(data_path) 55 | assert phase in ("train", "val", "test") 56 | 57 | self._data_path = os.path.join(data_path, "bird_dataset") 58 | assert os.path.isdir(self._data_path) 59 | 60 | self._phase = phase 61 | self._transform = transform 62 | 63 | if self._phase == "test": 64 | self._image_path_base = os.path.join(self._data_path, "test") 65 | self._image_paths = sorted( 66 | [os.path.join(self._image_path_base, image_path) for image_path in os.listdir(self._image_path_base)] 67 | ) 68 | else: 69 | self._train_path_base = os.path.join(self._data_path, "train") 70 | self._val_path_base = os.path.join(self._data_path, "val") 71 | 72 | trainval_dict = {"train": {"path": self._train_path_base}, "val": {"path": self._val_path_base}} 73 | 74 | data_path = trainval_dict[self._phase]["path"] 75 | all_image_paths = DatasetBase.get_all_files_with_format_from_path(data_path, ".jpg") 76 | all_json_paths = DatasetBase.get_all_files_with_format_from_path(data_path, ".json") 77 | 78 | assert len(all_image_paths) == len(all_json_paths) 79 | 80 | self._image_paths = [os.path.join(data_path, elem) for elem in all_image_paths] 81 | self._targets = [self._load_one_json(os.path.join(data_path, elem)) for elem in all_json_paths] 82 | 83 | def _load_one_json(self, json_path): 84 | bboxes = [] 85 | labels = [] 86 | 87 | p_json = json.load(open(json_path, "r")) 88 | 89 | for obj in p_json["objects"]: 90 | bboxes.append(obj["boundingbox"]) 91 | label_text = obj["label"] 92 | 93 | labels.append(0) 94 | 95 | return [bboxes, labels] 96 | -------------------------------------------------------------------------------- /data_loader/coco_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | from .dataset_base import DatasetBase, DatasetConfigBase 5 | from pycocotools.coco import COCO 6 | 7 | 8 | class COCODatasetConfig(DatasetConfigBase): 9 | def __init__(self): 10 | super(COCODatasetConfig, self).__init__() 11 | 12 | self.CLASSES = [ 13 | "person", 14 | "bicycle", 15 | "car", 16 | "motorbike", 17 | "aeroplane", 18 | "bus", 19 | "train", 20 | "truck", 21 | "boat", 22 | "traffic light", 23 | "fire hydrant", 24 | "stop sign", 25 | "parking meter", 26 | "bench", 27 | "bird", 28 | "cat", 29 | "dog", 30 | "horse", 31 | "sheep", 32 | "cow", 33 | "elephant", 34 | "bear", 35 | "zebra", 36 | "giraffe", 37 | "backpack", 38 | "umbrella", 39 | "handbag", 40 | "tie", 41 | "suitcase", 42 | "frisbee", 43 | "skis", 44 | "snowboard", 45 | "sports ball", 46 | "kite", 47 | "baseball bat", 48 | "baseball glove", 49 | "skateboard", 50 | "surfboard", 51 | "tennis racket", 52 | "bottle", 53 | "wine glass", 54 | "cup", 55 | "fork", 56 | "knife", 57 | "spoon", 58 | "bowl", 59 | "banana", 60 | "apple", 61 | "sandwich", 62 | "orange", 63 | "broccoli", 64 | "carrot", 65 | "hot dog", 66 | "pizza", 67 | "donut", 68 | "cake", 69 | "chair", 70 | "sofa", 71 | "pottedplant", 72 | "bed", 73 | "diningtable", 74 | "toilet", 75 | "tvmonitor", 76 | "laptop", 77 | "mouse", 78 | "remote", 79 | "keyboard", 80 | "cell phone", 81 | "microwave", 82 | "oven", 83 | "toaster", 84 | "sink", 85 | "refrigerator", 86 | "book", 87 | "clock", 88 | "vase", 89 | "scissors", 90 | "teddy bear", 91 | "hair drier", 92 | "toothbrush", 93 | ] 94 | 95 | self.COLORS = DatasetConfigBase.generate_color_chart(self.num_classes) 96 | 97 | 98 | _coco_config = COCODatasetConfig() 99 | 100 | 101 | class COCODataset(DatasetBase): 102 | __name__ = "coco_dataset" 103 | 104 | def __init__( 105 | self, 106 | data_path, 107 | data_path_suffix="coco_dataset", 108 | classes=_coco_config.CLASSES, 109 | colors=_coco_config.COLORS, 110 | phase="train", 111 | transform=None, 112 | shuffle=True, 113 | input_size=None, 114 | random_seed=2000, 115 | normalize_bbox=False, 116 | normalize_image=False, 117 | bbox_transformer=None, 118 | ): 119 | super(COCODataset, self).__init__( 120 | data_path, 121 | classes=classes, 122 | colors=colors, 123 | phase=phase, 124 | transform=transform, 125 | shuffle=shuffle, 126 | normalize_bbox=normalize_bbox, 127 | bbox_transformer=bbox_transformer, 128 | ) 129 | 130 | self._input_size = input_size 131 | self._normalize_image = normalize_image 132 | self._min_size = 1 133 | 134 | assert phase in ("train", "val") 135 | self._data_path = os.path.join(data_path, data_path_suffix) 136 | 137 | self._train_img_path = os.path.join(self._data_path, "train2017") 138 | self._val_img_path = os.path.join(self._data_path, "val2017") 139 | self._annotation_path = os.path.join(self._data_path, "annotations_trainval2017", "annotations") 140 | 141 | assert os.path.isdir(self._train_img_path) 142 | assert os.path.isdir(self._val_img_path) 143 | assert os.path.isdir(self._annotation_path) 144 | self._train_annotation_file = os.path.join(self._annotation_path, "instances_train2017.json") 145 | self._val_annotation_file = os.path.join(self._annotation_path, "instances_val2017.json") 146 | 147 | self._train_coco = COCO(self._train_annotation_file) 148 | self._val_coco = COCO(self._val_annotation_file) 149 | self._train_ids = self._train_coco.getImgIds() 150 | self._val_ids = self._train_coco.getImgIds() 151 | self._class_ids = sorted(self._train_coco.getCatIds()) 152 | 153 | self._map_info = { 154 | "train": {"img_path": self._train_img_path, "coco": self._train_coco, "ids": self._train_ids,}, 155 | "val": {"img_path": self._val_img_path, "coco": self._val_coco, "ids": self._val_ids,}, 156 | } 157 | 158 | self._image_paths = [] 159 | self._targets = [] 160 | 161 | for idx in self._map_info[self._phase]["ids"]: 162 | anno_ids = self._map_info[self._phase]["coco"].getAnnIds(imgIds=[int(idx)], iscrowd=None) 163 | annotations = self._map_info[self._phase]["coco"].loadAnns(anno_ids) 164 | 165 | image_path = os.path.join(self._map_info[self._phase]["img_path"], "{:012}".format(idx) + ".jpg",) 166 | self._image_paths.append(image_path) 167 | 168 | cur_boxes = [] 169 | cur_labels = [] 170 | 171 | for anno in annotations: 172 | xmin, ymin, width, height = anno["bbox"] 173 | if width < self._min_size or height < self._min_size: 174 | continue 175 | xmax = xmin + width 176 | ymax = ymin + height 177 | label = self._class_ids.index(anno["category_id"]) 178 | cur_boxes.append([xmin, ymin, xmax, ymax]) 179 | cur_labels.append(label) 180 | 181 | self._targets.append([cur_boxes, cur_labels]) 182 | -------------------------------------------------------------------------------- /data_loader/data_transform_base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from albumentations import * 4 | from albumentations.pytorch import ToTensor 5 | import random 6 | 7 | _SEED = 100 8 | random.seed(_SEED) 9 | 10 | 11 | class DataTransformBase(object): 12 | def __init__( 13 | self, 14 | transforms=[HorizontalFlip(p=0.5), GaussNoise(p=0.5), RandomBrightnessContrast(p=0.5),], 15 | input_size=None, 16 | normalize=False, 17 | ): 18 | self._input_size = input_size 19 | if self._input_size is not None: 20 | height, width = self._input_size 21 | self._height_offset = height // 32 - 8 22 | self._width_offset = width // 32 - 8 23 | assert self._height_offset > 0 and self._width_offset > 0 24 | 25 | self._normalize = normalize 26 | 27 | self._transform_dict = {"train": {}, "val": {}, "test": None} 28 | 29 | self._transform_dict["train"]["normal"] = transforms 30 | self._transform_dict["val"]["normal"] = [] 31 | 32 | self._bbox_params = BboxParams( 33 | format="pascal_voc", min_area=0.001, min_visibility=0.001, label_fields=["category_ids"], 34 | ) 35 | self._initialize_transform_dict() 36 | 37 | def _get_all_transforms_of_phase(self, phase): 38 | assert phase in ("train", "val") 39 | cur_transform = [] 40 | cur_transform.extend(self._transform_dict[phase]["normal"]) 41 | cur_transform.append(self._transform_dict[phase]["resize"]) 42 | cur_transform.append(self._transform_dict[phase]["normalize"]) 43 | 44 | return cur_transform 45 | 46 | def _initialize_transform_dict(self): 47 | if self._input_size is not None: 48 | height, width = self._input_size 49 | self._transform_dict["train"]["resize"] = Resize(height, width, always_apply=True) 50 | self._transform_dict["val"]["resize"] = Resize(height, width, always_apply=True) 51 | 52 | if self._normalize: 53 | self._transform_dict["train"]["normalize"] = Normalize(always_apply=True) 54 | self._transform_dict["val"]["normalize"] = Normalize(always_apply=True) 55 | else: 56 | self._transform_dict["train"]["normalize"] = ToTensor() 57 | self._transform_dict["val"]["normalize"] = ToTensor() 58 | 59 | self._transform_dict["train"]["all"] = self._get_all_transforms_of_phase("train") 60 | self._transform_dict["val"]["all"] = self._get_all_transforms_of_phase("val") 61 | 62 | self._transform_dict["test"] = self._transform_dict["val"]["all"] 63 | 64 | def update_size(self): 65 | if self._input_size is not None: 66 | random_offset = random.randint(0, 9) 67 | new_height = (random_offset + self._height_offset) * 32 68 | new_width = (random_offset + self._width_offset) * 32 69 | 70 | self._transform_dict["train"]["resize"] = Resize(new_height, new_width, always_apply=True) 71 | self._transform_dict["train"]["all"] = self._get_all_transforms_of_phase("train") 72 | 73 | def __call__(self, image, bboxes=None, labels=None, phase=None): 74 | if phase is None: 75 | transformer = Compose(self._transform_dict["test"]) 76 | return transformer(image=image) 77 | 78 | assert phase in ("train", "val") 79 | assert bboxes is not None 80 | assert labels is not None 81 | 82 | transformed_image = image 83 | transformed_bboxes = bboxes 84 | transformed_category_ids = labels 85 | for transform in self._transform_dict[phase]["all"]: 86 | annotations = { 87 | "image": transformed_image, 88 | "bboxes": transformed_bboxes, 89 | "category_ids": transformed_category_ids, 90 | } 91 | transformer = Compose([transform], bbox_params=self._bbox_params) 92 | augmented = transformer(**annotations) 93 | 94 | while len(augmented["bboxes"]) == 0: 95 | augmented = transformer(**annotations) 96 | 97 | transformed_image = augmented["image"] 98 | transformed_bboxes = augmented["bboxes"] 99 | transformed_category_ids = augmented["category_ids"] 100 | 101 | if not self._normalize: 102 | transformed_image = transformed_image.permute(2, 1, 0) 103 | 104 | return (transformed_image, transformed_bboxes, transformed_category_ids) 105 | -------------------------------------------------------------------------------- /data_loader/dataset_base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import cv2 5 | import numpy as np 6 | import tqdm 7 | import multiprocessing as mp 8 | from ctypes import c_int32 9 | from abc import abstractmethod 10 | 11 | 12 | _counter = mp.Value(c_int32) 13 | _counter_lock = mp.Lock() 14 | 15 | 16 | class DatasetConfigBase(object): 17 | def __init__(self): 18 | self.CLASSES = [] 19 | 20 | @property 21 | def num_classes(self): 22 | return len(self.CLASSES) 23 | 24 | @staticmethod 25 | def generate_color_chart(num_classes, seed=3700): 26 | assert num_classes > 0 27 | np.random.seed(seed) 28 | 29 | colors = np.random.randint(0, 255, size=(num_classes, 3), dtype="uint8") 30 | colors = np.vstack([colors]).astype("uint8") 31 | colors = [tuple(color) for color in list(colors)] 32 | colors = [tuple(int(e) for e in color) for color in colors] 33 | 34 | return colors 35 | 36 | 37 | class DatasetBase(object): 38 | def __init__( 39 | self, 40 | data_path, 41 | classes, 42 | colors, 43 | phase="train", 44 | transform=None, 45 | shuffle=True, 46 | normalize_bbox=False, 47 | bbox_transformer=None, 48 | multiscale=False, 49 | resize_after_batch_num=10, 50 | ): 51 | super(DatasetBase, self).__init__() 52 | assert os.path.isdir(data_path) 53 | assert phase in ("train", "val", "test") 54 | 55 | self._data_path = data_path 56 | self._classes = classes 57 | self._colors = colors 58 | self._phase = phase 59 | self._transform = transform 60 | self._shuffle = shuffle 61 | self._normalize_bbox = normalize_bbox 62 | self._bbox_transformer = bbox_transformer 63 | self._image_paths = [] 64 | self._targets = [] 65 | 66 | self._batch_count = 0 67 | self._multiscale = multiscale 68 | self._resize_after_batch_num = resize_after_batch_num 69 | 70 | @property 71 | def classes(self): 72 | return self._classes 73 | 74 | @property 75 | def colors(self): 76 | return self._colors 77 | 78 | @property 79 | def num_classes(self): 80 | return len(self._classes) 81 | 82 | def __getitem__(self, idx): 83 | """ 84 | X: np.array (batch_size, height, width, channel) 85 | y: list of length batch size 86 | y: (batch_size, [[number of bboxes, x_min, y_min, x_max, y_max, labels]]) 87 | """ 88 | assert idx < self.__len__() 89 | 90 | image, targets = self._data_generation(idx) 91 | 92 | if self._bbox_transformer is not None: 93 | targets = self._bbox_transformer(targets) 94 | 95 | return image, targets 96 | 97 | def __len__(self): 98 | return len(self._image_paths) 99 | 100 | def visualize_one_image(self, idx): 101 | assert not self._normalize_bbox 102 | assert idx < self.__len__() 103 | image, targets = self.__getitem__(idx) 104 | 105 | all_bboxes = targets[:, :-1] 106 | all_category_ids = targets[:, -1] 107 | 108 | all_bboxes = all_bboxes.astype(np.int64) 109 | all_category_ids = all_category_ids.astype(np.int64) 110 | 111 | return DatasetBase.visualize_one_image_util(image, self._classes, self._colors, all_bboxes, all_category_ids) 112 | 113 | @staticmethod 114 | def visualize_one_image_util(image, classes, colors, all_bboxes, all_category_ids): 115 | for (bbox, label) in zip(all_bboxes, all_category_ids): 116 | x_min, y_min, x_max, y_max = bbox 117 | 118 | cv2.rectangle(image, (x_min, y_min), (x_max, y_max), colors[label], 2) 119 | 120 | label_text = classes[label] 121 | label_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1) 122 | 123 | cv2.rectangle( 124 | image, 125 | (x_min, y_min), 126 | (x_min + label_size[0][0], y_min + int(1.3 * label_size[0][1])), 127 | colors[label], 128 | -1, 129 | ) 130 | cv2.putText( 131 | image, 132 | label_text, 133 | org=(x_min, y_min + label_size[0][1]), 134 | fontFace=cv2.FONT_HERSHEY_SIMPLEX, 135 | fontScale=0.35, 136 | color=(255, 255, 255), 137 | lineType=cv2.LINE_AA, 138 | ) 139 | 140 | return image 141 | 142 | def _data_generation(self, idx): 143 | abs_image_path = self._image_paths[idx] 144 | o_img = cv2.imread(abs_image_path) 145 | o_height, o_width, _ = o_img.shape 146 | 147 | o_bboxes, o_category_ids = self._targets[idx] 148 | 149 | o_bboxes = [DatasetBase.authentize_bbox(o_height, o_width, bbox) for bbox in o_bboxes] 150 | 151 | img = o_img 152 | bboxes = o_bboxes 153 | category_ids = o_category_ids 154 | height, width = o_height, o_width 155 | if self._transform: 156 | if isinstance(self._transform, list): 157 | for transform in self._transform: 158 | img, bboxes, category_ids = transform( 159 | image=img, bboxes=bboxes, labels=category_ids, phase=self._phase 160 | ) 161 | else: 162 | img, bboxes, category_ids = self._transform( 163 | image=img, bboxes=bboxes, labels=category_ids, phase=self._phase 164 | ) 165 | 166 | # use the height, width after transformation for normalization 167 | height, width, _ = img.shape 168 | 169 | # if number of boxes is 0, use original image 170 | # see data transform for more details 171 | 172 | if self._normalize_bbox: 173 | bboxes = [ 174 | [float(bbox[0]) / width, float(bbox[1]) / height, float(bbox[2]) / width, float(bbox[3]) / height,] 175 | for bbox in bboxes 176 | ] 177 | 178 | bboxes = np.array(bboxes) 179 | category_ids = np.array(category_ids).reshape(-1, 1) 180 | targets = np.concatenate((bboxes, category_ids), axis=-1) 181 | 182 | return img, targets 183 | 184 | def _process_one_image(self, idx): 185 | abs_image_path = self._image_paths[idx] 186 | o_img = cv2.imread(abs_image_path) 187 | o_height, o_width, _ = o_img.shape 188 | 189 | o_bboxes, _ = self._targets[idx] 190 | 191 | o_bboxes = [DatasetBase.authentize_bbox(o_height, o_width, bbox) for bbox in o_bboxes] 192 | 193 | o_bboxes = np.array(o_bboxes) 194 | widths = (o_bboxes[:, 2] - o_bboxes[:, 0]) / o_width 195 | heights = (o_bboxes[:, 3] - o_bboxes[:, 1]) / o_height 196 | normalized_dimensions = [[w, h] for w, h in zip(widths, heights)] 197 | 198 | with _counter_lock: 199 | _counter.value += 1 200 | 201 | return normalized_dimensions 202 | 203 | def get_all_normalized_boxes(self, num_processes=mp.cpu_count()): 204 | import functools 205 | 206 | _process_len = self.__len__() 207 | 208 | pbar = tqdm.tqdm(total=_process_len) 209 | 210 | result_bboxes = None 211 | with mp.Pool(num_processes) as p: 212 | future = p.map_async(self._process_one_image, range(_process_len)) 213 | while not future.ready(): 214 | if _counter.value != 0: 215 | with _counter_lock: 216 | increment = _counter.value 217 | _counter.value = 0 218 | pbar.update(n=increment) 219 | 220 | result_bboxes = future.get() 221 | result_bboxes = functools.reduce(lambda x, y: x + y, result_bboxes) 222 | 223 | pbar.close() 224 | return np.array(result_bboxes) 225 | 226 | def _process_one_image_to_get_size(self, idx): 227 | abs_image_path = self._image_paths[idx] 228 | o_img = cv2.imread(abs_image_path) 229 | o_height, o_width, _ = o_img.shape 230 | o_size = o_height * o_width 231 | o_bboxes, _ = self._targets[idx] 232 | 233 | cur_sizes = [] 234 | for (x1, y1, x2, y2) in o_bboxes: 235 | cur_sizes.append(np.sqrt((x2 - x1) * (y2 - y1) * 1.0 / o_size)) 236 | 237 | return cur_sizes 238 | 239 | def size_distribution(self, num_processes=mp.cpu_count()): 240 | 241 | import functools 242 | 243 | _process_len = self.__len__() 244 | 245 | pbar = tqdm.tqdm(total=_process_len) 246 | 247 | result_bboxes = None 248 | with mp.Pool(num_processes) as p: 249 | future = p.map_async(self._process_one_image_to_get_size, range(_process_len)) 250 | while not future.ready(): 251 | if _counter.value != 0: 252 | with _counter_lock: 253 | increment = _counter.value 254 | _counter.value = 0 255 | pbar.update(n=increment) 256 | 257 | result_bboxes = future.get() 258 | result_bboxes = functools.reduce(lambda x, y: x + y, result_bboxes) 259 | 260 | pbar.close() 261 | return np.array(result_bboxes) 262 | 263 | @staticmethod 264 | def authentize_bbox(o_height, o_width, bbox): 265 | bbox_type = type(bbox) 266 | 267 | x_min, y_min, x_max, y_max = bbox 268 | if x_min > x_max: 269 | x_min, x_max = x_max, x_min 270 | if y_min > y_max: 271 | y_min, y_max = y_max, y_min 272 | 273 | x_min = max(x_min, 0) 274 | x_max = min(x_max, o_width) 275 | y_min = max(y_min, 0) 276 | y_max = min(y_max, o_height) 277 | 278 | return bbox_type([x_min, y_min, x_max, y_max]) 279 | 280 | @staticmethod 281 | def color_to_color_idx_dict(colors): 282 | color_idx_dict = {} 283 | 284 | for i, color in enumerate(colors): 285 | color_idx_dict[color] = i 286 | 287 | return color_idx_dict 288 | 289 | @staticmethod 290 | def class_to_class_idx_dict(classes): 291 | class_idx_dict = {} 292 | 293 | for i, class_name in enumerate(classes): 294 | class_idx_dict[class_name] = i 295 | 296 | return class_idx_dict 297 | 298 | @staticmethod 299 | def human_sort(s): 300 | """Sort list the way humans do 301 | """ 302 | import re 303 | 304 | pattern = r"([0-9]+)" 305 | return [int(c) if c.isdigit() else c.lower() for c in re.split(pattern, s)] 306 | 307 | @staticmethod 308 | def get_all_files_with_format_from_path(dir_path, suffix_format, use_human_sort=True): 309 | import os 310 | 311 | all_files = [elem for elem in os.listdir(dir_path) if elem.endswith(suffix_format)] 312 | all_files.sort(key=DatasetBase.human_sort) 313 | 314 | return all_files 315 | 316 | def od_collate_fn(self, batch): 317 | import torch 318 | import numpy as np 319 | 320 | def _xywh_to_cxcywh(bbox): 321 | bbox[..., 0] += bbox[..., 2] / 2 322 | bbox[..., 1] += bbox[..., 3] / 2 323 | return bbox 324 | 325 | def _xyxy_to_cxcywh(bbox): 326 | bbox[..., 2] -= bbox[..., 0] 327 | bbox[..., 3] -= bbox[..., 1] 328 | return _xywh_to_cxcywh(bbox) 329 | 330 | if ( 331 | self._multiscale 332 | and (self._batch_count + 1) % self._resize_after_batch_num == 0 333 | and self._transform is not None 334 | ): 335 | if isinstance(self._transform, list): 336 | for i in range(len(self._transform)): 337 | self._transform[i].update_size() 338 | else: 339 | self._transform.update_size() 340 | 341 | images = [] 342 | labels = [] 343 | lengths = [] 344 | labels_with_tail = [] 345 | max_num_obj = 0 346 | 347 | for image, label in batch: 348 | image = np.transpose(image, (2, 1, 0)) 349 | image = np.expand_dims(image, axis=0) 350 | images.append(image) 351 | 352 | # xmin,ymin,xmax,ymax to xcenter,ycenter,width,height 353 | labels.append(_xyxy_to_cxcywh(label)) 354 | 355 | length = label.shape[0] 356 | lengths.append(length) 357 | max_num_obj = max(max_num_obj, length) 358 | 359 | for label in labels: 360 | num_obj = label.shape[0] 361 | zero_tail = np.zeros((max_num_obj - num_obj, label.shape[1]), dtype=float) 362 | label_with_tail = np.concatenate([label, zero_tail], axis=0) 363 | labels_with_tail.append(torch.FloatTensor(label_with_tail)) 364 | 365 | images = np.concatenate(images, axis=0) 366 | 367 | image_tensor = torch.FloatTensor(images) 368 | label_tensor = torch.stack(labels_with_tail) 369 | length_tensor = torch.tensor(lengths) 370 | 371 | self._batch_count += 1 372 | 373 | return image_tensor, label_tensor, length_tensor 374 | -------------------------------------------------------------------------------- /data_loader/switch_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import cv2 5 | import random 6 | import numpy as np 7 | import json 8 | from .dataset_base import DatasetBase, DatasetConfigBase 9 | 10 | 11 | class SwitchDatasetConfig(DatasetConfigBase): 12 | def __init__(self): 13 | super(SwitchDatasetConfig, self).__init__() 14 | 15 | self.CLASSES = [ 16 | "switch-unknown", 17 | "switch-right", 18 | "switch-left", 19 | ] 20 | 21 | self.COLORS = DatasetConfigBase.generate_color_chart(self.num_classes) 22 | 23 | 24 | _switch_config = SwitchDatasetConfig() 25 | 26 | 27 | class SwitchDataset(DatasetBase): 28 | __name__ = "switch_dataset" 29 | 30 | def __init__( 31 | self, 32 | data_path, 33 | classes=_switch_config.CLASSES, 34 | colors=_switch_config.COLORS, 35 | phase="train", 36 | transform=None, 37 | shuffle=True, 38 | random_seed=2000, 39 | normalize_bbox=False, 40 | bbox_transformer=None, 41 | multiscale=False, 42 | resize_after_batch_num=10, 43 | train_val_ratio=0.9, 44 | ): 45 | super(SwitchDataset, self).__init__( 46 | data_path, 47 | classes=classes, 48 | colors=colors, 49 | phase=phase, 50 | transform=transform, 51 | shuffle=shuffle, 52 | normalize_bbox=normalize_bbox, 53 | bbox_transformer=bbox_transformer, 54 | multiscale=multiscale, 55 | resize_after_batch_num=resize_after_batch_num, 56 | ) 57 | 58 | assert os.path.isdir(data_path) 59 | assert phase in ("train", "val") 60 | 61 | self._data_path = os.path.join(data_path, "switch_detection/data") 62 | assert os.path.isdir(self._data_path) 63 | 64 | self._phase = phase 65 | self._transform = transform 66 | 67 | _imgs_path = os.path.join(self._data_path, "imgs") 68 | _jsons_path = os.path.join(self._data_path, "labels") 69 | 70 | _all_image_paths = DatasetBase.get_all_files_with_format_from_path(_imgs_path, ".jpg") 71 | _all_json_paths = DatasetBase.get_all_files_with_format_from_path(_jsons_path, ".json") 72 | 73 | assert len(_all_image_paths) == len(_all_json_paths) 74 | 75 | _image_paths = [os.path.join(_imgs_path, elem) for elem in _all_image_paths] 76 | _targets = [self._load_one_json(os.path.join(_jsons_path, elem)) for elem in _all_json_paths] 77 | 78 | zipped = list(zip(_image_paths, _targets)) 79 | random.seed(random_seed) 80 | random.shuffle(zipped) 81 | _image_paths, _targets = zip(*zipped) 82 | 83 | _train_len = int(train_val_ratio * len(_image_paths)) 84 | if self._phase == "train": 85 | self._image_paths = _image_paths[:_train_len] 86 | self._targets = _targets[:_train_len] 87 | else: 88 | self._image_paths = _image_paths[_train_len:] 89 | self._targets = _targets[_train_len:] 90 | 91 | def _load_one_json(self, json_path): 92 | bboxes = [] 93 | labels = [] 94 | 95 | p_json = json.load(open(json_path, "r")) 96 | 97 | for obj in p_json["objects"]: 98 | if "boundingbox" in obj: 99 | x_min, y_min, x_max, y_max = obj["boundingbox"] 100 | label_text = obj["label"] 101 | label_idx = self._classes.index(label_text) 102 | 103 | bboxes.append([x_min, y_min, x_max, y_max]) 104 | labels.append(label_idx) 105 | 106 | return [bboxes, labels] 107 | -------------------------------------------------------------------------------- /data_loader/udacity_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import cv2 5 | import numpy as np 6 | import random 7 | from .dataset_base import DatasetBase, DatasetConfigBase 8 | 9 | 10 | class UdacityDatasetConfig(DatasetConfigBase): 11 | def __init__(self): 12 | super(UdacityDatasetConfig, self).__init__() 13 | 14 | self.CLASSES = ["car", "truck", "pedestrian", "biker", "trafficLight"] 15 | 16 | self.COLORS = DatasetConfigBase.generate_color_chart(self.num_classes) 17 | 18 | 19 | _udacity_config = UdacityDatasetConfig() 20 | 21 | 22 | class UdacityDataset(DatasetBase): 23 | __name__ = "udacity_dataset" 24 | 25 | def __init__( 26 | self, 27 | data_path, 28 | classes=_udacity_config.CLASSES, 29 | colors=_udacity_config.COLORS, 30 | phase="train", 31 | transform=None, 32 | shuffle=True, 33 | random_seed=2000, 34 | normalize_bbox=False, 35 | bbox_transformer=None, 36 | train_val_ratio=0.9, 37 | ): 38 | super(UdacityDataset, self).__init__( 39 | data_path, 40 | classes=classes, 41 | colors=colors, 42 | phase=phase, 43 | transform=transform, 44 | shuffle=shuffle, 45 | normalize_bbox=normalize_bbox, 46 | bbox_transformer=bbox_transformer, 47 | ) 48 | 49 | assert phase in ("train", "val") 50 | 51 | assert os.path.isdir(data_path) 52 | self._data_path = os.path.join(data_path, "udacity/object-dataset") 53 | assert os.path.isdir(self._data_path) 54 | 55 | self._annotation_file = os.path.join(self._data_path, "labels.csv") 56 | lines = [line.rstrip("\n") for line in open(self._annotation_file, "r")] 57 | lines = [line.split(" ") for line in lines] 58 | image_dict = {} 59 | class_idx_dict = self.class_to_class_idx_dict(self._classes) 60 | 61 | for line in lines: 62 | if line[0] not in image_dict.keys(): 63 | image_dict[line[0]] = [[], []] 64 | 65 | image_dict[line[0]][0].append([int(e) for e in line[1:5]]) 66 | label_name = line[6][1:][:-1] 67 | image_dict[line[0]][1].append(class_idx_dict[label_name]) 68 | 69 | self._image_paths = image_dict.keys() 70 | self._image_paths = [os.path.join(self._data_path, elem) for elem in self._image_paths] 71 | self._targets = image_dict.values() 72 | 73 | zipped = list(zip(self._image_paths, self._targets)) 74 | random.seed(random_seed) 75 | random.shuffle(zipped) 76 | self._image_paths, self._targets = zip(*zipped) 77 | 78 | train_len = int(train_val_ratio * len(self._image_paths)) 79 | if self._phase == "train": 80 | self._image_paths = self._image_paths[:train_len] 81 | self._targets = self._targets[:train_len] 82 | else: 83 | self._image_paths = self._image_paths[train_len:] 84 | self._targets = self._targets[train_len:] 85 | -------------------------------------------------------------------------------- /data_loader/voc_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import cv2 5 | import numpy as np 6 | import xml.etree.ElementTree as ET 7 | from .dataset_base import DatasetBase, DatasetConfigBase 8 | 9 | 10 | class VOCDatasetConfig(DatasetConfigBase): 11 | def __init__(self): 12 | super(VOCDatasetConfig, self).__init__() 13 | 14 | self.CLASSES = [ 15 | "aeroplane", 16 | "bicycle", 17 | "bird", 18 | "boat", 19 | "bottle", 20 | "bus", 21 | "car", 22 | "cat", 23 | "chair", 24 | "cow", 25 | "diningtable", 26 | "dog", 27 | "horse", 28 | "motorbike", 29 | "person", 30 | "pottedplant", 31 | "sheep", 32 | "sofa", 33 | "train", 34 | "tvmonitor", 35 | ] 36 | 37 | self.COLORS = DatasetConfigBase.generate_color_chart(self.num_classes) 38 | 39 | 40 | _voc_config = VOCDatasetConfig() 41 | 42 | 43 | class VOCDataset(DatasetBase): 44 | __name__ = "voc_dataset" 45 | 46 | def __init__( 47 | self, 48 | data_path, 49 | classes=_voc_config.CLASSES, 50 | colors=_voc_config.COLORS, 51 | phase="train", 52 | transform=None, 53 | shuffle=True, 54 | input_size=None, 55 | random_seed=2000, 56 | normalize_bbox=False, 57 | normalize_image=False, 58 | bbox_transformer=None, 59 | ): 60 | super(VOCDataset, self).__init__( 61 | data_path, 62 | classes=classes, 63 | colors=colors, 64 | phase=phase, 65 | transform=transform, 66 | shuffle=shuffle, 67 | normalize_bbox=normalize_bbox, 68 | bbox_transformer=bbox_transformer, 69 | ) 70 | 71 | self._input_size = input_size 72 | self._normalize_image = normalize_image 73 | 74 | assert phase in ("train", "val") 75 | self._data_path = os.path.join(self._data_path, "voc/VOCdevkit/VOC2012") 76 | self._image_path_file = os.path.join(self._data_path, "ImageSets/Main/{}.txt".format(self._phase),) 77 | 78 | lines = [line.rstrip("\n") for line in open(self._image_path_file)] 79 | 80 | image_path_base = os.path.join(self._data_path, "JPEGImages") 81 | anno_path_base = os.path.join(self._data_path, "Annotations") 82 | 83 | self._image_paths = [] 84 | anno_paths = [] 85 | 86 | for line in lines: 87 | self._image_paths.append(os.path.join(image_path_base, "{}.jpg".format(line))) 88 | anno_paths.append(os.path.join(anno_path_base, "{}.xml".format(line))) 89 | 90 | self._targets = [self._parse_one_xml(xml_path) for xml_path in anno_paths] 91 | 92 | np.random.seed(random_seed) 93 | 94 | def _parse_one_xml(self, xml_path): 95 | bboxes = [] 96 | labels = [] 97 | xml = ET.parse(xml_path).getroot() 98 | 99 | for obj in xml.iter("object"): 100 | difficult = int(obj.find("difficult").text) 101 | if difficult == 1: 102 | continue 103 | 104 | bndbox = [] 105 | 106 | name = obj.find("name").text.lower().strip() 107 | bbox = obj.find("bndbox") 108 | 109 | pts = ["xmin", "ymin", "xmax", "ymax"] 110 | 111 | for pt in pts: 112 | cur_pixel = int(bbox.find(pt).text) - 1 113 | bndbox.append(cur_pixel) 114 | 115 | label_idx = self._classes.index(name) 116 | 117 | bboxes.append(bndbox) 118 | labels.append(label_idx) 119 | 120 | return [bboxes, labels] 121 | -------------------------------------------------------------------------------- /data_loader/wheat_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import cv2 4 | import random 5 | import numpy as np 6 | import json 7 | from .dataset_base import DatasetBase, DatasetConfigBase 8 | 9 | 10 | class WheatDatasetConfig(DatasetConfigBase): 11 | def __init__(self): 12 | super(WheatDatasetConfig, self).__init__() 13 | 14 | self.CLASSES = [ 15 | "wheat", 16 | ] 17 | 18 | self.COLORS = DatasetConfigBase.generate_color_chart(self.num_classes) 19 | 20 | 21 | _wheat_config = WheatDatasetConfig() 22 | 23 | 24 | class WheatDataset(DatasetBase): 25 | __name__ = "wheat_dataset" 26 | 27 | def __init__( 28 | self, 29 | data_path, 30 | classes=_wheat_config.CLASSES, 31 | colors=_wheat_config.COLORS, 32 | phase="train", 33 | transform=None, 34 | shuffle=True, 35 | random_seed=2000, 36 | normalize_bbox=False, 37 | bbox_transformer=None, 38 | multiscale=False, 39 | resize_after_batch_num=10, 40 | train_val_ratio=0.9, 41 | ): 42 | super(WheatDataset, self).__init__( 43 | data_path, 44 | classes=classes, 45 | colors=colors, 46 | phase=phase, 47 | transform=transform, 48 | shuffle=shuffle, 49 | normalize_bbox=normalize_bbox, 50 | bbox_transformer=bbox_transformer, 51 | multiscale=multiscale, 52 | resize_after_batch_num=resize_after_batch_num, 53 | ) 54 | 55 | assert os.path.isdir(data_path) 56 | assert phase in ("train", "val") 57 | 58 | self._wheat_data_path = os.path.join(data_path, "global-wheat-detection") 59 | assert os.path.isdir(self._wheat_data_path) 60 | self._train_csv_path = os.path.join(self._wheat_data_path, "train.csv") 61 | self._image_prefix_path = os.path.join(self._wheat_data_path, "train") 62 | 63 | self._phase = phase 64 | self._transform = transform 65 | 66 | self._image_paths, self._targets = self._process_train_csv(self._train_csv_path) 67 | self._image_paths = [os.path.join(self._image_prefix_path, elem) + ".jpg" for elem in self._image_paths] 68 | 69 | zipped = list(zip(self._image_paths, self._targets)) 70 | random.seed(random_seed) 71 | random.shuffle(zipped) 72 | self._image_paths, self._targets = zip(*zipped) 73 | 74 | train_len = int(train_val_ratio * len(self._image_paths)) 75 | if self._phase == "train": 76 | self._image_paths = self._image_paths[:train_len] 77 | self._targets = self._targets[:train_len] 78 | else: 79 | self._image_paths = self._image_paths[train_len:] 80 | self._targets = self._targets[train_len:] 81 | 82 | def _process_train_csv(self, train_csv_path): 83 | image_dict = {} 84 | 85 | import dask.dataframe as dd 86 | 87 | df = dd.read_csv(train_csv_path) 88 | for idx, row in df.iterrows(): 89 | image_id = row["image_id"] 90 | if image_id not in image_dict.keys(): 91 | image_dict[image_id] = [[], []] 92 | 93 | source = row["source"] 94 | width = row["width"] 95 | height = row["height"] 96 | 97 | bbox = row["bbox"].strip("][").split(", ") 98 | bbox = [float(elem) for elem in bbox] 99 | 100 | xmin, ymin, width, height = bbox 101 | xmax = xmin + width 102 | ymax = ymin + height 103 | 104 | bbox = [int(xmin), int(ymin), int(xmax), int(ymax)] 105 | 106 | image_dict[image_id][0].append(bbox) 107 | image_dict[image_id][1].append(0) 108 | 109 | return image_dict.keys(), image_dict.values() 110 | -------------------------------------------------------------------------------- /dataset_params.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | __all__ = ["dataset_params"] 6 | 7 | 8 | dataset_params = { 9 | "coco_dataset": { 10 | "anchors": [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326],], 11 | "anchor_masks": [[6, 7, 8], [3, 4, 5], [0, 1, 2]], 12 | "num_classes": 80, 13 | "img_h": 416, 14 | "img_w": 416, 15 | }, 16 | "udacity_dataset": { 17 | "anchors": [ 18 | [0.014583333333333334, 0.03333333333333333], 19 | [0.017708333333333333, 0.05], 20 | [0.022916666666666665, 0.07333333333333333], 21 | [0.03333333333333333, 0.043333333333333335], 22 | [0.035416666666666666, 0.11833333333333333], 23 | [0.051041666666666666, 0.06], 24 | [0.08020833333333334, 0.08666666666666667], 25 | [0.12604166666666666, 0.15666666666666668], 26 | [0.2677083333333333, 0.3466666666666667], 27 | ], 28 | "anchor_masks": [[6, 7, 8], [3, 4, 5], [0, 1, 2]], 29 | "num_classes": 5, 30 | "img_h": 416, 31 | "img_w": 416, 32 | }, 33 | "bird_dataset": { 34 | "anchors": [ 35 | [0.013333333333333334, 0.013666666666666667], 36 | [0.016923076923076923, 0.027976190476190477], 37 | [0.022203947368421052, 0.044827615904163357], 38 | [0.025833333333333333, 0.016710875331564987], 39 | [0.034375, 0.028125], 40 | [0.038752362948960305, 0.07455104993043463], 41 | [0.05092592592592592, 0.04683129325109843], 42 | [0.06254458977407848, 0.0764872521246459], 43 | [0.07689655172413794, 0.14613778705636743], 44 | [0.11500570776255709, 0.09082682291666666], 45 | [0.162109375, 0.18448023426061494], 46 | [0.26129166666666664, 0.3815], 47 | ], 48 | "anchor_masks": [[8, 9, 10, 11], [4, 5, 6, 7], [0, 1, 2, 3]], 49 | "num_classes": 1, 50 | "img_h": 608, 51 | "img_w": 608, 52 | }, 53 | "voc_dataset": { 54 | "anchors": [ 55 | [0.052, 0.08], 56 | [0.07, 0.18018018018018017], 57 | [0.138, 0.30930930930930933], 58 | [0.162, 0.138], 59 | [0.224, 0.5350840978593272], 60 | [0.336, 0.29333333333333333], 61 | [0.41, 0.7173333333333334], 62 | [0.686, 0.466], 63 | [0.848, 0.88], 64 | ], 65 | "anchor_masks": [[6, 7, 8], [3, 4, 5], [0, 1, 2]], 66 | "num_classes": 20, 67 | "img_h": 416, 68 | "img_w": 416, 69 | }, 70 | "switch_dataset": { 71 | "anchors": [ 72 | [0.009895833333333333, 0.006481481481481481], 73 | [0.016145833333333335, 0.011111111111111112], 74 | [0.021875, 0.016666666666666666], 75 | [0.030208333333333334, 0.021296296296296296], 76 | [0.033854166666666664, 0.032407407407407406], 77 | [0.04635416666666667, 0.026851851851851852], 78 | [0.05572916666666667, 0.04351851851851852], 79 | [0.07916666666666666, 0.06574074074074074], 80 | [0.13619791666666664, 0.11481481481481481], 81 | ], 82 | "anchor_masks": [[6, 7, 8], [3, 4, 5], [0, 1, 2]], 83 | "num_classes": 3, 84 | "img_h": 608, 85 | "img_w": 608, 86 | }, 87 | "wheat_dataset": { 88 | "anchors": [ 89 | [0.044921875, 0.080078125], 90 | [0.0537109375, 0.052734375], 91 | [0.05859375, 0.033203125], 92 | [0.0703125, 0.06640625], 93 | [0.0771484375, 0.091796875], 94 | [0.0908203125, 0.0458984375], 95 | [0.0927734375, 0.1328125], 96 | [0.1064453125, 0.0703125], 97 | [0.142578125, 0.1025390625], 98 | ], 99 | "anchor_masks": [[6, 7, 8], [3, 4, 5], [0, 1, 2]], 100 | "num_classes": 1, 101 | "img_h": 608, 102 | "img_w": 608, 103 | }, 104 | } 105 | -------------------------------------------------------------------------------- /docs/images/labels.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmba15/yolov3_pytorch/993d9bf966965cb2f7800da2cb3b88ce1ea17f51/docs/images/labels.gif -------------------------------------------------------------------------------- /docs/images/test_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmba15/yolov3_pytorch/993d9bf966965cb2f7800da2cb3b88ce1ea17f51/docs/images/test_result.jpg -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: bird 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - defaults 6 | dependencies: 7 | - python >=3.6 8 | - cudatoolkit=10.0 9 | - cudnn=7.6.5 10 | - pip 11 | - pip: 12 | - -r requirements.txt 13 | -------------------------------------------------------------------------------- /estimate_priors_size_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | def main(args): 6 | import numpy as np 7 | from config import Config 8 | from utils import estimate_priors_sizes 9 | 10 | total_config = Config() 11 | if not args.dataset or args.dataset not in total_config.DATASETS.keys(): 12 | raise Exception("specify one of the datasets to use in {}".format(list(total_config.DATASETS.keys()))) 13 | 14 | DatasetClass = total_config.DATASETS[args.dataset] 15 | 16 | dataset = DatasetClass(data_path=total_config.DATA_PATH) 17 | clusters = estimate_priors_sizes(dataset, args.k) 18 | clusters = [list(elem) for elem in sorted(clusters, key=lambda x: x[0], reverse=False)] 19 | print("Anchors:") 20 | print(clusters) 21 | 22 | anchor_masks = list(np.arange(args.k)[::-1].reshape(3, -1)) 23 | anchor_masks = [list(elem[::-1]) for elem in anchor_masks] 24 | print("Anchor masks:") 25 | print(anchor_masks) 26 | 27 | 28 | if __name__ == "__main__": 29 | import argparse 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--k", type=int, default=9, help="number of clusters") 33 | parser.add_argument("--dataset", type=str, required=True, help="name of the dataset to use") 34 | parsed_args = parser.parse_args() 35 | 36 | main(parsed_args) 37 | -------------------------------------------------------------------------------- /export_onnx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from models import YoloNet 4 | import torch.onnx 5 | 6 | 7 | def main(args): 8 | import os 9 | from config import Config 10 | 11 | total_config = Config() 12 | if not args.dataset or args.dataset not in total_config.DATASETS.keys(): 13 | raise Exception("specify one of the datasets to use in {}".format(list(total_config.DATASETS.keys()))) 14 | if not args.snapshot or not os.path.isfile(args.snapshot): 15 | raise Exception("invalid snapshot") 16 | 17 | dataset = args.dataset 18 | dataset_class = total_config.DATASETS[dataset] 19 | dataset_params = total_config.DATASET_PARAMS[dataset] 20 | model = YoloNet(dataset_config=dataset_params) 21 | model.load_state_dict(torch.load(args.snapshot)["state_dict"]) 22 | model.eval() 23 | 24 | if args.batch_size: 25 | batch_size = args.batch_size 26 | else: 27 | batch_size = 1 28 | 29 | x = torch.randn(batch_size, 3, dataset_params["img_h"], dataset_params["img_w"]) 30 | torch.onnx.export( 31 | model, 32 | x, 33 | args.onnx_weight_file, 34 | verbose=True, 35 | input_names=["input"], 36 | output_names=["output"], 37 | do_constant_folding=True, 38 | operator_export_type=torch.onnx.OperatorExportTypes.ONNX, 39 | opset_version=11, 40 | ) 41 | 42 | if args.batch_size: 43 | return 44 | 45 | import onnx 46 | 47 | mp = onnx.load(args.onnx_weight_file) 48 | mp.graph.input[0].type.tensor_type.shape.dim[0].dim_param = "None" 49 | mp.graph.output[0].type.tensor_type.shape.dim[0].dim_param = "None" 50 | onnx.save(mp, "output.onnx") 51 | 52 | 53 | if __name__ == "__main__": 54 | import argparse 55 | 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument("--snapshot", required=True, type=str) 58 | parser.add_argument( 59 | "--dataset", type=str, required=True, help="name of the dataset to use", 60 | ) 61 | parser.add_argument("--onnx_weight_file", type=str, default="output.onnx") 62 | parser.add_argument("--batch_size", type=int) 63 | parsed_args = parser.parse_args() 64 | 65 | main(parsed_args) 66 | -------------------------------------------------------------------------------- /inference_onnx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from models import DetectHandler 4 | import torch 5 | 6 | 7 | def main(args): 8 | import onnxruntime as rt 9 | import cv2 10 | import numpy as np 11 | from config import Config 12 | 13 | total_config = Config() 14 | dataset = args.dataset 15 | dataset_class = total_config.DATASETS[dataset] 16 | dataset_params = total_config.DATASET_PARAMS[dataset] 17 | dataset_instance = dataset_class(data_path=total_config.DATA_PATH) 18 | 19 | img = cv2.imread(args.image_path) 20 | assert img is not None 21 | 22 | ori_h, ori_w = img.shape[:2] 23 | 24 | h_ratio = ori_h / args.img_h 25 | w_ratio = ori_w / args.img_w 26 | 27 | processed_img = cv2.resize(img, (args.img_w, args.img_h)) 28 | processed_img = processed_img / 255.0 29 | input_x = processed_img.transpose(2, 0, 1)[np.newaxis, :].astype(np.float32) 30 | 31 | sess = rt.InferenceSession(args.onnx_weight_file) 32 | 33 | assert len(sess.get_inputs()) == 1 34 | assert len(sess.get_outputs()) == 1 35 | 36 | input_name = sess.get_inputs()[0].name 37 | output_names = [elem.name for elem in sess.get_outputs()] 38 | predictions = sess.run(output_names, {input_name: input_x})[0] 39 | 40 | detect_handler = DetectHandler( 41 | num_classes=3, conf_thresh=args.conf_thresh, nms_thresh=args.nms_thresh, h_ratio=h_ratio, w_ratio=w_ratio, 42 | ) 43 | bboxes, scores, classes = detect_handler(predictions) 44 | 45 | result = dataset_class.visualize_one_image_util( 46 | img, dataset_instance.classes, dataset_instance.colors, bboxes, classes, 47 | ) 48 | cv2.imshow("result", result) 49 | cv2.waitKey(0) 50 | cv2.destroyAllWindows() 51 | 52 | 53 | if __name__ == "__main__": 54 | import argparse 55 | 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument("--img_h", type=int, required=True, help="height size of the input") 58 | parser.add_argument("--img_w", type=int, required=True, help="width size of the input") 59 | parser.add_argument( 60 | "--dataset", type=str, required=True, help="name of the dataset to use", 61 | ) 62 | parser.add_argument("--image_path", type=str, required=True, help="path to image") 63 | parser.add_argument("--onnx_weight_file", type=str, required=True) 64 | parser.add_argument("--conf_thresh", type=float, default=0.1) 65 | parser.add_argument("--nms_thresh", type=float, default=0.5) 66 | parsed_args = parser.parse_args() 67 | 68 | main(parsed_args) 69 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from .functions import * 4 | from .layers import * 5 | from .yolo_layer import YoloLayer 6 | from .yolo_v3 import YoloNet 7 | -------------------------------------------------------------------------------- /models/functions/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from .new_types import * 4 | from .utils import * 5 | from .detect import * 6 | from .mish import * 7 | -------------------------------------------------------------------------------- /models/functions/detect.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Function 7 | from .utils import nms 8 | from .new_types import BboxType 9 | 10 | 11 | __all__ = ["DetectHandler"] 12 | 13 | 14 | class DetectHandler(Function): 15 | def __init__( 16 | self, num_classes: int, conf_thresh: float, nms_thresh: float, h_ratio: float = None, w_ratio: float = None 17 | ): 18 | super(DetectHandler, self).__init__() 19 | self.num_classes = num_classes 20 | self.conf_thresh = conf_thresh 21 | self.nms_thresh = nms_thresh 22 | 23 | self.h_ratio = h_ratio 24 | self.w_ratio = w_ratio 25 | 26 | def __call__(self, predictions: torch.Tensor): 27 | if isinstance(predictions, np.ndarray): 28 | predictions = torch.FloatTensor(predictions) 29 | 30 | bboxes = predictions[..., :4].squeeze_(dim=0) 31 | scores = predictions[..., 4].squeeze_(dim=0) 32 | classes_one_hot = predictions[..., 5:].squeeze_(dim=0) 33 | classes = torch.argmax(classes_one_hot, dim=1) 34 | 35 | bboxes, scores, classes = nms( 36 | bboxes, 37 | scores, 38 | classes, 39 | num_classes=self.num_classes, 40 | bbox_mode=BboxType.CXCYWH, 41 | conf_thresh=self.conf_thresh, 42 | nms_thresh=self.nms_thresh, 43 | ) 44 | 45 | if self.h_ratio is not None and self.w_ratio is not None: 46 | bboxes[..., [0, 2]] *= self.w_ratio 47 | bboxes[..., [1, 3]] *= self.h_ratio 48 | 49 | return bboxes, scores, classes 50 | -------------------------------------------------------------------------------- /models/functions/mish.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | __all__ = ["Mish"] 9 | 10 | 11 | @torch.jit.script 12 | def _mish(input): 13 | """ 14 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) 15 | """ 16 | return input * torch.tanh(F.softplus(input)) 17 | 18 | 19 | class Mish(nn.Module): 20 | def __init__(self): 21 | super().__init__() 22 | 23 | def forward(self, input): 24 | return _mish(input) 25 | -------------------------------------------------------------------------------- /models/functions/new_types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import enum 4 | 5 | 6 | __all__ = ["BboxType"] 7 | 8 | 9 | class BboxType(enum.Enum): 10 | """ 11 | XYWH: xmin, ymin, width, height 12 | CXCYWH: xcenter, ycenter, width, height 13 | XYXY: xmin, ymin, xmax, ymax 14 | """ 15 | 16 | XYWH = 0 17 | CXCYWH = 1 18 | XYXY = 2 19 | -------------------------------------------------------------------------------- /models/functions/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import numpy as np 4 | import torch 5 | from .new_types import BboxType 6 | 7 | 8 | __all__ = ["bbox_iou", "nms", "transform_bbox"] 9 | 10 | 11 | def bbox_iou( 12 | bboxes1: torch.Tensor, bboxes2: torch.Tensor, bbox_mode: BboxType = BboxType.XYXY, epsilon: float = 1e-16 13 | ) -> torch.Tensor: 14 | """ 15 | Args: 16 | bboxes1: (num_boxes_1, 4) 17 | bboxes2: (num_boxes_2, 4) 18 | 19 | Return: 20 | ious: (num_boxes_1, num_boxes_2) 21 | """ 22 | if bbox_mode == BboxType.XYXY: 23 | b1_x1, b1_y1, b1_x2, b1_y2 = ( 24 | bboxes1[..., 0], 25 | bboxes1[..., 1], 26 | bboxes1[..., 2], 27 | bboxes1[..., 3], 28 | ) 29 | b2_x1, b2_y1, b2_x2, b2_y2 = ( 30 | bboxes2[..., 0], 31 | bboxes2[..., 1], 32 | bboxes2[..., 2], 33 | bboxes2[..., 3], 34 | ) 35 | elif bbox_mode == BboxType.CXCYWH: 36 | b1_x1, b1_x2 = bboxes1[..., 0] - bboxes1[..., 2] / 2, bboxes1[..., 0] + bboxes1[..., 2] / 2 37 | b1_y1, b1_y2 = bboxes1[..., 1] - bboxes1[..., 3] / 2, bboxes1[..., 1] + bboxes1[..., 3] / 2 38 | b2_x1, b2_x2 = bboxes2[..., 0] - bboxes2[..., 2] / 2, bboxes2[..., 0] + bboxes2[..., 2] / 2 39 | b2_y1, b2_y2 = bboxes2[..., 1] - bboxes2[..., 3] / 2, bboxes2[..., 1] + bboxes2[..., 3] / 2 40 | elif bbox_mode == BboxType.XYWH: 41 | b1_x1, b1_y1 = bboxes1[..., 0], bboxes1[..., 1] 42 | b2_x1, b2_y1 = bboxes2[..., 0], bboxes2[..., 1] 43 | b1_x2, b1_y2 = bboxes1[..., 0] + bboxes1[..., 2], bboxes1[..., 1] + bboxes1[..., 3] 44 | b2_x2, b2_y2 = bboxes2[..., 0] + bboxes2[..., 2], bboxes2[..., 1] + bboxes2[..., 3] 45 | else: 46 | raise Exception("not supported bbox type\n") 47 | 48 | num_b1 = bboxes1.shape[0] 49 | num_b2 = bboxes2.shape[0] 50 | 51 | inter_x1 = torch.max(b1_x1.unsqueeze(1).repeat(1, num_b2), b2_x1) 52 | inter_y1 = torch.max(b1_y1.unsqueeze(1).repeat(1, num_b2), b2_y1) 53 | inter_x2 = torch.min(b1_x2.unsqueeze(1).repeat(1, num_b2), b2_x2) 54 | inter_y2 = torch.min(b1_y2.unsqueeze(1).repeat(1, num_b2), b2_y2) 55 | 56 | inter_area = torch.clamp(inter_x2 - inter_x1, min=0) * torch.clamp(inter_y2 - inter_y1, min=0) 57 | b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) 58 | b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) 59 | union_area = b1_area.unsqueeze(1).repeat(1, num_b2) + b2_area.unsqueeze(0).repeat(num_b1, 1) - inter_area + epsilon 60 | 61 | iou = inter_area / union_area 62 | return iou 63 | 64 | 65 | def nms( 66 | bboxes: torch.Tensor, 67 | scores: torch.Tensor, 68 | classes: torch.Tensor, 69 | num_classes: int, 70 | conf_thresh: float = 0.8, 71 | nms_thresh: float = 0.5, 72 | bbox_mode: BboxType = BboxType.CXCYWH, 73 | ) -> (torch.Tensor, torch.Tensor, torch.Tensor): 74 | """ 75 | Args: 76 | bboxes: The location predictions for the img, Shape: [num_anchors,4]. 77 | scores: The class prediction scores for the img, Shape:[num_anchors]. 78 | classes: The label (non-one-hot) representation of the classes of the objects, 79 | Shape: [num_anchors]. 80 | num_classes: number of classes 81 | conf_thresh: threshold where all the detections below this value will be ignored. 82 | nms_thresh: overlap thresh for suppressing unnecessary boxes. 83 | 84 | Return: 85 | (bboxes, scores, classes) after nms suppression with bboxes in BboxType.XYXY box mode 86 | """ 87 | 88 | assert bboxes.shape[0] == scores.shape[0] == classes.shape[0] 89 | assert conf_thresh > 0 90 | 91 | num_anchors = bboxes.shape[0] 92 | 93 | if num_anchors == 0: 94 | return bboxes, scores, classes 95 | 96 | conf_index = torch.nonzero(torch.ge(scores, conf_thresh)).squeeze() 97 | bboxes = bboxes.index_select(0, conf_index) 98 | scores = scores.index_select(0, conf_index) 99 | classes = classes.index_select(0, conf_index) 100 | 101 | grouped_indices = _group_same_class_object(classes, one_hot=False, num_classes=num_classes) 102 | selected_indices_final = [] 103 | 104 | for class_id, member_idx in enumerate(grouped_indices): 105 | member_idx_tensor = bboxes.new_tensor(member_idx, dtype=torch.long) 106 | bboxes_one_class = bboxes.index_select(dim=0, index=member_idx_tensor) 107 | scores_one_class = scores.index_select(dim=0, index=member_idx_tensor) 108 | scores_one_class, sorted_indices = torch.sort(scores_one_class, descending=False) 109 | 110 | selected_indices = [] 111 | 112 | while sorted_indices.size(0) != 0: 113 | picked_index = sorted_indices[-1] 114 | selected_indices.append(picked_index) 115 | picked_bbox = bboxes_one_class[picked_index] 116 | 117 | picked_bbox.unsqueeze_(dim=0) 118 | 119 | ious = bbox_iou(picked_bbox, bboxes_one_class[sorted_indices[:-1]], bbox_mode=bbox_mode) 120 | ious.squeeze_(dim=0) 121 | 122 | under_indices = torch.nonzero(ious <= nms_thresh).squeeze() 123 | sorted_indices = sorted_indices.index_select(dim=0, index=under_indices) 124 | 125 | selected_indices_final.extend([member_idx[i] for i in selected_indices]) 126 | 127 | selected_indices_final = bboxes.new_tensor(selected_indices_final, dtype=torch.long) 128 | bboxes_result = bboxes.index_select(dim=0, index=selected_indices_final) 129 | scores_result = scores.index_select(dim=0, index=selected_indices_final) 130 | classes_result = classes.index_select(dim=0, index=selected_indices_final) 131 | 132 | return transform_bbox(bboxes_result, orig_mode=bbox_mode, target_mode=BboxType.XYXY), scores_result, classes_result 133 | 134 | 135 | def transform_bbox( 136 | bboxes: torch.Tensor, orig_mode: BboxType = BboxType.CXCYWH, target_mode: BboxType = BboxType.XYXY 137 | ) -> torch.Tensor: 138 | assert orig_mode != target_mode 139 | assert bboxes.shape[1] == 4 140 | 141 | if orig_mode == BboxType.CXCYWH: 142 | if target_mode == BboxType.XYXY: 143 | return torch.cat((bboxes[:, :2] - bboxes[:, 2:] / 2, bboxes[:, :2] + bboxes[:, 2:] / 2), dim=-1,) 144 | elif target_mode == BboxType.XYWH: 145 | return torch.cat((bboxes[:, :2] - bboxes[:, 2:] / 2, bboxes[:, :2]), dim=-1,) 146 | else: 147 | raise Exception("not supported conversion\n") 148 | elif orig_mode == BboxType.XYWH: 149 | if target_mode == BboxType.XYXY: 150 | return torch.cat((bboxes[:, :2], bboxes[:, 2:] + bboxes[:, :2]), dim=-1,) 151 | elif target_mode == BboxType.CXCYWH: 152 | return torch.cat((bboxes[:, :2] + bboxes[:, 2:] / 2, bboxes[:, 2:]), dim=-1,) 153 | else: 154 | raise Exception("not supported conversion\n") 155 | elif orig_mode == BboxType.XYXY: 156 | if target_mode == BboxType.CXCYWH: 157 | return torch.cat(((bboxes[:, :2] + bboxes[:, 2:]) / 2, bboxes[:, 2:] - bboxes[:, :2]), dim=-1,) 158 | if target_mode == BboxType.XYWH: 159 | return torch.cat((bboxes[:, :2], bboxes[:, 2:] - bboxes[:, :2]), dim=-1,) 160 | else: 161 | raise Exception("not supported conversion\n") 162 | else: 163 | raise Exception("not supported original bbox mode\n") 164 | 165 | 166 | def _group_same_class_object(obj_classes: torch.Tensor, one_hot: bool = True, num_classes: int = -1): 167 | """ 168 | Given a list of class results, group the object with the same class into a list. 169 | Returns a list with the length of num_classes, where each bucket has the objects with the same class. 170 | 171 | Args: 172 | obj_classes: The representation of classes of object. 173 | It can be either one-hot or label (non-one-hot). 174 | If it is one-hot, the shape should be: [num_objects, num_classes] 175 | If it is label (non-non-hot), the shape should be: [num_objects, ] 176 | one_hot: A flag telling the function whether obj_classes is one-hot representation. 177 | num_classes: The max number of classes if obj_classes is represented as non-one-hot format. 178 | 179 | Returns: 180 | a list of of a list, where for the i-th list, 181 | the elements in such list represent the indices of the objects in class i. 182 | """ 183 | 184 | if one_hot: 185 | num_classes = obj_classes.shape[-1] 186 | else: 187 | assert num_classes != -1 188 | grouped_index = [[] for _ in range(num_classes)] 189 | if one_hot: 190 | for idx, class_one_hot in enumerate(obj_classes): 191 | grouped_index[torch.argmax(class_one_hot)].append(idx) 192 | else: 193 | for idx, obj_class_ in enumerate(obj_classes): 194 | grouped_index[obj_class_].append(idx) 195 | 196 | return grouped_index 197 | -------------------------------------------------------------------------------- /models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from .backbone import * 4 | -------------------------------------------------------------------------------- /models/layers/backbone.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import torch.nn as nn 5 | from models.functions import Mish 6 | 7 | 8 | __all__ = ["ConvBnAct", "Darknet", "PreDetectionConvGroup", "UpsampleGroup"] 9 | 10 | 11 | def _activation_func(activation): 12 | import sys 13 | import copy 14 | 15 | try: 16 | return copy.deepcopy( 17 | nn.ModuleDict( 18 | [ 19 | ["relu", nn.ReLU(inplace=True)], 20 | ["leaky_relu", nn.LeakyReLU(negative_slope=0.1, inplace=True)], 21 | ["selu", nn.SELU(inplace=True)], 22 | ["mish", Mish()], 23 | ["identiy", nn.Identity()], 24 | ] 25 | )[activation] 26 | ) 27 | except Exception as e: 28 | print("no activation {}".format(activation)) 29 | sys.exit(-1) 30 | 31 | 32 | class ConvBnAct(nn.Module): 33 | def __init__( 34 | self, nin, nout, ks, s=1, pad="SAME", padding=0, use_bn=True, act="mish", 35 | ): 36 | super(ConvBnAct, self).__init__() 37 | 38 | self.use_bn = use_bn 39 | self.bn = None 40 | if use_bn: 41 | self.bn = nn.BatchNorm2d(nout) 42 | 43 | if pad == "SAME": 44 | padding = (ks - 1) // 2 45 | 46 | self.conv = nn.Conv2d( 47 | in_channels=nin, out_channels=nout, kernel_size=ks, stride=s, padding=padding, bias=not use_bn 48 | ) 49 | self.activation = _activation_func(act) 50 | 51 | def forward(self, x): 52 | out = self.conv(x) 53 | if self.use_bn: 54 | out = self.bn(out) 55 | 56 | return self.activation(out) 57 | 58 | 59 | class ResResidualBlock(nn.Module): 60 | def __init__(self, nin): 61 | super(ResResidualBlock, self).__init__() 62 | self.conv1 = ConvBnAct(nin, nin // 2, ks=1) 63 | self.conv2 = ConvBnAct(nin // 2, nin, ks=3) 64 | 65 | def forward(self, x): 66 | return x + self.conv2(self.conv1(x)) 67 | 68 | 69 | def _map_2_cfgdict(module_list): 70 | from collections import OrderedDict 71 | 72 | idx = 0 73 | mdict = OrderedDict() 74 | for i, m in enumerate(module_list): 75 | if isinstance(m, ResResidualBlock): 76 | mdict[idx] = None 77 | mdict[idx + 1] = None 78 | idx += 2 79 | mdict[idx] = i 80 | idx += 1 81 | return mdict 82 | 83 | 84 | def _make_res_stack(nin, num_blk): 85 | return nn.ModuleList([ConvBnAct(nin, nin * 2, 3, s=2)] + [ResResidualBlock(nin * 2) for n in range(num_blk)]) 86 | 87 | 88 | class Darknet(nn.Module): 89 | def __init__(self, blk_list, nout=32): 90 | super(Darknet, self).__init__() 91 | 92 | self.module_list = nn.ModuleList() 93 | self.module_list += [ConvBnAct(3, nout, 3)] 94 | for i, nb in enumerate(blk_list): 95 | self.module_list += _make_res_stack(nout * (2 ** i), nb) 96 | 97 | self.map2yolocfg = _map_2_cfgdict(self.module_list) 98 | self.cached_out_dict = dict() 99 | 100 | def forward(self, x): 101 | for i, m in enumerate(self.module_list): 102 | x = m(x) 103 | if i in self.cached_out_dict: 104 | self.cached_out_dict[i] = x 105 | return x 106 | 107 | # mode - normal -- direct index to module_list 108 | # - yolocfg -- index follow the sequences of the cfg file from https://github.com/pjreddie/darknet/blob/master/cfg/yolov3.cfg 109 | def add_cached_out(self, idx, mode="yolocfg"): 110 | if mode == "yolocfg": 111 | idxs = self.map2yolocfg[idx] 112 | self.cached_out_dict[idxs] = None 113 | 114 | def get_cached_out(self, idx, mode="yolocfg"): 115 | if mode == "yolocfg": 116 | idxs = self.map2yolocfg[idx] 117 | return self.cached_out_dict[idxs] 118 | 119 | def load_weight(self, weights_path): 120 | wm = WeightLoader(self) 121 | wm.load_weight(weights_path) 122 | 123 | 124 | class PreDetectionConvGroup(nn.Module): 125 | def __init__(self, nin, nout, num_classes, num_anchors, num_conv=3): 126 | super(PreDetectionConvGroup, self).__init__() 127 | self.module_list = nn.ModuleList() 128 | 129 | for i in range(num_conv): 130 | self.module_list += [ConvBnAct(nin, nout, ks=1)] 131 | self.module_list += [ConvBnAct(nout, nout * 2, ks=3)] 132 | if i == 0: 133 | nin = nout * 2 134 | 135 | self.module_list += [nn.Conv2d(nin, (num_classes + 5) * num_anchors, 1)] 136 | self.map2yolocfg = _map_2_cfgdict(self.module_list) 137 | self.cached_out_dict = dict() 138 | 139 | def forward(self, x): 140 | for i, m in enumerate(self.module_list): 141 | x = m(x) 142 | if i in self.cached_out_dict: 143 | self.cached_out_dict[i] = x 144 | return x 145 | 146 | def add_cached_out(self, idx, mode="yolocfg"): 147 | if mode == "yolocfg": 148 | idx = self.get_idx_from_yolo_idx(idx) 149 | elif idx < 0: 150 | idx = len(self.module_list) - idx 151 | 152 | self.cached_out_dict[idx] = None 153 | 154 | def get_cached_out(self, idx, mode="yolocfg"): 155 | if mode == "yolocfg": 156 | idx = self.get_idx_from_yolo_idx(idx) 157 | elif idx < 0: 158 | idx = len(self.module_list) - idx 159 | return self.cached_out_dict[idx] 160 | 161 | def get_idx_from_yolo_idx(self, idx): 162 | if idx < 0: 163 | return len(self.map2yolocfg) + idx 164 | else: 165 | return self.map2yolocfg[idx] 166 | 167 | 168 | class UpsampleGroup(nn.Module): 169 | def __init__(self, nin): 170 | super(UpsampleGroup, self).__init__() 171 | self.conv = ConvBnAct(nin, nin // 2, ks=1) 172 | 173 | def forward(self, route_head, route_tail): 174 | out = self.conv(route_head) 175 | out = nn.functional.interpolate(out, scale_factor=2, mode="nearest") 176 | return torch.cat((out, route_tail), 1) 177 | 178 | 179 | class WeightLoader: 180 | def __init__(self, model): 181 | super(WeightLoader, self).__init__() 182 | self._conv_list = self._find_conv_layers(model) 183 | 184 | def load_weight(self, weight_path): 185 | ptr = 0 186 | weights = self._read_file(weight_path) 187 | for m in self._conv_list: 188 | if type(m) == ConvBnAct: 189 | ptr = self._load_conv_bn_relu(m, weights, ptr) 190 | elif type(m) == nn.Conv2d: 191 | ptr = self._load_conv2d(m, weights, ptr) 192 | return ptr 193 | 194 | def _read_file(self, file): 195 | import numpy as np 196 | 197 | with open(file, "rb") as fp: 198 | header = np.fromfile(fp, dtype=np.int32, count=5) 199 | self.header = torch.from_numpy(header) 200 | self.seen = self.header[3] 201 | weights = np.fromfile(fp, dtype=np.float32) 202 | return weights 203 | 204 | def _copy_weight_to_model_parameters(self, param, weights, ptr): 205 | num_el = param.numel() 206 | param.data.copy_(torch.from_numpy(weights[ptr : ptr + num_el]).view_as(param.data)) 207 | return ptr + num_el 208 | 209 | def _load_conv_bn_relu(self, m, weights, ptr): 210 | ptr = self._copy_weight_to_model_parameters(m.bn.bias, weights, ptr) 211 | ptr = self._copy_weight_to_model_parameters(m.bn.weight, weights, ptr) 212 | ptr = self._copy_weight_to_model_parameters(m.bn.running_mean, weights, ptr) 213 | ptr = self._copy_weight_to_model_parameters(m.bn.running_var, weights, ptr) 214 | ptr = self._copy_weight_to_model_parameters(m.conv.weight, weights, ptr) 215 | return ptr 216 | 217 | def _load_conv2d(self, m, weights, ptr): 218 | ptr = self._copy_weight_to_model_parameters(m.bias, weights, ptr) 219 | ptr = self._copy_weight_to_model_parameters(m.weight, weights, ptr) 220 | return ptr 221 | 222 | def _find_conv_layers(self, mod): 223 | module_list = [] 224 | for m in mod.children(): 225 | if type(m) == ConvBnAct: 226 | module_list += [m] 227 | elif type(m) == nn.Conv2d: 228 | module_list += [m] 229 | elif isinstance(m, (nn.ModuleList, nn.Module)): 230 | module_list += self._find_conv_layers(m) 231 | elif type(m) == ResResidualBlock: 232 | module_list += self._find_conv_layers(m) 233 | return module_list 234 | -------------------------------------------------------------------------------- /models/yolo_layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | from .functions import bbox_iou, BboxType 7 | 8 | 9 | __all__ = ["YoloLayer"] 10 | 11 | 12 | class YoloLayer(nn.Module): 13 | def __init__( 14 | self, 15 | anchors_all, 16 | anchors_mask, 17 | num_classes, 18 | lambda_xy=1, 19 | lambda_wh=1, 20 | lambda_conf=1, 21 | lambda_cls=1, 22 | obj_scale=1, 23 | noobj_scale=1, 24 | ignore_thres=0.7, 25 | epsilon=1e-16, 26 | ): 27 | super(YoloLayer, self).__init__() 28 | 29 | assert num_classes > 0 30 | 31 | self._anchors_all = anchors_all 32 | self._anchors_mask = anchors_mask 33 | 34 | self._num_classes = num_classes 35 | self._bbox_attrib = 5 + num_classes 36 | 37 | self._lambda_xy = lambda_xy 38 | self._lambda_wh = lambda_wh 39 | self._lambda_conf = lambda_conf 40 | 41 | if self._num_classes == 1: 42 | self._lambda_cls = 0 43 | else: 44 | self._lambda_cls = lambda_cls 45 | 46 | self._obj_scale = obj_scale 47 | self._noobj_scale = noobj_scale 48 | self._ignore_thres = ignore_thres 49 | 50 | self._epsilon = epsilon 51 | 52 | self._mseloss = nn.MSELoss(reduction="sum") 53 | self._bceloss = nn.BCELoss(reduction="sum") 54 | self._bceloss_average = nn.BCELoss(reduction="elementwise_mean") 55 | 56 | def forward(self, x: torch.Tensor, img_dim: tuple, target=None): 57 | # x : batch_size * nA * (5 + num_classes) * H * W 58 | 59 | device = x.device 60 | if target is not None: 61 | assert target.device == x.device 62 | 63 | nB = x.shape[0] 64 | nA = len(self._anchors_mask) 65 | nH, nW = x.shape[2], x.shape[3] 66 | stride = img_dim[1] / nH 67 | anchors_all = torch.FloatTensor(self._anchors_all) / stride 68 | anchors = anchors_all[self._anchors_mask] 69 | 70 | # Reshape predictions from [B x [A * (5 + num_classes)] x H x W] to [B x A x H x W x (5 + num_classes)] 71 | preds = x.view(nB, nA, self._bbox_attrib, nH, nW).permute(0, 1, 3, 4, 2).contiguous() 72 | 73 | # tx, ty, tw, wh 74 | preds_xy = preds[..., :2].sigmoid() 75 | preds_wh = preds[..., 2:4] 76 | preds_conf = preds[..., 4].sigmoid() 77 | preds_cls = preds[..., 5:].sigmoid() 78 | 79 | # calculate cx, cy, anchor mesh 80 | mesh_y, mesh_x = torch.meshgrid([torch.arange(nH, device=device), torch.arange(nW, device=device)]) 81 | mesh_xy = torch.stack((mesh_x, mesh_y), 2).float() 82 | 83 | mesh_anchors = anchors.view(1, nA, 1, 1, 2).repeat(1, 1, nH, nW, 1).to(device) 84 | 85 | # pred_boxes holds bx,by,bw,bh 86 | pred_boxes = torch.FloatTensor(preds[..., :4].shape) 87 | pred_boxes[..., :2] = preds_xy + mesh_xy 88 | pred_boxes[..., 2:4] = preds_wh.exp() * mesh_anchors 89 | 90 | if target is not None: 91 | ( 92 | obj_mask, 93 | noobj_mask, 94 | box_coord_mask, 95 | tconf, 96 | tcls, 97 | tx, 98 | ty, 99 | tw, 100 | th, 101 | nCorrect, 102 | nGT, 103 | ) = self.build_target_tensor( 104 | pred_boxes, target, anchors_all, anchors, (nH, nW), self._num_classes, self._ignore_thres, 105 | ) 106 | 107 | # masks for loss calculations 108 | obj_mask, noobj_mask = obj_mask.to(device), noobj_mask.to(device) 109 | box_coord_mask = box_coord_mask.to(device) 110 | cls_mask = obj_mask == 1 111 | tconf, tcls = tconf.to(device), tcls.to(device) 112 | tx, ty, tw, th = tx.to(device), ty.to(device), tw.to(device), th.to(device) 113 | 114 | loss_x = self._lambda_xy * self._mseloss(preds_xy[..., 0] * box_coord_mask, tx * box_coord_mask) / 2 115 | loss_y = self._lambda_xy * self._mseloss(preds_xy[..., 1] * box_coord_mask, ty * box_coord_mask) / 2 116 | loss_w = self._lambda_wh * self._mseloss(preds_wh[..., 0] * box_coord_mask, tw * box_coord_mask) / 2 117 | loss_h = self._lambda_wh * self._mseloss(preds_wh[..., 1] * box_coord_mask, th * box_coord_mask) / 2 118 | 119 | loss_conf = ( 120 | self._lambda_conf 121 | * ( 122 | self._obj_scale * self._bceloss(preds_conf * obj_mask, obj_mask) 123 | + self._noobj_scale * self._bceloss(preds_conf * noobj_mask, noobj_mask * 0) 124 | ) 125 | / 1 126 | ) 127 | loss_cls = self._lambda_cls * self._bceloss(preds_cls[cls_mask], tcls[cls_mask]) / 1 128 | loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls 129 | 130 | return ( 131 | loss, 132 | loss.item() / nB, 133 | loss_x.item() / nB, 134 | loss_y.item() / nB, 135 | loss_w.item() / nB, 136 | loss_h.item() / nB, 137 | loss_conf.item() / nB, 138 | loss_cls.item() / nB, 139 | nCorrect, 140 | nGT, 141 | ) 142 | 143 | out = torch.cat((pred_boxes.to(device) * stride, preds_conf.to(device).unsqueeze(4), preds_cls.to(device),), 4,) 144 | 145 | # Reshape predictions from [B x A x H x W x (5 + num_classes)] to [B x [A x H x W] x (5 + num_classes)] 146 | out = out.permute(0, 2, 3, 1, 4).contiguous().view(nB, nA * nH * nW, self._bbox_attrib) 147 | 148 | return out 149 | 150 | def build_target_tensor( 151 | self, pred_boxes, target, anchors_all, anchors, inp_dim, num_classes, ignore_thres, 152 | ): 153 | nB = target.shape[0] 154 | nA = len(anchors) 155 | nH, nW = inp_dim[0], inp_dim[1] 156 | nCorrect = 0 157 | nGT = 0 158 | target = target.float() 159 | 160 | obj_mask = torch.zeros(nB, nA, nH, nW, requires_grad=False) 161 | noobj_mask = torch.ones(nB, nA, nH, nW, requires_grad=False) 162 | box_coord_mask = torch.zeros(nB, nA, nH, nW, requires_grad=False) 163 | tconf = torch.zeros(nB, nA, nH, nW, requires_grad=False) 164 | tcls = torch.zeros(nB, nA, nH, nW, num_classes, requires_grad=False) 165 | tx = torch.zeros(nB, nA, nH, nW, requires_grad=False) 166 | ty = torch.zeros(nB, nA, nH, nW, requires_grad=False) 167 | tw = torch.zeros(nB, nA, nH, nW, requires_grad=False) 168 | th = torch.zeros(nB, nA, nH, nW, requires_grad=False) 169 | 170 | for b in range(nB): 171 | for t in range(target.shape[1]): 172 | 173 | # ignore padded labels 174 | if target[b, t].sum() == 0: 175 | break 176 | 177 | gx = target[b, t, 0] * nW 178 | gy = target[b, t, 1] * nH 179 | gw = target[b, t, 2] * nW 180 | gh = target[b, t, 3] * nH 181 | gi = int(gx) 182 | gj = int(gy) 183 | 184 | # pred_boxes - [A x H x W x 4] 185 | # Do not train for objectness(noobj) if anchor iou > threshold. 186 | tmp_gt_boxes = torch.FloatTensor([gx, gy, gw, gh]).unsqueeze(0) 187 | tmp_pred_boxes = pred_boxes[b].view(-1, 4) 188 | tmp_ious, _ = torch.max(bbox_iou(tmp_pred_boxes, tmp_gt_boxes, bbox_mode=BboxType.CXCYWH), 1) 189 | ignore_idx = (tmp_ious > ignore_thres).view(nA, nH, nW) 190 | noobj_mask[b][ignore_idx] = 0 191 | 192 | # find best fit anchor for each ground truth box 193 | tmp_gt_boxes = torch.FloatTensor([[0, 0, gw, gh]]) 194 | tmp_anchor_boxes = torch.cat((torch.zeros(len(anchors_all), 2), anchors_all), 1) 195 | tmp_ious = bbox_iou(tmp_anchor_boxes, tmp_gt_boxes, bbox_mode=BboxType.CXCYWH) 196 | best_anchor = torch.argmax(tmp_ious, 0).item() 197 | 198 | # If the best_anchor belongs to this yolo_layer 199 | if best_anchor in self._anchors_mask: 200 | best_anchor = self._anchors_mask.index(best_anchor) 201 | # find iou for best fit anchor prediction box against the ground truth box 202 | tmp_gt_box = torch.FloatTensor([gx, gy, gw, gh]).unsqueeze(0) 203 | tmp_pred_box = pred_boxes[b, best_anchor, gj, gi].view(-1, 4) 204 | tmp_iou = bbox_iou(tmp_gt_box, tmp_pred_box, bbox_mode=BboxType.CXCYWH) 205 | 206 | if tmp_iou > 0.5: 207 | nCorrect += 1 208 | 209 | # larger gradient for small objects 210 | box_coord_mask[b, best_anchor, gj, gi] = math.sqrt(2 - target[b, t, 2] * target[b, t, 3]) 211 | 212 | obj_mask[b, best_anchor, gj, gi] = 1 213 | tconf[b, best_anchor, gj, gi] = 1 214 | tcls[b, best_anchor, gj, gi, int(target[b, t, 4])] = 1 215 | tx[b, best_anchor, gj, gi] = gx - gi 216 | ty[b, best_anchor, gj, gi] = gy - gj 217 | tw[b, best_anchor, gj, gi] = torch.log(gw / anchors[best_anchor, 0] + self._epsilon) 218 | th[b, best_anchor, gj, gi] = torch.log(gh / anchors[best_anchor, 1] + self._epsilon) 219 | 220 | nGT += 1 221 | return ( 222 | obj_mask, 223 | noobj_mask, 224 | box_coord_mask, 225 | tconf, 226 | tcls, 227 | tx, 228 | ty, 229 | tw, 230 | th, 231 | nCorrect, 232 | nGT, 233 | ) 234 | -------------------------------------------------------------------------------- /models/yolo_v3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .yolo_layer import YoloLayer 7 | from .layers import Darknet, PreDetectionConvGroup, UpsampleGroup 8 | 9 | 10 | __all__ = ["YoloNet"] 11 | 12 | _YOLOLAYER_PARAMS = { 13 | "lambda_xy": 1, 14 | "lambda_wh": 1, 15 | "lambda_conf": 1, 16 | "lambda_cls": 1, 17 | "obj_scale": 1, 18 | "noobj_scale": 1, 19 | "ignore_thres": 0.7, 20 | } 21 | 22 | 23 | class YoloNet(nn.Module): 24 | def __init__(self, dataset_config, yololayer_params=_YOLOLAYER_PARAMS): 25 | super(YoloNet, self).__init__() 26 | 27 | self._yololayer_params = yololayer_params 28 | self._all_anchors = [ 29 | [int(w_e * dataset_config["img_w"]), int(h_e * dataset_config["img_h"])] 30 | for (w_e, h_e) in dataset_config["anchors"] 31 | ] 32 | self._anchors_masks = dataset_config["anchor_masks"] 33 | assert len(self._anchors_masks) > 0 34 | self._num_anchors_each_layer = len(self._anchors_masks[0]) 35 | 36 | self._num_classes = dataset_config["num_classes"] 37 | 38 | self.stat_keys = [ 39 | "loss", 40 | "loss_x", 41 | "loss_y", 42 | "loss_w", 43 | "loss_h", 44 | "loss_conf", 45 | "loss_cls", 46 | "nCorrect", 47 | "nGT", 48 | "recall", 49 | ] 50 | 51 | self.feature = Darknet([1, 2, 8, 8, 4]) 52 | self.feature.add_cached_out(61) 53 | self.feature.add_cached_out(36) 54 | 55 | self.pre_det1 = PreDetectionConvGroup( 56 | 1024, 512, num_classes=self._num_classes, num_anchors=self._num_anchors_each_layer 57 | ) 58 | self.pre_det1.add_cached_out(-3) 59 | 60 | self.up1 = UpsampleGroup(512) 61 | self.pre_det2 = PreDetectionConvGroup( 62 | 768, 256, num_classes=self._num_classes, num_anchors=self._num_anchors_each_layer 63 | ) 64 | self.pre_det2.add_cached_out(-3) 65 | 66 | self.up2 = UpsampleGroup(256) 67 | self.pre_det3 = PreDetectionConvGroup( 68 | 384, 128, num_classes=self._num_classes, num_anchors=self._num_anchors_each_layer 69 | ) 70 | 71 | self.yolo_layers = [ 72 | YoloLayer( 73 | anchors_all=self._all_anchors, 74 | anchors_mask=anchors_mask, 75 | num_classes=self._num_classes, 76 | lambda_xy=self._yololayer_params["lambda_xy"], 77 | lambda_wh=self._yololayer_params["lambda_wh"], 78 | lambda_conf=self._yololayer_params["lambda_conf"], 79 | lambda_cls=self._yololayer_params["lambda_cls"], 80 | obj_scale=self._yololayer_params["obj_scale"], 81 | noobj_scale=self._yololayer_params["noobj_scale"], 82 | ignore_thres=self._yololayer_params["ignore_thres"], 83 | ) 84 | for anchors_mask in self._anchors_masks 85 | ] 86 | 87 | def forward(self, x: torch.Tensor, target=None): 88 | img_dim = (x.shape[3], x.shape[2]) 89 | out = self.feature(x) 90 | dets = [] 91 | 92 | # Detection layer 1 93 | out = self.pre_det1(out) 94 | dets.append(self.yolo_layers[0](out, img_dim, target)) 95 | 96 | # Upsample 1 97 | r_head1 = self.pre_det1.get_cached_out(-3) 98 | r_tail1 = self.feature.get_cached_out(61) 99 | out = self.up1(r_head1, r_tail1) 100 | 101 | # Detection layer 2 102 | out = self.pre_det2(out) 103 | dets.append(self.yolo_layers[1](out, img_dim, target)) 104 | 105 | # Upsample 2 106 | r_head2 = self.pre_det2.get_cached_out(-3) 107 | r_tail2 = self.feature.get_cached_out(36) 108 | out = self.up2(r_head2, r_tail2) 109 | 110 | # Detection layer 3 111 | out = self.pre_det3(out) 112 | dets.append(self.yolo_layers[2](out, img_dim, target)) 113 | 114 | if target is not None: 115 | loss, *out = [sum(det) for det in zip(dets[0], dets[1], dets[2])] 116 | 117 | self.stats = dict(zip(self.stat_keys, out)) 118 | self.stats["recall"] = self.stats["nCorrect"] / self.stats["nGT"] if self.stats["nGT"] else 0 119 | return loss 120 | else: 121 | return torch.cat(dets, 1) 122 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==0.4.3 2 | torch==1.5.0 3 | torchvision==0.6.0 4 | onnx==1.7.0 5 | onnxruntime==1.3.0 6 | tensorboardX==2.0 7 | pycocotools 8 | tqdm 9 | -------------------------------------------------------------------------------- /scripts/download_bird_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import sys 5 | 6 | 7 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) 8 | sys.path.append(os.path.join(_CURRENT_DIR, "..")) 9 | try: 10 | from utils import download_file_from_google_drive 11 | except Exception as e: 12 | print(e) 13 | exit(0) 14 | 15 | 16 | def main(): 17 | data_path = os.path.join(_CURRENT_DIR, "../data") 18 | file_id = "16IQjiGu-jl2oTqr5wsp9MmJxtQiuyIWq" 19 | destination = os.path.join(data_path, "bird_dataset.zip") 20 | if not os.path.isfile(destination) and not os.path.isdir(os.path.join(data_path, "bird_dataset")): 21 | download_file_from_google_drive(file_id, destination) 22 | os.system("cd {} && unzip bird_dataset.zip".format(data_path)) 23 | 24 | 25 | if __name__ == "__main__": 26 | main() 27 | -------------------------------------------------------------------------------- /scripts/download_darknet_weight.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import sys 5 | 6 | 7 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) 8 | sys.path.append(os.path.join(_CURRENT_DIR, "..")) 9 | try: 10 | from utils import download_file_from_google_drive 11 | except Exception as e: 12 | print(e) 13 | sys.exit(-1) 14 | 15 | 16 | def main(): 17 | data_path = os.path.join(_CURRENT_DIR, "../saved_models") 18 | os.system("mkdir -p {}".format(data_path)) 19 | file_id = "1_-FQFU1i79WySBehqdUXAdbI_-RSvFqb" 20 | destination = os.path.join(data_path, "darknet53.conv.74") 21 | if not os.path.isfile(destination): 22 | download_file_from_google_drive(file_id, destination) 23 | 24 | 25 | if __name__ == "__main__": 26 | main() 27 | -------------------------------------------------------------------------------- /scripts/download_darknet_weight.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | readonly WEIGHT_PATH="https://pjreddie.com/media/files/darknet53.conv.74" 4 | readonly CURRENT_DIR=$(dirname $(realpath $0)) 5 | readonly DATA_PATH=$(realpath ${CURRENT_DIR}/../saved_models) 6 | 7 | function validate_url { 8 | wget --spider $1 &> /dev/null; 9 | } 10 | 11 | if ! validate_url $WEIGHT_PATH; then 12 | echo "Invalid url to download darknet weight"; 13 | exit; 14 | fi 15 | 16 | echo "start downloading darknet weights" 17 | if [ ! -d ${DATA_PATH} ]; then 18 | mkdir -p ${DATA_PATH} 19 | fi 20 | 21 | if [ ! -f ${DATA_PATH}/darknet53.conv.74 ]; then 22 | wget -c ${WEIGHT_PATH} -P ${DATA_PATH} 23 | fi 24 | -------------------------------------------------------------------------------- /scripts/download_switch_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | 5 | 6 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) 7 | sys.path.append(os.path.join(_CURRENT_DIR, "..")) 8 | try: 9 | from utils import download_file_from_google_drive 10 | except Exception as e: 11 | print(e) 12 | exit(0) 13 | 14 | 15 | def main(): 16 | data_path = os.path.join(_CURRENT_DIR, "../data") 17 | file_id = "1EWEhmvDaYYm0SsydUEGWUDrnBzkLEQc_" 18 | destination = os.path.join(data_path, "switch_detection.zip") 19 | if not os.path.isfile(destination) and not os.path.isdir(os.path.join(data_path, "switch_detection")): 20 | download_file_from_google_drive(file_id, destination) 21 | os.system("cd {} && unzip switch_detection.zip".format(data_path)) 22 | 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /scripts/download_udacity_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # https://github.com/BeSlower/Udacity_object_dataset 4 | readonly CURRENT_DIR=$(dirname $(realpath $0)) 5 | readonly DATA_PATH_BASE=$(realpath ${CURRENT_DIR}/../data) 6 | readonly DATA_PATH=${DATA_PATH_BASE}/udacity 7 | 8 | echo "start downloading udacity dataset" 9 | if [ ! -d ${DATA_PATH} ]; then 10 | mkdir -p ${DATA_PATH} 11 | fi 12 | 13 | if [ ! -f ${DATA_PATH}/object-dataset.tar.gz ]; then 14 | wget -c https://s3.amazonaws.com/udacity-sdc/annotations/object-dataset.tar.gz -P ${DATA_PATH} 15 | fi 16 | 17 | tar -xvf ${DATA_PATH}/object-dataset.tar.gz -C ${DATA_PATH} 18 | -------------------------------------------------------------------------------- /scripts/download_voc_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | readonly CURRENT_DIR=$(dirname $(realpath $0)) 4 | readonly DATA_PATH_BASE=$(realpath ${CURRENT_DIR}/../data) 5 | readonly DATA_PATH=${DATA_PATH_BASE}/voc 6 | if [ ! -d ${DATA_PATH} ]; then 7 | mkdir -p ${DATA_PATH} 8 | fi 9 | 10 | DOWNLOAD_FILES=' 11 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 12 | http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar 13 | ' 14 | 15 | for download_file in ${DOWNLOAD_FILES}; do 16 | file_name=$(basename $download_file) 17 | if [ ! -f $DATA_PATH/$file_name ]; then 18 | wget ${download_file} -P $DATA_PATH 19 | extension="${file_name##*.}" 20 | if [[ ${extension} == "zip" ]]; then 21 | unzip $DATA_PATH/${file_name} -d ${DATA_PATH} 22 | fi 23 | 24 | if [[ ${extension} == "tar" ]]; then 25 | tar xvf $DATA_PATH/${file_name} -C ${DATA_PATH} 26 | fi 27 | fi 28 | done 29 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | from models import YoloNet 5 | from data_loader import DataTransformBase 6 | from models import DetectHandler 7 | 8 | 9 | def test_one_image( 10 | args, total_config, dataset_class, dataset_params, 11 | ): 12 | import cv2 13 | import numpy as np 14 | 15 | model_path = args.snapshot 16 | dataset = args.dataset 17 | dataset_params = total_config.DATASET_PARAMS[dataset] 18 | input_size = (dataset_params["img_h"], dataset_params["img_w"]) 19 | 20 | dataset_instance = dataset_class(data_path=total_config.DATA_PATH) 21 | num_classes = dataset_instance.num_classes 22 | model = YoloNet(dataset_config=dataset_params) 23 | model.load_state_dict(torch.load(model_path)["state_dict"]) 24 | model.eval() 25 | 26 | img = cv2.imread(args.image_path) 27 | orig_img = np.copy(img) 28 | 29 | ori_h, ori_w = img.shape[:2] 30 | h_ratio = ori_h / dataset_params["img_h"] 31 | w_ratio = ori_w / dataset_params["img_w"] 32 | 33 | img = cv2.resize(img, input_size) 34 | 35 | img = img / 255.0 36 | input_x = torch.tensor(img.transpose(2, 0, 1)[np.newaxis, :]).float() 37 | 38 | predictions = model(input_x) 39 | 40 | detect_handler = DetectHandler( 41 | num_classes=dataset_params["num_classes"], 42 | conf_thresh=args.conf_thresh, 43 | nms_thresh=args.nms_thresh, 44 | h_ratio=h_ratio, 45 | w_ratio=w_ratio, 46 | ) 47 | 48 | bboxes, scores, classes = detect_handler(predictions) 49 | 50 | result = dataset_class.visualize_one_image_util( 51 | orig_img, dataset_instance.classes, dataset_instance.colors, bboxes, classes, 52 | ) 53 | 54 | return orig_img 55 | 56 | 57 | def main(args): 58 | import cv2 59 | import os 60 | from config import Config 61 | 62 | total_config = Config() 63 | if not args.dataset or args.dataset not in total_config.DATASETS.keys(): 64 | raise Exception("specify one of the datasets to use in {}".format(list(total_config.DATASETS.keys()))) 65 | if not args.snapshot or not os.path.isfile(args.snapshot): 66 | raise Exception("invalid snapshot") 67 | if not args.image_path or not os.path.isfile(args.image_path): 68 | raise Exception("invalid image path") 69 | 70 | dataset = args.dataset 71 | dataset_class = total_config.DATASETS[dataset] 72 | dataset_params = total_config.DATASET_PARAMS[dataset] 73 | 74 | result = test_one_image(args, total_config, dataset_class, dataset_params,) 75 | 76 | cv2.imshow("result", result) 77 | cv2.waitKey(0) 78 | cv2.destroyAllWindows() 79 | cv2.imwrite("result.jpg", result) 80 | 81 | 82 | if __name__ == "__main__": 83 | import argparse 84 | 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument("--snapshot", required=True, type=str) 87 | parser.add_argument( 88 | "--dataset", type=str, required=True, help="name of the dataset to use", 89 | ) 90 | parser.add_argument("--image_path", required=True, type=str, help="path to the test image") 91 | parser.add_argument("--conf_thresh", type=float, default=0.1) 92 | parser.add_argument("--nms_thresh", type=float, default=0.5) 93 | parsed_args = parser.parse_args() 94 | 95 | main(parsed_args) 96 | -------------------------------------------------------------------------------- /test/get_dataset_size_distribution.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import sys 5 | 6 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) 7 | try: 8 | sys.path.append(os.path.join(_CURRENT_DIR, "..")) 9 | from config import Config 10 | except Exception as e: 11 | print(e) 12 | exit(1) 13 | 14 | 15 | def main(args): 16 | import numpy as np 17 | from matplotlib import pyplot as plt 18 | 19 | total_config = Config() 20 | if not args.dataset or args.dataset not in total_config.DATASETS.keys(): 21 | raise Exception("specify one of the datasets to use in {}".format(list(total_config.DATASETS.keys()))) 22 | 23 | DatasetClass = total_config.DATASETS[args.dataset] 24 | 25 | dataset = DatasetClass(data_path=total_config.DATA_PATH) 26 | size_list = dataset.size_distribution() 27 | hist, bins = np.histogram(size_list, bins=args.num_bins) 28 | 29 | print("min size {}".format(min(size_list))) 30 | print("max size {}".format(max(size_list))) 31 | print(hist, bins) 32 | plt.hist(size_list, bins=args.num_bins) 33 | plt.title("size/distance histogram") 34 | # plt.show() 35 | plt.savefig("size_distance_hist.png") 36 | 37 | 38 | if __name__ == "__main__": 39 | import argparse 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--dataset", type=str, required=True, help="name of the dataset to use") 43 | parser.add_argument("--num_bins", type=int, default=18, help="number of bins in histogram") 44 | parsed_args = parser.parse_args() 45 | 46 | main(parsed_args) 47 | -------------------------------------------------------------------------------- /test/test_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import sys 5 | 6 | 7 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) 8 | try: 9 | sys.path.append(os.path.join(_CURRENT_DIR, "..")) 10 | from models import YoloNet 11 | except Exception as e: 12 | print(e) 13 | sys.exit(-1) 14 | 15 | 16 | def main(args): 17 | import torch 18 | 19 | input_tensor = torch.randn(1, 3, args.img_h, args.img_w) 20 | 21 | params = { 22 | "anchors": [ 23 | [0.016923076923076923, 0.027196652719665274], 24 | [0.018, 0.013855213023900243], 25 | [0.02355072463768116, 0.044977511244377814], 26 | [0.033722163308589605, 0.025525525525525526], 27 | [0.049479166666666664, 0.049575070821529746], 28 | [0.05290373906125696, 0.08290488431876607], 29 | [0.09375, 0.098], 30 | [0.150390625, 0.1838283227241353], 31 | [0.26125, 0.36540185240513895], 32 | ], 33 | "anchor_masks": [[6, 7, 8], [3, 4, 5], [0, 1, 2]], 34 | "num_classes": args.num_classes, 35 | "img_w": args.img_w, 36 | "img_h": args.img_h, 37 | } 38 | 39 | model = YoloNet(dataset_config=params) 40 | output = model(input_tensor) 41 | for idx, p in enumerate(output): 42 | print("branch_{}: {}".format(idx, p.size())) 43 | 44 | 45 | if __name__ == "__main__": 46 | import argparse 47 | 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("--img_h", type=int, default=608) 50 | parser.add_argument("--img_w", type=int, default=608) 51 | parser.add_argument("--num_classes", type=int, default=80) 52 | parsed_args = parser.parse_args() 53 | 54 | main(parsed_args) 55 | -------------------------------------------------------------------------------- /test/test_visualization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import cv2 5 | import sys 6 | import argparse 7 | 8 | 9 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) 10 | try: 11 | sys.path.append(os.path.join(_CURRENT_DIR, "..")) 12 | from config import Config 13 | except Exception as e: 14 | print(e) 15 | sys.exit(-1) 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--idx", type=int, default=0, help="index of the images") 20 | parser.add_argument("--dataset", type=str, help="name of the dataset to use") 21 | parsed_args = parser.parse_args() 22 | 23 | 24 | def main(args): 25 | dt_config = Config() 26 | if not args.dataset or args.dataset not in dt_config.DATASETS.keys(): 27 | raise Exception("specify one of the datasets to use in {}".format(list(dt_config.DATASETS.keys()))) 28 | 29 | DatasetClass = dt_config.DATASETS[args.dataset] 30 | 31 | dataset = DatasetClass(data_path=dt_config.DATA_PATH) 32 | print("length of the dataset: {}".format(len(dataset))) 33 | 34 | assert args.idx < len(dataset) 35 | img = dataset.visualize_one_image(args.idx) 36 | cv2.imshow("visualized_bboxes", cv2.resize(img, (1000, 1000))) 37 | cv2.waitKey(0) 38 | cv2.imwrite("visualized_bboxes.png", img) 39 | 40 | 41 | if __name__ == "__main__": 42 | main(parsed_args) 43 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | from torch.utils.data import DataLoader 5 | import torch.optim as optim 6 | import torch.optim.lr_scheduler as lr_scheduler 7 | import torch 8 | from tensorboardX import SummaryWriter 9 | from albumentations import * 10 | 11 | from models import YoloNet 12 | from config import Config 13 | from trainer import Trainer 14 | from data_loader import DataTransformBase 15 | 16 | 17 | def train_process(args, total_config, dataset_class, data_transform_class, params): 18 | 19 | # --------------------------------------------------------------------------# 20 | # prepare dataset 21 | # --------------------------------------------------------------------------# 22 | 23 | def _worker_init_fn_(): 24 | import random 25 | import numpy as np 26 | import torch 27 | 28 | torch.manual_seed(args.random_seed) 29 | np.random.seed(args.random_seed) 30 | random.seed(args.random_seed) 31 | if torch.cuda.is_available(): 32 | torch.cuda.manual_seed(args.random_seed) 33 | 34 | input_size = (params["img_h"], params["img_w"]) 35 | 36 | transforms = [ 37 | OneOf([IAAAdditiveGaussianNoise(), GaussNoise()], p=0.5), 38 | OneOf([MedianBlur(blur_limit=3), GaussianBlur(blur_limit=3), MotionBlur(blur_limit=3),], p=0.1,), 39 | RandomGamma(gamma_limit=(80, 120), p=0.5), 40 | RandomBrightnessContrast(p=0.5), 41 | HueSaturationValue(hue_shift_limit=5, sat_shift_limit=20, val_shift_limit=10, p=0.5), 42 | ChannelShuffle(p=0.5), 43 | HorizontalFlip(p=0.5), 44 | Cutout(num_holes=5, max_w_size=40, max_h_size=40, p=0.5), 45 | Rotate(limit=20, p=0.5, border_mode=0), 46 | ] 47 | 48 | data_transform = data_transform_class(transforms=transforms, input_size=input_size) 49 | train_dataset = dataset_class( 50 | data_path=total_config.DATA_PATH, 51 | phase="train", 52 | normalize_bbox=True, 53 | transform=[data_transform], 54 | multiscale=args.multiscale, 55 | resize_after_batch_num=args.resize_after_batch_num, 56 | ) 57 | 58 | val_dataset = dataset_class( 59 | data_path=total_config.DATA_PATH, phase="val", normalize_bbox=True, transform=data_transform, 60 | ) 61 | 62 | train_data_loader = DataLoader( 63 | train_dataset, 64 | batch_size=args.batch_size, 65 | shuffle=True, 66 | collate_fn=train_dataset.od_collate_fn, 67 | num_workers=args.num_workers, 68 | drop_last=True, 69 | worker_init_fn=_worker_init_fn_(), 70 | ) 71 | val_data_loader = DataLoader( 72 | val_dataset, 73 | batch_size=args.batch_size, 74 | shuffle=False, 75 | collate_fn=val_dataset.od_collate_fn, 76 | num_workers=args.num_workers, 77 | drop_last=True, 78 | ) 79 | data_loaders_dict = {"train": train_data_loader, "val": val_data_loader} 80 | 81 | # --------------------------------------------------------------------------# 82 | # configuration for training 83 | # --------------------------------------------------------------------------# 84 | 85 | tblogger = SummaryWriter(total_config.LOG_PATH) 86 | 87 | model = YoloNet(dataset_config=params) 88 | if args.backbone_weight_path: 89 | model.feature.load_weight(args.backbone_weight_path) 90 | 91 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 92 | criterion = None 93 | 94 | base_lr_rate = args.lr_rate / (args.batch_size * args.batch_multiplier) 95 | base_weight_decay = args.weight_decay * (args.batch_size * args.batch_multiplier) 96 | steps = [float(v.strip()) for v in args.steps.split(",")] 97 | scales = [float(v.strip()) for v in args.scales.split(",")] 98 | 99 | def adjust_learning_rate(optimizer, processed_batch): 100 | lr = base_lr_rate 101 | for i in range(len(steps)): 102 | scale = scales[i] if i < len(scales) else 1 103 | if processed_batch >= steps[i]: 104 | lr = lr * scale 105 | if processed_batch == steps[i]: 106 | break 107 | else: 108 | break 109 | for param_group in optimizer.param_groups: 110 | param_group["lr"] = lr / args.batch_size 111 | return lr 112 | 113 | optimizer = torch.optim.SGD( 114 | model.parameters(), lr=base_lr_rate, momentum=args.momentum, weight_decay=base_weight_decay, 115 | ) 116 | 117 | trainer = Trainer( 118 | model=model, 119 | criterion=criterion, 120 | metric_func=None, 121 | optimizer=optimizer, 122 | num_epochs=args.num_epoch, 123 | save_period=args.save_period, 124 | config=total_config, 125 | data_loaders_dict=data_loaders_dict, 126 | device=device, 127 | dataset_name_base=train_dataset.__name__, 128 | batch_multiplier=args.batch_multiplier, 129 | adjust_lr_callback=adjust_learning_rate, 130 | logger=tblogger, 131 | ) 132 | 133 | if args.snapshot and os.path.isfile(args.snapshot): 134 | trainer.resume_checkpoint(args.snapshot) 135 | 136 | with torch.autograd.set_detect_anomaly(True): 137 | trainer.train() 138 | 139 | tblogger.close() 140 | 141 | 142 | def main(args): 143 | total_config = Config() 144 | total_config.display() 145 | if not args.dataset or args.dataset not in total_config.DATASETS.keys(): 146 | raise Exception("specify one of the datasets to use in {}".format(list(total_config.DATASETS.keys()))) 147 | 148 | dataset = args.dataset 149 | dataset_class = total_config.DATASETS[dataset] 150 | data_transform_class = DataTransformBase 151 | params = total_config.DATASET_PARAMS[dataset] 152 | train_process(args, total_config, dataset_class, data_transform_class, params) 153 | 154 | 155 | if __name__ == "__main__": 156 | import argparse 157 | 158 | parser = argparse.ArgumentParser() 159 | parser.add_argument("--batch_size", type=int, default=2) 160 | parser.add_argument("--num_epoch", type=int, default=300) 161 | parser.add_argument("--lr_rate", type=float, default=1e-3) 162 | parser.add_argument("--batch_multiplier", type=int, default=1) 163 | parser.add_argument("--momentum", default=0.9, type=float) 164 | parser.add_argument("--weight_decay", default=5e-4, type=float) 165 | parser.add_argument("--burn_in", default=1000, type=int) 166 | parser.add_argument("--steps", default="40000,45000", type=str) 167 | parser.add_argument("--scales", default=".1,.1", type=str) 168 | parser.add_argument("--gamma", default=0.1, type=float) 169 | parser.add_argument("--milestones", default="120, 220", type=str) 170 | parser.add_argument("--save_period", type=int, default=1) 171 | parser.add_argument("--backbone_weight_path", type=str) 172 | parser.add_argument("--multiscale", type=bool, default=True) 173 | parser.add_argument("--resize_after_batch_num", type=int, default=10) 174 | parser.add_argument("--snapshot", type=str, help="path to snapshot weights") 175 | parser.add_argument( 176 | "--dataset", 177 | required=True, 178 | type=str, 179 | help="name of the dataset to use", 180 | choices=["bird_dataset", "switch_dataset", "wheat_dataset"], 181 | ) 182 | parser.add_argument("--random_seed", type=int, default=12) 183 | parser.add_argument("--num_workers", type=int, default=4) 184 | parsed_args = parser.parse_args() 185 | 186 | main(parsed_args) 187 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from .trainer import Trainer 4 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import sys 5 | import torch 6 | import torch.nn as nn 7 | import logging 8 | from .trainer_base import TrainerBase 9 | 10 | 11 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) 12 | sys.path.append(os.path.join(_CURRENT_DIR, "..")) 13 | try: 14 | from utils import inf_loop 15 | except Exception as e: 16 | print(e) 17 | sys.exit(-1) 18 | 19 | 20 | class Trainer(TrainerBase): 21 | def __init__( 22 | self, 23 | model, 24 | criterion, 25 | metric_func, 26 | optimizer, 27 | num_epochs, 28 | save_period, 29 | config, 30 | data_loaders_dict, 31 | scheduler=None, 32 | device=None, 33 | len_epoch=None, 34 | dataset_name_base="", 35 | batch_multiplier=1, 36 | logger=None, 37 | processed_batch=0, 38 | adjust_lr_callback=None, 39 | print_after_batch_num=10, 40 | ): 41 | super(Trainer, self).__init__( 42 | model, 43 | criterion, 44 | metric_func, 45 | optimizer, 46 | num_epochs, 47 | save_period, 48 | config, 49 | device, 50 | dataset_name_base, 51 | batch_multiplier, 52 | logger, 53 | ) 54 | 55 | self.train_data_loader = data_loaders_dict["train"] 56 | self.val_data_loader = data_loaders_dict["val"] 57 | 58 | self.num_train_imgs = len(self.train_data_loader.dataset) 59 | self.num_val_imgs = len(self.val_data_loader.dataset) 60 | 61 | self.processed_batch = processed_batch 62 | self.adjust_lr_callback = adjust_lr_callback 63 | 64 | if len_epoch is None: 65 | self._len_epoch = len(self.train_data_loader) 66 | else: 67 | self.train_data_loader = inf_loop(self.train_data_loader) 68 | self._len_epoch = len_epoch 69 | 70 | self._do_validation = self.val_data_loader is not None 71 | self._scheduler = scheduler 72 | 73 | self._print_after_batch_num = print_after_batch_num 74 | 75 | self._stat_keys = ["loss", "loss_x", "loss_y", "loss_w", "loss_h", "loss_conf", "loss_cls"] 76 | 77 | def _train_epoch(self, epoch): 78 | self._model.train() 79 | 80 | batch_size = self.train_data_loader.batch_size 81 | 82 | epoch_train_loss = 0.0 83 | count = self._batch_multiplier 84 | running_losses = dict(zip(self._stat_keys, [0.0] * len(self._stat_keys))) 85 | 86 | for batch_idx, (data, target, length_tensor) in enumerate(self.train_data_loader): 87 | 88 | if self.adjust_lr_callback is not None: 89 | self.adjust_lr_callback(self._optimizer, self.processed_batch) 90 | self.processed_batch += 1 91 | 92 | data = data.to(self._device) 93 | target = target.to(self._device) 94 | length_tensor = length_tensor.to(self._device) 95 | 96 | if count == 0: 97 | self._optimizer.step() 98 | self._optimizer.zero_grad() 99 | count = self._batch_multiplier 100 | 101 | with torch.set_grad_enabled(True): 102 | output = self._model(data) 103 | 104 | train_loss = self._model(data, target) 105 | for key in self._stat_keys: 106 | running_losses[key] += self._model.stats[key] 107 | 108 | total_loss = train_loss / self._batch_multiplier 109 | total_loss.backward() 110 | count -= 1 111 | 112 | if (batch_idx + 1) % self._print_after_batch_num == 0: 113 | logging.info( 114 | "\n epoch: {}/{} || iter: {}/{} || [Losses: total: {}, loss_x: {}, loss_y: {}, loss_w: {}, loss_h: {}, loss_conf: {}, loss_cls: {} || lr_rate: {}".format( 115 | epoch, 116 | self._num_epochs, 117 | batch_idx, 118 | len(self.train_data_loader), 119 | running_losses["loss"] / self._print_after_batch_num, 120 | running_losses["loss_x"] / self._print_after_batch_num, 121 | running_losses["loss_y"] / self._print_after_batch_num, 122 | running_losses["loss_w"] / self._print_after_batch_num, 123 | running_losses["loss_h"] / self._print_after_batch_num, 124 | running_losses["loss_conf"] / self._print_after_batch_num, 125 | running_losses["loss_cls"] / self._print_after_batch_num, 126 | self._optimizer.param_groups[0]["lr"], 127 | ) 128 | ) 129 | running_losses = dict(zip(self._stat_keys, [0.0] * len(self._stat_keys))) 130 | 131 | epoch_train_loss += total_loss.item() * self._batch_multiplier 132 | 133 | if batch_idx == self._len_epoch: 134 | break 135 | 136 | if self._do_validation: 137 | epoch_val_loss = self._valid_epoch(epoch) 138 | 139 | if self._scheduler is not None: 140 | self._scheduler.step() 141 | 142 | return ( 143 | epoch_train_loss / self.num_train_imgs, 144 | epoch_val_loss / self.num_val_imgs, 145 | ) 146 | 147 | def _valid_epoch(self, epoch): 148 | print("start validation...") 149 | self._model.eval() 150 | 151 | epoch_val_loss = 0.0 152 | with torch.no_grad(): 153 | for batch_idx, (data, target, length_tensor) in enumerate(self.val_data_loader): 154 | 155 | data = data.to(self._device) 156 | target = target.to(self._device) 157 | length_tensor = length_tensor.to(self._device) 158 | 159 | val_loss = self._model(data, target) 160 | epoch_val_loss += val_loss.item() 161 | 162 | return epoch_val_loss 163 | -------------------------------------------------------------------------------- /trainer/trainer_base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import logging 4 | import os 5 | import sys 6 | import torch 7 | from abc import abstractmethod 8 | 9 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) 10 | sys.path.append(os.path.join(_CURRENT_DIR, "..")) 11 | 12 | 13 | class TrainerBase: 14 | def __init__( 15 | self, 16 | model, 17 | criterion, 18 | metric_func, 19 | optimizer, 20 | num_epochs, 21 | save_period, 22 | config, 23 | device=None, 24 | dataset_name_base="", 25 | batch_multiplier=1, 26 | logger=None, 27 | ): 28 | self._model = model 29 | self.criterion = criterion 30 | self._metric_func = metric_func 31 | self._optimizer = optimizer 32 | self._config = config 33 | self._dataset_name_base = dataset_name_base 34 | self._batch_multiplier = batch_multiplier 35 | self._logger = logger 36 | 37 | if device is None: 38 | self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 39 | else: 40 | self._device = device 41 | 42 | self._model = self._model.to(self._device) 43 | 44 | self._config = config 45 | 46 | self._checkpoint_dir = self._config.SAVED_MODEL_PATH 47 | 48 | self._start_epoch = 1 49 | 50 | self._num_epochs = num_epochs 51 | 52 | self._save_period = save_period 53 | 54 | @property 55 | def model(self): 56 | return self._model 57 | 58 | @abstractmethod 59 | def _train_epoch(self, epoch): 60 | raise NotImplementedError 61 | 62 | def _save_checkpoint(self, epoch, save_best=False): 63 | arch = type(self._model).__name__ 64 | state = { 65 | "arch": arch, 66 | "epoch": epoch, 67 | "state_dict": self._model.state_dict(), 68 | "optimizer": self._optimizer.state_dict(), 69 | } 70 | output_file = "checkpoint_{}_epoch_{}.pth".format(arch, epoch) 71 | if self._dataset_name_base and isinstance(self._dataset_name_base, str) and self._dataset_name_base != "": 72 | output_file = "{}_{}".format(self._dataset_name_base, output_file) 73 | 74 | filename = os.path.join(self._checkpoint_dir, output_file) 75 | torch.save(state, filename) 76 | 77 | # if save_best: 78 | # best_path = os.path.join(self._checkpoint_dir, "model_best.pth") 79 | # torch.save(state, best_path) 80 | 81 | def resume_checkpoint(self, resume_path): 82 | resume_path = str(resume_path) 83 | 84 | checkpoint = torch.load(resume_path) 85 | self._start_epoch = checkpoint["epoch"] + 1 86 | 87 | self._model.load_state_dict(checkpoint["state_dict"]) 88 | 89 | self._optimizer.load_state_dict(checkpoint["optimizer"]) 90 | 91 | def train(self): 92 | logging.info("========================================") 93 | logging.info("Start training {}".format(type(self._model).__name__)) 94 | logging.info("========================================") 95 | logs = [] 96 | 97 | for epoch in range(self._start_epoch, self._num_epochs + 1): 98 | train_loss, val_loss = self._train_epoch(epoch) 99 | 100 | log_epoch = { 101 | "epoch": epoch, 102 | "train_loss": train_loss, 103 | "val_loss": val_loss, 104 | } 105 | 106 | logging.info( 107 | "\n----------------------------------------------------\n" 108 | + "epoch: {}, train_loss: {: .4f}, val_loss: {: .4f}".format(epoch, train_loss, val_loss) 109 | + "\n----------------------------------------------------\n" 110 | ) 111 | logs.append(log_epoch) 112 | if self._logger: 113 | self._logger.add_scalar("train/train_loss", train_loss, epoch) 114 | self._logger.add_scalar("val/val_loss", val_loss, epoch) 115 | 116 | if (epoch + 1) % self._save_period == 0: 117 | self._save_checkpoint(epoch, save_best=True) 118 | 119 | return logs 120 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from .kmeans_bboxes import jaccard, avg_iou, kmeans_bboxes, estimate_avg_ious 4 | from .priors_bboxes import estimate_priors_sizes 5 | from .utils import * 6 | from .download_utility import * 7 | -------------------------------------------------------------------------------- /utils/download_utility.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import requests 4 | import tqdm 5 | 6 | __all__ = ["download_file_from_google_drive"] 7 | 8 | 9 | def download_file_from_google_drive(id, destination): 10 | URL = "https://docs.google.com/uc?export=download" 11 | 12 | session = requests.Session() 13 | 14 | response = session.get(URL, params={"id": id}, stream=True) 15 | token = get_confirm_token(response) 16 | 17 | if token: 18 | params = {"id": id, "confirm": token} 19 | response = session.get(URL, params=params, stream=True) 20 | 21 | save_response_content(response, destination) 22 | 23 | 24 | def get_confirm_token(response): 25 | for key, value in response.cookies.items(): 26 | if key.startswith("download_warning"): 27 | return value 28 | 29 | return None 30 | 31 | 32 | def save_response_content(response, destination): 33 | CHUNK_SIZE = 32768 34 | with open(destination, "wb") as f: 35 | for chunk in tqdm.tqdm(response.iter_content(CHUNK_SIZE)): 36 | if chunk: # filter out keep-alive new chunks 37 | f.write(chunk) 38 | -------------------------------------------------------------------------------- /utils/kmeans_bboxes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # ref: https://lars76.github.io/object-detection/k-means-anchor-boxes/ 4 | import numpy as np 5 | 6 | 7 | def jaccard(bboxes, clusters): 8 | x = np.minimum(clusters[:, 0], bboxes[0]) 9 | y = np.minimum(clusters[:, 1], bboxes[1]) 10 | if np.count_nonzero(x == 0) > 0 or np.count_nonzero(y == 0) > 0: 11 | raise ValueError("Box has no area") 12 | 13 | intersection = x * y 14 | box_area = bboxes[0] * bboxes[1] 15 | cluster_area = clusters[:, 0] * clusters[:, 1] 16 | 17 | iou = intersection / (box_area + cluster_area - intersection) 18 | 19 | return iou 20 | 21 | 22 | def avg_iou(bboxes, clusters): 23 | return np.mean([np.max(jaccard(bboxes[i], clusters)) for i in range(bboxes.shape[0])]) 24 | 25 | 26 | def kmeans_bboxes(bboxes, k=5, metric_dist=np.median, seed=100): 27 | assert k >= 2 28 | rows = bboxes.shape[0] 29 | 30 | distances = np.empty((rows, k)) 31 | last_clusters = np.zeros((rows,)) 32 | 33 | np.random.seed(seed) 34 | 35 | clusters = bboxes[np.random.choice(rows, k, replace=False)] 36 | 37 | while True: 38 | for row in range(rows): 39 | distances[row] = 1 - jaccard(bboxes[row], clusters) 40 | 41 | nearest_clusters = np.argmin(distances, axis=1) 42 | 43 | if (last_clusters == nearest_clusters).all(): 44 | break 45 | 46 | for cluster in range(k): 47 | clusters[cluster] = metric_dist(bboxes[nearest_clusters == cluster], axis=0) 48 | 49 | last_clusters = nearest_clusters 50 | 51 | return clusters 52 | 53 | 54 | def estimate_avg_ious(bboxes, k=5, metric_dist=np.median, seed=100): 55 | assert k >= 2 56 | 57 | num_vals = k - 1 58 | avg_ious = np.empty((num_vals,), dtype=float) 59 | for idx, k in enumerate(tqdm.tqdm(range(2, k + 1))): 60 | clusters = kmeans_bboxes(bboxes, k, metric_dist, seed) 61 | avg_ious[idx] = avg_iou(bboxes, clusters) 62 | 63 | return avg_ious 64 | -------------------------------------------------------------------------------- /utils/priors_bboxes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from .kmeans_bboxes import kmeans_bboxes 4 | 5 | 6 | def estimate_priors_sizes(dataset, k=5): 7 | bboxes = dataset.get_all_normalized_boxes() 8 | return kmeans_bboxes(bboxes, k) 9 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import numpy as np 5 | 6 | 7 | __all__ = ["inf_loop"] 8 | 9 | 10 | def inf_loop(data_loader): 11 | from itertools import repeat 12 | 13 | for loader in repeat(data_loader): 14 | yield from loader 15 | --------------------------------------------------------------------------------