├── github_imgs ├── from_net.png ├── mosaic.png └── realtime.jpg ├── .gitignore ├── tox.ini ├── mypy.ini ├── hubconf.py ├── LICENSE ├── sched_del.py ├── .github └── workflows │ └── main.yml ├── SplitDataset.ipynb ├── lars.py ├── video_demo.py ├── pl_model.py ├── README.md ├── utils.py ├── ranger.py ├── radam.py ├── dataset.py ├── model.py └── Training YOLOv4 .ipynb /github_imgs/from_net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VCasecnikovs/Yet-Another-YOLOv4-Pytorch/HEAD/github_imgs/from_net.png -------------------------------------------------------------------------------- /github_imgs/mosaic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VCasecnikovs/Yet-Another-YOLOv4-Pytorch/HEAD/github_imgs/mosaic.png -------------------------------------------------------------------------------- /github_imgs/realtime.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VCasecnikovs/Yet-Another-YOLOv4-Pytorch/HEAD/github_imgs/realtime.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.pickle 3 | *.log 4 | *.rar 5 | *.txt 6 | *.code-workspace 7 | /__pycache__ 8 | /.ipynb_checkpoints 9 | /.virtual_documents 10 | /labels 11 | /images 12 | /.mypy_cache 13 | /%USERPROFILE% 14 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | format = pylint 3 | exclude = .git,__pycache__,.idea,.gitignore,pics/,scripts/,nohup.out 4 | max-complexity = 15 5 | max-line-length = 127 6 | doctests = True 7 | ignore = E731,D104,D401,I101,I201,F401,F403,S001,D100,D101,D102,D103,D105,D106,D107,D200,D205,D400,W504,E203,D202 -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | # mypy configurations: http://bit.ly/2zEl9WI 3 | python_version = 3.7 4 | allow_redefinition = False 5 | check_untyped_defs = True 6 | disallow_any_generics = True 7 | ignore_missing_imports = True 8 | implicit_reexport = False 9 | strict_optional = True 10 | strict_equality = True 11 | no_implicit_optional = True 12 | warn_no_return = True 13 | warn_unused_ignores = True 14 | warn_redundant_casts = True 15 | warn_unused_configs = True 16 | warn_return_any = True 17 | warn_unreachable = True 18 | show_error_codes = True -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model import YOLOv4 3 | 4 | dependencies = ['torch'] 5 | 6 | def yolov4(pretrained=False, n_classes=80): 7 | """ 8 | YOLOv4 model 9 | pretrained (bool): kwargs, load pretrained weights into the model 10 | n_classes(int): amount of classes 11 | """ 12 | m = YOLOv4(n_classes=n_classes) 13 | if pretrained: 14 | try: #If we change input or output layers amount, we will have an option to use pretrained weights 15 | m.load_state_dict(torch.hub.load_state_dict_from_url("https://github.com/VCasecnikovs/Yet-Another-YOLOv4-Pytorch/releases/download/V1.0/yolov4.pth"), strict=False) 16 | except RuntimeError as e: 17 | print(f'[Warning] Ignoring {e}') 18 | 19 | return m 20 | 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 VCasecnikovs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /sched_del.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR 2 | 3 | class DelayerScheduler(_LRScheduler): 4 | """ Starts with a flat lr schedule until it reaches N epochs the applies a scheduler 5 | Args: 6 | optimizer (Optimizer): Wrapped optimizer. 7 | delay_epochs: number of epochs to keep the initial lr until starting aplying the scheduler 8 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 9 | """ 10 | 11 | def __init__(self, optimizer, delay_epochs, after_scheduler): 12 | self.delay_epochs = delay_epochs 13 | self.after_scheduler = after_scheduler 14 | self.finished = False 15 | super().__init__(optimizer) 16 | 17 | def get_lr(self): 18 | if self.last_epoch >= self.delay_epochs: 19 | if not self.finished: 20 | self.after_scheduler.base_lrs = self.base_lrs 21 | self.finished = True 22 | return self.after_scheduler.get_lr() 23 | 24 | return self.base_lrs 25 | 26 | def step(self, epoch=None): 27 | if self.finished: 28 | if epoch is None: 29 | self.after_scheduler.step(None) 30 | else: 31 | self.after_scheduler.step(epoch - self.delay_epochs) 32 | else: 33 | return super(DelayerScheduler, self).step(epoch) 34 | 35 | def DelayedCosineAnnealingLR(optimizer, delay_epochs, cosine_annealing_epochs): 36 | base_scheduler = CosineAnnealingLR(optimizer, cosine_annealing_epochs) 37 | return DelayerScheduler(optimizer, delay_epochs, base_scheduler) -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | name: Python style and tests checks 4 | on: 5 | push: 6 | branches: [master] 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 3.7 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: 3.7 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 20 | pip install black flake8 pytest mypy isort 21 | - name: Format with black 22 | run: | 23 | # The GitHub editor is 127 chars wide 24 | black . -l 127 25 | - name: Sort imports 26 | run: | 27 | isort . 28 | #- name: Check with mypy 29 | # run: | 30 | # mypy 31 | - name: Lint with flake8 32 | run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 35 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 36 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 37 | # - name: Run tests with pytest 38 | # run: | 39 | # pytest 40 | -------------------------------------------------------------------------------- /SplitDataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pathlib import Path\n", 10 | "import numpy as np" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "TRAIN_PCT = 0.8" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": { 26 | "tags": [] 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "all_labels = []\n", 31 | "\n", 32 | "for date_path in Path(\"images\").iterdir():\n", 33 | " for augs_path in date_path.iterdir():\n", 34 | " all_labels.append(str(augs_path)) " 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "train_idx = int(len(all_labels) * TRAIN_PCT)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "random_ids = np.random.permutation(np.arange(len(all_labels)))" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "train_labels = [label for i, label in enumerate(all_labels) if i in random_ids[:train_idx]]\n", 62 | "valid_labels = [label for i, label in enumerate(all_labels) if i in random_ids[train_idx:]]" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "with open(\"train.txt\", \"w\") as tf:\n", 72 | " tf.write(\"\\n\".join(train_labels))" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "with open(\"valid.txt\", \"w\") as vf:\n", 82 | " vf.write(\"\\n\".join(valid_labels))" 83 | ] 84 | } 85 | ], 86 | "metadata": { 87 | "kernelspec": { 88 | "display_name": "Python 3", 89 | "language": "python", 90 | "name": "python3" 91 | }, 92 | "language_info": { 93 | "codemirror_mode": { 94 | "name": "ipython", 95 | "version": 3 96 | }, 97 | "file_extension": ".py", 98 | "mimetype": "text/x-python", 99 | "name": "python", 100 | "nbconvert_exporter": "python", 101 | "pygments_lexer": "ipython3", 102 | "version": "3.7.7-final" 103 | } 104 | }, 105 | "nbformat": 4, 106 | "nbformat_minor": 4 107 | } -------------------------------------------------------------------------------- /lars.py: -------------------------------------------------------------------------------- 1 | """ Layer-wise adaptive rate scaling for SGD in PyTorch! """ 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | #Taken from https://github.com/noahgolmant/pytorch-lars/blob/master/lars.py 6 | class LARS(Optimizer): 7 | r"""Implements layer-wise adaptive rate scaling for SGD. 8 | Args: 9 | params (iterable): iterable of parameters to optimize or dicts defining 10 | parameter groups 11 | lr (float): base learning rate (\gamma_0) 12 | momentum (float, optional): momentum factor (default: 0) ("m") 13 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 14 | ("\beta") 15 | eta (float, optional): LARS coefficient 16 | max_epoch: maximum training epoch to determine polynomial LR decay. 17 | Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. 18 | Large Batch Training of Convolutional Networks: 19 | https://arxiv.org/abs/1708.03888 20 | """ 21 | def __init__(self, params, lr=required, momentum=.9, 22 | weight_decay=.0005, eta=0.001, max_epoch=200): 23 | if lr is not required and lr < 0.0: 24 | raise ValueError("Invalid learning rate: {}".format(lr)) 25 | if momentum < 0.0: 26 | raise ValueError("Invalid momentum value: {}".format(momentum)) 27 | if weight_decay < 0.0: 28 | raise ValueError("Invalid weight_decay value: {}" 29 | .format(weight_decay)) 30 | if eta < 0.0: 31 | raise ValueError("Invalid LARS coefficient value: {}".format(eta)) 32 | 33 | self.epoch = 0 34 | defaults = dict(lr=lr, momentum=momentum, 35 | weight_decay=weight_decay, 36 | eta=eta, max_epoch=max_epoch) 37 | super(LARS, self).__init__(params, defaults) 38 | 39 | def step(self, epoch=None, closure=None): 40 | """Performs a single optimization step. 41 | Arguments: 42 | closure (callable, optional): A closure that reevaluates the model 43 | and returns the loss. 44 | epoch: current epoch to calculate polynomial LR decay schedule. 45 | if None, uses self.epoch and increments it. 46 | """ 47 | loss = None 48 | if closure is not None: 49 | loss = closure() 50 | 51 | if epoch is None: 52 | epoch = self.epoch 53 | self.epoch += 1 54 | 55 | for group in self.param_groups: 56 | weight_decay = group['weight_decay'] 57 | momentum = group['momentum'] 58 | eta = group['eta'] 59 | lr = group['lr'] 60 | max_epoch = group['max_epoch'] 61 | 62 | for p in group['params']: 63 | if p.grad is None: 64 | continue 65 | 66 | param_state = self.state[p] 67 | d_p = p.grad.data 68 | 69 | weight_norm = torch.norm(p.data) 70 | grad_norm = torch.norm(d_p) 71 | 72 | # Global LR computed on polynomial decay schedule 73 | decay = (1 - float(epoch) / max_epoch) ** 2 74 | global_lr = lr * decay 75 | 76 | # Compute local learning rate for this layer 77 | local_lr = eta * weight_norm / \ 78 | (grad_norm + weight_decay * weight_norm) 79 | 80 | # Update the momentum term 81 | actual_lr = local_lr * global_lr 82 | 83 | if 'momentum_buffer' not in param_state: 84 | buf = param_state['momentum_buffer'] = \ 85 | torch.zeros_like(p.data) 86 | else: 87 | buf = param_state['momentum_buffer'] 88 | buf.mul_(momentum).add_(actual_lr, d_p + weight_decay * p.data) 89 | p.data.add_(-buf) 90 | 91 | return loss -------------------------------------------------------------------------------- /video_demo.py: -------------------------------------------------------------------------------- 1 | from model import YOLOv4 2 | import cv2 3 | from torch.backends import cudnn 4 | import torch 5 | import utils 6 | import time 7 | 8 | coco_dict = {0: 'person', 9 | 1: 'bicycle', 10 | 2: 'car', 11 | 3: 'motorbike', 12 | 4: 'aeroplane', 13 | 5: 'bus', 14 | 6: 'train', 15 | 7: 'truck', 16 | 8: 'boat', 17 | 9: 'traffic light', 18 | 10: 'fire hydrant', 19 | 11: 'stop sign', 20 | 12: 'parking meter', 21 | 13: 'bench', 22 | 14: 'bird', 23 | 15: 'cat', 24 | 16: 'dog', 25 | 17: 'horse', 26 | 18: 'sheep', 27 | 19: 'cow', 28 | 20: 'elephant', 29 | 21: 'bear', 30 | 22: 'zebra', 31 | 23: 'giraffe', 32 | 24: 'backpack', 33 | 25: 'umbrella', 34 | 26: 'handbag', 35 | 27: 'tie', 36 | 28: 'suitcase', 37 | 29: 'frisbee', 38 | 30: 'skis', 39 | 31: 'snowboard', 40 | 32: 'sports ball', 41 | 33: 'kite', 42 | 34: 'baseball bat', 43 | 35: 'baseball glove', 44 | 36: 'skateboard', 45 | 37: 'surfboard', 46 | 38: 'tennis racket', 47 | 39: 'bottle', 48 | 40: 'wine glass', 49 | 41: 'cup', 50 | 42: 'fork', 51 | 43: 'knife', 52 | 44: 'spoon', 53 | 45: 'bowl', 54 | 46: 'banana', 55 | 47: 'apple', 56 | 48: 'sandwich', 57 | 49: 'orange', 58 | 50: 'broccoli', 59 | 51: 'carrot', 60 | 52: 'hot dog', 61 | 53: 'pizza', 62 | 54: 'donut', 63 | 55: 'cake', 64 | 56: 'chair', 65 | 57: 'sofa', 66 | 58: 'pottedplant', 67 | 59: 'bed', 68 | 60: 'diningtable', 69 | 61: 'toilet', 70 | 62: 'tvmonitor', 71 | 63: 'laptop', 72 | 64: 'mouse', 73 | 65: 'remote', 74 | 66: 'keyboard', 75 | 67: 'cell phone', 76 | 68: 'microwave', 77 | 69: 'oven', 78 | 70: 'toaster', 79 | 71: 'sink', 80 | 72: 'refrigerator', 81 | 73: 'book', 82 | 74: 'clock', 83 | 75: 'vase', 84 | 76: 'scissors', 85 | 77: 'teddy bear', 86 | 78: 'hair drier', 87 | 79: 'toothbrush'} 88 | 89 | 90 | cudnn.fastest = True 91 | cudnn.benchmark = True 92 | threshold = 0.2 93 | iou_threshold = 0.2 94 | 95 | m = YOLOv4(pretrained=True, sam=False, eca=False) 96 | m.requires_grad_(False) 97 | m.eval() 98 | 99 | m = m.cuda() 100 | 101 | #To warm up JIT 102 | m(torch.zeros((1, 3, 608, 608)).cuda()) 103 | 104 | cap = cv2.VideoCapture(0) 105 | 106 | frames_n = 0 107 | start_time = time.time() 108 | 109 | while True: 110 | ret, frame = cap.read() 111 | if not ret: 112 | break 113 | 114 | 115 | sized = cv2.resize(frame, (m.img_dim, m.img_dim)) 116 | sized = cv2.cvtColor(sized, cv2.COLOR_BGR2RGB) 117 | 118 | 119 | x = torch.from_numpy(sized) 120 | x = x.permute(2, 0, 1) 121 | x = x.float() 122 | x /= 255 123 | 124 | anchors, _ = m(x[None].cuda()) 125 | 126 | confidence_threshold = 0.5 127 | iou_threshold = 0.5 128 | 129 | bboxes, labels = utils.get_bboxes_from_anchors(anchors, 0.4, 0.5, coco_dict) 130 | arr = utils.get_img_with_bboxes(x.cpu(), bboxes[0].cpu(), resize=False, labels=labels[0]) 131 | arr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) 132 | 133 | frames_n += 1 134 | 135 | arr = cv2.putText(arr, "FPS: " + str(frames_n / (time.time() - start_time)), (100, 100), cv2.FONT_HERSHEY_DUPLEX, 0.75, (255, 255, 255)) 136 | 137 | cv2.imshow("test", arr) 138 | if cv2.waitKey(1) & 0xFF == ord('q'): 139 | break 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /pl_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | from torch.utils.data import DataLoader 4 | 5 | from dataset import ListDataset 6 | from model import YOLOv4 7 | 8 | from lars import LARS 9 | from ranger import Ranger 10 | from radam import RAdam 11 | 12 | from sched_del import DelayedCosineAnnealingLR 13 | 14 | torch.backends.cudnn.benchmark = True 15 | 16 | class YOLOv4PL(pl.LightningModule): 17 | def __init__(self, hparams): 18 | super().__init__() 19 | 20 | self.hparams = hparams 21 | 22 | self.train_ds = ListDataset(hparams.train_ds, train=True, img_extensions=hparams.img_extensions) 23 | self.valid_ds = ListDataset(hparams.valid_ds, train=False, img_extensions=hparams.img_extensions) 24 | 25 | self.model = YOLOv4(n_classes=hparams.n_classes, 26 | pretrained=hparams.pretrained, 27 | dropblock=hparams.Dropblock, 28 | sam=hparams.SAM, 29 | eca=hparams.ECA, 30 | ws=hparams.WS, 31 | iou_aware=hparams.iou_aware, 32 | coord=hparams.coord, 33 | hard_mish=hparams.hard_mish, 34 | asff=hparams.asff, 35 | repulsion_loss=hparams.repulsion_loss, 36 | acff=hparams.acff, 37 | bcn=hparams.bcn, 38 | mbn=hparams.mbn).cuda() 39 | 40 | def train_dataloader(self): 41 | train_dl = DataLoader(self.train_ds, batch_size=self.hparams.bs, collate_fn=self.train_ds.collate_fn, pin_memory=True, num_workers=4) 42 | return train_dl 43 | 44 | def val_dataloader(self): 45 | valid_dl = DataLoader(self.valid_ds, batch_size=self.hparams.bs, collate_fn=self.valid_ds.collate_fn, pin_memory=True, num_workers=4) 46 | return valid_dl 47 | 48 | def forward(self, x, y=None): 49 | return self.model(x, y) 50 | 51 | def basic_training_step(self, batch): 52 | filenames, images, labels = batch 53 | y_hat, loss = self(images, labels) 54 | logger_logs = {"training_loss": loss} 55 | 56 | return {"loss": loss, "log": logger_logs} 57 | 58 | def sat_fgsm_training_step(self, batch, epsilon=0.01): 59 | filenames, images, labels = batch 60 | 61 | images.requires_grad_(True) 62 | y_hat, loss = self(images, labels) 63 | loss.backward() 64 | data_grad = images.grad.data 65 | images.requires_grad_(False) 66 | images = torch.clamp(images + data_grad.sign() * epsilon, 0, 1) 67 | return self.basic_training_step((filenames, images, labels)) 68 | 69 | def sat_vanila_training_step(self, batch, epsilon=1): 70 | filenames, images, labels = batch 71 | 72 | images.requires_grad_(True) 73 | y_hat, loss = self(images, labels) 74 | loss.backward() 75 | data_grad = images.grad.data 76 | images.requires_grad_(False) 77 | images = torch.clamp(images + data_grad, 0, 1) 78 | return self.basic_training_step((filenames, images, labels)) 79 | 80 | 81 | 82 | def training_step(self, batch, batch_idx): 83 | if self.hparams.SAT == "vanila": 84 | return self.sat_vanila_training_step(batch, self.hparams.epsilon) 85 | elif self.hparams.SAT == "fgsm": 86 | return self.sat_fgsm_training_step(batch, self.hparams.epsilon) 87 | else: 88 | return self.basic_training_step(batch) 89 | 90 | def training_epoch_end(self, outputs): 91 | training_loss_mean = torch.stack([x['training_loss'] for x in outputs]).mean() 92 | return {"loss": training_loss_mean, "log": {"training_loss_epoch": training_loss_mean}} 93 | 94 | def validation_step(self, batch, batch_idx): 95 | filenames, images, labels = batch 96 | y_hat, loss = self(images, labels) 97 | return {"val_loss": loss} 98 | 99 | def validation_epoch_end(self, outputs): 100 | val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() 101 | logger_logs = {"validation_loss": val_loss_mean} 102 | 103 | return {"val_loss": val_loss_mean, "log": logger_logs} 104 | 105 | def configure_optimizers(self): 106 | # With this thing we get only params, which requires grad (weights needed to train) 107 | params = filter(lambda p: p.requires_grad, self.model.parameters()) 108 | if self.hparams.optimizer == "Ranger": 109 | self.optimizer = Ranger(params, self.hparams.lr, weight_decay=self.hparams.wd) 110 | elif self.hparams.optimizer == "SGD": 111 | self.optimizer = torch.optim.SGD(params, self.hparams.lr, momentum=self.hparams.momentum, weight_decay=self.hparams.wd) 112 | elif self.hparams.optimizer == "LARS": 113 | self.optimizer = LARS(params, lr=self.hparams.lr, momentum=self.hparams.momentum, weight_decay=self.hparams.wd, max_epoch=self.hparams.epochs) 114 | elif self.hparams.optimizer == "RAdam": 115 | self.optimizer = RAdam(params, lr=self.hparams.lr, weight_decay=self.hparams.wd) 116 | 117 | if self.hparams.scheduler == "Cosine Warm-up": 118 | self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, self.hparams.lr, epochs=self.hparams.epochs, steps_per_epoch=1, pct_start=self.hparams.pct_start) 119 | if self.hparams.scheduler == "Cosine Delayed": 120 | self.scheduler = DelayedCosineAnnealingLR(self.optimizer, self.hparams.flat_epochs, self.hparams.cosine_epochs) 121 | 122 | 123 | sched_dict = {'scheduler': self.scheduler} 124 | 125 | 126 | return [self.optimizer], [sched_dict] 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Yet-Another-YOLOv4-Pytorch 2 | ![](github_imgs/from_net.png) 3 | 4 | !!! For the jupyter notebook please install pytorch-lightning version 0.7.6 5 | 6 | This is implementation of YOLOv4 object detection neural network on pytorch. I'll try to implement all features of original paper. 7 | 8 | - [x] Model 9 | - [x] Pretrained weights 10 | - [x] Custom classes 11 | - [x] CIoU 12 | - [x] YOLO dataset 13 | - [x] Letterbox for validation 14 | - [x] HSV transforms for train 15 | - [x] MOSAIC for train 16 | - [x] Dropblock layers for training. One in each PAN layer, but you can easily add it to each layer. (Thanks to Evgenii Zheltonozhskii for pytorch implementation) 17 | - [x] LARS optimizer 18 | - [x] Pytorch lightning 19 | - [x] Self adversial training with fgsm 20 | - [x] SAM attention block from official YOLOv4 paper 21 | - [x] ECA attention block from https://arxiv.org/abs/1910.03151 with fastglobalavgpool from https://arxiv.org/pdf/2003.13630.pdf 22 | - [x] Weight standartization from https://arxiv.org/abs/1903.10520 (Do not suggest to use with pretrained, could lead to an input explosion, used with track_running_stats, otherwise explosion) 23 | - [x] Notebook with guide 24 | - [x] IoU Aware from https://arxiv.org/abs/2007.12099 25 | - [x] NMS in Depth implementation (not connected) 26 | - [x] Matrix NMS algorithm from https://arxiv.org/abs/2007.12099 (not connected) 27 | - [ ] Deformable convolutions from https://arxiv.org/abs/2007.12099 28 | - [x] Coord convolutions from https://arxiv.org/abs/2007.12099 29 | - [x] Self adversial training with vanila grad 30 | - [x] Hard mish 31 | - [ ] Easy mAP for your DL 32 | - [x] ASFF from https://arxiv.org/abs/1911.09516 33 | - [x] RAdam optimizer from https://arxiv.org/abs/1908.03265 34 | - [x] Ranger optimizer (RAdam + LookAhead) from https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 35 | - [x] Repulsion Loss from https://arxiv.org/abs/1711.07752v2 36 | - [ ] Soft IoU Loss from https://arxiv.org/abs/1904.00853v3 37 | - [x] Learning IoU (as in IoU, just use IoU aware) 38 | - [ ] EM-Merger (TODO: use https://github.com/eg4000/SKU110K_CVPR19 to do postprocessing util) 39 | - [ ] Elastic from https://arxiv.org/abs/1812.05262 40 | - [x] BN microbatching 41 | - [x] BCN from https://arxiv.org/pdf/1903.10520.pdf 42 | - [ ] AdamP from https://arxiv.org/abs/2006.08217v1 43 | - [x] Channel-wise feature fusion by me) 44 | 45 | 46 | 47 | 48 | ## What you can already do 49 | You can use video_demo.py to take a look at the original weights realtime OD detection. (Have 9 fps on my GTX1060 laptop!!!) 50 | ![](/github_imgs/realtime.jpg) 51 | 52 | You can train your own model with mosaic augmentation for training. Guides how to do this are written below. Borders of images on some datasets are even hard to find. 53 | ![](/github_imgs/mosaic.png) 54 | 55 | 56 | You can make inference, guide bellow. 57 | 58 | 59 | ## Initialize NN 60 | 61 | #YOU CAN USE TORCH HUB 62 | m = torch.hub.load("VCasecnikovs/Yet-Another-YOLOv4-Pytorch", "yolov4", pretrained=True) 63 | 64 | import model 65 | #If you change n_classes from the pretrained, there will be caught one error, don't panic it is ok 66 | 67 | #FROM SAVED WEIGHTS 68 | m = model.YOLOv4(n_classes=1, weights_path="weights/yolov4.pth") 69 | 70 | #AUTOMATICALLY DOWNLOAD PRETRAINED 71 | m = model.YOLOv4(n_classes=1, pretrained=True) 72 | 73 | ## Download weights 74 | You can use torch hub 75 | or you can download weights using from this link: https://drive.google.com/open?id=12AaR4fvIQPZ468vhm0ZYZSLgWac2HBnq 76 | 77 | ## Initialize dataset 78 | 79 | import dataset 80 | d = dataset.ListDataset("train.txt", img_dir='images', labels_dir='labels', img_extensions=['.JPG'], train=True) 81 | path, img, bboxes = d[0] 82 | 83 | !!! You can use SplitDataset.ipynb to create train.txt and valid.txt 84 | 85 | "train.txt" is file which consists with filepaths to image (images\primula\DSC02542.JPG) 86 | 87 | img_dir - Folder with images 88 | labels_dir - Folder with txt files for annotation 89 | img_extensions - extensions if images 90 | 91 | If you set train=False -> uses letterboxes 92 | If you set train=True -> HSV augmentations and mosaic 93 | 94 | dataset has collate_function 95 | 96 | # collate func example 97 | y1 = d[0] 98 | y2 = d[1] 99 | paths_b, xb, yb = d.collate_fn((y1, y2)) 100 | # yb has 6 columns 101 | 102 | ## Y's format 103 | Is a tensor of size (B, 6), where B is amount of boxes in all batch images. 104 | 1. Index of img to which this anchor belongs (if 1, then it belongs to x[1]) 105 | 2. BBox class 106 | 3. x center 107 | 4. y center 108 | 5. width 109 | 6. height 110 | 111 | ## Forward with loss 112 | y_hat, loss = m(xb, yb) 113 | 114 | !!! y_hat is already resized anchors to image size bboxes 115 | 116 | ## Forward without loss 117 | y_hat, _ = m(img_batch) #_ is (0, 0, 0) 118 | 119 | ## Check if bboxes are correct 120 | import utils 121 | from PIL import Image 122 | path, img, bboxes = d[0] 123 | img_with_bboxes = utils.get_img_with_bboxes(img, bboxes[:, 2:]) #Returns numpy array 124 | Image.fromarray(img_with_bboxes) 125 | 126 | ## Get predicted bboxes 127 | anchors, loss = m(xb.cuda(), yb.cuda()) 128 | confidence_threshold = 0.05 129 | iou_threshold = 0.5 130 | bboxes, labels = utils.get_bboxes_from_anchors(anchors, confidence_threshold, iou_threshold, coco_dict) #COCO dict is id->class dictionary (f.e. 0->person) 131 | #For first img 132 | arr = utils.get_img_with_bboxes(xb[0].cpu(), bboxes[0].cpu(), resize=False, labels=labels[0]) 133 | Image.fromarray(arr) 134 | 135 | ## References 136 | In case if you missed:\ 137 | Paper Yolo v4: https://arxiv.org/abs/2004.10934\ 138 | Original repo: https://github.com/AlexeyAB/darknet#how-to-train-to-detect-your-custom-objects 139 | ``` 140 | @article{yolov4, 141 | title={YOLOv4: YOLOv4: Optimal Speed and Accuracy of Object Detection}, 142 | author={Alexey Bochkovskiy, Chien-Yao Wang, Hong-Yuan Mark Liao}, 143 | journal = {arXiv}, 144 | year={2020} 145 | } 146 | ``` 147 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import cv2 5 | from PIL import Image 6 | from torchvision.ops import nms 7 | 8 | def xyxy2xywh(x): 9 | # Convert bounding box format from [x1, y1, x2, y2] to [x, y, w, h] 10 | y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x) 11 | y[:, 0] = (x[:, 0] + x[:, 2]) / 2 12 | y[:, 1] = (x[:, 1] + x[:, 3]) / 2 13 | y[:, 2] = x[:, 2] - x[:, 0] 14 | y[:, 3] = x[:, 3] - x[:, 1] 15 | return y 16 | 17 | 18 | def xywh2xyxy(x): 19 | # Convert bounding box format from [x, y, w, h] to [x1, y1, x2, y2] 20 | y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x) 21 | y[:, 0] = x[:, 0] - x[:, 2] / 2 22 | y[:, 1] = x[:, 1] - x[:, 3] / 2 23 | y[:, 2] = x[:, 0] + x[:, 2] / 2 24 | y[:, 3] = x[:, 1] + x[:, 3] / 2 25 | return y 26 | 27 | def get_img_with_bboxes(img, bboxes, resize=True, labels=None, confidences= None): 28 | c, h, w = img.shape 29 | 30 | bboxes_xyxy = bboxes.clone() 31 | bboxes_xyxy[:, :4] = xywh2xyxy(bboxes[:, :4]) 32 | if resize: 33 | bboxes_xyxy[:,0] *= w 34 | bboxes_xyxy[:,1] *= h 35 | bboxes_xyxy[:,2] *= w 36 | bboxes_xyxy[:,3] *= h 37 | 38 | bboxes_xyxy[:, 0:4] = bboxes_xyxy[:,0:4].round() 39 | 40 | arr = bboxes_xyxy.numpy() 41 | 42 | img = img.permute(1, 2, 0) 43 | img = img.numpy() 44 | img = (img * 255).astype(np.uint8) 45 | 46 | #Otherwise cv2 rectangle will return UMat without paint 47 | img_ = img.copy() 48 | 49 | for i, bbox in enumerate(arr): 50 | img_ = cv2.rectangle(img_, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 0, 0), 3) 51 | if labels: 52 | text = labels[i] 53 | text += f" {bbox[4].item() :.2f}" 54 | 55 | img_ = cv2.putText(img_, text, (bbox[0], bbox[1]), cv2.FONT_HERSHEY_DUPLEX, 0.75, (255, 255, 255)) 56 | return img_ 57 | 58 | 59 | def bbox_iou(box1, box2, x1y1x2y2=True, get_areas = False): 60 | """ 61 | Returns the IoU of two bounding boxes 62 | """ 63 | if not x1y1x2y2: 64 | # Transform from center and width to exact coordinates 65 | b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2 66 | b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2 67 | b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2 68 | b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2 69 | else: 70 | # Get the coordinates of bounding boxes 71 | b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] 72 | b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] 73 | 74 | # get the coordinates of the intersection rectangle 75 | inter_rect_x1 = torch.max(b1_x1, b2_x1) 76 | inter_rect_y1 = torch.max(b1_y1, b2_y1) 77 | inter_rect_x2 = torch.min(b1_x2, b2_x2) 78 | inter_rect_y2 = torch.min(b1_y2, b2_y2) 79 | 80 | # Intersection area 81 | inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1, min=0) * torch.clamp( 82 | inter_rect_y2 - inter_rect_y1, min=0 83 | ) 84 | # Union Area 85 | b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) 86 | b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) 87 | union_area = (b1_area + b2_area - inter_area + 1e-16) 88 | 89 | 90 | if get_areas: 91 | return inter_area, union_area 92 | 93 | iou = inter_area / union_area 94 | return iou 95 | 96 | def nms_with_depth(bboxes, confidence, iou_threshold, depth_layer, depth_threshold): 97 | if len(bboxes) == 0: 98 | return bboxes 99 | 100 | for i in range(bboxes.shape[0]): 101 | for j in range(i+1, bboxes.shape[0]): 102 | iou = bbox_iou(bboxes[i], bboxes[j]) 103 | if iou > iou_threshold: 104 | #Getting center depth points of both bboxes 105 | D_oi = depth_layer[(bboxes[i, 0] + bboxes[i, 2])//2, (bboxes[i, 1] + bboxes[i, 3])//2] 106 | D_oj = depth_layer[(bboxes[j, 0] + bboxes[j, 2])//2, (bboxes[j, 1] + bboxes[j, 3])//2] 107 | if D_oi - D_oj < depth_threshold: 108 | average_depth_oi = depth_layer[bboxes[i, 0] : bboxes[i, 2], bboxes[i, 1] : bboxes[i, 3]] 109 | average_depth_oj = depth_layer[bboxes[j, 0] : bboxes[j, 2], bboxes[j, 1] : bboxes[j, 3]] 110 | score_oi = confidence[i] + 1/torch.log(average_depth_oi) 111 | score_oj = confidence[j] + 1/torch.log(average_depth_oj) 112 | if score_oi > score_oj: 113 | confidence[j] = 0 114 | else: 115 | confidence[i] = 0 116 | 117 | return confidence != 0 118 | 119 | def matrix_nms(boxes, confidence, iou_threshold, batch_size, method, sigma, N): 120 | boxes = boxes.reshape(batch_size, -1) 121 | intersection = torch.mm(boxes, boxes.T) 122 | areas = boxes.sum(dim=1).expand(N, N) 123 | union = areas + areas.T - intersection 124 | ious = (intersection / union).triu(diagonal=1) 125 | 126 | ious_cmax = ious.max(0) 127 | ious_cmax = ious_cmax.expand(N, N).T 128 | 129 | if method == "gauss": # gaussian 130 | decay = torch.exp(-(ious**2 - ious_cmax**2) / sigma) 131 | else: # linear 132 | decay = (1 - ious) / (1 - ious_cmax) 133 | 134 | decay = decay.min(dim=0) 135 | return confidence * decay 136 | 137 | 138 | def get_bboxes_from_anchors(anchors, confidence_threshold, iou_threshold, labels_dict, depth_layer = None, depth_threshold = 0.1): 139 | nbatches = anchors.shape[0] 140 | batch_bboxes = [] 141 | labels = [] 142 | 143 | for nbatch in range(nbatches): 144 | img_anchor = anchors[nbatch] 145 | confidence_filter = img_anchor[:, 4] > confidence_threshold 146 | img_anchor = img_anchor[confidence_filter] 147 | if depth_layer != None: 148 | keep = nms_with_depth(xywh2xyxy(img_anchor[:, :4]), img_anchor[:, 4], iou_threshold, depth_layer, depth_threshold) 149 | else: 150 | keep = nms(xywh2xyxy(img_anchor[:, :4]), img_anchor[:, 4], iou_threshold) 151 | 152 | img_bboxes = img_anchor[keep] 153 | batch_bboxes.append(img_bboxes) 154 | if len(img_bboxes) == 0: 155 | labels.append([]) 156 | continue 157 | labels.append([labels_dict[x.item()] for x in img_bboxes[:, 5:].argmax(1)]) 158 | 159 | return batch_bboxes, labels 160 | 161 | 162 | def iou_all_to_all(a, b): 163 | area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]) 164 | 165 | iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0]) 166 | ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1]) 167 | 168 | iw = torch.clamp(iw, min=0) 169 | ih = torch.clamp(ih, min=0) 170 | 171 | ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih 172 | 173 | ua = torch.clamp(ua, min=1e-8) 174 | 175 | intersection = iw * ih 176 | 177 | IoU = intersection / ua 178 | 179 | return IoU 180 | 181 | def smooth_ln(x, smooth =0.5): 182 | return torch.where( 183 | torch.le(x, smooth), 184 | -torch.log(1 - x), 185 | ((x - smooth) / (1 - smooth)) - np.log(1 - smooth) 186 | ) 187 | 188 | def iog(ground_truth, prediction): 189 | 190 | inter_xmin = torch.max(ground_truth[:, 0], prediction[:, 0]) 191 | inter_ymin = torch.max(ground_truth[:, 1], prediction[:, 1]) 192 | inter_xmax = torch.min(ground_truth[:, 2], prediction[:, 2]) 193 | inter_ymax = torch.min(ground_truth[:, 3], prediction[:, 3]) 194 | Iw = torch.clamp(inter_xmax - inter_xmin, min=0) 195 | Ih = torch.clamp(inter_ymax - inter_ymin, min=0) 196 | I = Iw * Ih 197 | G = (ground_truth[:, 2] - ground_truth[:, 0]) * (ground_truth[:, 3] - ground_truth[:, 1]) 198 | return I / G 199 | -------------------------------------------------------------------------------- /ranger.py: -------------------------------------------------------------------------------- 1 | # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. 2 | 3 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 4 | # and/or 5 | # https://github.com/lessw2020/Best-Deep-Learning-Optimizers 6 | 7 | # Ranger has now been used to capture 12 records on the FastAI leaderboard. 8 | 9 | # This version = 20.4.11 10 | 11 | # Credits: 12 | # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization 13 | # RAdam --> https://github.com/LiyuanLucasLiu/RAdam 14 | # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. 15 | # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 16 | 17 | # summary of changes: 18 | # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. 19 | # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), 20 | # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. 21 | # changes 8/31/19 - fix references to *self*.N_sma_threshold; 22 | # changed eps to 1e-5 as better default than 1e-8. 23 | 24 | import math 25 | import torch 26 | from torch.optim.optimizer import Optimizer, required 27 | 28 | 29 | class Ranger(Optimizer): 30 | 31 | def __init__(self, params, lr=1e-3, # lr 32 | alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options 33 | betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options 34 | # Gradient centralization on or off, applied to conv layers only or conv + fc layers 35 | use_gc=True, gc_conv_only=False 36 | ): 37 | 38 | # parameter checks 39 | if not 0.0 <= alpha <= 1.0: 40 | raise ValueError(f'Invalid slow update rate: {alpha}') 41 | if not 1 <= k: 42 | raise ValueError(f'Invalid lookahead steps: {k}') 43 | if not lr > 0: 44 | raise ValueError(f'Invalid Learning Rate: {lr}') 45 | if not eps > 0: 46 | raise ValueError(f'Invalid eps: {eps}') 47 | 48 | # parameter comments: 49 | # beta1 (momentum) of .95 seems to work better than .90... 50 | # N_sma_threshold of 5 seems better in testing than 4. 51 | # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. 52 | 53 | # prep defaults and init torch.optim base 54 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, 55 | N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay) 56 | super().__init__(params, defaults) 57 | 58 | # adjustable threshold 59 | self.N_sma_threshhold = N_sma_threshhold 60 | 61 | # look ahead params 62 | 63 | self.alpha = alpha 64 | self.k = k 65 | 66 | # radam buffer for state 67 | self.radam_buffer = [[None, None, None] for ind in range(10)] 68 | 69 | # gc on or off 70 | self.use_gc = use_gc 71 | 72 | # level of gradient centralization 73 | self.gc_gradient_threshold = 3 if gc_conv_only else 1 74 | 75 | print( 76 | f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}") 77 | if (self.use_gc and self.gc_gradient_threshold == 1): 78 | print(f"GC applied to both conv and fc layers") 79 | elif (self.use_gc and self.gc_gradient_threshold == 3): 80 | print(f"GC applied to conv layers only") 81 | 82 | def __setstate__(self, state): 83 | print("set state called") 84 | super(Ranger, self).__setstate__(state) 85 | 86 | def step(self, closure=None): 87 | loss = None 88 | # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure. 89 | # Uncomment if you need to use the actual closure... 90 | 91 | if closure is not None: 92 | loss = closure() 93 | 94 | # Evaluate averages and grad, update param tensors 95 | for group in self.param_groups: 96 | 97 | for p in group['params']: 98 | if p.grad is None: 99 | continue 100 | grad = p.grad.data.float() 101 | 102 | if grad.is_sparse: 103 | raise RuntimeError( 104 | 'Ranger optimizer does not support sparse gradients') 105 | 106 | p_data_fp32 = p.data.float() 107 | 108 | state = self.state[p] # get state dict for this param 109 | 110 | if len(state) == 0: # if first time to run...init dictionary with our desired entries 111 | # if self.first_run_check==0: 112 | # self.first_run_check=1 113 | #print("Initializing slow buffer...should not see this at load from saved model!") 114 | state['step'] = 0 115 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 116 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 117 | 118 | # look ahead weight storage now in state dict 119 | state['slow_buffer'] = torch.empty_like(p.data) 120 | state['slow_buffer'].copy_(p.data) 121 | 122 | else: 123 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 124 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as( 125 | p_data_fp32) 126 | 127 | # begin computations 128 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 129 | beta1, beta2 = group['betas'] 130 | 131 | # GC operation for Conv layers and FC layers 132 | if grad.dim() > self.gc_gradient_threshold: 133 | grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) 134 | 135 | state['step'] += 1 136 | 137 | # compute variance mov avg 138 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 139 | # compute mean moving avg 140 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 141 | 142 | buffered = self.radam_buffer[int(state['step'] % 10)] 143 | 144 | if state['step'] == buffered[0]: 145 | N_sma, step_size = buffered[1], buffered[2] 146 | else: 147 | buffered[0] = state['step'] 148 | beta2_t = beta2 ** state['step'] 149 | N_sma_max = 2 / (1 - beta2) - 1 150 | N_sma = N_sma_max - 2 * \ 151 | state['step'] * beta2_t / (1 - beta2_t) 152 | buffered[1] = N_sma 153 | if N_sma > self.N_sma_threshhold: 154 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * ( 155 | N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 156 | else: 157 | step_size = 1.0 / (1 - beta1 ** state['step']) 158 | buffered[2] = step_size 159 | 160 | if group['weight_decay'] != 0: 161 | p_data_fp32.add_(-group['weight_decay'] 162 | * group['lr'], p_data_fp32) 163 | 164 | # apply lr 165 | if N_sma > self.N_sma_threshhold: 166 | denom = exp_avg_sq.sqrt().add_(group['eps']) 167 | p_data_fp32.addcdiv_(-step_size * 168 | group['lr'], exp_avg, denom) 169 | else: 170 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 171 | 172 | p.data.copy_(p_data_fp32) 173 | 174 | # integrated look ahead... 175 | # we do it at the param level instead of group level 176 | if state['step'] % group['k'] == 0: 177 | # get access to slow param tensor 178 | slow_p = state['slow_buffer'] 179 | # (fast weights - slow weights) * alpha 180 | slow_p.add_(self.alpha, p.data - slow_p) 181 | # copy interpolated weights to RAdam param tensor 182 | p.data.copy_(slow_p) 183 | 184 | return loss 185 | -------------------------------------------------------------------------------- /radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class RAdam(Optimizer): 6 | 7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 8 | if not 0.0 <= lr: 9 | raise ValueError("Invalid learning rate: {}".format(lr)) 10 | if not 0.0 <= eps: 11 | raise ValueError("Invalid epsilon value: {}".format(eps)) 12 | if not 0.0 <= betas[0] < 1.0: 13 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 14 | if not 0.0 <= betas[1] < 1.0: 15 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 16 | 17 | self.degenerated_to_sgd = degenerated_to_sgd 18 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 19 | for param in params: 20 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 21 | param['buffer'] = [[None, None, None] for _ in range(10)] 22 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) 23 | super(RAdam, self).__init__(params, defaults) 24 | 25 | def __setstate__(self, state): 26 | super(RAdam, self).__setstate__(state) 27 | 28 | def step(self, closure=None): 29 | 30 | loss = None 31 | if closure is not None: 32 | loss = closure() 33 | 34 | for group in self.param_groups: 35 | 36 | for p in group['params']: 37 | if p.grad is None: 38 | continue 39 | grad = p.grad.data.float() 40 | if grad.is_sparse: 41 | raise RuntimeError('RAdam does not support sparse gradients') 42 | 43 | p_data_fp32 = p.data.float() 44 | 45 | state = self.state[p] 46 | 47 | if len(state) == 0: 48 | state['step'] = 0 49 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 50 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 51 | else: 52 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 53 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 54 | 55 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 56 | beta1, beta2 = group['betas'] 57 | 58 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 59 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 60 | 61 | state['step'] += 1 62 | buffered = group['buffer'][int(state['step'] % 10)] 63 | if state['step'] == buffered[0]: 64 | N_sma, step_size = buffered[1], buffered[2] 65 | else: 66 | buffered[0] = state['step'] 67 | beta2_t = beta2 ** state['step'] 68 | N_sma_max = 2 / (1 - beta2) - 1 69 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 70 | buffered[1] = N_sma 71 | 72 | # more conservative since it's an approximated value 73 | if N_sma >= 5: 74 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 75 | elif self.degenerated_to_sgd: 76 | step_size = 1.0 / (1 - beta1 ** state['step']) 77 | else: 78 | step_size = -1 79 | buffered[2] = step_size 80 | 81 | # more conservative since it's an approximated value 82 | if N_sma >= 5: 83 | if group['weight_decay'] != 0: 84 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 85 | denom = exp_avg_sq.sqrt().add_(group['eps']) 86 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 87 | p.data.copy_(p_data_fp32) 88 | elif step_size > 0: 89 | if group['weight_decay'] != 0: 90 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 91 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 92 | p.data.copy_(p_data_fp32) 93 | 94 | return loss 95 | 96 | class PlainRAdam(Optimizer): 97 | 98 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 99 | if not 0.0 <= lr: 100 | raise ValueError("Invalid learning rate: {}".format(lr)) 101 | if not 0.0 <= eps: 102 | raise ValueError("Invalid epsilon value: {}".format(eps)) 103 | if not 0.0 <= betas[0] < 1.0: 104 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 105 | if not 0.0 <= betas[1] < 1.0: 106 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 107 | 108 | self.degenerated_to_sgd = degenerated_to_sgd 109 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 110 | 111 | super(PlainRAdam, self).__init__(params, defaults) 112 | 113 | def __setstate__(self, state): 114 | super(PlainRAdam, self).__setstate__(state) 115 | 116 | def step(self, closure=None): 117 | 118 | loss = None 119 | if closure is not None: 120 | loss = closure() 121 | 122 | for group in self.param_groups: 123 | 124 | for p in group['params']: 125 | if p.grad is None: 126 | continue 127 | grad = p.grad.data.float() 128 | if grad.is_sparse: 129 | raise RuntimeError('RAdam does not support sparse gradients') 130 | 131 | p_data_fp32 = p.data.float() 132 | 133 | state = self.state[p] 134 | 135 | if len(state) == 0: 136 | state['step'] = 0 137 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 138 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 139 | else: 140 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 141 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 142 | 143 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 144 | beta1, beta2 = group['betas'] 145 | 146 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 147 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 148 | 149 | state['step'] += 1 150 | beta2_t = beta2 ** state['step'] 151 | N_sma_max = 2 / (1 - beta2) - 1 152 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 153 | 154 | 155 | # more conservative since it's an approximated value 156 | if N_sma >= 5: 157 | if group['weight_decay'] != 0: 158 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 159 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 160 | denom = exp_avg_sq.sqrt().add_(group['eps']) 161 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 162 | p.data.copy_(p_data_fp32) 163 | elif self.degenerated_to_sgd: 164 | if group['weight_decay'] != 0: 165 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 166 | step_size = group['lr'] / (1 - beta1 ** state['step']) 167 | p_data_fp32.add_(-step_size, exp_avg) 168 | p.data.copy_(p_data_fp32) 169 | 170 | return loss 171 | 172 | 173 | class AdamW(Optimizer): 174 | 175 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 176 | if not 0.0 <= lr: 177 | raise ValueError("Invalid learning rate: {}".format(lr)) 178 | if not 0.0 <= eps: 179 | raise ValueError("Invalid epsilon value: {}".format(eps)) 180 | if not 0.0 <= betas[0] < 1.0: 181 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 182 | if not 0.0 <= betas[1] < 1.0: 183 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 184 | 185 | defaults = dict(lr=lr, betas=betas, eps=eps, 186 | weight_decay=weight_decay, warmup = warmup) 187 | super(AdamW, self).__init__(params, defaults) 188 | 189 | def __setstate__(self, state): 190 | super(AdamW, self).__setstate__(state) 191 | 192 | def step(self, closure=None): 193 | loss = None 194 | if closure is not None: 195 | loss = closure() 196 | 197 | for group in self.param_groups: 198 | 199 | for p in group['params']: 200 | if p.grad is None: 201 | continue 202 | grad = p.grad.data.float() 203 | if grad.is_sparse: 204 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 205 | 206 | p_data_fp32 = p.data.float() 207 | 208 | state = self.state[p] 209 | 210 | if len(state) == 0: 211 | state['step'] = 0 212 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 213 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 214 | else: 215 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 216 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 217 | 218 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 219 | beta1, beta2 = group['betas'] 220 | 221 | state['step'] += 1 222 | 223 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 224 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 225 | 226 | denom = exp_avg_sq.sqrt().add_(group['eps']) 227 | bias_correction1 = 1 - beta1 ** state['step'] 228 | bias_correction2 = 1 - beta2 ** state['step'] 229 | 230 | if group['warmup'] > state['step']: 231 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 232 | else: 233 | scheduled_lr = group['lr'] 234 | 235 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 236 | 237 | if group['weight_decay'] != 0: 238 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 239 | 240 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 241 | 242 | p.data.copy_(p_data_fp32) 243 | 244 | return loss -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import torchvision.transforms as transforms 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | import torch.nn.functional as F 8 | import utils 9 | import random 10 | 11 | 12 | class ListDataset(Dataset): 13 | def __init__(self, list_path, img_dir="images", labels_dir="labels", img_extensions=[".JPG"], img_size=608, train=True, bbox_minsize=0.01, brightness_range=0.25, contrast_range=0.25, hue_range=0.05, saturation_range=0.25, cross_offset=0.2): 14 | with open(list_path, "r") as file: 15 | self.img_files = file.read().splitlines() 16 | 17 | self.label_files = [] 18 | for path in self.img_files: 19 | path = path.replace(img_dir, labels_dir) 20 | for ext in img_extensions: 21 | path = path.replace(ext, ".txt") 22 | 23 | self.label_files.append(path) 24 | 25 | self.img_size = img_size 26 | self.to_tensor = transforms.ToTensor() 27 | 28 | self.train = train 29 | 30 | self.bbox_minsize = bbox_minsize 31 | 32 | self.brightness_range = brightness_range 33 | self.contrast_range = contrast_range 34 | self.hue_range = hue_range 35 | self.saturation_range = saturation_range 36 | 37 | self.cross_offset = cross_offset 38 | 39 | def __getitem__(self, index): 40 | 41 | img_path = self.img_files[index % len(self.img_files)].rstrip() 42 | label_path = self.label_files[index % len(self.img_files)].rstrip() 43 | 44 | 45 | # Getting image 46 | img = Image.open(img_path).convert('RGB') 47 | width, height = img.size 48 | 49 | if os.path.exists(label_path): 50 | boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 5)) 51 | else: 52 | print(label_path) 53 | 54 | # RESIZING 55 | if width > height: 56 | ratio = height/width 57 | t_width = self.img_size 58 | t_height = int(ratio * self.img_size) 59 | else: 60 | ratio = width/height 61 | t_width = int(ratio * self.img_size) 62 | t_height = self.img_size 63 | 64 | img = transforms.functional.resize(img, (t_height, t_width)) 65 | 66 | # IF TRAIN APPLY BRIGHTNESS CONTRAST HUE SATURTATION 67 | if self.train: 68 | brightness_rnd = random.uniform(1 - self.brightness_range, 1 + self.brightness_range) 69 | contrast_rnd = random.uniform(1 - self.contrast_range, 1 + self.contrast_range) 70 | hue_rnd = random.uniform(-self.hue_range, self.hue_range) 71 | saturation_rnd = random.uniform(1 - self.saturation_range, 1 + self.saturation_range) 72 | 73 | img = transforms.functional.adjust_brightness(img, brightness_rnd) 74 | img = transforms.functional.adjust_contrast(img, contrast_rnd) 75 | img = transforms.functional.adjust_hue(img, hue_rnd) 76 | img = transforms.functional.adjust_saturation(img, saturation_rnd) 77 | 78 | 79 | # CONVERTING TO TENSOR 80 | tensor_img = transforms.functional.to_tensor(img) 81 | 82 | # Handle grayscaled images 83 | if len(tensor_img.shape) != 3: 84 | tensor_img = tensor_img.unsqueeze(0) 85 | tensor_img = tensor_img.expand((3, img.shape[1:])) 86 | 87 | # !!!WARNING IN PIL IT'S WIDTH HEIGHT, WHEN IN PYTORCH IT IS HEIGHT WIDTH 88 | 89 | # Apply augmentations for train it would be mosaic 90 | if self.train: 91 | mossaic_img = torch.zeros(3, self.img_size, self.img_size) 92 | 93 | # FINDING CROSS POINT 94 | cross_x = int(random.uniform(self.img_size * self.cross_offset, self.img_size * (1 - self.cross_offset))) 95 | cross_y = int(random.uniform(self.img_size * self.cross_offset, self.img_size * (1 - self.cross_offset))) 96 | 97 | fragment_img, fragment_bbox = self.get_mosaic(0, cross_x, cross_y, tensor_img, boxes) 98 | mossaic_img[:, 0:cross_y, 0:cross_x] = fragment_img 99 | boxes = fragment_bbox 100 | 101 | for n in range(1, 4): 102 | raw_fragment_img, raw_fragment_bbox = self.get_img_for_mosaic(brightness_rnd, contrast_rnd, hue_rnd, saturation_rnd) 103 | fragment_img, fragment_bbox = self.get_mosaic(n, cross_x, cross_y, raw_fragment_img, raw_fragment_bbox) 104 | boxes = torch.cat([boxes, fragment_bbox]) 105 | 106 | if n == 1: 107 | mossaic_img[:, 0 : cross_y, cross_x : self.img_size] = fragment_img 108 | elif n == 2: 109 | mossaic_img[:, cross_y : self.img_size, 0 : cross_x] = fragment_img 110 | elif n == 3: 111 | mossaic_img[:, cross_y : self.img_size, cross_x : self.img_size] = fragment_img 112 | 113 | #Set mossaic to return tensor 114 | tensor_img = mossaic_img 115 | 116 | 117 | # For validation it would be letterbox 118 | else: 119 | xyxy_bboxes = utils.xywh2xyxy(boxes[:, 1:]) 120 | 121 | #IMG 122 | padding = abs((t_width - t_height))//2 123 | padded_img = torch.zeros(3, self.img_size, self.img_size) 124 | if t_width > t_height: 125 | padded_img[:, padding:padding+t_height] = tensor_img 126 | else: 127 | padded_img[:, :, padding:padding+t_width] = tensor_img 128 | 129 | tensor_img = padded_img 130 | 131 | relative_padding = padding/self.img_size 132 | #BOXES 133 | if t_width > t_height: 134 | #Change y's relative position 135 | xyxy_bboxes[:, 1] *= ratio 136 | xyxy_bboxes[:, 3] *= ratio 137 | xyxy_bboxes[:, 1] += relative_padding 138 | xyxy_bboxes[:, 3] += relative_padding 139 | else:#x's 140 | xyxy_bboxes[:, 0] *= ratio 141 | xyxy_bboxes[:, 2] *= ratio 142 | xyxy_bboxes[:, 0] += relative_padding 143 | xyxy_bboxes[:, 2] += relative_padding 144 | 145 | boxes[:, 1:] = utils.xyxy2xywh(xyxy_bboxes) 146 | 147 | 148 | 149 | 150 | targets = torch.zeros((len(boxes), 6)) 151 | targets[:, 1:] = boxes 152 | 153 | 154 | return img_path, tensor_img, targets 155 | 156 | def get_img_for_mosaic(self, brightness_rnd, contrast_rnd, hue_rnd, saturation_rnd): 157 | random_index = random.randrange(0, len(self.img_files)) 158 | img_path = self.img_files[random_index].rstrip() 159 | label_path = self.label_files[random_index].rstrip() 160 | 161 | 162 | 163 | # Getting image 164 | img = Image.open(img_path).convert('RGB') 165 | width, height = img.size 166 | 167 | if os.path.exists(label_path): 168 | boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 5)) 169 | 170 | #RESIZING 171 | if width > height: 172 | ratio = height/width 173 | t_width = self.img_size 174 | t_height = int(ratio * self.img_size) 175 | 176 | else: 177 | ratio = width/height 178 | t_width = int(ratio * self.img_size) 179 | t_height = self.img_size 180 | 181 | img = transforms.functional.resize(img, (t_height, t_width)) 182 | 183 | img = transforms.functional.adjust_brightness(img, brightness_rnd) 184 | img = transforms.functional.adjust_contrast(img, contrast_rnd) 185 | img = transforms.functional.adjust_hue(img, hue_rnd) 186 | img = transforms.functional.adjust_saturation(img, saturation_rnd) 187 | 188 | #CONVERTING TO TENSOR 189 | tensor_img = transforms.functional.to_tensor(img) 190 | 191 | # Handle grayscaled images 192 | if len(tensor_img.shape) != 3: 193 | tensor_img = tensor_img.unsqueeze(0) 194 | tensor_img = tensor_img.expand((3, img.shape[1:])) 195 | 196 | return tensor_img, boxes 197 | 198 | 199 | # N is spatial parameter if 0 TOP LEFT, if 1 TOP RIGHT, if 2 BOTTOM LEFT, if 3 BOTTOM RIGHT 200 | def get_mosaic(self, n, cross_x, cross_y, tensor_img, boxes): 201 | t_height = tensor_img.shape[1] 202 | t_width = tensor_img.shape[2] 203 | 204 | xyxy_bboxes = utils.xywh2xyxy(boxes[:, 1:]) 205 | 206 | relative_cross_x = cross_x / self.img_size 207 | relative_cross_y = cross_y / self.img_size 208 | 209 | #CALCULATING TARGET WIDTH AND HEIGHT OF PICTURE 210 | if n == 0: 211 | width_of_nth_pic = cross_x 212 | height_of_nth_pic = cross_y 213 | elif n == 1: 214 | width_of_nth_pic = self.img_size - cross_x 215 | height_of_nth_pic = cross_y 216 | elif n == 2: 217 | width_of_nth_pic = cross_x 218 | height_of_nth_pic = self.img_size - cross_y 219 | elif n == 3: 220 | width_of_nth_pic = self.img_size - cross_x 221 | height_of_nth_pic = self.img_size - cross_y 222 | 223 | # self.img_size - width_of_1st_pic 224 | # selg.img_size - height_of_1st_pic 225 | 226 | 227 | # CHOOSING TOP LEFT CORNER (doing offset to have more than fex pixels in bbox :-) ) 228 | cut_x1 = random.randint(0, int(t_width * 0.33)) 229 | cut_y1 = random.randint(0, int(t_height * 0.33)) 230 | 231 | 232 | # Now we should find which axis should we randomly enlarge (this we do by finding out which ratio is bigger); cross x is basically width of the top left picture 233 | if (t_width - cut_x1) / width_of_nth_pic < (t_height - cut_y1) / height_of_nth_pic: 234 | cut_x2 = random.randint(cut_x1 + int(t_width * 0.67), t_width) 235 | cut_y2 = int(cut_y1 + (cut_x2-cut_x1)/width_of_nth_pic*height_of_nth_pic) 236 | 237 | else: 238 | cut_y2 = random.randint(cut_y1 + int(t_height * 0.67), t_height) 239 | cut_x2 = int(cut_x1 + (cut_y2-cut_y1)/height_of_nth_pic*width_of_nth_pic) 240 | 241 | # RESIZING AND INSERTING (TO DO 2D interpolation wants 4 dimensions, so I add and remove one by using None and squeeze) 242 | tensor_img = F.interpolate(tensor_img[:, cut_y1:cut_y2, cut_x1:cut_x2][None], (height_of_nth_pic, width_of_nth_pic)).squeeze() 243 | 244 | # BBOX 245 | relative_cut_x1 = cut_x1 / t_width 246 | relative_cut_y1 = cut_y1 / t_height 247 | relative_cropped_width = (cut_x2 - cut_x1) / t_width 248 | relative_cropped_height = (cut_y2 - cut_y1) / t_height 249 | 250 | # SHIFTING TO CUTTED IMG SO X1 Y1 WILL 0 251 | xyxy_bboxes[:, 0] = xyxy_bboxes[:, 0] - relative_cut_x1 252 | xyxy_bboxes[:, 1] = xyxy_bboxes[:, 1] - relative_cut_y1 253 | xyxy_bboxes[:, 2] = xyxy_bboxes[:, 2] - relative_cut_x1 254 | xyxy_bboxes[:, 3] = xyxy_bboxes[:, 3] - relative_cut_y1 255 | 256 | # RESIZING TO CUTTED IMG SO X2 WILL BE 1 257 | xyxy_bboxes[:, 0] /= relative_cropped_width 258 | xyxy_bboxes[:, 1] /= relative_cropped_height 259 | xyxy_bboxes[:, 2] /= relative_cropped_width 260 | xyxy_bboxes[:, 3] /= relative_cropped_height 261 | 262 | # CLAMPING BOUNDING BOXES, SO THEY DO NOT OVERCOME OUTSIDE THE IMAGE 263 | xyxy_bboxes[:, 0].clamp_(0, 1) 264 | xyxy_bboxes[:, 1].clamp_(0, 1) 265 | xyxy_bboxes[:, 2].clamp_(0, 1) 266 | xyxy_bboxes[:, 3].clamp_(0, 1) 267 | 268 | # FILTER TO THROUGH OUT ALL SMALL BBOXES 269 | filter_minbbox = (xyxy_bboxes[:, 2] - xyxy_bboxes[:, 0] > self.bbox_minsize) & (xyxy_bboxes[:, 3] - xyxy_bboxes[:, 1] > self.bbox_minsize) 270 | 271 | # RESIZING TO MOSAIC 272 | if n == 0: 273 | xyxy_bboxes[:, 0] *= relative_cross_x # 274 | xyxy_bboxes[:, 1] *= relative_cross_y #(1 - relative_cross_y) 275 | xyxy_bboxes[:, 2] *= relative_cross_x # 276 | xyxy_bboxes[:, 3] *= relative_cross_y #(1 - relative_cross_y) 277 | elif n==1: 278 | xyxy_bboxes[:, 0] *= (1 - relative_cross_x) 279 | xyxy_bboxes[:, 1] *= relative_cross_y 280 | xyxy_bboxes[:, 2] *= (1 - relative_cross_x) 281 | xyxy_bboxes[:, 3] *= relative_cross_y 282 | elif n==2: 283 | xyxy_bboxes[:, 0] *= relative_cross_x 284 | xyxy_bboxes[:, 1] *= (1 - relative_cross_y) 285 | xyxy_bboxes[:, 2] *= relative_cross_x 286 | xyxy_bboxes[:, 3] *= (1 - relative_cross_y) 287 | elif n==3: 288 | xyxy_bboxes[:, 0] *= (1 - relative_cross_x) 289 | xyxy_bboxes[:, 1] *= (1 - relative_cross_y) 290 | xyxy_bboxes[:, 2] *= (1 - relative_cross_x) 291 | xyxy_bboxes[:, 3] *= (1 - relative_cross_y) 292 | 293 | # RESIZING TO MOSAIC 294 | if n == 0: 295 | xyxy_bboxes[:, 0] = xyxy_bboxes[:, 0] # + relative_cross_x 296 | xyxy_bboxes[:, 1] = xyxy_bboxes[:, 1] # + relative_cross_y 297 | xyxy_bboxes[:, 2] = xyxy_bboxes[:, 2] # + relative_cross_x 298 | xyxy_bboxes[:, 3] = xyxy_bboxes[:, 3] # + relative_cross_y 299 | elif n==1: 300 | xyxy_bboxes[:, 0] = xyxy_bboxes[:, 0] + relative_cross_x 301 | xyxy_bboxes[:, 1] = xyxy_bboxes[:, 1] 302 | xyxy_bboxes[:, 2] = xyxy_bboxes[:, 2] + relative_cross_x 303 | xyxy_bboxes[:, 3] = xyxy_bboxes[:, 3] 304 | elif n==2: 305 | xyxy_bboxes[:, 0] = xyxy_bboxes[:, 0] 306 | xyxy_bboxes[:, 1] = xyxy_bboxes[:, 1] + relative_cross_y 307 | xyxy_bboxes[:, 2] = xyxy_bboxes[:, 2] 308 | xyxy_bboxes[:, 3] = xyxy_bboxes[:, 3] + relative_cross_y 309 | elif n==3: 310 | xyxy_bboxes[:, 0] = xyxy_bboxes[:, 0] + relative_cross_x 311 | xyxy_bboxes[:, 1] = xyxy_bboxes[:, 1] + relative_cross_y 312 | xyxy_bboxes[:, 2] = xyxy_bboxes[:, 2] + relative_cross_x 313 | xyxy_bboxes[:, 3] = xyxy_bboxes[:, 3] + relative_cross_y 314 | 315 | boxes = boxes[filter_minbbox] 316 | boxes[:, 1:] = utils.xyxy2xywh(xyxy_bboxes)[filter_minbbox] 317 | 318 | return tensor_img, boxes 319 | 320 | def collate_fn(self, batch): 321 | paths, imgs, targets = list(zip(*batch)) 322 | # Remove empty placeholder targets 323 | targets = [boxes for boxes in targets if boxes is not None] 324 | # Add sample index to targets 325 | for i, boxes in enumerate(targets): 326 | boxes[:, 0] = i 327 | targets = torch.cat(targets, 0) 328 | 329 | return paths, torch.stack(imgs), targets 330 | 331 | def __len__(self): 332 | return len(self.img_files) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | 6 | # Need for Pi 7 | import math 8 | 9 | #Mypy 10 | import typing as ty 11 | 12 | # Model consists of 13 | # - backbone 14 | # - neck 15 | # - head 16 | # - yolo 17 | 18 | # To implement: 19 | # Mish (just download) DONE 20 | # CSP (is in architecture) DONE 21 | # MiWRC (attention_forward) 22 | # SPP block (in architecture) DONE 23 | # PAN (in architecture) DONE 24 | # Implemented with 25 | # https://lutzroeder.github.io/netron/?url=https%3A%2F%2Fraw.githubusercontent.com%2FAlexeyAB%2Fdarknet%2Fmaster%2Fcfg%2Fyolov4.cfg 26 | 27 | 28 | class BadArguments(Exception): 29 | pass 30 | 31 | import torch 32 | import torch.nn.functional as F 33 | 34 | # Mish as written in darknet speed check 35 | class darknet_mish(torch.autograd.Function): 36 | """ 37 | We can implement our own custom autograd Functions by subclassing 38 | torch.autograd.Function and implementing the forward and backward passes 39 | which operate on Tensors. 40 | """ 41 | 42 | @staticmethod 43 | def forward(ctx, input): 44 | """ 45 | In the forward pass we receive a Tensor containing the input and return 46 | a Tensor containing the output. ctx is a context object that can be used 47 | to stash information for backward computation. You can cache arbitrary 48 | objects for use in the backward pass using the ctx.save_for_backward method. 49 | """ 50 | ctx.save_for_backward(input) 51 | e = torch.exp(input) 52 | n = e * e + 2 * e 53 | mask = input <= -0.6 54 | input[mask] = (input * (n / (n + 2)))[mask] 55 | input[~mask] = ((input - 2 * (input / (n + 2))))[~mask] 56 | 57 | return input 58 | 59 | @staticmethod 60 | def backward(ctx, grad_output): 61 | """ 62 | In the backward pass we receive a Tensor containing the gradient of the loss 63 | with respect to the output, and we need to compute the gradient of the loss 64 | with respect to the input. 65 | """ 66 | input, = ctx.saved_tensors 67 | 68 | sp = F.softplus(input) 69 | grad_sp = -torch.expm1(sp) 70 | 71 | tsp = F.tanh(sp) 72 | grad_tsp = (1 - tsp * tsp) * grad_sp 73 | grad = input * grad_tsp + tsp 74 | return grad 75 | 76 | 77 | class DarknetMish(nn.Module): 78 | def __init__(self): 79 | super().__init__() 80 | 81 | def forward(self, x): 82 | return darknet_mish.apply(x) 83 | 84 | # @torch.jit.script 85 | class HardMish(nn.Module): 86 | def __init__(self): 87 | super().__init__() 88 | def forward(self, x): 89 | return (x/2) * torch.clamp(x+2, min=0, max=2) 90 | 91 | 92 | # Taken from https://github.com/lessw2020/mish 93 | class Mish(nn.Module): 94 | def __init__(self): 95 | super().__init__() 96 | 97 | def forward(self, x): 98 | # inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!) 99 | return x * torch.tanh(F.softplus(x)) 100 | 101 | 102 | # Taken from https://github.com/Randl/DropBlock-pytorch/blob/master/DropBlock.py 103 | class DropBlock2D(nn.Module): 104 | r"""Randomly zeroes spatial blocks of the input tensor. 105 | As described in the paper 106 | `DropBlock: A regularization method for convolutional networks`_ , 107 | dropping whole blocks of feature map allows to remove semantic 108 | information as compared to regular dropout. 109 | Args: 110 | keep_prob (float, optional): probability of an element to be kept. 111 | Authors recommend to linearly decrease this value from 1 to desired 112 | value. 113 | block_size (int, optional): size of the block. Block size in paper 114 | usually equals last feature map dimensions. 115 | Shape: 116 | - Input: :math:`(N, C, H, W)` 117 | - Output: :math:`(N, C, H, W)` (same shape as input) 118 | .. _DropBlock: A regularization method for convolutional networks: 119 | https://arxiv.org/abs/1810.12890 120 | """ 121 | 122 | def __init__(self, keep_prob=0.9, block_size=7): 123 | super(DropBlock2D, self).__init__() 124 | self.keep_prob = keep_prob 125 | self.block_size = block_size 126 | 127 | def forward(self, input): 128 | # print("Before: ", torch.isnan(input).sum()) 129 | if not self.training or self.keep_prob == 1: 130 | return input 131 | gamma = (1. - self.keep_prob) / self.block_size ** 2 132 | for sh in input.shape[2:]: 133 | gamma *= sh / (sh - self.block_size + 1) 134 | M = torch.bernoulli(torch.ones_like(input) * gamma).to(device=input.device) 135 | Msum = F.conv2d(M, 136 | torch.ones((input.shape[1], 1, self.block_size, self.block_size)).to(device=input.device, 137 | dtype=input.dtype), 138 | padding=self.block_size // 2, 139 | groups=input.shape[1]) 140 | mask = (Msum < 1).to(device=input.device, dtype=input.dtype) 141 | # print("After: ", torch.isnan(input * mask * mask.numel() /mask.sum()).sum()) 142 | return input * mask * mask.numel() / mask.sum() 143 | 144 | class SAM(nn.Module): 145 | def __init__(self, in_channels): 146 | super().__init__() 147 | self.conv = nn.Conv2d(in_channels, out_channels=1, kernel_size=1) 148 | 149 | def forward(self, x): 150 | spatial_features = self.conv(x) 151 | attention = torch.sigmoid(spatial_features) 152 | return attention.expand_as(x) * x 153 | 154 | #Got and modified from https://arxiv.org/pdf/2003.13630.pdf 155 | class FastGlobalAvgPool2d(): 156 | def __init__(self, flatten=False): 157 | self.flatten = flatten 158 | def __call__(self, x): 159 | if self.flatten: 160 | in_size = x.size() 161 | return x.view((in_size[0], in_size[1], -1)).mean(dim=2) 162 | else: 163 | return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) 164 | 165 | #As an example was taken https://github.com/BangguWu/ECANet/blob/master/models/eca_module.py 166 | class ECA(nn.Module): 167 | def __init__(self, k_size=3): 168 | super().__init__() 169 | self.avg_pool = FastGlobalAvgPool2d(flatten=False) 170 | self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 171 | 172 | def forward(self, x): 173 | squized_channels = self.avg_pool(x) 174 | channel_features = self.conv(squized_channels.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 175 | attention = torch.sigmoid(channel_features) 176 | return attention.expand_as(x) * x 177 | 178 | 179 | #Taken from https://github.com/joe-siyuan-qiao/WeightStandardization modified with new std https://github.com/joe-siyuan-qiao/WeightStandardization/issues/1#issuecomment-528050344 180 | class Conv2dWS(nn.Conv2d): 181 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 182 | padding=0, dilation=1, groups=1, bias=True): 183 | super(Conv2dWS, self).__init__(in_channels, out_channels, kernel_size, stride, 184 | padding, dilation, groups, bias) 185 | 186 | def forward(self, x): 187 | # print("IN: ", (~torch.isfinite(x)).sum()) 188 | weight = self.weight 189 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, 190 | keepdim=True).mean(dim=3, keepdim=True) 191 | weight = weight - weight_mean 192 | std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 2e-5).view(-1, 1, 1, 1) + 1e-5 193 | weight = weight / std.expand_as(weight) 194 | return F.conv2d(x, weight, self.bias, self.stride, 195 | self.padding, self.dilation, self.groups) 196 | 197 | # From https://arxiv.org/pdf/2003.10152.pdf and https://github.com/Wizaron/coord-conv-pytorch/blob/master/coord_conv.py 198 | class AddCoordChannels(nn.Module): 199 | def __init__(self, w=9, h=9, b=1): 200 | super().__init__() 201 | self.w = w 202 | self.h = h 203 | self.b = b 204 | self.y_coords = 2.0 * torch.arange(h).unsqueeze(1).expand(h, w) / (h - 1.0) - 1.0 205 | self.x_coords = 2.0 * torch.arange(w).unsqueeze(0).expand(h, w) / (w - 1.0) - 1.0 206 | self.coords = torch.stack((self.x_coords, self.y_coords), dim=0) 207 | self.coords = torch.unsqueeze(self.coords, dim=0).repeat(b, 1, 1, 1) 208 | 209 | def forward(self, x): 210 | b, c, h, w = x.shape 211 | if w != self.w or h != self.h or b != self.b: 212 | self.w = w 213 | self.h = h 214 | self.b = b 215 | self.y_coords = 2.0 * torch.arange(h).unsqueeze(1).expand(h, w) / (h - 1.0) - 1.0 216 | self.x_coords = 2.0 * torch.arange(w).unsqueeze(0).expand(h, w) / (w - 1.0) - 1.0 217 | self.coords = torch.stack((self.x_coords, self.y_coords), dim=0) 218 | self.coords = torch.unsqueeze(self.coords, dim=0).repeat(b, 1, 1, 1) 219 | 220 | return torch.cat((x, self.coords.to(x.device)), dim=1) 221 | 222 | #Was taken from https://github.com/joe-siyuan-qiao/Batch-Channel-Normalization 223 | class BCNorm(nn.Module): 224 | def __init__(self, num_channels, num_groups=1, eps=1e-05, estimate=False): 225 | super(BCNorm, self).__init__() 226 | self.num_channels = num_channels 227 | self.num_groups = num_groups 228 | self.eps = eps 229 | self.weight = Parameter(torch.ones(1, num_groups, 1)) 230 | self.bias = Parameter(torch.zeros(1, num_groups, 1)) 231 | if estimate: 232 | self.bn = EstBN(num_channels) 233 | else: 234 | self.bn = nn.BatchNorm2d(num_channels) 235 | 236 | def forward(self, inp): 237 | out = self.bn(inp) 238 | out = out.view(1, inp.size(0) * self.num_groups, -1) 239 | out = torch.batch_norm(out, None, None, None, None, True, 0, self.eps, True) 240 | out = out.view(inp.size(0), self.num_groups, -1) 241 | out = self.weight * out + self.bias 242 | out = out.view_as(inp) 243 | return out 244 | 245 | class EstBN(nn.Module): 246 | 247 | def __init__(self, num_features): 248 | super(EstBN, self).__init__() 249 | self.num_features = num_features 250 | self.weight = Parameter(torch.ones(num_features)) 251 | self.bias = Parameter(torch.zeros(num_features)) 252 | self.register_buffer('running_mean', torch.zeros(num_features)) 253 | self.register_buffer('running_var', torch.ones(num_features)) 254 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 255 | self.register_buffer('estbn_moving_speed', torch.zeros(1)) 256 | 257 | def forward(self, inp): 258 | ms = self.estbn_moving_speed.item() 259 | if self.training: 260 | with torch.no_grad(): 261 | inp_t = inp.transpose(0, 1).contiguous().view(self.num_features, -1) 262 | running_mean = inp_t.mean(dim=1) 263 | inp_t = inp_t - self.running_mean.view(-1, 1) 264 | running_var = torch.mean(inp_t * inp_t, dim=1) 265 | self.running_mean.data.mul_(1 - ms).add_(ms * running_mean.data) 266 | self.running_var.data.mul_(1 - ms).add_(ms * running_var.data) 267 | out = inp - self.running_mean.view(1, -1, 1, 1) 268 | out = out / torch.sqrt(self.running_var + 1e-5).view(1, -1, 1, 1) 269 | weight = self.weight.view(1, -1, 1, 1) 270 | bias = self.bias.view(1, -1, 1, 1) 271 | out = weight * out + bias 272 | return out 273 | 274 | # Taken and modified from https://github.com/Tianxiaomo/pytorch-YOLOv4/blob/master/models.py 275 | class ConvBlock(nn.Module): 276 | def __init__(self, in_channels, out_channels, kernel_size, stride, activation, bn=True, bias=False, dropblock=False, sam=False, eca=False, ws=False, coord=False, hard_mish=False, bcn=False, mbn=False): 277 | super().__init__() 278 | 279 | # PADDING is (ks-1)/2 280 | padding = (kernel_size - 1) // 2 281 | 282 | modules: ty.List[ty.Union[nn.Module]] = [] 283 | #Adding two more to input channels if coord 284 | if coord: 285 | in_channels += 2 286 | modules.append(AddCoordChannels()) 287 | if ws: 288 | modules.append(Conv2dWS(in_channels, out_channels, kernel_size, stride, padding, bias=bias)) 289 | else: 290 | modules.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)) 291 | if bn: 292 | if bcn: 293 | modules.append(BCNorm(out_channels, estimate=True)) 294 | elif mbn: 295 | modules.append(EstBN(out_channels)) 296 | else: 297 | modules.append(nn.BatchNorm2d(out_channels, track_running_stats=not ws)) #IF WE ARE NOT USING track running stats and using WS, it just explodes. 298 | if activation == "mish": 299 | if hard_mish: 300 | modules.append(HardMish()) 301 | else: 302 | modules.append(Mish()) 303 | elif activation == "relu": 304 | modules.append(nn.ReLU(inplace=True)) 305 | elif activation == "leaky": 306 | modules.append(nn.LeakyReLU(0.1, inplace=True)) 307 | elif activation == "linear": 308 | pass 309 | else: 310 | raise BadArguments("Please use one of suggested activations: mish, relu, leaky, linear.") 311 | 312 | if sam: 313 | modules.append(SAM(out_channels)) 314 | 315 | if eca: 316 | modules.append(ECA()) 317 | 318 | if dropblock: 319 | modules.append(DropBlock2D()) 320 | 321 | self.module = nn.Sequential(*modules) 322 | 323 | def forward(self, x): 324 | y = self.module(x) 325 | return y 326 | 327 | 328 | # Taken and modified from https://github.com/Tianxiaomo/pytorch-YOLOv4/blob/master/models.py 329 | class ResBlock(nn.Module): 330 | """ 331 | Sequential residual blocks each of which consists of \ 332 | two convolution layers. 333 | Args: 334 | ch (int): number of input and output channels. 335 | nblocks (int): number of residual blocks. 336 | shortcut (bool): if True, residual tensor addition is enabled. 337 | """ 338 | # Creating few conv blocks. One with kernel 3, second with kernel 1. With residual skip connection 339 | def __init__(self, ch, nblocks=1, shortcut=True, dropblock=True, sam=False, eca=False, ws=False, coord=False, hard_mish=False, bcn=False, mbn=False): 340 | super().__init__() 341 | self.shortcut = shortcut 342 | self.module_list = nn.ModuleList() 343 | for i in range(nblocks): 344 | resblock_one = nn.ModuleList() 345 | resblock_one.append(ConvBlock(ch, ch, 1, 1, 'mish', dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn)) 346 | resblock_one.append(ConvBlock(ch, ch, 3, 1, 'mish', dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn)) 347 | self.module_list.append(resblock_one) 348 | 349 | if dropblock: 350 | self.use_dropblock = True 351 | self.dropblock = DropBlock2D() 352 | else: 353 | self.use_dropblock = False 354 | 355 | def forward(self, x): 356 | for module in self.module_list: 357 | h = x 358 | for res in module: 359 | h = res(h) 360 | x = x + h if self.shortcut else h 361 | if self.use_dropblock: 362 | x = self.dropblock(x) 363 | 364 | return x 365 | 366 | 367 | class DownSampleFirst(nn.Module): 368 | """ 369 | This is first downsample of the backbone model. 370 | It differs from the other stages, so it is written as another Module 371 | Args: 372 | in_channels (int): Amount of channels to input, if you use RGB, it should be 3 373 | """ 374 | def __init__(self, in_channels=3, dropblock=True, sam=False, eca=False, ws=False, coord=False, hard_mish=False, bcn=False, mbn=False): 375 | super().__init__() 376 | 377 | self.c1 = ConvBlock(in_channels, 32, 3, 1, "mish", dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 378 | self.c2 = ConvBlock(32, 64, 3, 2, "mish", dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 379 | self.c3 = ConvBlock(64, 64, 1, 1, "mish", dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 380 | self.c4 = ConvBlock(64, 32, 1, 1, "mish", dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 381 | self.c5 = ConvBlock(32, 64, 3, 1, "mish", dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 382 | self.c6 = ConvBlock(64, 64, 1, 1, "mish", dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 383 | 384 | # CSP Layer 385 | self.dense_c3_c6 = ConvBlock(64, 64, 1, 1, "mish", dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish) 386 | 387 | self.c7 = ConvBlock(128, 64, 1, 1, "mish", dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish) 388 | 389 | def forward(self, x): 390 | x1 = self.c1(x) 391 | x2 = self.c2(x1) 392 | x3 = self.c3(x2) 393 | x4 = self.c4(x3) 394 | x5 = self.c5(x4) 395 | x5 = x5 + x3 # Residual block 396 | x6 = self.c6(x5) 397 | xd6 = self.dense_c3_c6(x2) # CSP 398 | x6 = torch.cat([x6, xd6], dim=1) 399 | x7 = self.c7(x6) 400 | return x7 401 | 402 | 403 | class DownSampleBlock(nn.Module): 404 | def __init__(self, in_c, out_c, nblocks=2, dropblock=True, sam=False, eca=False, ws=False, coord=False, hard_mish=False, bcn=False, mbn=False): 405 | super().__init__() 406 | 407 | self.c1 = ConvBlock(in_c, out_c, 3, 2, "mish", dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 408 | self.c2 = ConvBlock(out_c, in_c, 1, 1, "mish", dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 409 | self.r3 = ResBlock(in_c, nblocks=nblocks, dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 410 | self.c4 = ConvBlock(in_c, in_c, 1, 1, "mish", dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 411 | 412 | # CSP Layer 413 | self.dense_c2_c4 = ConvBlock(out_c, in_c, 1, 1, "mish", dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 414 | 415 | self.c5 = ConvBlock(out_c, out_c, 1, 1, "mish", dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 416 | 417 | def forward(self, x): 418 | x1 = self.c1(x) 419 | x2 = self.c2(x1) 420 | x3 = self.r3(x2) 421 | x4 = self.c4(x3) 422 | xd4 = self.dense_c2_c4(x1) # CSP 423 | x4 = torch.cat([x4, xd4], dim=1) 424 | x5 = self.c5(x4) 425 | 426 | return x5 427 | 428 | 429 | class Backbone(nn.Module): 430 | def __init__(self, in_channels, dropblock=True, sam=False, eca=False, ws=False, coord=False, hard_mish=False, bcn=False, mbn=False): 431 | super().__init__() 432 | 433 | self.d1 = DownSampleFirst(in_channels=in_channels, dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 434 | self.d2 = DownSampleBlock(64, 128, nblocks=2, dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 435 | self.d3 = DownSampleBlock(128, 256, nblocks=8, dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 436 | self.d4 = DownSampleBlock(256, 512, nblocks=8, dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 437 | self.d5 = DownSampleBlock(512, 1024, nblocks=4, dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 438 | 439 | def forward(self, x): 440 | x1 = self.d1(x) 441 | x2 = self.d2(x1) 442 | x3 = self.d3(x2) 443 | x4 = self.d4(x3) 444 | x5 = self.d5(x4) 445 | return (x5, x4, x3) 446 | 447 | 448 | class PAN_Layer(nn.Module): 449 | def __init__(self, in_channels, dropblock=True, sam=False, eca=False, ws=False, coord=False, hard_mish=False, bcn=False, mbn=False): 450 | super().__init__() 451 | 452 | in_c = in_channels 453 | out_c = in_c // 2 454 | 455 | self.c1 = ConvBlock(in_c, out_c, 1, 1, "leaky", dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 456 | self.u2 = nn.Upsample(scale_factor=2, mode="nearest") 457 | # Gets input from d4 458 | self.c2_from_upsampled = ConvBlock(in_c, out_c, 1, 1, "leaky", dropblock=False, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 459 | # We use stack in PAN, so 512 460 | self.c3 = ConvBlock(in_c, out_c, 1, 1, "leaky", dropblock=False, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 461 | self.c4 = ConvBlock(out_c, in_c, 3, 1, "leaky", dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 462 | self.c5 = ConvBlock(in_c, out_c, 1, 1, "leaky", dropblock=False, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 463 | self.c6 = ConvBlock(out_c, in_c, 3, 1, "leaky", dropblock=False, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 464 | self.c7 = ConvBlock(in_c, out_c, 1, 1, "leaky", dropblock=False, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 465 | 466 | def forward(self, x_to_upsample, x_upsampled): 467 | x1 = self.c1(x_to_upsample) 468 | x2_1 = self.u2(x1) 469 | x2_2 = self.c2_from_upsampled(x_upsampled) 470 | # First is not upsampled! 471 | x2 = torch.cat([x2_2, x2_1], dim=1) 472 | x3 = self.c3(x2) 473 | x4 = self.c4(x3) 474 | x5 = self.c5(x4) 475 | x6 = self.c6(x5) 476 | x7 = self.c7(x6) 477 | return x7 478 | 479 | #Taken and modified from https://github.com/ruinmessi/ASFF/blob/0ff0e3393675583f7da65a7b443ea467e1eaed65/models/network_blocks.py#L267-L330 480 | class ASFF(nn.Module): 481 | def __init__(self, level, rfb=False, vis=False, bcn=False, mbn=False): 482 | super(ASFF, self).__init__() 483 | self.level = level 484 | self.dim = [512, 256, 128] 485 | self.inter_dim = self.dim[self.level] 486 | if level==0: 487 | self.stride_level_1 = ConvBlock(256, self.inter_dim, 3, 2, "leaky", bcn=bcn, mbn=mbn) 488 | self.stride_level_2 = ConvBlock(128, self.inter_dim, 3, 2, "leaky", bcn=bcn, mbn=mbn) 489 | self.expand = ConvBlock(self.inter_dim, 512, 3, 1, "leaky", bcn=bcn, mbn=mbn) 490 | elif level==1: 491 | self.compress_level_0 = ConvBlock(512, self.inter_dim, 1, 1, "leaky", bcn=bcn, mbn=mbn) 492 | self.stride_level_2 = ConvBlock(128, self.inter_dim, 3, 2, "leaky", bcn=bcn, mbn=mbn) 493 | self.expand = ConvBlock(self.inter_dim, 256, 3, 1, "leaky", bcn=bcn, mbn=mbn) 494 | elif level==2: 495 | self.compress_level_0 = ConvBlock(512, self.inter_dim, 1, 1, "leaky", bcn=bcn, mbn=mbn) 496 | self.compress_level_1 = ConvBlock(256, self.inter_dim, 1, 1, "leaky", bcn=bcn, mbn=mbn) 497 | self.expand = ConvBlock(self.inter_dim, 128, 3, 1, "leaky", bcn=bcn, mbn=mbn) 498 | 499 | compress_c = 8 if rfb else 16 #when adding rfb, we use half number of channels to save memory 500 | self.weight_level_0 = ConvBlock(self.inter_dim, compress_c, 1, 1, "leaky", bcn=bcn, mbn=mbn) 501 | self.weight_level_1 = ConvBlock(self.inter_dim, compress_c, 1, 1, "leaky", bcn=bcn, mbn=mbn) 502 | self.weight_level_2 = ConvBlock(self.inter_dim, compress_c, 1, 1, "leaky", bcn=bcn, mbn=mbn) 503 | self.weight_levels = nn.Conv2d(compress_c*3, 3, kernel_size=1, stride=1, padding=0) 504 | 505 | self.vis= vis 506 | 507 | 508 | def forward(self, x_level_0, x_level_1, x_level_2): 509 | if self.level==0: 510 | level_0_resized = x_level_0 # 512 -> 512 511 | level_1_resized = self.stride_level_1(x_level_1) # 256 -> 512 512 | level_2_downsampled_inter =F.max_pool2d(x_level_2, 3, stride=2, padding=1) 513 | level_2_resized = self.stride_level_2(level_2_downsampled_inter) # 128 -> 512 514 | elif self.level==1: 515 | level_0_compressed = self.compress_level_0(x_level_0) # 512 -> 256 516 | level_0_resized =F.interpolate(level_0_compressed, scale_factor=2, mode='nearest') 517 | level_1_resized =x_level_1 # 256 -> 256 518 | level_2_resized =self.stride_level_2(x_level_2) # 128 -> 256 519 | elif self.level==2: 520 | level_0_compressed = self.compress_level_0(x_level_0) # 512 -> 128 521 | level_1_compressed = self.compress_level_1(x_level_1) # 256 -> 128 522 | level_0_resized =F.interpolate(level_0_compressed, scale_factor=4, mode='nearest') 523 | level_1_resized =F.interpolate(level_1_compressed, scale_factor=2, mode='nearest') 524 | level_2_resized =x_level_2 #128 -> 128 525 | 526 | level_0_weight_v = self.weight_level_0(level_0_resized) 527 | level_1_weight_v = self.weight_level_1(level_1_resized) 528 | level_2_weight_v = self.weight_level_2(level_2_resized) 529 | levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v),1) 530 | levels_weight = self.weight_levels(levels_weight_v) 531 | levels_weight = F.softmax(levels_weight, dim=1) 532 | 533 | fused_out_reduced = level_0_resized * levels_weight[:,0:1,:,:]+\ 534 | level_1_resized * levels_weight[:,1:2,:,:]+\ 535 | level_2_resized * levels_weight[:,2:,:,:] 536 | 537 | out = self.expand(fused_out_reduced) 538 | 539 | if self.vis: 540 | return out, levels_weight, fused_out_reduced.sum(dim=1) 541 | else: 542 | return out 543 | 544 | #Author: Vadims Casecnikovs creator of this repository 545 | class ACFF(nn.Module): 546 | def __init__(self, level, rfb=False, vis=False, bcn=False, mbn=False): 547 | super(ACFF, self).__init__() 548 | self.level = level 549 | self.dim = [512, 256, 128] 550 | self.inter_dim = self.dim[self.level] 551 | if level==0: 552 | self.stride_level_1 = ConvBlock(256, self.inter_dim, 3, 2, "leaky", bcn=bcn, mbn=mbn) 553 | self.stride_level_2 = ConvBlock(128, self.inter_dim, 3, 2, "leaky", bcn=bcn, mbn=mbn) 554 | self.expand = ConvBlock(self.inter_dim, 512, 3, 1, "leaky", bcn=bcn, mbn=mbn) 555 | elif level==1: 556 | self.compress_level_0 = ConvBlock(512, self.inter_dim, 1, 1, "leaky", bcn=bcn, mbn=mbn) 557 | self.stride_level_2 = ConvBlock(128, self.inter_dim, 3, 2, "leaky", bcn=bcn, mbn=mbn) 558 | self.expand = ConvBlock(self.inter_dim, 256, 3, 1, "leaky", bcn=bcn, mbn=mbn) 559 | elif level==2: 560 | self.compress_level_0 = ConvBlock(512, self.inter_dim, 1, 1, "leaky", bcn=bcn, mbn=mbn) 561 | self.compress_level_1 = ConvBlock(256, self.inter_dim, 1, 1, "leaky", bcn=bcn, mbn=mbn) 562 | self.expand = ConvBlock(self.inter_dim, 128, 3, 1, "leaky", bcn=bcn, mbn=mbn) 563 | 564 | self.avg_pool = FastGlobalAvgPool2d() 565 | self.weights_spatial = torch.nn.Parameter(torch.ones((3, self.inter_dim, 3))) 566 | self.weights_spatial = torch.nn.init.kaiming_uniform_(self.weights_spatial, a=0, mode='fan_in', nonlinearity='relu') 567 | 568 | self.vis= vis 569 | 570 | 571 | def forward(self, x_level_0, x_level_1, x_level_2): 572 | #In this part we are trying to construct the same channel features as in target level 573 | if self.level==0: 574 | level_0_resized = x_level_0 # 512 -> 512 575 | level_1_resized = self.stride_level_1(x_level_1) # 256 -> 512 576 | level_2_downsampled_inter =F.max_pool2d(x_level_2, 3, stride=2, padding=1) 577 | level_2_resized = self.stride_level_2(level_2_downsampled_inter) # 128 -> 512 578 | elif self.level==1: 579 | level_0_compressed = self.compress_level_0(x_level_0) # 512 -> 256 580 | level_0_resized =F.interpolate(level_0_compressed, scale_factor=2, mode='nearest') 581 | level_1_resized =x_level_1 # 256 -> 256 582 | level_2_resized =self.stride_level_2(x_level_2) # 128 -> 256 583 | elif self.level==2: 584 | level_0_compressed = self.compress_level_0(x_level_0) # 512 -> 128 585 | level_1_compressed = self.compress_level_1(x_level_1) # 256 -> 128 586 | level_0_resized =F.interpolate(level_0_compressed, scale_factor=4, mode='nearest') 587 | level_1_resized =F.interpolate(level_1_compressed, scale_factor=2, mode='nearest') 588 | level_2_resized =x_level_2 #128 -> 128 589 | 590 | #In this part we are getting mean value of channel's(featuremaps) activations 591 | level_0_flattened = self.avg_pool(level_0_resized) 592 | level_1_flattened = self.avg_pool(level_1_resized) 593 | level_2_flattened = self.avg_pool(level_2_resized) 594 | 595 | #Concatenating all 3 levels, getting (B, C, L) 596 | levels_weight = torch.cat([level_0_flattened, level_1_flattened, level_2_flattened], dim=2).squeeze(-1) 597 | 598 | #For each channel, getting 3 values, they would show how much attention should we give for each level's channel 599 | level_weight = torch.einsum("bci, kci-> bck", levels_weight, self.weights_spatial) 600 | 601 | levels_weight = torch.nn.functional.softmax(levels_weight, dim=2).unsqueeze(-1) 602 | 603 | fused_out_reduced = level_0_resized * levels_weight[:,:,0:1,:]+\ 604 | level_1_resized * levels_weight[:,:,1:2,:]+\ 605 | level_2_resized * levels_weight[:,:,2:,:] 606 | 607 | out = self.expand(fused_out_reduced) 608 | 609 | if self.vis: 610 | return out, levels_weight, fused_out_reduced.sum(dim=1) 611 | else: 612 | return out 613 | 614 | 615 | 616 | class Neck(nn.Module): 617 | def __init__(self, spp_kernels=(5, 9, 13), PAN_layers=[512, 256], dropblock=True, sam=False, eca=False, ws=False, coord=False, hard_mish=False, asff=False, acff=False, bcn=False, mbn=False): 618 | super().__init__() 619 | assert not(asff and acff) 620 | self.asff = asff 621 | self.acff = acff 622 | 623 | self.c1 = ConvBlock(1024, 512, 1, 1, "leaky", dropblock=False, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 624 | self.c2 = ConvBlock(512, 1024, 3, 1, "leaky", dropblock=False, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 625 | self.c3 = ConvBlock(1024, 512, 1, 1, "leaky", dropblock=False, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 626 | 627 | # SPP block 628 | self.mp4_1 = nn.MaxPool2d(kernel_size=spp_kernels[0], stride=1, padding=spp_kernels[0] // 2) 629 | self.mp4_2 = nn.MaxPool2d(kernel_size=spp_kernels[1], stride=1, padding=spp_kernels[1] // 2) 630 | self.mp4_3 = nn.MaxPool2d(kernel_size=spp_kernels[2], stride=1, padding=spp_kernels[2] // 2) 631 | 632 | self.c5 = ConvBlock(2048, 512, 1, 1, "leaky", dropblock=False, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 633 | self.c6 = ConvBlock(512, 1024, 3, 1, "leaky", dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 634 | self.c7 = ConvBlock(1024, 512, 1, 1, "leaky", dropblock=False, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 635 | 636 | self.PAN8 = PAN_Layer(PAN_layers[0], dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 637 | self.PAN9 = PAN_Layer(PAN_layers[1], dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 638 | 639 | if asff: # branch inputs biggest objects: 512, medium objects: 256, smallest objects : 128 640 | self.ASFF_0 = ASFF(0) 641 | self.ASFF_1 = ASFF(1) 642 | self.ASFF_2 = ASFF(2) 643 | 644 | if acff: 645 | self.ACFF_0 = ACFF(0) 646 | self.ACFF_1 = ACFF(1) 647 | self.ACFF_2 = ACFF(2) 648 | 649 | def forward(self, input): 650 | d5, d4, d3 = input 651 | 652 | x1 = self.c1(d5) 653 | x2 = self.c2(x1) 654 | x3 = self.c3(x2) 655 | 656 | x4_1 = self.mp4_1(x3) 657 | x4_2 = self.mp4_2(x3) 658 | x4_3 = self.mp4_3(x3) 659 | x4 = torch.cat([x4_1, x4_2, x4_3, x3], dim=1) 660 | 661 | x5 = self.c5(x4) 662 | x6 = self.c6(x5) 663 | x7 = self.c7(x6) 664 | 665 | x8 = self.PAN8(x7, d4) 666 | x9 = self.PAN9(x8, d3) 667 | 668 | if self.asff: 669 | x7ASFF = self.ASFF_0(x7, x8, x9) 670 | x8ASFF = self.ASFF_1(x7, x8, x9) 671 | x9ASFF = self.ASFF_2(x7, x8, x9) 672 | return x9ASFF, x8ASFF, x7ASFF 673 | 674 | if self.acff: 675 | x7ACFF = self.ACFF_0(x7, x8, x9) 676 | x8ACFF = self.ACFF_1(x7, x8, x9) 677 | x9ACFF = self.ACFF_2(x7, x8, x9) 678 | return x9ACFF, x8ACFF, x7ACFF 679 | 680 | 681 | return x9, x8, x7 682 | 683 | 684 | 685 | class HeadPreprocessing(nn.Module): 686 | def __init__(self, in_channels, dropblock=True, sam=False, eca=False, ws=False, coord=False, hard_mish=False, bcn=False, mbn=False): 687 | super().__init__() 688 | ic = in_channels 689 | self.c1 = ConvBlock(ic, ic*2, 3, 2, 'leaky', dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 690 | self.c2 = ConvBlock(ic*4, ic*2, 1, 1, 'leaky', dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 691 | self.c3 = ConvBlock(ic*2, ic*4, 3, 1, 'leaky', dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 692 | self.c4 = ConvBlock(ic*4, ic*2, 1, 1, 'leaky', dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 693 | self.c5 = ConvBlock(ic*2, ic*4, 3, 1, 'leaky', dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 694 | self.c6 = ConvBlock(ic*4, ic*2, 1, 1, 'leaky', dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 695 | 696 | def forward(self, input, input_prev): 697 | x1 = self.c1(input_prev) 698 | x1 = torch.cat([x1, input], dim=1) 699 | x2 = self.c2(x1) 700 | x3 = self.c3(x2) 701 | x4 = self.c4(x3) 702 | x5 = self.c5(x4) 703 | x6 = self.c6(x5) 704 | 705 | return x6 706 | 707 | 708 | class HeadOutput(nn.Module): 709 | def __init__(self, in_channels, out_channels, dropblock=True, sam=False, eca=False, ws=False, coord=False, hard_mish=False, bcn=False, mbn=False): 710 | super().__init__() 711 | self.c1 = ConvBlock(in_channels, in_channels*2, 3, 1, "leaky", dropblock=False, sam=sam, eca=eca, ws=False, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 712 | self.c2 = ConvBlock(in_channels*2, out_channels, 1, 1, "linear", bn=False, bias=True, dropblock=False, sam=False, eca=False, ws=False, coord=False, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 713 | 714 | def forward(self, x): 715 | x1 = self.c1(x) 716 | x2 = self.c2(x1) 717 | return x2 718 | 719 | 720 | class Head(nn.Module): 721 | def __init__(self, output_ch, dropblock=True, sam=False, eca=False, ws=False, coord=False, hard_mish=False, bcn=False, mbn=False): 722 | super().__init__() 723 | 724 | self.ho1 = HeadOutput(128, output_ch, dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 725 | 726 | self.hp2 = HeadPreprocessing(128, dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 727 | self.ho2 = HeadOutput(256, output_ch, dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 728 | 729 | self.hp3 = HeadPreprocessing(256, dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 730 | self.ho3 = HeadOutput(512, output_ch, dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 731 | 732 | def forward(self, input): 733 | input1, input2, input3 = input 734 | 735 | x1 = self.ho1(input1) 736 | x2 = self.hp2(input2, input1) 737 | x3 = self.ho2(x2) 738 | 739 | x4 = self.hp3(input3, x2) 740 | x5 = self.ho3(x4) 741 | 742 | return (x1, x3, x5) 743 | 744 | 745 | class YOLOLayer(nn.Module): 746 | """Detection layer taken and modified from https://github.com/eriklindernoren/PyTorch-YOLOv3""" 747 | 748 | def __init__(self, anchors, num_classes, img_dim=608, grid_size=None, iou_aware=False, repulsion_loss=False): 749 | super(YOLOLayer, self).__init__() 750 | self.anchors = anchors 751 | self.num_anchors = len(anchors) 752 | self.num_classes = num_classes 753 | self.ignore_thres = 0.5 754 | self.obj_scale = 1 755 | self.noobj_scale = 100 756 | self.metrics = {} 757 | self.img_dim = img_dim 758 | if grid_size: 759 | self.grid_size = grid_size 760 | self.compute_grid_offsets(self.grid_size) 761 | else: 762 | self.grid_size = 0 # grid size 763 | 764 | self.iou_aware = iou_aware 765 | self.repulsion_loss = repulsion_loss 766 | 767 | def compute_grid_offsets(self, grid_size, cuda=True): 768 | self.grid_size = grid_size 769 | g = self.grid_size 770 | FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 771 | self.stride = self.img_dim / self.grid_size 772 | # Calculate offsets for each grid 773 | self.grid_x = torch.arange(g).repeat(g, 1).view([1, 1, g, g]).type(FloatTensor) 774 | self.grid_y = torch.arange(g).repeat(g, 1).t().view([1, 1, g, g]).type(FloatTensor) 775 | self.scaled_anchors = FloatTensor([(a_w / self.stride, a_h / self.stride) for a_w, a_h in self.anchors]) 776 | self.anchor_w = self.scaled_anchors[:, 0:1].view((1, self.num_anchors, 1, 1)) 777 | self.anchor_h = self.scaled_anchors[:, 1:2].view((1, self.num_anchors, 1, 1)) 778 | 779 | def build_targets(self, pred_boxes, pred_cls, target, anchors, ignore_thres): 780 | 781 | ByteTensor = torch.cuda.BoolTensor if pred_boxes.is_cuda else torch.BoolTensor 782 | FloatTensor = torch.cuda.FloatTensor if pred_boxes.is_cuda else torch.FloatTensor 783 | 784 | nB = pred_boxes.size(0) 785 | nA = pred_boxes.size(1) 786 | nC = pred_cls.size(-1) 787 | nG = pred_boxes.size(2) 788 | 789 | # Output tensors 790 | obj_mask = ByteTensor(nB, nA, nG, nG).fill_(0) 791 | noobj_mask = ByteTensor(nB, nA, nG, nG).fill_(1) 792 | class_mask = FloatTensor(nB, nA, nG, nG).fill_(0) 793 | iou = FloatTensor(nB, nA, nG, nG).fill_(0) 794 | tx = FloatTensor(nB, nA, nG, nG).fill_(0) 795 | ty = FloatTensor(nB, nA, nG, nG).fill_(0) 796 | tw = FloatTensor(nB, nA, nG, nG).fill_(0) 797 | th = FloatTensor(nB, nA, nG, nG).fill_(0) 798 | tcls = FloatTensor(nB, nA, nG, nG, nC).fill_(0) 799 | 800 | target_boxes_grid = FloatTensor(nB, nA, nG, nG, 4).fill_(0) 801 | 802 | #If target is zero, then return 803 | if target.shape[0] == 0: 804 | tconf = obj_mask.float() 805 | # print(iou, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tcls, tconf, target_boxes_grid) 806 | return iou, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tcls, tconf, target_boxes_grid 807 | 808 | # 2 3 xy 809 | # 4 5 wh 810 | # Convert to position relative to box 811 | target_boxes = target[:, 2:6] * nG 812 | gxy = target_boxes[:, :2] 813 | gwh = target_boxes[:, 2:] 814 | 815 | # Get anchors with best iou 816 | ious = torch.stack([self.bbox_wh_iou(anchor, gwh) for anchor in anchors]) 817 | best_ious, best_n = ious.max(0) 818 | 819 | # Separate target values 820 | b, target_labels = target[:, :2].long().t() 821 | gx, gy = gxy.t() 822 | gw, gh = gwh.t() 823 | gi, gj = gxy.long().t() 824 | 825 | # Setting target boxes to big grid, it would be used to count loss 826 | target_boxes_grid[b, best_n, gj, gi] = target_boxes 827 | 828 | # Set masks 829 | obj_mask[b, best_n, gj, gi] = 1 830 | noobj_mask[b, best_n, gj, gi] = 0 831 | 832 | # Set noobj mask to zero where iou exceeds ignore threshold 833 | for i, anchor_ious in enumerate(ious.t()): 834 | noobj_mask[b[i], anchor_ious > ignore_thres, gj[i], gi[i]] = 0 835 | 836 | # Coordinates 837 | tx[b, best_n, gj, gi] = gx - gx.floor() 838 | ty[b, best_n, gj, gi] = gy - gy.floor() 839 | 840 | # Width and height 841 | tw[b, best_n, gj, gi] = torch.log(gw / anchors[best_n][:, 0] + 1e-16) 842 | th[b, best_n, gj, gi] = torch.log(gh / anchors[best_n][:, 1] + 1e-16) 843 | 844 | # One-hot encoding of label (WE USE LABEL SMOOTHING) 845 | tcls[b, best_n, gj, gi, target_labels] = 0.9 846 | 847 | # Compute label correctness and iou at best anchor 848 | class_mask[b, best_n, gj, gi] = (pred_cls[b, best_n, gj, gi].argmax(-1) == target_labels).float() 849 | iou[b, best_n, gj, gi] = self.bbox_iou(pred_boxes[b, best_n, gj, gi], target_boxes, x1y1x2y2=False) 850 | 851 | tconf = obj_mask.float() 852 | 853 | return iou, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tcls, tconf, target_boxes_grid 854 | 855 | def bbox_wh_iou(self, wh1, wh2): 856 | wh2 = wh2.t() 857 | w1, h1 = wh1[0], wh1[1] 858 | w2, h2 = wh2[0], wh2[1] 859 | inter_area = torch.min(w1, w2) * torch.min(h1, h2) 860 | union_area = (w1 * h1 + 1e-16) + w2 * h2 - inter_area 861 | return inter_area / union_area 862 | 863 | 864 | def bbox_iou(self, box1, box2, x1y1x2y2=True, get_areas = False): 865 | """ 866 | Returns the IoU of two bounding boxes 867 | """ 868 | if not x1y1x2y2: 869 | # Transform from center and width to exact coordinates 870 | b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2 871 | b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2 872 | b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2 873 | b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2 874 | else: 875 | # Get the coordinates of bounding boxes 876 | b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] 877 | b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] 878 | 879 | # get the coordinates of the intersection rectangle 880 | inter_rect_x1 = torch.max(b1_x1, b2_x1) 881 | inter_rect_y1 = torch.max(b1_y1, b2_y1) 882 | inter_rect_x2 = torch.min(b1_x2, b2_x2) 883 | inter_rect_y2 = torch.min(b1_y2, b2_y2) 884 | 885 | # Intersection area 886 | inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1, min=0) * torch.clamp( 887 | inter_rect_y2 - inter_rect_y1, min=0 888 | ) 889 | # Union Area 890 | b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) 891 | b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) 892 | union_area = (b1_area + b2_area - inter_area + 1e-16) 893 | 894 | 895 | if get_areas: 896 | return inter_area, union_area 897 | 898 | iou = inter_area / union_area 899 | return iou 900 | 901 | 902 | def smallestenclosing(self, pred_boxes, target_boxes): 903 | # Calculating smallest enclosing 904 | targetxc = target_boxes[..., 0] 905 | targetyc = target_boxes[..., 1] 906 | targetwidth = target_boxes[..., 2] 907 | targetheight = target_boxes[..., 3] 908 | 909 | predxc = pred_boxes[..., 0] 910 | predyc = pred_boxes[..., 1] 911 | predwidth = pred_boxes[..., 2] 912 | predheight = pred_boxes[..., 3] 913 | 914 | xc1 = torch.min(predxc - (predwidth/2), targetxc - (targetwidth/2)) 915 | yc1 = torch.min(predyc - (predheight/2), targetyc - (targetheight/2)) 916 | xc2 = torch.max(predxc + (predwidth/2), targetxc + (targetwidth/2)) 917 | yc2 = torch.max(predyc + (predheight/2), targetyc + (targetheight/2)) 918 | 919 | return xc1, yc1, xc2, yc2 920 | 921 | def xywh2xyxy(self, x): 922 | # Convert bounding box format from [x, y, w, h] to [x1, y1, x2, y2] 923 | y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x) 924 | y[:, 0] = x[:, 0] - x[:, 2] / 2 925 | y[:, 1] = x[:, 1] - x[:, 3] / 2 926 | y[:, 2] = x[:, 0] + x[:, 2] / 2 927 | y[:, 3] = x[:, 1] + x[:, 3] / 2 928 | return y 929 | 930 | def iou_all_to_all(self, a, b): 931 | #Calculates intersection over union area for each a bounding box with each b bounding box 932 | area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]) 933 | 934 | iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0]) 935 | ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1]) 936 | 937 | iw = torch.clamp(iw, min=0) 938 | ih = torch.clamp(ih, min=0) 939 | 940 | ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih 941 | 942 | ua = torch.clamp(ua, min=1e-8) 943 | 944 | intersection = iw * ih 945 | 946 | IoU = intersection / ua 947 | 948 | return IoU 949 | 950 | def smooth_ln(self, x, smooth =0.5): 951 | return torch.where( 952 | torch.le(x, smooth), 953 | -torch.log(1 - x), 954 | ((x - smooth) / (1 - smooth)) - math.log(1 - smooth) 955 | ) 956 | 957 | def iog(self, ground_truth, prediction): 958 | 959 | inter_xmin = torch.max(ground_truth[:, 0], prediction[:, 0]) 960 | inter_ymin = torch.max(ground_truth[:, 1], prediction[:, 1]) 961 | inter_xmax = torch.min(ground_truth[:, 2], prediction[:, 2]) 962 | inter_ymax = torch.min(ground_truth[:, 3], prediction[:, 3]) 963 | Iw = torch.clamp(inter_xmax - inter_xmin, min=0) 964 | Ih = torch.clamp(inter_ymax - inter_ymin, min=0) 965 | I = Iw * Ih 966 | G = (ground_truth[:, 2] - ground_truth[:, 0]) * (ground_truth[:, 3] - ground_truth[:, 1]) 967 | return I / G 968 | 969 | def calculate_repullsion(self, y, y_hat): 970 | batch_size = y_hat.shape[0] 971 | RepGTS = [] 972 | RepBoxes = [] 973 | for bn in range(batch_size): 974 | #Repulsion between prediction bbox and neighboring target bbox, which are not target for this bounding box. (pred bbox <- -> 2nd/3rd/... by iou target bbox) 975 | pred_bboxes = self.xywh2xyxy(y_hat[bn, :, :4]) 976 | bn_mask = y[:, 0] == bn 977 | gt_bboxes = self.xywh2xyxy(y[bn_mask, 2:] * 608) 978 | iou_anchor_to_target = self.iou_all_to_all(pred_bboxes, gt_bboxes) 979 | val, ind = torch.topk(iou_anchor_to_target, 2) 980 | second_closest_target_index = ind[:, 1] 981 | second_closest_target = gt_bboxes[second_closest_target_index] 982 | RepGT = self.smooth_ln(self.iog(second_closest_target, pred_bboxes)).mean() 983 | RepGTS.append(RepGT) 984 | 985 | #Repulsion between pred bbox and pred bbox, which are not refering to the same target bbox. 986 | have_target_mask = val[:, 0] != 0 987 | anchors_with_target = pred_bboxes[have_target_mask] 988 | iou_anchor_to_anchor = self.iou_all_to_all(anchors_with_target, anchors_with_target) 989 | other_mask = (torch.eye(iou_anchor_to_anchor.shape[0]) == 0).to(iou_anchor_to_anchor.device) 990 | different_target_mask = (ind[have_target_mask, 0] != ind[have_target_mask, 0].unsqueeze(1)) 991 | iou_atoa_filtered = iou_anchor_to_anchor[other_mask & different_target_mask] 992 | RepBox = self.smooth_ln(iou_atoa_filtered).sum()/iou_atoa_filtered.sum() 993 | RepBoxes.append(RepBox) 994 | return torch.stack(RepGTS).mean(), torch.stack(RepBoxes).mean() 995 | 996 | def forward(self, x : torch.Tensor, targets=None): 997 | # Tensors for cuda support 998 | FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor 999 | 1000 | num_samples = x.size(0) 1001 | grid_size = x.size(2) 1002 | 1003 | if self.iou_aware: 1004 | not_class_channels = 6 1005 | else: 1006 | not_class_channels = 5 1007 | prediction = ( 1008 | x.view(num_samples, self.num_anchors, self.num_classes + not_class_channels, grid_size, grid_size) 1009 | .permute(0, 1, 3, 4, 2) 1010 | .contiguous() 1011 | ) 1012 | 1013 | # Get outputs 1014 | x = torch.sigmoid(prediction[..., 0]) # Center x 1015 | y = torch.sigmoid(prediction[..., 1]) # Center y 1016 | w = prediction[..., 2] # Width 1017 | h = prediction[..., 3] # Height 1018 | pred_conf = torch.sigmoid(prediction[..., 4]) # Conf 1019 | if not self.iou_aware: 1020 | pred_cls = torch.sigmoid(prediction[..., 5:]) # Cls pred 1021 | else: 1022 | pred_cls = torch.sigmoid(prediction[..., 5:-1])# Cls pred 1023 | pred_iou = torch.sigmoid(prediction[..., -1]) #IoU pred 1024 | 1025 | # If grid size does not match current we compute new offsets 1026 | if grid_size != self.grid_size or self.grid_x.is_cuda != x.is_cuda: 1027 | self.compute_grid_offsets(grid_size, cuda=x.is_cuda) 1028 | 1029 | # Add offset and scale with anchors 1030 | pred_boxes = FloatTensor(prediction[..., :4].shape) 1031 | pred_boxes[..., 0] = x + self.grid_x 1032 | pred_boxes[..., 1] = y + self.grid_y 1033 | pred_boxes[..., 2] = torch.exp(w) * self.anchor_w 1034 | pred_boxes[..., 3] = torch.exp(h) * self.anchor_h 1035 | 1036 | output = torch.cat( 1037 | ( 1038 | pred_boxes.view(num_samples, -1, 4) * self.stride, 1039 | pred_conf.view(num_samples, -1, 1), 1040 | pred_cls.view(num_samples, -1, self.num_classes), 1041 | ), 1042 | -1, 1043 | ) 1044 | 1045 | # OUTPUT IS ALL BOXES WITH THEIR CONFIDENCE AND WITH CLASS 1046 | if targets is None: 1047 | return output, 0 1048 | 1049 | iou, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tcls, tconf, target_boxes = self.build_targets( 1050 | pred_boxes=pred_boxes, 1051 | pred_cls=pred_cls, 1052 | target=targets, 1053 | anchors=self.scaled_anchors, 1054 | ignore_thres=self.ignore_thres 1055 | ) 1056 | 1057 | # Diagonal length of the smallest enclosing box (is already squared) 1058 | xc1, yc1, xc2, yc2 = self.smallestenclosing(pred_boxes[obj_mask], target_boxes[obj_mask]) 1059 | c = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + 1e-7 1060 | 1061 | # Euclidean distance between central points 1062 | d = (tx[obj_mask] - x[obj_mask]) ** 2 + (ty[obj_mask] - y[obj_mask]) ** 2 1063 | 1064 | rDIoU = d/c 1065 | 1066 | iou_masked = iou[obj_mask] 1067 | v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(tw[obj_mask]/th[obj_mask])-torch.atan(w[obj_mask]/h[obj_mask])), 2) 1068 | 1069 | with torch.no_grad(): 1070 | S = 1 - iou_masked 1071 | alpha = v / (S + v + 1e-7) 1072 | 1073 | if num_samples != 0: 1074 | CIoUloss = (1 - iou_masked + rDIoU + alpha * v).sum(0)/num_samples 1075 | else: 1076 | CIoUloss = 0 1077 | # print(torch.isnan(pred_conf).sum()) 1078 | 1079 | loss_conf_noobj = F.binary_cross_entropy(pred_conf[noobj_mask], tconf[noobj_mask]) 1080 | 1081 | if targets.shape[0] == 0: 1082 | loss_conf_obj = 0. 1083 | loss_cls = 0. 1084 | else: 1085 | loss_conf_obj = F.binary_cross_entropy(pred_conf[obj_mask], tconf[obj_mask]) 1086 | loss_cls = F.binary_cross_entropy(input=pred_cls[obj_mask], target=tcls[obj_mask]) 1087 | 1088 | loss_conf = self.obj_scale * loss_conf_obj + self.noobj_scale * loss_conf_noobj 1089 | total_loss = CIoUloss + loss_cls + loss_conf 1090 | 1091 | if self.iou_aware: 1092 | pred_iou_masked = pred_iou[obj_mask] 1093 | 1094 | # print("Pred iou", pred_iou.shape) 1095 | # print("IOU masked", iou_masked.shape) 1096 | # print("Pred iou", pred_iou) 1097 | # print("IOU masked", iou_masked) 1098 | # print("pred iou masked", pred_iou_masked.shape) 1099 | # print("pred iou masked", pred_iou_masked) 1100 | # print(F.binary_cross_entropy(pred_iou_masked, iou_masked.detach())) 1101 | total_loss += F.binary_cross_entropy(pred_iou_masked, iou_masked.detach()) 1102 | 1103 | if self.repulsion_loss: 1104 | repgt, repbox = self.calculate_repullsion(targets, output) 1105 | total_loss += 0.5 * repgt + 0.5 * repbox 1106 | 1107 | # print(f"C: {c}; D: {d}") 1108 | # print(f"Confidence is object: {loss_conf_obj}, Confidence no object: {loss_conf_noobj}") 1109 | # print(f"IoU: {iou_masked}; DIoU: {rDIoU}; alpha: {alpha}; v: {v}") 1110 | # print(f"CIoU : {CIoUloss.item()}; Confindence: {loss_conf.item()}; Class loss should be because of label smoothing: {loss_cls.item()}") 1111 | return output, total_loss 1112 | 1113 | 1114 | class YOLOv4(nn.Module): 1115 | def __init__(self, in_channels=3, n_classes=80, weights_path=None, pretrained=False, img_dim=608, anchors=None, dropblock=True, sam=False, eca=False, ws=False, iou_aware=False, coord=False, hard_mish=False, asff=False, repulsion_loss=False, acff=False, bcn=False, mbn=False): 1116 | super().__init__() 1117 | if anchors is None: 1118 | anchors = [[[10, 13], [16, 30], [33, 23]], 1119 | [[30, 61], [62, 45], [59, 119]], 1120 | [[116, 90], [156, 198], [373, 326]]] 1121 | 1122 | output_ch = (4 + 1 + n_classes) * 3 1123 | if iou_aware: 1124 | output_ch += 3 #1 for iou 1125 | 1126 | self.img_dim = img_dim 1127 | 1128 | self.backbone = Backbone(in_channels, dropblock=False, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn, mbn=mbn) 1129 | 1130 | self.neck = Neck(dropblock=dropblock, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, asff=asff, acff=acff, bcn=bcn, mbn=mbn) 1131 | 1132 | self.head = Head(output_ch, dropblock=False, sam=sam, eca=eca, ws=ws, coord=coord, hard_mish=hard_mish, bcn=bcn) 1133 | 1134 | self.yolo1 = YOLOLayer(anchors[0], n_classes, img_dim, iou_aware=iou_aware, repulsion_loss=repulsion_loss) 1135 | self.yolo2 = YOLOLayer(anchors[1], n_classes, img_dim, iou_aware=iou_aware, repulsion_loss=repulsion_loss) 1136 | self.yolo3 = YOLOLayer(anchors[2], n_classes, img_dim, iou_aware=iou_aware, repulsion_loss=repulsion_loss) 1137 | 1138 | if weights_path: 1139 | try: # If we change input or output layers amount, we will have an option to use pretrained weights 1140 | self.load_state_dict(torch.load(weights_path), strict=False) 1141 | except RuntimeError as e: 1142 | print(f'[Warning] Ignoring {e}') 1143 | elif pretrained: 1144 | try: # If we change input or output layers amount, we will have an option to use pretrained weights 1145 | self.load_state_dict(torch.hub.load_state_dict_from_url("https://github.com/VCasecnikovs/Yet-Another-YOLOv4-Pytorch/releases/download/V1.0/yolov4.pth"), strict=False) 1146 | except RuntimeError as e: 1147 | print(f'[Warning] Ignoring {e}') 1148 | 1149 | def forward(self, x, y=None): 1150 | b = self.backbone(x) 1151 | n = self.neck(b) 1152 | h = self.head(n) 1153 | 1154 | h1, h2, h3 = h 1155 | 1156 | out1, loss1 = self.yolo1(h1, y) 1157 | out2, loss2 = self.yolo2(h2, y) 1158 | out3, loss3 = self.yolo3(h3, y) 1159 | 1160 | out1 = out1.detach() 1161 | out2 = out2.detach() 1162 | out3 = out3.detach() 1163 | 1164 | out = torch.cat((out1, out2, out3), dim=1) 1165 | 1166 | loss = (loss1 + loss2 + loss3)/3 1167 | 1168 | return out, loss 1169 | 1170 | 1171 | if __name__ == "__main__": 1172 | import time 1173 | import numpy as np 1174 | 1175 | model = YOLOv4().cuda().eval() 1176 | x = torch.ones((1, 3, 608, 608)).cuda() 1177 | y = torch.from_numpy(np.asarray([[0, 1, 0.5, 0.5, 0.3, 0.3]])).float().cuda() 1178 | 1179 | for i in range(1): 1180 | t0 = time.time() 1181 | y_hat, loss = model(x, y) 1182 | t1 = time.time() 1183 | print(t1 - t0) 1184 | 1185 | print(loss) 1186 | -------------------------------------------------------------------------------- /Training YOLOv4 .ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pl_model import YOLOv4PL" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import pytorch_lightning as pl\n", 19 | "from argparse import Namespace\n", 20 | "from pytorch_lightning.callbacks import LearningRateLogger" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 3, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "hparams = {\n", 30 | " \"n_classes\" : 170,\n", 31 | "\n", 32 | " \"pretrained\" : False,\n", 33 | " \"train_ds\" : \"train.txt\",\n", 34 | " \"valid_ds\" : \"valid.txt\",\n", 35 | " \"img_extensions\" : [\".JPG\", \".jpg\"],\n", 36 | " \"bs\" : 1,\n", 37 | " \"momentum\": 0.9,\n", 38 | " \"wd\": 0.001,\n", 39 | " \"lr\": 1e-8,\n", 40 | " \"epochs\" : 100,\n", 41 | " \"pct_start\" : 10/100,\n", 42 | " \n", 43 | " \"optimizer\" : \"Ranger\",\n", 44 | " \"flat_epochs\" : 50,\n", 45 | " \"cosine_epochs\" : 25,\n", 46 | " \"scheduler\" : \"Cosine Delayed\", \n", 47 | " \n", 48 | " \"SAT\" : False,\n", 49 | " \"epsilon\" : 0.1,\n", 50 | " \"SAM\" : False,\n", 51 | " \"ECA\" : False,\n", 52 | " \"WS\" : False,\n", 53 | " \"Dropblock\" : False,\n", 54 | " \"iou_aware\" : False,\n", 55 | " \"coord\" : False,\n", 56 | " \"hard_mish\" : False,\n", 57 | " \"asff\" : False,\n", 58 | " \"repulsion_loss\" : False,\n", 59 | " \"acff\" : True,\n", 60 | " \"bcn\" : False,\n", 61 | " \"mbn\" : False,\n", 62 | "}" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 4, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "hparams = Namespace(**hparams)\n", 72 | "m = YOLOv4PL(hparams)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 5, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "m.cpu();" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 6, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "tb_logger = pl.loggers.TensorBoardLogger('logs/', name = \"yolov4\")" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 7, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "checkpoint_callback = pl.callbacks.ModelCheckpoint(\n", 100 | " filepath='model_checkpoints/yolov4{epoch:02d}',\n", 101 | " verbose=True,\n", 102 | " monitor=\"training_loss_epoch\",\n", 103 | " mode='min',\n", 104 | ")" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 8, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "output_type": "stream", 114 | "name": "stderr", 115 | "text": [ 116 | "GPU available: True, used: False\n", 117 | "No environment variable for node rank defined. Set as 0.\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "t = pl.Trainer(logger = tb_logger,\n", 123 | " checkpoint_callback=checkpoint_callback,\n", 124 | " gpus=0,\n", 125 | " precision=32,\n", 126 | " benchmark=True,\n", 127 | " callbacks=[LearningRateLogger()],\n", 128 | " min_epochs=100,\n", 129 | "\n", 130 | "\n", 131 | "# resume_from_checkpoint=\"model_checkpoints/yolov4epoch=82.ckpt\",\n", 132 | " # auto_lr_find=True,\n", 133 | " # auto_scale_batch_size='binsearch',\n", 134 | " # fast_dev_run=True\n", 135 | " )\n" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 9, 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "output_type": "stream", 145 | "name": "stderr", 146 | "text": [ 147 | "\n", 148 | " | Name | Type | Params\n", 149 | "--------------------------------------------------------------------------------\n", 150 | "0 | model | YOLOv4 | 69 M \n", 151 | "1 | model.backbone | Backbone | 26 M \n", 152 | "2 | model.backbone.d1 | DownSampleFirst | 61 K \n", 153 | "3 | model.backbone.d1.c1 | ConvBlock | 928 \n", 154 | "4 | model.backbone.d1.c1.module | Sequential | 928 \n", 155 | "5 | model.backbone.d1.c1.module.0 | Conv2d | 864 \n", 156 | "6 | model.backbone.d1.c1.module.1 | BatchNorm2d | 64 \n", 157 | "7 | model.backbone.d1.c1.module.2 | Mish | 0 \n", 158 | "8 | model.backbone.d1.c2 | ConvBlock | 18 K \n", 159 | "9 | model.backbone.d1.c2.module | Sequential | 18 K \n", 160 | "10 | model.backbone.d1.c2.module.0 | Conv2d | 18 K \n", 161 | "11 | model.backbone.d1.c2.module.1 | BatchNorm2d | 128 \n", 162 | "12 | model.backbone.d1.c2.module.2 | Mish | 0 \n", 163 | "13 | model.backbone.d1.c3 | ConvBlock | 4 K \n", 164 | "14 | model.backbone.d1.c3.module | Sequential | 4 K \n", 165 | "15 | model.backbone.d1.c3.module.0 | Conv2d | 4 K \n", 166 | "16 | model.backbone.d1.c3.module.1 | BatchNorm2d | 128 \n", 167 | "17 | model.backbone.d1.c3.module.2 | Mish | 0 \n", 168 | "18 | model.backbone.d1.c4 | ConvBlock | 2 K \n", 169 | "19 | model.backbone.d1.c4.module | Sequential | 2 K \n", 170 | "20 | model.backbone.d1.c4.module.0 | Conv2d | 2 K \n", 171 | "21 | model.backbone.d1.c4.module.1 | BatchNorm2d | 64 \n", 172 | "22 | model.backbone.d1.c4.module.2 | Mish | 0 \n", 173 | "23 | model.backbone.d1.c5 | ConvBlock | 18 K \n", 174 | "24 | model.backbone.d1.c5.module | Sequential | 18 K \n", 175 | "25 | model.backbone.d1.c5.module.0 | Conv2d | 18 K \n", 176 | "26 | model.backbone.d1.c5.module.1 | BatchNorm2d | 128 \n", 177 | "27 | model.backbone.d1.c5.module.2 | Mish | 0 \n", 178 | "28 | model.backbone.d1.c6 | ConvBlock | 4 K \n", 179 | "29 | model.backbone.d1.c6.module | Sequential | 4 K \n", 180 | "30 | model.backbone.d1.c6.module.0 | Conv2d | 4 K \n", 181 | "31 | model.backbone.d1.c6.module.1 | BatchNorm2d | 128 \n", 182 | "32 | model.backbone.d1.c6.module.2 | Mish | 0 \n", 183 | "33 | model.backbone.d1.dense_c3_c6 | ConvBlock | 4 K \n", 184 | "34 | model.backbone.d1.dense_c3_c6.module | Sequential | 4 K \n", 185 | "35 | model.backbone.d1.dense_c3_c6.module.0 | Conv2d | 4 K \n", 186 | "36 | model.backbone.d1.dense_c3_c6.module.1 | BatchNorm2d | 128 \n", 187 | "37 | model.backbone.d1.dense_c3_c6.module.2 | Mish | 0 \n", 188 | "38 | model.backbone.d1.c7 | ConvBlock | 8 K \n", 189 | "39 | model.backbone.d1.c7.module | Sequential | 8 K \n", 190 | "40 | model.backbone.d1.c7.module.0 | Conv2d | 8 K \n", 191 | "41 | model.backbone.d1.c7.module.1 | BatchNorm2d | 128 \n", 192 | "42 | model.backbone.d1.c7.module.2 | Mish | 0 \n", 193 | "43 | model.backbone.d2 | DownSampleBlock | 193 K \n", 194 | "44 | model.backbone.d2.c1 | ConvBlock | 73 K \n", 195 | "45 | model.backbone.d2.c1.module | Sequential | 73 K \n", 196 | "46 | model.backbone.d2.c1.module.0 | Conv2d | 73 K \n", 197 | "47 | model.backbone.d2.c1.module.1 | BatchNorm2d | 256 \n", 198 | "48 | model.backbone.d2.c1.module.2 | Mish | 0 \n", 199 | "49 | model.backbone.d2.c2 | ConvBlock | 8 K \n", 200 | "50 | model.backbone.d2.c2.module | Sequential | 8 K \n", 201 | "51 | model.backbone.d2.c2.module.0 | Conv2d | 8 K \n", 202 | "52 | model.backbone.d2.c2.module.1 | BatchNorm2d | 128 \n", 203 | "53 | model.backbone.d2.c2.module.2 | Mish | 0 \n", 204 | "54 | model.backbone.d2.r3 | ResBlock | 82 K \n", 205 | "55 | model.backbone.d2.r3.module_list | ModuleList | 82 K \n", 206 | "56 | model.backbone.d2.r3.module_list.0 | ModuleList | 41 K \n", 207 | "57 | model.backbone.d2.r3.module_list.0.0 | ConvBlock | 4 K \n", 208 | "58 | model.backbone.d2.r3.module_list.0.0.module | Sequential | 4 K \n", 209 | "59 | model.backbone.d2.r3.module_list.0.0.module.0 | Conv2d | 4 K \n", 210 | "60 | model.backbone.d2.r3.module_list.0.0.module.1 | BatchNorm2d | 128 \n", 211 | "61 | model.backbone.d2.r3.module_list.0.0.module.2 | Mish | 0 \n", 212 | "62 | model.backbone.d2.r3.module_list.0.1 | ConvBlock | 36 K \n", 213 | "63 | model.backbone.d2.r3.module_list.0.1.module | Sequential | 36 K \n", 214 | "64 | model.backbone.d2.r3.module_list.0.1.module.0 | Conv2d | 36 K \n", 215 | "65 | model.backbone.d2.r3.module_list.0.1.module.1 | BatchNorm2d | 128 \n", 216 | "66 | model.backbone.d2.r3.module_list.0.1.module.2 | Mish | 0 \n", 217 | "67 | model.backbone.d2.r3.module_list.1 | ModuleList | 41 K \n", 218 | "68 | model.backbone.d2.r3.module_list.1.0 | ConvBlock | 4 K \n", 219 | "69 | model.backbone.d2.r3.module_list.1.0.module | Sequential | 4 K \n", 220 | "70 | model.backbone.d2.r3.module_list.1.0.module.0 | Conv2d | 4 K \n", 221 | "71 | model.backbone.d2.r3.module_list.1.0.module.1 | BatchNorm2d | 128 \n", 222 | "72 | model.backbone.d2.r3.module_list.1.0.module.2 | Mish | 0 \n", 223 | "73 | model.backbone.d2.r3.module_list.1.1 | ConvBlock | 36 K \n", 224 | "74 | model.backbone.d2.r3.module_list.1.1.module | Sequential | 36 K \n", 225 | "75 | model.backbone.d2.r3.module_list.1.1.module.0 | Conv2d | 36 K \n", 226 | "76 | model.backbone.d2.r3.module_list.1.1.module.1 | BatchNorm2d | 128 \n", 227 | "77 | model.backbone.d2.r3.module_list.1.1.module.2 | Mish | 0 \n", 228 | "78 | model.backbone.d2.c4 | ConvBlock | 4 K \n", 229 | "79 | model.backbone.d2.c4.module | Sequential | 4 K \n", 230 | "80 | model.backbone.d2.c4.module.0 | Conv2d | 4 K \n", 231 | "81 | model.backbone.d2.c4.module.1 | BatchNorm2d | 128 \n", 232 | "82 | model.backbone.d2.c4.module.2 | Mish | 0 \n", 233 | "83 | model.backbone.d2.dense_c2_c4 | ConvBlock | 8 K \n", 234 | "84 | model.backbone.d2.dense_c2_c4.module | Sequential | 8 K \n", 235 | "85 | model.backbone.d2.dense_c2_c4.module.0 | Conv2d | 8 K \n", 236 | "86 | model.backbone.d2.dense_c2_c4.module.1 | BatchNorm2d | 128 \n", 237 | "87 | model.backbone.d2.dense_c2_c4.module.2 | Mish | 0 \n", 238 | "88 | model.backbone.d2.c5 | ConvBlock | 16 K \n", 239 | "89 | model.backbone.d2.c5.module | Sequential | 16 K \n", 240 | "90 | model.backbone.d2.c5.module.0 | Conv2d | 16 K \n", 241 | "91 | model.backbone.d2.c5.module.1 | BatchNorm2d | 256 \n", 242 | "92 | model.backbone.d2.c5.module.2 | Mish | 0 \n", 243 | "93 | model.backbone.d3 | DownSampleBlock | 1 M \n", 244 | "94 | model.backbone.d3.c1 | ConvBlock | 295 K \n", 245 | "95 | model.backbone.d3.c1.module | Sequential | 295 K \n", 246 | "96 | model.backbone.d3.c1.module.0 | Conv2d | 294 K \n", 247 | "97 | model.backbone.d3.c1.module.1 | BatchNorm2d | 512 \n", 248 | "98 | model.backbone.d3.c1.module.2 | Mish | 0 \n", 249 | "99 | model.backbone.d3.c2 | ConvBlock | 33 K \n", 250 | "100 | model.backbone.d3.c2.module | Sequential | 33 K \n", 251 | "101 | model.backbone.d3.c2.module.0 | Conv2d | 32 K \n", 252 | "102 | model.backbone.d3.c2.module.1 | BatchNorm2d | 256 \n", 253 | "103 | model.backbone.d3.c2.module.2 | Mish | 0 \n", 254 | "104 | model.backbone.d3.r3 | ResBlock | 1 M \n", 255 | "105 | model.backbone.d3.r3.module_list | ModuleList | 1 M \n", 256 | "106 | model.backbone.d3.r3.module_list.0 | ModuleList | 164 K \n", 257 | "107 | model.backbone.d3.r3.module_list.0.0 | ConvBlock | 16 K \n", 258 | "108 | model.backbone.d3.r3.module_list.0.0.module | Sequential | 16 K \n", 259 | "109 | model.backbone.d3.r3.module_list.0.0.module.0 | Conv2d | 16 K \n", 260 | "110 | model.backbone.d3.r3.module_list.0.0.module.1 | BatchNorm2d | 256 \n", 261 | "111 | model.backbone.d3.r3.module_list.0.0.module.2 | Mish | 0 \n", 262 | "112 | model.backbone.d3.r3.module_list.0.1 | ConvBlock | 147 K \n", 263 | "113 | model.backbone.d3.r3.module_list.0.1.module | Sequential | 147 K \n", 264 | "114 | model.backbone.d3.r3.module_list.0.1.module.0 | Conv2d | 147 K \n", 265 | "115 | model.backbone.d3.r3.module_list.0.1.module.1 | BatchNorm2d | 256 \n", 266 | "116 | model.backbone.d3.r3.module_list.0.1.module.2 | Mish | 0 \n", 267 | "117 | model.backbone.d3.r3.module_list.1 | ModuleList | 164 K \n", 268 | "118 | model.backbone.d3.r3.module_list.1.0 | ConvBlock | 16 K \n", 269 | "119 | model.backbone.d3.r3.module_list.1.0.module | Sequential | 16 K \n", 270 | "120 | model.backbone.d3.r3.module_list.1.0.module.0 | Conv2d | 16 K \n", 271 | "121 | model.backbone.d3.r3.module_list.1.0.module.1 | BatchNorm2d | 256 \n", 272 | "122 | model.backbone.d3.r3.module_list.1.0.module.2 | Mish | 0 \n", 273 | "123 | model.backbone.d3.r3.module_list.1.1 | ConvBlock | 147 K \n", 274 | "124 | model.backbone.d3.r3.module_list.1.1.module | Sequential | 147 K \n", 275 | "125 | model.backbone.d3.r3.module_list.1.1.module.0 | Conv2d | 147 K \n", 276 | "126 | model.backbone.d3.r3.module_list.1.1.module.1 | BatchNorm2d | 256 \n", 277 | "127 | model.backbone.d3.r3.module_list.1.1.module.2 | Mish | 0 \n", 278 | "128 | model.backbone.d3.r3.module_list.2 | ModuleList | 164 K \n", 279 | "129 | model.backbone.d3.r3.module_list.2.0 | ConvBlock | 16 K \n", 280 | "130 | model.backbone.d3.r3.module_list.2.0.module | Sequential | 16 K \n", 281 | "131 | model.backbone.d3.r3.module_list.2.0.module.0 | Conv2d | 16 K \n", 282 | "132 | model.backbone.d3.r3.module_list.2.0.module.1 | BatchNorm2d | 256 \n", 283 | "133 | model.backbone.d3.r3.module_list.2.0.module.2 | Mish | 0 \n", 284 | "134 | model.backbone.d3.r3.module_list.2.1 | ConvBlock | 147 K \n", 285 | "135 | model.backbone.d3.r3.module_list.2.1.module | Sequential | 147 K \n", 286 | "136 | model.backbone.d3.r3.module_list.2.1.module.0 | Conv2d | 147 K \n", 287 | "137 | model.backbone.d3.r3.module_list.2.1.module.1 | BatchNorm2d | 256 \n", 288 | "138 | model.backbone.d3.r3.module_list.2.1.module.2 | Mish | 0 \n", 289 | "139 | model.backbone.d3.r3.module_list.3 | ModuleList | 164 K \n", 290 | "140 | model.backbone.d3.r3.module_list.3.0 | ConvBlock | 16 K \n", 291 | "141 | model.backbone.d3.r3.module_list.3.0.module | Sequential | 16 K \n", 292 | "142 | model.backbone.d3.r3.module_list.3.0.module.0 | Conv2d | 16 K \n", 293 | "143 | model.backbone.d3.r3.module_list.3.0.module.1 | BatchNorm2d | 256 \n", 294 | "144 | model.backbone.d3.r3.module_list.3.0.module.2 | Mish | 0 \n", 295 | "145 | model.backbone.d3.r3.module_list.3.1 | ConvBlock | 147 K \n", 296 | "146 | model.backbone.d3.r3.module_list.3.1.module | Sequential | 147 K \n", 297 | "147 | model.backbone.d3.r3.module_list.3.1.module.0 | Conv2d | 147 K \n", 298 | "148 | model.backbone.d3.r3.module_list.3.1.module.1 | BatchNorm2d | 256 \n", 299 | "149 | model.backbone.d3.r3.module_list.3.1.module.2 | Mish | 0 \n", 300 | "150 | model.backbone.d3.r3.module_list.4 | ModuleList | 164 K \n", 301 | "151 | model.backbone.d3.r3.module_list.4.0 | ConvBlock | 16 K \n", 302 | "152 | model.backbone.d3.r3.module_list.4.0.module | Sequential | 16 K \n", 303 | "153 | model.backbone.d3.r3.module_list.4.0.module.0 | Conv2d | 16 K \n", 304 | "154 | model.backbone.d3.r3.module_list.4.0.module.1 | BatchNorm2d | 256 \n", 305 | "155 | model.backbone.d3.r3.module_list.4.0.module.2 | Mish | 0 \n", 306 | "156 | model.backbone.d3.r3.module_list.4.1 | ConvBlock | 147 K \n", 307 | "157 | model.backbone.d3.r3.module_list.4.1.module | Sequential | 147 K \n", 308 | "158 | model.backbone.d3.r3.module_list.4.1.module.0 | Conv2d | 147 K \n", 309 | "159 | model.backbone.d3.r3.module_list.4.1.module.1 | BatchNorm2d | 256 \n", 310 | "160 | model.backbone.d3.r3.module_list.4.1.module.2 | Mish | 0 \n", 311 | "161 | model.backbone.d3.r3.module_list.5 | ModuleList | 164 K \n", 312 | "162 | model.backbone.d3.r3.module_list.5.0 | ConvBlock | 16 K \n", 313 | "163 | model.backbone.d3.r3.module_list.5.0.module | Sequential | 16 K \n", 314 | "164 | model.backbone.d3.r3.module_list.5.0.module.0 | Conv2d | 16 K \n", 315 | "165 | model.backbone.d3.r3.module_list.5.0.module.1 | BatchNorm2d | 256 \n", 316 | "166 | model.backbone.d3.r3.module_list.5.0.module.2 | Mish | 0 \n", 317 | "167 | model.backbone.d3.r3.module_list.5.1 | ConvBlock | 147 K \n", 318 | "168 | model.backbone.d3.r3.module_list.5.1.module | Sequential | 147 K \n", 319 | "169 | model.backbone.d3.r3.module_list.5.1.module.0 | Conv2d | 147 K \n", 320 | "170 | model.backbone.d3.r3.module_list.5.1.module.1 | BatchNorm2d | 256 \n", 321 | "171 | model.backbone.d3.r3.module_list.5.1.module.2 | Mish | 0 \n", 322 | "172 | model.backbone.d3.r3.module_list.6 | ModuleList | 164 K \n", 323 | "173 | model.backbone.d3.r3.module_list.6.0 | ConvBlock | 16 K \n", 324 | "174 | model.backbone.d3.r3.module_list.6.0.module | Sequential | 16 K \n", 325 | "175 | model.backbone.d3.r3.module_list.6.0.module.0 | Conv2d | 16 K \n", 326 | "176 | model.backbone.d3.r3.module_list.6.0.module.1 | BatchNorm2d | 256 \n", 327 | "177 | model.backbone.d3.r3.module_list.6.0.module.2 | Mish | 0 \n", 328 | "178 | model.backbone.d3.r3.module_list.6.1 | ConvBlock | 147 K \n", 329 | "179 | model.backbone.d3.r3.module_list.6.1.module | Sequential | 147 K \n", 330 | "180 | model.backbone.d3.r3.module_list.6.1.module.0 | Conv2d | 147 K \n", 331 | "181 | model.backbone.d3.r3.module_list.6.1.module.1 | BatchNorm2d | 256 \n", 332 | "182 | model.backbone.d3.r3.module_list.6.1.module.2 | Mish | 0 \n", 333 | "183 | model.backbone.d3.r3.module_list.7 | ModuleList | 164 K \n", 334 | "184 | model.backbone.d3.r3.module_list.7.0 | ConvBlock | 16 K \n", 335 | "185 | model.backbone.d3.r3.module_list.7.0.module | Sequential | 16 K \n", 336 | "186 | model.backbone.d3.r3.module_list.7.0.module.0 | Conv2d | 16 K \n", 337 | "187 | model.backbone.d3.r3.module_list.7.0.module.1 | BatchNorm2d | 256 \n", 338 | "188 | model.backbone.d3.r3.module_list.7.0.module.2 | Mish | 0 \n", 339 | "189 | model.backbone.d3.r3.module_list.7.1 | ConvBlock | 147 K \n", 340 | "190 | model.backbone.d3.r3.module_list.7.1.module | Sequential | 147 K \n", 341 | "191 | model.backbone.d3.r3.module_list.7.1.module.0 | Conv2d | 147 K \n", 342 | "192 | model.backbone.d3.r3.module_list.7.1.module.1 | BatchNorm2d | 256 \n", 343 | "193 | model.backbone.d3.r3.module_list.7.1.module.2 | Mish | 0 \n", 344 | "194 | model.backbone.d3.c4 | ConvBlock | 16 K \n", 345 | "195 | model.backbone.d3.c4.module | Sequential | 16 K \n", 346 | "196 | model.backbone.d3.c4.module.0 | Conv2d | 16 K \n", 347 | "197 | model.backbone.d3.c4.module.1 | BatchNorm2d | 256 \n", 348 | "198 | model.backbone.d3.c4.module.2 | Mish | 0 \n", 349 | "199 | model.backbone.d3.dense_c2_c4 | ConvBlock | 33 K \n", 350 | "200 | model.backbone.d3.dense_c2_c4.module | Sequential | 33 K \n", 351 | "201 | model.backbone.d3.dense_c2_c4.module.0 | Conv2d | 32 K \n", 352 | "202 | model.backbone.d3.dense_c2_c4.module.1 | BatchNorm2d | 256 \n", 353 | "203 | model.backbone.d3.dense_c2_c4.module.2 | Mish | 0 \n", 354 | "204 | model.backbone.d3.c5 | ConvBlock | 66 K \n", 355 | "205 | model.backbone.d3.c5.module | Sequential | 66 K \n", 356 | "206 | model.backbone.d3.c5.module.0 | Conv2d | 65 K \n", 357 | "207 | model.backbone.d3.c5.module.1 | BatchNorm2d | 512 \n", 358 | "208 | model.backbone.d3.c5.module.2 | Mish | 0 \n", 359 | "209 | model.backbone.d4 | DownSampleBlock | 7 M \n", 360 | "210 | model.backbone.d4.c1 | ConvBlock | 1 M \n", 361 | "211 | model.backbone.d4.c1.module | Sequential | 1 M \n", 362 | "212 | model.backbone.d4.c1.module.0 | Conv2d | 1 M \n", 363 | "213 | model.backbone.d4.c1.module.1 | BatchNorm2d | 1 K \n", 364 | "214 | model.backbone.d4.c1.module.2 | Mish | 0 \n", 365 | "215 | model.backbone.d4.c2 | ConvBlock | 131 K \n", 366 | "216 | model.backbone.d4.c2.module | Sequential | 131 K \n", 367 | "217 | model.backbone.d4.c2.module.0 | Conv2d | 131 K \n", 368 | "218 | model.backbone.d4.c2.module.1 | BatchNorm2d | 512 \n", 369 | "219 | model.backbone.d4.c2.module.2 | Mish | 0 \n", 370 | "220 | model.backbone.d4.r3 | ResBlock | 5 M \n", 371 | "221 | model.backbone.d4.r3.module_list | ModuleList | 5 M \n", 372 | "222 | model.backbone.d4.r3.module_list.0 | ModuleList | 656 K \n", 373 | "223 | model.backbone.d4.r3.module_list.0.0 | ConvBlock | 66 K \n", 374 | "224 | model.backbone.d4.r3.module_list.0.0.module | Sequential | 66 K \n", 375 | "225 | model.backbone.d4.r3.module_list.0.0.module.0 | Conv2d | 65 K \n", 376 | "226 | model.backbone.d4.r3.module_list.0.0.module.1 | BatchNorm2d | 512 \n", 377 | "227 | model.backbone.d4.r3.module_list.0.0.module.2 | Mish | 0 \n", 378 | "228 | model.backbone.d4.r3.module_list.0.1 | ConvBlock | 590 K \n", 379 | "229 | model.backbone.d4.r3.module_list.0.1.module | Sequential | 590 K \n", 380 | "230 | model.backbone.d4.r3.module_list.0.1.module.0 | Conv2d | 589 K \n", 381 | "231 | model.backbone.d4.r3.module_list.0.1.module.1 | BatchNorm2d | 512 \n", 382 | "232 | model.backbone.d4.r3.module_list.0.1.module.2 | Mish | 0 \n", 383 | "233 | model.backbone.d4.r3.module_list.1 | ModuleList | 656 K \n", 384 | "234 | model.backbone.d4.r3.module_list.1.0 | ConvBlock | 66 K \n", 385 | "235 | model.backbone.d4.r3.module_list.1.0.module | Sequential | 66 K \n", 386 | "236 | model.backbone.d4.r3.module_list.1.0.module.0 | Conv2d | 65 K \n", 387 | "237 | model.backbone.d4.r3.module_list.1.0.module.1 | BatchNorm2d | 512 \n", 388 | "238 | model.backbone.d4.r3.module_list.1.0.module.2 | Mish | 0 \n", 389 | "239 | model.backbone.d4.r3.module_list.1.1 | ConvBlock | 590 K \n", 390 | "240 | model.backbone.d4.r3.module_list.1.1.module | Sequential | 590 K \n", 391 | "241 | model.backbone.d4.r3.module_list.1.1.module.0 | Conv2d | 589 K \n", 392 | "242 | model.backbone.d4.r3.module_list.1.1.module.1 | BatchNorm2d | 512 \n", 393 | "243 | model.backbone.d4.r3.module_list.1.1.module.2 | Mish | 0 \n", 394 | "244 | model.backbone.d4.r3.module_list.2 | ModuleList | 656 K \n", 395 | "245 | model.backbone.d4.r3.module_list.2.0 | ConvBlock | 66 K \n", 396 | "246 | model.backbone.d4.r3.module_list.2.0.module | Sequential | 66 K \n", 397 | "247 | model.backbone.d4.r3.module_list.2.0.module.0 | Conv2d | 65 K \n", 398 | "248 | model.backbone.d4.r3.module_list.2.0.module.1 | BatchNorm2d | 512 \n", 399 | "249 | model.backbone.d4.r3.module_list.2.0.module.2 | Mish | 0 \n", 400 | "250 | model.backbone.d4.r3.module_list.2.1 | ConvBlock | 590 K \n", 401 | "251 | model.backbone.d4.r3.module_list.2.1.module | Sequential | 590 K \n", 402 | "252 | model.backbone.d4.r3.module_list.2.1.module.0 | Conv2d | 589 K \n", 403 | "253 | model.backbone.d4.r3.module_list.2.1.module.1 | BatchNorm2d | 512 \n", 404 | "254 | model.backbone.d4.r3.module_list.2.1.module.2 | Mish | 0 \n", 405 | "255 | model.backbone.d4.r3.module_list.3 | ModuleList | 656 K \n", 406 | "256 | model.backbone.d4.r3.module_list.3.0 | ConvBlock | 66 K \n", 407 | "257 | model.backbone.d4.r3.module_list.3.0.module | Sequential | 66 K \n", 408 | "258 | model.backbone.d4.r3.module_list.3.0.module.0 | Conv2d | 65 K \n", 409 | "259 | model.backbone.d4.r3.module_list.3.0.module.1 | BatchNorm2d | 512 \n", 410 | "260 | model.backbone.d4.r3.module_list.3.0.module.2 | Mish | 0 \n", 411 | "261 | model.backbone.d4.r3.module_list.3.1 | ConvBlock | 590 K \n", 412 | "262 | model.backbone.d4.r3.module_list.3.1.module | Sequential | 590 K \n", 413 | "263 | model.backbone.d4.r3.module_list.3.1.module.0 | Conv2d | 589 K \n", 414 | "264 | model.backbone.d4.r3.module_list.3.1.module.1 | BatchNorm2d | 512 \n", 415 | "265 | model.backbone.d4.r3.module_list.3.1.module.2 | Mish | 0 \n", 416 | "266 | model.backbone.d4.r3.module_list.4 | ModuleList | 656 K \n", 417 | "267 | model.backbone.d4.r3.module_list.4.0 | ConvBlock | 66 K \n", 418 | "268 | model.backbone.d4.r3.module_list.4.0.module | Sequential | 66 K \n", 419 | "269 | model.backbone.d4.r3.module_list.4.0.module.0 | Conv2d | 65 K \n", 420 | "270 | model.backbone.d4.r3.module_list.4.0.module.1 | BatchNorm2d | 512 \n", 421 | "271 | model.backbone.d4.r3.module_list.4.0.module.2 | Mish | 0 \n", 422 | "272 | model.backbone.d4.r3.module_list.4.1 | ConvBlock | 590 K \n", 423 | "273 | model.backbone.d4.r3.module_list.4.1.module | Sequential | 590 K \n", 424 | "274 | model.backbone.d4.r3.module_list.4.1.module.0 | Conv2d | 589 K \n", 425 | "275 | model.backbone.d4.r3.module_list.4.1.module.1 | BatchNorm2d | 512 \n", 426 | "276 | model.backbone.d4.r3.module_list.4.1.module.2 | Mish | 0 \n", 427 | "277 | model.backbone.d4.r3.module_list.5 | ModuleList | 656 K \n", 428 | "278 | model.backbone.d4.r3.module_list.5.0 | ConvBlock | 66 K \n", 429 | "279 | model.backbone.d4.r3.module_list.5.0.module | Sequential | 66 K \n", 430 | "280 | model.backbone.d4.r3.module_list.5.0.module.0 | Conv2d | 65 K \n", 431 | "281 | model.backbone.d4.r3.module_list.5.0.module.1 | BatchNorm2d | 512 \n", 432 | "282 | model.backbone.d4.r3.module_list.5.0.module.2 | Mish | 0 \n", 433 | "283 | model.backbone.d4.r3.module_list.5.1 | ConvBlock | 590 K \n", 434 | "284 | model.backbone.d4.r3.module_list.5.1.module | Sequential | 590 K \n", 435 | "285 | model.backbone.d4.r3.module_list.5.1.module.0 | Conv2d | 589 K \n", 436 | "286 | model.backbone.d4.r3.module_list.5.1.module.1 | BatchNorm2d | 512 \n", 437 | "287 | model.backbone.d4.r3.module_list.5.1.module.2 | Mish | 0 \n", 438 | "288 | model.backbone.d4.r3.module_list.6 | ModuleList | 656 K \n", 439 | "289 | model.backbone.d4.r3.module_list.6.0 | ConvBlock | 66 K \n", 440 | "290 | model.backbone.d4.r3.module_list.6.0.module | Sequential | 66 K \n", 441 | "291 | model.backbone.d4.r3.module_list.6.0.module.0 | Conv2d | 65 K \n", 442 | "292 | model.backbone.d4.r3.module_list.6.0.module.1 | BatchNorm2d | 512 \n", 443 | "293 | model.backbone.d4.r3.module_list.6.0.module.2 | Mish | 0 \n", 444 | "294 | model.backbone.d4.r3.module_list.6.1 | ConvBlock | 590 K \n", 445 | "295 | model.backbone.d4.r3.module_list.6.1.module | Sequential | 590 K \n", 446 | "296 | model.backbone.d4.r3.module_list.6.1.module.0 | Conv2d | 589 K \n", 447 | "297 | model.backbone.d4.r3.module_list.6.1.module.1 | BatchNorm2d | 512 \n", 448 | "298 | model.backbone.d4.r3.module_list.6.1.module.2 | Mish | 0 \n", 449 | "299 | model.backbone.d4.r3.module_list.7 | ModuleList | 656 K \n", 450 | "300 | model.backbone.d4.r3.module_list.7.0 | ConvBlock | 66 K \n", 451 | "301 | model.backbone.d4.r3.module_list.7.0.module | Sequential | 66 K \n", 452 | "302 | model.backbone.d4.r3.module_list.7.0.module.0 | Conv2d | 65 K \n", 453 | "303 | model.backbone.d4.r3.module_list.7.0.module.1 | BatchNorm2d | 512 \n", 454 | "304 | model.backbone.d4.r3.module_list.7.0.module.2 | Mish | 0 \n", 455 | "305 | model.backbone.d4.r3.module_list.7.1 | ConvBlock | 590 K \n", 456 | "306 | model.backbone.d4.r3.module_list.7.1.module | Sequential | 590 K \n", 457 | "307 | model.backbone.d4.r3.module_list.7.1.module.0 | Conv2d | 589 K \n", 458 | "308 | model.backbone.d4.r3.module_list.7.1.module.1 | BatchNorm2d | 512 \n", 459 | "309 | model.backbone.d4.r3.module_list.7.1.module.2 | Mish | 0 \n", 460 | "310 | model.backbone.d4.c4 | ConvBlock | 66 K \n", 461 | "311 | model.backbone.d4.c4.module | Sequential | 66 K \n", 462 | "312 | model.backbone.d4.c4.module.0 | Conv2d | 65 K \n", 463 | "313 | model.backbone.d4.c4.module.1 | BatchNorm2d | 512 \n", 464 | "314 | model.backbone.d4.c4.module.2 | Mish | 0 \n", 465 | "315 | model.backbone.d4.dense_c2_c4 | ConvBlock | 131 K \n", 466 | "316 | model.backbone.d4.dense_c2_c4.module | Sequential | 131 K \n", 467 | "317 | model.backbone.d4.dense_c2_c4.module.0 | Conv2d | 131 K \n", 468 | "318 | model.backbone.d4.dense_c2_c4.module.1 | BatchNorm2d | 512 \n", 469 | "319 | model.backbone.d4.dense_c2_c4.module.2 | Mish | 0 \n", 470 | "320 | model.backbone.d4.c5 | ConvBlock | 263 K \n", 471 | "321 | model.backbone.d4.c5.module | Sequential | 263 K \n", 472 | "322 | model.backbone.d4.c5.module.0 | Conv2d | 262 K \n", 473 | "323 | model.backbone.d4.c5.module.1 | BatchNorm2d | 1 K \n", 474 | "324 | model.backbone.d4.c5.module.2 | Mish | 0 \n", 475 | "325 | model.backbone.d5 | DownSampleBlock | 17 M \n", 476 | "326 | model.backbone.d5.c1 | ConvBlock | 4 M \n", 477 | "327 | model.backbone.d5.c1.module | Sequential | 4 M \n", 478 | "328 | model.backbone.d5.c1.module.0 | Conv2d | 4 M \n", 479 | "329 | model.backbone.d5.c1.module.1 | BatchNorm2d | 2 K \n", 480 | "330 | model.backbone.d5.c1.module.2 | Mish | 0 \n", 481 | "331 | model.backbone.d5.c2 | ConvBlock | 525 K \n", 482 | "332 | model.backbone.d5.c2.module | Sequential | 525 K \n", 483 | "333 | model.backbone.d5.c2.module.0 | Conv2d | 524 K \n", 484 | "334 | model.backbone.d5.c2.module.1 | BatchNorm2d | 1 K \n", 485 | "335 | model.backbone.d5.c2.module.2 | Mish | 0 \n", 486 | "336 | model.backbone.d5.r3 | ResBlock | 10 M \n", 487 | "337 | model.backbone.d5.r3.module_list | ModuleList | 10 M \n", 488 | "338 | model.backbone.d5.r3.module_list.0 | ModuleList | 2 M \n", 489 | "339 | model.backbone.d5.r3.module_list.0.0 | ConvBlock | 263 K \n", 490 | "340 | model.backbone.d5.r3.module_list.0.0.module | Sequential | 263 K \n", 491 | "341 | model.backbone.d5.r3.module_list.0.0.module.0 | Conv2d | 262 K \n", 492 | "342 | model.backbone.d5.r3.module_list.0.0.module.1 | BatchNorm2d | 1 K \n", 493 | "343 | model.backbone.d5.r3.module_list.0.0.module.2 | Mish | 0 \n", 494 | "344 | model.backbone.d5.r3.module_list.0.1 | ConvBlock | 2 M \n", 495 | "345 | model.backbone.d5.r3.module_list.0.1.module | Sequential | 2 M \n", 496 | "346 | model.backbone.d5.r3.module_list.0.1.module.0 | Conv2d | 2 M \n", 497 | "347 | model.backbone.d5.r3.module_list.0.1.module.1 | BatchNorm2d | 1 K \n", 498 | "348 | model.backbone.d5.r3.module_list.0.1.module.2 | Mish | 0 \n", 499 | "349 | model.backbone.d5.r3.module_list.1 | ModuleList | 2 M \n", 500 | "350 | model.backbone.d5.r3.module_list.1.0 | ConvBlock | 263 K \n", 501 | "351 | model.backbone.d5.r3.module_list.1.0.module | Sequential | 263 K \n", 502 | "352 | model.backbone.d5.r3.module_list.1.0.module.0 | Conv2d | 262 K \n", 503 | "353 | model.backbone.d5.r3.module_list.1.0.module.1 | BatchNorm2d | 1 K \n", 504 | "354 | model.backbone.d5.r3.module_list.1.0.module.2 | Mish | 0 \n", 505 | "355 | model.backbone.d5.r3.module_list.1.1 | ConvBlock | 2 M \n", 506 | "356 | model.backbone.d5.r3.module_list.1.1.module | Sequential | 2 M \n", 507 | "357 | model.backbone.d5.r3.module_list.1.1.module.0 | Conv2d | 2 M \n", 508 | "358 | model.backbone.d5.r3.module_list.1.1.module.1 | BatchNorm2d | 1 K \n", 509 | "359 | model.backbone.d5.r3.module_list.1.1.module.2 | Mish | 0 \n", 510 | "360 | model.backbone.d5.r3.module_list.2 | ModuleList | 2 M \n", 511 | "361 | model.backbone.d5.r3.module_list.2.0 | ConvBlock | 263 K \n", 512 | "362 | model.backbone.d5.r3.module_list.2.0.module | Sequential | 263 K \n", 513 | "363 | model.backbone.d5.r3.module_list.2.0.module.0 | Conv2d | 262 K \n", 514 | "364 | model.backbone.d5.r3.module_list.2.0.module.1 | BatchNorm2d | 1 K \n", 515 | "365 | model.backbone.d5.r3.module_list.2.0.module.2 | Mish | 0 \n", 516 | "366 | model.backbone.d5.r3.module_list.2.1 | ConvBlock | 2 M \n", 517 | "367 | model.backbone.d5.r3.module_list.2.1.module | Sequential | 2 M \n", 518 | "368 | model.backbone.d5.r3.module_list.2.1.module.0 | Conv2d | 2 M \n", 519 | "369 | model.backbone.d5.r3.module_list.2.1.module.1 | BatchNorm2d | 1 K \n", 520 | "370 | model.backbone.d5.r3.module_list.2.1.module.2 | Mish | 0 \n", 521 | "371 | model.backbone.d5.r3.module_list.3 | ModuleList | 2 M \n", 522 | "372 | model.backbone.d5.r3.module_list.3.0 | ConvBlock | 263 K \n", 523 | "373 | model.backbone.d5.r3.module_list.3.0.module | Sequential | 263 K \n", 524 | "374 | model.backbone.d5.r3.module_list.3.0.module.0 | Conv2d | 262 K \n", 525 | "375 | model.backbone.d5.r3.module_list.3.0.module.1 | BatchNorm2d | 1 K \n", 526 | "376 | model.backbone.d5.r3.module_list.3.0.module.2 | Mish | 0 \n", 527 | "377 | model.backbone.d5.r3.module_list.3.1 | ConvBlock | 2 M \n", 528 | "378 | model.backbone.d5.r3.module_list.3.1.module | Sequential | 2 M \n", 529 | "379 | model.backbone.d5.r3.module_list.3.1.module.0 | Conv2d | 2 M \n", 530 | "380 | model.backbone.d5.r3.module_list.3.1.module.1 | BatchNorm2d | 1 K \n", 531 | "381 | model.backbone.d5.r3.module_list.3.1.module.2 | Mish | 0 \n", 532 | "382 | model.backbone.d5.c4 | ConvBlock | 263 K \n", 533 | "383 | model.backbone.d5.c4.module | Sequential | 263 K \n", 534 | "384 | model.backbone.d5.c4.module.0 | Conv2d | 262 K \n", 535 | "385 | model.backbone.d5.c4.module.1 | BatchNorm2d | 1 K \n", 536 | "386 | model.backbone.d5.c4.module.2 | Mish | 0 \n", 537 | "387 | model.backbone.d5.dense_c2_c4 | ConvBlock | 525 K \n", 538 | "388 | model.backbone.d5.dense_c2_c4.module | Sequential | 525 K \n", 539 | "389 | model.backbone.d5.dense_c2_c4.module.0 | Conv2d | 524 K \n", 540 | "390 | model.backbone.d5.dense_c2_c4.module.1 | BatchNorm2d | 1 K \n", 541 | "391 | model.backbone.d5.dense_c2_c4.module.2 | Mish | 0 \n", 542 | "392 | model.backbone.d5.c5 | ConvBlock | 1 M \n", 543 | "393 | model.backbone.d5.c5.module | Sequential | 1 M \n", 544 | "394 | model.backbone.d5.c5.module.0 | Conv2d | 1 M \n", 545 | "395 | model.backbone.d5.c5.module.1 | BatchNorm2d | 2 K \n", 546 | "396 | model.backbone.d5.c5.module.2 | Mish | 0 \n", 547 | "397 | model.neck | Neck | 21 M \n", 548 | "398 | model.neck.c1 | ConvBlock | 525 K \n", 549 | "399 | model.neck.c1.module | Sequential | 525 K \n", 550 | "400 | model.neck.c1.module.0 | Conv2d | 524 K \n", 551 | "401 | model.neck.c1.module.1 | BatchNorm2d | 1 K \n", 552 | "402 | model.neck.c1.module.2 | LeakyReLU | 0 \n", 553 | "403 | model.neck.c2 | ConvBlock | 4 M \n", 554 | "404 | model.neck.c2.module | Sequential | 4 M \n", 555 | "405 | model.neck.c2.module.0 | Conv2d | 4 M \n", 556 | "406 | model.neck.c2.module.1 | BatchNorm2d | 2 K \n", 557 | "407 | model.neck.c2.module.2 | LeakyReLU | 0 \n", 558 | "408 | model.neck.c3 | ConvBlock | 525 K \n", 559 | "409 | model.neck.c3.module | Sequential | 525 K \n", 560 | "410 | model.neck.c3.module.0 | Conv2d | 524 K \n", 561 | "411 | model.neck.c3.module.1 | BatchNorm2d | 1 K \n", 562 | "412 | model.neck.c3.module.2 | LeakyReLU | 0 \n", 563 | "413 | model.neck.mp4_1 | MaxPool2d | 0 \n", 564 | "414 | model.neck.mp4_2 | MaxPool2d | 0 \n", 565 | "415 | model.neck.mp4_3 | MaxPool2d | 0 \n", 566 | "416 | model.neck.c5 | ConvBlock | 1 M \n", 567 | "417 | model.neck.c5.module | Sequential | 1 M \n", 568 | "418 | model.neck.c5.module.0 | Conv2d | 1 M \n", 569 | "419 | model.neck.c5.module.1 | BatchNorm2d | 1 K \n", 570 | "420 | model.neck.c5.module.2 | LeakyReLU | 0 \n", 571 | "421 | model.neck.c6 | ConvBlock | 4 M \n", 572 | "422 | model.neck.c6.module | Sequential | 4 M \n", 573 | "423 | model.neck.c6.module.0 | Conv2d | 4 M \n", 574 | "424 | model.neck.c6.module.1 | BatchNorm2d | 2 K \n", 575 | "425 | model.neck.c6.module.2 | LeakyReLU | 0 \n", 576 | "426 | model.neck.c7 | ConvBlock | 525 K \n", 577 | "427 | model.neck.c7.module | Sequential | 525 K \n", 578 | "428 | model.neck.c7.module.0 | Conv2d | 524 K \n", 579 | "429 | model.neck.c7.module.1 | BatchNorm2d | 1 K \n", 580 | "430 | model.neck.c7.module.2 | LeakyReLU | 0 \n", 581 | "431 | model.neck.PAN8 | PAN_Layer | 3 M \n", 582 | "432 | model.neck.PAN8.c1 | ConvBlock | 131 K \n", 583 | "433 | model.neck.PAN8.c1.module | Sequential | 131 K \n", 584 | "434 | model.neck.PAN8.c1.module.0 | Conv2d | 131 K \n", 585 | "435 | model.neck.PAN8.c1.module.1 | BatchNorm2d | 512 \n", 586 | "436 | model.neck.PAN8.c1.module.2 | LeakyReLU | 0 \n", 587 | "437 | model.neck.PAN8.u2 | Upsample | 0 \n", 588 | "438 | model.neck.PAN8.c2_from_upsampled | ConvBlock | 131 K \n", 589 | "439 | model.neck.PAN8.c2_from_upsampled.module | Sequential | 131 K \n", 590 | "440 | model.neck.PAN8.c2_from_upsampled.module.0 | Conv2d | 131 K \n", 591 | "441 | model.neck.PAN8.c2_from_upsampled.module.1 | BatchNorm2d | 512 \n", 592 | "442 | model.neck.PAN8.c2_from_upsampled.module.2 | LeakyReLU | 0 \n", 593 | "443 | model.neck.PAN8.c3 | ConvBlock | 131 K \n", 594 | "444 | model.neck.PAN8.c3.module | Sequential | 131 K \n", 595 | "445 | model.neck.PAN8.c3.module.0 | Conv2d | 131 K \n", 596 | "446 | model.neck.PAN8.c3.module.1 | BatchNorm2d | 512 \n", 597 | "447 | model.neck.PAN8.c3.module.2 | LeakyReLU | 0 \n", 598 | "448 | model.neck.PAN8.c4 | ConvBlock | 1 M \n", 599 | "449 | model.neck.PAN8.c4.module | Sequential | 1 M \n", 600 | "450 | model.neck.PAN8.c4.module.0 | Conv2d | 1 M \n", 601 | "451 | model.neck.PAN8.c4.module.1 | BatchNorm2d | 1 K \n", 602 | "452 | model.neck.PAN8.c4.module.2 | LeakyReLU | 0 \n", 603 | "453 | model.neck.PAN8.c5 | ConvBlock | 131 K \n", 604 | "454 | model.neck.PAN8.c5.module | Sequential | 131 K \n", 605 | "455 | model.neck.PAN8.c5.module.0 | Conv2d | 131 K \n", 606 | "456 | model.neck.PAN8.c5.module.1 | BatchNorm2d | 512 \n", 607 | "457 | model.neck.PAN8.c5.module.2 | LeakyReLU | 0 \n", 608 | "458 | model.neck.PAN8.c6 | ConvBlock | 1 M \n", 609 | "459 | model.neck.PAN8.c6.module | Sequential | 1 M \n", 610 | "460 | model.neck.PAN8.c6.module.0 | Conv2d | 1 M \n", 611 | "461 | model.neck.PAN8.c6.module.1 | BatchNorm2d | 1 K \n", 612 | "462 | model.neck.PAN8.c6.module.2 | LeakyReLU | 0 \n", 613 | "463 | model.neck.PAN8.c7 | ConvBlock | 131 K \n", 614 | "464 | model.neck.PAN8.c7.module | Sequential | 131 K \n", 615 | "465 | model.neck.PAN8.c7.module.0 | Conv2d | 131 K \n", 616 | "466 | model.neck.PAN8.c7.module.1 | BatchNorm2d | 512 \n", 617 | "467 | model.neck.PAN8.c7.module.2 | LeakyReLU | 0 \n", 618 | "468 | model.neck.PAN9 | PAN_Layer | 755 K \n", 619 | "469 | model.neck.PAN9.c1 | ConvBlock | 33 K \n", 620 | "470 | model.neck.PAN9.c1.module | Sequential | 33 K \n", 621 | "471 | model.neck.PAN9.c1.module.0 | Conv2d | 32 K \n", 622 | "472 | model.neck.PAN9.c1.module.1 | BatchNorm2d | 256 \n", 623 | "473 | model.neck.PAN9.c1.module.2 | LeakyReLU | 0 \n", 624 | "474 | model.neck.PAN9.u2 | Upsample | 0 \n", 625 | "475 | model.neck.PAN9.c2_from_upsampled | ConvBlock | 33 K \n", 626 | "476 | model.neck.PAN9.c2_from_upsampled.module | Sequential | 33 K \n", 627 | "477 | model.neck.PAN9.c2_from_upsampled.module.0 | Conv2d | 32 K \n", 628 | "478 | model.neck.PAN9.c2_from_upsampled.module.1 | BatchNorm2d | 256 \n", 629 | "479 | model.neck.PAN9.c2_from_upsampled.module.2 | LeakyReLU | 0 \n", 630 | "480 | model.neck.PAN9.c3 | ConvBlock | 33 K \n", 631 | "481 | model.neck.PAN9.c3.module | Sequential | 33 K \n", 632 | "482 | model.neck.PAN9.c3.module.0 | Conv2d | 32 K \n", 633 | "483 | model.neck.PAN9.c3.module.1 | BatchNorm2d | 256 \n", 634 | "484 | model.neck.PAN9.c3.module.2 | LeakyReLU | 0 \n", 635 | "485 | model.neck.PAN9.c4 | ConvBlock | 295 K \n", 636 | "486 | model.neck.PAN9.c4.module | Sequential | 295 K \n", 637 | "487 | model.neck.PAN9.c4.module.0 | Conv2d | 294 K \n", 638 | "488 | model.neck.PAN9.c4.module.1 | BatchNorm2d | 512 \n", 639 | "489 | model.neck.PAN9.c4.module.2 | LeakyReLU | 0 \n", 640 | "490 | model.neck.PAN9.c5 | ConvBlock | 33 K \n", 641 | "491 | model.neck.PAN9.c5.module | Sequential | 33 K \n", 642 | "492 | model.neck.PAN9.c5.module.0 | Conv2d | 32 K \n", 643 | "493 | model.neck.PAN9.c5.module.1 | BatchNorm2d | 256 \n", 644 | "494 | model.neck.PAN9.c5.module.2 | LeakyReLU | 0 \n", 645 | "495 | model.neck.PAN9.c6 | ConvBlock | 295 K \n", 646 | "496 | model.neck.PAN9.c6.module | Sequential | 295 K \n", 647 | "497 | model.neck.PAN9.c6.module.0 | Conv2d | 294 K \n", 648 | "498 | model.neck.PAN9.c6.module.1 | BatchNorm2d | 512 \n", 649 | "499 | model.neck.PAN9.c6.module.2 | LeakyReLU | 0 \n", 650 | "500 | model.neck.PAN9.c7 | ConvBlock | 33 K \n", 651 | "501 | model.neck.PAN9.c7.module | Sequential | 33 K \n", 652 | "502 | model.neck.PAN9.c7.module.0 | Conv2d | 32 K \n", 653 | "503 | model.neck.PAN9.c7.module.1 | BatchNorm2d | 256 \n", 654 | "504 | model.neck.PAN9.c7.module.2 | LeakyReLU | 0 \n", 655 | "505 | model.neck.ACFF_0 | ACFF | 4 M \n", 656 | "506 | model.neck.ACFF_0.stride_level_1 | ConvBlock | 1 M \n", 657 | "507 | model.neck.ACFF_0.stride_level_1.module | Sequential | 1 M \n", 658 | "508 | model.neck.ACFF_0.stride_level_1.module.0 | Conv2d | 1 M \n", 659 | "509 | model.neck.ACFF_0.stride_level_1.module.1 | BatchNorm2d | 1 K \n", 660 | "510 | model.neck.ACFF_0.stride_level_1.module.2 | LeakyReLU | 0 \n", 661 | "511 | model.neck.ACFF_0.stride_level_2 | ConvBlock | 590 K \n", 662 | "512 | model.neck.ACFF_0.stride_level_2.module | Sequential | 590 K \n", 663 | "513 | model.neck.ACFF_0.stride_level_2.module.0 | Conv2d | 589 K \n", 664 | "514 | model.neck.ACFF_0.stride_level_2.module.1 | BatchNorm2d | 1 K \n", 665 | "515 | model.neck.ACFF_0.stride_level_2.module.2 | LeakyReLU | 0 \n", 666 | "516 | model.neck.ACFF_0.expand | ConvBlock | 2 M \n", 667 | "517 | model.neck.ACFF_0.expand.module | Sequential | 2 M \n", 668 | "518 | model.neck.ACFF_0.expand.module.0 | Conv2d | 2 M \n", 669 | "519 | model.neck.ACFF_0.expand.module.1 | BatchNorm2d | 1 K \n", 670 | "520 | model.neck.ACFF_0.expand.module.2 | LeakyReLU | 0 \n", 671 | "521 | model.neck.ACFF_1 | ACFF | 1 M \n", 672 | "522 | model.neck.ACFF_1.compress_level_0 | ConvBlock | 131 K \n", 673 | "523 | model.neck.ACFF_1.compress_level_0.module | Sequential | 131 K \n", 674 | "524 | model.neck.ACFF_1.compress_level_0.module.0 | Conv2d | 131 K \n", 675 | "525 | model.neck.ACFF_1.compress_level_0.module.1 | BatchNorm2d | 512 \n", 676 | "526 | model.neck.ACFF_1.compress_level_0.module.2 | LeakyReLU | 0 \n", 677 | "527 | model.neck.ACFF_1.stride_level_2 | ConvBlock | 295 K \n", 678 | "528 | model.neck.ACFF_1.stride_level_2.module | Sequential | 295 K \n", 679 | "529 | model.neck.ACFF_1.stride_level_2.module.0 | Conv2d | 294 K \n", 680 | "530 | model.neck.ACFF_1.stride_level_2.module.1 | BatchNorm2d | 512 \n", 681 | "531 | model.neck.ACFF_1.stride_level_2.module.2 | LeakyReLU | 0 \n", 682 | "532 | model.neck.ACFF_1.expand | ConvBlock | 590 K \n", 683 | "533 | model.neck.ACFF_1.expand.module | Sequential | 590 K \n", 684 | "534 | model.neck.ACFF_1.expand.module.0 | Conv2d | 589 K \n", 685 | "535 | model.neck.ACFF_1.expand.module.1 | BatchNorm2d | 512 \n", 686 | "536 | model.neck.ACFF_1.expand.module.2 | LeakyReLU | 0 \n", 687 | "537 | model.neck.ACFF_2 | ACFF | 247 K \n", 688 | "538 | model.neck.ACFF_2.compress_level_0 | ConvBlock | 65 K \n", 689 | "539 | model.neck.ACFF_2.compress_level_0.module | Sequential | 65 K \n", 690 | "540 | model.neck.ACFF_2.compress_level_0.module.0 | Conv2d | 65 K \n", 691 | "541 | model.neck.ACFF_2.compress_level_0.module.1 | BatchNorm2d | 256 \n", 692 | "542 | model.neck.ACFF_2.compress_level_0.module.2 | LeakyReLU | 0 \n", 693 | "543 | model.neck.ACFF_2.compress_level_1 | ConvBlock | 33 K \n", 694 | "544 | model.neck.ACFF_2.compress_level_1.module | Sequential | 33 K \n", 695 | "545 | model.neck.ACFF_2.compress_level_1.module.0 | Conv2d | 32 K \n", 696 | "546 | model.neck.ACFF_2.compress_level_1.module.1 | BatchNorm2d | 256 \n", 697 | "547 | model.neck.ACFF_2.compress_level_1.module.2 | LeakyReLU | 0 \n", 698 | "548 | model.neck.ACFF_2.expand | ConvBlock | 147 K \n", 699 | "549 | model.neck.ACFF_2.expand.module | Sequential | 147 K \n", 700 | "550 | model.neck.ACFF_2.expand.module.0 | Conv2d | 147 K \n", 701 | "551 | model.neck.ACFF_2.expand.module.1 | BatchNorm2d | 256 \n", 702 | "552 | model.neck.ACFF_2.expand.module.2 | LeakyReLU | 0 \n", 703 | "553 | model.head | Head | 21 M \n", 704 | "554 | model.head.ho1 | HeadOutput | 330 K \n", 705 | "555 | model.head.ho1.c1 | ConvBlock | 295 K \n", 706 | "556 | model.head.ho1.c1.module | Sequential | 295 K \n", 707 | "557 | model.head.ho1.c1.module.0 | Conv2d | 294 K \n", 708 | "558 | model.head.ho1.c1.module.1 | BatchNorm2d | 512 \n", 709 | "559 | model.head.ho1.c1.module.2 | LeakyReLU | 0 \n", 710 | "560 | model.head.ho1.c2 | ConvBlock | 34 K \n", 711 | "561 | model.head.ho1.c2.module | Sequential | 34 K \n", 712 | "562 | model.head.ho1.c2.module.0 | Conv2d | 34 K \n", 713 | "563 | model.head.hp2 | HeadPreprocessing | 3 M \n", 714 | "564 | model.head.hp2.c1 | ConvBlock | 295 K \n", 715 | "565 | model.head.hp2.c1.module | Sequential | 295 K \n", 716 | "566 | model.head.hp2.c1.module.0 | Conv2d | 294 K \n", 717 | "567 | model.head.hp2.c1.module.1 | BatchNorm2d | 512 \n", 718 | "568 | model.head.hp2.c1.module.2 | LeakyReLU | 0 \n", 719 | "569 | model.head.hp2.c2 | ConvBlock | 131 K \n", 720 | "570 | model.head.hp2.c2.module | Sequential | 131 K \n", 721 | "571 | model.head.hp2.c2.module.0 | Conv2d | 131 K \n", 722 | "572 | model.head.hp2.c2.module.1 | BatchNorm2d | 512 \n", 723 | "573 | model.head.hp2.c2.module.2 | LeakyReLU | 0 \n", 724 | "574 | model.head.hp2.c3 | ConvBlock | 1 M \n", 725 | "575 | model.head.hp2.c3.module | Sequential | 1 M \n", 726 | "576 | model.head.hp2.c3.module.0 | Conv2d | 1 M \n", 727 | "577 | model.head.hp2.c3.module.1 | BatchNorm2d | 1 K \n", 728 | "578 | model.head.hp2.c3.module.2 | LeakyReLU | 0 \n", 729 | "579 | model.head.hp2.c4 | ConvBlock | 131 K \n", 730 | "580 | model.head.hp2.c4.module | Sequential | 131 K \n", 731 | "581 | model.head.hp2.c4.module.0 | Conv2d | 131 K \n", 732 | "582 | model.head.hp2.c4.module.1 | BatchNorm2d | 512 \n", 733 | "583 | model.head.hp2.c4.module.2 | LeakyReLU | 0 \n", 734 | "584 | model.head.hp2.c5 | ConvBlock | 1 M \n", 735 | "585 | model.head.hp2.c5.module | Sequential | 1 M \n", 736 | "586 | model.head.hp2.c5.module.0 | Conv2d | 1 M \n", 737 | "587 | model.head.hp2.c5.module.1 | BatchNorm2d | 1 K \n", 738 | "588 | model.head.hp2.c5.module.2 | LeakyReLU | 0 \n", 739 | "589 | model.head.hp2.c6 | ConvBlock | 131 K \n", 740 | "590 | model.head.hp2.c6.module | Sequential | 131 K \n", 741 | "591 | model.head.hp2.c6.module.0 | Conv2d | 131 K \n", 742 | "592 | model.head.hp2.c6.module.1 | BatchNorm2d | 512 \n", 743 | "593 | model.head.hp2.c6.module.2 | LeakyReLU | 0 \n", 744 | "594 | model.head.ho2 | HeadOutput | 1 M \n", 745 | "595 | model.head.ho2.c1 | ConvBlock | 1 M \n", 746 | "596 | model.head.ho2.c1.module | Sequential | 1 M \n", 747 | "597 | model.head.ho2.c1.module.0 | Conv2d | 1 M \n", 748 | "598 | model.head.ho2.c1.module.1 | BatchNorm2d | 1 K \n", 749 | "599 | model.head.ho2.c1.module.2 | LeakyReLU | 0 \n", 750 | "600 | model.head.ho2.c2 | ConvBlock | 69 K \n", 751 | "601 | model.head.ho2.c2.module | Sequential | 69 K \n", 752 | "602 | model.head.ho2.c2.module.0 | Conv2d | 69 K \n", 753 | "603 | model.head.hp3 | HeadPreprocessing | 12 M \n", 754 | "604 | model.head.hp3.c1 | ConvBlock | 1 M \n", 755 | "605 | model.head.hp3.c1.module | Sequential | 1 M \n", 756 | "606 | model.head.hp3.c1.module.0 | Conv2d | 1 M \n", 757 | "607 | model.head.hp3.c1.module.1 | BatchNorm2d | 1 K \n", 758 | "608 | model.head.hp3.c1.module.2 | LeakyReLU | 0 \n", 759 | "609 | model.head.hp3.c2 | ConvBlock | 525 K \n", 760 | "610 | model.head.hp3.c2.module | Sequential | 525 K \n", 761 | "611 | model.head.hp3.c2.module.0 | Conv2d | 524 K \n", 762 | "612 | model.head.hp3.c2.module.1 | BatchNorm2d | 1 K \n", 763 | "613 | model.head.hp3.c2.module.2 | LeakyReLU | 0 \n", 764 | "614 | model.head.hp3.c3 | ConvBlock | 4 M \n", 765 | "615 | model.head.hp3.c3.module | Sequential | 4 M \n", 766 | "616 | model.head.hp3.c3.module.0 | Conv2d | 4 M \n", 767 | "617 | model.head.hp3.c3.module.1 | BatchNorm2d | 2 K \n", 768 | "618 | model.head.hp3.c3.module.2 | LeakyReLU | 0 \n", 769 | "619 | model.head.hp3.c4 | ConvBlock | 525 K \n", 770 | "620 | model.head.hp3.c4.module | Sequential | 525 K \n", 771 | "621 | model.head.hp3.c4.module.0 | Conv2d | 524 K \n", 772 | "622 | model.head.hp3.c4.module.1 | BatchNorm2d | 1 K \n", 773 | "623 | model.head.hp3.c4.module.2 | LeakyReLU | 0 \n", 774 | "624 | model.head.hp3.c5 | ConvBlock | 4 M \n", 775 | "625 | model.head.hp3.c5.module | Sequential | 4 M \n", 776 | "626 | model.head.hp3.c5.module.0 | Conv2d | 4 M \n", 777 | "627 | model.head.hp3.c5.module.1 | BatchNorm2d | 2 K \n", 778 | "628 | model.head.hp3.c5.module.2 | LeakyReLU | 0 \n", 779 | "629 | model.head.hp3.c6 | ConvBlock | 525 K \n", 780 | "630 | model.head.hp3.c6.module | Sequential | 525 K \n", 781 | "631 | model.head.hp3.c6.module.0 | Conv2d | 524 K \n", 782 | "632 | model.head.hp3.c6.module.1 | BatchNorm2d | 1 K \n", 783 | "633 | model.head.hp3.c6.module.2 | LeakyReLU | 0 \n", 784 | "634 | model.head.ho3 | HeadOutput | 4 M \n", 785 | "635 | model.head.ho3.c1 | ConvBlock | 4 M \n", 786 | "636 | model.head.ho3.c1.module | Sequential | 4 M \n", 787 | "637 | model.head.ho3.c1.module.0 | Conv2d | 4 M \n", 788 | "638 | model.head.ho3.c1.module.1 | BatchNorm2d | 2 K \n", 789 | "639 | model.head.ho3.c1.module.2 | LeakyReLU | 0 \n", 790 | "640 | model.head.ho3.c2 | ConvBlock | 138 K \n", 791 | "641 | model.head.ho3.c2.module | Sequential | 138 K \n", 792 | "642 | model.head.ho3.c2.module.0 | Conv2d | 138 K \n", 793 | "643 | model.yolo1 | YOLOLayer | 0 \n", 794 | "644 | model.yolo2 | YOLOLayer | 0 \n", 795 | "645 | model.yolo3 | YOLOLayer | 0 \n", 796 | "Ranger optimizer loaded. \n", 797 | "Gradient Centralization usage = True\n", 798 | "GC applied to both conv and fc layers\n", 799 | "Finding best initial lr: 0%| | 0/100 [00:00\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlr_find\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mm\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmin_lr\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1e-10\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmax_lr\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mearly_stop_threshold\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 2\u001b[0m \u001b[0mr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 810 | "\u001b[1;32mD:\\Apps\\Anaconda\\lib\\site-packages\\pytorch_lightning\\trainer\\lr_finder.py\u001b[0m in \u001b[0;36mlr_find\u001b[1;34m(self, model, train_dataloader, val_dataloaders, min_lr, max_lr, num_training, mode, early_stop_threshold, num_accumulation_steps)\u001b[0m\n\u001b[0;32m 168\u001b[0m self.fit(model,\n\u001b[0;32m 169\u001b[0m \u001b[0mtrain_dataloader\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtrain_dataloader\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 170\u001b[1;33m val_dataloaders=val_dataloaders)\n\u001b[0m\u001b[0;32m 171\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 172\u001b[0m \u001b[1;31m# Prompt if we stopped early\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 811 | "\u001b[1;32mD:\\Apps\\Anaconda\\lib\\site-packages\\pytorch_lightning\\trainer\\trainer.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, model, train_dataloader, val_dataloaders)\u001b[0m\n\u001b[0;32m 885\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moptimizers\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlr_schedulers\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moptimizer_frequencies\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minit_optimizers\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 886\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 887\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun_pretrain_routine\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 888\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 889\u001b[0m \u001b[1;31m# return 1 when finished\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 812 | "\u001b[1;32mD:\\Apps\\Anaconda\\lib\\site-packages\\pytorch_lightning\\trainer\\trainer.py\u001b[0m in \u001b[0;36mrun_pretrain_routine\u001b[1;34m(self, model)\u001b[0m\n\u001b[0;32m 1013\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1014\u001b[0m \u001b[1;31m# CORE TRAINING LOOP\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1015\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1016\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1017\u001b[0m def test(\n", 813 | "\u001b[1;32mD:\\Apps\\Anaconda\\lib\\site-packages\\pytorch_lightning\\trainer\\training_loop.py\u001b[0m in \u001b[0;36mtrain\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 345\u001b[0m \u001b[1;31m# RUN TNG EPOCH\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 346\u001b[0m \u001b[1;31m# -----------------\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 347\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun_training_epoch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 348\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 349\u001b[0m \u001b[1;31m# update LR schedulers\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 814 | "\u001b[1;32mD:\\Apps\\Anaconda\\lib\\site-packages\\pytorch_lightning\\trainer\\training_loop.py\u001b[0m in \u001b[0;36mrun_training_epoch\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 417\u001b[0m \u001b[1;31m# RUN TRAIN STEP\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 418\u001b[0m \u001b[1;31m# ---------------\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 419\u001b[1;33m \u001b[0m_outputs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun_training_batch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 420\u001b[0m \u001b[0mbatch_result\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad_norm_dic\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_step_metrics\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_output\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_outputs\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 421\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 815 | "\u001b[1;32mD:\\Apps\\Anaconda\\lib\\site-packages\\pytorch_lightning\\trainer\\training_loop.py\u001b[0m in \u001b[0;36mrun_training_batch\u001b[1;34m(self, batch, batch_idx)\u001b[0m\n\u001b[0;32m 595\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 596\u001b[0m \u001b[1;31m# calculate loss\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 597\u001b[1;33m \u001b[0mloss\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_output\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0moptimizer_closure\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 598\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 599\u001b[0m \u001b[1;31m# check if loss or model weights are nan\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 816 | "\u001b[1;32mD:\\Apps\\Anaconda\\lib\\site-packages\\pytorch_lightning\\trainer\\training_loop.py\u001b[0m in \u001b[0;36moptimizer_closure\u001b[1;34m()\u001b[0m\n\u001b[0;32m 559\u001b[0m opt_idx, self.hiddens)\n\u001b[0;32m 560\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 561\u001b[1;33m \u001b[0moutput_dict\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtraining_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msplit_batch\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mhiddens\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 562\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 563\u001b[0m \u001b[1;31m# format and reduce outputs accordingly\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 817 | "\u001b[1;32mD:\\Apps\\Anaconda\\lib\\site-packages\\pytorch_lightning\\trainer\\training_loop.py\u001b[0m in \u001b[0;36mtraining_forward\u001b[1;34m(self, batch, batch_idx, opt_idx, hiddens)\u001b[0m\n\u001b[0;32m 735\u001b[0m \u001b[1;31m# CPU forward\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 736\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 737\u001b[1;33m \u001b[0moutput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 738\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 739\u001b[0m \u001b[1;31m# allow any mode to define training_step_end\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 818 | "\u001b[1;32md:\\Projects\\Yet-Another-YOLOV4-Pytorch\\pl_model.py\u001b[0m in \u001b[0;36mtraining_step\u001b[1;34m(self, batch, batch_idx)\u001b[0m\n\u001b[0;32m 86\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msat_fgsm_training_step\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mhparams\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mepsilon\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 87\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 88\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbasic_training_step\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 89\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 90\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mtraining_epoch_end\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 819 | "\u001b[1;32md:\\Projects\\Yet-Another-YOLOV4-Pytorch\\pl_model.py\u001b[0m in \u001b[0;36mbasic_training_step\u001b[1;34m(self, batch)\u001b[0m\n\u001b[0;32m 51\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mbasic_training_step\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 52\u001b[0m \u001b[0mfilenames\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mimages\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 53\u001b[1;33m \u001b[0my_hat\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 54\u001b[0m \u001b[0mlogger_logs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m{\u001b[0m\u001b[1;34m\"training_loss\"\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 55\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 820 | "\u001b[1;32mD:\\Apps\\Anaconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 548\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 549\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 550\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 551\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 552\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 821 | "\u001b[1;32md:\\Projects\\Yet-Another-YOLOV4-Pytorch\\pl_model.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x, y)\u001b[0m\n\u001b[0;32m 47\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 48\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 49\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 50\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 51\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mbasic_training_step\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 822 | "\u001b[1;32mD:\\Apps\\Anaconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 548\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 549\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 550\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 551\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 552\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 823 | "\u001b[1;32md:\\Projects\\Yet-Another-YOLOV4-Pytorch\\model.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x, y)\u001b[0m\n\u001b[0;32m 1151\u001b[0m \u001b[0mh1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mh2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mh3\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mh\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1152\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1153\u001b[1;33m \u001b[0mout1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mloss1\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0myolo1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mh1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1154\u001b[0m \u001b[0mout2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mloss2\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0myolo2\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mh2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1155\u001b[0m \u001b[0mout3\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mloss3\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0myolo3\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mh3\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 824 | "\u001b[1;32mD:\\Apps\\Anaconda\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 548\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 549\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 550\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 551\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 552\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 825 | "\u001b[1;32md:\\Projects\\Yet-Another-YOLOV4-Pytorch\\model.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x, targets)\u001b[0m\n\u001b[0;32m 1052\u001b[0m \u001b[0mtarget\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtargets\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1053\u001b[0m \u001b[0manchors\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mscaled_anchors\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1054\u001b[1;33m \u001b[0mignore_thres\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mignore_thres\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1055\u001b[0m )\n\u001b[0;32m 1056\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 826 | "\u001b[1;32md:\\Projects\\Yet-Another-YOLOV4-Pytorch\\model.py\u001b[0m in \u001b[0;36mbuild_targets\u001b[1;34m(self, pred_boxes, pred_cls, target, anchors, ignore_thres)\u001b[0m\n\u001b[0;32m 843\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 844\u001b[0m \u001b[1;31m# One-hot encoding of label (WE USE LABEL SMOOTHING)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 845\u001b[1;33m \u001b[0mtcls\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbest_n\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgj\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgi\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget_labels\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m0.9\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 846\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 847\u001b[0m \u001b[1;31m# Compute label correctness and iou at best anchor\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 827 | "\u001b[1;31mIndexError\u001b[0m: index 72 is out of bounds for dimension 4 with size 40" 828 | ] 829 | } 830 | ], 831 | "source": [ 832 | "r = t.lr_find(m, min_lr=1e-10, max_lr=1, early_stop_threshold=None)\n", 833 | "r.plot()" 834 | ] 835 | }, 836 | { 837 | "cell_type": "code", 838 | "execution_count": null, 839 | "metadata": {}, 840 | "outputs": [], 841 | "source": [ 842 | "t.fit(m)" 843 | ] 844 | }, 845 | { 846 | "cell_type": "code", 847 | "execution_count": null, 848 | "metadata": {}, 849 | "outputs": [], 850 | "source": [] 851 | } 852 | ], 853 | "metadata": { 854 | "kernelspec": { 855 | "display_name": "Python 3", 856 | "language": "python", 857 | "name": "python3" 858 | }, 859 | "language_info": { 860 | "codemirror_mode": { 861 | "name": "ipython", 862 | "version": 3 863 | }, 864 | "file_extension": ".py", 865 | "mimetype": "text/x-python", 866 | "name": "python", 867 | "nbconvert_exporter": "python", 868 | "pygments_lexer": "ipython3", 869 | "version": "3.7.7-final" 870 | } 871 | }, 872 | "nbformat": 4, 873 | "nbformat_minor": 4 874 | } --------------------------------------------------------------------------------