├── .gitignore ├── assets ├── image.jpg └── fold_results.png ├── requirements.txt ├── src ├── config.py ├── models.py ├── data.py ├── augmentations.py └── trainer.py ├── LICENSE ├── README.md └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode -------------------------------------------------------------------------------- /assets/image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanaymeh/alien-signal-detection/HEAD/assets/image.jpg -------------------------------------------------------------------------------- /assets/fold_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tanaymeh/alien-signal-detection/HEAD/assets/fold_results.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | matplotlib 4 | seaborn 5 | tqdm 6 | opencv-contrib-python 7 | timm 8 | torch 9 | torchvision 10 | albumentations 11 | sklearn -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | N_SPLITS = 3 3 | model_name = 'vit_base_patch16_224' 4 | resize = (224, 224) 5 | TRAIN_BS = 32 6 | VALID_BS = 16 7 | num_workers = 8 8 | NB_EPOCHS = 10 9 | LABELS = 1 10 | FILE = "/input/train_labels.csv" 11 | FOLDER = "/input/train" -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import timm 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .config import Config 7 | 8 | class VITModel(nn.Module): 9 | """ 10 | Model Class for VIT Model 11 | """ 12 | def __init__(self, model_name=Config.model_name, pretrained=True): 13 | super(VITModel, self).__init__() 14 | self.backbone = timm.create_model(model_name, pretrained, in_chans=1) 15 | self.backbone.head = nn.Linear(self.backbone.head.in_features, Config.LABELS) 16 | 17 | def forward(self, x): 18 | x = self.backbone(x) 19 | return x 20 | 21 | class MLPMixer(nn.Module): 22 | """ 23 | Model Class for MLP Mixer Model 24 | """ 25 | def __init__(self, model_name=Config.model_name, pretrained=True): 26 | super(MLPMixer, self).__init__() 27 | self.backbone = timm.create_model(model_name, pretrained, in_chans=1) 28 | self.backbone.head = nn.Linear(self.backbone.head.in_features, Config.LABELS) 29 | 30 | def forward(self, x): 31 | x = self.backbone(x) 32 | return x -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | from .config import Config 9 | from .augmentations import Augments 10 | 11 | 12 | class SETIData(Dataset): 13 | def __init__(self, images, targets, is_test=False, augmentations=None): 14 | self.images = images 15 | self.targets = targets 16 | self.is_test = is_test 17 | self.augmentations = augmentations 18 | 19 | def __getitem__(self, index): 20 | img, target = self.images[index], self.targets[index] 21 | 22 | img = np.load(img) 23 | img = np.vstack(img) 24 | img = img.transpose(1, 0) 25 | img = img.astype("float")[..., np.newaxis] 26 | 27 | if self.augmentations: 28 | img = self.augmentations(image=img)['image'] 29 | 30 | if self.is_test: 31 | return img 32 | 33 | else: 34 | target = self.targets[index] 35 | return img, target 36 | 37 | def __len__(self): 38 | return len(self.images) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Tanay Mehta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/augmentations.py: -------------------------------------------------------------------------------- 1 | from albumentations import ( 2 | HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90, 3 | Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue, 4 | IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop, 5 | IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize 6 | ) 7 | 8 | from albumentations.pytorch import ToTensorV2 9 | 10 | from .config import Config 11 | 12 | class Augments: 13 | """ 14 | Contains Train, Validation Augments 15 | """ 16 | train_augments = Compose([ 17 | Resize(*Config.resize, p=1.0), 18 | HorizontalFlip(p=0.5), 19 | VerticalFlip(p=0.5), 20 | ShiftScaleRotate(p=0.5, shift_limit=0.2, scale_limit=0.2, rotate_limit=20, border_mode=0, value=0, mask_value=0), 21 | RandomResizedCrop(*Config.resize, p=1.0), 22 | ToTensorV2(p=1.0), 23 | ],p=1.) 24 | 25 | valid_augments = Compose([ 26 | Resize(*Config.resize, p=1.0), 27 | ToTensorV2(p=1.0), 28 | ], p=1.) -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import numpy as np 3 | import pandas as pd 4 | from tqdm.notebook import tqdm 5 | import cv2 6 | import gc 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | import timm 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import Dataset, DataLoader 14 | 15 | from sklearn.metrics import roc_auc_score 16 | 17 | def train_one_epoch(model, device, optimizer, dataloader, loss_fn, scheduler=None): 18 | """Trains a given model for 1 epoch on the given data 19 | 20 | Args: 21 | model: Main model 22 | device: Device on which model will be trained 23 | optimizer: Optimizer that will optimize during training 24 | dataloader: Training Dataloader 25 | loss_fn: Training Loss function. Will be optimized 26 | scheduler (optional): Scheduler for the learning rate. Defaults to None. 27 | """ 28 | prog_bar = tqdm(enumerate(dataloader), total=len(dataloader)) 29 | model.train() 30 | running_loss = 0 31 | for idx, (img, target) in prog_bar: 32 | img = img.to(device, torch.float) 33 | target = target.to(device, torch.float) 34 | 35 | output = model(img).view(-1) 36 | loss = loss_fn(output, target) 37 | 38 | # Sending the data from GPU to CPU in a numpy form (using .item()) consumes memory 39 | # So only do it once 40 | loss_item = loss.item() 41 | prog_bar.set_description('loss: {:.2f}'.format(loss_item)) 42 | 43 | loss.backward() 44 | optimizer.step() 45 | 46 | if scheduler: 47 | scheduler.step() 48 | 49 | optimizer.zero_grad(set_to_none=True) 50 | 51 | running_loss += loss_item 52 | 53 | return running_loss / len(dataloader) 54 | 55 | @torch.no_grad() 56 | def valid_one_epoch(model, device, dataloader, loss_fn): 57 | """Validates the model on the validation set through all batches 58 | 59 | Args: 60 | model: Main model 61 | device: Device on which model will be validated 62 | dataloader: Validation Dataloader 63 | loss_fn: Validation Loss function. Will NOT be optimized 64 | """ 65 | prog_bar = tqdm(enumerate(dataloader), total=len(dataloader)) 66 | all_targets, all_predictions = [], [] 67 | running_loss = 0 68 | model.eval() 69 | for idx, (img, target) in prog_bar: 70 | img = img.to(device, torch.float) 71 | target = target.to(device, torch.float) 72 | 73 | output = model(img).view(-1) 74 | 75 | loss = loss_fn(output, target) 76 | loss_item = loss.item() 77 | 78 | prog_bar.set_description('val_loss: {:.2f}'.format(loss_item)) 79 | 80 | all_targets.extend(target.cpu().detach().numpy().tolist()) 81 | all_predictions.extend(torch.sigmoid(output).cpu().detach().numpy().tolist()) 82 | 83 | running_loss += loss_item 84 | 85 | val_roc_auc = roc_auc_score(all_targets, all_predictions) 86 | return val_roc_auc, running_loss / len(dataloader) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
4 |
5 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |