├── images ├── yolo_ex.png ├── darknet_53_table.png ├── yolo_grid_image.png └── yolo_architecture.png ├── requirements.txt ├── README.md ├── loss.py ├── train.py ├── config.py ├── dataset.py ├── model.py ├── model_with_weights2.py ├── utils.py └── Implementing and training YOLOv3 - Medium.ipynb /images/yolo_ex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SannaPersson/YOLOv3-PyTorch/HEAD/images/yolo_ex.png -------------------------------------------------------------------------------- /images/darknet_53_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SannaPersson/YOLOv3-PyTorch/HEAD/images/darknet_53_table.png -------------------------------------------------------------------------------- /images/yolo_grid_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SannaPersson/YOLOv3-PyTorch/HEAD/images/yolo_grid_image.png -------------------------------------------------------------------------------- /images/yolo_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SannaPersson/YOLOv3-PyTorch/HEAD/images/yolo_architecture.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.19.2 2 | matplotlib>=3.3.4 3 | torch>=1.7.1 4 | tqdm>=4.56.0 5 | torchvision>=0.8.2 6 | albumentations>=0.5.2 7 | pandas>=1.2.1 8 | Pillow>=8.1.0 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YOLOv3 in PyTorch 2 | A quite minimal implementation of YOLOv3 in PyTorch spanning only around 800 lines of code (not including plot functions etc.) with support for training and evaluation and complete with helper functions for inference. There is currently pretrained weights for Pascal-VOC with MS COCO coming up. With minimal changes in the model with regards to the output format the original weights can also be loaded seamlessly. 3 | 4 | ## Installation 5 | 6 | ### Clone and install requirements 7 | ```bash 8 | $ git clone https://github.com/SannaPersson/YOLOv3-PyTorch.git 9 | $ cd YOLOv3-PyTorch 10 | $ pip install requirements.txt 11 | ``` 12 | ### Download pretrained weights on Pascal-VOC 13 | Pretrained weights for Pascal-VOC can be downloaded from this page: https://www.kaggle.com/sannapersson/yolov3-weights-for-pascal-voc-with-781-map 14 | 15 | ### Dowload original weights 16 | Download YOLOv3 weights from https://pjreddie.com/media/files/yolov3.weights. Save the weights to PyTorch format by running the model_with_weights.py file. 17 | Change line in train.py to import model_with_weights.py instead of model.py since the original output format is slightly different. 18 | 19 | ### Download Pascal VOC dataset 20 | Download the preprocessed dataset from [link](https://www.kaggle.com/aladdinpersson/pascal-voc-yolo-works-with-albumentations). Just unzip this in the main directory. 21 | The file structure of the dataset is a folder with images, a folder with corresponding text files containing the bounding boxes and class targets for each image and two csv-files containing the subsets of the data used for training and testing. 22 | 23 | 24 | 25 | ### Training 26 | Edit the config.py file to match the setup you want to use. Then run train.py 27 | 28 | ### Results 29 | | Model | mAP @ 50 IoU | 30 | | ----------------------- |:-----------------:| 31 | | YOLOv3 (Pascal VOC) | 78.2 | 32 | | YOLOv3 (MS-COCO) | Not done yet | 33 | 34 | The model was evaluated with confidence 0.2 and IOU threshold 0.45 using NMS. 35 | 36 | ## YOLOv3 paper 37 | The implementation is based on the following paper: 38 | ### An Incremental Improvement 39 | by Joseph Redmon, Ali Farhadi 40 | 41 | #### Abstract 42 | We present some updates to YOLO! We made a bunch of little design changes to make it better. We also trained this new network that’s pretty swell. It’s a little bigger than last time but more accurate. It’s still fast though, don’t worry. At 320 × 320 YOLOv3 runs in 22 ms at 28.2 mAP, as accurate as SSD but three times faster. When we look at the old .5 IOU mAP detection metric YOLOv3 is quite good. It achieves 57.9 AP50 in 51 ms on a Titan X, compared to 57.5 AP50 in 198 ms by RetinaNet, similar performance but 3.8× faster. As always, all the code is online at https://pjreddie.com/yolo/. 43 | 44 | ``` 45 | @article{yolov3, 46 | title={YOLOv3: An Incremental Improvement}, 47 | author={Redmon, Joseph and Farhadi, Ali}, 48 | journal = {arXiv}, 49 | year={2018} 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Yolo Loss Function similar to the one in Yolov3 paper, 3 | the difference from what I can tell is I use CrossEntropy for the classes 4 | instead of BinaryCrossEntropy. 5 | """ 6 | import random 7 | import torch 8 | import torch.nn as nn 9 | 10 | from utils import intersection_over_union 11 | 12 | 13 | class YoloLoss(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | self.mse = nn.MSELoss() 17 | self.bce = nn.BCEWithLogitsLoss() 18 | self.entropy = nn.CrossEntropyLoss() 19 | self.sigmoid = nn.Sigmoid() 20 | 21 | # Constants signifying how much to pay for each respective part of the loss 22 | self.lambda_class = 1 23 | self.lambda_noobj = 10 24 | self.lambda_obj = 1 25 | self.lambda_box = 10 26 | 27 | def forward(self, predictions, target, anchors): 28 | # Check where obj and noobj (we ignore if target == -1) 29 | obj = target[..., 0] == 1 # in paper this is Iobj_i 30 | noobj = target[..., 0] == 0 # in paper this is Inoobj_i 31 | 32 | # ======================= # 33 | # FOR NO OBJECT LOSS # 34 | # ======================= # 35 | 36 | no_object_loss = self.bce( 37 | (predictions[..., 0:1][noobj]), (target[..., 0:1][noobj]), 38 | ) 39 | 40 | # ==================== # 41 | # FOR OBJECT LOSS # 42 | # ==================== # 43 | 44 | anchors = anchors.reshape(1, 3, 1, 1, 2) 45 | box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(predictions[..., 3:5]) * anchors], dim=-1) 46 | ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach() 47 | object_loss = self.bce((predictions[..., 0:1][obj]), (ious * target[..., 0:1][obj])) 48 | 49 | # ======================== # 50 | # FOR BOX COORDINATES # 51 | # ======================== # 52 | 53 | predictions[..., 1:3] = self.sigmoid(predictions[..., 1:3]) # x,y coordinates 54 | target[..., 3:5] = torch.log( 55 | (1e-16 + target[..., 3:5] / anchors) 56 | ) # width, height coordinates 57 | box_loss = self.mse(predictions[..., 1:5][obj], target[..., 1:5][obj]) 58 | 59 | # ================== # 60 | # FOR CLASS LOSS # 61 | # ================== # 62 | 63 | class_loss = self.entropy( 64 | (predictions[..., 5:][obj]), (target[..., 5][obj].long()), 65 | ) 66 | 67 | # print("__________________________________") 68 | # print(self.lambda_box * box_loss) 69 | # print(self.lambda_obj * object_loss) 70 | # print(self.lambda_noobj * no_object_loss) 71 | # print(self.lambda_class * class_loss) 72 | # print("\n") 73 | 74 | return ( 75 | self.lambda_box * box_loss 76 | + self.lambda_obj * object_loss 77 | + self.lambda_noobj * no_object_loss 78 | + self.lambda_class * class_loss 79 | ) 80 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main file for training Yolo model on Pascal VOC and COCO dataset 3 | """ 4 | 5 | import config 6 | import torch 7 | import torch.optim as optim 8 | 9 | from model import YOLOv3 10 | from tqdm import tqdm 11 | from utils import ( 12 | mean_average_precision, 13 | cells_to_bboxes, 14 | get_evaluation_bboxes, 15 | save_checkpoint, 16 | load_checkpoint, 17 | check_class_accuracy, 18 | get_loaders, 19 | plot_couple_examples 20 | ) 21 | from loss import YoloLoss 22 | 23 | torch.backends.cudnn.benchmark = True 24 | 25 | 26 | def train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors): 27 | loop = tqdm(train_loader, leave=True) 28 | losses = [] 29 | for batch_idx, (x, y) in enumerate(loop): 30 | x = x.to(config.DEVICE) 31 | y0, y1, y2 = ( 32 | y[0].to(config.DEVICE), 33 | y[1].to(config.DEVICE), 34 | y[2].to(config.DEVICE), 35 | ) 36 | 37 | with torch.cuda.amp.autocast(): 38 | out = model(x) 39 | loss = ( 40 | loss_fn(out[0], y0, scaled_anchors[0]) 41 | + loss_fn(out[1], y1, scaled_anchors[1]) 42 | + loss_fn(out[2], y2, scaled_anchors[2]) 43 | ) 44 | 45 | losses.append(loss.item()) 46 | optimizer.zero_grad() 47 | scaler.scale(loss).backward() 48 | scaler.step(optimizer) 49 | scaler.update() 50 | 51 | # update progress bar 52 | mean_loss = sum(losses) / len(losses) 53 | loop.set_postfix(loss=mean_loss) 54 | 55 | 56 | 57 | def main(): 58 | model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE) 59 | optimizer = optim.Adam( 60 | model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY 61 | ) 62 | loss_fn = YoloLoss() 63 | scaler = torch.cuda.amp.GradScaler() 64 | 65 | train_loader, test_loader, train_eval_loader = get_loaders( 66 | train_csv_path=config.DATASET + "/train.csv", test_csv_path=config.DATASET + "/test.csv" 67 | ) 68 | 69 | if config.LOAD_MODEL: 70 | load_checkpoint( 71 | config.CHECKPOINT_FILE, model, optimizer, config.LEARNING_RATE 72 | ) 73 | 74 | scaled_anchors = ( 75 | torch.tensor(config.ANCHORS) 76 | * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) 77 | ).to(config.DEVICE) 78 | 79 | for epoch in range(config.NUM_EPOCHS): 80 | #plot_couple_examples(model, test_loader, 0.6, 0.5, scaled_anchors) 81 | train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors) 82 | 83 | if config.SAVE_MODEL: 84 | save_checkpoint(model, optimizer, filename=f"checkpoint.pth.tar") 85 | 86 | #print(f"Currently epoch {epoch}") 87 | #print("On Train Eval loader:") 88 | #check_class_accuracy(model, train_eval_loader, threshold=config.CONF_THRESHOLD) 89 | #print("On Train loader:") 90 | #check_class_accuracy(model, train_loader, threshold=config.CONF_THRESHOLD) 91 | 92 | if epoch % 10 == 0 and epoch > 0: 93 | print("On Test loader:") 94 | check_class_accuracy(model, test_loader, threshold=config.CONF_THRESHOLD) 95 | 96 | pred_boxes, true_boxes = get_evaluation_bboxes( 97 | test_loader, 98 | model, 99 | iou_threshold=config.NMS_IOU_THRESH, 100 | anchors=config.ANCHORS, 101 | threshold=config.CONF_THRESHOLD, 102 | ) 103 | mapval = mean_average_precision( 104 | pred_boxes, 105 | true_boxes, 106 | iou_threshold=config.MAP_IOU_THRESH, 107 | box_format="midpoint", 108 | num_classes=config.NUM_CLASSES, 109 | ) 110 | print(f"MAP: {mapval.item()}") 111 | 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | import cv2 3 | import torch 4 | 5 | from albumentations.pytorch import ToTensorV2 6 | from utils import seed_everything 7 | 8 | DATASET = 'PASCAL_VOC' 9 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 10 | # seed_everything() # If you want deterministic behavior 11 | NUM_WORKERS = 4 12 | BATCH_SIZE = 32 13 | IMAGE_SIZE = 416 14 | NUM_CLASSES = 80 15 | LEARNING_RATE = 3e-5 16 | WEIGHT_DECAY = 1e-4 17 | NUM_EPOCHS = 100 18 | CONF_THRESHOLD = 0.6 19 | MAP_IOU_THRESH = 0.5 20 | NMS_IOU_THRESH = 0.45 21 | S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8] 22 | PIN_MEMORY = True 23 | LOAD_MODEL = True 24 | SAVE_MODEL = True 25 | CHECKPOINT_FILE = "checkpoint.pth.tar" 26 | IMG_DIR = DATASET + "/images/" 27 | LABEL_DIR = DATASET + "/labels/" 28 | 29 | ANCHORS = [ 30 | [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)], 31 | [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)], 32 | [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)], 33 | ] # Note these have been rescaled to be between [0, 1] 34 | 35 | 36 | scale = 1.1 37 | train_transforms = A.Compose( 38 | [ 39 | A.LongestMaxSize(max_size=int(IMAGE_SIZE * scale)), 40 | A.PadIfNeeded( 41 | min_height=int(IMAGE_SIZE * scale), 42 | min_width=int(IMAGE_SIZE * scale), 43 | border_mode=cv2.BORDER_CONSTANT, 44 | ), 45 | A.RandomCrop(width=IMAGE_SIZE, height=IMAGE_SIZE), 46 | A.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.6, p=0.4), 47 | A.OneOf( 48 | [ 49 | A.ShiftScaleRotate( 50 | rotate_limit=10, p=0.4, border_mode=cv2.BORDER_CONSTANT 51 | ), 52 | A.IAAAffine(shear=10, p=0.4, mode="constant"), 53 | ], 54 | p=1.0, 55 | ), 56 | A.HorizontalFlip(p=0.5), 57 | A.Blur(p=0.1), 58 | A.CLAHE(p=0.1), 59 | A.Posterize(p=0.1), 60 | A.ToGray(p=0.1), 61 | A.ChannelShuffle(p=0.05), 62 | A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,), 63 | ToTensorV2(), 64 | ], 65 | bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[],), 66 | ) 67 | test_transforms = A.Compose( 68 | [ 69 | A.LongestMaxSize(max_size=IMAGE_SIZE), 70 | A.PadIfNeeded( 71 | min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT 72 | ), 73 | A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,), 74 | ToTensorV2(), 75 | ], 76 | bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[]), 77 | ) 78 | 79 | PASCAL_CLASSES = [ 80 | "aeroplane", 81 | "bicycle", 82 | "bird", 83 | "boat", 84 | "bottle", 85 | "bus", 86 | "car", 87 | "cat", 88 | "chair", 89 | "cow", 90 | "diningtable", 91 | "dog", 92 | "horse", 93 | "motorbike", 94 | "person", 95 | "pottedplant", 96 | "sheep", 97 | "sofa", 98 | "train", 99 | "tvmonitor" 100 | ] 101 | 102 | COCO_LABELS = ['person', 103 | 'bicycle', 104 | 'car', 105 | 'motorcycle', 106 | 'airplane', 107 | 'bus', 108 | 'train', 109 | 'truck', 110 | 'boat', 111 | 'traffic light', 112 | 'fire hydrant', 113 | 'stop sign', 114 | 'parking meter', 115 | 'bench', 116 | 'bird', 117 | 'cat', 118 | 'dog', 119 | 'horse', 120 | 'sheep', 121 | 'cow', 122 | 'elephant', 123 | 'bear', 124 | 'zebra', 125 | 'giraffe', 126 | 'backpack', 127 | 'umbrella', 128 | 'handbag', 129 | 'tie', 130 | 'suitcase', 131 | 'frisbee', 132 | 'skis', 133 | 'snowboard', 134 | 'sports ball', 135 | 'kite', 136 | 'baseball bat', 137 | 'baseball glove', 138 | 'skateboard', 139 | 'surfboard', 140 | 'tennis racket', 141 | 'bottle', 142 | 'wine glass', 143 | 'cup', 144 | 'fork', 145 | 'knife', 146 | 'spoon', 147 | 'bowl', 148 | 'banana', 149 | 'apple', 150 | 'sandwich', 151 | 'orange', 152 | 'broccoli', 153 | 'carrot', 154 | 'hot dog', 155 | 'pizza', 156 | 'donut', 157 | 'cake', 158 | 'chair', 159 | 'couch', 160 | 'potted plant', 161 | 'bed', 162 | 'dining table', 163 | 'toilet', 164 | 'tv', 165 | 'laptop', 166 | 'mouse', 167 | 'remote', 168 | 'keyboard', 169 | 'cell phone', 170 | 'microwave', 171 | 'oven', 172 | 'toaster', 173 | 'sink', 174 | 'refrigerator', 175 | 'book', 176 | 'clock', 177 | 'vase', 178 | 'scissors', 179 | 'teddy bear', 180 | 'hair drier', 181 | 'toothbrush' 182 | ] -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creates a Pytorch dataset to load the Pascal VOC & MS COCO datasets 3 | """ 4 | 5 | import config 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | import torch 10 | 11 | from PIL import Image, ImageFile 12 | from torch.utils.data import Dataset, DataLoader 13 | from utils import ( 14 | cells_to_bboxes, 15 | iou_width_height as iou, 16 | non_max_suppression as nms, 17 | plot_image 18 | ) 19 | 20 | ImageFile.LOAD_TRUNCATED_IMAGES = True 21 | 22 | class YOLODataset(Dataset): 23 | def __init__( 24 | self, 25 | csv_file, 26 | img_dir, 27 | label_dir, 28 | anchors, 29 | image_size=416, 30 | S=[13, 26, 52], 31 | C=20, 32 | transform=None, 33 | ): 34 | self.annotations = pd.read_csv(csv_file) 35 | self.img_dir = img_dir 36 | self.label_dir = label_dir 37 | self.image_size = image_size 38 | self.transform = transform 39 | self.S = S 40 | self.anchors = torch.tensor(anchors[0] + anchors[1] + anchors[2]) # for all 3 scales 41 | self.num_anchors = self.anchors.shape[0] 42 | self.num_anchors_per_scale = self.num_anchors // 3 43 | self.C = C 44 | self.ignore_iou_thresh = 0.5 45 | 46 | def __len__(self): 47 | return len(self.annotations) 48 | 49 | def __getitem__(self, index): 50 | label_path = os.path.join(self.label_dir, self.annotations.iloc[index, 1]) 51 | bboxes = np.roll(np.loadtxt(fname=label_path, delimiter=" ", ndmin=2), 4, axis=1).tolist() 52 | img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0]) 53 | image = np.array(Image.open(img_path).convert("RGB")) 54 | 55 | if self.transform: 56 | augmentations = self.transform(image=image, bboxes=bboxes) 57 | image = augmentations["image"] 58 | bboxes = augmentations["bboxes"] 59 | 60 | # Below assumes 3 scale predictions (as paper) and same num of anchors per scale 61 | targets = [torch.zeros((self.num_anchors // 3, S, S, 6)) for S in self.S] 62 | for box in bboxes: 63 | iou_anchors = iou(torch.tensor(box[2:4]), self.anchors) 64 | anchor_indices = iou_anchors.argsort(descending=True, dim=0) 65 | x, y, width, height, class_label = box 66 | has_anchor = [False] * 3 # each scale should have one anchor 67 | for anchor_idx in anchor_indices: 68 | scale_idx = anchor_idx // self.num_anchors_per_scale 69 | anchor_on_scale = anchor_idx % self.num_anchors_per_scale 70 | S = self.S[scale_idx] 71 | i, j = int(S * y), int(S * x) # which cell 72 | anchor_taken = targets[scale_idx][anchor_on_scale, i, j, 0] 73 | if not anchor_taken and not has_anchor[scale_idx]: 74 | targets[scale_idx][anchor_on_scale, i, j, 0] = 1 75 | x_cell, y_cell = S * x - j, S * y - i # both between [0,1] 76 | width_cell, height_cell = ( 77 | width * S, 78 | height * S, 79 | ) # can be greater than 1 since it's relative to cell 80 | box_coordinates = torch.tensor( 81 | [x_cell, y_cell, width_cell, height_cell] 82 | ) 83 | targets[scale_idx][anchor_on_scale, i, j, 1:5] = box_coordinates 84 | targets[scale_idx][anchor_on_scale, i, j, 5] = int(class_label) 85 | has_anchor[scale_idx] = True 86 | 87 | elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh: 88 | targets[scale_idx][anchor_on_scale, i, j, 0] = -1 # ignore prediction 89 | 90 | return image, tuple(targets) 91 | 92 | 93 | def test(): 94 | anchors = config.ANCHORS 95 | 96 | transform = config.test_transforms 97 | 98 | dataset = YOLODataset( 99 | "COCO/train.csv", 100 | "COCO/images/images/", 101 | "COCO/labels/labels_new/", 102 | S=[13, 26, 52], 103 | anchors=anchors, 104 | transform=transform, 105 | ) 106 | S = [13, 26, 52] 107 | scaled_anchors = torch.tensor(anchors) / ( 108 | 1 / torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) 109 | ) 110 | loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True) 111 | for x, y in loader: 112 | boxes = [] 113 | 114 | for i in range(y[0].shape[1]): 115 | anchor = scaled_anchors[i] 116 | print(anchor.shape) 117 | print(y[i].shape) 118 | boxes += cells_to_bboxes( 119 | y[i], is_preds=False, S=y[i].shape[2], anchors=anchor 120 | )[0] 121 | boxes = nms(boxes, iou_threshold=1, threshold=0.7, box_format="midpoint") 122 | print(boxes) 123 | plot_image(x[0].permute(1, 2, 0).to("cpu"), boxes) 124 | 125 | 126 | if __name__ == "__main__": 127 | test() 128 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of YOLOv3 architecture 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | """ 9 | Information about architecture config: 10 | Tuple is structured by (filters, kernel_size, stride) 11 | Every conv is a same convolution. 12 | List is structured by "B" indicating a residual block followed by the number of repeats 13 | "S" is for scale prediction block and computing the yolo loss 14 | "U" is for upsampling the feature map and concatenating with a previous layer 15 | """ 16 | config = [ 17 | (32, 3, 1), 18 | (64, 3, 2), 19 | ["B", 1], 20 | (128, 3, 2), 21 | ["B", 2], 22 | (256, 3, 2), 23 | ["B", 8], 24 | (512, 3, 2), 25 | ["B", 8], 26 | (1024, 3, 2), 27 | ["B", 4], # To this point is Darknet-53 28 | (512, 1, 1), 29 | (1024, 3, 1), 30 | "S", 31 | (256, 1, 1), 32 | "U", 33 | (256, 1, 1), 34 | (512, 3, 1), 35 | "S", 36 | (128, 1, 1), 37 | "U", 38 | (128, 1, 1), 39 | (256, 3, 1), 40 | "S", 41 | ] 42 | 43 | 44 | class CNNBlock(nn.Module): 45 | def __init__(self, in_channels, out_channels, bn_act=True, **kwargs): 46 | super().__init__() 47 | self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs) 48 | self.bn = nn.BatchNorm2d(out_channels) 49 | self.leaky = nn.LeakyReLU(0.1) 50 | self.use_bn_act = bn_act 51 | 52 | def forward(self, x): 53 | if self.use_bn_act: 54 | return self.leaky(self.bn(self.conv(x))) 55 | else: 56 | return self.conv(x) 57 | 58 | 59 | class ResidualBlock(nn.Module): 60 | def __init__(self, channels, use_residual=True, num_repeats=1): 61 | super().__init__() 62 | self.layers = nn.ModuleList() 63 | for repeat in range(num_repeats): 64 | self.layers += [ 65 | nn.Sequential( 66 | CNNBlock(channels, channels // 2, kernel_size=1), 67 | CNNBlock(channels // 2, channels, kernel_size=3, padding=1), 68 | ) 69 | ] 70 | 71 | self.use_residual = use_residual 72 | self.num_repeats = num_repeats 73 | 74 | def forward(self, x): 75 | for layer in self.layers: 76 | if self.use_residual: 77 | x = x + layer(x) 78 | else: 79 | x = layer(x) 80 | 81 | return x 82 | 83 | 84 | class ScalePrediction(nn.Module): 85 | def __init__(self, in_channels, num_classes): 86 | super().__init__() 87 | self.pred = nn.Sequential( 88 | CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1), 89 | CNNBlock( 90 | 2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1 91 | ), 92 | ) 93 | self.num_classes = num_classes 94 | 95 | def forward(self, x): 96 | return ( 97 | self.pred(x) 98 | .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3]) 99 | .permute(0, 1, 3, 4, 2) 100 | ) 101 | 102 | 103 | class YOLOv3(nn.Module): 104 | def __init__(self, in_channels=3, num_classes=80): 105 | super().__init__() 106 | self.num_classes = num_classes 107 | self.in_channels = in_channels 108 | self.layers = self._create_conv_layers() 109 | 110 | def forward(self, x): 111 | outputs = [] # for each scale 112 | route_connections = [] 113 | for layer in self.layers: 114 | if isinstance(layer, ScalePrediction): 115 | outputs.append(layer(x)) 116 | continue 117 | 118 | x = layer(x) 119 | 120 | if isinstance(layer, ResidualBlock) and layer.num_repeats == 8: 121 | route_connections.append(x) 122 | 123 | elif isinstance(layer, nn.Upsample): 124 | x = torch.cat([x, route_connections[-1]], dim=1) 125 | route_connections.pop() 126 | 127 | return outputs 128 | 129 | def _create_conv_layers(self): 130 | layers = nn.ModuleList() 131 | in_channels = self.in_channels 132 | 133 | for module in config: 134 | if isinstance(module, tuple): 135 | out_channels, kernel_size, stride = module 136 | layers.append( 137 | CNNBlock( 138 | in_channels, 139 | out_channels, 140 | kernel_size=kernel_size, 141 | stride=stride, 142 | padding=1 if kernel_size == 3 else 0, 143 | ) 144 | ) 145 | in_channels = out_channels 146 | 147 | elif isinstance(module, list): 148 | num_repeats = module[1] 149 | layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,)) 150 | 151 | elif isinstance(module, str): 152 | if module == "S": 153 | layers += [ 154 | ResidualBlock(in_channels, use_residual=False, num_repeats=1), 155 | CNNBlock(in_channels, in_channels // 2, kernel_size=1), 156 | ScalePrediction(in_channels // 2, num_classes=self.num_classes), 157 | ] 158 | in_channels = in_channels // 2 159 | 160 | elif module == "U": 161 | layers.append(nn.Upsample(scale_factor=2),) 162 | in_channels = in_channels * 3 163 | 164 | return layers 165 | 166 | 167 | if __name__ == "__main__": 168 | num_classes = 20 169 | IMAGE_SIZE = 416 170 | model = YOLOv3(num_classes=num_classes) 171 | x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE)) 172 | out = model(x) 173 | assert model(x)[0].shape == (2, 3, IMAGE_SIZE//32, IMAGE_SIZE//32, num_classes + 5) 174 | assert model(x)[1].shape == (2, 3, IMAGE_SIZE//16, IMAGE_SIZE//16, num_classes + 5) 175 | assert model(x)[2].shape == (2, 3, IMAGE_SIZE//8, IMAGE_SIZE//8, num_classes + 5) 176 | print("Success!") 177 | -------------------------------------------------------------------------------- /model_with_weights2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Yolo (v3) architecture 3 | 4 | paper (it's srsly hilarious): 5 | 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import numpy as np 12 | 13 | """ 14 | Information about architecture config: 15 | Tuple is structured by (filters, kernel_size, stride) 16 | Every conv is a same convolution. 17 | List is structured by "B" indicating a residual block followed by the number of repeats 18 | "S" is for scale prediction block and computing the yolo loss 19 | "U" is for upsampling the feature map and concatenating with a previous layer 20 | """ 21 | config = [ 22 | (32, 3, 1), 23 | (64, 3, 2), 24 | ["B", 1], 25 | (128, 3, 2), 26 | ["B", 2], 27 | (256, 3, 2), 28 | ["B", 8], 29 | (512, 3, 2), 30 | ["B", 8], 31 | (1024, 3, 2), 32 | ["B", 4], 33 | (512, 1, 1), 34 | (1024, 3, 1), 35 | "S", 36 | (256, 1, 1), 37 | "U", 38 | (256, 1, 1), 39 | (512, 3, 1), 40 | "S", 41 | (128, 1, 1), 42 | "U", 43 | (128, 1, 1), 44 | (256, 3, 1), 45 | "S", 46 | ] 47 | 48 | 49 | class CNNBlock(nn.Module): 50 | def __init__(self, in_channels, out_channels, bn_act=True, **kwargs): 51 | super(CNNBlock, self).__init__() 52 | self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs) 53 | self.bn = nn.BatchNorm2d(out_channels) 54 | self.leaky = nn.LeakyReLU(0.1) 55 | self.use_bn_act = bn_act 56 | 57 | def forward(self, x): 58 | if self.use_bn_act: 59 | return self.leaky(self.bn(self.conv(x))) 60 | else: 61 | return self.conv(x) 62 | 63 | 64 | class ResidualBlock(nn.Module): 65 | def __init__(self, channels, use_residual=True, num_repeats=1): 66 | super(ResidualBlock, self).__init__() 67 | self.layers = nn.ModuleList() 68 | for repeat in range(num_repeats): 69 | self.layers += [ 70 | nn.Sequential( 71 | CNNBlock(channels, channels // 2, kernel_size=1), 72 | CNNBlock(channels // 2, channels, kernel_size=3, padding=1), 73 | ) 74 | ] 75 | 76 | self.use_residual = use_residual 77 | self.num_repeats = num_repeats 78 | 79 | def forward(self, x): 80 | for layer in self.layers: 81 | if self.use_residual: 82 | x = x + layer(x) 83 | else: 84 | x = layer(x) 85 | 86 | return x 87 | 88 | 89 | class ScalePrediction(nn.Module): 90 | def __init__(self, in_channels, num_classes): 91 | super(ScalePrediction, self).__init__() 92 | self.pred = nn.Sequential( 93 | CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1), 94 | CNNBlock( 95 | 2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1 96 | ), 97 | ) 98 | self.num_classes = num_classes 99 | 100 | def forward(self, x): 101 | return ( 102 | self.pred(x) 103 | .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3]) 104 | .permute(0, 1, 3, 4, 2) 105 | ) 106 | 107 | 108 | class YOLOv3(nn.Module): 109 | def __init__(self, in_channels=3, num_classes=80): 110 | super(YOLOv3, self).__init__() 111 | self.num_classes = num_classes 112 | self.in_channels = in_channels 113 | self.layers = self._create_conv_layers() 114 | 115 | def forward(self, x): 116 | outputs = [] 117 | route_connections = [] 118 | for layer in self.layers: 119 | if isinstance(layer, ScalePrediction): 120 | outputs.append(layer(x)) 121 | continue 122 | 123 | x = layer(x) 124 | 125 | if isinstance(layer, ResidualBlock) and layer.num_repeats == 8: 126 | route_connections.append(x) 127 | 128 | elif isinstance(layer, nn.Upsample): 129 | x = torch.cat([x, route_connections[-1]], dim=1) 130 | route_connections.pop() 131 | 132 | return outputs 133 | 134 | def _create_conv_layers(self): 135 | layers = nn.ModuleList() 136 | in_channels = self.in_channels 137 | 138 | for module in config: 139 | if isinstance(module, tuple): 140 | out_channels, kernel_size, stride = module 141 | layers.append( 142 | CNNBlock( 143 | in_channels, 144 | out_channels, 145 | kernel_size=kernel_size, 146 | stride=stride, 147 | padding=1 if kernel_size == 3 else 0, 148 | ) 149 | ) 150 | in_channels = out_channels 151 | 152 | elif isinstance(module, list): 153 | num_repeats = module[1] 154 | layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,)) 155 | 156 | elif isinstance(module, str): 157 | if module == "S": 158 | layers += [ 159 | ResidualBlock(in_channels, use_residual=False, num_repeats=1), 160 | CNNBlock(in_channels, in_channels // 2, kernel_size=1), 161 | ScalePrediction(in_channels // 2, num_classes=self.num_classes), 162 | ] 163 | in_channels = in_channels // 2 164 | 165 | elif module == "U": 166 | layers.append(nn.Upsample(scale_factor=2),) 167 | in_channels += in_channels * 2 168 | 169 | return layers 170 | 171 | def load_CNN_weights(self, ptr, block): 172 | 173 | conv_layer = block.conv 174 | if block.use_bn_act: 175 | # Load BN bias, weights, running mean and running variance 176 | bn_layer = block.bn 177 | num_b = bn_layer.bias.numel() # Number of biases 178 | # Bias 179 | bn_b = torch.from_numpy(self.weights[ptr : ptr + num_b]).view_as( 180 | bn_layer.bias 181 | ) 182 | bn_layer.bias.data.copy_(bn_b) 183 | ptr += num_b 184 | # Weight 185 | bn_w = torch.from_numpy(self.weights[ptr : ptr + num_b]).view_as( 186 | bn_layer.weight 187 | ) 188 | bn_layer.weight.data.copy_(bn_w) 189 | ptr += num_b 190 | # Running Mean 191 | bn_rm = torch.from_numpy(self.weights[ptr : ptr + num_b]).view_as( 192 | bn_layer.running_mean 193 | ) 194 | bn_layer.running_mean.data.copy_(bn_rm) 195 | ptr += num_b 196 | # Running Var 197 | bn_rv = torch.from_numpy(self.weights[ptr : ptr + num_b]).view_as( 198 | bn_layer.running_var 199 | ) 200 | bn_layer.running_var.data.copy_(bn_rv) 201 | ptr += num_b 202 | else: 203 | # Load conv. bias 204 | num_b = conv_layer.bias.numel() 205 | 206 | conv_b = torch.from_numpy(self.weights[ptr : ptr + num_b]).view_as( 207 | conv_layer.bias 208 | ) 209 | conv_layer.bias.data.copy_(conv_b) 210 | ptr += num_b 211 | # Load conv. weights 212 | num_w = conv_layer.weight.numel() 213 | conv_w = torch.from_numpy(self.weights[ptr : ptr + num_w]).view_as( 214 | conv_layer.weight 215 | ) 216 | conv_layer.weight.data.copy_(conv_w) 217 | ptr += num_w 218 | return ptr 219 | 220 | def load_darknet_weights(self, weights_path): 221 | """Parses and loads the weights stored in 'weights_path'""" 222 | 223 | # Open the weights file 224 | with open(weights_path, "rb") as f: 225 | header = np.fromfile( 226 | f, dtype=np.int32, count=5 227 | ) # First five are header values 228 | self.header_info = header # Needed to write header when saving weights 229 | self.seen = header[3] # number of images seen during training 230 | self.weights = np.fromfile(f, dtype=np.float32) # The rest are weights 231 | 232 | ptr = 0 233 | for idx, layer in enumerate(self.layers): 234 | if isinstance(layer, CNNBlock): 235 | ptr = self.load_CNN_weights(ptr, layer) 236 | 237 | elif isinstance(layer, ResidualBlock): 238 | for i in range(layer.num_repeats): 239 | ptr = self.load_CNN_weights(ptr, layer.layers[i][0]) 240 | ptr = self.load_CNN_weights(ptr, layer.layers[i][1]) 241 | 242 | elif isinstance(layer, ScalePrediction): 243 | # print("Starting scale prediction route") 244 | cnn_block = layer.pred[0] 245 | last_block = layer.pred[1] 246 | ptr = self.load_CNN_weights(ptr, cnn_block) 247 | ptr = self.load_CNN_weights(ptr, last_block) 248 | 249 | # ptr = self.load_CNN_weights(ptr, cnn_block) 250 | # print("Scale prediction ") 251 | 252 | print(ptr) 253 | 254 | 255 | if __name__ == "__main__": 256 | 257 | model = YOLOv3() 258 | model.load_darknet_weights(weights_path="yolov31.weights") 259 | model.layers[15].pred[1] = CNNBlock(1024, 25 * 3, bn_act=False, kernel_size=1) 260 | model.layers[22].pred[1] = CNNBlock(512, 25 * 3, bn_act=False, kernel_size=1) 261 | model.layers[29].pred[1] = CNNBlock(256, 25 * 3, bn_act=False, kernel_size=1) 262 | optimizer = optim.Adam(model.parameters(), lr=1e-4) 263 | from utils import save_checkpoint 264 | 265 | checkpoint = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()} 266 | save_checkpoint(checkpoint) 267 | 268 | import sys 269 | 270 | sys.exit() 271 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import config 2 | import matplotlib.pyplot as plt 3 | import matplotlib.patches as patches 4 | import numpy as np 5 | import os 6 | import random 7 | import torch 8 | 9 | from collections import Counter 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | 14 | def iou_width_height(boxes1, boxes2): 15 | """ 16 | Parameters: 17 | boxes1 (tensor): width and height of the first bounding boxes 18 | boxes2 (tensor): width and height of the second bounding boxes 19 | Returns: 20 | tensor: Intersection over union of the corresponding boxes 21 | """ 22 | intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min( 23 | boxes1[..., 1], boxes2[..., 1] 24 | ) 25 | union = ( 26 | boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection 27 | ) 28 | return intersection / union 29 | 30 | 31 | def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"): 32 | """ 33 | Video explanation of this function: 34 | https://youtu.be/XXYG5ZWtjj0 35 | 36 | This function calculates intersection over union (iou) given pred boxes 37 | and target boxes. 38 | 39 | Parameters: 40 | boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4) 41 | boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4) 42 | box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2) 43 | 44 | Returns: 45 | tensor: Intersection over union for all examples 46 | """ 47 | 48 | if box_format == "midpoint": 49 | box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2 50 | box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2 51 | box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2 52 | box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2 53 | box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2 54 | box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2 55 | box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2 56 | box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2 57 | 58 | if box_format == "corners": 59 | box1_x1 = boxes_preds[..., 0:1] 60 | box1_y1 = boxes_preds[..., 1:2] 61 | box1_x2 = boxes_preds[..., 2:3] 62 | box1_y2 = boxes_preds[..., 3:4] 63 | box2_x1 = boxes_labels[..., 0:1] 64 | box2_y1 = boxes_labels[..., 1:2] 65 | box2_x2 = boxes_labels[..., 2:3] 66 | box2_y2 = boxes_labels[..., 3:4] 67 | 68 | x1 = torch.max(box1_x1, box2_x1) 69 | y1 = torch.max(box1_y1, box2_y1) 70 | x2 = torch.min(box1_x2, box2_x2) 71 | y2 = torch.min(box1_y2, box2_y2) 72 | 73 | intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) 74 | box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1)) 75 | box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1)) 76 | 77 | return intersection / (box1_area + box2_area - intersection + 1e-6) 78 | 79 | 80 | def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"): 81 | """ 82 | Video explanation of this function: 83 | https://youtu.be/YDkjWEN8jNA 84 | 85 | Does Non Max Suppression given bboxes 86 | 87 | Parameters: 88 | bboxes (list): list of lists containing all bboxes with each bboxes 89 | specified as [class_pred, prob_score, x1, y1, x2, y2] 90 | iou_threshold (float): threshold where predicted bboxes is correct 91 | threshold (float): threshold to remove predicted bboxes (independent of IoU) 92 | box_format (str): "midpoint" or "corners" used to specify bboxes 93 | 94 | Returns: 95 | list: bboxes after performing NMS given a specific IoU threshold 96 | """ 97 | 98 | assert type(bboxes) == list 99 | 100 | bboxes = [box for box in bboxes if box[1] > threshold] 101 | bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True) 102 | bboxes_after_nms = [] 103 | 104 | while bboxes: 105 | chosen_box = bboxes.pop(0) 106 | 107 | bboxes = [ 108 | box 109 | for box in bboxes 110 | if box[0] != chosen_box[0] 111 | or intersection_over_union( 112 | torch.tensor(chosen_box[2:]), 113 | torch.tensor(box[2:]), 114 | box_format=box_format, 115 | ) 116 | < iou_threshold 117 | ] 118 | 119 | bboxes_after_nms.append(chosen_box) 120 | 121 | return bboxes_after_nms 122 | 123 | 124 | def mean_average_precision( 125 | pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20 126 | ): 127 | """ 128 | Video explanation of this function: 129 | https://youtu.be/FppOzcDvaDI 130 | 131 | This function calculates mean average precision (mAP) 132 | 133 | Parameters: 134 | pred_boxes (list): list of lists containing all bboxes with each bboxes 135 | specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2] 136 | true_boxes (list): Similar as pred_boxes except all the correct ones 137 | iou_threshold (float): threshold where predicted bboxes is correct 138 | box_format (str): "midpoint" or "corners" used to specify bboxes 139 | num_classes (int): number of classes 140 | 141 | Returns: 142 | float: mAP value across all classes given a specific IoU threshold 143 | """ 144 | 145 | # list storing all AP for respective classes 146 | average_precisions = [] 147 | 148 | # used for numerical stability later on 149 | epsilon = 1e-6 150 | 151 | for c in range(num_classes): 152 | detections = [] 153 | ground_truths = [] 154 | 155 | # Go through all predictions and targets, 156 | # and only add the ones that belong to the 157 | # current class c 158 | for detection in pred_boxes: 159 | if detection[1] == c: 160 | detections.append(detection) 161 | 162 | for true_box in true_boxes: 163 | if true_box[1] == c: 164 | ground_truths.append(true_box) 165 | 166 | # find the amount of bboxes for each training example 167 | # Counter here finds how many ground truth bboxes we get 168 | # for each training example, so let's say img 0 has 3, 169 | # img 1 has 5 then we will obtain a dictionary with: 170 | # amount_bboxes = {0:3, 1:5} 171 | amount_bboxes = Counter([gt[0] for gt in ground_truths]) 172 | 173 | # We then go through each key, val in this dictionary 174 | # and convert to the following (w.r.t same example): 175 | # ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]} 176 | for key, val in amount_bboxes.items(): 177 | amount_bboxes[key] = torch.zeros(val) 178 | 179 | # sort by box probabilities which is index 2 180 | detections.sort(key=lambda x: x[2], reverse=True) 181 | TP = torch.zeros((len(detections))) 182 | FP = torch.zeros((len(detections))) 183 | total_true_bboxes = len(ground_truths) 184 | 185 | # If none exists for this class then we can safely skip 186 | if total_true_bboxes == 0: 187 | continue 188 | 189 | for detection_idx, detection in enumerate(detections): 190 | # Only take out the ground_truths that have the same 191 | # training idx as detection 192 | ground_truth_img = [ 193 | bbox for bbox in ground_truths if bbox[0] == detection[0] 194 | ] 195 | 196 | num_gts = len(ground_truth_img) 197 | best_iou = 0 198 | 199 | for idx, gt in enumerate(ground_truth_img): 200 | iou = intersection_over_union( 201 | torch.tensor(detection[3:]), 202 | torch.tensor(gt[3:]), 203 | box_format=box_format, 204 | ) 205 | 206 | if iou > best_iou: 207 | best_iou = iou 208 | best_gt_idx = idx 209 | 210 | if best_iou > iou_threshold: 211 | # only detect ground truth detection once 212 | if amount_bboxes[detection[0]][best_gt_idx] == 0: 213 | # true positive and add this bounding box to seen 214 | TP[detection_idx] = 1 215 | amount_bboxes[detection[0]][best_gt_idx] = 1 216 | else: 217 | FP[detection_idx] = 1 218 | 219 | # if IOU is lower then the detection is a false positive 220 | else: 221 | FP[detection_idx] = 1 222 | 223 | TP_cumsum = torch.cumsum(TP, dim=0) 224 | FP_cumsum = torch.cumsum(FP, dim=0) 225 | recalls = TP_cumsum / (total_true_bboxes + epsilon) 226 | precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon) 227 | precisions = torch.cat((torch.tensor([1]), precisions)) 228 | recalls = torch.cat((torch.tensor([0]), recalls)) 229 | # torch.trapz for numerical integration 230 | average_precisions.append(torch.trapz(precisions, recalls)) 231 | 232 | return sum(average_precisions) / len(average_precisions) 233 | 234 | 235 | def plot_image(image, boxes): 236 | """Plots predicted bounding boxes on the image""" 237 | cmap = plt.get_cmap("tab20b") 238 | class_labels = config.COCO_LABELS if config.DATASET=='COCO' else config.PASCAL_CLASSES 239 | colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))] 240 | im = np.array(image) 241 | height, width, _ = im.shape 242 | 243 | # Create figure and axes 244 | fig, ax = plt.subplots(1) 245 | # Display the image 246 | ax.imshow(im) 247 | 248 | # box[0] is x midpoint, box[2] is width 249 | # box[1] is y midpoint, box[3] is height 250 | 251 | # Create a Rectangle patch 252 | for box in boxes: 253 | assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height" 254 | class_pred = box[0] 255 | box = box[2:] 256 | upper_left_x = box[0] - box[2] / 2 257 | upper_left_y = box[1] - box[3] / 2 258 | rect = patches.Rectangle( 259 | (upper_left_x * width, upper_left_y * height), 260 | box[2] * width, 261 | box[3] * height, 262 | linewidth=2, 263 | edgecolor=colors[int(class_pred)], 264 | facecolor="none", 265 | ) 266 | # Add the patch to the Axes 267 | ax.add_patch(rect) 268 | plt.text( 269 | upper_left_x * width, 270 | upper_left_y * height, 271 | s=class_labels[int(class_pred)], 272 | color="white", 273 | verticalalignment="top", 274 | bbox={"color": colors[int(class_pred)], "pad": 0}, 275 | ) 276 | 277 | plt.show() 278 | 279 | 280 | def get_evaluation_bboxes( 281 | loader, 282 | model, 283 | iou_threshold, 284 | anchors, 285 | threshold, 286 | box_format="midpoint", 287 | device="cuda", 288 | ): 289 | # make sure model is in eval before get bboxes 290 | model.eval() 291 | train_idx = 0 292 | all_pred_boxes = [] 293 | all_true_boxes = [] 294 | for batch_idx, (x, labels) in enumerate(tqdm(loader)): 295 | x = x.to(device) 296 | 297 | with torch.no_grad(): 298 | predictions = model(x) 299 | 300 | batch_size = x.shape[0] 301 | bboxes = [[] for _ in range(batch_size)] 302 | for i in range(3): 303 | S = predictions[i].shape[2] 304 | anchor = torch.tensor([*anchors[i]]).to(device) * S 305 | boxes_scale_i = cells_to_bboxes( 306 | predictions[i], anchor, S=S, is_preds=True 307 | ) 308 | for idx, (box) in enumerate(boxes_scale_i): 309 | bboxes[idx] += box 310 | 311 | # we just want one bbox for each label, not one for each scale 312 | true_bboxes = cells_to_bboxes( 313 | labels[2], anchor, S=S, is_preds=False 314 | ) 315 | 316 | for idx in range(batch_size): 317 | nms_boxes = non_max_suppression( 318 | bboxes[idx], 319 | iou_threshold=iou_threshold, 320 | threshold=threshold, 321 | box_format=box_format, 322 | ) 323 | 324 | for nms_box in nms_boxes: 325 | all_pred_boxes.append([train_idx] + nms_box) 326 | 327 | for box in true_bboxes[idx]: 328 | if box[1] > threshold: 329 | all_true_boxes.append([train_idx] + box) 330 | 331 | train_idx += 1 332 | 333 | model.train() 334 | return all_pred_boxes, all_true_boxes 335 | 336 | 337 | def cells_to_bboxes(predictions, anchors, S, is_preds=True): 338 | """ 339 | Scales the predictions coming from the model to 340 | be relative to the entire image such that they for example later 341 | can be plotted or. 342 | INPUT: 343 | predictions: tensor of size (N, 3, S, S, num_classes+5) 344 | anchors: the anchors used for the predictions 345 | S: the number of cells the image is divided in on the width (and height) 346 | is_preds: whether the input is predictions or the true bounding boxes 347 | OUTPUT: 348 | converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index, 349 | object score, bounding box coordinates 350 | """ 351 | BATCH_SIZE = predictions.shape[0] 352 | num_anchors = len(anchors) 353 | box_predictions = predictions[..., 1:5] 354 | if is_preds: 355 | anchors = anchors.reshape(1, len(anchors), 1, 1, 2) 356 | box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2]) 357 | box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors 358 | scores = torch.sigmoid(predictions[..., 0:1]) 359 | best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1) 360 | else: 361 | scores = predictions[..., 0:1] 362 | best_class = predictions[..., 5:6] 363 | 364 | cell_indices = ( 365 | torch.arange(S) 366 | .repeat(predictions.shape[0], 3, S, 1) 367 | .unsqueeze(-1) 368 | .to(predictions.device) 369 | ) 370 | x = 1 / S * (box_predictions[..., 0:1] + cell_indices) 371 | y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4)) 372 | w_h = 1 / S * box_predictions[..., 2:4] 373 | converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(BATCH_SIZE, num_anchors * S * S, 6) 374 | return converted_bboxes.tolist() 375 | 376 | def check_class_accuracy(model, loader, threshold): 377 | model.eval() 378 | tot_class_preds, correct_class = 0, 0 379 | tot_noobj, correct_noobj = 0, 0 380 | tot_obj, correct_obj = 0, 0 381 | 382 | for idx, (x, y) in enumerate(tqdm(loader)): 383 | if idx == 100: 384 | break 385 | x = x.to(config.DEVICE) 386 | with torch.no_grad(): 387 | out = model(x) 388 | 389 | for i in range(3): 390 | y[i] = y[i].to(config.DEVICE) 391 | obj = y[i][..., 0] == 1 # in paper this is Iobj_i 392 | noobj = y[i][..., 0] == 0 # in paper this is Iobj_i 393 | 394 | correct_class += torch.sum( 395 | torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj] 396 | ) 397 | tot_class_preds += torch.sum(obj) 398 | 399 | obj_preds = torch.sigmoid(out[i][..., 0]) > threshold 400 | correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj]) 401 | tot_obj += torch.sum(obj) 402 | correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj]) 403 | tot_noobj += torch.sum(noobj) 404 | 405 | print(f"Class accuracy is: {(correct_class/(tot_class_preds+1e-16))*100:2f}%") 406 | print(f"No obj accuracy is: {(correct_noobj/(tot_noobj+1e-16))*100:2f}%") 407 | print(f"Obj accuracy is: {(correct_obj/(tot_obj+1e-16))*100:2f}%") 408 | model.train() 409 | 410 | 411 | def get_mean_std(loader): 412 | # var[X] = E[X**2] - E[X]**2 413 | channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0 414 | 415 | for data, _ in tqdm(loader): 416 | channels_sum += torch.mean(data, dim=[0, 2, 3]) 417 | channels_sqrd_sum += torch.mean(data ** 2, dim=[0, 2, 3]) 418 | num_batches += 1 419 | 420 | mean = channels_sum / num_batches 421 | std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5 422 | 423 | return mean, std 424 | 425 | 426 | def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"): 427 | print("=> Saving checkpoint") 428 | checkpoint = { 429 | "state_dict": model.state_dict(), 430 | "optimizer": optimizer.state_dict(), 431 | } 432 | torch.save(checkpoint, filename) 433 | 434 | 435 | def load_checkpoint(checkpoint_file, model, optimizer, lr): 436 | print("=> Loading checkpoint") 437 | checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE) 438 | model.load_state_dict(checkpoint["state_dict"]) 439 | optimizer.load_state_dict(checkpoint["optimizer"]) 440 | 441 | # If we don't do this then it will just have learning rate of old checkpoint 442 | # and it will lead to many hours of debugging \: 443 | for param_group in optimizer.param_groups: 444 | param_group["lr"] = lr 445 | 446 | 447 | def get_loaders(train_csv_path, test_csv_path): 448 | from dataset import YOLODataset 449 | 450 | IMAGE_SIZE = config.IMAGE_SIZE 451 | train_dataset = YOLODataset( 452 | train_csv_path, 453 | transform=config.train_transforms, 454 | S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8], 455 | img_dir=config.IMG_DIR, 456 | label_dir=config.LABEL_DIR, 457 | anchors=config.ANCHORS, 458 | ) 459 | test_dataset = YOLODataset( 460 | test_csv_path, 461 | transform=config.test_transforms, 462 | S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8], 463 | img_dir=config.IMG_DIR, 464 | label_dir=config.LABEL_DIR, 465 | anchors=config.ANCHORS, 466 | ) 467 | train_loader = DataLoader( 468 | dataset=train_dataset, 469 | batch_size=config.BATCH_SIZE, 470 | num_workers=config.NUM_WORKERS, 471 | pin_memory=config.PIN_MEMORY, 472 | shuffle=True, 473 | drop_last=False, 474 | ) 475 | test_loader = DataLoader( 476 | dataset=test_dataset, 477 | batch_size=config.BATCH_SIZE, 478 | num_workers=config.NUM_WORKERS, 479 | pin_memory=config.PIN_MEMORY, 480 | shuffle=False, 481 | drop_last=False, 482 | ) 483 | 484 | train_eval_dataset = YOLODataset( 485 | train_csv_path, 486 | transform=config.test_transforms, 487 | S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8], 488 | img_dir=config.IMG_DIR, 489 | label_dir=config.LABEL_DIR, 490 | anchors=config.ANCHORS, 491 | ) 492 | train_eval_loader = DataLoader( 493 | dataset=train_eval_dataset, 494 | batch_size=config.BATCH_SIZE, 495 | num_workers=config.NUM_WORKERS, 496 | pin_memory=config.PIN_MEMORY, 497 | shuffle=False, 498 | drop_last=False, 499 | ) 500 | 501 | return train_loader, test_loader, train_eval_loader 502 | 503 | def plot_couple_examples(model, loader, thresh, iou_thresh, anchors): 504 | model.eval() 505 | x, y = next(iter(loader)) 506 | x = x.to("cuda") 507 | with torch.no_grad(): 508 | out = model(x) 509 | bboxes = [[] for _ in range(x.shape[0])] 510 | for i in range(3): 511 | batch_size, A, S, _, _ = out[i].shape 512 | anchor = anchors[i] 513 | boxes_scale_i = cells_to_bboxes( 514 | out[i], anchor, S=S, is_preds=True 515 | ) 516 | for idx, (box) in enumerate(boxes_scale_i): 517 | bboxes[idx] += box 518 | 519 | model.train() 520 | 521 | for i in range(batch_size): 522 | nms_boxes = non_max_suppression( 523 | bboxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint", 524 | ) 525 | plot_image(x[i].permute(1,2,0).detach().cpu(), nms_boxes) 526 | 527 | 528 | 529 | def seed_everything(seed=42): 530 | os.environ['PYTHONHASHSEED'] = str(seed) 531 | random.seed(seed) 532 | np.random.seed(seed) 533 | torch.manual_seed(seed) 534 | torch.cuda.manual_seed(seed) 535 | torch.cuda.manual_seed_all(seed) 536 | torch.backends.cudnn.deterministic = True 537 | torch.backends.cudnn.benchmark = False 538 | -------------------------------------------------------------------------------- /Implementing and training YOLOv3 - Medium.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "explicit-netscape", 6 | "metadata": {}, 7 | "source": [ 8 | "# Implementing and training YOLOv3 from scratch in PyTorch\n", 9 | "\n", 10 | "\n", 11 | "For such a popular paper there are still few implementations explained of the YOLOv3 architecture completely from scratch. I'll do my best to add something useful to the list. The code is written together with [Aladdin Persson](https://www.youtube.com/channel/UCkzW5JSFwvKRjXABI-UTAkQ) and can be found on [github](https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/object_detection/YOLOv3). You can also download pretrained weights on the Pascal-VOC that obtain 78.1 MAP [here](https://www.kaggle.com/sannapersson/yolov3-weights-for-pascal-voc-with-781-map) for the implementation below. \n", 12 | "\n", 13 | "### Prerequisites:\n", 14 | "* Understanding the major parts in YOLOv1\n", 15 | "* Coding in PyTorch\n", 16 | "* Familiarity with convolutional networks and their training\n", 17 | "\n", 18 | "With this article I hope to convey:\n", 19 | "* Understanding of the key ideas necessary for implementing and training YOLOv3 from scratch in PyTorch\n", 20 | "* Complete code to use for training of YOLOv3\n", 21 | "* The relevant details of the algorithm to succeed if you choose to make you own implementation of YOLOv3\n", 22 | "\n", 23 | "The code is completely runnable if you download a utils.py and config.py file from the github above containing a few supporting functions and constants not specific to the YOLOv3 model. \n", 24 | "\n", 25 | "_Disclaimer: there are minor differences between this implementation and the original and I will point them out when we get to them._\n", 26 | "\n", 27 | "## Understanding the model \n", 28 | "Let's begin by understanding the fundamentals of the model. The YOLO (You only look once) algorithm is based on the idea that we divide the image into a grid with side $S$. The grid size depends upon which YOLO version we are implementing as well as the input image size but the details be clearer when we implement it. Each grid cell is responsible for making predictions of bounding boxes. You may then wonder what happens if an object covers several grid cells, will all predict a bounding box for the object? YOLO solves this by making only the cell containing the object's midpoint responsible for predicting the bounding box. This means that only one grid is responsible for each object's bounding box in the image. One drawback of this is that there can only be on bounding box in each grid cell. In YOLOv2 and forward they mitigate the issue by the making several bounding box predictions in the same grid cell. They also introduce anchor boxes which is an idea also seen in previous object detection papers such as Faster RCNN. \n", 29 | "![alt text](images/yolo_ex.png \"Title\")\n", 30 | "\n", 31 | "An anchor box is essentially a set of a width and a height chosen to represent a segment of the training data. For example a standing rectangle may suit a human while a wide rectangle is a better fit for a car. Using anchor boxes is a way of encoding knowledge about the training data into the model to help the model make appropriate predictions. It has been discussed whether this is actually desirable and there are more recent end-to-end approaches where anchor boxes are not used. The questions is then how to choose the anchors and an early approach was to hand design the anchors boxes by studying the training data, however, the authors of YOLOv2 found that using K-means clustering to generate them yielded better results. The anchors are used to allow the model to anchor its prediction to a predetermined box. The model will thus predict how much the true bounding box is offset in comparison with the anchor. This is one of the major differences from the original YOLO model. Each grid cell will have several anchor boxes and each anchor box can make one bounding box prediction. Each bounding box prediction will also be coupled with an object score as well as class predictions. The object score should reflect product of the probability that there is an object in the bounding box and the intersection over union between the predicted bounding box and the actual object. That means if there is no object in the grid cell corresponding to the specific anchor the target is zero and otherwise it is the intersection over union between the predicted box and the target bounding box. \n", 32 | "\n", 33 | "The predictions from the model $t_i$ are offsets to the anchors and will be converted to bounding boxes according to the following equations\n", 34 | "$$\n", 35 | "\\begin{array}{l}\n", 36 | "b_{x}=\\sigma\\left(t_{x}\\right)+c_{x} \\\\\n", 37 | "b_{y}=\\sigma\\left(t_{y}\\right)+c_{y} \\\\\n", 38 | "b_{w}=p_{w} e^{t_{w}} \\\\\n", 39 | "b_{h}=p_{h} e^{t_{h}}.\n", 40 | "\\end{array}\n", 41 | "$$\n", 42 | "where $p_w$ and $p_h$ are the width and height of the corresponding anchor boxes and $\\{b_x, b_y, b_w, b_h\\}$ is the resulting bounding box. \n", 43 | "\n", 44 | "In YOLOv3 the backbone network is DarkNet-53 and its structure can be understood from the following table. This network was pretrained on ImageNet and is used as a feature extractor in the YOLOv3 model. The paper, however, completely skips detailing the following 53 convolutional layers in the YOLOv3 model where the actual prediction of bounding boxes takes place in the model. \n", 45 | "![alt text](images/darknet_53_table.png \"Title\")\n", 46 | "\n", 47 | "The prediction of bounding boxes happens on three different places in the networks on three different scales. In this context a scale means the grid size, $S$, which we divide the image into. In YOLOv3 we predict bounding boxes on three different grid sizes. The intuition behind this is that on a coarser grid larger objects can more easily be detected and vice versa for smaller objects on finer grids. We therefore also divide the anchor boxes we have found such that we assign all the smallest anchors to the last and finest scale and the largest anchor boxes to the coarsest grid. In YOLOv3 the grid sizes used are [13, 26, 52] for an image size of 416x416. If you use another image size the first grid size will be the image size divided by 32 and the others will be a multiple of two of the previous one. The details of the model will be clear when we implement it but the following image by [Ayoosh Kathuria](https://medium.com/@ayoosh) (check out his Medium) gives great insight into the model architecture. \n", 48 | "\n", 49 | "![alt text](images/yolo_architecture.png \"Title\")\n", 50 | "\n", 51 | "The backbone network is a standard convolutional network quite similar to previous Darknet versions with the addition of residual connections. It is really after layer 53 that the interesting parts happen. As the architecture image visualizes there are three downward paths corresponding to predictions of three different grid scales. The network then continues forward from the place it was before the prediction path. After the first and second scale prediction paths there is an upscaling layer to double the size of the feature map and concatenates the feature mapes with a route from a previous layer along the channel dimension. The image details which convolutional layers the routes come from but we will instead use a trick to find them in our implementation.\n", 52 | "\n", 53 | "We are now ready to start actually coding the model. All model details are found in the configuration file for YOLOv3 on [Joseph Redmon's Github](https://github.com/pjreddie) who is the author of the paper. \n", 54 | "\n", 55 | "## Coding the model\n", 56 | "This is the part of the YOLOv3 implementation that I spent both the least and the most time on debugging. I found it manageable to make the model work but it took some time to correct details to make sure the original weights could be loaded. \n", 57 | "Everything in this section will be in a model.py file on [Github](https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/object_detection/YOLOv3/model.py). Let's start with the imports:\n", 58 | "\n", 59 | "\n" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "theoretical-asset", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "import torch\n", 70 | "import torch.nn as nn\n", 71 | "import torch.optim as optim" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "id": "technological-trail", 77 | "metadata": {}, 78 | "source": [ 79 | "First we will define the architecture building blocks in a list as a way of parsing the original config file that majorly increases the readibility and grasp of the complete model. " 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "systematic-advisory", 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "\"\"\" \n", 90 | "Information about architecture config:\n", 91 | "Tuple is structured by and signifies a convolutional block (filters, kernel_size, stride) \n", 92 | "Every convolutional layer is a same convolution. \n", 93 | "List is structured by \"B\" indicating a residual block followed by the number of repeats. \n", 94 | "\"S\" is for a scale prediction block and computing the yolo loss\n", 95 | "\"U\" is for upsampling the feature map\n", 96 | "\"\"\"\n", 97 | "config = [\n", 98 | " (32, 3, 1),\n", 99 | " (64, 3, 2),\n", 100 | " [\"B\", 1],\n", 101 | " (128, 3, 2),\n", 102 | " [\"B\", 2],\n", 103 | " (256, 3, 2),\n", 104 | " [\"B\", 8],\n", 105 | " # first route from the end of the previous block\n", 106 | " (512, 3, 2),\n", 107 | " [\"B\", 8], \n", 108 | " # second route from the end of the previous block\n", 109 | " (1024, 3, 2),\n", 110 | " [\"B\", 4],\n", 111 | " # until here is YOLO-53\n", 112 | " (512, 1, 1),\n", 113 | " (1024, 3, 1),\n", 114 | " \"S\",\n", 115 | " (256, 1, 1),\n", 116 | " \"U\",\n", 117 | " (256, 1, 1),\n", 118 | " (512, 3, 1),\n", 119 | " \"S\",\n", 120 | " (128, 1, 1),\n", 121 | " \"U\",\n", 122 | " (128, 1, 1),\n", 123 | " (256, 3, 1),\n", 124 | " \"S\",\n", 125 | "]" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "id": "comparative-guyana", 131 | "metadata": {}, 132 | "source": [ 133 | "#### Defining the building blocks\n", 134 | "We will now define the most common building blocks of the architecture as separate classes to avoid repeating code over and over again. Each tuple signifies a convolutional block with batch normalization and leaky relu added to it. " 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "surface-diana", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "class CNNBlock(nn.Module):\n", 145 | " def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):\n", 146 | " super(CNNBlock, self).__init__()\n", 147 | " self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)\n", 148 | " self.bn = nn.BatchNorm2d(out_channels)\n", 149 | " self.leaky = nn.LeakyReLU(0.1)\n", 150 | " self.use_bn_act = bn_act\n", 151 | "\n", 152 | " def forward(self, x):\n", 153 | " if self.use_bn_act:\n", 154 | " return self.leaky(self.bn(self.conv(x)))\n", 155 | " else:\n", 156 | " return self.conv(x)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "id": "stopped-bridges", 162 | "metadata": {}, 163 | "source": [ 164 | "This layer also allows us to toggle the `bn_act` to false and skip the batch normalization and activation function which we will use in the last layer before output. In the case where we use batch normalization the bias term of the convolutional layer will have to effect but occupying VRAM. \n", 165 | "\n", 166 | "We then define the residual block which is essentially a combination of two convolutional blocks with a residual connection. The number of channels will be halved in the first convolutional layer and then doubled again in the second. The input size will therefore be maintained through the residual block. As in the `CNNBlock` we will have an argument to allow us to skip the residual connection which we will use in parts of the architecture. " 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "lightweight-possession", 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "class ResidualBlock(nn.Module):\n", 177 | " def __init__(self, channels, use_residual=True, num_repeats=1):\n", 178 | " super(ResidualBlock, self).__init__()\n", 179 | " self.layers = nn.ModuleList()\n", 180 | " for repeat in range(num_repeats):\n", 181 | " self.layers += [\n", 182 | " nn.Sequential(\n", 183 | " CNNBlock(channels, channels // 2, kernel_size=1),\n", 184 | " CNNBlock(channels // 2, channels, kernel_size=3, padding=1),\n", 185 | " )\n", 186 | " ]\n", 187 | "\n", 188 | " self.use_residual = use_residual\n", 189 | " self.num_repeats = num_repeats\n", 190 | "\n", 191 | " def forward(self, x):\n", 192 | " for layer in self.layers:\n", 193 | " x = layer(x) + self.use_residual * x\n", 194 | "\n", 195 | " return x" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "id": "departmental-tribune", 201 | "metadata": {}, 202 | "source": [ 203 | "The last predefined block we will use is the ScalePrediction which is the last two convolutional layers leading up to the prediction for each scale. Here the image of the architecture above actually is slightly incorrect and this block includes the downward path except for the loss function. We will reshape the output such that it has the the shape (batch size, anchors per scale, grid size, grid size, 5 + number of classes) where 5 refers to the object score and four bounding box coordinates. To obtain this shape we have to permute the output such that the class predictions end up in the last dimension." 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "id": "pregnant-injury", 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "class ScalePrediction(nn.Module):\n", 214 | " def __init__(self, in_channels, num_classes, anchors_per_scale):\n", 215 | " super(ScalePrediction, self).__init__()\n", 216 | " self.pred = nn.Sequential(\n", 217 | " CNNBlock(in_channels, 2*in_channels, kernel_size=3, padding=1),\n", 218 | " CNNBlock(2*in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1),\n", 219 | " )\n", 220 | " self.num_classes = num_classes\n", 221 | " self.anchors_per_scale = anchors_per_scale\n", 222 | "\n", 223 | " def forward(self, x):\n", 224 | " return (\n", 225 | " self.pred(x)\n", 226 | " .reshape(x.shape[0], self.anchors_per_scale, self.num_classes + 5, x.shape[2], x.shape[3])\n", 227 | " .permute(0, 1, 3, 4, 2)\n", 228 | " )" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "id": "intelligent-township", 234 | "metadata": {}, 235 | "source": [ 236 | "### Putting it together in YOLOv3\n", 237 | "We will now put it all together to the YOLOv3 model for the detection task. Most of the action takes place in the `_create_conv_layers` function where we build the model using the blocks defined above. Essentially we will just loop through the config list that we created above and add the blocks defined above in the correct order. " 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "id": "homeless-retro", 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "class YOLOv3(nn.Module):\n", 248 | " def __init__(self, in_channels=3, num_classes=80):\n", 249 | " super(YOLOv3, self).__init__()\n", 250 | " self.num_classes = num_classes\n", 251 | " self.in_channels = in_channels\n", 252 | " self.layers = self._create_conv_layers()\n", 253 | "\n", 254 | " def forward(self, x):\n", 255 | " outputs = []\n", 256 | " route_connections = []\n", 257 | " for layer in self.layers:\n", 258 | " if isinstance(layer, ScalePrediction):\n", 259 | " outputs.append(layer(x))\n", 260 | " continue\n", 261 | "\n", 262 | " x = layer(x)\n", 263 | "\n", 264 | " if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:\n", 265 | " route_connections.append(x)\n", 266 | "\n", 267 | " elif isinstance(layer, nn.Upsample):\n", 268 | " x = torch.cat([x, route_connections[-1]], dim=1)\n", 269 | " route_connections.pop()\n", 270 | "\n", 271 | " return outputs\n", 272 | "\n", 273 | "\n", 274 | " def _create_conv_layers(self):\n", 275 | " layers = nn.ModuleList()\n", 276 | " in_channels = self.in_channels\n", 277 | "\n", 278 | " for module in config:\n", 279 | " if isinstance(module, tuple):\n", 280 | " out_channels, kernel_size, stride = module\n", 281 | " layers.append(\n", 282 | " CNNBlock(\n", 283 | " in_channels,\n", 284 | " out_channels,\n", 285 | " kernel_size=kernel_size,\n", 286 | " stride=stride,\n", 287 | " padding=1 if kernel_size == 3 else 0,\n", 288 | " )\n", 289 | " )\n", 290 | " in_channels = out_channels\n", 291 | "\n", 292 | " elif isinstance(module, list):\n", 293 | " num_repeats = module[1]\n", 294 | " layers.append(\n", 295 | " ResidualBlock(\n", 296 | " in_channels,\n", 297 | " num_repeats=num_repeats,\n", 298 | " )\n", 299 | " )\n", 300 | "\n", 301 | " elif isinstance(module, str):\n", 302 | " if module == \"S\":\n", 303 | " layers += [\n", 304 | " ResidualBlock(in_channels, use_residual=False, num_repeats=1),\n", 305 | " CNNBlock(in_channels, in_channels // 2, kernel_size=1),\n", 306 | " ScalePrediction(in_channels // 2, num_classes=self.num_classes),\n", 307 | " ]\n", 308 | " in_channels = in_channels // 2\n", 309 | "\n", 310 | " elif module == \"U\":\n", 311 | " layers.append(\n", 312 | " nn.Upsample(scale_factor=2),\n", 313 | " )\n", 314 | " in_channels = in_channels * 3\n", 315 | "\n", 316 | " return layers" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "id": "military-greene", 322 | "metadata": {}, 323 | "source": [ 324 | "The trickiest part here is in the case in where there is an `\"S\"` in the config list which means that we are on the last layers towards a prediction on a specific scale. In these cases we will have three convolutional layers (one residual block and one convolutional block) following the same pattern on all prediction scales. To avoid creating a mess in the config list it is easiest to just add them here before the ScalePrediction. \n", 325 | "\n", 326 | "It should also be noted that we triple the `in_channels` after we add the upsamling layer and this is due to the route that we will concatenate in the forward propagation that has twice as many channels as the output from the upsampling layer. \n", 327 | "\n", 328 | "This leads us into the structure of the forward function. In the first if statement we check if the layer is a `ScalePrediction` block and in this case we will append its output to a list and later on compute the loss for each of the predictions separetely. We will then continue on in the model from the place the Scaleprediction started.\n", 329 | "\n", 330 | "I earlier mentioned that we will use a trick to find the layers that are routed forward. The second if-statement will take care of this and find the route layers specified in the image of the architecture above without us keeping track of unnecessarily complicated indices. The two routes will be the outputs from the residual blocks in the config list which have 8 repeats which we found this by just reading the original model configuration carefully. When we encounter an upsamling layer we will concatenate the output with the last route previously found following the image of the architecture above. \n", 331 | "\n", 332 | "Before we move on to the data loading I'll add a test function below that acts as a sanity check that the model at least outputs the correct shapes. " 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": null, 338 | "id": "regular-label", 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "def test():\n", 343 | " num_classes = 20\n", 344 | " model = YOLOv3(num_classes=num_classes)\n", 345 | " img_size = 416\n", 346 | " x = torch.randn((2, 3, img_size, img_size))\n", 347 | " out = model(x)\n", 348 | " assert out[0].shape == (2, 3, img_size//32, img_size//32, 5 + num_classes)\n", 349 | " assert out[1].shape == (2, 3, img_size//16, img_size//16, 5 + num_classes)\n", 350 | " assert out[2].shape == (2, 3, img_size//8, img_size//8, 5 + num_classes)\n", 351 | "\n", 352 | "test()" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "id": "infinite-helena", 358 | "metadata": {}, 359 | "source": [ 360 | "## Loading the data\n", 361 | "In the dataset class we will load an image and the corresponding bounding boxes, perform augmentation using the [Albumentations library](https://albumentations.ai/) and then create the matrix form of the target that will be used to compute the loss. If you are not familiar with the Albumentations library it is a augmentations library with official support for PyTorch that can be used for data augmentation for detection, segmentation and other tasks which requires that the augmentations are performed both on the image and the target. \n", 362 | "\n", 363 | "We earlier mentioned that each scale will have anchor boxes associated with them and in the data loading we will compute which cell and which anchor that should be responsible for the particular target bounding box. Everything in this section will be in a dataset.py file. \n", 364 | "\n", 365 | "\n", 366 | "### Imports\n", 367 | "Most of the imports we are using are standard for the dataset class in PyTorch with the additional Albumentations package for the data augmentation. The imports from the utils, however, require some additional explanation. In a utils.py file that you can find on [Github](https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/object_detection/YOLOv3/utils.py) we will store some functions for handling bounding boxes conversions, [non-max suppression](https://youtu.be/YDkjWEN8jNA) and [mean average precision](https://youtu.be/FppOzcDvaDI). The only function that we will use in the data loading is the [intersection over union](https://www.youtube.com/watch?v=XXYG5ZWtjj0) function taking as input two tensors with the width and height of bounding boxes and outputting the corresponding intersection over union. The other files we import from the utils are only for checking that the data loading actually works. Plotting images and bounding boxes each time you modify the dataset class or augmentations can save you a lot of debugging time. " 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": null, 373 | "id": "loose-detection", 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "import config\n", 378 | "import numpy as np\n", 379 | "import os\n", 380 | "import pandas as pd\n", 381 | "import torch\n", 382 | "\n", 383 | "from PIL import Image, ImageFile\n", 384 | "from torch.utils.data import Dataset, DataLoader\n", 385 | "from utils import (\n", 386 | " cells_to_bboxes, # only for testing\n", 387 | " iou_width_height as iou,\n", 388 | " non_max_suppression as nms, # only for testing\n", 389 | " plot_image #only for testing\n", 390 | ")\n", 391 | "\n", 392 | "ImageFile.LOAD_TRUNCATED_IMAGES = True\n" 393 | ] 394 | }, 395 | { 396 | "cell_type": "markdown", 397 | "id": "roman-graphic", 398 | "metadata": {}, 399 | "source": [ 400 | "### Data format\n", 401 | "\n", 402 | "The part of the data loading that is different from image classification is the way we process the bounding boxes and format them such that they can be inputted to the model. The data loading below assumes that the data is formatted such that you have a folder with all images, a folder with a text file for each image detailing the bounding boxes and one or several csv files for the train, development and test set. The text file for an image should be formatted such that each row corresponds to a bounding box of the image with class label, x coordinate y coordinate, width, height in that specific order. The bounding box coordinates should be relative to the image such that if an object has midpoint in the middle of the image and covers it in half in both width and height we would specify: class label 0.5 0.5 0.5 0.5, on a row in the text file. In the csv file you want to specify the image file name and the text file name in two different columns. \n", 403 | "\n", 404 | "If you just want to get started without having to format the data you can download the Pascal-VOC dataset from Kaggle [here](https://www.kaggle.com/aladdinpersson/pascal-voc-yolo-works-with-albumentations) where the data is already formatted. \n", 405 | "\n", 406 | "Even if your dataset is not formatted this way it should be manageable to modify the data loading such that you can still make the training labels the same way. \n", 407 | "\n", 408 | "### Dataset class overview\n", 409 | "In a Pytorch dataset there are three building blocks: the init-method, the dataset length and the \\_\\_getitem\\_\\_-method. \n", 410 | "\n", 411 | "\n", 412 | "The important part in dataset class is how we handle the anchor boxes. We will specify the anchor boxes in the following manner\n", 413 | "```python\n", 414 | "ANCHORS = [\n", 415 | " [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],\n", 416 | " [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],\n", 417 | " [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],\n", 418 | "] \n", 419 | "```\n", 420 | "where each tuple corresponds to the width and the height of a anchor box relative to the image size and each list grouping together three tuples correspond to the anchors used on a specific prediction scale. The first list contains the largest anchor boxes which will be used for prediction on the coarsest grid where its presumably easier to predict larger bounding boxes. The following lists containing medium and small anchor boxes will be used for the medium and finest grid following the same reasoning. The anchors above are the ones used in the original paper but have beeen scaled to be relative to the image. \n", 421 | "\n", 422 | "Even if you are training on another dataset these anchors will probably work quite well, however, if your dataset is very different from MSCOCO you would probably generate your own anchor boxes and then it is probably wise to assign the anchor boxes to the different scales by their size as was done in the paper. In this case you would collect data of the widths and heights of the bounding boxes in your dataset and run these through K-means clustering with the intersection of union as the distance measure. The resulting centroids would be your anchor boxes. \n", 423 | "\n", 424 | "Below is the complete dataset class. We will load an image and its bounding boxes and perform augmentations on both. For each bounding box we will then assign it to the grid cell which contains its midpoint and decide which anchor is responsible for it by determining which anchor the bounding box has highest intersection over union with. Exactly how we build the targets is explained more in depth below the code. " 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": null, 430 | "id": "criminal-wonder", 431 | "metadata": {}, 432 | "outputs": [], 433 | "source": [ 434 | "\"\"\"\n", 435 | "Creates a Pytorch dataset to load the Pascal VOC & MS COCO datasets\n", 436 | "\"\"\"\n", 437 | "class YOLODataset(Dataset):\n", 438 | " def __init__(\n", 439 | " self,\n", 440 | " csv_file,\n", 441 | " img_dir,\n", 442 | " label_dir,\n", 443 | " anchors,\n", 444 | " image_size=416,\n", 445 | " S=[13, 26, 52],\n", 446 | " C=20,\n", 447 | " transform=None,\n", 448 | " ):\n", 449 | " self.annotations = pd.read_csv(csv_file)\n", 450 | " self.img_dir = img_dir\n", 451 | " self.label_dir = label_dir\n", 452 | " self.image_size = image_size\n", 453 | " self.transform = transform\n", 454 | " self.S = S\n", 455 | " self.anchors = torch.tensor(anchors[0] + anchors[1] + anchors[2]) # for all 3 scales\n", 456 | " self.num_anchors = self.anchors.shape[0]\n", 457 | " self.num_anchors_per_scale = self.num_anchors // 3\n", 458 | " self.C = C\n", 459 | " self.ignore_iou_thresh = 0.5\n", 460 | "\n", 461 | " def __len__(self):\n", 462 | " return len(self.annotations)\n", 463 | "\n", 464 | " def __getitem__(self, index):\n", 465 | " label_path = os.path.join(self.label_dir, self.annotations.iloc[index, 1])\n", 466 | " bboxes = np.roll(np.loadtxt(fname=label_path, delimiter=\" \", ndmin=2), 4, axis=1).tolist()\n", 467 | " img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])\n", 468 | " image = np.array(Image.open(img_path).convert(\"RGB\"))\n", 469 | "\n", 470 | " # apply augmentations with albumentations \n", 471 | " if self.transform:\n", 472 | " augmentations = self.transform(image=image, bboxes=bboxes)\n", 473 | " image = augmentations[\"image\"]\n", 474 | " bboxes = augmentations[\"bboxes\"]\n", 475 | " \n", 476 | " # Building the targets below:\n", 477 | " # Below assumes 3 scale predictions (as paper) and same num of anchors per scale\n", 478 | " targets = [torch.zeros((self.num_anchors // 3, S, S, 6)) for S in self.S]\n", 479 | " for box in bboxes:\n", 480 | " iou_anchors = iou(torch.tensor(box[2:4]), self.anchors)\n", 481 | " anchor_indices = iou_anchors.argsort(descending=True, dim=0)\n", 482 | " x, y, width, height, class_label = box\n", 483 | " has_anchor = [False] * 3 # each scale should have one anchor\n", 484 | " for anchor_idx in anchor_indices:\n", 485 | " scale_idx = anchor_idx // self.num_anchors_per_scale\n", 486 | " anchor_on_scale = anchor_idx % self.num_anchors_per_scale\n", 487 | " S = self.S[scale_idx]\n", 488 | " i, j = int(S * y), int(S * x) # which cell\n", 489 | " anchor_taken = targets[scale_idx][anchor_on_scale, i, j, 0]\n", 490 | " if not anchor_taken and not has_anchor[scale_idx]:\n", 491 | " targets[scale_idx][anchor_on_scale, i, j, 0] = 1\n", 492 | " x_cell, y_cell = S * x - j, S * y - i # both between [0,1]\n", 493 | " width_cell, height_cell = (\n", 494 | " width * S,\n", 495 | " height * S,\n", 496 | " ) # can be greater than 1 since it's relative to cell\n", 497 | " box_coordinates = torch.tensor(\n", 498 | " [x_cell, y_cell, width_cell, height_cell]\n", 499 | " )\n", 500 | " targets[scale_idx][anchor_on_scale, i, j, 1:5] = box_coordinates\n", 501 | " targets[scale_idx][anchor_on_scale, i, j, 5] = int(class_label)\n", 502 | " has_anchor[scale_idx] = True\n", 503 | "\n", 504 | " elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:\n", 505 | " targets[scale_idx][anchor_on_scale, i, j, 0] = -1 # ignore prediction\n", 506 | "\n", 507 | " return image, tuple(targets)\n", 508 | "\n", 509 | "\n" 510 | ] 511 | }, 512 | { 513 | "cell_type": "markdown", 514 | "id": "enormous-generator", 515 | "metadata": {}, 516 | "source": [ 517 | "In the init-metod we will just combine the list above to a tensor of shape (9,2) by `self.anchors = torch.tensor(anchors[0] + anchors[1] + anchors[2])` corresponding to each anchor box on all scales. We will also specify an ignore-threshold which will be used when building the targets as is explained below. \n", 518 | "\n", 519 | "The second challenging part of the data loading is in the getitem-method where we will load the image and the corresponding text file for the bounding boxes and process it such that we can input it to the model. For data augmentation we use the Albumentations library which requires the image and bounding boxes to be numpy arrays. The bounding boxes are also expected to be in the format \\[x, y, width, height, class label\\] which is different from how we have formatted it in the text file and we therefore use `np.roll` to change this. The reason for this inconsistency is that the text files are structured the same way as in the original implementation and if you are formatting a custom dataset you may consider modifying this if you are also using Albumentations. \n", 520 | "\n", 521 | "Here it should be noted that if you download Pascal-VOC or MS COCO dataset from the official sites or from Joseph Redmon's website you may run into some out of range issues when using Albumentations depending on how you convert the labels to the format x, y, width, height where (x,y) signifies the object's midpoint. If you do, make sure you have converted the labels as is specified in this [Github issue](https://github.com/albumentations-team/albumentations/issues/459) and you will save a couple of hours of debugging. \n", 522 | "\n", 523 | "### Building targets\n", 524 | "When we load the labels for a specific image it will only be an array with all the bounding boxes and to be able to calculate the loss we want to format the targets similarily to the model output. The model will output predictions on three different scales so we will also build three different targets. Each target for a particular scale and image will have shape (number of anchors // 3, grid size, grid size, 6) where 6 corresponds to the object score, four bounding box coordinates and class label. We make two assumptions which are that there is only one label per bounding box and that there is an equal number of bounding boxes on each scale. We start with initializing the three different target tensors to zeros with \n", 525 | "`targets = [torch.zeros((self.num_anchors // 3, S, S, 6)) for S in self.S]` where `self.S` is a list with the different grid sizes e.g. for an image size of 416x416 we have `S=[13, 26, 52]` or more general we have `S = [image_size// 32, image_size//16, image_size//8]`since at the prediction state the feature map will have be downscaled with the factors in the denominator.\n", 526 | "\n", 527 | "The next step is to loop through all the bounding boxes in this particular image. If you have a lot of bounding boxes this will be quite expensive but haven't yet figured out a way to remove this step without taking shortcuts when assigning the anchor boxes. Let me know if you have any ideas on how to optimize this! We will then compute the intersection over union between the target's width and height and all the anchor boxes and sort the result such that the index of the anchor with the largest intersection over union with the target box appears first in the list. \n", 528 | "```python\n", 529 | "iou_anchors = iou(torch.tensor(box[2:4]), self.anchors)\n", 530 | "anchor_indices = iou_anchors.argsort(descending=True, dim=0)\n", 531 | "```\n", 532 | "We will then loop through the nine indices to assign the target to the best anchors. Our goal is to assign each target bounding box to an anchor on each scale i.e. in total assign each target to one anchor in each of the target matrices we intialized above. In addition we will also check if an anchor is not the most suitable for the bounding box but it still has an intersection over union higher than 0.5 as is specified in the `ignore_iou_thresh` and then we will mark this target such that no loss is incurred for the prediction of this anchor box. From my understanding the reasoning behind this is that during inference this anchor could also make valid predictions on similar objects and non-max suppression will remove surplus bounding boxes. During training we therefore do not want to force the particular anchor to predict that there is not an object. We first compute which cell the bounding box belongs to by `i, j = int(S * y), int(S * x)` and then we check if the anchor we are currently at is taken in this cell by `anchor_taken = targets[scale_idx][anchor_on_scale, i, j, 0]`. As you can probably imagine it is relatively uncommon for most datasets to have two objects with midpoint in the same cell of such similar size that they fit the same anchor box, however, if you run this through a couple of hundred examples you'll notice it occurs several times on for example the Pascal-VOC dataset. In addition to checking if the particular anchor is taken, we also check if the current bounding box already has an anchor on this particular prediction scale. We only want one target anchor on each scale to allow for specialization between the anchor boxes such that they focus on prediction different kinds of objects. \n", 533 | "\n", 534 | "If we find an anchor that is unoccupied and our current bounding box does not have an anchor on the scale which the anchor belongs to, we want to assign this anchor to the bounding box. First we will set the object score on this anchor to 1 by: ` targets[scale_idx][anchor_on_scale, i, j, 0] = 1,` to indicate that there is an object in this cell. We then compute the box coordinates relative to the cell such the midpoint (x,y) states where in the cell the object is and the width and the height corresponds to how many cells the bounding box covers. This is computed by: \n", 535 | "```python\n", 536 | "x_cell, y_cell = S * x - j, S * y - i # both between [0,1]\n", 537 | "width_cell, height_cell = width * S, height * S # can be greater than 1 since it's relative to the cell\n", 538 | "```\n", 539 | "We will then add the bounding box coordinates as well as the class label to the cell and the anchor box indicated by `i`, `j` and `anchor_on_scale` respectively. Lastly we will update the flag `has_anchor[scale_idx]` to True to indicate that the particular prediction scale now has an anchor. \n", 540 | "\n", 541 | "Only doing the data loading in the way above would be sufficient. In the YOLOv3 paper they, however, also check if the anchor we are currently at has an intersection over union greater than `ignore_iou_thresh = 0.5` and then they do not incur loss for this anchor's prediction. We will do this by setting the object score of the anchor in the object cell to -1 i.e. `targets[scale_idx][anchor_on_scale, i, j, 0] = -1`. In the loss function we will later make sure that no loss is incurred for these anchors. \n", 542 | "\n", 543 | "To make sure that the data loading works it is beneficial to plot a few examples with augmentations added to them and the bounding boxes. The code below should do the trick, possibly with some modifications depending on how you structure the data. " 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": null, 549 | "id": "popular-circuit", 550 | "metadata": {}, 551 | "outputs": [], 552 | "source": [ 553 | "def test():\n", 554 | " anchors = config.ANCHORS\n", 555 | "\n", 556 | " transform = config.train_transforms\n", 557 | "\n", 558 | " dataset = YOLODataset(\n", 559 | " config.DATASET+'/train',\n", 560 | " config.IMG_DIR,\n", 561 | " config.LABEL_DIR,\n", 562 | " S=[13, 26, 52],\n", 563 | " anchors=anchors,\n", 564 | " transform=transform,\n", 565 | " )\n", 566 | " S = [13, 26, 52]\n", 567 | " scaled_anchors = torch.tensor(anchors) / (\n", 568 | " 1 / torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)\n", 569 | " )\n", 570 | " loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True)\n", 571 | " for x, y in loader:\n", 572 | " boxes = []\n", 573 | "\n", 574 | " for i in range(y[0].shape[1]):\n", 575 | " anchor = scaled_anchors[i]\n", 576 | " boxes += cells_to_bboxes(\n", 577 | " y[i], is_preds=False, S=y[i].shape[2], anchors=anchor\n", 578 | " )[0]\n", 579 | " boxes = nms(boxes, iou_threshold=1, threshold=0.7, box_format=\"midpoint\")\n", 580 | " print(boxes)\n", 581 | " plot_image(x[0].permute(1, 2, 0).to(\"cpu\"), boxes)\n", 582 | "\n", 583 | "\n", 584 | "if __name__ == \"__main__\":\n", 585 | " test()" 586 | ] 587 | }, 588 | { 589 | "cell_type": "markdown", 590 | "id": "associate-houston", 591 | "metadata": {}, 592 | "source": [ 593 | "## YOLOv3 loss function\n", 594 | "In the original YOLO paper the author states the loss function and the same expression can be found in articles on YOLOv2 or v3 which is at best a simplification compared to the actual implementation. If you are familiar with the original YOLO loss you will recognize all parts below but they are tweaked to match the idea with the anchor boxes. The loss function can be divided into four parts and I will go through each separately and then combine them in the end. \n", 595 | "\n", 596 | "First we will form two binary tensors signifying where in what cells using which anchors that have objects assigned to them and not. \n", 597 | "```python \n", 598 | "obj = target[..., 0] == 1 \n", 599 | "noobj = target[..., 0] == 0 \n", 600 | "```\n", 601 | "The reason for not only using one of these is that we in the data loading assign the anchors which we should ignore to -1. Indexing only the indices above in all parts of the loss function will make sure that we do not incur any loss on these anchors. I will state all parts of the loss also as mathematical formulas based on the way they are implemented in the code. They are just translations from the code for those who find it easier to understand the loss in that format so don't worry if they're not your cup of tea. \n", 602 | "\n", 603 | "### No object loss\n", 604 | "For the anchors in all cells that do not have an object assigned to them i.e. all indices that are set to one in `noobj` we want to incur loss only for their object score. The target will be all zeros since we want these anchors to predict an object score of zero and we will apply a sigmoid function to the network outputs and use a binary crossentropy loss. In code we have that \n", 605 | "```python \n", 606 | "no_object_loss = self.bce(\n", 607 | " (predictions[..., 0:1][noobj]), (target[..., 0:1][noobj]),\n", 608 | " )\n", 609 | "```\n", 610 | "where `self.bce` refers to an instance of the PyTorch BCEWithLogitsLoss() which applies the sigmoid function and then calculates the binary crossentropy loss. \n", 611 | "\n", 612 | "In mathematics we have that \n", 613 | "$$\n", 614 | "\\begin{aligned} \n", 615 | "L_{noobj} &= \\frac{1}{N \\sum_{a, i,j} \\mathbb{1}_{a\\ i\\ j}^{\\text {noobj }}} \\sum_{n=1}^N \\sum_{a,i,j \\in \\mathbb{1}_{a\\ i \\ j}^{\\text {noobj }}} BCE \\left ( y_{n,a,i,j}^{obj}, \\sigma\\left(t_{n,a,i,j}^{obj}\\right)\\right) \\\\\n", 616 | "&=\\frac{1}{N \\sum_{a, i,j} \\mathbb{1}_{a\\ i\\ j}^{\\text {noobj }}} \\sum_{n=1}^N \\sum_{a,i,j \\in \\mathbb{1}_{a\\ i \\ j}^{\\text {noobj }}} -\\left[y_{n,a,i,j}^{obj} \\cdot \\log \\sigma\\left(t_{n,a,i,j}^{obj}\\right)+\\left(1-y_{n,a,i,j}^{obj}\\right) \\cdot \\log \\left(1-\\sigma(t_{n,a,i,j}^{obj})\\right)\\right]\n", 617 | "\\end{aligned}\n", 618 | "$$\n", 619 | "\n", 620 | "where $N$ is the batch size, $i,\\ j$ signifies the cell where and $a$ the anchor index and $\\mathbb{1}_{a\\ i\\ j}^{\\text {noobj }}$ is a binary tensor with ones on anchors not assigned to an object. The output from the network is denoted $t$ and the target $y$ and $\\sigma$ is the sigmoid function given by \n", 621 | "\n", 622 | "$$\n", 623 | "\\sigma(x) = \\frac{1}{1+e^{-x}}.\n", 624 | "$$\n", 625 | "\n", 626 | "### Object loss\n", 627 | "For the anchors that have an object assigned to them we want them to predict a appropriate bounding box for the object. When building the target tensors we assigned these anchors to have an object score to 1. One idea is to then just do similarily as in the no object loss and train the network to output large values in the cells and anchors for which we have assigned a target bounding box. This would, however, mean that no matter how horrible a bounding box prediction the network makes it would still try to predict a high object score. During inference we are guided by the object score when choosing which bounding boxes to output and if we do as proposed the object score would actually not reflect how likely it actually is that there is an object in the outputted bounding box. The idea in the YOLOv3 paper instead that the object score that the model predicts should reflect the intersection over union between the prediction and the target bounding box. It is slightly unclear how this is actually implemented originally and I have seen several different versions in others' code. In our implementation we will during training time calculate the intersection over union between the target bounding boxes and the predicted bounding boxes in the output and use this as the target for the object score. This does not seem to slow down training noticeably. \n", 628 | "\n", 629 | "In the code we will convert the model predictions to bounding boxes according to the formulas in the paper\n", 630 | "$$\n", 631 | "\\begin{array}{l}\n", 632 | "b_{x}=\\sigma\\left(t_{x}\\right) \\\\\n", 633 | "b_{y}=\\sigma\\left(t_{y}\\right) \\\\\n", 634 | "b_{w}=p_{w} e^{t_{w}} \\\\\n", 635 | "b_{h}=p_{h} e^{t_{h}},\n", 636 | "\\end{array} \\\\\n", 637 | "$$\n", 638 | "where $p_w$ and $p_h$ are the anchor box dimensions and ($b_x, b_y, b_w, b_h$) is the resulting bounding box relative to the cell. We will then calculate the intersection over union with the target that we defined in the dataset class and lastly as in the no object loss above apply the binary cross entropy loss between the object score predictions and the calculated intersection over union. Note that the loss will only be applied to the anchors assigned to a target bounding box signified by indexing by `obj`. \n", 639 | "\n", 640 | "```python\n", 641 | "anchors = anchors.reshape(1, 3, 1, 1, 2) # reshaping for broadcasting \n", 642 | "box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(predictions[..., 3:5]) * anchors], dim=-1)\n", 643 | "ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach()\n", 644 | "object_loss = self.bce((predictions[..., 0:1][obj]), (ious * target[..., 0:1][obj]))\n", 645 | "```\n", 646 | "\n", 647 | "The mathematical formula will be similar to the one above \n", 648 | "$$\n", 649 | "L_{obj}= \\frac{1}{N \\sum_{a, i,j} \\mathbb{1}_{a\\ i\\ j}^{\\text {obj }}} \\sum_{n=1}^N \\sum_{a,i,j \\in \\mathbb{1}_{a\\ i \\ j}^{\\text {obj }}} BCE \\left ( \\hat{y}_{n,a,i,j}^{obj}, \\sigma\\left(t_{n,a,i,j}^{obj}\\right)\\right)\n", 650 | "$$\n", 651 | "with \n", 652 | "$$ \\hat{y} = IOU(y^{box}, b) $$\n", 653 | "where $b$ is the bounding box computed above and $ \\mathbb{1}_{a\\ i\\ j}^{\\text {obj }}$ corresponds to the binary tensor with ones for the anchors assigned to a target bounding box. \n", 654 | "\n", 655 | "### Box coordinates loss\n", 656 | "For the box coordinates we will simply use a mean squared error loss in the positions where there actually are objects. All predictions where there is no corresponding target bounding box will be ignored. We will apply a sigmoid function to the $x$\n", 657 | "and $y$ coordinates to make sure that they are between \\[0,1\\] but instead of converting the widths and heights as above we want to compute the ground truth value $\\hat{t}$ that the network should predict. We find it by inverting the formula above for the bounding boxes\n", 658 | "$$ \n", 659 | "\\begin{aligned}\n", 660 | "\\hat{t}_w &= \\log (y_w / p_w) \\\\\n", 661 | "\\hat{t}_h &= \\log (y_h / p_h)\n", 662 | "\\end{aligned}\n", 663 | "$$\n", 664 | "where the $y_w$ and $y_h$ are the target width and height. We will then apply the mean squared error loss between the targets and predictions. \n", 665 | "```python\n", 666 | "predictions[..., 1:3] = self.sigmoid(predictions[..., 1:3]) # x, y, coordinates\n", 667 | "target[..., 3:5] = torch.log(1e-16 + target[..., 3:5] / anchors) # convert target width and height\n", 668 | "box_loss = self.mse(predictions[..., 1:5][obj], target[..., 1:5][obj]) #index by obj to only apply loss for objects\n", 669 | "```\n", 670 | "The equivalent formula is given by\n", 671 | "$$\n", 672 | "L_{box} = \\frac{1}{N \\sum_{a, i,j} \\mathbb{1}_{a\\ i\\ j}^{\\text {obj }}} \\sum_{n=1}^N \\sum_{a,i,j \\in \\mathbb{1}_{a\\ i \\ j}^{\\text {obj }}} \\left(\\sigma(t^x_{n,a,i,j}) - y^x_{n,a,i,j} \\right)^2 +\n", 673 | "\\left(\\sigma(t^y_{n,a,i,j}) - y^y_{n,a,i,j} \\right)^2 +\n", 674 | "\\left(t^w_{n,a,i,j} - \\hat{t}^w_{n,a,i,j} \\right)^2 +\n", 675 | "\\left(t^h_{n,a,i,j} - \\hat{t}^h_{n,a,i,j} \\right)^2, \n", 676 | "$$\n", 677 | "where $\\hat{t}^*$ is the ground truth labels for what actual values the model should predict.\n", 678 | "\n", 679 | "### Class loss\n", 680 | "We will only incur loss for the class predictions where there actually is an object. Our implementation differs slightly from the paper's in the case of a class loss and we will use a cross entropy loss to compute the class loss. This assumes that each bounding box only has one label. The YOLOv3 motivates that it does not want to have this limitation and instead uses an binary cross entropy such that several labels can be assigned to a single object e.g. woman and person. \n", 681 | "```python\n", 682 | "class_loss = self.entropy(predictions[..., 5:][obj], target[..., 5][obj].long()\n", 683 | "```\n", 684 | "where `self.entropy` refers to an instance of PyTorch's CrossEntropyLoss() with combines the softmax function and negative loglikelihood loss. This corresponds to\n", 685 | "\n", 686 | "$$ \n", 687 | "L_{class} = \\frac{1}{N \\sum_{a, i,j} \\mathbb{1}_{a\\ i\\ j}^{\\text {obj }}} \\sum_{n=1}^N \\sum_{a,i,j \\in \\mathbb{1}_{a\\ i \\ j}^{\\text {obj }}} -\\log \\left(\\frac{\\exp (t_{n, a, i,j}^{c})}{\\sum_{k} \\exp (t_{n, a, i,j}^k)}\\right),\n", 688 | "$$\n", 689 | "where $t_{n, a, i,j}^{c}$ is the prediction for the correct class $c$. \n", 690 | "\n", 691 | "### Total loss\n", 692 | "I will not attempt to put the entire loss function in a single formula as this only creates an unnecessarily complicated expression when each part can be understood and computed separately. The total loss is computed by \n", 693 | "```python\n", 694 | "self.lambda_box * box_loss\n", 695 | "+ self.lambda_obj * object_loss\n", 696 | "+ self.lambda_noobj * no_object_loss\n", 697 | "+ self.lambda_class * class_loss\n", 698 | "```\n", 699 | "or equivalently \n", 700 | "$$ L= \\lambda_{noobj} L_{noobj} + \\lambda_{obj} L_{obj} + \\lambda_{box} L_{box} + \\lambda_{class} L_{class} $$\n", 701 | "where each $\\lambda_*$ is a constant signifying the importance of each part of the loss. It seems that the original implementation uses $\\lambda_* = 1$ for all constants but during training we found better convergence by modifying them. \n", 702 | "\n", 703 | "The complete code for the loss function is found below and the code is placed in a separate loss.py file. " 704 | ] 705 | }, 706 | { 707 | "cell_type": "code", 708 | "execution_count": null, 709 | "id": "consistent-turning", 710 | "metadata": {}, 711 | "outputs": [], 712 | "source": [ 713 | "\"\"\"\n", 714 | "Implementation of Yolo Loss Function similar to the one in Yolov3 paper,\n", 715 | "the difference from what I can tell is I use CrossEntropy for the classes\n", 716 | "instead of BinaryCrossEntropy.\n", 717 | "\"\"\"\n", 718 | "import random\n", 719 | "import torch\n", 720 | "import torch.nn as nn\n", 721 | "from utils import intersection_over_union\n", 722 | "\n", 723 | "\n", 724 | "class YoloLoss(nn.Module):\n", 725 | " def __init__(self):\n", 726 | " super().__init__()\n", 727 | " self.mse = nn.MSELoss()\n", 728 | " self.bce = nn.BCEWithLogitsLoss()\n", 729 | " self.entropy = nn.CrossEntropyLoss()\n", 730 | " self.sigmoid = nn.Sigmoid()\n", 731 | "\n", 732 | " # Constants signifying how much to pay for each respective part of the loss\n", 733 | " self.lambda_class = 1\n", 734 | " self.lambda_noobj = 10\n", 735 | " self.lambda_obj = 1\n", 736 | " self.lambda_box = 10\n", 737 | "\n", 738 | " def forward(self, predictions, target, anchors):\n", 739 | " \"\"\"\n", 740 | " :param predictions: output from model of shape: (batch size, anchors on scale , grid size, grid size, 5 + num classes)\n", 741 | " :param target: targets on particular scale of shape: (batch size, anchors on scale, grid size, grid size, 6)\n", 742 | " :param anchors: anchor boxes on the particular scale of shape (anchors on scale, 2)\n", 743 | " :return: returns the loss on the particular scale\n", 744 | " \"\"\"\n", 745 | "\n", 746 | " # Check where obj and noobj (we ignore if target == -1)\n", 747 | " # Here we check where in the label matrix there is an object or not\n", 748 | " obj = target[..., 0] == 1 # in paper this is Iobj_i\n", 749 | " noobj = target[..., 0] == 0 # in paper this is Inoobj_i\n", 750 | "\n", 751 | " # ======================= #\n", 752 | " # FOR NO OBJECT LOSS #\n", 753 | " # ======================= #\n", 754 | " # The indexing noobj refers to the fact that we only apply the loss where there is no object\n", 755 | " no_object_loss = self.bce(\n", 756 | " (predictions[..., 0:1][noobj]), (target[..., 0:1][noobj]),\n", 757 | " )\n", 758 | "\n", 759 | " # ==================== #\n", 760 | " # FOR OBJECT LOSS #\n", 761 | " # ==================== #\n", 762 | " # Here we compute the loss for the cells and anchor boxes that contain an object\n", 763 | " # Reschape anchors to allow for broadcasting in multiplication below\n", 764 | " anchors = anchors.reshape(1, 3, 1, 1, 2)\n", 765 | " # Convert outputs from model to bounding boxes according to formulas in paper\n", 766 | " box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(predictions[..., 3:5]) * anchors], dim=-1)\n", 767 | " # Targets for the object prediction should be the iou of the predicted bounding box and the target bounding box\n", 768 | " ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach()\n", 769 | " # Only incur loss for the cells where there is an objects signified by indexing with obj\n", 770 | " object_loss = self.bce((predictions[..., 0:1][obj]), (ious * target[..., 0:1][obj]))\n", 771 | "\n", 772 | " # ======================== #\n", 773 | " # FOR BOX COORDINATES #\n", 774 | " # ======================== #\n", 775 | " # apply sigmoid to x, y coordinates to convert to bounding boxes\n", 776 | " predictions[..., 1:3] = self.sigmoid(predictions[..., 1:3]) \n", 777 | " # to improve gradient flow we convert targets' width and height to the same format as predictions\n", 778 | " target[..., 3:5] = torch.log(\n", 779 | " (1e-16 + target[..., 3:5] / anchors)\n", 780 | " ) \n", 781 | " # compute mse loss for boxes\n", 782 | " box_loss = self.mse(predictions[..., 1:5][obj], target[..., 1:5][obj])\n", 783 | "\n", 784 | " # ================== #\n", 785 | " # FOR CLASS LOSS #\n", 786 | " # ================== #\n", 787 | " # here we just apply cross entropy loss as is customary with classification problems\n", 788 | " class_loss = self.entropy(\n", 789 | " (predictions[..., 5:][obj]), (target[..., 5][obj].long()),\n", 790 | " )\n", 791 | " \n", 792 | " return (\n", 793 | " self.lambda_box * box_loss\n", 794 | " + self.lambda_obj * object_loss\n", 795 | " + self.lambda_noobj * no_object_loss\n", 796 | " + self.lambda_class * class_loss\n", 797 | " )" 798 | ] 799 | }, 800 | { 801 | "cell_type": "markdown", 802 | "id": "banned-chicago", 803 | "metadata": {}, 804 | "source": [ 805 | "## Training the model \n", 806 | "The training configuration is completely contained in the config.py file that can be found on [Github](https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/object_detection/YOLOv3). This is where we specify the image size, dataset paths, augmentations, learning rate and all other constants. I will not include it here and if you implement YOLOv3 you can just copy it from above or write you own training configuration. \n", 807 | "\n", 808 | "What we instead will focus on is building the training loop which should be quite straightforward. Everything from here will be placed in a train.py file which we can then run to train the model. First we will define the imports where we will import our previously defined modules and in addition a couple of helper functions from the utils.py file you can find on [Github](https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/object_detection/YOLOv3/utils.py)." 809 | ] 810 | }, 811 | { 812 | "cell_type": "code", 813 | "execution_count": null, 814 | "id": "german-mouse", 815 | "metadata": {}, 816 | "outputs": [], 817 | "source": [ 818 | "import config\n", 819 | "import torch\n", 820 | "import torch.optim as optim\n", 821 | "\n", 822 | "from model import YOLOv3\n", 823 | "from tqdm import tqdm\n", 824 | "from utils import (\n", 825 | " mean_average_precision,\n", 826 | " cells_to_bboxes,\n", 827 | " get_evaluation_bboxes,\n", 828 | " save_checkpoint,\n", 829 | " load_checkpoint,\n", 830 | " check_class_accuracy,\n", 831 | " get_loaders,\n", 832 | " plot_couple_examples\n", 833 | ")\n", 834 | "from loss import YoloLoss\n", 835 | "\n", 836 | "torch.backends.cudnn.benchmark = True" 837 | ] 838 | }, 839 | { 840 | "cell_type": "markdown", 841 | "id": "generic-hardwood", 842 | "metadata": {}, 843 | "source": [ 844 | "We will then define a training function which will train the network for one epoch. We will take as input the model, the data loader, the optimizer the loss function, a scaler for mixed precision training and scaled anchors such that each anchor is relative to the prediction scale. Originally the anchors are relative to the entire image but to the loss we want to input them relative to the cell and this is accomplished by scaling them with the grid size of the prediction scale. \n", 845 | "\n", 846 | "We calculate the total loss as the sum of the losses for each prediction scale, three of them in total. We use mixed precision training to train the model. " 847 | ] 848 | }, 849 | { 850 | "cell_type": "code", 851 | "execution_count": null, 852 | "id": "laden-speaker", 853 | "metadata": {}, 854 | "outputs": [], 855 | "source": [ 856 | "def train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors):\n", 857 | " loop = tqdm(train_loader, leave=True)\n", 858 | " losses = []\n", 859 | " for batch_idx, (x, y) in enumerate(loop):\n", 860 | " x = x.to(config.DEVICE)\n", 861 | " y0, y1, y2 = (\n", 862 | " y[0].to(config.DEVICE),\n", 863 | " y[1].to(config.DEVICE),\n", 864 | " y[2].to(config.DEVICE),\n", 865 | " )\n", 866 | "\n", 867 | " with torch.cuda.amp.autocast():\n", 868 | " out = model(x)\n", 869 | " loss = (\n", 870 | " loss_fn(out[0], y0, scaled_anchors[0])\n", 871 | " + loss_fn(out[1], y1, scaled_anchors[1])\n", 872 | " + loss_fn(out[2], y2, scaled_anchors[2])\n", 873 | " )\n", 874 | "\n", 875 | " losses.append(loss.item())\n", 876 | " optimizer.zero_grad()\n", 877 | " scaler.scale(loss).backward()\n", 878 | " scaler.step(optimizer)\n", 879 | " scaler.update()\n", 880 | "\n", 881 | " # update progress bar\n", 882 | " mean_loss = sum(losses) / len(losses)\n", 883 | " loop.set_postfix(loss=mean_loss)\n" 884 | ] 885 | }, 886 | { 887 | "cell_type": "markdown", 888 | "id": "preceding-force", 889 | "metadata": {}, 890 | "source": [ 891 | "We have now come to the part where we are ready to actually train the model. The main function will take care of setting up the model, loss function, data loaders etc. and in each epoch we will run the train function defined above. Once every ten epochs we will evaluate the model by checking the mean average precision on the test loader. Note that this can be costly if your model's performance is bad because there may be many false positives that the non max suppression and mean average precision functions have to loop through." 892 | ] 893 | }, 894 | { 895 | "cell_type": "code", 896 | "execution_count": null, 897 | "id": "dried-visiting", 898 | "metadata": {}, 899 | "outputs": [], 900 | "source": [ 901 | "def main():\n", 902 | " model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE)\n", 903 | " optimizer = optim.Adam(\n", 904 | " model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY\n", 905 | " )\n", 906 | " loss_fn = YoloLoss()\n", 907 | " scaler = torch.cuda.amp.GradScaler()\n", 908 | "\n", 909 | " train_loader, test_loader, train_eval_loader = get_loaders(\n", 910 | " train_csv_path=config.DATASET + \"/train.csv\", test_csv_path=config.DATASET + \"/test.csv\"\n", 911 | " )\n", 912 | "\n", 913 | " if config.LOAD_MODEL:\n", 914 | " load_checkpoint(\n", 915 | " config.CHECKPOINT_FILE, model, optimizer, config.LEARNING_RATE\n", 916 | " )\n", 917 | " \n", 918 | " #Scale anchors to each prediction scale\n", 919 | " scaled_anchors = (\n", 920 | " torch.tensor(config.ANCHORS)\n", 921 | " * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)\n", 922 | " ).to(config.DEVICE)\n", 923 | "\n", 924 | " for epoch in range(config.NUM_EPOCHS):\n", 925 | " train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors)\n", 926 | "\n", 927 | " if config.SAVE_MODEL:\n", 928 | " save_checkpoint(model, optimizer, filename=f\"checkpoint.pth.tar\")\n", 929 | "\n", 930 | " if epoch % 10 == 0 and epoch > 0:\n", 931 | " print(\"On Test loader:\")\n", 932 | " check_class_accuracy(model, test_loader, threshold=config.CONF_THRESHOLD)\n", 933 | " # Run model on test set and convert outputs to bounding boxes relative to image\n", 934 | " pred_boxes, true_boxes = get_evaluation_bboxes(\n", 935 | " test_loader,\n", 936 | " model,\n", 937 | " iou_threshold=config.NMS_IOU_THRESH,\n", 938 | " anchors=config.ANCHORS,\n", 939 | " threshold=config.CONF_THRESHOLD,\n", 940 | " )\n", 941 | " # Compute mean average precision \n", 942 | " mapval = mean_average_precision(\n", 943 | " pred_boxes,\n", 944 | " true_boxes,\n", 945 | " iou_threshold=config.MAP_IOU_THRESH,\n", 946 | " box_format=\"midpoint\",\n", 947 | " num_classes=config.NUM_CLASSES,\n", 948 | " )\n", 949 | " print(f\"MAP: {mapval.item()}\")\n", 950 | "\n", 951 | "\n", 952 | "\n", 953 | "if __name__ == \"__main__\":\n", 954 | " main()" 955 | ] 956 | }, 957 | { 958 | "cell_type": "markdown", 959 | "id": "found-psychology", 960 | "metadata": {}, 961 | "source": [ 962 | "We have now reached the end of this YOLOv3 implementation and if you feel that everything is crystal clear then: Wow I've really outdone myself. It is more likely that you have to reiterate this and possibly others' implementations if your goal is to implement YOLOv3 yourself. Anyhow, I hope that you take with you some key implementational details of YOLOv3 from this article and if you have any lingering thoughts, leave a comment! " 963 | ] 964 | } 965 | ], 966 | "metadata": { 967 | "kernelspec": { 968 | "display_name": "Python 3", 969 | "language": "python", 970 | "name": "python3" 971 | }, 972 | "language_info": { 973 | "codemirror_mode": { 974 | "name": "ipython", 975 | "version": 3 976 | }, 977 | "file_extension": ".py", 978 | "mimetype": "text/x-python", 979 | "name": "python", 980 | "nbconvert_exporter": "python", 981 | "pygments_lexer": "ipython3", 982 | "version": "3.8.5" 983 | } 984 | }, 985 | "nbformat": 4, 986 | "nbformat_minor": 5 987 | } 988 | --------------------------------------------------------------------------------