├── utils ├── image.png ├── val_logs.png ├── train_logs.png ├── model.py ├── training_utils.py ├── dataset.py └── model_utils.py ├── requirements.txt ├── .gitignore ├── LICENSE ├── README.md ├── eval.py └── train.py /utils/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GirinChutia/FasterRCNN-Torchvision-FineTuning/HEAD/utils/image.png -------------------------------------------------------------------------------- /utils/val_logs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GirinChutia/FasterRCNN-Torchvision-FineTuning/HEAD/utils/val_logs.png -------------------------------------------------------------------------------- /utils/train_logs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GirinChutia/FasterRCNN-Torchvision-FineTuning/HEAD/utils/train_logs.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.7.1 2 | numpy==1.23.5 3 | opencv_python==4.7.0.72 4 | opencv-python-headless==4.6.0.66 5 | Pillow==10.0.0 6 | pycocotools==2.0.6 7 | simple_parsing==0.1.2.post1 8 | torch==2.0.0+cu117 9 | torchvision==0.15.1+cu117 10 | tqdm==4.65.0 11 | tensorboard==2.13.0 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /dataset 2 | /Count_wheat_spikes_using_finetuned_faster_rcnn.ipynb 3 | /wheat_spikes_using_faster_rcnn.ipynb 4 | /test.ipynb 5 | /utils/detection 6 | /__pycache__ 7 | /utils/__pycache__ 8 | /exp 9 | /weight_outputs 10 | /test.py 11 | /*.jpg 12 | weight_outputs_best/best_model.pth 13 | test_dect.json 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Girin Chutia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torchvision.models.detection.faster_rcnn import FastRCNNPredictor 3 | from torchvision.models.detection import ( 4 | FasterRCNN_ResNet50_FPN_Weights, 5 | FasterRCNN_ResNet50_FPN_V2_Weights, 6 | ) 7 | import warnings 8 | import torch 9 | warnings.filterwarnings("ignore", category=UserWarning) 10 | 11 | 12 | def create_model(num_classes,checkpoint=None,device='cpu'): 13 | """ 14 | Create a model for object detection using the Faster R-CNN architecture. 15 | 16 | Parameters: 17 | - num_classes (int): The number of classes for object detection. (Total classes + 1 [for background class]) 18 | - checkpoint (str) : checkpoint path for the pretrained custom model 19 | - device (str) : cpu / cuda 20 | Returns: 21 | - model (torchvision.models.detection.fasterrcnn_resnet50_fpn): The created model for object detection. 22 | """ 23 | model = torchvision.models.detection.fasterrcnn_resnet50_fpn( 24 | pretrained=True, 25 | weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT, 26 | pretrained_backbone=True, 27 | #weights_backbone = 'ResNet50_Weights.DEFAULT', 28 | ) 29 | in_features = model.roi_heads.box_predictor.cls_score.in_features 30 | model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 31 | if checkpoint: 32 | checkpoint = torch.load(checkpoint, map_location=device) 33 | model.load_state_dict(checkpoint['model_state_dict']) 34 | else: 35 | model = model.to(device) 36 | return model 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Training code for torchvision FasterRCNN model with custom COCO dataset 2 | 3 | --- 4 | # Faster RCNN : 5 | Faster RCNN is an object detection model introduced in [Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks](https://arxiv.org/abs/1506.01497) paper. 6 | 7 | The architechure of Faster RCNN model is shown below, 8 | ![](utils/image.png) 9 | 10 | Faster R-CNN, is composed of two modules. The first module is a deep fully convolutional network that proposes regions, and the second module is the Fast R-CNN detector that uses the proposed regions. 11 | 12 | --- 13 | # Environment : 14 | - Python version used : 3.9.16 15 | - Create a python or conda environment using ***requirements.txt*** 16 | 17 | --- 18 | # Training Instructions : 19 | 20 | To train the Faster RCNN model follow the below steps : 21 | 22 | 1. Prepare dataset : 23 | - Prepare dataset in COCO format. It should have the below 2 files & folders 24 | - Image folder 25 | - Annotation file (Json file) in coco format 26 | 27 | 2. Run : 28 | > python train.py --epoch 10 --train_image_dir --val_image_dir --train_coco_json --val_coco_json --batch_size 16 --exp_folder 29 | 30 | The training weights and tensorboard logs will be saved in experiment folder 31 | 32 | The training and validation logs can be visualized in tensorboard as shown below : 33 | > Train logs 34 | ![Alt text](utils/train_logs.png) 35 | > Val Logs 36 | ![Alt text](utils/val_logs.png) 37 | 38 | --- 39 | # Inference : 40 | 41 | The instruction about inference with a trained model are discussed in ***demo_inference.ipynb*** notebook -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from utils.dataset import CocoDataset 2 | import torch 3 | from utils.model_utils import InferFasterRCNN,display_gt_pred 4 | from pycocotools.coco import COCO 5 | import os 6 | from pycocotools.cocoeval import COCOeval 7 | from tqdm import tqdm 8 | import json 9 | import gc 10 | 11 | def save_json(data, file_path): 12 | with open(file_path, 'w') as file: 13 | json.dump(data, file) 14 | 15 | def evaluate_model(image_dir, 16 | gt_ann_file, 17 | model_weight): 18 | 19 | _ds = CocoDataset( 20 | image_folder=image_dir, 21 | annotations_file=gt_ann_file, 22 | height=640, 23 | width=640, 24 | ) 25 | 26 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 27 | 28 | IF_C = InferFasterRCNN(num_classes=_ds.get_total_classes_count() + 1, 29 | classnames=_ds.get_classnames()) 30 | 31 | IF_C.load_model(checkpoint=model_weight, 32 | device=device) 33 | 34 | image_dir = image_dir 35 | 36 | cocoGt=COCO(annotation_file=gt_ann_file) 37 | imgIds = cocoGt.getImgIds() # all image ids 38 | 39 | res_id = 1 40 | res_all = [] 41 | 42 | for id in tqdm(imgIds,total=len(imgIds)): 43 | id = id 44 | img_info = cocoGt.loadImgs(imgIds[id])[0] 45 | annIds = cocoGt.getAnnIds(imgIds=img_info['id']) 46 | ann_info = cocoGt.loadAnns(annIds) 47 | image_path = os.path.join(image_dir, 48 | img_info['file_name']) 49 | transform_info = CocoDataset.transform_image_for_inference(image_path,width=640,height=640) 50 | result = IF_C.infer_image(transform_info=transform_info, 51 | visualize=False) 52 | 53 | if len(result)>0: 54 | pred_boxes_xyxy = result['unscaled_boxes'] 55 | pred_boxes_xywh = [[i[0],i[1],i[2]-i[0],i[3]-i[1]] for i in pred_boxes_xyxy] 56 | pred_classes = result['pred_classes'] 57 | pred_scores = result['scores'] 58 | pred_labels = result['labels'] 59 | 60 | for i in range(len(pred_boxes_xywh)): 61 | res_temp = {"id":res_id, 62 | "image_id":id, 63 | "bbox":pred_boxes_xywh[i], 64 | "segmentation":[], 65 | "iscrowd": 0, 66 | "category_id": int(pred_labels[i]), 67 | "area":pred_boxes_xywh[i][2]*pred_boxes_xywh[i][3], 68 | "score": float(pred_scores[i])} 69 | res_all.append(res_temp) 70 | res_id+=1 71 | 72 | save_json_path = 'test_dect.json' 73 | save_json(res_all,save_json_path) 74 | 75 | cocoGt=COCO(gt_ann_file) 76 | cocoDt=cocoGt.loadRes(save_json_path) 77 | 78 | cocoEval = COCOeval(cocoGt,cocoDt,iouType='bbox') 79 | cocoEval.evaluate() 80 | cocoEval.accumulate() 81 | cocoEval.summarize() 82 | 83 | AP_50_95 = cocoEval.stats.tolist()[0] 84 | AP_50 = cocoEval.stats.tolist()[1] 85 | 86 | del IF_C,_ds 87 | os.remove(save_json_path) 88 | 89 | torch.cuda.empty_cache() 90 | gc.collect() 91 | 92 | return {'AP_50_95':AP_50_95, 93 | 'AP_50':AP_50} -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from time import sleep 3 | from utils.dataset import CocoDataset 4 | from utils.model import create_model 5 | from utils.training_utils import SaveBestModel,train_one_epoch,val_one_epoch,get_datasets 6 | import torch 7 | import os 8 | import time 9 | from datetime import datetime 10 | from dataclasses import dataclass 11 | from simple_parsing import ArgumentParser 12 | from eval import evaluate_model 13 | from torch import nn 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | def train( 17 | train_dataset, 18 | val_dataset, 19 | epochs=2, 20 | batch_size=8, 21 | exp_folder="exp", 22 | val_eval_freq=1, 23 | ): 24 | 25 | date_format = "%d-%m-%Y-%H-%M-%S" 26 | date_string = time.strftime(date_format) 27 | 28 | exp_folder = os.path.join("exp", "summary", date_string) 29 | writer = SummaryWriter(exp_folder) 30 | 31 | def custom_collate(data): 32 | return data 33 | 34 | # Dataloaders -- 35 | train_dl = torch.utils.data.DataLoader( 36 | train_dataset, 37 | batch_size=batch_size, 38 | shuffle=True, 39 | collate_fn=custom_collate, 40 | pin_memory=True if torch.cuda.is_available() else False, 41 | ) 42 | 43 | val_dl = torch.utils.data.DataLoader( 44 | val_dataset, 45 | batch_size=batch_size, 46 | shuffle=False, 47 | collate_fn=custom_collate, 48 | pin_memory=True if torch.cuda.is_available() else False, 49 | ) 50 | 51 | # Device -- 52 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 53 | 54 | # Model -- 55 | faster_rcnn_model = create_model(train_dataset.get_total_classes_count() + 1) 56 | faster_rcnn_model = faster_rcnn_model.to(device) 57 | 58 | # Optimizer -- 59 | pg0, pg1, pg2 = [], [], [] # optimizer parameter groups 60 | 61 | for k, v in faster_rcnn_model.named_modules(): 62 | if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter): 63 | pg2.append(v.bias) # biases 64 | if isinstance(v, nn.BatchNorm2d) or "bn" in k: 65 | pg0.append(v.weight) # no decay 66 | elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter): 67 | pg1.append(v.weight) # apply decay 68 | 69 | optimizer = torch.optim.SGD( 70 | pg0, lr=0.001, momentum=0.9, nesterov=True 71 | ) # BN 72 | 73 | optimizer.add_param_group( 74 | {"params": pg1, "weight_decay": 5e-4} 75 | ) # add pg1 with weight_decay # Weights 76 | 77 | optimizer.add_param_group({"params": pg2}) # Biases 78 | 79 | 80 | num_epochs = epochs 81 | save_best_model = SaveBestModel(output_dir=exp_folder) 82 | 83 | for epoch in range(num_epochs): 84 | 85 | faster_rcnn_model, optimizer, writer, epoch_loss = train_one_epoch( 86 | faster_rcnn_model, 87 | train_dl, 88 | optimizer, 89 | writer, 90 | epoch + 1, 91 | num_epochs, 92 | device, 93 | ) 94 | 95 | sleep(0.1) 96 | 97 | if (epoch % val_eval_freq == 0) and epoch != 0: # Do evaluation of validation set 98 | eval_result = evaluate_model(image_dir=val_dataset.image_folder, 99 | gt_ann_file=val_dataset.annotations_file, 100 | model_weight=save_best_model.model_save_path) 101 | 102 | sleep(0.1) 103 | 104 | writer.add_scalar("Val/AP_50_95", eval_result['AP_50_95'], epoch + 1) 105 | writer.add_scalar("Val/AP_50", eval_result['AP_50'], epoch + 1) 106 | 107 | else: 108 | writer, val_epoch_loss = val_one_epoch( 109 | faster_rcnn_model, 110 | val_dl, 111 | writer, 112 | epoch + 1, 113 | num_epochs, 114 | device, 115 | log=True, 116 | ) 117 | 118 | sleep(0.1) 119 | 120 | save_best_model(val_epoch_loss, 121 | epoch, 122 | faster_rcnn_model, 123 | optimizer) 124 | 125 | 126 | _, _ = val_one_epoch( 127 | faster_rcnn_model, val_dl, writer, epoch + 1, num_epochs, device, log=False 128 | ) 129 | 130 | writer.add_hparams( 131 | {"epochs": epochs, "batch_size": batch_size}, 132 | {"Train/total_loss": epoch_loss, "Val/total_loss": val_epoch_loss}, 133 | ) 134 | 135 | @dataclass 136 | class DatasetPaths: 137 | train_image_dir: str = r"D:\Work\work\FasterRCNN-Torchvision-FineTuning\dataset\AquariumDataset\train\images" 138 | val_image_dir: str = r"D:\Work\work\FasterRCNN-Torchvision-FineTuning\dataset\AquariumDataset\valid\images" 139 | train_coco_json: str = r"D:\Work\work\FasterRCNN-Torchvision-FineTuning\dataset\AquariumDataset\train\_annotations.coco_neg.json" 140 | val_coco_json: str = r"D:\Work\work\FasterRCNN-Torchvision-FineTuning\dataset\AquariumDataset\valid\_annotations.coco.json" 141 | 142 | @dataclass 143 | class TrainingConfig: 144 | epochs: int = 15 145 | batch_size: int = 6 146 | val_eval_freq: int = 2 147 | exp_folder: str = 'exp' 148 | 149 | if __name__ == "__main__": 150 | 151 | parser = ArgumentParser() 152 | parser.add_arguments(DatasetPaths,dest='dataset_config') 153 | parser.add_arguments(TrainingConfig,dest='training_config') 154 | args = parser.parse_args() 155 | 156 | dataset_config: DatasetPaths = args.dataset_config 157 | training_config: TrainingConfig = args.training_config 158 | 159 | train_ds, val_ds = get_datasets(train_image_dir=dataset_config.train_image_dir, 160 | train_coco_json=dataset_config.train_coco_json, 161 | val_image_dir=dataset_config.val_image_dir, 162 | val_coco_json=dataset_config.val_coco_json) 163 | train(train_ds, val_ds, 164 | epochs=training_config.epochs, 165 | batch_size=training_config.batch_size, 166 | val_eval_freq=training_config.val_eval_freq, 167 | exp_folder=training_config.exp_folder) 168 | 169 | -------------------------------------------------------------------------------- /utils/training_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import os 4 | from .dataset import CocoDataset 5 | 6 | def get_datasets(train_image_dir:str, 7 | val_image_dir:str, 8 | train_coco_json:str, 9 | val_coco_json:str): 10 | 11 | train_ds = CocoDataset( 12 | image_folder=train_image_dir, 13 | annotations_file=train_coco_json, 14 | height=640, 15 | width=640, 16 | ) 17 | 18 | val_ds = CocoDataset( 19 | image_folder=val_image_dir, 20 | annotations_file=val_coco_json, 21 | height=640, 22 | width=640, 23 | ) 24 | 25 | return train_ds, val_ds 26 | 27 | class SaveBestModel: 28 | """ 29 | Class to save the best model while training. If the current epoch's 30 | validation loss is less than the previous least less, then save the 31 | model state. 32 | """ 33 | def __init__( 34 | self, best_valid_loss=float('inf'), output_dir = 'weight_outputs', 35 | ): 36 | self.best_valid_loss = best_valid_loss 37 | 38 | os.makedirs(output_dir,exist_ok=True) 39 | 40 | self.output_dir = output_dir 41 | 42 | def __call__( 43 | self, current_valid_loss, 44 | epoch, model, optimizer 45 | ): 46 | self.model_save_path = f'{self.output_dir}/best_model.pth' 47 | if current_valid_loss < self.best_valid_loss: 48 | self.best_valid_loss = current_valid_loss 49 | print(f"\nBest validation loss: {self.best_valid_loss}") 50 | print(f"\nSaving best model for epoch: {epoch+1}\n") 51 | torch.save({ 52 | 'epoch': epoch+1, 53 | 'model_state_dict': model.state_dict(), 54 | 'optimizer_state_dict': optimizer.state_dict(), 55 | }, self.model_save_path) 56 | 57 | @torch.inference_mode() 58 | def val_one_epoch(model, val_dl, writer, epoch_no, total_epoch, device, log=True): 59 | with tqdm(val_dl, unit="batch") as tepoch: 60 | epoch_loss = 0 61 | _classifier_loss = 0 62 | _loss_box_reg = 0 63 | _loss_rpn_box_reg = 0 64 | _loss_objectness = 0 65 | for data in tepoch: 66 | tepoch.set_description(f"Val:Epoch {epoch_no}/{total_epoch}") 67 | imgs = [] 68 | targets = [] 69 | for d in data: 70 | imgs.append(d[0].to(device)) 71 | targ = {} 72 | targ["boxes"] = d[1]["boxes"].to(device) 73 | targ["labels"] = d[1]["labels"].to(device) 74 | targets.append(targ) 75 | loss_dict = model(imgs, targets) 76 | 77 | loss = sum(v for v in loss_dict.values()) 78 | classifier_loss = loss_dict.get("loss_classifier").cpu().detach().numpy() 79 | loss_box_reg = loss_dict.get("loss_box_reg").cpu().detach().numpy() 80 | loss_objectness = loss_dict.get("loss_objectness").cpu().detach().numpy() 81 | loss_rpn_box_reg = loss_dict.get("loss_rpn_box_reg").cpu().detach().numpy() 82 | 83 | epoch_loss += loss.cpu().detach().numpy() 84 | _classifier_loss += classifier_loss 85 | _loss_box_reg += loss_box_reg 86 | _loss_objectness += loss_objectness 87 | _loss_rpn_box_reg += loss_rpn_box_reg 88 | 89 | tepoch.set_postfix( 90 | total_loss=epoch_loss, 91 | loss_classifier=_classifier_loss, 92 | boxreg_loss=_loss_box_reg, 93 | obj_loss=_loss_objectness, 94 | rpn_boxreg_loss=_loss_rpn_box_reg, 95 | ) 96 | 97 | if log: 98 | writer.add_scalar("Val/total_loss", epoch_loss, epoch_no) 99 | writer.add_scalar("Val/classifier_loss", _classifier_loss, epoch_no) 100 | writer.add_scalar("Val/box_reg_loss", _loss_box_reg, epoch_no) 101 | writer.add_scalar("Val/objectness_loss", _loss_objectness, epoch_no) 102 | writer.add_scalar("Val/rpn_box_reg_loss", _loss_rpn_box_reg, epoch_no) 103 | 104 | return writer, epoch_loss 105 | 106 | def train_one_epoch(model, train_dl, optimizer, writer, epoch_no, total_epoch, device): 107 | with tqdm(train_dl, unit="batch") as tepoch: 108 | epoch_loss = 0 109 | _classifier_loss = 0 110 | _loss_box_reg = 0 111 | _loss_rpn_box_reg = 0 112 | _loss_objectness = 0 113 | for data in tepoch: 114 | tepoch.set_description(f"Train:Epoch {epoch_no}/{total_epoch}") 115 | imgs = [] 116 | targets = [] 117 | for d in data: 118 | imgs.append(d[0].to(device)) 119 | targ = {} 120 | targ["boxes"] = d[1]["boxes"].to(device) 121 | targ["labels"] = d[1]["labels"].to(device) 122 | targets.append(targ) 123 | loss_dict = model(imgs, targets) 124 | 125 | loss = sum(v for v in loss_dict.values()) 126 | classifier_loss = loss_dict.get("loss_classifier").cpu().detach().numpy() 127 | loss_box_reg = loss_dict.get("loss_box_reg").cpu().detach().numpy() 128 | loss_objectness = loss_dict.get("loss_objectness").cpu().detach().numpy() 129 | loss_rpn_box_reg = loss_dict.get("loss_rpn_box_reg").cpu().detach().numpy() 130 | 131 | epoch_loss += loss.cpu().detach().numpy() 132 | _classifier_loss += classifier_loss 133 | _loss_box_reg += loss_box_reg 134 | _loss_objectness += loss_objectness 135 | _loss_rpn_box_reg += loss_rpn_box_reg 136 | 137 | optimizer.zero_grad() 138 | loss.backward() 139 | optimizer.step() 140 | 141 | tepoch.set_postfix( 142 | total_loss=epoch_loss, 143 | loss_classifier=_classifier_loss, 144 | boxreg_loss=_loss_box_reg, 145 | obj_loss=_loss_objectness, 146 | rpn_boxreg_loss=_loss_rpn_box_reg, 147 | ) 148 | 149 | writer.add_scalar("Train/total_loss", epoch_loss, epoch_no) 150 | writer.add_scalar("Train/classifier_loss", _classifier_loss, epoch_no) 151 | writer.add_scalar("Train/box_reg_loss", _loss_box_reg, epoch_no) 152 | writer.add_scalar("Train/objectness_loss", _loss_objectness, epoch_no) 153 | writer.add_scalar("Train/rpn_box_reg_loss", _loss_rpn_box_reg, epoch_no) 154 | 155 | return model, optimizer, writer, epoch_loss -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import abc, cv2, glob,copy 3 | import torch, os, json 4 | from torch.utils.data import Dataset, DataLoader 5 | import numpy as np 6 | from collections import defaultdict 7 | from torchvision import ops 8 | import matplotlib.patches as patches 9 | from torchvision import transforms as T 10 | from PIL import Image 11 | import matplotlib.pyplot as plt 12 | 13 | COCOBox_base = namedtuple("COCOBox", ["xmin", "ymin", "width", "height"]) 14 | VOCBox_base = namedtuple("VOCBox", ["xmin", "ymin", "xmax", "ymax"]) 15 | 16 | class COCOBox(COCOBox_base): 17 | def area(self): 18 | return self.width * self.height 19 | 20 | 21 | class VOCBox(VOCBox_base): 22 | def area(self): 23 | return (self.xmax - self.xmin) * (self.ymax - self.ymin) 24 | 25 | 26 | # Define the abstract base class for loading datasets 27 | class DatasetLoader(metaclass=abc.ABCMeta): 28 | @abc.abstractmethod 29 | def load_images(self): 30 | pass 31 | 32 | @abc.abstractmethod 33 | def load_annotations(self): 34 | pass 35 | 36 | 37 | # the dataset class 38 | class CocoDataset(Dataset): 39 | def __init__(self, image_folder, annotations_file, width, height, transforms=None): 40 | 41 | self.transforms = transforms 42 | self.image_folder = image_folder 43 | self.annotations_file = annotations_file 44 | self.height = height 45 | self.width = width 46 | 47 | if not isinstance(self.image_folder, str): 48 | raise ValueError("image_folder should be a string") 49 | 50 | if not isinstance(annotations_file, str): 51 | raise ValueError("annotations_file should be a string") 52 | 53 | self.annotations_file = annotations_file 54 | self.image_folder = image_folder 55 | self.width = width 56 | self.height = height 57 | 58 | with open(annotations_file, "r") as f: 59 | self.annotations = json.load(f) 60 | 61 | self.image_ids = defaultdict(list) 62 | for i in self.annotations["images"]: 63 | self.image_ids[i["id"]] = i # key = image_id 64 | 65 | self.annotation_ids = defaultdict(list) 66 | for i in self.annotations["annotations"]: 67 | self.annotation_ids[i["image_id"]].append(i) # key = image_id 68 | 69 | self.cats_id2label = {} 70 | self.label_names = [] 71 | 72 | first_label_id = self.annotations["categories"][0]["id"] 73 | if first_label_id == 0: 74 | for i in self.annotations["categories"][1:]: 75 | self.cats_id2label[i["id"]] = i["name"] 76 | self.label_names.append(i["name"]) 77 | if first_label_id == 1: 78 | for i in self.annotations["categories"]: 79 | self.cats_id2label[i["id"]] = i["name"] 80 | self.label_names.append(i["name"]) 81 | if first_label_id > 1: 82 | raise AssertionError( 83 | "Something went wrong in categories, check the annotation file!" 84 | ) 85 | 86 | def get_total_classes_count(self): 87 | return len(self.cats_id2label) 88 | 89 | def get_classnames(self): 90 | return [v for k, v in self.cats_id2label.items()] 91 | 92 | def load_images_annotations(self, index): 93 | image_info = self.image_ids[index] 94 | image_path = os.path.join(self.image_folder, image_info["file_name"]) 95 | 96 | image = cv2.imread(image_path) 97 | rimage = cv2.cvtColor( 98 | image, cv2.COLOR_BGR2RGB 99 | ) # .astype(np.float32) # convert BGR to RGB color format 100 | rimage = cv2.resize(rimage, (self.width, self.height)) 101 | # rimage /= 255.0 102 | rimage = Image.fromarray(rimage) 103 | 104 | image_height, image_width = ( 105 | image_info["height"], 106 | image_info["width"], 107 | ) # original height & width 108 | anno_info = self.annotation_ids[index] 109 | 110 | if len(anno_info) == 0: # for negative images (Images without annotations) 111 | boxes = torch.zeros((0, 4), dtype=torch.float32) 112 | labels = torch.zeros((0, 1), dtype=torch.int64) 113 | iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64) 114 | area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 115 | else: 116 | boxes = [] 117 | labels_id = [] 118 | 119 | for ainfo in anno_info: 120 | xmin, ymin, w, h = ainfo["bbox"] 121 | xmax, ymax = xmin + w, ymin + h 122 | 123 | xmin_final = (xmin / image_width) * self.width 124 | xmax_final = (xmax / image_width) * self.width 125 | ymin_final = (ymin / image_height) * self.height 126 | ymax_final = (ymax / image_height) * self.height 127 | 128 | category_id = ainfo["category_id"] 129 | 130 | boxes.append([xmin_final, ymin_final, xmax_final, ymax_final]) 131 | labels_id.append(category_id) 132 | 133 | boxes = torch.as_tensor( 134 | boxes, dtype=torch.float32 135 | ) # bounding box to tensor 136 | area = (boxes[:, 3] - boxes[:, 1]) * ( 137 | boxes[:, 2] - boxes[:, 0] 138 | ) # area of the bounding boxes 139 | iscrowd = torch.zeros( 140 | (boxes.shape[0],), dtype=torch.int64 141 | ) # no crowd instances 142 | labels = torch.as_tensor(labels_id, dtype=torch.int64) # labels to tensor 143 | 144 | # final `target` dictionary 145 | target = {} 146 | target["boxes"] = boxes 147 | target["labels"] = labels 148 | target["area"] = area 149 | target["iscrowd"] = iscrowd 150 | image_id = torch.tensor([index]) 151 | target["image_id"] = image_id 152 | 153 | return { 154 | "image": rimage, 155 | "height": image_height, 156 | "width": image_width, 157 | "target": target, 158 | } 159 | 160 | @staticmethod 161 | def transform_image_for_inference(image_path,width,height): 162 | 163 | image = cv2.imread(image_path) 164 | ori_h, ori_w, _ = image.shape 165 | 166 | oimage = copy.deepcopy(image) 167 | oimage = Image.fromarray(oimage) 168 | oimage = T.ToTensor()(oimage) 169 | 170 | rimage = cv2.cvtColor( 171 | image, cv2.COLOR_BGR2RGB 172 | ) 173 | rimage = cv2.resize(rimage, (width,height)) 174 | rimage = Image.fromarray(rimage) 175 | rimage = T.ToTensor()(rimage) 176 | # rimage = torch.unsqueeze(rimage, 0) 177 | 178 | transform_info = {'original_width':ori_w, 179 | 'original_height':ori_h, 180 | 'resized_width':width, 181 | 'resized_height':height, 182 | 'resized_image':rimage, 183 | 'original_image':oimage} 184 | 185 | return transform_info # this can directly go to model for inference 186 | 187 | @staticmethod 188 | def display_bbox( 189 | bboxes, fig, ax, classes=None, in_format="xyxy", color="y", line_width=3 190 | ): 191 | if type(bboxes) == np.ndarray: 192 | bboxes = torch.from_numpy(bboxes) 193 | if classes: 194 | assert len(bboxes) == len(classes) 195 | # convert boxes to xywh format 196 | bboxes = ops.box_convert(bboxes, in_fmt=in_format, out_fmt="xywh") 197 | c = 0 198 | for box in bboxes: 199 | x, y, w, h = box.numpy() 200 | # display bounding box 201 | rect = patches.Rectangle( 202 | (x, y), w, h, linewidth=line_width, edgecolor=color, facecolor="none" 203 | ) 204 | ax.add_patch(rect) 205 | # display category 206 | if classes: 207 | if classes[c] == "pad": 208 | continue 209 | ax.text( 210 | x + 5, y + 20, classes[c], bbox=dict(facecolor="yellow", alpha=0.5) 211 | ) 212 | c += 1 213 | 214 | return fig, ax 215 | 216 | def __getitem__(self, idx): 217 | 218 | sample = self.load_images_annotations(idx) 219 | image_resized = sample["image"] 220 | target = sample["target"] 221 | 222 | # apply the image transforms 223 | if self.transforms: 224 | sample = self.transforms( 225 | image=image_resized, bboxes=target["boxes"], labels=sample["labels"] 226 | ) 227 | image_resized = sample["image"] 228 | target["boxes"] = torch.Tensor(sample["bboxes"]) 229 | 230 | return T.ToTensor()(image_resized), target 231 | 232 | def __len__(self): 233 | return len(self.image_ids) 234 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import os 4 | from .model import create_model 5 | from .dataset import CocoDataset 6 | import torch 7 | import numpy as np 8 | import copy 9 | import matplotlib.pyplot as plt 10 | import matplotlib.patches as patches 11 | 12 | class InferFasterRCNN: 13 | def __init__(self, num_classes=None, classnames=[]): 14 | 15 | assert type(num_classes) != type(None), "Define number of classes" 16 | 17 | self.num_classes = num_classes # total_class_no + 1 (for background) 18 | 19 | self.classnames = ["__background__"] 20 | self.classnames.extend(classnames) 21 | 22 | self.colors = np.random.uniform(0, 255, size=(len(self.classnames), 3)) 23 | 24 | assert ( 25 | len(self.classnames) == self.num_classes 26 | ), f"num_classes: {self.num_classes}, len(classnames): {len(self.classnames)}.\ 27 | num_classes should be equal to count of actual classes in classnames list without background + 1" 28 | 29 | def load_model(self, checkpoint, device="cpu"): 30 | self.device = device 31 | self.model = create_model( 32 | self.num_classes, checkpoint=checkpoint, device=self.device 33 | ) 34 | self.model = self.model.eval() 35 | 36 | def infer_image(self, transform_info ,detection_threshold=0.5, visualize=False): 37 | 38 | ''' 39 | image : original unscaled image 40 | ''' 41 | 42 | display_unscaled = True 43 | h_ratio = transform_info['original_height']/transform_info['resized_height'] 44 | w_ratio = transform_info['original_width']/transform_info['resized_width'] 45 | 46 | orig_image = transform_info['resized_image'] 47 | orig_image = orig_image.cpu().numpy() 48 | orig_image = np.transpose(orig_image, (1, 2, 0)) 49 | orig_image = np.ascontiguousarray(orig_image, dtype=np.float32) 50 | image = torch.unsqueeze(transform_info['resized_image'], 0) 51 | 52 | with torch.no_grad(): 53 | self.model = self.model.to(self.device) 54 | outputs = self.model(image.to(self.device)) 55 | 56 | # load all detection to CPU for further operations 57 | outputs = [{k: v.to("cpu") for k, v in t.items()} for t in outputs] 58 | 59 | results = {} 60 | _f_boxes,_f_scores,_f_labels = [],[],[] 61 | 62 | # carry further only if there are detected boxes 63 | if len(outputs[0]["boxes"]) != 0: 64 | boxes = outputs[0]["boxes"].data.numpy() # xyxy 65 | scores = outputs[0]["scores"].data.numpy() 66 | labels = outputs[0]["labels"].cpu().numpy() 67 | 68 | # filter out boxes according to `detection_threshold` 69 | for i in range(len(boxes)): 70 | if scores[i] >= detection_threshold: 71 | _f_boxes.append(boxes[i]) 72 | _f_labels.append(labels[i]) 73 | _f_scores.append(scores[i]) 74 | 75 | boxes,labels,scores = _f_boxes,_f_labels,_f_scores 76 | #boxes = boxes[scores >= detection_threshold].astype(np.int32) 77 | draw_boxes = boxes.copy() 78 | 79 | # get all the predicited class names 80 | pred_classes = [ 81 | self.classnames[i] for i in labels 82 | ] 83 | 84 | results['unscaled_boxes'] = [[i[0]*w_ratio, i[1]*h_ratio, i[2]*w_ratio, i[3]*h_ratio] for i in boxes] # in original image size 85 | results['scaled_boxes'] = boxes # in resize image size 86 | results['scores'] = scores 87 | results['pred_classes'] = pred_classes 88 | results['labels'] = labels 89 | 90 | if not display_unscaled: 91 | # draw the bounding boxes and write the class name on top of it 92 | for j, box in enumerate(draw_boxes): 93 | class_name = pred_classes[j] 94 | color = self.colors[self.classnames.index(class_name)] 95 | cv2.rectangle( 96 | orig_image, 97 | (int(box[0]), int(box[1])), 98 | (int(box[2]), int(box[3])), 99 | color, 100 | 2, 101 | ) 102 | cv2.putText( 103 | orig_image, 104 | class_name, 105 | (int(box[0]), int(box[1] - 5)), 106 | cv2.FONT_HERSHEY_SIMPLEX, 107 | 0.7, 108 | color, 109 | 2, 110 | lineType=cv2.LINE_AA, 111 | ) 112 | 113 | if visualize: 114 | plt.figure(figsize=(10, 10)) 115 | plt.imshow(orig_image[:,:,::-1]) 116 | plt.show() 117 | 118 | else: 119 | # draw the bounding boxes and write the class name on top of it 120 | draw_boxes_scaled = results['unscaled_boxes'] 121 | scaled_orig_image = transform_info['original_image'] 122 | scaled_orig_image = scaled_orig_image.cpu().numpy() 123 | scaled_orig_image = np.transpose(scaled_orig_image, (1, 2, 0)) 124 | scaled_orig_image = np.ascontiguousarray(scaled_orig_image, dtype=np.float32) 125 | 126 | for j, box in enumerate(draw_boxes_scaled): 127 | class_name = pred_classes[j] 128 | color = self.colors[self.classnames.index(class_name)] 129 | cv2.rectangle( 130 | scaled_orig_image, 131 | (int(box[0]), int(box[1])), 132 | (int(box[2]), int(box[3])), 133 | color, 134 | 2, 135 | ) 136 | cv2.putText( 137 | scaled_orig_image, 138 | class_name, 139 | (int(box[0]), int(box[1] - 5)), 140 | cv2.FONT_HERSHEY_SIMPLEX, 141 | 0.7, 142 | color, 143 | 2, 144 | lineType=cv2.LINE_AA, 145 | ) 146 | 147 | if visualize: 148 | plt.figure(figsize=(10, 10)) 149 | plt.imshow(scaled_orig_image) # [:,:,::-1]) 150 | plt.show() 151 | 152 | return results 153 | 154 | def infer_image_path(self, image_path, detection_threshold=0.5, visualize=False): 155 | 156 | image = cv2.imread(image_path) 157 | orig_image = image.copy() 158 | 159 | # BGR to RGB 160 | image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB).astype(np.float32) 161 | # make the pixel range between 0 and 1 162 | image /= 255.0 163 | # bring color channels to front 164 | image = np.transpose(image, (2, 0, 1)).astype(np.float32) 165 | # convert to tensor 166 | image = torch.tensor(image, dtype=torch.float).cpu() 167 | 168 | # add batch dimension 169 | image = torch.unsqueeze(image, 0) 170 | with torch.no_grad(): 171 | self.model = self.model.to(self.device) 172 | outputs = self.model(image.to(self.device)) 173 | 174 | # load all detection to CPU for further operations 175 | outputs = [{k: v.to("cpu") for k, v in t.items()} for t in outputs] 176 | 177 | 178 | # carry further only if there are detected boxes 179 | if len(outputs[0]["boxes"]) != 0: 180 | boxes = outputs[0]["boxes"].data.numpy() 181 | scores = outputs[0]["scores"].data.numpy() 182 | 183 | # filter out boxes according to `detection_threshold` 184 | boxes = boxes[scores >= detection_threshold].astype(np.int32) 185 | draw_boxes = boxes.copy() 186 | 187 | # get all the predicited class names 188 | pred_classes = [ 189 | self.classnames[i] for i in outputs[0]["labels"].cpu().numpy() 190 | ] 191 | 192 | # draw the bounding boxes and write the class name on top of it 193 | for j, box in enumerate(draw_boxes): 194 | class_name = pred_classes[j] 195 | color = self.colors[self.classnames.index(class_name)] 196 | cv2.rectangle( 197 | orig_image, 198 | (int(box[0]), int(box[1])), 199 | (int(box[2]), int(box[3])), 200 | color, 201 | 2, 202 | ) 203 | cv2.putText( 204 | orig_image, 205 | class_name, 206 | (int(box[0]), int(box[1] - 5)), 207 | cv2.FONT_HERSHEY_SIMPLEX, 208 | 0.7, 209 | color, 210 | 2, 211 | lineType=cv2.LINE_AA, 212 | ) 213 | 214 | if visualize: 215 | plt.figure(figsize=(10, 10)) 216 | plt.imshow(orig_image[:, :, ::-1]) 217 | plt.show() 218 | 219 | return outputs 220 | 221 | 222 | def draw_bounding_boxes(self,image, bboxes, class_labels, figsize=(12,12)): 223 | class_labels = class_labels.cpu().numpy() 224 | bboxes = bboxes.cpu().numpy() 225 | for j, box in enumerate(bboxes): 226 | label = class_labels[j] 227 | color = self.colors[label] 228 | cv2.rectangle( 229 | image, 230 | (int(box[0]), int(box[1])), 231 | (int(box[2]), int(box[3])), 232 | (0,200,0), 233 | 1, 234 | ) 235 | cv2.putText( 236 | image, 237 | self.classnames[int(label)], 238 | (int(box[0] + 15), int(box[1] + 15)), 239 | cv2.FONT_HERSHEY_SIMPLEX, 240 | 0.5, 241 | (0,200,0), 242 | 2, 243 | lineType=cv2.LINE_AA, 244 | ) 245 | 246 | plt.figure(figsize=figsize) 247 | plt.imshow(image) 248 | plt.show() 249 | 250 | 251 | class SaveBestModel: 252 | """ 253 | Class to save the best model while training. If the current epoch's 254 | validation loss is less than the previous least less, then save the 255 | model state. 256 | """ 257 | 258 | def __init__( 259 | self, best_valid_loss=float("inf"), output_dir="weight_outputs", 260 | ): 261 | self.best_valid_loss = best_valid_loss 262 | 263 | os.makedirs(output_dir, exist_ok=True) 264 | 265 | self.output_dir = output_dir 266 | 267 | def __call__(self, current_valid_loss, epoch, model, optimizer): 268 | if current_valid_loss < self.best_valid_loss: 269 | self.best_valid_loss = current_valid_loss 270 | print(f"\nBest validation loss: {self.best_valid_loss}") 271 | print(f"\nSaving best model for epoch: {epoch+1}\n") 272 | torch.save( 273 | { 274 | "epoch": epoch + 1, 275 | "model_state_dict": model.state_dict(), 276 | "optimizer_state_dict": optimizer.state_dict(), 277 | }, 278 | f"{self.output_dir}/best_model.pth", 279 | ) 280 | 281 | def display_gt_pred(image_path, 282 | gt_boxes, 283 | pred_boxes, 284 | gt_class, 285 | pred_class, 286 | pred_scores, 287 | box_format='xywh', 288 | figsize=(10,10), 289 | classnames = []): 290 | 291 | line_width = 1 292 | gt_color = 'g' 293 | pred_color = 'r' 294 | img = cv2.imread(image_path) 295 | fig, ax = plt.subplots(figsize=figsize) 296 | ax.imshow(img[:,:,::-1]) 297 | 298 | for gb,gc in zip(gt_boxes,gt_class): 299 | 300 | if format == 'xywh': 301 | x, y, w, h = gb 302 | 303 | if box_format == 'xyxy': 304 | x1, y1, x2, y2 = gb 305 | x,y,w,h = x1,y1,x2-x1,y2-y1 306 | 307 | rect = patches.Rectangle( 308 | (x, y), w, h, linewidth=line_width, edgecolor=gt_color, facecolor="none" 309 | ) 310 | ax.add_patch(rect) 311 | 312 | if len(classnames)>0: 313 | ax.text(x + 5, y + 20, classnames[int(gc)-1], bbox=dict(facecolor="yellow", alpha=0.5)) 314 | else: 315 | ax.text(x + 5, y + 20, gc, bbox=dict(facecolor="yellow", alpha=0.5)) 316 | 317 | for pb,pc,ps in zip(pred_boxes,pred_class,pred_scores): 318 | if format == 'xywh': 319 | x, y, w, h = pb 320 | if box_format == 'xyxy': 321 | x1, y1, x2, y2 = pb 322 | x,y,w,h = x1,y1,x2-x1,y2-y1 323 | rect = patches.Rectangle( 324 | (x, y), w, h, linewidth=line_width+1, edgecolor=pred_color, facecolor="none" 325 | ) 326 | ax.add_patch(rect) 327 | ax.text(x + 5, y + 40, f'{pc},{round(ps*100,2)}', bbox=dict(facecolor="red", alpha=0.5)) 328 | 329 | plt.axis('off') 330 | plt.show() --------------------------------------------------------------------------------