├── .github └── ISSUE_TEMPLATE │ ├── general-questions.md │ ├── help-wanted.md │ └── bug_report.md ├── requirements.txt ├── configs └── sample.yaml ├── airogs_dataset.py ├── early_stopping.py ├── README.md ├── notebooks ├── Testing Ensemble-AIROGS.ipynb ├── Testing Ensemble-RIM-ONE DL.ipynb ├── Testing RIM ONE DL.ipynb └── central_crop.ipynb ├── run.py └── run_fold.py /.github/ISSUE_TEMPLATE/general-questions.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Questions 3 | about: For any general questions. 4 | title: '' 5 | labels: question 6 | assignees: ahmed1996said, rmuhtaseb 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | grad-cam==1.3.1 2 | ipython==7.27.0 3 | jupyter-client==7.0.2 4 | jupyter-core==4.7.1 5 | jupyter-server==1.11.0 6 | jupyterlab==3.1.12 7 | jupyterlab-server==2.8.1 8 | matplotlib==3.5.1 9 | numpy==1.21.4 10 | opencv-python==4.5.3.56 11 | opencv-python-headless==4.5.3.56 12 | pandas==1.3.3 13 | Pillow==8.3.2 14 | scikit-image==0.18.3 15 | scikit-learn==0.24.2 16 | scipy==1.7.1 17 | timm==0.5.4 18 | torch==1.9.0 19 | torch-summary==1.4.5 20 | torchmetrics==0.5.1 21 | torchvision==0.10.0 22 | tqdm==4.62.2 23 | wandb==0.12.2 24 | -------------------------------------------------------------------------------- /configs/sample.yaml: -------------------------------------------------------------------------------- 1 | exp_id: 12345 2 | fold_num: 0 3 | resize: 256 4 | epochs: 50 5 | batch_size: 64 6 | num_workers: 64 7 | data_dir: /PATH/TO/DATA/PARENT/FOLDER 8 | images_dir_name: /PATH/TO/DATA 9 | run_test: True 10 | model_name: resnet18 # any model supported by timm (https://github.com/rwightman/pytorch-image-models/tree/master/timm/models) 11 | pretrained: True 12 | optimizer_name: Adam 13 | momentum: 0.1 14 | scheduler: null 15 | lr_step_period: 5 16 | lr: 1e-3 17 | patience: 5 18 | apply_augs: False 19 | apply_clahe: False 20 | dropout: 0.0 21 | apply_scaling: False 22 | polar_transform: False 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/help-wanted.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Help Wanted 3 | about: Having issues running the code? Let us know! 4 | title: '' 5 | labels: help wanted 6 | assignees: ahmed1996said, rmuhtaseb 7 | 8 | --- 9 | 10 | Describe the issue you're facing 11 | A clear and concise description of what the bug is. 12 | 13 | To Reproduce 14 | Steps to reproduce the behavior: 15 | 16 | Go to '...' 17 | Click on '....' 18 | Scroll down to '....' 19 | See error 20 | Expected behavior 21 | A clear and concise description of what you expected to happen. 22 | 23 | Screenshots 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | OS (please complete the following information): 27 | 28 | Linux, windows or MacOS? 29 | Additional context 30 | Add any other context about the problem here. 31 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: ahmed1996said, rmuhtaseb 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **OS (please complete the following information):** 27 | - Linux, windows or MacOS? 28 | 29 | 30 | **Additional context** 31 | Add any other context about the problem here. 32 | -------------------------------------------------------------------------------- /airogs_dataset.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import pandas as pd 3 | import glob 4 | import os 5 | from PIL import Image 6 | import cv2 7 | import numpy as np 8 | from PIL import Image 9 | from skimage.exposure import equalize_adapthist 10 | from skimage.transform import warp_polar 11 | 12 | def polar(image): 13 | return warp_polar(image, radius=(max(image.shape) // 2), multichannel=True) 14 | 15 | class Airogs(torchvision.datasets.VisionDataset): 16 | 17 | def __init__(self, split='train', path='', images_dir_name='train',transforms=None,polar_transforms=False,apply_clahe=False): 18 | self.split = split 19 | self.path = path 20 | self.images_dir_name = images_dir_name 21 | self.df_files = pd.read_csv(os.path.join(self.path, self.split + ".csv")) ## columns = ['challenge_id', 'class', 'referable', 'gradable'] 22 | self.transforms = transforms 23 | self.polar_transforms = polar_transforms 24 | self.apply_clahe = apply_clahe 25 | print("{} size: {}".format(split, len(self.df_files))) 26 | 27 | def __getitem__(self, index): 28 | file_name = self.df_files.loc[index, 'challenge_id'] 29 | path_mask = os.path.join(self.path, self.images_dir_name,"*" ,file_name + '.jpg') 30 | image_path = glob.glob(path_mask)[0] 31 | image = Image.open(image_path) 32 | 33 | label = self.df_files.loc[index, 'class'] 34 | label = 0 if label == 'NRG' else 1 35 | 36 | if self.polar_transforms: 37 | image = image = np.array(image, dtype=np.float64) 38 | image = polar(image) 39 | 40 | if self.apply_clahe: 41 | image = np.array(image, dtype=np.float64) / 255.0 42 | image = equalize_adapthist(image) 43 | image = (image*255).astype('uint8') 44 | 45 | assert(self.transforms != None) 46 | image = self.transforms(image) 47 | return image, label 48 | 49 | def __len__(self): 50 | return len(self.df_files) 51 | 52 | -------------------------------------------------------------------------------- /early_stopping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class EarlyStopping: 5 | """Early stops the training if validation loss doesn't improve after a given patience.""" 6 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): 7 | """ 8 | Args: 9 | patience (int): How long to wait after last time validation loss improved. 10 | Default: 7 11 | verbose (bool): If True, prints a message for each validation loss improvement. 12 | Default: False 13 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 14 | Default: 0 15 | path (str): Path for the checkpoint to be saved to. 16 | Default: 'checkpoint.pt' 17 | trace_func (function): trace print function. 18 | Default: print 19 | """ 20 | self.patience = patience 21 | self.verbose = verbose 22 | self.counter = 0 23 | self.best_score = None 24 | self.early_stop = False 25 | self.val_loss_min = np.Inf 26 | self.delta = delta 27 | self.path = path 28 | self.trace_func = trace_func 29 | def __call__(self, val_loss, model): 30 | 31 | score = -val_loss 32 | 33 | if self.best_score is None: 34 | self.best_score = score 35 | self.save_checkpoint(val_loss, model) 36 | elif score < self.best_score + self.delta: 37 | self.counter += 1 38 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 39 | if self.counter >= self.patience: 40 | self.early_stop = True 41 | else: 42 | self.best_score = score 43 | self.save_checkpoint(val_loss, model) 44 | self.counter = 0 45 | 46 | def save_checkpoint(self, val_loss, model): 47 | '''Saves model when validation loss decrease.''' 48 | if self.verbose: 49 | self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 50 | torch.save(model.state_dict(), self.path) 51 | self.val_loss_min = val_loss -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![arXiv](https://img.shields.io/badge/arXiv-2205.12902-.svg)](https://arxiv.org/abs/2205.12902) 2 | 3 | # GARDNet: Robust Multi-View Network for Glaucoma Classification in Color Fundus Images 4 | **Authors:** Ahmed Al Mahrooqi, Dmitrii Medvedev, Rand Muhtaseb, Mohammad Yaqub 5 | 6 | **Instituion:** Mohamed bin Zayed University of Artificial Intelligence 7 | 8 | ## :page_facing_up: Abstract 9 | Glaucoma is one of the most severe eye diseases, characterized by rapid progression and leading to irreversible blindness. It is often the case that diagnostics is carried out when one’s sight has already significantly degraded due to the lack of noticeable symptoms at early stage of the disease. Regular glaucoma screenings of the population shall improve early-stage detection, however the desirable frequency of etymological checkups is often not feasible due to the excessive load imposed by manual diagnostics on limited number of specialists. Considering the basic methodology to detect glaucoma is to analyze fundus images for the optic-disc-to-optic-cup ratio, Machine Learning algorithms can offer sophisticated methods for image processing and classification. In our work, we propose an advanced image pre-processing technique combined with a multi-view network of deep classification models to categorize glaucoma. Our Glaucoma Automated Retinal Detection Network (GARDNet) has been successfully tested on Rotterdam EyePACS AIROGS dataset with an AUC of 0.92, and then additionally fine-tuned and tested on RIM-ONE DL dataset with an AUC of 0.9308 outperforming the state- of-the-art of 0.9272. 10 | 11 | This work has been accepted at MICCAI 2022 workshop [OMIA9](https://sites.google.com/view/omia9). 12 | ### :key: Keywords 13 | Glaucoma Classification, Color Fundus Images. Computer Aided Diagnosis 14 | 15 | 16 | ## :open_file_folder: File Structure 17 | 18 | / 19 | ├── configs # contains experiment configuration .yaml files required to train 20 | ├── notebooks # contains .ipynb notebooks for preprocessing and testing 21 | ├── GradCAM.ipynb # GradCAM Visualization script 22 | ├── Testing Ensemble-AIROGS.ipynb # Testing script for the ensemble model on AIROGS 23 | ├── Testing Ensemble-RIM-ONE DL.ipynb # Testing script for the ensemble model on RIM-ONE DL 24 | ├── Testing RIM ONE DL.ipynb # Finetune and testing script on RIM-ONE DL 25 | ├── bbox_crop.ipynb # Preprocessing script for cropping AIROGS using bounding box coords 26 | ├── central_crop.ipynb # Preprocessing scrirpt for cropping AIROGS using central crop 27 | └── Optic Disc Segmentation and Crop.ipynb # Optic disc segmentation and bounding box coords for preprocessing 28 | ├── README.md 29 | ├── airogs_dataset.py # contains dataset class for AIROGS 30 | ├── early_stopping.py # contains script for early stopping 31 | ├── requirements.txt # contains packages and libraries needed to run our code 32 | ├── run.py # contains training script 33 | └── run_fold.py # contains training script with cross validation 34 | ## :framed_picture: Data 35 | - Download the training images from the [Airogs Challenge Website](https://airogs.grand-challenge.org/data-and-challenge/) 36 | - Download the training and testing images for [RIM-ONE DL Dataset](https://bit.ly/rim-one-dl-images) 37 | - Download the the CSVs and checkpoints from this [Google Drive link](https://drive.google.com/drive/folders/1i9y8IZfKJkNtcxeIJ10EU2Z25eeMwFKe?usp=sharing): 38 | ## :package: Requirements 39 | You can install all requirements using `pip` by running this command: 40 | 41 | ``` pip install -r requirements.txt``` 42 | 43 | Generally speaking, our code uses the following core packages: 44 | - PyTorch 1.9.0 45 | - [wandb](https://wandb.ai): you need to create an account for logging purposes 46 | 47 | ## :arrow_forward: Training 48 | For training, you can run the following code: 49 | 50 | ``` python run_fold.py configs/sample.yaml``` 51 | 52 | ## :question: Questions? 53 | For all code related questions, please create a GitHub Issue above and our team will respond to you as soon as possible. 54 | 55 | -------------------------------------------------------------------------------- /notebooks/Testing Ensemble-AIROGS.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "d26271aa", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import math\n", 12 | "import torch\n", 13 | "import torchvision\n", 14 | "import torch.nn as nn\n", 15 | "import torch.optim as optim\n", 16 | "from torch.nn import CrossEntropyLoss\n", 17 | "from tqdm import tqdm\n", 18 | "from torchvision.models import resnet18\n", 19 | "import timm\n", 20 | "from torch.utils.data import DataLoader\n", 21 | "from skimage.io import imread\n", 22 | "import sklearn\n", 23 | "from sklearn import metrics\n", 24 | "from sklearn.metrics import f1_score\n", 25 | "from sklearn.utils import class_weight\n", 26 | "import pandas as pd\n", 27 | "import numpy as np\n", 28 | "import torchvision.transforms as transforms\n", 29 | "from PIL import Image\n", 30 | "import matplotlib.pyplot as plt\n", 31 | "import random\n", 32 | "from early_stopping import EarlyStopping\n", 33 | "import os\n", 34 | "from airogs_dataset import Airogs\n", 35 | "import wandb\n", 36 | "import sys\n", 37 | "import sklearn.metrics\n", 38 | "import yaml\n", 39 | "torch.multiprocessing.set_sharing_strategy('file_system')" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "98fe9b51", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "############ CONFIGS ############\n", 50 | "\n", 51 | "num_workers = 32\n", 52 | "batch_size = 8\n", 53 | "\n", 54 | "\n", 55 | "#original\n", 56 | "model_0 = timm.create_model('efficientnet_b0',num_classes=2)\n", 57 | "model_0.load_state_dict(torch.load('PATH_TO_CKPTS/airogs_1.pt')['state_dict'])\n", 58 | "\n", 59 | "#polar\n", 60 | "model_1 = timm.create_model('efficientnet_b0',num_classes=2)\n", 61 | "model_1.load_state_dict(torch.load('PATH_TO_CKPTS/airogs_2.pt')['state_dict'])\n", 62 | "\n", 63 | "#cropped\n", 64 | "model_2 = timm.create_model('efficientnet_b0',num_classes=2)\n", 65 | "model_2.load_state_dict(torch.load('PATH_TO_CKPTS/airogs_3.pt')['state_dict'])\n", 66 | "\n", 67 | "models=[model_0,model_1,model_2]\n", 68 | "\n", 69 | "\n", 70 | "transforms = [\n", 71 | " torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Resize((256,256))]),\n", 72 | " torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Resize((256,256))]),\n", 73 | " torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Resize((256,256))]),\n", 74 | "]\n", 75 | "\n", 76 | "apply_clahe = [True,True,True]\n", 77 | "path = ['PATH_TO_DATA_DIR',\n", 78 | " 'PATH_TO_DATA_DIR',\n", 79 | " 'PATH_TO_DATA_DIR'\n", 80 | " ]\n", 81 | "\n", 82 | "images_dir_name = ['PATH_TO_UNCROPPED_IMG_FOLDER',\n", 83 | " 'PATH_TO_CROPPED_IMG_FOLDER',\n", 84 | " 'PATH_TO_CROPPED_IMG_FOLDER'\n", 85 | "]\n", 86 | "test_datasets = [Airogs(path=path[0],images_dir_name=images_dir_name[0],split=\"test\",transforms=transforms[0],apply_clahe=apply_clahe[0]),\n", 87 | " Airogs(path=path[1],images_dir_name=images_dir_name[1],split=\"test\",transforms=transforms[1],apply_clahe=apply_clahe[1],polar_transforms=True),\n", 88 | " Airogs(path=path[2],images_dir_name=images_dir_name[2],split=\"test\",transforms=transforms[2],apply_clahe=apply_clahe[2]),\n", 89 | " ]\n", 90 | " \n", 91 | "\n", 92 | "test_loader = [\n", 93 | " DataLoader(test_datasets[0], batch_size=batch_size,shuffle=False,num_workers=num_workers),\n", 94 | " DataLoader(test_datasets[1], batch_size=batch_size,shuffle=False,num_workers=num_workers),\n", 95 | " DataLoader(test_datasets[2], batch_size=batch_size,shuffle=False,num_workers=num_workers)\n", 96 | "]" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "944003ef", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "labels = {0: [], 1: [], 2: []}\n", 107 | "predictions = {0: [], 1: [], 2: []}\n", 108 | "probs = {0: [], 1: [], 2: []}\n", 109 | "\n", 110 | "with torch.no_grad():\n", 111 | " for i in range(3):\n", 112 | " models[i].eval()\n", 113 | " models[i] = models[i].cuda()\n", 114 | " for (inp, target) in tqdm(test_loader[i]):\n", 115 | " labels[i] += target\n", 116 | " batch_prediction = models[i](inp.cuda())\n", 117 | " probs[i] += torch.softmax(batch_prediction,dim=1)\n", 118 | " _, batch_prediction = torch.max(batch_prediction, dim=1)\n", 119 | " predictions[i] += batch_prediction.detach().tolist()" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "d5c3a0e1", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "_probs = {}\n", 130 | "_labels = {}\n", 131 | "\n", 132 | "_probs[0] = np.asarray(list(map(lambda item: item.cpu().numpy(), probs[0])))\n", 133 | "_probs[1] = np.asarray(list(map(lambda item: item.cpu().numpy(), probs[1])))\n", 134 | "_probs[2] = np.asarray(list(map(lambda item: item.cpu().numpy(), probs[2])))\n", 135 | "\n", 136 | "_labels[0] = np.asarray(list(map(lambda item: item.cpu().numpy(), labels[0])))\n", 137 | "_labels[1] = np.asarray(list(map(lambda item: item.cpu().numpy(), labels[1])))\n", 138 | "_labels[2] = np.asarray(list(map(lambda item: item.cpu().numpy(), labels[2])))" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "id": "3d54aa55", 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "w_1 = 2\n", 149 | "w_2 = .5\n", 150 | "w_3 = .5\n", 151 | "avg_probs = (w_1*_probs[0] + w_2*_probs[1] + w_3*_probs[2])/3" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "id": "5091c546", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "avg_probs.shape" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "id": "9d5ba5ce", 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "preds = np.argmax(avg_probs,axis=1)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "id": "63ef0c4f", 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "gt = _labels[0]" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "id": "f4fdc71d", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "sklearn.metrics.f1_score(gt, preds, average=\"macro\")" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "id": "58ebebb3", 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "sklearn.metrics.roc_auc_score(gt, preds)" 202 | ] 203 | } 204 | ], 205 | "metadata": { 206 | "kernelspec": { 207 | "display_name": "Python 3", 208 | "language": "python", 209 | "name": "python3" 210 | }, 211 | "language_info": { 212 | "codemirror_mode": { 213 | "name": "ipython", 214 | "version": 3 215 | }, 216 | "file_extension": ".py", 217 | "mimetype": "text/x-python", 218 | "name": "python", 219 | "nbconvert_exporter": "python", 220 | "pygments_lexer": "ipython3", 221 | "version": "3.9.13" 222 | } 223 | }, 224 | "nbformat": 4, 225 | "nbformat_minor": 5 226 | } 227 | -------------------------------------------------------------------------------- /notebooks/Testing Ensemble-RIM-ONE DL.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "d26271aa", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import math\n", 12 | "import torch\n", 13 | "import torchvision\n", 14 | "import torch.nn as nn\n", 15 | "import torch.optim as optim\n", 16 | "from torch.nn import CrossEntropyLoss\n", 17 | "from tqdm import tqdm\n", 18 | "from torchvision.models import resnet18\n", 19 | "import timm\n", 20 | "from torch.utils.data import DataLoader\n", 21 | "from skimage.io import imread\n", 22 | "import sklearn\n", 23 | "from sklearn import metrics\n", 24 | "from sklearn.metrics import f1_score\n", 25 | "from sklearn.utils import class_weight\n", 26 | "import pandas as pd\n", 27 | "import numpy as np\n", 28 | "import torchvision.transforms as transforms\n", 29 | "from PIL import Image\n", 30 | "import matplotlib.pyplot as plt\n", 31 | "import random\n", 32 | "from early_stopping import EarlyStopping\n", 33 | "import os\n", 34 | "from airogs_dataset import Airogs\n", 35 | "import wandb\n", 36 | "import sys\n", 37 | "from torchvision.datasets import ImageFolder\n", 38 | "import sklearn.metrics\n", 39 | "import yaml\n", 40 | "torch.multiprocessing.set_sharing_strategy('file_system')" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "d9cc4f68", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "from skimage.exposure import equalize_adapthist\n", 51 | "from skimage.transform import warp_polar\n", 52 | "\n", 53 | "class CLAHE(torch.nn.Module):\n", 54 | " def forward(self, img):\n", 55 | " image = np.array(img, dtype=np.float64) / 255.0\n", 56 | " image = equalize_adapthist(image)\n", 57 | " image = (image*255).astype('uint8')\n", 58 | "\n", 59 | " return image\n", 60 | "\n", 61 | "class POLAR(torch.nn.Module):\n", 62 | " def polar(self,image):\n", 63 | " return warp_polar(image, radius=(max(image.shape) // 2), multichannel=True)\n", 64 | " \n", 65 | " def forward(self, image):\n", 66 | " image = np.array(image, dtype=np.float64)\n", 67 | " image = self.polar(image)\n", 68 | " return image\n", 69 | "\n", 70 | "def set_seed(s):\n", 71 | " torch.manual_seed(s)\n", 72 | " torch.cuda.manual_seed_all(s)\n", 73 | " torch.backends.cudnn.deterministic = True\n", 74 | " torch.backends.cudnn.benchmark = False\n", 75 | " np.random.seed(s)\n", 76 | " random.seed(s)\n", 77 | " os.environ['PYTHONHASHSEED'] = str(s)\n", 78 | "set_seed(0)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "98fe9b51", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "############ CONFIGS ############\n", 89 | "\n", 90 | "num_workers = 32\n", 91 | "batch_size = 8\n", 92 | "\n", 93 | "\n", 94 | "#original\n", 95 | "model_0 = timm.create_model('efficientnet_b0',num_classes=2)\n", 96 | "model_0.load_state_dict(torch.load('PATH_TO_CHEKPOINTS/rimonedl_1.pt')['state_dict'])\n", 97 | "\n", 98 | "\n", 99 | "\n", 100 | "#polar\n", 101 | "model_1 = timm.create_model('efficientnet_b0',num_classes=2)\n", 102 | "model_1.load_state_dict(torch.load('PATH_TO_CHEKPOINTS/rimonedl_2.pt')['state_dict'])\n", 103 | "\n", 104 | "models=[model_0,model_1]\n", 105 | "\n", 106 | "\n", 107 | "transforms = [\n", 108 | " torchvision.transforms.Compose([CLAHE(),torchvision.transforms.ToTensor(),torchvision.transforms.Resize((256,256))]),\n", 109 | " torchvision.transforms.Compose([POLAR(),CLAHE(),torchvision.transforms.ToTensor(),torchvision.transforms.Resize((256,256))]),\n", 110 | "]\n", 111 | "\n", 112 | "path = ['PATH_TO_DATASET/rim_one_dl/partitioned_by_hospital/test_set',\n", 113 | " 'PATH_TO_DATASET/rim_one_dl/partitioned_by_hospital/test_set',\n", 114 | " ]\n", 115 | "\n", 116 | "\n", 117 | "test_datasets = [\n", 118 | " ImageFolder(path[0], transform=transforms[0]),\n", 119 | " ImageFolder(path[1], transform=transforms[1]),\n", 120 | " ]\n", 121 | " \n", 122 | "\n", 123 | "test_loader = [\n", 124 | " DataLoader(test_datasets[0], batch_size=batch_size,shuffle=False,num_workers=num_workers),\n", 125 | " DataLoader(test_datasets[1], batch_size=batch_size,shuffle=False,num_workers=num_workers),\n", 126 | "]" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "id": "944003ef", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "labels = {0: [], 1: []}\n", 137 | "predictions = {0: [], 1: []}\n", 138 | "probs = {0: [], 1: []}\n", 139 | "\n", 140 | "with torch.no_grad():\n", 141 | " for i in range(2):\n", 142 | " models[i].eval()\n", 143 | " models[i] = models[i].cuda()\n", 144 | " for (inp, target) in tqdm(test_loader[i]):\n", 145 | " labels[i] += target\n", 146 | " batch_prediction = models[i](inp.cuda())\n", 147 | " probs[i] += torch.softmax(batch_prediction,dim=1)\n", 148 | " _, batch_prediction = torch.max(batch_prediction, dim=1)\n", 149 | " predictions[i] += batch_prediction.detach().tolist()" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "id": "d5c3a0e1", 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "_probs = {}\n", 160 | "_labels = {}\n", 161 | "\n", 162 | "_probs[0] = np.asarray(list(map(lambda item: item.cpu().numpy(), probs[0])))\n", 163 | "_probs[1] = np.asarray(list(map(lambda item: item.cpu().numpy(), probs[1])))\n", 164 | "\n", 165 | "_labels[0] = np.asarray(list(map(lambda item: item.cpu().numpy(), labels[0])))\n", 166 | "_labels[1] = np.asarray(list(map(lambda item: item.cpu().numpy(), labels[1])))\n" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "3d54aa55", 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "w_1 = 1\n", 177 | "w_2 = 1\n", 178 | "avg_probs = ((w_1*_probs[0]) + (w_2*_probs[1]))/2" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "id": "9d5ba5ce", 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "preds = np.argmax(avg_probs,axis=1)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "id": "63ef0c4f", 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "gt = _labels[0]" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "id": "f4fdc71d", 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "sklearn.metrics.f1_score(gt, preds, average=\"macro\")" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "id": "58ebebb3", 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "sklearn.metrics.roc_auc_score(gt, preds)" 219 | ] 220 | } 221 | ], 222 | "metadata": { 223 | "kernelspec": { 224 | "display_name": "Python 3", 225 | "language": "python", 226 | "name": "python3" 227 | }, 228 | "language_info": { 229 | "codemirror_mode": { 230 | "name": "ipython", 231 | "version": 3 232 | }, 233 | "file_extension": ".py", 234 | "mimetype": "text/x-python", 235 | "name": "python", 236 | "nbconvert_exporter": "python", 237 | "pygments_lexer": "ipython3", 238 | "version": "3.9.13" 239 | } 240 | }, 241 | "nbformat": 4, 242 | "nbformat_minor": 5 243 | } 244 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torchvision 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.nn import CrossEntropyLoss 8 | from tqdm import tqdm 9 | from torchvision.models import resnet18 10 | from torch.utils.data import DataLoader 11 | from skimage.io import imread 12 | import sklearn 13 | from sklearn import metrics 14 | from sklearn.metrics import f1_score 15 | from sklearn.utils import class_weight 16 | import pandas as pd 17 | import numpy as np 18 | import torchvision.transforms as transforms 19 | from PIL import Image 20 | 21 | from airogs_dataset import Airogs 22 | import wandb 23 | 24 | 25 | def main(): 26 | resize = 224 27 | epochs = 50 28 | lr = 0.01 29 | lr_step_period = None 30 | momentum = 0.1 31 | batch_size = 64 32 | num_workers = 16 33 | 34 | data_dir = "/l/users/20020052/data/airogs/" 35 | images_dir_name = "train" 36 | output_dir = "output" 37 | run_test = True 38 | pretrained = True 39 | model_name = "resnet18" 40 | optimizer_name = "sgd" 41 | name = f"exp1_{model_name}_{resize}R" 42 | 43 | wandb.init(name=name, project="airogs_final", entity="airogs") 44 | 45 | os.makedirs(output_dir, exist_ok=True) 46 | 47 | if torch.cuda.is_available(): 48 | device = torch.device("cuda:0") 49 | else: 50 | device = torch.device("cpu") 51 | 52 | wandb.config.update ({ 53 | "epochs": epochs, 54 | "lr": lr, 55 | "lr_step_period": lr_step_period, 56 | "momentun": momentum, 57 | "batch_size": batch_size, 58 | "num_workers": num_workers, 59 | "data_dir": data_dir, 60 | "images_dir_name": images_dir_name, 61 | "output_dir": output_dir, 62 | "run_test": run_test, 63 | "pretrained": pretrained, 64 | "model": model_name, 65 | "optimizer": optimizer_name, 66 | "device": device.type, 67 | "resize": resize 68 | }) 69 | 70 | 71 | transform = None 72 | polar_transform = None 73 | 74 | if resize != None: 75 | transform = torchvision.transforms.Compose({ 76 | transforms.ToTensor(), 77 | transforms.Resize((resize, resize)) 78 | }) 79 | else: 80 | transform = torchvision.transforms.Compose([ 81 | transforms.ToTensor(), 82 | ]) 83 | 84 | 85 | train_dataset = Airogs( 86 | path=data_dir, 87 | images_dir_name=images_dir_name, 88 | split="train", 89 | transforms=transform, 90 | polar_transforms=polar_transform 91 | ) 92 | val_dataset = Airogs( 93 | path=data_dir, 94 | images_dir_name=images_dir_name, 95 | split="val", 96 | transforms=transform 97 | ) 98 | 99 | test_dataset = Airogs( 100 | path=data_dir, 101 | images_dir_name=images_dir_name, 102 | split="test", 103 | transforms=test_transform 104 | ) 105 | 106 | train_loader = DataLoader(train_dataset, 107 | batch_size=batch_size, 108 | shuffle=True, 109 | num_workers=num_workers, 110 | ) 111 | val_loader = DataLoader(val_dataset, 112 | batch_size=batch_size, 113 | shuffle=True, 114 | num_workers=num_workers, 115 | ) 116 | 117 | test_loader = DataLoader(test_dataset, 118 | batch_size=batch_size, 119 | shuffle=False, 120 | num_workers=num_workers, 121 | ) 122 | csv_data = pd.read_csv(os.path.join(data_dir, "train.csv")) 123 | labels_referable = csv_data['referable'] 124 | weight_referable = class_weight.compute_class_weight(class_weight='balanced', classes = np.unique(labels_referable), y=labels_referable).astype('float32') 125 | print("Class Weights: ", weight_referable) 126 | 127 | 128 | wandb.config.update({ 129 | "train_count": len(train_dataset), 130 | "val_count": len(val_dataset), 131 | "class_weights": ", ".join(map(lambda x: str(x), weight_referable)) 132 | }) 133 | 134 | 135 | if model_name == "resnet18": 136 | model = resnet18(pretrained=pretrained) 137 | model.fc = nn.Linear(in_features=model.fc.in_features, out_features=2, bias=True) 138 | model = model.to(device) 139 | 140 | wandb.watch(model) 141 | 142 | criterion = CrossEntropyLoss(weight=torch.from_numpy(weight_referable).to(device)) 143 | 144 | if optimizer_name == "sgd": 145 | optimizer = optim.SGD(model.parameters(),lr=lr,momentum=momentum) 146 | 147 | if lr_step_period is None: 148 | lr_step_period = math.inf 149 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_step_period) 150 | 151 | with open(os.path.join(output_dir, "log.csv"), "a") as f: 152 | f.write("Train Dataset size: {}".format(len(train_dataset))) 153 | f.write("Validation Dataset size: {}".format(len(val_dataset))) 154 | 155 | epoch_resume = 0 156 | best_f1 = 0.0 157 | try: 158 | # Attempt to load checkpoint 159 | checkpoint = torch.load(os.path.join(output_dir, "checkpoint.pt")) 160 | model.load_state_dict(checkpoint['state_dict']) 161 | optimizer.load_state_dict(checkpoint['opt_dict']) 162 | scheduler.load_state_dict(checkpoint['scheduler_dict']) 163 | epoch_resume = checkpoint["epoch"] + 1 164 | best_f1 = checkpoint["best_f1"] 165 | f.write("Resuming from epoch {}\n".format(epoch_resume)) 166 | f.flush() 167 | except FileNotFoundError: 168 | f.write("Starting run from scratch\n") 169 | 170 | # Train 171 | if epoch_resume < epochs: 172 | f.write("Resuming training\n") 173 | for epoch in range(epoch_resume, epochs): 174 | for split in ['Train', 'Val']: 175 | if split == "Train": 176 | model.train() 177 | else: 178 | model.eval() 179 | 180 | epoch_total_loss = 0 181 | labels = [] 182 | predictions = [] 183 | loader = train_loader if split == "Train" else val_loader 184 | for batch_num, (inp, target) in enumerate(tqdm(loader)): 185 | labels+=target 186 | optimizer.zero_grad() 187 | output = model(inp.to(device)) 188 | _, batch_prediction = torch.max(output, dim=1) 189 | predictions += batch_prediction.detach().tolist() 190 | batch_loss = criterion(output, target.to(device)) 191 | epoch_total_loss += batch_loss.item() 192 | 193 | if split == "Train": 194 | batch_loss.backward() 195 | optimizer.step() 196 | 197 | avrg_loss = epoch_total_loss / loader.dataset.__len__() 198 | accuracy = metrics.accuracy_score(labels, predictions) 199 | confusion = metrics.confusion_matrix(labels, predictions) 200 | _f1_score = f1_score(labels, predictions, average="macro") 201 | auc = sklearn.metrics.roc_auc_score(labels, predictions) 202 | print("%s Epoch %d - loss=%0.4f AUC=%0.4f F1=%0.4f Accuracy=%0.4f" % (split, epoch, avrg_loss, auc, _f1_score, accuracy)) 203 | f.write("%s Epoch {} - loss={} AUC={} F1={} Accuracy={}\n".format(split, epoch, avrg_loss, auc, _f1_score, accuracy)) 204 | f.flush() 205 | 206 | if split == "Train": 207 | wandb.log({"epoch": epoch, "train loss": avrg_loss, "train acc": accuracy, "train f1": _f1_score, "train auc": auc}) 208 | else: 209 | wandb.log({"epoch": epoch, "val loss": avrg_loss, "val acc": accuracy, "val f1": _f1_score, "val auc": auc}) 210 | 211 | scheduler.step() 212 | 213 | # save model 214 | checkpoint = { 215 | 'epoch': epoch, 216 | 'best_f1': best_f1, 217 | 'f1': _f1_score, 218 | 'auc': auc, 219 | 'loss': avrg_loss, 220 | 'state_dict': model.state_dict(), 221 | 'opt_dict': optimizer.state_dict(), 222 | 'scheduler_dict': scheduler.state_dict() 223 | } 224 | 225 | torch.save(checkpoint, os.path.join(output_dir, "checkpoint.pt")) 226 | if _f1_score > best_f1: 227 | best_f1 = _f1_score 228 | checkpoint["best_f1"] = best_f1 229 | torch.save(checkpoint, os.path.join(output_dir, "best.pt")) 230 | 231 | #print(confusion) 232 | #f.write("%s {} - Confusion={}\n".format(split, confusion)) 233 | else: 234 | print("Skipping training\n") 235 | f.write("Skipping training\n") 236 | 237 | # Testing 238 | if run_test: 239 | checkpoint = torch.load(os.path.join(output_dir, "best.pt")) 240 | model.load_state_dict(checkpoint['state_dict']) 241 | f.write("Best F1 {} from epoch {}\n".format(checkpoint["best_f1"], checkpoint["epoch"])) 242 | f.flush() 243 | print("Best F1 {} from epoch {}\n".format(checkpoint["best_f1"], checkpoint["epoch"])) 244 | 245 | model.eval() 246 | labels = [] 247 | predictions = [] 248 | for (inp, target) in tqdm(test_loader): 249 | labels+=target 250 | batch_prediction = model(inp.to(device)) 251 | _, batch_prediction = torch.max(batch_prediction, dim=1) 252 | predictions += batch_prediction.detach().tolist() 253 | accuracy = metrics.accuracy_score(labels, predictions) 254 | f.write("Test Accuracy = {}\n".format(accuracy)) 255 | print("Test Accuracy = %0.2f" % (accuracy)) 256 | confusion = metrics.confusion_matrix(labels, predictions) 257 | f.write("Test Confusion Matrix = {}\n".format(confusion)) 258 | print(confusion) 259 | _f1_score = f1_score(labels, predictions, average="macro") 260 | f.write("Test F1 Score = {}\n".format(_f1_score)) 261 | print("Test F1 = %0.2f" % (_f1_score)) 262 | auc = sklearn.metrics.roc_auc_score(labels, predictions) 263 | f.write("Test AUC = {}\n".format(auc)) 264 | print("Test AUC = %0.2f" % (auc)) 265 | f.flush() 266 | 267 | if __name__ == "__main__": 268 | main() 269 | -------------------------------------------------------------------------------- /notebooks/Testing RIM ONE DL.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "ec31a4d8", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import torchvision\n", 12 | "import pandas as pd\n", 13 | "import glob\n", 14 | "import os\n", 15 | "from PIL import Image\n", 16 | "import cv2\n", 17 | "import numpy as np\n", 18 | "from PIL import Image\n", 19 | "import random\n", 20 | "from pathlib import Path\n", 21 | "import matplotlib.pyplot as plt" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "dc003fc1", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "def set_seed(s):\n", 32 | " torch.manual_seed(s)\n", 33 | " torch.cuda.manual_seed_all(s)\n", 34 | " torch.backends.cudnn.deterministic = True\n", 35 | " torch.backends.cudnn.benchmark = False\n", 36 | " np.random.seed(s)\n", 37 | " random.seed(s)\n", 38 | " os.environ['PYTHONHASHSEED'] = str(s)\n", 39 | "set_seed(0)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "e0993a76", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "from skimage.exposure import equalize_adapthist\n", 50 | "from skimage.transform import warp_polar\n", 51 | "\n", 52 | "class CLAHE(torch.nn.Module):\n", 53 | " def forward(self, img):\n", 54 | " image = np.array(img, dtype=np.float64) / 255.0\n", 55 | " image = equalize_adapthist(image)\n", 56 | " image = (image*255).astype('uint8')\n", 57 | "\n", 58 | " return image\n", 59 | "\n", 60 | "class POLAR(torch.nn.Module):\n", 61 | " def polar(self,image):\n", 62 | " return warp_polar(image, radius=(max(image.shape) // 2), multichannel=True)\n", 63 | " \n", 64 | " def forward(self, image):\n", 65 | " image = np.array(image, dtype=np.float64)\n", 66 | " image = self.polar(image)\n", 67 | " return image" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "id": "c9f297cf", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "from torchvision.datasets import ImageFolder\n", 78 | "from torch.utils.data import DataLoader\n", 79 | "import torchvision.transforms as transforms\n", 80 | "\n", 81 | "split = \"test\"\n", 82 | "batch_size = 32\n", 83 | "num_workers = 32\n", 84 | "train_path = f\"PATH_TO_DATA/partitioned_by_hospital/training_set\" # path to dataset training set\n", 85 | "path = f\"PATH_TO_DATA/partitioned_by_hospital/{split}_set\" # path to dataset folder\n", 86 | "output_dir = \"PATH_TO_SAVE_OUTPUTS\" # path to save checkpoints\n", 87 | "\n", 88 | "train_transform = torchvision.transforms.Compose([\n", 89 | " CLAHE(),\n", 90 | " transforms.ToTensor(),\n", 91 | " transforms.Resize(256),\n", 92 | " transforms.RandomVerticalFlip(),\n", 93 | " transforms.RandomHorizontalFlip(),\n", 94 | " transforms.RandomAffine(0,scale=(1.0,1.3))\n", 95 | " ])\n", 96 | "transform = torchvision.transforms.Compose([\n", 97 | " CLAHE(),\n", 98 | " transforms.ToTensor(),\n", 99 | " transforms.Resize(256)\n", 100 | " ])\n", 101 | "train_dataset = ImageFolder(train_path, transform=train_transform)\n", 102 | "num = int(np.floor(len(train_dataset) * 1))\n", 103 | "indices = np.random.choice(len(train_dataset), num, replace=False)\n", 104 | "train_dataset = torch.utils.data.Subset(train_dataset, indices)\n", 105 | "train_loader = DataLoader(train_dataset, \n", 106 | " batch_size=batch_size, \n", 107 | " shuffle=True,\n", 108 | " num_workers=num_workers,\n", 109 | " )\n", 110 | "test_dataset = ImageFolder(path, transform=transform)\n", 111 | "test_loader = DataLoader(test_dataset, \n", 112 | " batch_size=batch_size, \n", 113 | " shuffle=True,\n", 114 | " num_workers=num_workers,\n", 115 | " )\n", 116 | "\n", 117 | "print(len(train_dataset))\n", 118 | "print(len(test_dataset))" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "0725981d", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "_labels = []\n", 129 | "for j in range(len(train_dataset)):\n", 130 | " _labels.append(train_dataset[j][1])\n", 131 | "_labels = np.asarray(_labels)\n", 132 | "np.unique(_labels)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "id": "18daca84", 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "import timm\n", 143 | "\n", 144 | "model_name = \"efficientnet_b0\"\n", 145 | "pretrained = True\n", 146 | "dropout = 0.2\n", 147 | "lr = 0.0005\n", 148 | "#momentum = 0.1\n", 149 | "epochs = 20\n", 150 | "\n", 151 | "if torch.cuda.is_available():\n", 152 | " device = torch.device(\"cuda:0\")\n", 153 | "else:\n", 154 | " device = torch.device(\"cpu\")\n", 155 | "\n", 156 | "model = timm.create_model(model_name, pretrained=pretrained, num_classes=2, drop_rate=dropout)\n", 157 | "model = model.to(device)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "id": "350aaa38", 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "path = \"PATH_TO_CKPTS/rimonedl_1.pt\"\n", 168 | "checkpoint = torch.load(path)\n", 169 | "model.load_state_dict(checkpoint['state_dict'])\n", 170 | "print(\"Best F1 {} from epoch {}\\n\".format(checkpoint[\"best_f1\"], checkpoint[\"epoch\"]))" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "id": "dd683741", 177 | "metadata": { 178 | "scrolled": false 179 | }, 180 | "outputs": [], 181 | "source": [ 182 | "import os\n", 183 | "from sklearn.utils import class_weight\n", 184 | "from torch.nn import CrossEntropyLoss\n", 185 | "import torch.optim as optim\n", 186 | "from tqdm import tqdm\n", 187 | "import sklearn\n", 188 | "from sklearn import metrics\n", 189 | "from sklearn.metrics import f1_score\n", 190 | "\n", 191 | "\n", 192 | "if not os.path.exists(output_dir):\n", 193 | " os.makedirs(output_dir)\n", 194 | "\n", 195 | "weight_referable = class_weight.compute_class_weight(class_weight='balanced', classes = np.unique(_labels), y=_labels).astype('float32') \n", 196 | "weight_referable = np.array([weight_referable[0], weight_referable[1]])\n", 197 | "criterion = CrossEntropyLoss(weight=torch.from_numpy(weight_referable).to(device))\n", 198 | "print(weight_referable)\n", 199 | "\n", 200 | "optimizer = optim.Adam(model.parameters(),lr=lr)\n", 201 | "\n", 202 | "epoch_resume = 0\n", 203 | "best_f1 = 0.0\n", 204 | "\n", 205 | "\n", 206 | "# Train\n", 207 | "if epoch_resume < epochs:\n", 208 | " print(\"Resuming training\\n\")\n", 209 | " for epoch in range(epoch_resume, epochs):\n", 210 | " for split in ['Train']:\n", 211 | " if split == \"Train\":\n", 212 | " model.train()\n", 213 | " else:\n", 214 | " model.eval()\n", 215 | "\n", 216 | " epoch_total_loss = 0\n", 217 | " labels = []\n", 218 | " predictions = []\n", 219 | " loader = train_loader if split == \"Train\" else val_loader\n", 220 | " for batch_num, (inp, target) in enumerate(tqdm(loader)):\n", 221 | " labels+=(target)\n", 222 | " optimizer.zero_grad()\n", 223 | " output = model(inp.to(device))\n", 224 | " _, batch_prediction = torch.max(output, dim=1)\n", 225 | " predictions += batch_prediction.detach().tolist()\n", 226 | " batch_loss = criterion(output, (target).to(device))\n", 227 | " epoch_total_loss += batch_loss.item()\n", 228 | "\n", 229 | " if split == \"Train\":\n", 230 | " batch_loss.backward()\n", 231 | " optimizer.step()\n", 232 | "\n", 233 | " avrg_loss = epoch_total_loss / loader.dataset.__len__()\n", 234 | " accuracy = metrics.accuracy_score(labels, predictions)\n", 235 | " confusion = metrics.confusion_matrix(labels, predictions)\n", 236 | " _f1_score = f1_score(labels, predictions, average=\"macro\")\n", 237 | " auc = sklearn.metrics.roc_auc_score(labels, predictions)\n", 238 | " print(\"%s Epoch %d - loss=%0.4f AUC=%0.4f F1=%0.4f Accuracy=%0.4f\" % (split, epoch, avrg_loss, auc, _f1_score, accuracy))\n", 239 | "\n", 240 | "\n", 241 | " # save model\n", 242 | " checkpoint = {\n", 243 | " 'epoch': epoch,\n", 244 | " 'best_f1': best_f1,\n", 245 | " 'f1': _f1_score,\n", 246 | " 'auc': auc,\n", 247 | " 'loss': avrg_loss,\n", 248 | " 'state_dict': model.state_dict(),\n", 249 | " 'opt_dict': optimizer.state_dict(),\n", 250 | " #'scheduler_dict': scheduler.state_dict()\n", 251 | " }\n", 252 | "\n", 253 | " torch.save(checkpoint, os.path.join(output_dir, f\"checkpoint_{epoch}.pt\"))\n", 254 | " if _f1_score > best_f1:\n", 255 | " best_f1 = _f1_score\n", 256 | " checkpoint[\"best_f1\"] = best_f1\n", 257 | " torch.save(checkpoint, os.path.join(output_dir, \"best.pt\"))\n", 258 | "else:\n", 259 | " print(\"Skipping training\\n\")" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "id": "1d143e34", 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "path = f\"./{output_dir}/CKPT_TO_TEST.pt\"\n", 270 | "checkpoint = torch.load(path)\n", 271 | "model.load_state_dict(checkpoint['state_dict'])\n", 272 | "print(\"Best F1 {} from epoch {}\\n\".format(checkpoint[\"best_f1\"], checkpoint[\"epoch\"]))" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "id": "30e62fe9", 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "import torch\n", 283 | "from tqdm import tqdm\n", 284 | "import sklearn\n", 285 | "from sklearn import metrics\n", 286 | "from sklearn.metrics import f1_score\n", 287 | "\n", 288 | "model.eval()\n", 289 | "labels = []\n", 290 | "predictions = []\n", 291 | "with torch.no_grad():\n", 292 | " for (inp, target) in tqdm(test_loader):\n", 293 | " labels+=(target)\n", 294 | " batch_prediction = model(inp.to(device))\n", 295 | " _, batch_prediction = torch.max(batch_prediction, dim=1)\n", 296 | " predictions += batch_prediction.detach().tolist()\n", 297 | "accuracy = metrics.accuracy_score(labels, predictions)\n", 298 | "print(\"Test Accuracy = %0.5f\" % (accuracy))\n", 299 | "\n", 300 | "confusion = metrics.confusion_matrix(labels, predictions)\n", 301 | "print(confusion)\n", 302 | "\n", 303 | "_f1_score = f1_score(labels, predictions, average=\"macro\")\n", 304 | "print(\"Test F1 = %0.5f\" % (_f1_score))\n", 305 | "\n", 306 | "auc = sklearn.metrics.roc_auc_score(labels, predictions)\n", 307 | "print(\"Test AUC = %0.5f\" % (auc))" 308 | ] 309 | } 310 | ], 311 | "metadata": { 312 | "kernelspec": { 313 | "display_name": "Python 3", 314 | "language": "python", 315 | "name": "python3" 316 | }, 317 | "language_info": { 318 | "codemirror_mode": { 319 | "name": "ipython", 320 | "version": 3 321 | }, 322 | "file_extension": ".py", 323 | "mimetype": "text/x-python", 324 | "name": "python", 325 | "nbconvert_exporter": "python", 326 | "pygments_lexer": "ipython3", 327 | "version": "3.9.13" 328 | } 329 | }, 330 | "nbformat": 4, 331 | "nbformat_minor": 5 332 | } 333 | -------------------------------------------------------------------------------- /notebooks/central_crop.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import torchvision.transforms as transforms\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "from glob import glob\n", 13 | "import os\n", 14 | "from PIL import Image\n", 15 | "import cv2\n", 16 | "from tqdm import tqdm" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "data": { 26 | "text/html": [ 27 | "
\n", 28 | "\n", 41 | "\n", 42 | " \n", 43 | " \n", 44 | " \n", 45 | " \n", 46 | " \n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | "
challenge_idclassreferablegradablex1y1x2y2
0TRAIN000011NRG00121.00000084.000000192.000000167.0
1TRAIN000013NRG0064.000000103.000000114.000000170.0
2TRAIN000023NRG0036.00000067.000000101.000000135.0
3TRAIN000030NRG00122.00000085.000000199.000000157.0
4TRAIN000031NRG00-1.000000-1.000000-1.000000-1.0
...........................
20283TRAIN101402NRG00-1.000000-1.000000-1.000000-1.0
20284TRAIN101406NRG00-1.000000-1.000000-1.000000-1.0
20285TRAIN101410NRG00182.00000084.000000249.000000155.0
20286TRAIN101417NRG0073.800003179.800003192.199997255.0
20287TRAIN101428NRG00-1.000000-1.000000-1.000000-1.0
\n", 179 | "

20288 rows × 8 columns

\n", 180 | "
" 181 | ], 182 | "text/plain": [ 183 | " challenge_id class referable gradable x1 y1 \\\n", 184 | "0 TRAIN000011 NRG 0 0 121.000000 84.000000 \n", 185 | "1 TRAIN000013 NRG 0 0 64.000000 103.000000 \n", 186 | "2 TRAIN000023 NRG 0 0 36.000000 67.000000 \n", 187 | "3 TRAIN000030 NRG 0 0 122.000000 85.000000 \n", 188 | "4 TRAIN000031 NRG 0 0 -1.000000 -1.000000 \n", 189 | "... ... ... ... ... ... ... \n", 190 | "20283 TRAIN101402 NRG 0 0 -1.000000 -1.000000 \n", 191 | "20284 TRAIN101406 NRG 0 0 -1.000000 -1.000000 \n", 192 | "20285 TRAIN101410 NRG 0 0 182.000000 84.000000 \n", 193 | "20286 TRAIN101417 NRG 0 0 73.800003 179.800003 \n", 194 | "20287 TRAIN101428 NRG 0 0 -1.000000 -1.000000 \n", 195 | "\n", 196 | " x2 y2 \n", 197 | "0 192.000000 167.0 \n", 198 | "1 114.000000 170.0 \n", 199 | "2 101.000000 135.0 \n", 200 | "3 199.000000 157.0 \n", 201 | "4 -1.000000 -1.0 \n", 202 | "... ... ... \n", 203 | "20283 -1.000000 -1.0 \n", 204 | "20284 -1.000000 -1.0 \n", 205 | "20285 249.000000 155.0 \n", 206 | "20286 192.199997 255.0 \n", 207 | "20287 -1.000000 -1.0 \n", 208 | "\n", 209 | "[20288 rows x 8 columns]" 210 | ] 211 | }, 212 | "execution_count": 2, 213 | "metadata": {}, 214 | "output_type": "execute_result" 215 | } 216 | ], 217 | "source": [ 218 | "df = pd.read_csv('PATH_TO_DATA/train_256_bbs_final.csv')\n", 219 | "df" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 3, 225 | "metadata": {}, 226 | "outputs": [ 227 | { 228 | "data": { 229 | "text/html": [ 230 | "
\n", 231 | "\n", 244 | "\n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | "
challenge_idclassreferablegradablex1y1x2y2
4TRAIN000031NRG00-1.0-1.0-1.0-1.0
8TRAIN000060RG10-1.0-1.0-1.0-1.0
17TRAIN000103NRG00-1.0-1.0-1.0-1.0
19TRAIN000106RG10-1.0-1.0-1.0-1.0
20TRAIN000119NRG00-1.0-1.0-1.0-1.0
...........................
20268TRAIN101342NRG00-1.0-1.0-1.0-1.0
20276TRAIN101366NRG00-1.0-1.0-1.0-1.0
20283TRAIN101402NRG00-1.0-1.0-1.0-1.0
20284TRAIN101406NRG00-1.0-1.0-1.0-1.0
20287TRAIN101428NRG00-1.0-1.0-1.0-1.0
\n", 382 | "

3933 rows × 8 columns

\n", 383 | "
" 384 | ], 385 | "text/plain": [ 386 | " challenge_id class referable gradable x1 y1 x2 y2\n", 387 | "4 TRAIN000031 NRG 0 0 -1.0 -1.0 -1.0 -1.0\n", 388 | "8 TRAIN000060 RG 1 0 -1.0 -1.0 -1.0 -1.0\n", 389 | "17 TRAIN000103 NRG 0 0 -1.0 -1.0 -1.0 -1.0\n", 390 | "19 TRAIN000106 RG 1 0 -1.0 -1.0 -1.0 -1.0\n", 391 | "20 TRAIN000119 NRG 0 0 -1.0 -1.0 -1.0 -1.0\n", 392 | "... ... ... ... ... ... ... ... ...\n", 393 | "20268 TRAIN101342 NRG 0 0 -1.0 -1.0 -1.0 -1.0\n", 394 | "20276 TRAIN101366 NRG 0 0 -1.0 -1.0 -1.0 -1.0\n", 395 | "20283 TRAIN101402 NRG 0 0 -1.0 -1.0 -1.0 -1.0\n", 396 | "20284 TRAIN101406 NRG 0 0 -1.0 -1.0 -1.0 -1.0\n", 397 | "20287 TRAIN101428 NRG 0 0 -1.0 -1.0 -1.0 -1.0\n", 398 | "\n", 399 | "[3933 rows x 8 columns]" 400 | ] 401 | }, 402 | "execution_count": 3, 403 | "metadata": {}, 404 | "output_type": "execute_result" 405 | } 406 | ], 407 | "source": [ 408 | "# Select images without bounding box only\n", 409 | "\n", 410 | "df_noBox = df[df['x1'] == -1]\n", 411 | "df_noBox" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 4, 417 | "metadata": {}, 418 | "outputs": [ 419 | { 420 | "name": "stderr", 421 | "output_type": "stream", 422 | "text": [ 423 | "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3933/3933 [19:43<00:00, 3.32it/s]\n" 424 | ] 425 | } 426 | ], 427 | "source": [ 428 | "for img in tqdm(df_noBox['challenge_id']):\n", 429 | " \n", 430 | " path = glob(os.path.join('PATH_TO_DATA','AIROGS_train','*',img+'*'))[-1] \n", 431 | " new_path ='PATH_TO_SAVE_CROPPED_IMAGES'\n", 432 | "\n", 433 | " os.makedirs(new_path.replace(os.path.basename(new_path),''),exist_ok=True)\n", 434 | " \n", 435 | " im = plt.imread(path) \n", 436 | " w,h,_ = im.shape\n", 437 | " \n", 438 | " transform = transforms.Compose([\n", 439 | " transforms.ToTensor(), \n", 440 | " transforms.CenterCrop(w*0.85),\n", 441 | " transforms.Resize((512, 512))\n", 442 | " ])\n", 443 | "\n", 444 | " _img = transform(im)\n", 445 | " _img = (_img.moveaxis(0, -1).numpy() * 255).astype('uint8')\n", 446 | " \n", 447 | " assert cv2.imwrite(f'PATH_TO_SAVE_CROPPED_IMAGES/{img}.jpg', cv2.cvtColor(_img, cv2.COLOR_RGB2BGR))\n" 448 | ] 449 | } 450 | ], 451 | "metadata": { 452 | "interpreter": { 453 | "hash": "0a35368ac066db472e96c8cb9ae044ce9a84e7aa82e93512bf3e5098ca556c0f" 454 | }, 455 | "kernelspec": { 456 | "display_name": "Python 3", 457 | "language": "python", 458 | "name": "python3" 459 | }, 460 | "language_info": { 461 | "codemirror_mode": { 462 | "name": "ipython", 463 | "version": 3 464 | }, 465 | "file_extension": ".py", 466 | "mimetype": "text/x-python", 467 | "name": "python", 468 | "nbconvert_exporter": "python", 469 | "pygments_lexer": "ipython3", 470 | "version": "3.9.13" 471 | } 472 | }, 473 | "nbformat": 4, 474 | "nbformat_minor": 2 475 | } 476 | -------------------------------------------------------------------------------- /run_fold.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torchvision 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.nn import CrossEntropyLoss 8 | from tqdm import tqdm 9 | from torchvision.models import resnet18 10 | import timm 11 | from torch.utils.data import DataLoader 12 | from skimage.io import imread 13 | import sklearn 14 | from sklearn import metrics 15 | from sklearn.metrics import f1_score 16 | from sklearn.utils import class_weight 17 | import pandas as pd 18 | import numpy as np 19 | import torchvision.transforms as transforms 20 | from PIL import Image 21 | import random 22 | from early_stopping import EarlyStopping 23 | import os 24 | from airogs_dataset import Airogs 25 | import wandb 26 | import sys 27 | import yaml 28 | torch.multiprocessing.set_sharing_strategy('file_system') 29 | 30 | def main(path_to_config,folds=5): 31 | with open(path_to_config,'r') as file_: 32 | config = yaml.safe_load(file_) 33 | 34 | fold_num = config['fold_num'] 35 | exp_id = config['exp_id'] 36 | resize = config['resize'] 37 | epochs = config['epochs'] 38 | lr = float(config['lr']) 39 | lr_step_period = config['lr_step_period'] 40 | momentum = config['momentum'] 41 | batch_size = config['batch_size'] 42 | num_workers = config['num_workers'] 43 | data_dir = config['data_dir'] 44 | images_dir_name = config['images_dir_name'] 45 | run_test = config['run_test'] 46 | pretrained = config['pretrained'] 47 | model_name = config['model_name'] 48 | optimizer_name = config['optimizer_name'] 49 | scheduler = config['scheduler'] 50 | patience = config['patience'] 51 | apply_augs = config['apply_augs'] 52 | apply_clahe = config['apply_clahe'] 53 | dropout = config['dropout'] 54 | try: 55 | apply_scaling = config['apply_scaling'] 56 | except: 57 | apply_scaling = False 58 | 59 | try: 60 | polar_transform = config['polar_transform'] 61 | except: 62 | polar_transform = False 63 | 64 | 65 | output_dir = f"output/{exp_id}_{fold_num}" 66 | assert fold_num in range(5), "Fold number has to be betwen 0-4." 67 | assert not os.path.exists(output_dir), 'Fold already exists!' 68 | os.makedirs(output_dir, exist_ok=True) 69 | 70 | if torch.cuda.is_available(): 71 | device = torch.device("cuda:0") 72 | else: 73 | device = torch.device("cpu") 74 | 75 | transform = None 76 | 77 | if resize != None and resize != 512: 78 | print(f"Using resized image size {resize}x{resize}") 79 | transform = torchvision.transforms.Compose([ 80 | transforms.ToTensor(), 81 | transforms.Resize(resize), 82 | ]) 83 | augs = [ 84 | transforms.ToTensor(), 85 | transforms.Resize(resize), 86 | ] 87 | if apply_augs: 88 | augs.append(transforms.RandomVerticalFlip()) 89 | augs.append(transforms.RandomHorizontalFlip()) 90 | augs.append(transforms.RandomRotation(10)) 91 | 92 | train_transform = torchvision.transforms.Compose(augs) 93 | 94 | else: 95 | print('Using original image size 512x512') 96 | transform = torchvision.transforms.Compose([ 97 | transforms.ToTensor() 98 | ]) 99 | augs = [ 100 | transforms.ToTensor(), 101 | ] 102 | if apply_augs: 103 | augs.append(transforms.RandomVerticalFlip()) 104 | augs.append(transforms.RandomHorizontalFlip()) 105 | augs.append(transforms.RandomRotation(10)) 106 | if apply_scaling: 107 | augs.append(transforms.RandomAffine(0,scale=(1,1.5))) 108 | train_transform = torchvision.transforms.Compose(augs) 109 | 110 | early_stopping = EarlyStopping(patience=patience, verbose=True) 111 | 112 | f1s, aucs, accuracies, losses = [], [], [], [] 113 | 114 | for k in range(folds): 115 | if k != fold_num: 116 | continue 117 | 118 | name = f"exp{exp_id}_{k}fold_{model_name}_{resize}" 119 | os.environ["WANDB_RUN_GROUP"] = f"exp{exp_id}_{model_name}_{resize}" 120 | wandb.init(project="airogs_final", entity="airogs", name=name, reinit=True) 121 | 122 | train_dataset = Airogs( 123 | path=data_dir, 124 | images_dir_name=images_dir_name, 125 | split=f"train_{k}", 126 | transforms=train_transform, 127 | polar_transforms=polar_transform, 128 | apply_clahe = apply_clahe 129 | ) 130 | val_dataset = Airogs( 131 | path=data_dir, 132 | images_dir_name=images_dir_name, 133 | split=f"val_{k}", 134 | transforms=transform, 135 | polar_transforms=polar_transform, 136 | apply_clahe = apply_clahe 137 | ) 138 | 139 | train_loader = DataLoader(train_dataset, 140 | batch_size=batch_size, 141 | shuffle=True, 142 | num_workers=num_workers, 143 | ) 144 | val_loader = DataLoader(val_dataset, 145 | batch_size=batch_size, 146 | shuffle=True, 147 | num_workers=num_workers, 148 | ) 149 | 150 | csv_data = pd.read_csv(os.path.join(data_dir, "train.csv")) 151 | labels_referable = csv_data['referable'] 152 | weight_referable = class_weight.compute_class_weight(class_weight='balanced', classes = np.unique(labels_referable), y=labels_referable).astype('float32') 153 | print("Class Weights: ", weight_referable) 154 | 155 | wandb.config.update ({ 156 | "epochs": epochs, 157 | "lr": lr, 158 | "lr_step_period": lr_step_period, 159 | "momentun": momentum, 160 | "batch_size": batch_size, 161 | "num_workers": num_workers, 162 | "data_dir": data_dir, 163 | "images_dir_name": images_dir_name, 164 | "output_dir": output_dir, 165 | "run_test": run_test, 166 | "pretrained": pretrained, 167 | "model": model_name, 168 | "optimizer": optimizer_name, 169 | "device": device.type, 170 | "resize": resize, 171 | "train_count": len(train_dataset), 172 | "val_count": len(val_dataset), 173 | "patience" : patience, 174 | "scheduler": scheduler, 175 | "class_weights": ", ".join(map(lambda x: str(x), weight_referable)), 176 | "apply_augs": apply_augs, 177 | "apply_clahe": apply_clahe, 178 | "dropout": dropout, 179 | "fold_num": fold_num, 180 | "apply_scaling": apply_scaling, 181 | "polar_transform": polar_transform 182 | }) 183 | 184 | if scheduler == None: 185 | print("Scheduler is not set") 186 | 187 | 188 | if model_name in timm.list_models(model_name): 189 | model = timm.create_model(model_name,pretrained=pretrained,num_classes=2,drop_rate=dropout) 190 | else: 191 | assert False, f"Model {model_name} not recognized" 192 | model = model.to(device) 193 | 194 | print(f"Using Model: {model_name}") 195 | 196 | wandb.watch(model) 197 | 198 | criterion = CrossEntropyLoss(weight=torch.from_numpy(weight_referable).to(device)) 199 | 200 | if optimizer_name == "sgd": 201 | optimizer = optim.SGD(model.parameters(),lr=lr,momentum=momentum) 202 | elif optimizer_name == 'Adam': 203 | optimizer = optim.Adam(model.parameters(),lr=lr) 204 | else: 205 | assert False, f"Optimizer {optimizer} not recognized" 206 | 207 | if lr_step_period is None: 208 | lr_step_period = math.inf 209 | 210 | if scheduler == 'step': 211 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_step_period) 212 | elif scheduler == 'cosine': 213 | optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max) 214 | 215 | with open(os.path.join(output_dir, f"log_{exp_id}_{fold_num}.csv"), "a") as f: 216 | f.write("Train Dataset size: {}".format(len(train_dataset))) 217 | f.write("Validation Dataset size: {}".format(len(val_dataset))) 218 | 219 | epoch_resume = 0 220 | best_f1 = 0.0 221 | try: 222 | # Attempt to load checkpoint 223 | checkpoint = torch.load(os.path.join(output_dir, "checkpoint.pt")) 224 | model.load_state_dict(checkpoint['state_dict']) 225 | optimizer.load_state_dict(checkpoint['opt_dict']) 226 | if scheduler != None: 227 | scheduler.load_state_dict(checkpoint['scheduler_dict']) 228 | epoch_resume = checkpoint["epoch"] + 1 229 | best_f1 = checkpoint["best_f1"] 230 | f.write("Resuming from epoch {}\n".format(epoch_resume)) 231 | f.flush() 232 | except FileNotFoundError: 233 | f.write("Starting run from scratch\n") 234 | 235 | # Train 236 | if epoch_resume < epochs: 237 | print(f"---------\nTraining Fold {str(k)}\n") 238 | f.write("--------\nTraining Fold {}\n".format(str(k))) 239 | for epoch in range(epoch_resume, epochs): 240 | for split in ['Train', 'Val']: 241 | if split == "Train": 242 | model.train() 243 | epoch_total_loss = 0 244 | labels = [] 245 | predictions = [] 246 | loader = train_loader 247 | for batch_num, (inp, target) in enumerate(tqdm(loader)): 248 | labels+=target 249 | optimizer.zero_grad() 250 | output = model(inp.to(device)) 251 | _, batch_prediction = torch.max(output, dim=1) 252 | predictions += batch_prediction.detach().tolist() 253 | batch_loss = criterion(output, target.to(device)) 254 | epoch_total_loss += batch_loss.item() 255 | batch_loss.backward() 256 | optimizer.step() 257 | else: 258 | model.eval() 259 | with torch.no_grad(): 260 | epoch_total_loss = 0 261 | labels = [] 262 | predictions = [] 263 | loader = val_loader 264 | for batch_num, (inp, target) in enumerate(tqdm(loader)): 265 | labels+=target 266 | output = model(inp.to(device)) 267 | _, batch_prediction = torch.max(output, dim=1) 268 | predictions += batch_prediction.detach().tolist() 269 | batch_loss = criterion(output, target.to(device)) 270 | epoch_total_loss += batch_loss.item() 271 | 272 | avrg_loss = epoch_total_loss / loader.dataset.__len__() 273 | accuracy = metrics.accuracy_score(labels, predictions) 274 | confusion = metrics.confusion_matrix(labels, predictions) 275 | _f1_score = f1_score(labels, predictions, average="macro") 276 | auc = sklearn.metrics.roc_auc_score(labels, predictions) 277 | print("%s Epoch %d - loss=%0.4f AUC=%0.4f F1=%0.4f Accuracy=%0.4f" % (split, epoch, avrg_loss, auc, _f1_score, accuracy)) 278 | f.write("{} Epoch {} - loss={} AUC={} F1={} Accuracy={}\n".format(split, epoch, avrg_loss, auc, _f1_score, accuracy)) 279 | f.flush() 280 | 281 | if split == 'Train': 282 | wandb.log({"epoch": epoch, "train loss": avrg_loss, "train acc": accuracy, "train f1": _f1_score, "train auc": auc}) 283 | else: 284 | wandb.log({"epoch": epoch, "val loss": avrg_loss, "val acc": accuracy, "val f1": _f1_score, "val auc": auc}) 285 | f1s.append(_f1_score) 286 | aucs.append(auc) 287 | accuracies.append(accuracy) 288 | losses.append(avrg_loss) 289 | 290 | early_stopping(avrg_loss, model) 291 | if early_stopping.early_stop: 292 | print("Early stopping") 293 | break 294 | 295 | 296 | 297 | if k == (folds - 1) and split == "Val": 298 | f1_mean = np.mean(f1s) 299 | auc_mean = np.mean(aucs) 300 | acc_mean = np.mean(accuracies) 301 | loss_mean = np.mean(losses) 302 | 303 | print("--------\n%s Epoch %d - mean_loss=%0.4f mean_AUC=%0.4f mean_F1=%0.4f mean_Accuracy=%0.4f" % (split, epoch, loss_mean, auc_mean, f1_mean, acc_mean)) 304 | f.write("--------\n{} Epoch {} - mean_loss={} mean_AUC={} mean_F1={} mean_Accuracy={}\n".format(split, epoch, loss_mean, auc_mean, f1_mean, acc_mean)) 305 | wandb.log({"epoch": epoch, "val mean loss": loss_mean, "val mean acc": acc_mean, "val mean f1": f1_mean, "val mean auc": auc_mean}) 306 | f.flush() 307 | 308 | if scheduler != None: 309 | scheduler.step() 310 | 311 | # save model 312 | checkpoint = { 313 | 'epoch': epoch, 314 | 'best_f1': best_f1, 315 | 'f1': _f1_score, 316 | 'auc': auc, 317 | 'loss': avrg_loss, 318 | 'state_dict': model.state_dict(), 319 | 'opt_dict': optimizer.state_dict() 320 | } 321 | 322 | if scheduler != None: 323 | checkpoint["scheduler_dict"] = scheduler.state_dict() 324 | 325 | torch.save(checkpoint, os.path.join(output_dir, f"checkpoint_{fold_num}.pt")) 326 | if _f1_score > best_f1: 327 | best_f1 = _f1_score 328 | checkpoint["best_f1"] = best_f1 329 | torch.save(checkpoint, os.path.join(output_dir, f"best_{fold_num}.pt")) 330 | 331 | else: 332 | print("Skipping training\n") 333 | f.write("Skipping training\n") 334 | f.flush() 335 | 336 | # Testing 337 | if run_test: 338 | with open(os.path.join(output_dir, f"log_{exp_id}_{fold_num}.csv"), "a") as f: 339 | # Best F1 Score 340 | # max_f1 = max(f1s) 341 | # max_index = f1s.index(max_f1) 342 | 343 | checkpoint = torch.load(os.path.join(output_dir, f"best_{fold_num}.pt")) 344 | model.load_state_dict(checkpoint['state_dict']) 345 | f.write("---------\nBest F1 {} for fold {} epoch {}\n".format(checkpoint["best_f1"], str(fold_num), checkpoint["epoch"])) 346 | f.flush() 347 | print("-----------\nBest F1 {} for fold {} epoch {}\n".format(checkpoint["best_f1"], str(fold_num), checkpoint["epoch"])) 348 | 349 | test_dataset = Airogs( 350 | path=data_dir, 351 | images_dir_name=images_dir_name, 352 | split="test", 353 | transforms=transform, 354 | polar_transforms=polar_transform, 355 | apply_clahe = apply_clahe 356 | ) 357 | test_loader = DataLoader(test_dataset, 358 | batch_size=batch_size, 359 | shuffle=True, 360 | num_workers=num_workers, 361 | ) 362 | 363 | 364 | model.eval() 365 | labels = [] 366 | predictions = [] 367 | with torch.no_grad(): 368 | for (inp, target) in tqdm(test_loader): 369 | labels+=target 370 | batch_prediction = model(inp.to(device)) 371 | _, batch_prediction = torch.max(batch_prediction, dim=1) 372 | predictions += batch_prediction.detach().tolist() 373 | accuracy = metrics.accuracy_score(labels, predictions) 374 | f.write("Test Accuracy = {}\n".format(accuracy)) 375 | print("Test Accuracy = %0.2f" % (accuracy)) 376 | _f1_score = f1_score(labels, predictions, average="macro") 377 | f.write("Test F1 Score = {}\n".format(_f1_score)) 378 | print("Test F1 = %0.2f" % (_f1_score)) 379 | auc = sklearn.metrics.roc_auc_score(labels, predictions) 380 | f.write("Test AUC = {}\n".format(auc)) 381 | print("Test AUC = %0.2f" % (auc)) 382 | confusion = metrics.confusion_matrix(labels, predictions) 383 | f.write("Confusion Matrix = {}\n".format(confusion)) 384 | print(confusion) 385 | f.flush() 386 | 387 | wandb.log({"test acc": accuracy, "test f1": _f1_score, "test auc": auc}) 388 | 389 | 390 | 391 | if __name__ == "__main__": 392 | path_to_config = sys.argv[1] 393 | main(path_to_config) 394 | --------------------------------------------------------------------------------