├── README.md ├── app.py └── DR_classification.ipynb /README.md: -------------------------------------------------------------------------------- 1 | **BVAC (Blood Vessel Attention - Conditional) Generative Adversarial Network (GAN) for Diabetic Retinopathy (DR) diagnosis** 2 | ![WhatsApp Image 2025-05-05 at 5 28 04 PM](https://github.com/user-attachments/assets/83f676f2-4166-45d1-a3cf-327a8a0f6cfb) 3 | ![WhatsApp Image 2025-05-05 at 5 28 07 PM (2)](https://github.com/user-attachments/assets/2e45dd15-12f0-48be-9a1c-d9bb885ff9ee) 4 | 5 | THE APP TESTED AND DEPLOYED IN HOSPITAL 6 | 7 | ![APP Screenshot](https://github.com/user-attachments/assets/28ee70b0-40fb-47c7-8c6f-d5ed49bcf8b5) 8 | 9 | 10 | Screenshot of the User friendly APP 11 | 12 | 13 | 14 | **File Descriptions and order** 15 | ****Fundus_OCTA_cGAN_v1. ipynb** ** - Python code for BVAC GAN 16 | 17 | ****DR_classification. ipynb** ** - DR Classification with Fundus and DR labels from [3] & BVAC GAN synthesized OCTA pairs for the fundus images as supplementary input. 18 | 19 | **app.py** - Flask-based Application with user-friendly infterface for Doctors to upload fundus and seek OCTA equivalent. 20 | 21 | The Dataset is available in [1] and a similar research work to generate synthetic OCTA images using dataset [1] is available in [2]. 22 | DR_classification. ipynb - Python code to diagnose Diabetic Retinopathy with Fundus images from dataset [3] and BVAC GAN synthesized OCTA images (Output of Fundus_OCTA_cGAN_v1. ipynb) 23 | 24 | **ORDER OF EXECUTION** 25 | FIRST RUN **Fundus_OCTA_cGAN_v1. ipynb** to train BVAC GAN to synthesize OCTA paired images from the fundus imageas using dataset[1] available in [2]. 26 | NEXT RUN **DR_classification. ipynb** with dataset in [3]. This code generates OCTA pairs for fundus images in [3] and classifies DR with DR labels available in [3]. 27 | 28 | **Proposed Methodology:** 29 | The fundus imaging has no depth information of retinal vasculature and OCTA has no direct visualization of the Optic Disc, macula. Both the modalities hold complementary feature maps with respect to DR diagnosis. We propose a first of the kind model, combining the two modalities i. (original) fundus images ii. (GAN synthesized) OCT-A images to diagnose DR. We feed the fundus images and BVAC synthesized OCT-A images to 2 ResNet 50 CNNs with a merged Fully Connected layer for DR diagnosis. 30 | 31 | The following steps are executed to generate synthetic OCT-A image pair from Fundus image and diagnose DR. 32 | 1. Pre-processing steps 33 | 2. Pipeline to create dataset for train, test of Fundus and OCT-A images 34 | 3. Generator: Improved U-Net with Squeeze & Excitation (Entropy based) block 35 | Encoder stage of U-Net: customized threshold-based Squeeze and excitaion (SE) block 36 | 4. Loss: Generator loss, discriminator loss 37 | 5. Training the model 38 | 6. Develop a Web based application (APP) to predict OCT-A pair from Fundus image 39 | 40 | The below fig shows a fundus image sample, OCTA synthesized by conventional GAN, our BVAC GAN and the ground truth OCT-A 41 | ![image](https://github.com/user-attachments/assets/946a4a99-937b-449c-ace7-4c8d172f2cfa) 42 | 43 | ![WhatsApp Image 2025-05-09 at 2 26 27 PM (1)](https://github.com/user-attachments/assets/45058898-d3f1-49ac-b7c5-3b5499269d8b) 44 | 45 | DR Diagnosis 46 | 47 | ![WhatsApp Image 2025-05-09 at 2 26 26 PM](https://github.com/user-attachments/assets/35c36d6f-771e-4003-b908-4aeff1e5d1f1) 48 | 49 | Entropy based Squeeze & Excitation 50 | 51 | ![image](https://github.com/user-attachments/assets/38140e5e-d1c8-47ec-a8ad-4811438eaba8) 52 | 53 | OUTCOMES OF OUR BVAC GAN 54 | 55 | [ IN FIG : FUNUS - OCTA predicted by cGAN - BVAC GAN - Ground truth] 56 | 57 | References: 58 | 1. https://zenodo.org/records/6476639 59 | 2. Coronado I, Pachade S, Trucco E, Abdelkhaleq R, Yan J, Salazar-Marioni S, Jagolino-Cole A, Bahrainian M, Channa R, Sheth SA, Giancardo L. Synthetic OCT-A blood vessel maps using fundus images and generative adversarial networks. Sci Rep 2023;13:15325. https://doi.org/10.1038/s41598-023-42062-9. 60 | 3. https://www.kaggle.com/datasets/benjaminwarner/resized-2015-2019-blindness-detection-images 61 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | from flask import Flask, request, jsonify, send_file 3 | from PIL import Image 4 | import numpy as np 5 | import tensorflow as tf 6 | import io 7 | 8 | # --- Model Loading (same as your code) --- 9 | from tensorflow.keras import layers 10 | 11 | class InstanceNormalization(tf.keras.layers.Layer): 12 | def __init__(self, epsilon=1e-5, **kwargs): 13 | super(InstanceNormalization, self).__init__(**kwargs) 14 | self.epsilon = epsilon 15 | 16 | def build(self, input_shape): 17 | self.scale = self.add_weight(name="scale", 18 | shape=(input_shape[-1],), 19 | initializer="ones", 20 | trainable=True) 21 | self.offset = self.add_weight(name="offset", 22 | shape=(input_shape[-1],), 23 | initializer="zeros", 24 | trainable=True) 25 | super(InstanceNormalization, self).build(input_shape) 26 | 27 | def call(self, inputs): 28 | mean, variance = tf.nn.moments(inputs, axes=[1, 2], keepdims=True) 29 | normalized = (inputs - mean) / tf.sqrt(variance + self.epsilon) 30 | return self.scale * normalized + self.offset 31 | 32 | def get_config(self): 33 | config = super(InstanceNormalization, self).get_config() 34 | config.update({"epsilon": self.epsilon}) 35 | return config 36 | 37 | class ThresholdSEBlock(tf.keras.layers.Layer): 38 | def __init__(self, channels, reduction=16, threshold=0.5, **kwargs): 39 | super(ThresholdSEBlock, self).__init__(**kwargs) 40 | self.channels = channels 41 | self.reduction = reduction 42 | self.threshold = threshold 43 | self.global_avg_pool = layers.GlobalAveragePooling2D() 44 | self.fc1 = layers.Dense(channels // reduction, activation='relu') 45 | self.fc2 = layers.Dense(channels, activation='sigmoid') 46 | 47 | def call(self, inputs): 48 | x = self.global_avg_pool(inputs) 49 | x = self.fc1(x) 50 | x = self.fc2(x) 51 | x = tf.where(x > self.threshold, x, tf.zeros_like(x)) 52 | input_shape = tf.shape(inputs) 53 | reshape_shape = [input_shape[0], 1, 1, self.channels] 54 | x = tf.reshape(x, reshape_shape) 55 | return inputs * x 56 | 57 | def get_config(self): 58 | config = super(ThresholdSEBlock, self).get_config() 59 | config.update({ 60 | "channels": self.channels, 61 | "reduction": self.reduction, 62 | "threshold": self.threshold 63 | }) 64 | return config 65 | 66 | # --- Load model --- 67 | custom_objects = { 68 | 'InstanceNormalization': InstanceNormalization, 69 | 'ThresholdSEBlock': ThresholdSEBlock 70 | } 71 | # "C:\Users\B SAKETH REDDY\Downloads\Fundas\" 72 | model_path = 'C:\Personal\Research\IEEE-IES-GenAI-Hackathon2025\coding\Fundus_to_OCTA\generator_g.h5' 73 | if not os.path.exists(model_path): 74 | raise FileNotFoundError(f"Model file not found at: {model_path}") 75 | 76 | loaded_generator_g = tf.keras.models.load_model(model_path, custom_objects=custom_objects) 77 | 78 | # --- Flask App --- 79 | app = Flask(__name__) 80 | 81 | IMG_HEIGHT = 256 82 | IMG_WIDTH = 256 83 | CHANNELS = 3 84 | 85 | def preprocess_image(image_bytes): 86 | image = Image.open(io.BytesIO(image_bytes)).convert('RGB') 87 | image = image.resize((IMG_WIDTH, IMG_HEIGHT)) 88 | image_array = np.array(image).astype(np.float32) / 127.5 - 1.0 # normalize to [-1, 1] 89 | return np.expand_dims(image_array, axis=0) 90 | 91 | def postprocess_image(output_tensor): 92 | output_array = (output_tensor[0] * 0.5 + 0.5) * 255.0 # scale back to [0, 255] 93 | output_array = np.clip(output_array, 0, 255).astype(np.uint8) 94 | output_image = Image.fromarray(output_array) 95 | return output_image 96 | 97 | @app.route('/generate', methods=['POST']) 98 | def generate(): 99 | if 'image' not in request.files: 100 | return jsonify({'error': 'No image file provided.'}), 400 101 | 102 | image_file = request.files['image'] 103 | try: 104 | input_tensor = preprocess_image(image_file.read()) 105 | prediction = loaded_generator_g.predict(input_tensor) 106 | output_image = postprocess_image(prediction) 107 | 108 | img_io = io.BytesIO() 109 | output_image.save(img_io, 'PNG') 110 | img_io.seek(0) 111 | return send_file(img_io, mimetype='image/png') 112 | except Exception as e: 113 | return jsonify({'error': str(e)}), 500 114 | 115 | @app.route('/') 116 | def index(): 117 | return ''' 118 | 119 | 120 | 121 | 122 | 123 | 124 | AI Image Generator 125 | 126 | 362 | 363 | 364 |
365 |
366 | 370 |
371 | 372 |

AI-Powered Image Generator

373 |
374 | 375 | 376 |

No file chosen

377 | 378 |
379 |
380 |
381 |

Generated image will appear below:

382 |
383 | Uploaded Image 384 | Generated Image 385 |
386 | 387 |
388 |
389 | 390 | 469 | 470 | 471 | ''' 472 | 473 | 474 | 475 | if __name__ == '__main__': 476 | app.run(debug=True) 477 | -------------------------------------------------------------------------------- /DR_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "gpuType": "T4" 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "source": [ 22 | "Classify Diabetic Retinopathy (DR) into five-classes:\n", 23 | "\n", 24 | "0 - No DR,\n", 25 | "1 - Mild,\n", 26 | "2 - Moderate,\n", 27 | "3 - Severe,\n", 28 | "4 - Proliferative DR" 29 | ], 30 | "metadata": { 31 | "id": "TtIHiwampJqm" 32 | } 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "source": [ 37 | "To get updated .csv file based on the filename of the folder. other rows gets deleted from the .csv file." 38 | ], 39 | "metadata": { 40 | "id": "Xx2ba3AZrpd3" 41 | } 42 | }, 43 | { 44 | "cell_type": "code", 45 | "source": [ 46 | "import os\n", 47 | "import pandas as pd\n", 48 | "\n", 49 | "# Configuration\n", 50 | "csv_file_path = '/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/trainLabels19.csv' # Path to your CSV file\n", 51 | "image_folder = '/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/ResizedTrain19-samples' # Path to your image dataset\n", 52 | "output_csv_path = 'updated_trainLabels19.csv' # Path for the updated CSV\n", 53 | "\n", 54 | "# Step 1: Load the CSV file\n", 55 | "df = pd.read_csv(csv_file_path)\n", 56 | "\n", 57 | "# Step 2: Get list of actual image files in the dataset\n", 58 | "# Assuming images are in .png or .jpg format - adjust extensions if needed\n", 59 | "image_extensions = ('.png', '.jpg', '.jpeg')\n", 60 | "image_files = set()\n", 61 | "for file in os.listdir(image_folder):\n", 62 | " if file.lower().endswith(image_extensions):\n", 63 | " # Remove extension for matching with CSV\n", 64 | " base_name = os.path.splitext(file)[0]\n", 65 | " image_files.add(base_name)\n", 66 | "\n", 67 | "# Step 3: Filter the DataFrame to keep only rows with existing images\n", 68 | "filtered_df = df[df['id_code'].isin(image_files)]\n", 69 | "\n", 70 | "# Step 4: Save the filtered DataFrame to a new CSV\n", 71 | "filtered_df.to_csv(output_csv_path, index=False)\n", 72 | "\n", 73 | "print(f\"Original rows: {len(df)}, Filtered rows: {len(filtered_df)}\")\n", 74 | "print(f\"Updated CSV saved to: {output_csv_path}\")" 75 | ], 76 | "metadata": { 77 | "id": "fsSE9JpdXQi2", 78 | "colab": { 79 | "base_uri": "https://localhost:8080/" 80 | }, 81 | "outputId": "f18395a1-aa71-46ed-a18f-8120bd61c367" 82 | }, 83 | "execution_count": null, 84 | "outputs": [ 85 | { 86 | "output_type": "stream", 87 | "name": "stdout", 88 | "text": [ 89 | "Original rows: 3662, Filtered rows: 728\n", 90 | "Updated CSV saved to: updated_trainLabels19.csv\n" 91 | ] 92 | } 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "source": [ 98 | "The following scripts contains:\n", 99 | "(i) GAN-based OCT-A generation\n", 100 | "\n", 101 | "(ii) Training ResNet50Fusion model\n", 102 | "\n", 103 | "(iii) Saving the model\n", 104 | "\n", 105 | "(iv) Running inference on test data\n", 106 | "\n", 107 | "(v) Saving predictions to CSV" 108 | ], 109 | "metadata": { 110 | "id": "pYOzsceSjYs4" 111 | } 112 | }, 113 | { 114 | "cell_type": "code", 115 | "source": [ 116 | "import os\n", 117 | "import pandas as pd\n", 118 | "from PIL import Image\n", 119 | "import numpy as np\n", 120 | "import torch\n", 121 | "import torch.nn as nn\n", 122 | "from torch.utils.data import Dataset, DataLoader\n", 123 | "from torchvision import models, transforms\n", 124 | "import tensorflow as tf\n", 125 | "\n", 126 | "# ----------------------------\n", 127 | "# Load BVAC GAN generator\n", 128 | "# ----------------------------\n", 129 | "class InstanceNormalization(tf.keras.layers.Layer):\n", 130 | " def __init__(self, epsilon=1e-5, **kwargs):\n", 131 | " super().__init__(**kwargs)\n", 132 | " self.epsilon = epsilon\n", 133 | "\n", 134 | " def build(self, input_shape):\n", 135 | " self.scale = self.add_weight(name=\"scale\", shape=(input_shape[-1],), initializer=\"ones\", trainable=True)\n", 136 | " self.offset = self.add_weight(name=\"offset\", shape=(input_shape[-1],), initializer=\"zeros\", trainable=True)\n", 137 | "\n", 138 | " def call(self, inputs):\n", 139 | " mean, var = tf.nn.moments(inputs, [1, 2], keepdims=True)\n", 140 | " normalized = (inputs - mean) / tf.sqrt(var + self.epsilon)\n", 141 | " return self.scale * normalized + self.offset\n", 142 | "\n", 143 | "class ThresholdSEBlock(tf.keras.layers.Layer):\n", 144 | " def __init__(self, channels, reduction=16, threshold=0.5, **kwargs):\n", 145 | " super().__init__(**kwargs)\n", 146 | " self.channels = channels\n", 147 | " self.reduction = reduction\n", 148 | " self.threshold = threshold\n", 149 | " self.global_avg_pool = tf.keras.layers.GlobalAveragePooling2D()\n", 150 | " self.fc1 = tf.keras.layers.Dense(channels // reduction, activation='relu')\n", 151 | " self.fc2 = tf.keras.layers.Dense(channels, activation='sigmoid')\n", 152 | "\n", 153 | " def call(self, inputs):\n", 154 | " x = self.global_avg_pool(inputs)\n", 155 | " x = self.fc1(x)\n", 156 | " x = self.fc2(x)\n", 157 | " x = tf.where(x > self.threshold, x, tf.zeros_like(x))\n", 158 | " x = tf.reshape(x, [-1, 1, 1, self.channels])\n", 159 | " return inputs * x\n", 160 | "\n", 161 | "generator = tf.keras.models.load_model(\n", 162 | " '/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/generator_g.h5',\n", 163 | " custom_objects={'InstanceNormalization': InstanceNormalization, 'ThresholdSEBlock': ThresholdSEBlock}\n", 164 | ")\n", 165 | "\n", 166 | "# ----------------------------\n", 167 | "# Image Pre/Post-Processing\n", 168 | "# ----------------------------\n", 169 | "def preprocess_tf(img_path):\n", 170 | " img = Image.open(img_path).convert(\"RGB\").resize((256, 256))\n", 171 | " arr = np.array(img).astype(np.float32) / 127.5 - 1.0\n", 172 | " return np.expand_dims(arr, 0)\n", 173 | "\n", 174 | "def postprocess_tf(tensor):\n", 175 | " tensor = (tensor[0] * 0.5 + 0.5) * 255.0\n", 176 | " return Image.fromarray(np.clip(tensor, 0, 255).astype(np.uint8))\n", 177 | "\n", 178 | "def generate_synthetic_oct(images_dir, synthetic_dir):\n", 179 | " os.makedirs(synthetic_dir, exist_ok=True)\n", 180 | " for img_name in os.listdir(images_dir):\n", 181 | " if not img_name.lower().endswith(('.png', '.jpg', '.jpeg')):\n", 182 | " continue\n", 183 | " img_path = os.path.join(images_dir, img_name)\n", 184 | " input_tensor = preprocess_tf(img_path)\n", 185 | " output_tensor = generator.predict(input_tensor)\n", 186 | " output_img = postprocess_tf(output_tensor)\n", 187 | " output_img.save(os.path.join(synthetic_dir, img_name))\n", 188 | "\n", 189 | "# ----------------------------\n", 190 | "# Dataset Classes\n", 191 | "# ----------------------------\n", 192 | "class FusedDataset(Dataset):\n", 193 | " def __init__(self, csv_file, fundus_dir, synthetic_dir):\n", 194 | " self.data = pd.read_csv(csv_file)\n", 195 | " self.fundus_dir = fundus_dir\n", 196 | " self.synthetic_dir = synthetic_dir\n", 197 | " self.transform = transforms.Compose([\n", 198 | " transforms.Resize((224, 224)),\n", 199 | " transforms.ToTensor(),\n", 200 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 201 | " ])\n", 202 | "\n", 203 | " def __len__(self):\n", 204 | " return len(self.data)\n", 205 | "\n", 206 | " def __getitem__(self, idx):\n", 207 | " row = self.data.iloc[idx]\n", 208 | " img_id = row['id_code']\n", 209 | " label = int(row['diagnosis'])\n", 210 | " extensions = ['.jpg', '.jpeg', '.png']\n", 211 | " for ext in extensions:\n", 212 | " fp = os.path.join(self.fundus_dir, img_id + ext)\n", 213 | " sp = os.path.join(self.synthetic_dir, img_id + ext)\n", 214 | " if os.path.exists(fp) and os.path.exists(sp):\n", 215 | " fundus = Image.open(fp).convert(\"RGB\")\n", 216 | " synthetic = Image.open(sp).convert(\"RGB\")\n", 217 | " fundus = self.transform(fundus)\n", 218 | " synthetic = self.transform(synthetic)\n", 219 | " fused = torch.cat((fundus, synthetic), dim=0)\n", 220 | " return fused, label\n", 221 | " raise FileNotFoundError(f\"{img_id} not found.\")\n", 222 | "\n", 223 | "class TestDataset(Dataset):\n", 224 | " def __init__(self, csv_file, fundus_dir, synthetic_dir):\n", 225 | " self.data = pd.read_csv(csv_file)\n", 226 | " self.fundus_dir = fundus_dir\n", 227 | " self.synthetic_dir = synthetic_dir\n", 228 | " self.transform = transforms.Compose([\n", 229 | " transforms.Resize((224, 224)),\n", 230 | " transforms.ToTensor(),\n", 231 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 232 | " ])\n", 233 | "\n", 234 | " def __len__(self):\n", 235 | " return len(self.data)\n", 236 | "\n", 237 | " def __getitem__(self, idx):\n", 238 | " row = self.data.iloc[idx]\n", 239 | " img_id = row['id_code']\n", 240 | " extensions = ['.jpg', '.jpeg', '.png']\n", 241 | " for ext in extensions:\n", 242 | " fp = os.path.join(self.fundus_dir, img_id + ext)\n", 243 | " sp = os.path.join(self.synthetic_dir, img_id + ext)\n", 244 | " if os.path.exists(fp) and os.path.exists(sp):\n", 245 | " fundus = Image.open(fp).convert(\"RGB\")\n", 246 | " synthetic = Image.open(sp).convert(\"RGB\")\n", 247 | " fundus = self.transform(fundus)\n", 248 | " synthetic = self.transform(synthetic)\n", 249 | " fused = torch.cat((fundus, synthetic), dim=0)\n", 250 | " return fused, img_id\n", 251 | " raise FileNotFoundError(f\"{img_id} not found.\")\n", 252 | "\n", 253 | "# ----------------------------\n", 254 | "# Model Definition\n", 255 | "# ----------------------------\n", 256 | "class ResNet50Fusion(nn.Module):\n", 257 | " def __init__(self, num_classes=5):\n", 258 | " super().__init__()\n", 259 | " base = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)\n", 260 | " self.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)\n", 261 | " self.conv1.weight.data[:, :3] = base.conv1.weight.data\n", 262 | " self.conv1.weight.data[:, 3:] = base.conv1.weight.data.clone()\n", 263 | " base.conv1 = self.conv1\n", 264 | " base.fc = nn.Linear(2048, num_classes)\n", 265 | " self.model = base\n", 266 | "\n", 267 | " def forward(self, x):\n", 268 | " return self.model(x)\n", 269 | "\n", 270 | "# ----------------------------\n", 271 | "# Training Function\n", 272 | "# ----------------------------\n", 273 | "def train_model():\n", 274 | " fundus_dir = \"/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/ResizedTrain19-samples\"\n", 275 | " synthetic_dir = \"/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/synthetic_octa\"\n", 276 | " csv_file = \"/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/updated_trainLabels19.csv\"\n", 277 | "\n", 278 | " generate_synthetic_oct(fundus_dir, synthetic_dir)\n", 279 | "\n", 280 | " dataset = FusedDataset(csv_file, fundus_dir, synthetic_dir)\n", 281 | " loader = DataLoader(dataset, batch_size=16, shuffle=True)\n", 282 | "\n", 283 | " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 284 | " model = ResNet50Fusion(num_classes=5).to(device)\n", 285 | " optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n", 286 | " criterion = nn.CrossEntropyLoss()\n", 287 | "\n", 288 | " for epoch in range(10):\n", 289 | " model.train()\n", 290 | " total_loss, correct = 0, 0\n", 291 | " for inputs, labels in loader:\n", 292 | " inputs, labels = inputs.to(device), labels.to(device)\n", 293 | " optimizer.zero_grad()\n", 294 | " outputs = model(inputs)\n", 295 | " loss = criterion(outputs, labels)\n", 296 | " loss.backward()\n", 297 | " optimizer.step()\n", 298 | " total_loss += loss.item()\n", 299 | " correct += (outputs.argmax(1) == labels).sum().item()\n", 300 | " print(f\"Epoch {epoch+1}, Loss: {total_loss:.4f}, Acc: {correct/len(dataset):.4f}\")\n", 301 | "\n", 302 | " torch.save(model.state_dict(), \"resnet50_fused_model.pth\")\n", 303 | " print(\"Model saved.\")\n", 304 | "\n", 305 | "# ----------------------------\n", 306 | "# Inference Function\n", 307 | "# ----------------------------\n", 308 | "def predict(model, dataloader, device):\n", 309 | " model.eval()\n", 310 | " results = []\n", 311 | " with torch.no_grad():\n", 312 | " for inputs, img_ids in dataloader:\n", 313 | " inputs = inputs.to(device)\n", 314 | " outputs = model(inputs)\n", 315 | " preds = torch.argmax(outputs, dim=1).cpu().numpy()\n", 316 | " for img_id, pred in zip(img_ids, preds):\n", 317 | " results.append((img_id, pred))\n", 318 | " return results\n", 319 | "\n", 320 | "def run_inference():\n", 321 | " test_csv = \"/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/trainLabels19_test.csv\"\n", 322 | " fundus_dir = \"/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/ResizedTrain19_test\"\n", 323 | " synthetic_dir = \"/content/drive/MyDrive/Colab Notebooks/Fundus_OCTA_GAN/synthetic_test_octa\"\n", 324 | "\n", 325 | " generate_synthetic_oct(fundus_dir, synthetic_dir)\n", 326 | "\n", 327 | " test_dataset = TestDataset(test_csv, fundus_dir, synthetic_dir)\n", 328 | " test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)\n", 329 | "\n", 330 | " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 331 | " model = ResNet50Fusion(num_classes=5).to(device)\n", 332 | " model.load_state_dict(torch.load(\"resnet50_fused_model.pth\"))\n", 333 | "\n", 334 | " predictions = predict(model, test_loader, device)\n", 335 | " pd.DataFrame(predictions, columns=[\"id_code\", \"predicted_label\"]).to_csv(\"DR_predictions.csv\", index=False)\n", 336 | " print(\"Predictions saved to DR_predictions.csv\")\n", 337 | "\n", 338 | "# ----------------------------\n", 339 | "# Run Training and Inference\n", 340 | "# ----------------------------\n", 341 | "if __name__ == \"__main__\":\n", 342 | " train_model()\n", 343 | " run_inference()\n" 344 | ], 345 | "metadata": { 346 | "id": "lRy_xu-Sj3Mw" 347 | }, 348 | "execution_count": null, 349 | "outputs": [] 350 | } 351 | ] 352 | } --------------------------------------------------------------------------------