├── LICENSE
├── README.md
├── Train_Detectron2_Object_Detector_Custom_Data.ipynb
├── class.names
├── data
└── .gitignore
├── loss.py
├── plot_loss.py
├── predict.py
├── requirements.txt
├── train.py
├── util.py
└── visualize_data.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 computervisiondeveloper
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # train-object-detector-detectron2
2 |
3 |
4 |
5 |
6 | Watch on YouTube: Train object detector custom data with detectron2 !
7 |
8 |
9 |
--------------------------------------------------------------------------------
/Train_Detectron2_Object_Detector_Custom_Data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": []
7 | },
8 | "kernelspec": {
9 | "name": "python3",
10 | "display_name": "Python 3"
11 | },
12 | "language_info": {
13 | "name": "python"
14 | },
15 | "accelerator": "GPU",
16 | "gpuClass": "standard"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "source": [
22 | "### Connect Google Drive"
23 | ],
24 | "metadata": {
25 | "id": "-GIYjVIA96Ne"
26 | }
27 | },
28 | {
29 | "cell_type": "code",
30 | "source": [
31 | "from google.colab import drive\n",
32 | "\n",
33 | "drive.mount('/content/gdrive')"
34 | ],
35 | "metadata": {
36 | "id": "-wxgALmBx6Rj"
37 | },
38 | "execution_count": null,
39 | "outputs": []
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "source": [
44 | "### Install requirements"
45 | ],
46 | "metadata": {
47 | "id": "D9ZsFr6--Lxg"
48 | }
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {
54 | "id": "xWh477fIw_al"
55 | },
56 | "outputs": [],
57 | "source": [
58 | "!pip install torch==2.0.0\n",
59 | "!pip install torchvision==0.15.1\n",
60 | "!pip install opencv-python==4.6.0.66\n",
61 | "!pip install matplotlib==3.5.3\n",
62 | "!pip install git+https://github.com/facebookresearch/detectron2.git\n"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "source": [
68 | "### Change working directory"
69 | ],
70 | "metadata": {
71 | "id": "q1PMCpJK-QNH"
72 | }
73 | },
74 | {
75 | "cell_type": "code",
76 | "source": [
77 | "%cd /content/gdrive/MyDrive/ComputerVisionEngineer/TrainDetectron2ObjectDetector/"
78 | ],
79 | "metadata": {
80 | "id": "La68ZGbq0Qni"
81 | },
82 | "execution_count": null,
83 | "outputs": []
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "source": [
88 | "### Train !"
89 | ],
90 | "metadata": {
91 | "id": "3chYfIWs-Sqa"
92 | }
93 | },
94 | {
95 | "cell_type": "code",
96 | "source": [
97 | "!python train.py --device gpu --learning-rate 0.00001 --iterations 6000"
98 | ],
99 | "metadata": {
100 | "id": "5xDV20jH0mzk"
101 | },
102 | "execution_count": null,
103 | "outputs": []
104 | }
105 | ]
106 | }
--------------------------------------------------------------------------------
/class.names:
--------------------------------------------------------------------------------
1 | alpaca
2 |
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | .gitignore
2 |
3 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | from detectron2.engine import HookBase
2 | from detectron2.data import build_detection_train_loader
3 | import detectron2.utils.comm as comm
4 | import torch
5 |
6 |
7 | class ValidationLoss(HookBase):
8 | """
9 | A hook that computes validation loss during training.
10 |
11 | Attributes:
12 | cfg (CfgNode): The detectron2 config node.
13 | _loader (iterator): An iterator over the validation dataset.
14 | """
15 |
16 | def __init__(self, cfg):
17 | """
18 | Args:
19 | cfg (CfgNode): The detectron2 config node.
20 | """
21 | super().__init__()
22 | self.cfg = cfg.clone()
23 | # Switch to the validation dataset
24 | self.cfg.DATASETS.TRAIN = cfg.DATASETS.VAL
25 | # Build the validation data loader iterator
26 | self._loader = iter(build_detection_train_loader(self.cfg))
27 |
28 | def after_step(self):
29 | """
30 | Computes the validation loss after each training step.
31 | """
32 | # Get the next batch of data from the validation data loader
33 | data = next(self._loader)
34 | with torch.no_grad():
35 | # Compute the validation loss on the current batch of data
36 | loss_dict = self.trainer.model(data)
37 |
38 | # Check for invalid losses
39 | losses = sum(loss_dict.values())
40 | assert torch.isfinite(losses).all(), loss_dict
41 |
42 | # Reduce the loss across all workers
43 | loss_dict_reduced = {"val_" + k: v.item() for k, v in
44 | comm.reduce_dict(loss_dict).items()}
45 | losses_reduced = sum(loss for loss in loss_dict_reduced.values())
46 |
47 | # Save the validation loss in the trainer storage
48 | if comm.is_main_process():
49 | self.trainer.storage.put_scalars(total_val_loss=losses_reduced,
50 | **loss_dict_reduced)
51 |
--------------------------------------------------------------------------------
/plot_loss.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import json
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 |
7 |
8 | def moving_average(a, n=3):
9 | ret = np.cumsum(a, dtype=float)
10 | ret[n:] = ret[n:] - ret[:-n]
11 | return ret[n - 1:] / n
12 |
13 |
14 | metrics_file = './metrics.json'
15 |
16 | with open(metrics_file, 'r') as f:
17 | metrics = [ast.literal_eval(l[:-1]) for l in f.readlines()]
18 | f.close()
19 |
20 | train_loss = [float(v['loss_box_reg']) for v in metrics if 'loss_box_reg' in v.keys()]
21 | val_loss = [float(v['val_loss_box_reg']) for v in metrics if 'val_loss_box_reg' in v.keys()]
22 |
23 | N = 40
24 |
25 | train_loss_avg = moving_average(train_loss, n=N)
26 | val_loss_avg = moving_average(val_loss, n=N)
27 |
28 | plt.plot(range(20 * N - 1, 20 * len(train_loss), 20), train_loss_avg, label='train loss')
29 | plt.plot(range(20 * N - 1, 20 * len(train_loss), 20), val_loss_avg, label='val loss')
30 | plt.legend()
31 | plt.grid()
32 | plt.show()
33 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | from detectron2.config import get_cfg
2 | from detectron2.engine import DefaultPredictor
3 | from detectron2 import model_zoo
4 | import cv2
5 |
6 |
7 | # Load config from a config file
8 | cfg = get_cfg()
9 | cfg.merge_from_file(model_zoo.get_config_file('COCO-Detection/retinanet_R_101_FPN_3x.yaml'))
10 | cfg.MODEL.WEIGHTS = './model_0002999.pth'
11 | cfg.MODEL.DEVICE = 'cpu'
12 |
13 | # Create predictor instance
14 | predictor = DefaultPredictor(cfg)
15 |
16 | # Load image
17 | image = cv2.imread("./data/val/imgs/3e115eab82413cd4.jpg")
18 |
19 | # Perform prediction
20 | outputs = predictor(image)
21 |
22 | threshold = 0.5
23 |
24 | # Display predictions
25 | preds = outputs["instances"].pred_classes.tolist()
26 | scores = outputs["instances"].scores.tolist()
27 | bboxes = outputs["instances"].pred_boxes
28 |
29 | for j, bbox in enumerate(bboxes):
30 | bbox = bbox.tolist()
31 |
32 | score = scores[j]
33 | pred = preds[j]
34 |
35 | if score > threshold:
36 | x1, y1, x2, y2 = [int(i) for i in bbox]
37 |
38 | cv2.rectangle(image, (x1, y1), (x2, y2), (0, 0, 255), 5)
39 |
40 | cv2.imshow('image', image)
41 | cv2.waitKey(0)
42 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.0
2 | torchvision==0.15.1
3 | opencv-python==4.6.0.66
4 | matplotlib==3.5.3
5 | git+https://github.com/facebookresearch/detectron2.git
6 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import util
4 |
5 |
6 | if __name__ == "__main__":
7 | """
8 | annotations should be provided in yolo format, this is:
9 | class, xc, yc, w, h
10 | data needs to follow this structure:
11 |
12 | data-dir
13 | ----- train
14 | --------- imgs
15 | ------------ filename0001.jpg
16 | ------------ filename0002.jpg
17 | ------------ ....
18 | --------- anns
19 | ------------ filename0001.txt
20 | ------------ filename0002.txt
21 | ------------ ....
22 | ----- val
23 | --------- imgs
24 | ------------ filename0001.jpg
25 | ------------ filename0002.jpg
26 | ------------ ....
27 | --------- anns
28 | ------------ filename0001.txt
29 | ------------ filename0002.txt
30 | ------------ ....
31 |
32 | """
33 |
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument('--class-list', default='./class.names')
36 | parser.add_argument('--data-dir', default='./data')
37 | parser.add_argument('--output-dir', default='./output')
38 | parser.add_argument('--device', default='cpu')
39 | parser.add_argument('--learning-rate', default=0.00025)
40 | parser.add_argument('--batch-size', default=4)
41 | parser.add_argument('--iterations', default=10000)
42 | parser.add_argument('--checkpoint-period', default=500)
43 | parser.add_argument('--model', default='COCO-Detection/retinanet_R_101_FPN_3x.yaml')
44 |
45 | args = parser.parse_args()
46 |
47 | util.train(args.output_dir,
48 | args.data_dir,
49 | args.class_list,
50 | device=args.device,
51 | learning_rate=float(args.learning_rate),
52 | batch_size=int(args.batch_size),
53 | iterations=int(args.iterations),
54 | checkpoint_period=int(args.checkpoint_period),
55 | model=args.model)
56 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from detectron2.engine import DefaultTrainer
4 | from detectron2.data import MetadataCatalog, DatasetCatalog
5 | from detectron2.structures import BoxMode
6 | from detectron2.config import get_cfg as _get_cfg
7 | from detectron2 import model_zoo
8 |
9 | from loss import ValidationLoss
10 |
11 | import cv2
12 |
13 |
14 | def get_cfg(output_dir, learning_rate, batch_size, iterations, checkpoint_period, model, device, nmr_classes):
15 | """
16 | Create a Detectron2 configuration object and set its attributes.
17 |
18 | Args:
19 | output_dir (str): The path to the output directory where the trained model and logs will be saved.
20 | learning_rate (float): The learning rate for the optimizer.
21 | batch_size (int): The batch size used during training.
22 | iterations (int): The maximum number of training iterations.
23 | checkpoint_period (int): The number of iterations between consecutive checkpoints.
24 | model (str): The name of the model to use, which should be one of the models available in Detectron2's model zoo.
25 | device (str): The device to use for training, which should be 'cpu' or 'cuda'.
26 | nmr_classes (int): The number of classes in the dataset.
27 |
28 | Returns:
29 | The Detectron2 configuration object.
30 | """
31 | cfg = _get_cfg()
32 |
33 | # Merge the model's default configuration file with the default Detectron2 configuration file.
34 | cfg.merge_from_file(model_zoo.get_config_file(model))
35 |
36 | # Set the training and validation datasets and exclude the test dataset.
37 | cfg.DATASETS.TRAIN = ("train",)
38 | cfg.DATASETS.VAL = ("val",)
39 | cfg.DATASETS.TEST = ()
40 |
41 | # Set the device to use for training.
42 | if device in ['cpu']:
43 | cfg.MODEL.DEVICE = 'cpu'
44 |
45 | # Set the number of data loader workers.
46 | cfg.DATALOADER.NUM_WORKERS = 2
47 |
48 | # Set the model weights to the ones pre-trained on the COCO dataset.
49 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(model)
50 |
51 | # Set the batch size used by the solver.
52 | cfg.SOLVER.IMS_PER_BATCH = batch_size
53 |
54 | # Set the checkpoint period.
55 | cfg.SOLVER.CHECKPOINT_PERIOD = checkpoint_period
56 |
57 | # Set the base learning rate.
58 | cfg.SOLVER.BASE_LR = learning_rate
59 |
60 | # Set the maximum number of training iterations.
61 | cfg.SOLVER.MAX_ITER = iterations
62 |
63 | # Set the learning rate scheduler steps to an empty list, which means the learning rate will not be decayed.
64 | cfg.SOLVER.STEPS = []
65 |
66 | # Set the batch size used by the ROI heads during training.
67 | cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
68 |
69 | # Set the number of classes.
70 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = nmr_classes
71 |
72 | # Set the output directory.
73 | cfg.OUTPUT_DIR = output_dir
74 |
75 | return cfg
76 |
77 |
78 | def get_dicts(img_dir, ann_dir):
79 | """
80 | Read the annotations for the dataset in YOLO format and create a list of dictionaries containing information for each
81 | image.
82 |
83 | Args:
84 | img_dir (str): Directory containing the images.
85 | ann_dir (str): Directory containing the annotations.
86 |
87 | Returns:
88 | list[dict]: A list of dictionaries containing information for each image. Each dictionary has the following keys:
89 | - file_name: The path to the image file.
90 | - image_id: The unique identifier for the image.
91 | - height: The height of the image in pixels.
92 | - width: The width of the image in pixels.
93 | - annotations: A list of dictionaries, one for each object in the image, containing the following keys:
94 | - bbox: A list of four integers [x0, y0, w, h] representing the bounding box of the object in the image,
95 | where (x0, y0) is the top-left corner and (w, h) are the width and height of the bounding box,
96 | respectively.
97 | - bbox_mode: A constant from the `BoxMode` class indicating the format of the bounding box coordinates
98 | (e.g., `BoxMode.XYWH_ABS` for absolute coordinates in the format [x0, y0, w, h]).
99 | - category_id: The integer ID of the object's class.
100 | """
101 | dataset_dicts = []
102 | for idx, file in enumerate(os.listdir(ann_dir)):
103 | # annotations should be provided in yolo format
104 |
105 | record = {}
106 |
107 | filename = os.path.join(img_dir, file[:-4] + '.jpg')
108 | height, width = cv2.imread(filename).shape[:2]
109 |
110 | record["file_name"] = filename
111 | record["image_id"] = idx
112 | record["height"] = height
113 | record["width"] = width
114 |
115 | objs = []
116 | with open(os.path.join(ann_dir, file)) as r:
117 | lines = [l[:-1] for l in r.readlines()]
118 |
119 | for _, line in enumerate(lines):
120 | if len(line) > 2:
121 | label, cx, cy, w_, h_ = line.split(' ')
122 |
123 | obj = {
124 | "bbox": [int((float(cx) - (float(w_) / 2)) * width),
125 | int((float(cy) - (float(h_) / 2)) * height),
126 | int(float(w_) * width),
127 | int(float(h_) * height)],
128 | "bbox_mode": BoxMode.XYWH_ABS,
129 | "category_id": int(label),
130 | }
131 |
132 | objs.append(obj)
133 | record["annotations"] = objs
134 | dataset_dicts.append(record)
135 | return dataset_dicts
136 |
137 |
138 | def register_datasets(root_dir, class_list_file):
139 | """
140 | Registers the train and validation datasets and returns the number of classes.
141 |
142 | Args:
143 | root_dir (str): Path to the root directory of the dataset.
144 | class_list_file (str): Path to the file containing the list of class names.
145 |
146 | Returns:
147 | int: The number of classes in the dataset.
148 | """
149 | # Read the list of class names from the class list file.
150 | with open(class_list_file, 'r') as reader:
151 | classes_ = [l[:-1] for l in reader.readlines()]
152 |
153 | # Register the train and validation datasets.
154 | for d in ['train', 'val']:
155 | DatasetCatalog.register(d, lambda d=d: get_dicts(os.path.join(root_dir, d, 'imgs'),
156 | os.path.join(root_dir, d, 'anns')))
157 | # Set the metadata for the dataset.
158 | MetadataCatalog.get(d).set(thing_classes=classes_)
159 |
160 | return len(classes_)
161 |
162 |
163 | def train(output_dir, data_dir, class_list_file, learning_rate, batch_size, iterations, checkpoint_period, device,
164 | model):
165 | """
166 | Train a Detectron2 model on a custom dataset.
167 |
168 | Args:
169 | output_dir (str): Path to the directory to save the trained model and output files.
170 | data_dir (str): Path to the directory containing the dataset.
171 | class_list_file (str): Path to the file containing the list of class names in the dataset.
172 | learning_rate (float): Learning rate for the optimizer.
173 | batch_size (int): Batch size for training.
174 | iterations (int): Maximum number of training iterations.
175 | checkpoint_period (int): Number of iterations after which to save a checkpoint of the model.
176 | device (str): Device to use for training (e.g., 'cpu' or 'cuda').
177 | model (str): Name of the model configuration to use. Must be a key in the Detectron2 model zoo.
178 |
179 | Returns:
180 | None
181 | """
182 |
183 | # Register the dataset and get the number of classes
184 | nmr_classes = register_datasets(data_dir, class_list_file)
185 |
186 | # Get the configuration for the model
187 | cfg = get_cfg(output_dir, learning_rate, batch_size, iterations, checkpoint_period, model, device, nmr_classes)
188 |
189 | # Create the output directory
190 | os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
191 |
192 | # Create the trainer object
193 | trainer = DefaultTrainer(cfg)
194 |
195 | # Create a custom validation loss object
196 | val_loss = ValidationLoss(cfg)
197 |
198 | # Register the custom validation loss object as a hook to the trainer
199 | trainer.register_hooks([val_loss])
200 |
201 | # Swap the positions of the evaluation and checkpointing hooks so that the validation loss is logged correctly
202 | trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1]
203 |
204 | # Resume training from a checkpoint or load the initial model weights
205 | trainer.resume_or_load(resume=False)
206 |
207 | # Train the model
208 | trainer.train()
209 |
--------------------------------------------------------------------------------
/visualize_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import matplotlib.pyplot as plt
5 | import cv2
6 |
7 |
8 | IMGS_DIR = os.path.join('.', 'data', 'train', 'imgs')
9 | ANNS_DIR = os.path.join('.', 'data', 'train', 'anns')
10 |
11 | if __name__ == "__main__":
12 | files = os.listdir(IMGS_DIR)
13 | while True:
14 | fig = plt.figure()
15 | k = random.randint(0, len(files) - 1)
16 | img = cv2.imread(os.path.join(IMGS_DIR, files[k]))
17 | ann_file = os.path.join(ANNS_DIR, files[k][:-4] + '.txt')
18 |
19 | h_img, w_img, _ = img.shape
20 | with open(ann_file, 'r') as f:
21 | lines = [l[:-1] for l in f.readlines() if len(l) > 2]
22 | for line in lines:
23 | line = line.split(' ')
24 | class_, x0, y0, w, h = line
25 | x1 = int((float(x0) - (float(w) / 2)) * w_img)
26 | y1 = int((float(y0) - (float(h) / 2)) * h_img)
27 | x2 = x1 + int(float(w) * w_img)
28 | y2 = y1 + int(float(h) * h_img)
29 | img = cv2.rectangle(img,
30 | (x1, y1),
31 | (x2, y2),
32 | (0, 255, 0),
33 | 4)
34 | mng = plt.get_current_fig_manager()
35 | mng.resize(*mng.window.maxsize())
36 | plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
37 | plt.show()
38 |
--------------------------------------------------------------------------------