├── LICENSE ├── README.md ├── Train_Detectron2_Object_Detector_Custom_Data.ipynb ├── class.names ├── data └── .gitignore ├── loss.py ├── plot_loss.py ├── predict.py ├── requirements.txt ├── train.py ├── util.py └── visualize_data.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 computervisiondeveloper 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # train-object-detector-detectron2 2 | 3 |

4 | 5 | Watch the video 6 |
Watch on YouTube: Train object detector custom data with detectron2 ! 7 |
8 |

9 | -------------------------------------------------------------------------------- /Train_Detectron2_Object_Detector_Custom_Data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [] 7 | }, 8 | "kernelspec": { 9 | "name": "python3", 10 | "display_name": "Python 3" 11 | }, 12 | "language_info": { 13 | "name": "python" 14 | }, 15 | "accelerator": "GPU", 16 | "gpuClass": "standard" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "source": [ 22 | "### Connect Google Drive" 23 | ], 24 | "metadata": { 25 | "id": "-GIYjVIA96Ne" 26 | } 27 | }, 28 | { 29 | "cell_type": "code", 30 | "source": [ 31 | "from google.colab import drive\n", 32 | "\n", 33 | "drive.mount('/content/gdrive')" 34 | ], 35 | "metadata": { 36 | "id": "-wxgALmBx6Rj" 37 | }, 38 | "execution_count": null, 39 | "outputs": [] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "source": [ 44 | "### Install requirements" 45 | ], 46 | "metadata": { 47 | "id": "D9ZsFr6--Lxg" 48 | } 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": { 54 | "id": "xWh477fIw_al" 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "!pip install torch==2.0.0\n", 59 | "!pip install torchvision==0.15.1\n", 60 | "!pip install opencv-python==4.6.0.66\n", 61 | "!pip install matplotlib==3.5.3\n", 62 | "!pip install git+https://github.com/facebookresearch/detectron2.git\n" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "source": [ 68 | "### Change working directory" 69 | ], 70 | "metadata": { 71 | "id": "q1PMCpJK-QNH" 72 | } 73 | }, 74 | { 75 | "cell_type": "code", 76 | "source": [ 77 | "%cd /content/gdrive/MyDrive/ComputerVisionEngineer/TrainDetectron2ObjectDetector/" 78 | ], 79 | "metadata": { 80 | "id": "La68ZGbq0Qni" 81 | }, 82 | "execution_count": null, 83 | "outputs": [] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "source": [ 88 | "### Train !" 89 | ], 90 | "metadata": { 91 | "id": "3chYfIWs-Sqa" 92 | } 93 | }, 94 | { 95 | "cell_type": "code", 96 | "source": [ 97 | "!python train.py --device gpu --learning-rate 0.00001 --iterations 6000" 98 | ], 99 | "metadata": { 100 | "id": "5xDV20jH0mzk" 101 | }, 102 | "execution_count": null, 103 | "outputs": [] 104 | } 105 | ] 106 | } -------------------------------------------------------------------------------- /class.names: -------------------------------------------------------------------------------- 1 | alpaca 2 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | .gitignore 2 | 3 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | from detectron2.engine import HookBase 2 | from detectron2.data import build_detection_train_loader 3 | import detectron2.utils.comm as comm 4 | import torch 5 | 6 | 7 | class ValidationLoss(HookBase): 8 | """ 9 | A hook that computes validation loss during training. 10 | 11 | Attributes: 12 | cfg (CfgNode): The detectron2 config node. 13 | _loader (iterator): An iterator over the validation dataset. 14 | """ 15 | 16 | def __init__(self, cfg): 17 | """ 18 | Args: 19 | cfg (CfgNode): The detectron2 config node. 20 | """ 21 | super().__init__() 22 | self.cfg = cfg.clone() 23 | # Switch to the validation dataset 24 | self.cfg.DATASETS.TRAIN = cfg.DATASETS.VAL 25 | # Build the validation data loader iterator 26 | self._loader = iter(build_detection_train_loader(self.cfg)) 27 | 28 | def after_step(self): 29 | """ 30 | Computes the validation loss after each training step. 31 | """ 32 | # Get the next batch of data from the validation data loader 33 | data = next(self._loader) 34 | with torch.no_grad(): 35 | # Compute the validation loss on the current batch of data 36 | loss_dict = self.trainer.model(data) 37 | 38 | # Check for invalid losses 39 | losses = sum(loss_dict.values()) 40 | assert torch.isfinite(losses).all(), loss_dict 41 | 42 | # Reduce the loss across all workers 43 | loss_dict_reduced = {"val_" + k: v.item() for k, v in 44 | comm.reduce_dict(loss_dict).items()} 45 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 46 | 47 | # Save the validation loss in the trainer storage 48 | if comm.is_main_process(): 49 | self.trainer.storage.put_scalars(total_val_loss=losses_reduced, 50 | **loss_dict_reduced) 51 | -------------------------------------------------------------------------------- /plot_loss.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | 8 | def moving_average(a, n=3): 9 | ret = np.cumsum(a, dtype=float) 10 | ret[n:] = ret[n:] - ret[:-n] 11 | return ret[n - 1:] / n 12 | 13 | 14 | metrics_file = './metrics.json' 15 | 16 | with open(metrics_file, 'r') as f: 17 | metrics = [ast.literal_eval(l[:-1]) for l in f.readlines()] 18 | f.close() 19 | 20 | train_loss = [float(v['loss_box_reg']) for v in metrics if 'loss_box_reg' in v.keys()] 21 | val_loss = [float(v['val_loss_box_reg']) for v in metrics if 'val_loss_box_reg' in v.keys()] 22 | 23 | N = 40 24 | 25 | train_loss_avg = moving_average(train_loss, n=N) 26 | val_loss_avg = moving_average(val_loss, n=N) 27 | 28 | plt.plot(range(20 * N - 1, 20 * len(train_loss), 20), train_loss_avg, label='train loss') 29 | plt.plot(range(20 * N - 1, 20 * len(train_loss), 20), val_loss_avg, label='val loss') 30 | plt.legend() 31 | plt.grid() 32 | plt.show() 33 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from detectron2.config import get_cfg 2 | from detectron2.engine import DefaultPredictor 3 | from detectron2 import model_zoo 4 | import cv2 5 | 6 | 7 | # Load config from a config file 8 | cfg = get_cfg() 9 | cfg.merge_from_file(model_zoo.get_config_file('COCO-Detection/retinanet_R_101_FPN_3x.yaml')) 10 | cfg.MODEL.WEIGHTS = './model_0002999.pth' 11 | cfg.MODEL.DEVICE = 'cpu' 12 | 13 | # Create predictor instance 14 | predictor = DefaultPredictor(cfg) 15 | 16 | # Load image 17 | image = cv2.imread("./data/val/imgs/3e115eab82413cd4.jpg") 18 | 19 | # Perform prediction 20 | outputs = predictor(image) 21 | 22 | threshold = 0.5 23 | 24 | # Display predictions 25 | preds = outputs["instances"].pred_classes.tolist() 26 | scores = outputs["instances"].scores.tolist() 27 | bboxes = outputs["instances"].pred_boxes 28 | 29 | for j, bbox in enumerate(bboxes): 30 | bbox = bbox.tolist() 31 | 32 | score = scores[j] 33 | pred = preds[j] 34 | 35 | if score > threshold: 36 | x1, y1, x2, y2 = [int(i) for i in bbox] 37 | 38 | cv2.rectangle(image, (x1, y1), (x2, y2), (0, 0, 255), 5) 39 | 40 | cv2.imshow('image', image) 41 | cv2.waitKey(0) 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.0 2 | torchvision==0.15.1 3 | opencv-python==4.6.0.66 4 | matplotlib==3.5.3 5 | git+https://github.com/facebookresearch/detectron2.git 6 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import util 4 | 5 | 6 | if __name__ == "__main__": 7 | """ 8 | annotations should be provided in yolo format, this is: 9 | class, xc, yc, w, h 10 | data needs to follow this structure: 11 | 12 | data-dir 13 | ----- train 14 | --------- imgs 15 | ------------ filename0001.jpg 16 | ------------ filename0002.jpg 17 | ------------ .... 18 | --------- anns 19 | ------------ filename0001.txt 20 | ------------ filename0002.txt 21 | ------------ .... 22 | ----- val 23 | --------- imgs 24 | ------------ filename0001.jpg 25 | ------------ filename0002.jpg 26 | ------------ .... 27 | --------- anns 28 | ------------ filename0001.txt 29 | ------------ filename0002.txt 30 | ------------ .... 31 | 32 | """ 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--class-list', default='./class.names') 36 | parser.add_argument('--data-dir', default='./data') 37 | parser.add_argument('--output-dir', default='./output') 38 | parser.add_argument('--device', default='cpu') 39 | parser.add_argument('--learning-rate', default=0.00025) 40 | parser.add_argument('--batch-size', default=4) 41 | parser.add_argument('--iterations', default=10000) 42 | parser.add_argument('--checkpoint-period', default=500) 43 | parser.add_argument('--model', default='COCO-Detection/retinanet_R_101_FPN_3x.yaml') 44 | 45 | args = parser.parse_args() 46 | 47 | util.train(args.output_dir, 48 | args.data_dir, 49 | args.class_list, 50 | device=args.device, 51 | learning_rate=float(args.learning_rate), 52 | batch_size=int(args.batch_size), 53 | iterations=int(args.iterations), 54 | checkpoint_period=int(args.checkpoint_period), 55 | model=args.model) 56 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from detectron2.engine import DefaultTrainer 4 | from detectron2.data import MetadataCatalog, DatasetCatalog 5 | from detectron2.structures import BoxMode 6 | from detectron2.config import get_cfg as _get_cfg 7 | from detectron2 import model_zoo 8 | 9 | from loss import ValidationLoss 10 | 11 | import cv2 12 | 13 | 14 | def get_cfg(output_dir, learning_rate, batch_size, iterations, checkpoint_period, model, device, nmr_classes): 15 | """ 16 | Create a Detectron2 configuration object and set its attributes. 17 | 18 | Args: 19 | output_dir (str): The path to the output directory where the trained model and logs will be saved. 20 | learning_rate (float): The learning rate for the optimizer. 21 | batch_size (int): The batch size used during training. 22 | iterations (int): The maximum number of training iterations. 23 | checkpoint_period (int): The number of iterations between consecutive checkpoints. 24 | model (str): The name of the model to use, which should be one of the models available in Detectron2's model zoo. 25 | device (str): The device to use for training, which should be 'cpu' or 'cuda'. 26 | nmr_classes (int): The number of classes in the dataset. 27 | 28 | Returns: 29 | The Detectron2 configuration object. 30 | """ 31 | cfg = _get_cfg() 32 | 33 | # Merge the model's default configuration file with the default Detectron2 configuration file. 34 | cfg.merge_from_file(model_zoo.get_config_file(model)) 35 | 36 | # Set the training and validation datasets and exclude the test dataset. 37 | cfg.DATASETS.TRAIN = ("train",) 38 | cfg.DATASETS.VAL = ("val",) 39 | cfg.DATASETS.TEST = () 40 | 41 | # Set the device to use for training. 42 | if device in ['cpu']: 43 | cfg.MODEL.DEVICE = 'cpu' 44 | 45 | # Set the number of data loader workers. 46 | cfg.DATALOADER.NUM_WORKERS = 2 47 | 48 | # Set the model weights to the ones pre-trained on the COCO dataset. 49 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(model) 50 | 51 | # Set the batch size used by the solver. 52 | cfg.SOLVER.IMS_PER_BATCH = batch_size 53 | 54 | # Set the checkpoint period. 55 | cfg.SOLVER.CHECKPOINT_PERIOD = checkpoint_period 56 | 57 | # Set the base learning rate. 58 | cfg.SOLVER.BASE_LR = learning_rate 59 | 60 | # Set the maximum number of training iterations. 61 | cfg.SOLVER.MAX_ITER = iterations 62 | 63 | # Set the learning rate scheduler steps to an empty list, which means the learning rate will not be decayed. 64 | cfg.SOLVER.STEPS = [] 65 | 66 | # Set the batch size used by the ROI heads during training. 67 | cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 68 | 69 | # Set the number of classes. 70 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = nmr_classes 71 | 72 | # Set the output directory. 73 | cfg.OUTPUT_DIR = output_dir 74 | 75 | return cfg 76 | 77 | 78 | def get_dicts(img_dir, ann_dir): 79 | """ 80 | Read the annotations for the dataset in YOLO format and create a list of dictionaries containing information for each 81 | image. 82 | 83 | Args: 84 | img_dir (str): Directory containing the images. 85 | ann_dir (str): Directory containing the annotations. 86 | 87 | Returns: 88 | list[dict]: A list of dictionaries containing information for each image. Each dictionary has the following keys: 89 | - file_name: The path to the image file. 90 | - image_id: The unique identifier for the image. 91 | - height: The height of the image in pixels. 92 | - width: The width of the image in pixels. 93 | - annotations: A list of dictionaries, one for each object in the image, containing the following keys: 94 | - bbox: A list of four integers [x0, y0, w, h] representing the bounding box of the object in the image, 95 | where (x0, y0) is the top-left corner and (w, h) are the width and height of the bounding box, 96 | respectively. 97 | - bbox_mode: A constant from the `BoxMode` class indicating the format of the bounding box coordinates 98 | (e.g., `BoxMode.XYWH_ABS` for absolute coordinates in the format [x0, y0, w, h]). 99 | - category_id: The integer ID of the object's class. 100 | """ 101 | dataset_dicts = [] 102 | for idx, file in enumerate(os.listdir(ann_dir)): 103 | # annotations should be provided in yolo format 104 | 105 | record = {} 106 | 107 | filename = os.path.join(img_dir, file[:-4] + '.jpg') 108 | height, width = cv2.imread(filename).shape[:2] 109 | 110 | record["file_name"] = filename 111 | record["image_id"] = idx 112 | record["height"] = height 113 | record["width"] = width 114 | 115 | objs = [] 116 | with open(os.path.join(ann_dir, file)) as r: 117 | lines = [l[:-1] for l in r.readlines()] 118 | 119 | for _, line in enumerate(lines): 120 | if len(line) > 2: 121 | label, cx, cy, w_, h_ = line.split(' ') 122 | 123 | obj = { 124 | "bbox": [int((float(cx) - (float(w_) / 2)) * width), 125 | int((float(cy) - (float(h_) / 2)) * height), 126 | int(float(w_) * width), 127 | int(float(h_) * height)], 128 | "bbox_mode": BoxMode.XYWH_ABS, 129 | "category_id": int(label), 130 | } 131 | 132 | objs.append(obj) 133 | record["annotations"] = objs 134 | dataset_dicts.append(record) 135 | return dataset_dicts 136 | 137 | 138 | def register_datasets(root_dir, class_list_file): 139 | """ 140 | Registers the train and validation datasets and returns the number of classes. 141 | 142 | Args: 143 | root_dir (str): Path to the root directory of the dataset. 144 | class_list_file (str): Path to the file containing the list of class names. 145 | 146 | Returns: 147 | int: The number of classes in the dataset. 148 | """ 149 | # Read the list of class names from the class list file. 150 | with open(class_list_file, 'r') as reader: 151 | classes_ = [l[:-1] for l in reader.readlines()] 152 | 153 | # Register the train and validation datasets. 154 | for d in ['train', 'val']: 155 | DatasetCatalog.register(d, lambda d=d: get_dicts(os.path.join(root_dir, d, 'imgs'), 156 | os.path.join(root_dir, d, 'anns'))) 157 | # Set the metadata for the dataset. 158 | MetadataCatalog.get(d).set(thing_classes=classes_) 159 | 160 | return len(classes_) 161 | 162 | 163 | def train(output_dir, data_dir, class_list_file, learning_rate, batch_size, iterations, checkpoint_period, device, 164 | model): 165 | """ 166 | Train a Detectron2 model on a custom dataset. 167 | 168 | Args: 169 | output_dir (str): Path to the directory to save the trained model and output files. 170 | data_dir (str): Path to the directory containing the dataset. 171 | class_list_file (str): Path to the file containing the list of class names in the dataset. 172 | learning_rate (float): Learning rate for the optimizer. 173 | batch_size (int): Batch size for training. 174 | iterations (int): Maximum number of training iterations. 175 | checkpoint_period (int): Number of iterations after which to save a checkpoint of the model. 176 | device (str): Device to use for training (e.g., 'cpu' or 'cuda'). 177 | model (str): Name of the model configuration to use. Must be a key in the Detectron2 model zoo. 178 | 179 | Returns: 180 | None 181 | """ 182 | 183 | # Register the dataset and get the number of classes 184 | nmr_classes = register_datasets(data_dir, class_list_file) 185 | 186 | # Get the configuration for the model 187 | cfg = get_cfg(output_dir, learning_rate, batch_size, iterations, checkpoint_period, model, device, nmr_classes) 188 | 189 | # Create the output directory 190 | os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) 191 | 192 | # Create the trainer object 193 | trainer = DefaultTrainer(cfg) 194 | 195 | # Create a custom validation loss object 196 | val_loss = ValidationLoss(cfg) 197 | 198 | # Register the custom validation loss object as a hook to the trainer 199 | trainer.register_hooks([val_loss]) 200 | 201 | # Swap the positions of the evaluation and checkpointing hooks so that the validation loss is logged correctly 202 | trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1] 203 | 204 | # Resume training from a checkpoint or load the initial model weights 205 | trainer.resume_or_load(resume=False) 206 | 207 | # Train the model 208 | trainer.train() 209 | -------------------------------------------------------------------------------- /visualize_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import matplotlib.pyplot as plt 5 | import cv2 6 | 7 | 8 | IMGS_DIR = os.path.join('.', 'data', 'train', 'imgs') 9 | ANNS_DIR = os.path.join('.', 'data', 'train', 'anns') 10 | 11 | if __name__ == "__main__": 12 | files = os.listdir(IMGS_DIR) 13 | while True: 14 | fig = plt.figure() 15 | k = random.randint(0, len(files) - 1) 16 | img = cv2.imread(os.path.join(IMGS_DIR, files[k])) 17 | ann_file = os.path.join(ANNS_DIR, files[k][:-4] + '.txt') 18 | 19 | h_img, w_img, _ = img.shape 20 | with open(ann_file, 'r') as f: 21 | lines = [l[:-1] for l in f.readlines() if len(l) > 2] 22 | for line in lines: 23 | line = line.split(' ') 24 | class_, x0, y0, w, h = line 25 | x1 = int((float(x0) - (float(w) / 2)) * w_img) 26 | y1 = int((float(y0) - (float(h) / 2)) * h_img) 27 | x2 = x1 + int(float(w) * w_img) 28 | y2 = y1 + int(float(h) * h_img) 29 | img = cv2.rectangle(img, 30 | (x1, y1), 31 | (x2, y2), 32 | (0, 255, 0), 33 | 4) 34 | mng = plt.get_current_fig_manager() 35 | mng.resize(*mng.window.maxsize()) 36 | plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) 37 | plt.show() 38 | --------------------------------------------------------------------------------