├── GettingStarted.md ├── README.md ├── Single_test_inference.ipynb ├── blindness.py ├── images ├── gui1.JPG ├── gui2.JPG ├── gui3.JPG ├── mat.png ├── sms.JPG ├── vis.gif └── visual1.JPG ├── inference.ipynb ├── model.py ├── requirements.txt ├── sampleimages ├── eye1.png ├── eye10.jpg ├── eye11.png ├── eye12.jpg ├── eye13.jpg ├── eye14.jpg ├── eye15.jpg ├── eye16.png ├── eye17.png ├── eye18.png ├── eye19.png ├── eye2.png ├── eye20.png ├── eye3.png ├── eye4.jpg ├── eye5.jpg ├── eye6.jpg ├── eye7.jpg ├── eye8.jpg └── eye9.jpg ├── send_sms.py └── training.ipynb /GettingStarted.md: -------------------------------------------------------------------------------- 1 | ## Installation : 2 | Tip : Make sure to install [Numpy](https://pypi.org/project/numpy/), [Pandas](https://pypi.org/project/pandas/), [Matplotlib](https://pypi.org/project/matplotlib/) first and then proceed next. 3 | * [Torch package](https://pytorch.org/get-started/locally/) 4 | * [Tkinter](https://tkdocs.com/tutorial/install.html) 5 | * [HeidiSQL](https://www.heidisql.com/download.php) 6 | Grab a cup of coffee as these will take some time ! 7 | * [click here](https://support.hypernode.com/knowledgebase/use-heidisql/#Download_HeidiSQL) to start server in HeidiSQL and configure settings by setting username and password. 8 | ## Get, set and go : 9 | * Download complete Project files using following command from git bash/ cmd (terminal): 10 | ``` 11 | git clone https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch 12 | 13 | ``` 14 | [Note you need mainly these four things to get started :] 15 | > [blindness.py](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/blindness.py) 16 | > [model.py](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/model.py) 17 | > [classifier.pt](#) 18 | > [send_sms.py](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/send_sms.py) 19 | 20 | * Create a new database and table accordingly. 21 | * Then, Go to 'blindness.py' file and change some configuration settings according to your database. 22 | ``` 23 | connection = sk.connect( 24 | host="localhost", 25 | user="root", 26 | password="********", 27 | database="********" 28 | ) 29 | ``` 30 | * Now, your DB server must be connected. 31 | * Finally, you also want 'classifier.pt' file which contains model's dictionary required when it is to be loaded. 32 | [Download here](https://www.kaggle.com/souravs17031999/blindness-detection-pretrained-weights-pytorch) and put that file in the same directory and then modify the path accordingly in the 'model.py' file. 33 | ``` 34 | model = load_model('../Desktop/classifier.pt') 35 | 36 | ``` 37 | * Finally, execute your 'blindness.py' file and your GUI must start (recommended to start this from your terminal and keep all your project files in same directory). 38 | * Upload the image and get your predictions. 39 | 40 | ## Optional : 41 | * If you want to get SMS on mobile for your predictions , then Create an account on [Twilio](http://twilio.com/) by verifying your number. 42 | * Next, get your credentials from the Dashboard of twilio and use 'send_sms.py' to fetch API request and fill and replace your credentials. 43 | [Note, currently 'send_sms.py' is commented out, so uncomment everything before using it]. 44 | * Then, uncomment following lines in 'blindness.py' 45 | ```#from send_sms import *``` 46 | ```#send(value, classes)``` 47 | * Your messaging service should start and also you now see Message_Id printed on your terminal. 48 | 49 | 50 | [Note : You can use sample images in the folder [sampleimages](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/tree/master/sampleimages) which is taken from the original test dataset to test the system] 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Project Name : Retinal Blindness (Diabetic Retinopathy) detection 2 | 3 | # Problem statement : 4 | Diabetic Retinopathy is a disease with an increasing prevalence and the main cause of blindness among working-age population. The risk of severe vision loss can be significantly reduced by timely diagnosis and treatment. Systematic screening for DR has been identified as a cost-effective way to save health services resources. Automatic retinal image analysis is emerging as an important screening tool for early DR detection, which can reduce the workload associated to manual grading as well as save diagnosis costs and time. Many research efforts in the last years have been devoted to developing automated tools to help in the detection and evaluation of DR lesions. 5 | We are interested in automating this predition using deep learning models. 6 | 7 | # Dataset : [APOTS Kaggle Blindness dataset](https://www.kaggle.com/c/aptos2019-blindness-detection) 8 | 9 | # Solution : 10 | I am proposing Deep Learning classification technique using CNN pretrained model [resnet152](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) to classify severity levels of DR ranging from 0 (NO DR) to 4 (Proliferative DR). 11 | This is a collaborative project of team of three where my main work is on developing, training and testing various CNN models along with some secondary work. 12 | Deep learning looks promising because already various types of image classification tasks has been performed by various CNN's so, we can rely on DL pretrained models or we can modify some layers if we wish to :) 13 | A GUI based system has been made using Tkinter and used heidiSQL to maintain and store a list of predictions with their patient id and name (which is very risky , the reason we will get to it some time later). 14 | Twilio API have been used to Make SMS connectivity to patients possible in case they are not contactable or accesible (in that case we can also use mail). 15 | 16 | # Summary of Technologies used in this project : 17 | | Dev Env. | Framework/ library/ languages | 18 | | ------------- | ------------- | 19 | | Backend development | PyTorch (Deep learning framework) | 20 | | Frontend development | Tkinter (Python GUI toolkit) | 21 | | Database connectivity | HeidiSQL (MySQL server) | 22 | | Programming Languages | Python, SQL | 23 | | API | Twilio cloud API| 24 | 25 | # Data visualization : 26 | Input data (raw) is like this - 27 | ![visual1](images/visual1.JPG) 28 | 29 | # Resnet152 model summary : 30 | I have only shown below the main layers of resnet and each of the 'layer1', 'layer2', 'layer3' and 'layer4' contains various more layers. 31 | 32 | ![mat](images/mat.png) 33 | 34 | # Visualization of complete system : 35 | ![visual](images/vis.gif) 36 | 37 | 38 | # Getting started : 39 | [Click](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/GettingStarted.md) here to get started locally on your system. 40 | 41 | ## Some snaps : 42 | ![images/gui1.JPG](images/gui1.JPG) 43 | ![images/gui2.JPG](images/gui2.JPG) 44 | ![images/gui3.JPG](images/gui3.JPG) 45 | ![images/sms.JPG](images/sms.JPG) 46 | 47 | 48 | # Future Prospect : 49 | * My next goal is to develop this into WebApp (probably using some light weight model as resnet models are heavy). 50 | * Next goal will be using encryption techniques to achieve not only high accuracy but also high level of privacy in terms of differentially private basis and use technqiues such as Federated learning and Secure Multi party computation for privacy preserving deep learning classification. 51 | Btw, i have already made one project using federated learning on classification task , [check out here](https://github.com/souravs17031999/Federatedencryption-showcase). 52 | Acheiving a level of privacy is also very important task in medical datasets so that there can be factor of trust established between different stakeholders using the system. 53 | * Some ideas for concurrency control has to be implemented properly using some kind of locks defined in MySQL so that multiple users can use the system at the same time when deployed on web. 54 | (Otherwise, locally you can run the executable file multiple times to open and run the GUI and it works fine). 55 | * Reducing TYPE-II error (false negatives) as this metric is really useful in Healthcare domain. 56 | 57 | # Navigating the project : 58 | * [Check out the training code here](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/training.ipynb) 59 | * [Check out testing done on unseen image](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/Single_test_inference.ipynb) 60 | * [Check out the executable file (for running GUI)](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/blindness.py) 61 | * [check out the model executable file (for loading to get inference locally)](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/model.py) 62 | * [check out the Twilio API executable file (to get SMS for inference)](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/send_sms.py) 63 | * [for getting pre-trained weights for this model, check out getting started section here](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/GettingStarted.md) 64 | 65 | [Note : The training files in this repo is only shown after final training as it took around more than 100 epochs to reach 97% accuracy and a lot of compute power and time.] 66 | 67 | 68 | ⭐️ this Project if you liked it ! 69 | -------------------------------------------------------------------------------- /blindness.py: -------------------------------------------------------------------------------- 1 | # Importing all packages 2 | from tkinter import * 3 | from tkinter.ttk import * 4 | from tkinter import messagebox 5 | from PIL import Image 6 | import os 7 | 8 | import mysql.connector 9 | from tkinter.filedialog import askopenfilename, asksaveasfilename 10 | import mysql.connector as sk 11 | from model import * 12 | #from send_sms import * 13 | print('GUI SYSTEM STARTED...') 14 | #--------------------------------------------------------------------------------- 15 | 16 | def LogIn(): 17 | username = box1.get() 18 | 19 | u = 1 20 | 21 | if len(username) == 0: 22 | u = 0 23 | messagebox.showinfo("Error", "You must enter something Sir") 24 | 25 | if u: 26 | password = box2.get() 27 | 28 | if len(password): 29 | query = "SELECT * FROM THEGREAT" 30 | 31 | sql.execute(query) 32 | 33 | data = sql.fetchall() 34 | 35 | g = 0 36 | b = 0 37 | 38 | for i in data: 39 | if i[0] == username: 40 | g = 1 41 | if i[1] == password: 42 | b = 1 43 | 44 | 45 | if g and b: 46 | messagebox.showinfo('Hello Sir', 'Welcome to the System') 47 | else: 48 | messagebox.showinfo('Sorry', 'Wrong Username or Password') 49 | 50 | global y 51 | y = True 52 | else: 53 | messagebox.showinfo("Error", "You must enter a password Sir!!") 54 | 55 | def OpenFile(): 56 | username = box1.get() 57 | if y: 58 | try: 59 | a = askopenfilename() 60 | print(a) 61 | value, classes = main(a) 62 | messagebox.showinfo("your report", ("Predicted Label is ", value, "\nPredicted Class is ", classes)) 63 | 64 | query = 'UPDATE THEGREAT SET PREDICT = "%s" WHERE USERNAME = "%s"'%(value, username) 65 | 66 | sql.execute(query) 67 | #print(query) 68 | connection.commit() 69 | 70 | #------********************Only use when required to send message 71 | #send(value, classes) 72 | #------********************************************************* 73 | image = Image.open(a) 74 | # plotting image 75 | file = image.convert('RGB') 76 | plt.imshow(np.array(file)) 77 | plt.title(f'your report is label : {value} class : {classes}') 78 | plt.show() 79 | #print(image) 80 | print('Thanks for using the system !') 81 | #fn, text = os.path.splitext(a) #fn stands for filename 82 | except Exception as error: 83 | print("File not selected ! Exiting..., Please try again") 84 | 85 | 86 | else: 87 | messagebox.showinfo("Hello Sir", "You need to Login first") 88 | 89 | 90 | x = 0 91 | y = False 92 | 93 | 94 | def Signup(): 95 | username = box1.get() 96 | password = box2.get() 97 | 98 | u = 1 99 | 100 | if len(username) == 0 or len(password) == 0: 101 | u = 0 102 | messagebox.showinfo("Error", "You must enter something Sir") 103 | 104 | if u: 105 | query1 = "SELECT * FROM THEGREAT" 106 | sql.execute(query1) 107 | 108 | data = sql.fetchall() 109 | 110 | z = 1 111 | 112 | for i in data: 113 | if i[0] == username: 114 | messagebox.showinfo("Sorry Sir", "This username is already registered, try a new one") 115 | z = 0 116 | 117 | if z: 118 | query = "INSERT INTO THEGREAT (USERNAME, PASSWORD) VALUES('%s', '%s')" % (username, password) 119 | messagebox.showinfo("signed up", ("Hi ",username ,"\n Now you can login with your credentials !")) 120 | sql.execute(query) 121 | connection.commit() 122 | 123 | 124 | #----------------------------------------------------------------------------------------- 125 | 126 | 127 | connection = sk.connect( 128 | host="localhost", 129 | user="root", 130 | password="SOURAVs99@", 131 | database="batch_db_new" 132 | ) 133 | 134 | sql = connection.cursor() 135 | 136 | root = Tk() 137 | 138 | root.geometry('700x400') 139 | root.title("SK's Blindness Detection System") 140 | root.configure(bg='pale turquoise') 141 | 142 | 143 | label1 = Label(root, text="Demo for BDS", font=('Arial', 30)) 144 | label1.grid(padx=30, pady=30, row=0, column=0, sticky='W') 145 | 146 | label2 = Label(root, text="Enter your username: ", font=('Arial', 20)) 147 | label2.grid(padx=10, pady=10, row=1, column=0, sticky='W') 148 | 149 | label3 = Label(root, text="Enter your password: ", font=('Arial', 20)) 150 | label3.grid(padx=10, pady=20, row=2, column=0, sticky='W') 151 | 152 | box1 = Entry(root) 153 | box1.grid(row=1, column=1) 154 | 155 | box2 = Entry(root, show='*') 156 | box2.grid(row=2, column=1) 157 | 158 | button3 = Button(root, text="Signup", command=Signup) 159 | button3.grid(padx=10, pady=20, row=3, column=1) 160 | 161 | button1 = Button(root, text="LogIn", command=LogIn) 162 | button1.grid(padx=10, pady=20, row=3, column=2) 163 | 164 | button2 = Button(root, text="Upload Image", command=OpenFile) 165 | button2.grid(padx=10, pady=20, row=2, column=3) 166 | 167 | # concurrency control in InnoDB 168 | # Read_locks useful when locks another user trying to update the value in the same row which is allocated for another user , both at the same time 169 | #SELECT * FROM t1, t2 FOR SHARE OF t1 FOR UPDATE OF t2; 170 | # START TRANSACTION; 171 | # SELECT * FROM your_table WHERE state != 'PROCESSING' 172 | # ORDER BY date_added ASC LIMIT 1 FOR UPDATE; 173 | # if (rows_selected = 0) { //finished processing the queue, abort} 174 | # else { 175 | # UPDATE your_table WHERE id = $row.id SET state = 'PROCESSING' 176 | # COMMIT; 177 | # 178 | # // row is processed here, outside of the transaction, and it can take as much time as we want 179 | # 180 | # // once we finish: 181 | # DELETE FROM your_table WHERE id = $row.id and state = 'PROCESSING' LIMIT 1; 182 | # } 183 | 184 | root.mainloop() 185 | -------------------------------------------------------------------------------- /images/gui1.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/images/gui1.JPG -------------------------------------------------------------------------------- /images/gui2.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/images/gui2.JPG -------------------------------------------------------------------------------- /images/gui3.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/images/gui3.JPG -------------------------------------------------------------------------------- /images/mat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/images/mat.png -------------------------------------------------------------------------------- /images/sms.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/images/sms.JPG -------------------------------------------------------------------------------- /images/vis.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/images/vis.gif -------------------------------------------------------------------------------- /images/visual1.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/images/visual1.JPG -------------------------------------------------------------------------------- /inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.6.6" 21 | }, 22 | "colab": { 23 | "name": "inference.ipynb", 24 | "provenance": [], 25 | "include_colab_link": true 26 | } 27 | }, 28 | "cells": [ 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "view-in-github", 33 | "colab_type": "text" 34 | }, 35 | "source": [ 36 | "\"Open" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "metadata": { 42 | "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0", 43 | "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a", 44 | "id": "Rt29Iy94yMO2", 45 | "colab_type": "code", 46 | "outputId": "fb0fa31a-c39c-429e-8128-dc050d121773", 47 | "colab": {} 48 | }, 49 | "source": [ 50 | "# Imports here\n", 51 | "from __future__ import print_function, division\n", 52 | "import numpy as np\n", 53 | "import matplotlib.pyplot as plt\n", 54 | "from torch.utils import data\n", 55 | "import torch\n", 56 | "from torch import nn\n", 57 | "from torch import optim\n", 58 | "import torchvision\n", 59 | "import torch.nn.functional as F\n", 60 | "from torchvision import datasets, transforms, models\n", 61 | "import torchvision.models as models\n", 62 | "from torch.utils.data.sampler import SubsetRandomSampler\n", 63 | "from torch.utils.data import Dataset, DataLoader\n", 64 | "from skimage import io, transform\n", 65 | "import torch.utils.data as data_utils\n", 66 | "from PIL import Image, ImageFile\n", 67 | "import json\n", 68 | "from torch.optim import lr_scheduler\n", 69 | "import time\n", 70 | "import os\n", 71 | "import argparse\n", 72 | "import copy\n", 73 | "import pandas as pd\n", 74 | "ImageFile.LOAD_TRUNCATED_IMAGES = True\n", 75 | "import cv2\n", 76 | "# Import useful sklearn functions\n", 77 | "import sklearn\n", 78 | "from sklearn.metrics import cohen_kappa_score, accuracy_score\n", 79 | "import time\n", 80 | "from tqdm import tqdm_notebook\n", 81 | "\n", 82 | "import os\n", 83 | "print(os.listdir(\"../input\"))\n", 84 | "base_dir = \"../input/aptos2019-blindness-detection/\"" 85 | ], 86 | "execution_count": 0, 87 | "outputs": [ 88 | { 89 | "output_type": "stream", 90 | "text": [ 91 | "['aptos2019-blindness-detection', 'kernel4f121f3247']\n" 92 | ], 93 | "name": "stdout" 94 | } 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "metadata": { 100 | "id": "1XaCclZIyMPJ", 101 | "colab_type": "code", 102 | "outputId": "4d03fd95-4cc4-41a6-ab64-66383aae9ced", 103 | "colab": {} 104 | }, 105 | "source": [ 106 | "print(os.listdir(\"../input/kernel4f121f3247\"))\n" 107 | ], 108 | "execution_count": 0, 109 | "outputs": [ 110 | { 111 | "output_type": "stream", 112 | "text": [ 113 | "['__output__.json', '__results___files', 'custom.css', '__results__.html', '__notebook__.ipynb', 'classifier.pt']\n" 114 | ], 115 | "name": "stdout" 116 | } 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "metadata": { 122 | "id": "CfXBcJeKyMPV", 123 | "colab_type": "code", 124 | "colab": {} 125 | }, 126 | "source": [ 127 | "class CreateDataset(Dataset):\n", 128 | " def __init__(self, df_data, data_dir = '../input/', transform=None):\n", 129 | " super().__init__()\n", 130 | " self.df = df_data.values\n", 131 | " self.data_dir = data_dir\n", 132 | " self.transform = transform\n", 133 | "\n", 134 | " def __len__(self):\n", 135 | " return len(self.df)\n", 136 | " \n", 137 | " def __getitem__(self, index):\n", 138 | " img_name,label = self.df[index]\n", 139 | " img_path = os.path.join(self.data_dir, img_name+'.png')\n", 140 | " image = cv2.imread(img_path)\n", 141 | " if self.transform is not None:\n", 142 | " image = self.transform(image)\n", 143 | " return image, label" 144 | ], 145 | "execution_count": 0, 146 | "outputs": [] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "metadata": { 151 | "id": "yYOZ68A2yMPf", 152 | "colab_type": "code", 153 | "colab": {} 154 | }, 155 | "source": [ 156 | "test_csv = pd.read_csv('../input/aptos2019-blindness-detection/test.csv')\n" 157 | ], 158 | "execution_count": 0, 159 | "outputs": [] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "metadata": { 164 | "id": "14T6uoF2yMPm", 165 | "colab_type": "code", 166 | "colab": {} 167 | }, 168 | "source": [ 169 | "test_path = \"../input/aptos2019-blindness-detection/test_images/\"\n" 170 | ], 171 | "execution_count": 0, 172 | "outputs": [] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "metadata": { 177 | "id": "3CuePz9tyMPq", 178 | "colab_type": "code", 179 | "colab": {} 180 | }, 181 | "source": [ 182 | "test_transforms = torchvision.transforms.Compose([\n", 183 | " torchvision.transforms.ToPILImage(),\n", 184 | " torchvision.transforms.Resize((224, 224)),\n", 185 | " #torchvision.transforms.ColorJitter(brightness=2, contrast=2),\n", 186 | " torchvision.transforms.RandomHorizontalFlip(p=0.5),\n", 187 | " torchvision.transforms.ToTensor(),\n", 188 | " torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n", 189 | "])" 190 | ], 191 | "execution_count": 0, 192 | "outputs": [] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "metadata": { 197 | "id": "p9ekp4mByMPu", 198 | "colab_type": "code", 199 | "colab": {} 200 | }, 201 | "source": [ 202 | "test_csv['diagnosis'] = -1\n" 203 | ], 204 | "execution_count": 0, 205 | "outputs": [] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "metadata": { 210 | "id": "fuDzK_jqyMP5", 211 | "colab_type": "code", 212 | "colab": {} 213 | }, 214 | "source": [ 215 | "test_data = CreateDataset(df_data=test_csv, data_dir=test_path, transform=test_transforms)\n", 216 | "test_loader = DataLoader(test_data, batch_size=64, shuffle=False)" 217 | ], 218 | "execution_count": 0, 219 | "outputs": [] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "metadata": { 224 | "id": "-bx2Z1wVyMP_", 225 | "colab_type": "code", 226 | "colab": {} 227 | }, 228 | "source": [ 229 | "def round_off_preds(preds, coef=[0.5, 1.5, 2.5, 3.5]):\n", 230 | " for i, pred in enumerate(preds):\n", 231 | " if pred < coef[0]:\n", 232 | " preds[i] = 0\n", 233 | " elif pred >= coef[0] and pred < coef[1]:\n", 234 | " preds[i] = 1\n", 235 | " elif pred >= coef[1] and pred < coef[2]:\n", 236 | " preds[i] = 2\n", 237 | " elif pred >= coef[2] and pred < coef[3]:\n", 238 | " preds[i] = 3\n", 239 | " else:\n", 240 | " preds[i] = 4\n", 241 | " return preds" 242 | ], 243 | "execution_count": 0, 244 | "outputs": [] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "metadata": { 249 | "id": "sw2LHzS3yMQC", 250 | "colab_type": "code", 251 | "colab": {} 252 | }, 253 | "source": [ 254 | "def predict(testloader):\n", 255 | " '''Function used to make predictions on the test set'''\n", 256 | " model.eval()\n", 257 | " preds = []\n", 258 | " for batch_i, (data, target) in enumerate(testloader):\n", 259 | " data, target = data.cuda(), target.cuda()\n", 260 | " output = model(data)\n", 261 | " pr = output.detach().cpu().numpy()\n", 262 | " for i in pr:\n", 263 | " preds.append(i.item())\n", 264 | " \n", 265 | " return preds" 266 | ], 267 | "execution_count": 0, 268 | "outputs": [] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "metadata": { 273 | "id": "7lWcz7t8yMQH", 274 | "colab_type": "code", 275 | "colab": {} 276 | }, 277 | "source": [ 278 | "def load_model(path):\n", 279 | " checkpoint = torch.load(path)\n", 280 | " model.load_state_dict(checkpoint['model_state_dict'])\n", 281 | " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", 282 | " return model" 283 | ], 284 | "execution_count": 0, 285 | "outputs": [] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "metadata": { 290 | "id": "ohMGUHUHyMQM", 291 | "colab_type": "code", 292 | "colab": {} 293 | }, 294 | "source": [ 295 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 296 | "\n", 297 | "model = models.resnet152(pretrained=False) \n", 298 | "\n", 299 | "num_ftrs = model.fc.in_features \n", 300 | "out_ftrs = 5 \n", 301 | " \n", 302 | "model.fc = nn.Sequential(nn.Linear(num_ftrs, 512),nn.ReLU(),nn.Linear(512,out_ftrs),nn.LogSoftmax(dim=1))\n", 303 | "\n", 304 | "criterion = nn.NLLLoss()\n", 305 | "optimizer = torch.optim.Adam(filter(lambda p:p.requires_grad,model.parameters()) , lr = 0.00001) \n", 306 | "\n", 307 | "scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)\n", 308 | "model.to(device);" 309 | ], 310 | "execution_count": 0, 311 | "outputs": [] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "metadata": { 316 | "id": "g8uMWj8lyMQT", 317 | "colab_type": "code", 318 | "outputId": "b19053b3-e48f-4cdd-d0a8-d562929bb33a", 319 | "colab": {} 320 | }, 321 | "source": [ 322 | "# to unfreeze more layers \n", 323 | "for name,child in model.named_children():\n", 324 | " if name in ['layer2','layer3','layer4','fc']:\n", 325 | " print(name + 'is unfrozen')\n", 326 | " for param in child.parameters():\n", 327 | " param.requires_grad = True\n", 328 | " else:\n", 329 | " print(name + 'is frozen')\n", 330 | " for param in child.parameters():\n", 331 | " param.requires_grad = False" 332 | ], 333 | "execution_count": 0, 334 | "outputs": [ 335 | { 336 | "output_type": "stream", 337 | "text": [ 338 | "conv1is frozen\n", 339 | "bn1is frozen\n", 340 | "reluis frozen\n", 341 | "maxpoolis frozen\n", 342 | "layer1is frozen\n", 343 | "layer2is unfrozen\n", 344 | "layer3is unfrozen\n", 345 | "layer4is unfrozen\n", 346 | "avgpoolis frozen\n", 347 | "fcis unfrozen\n" 348 | ], 349 | "name": "stdout" 350 | } 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "metadata": { 356 | "id": "2dNbnpITyMQc", 357 | "colab_type": "code", 358 | "colab": {} 359 | }, 360 | "source": [ 361 | "optimizer = torch.optim.Adam(filter(lambda p:p.requires_grad,model.parameters()) , lr = 0.000001) \n", 362 | "scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)" 363 | ], 364 | "execution_count": 0, 365 | "outputs": [] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "metadata": { 370 | "id": "2tNkZ6otyMQh", 371 | "colab_type": "code", 372 | "colab": {} 373 | }, 374 | "source": [ 375 | "model = load_model(\"../input/kernel4f121f3247/classifier.pt\")\n" 376 | ], 377 | "execution_count": 0, 378 | "outputs": [] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "metadata": { 383 | "id": "3CyUoakYyMQp", 384 | "colab_type": "code", 385 | "colab": {} 386 | }, 387 | "source": [ 388 | "test_dir = \"../input/aptos2019-blindness-detection/test_images/\"" 389 | ], 390 | "execution_count": 0, 391 | "outputs": [] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "metadata": { 396 | "id": "Yk7nwGWyyMQu", 397 | "colab_type": "code", 398 | "outputId": "10af9635-db30-4c58-eab3-b66e2b632228", 399 | "colab": {} 400 | }, 401 | "source": [ 402 | "with torch.no_grad():\n", 403 | " model.eval()\n", 404 | " p_labels = []\n", 405 | " img_ids = []\n", 406 | " i = 0\n", 407 | " for inputs, labels in test_loader:\n", 408 | " i += 1\n", 409 | " if i % 10 == 0:\n", 410 | " print(f'{i} pass step')\n", 411 | " inputs = inputs.to(device)\n", 412 | " labels = labels.to(device)\n", 413 | " outputs = model(inputs)\n", 414 | " _, preds = torch.max(outputs, 1)\n", 415 | " p_labels.append(preds)\n", 416 | " # getting ids of file images " 417 | ], 418 | "execution_count": 0, 419 | "outputs": [ 420 | { 421 | "output_type": "stream", 422 | "text": [ 423 | "10 pass step\n", 424 | "20 pass step\n", 425 | "30 pass step\n" 426 | ], 427 | "name": "stdout" 428 | } 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "metadata": { 434 | "id": "-LH6LE9YyMQ1", 435 | "colab_type": "code", 436 | "outputId": "9523a4f6-15ea-4211-8005-a332f6fd3684", 437 | "colab": {} 438 | }, 439 | "source": [ 440 | "p_labels" 441 | ], 442 | "execution_count": 0, 443 | "outputs": [ 444 | { 445 | "output_type": "execute_result", 446 | "data": { 447 | "text/plain": [ 448 | "[tensor([2, 2, 2, 2, 2, 2, 2, 1, 3, 0, 4, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2,\n", 449 | " 2, 0, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 2, 0, 3, 2, 2, 2, 2, 1, 2, 2, 3,\n", 450 | " 2, 2, 3, 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 2, 3, 2], device='cuda:0'),\n", 451 | " tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 0, 0, 0, 2, 1, 1, 2, 2, 2, 2,\n", 452 | " 2, 1, 2, 1, 2, 2, 2, 2, 1, 0, 2, 2, 1, 0, 0, 2, 2, 2, 3, 2, 2, 0, 2, 3,\n", 453 | " 2, 1, 2, 2, 2, 0, 2, 2, 0, 0, 3, 2, 3, 2, 2, 2], device='cuda:0'),\n", 454 | " tensor([0, 2, 0, 2, 2, 1, 2, 2, 2, 2, 0, 2, 2, 1, 2, 2, 2, 2, 0, 1, 0, 2, 1, 2,\n", 455 | " 2, 0, 2, 1, 2, 2, 2, 0, 2, 2, 2, 2, 2, 4, 0, 2, 3, 0, 0, 2, 4, 2, 2, 2,\n", 456 | " 2, 3, 0, 0, 2, 2, 2, 2, 2, 2, 0, 0, 2, 0, 2, 2], device='cuda:0'),\n", 457 | " tensor([0, 2, 2, 2, 4, 1, 2, 2, 2, 2, 2, 2, 0, 0, 2, 2, 0, 2, 0, 2, 2, 2, 0, 2,\n", 458 | " 2, 2, 2, 2, 2, 0, 0, 2, 4, 2, 2, 2, 0, 2, 2, 2, 1, 1, 3, 2, 2, 2, 2, 0,\n", 459 | " 0, 2, 2, 0, 0, 2, 2, 4, 0, 0, 0, 1, 0, 2, 2, 0], device='cuda:0'),\n", 460 | " tensor([0, 1, 2, 3, 2, 2, 0, 0, 2, 2, 3, 0, 3, 2, 1, 0, 1, 0, 2, 2, 2, 4, 2, 2,\n", 461 | " 2, 3, 2, 2, 2, 0, 0, 0, 2, 2, 0, 2, 2, 1, 2, 0, 2, 2, 2, 2, 0, 2, 2, 2,\n", 462 | " 2, 2, 2, 2, 0, 2, 2, 2, 2, 0, 2, 2, 1, 2, 2, 2], device='cuda:0'),\n", 463 | " tensor([0, 0, 1, 0, 2, 3, 2, 1, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 4, 2, 0, 2, 0, 2,\n", 464 | " 2, 1, 0, 2, 0, 2, 2, 2, 0, 2, 2, 1, 2, 2, 0, 1, 2, 2, 2, 2, 2, 0, 1, 2,\n", 465 | " 2, 2, 2, 2, 0, 2, 2, 0, 4, 2, 2, 2, 2, 2, 1, 2], device='cuda:0'),\n", 466 | " tensor([2, 2, 1, 2, 0, 0, 2, 2, 0, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", 467 | " 2, 0, 0, 1, 1, 3, 2, 1, 0, 1, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 0, 2,\n", 468 | " 2, 1, 2, 2, 4, 2, 0, 0, 2, 1, 0, 2, 2, 2, 0, 2], device='cuda:0'),\n", 469 | " tensor([2, 0, 1, 0, 2, 2, 0, 2, 2, 2, 2, 3, 0, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0,\n", 470 | " 2, 0, 0, 3, 2, 1, 2, 0, 1, 2, 2, 2, 2, 2, 3, 0, 2, 2, 2, 0, 2, 3, 2, 2,\n", 471 | " 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 2, 0, 2, 0], device='cuda:0'),\n", 472 | " tensor([2, 2, 0, 2, 2, 2, 1, 2, 2, 0, 2, 1, 0, 2, 2, 1, 2, 2, 2, 2, 0, 3, 2, 0,\n", 473 | " 2, 2, 1, 2, 2, 0, 3, 2, 2, 2, 4, 2, 3, 2, 1, 2, 2, 2, 2, 1, 3, 2, 3, 2,\n", 474 | " 2, 2, 2, 0, 2, 0, 4, 2, 2, 2, 2, 0, 2, 2, 0, 2], device='cuda:0'),\n", 475 | " tensor([2, 2, 0, 0, 2, 1, 2, 2, 2, 3, 2, 2, 2, 2, 2, 4, 0, 2, 0, 2, 2, 2, 2, 0,\n", 476 | " 2, 4, 2, 0, 2, 0, 2, 2, 2, 0, 2, 2, 2, 0, 2, 0, 0, 0, 0, 2, 2, 2, 0, 2,\n", 477 | " 2, 2, 2, 2, 2, 0, 2, 0, 3, 2, 0, 0, 2, 3, 2, 1], device='cuda:0'),\n", 478 | " tensor([2, 1, 2, 2, 2, 1, 0, 2, 2, 3, 2, 2, 2, 0, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2,\n", 479 | " 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 0, 2, 1, 0, 2, 0, 2, 1, 0, 2, 2, 2, 2,\n", 480 | " 2, 2, 2, 2, 3, 2, 0, 2, 2, 0, 2, 1, 1, 0, 0, 2], device='cuda:0'),\n", 481 | " tensor([2, 2, 2, 2, 2, 2, 1, 1, 0, 2, 2, 0, 2, 0, 1, 1, 2, 2, 0, 0, 2, 2, 2, 0,\n", 482 | " 1, 2, 4, 1, 2, 2, 2, 2, 2, 2, 0, 0, 2, 2, 1, 3, 2, 2, 2, 0, 2, 2, 0, 2,\n", 483 | " 2, 2, 2, 2, 4, 0, 2, 2, 4, 2, 2, 2, 2, 2, 2, 2], device='cuda:0'),\n", 484 | " tensor([0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 3, 1, 2,\n", 485 | " 2, 0, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 3, 0, 0, 0, 0, 2, 2,\n", 486 | " 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2], device='cuda:0'),\n", 487 | " tensor([2, 2, 2, 2, 2, 2, 1, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 3, 0, 2, 2, 1,\n", 488 | " 2, 1, 2, 2, 1, 2, 0, 2, 1, 4, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 3, 0, 2,\n", 489 | " 2, 2, 0, 2, 2, 0, 2, 3, 2, 2, 2, 0, 4, 2, 4, 2], device='cuda:0'),\n", 490 | " tensor([0, 2, 2, 2, 2, 1, 2, 2, 2, 2, 3, 1, 4, 1, 0, 2, 2, 2, 2, 2, 2, 0, 2, 2,\n", 491 | " 0, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 0, 2,\n", 492 | " 2, 2, 0, 0, 0, 2, 2, 0, 2, 4, 2, 1, 2, 2, 2, 2], device='cuda:0'),\n", 493 | " tensor([2, 1, 3, 2, 2, 3, 2, 0, 2, 2, 1, 2, 2, 2, 2, 2, 2, 3, 3, 1, 3, 2, 2, 2,\n", 494 | " 2, 0, 2, 1, 4, 0, 2, 2, 2, 2, 2, 2, 3, 0, 2, 0, 2, 4, 2, 2, 2, 0, 2, 2,\n", 495 | " 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2], device='cuda:0'),\n", 496 | " tensor([2, 0, 0, 2, 2, 1, 2, 2, 0, 3, 2, 2, 0, 2, 2, 3, 0, 2, 2, 0, 0, 2, 2, 0,\n", 497 | " 2, 2, 2, 2, 2, 2, 2, 0, 1, 2, 0, 0, 0, 0, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2,\n", 498 | " 2, 3, 2, 2, 4, 0, 2, 1, 2, 2, 2, 0, 2, 2, 4, 2], device='cuda:0'),\n", 499 | " tensor([2, 2, 0, 4, 3, 2, 2, 2, 0, 2, 2, 2, 4, 2, 3, 2, 0, 2, 2, 2, 2, 2, 3, 2,\n", 500 | " 0, 2, 2, 2, 0, 2, 2, 2, 2, 2, 3, 3, 2, 0, 2, 2, 0, 2, 2, 2, 2, 2, 0, 3,\n", 501 | " 0, 0, 4, 2, 1, 2, 2, 2, 2, 2, 4, 2, 2, 2, 2, 2], device='cuda:0'),\n", 502 | " tensor([1, 2, 2, 2, 0, 2, 4, 2, 2, 4, 2, 2, 2, 2, 0, 2, 1, 1, 2, 2, 2, 0, 3, 2,\n", 503 | " 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 0, 2, 0, 4, 2, 2,\n", 504 | " 2, 2, 2, 2, 3, 2, 3, 0, 2, 2, 0, 0, 0, 2, 1, 0], device='cuda:0'),\n", 505 | " tensor([0, 2, 2, 4, 0, 2, 0, 0, 2, 3, 2, 2, 0, 0, 2, 2, 1, 0, 2, 2, 2, 0, 2, 1,\n", 506 | " 0, 2, 2, 2, 2, 0, 2, 1, 2, 2, 0, 2, 0, 2, 2, 3, 0, 4, 3, 0, 2, 2, 1, 4,\n", 507 | " 2, 2, 0, 0, 2, 2, 4, 2, 1, 2, 2, 1, 2, 0, 2, 0], device='cuda:0'),\n", 508 | " tensor([0, 1, 2, 2, 2, 3, 2, 2, 1, 0, 0, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", 509 | " 2, 0, 2, 0, 2, 2, 2, 3, 2, 2, 2, 0, 2, 0, 1, 0, 2, 2, 2, 4, 2, 2, 2, 0,\n", 510 | " 1, 0, 1, 2, 0, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0], device='cuda:0'),\n", 511 | " tensor([2, 2, 3, 2, 2, 2, 4, 4, 2, 2, 2, 0, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", 512 | " 2, 2, 0, 2, 0, 2, 2, 0, 0, 0, 2, 2, 2, 4, 0, 0, 0, 2, 2, 2, 2, 0, 2, 2,\n", 513 | " 2, 3, 0, 2, 0, 3, 2, 3, 0, 1, 2, 2, 2, 2, 0, 2], device='cuda:0'),\n", 514 | " tensor([2, 2, 0, 2, 0, 2, 2, 2, 2, 0, 3, 0, 0, 0, 0, 2, 2, 2, 0, 2, 2, 2, 2, 2,\n", 515 | " 2, 2, 2, 4, 2, 0, 3, 4, 2, 2, 2, 3, 1, 4, 2, 0, 0, 2, 2, 1, 2, 2, 1, 0,\n", 516 | " 2, 1, 1, 2, 2, 0, 2, 2, 2, 0, 2, 2, 0, 2, 2, 3], device='cuda:0'),\n", 517 | " tensor([0, 2, 3, 0, 2, 0, 2, 0, 1, 2, 2, 2, 2, 0, 2, 2, 2, 0, 1, 0, 2, 3, 2, 4,\n", 518 | " 0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 0, 2, 0, 2, 1, 0, 2, 0, 2, 3, 2, 2,\n", 519 | " 2, 2, 1, 2, 2, 2, 3, 0, 3, 2, 2, 2, 2, 2, 2, 3], device='cuda:0'),\n", 520 | " tensor([2, 1, 2, 0, 2, 2, 1, 0, 2, 2, 4, 2, 2, 1, 2, 3, 4, 0, 2, 0, 2, 2, 2, 2,\n", 521 | " 2, 4, 1, 2, 2, 2, 2, 3, 1, 2, 2, 2, 2, 0, 2, 0, 2, 2, 0, 2, 2, 4, 2, 3,\n", 522 | " 0, 1, 3, 2, 2, 2, 2, 2, 1, 2, 2, 1, 3, 2, 2, 1], device='cuda:0'),\n", 523 | " tensor([2, 2, 2, 2, 2, 2, 1, 2, 2, 0, 4, 2, 2, 2, 2, 0, 2, 2, 2, 0, 3, 1, 1, 0,\n", 524 | " 2, 1, 2, 0, 2, 2, 0, 0, 2, 2, 3, 2, 4, 2, 0, 0, 0, 2, 2, 1, 2, 0, 2, 2,\n", 525 | " 2, 3, 2, 2, 0, 2, 2, 0, 0, 3, 2, 2, 2, 0, 4, 2], device='cuda:0'),\n", 526 | " tensor([0, 0, 2, 2, 2, 1, 0, 2, 4, 2, 2, 0, 1, 0, 4, 2, 2, 4, 2, 2, 2, 0, 1, 2,\n", 527 | " 0, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 1, 1, 2, 0, 0, 2, 2, 2, 2, 3, 0, 2, 2,\n", 528 | " 2, 2, 0, 2, 0, 0, 2, 0, 2, 2, 0, 2, 0, 0, 2, 2], device='cuda:0'),\n", 529 | " tensor([0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 1, 2,\n", 530 | " 2, 2, 0, 3, 2, 2, 2, 3, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 1, 2, 0, 2, 0, 0,\n", 531 | " 2, 2, 0, 0, 2, 4, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0'),\n", 532 | " tensor([3, 3, 2, 0, 0, 2, 4, 3, 2, 2, 2, 0, 3, 1, 2, 2, 0, 2, 3, 2, 2, 0, 2, 0,\n", 533 | " 2, 2, 0, 4, 2, 3, 2, 2, 2, 2, 0, 0, 2, 2, 2, 0, 2, 2, 3, 2, 4, 4, 2, 2,\n", 534 | " 2, 1, 2, 2, 0, 1, 2, 2, 2, 2, 2, 0, 0, 0, 0, 2], device='cuda:0'),\n", 535 | " tensor([2, 1, 0, 2, 0, 2, 2, 3, 1, 2, 3, 2, 1, 2, 2, 2, 2, 2, 2, 3, 3, 2, 2, 2,\n", 536 | " 4, 0, 2, 2, 3, 2, 2, 2, 0, 0, 0, 2, 2, 2, 4, 2, 2, 2, 0, 0, 0, 2, 2, 2,\n", 537 | " 1, 2, 2, 0, 2, 2, 0, 3, 2, 2, 2, 2, 2, 1, 2, 2], device='cuda:0'),\n", 538 | " tensor([2, 2, 2, 0, 2, 2, 3, 0], device='cuda:0')]" 539 | ] 540 | }, 541 | "metadata": { 542 | "tags": [] 543 | }, 544 | "execution_count": 19 545 | } 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "metadata": { 551 | "id": "nGcEgjL7yMRE", 552 | "colab_type": "code", 553 | "colab": {} 554 | }, 555 | "source": [ 556 | "pred_labels = []\n", 557 | "for l in p_labels:\n", 558 | " for l1 in l:\n", 559 | " pred_labels.append(l1.item())" 560 | ], 561 | "execution_count": 0, 562 | "outputs": [] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "metadata": { 567 | "id": "aeP6MAfUyMSP", 568 | "colab_type": "code", 569 | "outputId": "055465d8-db0e-40d5-fab1-8631dddab286", 570 | "colab": {} 571 | }, 572 | "source": [ 573 | "pred_labels" 574 | ], 575 | "execution_count": 0, 576 | "outputs": [ 577 | { 578 | "output_type": "execute_result", 579 | "data": { 580 | "text/plain": [ 581 | "[2,\n", 582 | " 2,\n", 583 | " 2,\n", 584 | " 2,\n", 585 | " 2,\n", 586 | " 2,\n", 587 | " 2,\n", 588 | " 1,\n", 589 | " 3,\n", 590 | " 0,\n", 591 | " 4,\n", 592 | " 2,\n", 593 | " 2,\n", 594 | " 2,\n", 595 | " 2,\n", 596 | " 2,\n", 597 | " 2,\n", 598 | " 2,\n", 599 | " 2,\n", 600 | " 1,\n", 601 | " 2,\n", 602 | " 2,\n", 603 | " 2,\n", 604 | " 2,\n", 605 | " 2,\n", 606 | " 0,\n", 607 | " 2,\n", 608 | " 2,\n", 609 | " 2,\n", 610 | " 2,\n", 611 | " 0,\n", 612 | " 2,\n", 613 | " 2,\n", 614 | " 0,\n", 615 | " 2,\n", 616 | " 2,\n", 617 | " 2,\n", 618 | " 2,\n", 619 | " 0,\n", 620 | " 3,\n", 621 | " 2,\n", 622 | " 2,\n", 623 | " 2,\n", 624 | " 2,\n", 625 | " 1,\n", 626 | " 2,\n", 627 | " 2,\n", 628 | " 3,\n", 629 | " 2,\n", 630 | " 2,\n", 631 | " 3,\n", 632 | " 2,\n", 633 | " 2,\n", 634 | " 0,\n", 635 | " 2,\n", 636 | " 0,\n", 637 | " 2,\n", 638 | " 2,\n", 639 | " 2,\n", 640 | " 2,\n", 641 | " 2,\n", 642 | " 2,\n", 643 | " 3,\n", 644 | " 2,\n", 645 | " 2,\n", 646 | " 2,\n", 647 | " 2,\n", 648 | " 2,\n", 649 | " 2,\n", 650 | " 2,\n", 651 | " 2,\n", 652 | " 2,\n", 653 | " 2,\n", 654 | " 0,\n", 655 | " 2,\n", 656 | " 2,\n", 657 | " 2,\n", 658 | " 2,\n", 659 | " 0,\n", 660 | " 0,\n", 661 | " 0,\n", 662 | " 2,\n", 663 | " 1,\n", 664 | " 1,\n", 665 | " 2,\n", 666 | " 2,\n", 667 | " 2,\n", 668 | " 2,\n", 669 | " 2,\n", 670 | " 1,\n", 671 | " 2,\n", 672 | " 1,\n", 673 | " 2,\n", 674 | " 2,\n", 675 | " 2,\n", 676 | " 2,\n", 677 | " 1,\n", 678 | " 0,\n", 679 | " 2,\n", 680 | " 2,\n", 681 | " 1,\n", 682 | " 0,\n", 683 | " 0,\n", 684 | " 2,\n", 685 | " 2,\n", 686 | " 2,\n", 687 | " 3,\n", 688 | " 2,\n", 689 | " 2,\n", 690 | " 0,\n", 691 | " 2,\n", 692 | " 3,\n", 693 | " 2,\n", 694 | " 1,\n", 695 | " 2,\n", 696 | " 2,\n", 697 | " 2,\n", 698 | " 0,\n", 699 | " 2,\n", 700 | " 2,\n", 701 | " 0,\n", 702 | " 0,\n", 703 | " 3,\n", 704 | " 2,\n", 705 | " 3,\n", 706 | " 2,\n", 707 | " 2,\n", 708 | " 2,\n", 709 | " 0,\n", 710 | " 2,\n", 711 | " 0,\n", 712 | " 2,\n", 713 | " 2,\n", 714 | " 1,\n", 715 | " 2,\n", 716 | " 2,\n", 717 | " 2,\n", 718 | " 2,\n", 719 | " 0,\n", 720 | " 2,\n", 721 | " 2,\n", 722 | " 1,\n", 723 | " 2,\n", 724 | " 2,\n", 725 | " 2,\n", 726 | " 2,\n", 727 | " 0,\n", 728 | " 1,\n", 729 | " 0,\n", 730 | " 2,\n", 731 | " 1,\n", 732 | " 2,\n", 733 | " 2,\n", 734 | " 0,\n", 735 | " 2,\n", 736 | " 1,\n", 737 | " 2,\n", 738 | " 2,\n", 739 | " 2,\n", 740 | " 0,\n", 741 | " 2,\n", 742 | " 2,\n", 743 | " 2,\n", 744 | " 2,\n", 745 | " 2,\n", 746 | " 4,\n", 747 | " 0,\n", 748 | " 2,\n", 749 | " 3,\n", 750 | " 0,\n", 751 | " 0,\n", 752 | " 2,\n", 753 | " 4,\n", 754 | " 2,\n", 755 | " 2,\n", 756 | " 2,\n", 757 | " 2,\n", 758 | " 3,\n", 759 | " 0,\n", 760 | " 0,\n", 761 | " 2,\n", 762 | " 2,\n", 763 | " 2,\n", 764 | " 2,\n", 765 | " 2,\n", 766 | " 2,\n", 767 | " 0,\n", 768 | " 0,\n", 769 | " 2,\n", 770 | " 0,\n", 771 | " 2,\n", 772 | " 2,\n", 773 | " 0,\n", 774 | " 2,\n", 775 | " 2,\n", 776 | " 2,\n", 777 | " 4,\n", 778 | " 1,\n", 779 | " 2,\n", 780 | " 2,\n", 781 | " 2,\n", 782 | " 2,\n", 783 | " 2,\n", 784 | " 2,\n", 785 | " 0,\n", 786 | " 0,\n", 787 | " 2,\n", 788 | " 2,\n", 789 | " 0,\n", 790 | " 2,\n", 791 | " 0,\n", 792 | " 2,\n", 793 | " 2,\n", 794 | " 2,\n", 795 | " 0,\n", 796 | " 2,\n", 797 | " 2,\n", 798 | " 2,\n", 799 | " 2,\n", 800 | " 2,\n", 801 | " 2,\n", 802 | " 0,\n", 803 | " 0,\n", 804 | " 2,\n", 805 | " 4,\n", 806 | " 2,\n", 807 | " 2,\n", 808 | " 2,\n", 809 | " 0,\n", 810 | " 2,\n", 811 | " 2,\n", 812 | " 2,\n", 813 | " 1,\n", 814 | " 1,\n", 815 | " 3,\n", 816 | " 2,\n", 817 | " 2,\n", 818 | " 2,\n", 819 | " 2,\n", 820 | " 0,\n", 821 | " 0,\n", 822 | " 2,\n", 823 | " 2,\n", 824 | " 0,\n", 825 | " 0,\n", 826 | " 2,\n", 827 | " 2,\n", 828 | " 4,\n", 829 | " 0,\n", 830 | " 0,\n", 831 | " 0,\n", 832 | " 1,\n", 833 | " 0,\n", 834 | " 2,\n", 835 | " 2,\n", 836 | " 0,\n", 837 | " 0,\n", 838 | " 1,\n", 839 | " 2,\n", 840 | " 3,\n", 841 | " 2,\n", 842 | " 2,\n", 843 | " 0,\n", 844 | " 0,\n", 845 | " 2,\n", 846 | " 2,\n", 847 | " 3,\n", 848 | " 0,\n", 849 | " 3,\n", 850 | " 2,\n", 851 | " 1,\n", 852 | " 0,\n", 853 | " 1,\n", 854 | " 0,\n", 855 | " 2,\n", 856 | " 2,\n", 857 | " 2,\n", 858 | " 4,\n", 859 | " 2,\n", 860 | " 2,\n", 861 | " 2,\n", 862 | " 3,\n", 863 | " 2,\n", 864 | " 2,\n", 865 | " 2,\n", 866 | " 0,\n", 867 | " 0,\n", 868 | " 0,\n", 869 | " 2,\n", 870 | " 2,\n", 871 | " 0,\n", 872 | " 2,\n", 873 | " 2,\n", 874 | " 1,\n", 875 | " 2,\n", 876 | " 0,\n", 877 | " 2,\n", 878 | " 2,\n", 879 | " 2,\n", 880 | " 2,\n", 881 | " 0,\n", 882 | " 2,\n", 883 | " 2,\n", 884 | " 2,\n", 885 | " 2,\n", 886 | " 2,\n", 887 | " 2,\n", 888 | " 2,\n", 889 | " 0,\n", 890 | " 2,\n", 891 | " 2,\n", 892 | " 2,\n", 893 | " 2,\n", 894 | " 0,\n", 895 | " 2,\n", 896 | " 2,\n", 897 | " 1,\n", 898 | " 2,\n", 899 | " 2,\n", 900 | " 2,\n", 901 | " 0,\n", 902 | " 0,\n", 903 | " 1,\n", 904 | " 0,\n", 905 | " 2,\n", 906 | " 3,\n", 907 | " 2,\n", 908 | " 1,\n", 909 | " 2,\n", 910 | " 2,\n", 911 | " 2,\n", 912 | " 2,\n", 913 | " 2,\n", 914 | " 2,\n", 915 | " 0,\n", 916 | " 2,\n", 917 | " 2,\n", 918 | " 2,\n", 919 | " 4,\n", 920 | " 2,\n", 921 | " 0,\n", 922 | " 2,\n", 923 | " 0,\n", 924 | " 2,\n", 925 | " 2,\n", 926 | " 1,\n", 927 | " 0,\n", 928 | " 2,\n", 929 | " 0,\n", 930 | " 2,\n", 931 | " 2,\n", 932 | " 2,\n", 933 | " 0,\n", 934 | " 2,\n", 935 | " 2,\n", 936 | " 1,\n", 937 | " 2,\n", 938 | " 2,\n", 939 | " 0,\n", 940 | " 1,\n", 941 | " 2,\n", 942 | " 2,\n", 943 | " 2,\n", 944 | " 2,\n", 945 | " 2,\n", 946 | " 0,\n", 947 | " 1,\n", 948 | " 2,\n", 949 | " 2,\n", 950 | " 2,\n", 951 | " 2,\n", 952 | " 2,\n", 953 | " 0,\n", 954 | " 2,\n", 955 | " 2,\n", 956 | " 0,\n", 957 | " 4,\n", 958 | " 2,\n", 959 | " 2,\n", 960 | " 2,\n", 961 | " 2,\n", 962 | " 2,\n", 963 | " 1,\n", 964 | " 2,\n", 965 | " 2,\n", 966 | " 2,\n", 967 | " 1,\n", 968 | " 2,\n", 969 | " 0,\n", 970 | " 0,\n", 971 | " 2,\n", 972 | " 2,\n", 973 | " 0,\n", 974 | " 2,\n", 975 | " 2,\n", 976 | " 0,\n", 977 | " 2,\n", 978 | " 2,\n", 979 | " 2,\n", 980 | " 2,\n", 981 | " 2,\n", 982 | " 2,\n", 983 | " 2,\n", 984 | " 2,\n", 985 | " 2,\n", 986 | " 2,\n", 987 | " 2,\n", 988 | " 2,\n", 989 | " 2,\n", 990 | " 0,\n", 991 | " 0,\n", 992 | " 1,\n", 993 | " 1,\n", 994 | " 3,\n", 995 | " 2,\n", 996 | " 1,\n", 997 | " 0,\n", 998 | " 1,\n", 999 | " 2,\n", 1000 | " 2,\n", 1001 | " 2,\n", 1002 | " 2,\n", 1003 | " 2,\n", 1004 | " 0,\n", 1005 | " 2,\n", 1006 | " 2,\n", 1007 | " 0,\n", 1008 | " 2,\n", 1009 | " 2,\n", 1010 | " 2,\n", 1011 | " 0,\n", 1012 | " 2,\n", 1013 | " 2,\n", 1014 | " 1,\n", 1015 | " 2,\n", 1016 | " 2,\n", 1017 | " 4,\n", 1018 | " 2,\n", 1019 | " 0,\n", 1020 | " 0,\n", 1021 | " 2,\n", 1022 | " 1,\n", 1023 | " 0,\n", 1024 | " 2,\n", 1025 | " 2,\n", 1026 | " 2,\n", 1027 | " 0,\n", 1028 | " 2,\n", 1029 | " 2,\n", 1030 | " 0,\n", 1031 | " 1,\n", 1032 | " 0,\n", 1033 | " 2,\n", 1034 | " 2,\n", 1035 | " 0,\n", 1036 | " 2,\n", 1037 | " 2,\n", 1038 | " 2,\n", 1039 | " 2,\n", 1040 | " 3,\n", 1041 | " 0,\n", 1042 | " 2,\n", 1043 | " 2,\n", 1044 | " 2,\n", 1045 | " 2,\n", 1046 | " 2,\n", 1047 | " 2,\n", 1048 | " 0,\n", 1049 | " 2,\n", 1050 | " 2,\n", 1051 | " 2,\n", 1052 | " 0,\n", 1053 | " 2,\n", 1054 | " 0,\n", 1055 | " 0,\n", 1056 | " 3,\n", 1057 | " 2,\n", 1058 | " 1,\n", 1059 | " 2,\n", 1060 | " 0,\n", 1061 | " 1,\n", 1062 | " 2,\n", 1063 | " 2,\n", 1064 | " 2,\n", 1065 | " 2,\n", 1066 | " 2,\n", 1067 | " 3,\n", 1068 | " 0,\n", 1069 | " 2,\n", 1070 | " 2,\n", 1071 | " 2,\n", 1072 | " 0,\n", 1073 | " 2,\n", 1074 | " 3,\n", 1075 | " 2,\n", 1076 | " 2,\n", 1077 | " 2,\n", 1078 | " 3,\n", 1079 | " 2,\n", 1080 | " 2,\n", 1081 | " 2,\n", 1082 | " 2,\n", 1083 | " 2,\n", 1084 | " 2,\n", 1085 | " 2,\n", 1086 | " 2,\n", 1087 | " 2,\n", 1088 | " 4,\n", 1089 | " 2,\n", 1090 | " 0,\n", 1091 | " 2,\n", 1092 | " 0,\n", 1093 | " 2,\n", 1094 | " 2,\n", 1095 | " 0,\n", 1096 | " 2,\n", 1097 | " 2,\n", 1098 | " 2,\n", 1099 | " 1,\n", 1100 | " 2,\n", 1101 | " 2,\n", 1102 | " 0,\n", 1103 | " 2,\n", 1104 | " 1,\n", 1105 | " 0,\n", 1106 | " 2,\n", 1107 | " 2,\n", 1108 | " 1,\n", 1109 | " 2,\n", 1110 | " 2,\n", 1111 | " 2,\n", 1112 | " 2,\n", 1113 | " 0,\n", 1114 | " 3,\n", 1115 | " 2,\n", 1116 | " 0,\n", 1117 | " 2,\n", 1118 | " 2,\n", 1119 | " 1,\n", 1120 | " 2,\n", 1121 | " 2,\n", 1122 | " 0,\n", 1123 | " 3,\n", 1124 | " 2,\n", 1125 | " 2,\n", 1126 | " 2,\n", 1127 | " 4,\n", 1128 | " 2,\n", 1129 | " 3,\n", 1130 | " 2,\n", 1131 | " 1,\n", 1132 | " 2,\n", 1133 | " 2,\n", 1134 | " 2,\n", 1135 | " 2,\n", 1136 | " 1,\n", 1137 | " 3,\n", 1138 | " 2,\n", 1139 | " 3,\n", 1140 | " 2,\n", 1141 | " 2,\n", 1142 | " 2,\n", 1143 | " 2,\n", 1144 | " 0,\n", 1145 | " 2,\n", 1146 | " 0,\n", 1147 | " 4,\n", 1148 | " 2,\n", 1149 | " 2,\n", 1150 | " 2,\n", 1151 | " 2,\n", 1152 | " 0,\n", 1153 | " 2,\n", 1154 | " 2,\n", 1155 | " 0,\n", 1156 | " 2,\n", 1157 | " 2,\n", 1158 | " 2,\n", 1159 | " 0,\n", 1160 | " 0,\n", 1161 | " 2,\n", 1162 | " 1,\n", 1163 | " 2,\n", 1164 | " 2,\n", 1165 | " 2,\n", 1166 | " 3,\n", 1167 | " 2,\n", 1168 | " 2,\n", 1169 | " 2,\n", 1170 | " 2,\n", 1171 | " 2,\n", 1172 | " 4,\n", 1173 | " 0,\n", 1174 | " 2,\n", 1175 | " 0,\n", 1176 | " 2,\n", 1177 | " 2,\n", 1178 | " 2,\n", 1179 | " 2,\n", 1180 | " 0,\n", 1181 | " 2,\n", 1182 | " 4,\n", 1183 | " 2,\n", 1184 | " 0,\n", 1185 | " 2,\n", 1186 | " 0,\n", 1187 | " 2,\n", 1188 | " 2,\n", 1189 | " 2,\n", 1190 | " 0,\n", 1191 | " 2,\n", 1192 | " 2,\n", 1193 | " 2,\n", 1194 | " 0,\n", 1195 | " 2,\n", 1196 | " 0,\n", 1197 | " 0,\n", 1198 | " 0,\n", 1199 | " 0,\n", 1200 | " 2,\n", 1201 | " 2,\n", 1202 | " 2,\n", 1203 | " 0,\n", 1204 | " 2,\n", 1205 | " 2,\n", 1206 | " 2,\n", 1207 | " 2,\n", 1208 | " 2,\n", 1209 | " 2,\n", 1210 | " 0,\n", 1211 | " 2,\n", 1212 | " 0,\n", 1213 | " 3,\n", 1214 | " 2,\n", 1215 | " 0,\n", 1216 | " 0,\n", 1217 | " 2,\n", 1218 | " 3,\n", 1219 | " 2,\n", 1220 | " 1,\n", 1221 | " 2,\n", 1222 | " 1,\n", 1223 | " 2,\n", 1224 | " 2,\n", 1225 | " 2,\n", 1226 | " 1,\n", 1227 | " 0,\n", 1228 | " 2,\n", 1229 | " 2,\n", 1230 | " 3,\n", 1231 | " 2,\n", 1232 | " 2,\n", 1233 | " 2,\n", 1234 | " 0,\n", 1235 | " 2,\n", 1236 | " 2,\n", 1237 | " 2,\n", 1238 | " 2,\n", 1239 | " 1,\n", 1240 | " 2,\n", 1241 | " 2,\n", 1242 | " 2,\n", 1243 | " 2,\n", 1244 | " 2,\n", 1245 | " 0,\n", 1246 | " 2,\n", 1247 | " 2,\n", 1248 | " 2,\n", 1249 | " 2,\n", 1250 | " 2,\n", 1251 | " 2,\n", 1252 | " 2,\n", 1253 | " 2,\n", 1254 | " 2,\n", 1255 | " 3,\n", 1256 | " 0,\n", 1257 | " 2,\n", 1258 | " 1,\n", 1259 | " 0,\n", 1260 | " 2,\n", 1261 | " 0,\n", 1262 | " 2,\n", 1263 | " 1,\n", 1264 | " 0,\n", 1265 | " 2,\n", 1266 | " 2,\n", 1267 | " 2,\n", 1268 | " 2,\n", 1269 | " 2,\n", 1270 | " 2,\n", 1271 | " 2,\n", 1272 | " 2,\n", 1273 | " 3,\n", 1274 | " 2,\n", 1275 | " 0,\n", 1276 | " 2,\n", 1277 | " 2,\n", 1278 | " 0,\n", 1279 | " 2,\n", 1280 | " 1,\n", 1281 | " 1,\n", 1282 | " 0,\n", 1283 | " 0,\n", 1284 | " 2,\n", 1285 | " 2,\n", 1286 | " 2,\n", 1287 | " 2,\n", 1288 | " 2,\n", 1289 | " 2,\n", 1290 | " 2,\n", 1291 | " 1,\n", 1292 | " 1,\n", 1293 | " 0,\n", 1294 | " 2,\n", 1295 | " 2,\n", 1296 | " 0,\n", 1297 | " 2,\n", 1298 | " 0,\n", 1299 | " 1,\n", 1300 | " 1,\n", 1301 | " 2,\n", 1302 | " 2,\n", 1303 | " 0,\n", 1304 | " 0,\n", 1305 | " 2,\n", 1306 | " 2,\n", 1307 | " 2,\n", 1308 | " 0,\n", 1309 | " 1,\n", 1310 | " 2,\n", 1311 | " 4,\n", 1312 | " 1,\n", 1313 | " 2,\n", 1314 | " 2,\n", 1315 | " 2,\n", 1316 | " 2,\n", 1317 | " 2,\n", 1318 | " 2,\n", 1319 | " 0,\n", 1320 | " 0,\n", 1321 | " 2,\n", 1322 | " 2,\n", 1323 | " 1,\n", 1324 | " 3,\n", 1325 | " 2,\n", 1326 | " 2,\n", 1327 | " 2,\n", 1328 | " 0,\n", 1329 | " 2,\n", 1330 | " 2,\n", 1331 | " 0,\n", 1332 | " 2,\n", 1333 | " 2,\n", 1334 | " 2,\n", 1335 | " 2,\n", 1336 | " 2,\n", 1337 | " 4,\n", 1338 | " 0,\n", 1339 | " 2,\n", 1340 | " 2,\n", 1341 | " 4,\n", 1342 | " 2,\n", 1343 | " 2,\n", 1344 | " 2,\n", 1345 | " 2,\n", 1346 | " 2,\n", 1347 | " 2,\n", 1348 | " 2,\n", 1349 | " 0,\n", 1350 | " 1,\n", 1351 | " 2,\n", 1352 | " 2,\n", 1353 | " 2,\n", 1354 | " 2,\n", 1355 | " 2,\n", 1356 | " 2,\n", 1357 | " 2,\n", 1358 | " 2,\n", 1359 | " 2,\n", 1360 | " 2,\n", 1361 | " 0,\n", 1362 | " 2,\n", 1363 | " 2,\n", 1364 | " 2,\n", 1365 | " 2,\n", 1366 | " 2,\n", 1367 | " 2,\n", 1368 | " 2,\n", 1369 | " 2,\n", 1370 | " 3,\n", 1371 | " 1,\n", 1372 | " 2,\n", 1373 | " 2,\n", 1374 | " 0,\n", 1375 | " 2,\n", 1376 | " 2,\n", 1377 | " 2,\n", 1378 | " 3,\n", 1379 | " 2,\n", 1380 | " 2,\n", 1381 | " 2,\n", 1382 | " 2,\n", 1383 | " 2,\n", 1384 | " 2,\n", 1385 | " 2,\n", 1386 | " 2,\n", 1387 | " 3,\n", 1388 | " 2,\n", 1389 | " 1,\n", 1390 | " 3,\n", 1391 | " 0,\n", 1392 | " 0,\n", 1393 | " 0,\n", 1394 | " 0,\n", 1395 | " 2,\n", 1396 | " 2,\n", 1397 | " 2,\n", 1398 | " 2,\n", 1399 | " 2,\n", 1400 | " 2,\n", 1401 | " 2,\n", 1402 | " 2,\n", 1403 | " 2,\n", 1404 | " 3,\n", 1405 | " 2,\n", 1406 | " 3,\n", 1407 | " 2,\n", 1408 | " 2,\n", 1409 | " 2,\n", 1410 | " 2,\n", 1411 | " 2,\n", 1412 | " 2,\n", 1413 | " 2,\n", 1414 | " 2,\n", 1415 | " 2,\n", 1416 | " 2,\n", 1417 | " 2,\n", 1418 | " 2,\n", 1419 | " 1,\n", 1420 | " 2,\n", 1421 | " 3,\n", 1422 | " 2,\n", 1423 | " 2,\n", 1424 | " 2,\n", 1425 | " 2,\n", 1426 | " 2,\n", 1427 | " 2,\n", 1428 | " 2,\n", 1429 | " 2,\n", 1430 | " 0,\n", 1431 | " 2,\n", 1432 | " 3,\n", 1433 | " 0,\n", 1434 | " 2,\n", 1435 | " 2,\n", 1436 | " 1,\n", 1437 | " 2,\n", 1438 | " 1,\n", 1439 | " 2,\n", 1440 | " 2,\n", 1441 | " 1,\n", 1442 | " 2,\n", 1443 | " 0,\n", 1444 | " 2,\n", 1445 | " 1,\n", 1446 | " 4,\n", 1447 | " 2,\n", 1448 | " 2,\n", 1449 | " 2,\n", 1450 | " 2,\n", 1451 | " 2,\n", 1452 | " 2,\n", 1453 | " 2,\n", 1454 | " 2,\n", 1455 | " 0,\n", 1456 | " 2,\n", 1457 | " 2,\n", 1458 | " 3,\n", 1459 | " 0,\n", 1460 | " 2,\n", 1461 | " 2,\n", 1462 | " 2,\n", 1463 | " 0,\n", 1464 | " 2,\n", 1465 | " 2,\n", 1466 | " 0,\n", 1467 | " 2,\n", 1468 | " 3,\n", 1469 | " 2,\n", 1470 | " 2,\n", 1471 | " 2,\n", 1472 | " 0,\n", 1473 | " 4,\n", 1474 | " 2,\n", 1475 | " 4,\n", 1476 | " 2,\n", 1477 | " 0,\n", 1478 | " 2,\n", 1479 | " 2,\n", 1480 | " 2,\n", 1481 | " 2,\n", 1482 | " 1,\n", 1483 | " 2,\n", 1484 | " 2,\n", 1485 | " 2,\n", 1486 | " 2,\n", 1487 | " 3,\n", 1488 | " 1,\n", 1489 | " 4,\n", 1490 | " 1,\n", 1491 | " 0,\n", 1492 | " 2,\n", 1493 | " 2,\n", 1494 | " 2,\n", 1495 | " 2,\n", 1496 | " 2,\n", 1497 | " 2,\n", 1498 | " 0,\n", 1499 | " 2,\n", 1500 | " 2,\n", 1501 | " 0,\n", 1502 | " 3,\n", 1503 | " 2,\n", 1504 | " 2,\n", 1505 | " 2,\n", 1506 | " 2,\n", 1507 | " 2,\n", 1508 | " 2,\n", 1509 | " 2,\n", 1510 | " 2,\n", 1511 | " 2,\n", 1512 | " 2,\n", 1513 | " 2,\n", 1514 | " 2,\n", 1515 | " 2,\n", 1516 | " 2,\n", 1517 | " 2,\n", 1518 | " 2,\n", 1519 | " 2,\n", 1520 | " 1,\n", 1521 | " 1,\n", 1522 | " 2,\n", 1523 | " 0,\n", 1524 | " 2,\n", 1525 | " 2,\n", 1526 | " 2,\n", 1527 | " 0,\n", 1528 | " 0,\n", 1529 | " 0,\n", 1530 | " 2,\n", 1531 | " 2,\n", 1532 | " 0,\n", 1533 | " 2,\n", 1534 | " 4,\n", 1535 | " 2,\n", 1536 | " 1,\n", 1537 | " 2,\n", 1538 | " 2,\n", 1539 | " 2,\n", 1540 | " 2,\n", 1541 | " 2,\n", 1542 | " 1,\n", 1543 | " 3,\n", 1544 | " 2,\n", 1545 | " 2,\n", 1546 | " 3,\n", 1547 | " 2,\n", 1548 | " 0,\n", 1549 | " 2,\n", 1550 | " 2,\n", 1551 | " 1,\n", 1552 | " 2,\n", 1553 | " 2,\n", 1554 | " 2,\n", 1555 | " 2,\n", 1556 | " 2,\n", 1557 | " 2,\n", 1558 | " 3,\n", 1559 | " 3,\n", 1560 | " 1,\n", 1561 | " 3,\n", 1562 | " 2,\n", 1563 | " 2,\n", 1564 | " 2,\n", 1565 | " 2,\n", 1566 | " 0,\n", 1567 | " 2,\n", 1568 | " 1,\n", 1569 | " 4,\n", 1570 | " 0,\n", 1571 | " 2,\n", 1572 | " 2,\n", 1573 | " 2,\n", 1574 | " 2,\n", 1575 | " 2,\n", 1576 | " 2,\n", 1577 | " 3,\n", 1578 | " 0,\n", 1579 | " 2,\n", 1580 | " 0,\n", 1581 | " ...]" 1582 | ] 1583 | }, 1584 | "metadata": { 1585 | "tags": [] 1586 | }, 1587 | "execution_count": 21 1588 | } 1589 | ] 1590 | }, 1591 | { 1592 | "cell_type": "code", 1593 | "metadata": { 1594 | "id": "GTyoqsENyMSj", 1595 | "colab_type": "code", 1596 | "colab": {} 1597 | }, 1598 | "source": [ 1599 | "sample_sub = pd.read_csv('../input/aptos2019-blindness-detection/sample_submission.csv')\n" 1600 | ], 1601 | "execution_count": 0, 1602 | "outputs": [] 1603 | }, 1604 | { 1605 | "cell_type": "code", 1606 | "metadata": { 1607 | "id": "z805HlDiyMSp", 1608 | "colab_type": "code", 1609 | "colab": {} 1610 | }, 1611 | "source": [ 1612 | "sample_sub.diagnosis = pred_labels\n" 1613 | ], 1614 | "execution_count": 0, 1615 | "outputs": [] 1616 | }, 1617 | { 1618 | "cell_type": "code", 1619 | "metadata": { 1620 | "id": "MQvuj9dCyMSu", 1621 | "colab_type": "code", 1622 | "outputId": "c9245967-8905-49ac-b449-c7fd4f420da7", 1623 | "colab": {} 1624 | }, 1625 | "source": [ 1626 | "sample_sub.head()" 1627 | ], 1628 | "execution_count": 0, 1629 | "outputs": [ 1630 | { 1631 | "output_type": "execute_result", 1632 | "data": { 1633 | "text/html": [ 1634 | "
\n", 1635 | "\n", 1648 | "\n", 1649 | " \n", 1650 | " \n", 1651 | " \n", 1652 | " \n", 1653 | " \n", 1654 | " \n", 1655 | " \n", 1656 | " \n", 1657 | " \n", 1658 | " \n", 1659 | " \n", 1660 | " \n", 1661 | " \n", 1662 | " \n", 1663 | " \n", 1664 | " \n", 1665 | " \n", 1666 | " \n", 1667 | " \n", 1668 | " \n", 1669 | " \n", 1670 | " \n", 1671 | " \n", 1672 | " \n", 1673 | " \n", 1674 | " \n", 1675 | " \n", 1676 | " \n", 1677 | " \n", 1678 | " \n", 1679 | " \n", 1680 | " \n", 1681 | " \n", 1682 | " \n", 1683 | "
id_codediagnosis
00005cfc8afb62
1003f0afdcd152
2006efc72b6382
300836aaacf062
4009245722fa42
\n", 1684 | "
" 1685 | ], 1686 | "text/plain": [ 1687 | " id_code diagnosis\n", 1688 | "0 0005cfc8afb6 2\n", 1689 | "1 003f0afdcd15 2\n", 1690 | "2 006efc72b638 2\n", 1691 | "3 00836aaacf06 2\n", 1692 | "4 009245722fa4 2" 1693 | ] 1694 | }, 1695 | "metadata": { 1696 | "tags": [] 1697 | }, 1698 | "execution_count": 24 1699 | } 1700 | ] 1701 | }, 1702 | { 1703 | "cell_type": "code", 1704 | "metadata": { 1705 | "id": "RqeXCZ1IyMSw", 1706 | "colab_type": "code", 1707 | "colab": {} 1708 | }, 1709 | "source": [ 1710 | "sample_sub.to_csv('submission.csv', index=False)" 1711 | ], 1712 | "execution_count": 0, 1713 | "outputs": [] 1714 | } 1715 | ] 1716 | } -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Importing all packages 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from torch.utils import data 5 | import torch 6 | from torch import nn 7 | from torch import optim 8 | import torchvision 9 | import torch.nn.functional as F 10 | from torchvision import datasets, transforms, models 11 | import torchvision.models as models 12 | from PIL import Image, ImageFile 13 | import json 14 | from torch.optim import lr_scheduler 15 | import random 16 | import os 17 | import sys 18 | 19 | print('Imported packages') 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | model = models.resnet152(pretrained=False) 22 | num_ftrs = model.fc.in_features 23 | out_ftrs = 5 24 | model.fc = nn.Sequential(nn.Linear(num_ftrs, 512),nn.ReLU(),nn.Linear(512,out_ftrs),nn.LogSoftmax(dim=1)) 25 | criterion = nn.NLLLoss() 26 | optimizer = torch.optim.Adam(filter(lambda p:p.requires_grad,model.parameters()) , lr = 0.00001) 27 | 28 | scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) 29 | model.to(device); 30 | # to unfreeze more layers 31 | 32 | 33 | for name,child in model.named_children(): 34 | if name in ['layer2','layer3','layer4','fc']: 35 | #print(name + 'is unfrozen') 36 | for param in child.parameters(): 37 | param.requires_grad = True 38 | else: 39 | #print(name + 'is frozen') 40 | for param in child.parameters(): 41 | param.requires_grad = False 42 | optimizer = torch.optim.Adam(filter(lambda p:p.requires_grad,model.parameters()) , lr = 0.000001) 43 | scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) 44 | 45 | def load_model(path): 46 | checkpoint = torch.load(path,map_location='cpu') 47 | model.load_state_dict(checkpoint['model_state_dict']) 48 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 49 | 50 | return model 51 | def inference(model, file, transform, classes): 52 | file = Image.open(file).convert('RGB') 53 | img = transform(file).unsqueeze(0) 54 | print('Transforming your image...') 55 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 56 | model.eval() 57 | with torch.no_grad(): 58 | print('Passing your image to the model....') 59 | out = model(img.to(device)) 60 | ps = torch.exp(out) 61 | top_p, top_class = ps.topk(1, dim=1) 62 | value = top_class.item() 63 | print("Predicted Severity Value: ", value) 64 | print("class is: ", classes[value]) 65 | print('Your image is printed:') 66 | return value, classes[value] 67 | # plt.imshow(np.array(file)) 68 | # plt.show() 69 | 70 | 71 | model = load_model('../Desktop/classifier.pt') 72 | print("Model loaded Succesfully") 73 | classes = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR'] 74 | test_transforms = torchvision.transforms.Compose([ 75 | torchvision.transforms.Resize((224, 224)), 76 | torchvision.transforms.RandomHorizontalFlip(p=0.5), 77 | torchvision.transforms.ToTensor(), 78 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 79 | ]) 80 | def main(path): 81 | x, y = inference(model, path, test_transforms, classes) 82 | return x, y 83 | # if __name__ == '__model__': 84 | # # test_dir = '../Desktop/eye' 85 | # # folders = os.listdir(test_dir) 86 | # # for num in range(len(folders)): 87 | # # path = test_dir+"/"+folders[num] 88 | # # print(path) 89 | # # inference(model, path, test_transforms, classes) 90 | # l = sys.argv 91 | # if(len(l)>1): 92 | # for i in range(1, len(l)): 93 | # print(l[i]) 94 | # path = l[i] 95 | # inference(model, path, test_transforms, classes) 96 | # else: 97 | # print('please provide the exact path of image !') 98 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.0.2 2 | notebook==6.4.12 3 | numpy==1.22.0 4 | pandas==0.23.4 5 | Pillow>=7.1.0 6 | python-dateutil==2.7.5 7 | python-decouple==3.1 8 | python-engineio==3.11.1 9 | python-slugify==3.0.4 10 | python-socketio==2.0.0 11 | scikit-learn==0.21.2 12 | torch==1.13.1 13 | torchvision==0.3.0 14 | -------------------------------------------------------------------------------- /sampleimages/eye1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye1.png -------------------------------------------------------------------------------- /sampleimages/eye10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye10.jpg -------------------------------------------------------------------------------- /sampleimages/eye11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye11.png -------------------------------------------------------------------------------- /sampleimages/eye12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye12.jpg -------------------------------------------------------------------------------- /sampleimages/eye13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye13.jpg -------------------------------------------------------------------------------- /sampleimages/eye14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye14.jpg -------------------------------------------------------------------------------- /sampleimages/eye15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye15.jpg -------------------------------------------------------------------------------- /sampleimages/eye16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye16.png -------------------------------------------------------------------------------- /sampleimages/eye17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye17.png -------------------------------------------------------------------------------- /sampleimages/eye18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye18.png -------------------------------------------------------------------------------- /sampleimages/eye19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye19.png -------------------------------------------------------------------------------- /sampleimages/eye2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye2.png -------------------------------------------------------------------------------- /sampleimages/eye20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye20.png -------------------------------------------------------------------------------- /sampleimages/eye3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye3.png -------------------------------------------------------------------------------- /sampleimages/eye4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye4.jpg -------------------------------------------------------------------------------- /sampleimages/eye5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye5.jpg -------------------------------------------------------------------------------- /sampleimages/eye6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye6.jpg -------------------------------------------------------------------------------- /sampleimages/eye7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye7.jpg -------------------------------------------------------------------------------- /sampleimages/eye8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye8.jpg -------------------------------------------------------------------------------- /sampleimages/eye9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/sampleimages/eye9.jpg -------------------------------------------------------------------------------- /send_sms.py: -------------------------------------------------------------------------------- 1 | from twilio.rest import Client 2 | 3 | #-------------------------------------------------------- 4 | # change values of account_sid, auth_token, to and from - all from twilio account 5 | #------------------------------------------------------- 6 | # def send(value, classes): 7 | # # Your Account SID from twilio.com/console 8 | # account_sid = "**************************8" 9 | # # Your Auth Token from twilio.com/console 10 | # auth_token = "**************************" 11 | # 12 | # client = Client(account_sid, auth_token) 13 | # 14 | # message = client.messages.create( 15 | # to="+**********", 16 | # from_="+12*********", 17 | # body=f"Blindness detection system report! severity level is : {value} and class is {classes}") 18 | # 19 | # print('Message sent Succesfully !') 20 | # print(message.sid) 21 | --------------------------------------------------------------------------------