├── .gitignore
├── README.md
├── config.py
├── data
└── README.md
├── data_loader
├── __init__.py
├── bird_dataset.py
├── coco_dataset.py
├── data_transform_base.py
├── dataset_base.py
├── switch_dataset.py
├── udacity_dataset.py
├── voc_dataset.py
└── wheat_dataset.py
├── dataset_params.py
├── docs
└── images
│ ├── labels.gif
│ └── test_result.jpg
├── environment.yml
├── estimate_priors_size_dataset.py
├── export_onnx.py
├── inference_onnx.py
├── models
├── __init__.py
├── functions
│ ├── __init__.py
│ ├── detect.py
│ ├── mish.py
│ ├── new_types.py
│ └── utils.py
├── layers
│ ├── __init__.py
│ └── backbone.py
├── yolo_layer.py
└── yolo_v3.py
├── requirements.txt
├── scripts
├── download_bird_dataset.py
├── download_darknet_weight.py
├── download_darknet_weight.sh
├── download_switch_dataset.py
├── download_udacity_dataset.sh
└── download_voc_dataset.sh
├── test.py
├── test
├── get_dataset_size_distribution.py
├── test_model.py
└── test_visualization.py
├── train.py
├── trainer
├── __init__.py
├── trainer.py
└── trainer_base.py
└── utils
├── __init__.py
├── download_utility.py
├── kmeans_bboxes.py
├── priors_bboxes.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled source #
2 | ###################
3 | *.com
4 | *.class
5 | *.dll
6 | *.exe
7 | *.o
8 | *.so
9 | *.pyc
10 | .ipynb_checkpoints
11 | *~
12 | *#
13 | build*
14 |
15 | # Packages #
16 | ###################
17 | # it's better to unpack these files and commit the raw source
18 | # git has its own built in compression methods
19 | *.7z
20 | *.dmg
21 | *.gz
22 | *.iso
23 | *.jar
24 | *.rar
25 | *.tar
26 | *.zip
27 |
28 | # Logs and databases #
29 | ######################
30 | *.log
31 | *.sql
32 | *.sqlite
33 |
34 | # OS generated files #
35 | ######################
36 | .DS_Store
37 | .DS_Store?
38 | ._*
39 | .Spotlight-V100
40 | .Trashes
41 | ehthumbs.db
42 | Thumbs.db
43 |
44 | # Images
45 | ######################
46 | *.jpg
47 | *.gif
48 | *.png
49 | *.svg
50 | *.ico
51 |
52 | # Video
53 | ######################
54 | *.wmv
55 | *.mpg
56 | *.mpeg
57 | *.mp4
58 | *.mov
59 | *.flv
60 | *.avi
61 | *.ogv
62 | *.ogg
63 | *.webm
64 |
65 | # Audio
66 | ######################
67 | *.wav
68 | *.mp3
69 | *.wma
70 |
71 | # Fonts
72 | ######################
73 | Fonts
74 | *.eot
75 | *.ttf
76 | *.woff
77 |
78 | # Format
79 | ######################
80 | CPPLINT.cfg
81 | .clang-format
82 |
83 | # Gtags
84 | ######################
85 | GPATH
86 | GRTAGS
87 | GSYMS
88 | GTAGS
89 |
90 | data/
91 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://opensource.org/licenses/MIT)
2 |
3 | # Pytorch Implementation of [Yolov3](https://pjreddie.com/media/files/papers/YOLOv3.pdf) For Bird Detection
4 | ***
5 |
6 | This project provides a dataset for wild birds and yolov3 implementation in pytorch for training the dataset. This bird detection dataset is special in the sense that it also provides the dense labels of birds in flock.
7 | The images of birds are collected from the internet, partly by crawling. Label samples can be seen as followings.
8 |
9 |
10 |
11 | Label Samples |
12 |
13 |
14 |  |
15 |
16 |
17 |
18 | ## TODO ##
19 | ***
20 |
21 | - [x] Train on Bird Dataset
22 | - [x] Export onnx weight and test inferencing on onnx weight
23 | - [x] Train on multiple scales
24 | - [x] Mish activation
25 | - [x] Onnx Model
26 |
27 | ## Preparation ##
28 | ***
29 |
30 | ```bash
31 | python3 -m pip install -r requirements.txt
32 | ```
33 |
34 | - Download darknet53 backbone trained on imagenet dataset
35 | ```bash
36 | python3 scripts/download_darknet_weight.py
37 | ```
38 |
39 | After running this script, darknet53.conv.74 weights will be saved inside save_models directory.
40 |
41 | - Download bird dataset
42 | ```bash
43 | python3 scripts/download_bird_dataset.py
44 | ```
45 |
46 | The bird dataset will be saved and extracted in data directory
47 |
48 | ## Scripts ##
49 | ***
50 |
51 | - Training (details for parameters please see train.py script)
52 | ```bash
53 | python3 train.py --dataset bird_dataset --backbone_weight_path ./saved_models/darknet53.conv.74
54 | ```
55 |
56 | Weights will be saved inside save_models directory.
57 |
58 | - Testing
59 | ```bash
60 | python3 test.py --dataset bird_dataset --snapshot [path/to/snapshot] --image_path [path/to/image] --conf_thresh [confidence/thresh] --nms_thresh [nms/thresh]
61 | ```
62 |
63 | A sample trained weight can be download from [HERE](https://drive.google.com/file/d/1DkxLsReA-vEjjWG5jTtzL2gb_kQEZe6b/view?usp=sharing)
64 |
65 |
66 |
67 | Test Result |
68 |
69 |
70 |  |
71 |
72 |
73 |
74 | - Export to onnx model
75 | ```bash
76 | python3 export_onnx.py --dataset bird_dataset --snapshot [path/to/weight/snapshot] --batch_size [batch/size] --onnx_weight_file [output/onnx/file]
77 | ```
78 |
79 | - Inferece with onnx
80 | ```bash
81 | python3 inference_onnx.py --dataset bird_dataset --img_h [img/input/height] --img_w [img/input/width] --image_path [image/path] --onnx_weight_file [onnx/weight] --conf_thresh [confidence/threshold] --nms_thresh [nms_threshold]
82 | ```
83 |
84 | ## References ##
85 | ***
86 |
87 | - [Yolov3](https://pjreddie.com/media/files/papers/YOLOv3.pdf)
88 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import logging
5 | from dataset_params import dataset_params
6 | from data_loader import BirdDataset, UdacityDataset, VOCDataset, COCODataset, SwitchDataset, WheatDataset
7 |
8 |
9 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
10 |
11 |
12 | class Config(object):
13 | DATASETS = {
14 | "bird_dataset": BirdDataset,
15 | "udacity_dataset": UdacityDataset,
16 | "voc_dataset": VOCDataset,
17 | "coco_dataset": COCODataset,
18 | "switch_dataset": SwitchDataset,
19 | "wheat_dataset": WheatDataset,
20 | }
21 |
22 | DATASET_PARAMS = dataset_params
23 |
24 | def __init__(self):
25 | self.CURRENT_DIR = _CURRENT_DIR
26 |
27 | self.DATA_PATH = os.path.abspath(os.path.join(_CURRENT_DIR, "data"))
28 |
29 | self.SAVED_MODEL_PATH = os.path.join(self.CURRENT_DIR, "saved_models")
30 | if not os.path.isdir(self.SAVED_MODEL_PATH):
31 | os.system("mkdir -p {}".format(self.SAVED_MODEL_PATH))
32 |
33 | self.LOG_PATH = os.path.join(self.CURRENT_DIR, "logs")
34 | if not os.path.isdir(self.LOG_PATH):
35 | os.system("mkdir -p {}".format(self.LOG_PATH))
36 | _config_logging(log_file=os.path.join(self.LOG_PATH, "log.txt"))
37 |
38 | def display(self):
39 | """
40 | Display Configuration values.
41 | """
42 | print("\nConfigurations:")
43 | for a in dir(self):
44 | if not a.startswith("__") and not callable(getattr(self, a)):
45 | print("{:30} {}".format(a, getattr(self, a)))
46 | print("\n")
47 |
48 |
49 | def _config_logging(log_file, log_level=logging.DEBUG):
50 | import sys
51 |
52 | format_line = "%(asctime)s - %(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)"
53 | custom_formatter = CustomFormatter(format_line)
54 | stream_handler = logging.StreamHandler()
55 | file_handler = logging.FileHandler(log_file)
56 |
57 | stream_handler.setFormatter(custom_formatter)
58 |
59 | logging.basicConfig(handlers=[file_handler, stream_handler], level=log_level, format=format_line)
60 |
61 |
62 | class CustomFormatter(logging.Formatter):
63 | def format(self, record, *args, **kwargs):
64 | import copy
65 |
66 | LOG_COLORS = {
67 | logging.INFO: "\x1b[33m",
68 | logging.DEBUG: "\x1b[36m",
69 | logging.WARNING: "\x1b[31m",
70 | logging.ERROR: "\x1b[31;1m",
71 | logging.CRITICAL: "\x1b[35m",
72 | }
73 |
74 | new_record = copy.copy(record)
75 | if new_record.levelno in LOG_COLORS:
76 | new_record.levelname = "{color_begin}{level}{color_end}".format(
77 | level=new_record.levelname, color_begin=LOG_COLORS[new_record.levelno], color_end="\x1b[0m",
78 | )
79 | return super(CustomFormatter, self).format(new_record, *args, **kwargs)
80 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xmba15/yolov3_pytorch/993d9bf966965cb2f7800da2cb3b88ce1ea17f51/data/README.md
--------------------------------------------------------------------------------
/data_loader/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from .udacity_dataset import UdacityDataset
4 | from .bird_dataset import BirdDataset
5 | from .voc_dataset import VOCDataset
6 | from .coco_dataset import COCODataset
7 | from .switch_dataset import SwitchDataset
8 | from .wheat_dataset import WheatDataset
9 | from .data_transform_base import DataTransformBase
10 |
--------------------------------------------------------------------------------
/data_loader/bird_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import cv2
5 | import numpy as np
6 | import json
7 | from .dataset_base import DatasetBase, DatasetConfigBase
8 |
9 |
10 | class BirdDatasetConfig(DatasetConfigBase):
11 | def __init__(self):
12 | super(BirdDatasetConfig, self).__init__()
13 |
14 | self.CLASSES = [
15 | "bird",
16 | ]
17 |
18 | self.COLORS = DatasetConfigBase.generate_color_chart(self.num_classes)
19 |
20 |
21 | _bird_config = BirdDatasetConfig()
22 |
23 |
24 | class BirdDataset(DatasetBase):
25 | __name__ = "bird_dataset"
26 |
27 | def __init__(
28 | self,
29 | data_path,
30 | classes=_bird_config.CLASSES,
31 | colors=_bird_config.COLORS,
32 | phase="train",
33 | transform=None,
34 | shuffle=True,
35 | random_seed=2000,
36 | normalize_bbox=False,
37 | bbox_transformer=None,
38 | multiscale=False,
39 | resize_after_batch_num=10,
40 | ):
41 | super(BirdDataset, self).__init__(
42 | data_path,
43 | classes=classes,
44 | colors=colors,
45 | phase=phase,
46 | transform=transform,
47 | shuffle=shuffle,
48 | normalize_bbox=normalize_bbox,
49 | bbox_transformer=bbox_transformer,
50 | multiscale=multiscale,
51 | resize_after_batch_num=resize_after_batch_num,
52 | )
53 |
54 | assert os.path.isdir(data_path)
55 | assert phase in ("train", "val", "test")
56 |
57 | self._data_path = os.path.join(data_path, "bird_dataset")
58 | assert os.path.isdir(self._data_path)
59 |
60 | self._phase = phase
61 | self._transform = transform
62 |
63 | if self._phase == "test":
64 | self._image_path_base = os.path.join(self._data_path, "test")
65 | self._image_paths = sorted(
66 | [os.path.join(self._image_path_base, image_path) for image_path in os.listdir(self._image_path_base)]
67 | )
68 | else:
69 | self._train_path_base = os.path.join(self._data_path, "train")
70 | self._val_path_base = os.path.join(self._data_path, "val")
71 |
72 | trainval_dict = {"train": {"path": self._train_path_base}, "val": {"path": self._val_path_base}}
73 |
74 | data_path = trainval_dict[self._phase]["path"]
75 | all_image_paths = DatasetBase.get_all_files_with_format_from_path(data_path, ".jpg")
76 | all_json_paths = DatasetBase.get_all_files_with_format_from_path(data_path, ".json")
77 |
78 | assert len(all_image_paths) == len(all_json_paths)
79 |
80 | self._image_paths = [os.path.join(data_path, elem) for elem in all_image_paths]
81 | self._targets = [self._load_one_json(os.path.join(data_path, elem)) for elem in all_json_paths]
82 |
83 | def _load_one_json(self, json_path):
84 | bboxes = []
85 | labels = []
86 |
87 | p_json = json.load(open(json_path, "r"))
88 |
89 | for obj in p_json["objects"]:
90 | bboxes.append(obj["boundingbox"])
91 | label_text = obj["label"]
92 |
93 | labels.append(0)
94 |
95 | return [bboxes, labels]
96 |
--------------------------------------------------------------------------------
/data_loader/coco_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | from .dataset_base import DatasetBase, DatasetConfigBase
5 | from pycocotools.coco import COCO
6 |
7 |
8 | class COCODatasetConfig(DatasetConfigBase):
9 | def __init__(self):
10 | super(COCODatasetConfig, self).__init__()
11 |
12 | self.CLASSES = [
13 | "person",
14 | "bicycle",
15 | "car",
16 | "motorbike",
17 | "aeroplane",
18 | "bus",
19 | "train",
20 | "truck",
21 | "boat",
22 | "traffic light",
23 | "fire hydrant",
24 | "stop sign",
25 | "parking meter",
26 | "bench",
27 | "bird",
28 | "cat",
29 | "dog",
30 | "horse",
31 | "sheep",
32 | "cow",
33 | "elephant",
34 | "bear",
35 | "zebra",
36 | "giraffe",
37 | "backpack",
38 | "umbrella",
39 | "handbag",
40 | "tie",
41 | "suitcase",
42 | "frisbee",
43 | "skis",
44 | "snowboard",
45 | "sports ball",
46 | "kite",
47 | "baseball bat",
48 | "baseball glove",
49 | "skateboard",
50 | "surfboard",
51 | "tennis racket",
52 | "bottle",
53 | "wine glass",
54 | "cup",
55 | "fork",
56 | "knife",
57 | "spoon",
58 | "bowl",
59 | "banana",
60 | "apple",
61 | "sandwich",
62 | "orange",
63 | "broccoli",
64 | "carrot",
65 | "hot dog",
66 | "pizza",
67 | "donut",
68 | "cake",
69 | "chair",
70 | "sofa",
71 | "pottedplant",
72 | "bed",
73 | "diningtable",
74 | "toilet",
75 | "tvmonitor",
76 | "laptop",
77 | "mouse",
78 | "remote",
79 | "keyboard",
80 | "cell phone",
81 | "microwave",
82 | "oven",
83 | "toaster",
84 | "sink",
85 | "refrigerator",
86 | "book",
87 | "clock",
88 | "vase",
89 | "scissors",
90 | "teddy bear",
91 | "hair drier",
92 | "toothbrush",
93 | ]
94 |
95 | self.COLORS = DatasetConfigBase.generate_color_chart(self.num_classes)
96 |
97 |
98 | _coco_config = COCODatasetConfig()
99 |
100 |
101 | class COCODataset(DatasetBase):
102 | __name__ = "coco_dataset"
103 |
104 | def __init__(
105 | self,
106 | data_path,
107 | data_path_suffix="coco_dataset",
108 | classes=_coco_config.CLASSES,
109 | colors=_coco_config.COLORS,
110 | phase="train",
111 | transform=None,
112 | shuffle=True,
113 | input_size=None,
114 | random_seed=2000,
115 | normalize_bbox=False,
116 | normalize_image=False,
117 | bbox_transformer=None,
118 | ):
119 | super(COCODataset, self).__init__(
120 | data_path,
121 | classes=classes,
122 | colors=colors,
123 | phase=phase,
124 | transform=transform,
125 | shuffle=shuffle,
126 | normalize_bbox=normalize_bbox,
127 | bbox_transformer=bbox_transformer,
128 | )
129 |
130 | self._input_size = input_size
131 | self._normalize_image = normalize_image
132 | self._min_size = 1
133 |
134 | assert phase in ("train", "val")
135 | self._data_path = os.path.join(data_path, data_path_suffix)
136 |
137 | self._train_img_path = os.path.join(self._data_path, "train2017")
138 | self._val_img_path = os.path.join(self._data_path, "val2017")
139 | self._annotation_path = os.path.join(self._data_path, "annotations_trainval2017", "annotations")
140 |
141 | assert os.path.isdir(self._train_img_path)
142 | assert os.path.isdir(self._val_img_path)
143 | assert os.path.isdir(self._annotation_path)
144 | self._train_annotation_file = os.path.join(self._annotation_path, "instances_train2017.json")
145 | self._val_annotation_file = os.path.join(self._annotation_path, "instances_val2017.json")
146 |
147 | self._train_coco = COCO(self._train_annotation_file)
148 | self._val_coco = COCO(self._val_annotation_file)
149 | self._train_ids = self._train_coco.getImgIds()
150 | self._val_ids = self._train_coco.getImgIds()
151 | self._class_ids = sorted(self._train_coco.getCatIds())
152 |
153 | self._map_info = {
154 | "train": {"img_path": self._train_img_path, "coco": self._train_coco, "ids": self._train_ids,},
155 | "val": {"img_path": self._val_img_path, "coco": self._val_coco, "ids": self._val_ids,},
156 | }
157 |
158 | self._image_paths = []
159 | self._targets = []
160 |
161 | for idx in self._map_info[self._phase]["ids"]:
162 | anno_ids = self._map_info[self._phase]["coco"].getAnnIds(imgIds=[int(idx)], iscrowd=None)
163 | annotations = self._map_info[self._phase]["coco"].loadAnns(anno_ids)
164 |
165 | image_path = os.path.join(self._map_info[self._phase]["img_path"], "{:012}".format(idx) + ".jpg",)
166 | self._image_paths.append(image_path)
167 |
168 | cur_boxes = []
169 | cur_labels = []
170 |
171 | for anno in annotations:
172 | xmin, ymin, width, height = anno["bbox"]
173 | if width < self._min_size or height < self._min_size:
174 | continue
175 | xmax = xmin + width
176 | ymax = ymin + height
177 | label = self._class_ids.index(anno["category_id"])
178 | cur_boxes.append([xmin, ymin, xmax, ymax])
179 | cur_labels.append(label)
180 |
181 | self._targets.append([cur_boxes, cur_labels])
182 |
--------------------------------------------------------------------------------
/data_loader/data_transform_base.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from albumentations import *
4 | from albumentations.pytorch import ToTensor
5 | import random
6 |
7 | _SEED = 100
8 | random.seed(_SEED)
9 |
10 |
11 | class DataTransformBase(object):
12 | def __init__(
13 | self,
14 | transforms=[HorizontalFlip(p=0.5), GaussNoise(p=0.5), RandomBrightnessContrast(p=0.5),],
15 | input_size=None,
16 | normalize=False,
17 | ):
18 | self._input_size = input_size
19 | if self._input_size is not None:
20 | height, width = self._input_size
21 | self._height_offset = height // 32 - 8
22 | self._width_offset = width // 32 - 8
23 | assert self._height_offset > 0 and self._width_offset > 0
24 |
25 | self._normalize = normalize
26 |
27 | self._transform_dict = {"train": {}, "val": {}, "test": None}
28 |
29 | self._transform_dict["train"]["normal"] = transforms
30 | self._transform_dict["val"]["normal"] = []
31 |
32 | self._bbox_params = BboxParams(
33 | format="pascal_voc", min_area=0.001, min_visibility=0.001, label_fields=["category_ids"],
34 | )
35 | self._initialize_transform_dict()
36 |
37 | def _get_all_transforms_of_phase(self, phase):
38 | assert phase in ("train", "val")
39 | cur_transform = []
40 | cur_transform.extend(self._transform_dict[phase]["normal"])
41 | cur_transform.append(self._transform_dict[phase]["resize"])
42 | cur_transform.append(self._transform_dict[phase]["normalize"])
43 |
44 | return cur_transform
45 |
46 | def _initialize_transform_dict(self):
47 | if self._input_size is not None:
48 | height, width = self._input_size
49 | self._transform_dict["train"]["resize"] = Resize(height, width, always_apply=True)
50 | self._transform_dict["val"]["resize"] = Resize(height, width, always_apply=True)
51 |
52 | if self._normalize:
53 | self._transform_dict["train"]["normalize"] = Normalize(always_apply=True)
54 | self._transform_dict["val"]["normalize"] = Normalize(always_apply=True)
55 | else:
56 | self._transform_dict["train"]["normalize"] = ToTensor()
57 | self._transform_dict["val"]["normalize"] = ToTensor()
58 |
59 | self._transform_dict["train"]["all"] = self._get_all_transforms_of_phase("train")
60 | self._transform_dict["val"]["all"] = self._get_all_transforms_of_phase("val")
61 |
62 | self._transform_dict["test"] = self._transform_dict["val"]["all"]
63 |
64 | def update_size(self):
65 | if self._input_size is not None:
66 | random_offset = random.randint(0, 9)
67 | new_height = (random_offset + self._height_offset) * 32
68 | new_width = (random_offset + self._width_offset) * 32
69 |
70 | self._transform_dict["train"]["resize"] = Resize(new_height, new_width, always_apply=True)
71 | self._transform_dict["train"]["all"] = self._get_all_transforms_of_phase("train")
72 |
73 | def __call__(self, image, bboxes=None, labels=None, phase=None):
74 | if phase is None:
75 | transformer = Compose(self._transform_dict["test"])
76 | return transformer(image=image)
77 |
78 | assert phase in ("train", "val")
79 | assert bboxes is not None
80 | assert labels is not None
81 |
82 | transformed_image = image
83 | transformed_bboxes = bboxes
84 | transformed_category_ids = labels
85 | for transform in self._transform_dict[phase]["all"]:
86 | annotations = {
87 | "image": transformed_image,
88 | "bboxes": transformed_bboxes,
89 | "category_ids": transformed_category_ids,
90 | }
91 | transformer = Compose([transform], bbox_params=self._bbox_params)
92 | augmented = transformer(**annotations)
93 |
94 | while len(augmented["bboxes"]) == 0:
95 | augmented = transformer(**annotations)
96 |
97 | transformed_image = augmented["image"]
98 | transformed_bboxes = augmented["bboxes"]
99 | transformed_category_ids = augmented["category_ids"]
100 |
101 | if not self._normalize:
102 | transformed_image = transformed_image.permute(2, 1, 0)
103 |
104 | return (transformed_image, transformed_bboxes, transformed_category_ids)
105 |
--------------------------------------------------------------------------------
/data_loader/dataset_base.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import cv2
5 | import numpy as np
6 | import tqdm
7 | import multiprocessing as mp
8 | from ctypes import c_int32
9 | from abc import abstractmethod
10 |
11 |
12 | _counter = mp.Value(c_int32)
13 | _counter_lock = mp.Lock()
14 |
15 |
16 | class DatasetConfigBase(object):
17 | def __init__(self):
18 | self.CLASSES = []
19 |
20 | @property
21 | def num_classes(self):
22 | return len(self.CLASSES)
23 |
24 | @staticmethod
25 | def generate_color_chart(num_classes, seed=3700):
26 | assert num_classes > 0
27 | np.random.seed(seed)
28 |
29 | colors = np.random.randint(0, 255, size=(num_classes, 3), dtype="uint8")
30 | colors = np.vstack([colors]).astype("uint8")
31 | colors = [tuple(color) for color in list(colors)]
32 | colors = [tuple(int(e) for e in color) for color in colors]
33 |
34 | return colors
35 |
36 |
37 | class DatasetBase(object):
38 | def __init__(
39 | self,
40 | data_path,
41 | classes,
42 | colors,
43 | phase="train",
44 | transform=None,
45 | shuffle=True,
46 | normalize_bbox=False,
47 | bbox_transformer=None,
48 | multiscale=False,
49 | resize_after_batch_num=10,
50 | ):
51 | super(DatasetBase, self).__init__()
52 | assert os.path.isdir(data_path)
53 | assert phase in ("train", "val", "test")
54 |
55 | self._data_path = data_path
56 | self._classes = classes
57 | self._colors = colors
58 | self._phase = phase
59 | self._transform = transform
60 | self._shuffle = shuffle
61 | self._normalize_bbox = normalize_bbox
62 | self._bbox_transformer = bbox_transformer
63 | self._image_paths = []
64 | self._targets = []
65 |
66 | self._batch_count = 0
67 | self._multiscale = multiscale
68 | self._resize_after_batch_num = resize_after_batch_num
69 |
70 | @property
71 | def classes(self):
72 | return self._classes
73 |
74 | @property
75 | def colors(self):
76 | return self._colors
77 |
78 | @property
79 | def num_classes(self):
80 | return len(self._classes)
81 |
82 | def __getitem__(self, idx):
83 | """
84 | X: np.array (batch_size, height, width, channel)
85 | y: list of length batch size
86 | y: (batch_size, [[number of bboxes, x_min, y_min, x_max, y_max, labels]])
87 | """
88 | assert idx < self.__len__()
89 |
90 | image, targets = self._data_generation(idx)
91 |
92 | if self._bbox_transformer is not None:
93 | targets = self._bbox_transformer(targets)
94 |
95 | return image, targets
96 |
97 | def __len__(self):
98 | return len(self._image_paths)
99 |
100 | def visualize_one_image(self, idx):
101 | assert not self._normalize_bbox
102 | assert idx < self.__len__()
103 | image, targets = self.__getitem__(idx)
104 |
105 | all_bboxes = targets[:, :-1]
106 | all_category_ids = targets[:, -1]
107 |
108 | all_bboxes = all_bboxes.astype(np.int64)
109 | all_category_ids = all_category_ids.astype(np.int64)
110 |
111 | return DatasetBase.visualize_one_image_util(image, self._classes, self._colors, all_bboxes, all_category_ids)
112 |
113 | @staticmethod
114 | def visualize_one_image_util(image, classes, colors, all_bboxes, all_category_ids):
115 | for (bbox, label) in zip(all_bboxes, all_category_ids):
116 | x_min, y_min, x_max, y_max = bbox
117 |
118 | cv2.rectangle(image, (x_min, y_min), (x_max, y_max), colors[label], 2)
119 |
120 | label_text = classes[label]
121 | label_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)
122 |
123 | cv2.rectangle(
124 | image,
125 | (x_min, y_min),
126 | (x_min + label_size[0][0], y_min + int(1.3 * label_size[0][1])),
127 | colors[label],
128 | -1,
129 | )
130 | cv2.putText(
131 | image,
132 | label_text,
133 | org=(x_min, y_min + label_size[0][1]),
134 | fontFace=cv2.FONT_HERSHEY_SIMPLEX,
135 | fontScale=0.35,
136 | color=(255, 255, 255),
137 | lineType=cv2.LINE_AA,
138 | )
139 |
140 | return image
141 |
142 | def _data_generation(self, idx):
143 | abs_image_path = self._image_paths[idx]
144 | o_img = cv2.imread(abs_image_path)
145 | o_height, o_width, _ = o_img.shape
146 |
147 | o_bboxes, o_category_ids = self._targets[idx]
148 |
149 | o_bboxes = [DatasetBase.authentize_bbox(o_height, o_width, bbox) for bbox in o_bboxes]
150 |
151 | img = o_img
152 | bboxes = o_bboxes
153 | category_ids = o_category_ids
154 | height, width = o_height, o_width
155 | if self._transform:
156 | if isinstance(self._transform, list):
157 | for transform in self._transform:
158 | img, bboxes, category_ids = transform(
159 | image=img, bboxes=bboxes, labels=category_ids, phase=self._phase
160 | )
161 | else:
162 | img, bboxes, category_ids = self._transform(
163 | image=img, bboxes=bboxes, labels=category_ids, phase=self._phase
164 | )
165 |
166 | # use the height, width after transformation for normalization
167 | height, width, _ = img.shape
168 |
169 | # if number of boxes is 0, use original image
170 | # see data transform for more details
171 |
172 | if self._normalize_bbox:
173 | bboxes = [
174 | [float(bbox[0]) / width, float(bbox[1]) / height, float(bbox[2]) / width, float(bbox[3]) / height,]
175 | for bbox in bboxes
176 | ]
177 |
178 | bboxes = np.array(bboxes)
179 | category_ids = np.array(category_ids).reshape(-1, 1)
180 | targets = np.concatenate((bboxes, category_ids), axis=-1)
181 |
182 | return img, targets
183 |
184 | def _process_one_image(self, idx):
185 | abs_image_path = self._image_paths[idx]
186 | o_img = cv2.imread(abs_image_path)
187 | o_height, o_width, _ = o_img.shape
188 |
189 | o_bboxes, _ = self._targets[idx]
190 |
191 | o_bboxes = [DatasetBase.authentize_bbox(o_height, o_width, bbox) for bbox in o_bboxes]
192 |
193 | o_bboxes = np.array(o_bboxes)
194 | widths = (o_bboxes[:, 2] - o_bboxes[:, 0]) / o_width
195 | heights = (o_bboxes[:, 3] - o_bboxes[:, 1]) / o_height
196 | normalized_dimensions = [[w, h] for w, h in zip(widths, heights)]
197 |
198 | with _counter_lock:
199 | _counter.value += 1
200 |
201 | return normalized_dimensions
202 |
203 | def get_all_normalized_boxes(self, num_processes=mp.cpu_count()):
204 | import functools
205 |
206 | _process_len = self.__len__()
207 |
208 | pbar = tqdm.tqdm(total=_process_len)
209 |
210 | result_bboxes = None
211 | with mp.Pool(num_processes) as p:
212 | future = p.map_async(self._process_one_image, range(_process_len))
213 | while not future.ready():
214 | if _counter.value != 0:
215 | with _counter_lock:
216 | increment = _counter.value
217 | _counter.value = 0
218 | pbar.update(n=increment)
219 |
220 | result_bboxes = future.get()
221 | result_bboxes = functools.reduce(lambda x, y: x + y, result_bboxes)
222 |
223 | pbar.close()
224 | return np.array(result_bboxes)
225 |
226 | def _process_one_image_to_get_size(self, idx):
227 | abs_image_path = self._image_paths[idx]
228 | o_img = cv2.imread(abs_image_path)
229 | o_height, o_width, _ = o_img.shape
230 | o_size = o_height * o_width
231 | o_bboxes, _ = self._targets[idx]
232 |
233 | cur_sizes = []
234 | for (x1, y1, x2, y2) in o_bboxes:
235 | cur_sizes.append(np.sqrt((x2 - x1) * (y2 - y1) * 1.0 / o_size))
236 |
237 | return cur_sizes
238 |
239 | def size_distribution(self, num_processes=mp.cpu_count()):
240 |
241 | import functools
242 |
243 | _process_len = self.__len__()
244 |
245 | pbar = tqdm.tqdm(total=_process_len)
246 |
247 | result_bboxes = None
248 | with mp.Pool(num_processes) as p:
249 | future = p.map_async(self._process_one_image_to_get_size, range(_process_len))
250 | while not future.ready():
251 | if _counter.value != 0:
252 | with _counter_lock:
253 | increment = _counter.value
254 | _counter.value = 0
255 | pbar.update(n=increment)
256 |
257 | result_bboxes = future.get()
258 | result_bboxes = functools.reduce(lambda x, y: x + y, result_bboxes)
259 |
260 | pbar.close()
261 | return np.array(result_bboxes)
262 |
263 | @staticmethod
264 | def authentize_bbox(o_height, o_width, bbox):
265 | bbox_type = type(bbox)
266 |
267 | x_min, y_min, x_max, y_max = bbox
268 | if x_min > x_max:
269 | x_min, x_max = x_max, x_min
270 | if y_min > y_max:
271 | y_min, y_max = y_max, y_min
272 |
273 | x_min = max(x_min, 0)
274 | x_max = min(x_max, o_width)
275 | y_min = max(y_min, 0)
276 | y_max = min(y_max, o_height)
277 |
278 | return bbox_type([x_min, y_min, x_max, y_max])
279 |
280 | @staticmethod
281 | def color_to_color_idx_dict(colors):
282 | color_idx_dict = {}
283 |
284 | for i, color in enumerate(colors):
285 | color_idx_dict[color] = i
286 |
287 | return color_idx_dict
288 |
289 | @staticmethod
290 | def class_to_class_idx_dict(classes):
291 | class_idx_dict = {}
292 |
293 | for i, class_name in enumerate(classes):
294 | class_idx_dict[class_name] = i
295 |
296 | return class_idx_dict
297 |
298 | @staticmethod
299 | def human_sort(s):
300 | """Sort list the way humans do
301 | """
302 | import re
303 |
304 | pattern = r"([0-9]+)"
305 | return [int(c) if c.isdigit() else c.lower() for c in re.split(pattern, s)]
306 |
307 | @staticmethod
308 | def get_all_files_with_format_from_path(dir_path, suffix_format, use_human_sort=True):
309 | import os
310 |
311 | all_files = [elem for elem in os.listdir(dir_path) if elem.endswith(suffix_format)]
312 | all_files.sort(key=DatasetBase.human_sort)
313 |
314 | return all_files
315 |
316 | def od_collate_fn(self, batch):
317 | import torch
318 | import numpy as np
319 |
320 | def _xywh_to_cxcywh(bbox):
321 | bbox[..., 0] += bbox[..., 2] / 2
322 | bbox[..., 1] += bbox[..., 3] / 2
323 | return bbox
324 |
325 | def _xyxy_to_cxcywh(bbox):
326 | bbox[..., 2] -= bbox[..., 0]
327 | bbox[..., 3] -= bbox[..., 1]
328 | return _xywh_to_cxcywh(bbox)
329 |
330 | if (
331 | self._multiscale
332 | and (self._batch_count + 1) % self._resize_after_batch_num == 0
333 | and self._transform is not None
334 | ):
335 | if isinstance(self._transform, list):
336 | for i in range(len(self._transform)):
337 | self._transform[i].update_size()
338 | else:
339 | self._transform.update_size()
340 |
341 | images = []
342 | labels = []
343 | lengths = []
344 | labels_with_tail = []
345 | max_num_obj = 0
346 |
347 | for image, label in batch:
348 | image = np.transpose(image, (2, 1, 0))
349 | image = np.expand_dims(image, axis=0)
350 | images.append(image)
351 |
352 | # xmin,ymin,xmax,ymax to xcenter,ycenter,width,height
353 | labels.append(_xyxy_to_cxcywh(label))
354 |
355 | length = label.shape[0]
356 | lengths.append(length)
357 | max_num_obj = max(max_num_obj, length)
358 |
359 | for label in labels:
360 | num_obj = label.shape[0]
361 | zero_tail = np.zeros((max_num_obj - num_obj, label.shape[1]), dtype=float)
362 | label_with_tail = np.concatenate([label, zero_tail], axis=0)
363 | labels_with_tail.append(torch.FloatTensor(label_with_tail))
364 |
365 | images = np.concatenate(images, axis=0)
366 |
367 | image_tensor = torch.FloatTensor(images)
368 | label_tensor = torch.stack(labels_with_tail)
369 | length_tensor = torch.tensor(lengths)
370 |
371 | self._batch_count += 1
372 |
373 | return image_tensor, label_tensor, length_tensor
374 |
--------------------------------------------------------------------------------
/data_loader/switch_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import cv2
5 | import random
6 | import numpy as np
7 | import json
8 | from .dataset_base import DatasetBase, DatasetConfigBase
9 |
10 |
11 | class SwitchDatasetConfig(DatasetConfigBase):
12 | def __init__(self):
13 | super(SwitchDatasetConfig, self).__init__()
14 |
15 | self.CLASSES = [
16 | "switch-unknown",
17 | "switch-right",
18 | "switch-left",
19 | ]
20 |
21 | self.COLORS = DatasetConfigBase.generate_color_chart(self.num_classes)
22 |
23 |
24 | _switch_config = SwitchDatasetConfig()
25 |
26 |
27 | class SwitchDataset(DatasetBase):
28 | __name__ = "switch_dataset"
29 |
30 | def __init__(
31 | self,
32 | data_path,
33 | classes=_switch_config.CLASSES,
34 | colors=_switch_config.COLORS,
35 | phase="train",
36 | transform=None,
37 | shuffle=True,
38 | random_seed=2000,
39 | normalize_bbox=False,
40 | bbox_transformer=None,
41 | multiscale=False,
42 | resize_after_batch_num=10,
43 | train_val_ratio=0.9,
44 | ):
45 | super(SwitchDataset, self).__init__(
46 | data_path,
47 | classes=classes,
48 | colors=colors,
49 | phase=phase,
50 | transform=transform,
51 | shuffle=shuffle,
52 | normalize_bbox=normalize_bbox,
53 | bbox_transformer=bbox_transformer,
54 | multiscale=multiscale,
55 | resize_after_batch_num=resize_after_batch_num,
56 | )
57 |
58 | assert os.path.isdir(data_path)
59 | assert phase in ("train", "val")
60 |
61 | self._data_path = os.path.join(data_path, "switch_detection/data")
62 | assert os.path.isdir(self._data_path)
63 |
64 | self._phase = phase
65 | self._transform = transform
66 |
67 | _imgs_path = os.path.join(self._data_path, "imgs")
68 | _jsons_path = os.path.join(self._data_path, "labels")
69 |
70 | _all_image_paths = DatasetBase.get_all_files_with_format_from_path(_imgs_path, ".jpg")
71 | _all_json_paths = DatasetBase.get_all_files_with_format_from_path(_jsons_path, ".json")
72 |
73 | assert len(_all_image_paths) == len(_all_json_paths)
74 |
75 | _image_paths = [os.path.join(_imgs_path, elem) for elem in _all_image_paths]
76 | _targets = [self._load_one_json(os.path.join(_jsons_path, elem)) for elem in _all_json_paths]
77 |
78 | zipped = list(zip(_image_paths, _targets))
79 | random.seed(random_seed)
80 | random.shuffle(zipped)
81 | _image_paths, _targets = zip(*zipped)
82 |
83 | _train_len = int(train_val_ratio * len(_image_paths))
84 | if self._phase == "train":
85 | self._image_paths = _image_paths[:_train_len]
86 | self._targets = _targets[:_train_len]
87 | else:
88 | self._image_paths = _image_paths[_train_len:]
89 | self._targets = _targets[_train_len:]
90 |
91 | def _load_one_json(self, json_path):
92 | bboxes = []
93 | labels = []
94 |
95 | p_json = json.load(open(json_path, "r"))
96 |
97 | for obj in p_json["objects"]:
98 | if "boundingbox" in obj:
99 | x_min, y_min, x_max, y_max = obj["boundingbox"]
100 | label_text = obj["label"]
101 | label_idx = self._classes.index(label_text)
102 |
103 | bboxes.append([x_min, y_min, x_max, y_max])
104 | labels.append(label_idx)
105 |
106 | return [bboxes, labels]
107 |
--------------------------------------------------------------------------------
/data_loader/udacity_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import cv2
5 | import numpy as np
6 | import random
7 | from .dataset_base import DatasetBase, DatasetConfigBase
8 |
9 |
10 | class UdacityDatasetConfig(DatasetConfigBase):
11 | def __init__(self):
12 | super(UdacityDatasetConfig, self).__init__()
13 |
14 | self.CLASSES = ["car", "truck", "pedestrian", "biker", "trafficLight"]
15 |
16 | self.COLORS = DatasetConfigBase.generate_color_chart(self.num_classes)
17 |
18 |
19 | _udacity_config = UdacityDatasetConfig()
20 |
21 |
22 | class UdacityDataset(DatasetBase):
23 | __name__ = "udacity_dataset"
24 |
25 | def __init__(
26 | self,
27 | data_path,
28 | classes=_udacity_config.CLASSES,
29 | colors=_udacity_config.COLORS,
30 | phase="train",
31 | transform=None,
32 | shuffle=True,
33 | random_seed=2000,
34 | normalize_bbox=False,
35 | bbox_transformer=None,
36 | train_val_ratio=0.9,
37 | ):
38 | super(UdacityDataset, self).__init__(
39 | data_path,
40 | classes=classes,
41 | colors=colors,
42 | phase=phase,
43 | transform=transform,
44 | shuffle=shuffle,
45 | normalize_bbox=normalize_bbox,
46 | bbox_transformer=bbox_transformer,
47 | )
48 |
49 | assert phase in ("train", "val")
50 |
51 | assert os.path.isdir(data_path)
52 | self._data_path = os.path.join(data_path, "udacity/object-dataset")
53 | assert os.path.isdir(self._data_path)
54 |
55 | self._annotation_file = os.path.join(self._data_path, "labels.csv")
56 | lines = [line.rstrip("\n") for line in open(self._annotation_file, "r")]
57 | lines = [line.split(" ") for line in lines]
58 | image_dict = {}
59 | class_idx_dict = self.class_to_class_idx_dict(self._classes)
60 |
61 | for line in lines:
62 | if line[0] not in image_dict.keys():
63 | image_dict[line[0]] = [[], []]
64 |
65 | image_dict[line[0]][0].append([int(e) for e in line[1:5]])
66 | label_name = line[6][1:][:-1]
67 | image_dict[line[0]][1].append(class_idx_dict[label_name])
68 |
69 | self._image_paths = image_dict.keys()
70 | self._image_paths = [os.path.join(self._data_path, elem) for elem in self._image_paths]
71 | self._targets = image_dict.values()
72 |
73 | zipped = list(zip(self._image_paths, self._targets))
74 | random.seed(random_seed)
75 | random.shuffle(zipped)
76 | self._image_paths, self._targets = zip(*zipped)
77 |
78 | train_len = int(train_val_ratio * len(self._image_paths))
79 | if self._phase == "train":
80 | self._image_paths = self._image_paths[:train_len]
81 | self._targets = self._targets[:train_len]
82 | else:
83 | self._image_paths = self._image_paths[train_len:]
84 | self._targets = self._targets[train_len:]
85 |
--------------------------------------------------------------------------------
/data_loader/voc_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import cv2
5 | import numpy as np
6 | import xml.etree.ElementTree as ET
7 | from .dataset_base import DatasetBase, DatasetConfigBase
8 |
9 |
10 | class VOCDatasetConfig(DatasetConfigBase):
11 | def __init__(self):
12 | super(VOCDatasetConfig, self).__init__()
13 |
14 | self.CLASSES = [
15 | "aeroplane",
16 | "bicycle",
17 | "bird",
18 | "boat",
19 | "bottle",
20 | "bus",
21 | "car",
22 | "cat",
23 | "chair",
24 | "cow",
25 | "diningtable",
26 | "dog",
27 | "horse",
28 | "motorbike",
29 | "person",
30 | "pottedplant",
31 | "sheep",
32 | "sofa",
33 | "train",
34 | "tvmonitor",
35 | ]
36 |
37 | self.COLORS = DatasetConfigBase.generate_color_chart(self.num_classes)
38 |
39 |
40 | _voc_config = VOCDatasetConfig()
41 |
42 |
43 | class VOCDataset(DatasetBase):
44 | __name__ = "voc_dataset"
45 |
46 | def __init__(
47 | self,
48 | data_path,
49 | classes=_voc_config.CLASSES,
50 | colors=_voc_config.COLORS,
51 | phase="train",
52 | transform=None,
53 | shuffle=True,
54 | input_size=None,
55 | random_seed=2000,
56 | normalize_bbox=False,
57 | normalize_image=False,
58 | bbox_transformer=None,
59 | ):
60 | super(VOCDataset, self).__init__(
61 | data_path,
62 | classes=classes,
63 | colors=colors,
64 | phase=phase,
65 | transform=transform,
66 | shuffle=shuffle,
67 | normalize_bbox=normalize_bbox,
68 | bbox_transformer=bbox_transformer,
69 | )
70 |
71 | self._input_size = input_size
72 | self._normalize_image = normalize_image
73 |
74 | assert phase in ("train", "val")
75 | self._data_path = os.path.join(self._data_path, "voc/VOCdevkit/VOC2012")
76 | self._image_path_file = os.path.join(self._data_path, "ImageSets/Main/{}.txt".format(self._phase),)
77 |
78 | lines = [line.rstrip("\n") for line in open(self._image_path_file)]
79 |
80 | image_path_base = os.path.join(self._data_path, "JPEGImages")
81 | anno_path_base = os.path.join(self._data_path, "Annotations")
82 |
83 | self._image_paths = []
84 | anno_paths = []
85 |
86 | for line in lines:
87 | self._image_paths.append(os.path.join(image_path_base, "{}.jpg".format(line)))
88 | anno_paths.append(os.path.join(anno_path_base, "{}.xml".format(line)))
89 |
90 | self._targets = [self._parse_one_xml(xml_path) for xml_path in anno_paths]
91 |
92 | np.random.seed(random_seed)
93 |
94 | def _parse_one_xml(self, xml_path):
95 | bboxes = []
96 | labels = []
97 | xml = ET.parse(xml_path).getroot()
98 |
99 | for obj in xml.iter("object"):
100 | difficult = int(obj.find("difficult").text)
101 | if difficult == 1:
102 | continue
103 |
104 | bndbox = []
105 |
106 | name = obj.find("name").text.lower().strip()
107 | bbox = obj.find("bndbox")
108 |
109 | pts = ["xmin", "ymin", "xmax", "ymax"]
110 |
111 | for pt in pts:
112 | cur_pixel = int(bbox.find(pt).text) - 1
113 | bndbox.append(cur_pixel)
114 |
115 | label_idx = self._classes.index(name)
116 |
117 | bboxes.append(bndbox)
118 | labels.append(label_idx)
119 |
120 | return [bboxes, labels]
121 |
--------------------------------------------------------------------------------
/data_loader/wheat_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import os
3 | import cv2
4 | import random
5 | import numpy as np
6 | import json
7 | from .dataset_base import DatasetBase, DatasetConfigBase
8 |
9 |
10 | class WheatDatasetConfig(DatasetConfigBase):
11 | def __init__(self):
12 | super(WheatDatasetConfig, self).__init__()
13 |
14 | self.CLASSES = [
15 | "wheat",
16 | ]
17 |
18 | self.COLORS = DatasetConfigBase.generate_color_chart(self.num_classes)
19 |
20 |
21 | _wheat_config = WheatDatasetConfig()
22 |
23 |
24 | class WheatDataset(DatasetBase):
25 | __name__ = "wheat_dataset"
26 |
27 | def __init__(
28 | self,
29 | data_path,
30 | classes=_wheat_config.CLASSES,
31 | colors=_wheat_config.COLORS,
32 | phase="train",
33 | transform=None,
34 | shuffle=True,
35 | random_seed=2000,
36 | normalize_bbox=False,
37 | bbox_transformer=None,
38 | multiscale=False,
39 | resize_after_batch_num=10,
40 | train_val_ratio=0.9,
41 | ):
42 | super(WheatDataset, self).__init__(
43 | data_path,
44 | classes=classes,
45 | colors=colors,
46 | phase=phase,
47 | transform=transform,
48 | shuffle=shuffle,
49 | normalize_bbox=normalize_bbox,
50 | bbox_transformer=bbox_transformer,
51 | multiscale=multiscale,
52 | resize_after_batch_num=resize_after_batch_num,
53 | )
54 |
55 | assert os.path.isdir(data_path)
56 | assert phase in ("train", "val")
57 |
58 | self._wheat_data_path = os.path.join(data_path, "global-wheat-detection")
59 | assert os.path.isdir(self._wheat_data_path)
60 | self._train_csv_path = os.path.join(self._wheat_data_path, "train.csv")
61 | self._image_prefix_path = os.path.join(self._wheat_data_path, "train")
62 |
63 | self._phase = phase
64 | self._transform = transform
65 |
66 | self._image_paths, self._targets = self._process_train_csv(self._train_csv_path)
67 | self._image_paths = [os.path.join(self._image_prefix_path, elem) + ".jpg" for elem in self._image_paths]
68 |
69 | zipped = list(zip(self._image_paths, self._targets))
70 | random.seed(random_seed)
71 | random.shuffle(zipped)
72 | self._image_paths, self._targets = zip(*zipped)
73 |
74 | train_len = int(train_val_ratio * len(self._image_paths))
75 | if self._phase == "train":
76 | self._image_paths = self._image_paths[:train_len]
77 | self._targets = self._targets[:train_len]
78 | else:
79 | self._image_paths = self._image_paths[train_len:]
80 | self._targets = self._targets[train_len:]
81 |
82 | def _process_train_csv(self, train_csv_path):
83 | image_dict = {}
84 |
85 | import dask.dataframe as dd
86 |
87 | df = dd.read_csv(train_csv_path)
88 | for idx, row in df.iterrows():
89 | image_id = row["image_id"]
90 | if image_id not in image_dict.keys():
91 | image_dict[image_id] = [[], []]
92 |
93 | source = row["source"]
94 | width = row["width"]
95 | height = row["height"]
96 |
97 | bbox = row["bbox"].strip("][").split(", ")
98 | bbox = [float(elem) for elem in bbox]
99 |
100 | xmin, ymin, width, height = bbox
101 | xmax = xmin + width
102 | ymax = ymin + height
103 |
104 | bbox = [int(xmin), int(ymin), int(xmax), int(ymax)]
105 |
106 | image_dict[image_id][0].append(bbox)
107 | image_dict[image_id][1].append(0)
108 |
109 | return image_dict.keys(), image_dict.values()
110 |
--------------------------------------------------------------------------------
/dataset_params.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | __all__ = ["dataset_params"]
6 |
7 |
8 | dataset_params = {
9 | "coco_dataset": {
10 | "anchors": [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326],],
11 | "anchor_masks": [[6, 7, 8], [3, 4, 5], [0, 1, 2]],
12 | "num_classes": 80,
13 | "img_h": 416,
14 | "img_w": 416,
15 | },
16 | "udacity_dataset": {
17 | "anchors": [
18 | [0.014583333333333334, 0.03333333333333333],
19 | [0.017708333333333333, 0.05],
20 | [0.022916666666666665, 0.07333333333333333],
21 | [0.03333333333333333, 0.043333333333333335],
22 | [0.035416666666666666, 0.11833333333333333],
23 | [0.051041666666666666, 0.06],
24 | [0.08020833333333334, 0.08666666666666667],
25 | [0.12604166666666666, 0.15666666666666668],
26 | [0.2677083333333333, 0.3466666666666667],
27 | ],
28 | "anchor_masks": [[6, 7, 8], [3, 4, 5], [0, 1, 2]],
29 | "num_classes": 5,
30 | "img_h": 416,
31 | "img_w": 416,
32 | },
33 | "bird_dataset": {
34 | "anchors": [
35 | [0.013333333333333334, 0.013666666666666667],
36 | [0.016923076923076923, 0.027976190476190477],
37 | [0.022203947368421052, 0.044827615904163357],
38 | [0.025833333333333333, 0.016710875331564987],
39 | [0.034375, 0.028125],
40 | [0.038752362948960305, 0.07455104993043463],
41 | [0.05092592592592592, 0.04683129325109843],
42 | [0.06254458977407848, 0.0764872521246459],
43 | [0.07689655172413794, 0.14613778705636743],
44 | [0.11500570776255709, 0.09082682291666666],
45 | [0.162109375, 0.18448023426061494],
46 | [0.26129166666666664, 0.3815],
47 | ],
48 | "anchor_masks": [[8, 9, 10, 11], [4, 5, 6, 7], [0, 1, 2, 3]],
49 | "num_classes": 1,
50 | "img_h": 608,
51 | "img_w": 608,
52 | },
53 | "voc_dataset": {
54 | "anchors": [
55 | [0.052, 0.08],
56 | [0.07, 0.18018018018018017],
57 | [0.138, 0.30930930930930933],
58 | [0.162, 0.138],
59 | [0.224, 0.5350840978593272],
60 | [0.336, 0.29333333333333333],
61 | [0.41, 0.7173333333333334],
62 | [0.686, 0.466],
63 | [0.848, 0.88],
64 | ],
65 | "anchor_masks": [[6, 7, 8], [3, 4, 5], [0, 1, 2]],
66 | "num_classes": 20,
67 | "img_h": 416,
68 | "img_w": 416,
69 | },
70 | "switch_dataset": {
71 | "anchors": [
72 | [0.009895833333333333, 0.006481481481481481],
73 | [0.016145833333333335, 0.011111111111111112],
74 | [0.021875, 0.016666666666666666],
75 | [0.030208333333333334, 0.021296296296296296],
76 | [0.033854166666666664, 0.032407407407407406],
77 | [0.04635416666666667, 0.026851851851851852],
78 | [0.05572916666666667, 0.04351851851851852],
79 | [0.07916666666666666, 0.06574074074074074],
80 | [0.13619791666666664, 0.11481481481481481],
81 | ],
82 | "anchor_masks": [[6, 7, 8], [3, 4, 5], [0, 1, 2]],
83 | "num_classes": 3,
84 | "img_h": 608,
85 | "img_w": 608,
86 | },
87 | "wheat_dataset": {
88 | "anchors": [
89 | [0.044921875, 0.080078125],
90 | [0.0537109375, 0.052734375],
91 | [0.05859375, 0.033203125],
92 | [0.0703125, 0.06640625],
93 | [0.0771484375, 0.091796875],
94 | [0.0908203125, 0.0458984375],
95 | [0.0927734375, 0.1328125],
96 | [0.1064453125, 0.0703125],
97 | [0.142578125, 0.1025390625],
98 | ],
99 | "anchor_masks": [[6, 7, 8], [3, 4, 5], [0, 1, 2]],
100 | "num_classes": 1,
101 | "img_h": 608,
102 | "img_w": 608,
103 | },
104 | }
105 |
--------------------------------------------------------------------------------
/docs/images/labels.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xmba15/yolov3_pytorch/993d9bf966965cb2f7800da2cb3b88ce1ea17f51/docs/images/labels.gif
--------------------------------------------------------------------------------
/docs/images/test_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xmba15/yolov3_pytorch/993d9bf966965cb2f7800da2cb3b88ce1ea17f51/docs/images/test_result.jpg
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: bird
2 | channels:
3 | - pytorch
4 | - anaconda
5 | - defaults
6 | dependencies:
7 | - python >=3.6
8 | - cudatoolkit=10.0
9 | - cudnn=7.6.5
10 | - pip
11 | - pip:
12 | - -r requirements.txt
13 |
--------------------------------------------------------------------------------
/estimate_priors_size_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | def main(args):
6 | import numpy as np
7 | from config import Config
8 | from utils import estimate_priors_sizes
9 |
10 | total_config = Config()
11 | if not args.dataset or args.dataset not in total_config.DATASETS.keys():
12 | raise Exception("specify one of the datasets to use in {}".format(list(total_config.DATASETS.keys())))
13 |
14 | DatasetClass = total_config.DATASETS[args.dataset]
15 |
16 | dataset = DatasetClass(data_path=total_config.DATA_PATH)
17 | clusters = estimate_priors_sizes(dataset, args.k)
18 | clusters = [list(elem) for elem in sorted(clusters, key=lambda x: x[0], reverse=False)]
19 | print("Anchors:")
20 | print(clusters)
21 |
22 | anchor_masks = list(np.arange(args.k)[::-1].reshape(3, -1))
23 | anchor_masks = [list(elem[::-1]) for elem in anchor_masks]
24 | print("Anchor masks:")
25 | print(anchor_masks)
26 |
27 |
28 | if __name__ == "__main__":
29 | import argparse
30 |
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument("--k", type=int, default=9, help="number of clusters")
33 | parser.add_argument("--dataset", type=str, required=True, help="name of the dataset to use")
34 | parsed_args = parser.parse_args()
35 |
36 | main(parsed_args)
37 |
--------------------------------------------------------------------------------
/export_onnx.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from models import YoloNet
4 | import torch.onnx
5 |
6 |
7 | def main(args):
8 | import os
9 | from config import Config
10 |
11 | total_config = Config()
12 | if not args.dataset or args.dataset not in total_config.DATASETS.keys():
13 | raise Exception("specify one of the datasets to use in {}".format(list(total_config.DATASETS.keys())))
14 | if not args.snapshot or not os.path.isfile(args.snapshot):
15 | raise Exception("invalid snapshot")
16 |
17 | dataset = args.dataset
18 | dataset_class = total_config.DATASETS[dataset]
19 | dataset_params = total_config.DATASET_PARAMS[dataset]
20 | model = YoloNet(dataset_config=dataset_params)
21 | model.load_state_dict(torch.load(args.snapshot)["state_dict"])
22 | model.eval()
23 |
24 | if args.batch_size:
25 | batch_size = args.batch_size
26 | else:
27 | batch_size = 1
28 |
29 | x = torch.randn(batch_size, 3, dataset_params["img_h"], dataset_params["img_w"])
30 | torch.onnx.export(
31 | model,
32 | x,
33 | args.onnx_weight_file,
34 | verbose=True,
35 | input_names=["input"],
36 | output_names=["output"],
37 | do_constant_folding=True,
38 | operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
39 | opset_version=11,
40 | )
41 |
42 | if args.batch_size:
43 | return
44 |
45 | import onnx
46 |
47 | mp = onnx.load(args.onnx_weight_file)
48 | mp.graph.input[0].type.tensor_type.shape.dim[0].dim_param = "None"
49 | mp.graph.output[0].type.tensor_type.shape.dim[0].dim_param = "None"
50 | onnx.save(mp, "output.onnx")
51 |
52 |
53 | if __name__ == "__main__":
54 | import argparse
55 |
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument("--snapshot", required=True, type=str)
58 | parser.add_argument(
59 | "--dataset", type=str, required=True, help="name of the dataset to use",
60 | )
61 | parser.add_argument("--onnx_weight_file", type=str, default="output.onnx")
62 | parser.add_argument("--batch_size", type=int)
63 | parsed_args = parser.parse_args()
64 |
65 | main(parsed_args)
66 |
--------------------------------------------------------------------------------
/inference_onnx.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from models import DetectHandler
4 | import torch
5 |
6 |
7 | def main(args):
8 | import onnxruntime as rt
9 | import cv2
10 | import numpy as np
11 | from config import Config
12 |
13 | total_config = Config()
14 | dataset = args.dataset
15 | dataset_class = total_config.DATASETS[dataset]
16 | dataset_params = total_config.DATASET_PARAMS[dataset]
17 | dataset_instance = dataset_class(data_path=total_config.DATA_PATH)
18 |
19 | img = cv2.imread(args.image_path)
20 | assert img is not None
21 |
22 | ori_h, ori_w = img.shape[:2]
23 |
24 | h_ratio = ori_h / args.img_h
25 | w_ratio = ori_w / args.img_w
26 |
27 | processed_img = cv2.resize(img, (args.img_w, args.img_h))
28 | processed_img = processed_img / 255.0
29 | input_x = processed_img.transpose(2, 0, 1)[np.newaxis, :].astype(np.float32)
30 |
31 | sess = rt.InferenceSession(args.onnx_weight_file)
32 |
33 | assert len(sess.get_inputs()) == 1
34 | assert len(sess.get_outputs()) == 1
35 |
36 | input_name = sess.get_inputs()[0].name
37 | output_names = [elem.name for elem in sess.get_outputs()]
38 | predictions = sess.run(output_names, {input_name: input_x})[0]
39 |
40 | detect_handler = DetectHandler(
41 | num_classes=3, conf_thresh=args.conf_thresh, nms_thresh=args.nms_thresh, h_ratio=h_ratio, w_ratio=w_ratio,
42 | )
43 | bboxes, scores, classes = detect_handler(predictions)
44 |
45 | result = dataset_class.visualize_one_image_util(
46 | img, dataset_instance.classes, dataset_instance.colors, bboxes, classes,
47 | )
48 | cv2.imshow("result", result)
49 | cv2.waitKey(0)
50 | cv2.destroyAllWindows()
51 |
52 |
53 | if __name__ == "__main__":
54 | import argparse
55 |
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument("--img_h", type=int, required=True, help="height size of the input")
58 | parser.add_argument("--img_w", type=int, required=True, help="width size of the input")
59 | parser.add_argument(
60 | "--dataset", type=str, required=True, help="name of the dataset to use",
61 | )
62 | parser.add_argument("--image_path", type=str, required=True, help="path to image")
63 | parser.add_argument("--onnx_weight_file", type=str, required=True)
64 | parser.add_argument("--conf_thresh", type=float, default=0.1)
65 | parser.add_argument("--nms_thresh", type=float, default=0.5)
66 | parsed_args = parser.parse_args()
67 |
68 | main(parsed_args)
69 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from .functions import *
4 | from .layers import *
5 | from .yolo_layer import YoloLayer
6 | from .yolo_v3 import YoloNet
7 |
--------------------------------------------------------------------------------
/models/functions/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from .new_types import *
4 | from .utils import *
5 | from .detect import *
6 | from .mish import *
7 |
--------------------------------------------------------------------------------
/models/functions/detect.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Function
7 | from .utils import nms
8 | from .new_types import BboxType
9 |
10 |
11 | __all__ = ["DetectHandler"]
12 |
13 |
14 | class DetectHandler(Function):
15 | def __init__(
16 | self, num_classes: int, conf_thresh: float, nms_thresh: float, h_ratio: float = None, w_ratio: float = None
17 | ):
18 | super(DetectHandler, self).__init__()
19 | self.num_classes = num_classes
20 | self.conf_thresh = conf_thresh
21 | self.nms_thresh = nms_thresh
22 |
23 | self.h_ratio = h_ratio
24 | self.w_ratio = w_ratio
25 |
26 | def __call__(self, predictions: torch.Tensor):
27 | if isinstance(predictions, np.ndarray):
28 | predictions = torch.FloatTensor(predictions)
29 |
30 | bboxes = predictions[..., :4].squeeze_(dim=0)
31 | scores = predictions[..., 4].squeeze_(dim=0)
32 | classes_one_hot = predictions[..., 5:].squeeze_(dim=0)
33 | classes = torch.argmax(classes_one_hot, dim=1)
34 |
35 | bboxes, scores, classes = nms(
36 | bboxes,
37 | scores,
38 | classes,
39 | num_classes=self.num_classes,
40 | bbox_mode=BboxType.CXCYWH,
41 | conf_thresh=self.conf_thresh,
42 | nms_thresh=self.nms_thresh,
43 | )
44 |
45 | if self.h_ratio is not None and self.w_ratio is not None:
46 | bboxes[..., [0, 2]] *= self.w_ratio
47 | bboxes[..., [1, 3]] *= self.h_ratio
48 |
49 | return bboxes, scores, classes
50 |
--------------------------------------------------------------------------------
/models/functions/mish.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 |
7 |
8 | __all__ = ["Mish"]
9 |
10 |
11 | @torch.jit.script
12 | def _mish(input):
13 | """
14 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
15 | """
16 | return input * torch.tanh(F.softplus(input))
17 |
18 |
19 | class Mish(nn.Module):
20 | def __init__(self):
21 | super().__init__()
22 |
23 | def forward(self, input):
24 | return _mish(input)
25 |
--------------------------------------------------------------------------------
/models/functions/new_types.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import enum
4 |
5 |
6 | __all__ = ["BboxType"]
7 |
8 |
9 | class BboxType(enum.Enum):
10 | """
11 | XYWH: xmin, ymin, width, height
12 | CXCYWH: xcenter, ycenter, width, height
13 | XYXY: xmin, ymin, xmax, ymax
14 | """
15 |
16 | XYWH = 0
17 | CXCYWH = 1
18 | XYXY = 2
19 |
--------------------------------------------------------------------------------
/models/functions/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import numpy as np
4 | import torch
5 | from .new_types import BboxType
6 |
7 |
8 | __all__ = ["bbox_iou", "nms", "transform_bbox"]
9 |
10 |
11 | def bbox_iou(
12 | bboxes1: torch.Tensor, bboxes2: torch.Tensor, bbox_mode: BboxType = BboxType.XYXY, epsilon: float = 1e-16
13 | ) -> torch.Tensor:
14 | """
15 | Args:
16 | bboxes1: (num_boxes_1, 4)
17 | bboxes2: (num_boxes_2, 4)
18 |
19 | Return:
20 | ious: (num_boxes_1, num_boxes_2)
21 | """
22 | if bbox_mode == BboxType.XYXY:
23 | b1_x1, b1_y1, b1_x2, b1_y2 = (
24 | bboxes1[..., 0],
25 | bboxes1[..., 1],
26 | bboxes1[..., 2],
27 | bboxes1[..., 3],
28 | )
29 | b2_x1, b2_y1, b2_x2, b2_y2 = (
30 | bboxes2[..., 0],
31 | bboxes2[..., 1],
32 | bboxes2[..., 2],
33 | bboxes2[..., 3],
34 | )
35 | elif bbox_mode == BboxType.CXCYWH:
36 | b1_x1, b1_x2 = bboxes1[..., 0] - bboxes1[..., 2] / 2, bboxes1[..., 0] + bboxes1[..., 2] / 2
37 | b1_y1, b1_y2 = bboxes1[..., 1] - bboxes1[..., 3] / 2, bboxes1[..., 1] + bboxes1[..., 3] / 2
38 | b2_x1, b2_x2 = bboxes2[..., 0] - bboxes2[..., 2] / 2, bboxes2[..., 0] + bboxes2[..., 2] / 2
39 | b2_y1, b2_y2 = bboxes2[..., 1] - bboxes2[..., 3] / 2, bboxes2[..., 1] + bboxes2[..., 3] / 2
40 | elif bbox_mode == BboxType.XYWH:
41 | b1_x1, b1_y1 = bboxes1[..., 0], bboxes1[..., 1]
42 | b2_x1, b2_y1 = bboxes2[..., 0], bboxes2[..., 1]
43 | b1_x2, b1_y2 = bboxes1[..., 0] + bboxes1[..., 2], bboxes1[..., 1] + bboxes1[..., 3]
44 | b2_x2, b2_y2 = bboxes2[..., 0] + bboxes2[..., 2], bboxes2[..., 1] + bboxes2[..., 3]
45 | else:
46 | raise Exception("not supported bbox type\n")
47 |
48 | num_b1 = bboxes1.shape[0]
49 | num_b2 = bboxes2.shape[0]
50 |
51 | inter_x1 = torch.max(b1_x1.unsqueeze(1).repeat(1, num_b2), b2_x1)
52 | inter_y1 = torch.max(b1_y1.unsqueeze(1).repeat(1, num_b2), b2_y1)
53 | inter_x2 = torch.min(b1_x2.unsqueeze(1).repeat(1, num_b2), b2_x2)
54 | inter_y2 = torch.min(b1_y2.unsqueeze(1).repeat(1, num_b2), b2_y2)
55 |
56 | inter_area = torch.clamp(inter_x2 - inter_x1, min=0) * torch.clamp(inter_y2 - inter_y1, min=0)
57 | b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
58 | b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
59 | union_area = b1_area.unsqueeze(1).repeat(1, num_b2) + b2_area.unsqueeze(0).repeat(num_b1, 1) - inter_area + epsilon
60 |
61 | iou = inter_area / union_area
62 | return iou
63 |
64 |
65 | def nms(
66 | bboxes: torch.Tensor,
67 | scores: torch.Tensor,
68 | classes: torch.Tensor,
69 | num_classes: int,
70 | conf_thresh: float = 0.8,
71 | nms_thresh: float = 0.5,
72 | bbox_mode: BboxType = BboxType.CXCYWH,
73 | ) -> (torch.Tensor, torch.Tensor, torch.Tensor):
74 | """
75 | Args:
76 | bboxes: The location predictions for the img, Shape: [num_anchors,4].
77 | scores: The class prediction scores for the img, Shape:[num_anchors].
78 | classes: The label (non-one-hot) representation of the classes of the objects,
79 | Shape: [num_anchors].
80 | num_classes: number of classes
81 | conf_thresh: threshold where all the detections below this value will be ignored.
82 | nms_thresh: overlap thresh for suppressing unnecessary boxes.
83 |
84 | Return:
85 | (bboxes, scores, classes) after nms suppression with bboxes in BboxType.XYXY box mode
86 | """
87 |
88 | assert bboxes.shape[0] == scores.shape[0] == classes.shape[0]
89 | assert conf_thresh > 0
90 |
91 | num_anchors = bboxes.shape[0]
92 |
93 | if num_anchors == 0:
94 | return bboxes, scores, classes
95 |
96 | conf_index = torch.nonzero(torch.ge(scores, conf_thresh)).squeeze()
97 | bboxes = bboxes.index_select(0, conf_index)
98 | scores = scores.index_select(0, conf_index)
99 | classes = classes.index_select(0, conf_index)
100 |
101 | grouped_indices = _group_same_class_object(classes, one_hot=False, num_classes=num_classes)
102 | selected_indices_final = []
103 |
104 | for class_id, member_idx in enumerate(grouped_indices):
105 | member_idx_tensor = bboxes.new_tensor(member_idx, dtype=torch.long)
106 | bboxes_one_class = bboxes.index_select(dim=0, index=member_idx_tensor)
107 | scores_one_class = scores.index_select(dim=0, index=member_idx_tensor)
108 | scores_one_class, sorted_indices = torch.sort(scores_one_class, descending=False)
109 |
110 | selected_indices = []
111 |
112 | while sorted_indices.size(0) != 0:
113 | picked_index = sorted_indices[-1]
114 | selected_indices.append(picked_index)
115 | picked_bbox = bboxes_one_class[picked_index]
116 |
117 | picked_bbox.unsqueeze_(dim=0)
118 |
119 | ious = bbox_iou(picked_bbox, bboxes_one_class[sorted_indices[:-1]], bbox_mode=bbox_mode)
120 | ious.squeeze_(dim=0)
121 |
122 | under_indices = torch.nonzero(ious <= nms_thresh).squeeze()
123 | sorted_indices = sorted_indices.index_select(dim=0, index=under_indices)
124 |
125 | selected_indices_final.extend([member_idx[i] for i in selected_indices])
126 |
127 | selected_indices_final = bboxes.new_tensor(selected_indices_final, dtype=torch.long)
128 | bboxes_result = bboxes.index_select(dim=0, index=selected_indices_final)
129 | scores_result = scores.index_select(dim=0, index=selected_indices_final)
130 | classes_result = classes.index_select(dim=0, index=selected_indices_final)
131 |
132 | return transform_bbox(bboxes_result, orig_mode=bbox_mode, target_mode=BboxType.XYXY), scores_result, classes_result
133 |
134 |
135 | def transform_bbox(
136 | bboxes: torch.Tensor, orig_mode: BboxType = BboxType.CXCYWH, target_mode: BboxType = BboxType.XYXY
137 | ) -> torch.Tensor:
138 | assert orig_mode != target_mode
139 | assert bboxes.shape[1] == 4
140 |
141 | if orig_mode == BboxType.CXCYWH:
142 | if target_mode == BboxType.XYXY:
143 | return torch.cat((bboxes[:, :2] - bboxes[:, 2:] / 2, bboxes[:, :2] + bboxes[:, 2:] / 2), dim=-1,)
144 | elif target_mode == BboxType.XYWH:
145 | return torch.cat((bboxes[:, :2] - bboxes[:, 2:] / 2, bboxes[:, :2]), dim=-1,)
146 | else:
147 | raise Exception("not supported conversion\n")
148 | elif orig_mode == BboxType.XYWH:
149 | if target_mode == BboxType.XYXY:
150 | return torch.cat((bboxes[:, :2], bboxes[:, 2:] + bboxes[:, :2]), dim=-1,)
151 | elif target_mode == BboxType.CXCYWH:
152 | return torch.cat((bboxes[:, :2] + bboxes[:, 2:] / 2, bboxes[:, 2:]), dim=-1,)
153 | else:
154 | raise Exception("not supported conversion\n")
155 | elif orig_mode == BboxType.XYXY:
156 | if target_mode == BboxType.CXCYWH:
157 | return torch.cat(((bboxes[:, :2] + bboxes[:, 2:]) / 2, bboxes[:, 2:] - bboxes[:, :2]), dim=-1,)
158 | if target_mode == BboxType.XYWH:
159 | return torch.cat((bboxes[:, :2], bboxes[:, 2:] - bboxes[:, :2]), dim=-1,)
160 | else:
161 | raise Exception("not supported conversion\n")
162 | else:
163 | raise Exception("not supported original bbox mode\n")
164 |
165 |
166 | def _group_same_class_object(obj_classes: torch.Tensor, one_hot: bool = True, num_classes: int = -1):
167 | """
168 | Given a list of class results, group the object with the same class into a list.
169 | Returns a list with the length of num_classes, where each bucket has the objects with the same class.
170 |
171 | Args:
172 | obj_classes: The representation of classes of object.
173 | It can be either one-hot or label (non-one-hot).
174 | If it is one-hot, the shape should be: [num_objects, num_classes]
175 | If it is label (non-non-hot), the shape should be: [num_objects, ]
176 | one_hot: A flag telling the function whether obj_classes is one-hot representation.
177 | num_classes: The max number of classes if obj_classes is represented as non-one-hot format.
178 |
179 | Returns:
180 | a list of of a list, where for the i-th list,
181 | the elements in such list represent the indices of the objects in class i.
182 | """
183 |
184 | if one_hot:
185 | num_classes = obj_classes.shape[-1]
186 | else:
187 | assert num_classes != -1
188 | grouped_index = [[] for _ in range(num_classes)]
189 | if one_hot:
190 | for idx, class_one_hot in enumerate(obj_classes):
191 | grouped_index[torch.argmax(class_one_hot)].append(idx)
192 | else:
193 | for idx, obj_class_ in enumerate(obj_classes):
194 | grouped_index[obj_class_].append(idx)
195 |
196 | return grouped_index
197 |
--------------------------------------------------------------------------------
/models/layers/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from .backbone import *
4 |
--------------------------------------------------------------------------------
/models/layers/backbone.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import torch
4 | import torch.nn as nn
5 | from models.functions import Mish
6 |
7 |
8 | __all__ = ["ConvBnAct", "Darknet", "PreDetectionConvGroup", "UpsampleGroup"]
9 |
10 |
11 | def _activation_func(activation):
12 | import sys
13 | import copy
14 |
15 | try:
16 | return copy.deepcopy(
17 | nn.ModuleDict(
18 | [
19 | ["relu", nn.ReLU(inplace=True)],
20 | ["leaky_relu", nn.LeakyReLU(negative_slope=0.1, inplace=True)],
21 | ["selu", nn.SELU(inplace=True)],
22 | ["mish", Mish()],
23 | ["identiy", nn.Identity()],
24 | ]
25 | )[activation]
26 | )
27 | except Exception as e:
28 | print("no activation {}".format(activation))
29 | sys.exit(-1)
30 |
31 |
32 | class ConvBnAct(nn.Module):
33 | def __init__(
34 | self, nin, nout, ks, s=1, pad="SAME", padding=0, use_bn=True, act="mish",
35 | ):
36 | super(ConvBnAct, self).__init__()
37 |
38 | self.use_bn = use_bn
39 | self.bn = None
40 | if use_bn:
41 | self.bn = nn.BatchNorm2d(nout)
42 |
43 | if pad == "SAME":
44 | padding = (ks - 1) // 2
45 |
46 | self.conv = nn.Conv2d(
47 | in_channels=nin, out_channels=nout, kernel_size=ks, stride=s, padding=padding, bias=not use_bn
48 | )
49 | self.activation = _activation_func(act)
50 |
51 | def forward(self, x):
52 | out = self.conv(x)
53 | if self.use_bn:
54 | out = self.bn(out)
55 |
56 | return self.activation(out)
57 |
58 |
59 | class ResResidualBlock(nn.Module):
60 | def __init__(self, nin):
61 | super(ResResidualBlock, self).__init__()
62 | self.conv1 = ConvBnAct(nin, nin // 2, ks=1)
63 | self.conv2 = ConvBnAct(nin // 2, nin, ks=3)
64 |
65 | def forward(self, x):
66 | return x + self.conv2(self.conv1(x))
67 |
68 |
69 | def _map_2_cfgdict(module_list):
70 | from collections import OrderedDict
71 |
72 | idx = 0
73 | mdict = OrderedDict()
74 | for i, m in enumerate(module_list):
75 | if isinstance(m, ResResidualBlock):
76 | mdict[idx] = None
77 | mdict[idx + 1] = None
78 | idx += 2
79 | mdict[idx] = i
80 | idx += 1
81 | return mdict
82 |
83 |
84 | def _make_res_stack(nin, num_blk):
85 | return nn.ModuleList([ConvBnAct(nin, nin * 2, 3, s=2)] + [ResResidualBlock(nin * 2) for n in range(num_blk)])
86 |
87 |
88 | class Darknet(nn.Module):
89 | def __init__(self, blk_list, nout=32):
90 | super(Darknet, self).__init__()
91 |
92 | self.module_list = nn.ModuleList()
93 | self.module_list += [ConvBnAct(3, nout, 3)]
94 | for i, nb in enumerate(blk_list):
95 | self.module_list += _make_res_stack(nout * (2 ** i), nb)
96 |
97 | self.map2yolocfg = _map_2_cfgdict(self.module_list)
98 | self.cached_out_dict = dict()
99 |
100 | def forward(self, x):
101 | for i, m in enumerate(self.module_list):
102 | x = m(x)
103 | if i in self.cached_out_dict:
104 | self.cached_out_dict[i] = x
105 | return x
106 |
107 | # mode - normal -- direct index to module_list
108 | # - yolocfg -- index follow the sequences of the cfg file from https://github.com/pjreddie/darknet/blob/master/cfg/yolov3.cfg
109 | def add_cached_out(self, idx, mode="yolocfg"):
110 | if mode == "yolocfg":
111 | idxs = self.map2yolocfg[idx]
112 | self.cached_out_dict[idxs] = None
113 |
114 | def get_cached_out(self, idx, mode="yolocfg"):
115 | if mode == "yolocfg":
116 | idxs = self.map2yolocfg[idx]
117 | return self.cached_out_dict[idxs]
118 |
119 | def load_weight(self, weights_path):
120 | wm = WeightLoader(self)
121 | wm.load_weight(weights_path)
122 |
123 |
124 | class PreDetectionConvGroup(nn.Module):
125 | def __init__(self, nin, nout, num_classes, num_anchors, num_conv=3):
126 | super(PreDetectionConvGroup, self).__init__()
127 | self.module_list = nn.ModuleList()
128 |
129 | for i in range(num_conv):
130 | self.module_list += [ConvBnAct(nin, nout, ks=1)]
131 | self.module_list += [ConvBnAct(nout, nout * 2, ks=3)]
132 | if i == 0:
133 | nin = nout * 2
134 |
135 | self.module_list += [nn.Conv2d(nin, (num_classes + 5) * num_anchors, 1)]
136 | self.map2yolocfg = _map_2_cfgdict(self.module_list)
137 | self.cached_out_dict = dict()
138 |
139 | def forward(self, x):
140 | for i, m in enumerate(self.module_list):
141 | x = m(x)
142 | if i in self.cached_out_dict:
143 | self.cached_out_dict[i] = x
144 | return x
145 |
146 | def add_cached_out(self, idx, mode="yolocfg"):
147 | if mode == "yolocfg":
148 | idx = self.get_idx_from_yolo_idx(idx)
149 | elif idx < 0:
150 | idx = len(self.module_list) - idx
151 |
152 | self.cached_out_dict[idx] = None
153 |
154 | def get_cached_out(self, idx, mode="yolocfg"):
155 | if mode == "yolocfg":
156 | idx = self.get_idx_from_yolo_idx(idx)
157 | elif idx < 0:
158 | idx = len(self.module_list) - idx
159 | return self.cached_out_dict[idx]
160 |
161 | def get_idx_from_yolo_idx(self, idx):
162 | if idx < 0:
163 | return len(self.map2yolocfg) + idx
164 | else:
165 | return self.map2yolocfg[idx]
166 |
167 |
168 | class UpsampleGroup(nn.Module):
169 | def __init__(self, nin):
170 | super(UpsampleGroup, self).__init__()
171 | self.conv = ConvBnAct(nin, nin // 2, ks=1)
172 |
173 | def forward(self, route_head, route_tail):
174 | out = self.conv(route_head)
175 | out = nn.functional.interpolate(out, scale_factor=2, mode="nearest")
176 | return torch.cat((out, route_tail), 1)
177 |
178 |
179 | class WeightLoader:
180 | def __init__(self, model):
181 | super(WeightLoader, self).__init__()
182 | self._conv_list = self._find_conv_layers(model)
183 |
184 | def load_weight(self, weight_path):
185 | ptr = 0
186 | weights = self._read_file(weight_path)
187 | for m in self._conv_list:
188 | if type(m) == ConvBnAct:
189 | ptr = self._load_conv_bn_relu(m, weights, ptr)
190 | elif type(m) == nn.Conv2d:
191 | ptr = self._load_conv2d(m, weights, ptr)
192 | return ptr
193 |
194 | def _read_file(self, file):
195 | import numpy as np
196 |
197 | with open(file, "rb") as fp:
198 | header = np.fromfile(fp, dtype=np.int32, count=5)
199 | self.header = torch.from_numpy(header)
200 | self.seen = self.header[3]
201 | weights = np.fromfile(fp, dtype=np.float32)
202 | return weights
203 |
204 | def _copy_weight_to_model_parameters(self, param, weights, ptr):
205 | num_el = param.numel()
206 | param.data.copy_(torch.from_numpy(weights[ptr : ptr + num_el]).view_as(param.data))
207 | return ptr + num_el
208 |
209 | def _load_conv_bn_relu(self, m, weights, ptr):
210 | ptr = self._copy_weight_to_model_parameters(m.bn.bias, weights, ptr)
211 | ptr = self._copy_weight_to_model_parameters(m.bn.weight, weights, ptr)
212 | ptr = self._copy_weight_to_model_parameters(m.bn.running_mean, weights, ptr)
213 | ptr = self._copy_weight_to_model_parameters(m.bn.running_var, weights, ptr)
214 | ptr = self._copy_weight_to_model_parameters(m.conv.weight, weights, ptr)
215 | return ptr
216 |
217 | def _load_conv2d(self, m, weights, ptr):
218 | ptr = self._copy_weight_to_model_parameters(m.bias, weights, ptr)
219 | ptr = self._copy_weight_to_model_parameters(m.weight, weights, ptr)
220 | return ptr
221 |
222 | def _find_conv_layers(self, mod):
223 | module_list = []
224 | for m in mod.children():
225 | if type(m) == ConvBnAct:
226 | module_list += [m]
227 | elif type(m) == nn.Conv2d:
228 | module_list += [m]
229 | elif isinstance(m, (nn.ModuleList, nn.Module)):
230 | module_list += self._find_conv_layers(m)
231 | elif type(m) == ResResidualBlock:
232 | module_list += self._find_conv_layers(m)
233 | return module_list
234 |
--------------------------------------------------------------------------------
/models/yolo_layer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import torch
4 | import torch.nn as nn
5 | import math
6 | from .functions import bbox_iou, BboxType
7 |
8 |
9 | __all__ = ["YoloLayer"]
10 |
11 |
12 | class YoloLayer(nn.Module):
13 | def __init__(
14 | self,
15 | anchors_all,
16 | anchors_mask,
17 | num_classes,
18 | lambda_xy=1,
19 | lambda_wh=1,
20 | lambda_conf=1,
21 | lambda_cls=1,
22 | obj_scale=1,
23 | noobj_scale=1,
24 | ignore_thres=0.7,
25 | epsilon=1e-16,
26 | ):
27 | super(YoloLayer, self).__init__()
28 |
29 | assert num_classes > 0
30 |
31 | self._anchors_all = anchors_all
32 | self._anchors_mask = anchors_mask
33 |
34 | self._num_classes = num_classes
35 | self._bbox_attrib = 5 + num_classes
36 |
37 | self._lambda_xy = lambda_xy
38 | self._lambda_wh = lambda_wh
39 | self._lambda_conf = lambda_conf
40 |
41 | if self._num_classes == 1:
42 | self._lambda_cls = 0
43 | else:
44 | self._lambda_cls = lambda_cls
45 |
46 | self._obj_scale = obj_scale
47 | self._noobj_scale = noobj_scale
48 | self._ignore_thres = ignore_thres
49 |
50 | self._epsilon = epsilon
51 |
52 | self._mseloss = nn.MSELoss(reduction="sum")
53 | self._bceloss = nn.BCELoss(reduction="sum")
54 | self._bceloss_average = nn.BCELoss(reduction="elementwise_mean")
55 |
56 | def forward(self, x: torch.Tensor, img_dim: tuple, target=None):
57 | # x : batch_size * nA * (5 + num_classes) * H * W
58 |
59 | device = x.device
60 | if target is not None:
61 | assert target.device == x.device
62 |
63 | nB = x.shape[0]
64 | nA = len(self._anchors_mask)
65 | nH, nW = x.shape[2], x.shape[3]
66 | stride = img_dim[1] / nH
67 | anchors_all = torch.FloatTensor(self._anchors_all) / stride
68 | anchors = anchors_all[self._anchors_mask]
69 |
70 | # Reshape predictions from [B x [A * (5 + num_classes)] x H x W] to [B x A x H x W x (5 + num_classes)]
71 | preds = x.view(nB, nA, self._bbox_attrib, nH, nW).permute(0, 1, 3, 4, 2).contiguous()
72 |
73 | # tx, ty, tw, wh
74 | preds_xy = preds[..., :2].sigmoid()
75 | preds_wh = preds[..., 2:4]
76 | preds_conf = preds[..., 4].sigmoid()
77 | preds_cls = preds[..., 5:].sigmoid()
78 |
79 | # calculate cx, cy, anchor mesh
80 | mesh_y, mesh_x = torch.meshgrid([torch.arange(nH, device=device), torch.arange(nW, device=device)])
81 | mesh_xy = torch.stack((mesh_x, mesh_y), 2).float()
82 |
83 | mesh_anchors = anchors.view(1, nA, 1, 1, 2).repeat(1, 1, nH, nW, 1).to(device)
84 |
85 | # pred_boxes holds bx,by,bw,bh
86 | pred_boxes = torch.FloatTensor(preds[..., :4].shape)
87 | pred_boxes[..., :2] = preds_xy + mesh_xy
88 | pred_boxes[..., 2:4] = preds_wh.exp() * mesh_anchors
89 |
90 | if target is not None:
91 | (
92 | obj_mask,
93 | noobj_mask,
94 | box_coord_mask,
95 | tconf,
96 | tcls,
97 | tx,
98 | ty,
99 | tw,
100 | th,
101 | nCorrect,
102 | nGT,
103 | ) = self.build_target_tensor(
104 | pred_boxes, target, anchors_all, anchors, (nH, nW), self._num_classes, self._ignore_thres,
105 | )
106 |
107 | # masks for loss calculations
108 | obj_mask, noobj_mask = obj_mask.to(device), noobj_mask.to(device)
109 | box_coord_mask = box_coord_mask.to(device)
110 | cls_mask = obj_mask == 1
111 | tconf, tcls = tconf.to(device), tcls.to(device)
112 | tx, ty, tw, th = tx.to(device), ty.to(device), tw.to(device), th.to(device)
113 |
114 | loss_x = self._lambda_xy * self._mseloss(preds_xy[..., 0] * box_coord_mask, tx * box_coord_mask) / 2
115 | loss_y = self._lambda_xy * self._mseloss(preds_xy[..., 1] * box_coord_mask, ty * box_coord_mask) / 2
116 | loss_w = self._lambda_wh * self._mseloss(preds_wh[..., 0] * box_coord_mask, tw * box_coord_mask) / 2
117 | loss_h = self._lambda_wh * self._mseloss(preds_wh[..., 1] * box_coord_mask, th * box_coord_mask) / 2
118 |
119 | loss_conf = (
120 | self._lambda_conf
121 | * (
122 | self._obj_scale * self._bceloss(preds_conf * obj_mask, obj_mask)
123 | + self._noobj_scale * self._bceloss(preds_conf * noobj_mask, noobj_mask * 0)
124 | )
125 | / 1
126 | )
127 | loss_cls = self._lambda_cls * self._bceloss(preds_cls[cls_mask], tcls[cls_mask]) / 1
128 | loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls
129 |
130 | return (
131 | loss,
132 | loss.item() / nB,
133 | loss_x.item() / nB,
134 | loss_y.item() / nB,
135 | loss_w.item() / nB,
136 | loss_h.item() / nB,
137 | loss_conf.item() / nB,
138 | loss_cls.item() / nB,
139 | nCorrect,
140 | nGT,
141 | )
142 |
143 | out = torch.cat((pred_boxes.to(device) * stride, preds_conf.to(device).unsqueeze(4), preds_cls.to(device),), 4,)
144 |
145 | # Reshape predictions from [B x A x H x W x (5 + num_classes)] to [B x [A x H x W] x (5 + num_classes)]
146 | out = out.permute(0, 2, 3, 1, 4).contiguous().view(nB, nA * nH * nW, self._bbox_attrib)
147 |
148 | return out
149 |
150 | def build_target_tensor(
151 | self, pred_boxes, target, anchors_all, anchors, inp_dim, num_classes, ignore_thres,
152 | ):
153 | nB = target.shape[0]
154 | nA = len(anchors)
155 | nH, nW = inp_dim[0], inp_dim[1]
156 | nCorrect = 0
157 | nGT = 0
158 | target = target.float()
159 |
160 | obj_mask = torch.zeros(nB, nA, nH, nW, requires_grad=False)
161 | noobj_mask = torch.ones(nB, nA, nH, nW, requires_grad=False)
162 | box_coord_mask = torch.zeros(nB, nA, nH, nW, requires_grad=False)
163 | tconf = torch.zeros(nB, nA, nH, nW, requires_grad=False)
164 | tcls = torch.zeros(nB, nA, nH, nW, num_classes, requires_grad=False)
165 | tx = torch.zeros(nB, nA, nH, nW, requires_grad=False)
166 | ty = torch.zeros(nB, nA, nH, nW, requires_grad=False)
167 | tw = torch.zeros(nB, nA, nH, nW, requires_grad=False)
168 | th = torch.zeros(nB, nA, nH, nW, requires_grad=False)
169 |
170 | for b in range(nB):
171 | for t in range(target.shape[1]):
172 |
173 | # ignore padded labels
174 | if target[b, t].sum() == 0:
175 | break
176 |
177 | gx = target[b, t, 0] * nW
178 | gy = target[b, t, 1] * nH
179 | gw = target[b, t, 2] * nW
180 | gh = target[b, t, 3] * nH
181 | gi = int(gx)
182 | gj = int(gy)
183 |
184 | # pred_boxes - [A x H x W x 4]
185 | # Do not train for objectness(noobj) if anchor iou > threshold.
186 | tmp_gt_boxes = torch.FloatTensor([gx, gy, gw, gh]).unsqueeze(0)
187 | tmp_pred_boxes = pred_boxes[b].view(-1, 4)
188 | tmp_ious, _ = torch.max(bbox_iou(tmp_pred_boxes, tmp_gt_boxes, bbox_mode=BboxType.CXCYWH), 1)
189 | ignore_idx = (tmp_ious > ignore_thres).view(nA, nH, nW)
190 | noobj_mask[b][ignore_idx] = 0
191 |
192 | # find best fit anchor for each ground truth box
193 | tmp_gt_boxes = torch.FloatTensor([[0, 0, gw, gh]])
194 | tmp_anchor_boxes = torch.cat((torch.zeros(len(anchors_all), 2), anchors_all), 1)
195 | tmp_ious = bbox_iou(tmp_anchor_boxes, tmp_gt_boxes, bbox_mode=BboxType.CXCYWH)
196 | best_anchor = torch.argmax(tmp_ious, 0).item()
197 |
198 | # If the best_anchor belongs to this yolo_layer
199 | if best_anchor in self._anchors_mask:
200 | best_anchor = self._anchors_mask.index(best_anchor)
201 | # find iou for best fit anchor prediction box against the ground truth box
202 | tmp_gt_box = torch.FloatTensor([gx, gy, gw, gh]).unsqueeze(0)
203 | tmp_pred_box = pred_boxes[b, best_anchor, gj, gi].view(-1, 4)
204 | tmp_iou = bbox_iou(tmp_gt_box, tmp_pred_box, bbox_mode=BboxType.CXCYWH)
205 |
206 | if tmp_iou > 0.5:
207 | nCorrect += 1
208 |
209 | # larger gradient for small objects
210 | box_coord_mask[b, best_anchor, gj, gi] = math.sqrt(2 - target[b, t, 2] * target[b, t, 3])
211 |
212 | obj_mask[b, best_anchor, gj, gi] = 1
213 | tconf[b, best_anchor, gj, gi] = 1
214 | tcls[b, best_anchor, gj, gi, int(target[b, t, 4])] = 1
215 | tx[b, best_anchor, gj, gi] = gx - gi
216 | ty[b, best_anchor, gj, gi] = gy - gj
217 | tw[b, best_anchor, gj, gi] = torch.log(gw / anchors[best_anchor, 0] + self._epsilon)
218 | th[b, best_anchor, gj, gi] = torch.log(gh / anchors[best_anchor, 1] + self._epsilon)
219 |
220 | nGT += 1
221 | return (
222 | obj_mask,
223 | noobj_mask,
224 | box_coord_mask,
225 | tconf,
226 | tcls,
227 | tx,
228 | ty,
229 | tw,
230 | th,
231 | nCorrect,
232 | nGT,
233 | )
234 |
--------------------------------------------------------------------------------
/models/yolo_v3.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .yolo_layer import YoloLayer
7 | from .layers import Darknet, PreDetectionConvGroup, UpsampleGroup
8 |
9 |
10 | __all__ = ["YoloNet"]
11 |
12 | _YOLOLAYER_PARAMS = {
13 | "lambda_xy": 1,
14 | "lambda_wh": 1,
15 | "lambda_conf": 1,
16 | "lambda_cls": 1,
17 | "obj_scale": 1,
18 | "noobj_scale": 1,
19 | "ignore_thres": 0.7,
20 | }
21 |
22 |
23 | class YoloNet(nn.Module):
24 | def __init__(self, dataset_config, yololayer_params=_YOLOLAYER_PARAMS):
25 | super(YoloNet, self).__init__()
26 |
27 | self._yololayer_params = yololayer_params
28 | self._all_anchors = [
29 | [int(w_e * dataset_config["img_w"]), int(h_e * dataset_config["img_h"])]
30 | for (w_e, h_e) in dataset_config["anchors"]
31 | ]
32 | self._anchors_masks = dataset_config["anchor_masks"]
33 | assert len(self._anchors_masks) > 0
34 | self._num_anchors_each_layer = len(self._anchors_masks[0])
35 |
36 | self._num_classes = dataset_config["num_classes"]
37 |
38 | self.stat_keys = [
39 | "loss",
40 | "loss_x",
41 | "loss_y",
42 | "loss_w",
43 | "loss_h",
44 | "loss_conf",
45 | "loss_cls",
46 | "nCorrect",
47 | "nGT",
48 | "recall",
49 | ]
50 |
51 | self.feature = Darknet([1, 2, 8, 8, 4])
52 | self.feature.add_cached_out(61)
53 | self.feature.add_cached_out(36)
54 |
55 | self.pre_det1 = PreDetectionConvGroup(
56 | 1024, 512, num_classes=self._num_classes, num_anchors=self._num_anchors_each_layer
57 | )
58 | self.pre_det1.add_cached_out(-3)
59 |
60 | self.up1 = UpsampleGroup(512)
61 | self.pre_det2 = PreDetectionConvGroup(
62 | 768, 256, num_classes=self._num_classes, num_anchors=self._num_anchors_each_layer
63 | )
64 | self.pre_det2.add_cached_out(-3)
65 |
66 | self.up2 = UpsampleGroup(256)
67 | self.pre_det3 = PreDetectionConvGroup(
68 | 384, 128, num_classes=self._num_classes, num_anchors=self._num_anchors_each_layer
69 | )
70 |
71 | self.yolo_layers = [
72 | YoloLayer(
73 | anchors_all=self._all_anchors,
74 | anchors_mask=anchors_mask,
75 | num_classes=self._num_classes,
76 | lambda_xy=self._yololayer_params["lambda_xy"],
77 | lambda_wh=self._yololayer_params["lambda_wh"],
78 | lambda_conf=self._yololayer_params["lambda_conf"],
79 | lambda_cls=self._yololayer_params["lambda_cls"],
80 | obj_scale=self._yololayer_params["obj_scale"],
81 | noobj_scale=self._yololayer_params["noobj_scale"],
82 | ignore_thres=self._yololayer_params["ignore_thres"],
83 | )
84 | for anchors_mask in self._anchors_masks
85 | ]
86 |
87 | def forward(self, x: torch.Tensor, target=None):
88 | img_dim = (x.shape[3], x.shape[2])
89 | out = self.feature(x)
90 | dets = []
91 |
92 | # Detection layer 1
93 | out = self.pre_det1(out)
94 | dets.append(self.yolo_layers[0](out, img_dim, target))
95 |
96 | # Upsample 1
97 | r_head1 = self.pre_det1.get_cached_out(-3)
98 | r_tail1 = self.feature.get_cached_out(61)
99 | out = self.up1(r_head1, r_tail1)
100 |
101 | # Detection layer 2
102 | out = self.pre_det2(out)
103 | dets.append(self.yolo_layers[1](out, img_dim, target))
104 |
105 | # Upsample 2
106 | r_head2 = self.pre_det2.get_cached_out(-3)
107 | r_tail2 = self.feature.get_cached_out(36)
108 | out = self.up2(r_head2, r_tail2)
109 |
110 | # Detection layer 3
111 | out = self.pre_det3(out)
112 | dets.append(self.yolo_layers[2](out, img_dim, target))
113 |
114 | if target is not None:
115 | loss, *out = [sum(det) for det in zip(dets[0], dets[1], dets[2])]
116 |
117 | self.stats = dict(zip(self.stat_keys, out))
118 | self.stats["recall"] = self.stats["nCorrect"] / self.stats["nGT"] if self.stats["nGT"] else 0
119 | return loss
120 | else:
121 | return torch.cat(dets, 1)
122 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | albumentations==0.4.3
2 | torch==1.5.0
3 | torchvision==0.6.0
4 | onnx==1.7.0
5 | onnxruntime==1.3.0
6 | tensorboardX==2.0
7 | pycocotools
8 | tqdm
9 |
--------------------------------------------------------------------------------
/scripts/download_bird_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import sys
5 |
6 |
7 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
8 | sys.path.append(os.path.join(_CURRENT_DIR, ".."))
9 | try:
10 | from utils import download_file_from_google_drive
11 | except Exception as e:
12 | print(e)
13 | exit(0)
14 |
15 |
16 | def main():
17 | data_path = os.path.join(_CURRENT_DIR, "../data")
18 | file_id = "16IQjiGu-jl2oTqr5wsp9MmJxtQiuyIWq"
19 | destination = os.path.join(data_path, "bird_dataset.zip")
20 | if not os.path.isfile(destination) and not os.path.isdir(os.path.join(data_path, "bird_dataset")):
21 | download_file_from_google_drive(file_id, destination)
22 | os.system("cd {} && unzip bird_dataset.zip".format(data_path))
23 |
24 |
25 | if __name__ == "__main__":
26 | main()
27 |
--------------------------------------------------------------------------------
/scripts/download_darknet_weight.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import sys
5 |
6 |
7 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
8 | sys.path.append(os.path.join(_CURRENT_DIR, ".."))
9 | try:
10 | from utils import download_file_from_google_drive
11 | except Exception as e:
12 | print(e)
13 | sys.exit(-1)
14 |
15 |
16 | def main():
17 | data_path = os.path.join(_CURRENT_DIR, "../saved_models")
18 | os.system("mkdir -p {}".format(data_path))
19 | file_id = "1_-FQFU1i79WySBehqdUXAdbI_-RSvFqb"
20 | destination = os.path.join(data_path, "darknet53.conv.74")
21 | if not os.path.isfile(destination):
22 | download_file_from_google_drive(file_id, destination)
23 |
24 |
25 | if __name__ == "__main__":
26 | main()
27 |
--------------------------------------------------------------------------------
/scripts/download_darknet_weight.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | readonly WEIGHT_PATH="https://pjreddie.com/media/files/darknet53.conv.74"
4 | readonly CURRENT_DIR=$(dirname $(realpath $0))
5 | readonly DATA_PATH=$(realpath ${CURRENT_DIR}/../saved_models)
6 |
7 | function validate_url {
8 | wget --spider $1 &> /dev/null;
9 | }
10 |
11 | if ! validate_url $WEIGHT_PATH; then
12 | echo "Invalid url to download darknet weight";
13 | exit;
14 | fi
15 |
16 | echo "start downloading darknet weights"
17 | if [ ! -d ${DATA_PATH} ]; then
18 | mkdir -p ${DATA_PATH}
19 | fi
20 |
21 | if [ ! -f ${DATA_PATH}/darknet53.conv.74 ]; then
22 | wget -c ${WEIGHT_PATH} -P ${DATA_PATH}
23 | fi
24 |
--------------------------------------------------------------------------------
/scripts/download_switch_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import os
3 | import sys
4 |
5 |
6 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
7 | sys.path.append(os.path.join(_CURRENT_DIR, ".."))
8 | try:
9 | from utils import download_file_from_google_drive
10 | except Exception as e:
11 | print(e)
12 | exit(0)
13 |
14 |
15 | def main():
16 | data_path = os.path.join(_CURRENT_DIR, "../data")
17 | file_id = "1EWEhmvDaYYm0SsydUEGWUDrnBzkLEQc_"
18 | destination = os.path.join(data_path, "switch_detection.zip")
19 | if not os.path.isfile(destination) and not os.path.isdir(os.path.join(data_path, "switch_detection")):
20 | download_file_from_google_drive(file_id, destination)
21 | os.system("cd {} && unzip switch_detection.zip".format(data_path))
22 |
23 |
24 | if __name__ == "__main__":
25 | main()
26 |
--------------------------------------------------------------------------------
/scripts/download_udacity_dataset.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # https://github.com/BeSlower/Udacity_object_dataset
4 | readonly CURRENT_DIR=$(dirname $(realpath $0))
5 | readonly DATA_PATH_BASE=$(realpath ${CURRENT_DIR}/../data)
6 | readonly DATA_PATH=${DATA_PATH_BASE}/udacity
7 |
8 | echo "start downloading udacity dataset"
9 | if [ ! -d ${DATA_PATH} ]; then
10 | mkdir -p ${DATA_PATH}
11 | fi
12 |
13 | if [ ! -f ${DATA_PATH}/object-dataset.tar.gz ]; then
14 | wget -c https://s3.amazonaws.com/udacity-sdc/annotations/object-dataset.tar.gz -P ${DATA_PATH}
15 | fi
16 |
17 | tar -xvf ${DATA_PATH}/object-dataset.tar.gz -C ${DATA_PATH}
18 |
--------------------------------------------------------------------------------
/scripts/download_voc_dataset.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | readonly CURRENT_DIR=$(dirname $(realpath $0))
4 | readonly DATA_PATH_BASE=$(realpath ${CURRENT_DIR}/../data)
5 | readonly DATA_PATH=${DATA_PATH_BASE}/voc
6 | if [ ! -d ${DATA_PATH} ]; then
7 | mkdir -p ${DATA_PATH}
8 | fi
9 |
10 | DOWNLOAD_FILES='
11 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
12 | http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
13 | '
14 |
15 | for download_file in ${DOWNLOAD_FILES}; do
16 | file_name=$(basename $download_file)
17 | if [ ! -f $DATA_PATH/$file_name ]; then
18 | wget ${download_file} -P $DATA_PATH
19 | extension="${file_name##*.}"
20 | if [[ ${extension} == "zip" ]]; then
21 | unzip $DATA_PATH/${file_name} -d ${DATA_PATH}
22 | fi
23 |
24 | if [[ ${extension} == "tar" ]]; then
25 | tar xvf $DATA_PATH/${file_name} -C ${DATA_PATH}
26 | fi
27 | fi
28 | done
29 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import torch
4 | from models import YoloNet
5 | from data_loader import DataTransformBase
6 | from models import DetectHandler
7 |
8 |
9 | def test_one_image(
10 | args, total_config, dataset_class, dataset_params,
11 | ):
12 | import cv2
13 | import numpy as np
14 |
15 | model_path = args.snapshot
16 | dataset = args.dataset
17 | dataset_params = total_config.DATASET_PARAMS[dataset]
18 | input_size = (dataset_params["img_h"], dataset_params["img_w"])
19 |
20 | dataset_instance = dataset_class(data_path=total_config.DATA_PATH)
21 | num_classes = dataset_instance.num_classes
22 | model = YoloNet(dataset_config=dataset_params)
23 | model.load_state_dict(torch.load(model_path)["state_dict"])
24 | model.eval()
25 |
26 | img = cv2.imread(args.image_path)
27 | orig_img = np.copy(img)
28 |
29 | ori_h, ori_w = img.shape[:2]
30 | h_ratio = ori_h / dataset_params["img_h"]
31 | w_ratio = ori_w / dataset_params["img_w"]
32 |
33 | img = cv2.resize(img, input_size)
34 |
35 | img = img / 255.0
36 | input_x = torch.tensor(img.transpose(2, 0, 1)[np.newaxis, :]).float()
37 |
38 | predictions = model(input_x)
39 |
40 | detect_handler = DetectHandler(
41 | num_classes=dataset_params["num_classes"],
42 | conf_thresh=args.conf_thresh,
43 | nms_thresh=args.nms_thresh,
44 | h_ratio=h_ratio,
45 | w_ratio=w_ratio,
46 | )
47 |
48 | bboxes, scores, classes = detect_handler(predictions)
49 |
50 | result = dataset_class.visualize_one_image_util(
51 | orig_img, dataset_instance.classes, dataset_instance.colors, bboxes, classes,
52 | )
53 |
54 | return orig_img
55 |
56 |
57 | def main(args):
58 | import cv2
59 | import os
60 | from config import Config
61 |
62 | total_config = Config()
63 | if not args.dataset or args.dataset not in total_config.DATASETS.keys():
64 | raise Exception("specify one of the datasets to use in {}".format(list(total_config.DATASETS.keys())))
65 | if not args.snapshot or not os.path.isfile(args.snapshot):
66 | raise Exception("invalid snapshot")
67 | if not args.image_path or not os.path.isfile(args.image_path):
68 | raise Exception("invalid image path")
69 |
70 | dataset = args.dataset
71 | dataset_class = total_config.DATASETS[dataset]
72 | dataset_params = total_config.DATASET_PARAMS[dataset]
73 |
74 | result = test_one_image(args, total_config, dataset_class, dataset_params,)
75 |
76 | cv2.imshow("result", result)
77 | cv2.waitKey(0)
78 | cv2.destroyAllWindows()
79 | cv2.imwrite("result.jpg", result)
80 |
81 |
82 | if __name__ == "__main__":
83 | import argparse
84 |
85 | parser = argparse.ArgumentParser()
86 | parser.add_argument("--snapshot", required=True, type=str)
87 | parser.add_argument(
88 | "--dataset", type=str, required=True, help="name of the dataset to use",
89 | )
90 | parser.add_argument("--image_path", required=True, type=str, help="path to the test image")
91 | parser.add_argument("--conf_thresh", type=float, default=0.1)
92 | parser.add_argument("--nms_thresh", type=float, default=0.5)
93 | parsed_args = parser.parse_args()
94 |
95 | main(parsed_args)
96 |
--------------------------------------------------------------------------------
/test/get_dataset_size_distribution.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import sys
5 |
6 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
7 | try:
8 | sys.path.append(os.path.join(_CURRENT_DIR, ".."))
9 | from config import Config
10 | except Exception as e:
11 | print(e)
12 | exit(1)
13 |
14 |
15 | def main(args):
16 | import numpy as np
17 | from matplotlib import pyplot as plt
18 |
19 | total_config = Config()
20 | if not args.dataset or args.dataset not in total_config.DATASETS.keys():
21 | raise Exception("specify one of the datasets to use in {}".format(list(total_config.DATASETS.keys())))
22 |
23 | DatasetClass = total_config.DATASETS[args.dataset]
24 |
25 | dataset = DatasetClass(data_path=total_config.DATA_PATH)
26 | size_list = dataset.size_distribution()
27 | hist, bins = np.histogram(size_list, bins=args.num_bins)
28 |
29 | print("min size {}".format(min(size_list)))
30 | print("max size {}".format(max(size_list)))
31 | print(hist, bins)
32 | plt.hist(size_list, bins=args.num_bins)
33 | plt.title("size/distance histogram")
34 | # plt.show()
35 | plt.savefig("size_distance_hist.png")
36 |
37 |
38 | if __name__ == "__main__":
39 | import argparse
40 |
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("--dataset", type=str, required=True, help="name of the dataset to use")
43 | parser.add_argument("--num_bins", type=int, default=18, help="number of bins in histogram")
44 | parsed_args = parser.parse_args()
45 |
46 | main(parsed_args)
47 |
--------------------------------------------------------------------------------
/test/test_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import sys
5 |
6 |
7 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
8 | try:
9 | sys.path.append(os.path.join(_CURRENT_DIR, ".."))
10 | from models import YoloNet
11 | except Exception as e:
12 | print(e)
13 | sys.exit(-1)
14 |
15 |
16 | def main(args):
17 | import torch
18 |
19 | input_tensor = torch.randn(1, 3, args.img_h, args.img_w)
20 |
21 | params = {
22 | "anchors": [
23 | [0.016923076923076923, 0.027196652719665274],
24 | [0.018, 0.013855213023900243],
25 | [0.02355072463768116, 0.044977511244377814],
26 | [0.033722163308589605, 0.025525525525525526],
27 | [0.049479166666666664, 0.049575070821529746],
28 | [0.05290373906125696, 0.08290488431876607],
29 | [0.09375, 0.098],
30 | [0.150390625, 0.1838283227241353],
31 | [0.26125, 0.36540185240513895],
32 | ],
33 | "anchor_masks": [[6, 7, 8], [3, 4, 5], [0, 1, 2]],
34 | "num_classes": args.num_classes,
35 | "img_w": args.img_w,
36 | "img_h": args.img_h,
37 | }
38 |
39 | model = YoloNet(dataset_config=params)
40 | output = model(input_tensor)
41 | for idx, p in enumerate(output):
42 | print("branch_{}: {}".format(idx, p.size()))
43 |
44 |
45 | if __name__ == "__main__":
46 | import argparse
47 |
48 | parser = argparse.ArgumentParser()
49 | parser.add_argument("--img_h", type=int, default=608)
50 | parser.add_argument("--img_w", type=int, default=608)
51 | parser.add_argument("--num_classes", type=int, default=80)
52 | parsed_args = parser.parse_args()
53 |
54 | main(parsed_args)
55 |
--------------------------------------------------------------------------------
/test/test_visualization.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import cv2
5 | import sys
6 | import argparse
7 |
8 |
9 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
10 | try:
11 | sys.path.append(os.path.join(_CURRENT_DIR, ".."))
12 | from config import Config
13 | except Exception as e:
14 | print(e)
15 | sys.exit(-1)
16 |
17 |
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("--idx", type=int, default=0, help="index of the images")
20 | parser.add_argument("--dataset", type=str, help="name of the dataset to use")
21 | parsed_args = parser.parse_args()
22 |
23 |
24 | def main(args):
25 | dt_config = Config()
26 | if not args.dataset or args.dataset not in dt_config.DATASETS.keys():
27 | raise Exception("specify one of the datasets to use in {}".format(list(dt_config.DATASETS.keys())))
28 |
29 | DatasetClass = dt_config.DATASETS[args.dataset]
30 |
31 | dataset = DatasetClass(data_path=dt_config.DATA_PATH)
32 | print("length of the dataset: {}".format(len(dataset)))
33 |
34 | assert args.idx < len(dataset)
35 | img = dataset.visualize_one_image(args.idx)
36 | cv2.imshow("visualized_bboxes", cv2.resize(img, (1000, 1000)))
37 | cv2.waitKey(0)
38 | cv2.imwrite("visualized_bboxes.png", img)
39 |
40 |
41 | if __name__ == "__main__":
42 | main(parsed_args)
43 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | from torch.utils.data import DataLoader
5 | import torch.optim as optim
6 | import torch.optim.lr_scheduler as lr_scheduler
7 | import torch
8 | from tensorboardX import SummaryWriter
9 | from albumentations import *
10 |
11 | from models import YoloNet
12 | from config import Config
13 | from trainer import Trainer
14 | from data_loader import DataTransformBase
15 |
16 |
17 | def train_process(args, total_config, dataset_class, data_transform_class, params):
18 |
19 | # --------------------------------------------------------------------------#
20 | # prepare dataset
21 | # --------------------------------------------------------------------------#
22 |
23 | def _worker_init_fn_():
24 | import random
25 | import numpy as np
26 | import torch
27 |
28 | torch.manual_seed(args.random_seed)
29 | np.random.seed(args.random_seed)
30 | random.seed(args.random_seed)
31 | if torch.cuda.is_available():
32 | torch.cuda.manual_seed(args.random_seed)
33 |
34 | input_size = (params["img_h"], params["img_w"])
35 |
36 | transforms = [
37 | OneOf([IAAAdditiveGaussianNoise(), GaussNoise()], p=0.5),
38 | OneOf([MedianBlur(blur_limit=3), GaussianBlur(blur_limit=3), MotionBlur(blur_limit=3),], p=0.1,),
39 | RandomGamma(gamma_limit=(80, 120), p=0.5),
40 | RandomBrightnessContrast(p=0.5),
41 | HueSaturationValue(hue_shift_limit=5, sat_shift_limit=20, val_shift_limit=10, p=0.5),
42 | ChannelShuffle(p=0.5),
43 | HorizontalFlip(p=0.5),
44 | Cutout(num_holes=5, max_w_size=40, max_h_size=40, p=0.5),
45 | Rotate(limit=20, p=0.5, border_mode=0),
46 | ]
47 |
48 | data_transform = data_transform_class(transforms=transforms, input_size=input_size)
49 | train_dataset = dataset_class(
50 | data_path=total_config.DATA_PATH,
51 | phase="train",
52 | normalize_bbox=True,
53 | transform=[data_transform],
54 | multiscale=args.multiscale,
55 | resize_after_batch_num=args.resize_after_batch_num,
56 | )
57 |
58 | val_dataset = dataset_class(
59 | data_path=total_config.DATA_PATH, phase="val", normalize_bbox=True, transform=data_transform,
60 | )
61 |
62 | train_data_loader = DataLoader(
63 | train_dataset,
64 | batch_size=args.batch_size,
65 | shuffle=True,
66 | collate_fn=train_dataset.od_collate_fn,
67 | num_workers=args.num_workers,
68 | drop_last=True,
69 | worker_init_fn=_worker_init_fn_(),
70 | )
71 | val_data_loader = DataLoader(
72 | val_dataset,
73 | batch_size=args.batch_size,
74 | shuffle=False,
75 | collate_fn=val_dataset.od_collate_fn,
76 | num_workers=args.num_workers,
77 | drop_last=True,
78 | )
79 | data_loaders_dict = {"train": train_data_loader, "val": val_data_loader}
80 |
81 | # --------------------------------------------------------------------------#
82 | # configuration for training
83 | # --------------------------------------------------------------------------#
84 |
85 | tblogger = SummaryWriter(total_config.LOG_PATH)
86 |
87 | model = YoloNet(dataset_config=params)
88 | if args.backbone_weight_path:
89 | model.feature.load_weight(args.backbone_weight_path)
90 |
91 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
92 | criterion = None
93 |
94 | base_lr_rate = args.lr_rate / (args.batch_size * args.batch_multiplier)
95 | base_weight_decay = args.weight_decay * (args.batch_size * args.batch_multiplier)
96 | steps = [float(v.strip()) for v in args.steps.split(",")]
97 | scales = [float(v.strip()) for v in args.scales.split(",")]
98 |
99 | def adjust_learning_rate(optimizer, processed_batch):
100 | lr = base_lr_rate
101 | for i in range(len(steps)):
102 | scale = scales[i] if i < len(scales) else 1
103 | if processed_batch >= steps[i]:
104 | lr = lr * scale
105 | if processed_batch == steps[i]:
106 | break
107 | else:
108 | break
109 | for param_group in optimizer.param_groups:
110 | param_group["lr"] = lr / args.batch_size
111 | return lr
112 |
113 | optimizer = torch.optim.SGD(
114 | model.parameters(), lr=base_lr_rate, momentum=args.momentum, weight_decay=base_weight_decay,
115 | )
116 |
117 | trainer = Trainer(
118 | model=model,
119 | criterion=criterion,
120 | metric_func=None,
121 | optimizer=optimizer,
122 | num_epochs=args.num_epoch,
123 | save_period=args.save_period,
124 | config=total_config,
125 | data_loaders_dict=data_loaders_dict,
126 | device=device,
127 | dataset_name_base=train_dataset.__name__,
128 | batch_multiplier=args.batch_multiplier,
129 | adjust_lr_callback=adjust_learning_rate,
130 | logger=tblogger,
131 | )
132 |
133 | if args.snapshot and os.path.isfile(args.snapshot):
134 | trainer.resume_checkpoint(args.snapshot)
135 |
136 | with torch.autograd.set_detect_anomaly(True):
137 | trainer.train()
138 |
139 | tblogger.close()
140 |
141 |
142 | def main(args):
143 | total_config = Config()
144 | total_config.display()
145 | if not args.dataset or args.dataset not in total_config.DATASETS.keys():
146 | raise Exception("specify one of the datasets to use in {}".format(list(total_config.DATASETS.keys())))
147 |
148 | dataset = args.dataset
149 | dataset_class = total_config.DATASETS[dataset]
150 | data_transform_class = DataTransformBase
151 | params = total_config.DATASET_PARAMS[dataset]
152 | train_process(args, total_config, dataset_class, data_transform_class, params)
153 |
154 |
155 | if __name__ == "__main__":
156 | import argparse
157 |
158 | parser = argparse.ArgumentParser()
159 | parser.add_argument("--batch_size", type=int, default=2)
160 | parser.add_argument("--num_epoch", type=int, default=300)
161 | parser.add_argument("--lr_rate", type=float, default=1e-3)
162 | parser.add_argument("--batch_multiplier", type=int, default=1)
163 | parser.add_argument("--momentum", default=0.9, type=float)
164 | parser.add_argument("--weight_decay", default=5e-4, type=float)
165 | parser.add_argument("--burn_in", default=1000, type=int)
166 | parser.add_argument("--steps", default="40000,45000", type=str)
167 | parser.add_argument("--scales", default=".1,.1", type=str)
168 | parser.add_argument("--gamma", default=0.1, type=float)
169 | parser.add_argument("--milestones", default="120, 220", type=str)
170 | parser.add_argument("--save_period", type=int, default=1)
171 | parser.add_argument("--backbone_weight_path", type=str)
172 | parser.add_argument("--multiscale", type=bool, default=True)
173 | parser.add_argument("--resize_after_batch_num", type=int, default=10)
174 | parser.add_argument("--snapshot", type=str, help="path to snapshot weights")
175 | parser.add_argument(
176 | "--dataset",
177 | required=True,
178 | type=str,
179 | help="name of the dataset to use",
180 | choices=["bird_dataset", "switch_dataset", "wheat_dataset"],
181 | )
182 | parser.add_argument("--random_seed", type=int, default=12)
183 | parser.add_argument("--num_workers", type=int, default=4)
184 | parsed_args = parser.parse_args()
185 |
186 | main(parsed_args)
187 |
--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from .trainer import Trainer
4 |
--------------------------------------------------------------------------------
/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import sys
5 | import torch
6 | import torch.nn as nn
7 | import logging
8 | from .trainer_base import TrainerBase
9 |
10 |
11 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
12 | sys.path.append(os.path.join(_CURRENT_DIR, ".."))
13 | try:
14 | from utils import inf_loop
15 | except Exception as e:
16 | print(e)
17 | sys.exit(-1)
18 |
19 |
20 | class Trainer(TrainerBase):
21 | def __init__(
22 | self,
23 | model,
24 | criterion,
25 | metric_func,
26 | optimizer,
27 | num_epochs,
28 | save_period,
29 | config,
30 | data_loaders_dict,
31 | scheduler=None,
32 | device=None,
33 | len_epoch=None,
34 | dataset_name_base="",
35 | batch_multiplier=1,
36 | logger=None,
37 | processed_batch=0,
38 | adjust_lr_callback=None,
39 | print_after_batch_num=10,
40 | ):
41 | super(Trainer, self).__init__(
42 | model,
43 | criterion,
44 | metric_func,
45 | optimizer,
46 | num_epochs,
47 | save_period,
48 | config,
49 | device,
50 | dataset_name_base,
51 | batch_multiplier,
52 | logger,
53 | )
54 |
55 | self.train_data_loader = data_loaders_dict["train"]
56 | self.val_data_loader = data_loaders_dict["val"]
57 |
58 | self.num_train_imgs = len(self.train_data_loader.dataset)
59 | self.num_val_imgs = len(self.val_data_loader.dataset)
60 |
61 | self.processed_batch = processed_batch
62 | self.adjust_lr_callback = adjust_lr_callback
63 |
64 | if len_epoch is None:
65 | self._len_epoch = len(self.train_data_loader)
66 | else:
67 | self.train_data_loader = inf_loop(self.train_data_loader)
68 | self._len_epoch = len_epoch
69 |
70 | self._do_validation = self.val_data_loader is not None
71 | self._scheduler = scheduler
72 |
73 | self._print_after_batch_num = print_after_batch_num
74 |
75 | self._stat_keys = ["loss", "loss_x", "loss_y", "loss_w", "loss_h", "loss_conf", "loss_cls"]
76 |
77 | def _train_epoch(self, epoch):
78 | self._model.train()
79 |
80 | batch_size = self.train_data_loader.batch_size
81 |
82 | epoch_train_loss = 0.0
83 | count = self._batch_multiplier
84 | running_losses = dict(zip(self._stat_keys, [0.0] * len(self._stat_keys)))
85 |
86 | for batch_idx, (data, target, length_tensor) in enumerate(self.train_data_loader):
87 |
88 | if self.adjust_lr_callback is not None:
89 | self.adjust_lr_callback(self._optimizer, self.processed_batch)
90 | self.processed_batch += 1
91 |
92 | data = data.to(self._device)
93 | target = target.to(self._device)
94 | length_tensor = length_tensor.to(self._device)
95 |
96 | if count == 0:
97 | self._optimizer.step()
98 | self._optimizer.zero_grad()
99 | count = self._batch_multiplier
100 |
101 | with torch.set_grad_enabled(True):
102 | output = self._model(data)
103 |
104 | train_loss = self._model(data, target)
105 | for key in self._stat_keys:
106 | running_losses[key] += self._model.stats[key]
107 |
108 | total_loss = train_loss / self._batch_multiplier
109 | total_loss.backward()
110 | count -= 1
111 |
112 | if (batch_idx + 1) % self._print_after_batch_num == 0:
113 | logging.info(
114 | "\n epoch: {}/{} || iter: {}/{} || [Losses: total: {}, loss_x: {}, loss_y: {}, loss_w: {}, loss_h: {}, loss_conf: {}, loss_cls: {} || lr_rate: {}".format(
115 | epoch,
116 | self._num_epochs,
117 | batch_idx,
118 | len(self.train_data_loader),
119 | running_losses["loss"] / self._print_after_batch_num,
120 | running_losses["loss_x"] / self._print_after_batch_num,
121 | running_losses["loss_y"] / self._print_after_batch_num,
122 | running_losses["loss_w"] / self._print_after_batch_num,
123 | running_losses["loss_h"] / self._print_after_batch_num,
124 | running_losses["loss_conf"] / self._print_after_batch_num,
125 | running_losses["loss_cls"] / self._print_after_batch_num,
126 | self._optimizer.param_groups[0]["lr"],
127 | )
128 | )
129 | running_losses = dict(zip(self._stat_keys, [0.0] * len(self._stat_keys)))
130 |
131 | epoch_train_loss += total_loss.item() * self._batch_multiplier
132 |
133 | if batch_idx == self._len_epoch:
134 | break
135 |
136 | if self._do_validation:
137 | epoch_val_loss = self._valid_epoch(epoch)
138 |
139 | if self._scheduler is not None:
140 | self._scheduler.step()
141 |
142 | return (
143 | epoch_train_loss / self.num_train_imgs,
144 | epoch_val_loss / self.num_val_imgs,
145 | )
146 |
147 | def _valid_epoch(self, epoch):
148 | print("start validation...")
149 | self._model.eval()
150 |
151 | epoch_val_loss = 0.0
152 | with torch.no_grad():
153 | for batch_idx, (data, target, length_tensor) in enumerate(self.val_data_loader):
154 |
155 | data = data.to(self._device)
156 | target = target.to(self._device)
157 | length_tensor = length_tensor.to(self._device)
158 |
159 | val_loss = self._model(data, target)
160 | epoch_val_loss += val_loss.item()
161 |
162 | return epoch_val_loss
163 |
--------------------------------------------------------------------------------
/trainer/trainer_base.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import logging
4 | import os
5 | import sys
6 | import torch
7 | from abc import abstractmethod
8 |
9 | _CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
10 | sys.path.append(os.path.join(_CURRENT_DIR, ".."))
11 |
12 |
13 | class TrainerBase:
14 | def __init__(
15 | self,
16 | model,
17 | criterion,
18 | metric_func,
19 | optimizer,
20 | num_epochs,
21 | save_period,
22 | config,
23 | device=None,
24 | dataset_name_base="",
25 | batch_multiplier=1,
26 | logger=None,
27 | ):
28 | self._model = model
29 | self.criterion = criterion
30 | self._metric_func = metric_func
31 | self._optimizer = optimizer
32 | self._config = config
33 | self._dataset_name_base = dataset_name_base
34 | self._batch_multiplier = batch_multiplier
35 | self._logger = logger
36 |
37 | if device is None:
38 | self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
39 | else:
40 | self._device = device
41 |
42 | self._model = self._model.to(self._device)
43 |
44 | self._config = config
45 |
46 | self._checkpoint_dir = self._config.SAVED_MODEL_PATH
47 |
48 | self._start_epoch = 1
49 |
50 | self._num_epochs = num_epochs
51 |
52 | self._save_period = save_period
53 |
54 | @property
55 | def model(self):
56 | return self._model
57 |
58 | @abstractmethod
59 | def _train_epoch(self, epoch):
60 | raise NotImplementedError
61 |
62 | def _save_checkpoint(self, epoch, save_best=False):
63 | arch = type(self._model).__name__
64 | state = {
65 | "arch": arch,
66 | "epoch": epoch,
67 | "state_dict": self._model.state_dict(),
68 | "optimizer": self._optimizer.state_dict(),
69 | }
70 | output_file = "checkpoint_{}_epoch_{}.pth".format(arch, epoch)
71 | if self._dataset_name_base and isinstance(self._dataset_name_base, str) and self._dataset_name_base != "":
72 | output_file = "{}_{}".format(self._dataset_name_base, output_file)
73 |
74 | filename = os.path.join(self._checkpoint_dir, output_file)
75 | torch.save(state, filename)
76 |
77 | # if save_best:
78 | # best_path = os.path.join(self._checkpoint_dir, "model_best.pth")
79 | # torch.save(state, best_path)
80 |
81 | def resume_checkpoint(self, resume_path):
82 | resume_path = str(resume_path)
83 |
84 | checkpoint = torch.load(resume_path)
85 | self._start_epoch = checkpoint["epoch"] + 1
86 |
87 | self._model.load_state_dict(checkpoint["state_dict"])
88 |
89 | self._optimizer.load_state_dict(checkpoint["optimizer"])
90 |
91 | def train(self):
92 | logging.info("========================================")
93 | logging.info("Start training {}".format(type(self._model).__name__))
94 | logging.info("========================================")
95 | logs = []
96 |
97 | for epoch in range(self._start_epoch, self._num_epochs + 1):
98 | train_loss, val_loss = self._train_epoch(epoch)
99 |
100 | log_epoch = {
101 | "epoch": epoch,
102 | "train_loss": train_loss,
103 | "val_loss": val_loss,
104 | }
105 |
106 | logging.info(
107 | "\n----------------------------------------------------\n"
108 | + "epoch: {}, train_loss: {: .4f}, val_loss: {: .4f}".format(epoch, train_loss, val_loss)
109 | + "\n----------------------------------------------------\n"
110 | )
111 | logs.append(log_epoch)
112 | if self._logger:
113 | self._logger.add_scalar("train/train_loss", train_loss, epoch)
114 | self._logger.add_scalar("val/val_loss", val_loss, epoch)
115 |
116 | if (epoch + 1) % self._save_period == 0:
117 | self._save_checkpoint(epoch, save_best=True)
118 |
119 | return logs
120 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from .kmeans_bboxes import jaccard, avg_iou, kmeans_bboxes, estimate_avg_ious
4 | from .priors_bboxes import estimate_priors_sizes
5 | from .utils import *
6 | from .download_utility import *
7 |
--------------------------------------------------------------------------------
/utils/download_utility.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import requests
4 | import tqdm
5 |
6 | __all__ = ["download_file_from_google_drive"]
7 |
8 |
9 | def download_file_from_google_drive(id, destination):
10 | URL = "https://docs.google.com/uc?export=download"
11 |
12 | session = requests.Session()
13 |
14 | response = session.get(URL, params={"id": id}, stream=True)
15 | token = get_confirm_token(response)
16 |
17 | if token:
18 | params = {"id": id, "confirm": token}
19 | response = session.get(URL, params=params, stream=True)
20 |
21 | save_response_content(response, destination)
22 |
23 |
24 | def get_confirm_token(response):
25 | for key, value in response.cookies.items():
26 | if key.startswith("download_warning"):
27 | return value
28 |
29 | return None
30 |
31 |
32 | def save_response_content(response, destination):
33 | CHUNK_SIZE = 32768
34 | with open(destination, "wb") as f:
35 | for chunk in tqdm.tqdm(response.iter_content(CHUNK_SIZE)):
36 | if chunk: # filter out keep-alive new chunks
37 | f.write(chunk)
38 |
--------------------------------------------------------------------------------
/utils/kmeans_bboxes.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # ref: https://lars76.github.io/object-detection/k-means-anchor-boxes/
4 | import numpy as np
5 |
6 |
7 | def jaccard(bboxes, clusters):
8 | x = np.minimum(clusters[:, 0], bboxes[0])
9 | y = np.minimum(clusters[:, 1], bboxes[1])
10 | if np.count_nonzero(x == 0) > 0 or np.count_nonzero(y == 0) > 0:
11 | raise ValueError("Box has no area")
12 |
13 | intersection = x * y
14 | box_area = bboxes[0] * bboxes[1]
15 | cluster_area = clusters[:, 0] * clusters[:, 1]
16 |
17 | iou = intersection / (box_area + cluster_area - intersection)
18 |
19 | return iou
20 |
21 |
22 | def avg_iou(bboxes, clusters):
23 | return np.mean([np.max(jaccard(bboxes[i], clusters)) for i in range(bboxes.shape[0])])
24 |
25 |
26 | def kmeans_bboxes(bboxes, k=5, metric_dist=np.median, seed=100):
27 | assert k >= 2
28 | rows = bboxes.shape[0]
29 |
30 | distances = np.empty((rows, k))
31 | last_clusters = np.zeros((rows,))
32 |
33 | np.random.seed(seed)
34 |
35 | clusters = bboxes[np.random.choice(rows, k, replace=False)]
36 |
37 | while True:
38 | for row in range(rows):
39 | distances[row] = 1 - jaccard(bboxes[row], clusters)
40 |
41 | nearest_clusters = np.argmin(distances, axis=1)
42 |
43 | if (last_clusters == nearest_clusters).all():
44 | break
45 |
46 | for cluster in range(k):
47 | clusters[cluster] = metric_dist(bboxes[nearest_clusters == cluster], axis=0)
48 |
49 | last_clusters = nearest_clusters
50 |
51 | return clusters
52 |
53 |
54 | def estimate_avg_ious(bboxes, k=5, metric_dist=np.median, seed=100):
55 | assert k >= 2
56 |
57 | num_vals = k - 1
58 | avg_ious = np.empty((num_vals,), dtype=float)
59 | for idx, k in enumerate(tqdm.tqdm(range(2, k + 1))):
60 | clusters = kmeans_bboxes(bboxes, k, metric_dist, seed)
61 | avg_ious[idx] = avg_iou(bboxes, clusters)
62 |
63 | return avg_ious
64 |
--------------------------------------------------------------------------------
/utils/priors_bboxes.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from .kmeans_bboxes import kmeans_bboxes
4 |
5 |
6 | def estimate_priors_sizes(dataset, k=5):
7 | bboxes = dataset.get_all_normalized_boxes()
8 | return kmeans_bboxes(bboxes, k)
9 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import torch
4 | import numpy as np
5 |
6 |
7 | __all__ = ["inf_loop"]
8 |
9 |
10 | def inf_loop(data_loader):
11 | from itertools import repeat
12 |
13 | for loader in repeat(data_loader):
14 | yield from loader
15 |
--------------------------------------------------------------------------------