├── .deepsource.toml ├── .gitignore ├── README.md ├── model.hdf5 ├── reports ├── Experiments _ rotated-object-detection – Weights & Biases.pdf ├── FPN-summary ordered table.png └── FPN_torchsummary.txt ├── requirements.txt ├── setup.py ├── src ├── __init__.py ├── callbacks │ ├── __init__.py │ ├── base.py │ ├── logging.py │ └── model_checkpoint.py ├── checkpoints │ ├── __init__.py │ ├── model_93_ap.pt │ ├── runX.pt │ └── runX_optimizer.pt ├── dataloader.py ├── dataset.py ├── loss.py ├── main.py ├── metrics.py ├── models │ ├── __init__.py │ ├── baseline.py │ ├── detect_orn.py │ ├── detector.py │ ├── detector_fpn.py │ └── mish.py ├── train.py └── trainer.py └── tests ├── __init__.py ├── test_data.py ├── test_metrics.py └── test_models.py /.deepsource.toml: -------------------------------------------------------------------------------- 1 | version = 1 2 | 3 | [[analyzers]] 4 | name = "python" 5 | enabled = true 6 | 7 | [analyzers.meta] 8 | runtime_version = "3.x.x" 9 | 10 | [[analyzers]] 11 | name = "test-coverage" 12 | enabled = true -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # dot files 2 | .vscode 3 | 4 | # cache 5 | __pycache__/ 6 | .pytest_cache 7 | 8 | # packaging 9 | *.egg-info/ 10 | 11 | # logs 12 | wandb/ 13 | 14 | rotated_ship_data.py 15 | 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rotated-Object-Detection 2 | Novel ResNet inspired Tiny-FPN network (<2M params) for Rotated Object Detection using 5-parameter Modulated Rotation Loss 3 | 4 | ### Crux 5 | * **Architecture**: FPN with classification and regression heads ~1.9M parameters 6 | * **Loss Function**: 5 Parameter Modulated Rotation Loss 7 | * **Activation**: Mish 8 | * **Model Summary** - *reports/FPN_torchsummary.txt* (reports/ also contain alterantive summary with named layers in table) 9 | * **Training Script** - *src/train.py* 10 | * **Final Model Weights** - *src/checkpoints/model_93_ap.pt* 11 | * **Python Deps. and version** - *requirements.txt* 12 | * **Evaluation** - *src/main.py* 13 | 14 | 15 | ### Method 16 | * The reported results are using a ResNet inspired building block modules and an FPN. 17 | * Separate classification and regression subnets (single FC) are used. 18 | * Feature map from the top of the pyramid that has the best semantic representation is used for classification. 19 | * While the finer feature map at the bottom of the pyramid that has the best global representation is used for regressing the rotated bounding box. Finer details can be found in the code as comments. Code: *src/models/detector_fpn.py* 20 | 21 | * The whole implementation is from scratch, in PyTorch. Only the method for calculating AP from PR curves is borrowed and referenced (*src/metrics.py/compute_ap*). 22 | 23 | ### Approach 24 | 1. Random data generator that creates images with high noise and rotated objects (shapes) in random scales and orientations. (Private) 25 | 2. Compare reusing generated samples for each epoch VS online generating and loading 26 | 3. Implement modulated rotated loss and other metrics 27 | 4. Experiment with loss functions and activations 28 | 5. Tried to replace standard convolutional layers with ORN (Oriented Response Network) that use rotated filters to learn orientation (Could not integrate due to technical challenges) 29 | 6. Improve basic model to use different heads for classification and regression 30 | 7. Try variations by removing 512-dimensional filters as they take up the most parameters (~1M) 31 | 8. Add feature pyramid and experiment with different building blocks and convolutional parameters (kernel size, stride in the first layer plays a big role) 32 | 9. Streamline parameters in the building blocks and the prediction heads to be lower than 2M 33 | 34 | * **Please find the rest of the report, with details on experiments and analysis, in** *reports/experiments.pdf* 35 | 36 | ### Opportunities to improve 37 | 1. Use the rest of the pyramid layers for prediction (take more parameters) and have better logic to get the best detection 38 | 2. Integrate ORN layers to FPN 39 | 3. Using DenseNets with compact convolution layer configurations 40 | -------------------------------------------------------------------------------- /model.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bsridatta/Rotated-Object-Detection/4981b4bd5d352475a244508a43f2c01b47780fdc/model.hdf5 -------------------------------------------------------------------------------- /reports/Experiments _ rotated-object-detection – Weights & Biases.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bsridatta/Rotated-Object-Detection/4981b4bd5d352475a244508a43f2c01b47780fdc/reports/Experiments _ rotated-object-detection – Weights & Biases.pdf -------------------------------------------------------------------------------- /reports/FPN-summary ordered table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bsridatta/Rotated-Object-Detection/4981b4bd5d352475a244508a43f2c01b47780fdc/reports/FPN-summary ordered table.png -------------------------------------------------------------------------------- /reports/FPN_torchsummary.txt: -------------------------------------------------------------------------------- 1 | 2 | ---------------------------------------------------------------- 3 | Layer (type) Output Shape Param # 4 | ================================================================ 5 | Conv2d-1 [-1, 8, 200, 200] 72 6 | BatchNorm2d-2 [-1, 8, 200, 200] 16 7 | MaxPool2d-3 [-1, 8, 100, 100] 0 8 | Conv2d-4 [-1, 16, 100, 100] 1,152 9 | BatchNorm2d-5 [-1, 16, 100, 100] 32 10 | MaxPool2d-6 [-1, 16, 50, 50] 0 11 | Conv2d-7 [-1, 32, 25, 25] 4,608 12 | BatchNorm2d-8 [-1, 32, 25, 25] 64 13 | Mish-9 [-1, 32, 25, 25] 0 14 | Conv2d-10 [-1, 32, 25, 25] 9,216 15 | BatchNorm2d-11 [-1, 32, 25, 25] 64 16 | Conv2d-12 [-1, 32, 25, 25] 512 17 | BatchNorm2d-13 [-1, 32, 25, 25] 64 18 | Mish-14 [-1, 32, 25, 25] 0 19 | ConvBlock-15 [-1, 32, 25, 25] 0 20 | Conv2d-16 [-1, 64, 13, 13] 18,432 21 | BatchNorm2d-17 [-1, 64, 13, 13] 128 22 | Mish-18 [-1, 64, 13, 13] 0 23 | Conv2d-19 [-1, 64, 13, 13] 36,864 24 | BatchNorm2d-20 [-1, 64, 13, 13] 128 25 | Conv2d-21 [-1, 64, 13, 13] 2,048 26 | BatchNorm2d-22 [-1, 64, 13, 13] 128 27 | Mish-23 [-1, 64, 13, 13] 0 28 | ConvBlock-24 [-1, 64, 13, 13] 0 29 | Conv2d-25 [-1, 128, 7, 7] 73,728 30 | BatchNorm2d-26 [-1, 128, 7, 7] 256 31 | Mish-27 [-1, 128, 7, 7] 0 32 | Conv2d-28 [-1, 128, 7, 7] 147,456 33 | BatchNorm2d-29 [-1, 128, 7, 7] 256 34 | Conv2d-30 [-1, 128, 7, 7] 8,192 35 | BatchNorm2d-31 [-1, 128, 7, 7] 256 36 | Mish-32 [-1, 128, 7, 7] 0 37 | ConvBlock-33 [-1, 128, 7, 7] 0 38 | Conv2d-34 [-1, 256, 4, 4] 294,912 39 | BatchNorm2d-35 [-1, 256, 4, 4] 512 40 | Mish-36 [-1, 256, 4, 4] 0 41 | Conv2d-37 [-1, 256, 4, 4] 589,824 42 | BatchNorm2d-38 [-1, 256, 4, 4] 512 43 | Conv2d-39 [-1, 256, 4, 4] 32,768 44 | BatchNorm2d-40 [-1, 256, 4, 4] 512 45 | Mish-41 [-1, 256, 4, 4] 0 46 | ConvBlock-42 [-1, 256, 4, 4] 0 47 | Conv2d-43 [-1, 256, 4, 4] 65,792 48 | Conv2d-44 [-1, 256, 7, 7] 33,024 49 | Conv2d-45 [-1, 256, 13, 13] 16,640 50 | Conv2d-46 [-1, 256, 25, 25] 8,448 51 | Conv2d-47 [-1, 256, 25, 25] 590,080 52 | AdaptiveAvgPool2d-48 [-1, 256, 1, 1] 0 53 | AdaptiveAvgPool2d-49 [-1, 256, 1, 1] 0 54 | Flatten-50 [-1, 256] 0 55 | Linear-51 [-1, 1] 257 56 | Flatten-52 [-1, 256] 0 57 | Linear-53 [-1, 5] 1,285 58 | ================================================================ 59 | Total params: 1,938,238 60 | Trainable params: 1,938,238 61 | Non-trainable params: 0 62 | ---------------------------------------------------------------- 63 | Input size (MB): 0.15 64 | Forward/backward pass size (MB): 13.97 65 | Params size (MB): 7.39 66 | Estimated Total Size (MB): 21.52 67 | ---------------------------------------------------------------- -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.3.1 2 | numpy==1.18.5 3 | tqdm==4.48.2 4 | Shapely==1.7.0 5 | scikit_image==0.16.2 6 | torch==1.6.0 7 | torchvision==0.7.0 8 | tensorflow==2.3.0 9 | Pillow==7.2.0 10 | pytest==6.0.2 11 | torchsummary==1.5.1 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="bsridatta", 5 | version="0.0.1", 6 | author="Sri Datta Budaraju", 7 | author_email="b.sridatta@gmail.com", 8 | packages=setuptools.find_packages(), 9 | python_requires=">=3.6", 10 | ) 11 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bsridatta/Rotated-Object-Detection/4981b4bd5d352475a244508a43f2c01b47780fdc/src/__init__.py -------------------------------------------------------------------------------- /src/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import CallbackList, Callback 2 | from .model_checkpoint import ModelCheckpoint 3 | from .logging import Logging 4 | 5 | 6 | __all__ = [ 7 | "CallbackList", 8 | "ModelCheckpoint", 9 | "Logging", 10 | ] 11 | -------------------------------------------------------------------------------- /src/callbacks/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Callback inspritations from PyTorch Lightning - https://github.com/PyTorchLightning/PyTorch-Lightning 4 | and https://github.com/devforfu/pytorch_playground/blob/master/loop.ipynb 5 | """ 6 | 7 | import abc 8 | 9 | 10 | class Callback(abc.ABC): 11 | def setup(self, **kwargs): 12 | """Called before the training procedure""" 13 | pass 14 | 15 | def teardown(self, **kwargs): 16 | """Called after training procedure""" 17 | pass 18 | 19 | def on_epoch_start(self, **kwargs): 20 | """Called when epoch begins""" 21 | pass 22 | 23 | def on_epoch_end(self, **kwargs): 24 | """Called when epoch terminates""" 25 | pass 26 | 27 | def on_train_batch_start(self, **kwargs): 28 | """Called when training step begins""" 29 | pass 30 | 31 | def on_train_batch_end(self, **kwargs): 32 | """Called when training step ends""" 33 | pass 34 | 35 | def on_validation_batch_start(self, **kwargs): 36 | """Called when validation step begins""" 37 | pass 38 | 39 | def on_validation_batch_end(self, **kwargs): 40 | """Called when validation step ends""" 41 | pass 42 | 43 | def on_test_batch_start(self, **kwargs): 44 | """Called when test batch begins""" 45 | pass 46 | 47 | def on_test_batch_end(self, **kwargs): 48 | """Called when test batch ends""" 49 | pass 50 | 51 | def on_train_start(self, **kwargs): 52 | """Called when training loop begins""" 53 | pass 54 | 55 | def on_train_end(self, **kwargs): 56 | """Called when training loop ends""" 57 | pass 58 | 59 | def on_validation_start(self, **kwargs): 60 | """Called when validation loop begins""" 61 | pass 62 | 63 | def on_validation_end(self, **kwargs): 64 | """Called when validation loop ends""" 65 | pass 66 | 67 | def on_test_start(self, **kwargs): 68 | """Called when test loop begins""" 69 | pass 70 | 71 | def on_test_end(self, **kwargs): 72 | """Called when test loop ends""" 73 | pass 74 | 75 | 76 | class CallbackList(Callback): 77 | def __init__(self, callbacks): 78 | self.callbacks = callbacks 79 | 80 | def setup(self, **kwargs): 81 | """Called before the training procedure""" 82 | for callback in self.callbacks: 83 | callback.setup(**kwargs) 84 | 85 | def teardown(self, **kwargs): 86 | """Called after training procedure""" 87 | for callback in self.callbacks: 88 | callback.teardown(**kwargs) 89 | 90 | def on_epoch_start(self, **kwargs): 91 | """Called when epoch begins""" 92 | for callback in self.callbacks: 93 | callback.on_epoch_start(**kwargs) 94 | 95 | def on_epoch_end(self, **kwargs): 96 | """Called when epoch terminates""" 97 | for callback in self.callbacks: 98 | callback.on_epoch_end(**kwargs) 99 | 100 | def on_train_batch_start(self, **kwargs): 101 | """Called when training step begins""" 102 | for callback in self.callbacks: 103 | callback.on_train_batch_start(**kwargs) 104 | 105 | def on_train_batch_end(self, **kwargs): 106 | """Called when training step ends""" 107 | for callback in self.callbacks: 108 | callback.on_train_batch_end(**kwargs) 109 | 110 | def on_validation_batch_start(self, **kwargs): 111 | """Called when validation step begins""" 112 | for callback in self.callbacks: 113 | callback.on_validation_batch_start(**kwargs) 114 | 115 | def on_validation_batch_end(self, **kwargs): 116 | """Called when validation step ends""" 117 | for callback in self.callbacks: 118 | callback.on_validation_batch_end(**kwargs) 119 | 120 | def on_test_batch_start(self, **kwargs): 121 | """Called when test batch begins""" 122 | for callback in self.callbacks: 123 | callback.on_test_batch_start(**kwargs) 124 | 125 | def on_test_batch_end(self, **kwargs): 126 | """Called when test batch ends""" 127 | for callback in self.callbacks: 128 | callback.on_test_batch_end(**kwargs) 129 | 130 | def on_train_start(self, **kwargs): 131 | """Called when training loop begins""" 132 | for callback in self.callbacks: 133 | callback.on_train_start(**kwargs) 134 | 135 | def on_train_end(self, **kwargs): 136 | """Called when training loop ends""" 137 | for callback in self.callbacks: 138 | callback.on_train_end(**kwargs) 139 | 140 | def on_validation_start(self, **kwargs): 141 | """Called when validation loop begins""" 142 | for callback in self.callbacks: 143 | callback.on_validation_start(**kwargs) 144 | 145 | def on_validation_end(self, **kwargs): 146 | """Called when validation loop ends""" 147 | for callback in self.callbacks: 148 | callback.on_validation_end(**kwargs) 149 | 150 | def on_test_start(self, **kwargs): 151 | """Called when test loop begins""" 152 | for callback in self.callbacks: 153 | callback.on_test_start(**kwargs) 154 | 155 | def on_test_end(self, **kwargs): 156 | """Called when test loop ends""" 157 | for callback in self.callbacks: 158 | callback.on_test_end(**kwargs) 159 | -------------------------------------------------------------------------------- /src/callbacks/logging.py: -------------------------------------------------------------------------------- 1 | from .base import Callback 2 | import torch 3 | 4 | 5 | class Logging(Callback): 6 | """Logging and printing metrics""" 7 | 8 | def setup(self, opt, model, **kwargs): 9 | print(f"[INFO]: Start training procedure using device: {opt.device}") 10 | # log gradients and parameters of the model during training 11 | if opt.use_wandb: 12 | opt.logger.watch(model, log="all") 13 | 14 | def on_train_batch_end( 15 | self, opt, batch_idx, batch, dataloader, output, l_ship, l_bbox, **kwargs 16 | ): 17 | batch_len = len(batch["input"]) 18 | dataset_len = len(dataloader.dataset) 19 | n_batches = len(dataloader) 20 | 21 | # print to console 22 | print( 23 | "Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}\tL_ship: {:.4f}\tL_bbox: {:.4f}".format( 24 | opt.epoch, 25 | batch_idx * batch_len, 26 | dataset_len, 27 | 100.0 * batch_idx / n_batches, 28 | output, 29 | l_ship, 30 | l_bbox, 31 | ), 32 | end="\n", 33 | ) 34 | 35 | # log to wandb 36 | if opt.use_wandb: 37 | opt.logger.log({"train_loss": output, "l_ship": l_ship, "l_bbox": l_bbox}) 38 | 39 | def on_validation_end(self, opt, output, metrics, l_ship, l_bbox, **kwargs): 40 | # print and log metrics and loss after validation epoch 41 | print( 42 | "Valiation - Loss: {:.4f}\tL_ship: {:.4f}\tL_bbox: {:.4f}".format( 43 | output, l_ship, l_bbox 44 | ), 45 | end="\t", 46 | ) 47 | 48 | for k in metrics.keys(): 49 | print(f"{k}: {metrics[k]}", end="\t") 50 | if opt.use_wandb: 51 | opt.logger.log(metrics, commit=False) 52 | opt.logger.log( 53 | { 54 | "val_loss": output, 55 | "epoch": opt.epoch, 56 | "val_l_ship": l_ship, 57 | "val_l_bbox": l_bbox, 58 | } 59 | ) 60 | print("") 61 | 62 | def on_epoch_end(self, opt, optimizer, **kwargs): 63 | lr = optimizer.param_groups[0]["lr"] 64 | if opt.use_wandb: 65 | opt.logger.log({f"LR": lr}) 66 | print("lr @ ", lr) 67 | -------------------------------------------------------------------------------- /src/callbacks/model_checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from callbacks.base import Callback 6 | 7 | 8 | class ModelCheckpoint(Callback): 9 | def __init__(self): 10 | self.val_loss_min = float("inf") 11 | 12 | def setup(self, opt, model, optimizer, **kwargs): 13 | # Save model code to wandb 14 | if opt.use_wandb: 15 | opt.logger.save(f"{os.path.dirname(os.path.abspath(__file__))}/models/*") 16 | 17 | # Resume training 18 | if opt.resume_run not in "None": 19 | state = torch.load( 20 | f"{opt.save_dir}/{opt.resume_run}.pt", map_location=opt.device 21 | ) 22 | print( 23 | f'[INFO] Loaded Checkpoint {opt.resume_run}: @ epoch {state["epoch"]}' 24 | ) 25 | model.load_state_dict(state["model_state_dict"]) 26 | 27 | # Optimizers 28 | optimizer_state_dic = torch.load( 29 | f"{opt.save_dir}/{opt.resume_run}_optimizer.pt", map_location=opt.device 30 | ) 31 | optimizer.load_state_dict(optimizer_state_dic) 32 | 33 | def on_epoch_end(self, opt, val_loss, model, optimizer, epoch, **kwargs): 34 | # track val loss and save model when it decreases 35 | if val_loss < self.val_loss_min and opt.device != "cpu": 36 | self.val_loss_min = val_loss 37 | 38 | try: 39 | state_dict = model.module.state_dict() 40 | except AttributeError: 41 | state_dict = model.state_dict() 42 | 43 | state = { 44 | "epoch": epoch, 45 | "val_loss": val_loss, 46 | "model_state_dict": state_dict, 47 | } 48 | 49 | # model 50 | torch.save(state, f"{opt.save_dir}/{opt.run_name}.pt") 51 | if opt.use_wandb: 52 | opt.logger.save(f"{opt.save_dir}/{opt.run_name}.pt") 53 | print(f"[INFO] Saved pt: {opt.save_dir}/{opt.run_name}.pt") 54 | 55 | del state 56 | 57 | # Optimizer 58 | torch.save( 59 | optimizer.state_dict(), f"{opt.save_dir}/{opt.run_name}_optimizer.pt" 60 | ) 61 | if opt.use_wandb: 62 | opt.logger.save(f"{opt.save_dir}/{opt.run_name}_optimizer.pt") 63 | print(f"[INFO] Saved pt: {opt.save_dir}/{opt.run_name}_optimizer.pt") 64 | -------------------------------------------------------------------------------- /src/checkpoints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bsridatta/Rotated-Object-Detection/4981b4bd5d352475a244508a43f2c01b47780fdc/src/checkpoints/__init__.py -------------------------------------------------------------------------------- /src/checkpoints/model_93_ap.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bsridatta/Rotated-Object-Detection/4981b4bd5d352475a244508a43f2c01b47780fdc/src/checkpoints/model_93_ap.pt -------------------------------------------------------------------------------- /src/checkpoints/runX.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bsridatta/Rotated-Object-Detection/4981b4bd5d352475a244508a43f2c01b47780fdc/src/checkpoints/runX.pt -------------------------------------------------------------------------------- /src/checkpoints/runX_optimizer.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bsridatta/Rotated-Object-Detection/4981b4bd5d352475a244508a43f2c01b47780fdc/src/checkpoints/runX_optimizer.pt -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | from torch.utils.data import DataLoader 4 | 5 | from src.dataset import Ships 6 | 7 | 8 | def train_dataloader(opt: Namespace) -> DataLoader: 9 | print("[INFO]: Train dataloader called") 10 | dataset = Ships(n_samples=opt.train_len) 11 | sampler = None 12 | shuffle = True 13 | loader = DataLoader( 14 | dataset=dataset, 15 | batch_size=opt.batch_size, 16 | num_workers=opt.num_workers, 17 | pin_memory=opt.pin_memory, 18 | sampler=sampler, 19 | shuffle=shuffle, 20 | ) 21 | print("samples - ", len(dataset)) 22 | return loader 23 | 24 | 25 | def val_dataloader(opt: Namespace) -> DataLoader: 26 | print("[INFO]: Validation dataloader called") 27 | dataset = Ships(n_samples=opt.val_len) 28 | sampler = None 29 | shuffle = True 30 | loader = DataLoader( 31 | dataset=dataset, 32 | batch_size=opt.batch_size, 33 | num_workers=opt.num_workers, 34 | pin_memory=opt.pin_memory, 35 | sampler=sampler, 36 | shuffle=shuffle, 37 | ) 38 | print("samples - ", len(dataset)) 39 | return loader 40 | 41 | 42 | def test_dataloader(opt: Namespace) -> DataLoader: 43 | print("[INFO]: Test dataloader called") 44 | dataset = Ships(n_samples=opt.test_len) 45 | sampler = None 46 | shuffle = True 47 | loader = DataLoader( 48 | dataset=dataset, 49 | batch_size=opt.batch_size, 50 | num_workers=opt.num_workers, 51 | pin_memory=opt.pin_memory, 52 | sampler=sampler, 53 | shuffle=shuffle, 54 | ) 55 | print("samples - ", len(dataset)) 56 | return loader 57 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | import torch 5 | from torch.functional import Tensor 6 | from torch.utils.data import Dataset 7 | from tqdm import tqdm 8 | 9 | from src.rotated_ship_data import make_data 10 | 11 | 12 | class Ships(Dataset): 13 | """ship datasets with has ship labels 14 | Keyword Arguments: 15 | n_samples {int} -- items in dataset, here, items per epoch (default: {1000}) 16 | pre_load {bool} -- to make all items at once and query for each step (default: {False}) 17 | 18 | Returns: 19 | sample {Tenosr} -- p_ship, x, y, yaw, h, w 20 | """ 21 | 22 | def __init__(self, n_samples: int = 1000, pre_load: bool = False): 23 | self.n_samples = n_samples 24 | self.pre_load = pre_load 25 | if pre_load: 26 | images, labels = make_batch(n_samples) 27 | # row, col -> n_channel,row,col 28 | inp = torch.tensor(images, dtype=torch.float32) 29 | self.inps = inp[:, None, :, :] 30 | 31 | # x,y,yaw,h,w -> p(ship),x,y,yaw,h,w 32 | target = torch.tensor(labels, dtype=torch.float32) 33 | has_ship = (~torch.isnan(target[:, 0])).float().reshape(-1, 1) 34 | self.targets = torch.cat((has_ship, target), dim=1) 35 | 36 | def __len__(self): 37 | return self.n_samples 38 | 39 | def __getitem__(self, idx: int) -> Dict[str, Tensor]: 40 | if self.pre_load: 41 | inp = self.inps[idx] 42 | target = self.targets[idx] 43 | else: 44 | image, label = make_data() 45 | 46 | # row, col -> n_channel,row,col 47 | inp = torch.tensor(image, dtype=torch.float32) 48 | inp = inp[None, :, :] 49 | 50 | # x,y,yaw,h,w -> p(ship),x,y,yaw,h,w 51 | target = torch.tensor(label, dtype=torch.float32) 52 | has_ship = (~torch.isnan(target[0])).float().reshape(1) 53 | target = torch.cat((has_ship, target), dim=0) 54 | 55 | sample = {"input": inp, "target": target} 56 | 57 | return sample 58 | 59 | 60 | # Used for simple experiment 61 | 62 | 63 | def make_batch(batch_size: int): 64 | """Used only when pre_load = True 65 | 66 | Arguments: 67 | batch_size {int} -- number of samples to generate 68 | 69 | Returns: 70 | images, labels -- images with/without ship, label with has_ship 71 | """ 72 | imgs, labels = zip(*[make_data() for _ in tqdm(range(batch_size))]) 73 | imgs = np.stack(imgs) 74 | labels = np.stack(labels) 75 | return imgs, labels 76 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | def compute_loss(pred: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tensor]: 8 | """Compute loss handling no ships 9 | 10 | Arguments: 11 | pred {Tensor Batch} -- p(ship), x, y, yaw, w, h 12 | target {Tensor Batch} -- p(ship), x, y, yaw, w, h 13 | 14 | Returns: 15 | loss -- list of all - not averaged 16 | """ 17 | assert pred.shape[-1] == 6 18 | assert target.shape[-1] == 6 19 | 20 | # instances with no ships to pred bbox 21 | idx_no_ship = torch.nonzero(target[:, 0] == 0, as_tuple=True) 22 | 23 | l_bbox = lmr5p(pred[:, 1:], target[:, 1:]) 24 | l_bbox[idx_no_ship] = 0 25 | 26 | l_ship = torch.nn.functional.binary_cross_entropy_with_logits( 27 | pred[:, 0], target[:, 0], reduction="none" 28 | ) 29 | 30 | loss = l_ship + l_bbox 31 | 32 | return loss, l_ship, l_bbox 33 | 34 | 35 | def lmr5p(pred: Tensor, target: Tensor) -> Tensor: 36 | """5 parameter modulated rotation loss 37 | 38 | Arguments: 39 | pred {Tensor Batch} -- x, y, yaw, w, h 40 | target {Tensor Batch} -- x, y, yaw, w, h 41 | 42 | * X and Y position (centre of the bounding box) 43 | * Yaw (direction of heading) 44 | * Width (size tangential to the direction of yaw) 45 | * Height (size along the direct of yaw) 46 | 47 | Returns: 48 | loss for each pred, target pair without sum 49 | 50 | Reference: Eqn(2) https://arxiv.org/pdf/1911.08299.pdf 51 | """ 52 | assert pred.shape[-1] == 5 53 | assert target.shape[-1] == 5 54 | 55 | x1, x2 = pred[:, 0], target[:, 0] 56 | y1, y2 = pred[:, 1], target[:, 1] 57 | yaw1, yaw2 = pred[:, 2], target[:, 2] 58 | w1, w2 = pred[:, 3], target[:, 3] 59 | h1, h2 = pred[:, 4], target[:, 4] 60 | 61 | # center point loss 62 | lcp = torch.abs(x1 - x2) + torch.abs(y1 - y2) 63 | 64 | lmr5p_ = torch.min( 65 | lcp + torch.abs(w1 - w2) + torch.abs(h1 - h2) + torch.abs(yaw1 - yaw2), 66 | lcp 67 | + torch.abs(w1 - h2) 68 | + torch.abs(h1 - w2) 69 | + torch.abs(90 - torch.abs(yaw1 - yaw2)), 70 | ) 71 | 72 | return lmr5p_ 73 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import src.dataloader as loader 8 | import src.models as models 9 | from src.metrics import compute_metrics 10 | from src.trainer import validation_epoch 11 | 12 | 13 | def main(): 14 | # Experiment configuration, opt, is distributed to all the other modules 15 | opt = _do_setup() 16 | 17 | test_loader = loader.test_dataloader(opt) 18 | 19 | model = models.Detector_FPN() 20 | model.to(opt.device) 21 | state = torch.load( 22 | f"{os.path.dirname(os.path.abspath(__file__))}/checkpoints/model_93_ap.pt", 23 | map_location=opt.device, 24 | ) 25 | model.load_state_dict(state["model_state_dict"]) 26 | 27 | # snippet from src/trainer.py/validation_epoch() 28 | model.eval() 29 | 30 | ap = [] 31 | with torch.no_grad(): 32 | for batch_idx, batch in enumerate(test_loader): 33 | for key in batch.keys(): 34 | batch[key] = batch[key].to(opt.device) 35 | 36 | # validation step 37 | input, target = batch["input"], batch["target"] 38 | output = model(input) 39 | 40 | _prec, _rec, _f1, _ap, _iou = compute_metrics(output, target) 41 | ap.append(_ap) 42 | 43 | avg_ap = sum(ap) / len(ap) 44 | 45 | print(f"\n AP on {len(test_loader.dataset)} samples: {avg_ap}") 46 | 47 | 48 | def _do_setup(): 49 | parser = _get_argparser() 50 | opt = parser.parse_args() 51 | 52 | # fix seed for reproducibility 53 | torch.manual_seed(opt.seed) 54 | np.random.seed(opt.seed) 55 | 56 | # GPU setup 57 | use_cuda = opt.cuda and torch.cuda.is_available() 58 | device = torch.device("cuda" if use_cuda else "cpu") 59 | opt.device = device # Adding device to opt, not already in argparse 60 | opt.num_workers = 4 if use_cuda else 4 # to tune per device 61 | return opt 62 | 63 | 64 | def _get_argparser(): 65 | 66 | parser = ArgumentParser() 67 | # training specific 68 | # fmt: off 69 | parser.add_argument("--batch_size", default=256, type=int, 70 | help="number of samples per step, have more than one for batch norm") 71 | parser.add_argument("--resume_run", default="None", type=str, 72 | help="auto load ckpt") 73 | # data 74 | parser.add_argument("--test_len", default=8000, type=int, 75 | help="number of samples for testing") 76 | # device 77 | parser.add_argument("--cuda", default=True, type=lambda x: (str(x).lower() == "true"), 78 | help="enable cuda if available") 79 | parser.add_argument("--pin_memory", default=False, type=lambda x: (str(x).lower() == "true"), 80 | help="pin memory to device") 81 | parser.add_argument("--seed", default=400, type=int, 82 | help="random seed") 83 | # fmt: on 84 | return parser 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from shapely.geometry import Polygon 6 | from torch.tensor import Tensor 7 | 8 | from src.rotated_ship_data import _make_box_pts 9 | 10 | 11 | def compute_metrics( 12 | pred: torch.Tensor, 13 | target: torch.Tensor, 14 | iou_threshold: float = 0.7, 15 | pr_score: float = 0.5, 16 | ) -> Tuple[Tensor, Tensor, Tensor, float, Tensor]: 17 | """Compute IOU, AP, Precision, Recall, F1 score 18 | 19 | Arguments: 20 | pred {Tensor Batch} -- b, [p(ship), x, y, yaw, w, h] 21 | target {Tensor Batch} -- b, [p(ship), x, y, yaw, w, h] 22 | 23 | Keyword Arguments: 24 | iou_threshold {float} -- predicted bbox is correct if IOU > this value (default: {0.7}) 25 | pr_score {float} -- object conf. threshold to sample precision and recall (default: {0.5}) 26 | 27 | Returns: 28 | precision, recall, F1 @ pr_score, AP@ iou_threshold and mean IOU 29 | 30 | Reference: 31 | https://github.com/ultralytics/yolov3/blob/e0a5a6b411cca45f0d64aa932abffbf3c99b92b3/test.py 32 | """ 33 | pred = pred.detach() 34 | target = target.detach() 35 | conf = pred[:, 0] 36 | 37 | # TPs based on IOU threshold 38 | ious = torch.zeros((pred.shape[0], 1), dtype=float, device=pred.device) 39 | tp = torch.zeros((pred.shape[0], 1), dtype=bool, device=pred.device) 40 | 41 | # get all IOUS if there is bbox in target 42 | for i in torch.nonzero(target[:, 0], as_tuple=False): 43 | # enable convertion to numpy in _make_box_pts 44 | t = Polygon(_make_box_pts(*target[i, 1:].flatten().cpu())) 45 | p = Polygon(_make_box_pts(*pred[i, 1:].flatten().cpu())) 46 | iou = t.intersection(p).area / t.union(p).area 47 | ious[i] = iou 48 | 49 | # is TP if IOU > threshold 50 | tp[ious > iou_threshold] = True 51 | 52 | mean_iou = torch.mean(ious) 53 | 54 | # Calcualted Precision, Recall, F1, AP 55 | # sort by conf 56 | sorted_idx = torch.argsort(conf, dim=0, descending=True) 57 | tp, conf = tp[sorted_idx], conf[sorted_idx] 58 | 59 | tp = tp * 1.0 # boolean to float 60 | # TP, FP Cummulative 61 | tpc = torch.cumsum(tp, dim=0) 62 | fpc = torch.cumsum(1 - tp, dim=0) 63 | 64 | # TP + FN = N(Target=1) constant 65 | eps = 1e-20 66 | sum_tp_fn = (target[:, 0]).sum() + eps 67 | prec = tpc / (tpc + fpc) 68 | rec = tpc / sum_tp_fn 69 | 70 | # One P, R at conf threshold 71 | # -1 as conf decreases along x 72 | p = torch.tensor(np.interp(-pr_score, -conf.cpu(), prec[:, 0].cpu())) 73 | r = torch.tensor(np.interp(-pr_score, -conf.cpu(), rec[:, 0].cpu())) 74 | 75 | ap = compute_ap(list(rec), list(prec)) 76 | 77 | f1 = 2 * p * r / (p + r + eps) 78 | 79 | return p, r, f1, ap, mean_iou 80 | 81 | 82 | def compute_ap(recall: List[Tensor], precision: List[Tensor]) -> float: 83 | """Compute the average precision, given the recall and precision curves. 84 | 85 | Code Source: 86 | unmodified - https://github.com/rbgirshick/py-faster-rcnn. 87 | 88 | Reference: 89 | https://github.com/ultralytics/yolov3/blob/e0a5a6b411cca45f0d64aa932abffbf3c99b92b3/test.py 90 | 91 | # Arguments 92 | recall: The recall curve (list). 93 | precision: The precision curve (list). 94 | # Returns 95 | The average precision as computed in py-faster-rcnn. 96 | """ 97 | 98 | # Append sentinel values to beginning and end 99 | mrec = np.concatenate(([0.0], recall, [min(recall[-1] + 1e-3, 1.0)])).astype( 100 | "float" 101 | ) 102 | mpre = np.concatenate(([0.0], precision, [0.0])) 103 | 104 | # Compute the precision envelope 105 | mpre = np.flip(np.maximum.accumulate(np.flip(mpre))).astype("float") 106 | 107 | # Integrate area under curve 108 | method = "interp" # methods: 'continuous', 'interp' 109 | if method == "interp": 110 | x = np.linspace(0, 1, 101) # 101-point interp (COCO) 111 | ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate 112 | else: # 'continuous' 113 | # points where x axis (recall) changes 114 | i = np.where(mrec[1:] != mrec[:-1])[0] 115 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve 116 | 117 | return ap 118 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .baseline import Baseline 2 | from .detector import Detector 3 | from .detector_fpn import Detector_FPN 4 | 5 | # from .detector_orn import Detector_ORN # can not run on cpu 6 | __all__ = ["Baseline", "Detector", "Detector_FPN"] # , 'Detector_ORN'] 7 | -------------------------------------------------------------------------------- /src/models/baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Baseline(nn.Module): 6 | """ 7 | PyTorch equivalent of the given baseline model 8 | """ 9 | 10 | def __init__(self): 11 | super(Baseline, self).__init__() 12 | self.image_size = 200 13 | self.n_filters = [x * 8 for x in [1, 2, 4, 8, 16, 32, 64]] 14 | self.features = self._build_features(self.n_filters) 15 | self.classifier = nn.Sequential(nn.Flatten(), nn.Linear(self.n_filters[-1], 5)) 16 | 17 | def _build_features(self, n_filter): 18 | """Generate feature/backbone network 19 | 20 | Arguments: 21 | n_filter {list} -- number of filter for each conv block 22 | 23 | Returns: 24 | feature extraction module 25 | """ 26 | layers = nn.ModuleList() 27 | 28 | i_channels = 1 29 | for i in n_filter: 30 | o_channels = i 31 | layers.append( 32 | nn.Conv2d( 33 | i_channels, 34 | o_channels, 35 | kernel_size=3, 36 | stride=1, 37 | padding=1, 38 | bias=False, 39 | ) 40 | ) 41 | layers.append(nn.BatchNorm2d(num_features=o_channels)) 42 | layers.append(nn.ReLU()) 43 | layers.append(nn.MaxPool2d(2)) 44 | i_channels = o_channels 45 | 46 | return nn.Sequential(*layers) 47 | 48 | def forward(self, x): 49 | x = self.features(x) 50 | x = self.classifier(x) 51 | return x 52 | 53 | 54 | # Run file to see summary 55 | if __name__ == "__main__": 56 | from torchsummary import summary 57 | 58 | inp = torch.rand((2, 1, 200, 200)) 59 | net = Baseline() 60 | out = net(inp) 61 | 62 | summary(net, inp.shape[1:]) 63 | print(out.shape) 64 | -------------------------------------------------------------------------------- /src/models/detect_orn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from src.CUDA.ORN.orn.functions import oraligned1d 4 | from src.CUDA.ORN.orn.modules import ORConv2d 5 | from src.models.mish import Mish 6 | 7 | 8 | class Detector_ORN(nn.Module): 9 | """Incomplete implementation of Oriented Response Networks, runs only on gpu 10 | Implicitly learns orientation of objects using ARF(Active Rotation Filters) 11 | Advatages - better IOU, fewer parameters, faster convergence, should be ideal for the task 12 | ORN paper - https://arxiv.org/pdf/1701.01833.pdf 13 | """ 14 | 15 | def __init__(self): 16 | super(Detector_ORN, self).__init__() 17 | self.image_size = 200 18 | 19 | # Note: filters are not mul by 8 20 | self.n_filters = [x for x in [1, 2, 4, 8, 16, 32, 64]] 21 | 22 | # self.activ = nn.ReLU() 23 | self.activ = Mish() 24 | self.n_orientation = 8 25 | 26 | self.features = self._build_features( 27 | self.n_filters, self.activ, self.n_orientation 28 | ) 29 | 30 | self.classifier = nn.Sequential(nn.Flatten(), nn.Linear(self.n_filters[-1], 1)) 31 | 32 | self.regressor = nn.Sequential( 33 | nn.Flatten(), 34 | nn.Linear(self.n_filters[-1], self.n_filters[-1]), 35 | self.activ, 36 | nn.Dropout(), 37 | nn.Linear(self.n_filters[-1], 5), 38 | ) 39 | 40 | def _build_features(self, n_filter, activ, n_orientation): 41 | """Generate feature/backbone network 42 | 43 | Arguments: 44 | n_filter {list} -- number of filter for each conv block 45 | activ {nn.Module} -- activation function to be used 46 | n_orientations {int} -- orientations for ARF 47 | 48 | Returns: 49 | feature extraction module 50 | """ 51 | layers = nn.ModuleList() 52 | i_channels = 1 53 | for i in n_filter: 54 | o_channels = i 55 | 56 | if i_channels == 1: 57 | arf_config_ = (1, n_orientation) 58 | else: 59 | arf_config_ = n_orientation 60 | 61 | layers.append( 62 | ORConv2d( 63 | i_channels, 64 | o_channels, 65 | arf_config=arf_config_, 66 | kernel_size=3, 67 | stride=1, 68 | padding=1, 69 | bias=False, 70 | ) 71 | ) 72 | 73 | if i != n_filter[-1]: # mimicing the paper 74 | layers.append(nn.BatchNorm2d(num_features=o_channels)) 75 | layers.append(activ) 76 | layers.append(nn.MaxPool2d(2)) 77 | else: 78 | # last layer of the feature network 79 | layers.append(activ) 80 | 81 | i_channels = o_channels 82 | 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | x = self.features(x) 87 | # orn pooling 88 | x = oraligned1d(x, self.n_orientation) 89 | 90 | classification = self.classifier(x) 91 | regression = self.regressor(x) 92 | 93 | return torch.cat((classification, regression), dim=1) 94 | 95 | 96 | # Run file to see summary 97 | if __name__ == "__main__": 98 | from torchsummary import summary 99 | 100 | inp = torch.rand((2, 1, 200, 200)) 101 | net = Detector_ORN() 102 | out = net(inp) 103 | 104 | summary(net, inp.shape[1:]) 105 | print(out.shape) 106 | -------------------------------------------------------------------------------- /src/models/detector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .mish import Mish 5 | 6 | 7 | class Detector(nn.Module): 8 | """Equivalent to baseline architecture with addition of 9 | classification and regression heads to output 6 attr. p_ship,x,y,yaw,h,w 10 | """ 11 | 12 | def __init__(self): 13 | super(Detector, self).__init__() 14 | self.image_size = 200 15 | self.n_filters = [x * 8 for x in [1, 2, 4, 8, 16, 32, 64]] 16 | 17 | # self.activ = nn.ReLU() 18 | self.activ = Mish() 19 | 20 | self.features = self._build_features(self.n_filters, self.activ) 21 | 22 | self.classifier = nn.Sequential(nn.Flatten(), nn.Linear(self.n_filters[-1], 1)) 23 | 24 | self.regressor = nn.Sequential( 25 | nn.Flatten(), 26 | # nn.Linear(self.n_filters[-1], self.n_filters[-1]), 27 | # self.activ, 28 | # nn.Dropout(), 29 | nn.Linear(self.n_filters[-1], 5), 30 | ) 31 | 32 | def _build_features(self, n_filter, activ): 33 | """Generate feature/backbone network 34 | 35 | Arguments: 36 | n_filter {list} -- number of filter for each conv block 37 | activ {nn.Module} -- activation function to be used 38 | 39 | Returns: 40 | feature extraction module 41 | """ 42 | layers = nn.ModuleList() 43 | 44 | i_channels = 1 45 | for i in n_filter: 46 | o_channels = i 47 | 48 | layers.append( 49 | nn.Conv2d( 50 | i_channels, 51 | o_channels, 52 | kernel_size=3, 53 | stride=1, 54 | padding=1, 55 | bias=False, 56 | ) 57 | ) 58 | layers.append(nn.BatchNorm2d(num_features=o_channels)) 59 | layers.append(activ) 60 | layers.append(nn.MaxPool2d(2)) 61 | 62 | i_channels = o_channels 63 | 64 | return nn.Sequential(*layers) 65 | 66 | def forward(self, x): 67 | x = self.features(x) 68 | classification = self.classifier(x) 69 | regression = self.regressor(x) 70 | 71 | return torch.cat((classification, regression), dim=1) 72 | 73 | 74 | # Run file to see summary 75 | if __name__ == "__main__": 76 | from torchsummary import summary 77 | 78 | inp = torch.rand((2, 1, 200, 200)) 79 | net = Detector() 80 | out = net(inp) 81 | 82 | summary(net, inp.shape[1:]) 83 | print(out.shape) 84 | -------------------------------------------------------------------------------- /src/models/detector_fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .mish import Mish 6 | 7 | 8 | class ConvBlock(nn.Module): 9 | """Similar to the building block in ResNet https://arxiv.org/abs/1512.03385 10 | 2*conv+bn layers with residual connection. 11 | Represents each 'stage' from which feature pyramid is build 12 | 13 | Arguments: 14 | i_channels {int} -- input channels to the block 15 | o_channels {int} -- ouput channels from the block 16 | 17 | Keyword Arguments: 18 | stride {int} -- replace pooling with stride (default: {2}) 19 | padding {int} -- preserve feature map dims (default: {1}) 20 | """ 21 | 22 | def __init__(self, i_channels, o_channels, stride=2, padding=1): 23 | super(ConvBlock, self).__init__() 24 | self.conv1 = nn.Conv2d( 25 | i_channels, 26 | o_channels, 27 | kernel_size=3, 28 | stride=stride, 29 | padding=padding, 30 | bias=False, 31 | ) 32 | self.bn1 = nn.BatchNorm2d(num_features=o_channels) 33 | self.activ = Mish() 34 | self.conv2 = nn.Conv2d( 35 | o_channels, o_channels, kernel_size=3, padding=1, bias=False 36 | ) 37 | self.bn2 = nn.BatchNorm2d(num_features=o_channels) 38 | 39 | self.downsample = None 40 | if stride != 1: 41 | self.downsample = nn.Sequential( 42 | nn.Conv2d(i_channels, o_channels, kernel_size=1, stride=2, bias=False), 43 | nn.BatchNorm2d(o_channels), 44 | ) 45 | 46 | def forward(self, x): 47 | residual = x 48 | x = self.bn1(self.conv1(x)) 49 | x = self.activ(x) 50 | x = self.bn2(self.conv2(x)) 51 | 52 | # downsample residual to match conv output 53 | if self.downsample is not None: 54 | residual = self.downsample(residual) 55 | 56 | x = x + residual 57 | x = self.activ(x) 58 | return x 59 | 60 | 61 | class Detector_FPN(nn.Module): 62 | """ResNet(18) inspired architecture with Feature Pyramid Network 63 | Classification from the top of the pyramid and reg. from bottom 64 | 65 | References: 66 | FPN for Object Detection - https://arxiv.org/pdf/1612.03144.pdf 67 | Code References: 68 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 69 | https://github.com/kuangliu/pytorch-fpn/blob/master/fpn.py 70 | https://keras.io/examples/vision/retinanet/ 71 | """ 72 | 73 | def __init__(self): 74 | super(Detector_FPN, self).__init__() 75 | self.image_size = 200 76 | self.activ = Mish() 77 | 78 | # output filters at each 'stage' 79 | filters = [16, 32, 64, 128, 256] 80 | 81 | # Bottom-Up pathway 82 | # Extremely important to not have bigger stride in the top layers 83 | # intuition is to have precise information of the ship vertices 84 | self.conv_c1 = nn.Sequential( 85 | nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1, bias=False), 86 | nn.BatchNorm2d(8), 87 | nn.MaxPool2d(2), 88 | nn.Conv2d(8, filters[0], kernel_size=3, stride=1, padding=1, bias=False), 89 | nn.BatchNorm2d(filters[0]), 90 | nn.MaxPool2d(2), 91 | ) 92 | 93 | # Stages used to build pyramid 94 | self.conv_c2 = ConvBlock(filters[0], filters[1], stride=2, padding=1) 95 | self.conv_c3 = ConvBlock(filters[1], filters[2], stride=2, padding=1) 96 | self.conv_c4 = ConvBlock(filters[2], filters[3], stride=2, padding=1) 97 | self.conv_c5 = ConvBlock(filters[3], filters[4], stride=2, padding=1) 98 | 99 | # pyramid channels fixed to 256 - as mentioned in the cited paper 100 | py_chs = 256 101 | 102 | # Top-Down pathway 103 | # Rest of the pyramid is built by upsampling from pyramid top 104 | self.conv_pyramid_top = nn.Conv2d(filters[4], py_chs, 1) 105 | 106 | # Lateral Connections 107 | # reduce bottom up channels to match pyramid 108 | self.conv_c2_red = nn.Conv2d(filters[1], py_chs, kernel_size=1) 109 | self.conv_c3_red = nn.Conv2d(filters[2], py_chs, kernel_size=1) 110 | self.conv_c4_red = nn.Conv2d(filters[3], py_chs, kernel_size=1) 111 | 112 | # smooth pyramid levels to reduce aliasing effect from upsampling 113 | self.conv_p2_smooth = nn.Conv2d(py_chs, py_chs, kernel_size=3, padding=1) 114 | self.conv_p3_smooth = nn.Conv2d(py_chs, py_chs, kernel_size=3, padding=1) 115 | self.conv_p4_smooth = nn.Conv2d(py_chs, py_chs, kernel_size=3, padding=1) 116 | 117 | # average pooling to flatten features for cls. and reg. heads 118 | self.avg_pooling = nn.AdaptiveAvgPool2d((1, 1)) 119 | 120 | # classification and regression subnets 121 | self.classifier = nn.Sequential( 122 | # conserve params 123 | # nn.Conv2d(py_chs, 1, 124 | # kernel_size=3, stride=2, bias=False), 125 | nn.Flatten(), 126 | # nn.Linear(py_chs, py_chs), 127 | # self.activ, 128 | nn.Linear(py_chs, 1), 129 | # nn.Sigmoid() ## using BCE with logits 130 | ) 131 | 132 | self.regressor = nn.Sequential( 133 | # conserve params 134 | # nn.Conv2d(py_chs, 5, 135 | # kernel_size=3, stride=2, bias=False), 136 | nn.Flatten(), 137 | # nn.Linear(py_chs, py_chs), 138 | # self.activ, 139 | nn.Linear(py_chs, 5), 140 | ) 141 | 142 | def forward(self, x): 143 | 144 | # Bottom-Up pathway 145 | c1 = self.conv_c1(x) 146 | c2 = self.conv_c2(c1) 147 | c3 = self.conv_c3(c2) 148 | c4 = self.conv_c4(c3) 149 | c5 = self.conv_c5(c4) 150 | 151 | # Top-Down pathway 152 | p5 = self.conv_pyramid_top(c5) 153 | # add lateral connections from reduced bottom up to inner pyramid 154 | p4 = self._upsample_add(p5, self.conv_c4_red(c4)) 155 | p3 = self._upsample_add(p4, self.conv_c3_red(c3)) 156 | p2 = self._upsample_add(p3, self.conv_c2_red(c2)) 157 | 158 | # smoothing the pyramid 159 | # p5 only goes through 1x1 so no need to smooth 160 | # conserve parameters with assumption that finer res. can preditict better boxes 161 | # else get highest conf. pred or smthg more complex 162 | # p4 = self.conv_p4_smooth(p4) 163 | # p3 = self.conv_p3_smooth(p3) 164 | p2 = self.conv_p2_smooth(p2) 165 | 166 | # top of the pyramid has the best semantic features 167 | # and the lowest or the finest layer has best global features 168 | cls_feat = self.avg_pooling(p5) 169 | reg_feat = self.avg_pooling(p2) 170 | 171 | classification = self.classifier(cls_feat) 172 | regression = self.regressor(reg_feat) 173 | 174 | p_ship = classification.view(x.shape[0], 1) 175 | bbox = regression.view(x.shape[0], 5) 176 | 177 | return torch.cat((p_ship, bbox), dim=1) 178 | 179 | def _upsample_add(self, p_prev, lc): 180 | """takes a pyramid layer, upsamples by factor of 2 and adds corres. lateral connections 181 | 182 | Arguments: 183 | p_prev {tensor} -- coarser feature map 184 | lc {tensor} -- lateral connection 185 | 186 | Returns: 187 | finer feature map, lower pyramid layer 188 | """ 189 | p = F.interpolate(p_prev, size=(lc.shape[-2:]), mode="nearest") 190 | return p + lc 191 | 192 | 193 | # Run file to see summary 194 | if __name__ == "__main__": 195 | from torchsummary import summary 196 | 197 | inp = torch.rand((2, 1, 200, 200)) 198 | net = Detector_FPN() 199 | out = net(inp) 200 | 201 | # print(out.shape) 202 | summary(net, inp.shape[1:]) 203 | # print(net) 204 | -------------------------------------------------------------------------------- /src/models/mish.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | @torch.jit.script 6 | def mish(input): 7 | """ 8 | Source: https://github.com/digantamisra98/Mish/blob/master/Mish/Torch/mish.py 9 | 10 | mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) 11 | """ 12 | return input * torch.tanh(F.softplus(input)) 13 | 14 | 15 | class Mish(torch.nn.Module): 16 | """ 17 | Source: https://github.com/digantamisra98/Mish/blob/master/Mish/Torch/mish.py 18 | 19 | Applies the mish function element-wise: 20 | Shape: 21 | - Input: (N, *) where * means, any number of additional 22 | dimensions 23 | - Output: (N, *), same shape as the input 24 | """ 25 | 26 | def __init__(self): 27 | super().__init__() 28 | 29 | def forward(self, input): 30 | return mish(input) 31 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import gc 3 | import os 4 | from argparse import ArgumentParser, Namespace 5 | 6 | import numpy as np 7 | import torch 8 | 9 | import src.dataloader as loader 10 | import src.models as models 11 | from src.callbacks import CallbackList, Logging, ModelCheckpoint 12 | from src.trainer import training_epoch, validation_epoch 13 | 14 | 15 | def main(): 16 | # Experiment configuration, opt, is distributed to all the other modules 17 | opt = _do_setup() 18 | 19 | train_loader = loader.train_dataloader(opt) 20 | val_loader = loader.val_dataloader(opt) 21 | 22 | model = models.Detector_FPN() 23 | optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate) 24 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 25 | optimizer, mode="min", verbose=True 26 | ) 27 | 28 | # data parallel 29 | if torch.cuda.device_count() > 1: 30 | print(f"[INFO]: Using {torch.cuda.device_count()} GPUs") 31 | model = torch.nn.DataParallel(model) 32 | 33 | model.to(opt.device) 34 | 35 | # custom callbacks 36 | cb = CallbackList([Logging(), ModelCheckpoint()]) 37 | 38 | # required info for - checkpoint cb 39 | cb.setup(opt=opt, model=model, optimizer=optimizer) 40 | 41 | # Train and Val 42 | for epoch in range(1, opt.epochs + 1): 43 | opt.epoch = epoch 44 | training_epoch(cb, opt, model, train_loader, optimizer) 45 | val_loss = validation_epoch(cb, opt, model, val_loader) 46 | scheduler.step(val_loss) 47 | 48 | # required info for - checkpoint cb 49 | cb.on_epoch_end( 50 | opt=opt, val_loss=val_loss, model=model, optimizer=optimizer, epoch=epoch 51 | ) 52 | 53 | del val_loss 54 | gc.collect() 55 | 56 | # sync opt with wandb for easy experiment comparision 57 | if opt.use_wandb: 58 | wandb = opt.logger 59 | opt.logger = None # wandb cant have objects in its config 60 | wandb.config.update(opt) 61 | 62 | 63 | def _do_setup(): 64 | parser = _get_argparser() 65 | opt = parser.parse_args() 66 | 67 | # fix seed for reproducibility 68 | torch.manual_seed(opt.seed) 69 | np.random.seed(opt.seed) 70 | 71 | # GPU setup 72 | use_cuda = opt.cuda and torch.cuda.is_available() 73 | device = torch.device("cuda" if use_cuda else "cpu") 74 | opt.device = device # Adding device to opt, not already in argparse 75 | opt.num_workers = 4 if use_cuda else 4 # to tune per device 76 | opt.run_name = "runX" 77 | 78 | # wandb for experiment monitoring 79 | os.environ["WANDB_NOTES"] = "test" 80 | if opt.use_wandb: 81 | import wandb 82 | 83 | if not use_cuda: 84 | # os.environ['WANDB_MODE'] = 'dryrun' # ignore when debugging on cpu 85 | os.environ["WANDB_TAGS"] = "CPU" 86 | wandb.init( 87 | anonymous="allow", project="rotated-object-detection", config=opt 88 | ) 89 | else: 90 | wandb.init( 91 | anonymous="allow", project="rotated-object-detection", config=opt 92 | ) 93 | 94 | opt.logger = wandb 95 | # opt.logger.run.save() 96 | opt.run_name = opt.logger.run.name # handle name change in wandb 97 | atexit.register(_sync_before_exit, opt, wandb) 98 | 99 | return opt 100 | 101 | 102 | def _sync_before_exit(opt: Namespace, wandb): 103 | print("[INFO]: Sync wandb before terminating") 104 | opt.logger = None # wandb cant have objects in its config 105 | wandb.config.update(opt) 106 | 107 | 108 | def _get_argparser(): 109 | 110 | parser = ArgumentParser() 111 | # fmt: off 112 | # training specific 113 | parser.add_argument("--epochs", default=150, type=int, 114 | help="number of epochs to train") 115 | parser.add_argument("--batch_size", default=256, type=int, 116 | help="number of samples per step, have more than one for batch norm") 117 | parser.add_argument("--learning_rate", default=1e-3, type=float, 118 | help="learning rate for all optimizers") 119 | parser.add_argument("--resume_run", default="None", type=str, 120 | help="auto load ckpt") 121 | # data 122 | parser.add_argument("--train_len", default=3, type=int, 123 | help="number of samples for training") 124 | parser.add_argument("--val_len", default=3, type=int, 125 | help="number of samples for validation") 126 | parser.add_argument("--test_len", default=2, type=int, 127 | help="number of samples for testing") 128 | # output 129 | parser.add_argument("--use_wandb", default=False, type=bool, 130 | help="use wandb to monitor training") 131 | parser.add_argument("--save_dir", default=f"{os.path.dirname(os.path.abspath(__file__))}/checkpoints", type=str, 132 | help="path to save checkpoints") 133 | # device 134 | parser.add_argument("--cuda", default=True, type=lambda x: (str(x).lower() == "true"), 135 | help="enable cuda if available") 136 | parser.add_argument("--pin_memory", default=False, type=lambda x: (str(x).lower() == "true"), 137 | help="pin memory to device") 138 | parser.add_argument("--seed", default=400, type=int, help="random seed") 139 | # fmt: on 140 | 141 | return parser 142 | 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from argparse import Namespace 3 | 4 | import torch 5 | from torch.nn.modules.module import Module 6 | from torch.optim import Optimizer 7 | from torch.utils.data import DataLoader 8 | 9 | from src.callbacks.base import CallbackList 10 | from src.loss import compute_loss 11 | from src.metrics import compute_metrics 12 | from src.rotated_ship_data import score_iou 13 | 14 | 15 | def training_epoch( 16 | cb: CallbackList, 17 | opt: Namespace, 18 | model: Module, 19 | train_loader: DataLoader, 20 | optimizer: Optimizer, 21 | ) -> None: 22 | """logic for each training epoch""" 23 | model.train() 24 | 25 | for batch_idx, batch in enumerate(train_loader): 26 | for key in batch.keys(): 27 | batch[key] = batch[key].to(opt.device) 28 | 29 | optimizer.zero_grad() 30 | 31 | # training step 32 | input, target = batch["input"], batch["target"] 33 | output = model(input) 34 | loss, _l_ship, _l_bbox = compute_loss(output, target) 35 | loss = loss.mean() 36 | loss.backward() 37 | optimizer.step() 38 | 39 | # required info for - logging cb 40 | cb.on_train_batch_end( 41 | opt=opt, 42 | batch_idx=batch_idx, 43 | batch=batch, 44 | dataloader=train_loader, 45 | output=loss.item(), 46 | l_ship=_l_ship.mean().item(), 47 | l_bbox=_l_bbox.mean().item(), 48 | ) 49 | 50 | del loss 51 | del batch 52 | gc.collect() 53 | 54 | 55 | def validation_epoch( 56 | cb: CallbackList, 57 | opt: Namespace, 58 | model: Module, 59 | val_loader: DataLoader, 60 | ) -> torch.Tensor: 61 | """logic for each validation epoch""" 62 | model.eval() 63 | 64 | # metrics to return 65 | losses = [] 66 | prec = [] 67 | rec = [] 68 | f1 = [] 69 | ap = [] 70 | iou = [] 71 | l_ship = [] 72 | l_bbox = [] 73 | 74 | with torch.no_grad(): 75 | for batch_idx, batch in enumerate(val_loader): 76 | for key in batch.keys(): 77 | batch[key] = batch[key].to(opt.device) 78 | 79 | # validation step 80 | input, target = batch["input"], batch["target"] 81 | output = model(input) 82 | 83 | loss, _l_ship, _l_bbox = compute_loss(output, target) 84 | _prec, _rec, _f1, _ap, _iou = compute_metrics(output, target) 85 | 86 | # append incase analysis of distribution is of interest 87 | losses.append(loss) 88 | l_ship.append(_l_ship) 89 | l_bbox.append(_l_bbox) 90 | prec.append(_prec) 91 | rec.append(_rec) 92 | f1.append(_f1) 93 | ap.append(_ap) 94 | iou.append(_iou) 95 | 96 | loss_avg = torch.mean(torch.cat(losses)) 97 | l_ship = torch.mean(torch.cat(l_ship)) 98 | l_bbox = torch.mean(torch.cat(l_bbox)) 99 | 100 | metrics = {} 101 | for k, m in zip(["prec", "rec", "f1", "ap", "iou"], [prec, rec, f1, ap, iou]): 102 | m = sum(m) / len(m) 103 | metrics[k] = m 104 | 105 | cb.on_validation_end( 106 | opt=opt, output=loss_avg, metrics=metrics, l_ship=l_ship, l_bbox=l_bbox 107 | ) 108 | 109 | return loss_avg 110 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bsridatta/Rotated-Object-Detection/4981b4bd5d352475a244508a43f2c01b47780fdc/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from src.dataset import Ships 5 | 6 | 7 | @pytest.fixture 8 | def random(): 9 | torch.manual_seed(0) 10 | 11 | 12 | def test_label_has_ship(): 13 | dataset = Ships(15) 14 | for i in range(len(dataset)): 15 | sample = dataset[i]["target"] 16 | if sample[0] == 1: 17 | assert ~torch.isnan(sample[1]).item() 18 | elif sample[0] == 0: 19 | assert torch.isnan(sample[1]).item() 20 | else: 21 | assert False 22 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from src.loss import compute_loss, lmr5p 5 | from src.metrics import compute_metrics 6 | 7 | 8 | @pytest.fixture 9 | def random(): 10 | torch.manual_seed(0) 11 | 12 | 13 | nan = float("nan") 14 | 15 | 16 | # pred, target 17 | # True, True TP # IOU = 1 18 | # False, False TN # IOU = 1 19 | # False, True FN # IOU = 1 20 | # True, False FP # IOU = 1 21 | # True, True FP # IOU = 0 TP -> FP IOU