├── GettingStarted.md
├── README.md
├── Single_test_inference.ipynb
├── blindness.py
├── images
├── gui1.JPG
├── gui2.JPG
├── gui3.JPG
├── mat.png
├── sms.JPG
├── vis.gif
└── visual1.JPG
├── inference.ipynb
├── model.py
├── requirements.txt
├── sampleimages
├── eye1.png
├── eye10.jpg
├── eye11.png
├── eye12.jpg
├── eye13.jpg
├── eye14.jpg
├── eye15.jpg
├── eye16.png
├── eye17.png
├── eye18.png
├── eye19.png
├── eye2.png
├── eye20.png
├── eye3.png
├── eye4.jpg
├── eye5.jpg
├── eye6.jpg
├── eye7.jpg
├── eye8.jpg
└── eye9.jpg
├── send_sms.py
└── training.ipynb
/GettingStarted.md:
--------------------------------------------------------------------------------
1 | ## Installation :
2 | Tip : Make sure to install [Numpy](https://pypi.org/project/numpy/), [Pandas](https://pypi.org/project/pandas/), [Matplotlib](https://pypi.org/project/matplotlib/) first and then proceed next.
3 | * [Torch package](https://pytorch.org/get-started/locally/)
4 | * [Tkinter](https://tkdocs.com/tutorial/install.html)
5 | * [HeidiSQL](https://www.heidisql.com/download.php)
6 | Grab a cup of coffee as these will take some time !
7 | * [click here](https://support.hypernode.com/knowledgebase/use-heidisql/#Download_HeidiSQL) to start server in HeidiSQL and configure settings by setting username and password.
8 | ## Get, set and go :
9 | * Download complete Project files using following command from git bash/ cmd (terminal):
10 | ```
11 | git clone https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch
12 |
13 | ```
14 | [Note you need mainly these four things to get started :]
15 | > [blindness.py](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/blindness.py)
16 | > [model.py](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/model.py)
17 | > [classifier.pt](#)
18 | > [send_sms.py](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/send_sms.py)
19 |
20 | * Create a new database and table accordingly.
21 | * Then, Go to 'blindness.py' file and change some configuration settings according to your database.
22 | ```
23 | connection = sk.connect(
24 | host="localhost",
25 | user="root",
26 | password="********",
27 | database="********"
28 | )
29 | ```
30 | * Now, your DB server must be connected.
31 | * Finally, you also want 'classifier.pt' file which contains model's dictionary required when it is to be loaded.
32 | [Download here](https://www.kaggle.com/souravs17031999/blindness-detection-pretrained-weights-pytorch) and put that file in the same directory and then modify the path accordingly in the 'model.py' file.
33 | ```
34 | model = load_model('../Desktop/classifier.pt')
35 |
36 | ```
37 | * Finally, execute your 'blindness.py' file and your GUI must start (recommended to start this from your terminal and keep all your project files in same directory).
38 | * Upload the image and get your predictions.
39 |
40 | ## Optional :
41 | * If you want to get SMS on mobile for your predictions , then Create an account on [Twilio](http://twilio.com/) by verifying your number.
42 | * Next, get your credentials from the Dashboard of twilio and use 'send_sms.py' to fetch API request and fill and replace your credentials.
43 | [Note, currently 'send_sms.py' is commented out, so uncomment everything before using it].
44 | * Then, uncomment following lines in 'blindness.py'
45 | ```#from send_sms import *```
46 | ```#send(value, classes)```
47 | * Your messaging service should start and also you now see Message_Id printed on your terminal.
48 |
49 |
50 | [Note : You can use sample images in the folder [sampleimages](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/tree/master/sampleimages) which is taken from the original test dataset to test the system]
51 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Project Name : Retinal Blindness (Diabetic Retinopathy) detection
2 |
3 | # Problem statement :
4 | Diabetic Retinopathy is a disease with an increasing prevalence and the main cause of blindness among working-age population. The risk of severe vision loss can be significantly reduced by timely diagnosis and treatment. Systematic screening for DR has been identified as a cost-effective way to save health services resources. Automatic retinal image analysis is emerging as an important screening tool for early DR detection, which can reduce the workload associated to manual grading as well as save diagnosis costs and time. Many research efforts in the last years have been devoted to developing automated tools to help in the detection and evaluation of DR lesions.
5 | We are interested in automating this predition using deep learning models.
6 |
7 | # Dataset : [APOTS Kaggle Blindness dataset](https://www.kaggle.com/c/aptos2019-blindness-detection)
8 |
9 | # Solution :
10 | I am proposing Deep Learning classification technique using CNN pretrained model [resnet152](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) to classify severity levels of DR ranging from 0 (NO DR) to 4 (Proliferative DR).
11 | This is a collaborative project of team of three where my main work is on developing, training and testing various CNN models along with some secondary work.
12 | Deep learning looks promising because already various types of image classification tasks has been performed by various CNN's so, we can rely on DL pretrained models or we can modify some layers if we wish to :)
13 | A GUI based system has been made using Tkinter and used heidiSQL to maintain and store a list of predictions with their patient id and name (which is very risky , the reason we will get to it some time later).
14 | Twilio API have been used to Make SMS connectivity to patients possible in case they are not contactable or accesible (in that case we can also use mail).
15 |
16 | # Summary of Technologies used in this project :
17 | | Dev Env. | Framework/ library/ languages |
18 | | ------------- | ------------- |
19 | | Backend development | PyTorch (Deep learning framework) |
20 | | Frontend development | Tkinter (Python GUI toolkit) |
21 | | Database connectivity | HeidiSQL (MySQL server) |
22 | | Programming Languages | Python, SQL |
23 | | API | Twilio cloud API|
24 |
25 | # Data visualization :
26 | Input data (raw) is like this -
27 | 
28 |
29 | # Resnet152 model summary :
30 | I have only shown below the main layers of resnet and each of the 'layer1', 'layer2', 'layer3' and 'layer4' contains various more layers.
31 |
32 | 
33 |
34 | # Visualization of complete system :
35 | 
36 |
37 |
38 | # Getting started :
39 | [Click](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/GettingStarted.md) here to get started locally on your system.
40 |
41 | ## Some snaps :
42 | 
43 | 
44 | 
45 | 
46 |
47 |
48 | # Future Prospect :
49 | * My next goal is to develop this into WebApp (probably using some light weight model as resnet models are heavy).
50 | * Next goal will be using encryption techniques to achieve not only high accuracy but also high level of privacy in terms of differentially private basis and use technqiues such as Federated learning and Secure Multi party computation for privacy preserving deep learning classification.
51 | Btw, i have already made one project using federated learning on classification task , [check out here](https://github.com/souravs17031999/Federatedencryption-showcase).
52 | Acheiving a level of privacy is also very important task in medical datasets so that there can be factor of trust established between different stakeholders using the system.
53 | * Some ideas for concurrency control has to be implemented properly using some kind of locks defined in MySQL so that multiple users can use the system at the same time when deployed on web.
54 | (Otherwise, locally you can run the executable file multiple times to open and run the GUI and it works fine).
55 | * Reducing TYPE-II error (false negatives) as this metric is really useful in Healthcare domain.
56 |
57 | # Navigating the project :
58 | * [Check out the training code here](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/training.ipynb)
59 | * [Check out testing done on unseen image](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/Single_test_inference.ipynb)
60 | * [Check out the executable file (for running GUI)](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/blindness.py)
61 | * [check out the model executable file (for loading to get inference locally)](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/model.py)
62 | * [check out the Twilio API executable file (to get SMS for inference)](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/send_sms.py)
63 | * [for getting pre-trained weights for this model, check out getting started section here](https://github.com/souravs17031999/Retinal_blindness_detection_Pytorch/blob/master/GettingStarted.md)
64 |
65 | [Note : The training files in this repo is only shown after final training as it took around more than 100 epochs to reach 97% accuracy and a lot of compute power and time.]
66 |
67 |
68 | ⭐️ this Project if you liked it !
69 |
--------------------------------------------------------------------------------
/blindness.py:
--------------------------------------------------------------------------------
1 | # Importing all packages
2 | from tkinter import *
3 | from tkinter.ttk import *
4 | from tkinter import messagebox
5 | from PIL import Image
6 | import os
7 |
8 | import mysql.connector
9 | from tkinter.filedialog import askopenfilename, asksaveasfilename
10 | import mysql.connector as sk
11 | from model import *
12 | #from send_sms import *
13 | print('GUI SYSTEM STARTED...')
14 | #---------------------------------------------------------------------------------
15 |
16 | def LogIn():
17 | username = box1.get()
18 |
19 | u = 1
20 |
21 | if len(username) == 0:
22 | u = 0
23 | messagebox.showinfo("Error", "You must enter something Sir")
24 |
25 | if u:
26 | password = box2.get()
27 |
28 | if len(password):
29 | query = "SELECT * FROM THEGREAT"
30 |
31 | sql.execute(query)
32 |
33 | data = sql.fetchall()
34 |
35 | g = 0
36 | b = 0
37 |
38 | for i in data:
39 | if i[0] == username:
40 | g = 1
41 | if i[1] == password:
42 | b = 1
43 |
44 |
45 | if g and b:
46 | messagebox.showinfo('Hello Sir', 'Welcome to the System')
47 | else:
48 | messagebox.showinfo('Sorry', 'Wrong Username or Password')
49 |
50 | global y
51 | y = True
52 | else:
53 | messagebox.showinfo("Error", "You must enter a password Sir!!")
54 |
55 | def OpenFile():
56 | username = box1.get()
57 | if y:
58 | try:
59 | a = askopenfilename()
60 | print(a)
61 | value, classes = main(a)
62 | messagebox.showinfo("your report", ("Predicted Label is ", value, "\nPredicted Class is ", classes))
63 |
64 | query = 'UPDATE THEGREAT SET PREDICT = "%s" WHERE USERNAME = "%s"'%(value, username)
65 |
66 | sql.execute(query)
67 | #print(query)
68 | connection.commit()
69 |
70 | #------********************Only use when required to send message
71 | #send(value, classes)
72 | #------*********************************************************
73 | image = Image.open(a)
74 | # plotting image
75 | file = image.convert('RGB')
76 | plt.imshow(np.array(file))
77 | plt.title(f'your report is label : {value} class : {classes}')
78 | plt.show()
79 | #print(image)
80 | print('Thanks for using the system !')
81 | #fn, text = os.path.splitext(a) #fn stands for filename
82 | except Exception as error:
83 | print("File not selected ! Exiting..., Please try again")
84 |
85 |
86 | else:
87 | messagebox.showinfo("Hello Sir", "You need to Login first")
88 |
89 |
90 | x = 0
91 | y = False
92 |
93 |
94 | def Signup():
95 | username = box1.get()
96 | password = box2.get()
97 |
98 | u = 1
99 |
100 | if len(username) == 0 or len(password) == 0:
101 | u = 0
102 | messagebox.showinfo("Error", "You must enter something Sir")
103 |
104 | if u:
105 | query1 = "SELECT * FROM THEGREAT"
106 | sql.execute(query1)
107 |
108 | data = sql.fetchall()
109 |
110 | z = 1
111 |
112 | for i in data:
113 | if i[0] == username:
114 | messagebox.showinfo("Sorry Sir", "This username is already registered, try a new one")
115 | z = 0
116 |
117 | if z:
118 | query = "INSERT INTO THEGREAT (USERNAME, PASSWORD) VALUES('%s', '%s')" % (username, password)
119 | messagebox.showinfo("signed up", ("Hi ",username ,"\n Now you can login with your credentials !"))
120 | sql.execute(query)
121 | connection.commit()
122 |
123 |
124 | #-----------------------------------------------------------------------------------------
125 |
126 |
127 | connection = sk.connect(
128 | host="localhost",
129 | user="root",
130 | password="SOURAVs99@",
131 | database="batch_db_new"
132 | )
133 |
134 | sql = connection.cursor()
135 |
136 | root = Tk()
137 |
138 | root.geometry('700x400')
139 | root.title("SK's Blindness Detection System")
140 | root.configure(bg='pale turquoise')
141 |
142 |
143 | label1 = Label(root, text="Demo for BDS", font=('Arial', 30))
144 | label1.grid(padx=30, pady=30, row=0, column=0, sticky='W')
145 |
146 | label2 = Label(root, text="Enter your username: ", font=('Arial', 20))
147 | label2.grid(padx=10, pady=10, row=1, column=0, sticky='W')
148 |
149 | label3 = Label(root, text="Enter your password: ", font=('Arial', 20))
150 | label3.grid(padx=10, pady=20, row=2, column=0, sticky='W')
151 |
152 | box1 = Entry(root)
153 | box1.grid(row=1, column=1)
154 |
155 | box2 = Entry(root, show='*')
156 | box2.grid(row=2, column=1)
157 |
158 | button3 = Button(root, text="Signup", command=Signup)
159 | button3.grid(padx=10, pady=20, row=3, column=1)
160 |
161 | button1 = Button(root, text="LogIn", command=LogIn)
162 | button1.grid(padx=10, pady=20, row=3, column=2)
163 |
164 | button2 = Button(root, text="Upload Image", command=OpenFile)
165 | button2.grid(padx=10, pady=20, row=2, column=3)
166 |
167 | # concurrency control in InnoDB
168 | # Read_locks useful when locks another user trying to update the value in the same row which is allocated for another user , both at the same time
169 | #SELECT * FROM t1, t2 FOR SHARE OF t1 FOR UPDATE OF t2;
170 | # START TRANSACTION;
171 | # SELECT * FROM your_table WHERE state != 'PROCESSING'
172 | # ORDER BY date_added ASC LIMIT 1 FOR UPDATE;
173 | # if (rows_selected = 0) { //finished processing the queue, abort}
174 | # else {
175 | # UPDATE your_table WHERE id = $row.id SET state = 'PROCESSING'
176 | # COMMIT;
177 | #
178 | # // row is processed here, outside of the transaction, and it can take as much time as we want
179 | #
180 | # // once we finish:
181 | # DELETE FROM your_table WHERE id = $row.id and state = 'PROCESSING' LIMIT 1;
182 | # }
183 |
184 | root.mainloop()
185 |
--------------------------------------------------------------------------------
/images/gui1.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/images/gui1.JPG
--------------------------------------------------------------------------------
/images/gui2.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/images/gui2.JPG
--------------------------------------------------------------------------------
/images/gui3.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/images/gui3.JPG
--------------------------------------------------------------------------------
/images/mat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/images/mat.png
--------------------------------------------------------------------------------
/images/sms.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/images/sms.JPG
--------------------------------------------------------------------------------
/images/vis.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/images/vis.gif
--------------------------------------------------------------------------------
/images/visual1.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/souravs17031999/Retinal_blindness_detection_Pytorch/39817154741dd10ede24b3b3fdfbc96b305f0fa9/images/visual1.JPG
--------------------------------------------------------------------------------
/inference.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "kernelspec": {
6 | "display_name": "Python 3",
7 | "language": "python",
8 | "name": "python3"
9 | },
10 | "language_info": {
11 | "codemirror_mode": {
12 | "name": "ipython",
13 | "version": 3
14 | },
15 | "file_extension": ".py",
16 | "mimetype": "text/x-python",
17 | "name": "python",
18 | "nbconvert_exporter": "python",
19 | "pygments_lexer": "ipython3",
20 | "version": "3.6.6"
21 | },
22 | "colab": {
23 | "name": "inference.ipynb",
24 | "provenance": [],
25 | "include_colab_link": true
26 | }
27 | },
28 | "cells": [
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "view-in-github",
33 | "colab_type": "text"
34 | },
35 | "source": [
36 | ""
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "metadata": {
42 | "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
43 | "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a",
44 | "id": "Rt29Iy94yMO2",
45 | "colab_type": "code",
46 | "outputId": "fb0fa31a-c39c-429e-8128-dc050d121773",
47 | "colab": {}
48 | },
49 | "source": [
50 | "# Imports here\n",
51 | "from __future__ import print_function, division\n",
52 | "import numpy as np\n",
53 | "import matplotlib.pyplot as plt\n",
54 | "from torch.utils import data\n",
55 | "import torch\n",
56 | "from torch import nn\n",
57 | "from torch import optim\n",
58 | "import torchvision\n",
59 | "import torch.nn.functional as F\n",
60 | "from torchvision import datasets, transforms, models\n",
61 | "import torchvision.models as models\n",
62 | "from torch.utils.data.sampler import SubsetRandomSampler\n",
63 | "from torch.utils.data import Dataset, DataLoader\n",
64 | "from skimage import io, transform\n",
65 | "import torch.utils.data as data_utils\n",
66 | "from PIL import Image, ImageFile\n",
67 | "import json\n",
68 | "from torch.optim import lr_scheduler\n",
69 | "import time\n",
70 | "import os\n",
71 | "import argparse\n",
72 | "import copy\n",
73 | "import pandas as pd\n",
74 | "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
75 | "import cv2\n",
76 | "# Import useful sklearn functions\n",
77 | "import sklearn\n",
78 | "from sklearn.metrics import cohen_kappa_score, accuracy_score\n",
79 | "import time\n",
80 | "from tqdm import tqdm_notebook\n",
81 | "\n",
82 | "import os\n",
83 | "print(os.listdir(\"../input\"))\n",
84 | "base_dir = \"../input/aptos2019-blindness-detection/\""
85 | ],
86 | "execution_count": 0,
87 | "outputs": [
88 | {
89 | "output_type": "stream",
90 | "text": [
91 | "['aptos2019-blindness-detection', 'kernel4f121f3247']\n"
92 | ],
93 | "name": "stdout"
94 | }
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "metadata": {
100 | "id": "1XaCclZIyMPJ",
101 | "colab_type": "code",
102 | "outputId": "4d03fd95-4cc4-41a6-ab64-66383aae9ced",
103 | "colab": {}
104 | },
105 | "source": [
106 | "print(os.listdir(\"../input/kernel4f121f3247\"))\n"
107 | ],
108 | "execution_count": 0,
109 | "outputs": [
110 | {
111 | "output_type": "stream",
112 | "text": [
113 | "['__output__.json', '__results___files', 'custom.css', '__results__.html', '__notebook__.ipynb', 'classifier.pt']\n"
114 | ],
115 | "name": "stdout"
116 | }
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "metadata": {
122 | "id": "CfXBcJeKyMPV",
123 | "colab_type": "code",
124 | "colab": {}
125 | },
126 | "source": [
127 | "class CreateDataset(Dataset):\n",
128 | " def __init__(self, df_data, data_dir = '../input/', transform=None):\n",
129 | " super().__init__()\n",
130 | " self.df = df_data.values\n",
131 | " self.data_dir = data_dir\n",
132 | " self.transform = transform\n",
133 | "\n",
134 | " def __len__(self):\n",
135 | " return len(self.df)\n",
136 | " \n",
137 | " def __getitem__(self, index):\n",
138 | " img_name,label = self.df[index]\n",
139 | " img_path = os.path.join(self.data_dir, img_name+'.png')\n",
140 | " image = cv2.imread(img_path)\n",
141 | " if self.transform is not None:\n",
142 | " image = self.transform(image)\n",
143 | " return image, label"
144 | ],
145 | "execution_count": 0,
146 | "outputs": []
147 | },
148 | {
149 | "cell_type": "code",
150 | "metadata": {
151 | "id": "yYOZ68A2yMPf",
152 | "colab_type": "code",
153 | "colab": {}
154 | },
155 | "source": [
156 | "test_csv = pd.read_csv('../input/aptos2019-blindness-detection/test.csv')\n"
157 | ],
158 | "execution_count": 0,
159 | "outputs": []
160 | },
161 | {
162 | "cell_type": "code",
163 | "metadata": {
164 | "id": "14T6uoF2yMPm",
165 | "colab_type": "code",
166 | "colab": {}
167 | },
168 | "source": [
169 | "test_path = \"../input/aptos2019-blindness-detection/test_images/\"\n"
170 | ],
171 | "execution_count": 0,
172 | "outputs": []
173 | },
174 | {
175 | "cell_type": "code",
176 | "metadata": {
177 | "id": "3CuePz9tyMPq",
178 | "colab_type": "code",
179 | "colab": {}
180 | },
181 | "source": [
182 | "test_transforms = torchvision.transforms.Compose([\n",
183 | " torchvision.transforms.ToPILImage(),\n",
184 | " torchvision.transforms.Resize((224, 224)),\n",
185 | " #torchvision.transforms.ColorJitter(brightness=2, contrast=2),\n",
186 | " torchvision.transforms.RandomHorizontalFlip(p=0.5),\n",
187 | " torchvision.transforms.ToTensor(),\n",
188 | " torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n",
189 | "])"
190 | ],
191 | "execution_count": 0,
192 | "outputs": []
193 | },
194 | {
195 | "cell_type": "code",
196 | "metadata": {
197 | "id": "p9ekp4mByMPu",
198 | "colab_type": "code",
199 | "colab": {}
200 | },
201 | "source": [
202 | "test_csv['diagnosis'] = -1\n"
203 | ],
204 | "execution_count": 0,
205 | "outputs": []
206 | },
207 | {
208 | "cell_type": "code",
209 | "metadata": {
210 | "id": "fuDzK_jqyMP5",
211 | "colab_type": "code",
212 | "colab": {}
213 | },
214 | "source": [
215 | "test_data = CreateDataset(df_data=test_csv, data_dir=test_path, transform=test_transforms)\n",
216 | "test_loader = DataLoader(test_data, batch_size=64, shuffle=False)"
217 | ],
218 | "execution_count": 0,
219 | "outputs": []
220 | },
221 | {
222 | "cell_type": "code",
223 | "metadata": {
224 | "id": "-bx2Z1wVyMP_",
225 | "colab_type": "code",
226 | "colab": {}
227 | },
228 | "source": [
229 | "def round_off_preds(preds, coef=[0.5, 1.5, 2.5, 3.5]):\n",
230 | " for i, pred in enumerate(preds):\n",
231 | " if pred < coef[0]:\n",
232 | " preds[i] = 0\n",
233 | " elif pred >= coef[0] and pred < coef[1]:\n",
234 | " preds[i] = 1\n",
235 | " elif pred >= coef[1] and pred < coef[2]:\n",
236 | " preds[i] = 2\n",
237 | " elif pred >= coef[2] and pred < coef[3]:\n",
238 | " preds[i] = 3\n",
239 | " else:\n",
240 | " preds[i] = 4\n",
241 | " return preds"
242 | ],
243 | "execution_count": 0,
244 | "outputs": []
245 | },
246 | {
247 | "cell_type": "code",
248 | "metadata": {
249 | "id": "sw2LHzS3yMQC",
250 | "colab_type": "code",
251 | "colab": {}
252 | },
253 | "source": [
254 | "def predict(testloader):\n",
255 | " '''Function used to make predictions on the test set'''\n",
256 | " model.eval()\n",
257 | " preds = []\n",
258 | " for batch_i, (data, target) in enumerate(testloader):\n",
259 | " data, target = data.cuda(), target.cuda()\n",
260 | " output = model(data)\n",
261 | " pr = output.detach().cpu().numpy()\n",
262 | " for i in pr:\n",
263 | " preds.append(i.item())\n",
264 | " \n",
265 | " return preds"
266 | ],
267 | "execution_count": 0,
268 | "outputs": []
269 | },
270 | {
271 | "cell_type": "code",
272 | "metadata": {
273 | "id": "7lWcz7t8yMQH",
274 | "colab_type": "code",
275 | "colab": {}
276 | },
277 | "source": [
278 | "def load_model(path):\n",
279 | " checkpoint = torch.load(path)\n",
280 | " model.load_state_dict(checkpoint['model_state_dict'])\n",
281 | " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
282 | " return model"
283 | ],
284 | "execution_count": 0,
285 | "outputs": []
286 | },
287 | {
288 | "cell_type": "code",
289 | "metadata": {
290 | "id": "ohMGUHUHyMQM",
291 | "colab_type": "code",
292 | "colab": {}
293 | },
294 | "source": [
295 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
296 | "\n",
297 | "model = models.resnet152(pretrained=False) \n",
298 | "\n",
299 | "num_ftrs = model.fc.in_features \n",
300 | "out_ftrs = 5 \n",
301 | " \n",
302 | "model.fc = nn.Sequential(nn.Linear(num_ftrs, 512),nn.ReLU(),nn.Linear(512,out_ftrs),nn.LogSoftmax(dim=1))\n",
303 | "\n",
304 | "criterion = nn.NLLLoss()\n",
305 | "optimizer = torch.optim.Adam(filter(lambda p:p.requires_grad,model.parameters()) , lr = 0.00001) \n",
306 | "\n",
307 | "scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)\n",
308 | "model.to(device);"
309 | ],
310 | "execution_count": 0,
311 | "outputs": []
312 | },
313 | {
314 | "cell_type": "code",
315 | "metadata": {
316 | "id": "g8uMWj8lyMQT",
317 | "colab_type": "code",
318 | "outputId": "b19053b3-e48f-4cdd-d0a8-d562929bb33a",
319 | "colab": {}
320 | },
321 | "source": [
322 | "# to unfreeze more layers \n",
323 | "for name,child in model.named_children():\n",
324 | " if name in ['layer2','layer3','layer4','fc']:\n",
325 | " print(name + 'is unfrozen')\n",
326 | " for param in child.parameters():\n",
327 | " param.requires_grad = True\n",
328 | " else:\n",
329 | " print(name + 'is frozen')\n",
330 | " for param in child.parameters():\n",
331 | " param.requires_grad = False"
332 | ],
333 | "execution_count": 0,
334 | "outputs": [
335 | {
336 | "output_type": "stream",
337 | "text": [
338 | "conv1is frozen\n",
339 | "bn1is frozen\n",
340 | "reluis frozen\n",
341 | "maxpoolis frozen\n",
342 | "layer1is frozen\n",
343 | "layer2is unfrozen\n",
344 | "layer3is unfrozen\n",
345 | "layer4is unfrozen\n",
346 | "avgpoolis frozen\n",
347 | "fcis unfrozen\n"
348 | ],
349 | "name": "stdout"
350 | }
351 | ]
352 | },
353 | {
354 | "cell_type": "code",
355 | "metadata": {
356 | "id": "2dNbnpITyMQc",
357 | "colab_type": "code",
358 | "colab": {}
359 | },
360 | "source": [
361 | "optimizer = torch.optim.Adam(filter(lambda p:p.requires_grad,model.parameters()) , lr = 0.000001) \n",
362 | "scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)"
363 | ],
364 | "execution_count": 0,
365 | "outputs": []
366 | },
367 | {
368 | "cell_type": "code",
369 | "metadata": {
370 | "id": "2tNkZ6otyMQh",
371 | "colab_type": "code",
372 | "colab": {}
373 | },
374 | "source": [
375 | "model = load_model(\"../input/kernel4f121f3247/classifier.pt\")\n"
376 | ],
377 | "execution_count": 0,
378 | "outputs": []
379 | },
380 | {
381 | "cell_type": "code",
382 | "metadata": {
383 | "id": "3CyUoakYyMQp",
384 | "colab_type": "code",
385 | "colab": {}
386 | },
387 | "source": [
388 | "test_dir = \"../input/aptos2019-blindness-detection/test_images/\""
389 | ],
390 | "execution_count": 0,
391 | "outputs": []
392 | },
393 | {
394 | "cell_type": "code",
395 | "metadata": {
396 | "id": "Yk7nwGWyyMQu",
397 | "colab_type": "code",
398 | "outputId": "10af9635-db30-4c58-eab3-b66e2b632228",
399 | "colab": {}
400 | },
401 | "source": [
402 | "with torch.no_grad():\n",
403 | " model.eval()\n",
404 | " p_labels = []\n",
405 | " img_ids = []\n",
406 | " i = 0\n",
407 | " for inputs, labels in test_loader:\n",
408 | " i += 1\n",
409 | " if i % 10 == 0:\n",
410 | " print(f'{i} pass step')\n",
411 | " inputs = inputs.to(device)\n",
412 | " labels = labels.to(device)\n",
413 | " outputs = model(inputs)\n",
414 | " _, preds = torch.max(outputs, 1)\n",
415 | " p_labels.append(preds)\n",
416 | " # getting ids of file images "
417 | ],
418 | "execution_count": 0,
419 | "outputs": [
420 | {
421 | "output_type": "stream",
422 | "text": [
423 | "10 pass step\n",
424 | "20 pass step\n",
425 | "30 pass step\n"
426 | ],
427 | "name": "stdout"
428 | }
429 | ]
430 | },
431 | {
432 | "cell_type": "code",
433 | "metadata": {
434 | "id": "-LH6LE9YyMQ1",
435 | "colab_type": "code",
436 | "outputId": "9523a4f6-15ea-4211-8005-a332f6fd3684",
437 | "colab": {}
438 | },
439 | "source": [
440 | "p_labels"
441 | ],
442 | "execution_count": 0,
443 | "outputs": [
444 | {
445 | "output_type": "execute_result",
446 | "data": {
447 | "text/plain": [
448 | "[tensor([2, 2, 2, 2, 2, 2, 2, 1, 3, 0, 4, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2,\n",
449 | " 2, 0, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 2, 0, 3, 2, 2, 2, 2, 1, 2, 2, 3,\n",
450 | " 2, 2, 3, 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 2, 3, 2], device='cuda:0'),\n",
451 | " tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 0, 0, 0, 2, 1, 1, 2, 2, 2, 2,\n",
452 | " 2, 1, 2, 1, 2, 2, 2, 2, 1, 0, 2, 2, 1, 0, 0, 2, 2, 2, 3, 2, 2, 0, 2, 3,\n",
453 | " 2, 1, 2, 2, 2, 0, 2, 2, 0, 0, 3, 2, 3, 2, 2, 2], device='cuda:0'),\n",
454 | " tensor([0, 2, 0, 2, 2, 1, 2, 2, 2, 2, 0, 2, 2, 1, 2, 2, 2, 2, 0, 1, 0, 2, 1, 2,\n",
455 | " 2, 0, 2, 1, 2, 2, 2, 0, 2, 2, 2, 2, 2, 4, 0, 2, 3, 0, 0, 2, 4, 2, 2, 2,\n",
456 | " 2, 3, 0, 0, 2, 2, 2, 2, 2, 2, 0, 0, 2, 0, 2, 2], device='cuda:0'),\n",
457 | " tensor([0, 2, 2, 2, 4, 1, 2, 2, 2, 2, 2, 2, 0, 0, 2, 2, 0, 2, 0, 2, 2, 2, 0, 2,\n",
458 | " 2, 2, 2, 2, 2, 0, 0, 2, 4, 2, 2, 2, 0, 2, 2, 2, 1, 1, 3, 2, 2, 2, 2, 0,\n",
459 | " 0, 2, 2, 0, 0, 2, 2, 4, 0, 0, 0, 1, 0, 2, 2, 0], device='cuda:0'),\n",
460 | " tensor([0, 1, 2, 3, 2, 2, 0, 0, 2, 2, 3, 0, 3, 2, 1, 0, 1, 0, 2, 2, 2, 4, 2, 2,\n",
461 | " 2, 3, 2, 2, 2, 0, 0, 0, 2, 2, 0, 2, 2, 1, 2, 0, 2, 2, 2, 2, 0, 2, 2, 2,\n",
462 | " 2, 2, 2, 2, 0, 2, 2, 2, 2, 0, 2, 2, 1, 2, 2, 2], device='cuda:0'),\n",
463 | " tensor([0, 0, 1, 0, 2, 3, 2, 1, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 4, 2, 0, 2, 0, 2,\n",
464 | " 2, 1, 0, 2, 0, 2, 2, 2, 0, 2, 2, 1, 2, 2, 0, 1, 2, 2, 2, 2, 2, 0, 1, 2,\n",
465 | " 2, 2, 2, 2, 0, 2, 2, 0, 4, 2, 2, 2, 2, 2, 1, 2], device='cuda:0'),\n",
466 | " tensor([2, 2, 1, 2, 0, 0, 2, 2, 0, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
467 | " 2, 0, 0, 1, 1, 3, 2, 1, 0, 1, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 0, 2,\n",
468 | " 2, 1, 2, 2, 4, 2, 0, 0, 2, 1, 0, 2, 2, 2, 0, 2], device='cuda:0'),\n",
469 | " tensor([2, 0, 1, 0, 2, 2, 0, 2, 2, 2, 2, 3, 0, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0,\n",
470 | " 2, 0, 0, 3, 2, 1, 2, 0, 1, 2, 2, 2, 2, 2, 3, 0, 2, 2, 2, 0, 2, 3, 2, 2,\n",
471 | " 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 2, 0, 2, 0], device='cuda:0'),\n",
472 | " tensor([2, 2, 0, 2, 2, 2, 1, 2, 2, 0, 2, 1, 0, 2, 2, 1, 2, 2, 2, 2, 0, 3, 2, 0,\n",
473 | " 2, 2, 1, 2, 2, 0, 3, 2, 2, 2, 4, 2, 3, 2, 1, 2, 2, 2, 2, 1, 3, 2, 3, 2,\n",
474 | " 2, 2, 2, 0, 2, 0, 4, 2, 2, 2, 2, 0, 2, 2, 0, 2], device='cuda:0'),\n",
475 | " tensor([2, 2, 0, 0, 2, 1, 2, 2, 2, 3, 2, 2, 2, 2, 2, 4, 0, 2, 0, 2, 2, 2, 2, 0,\n",
476 | " 2, 4, 2, 0, 2, 0, 2, 2, 2, 0, 2, 2, 2, 0, 2, 0, 0, 0, 0, 2, 2, 2, 0, 2,\n",
477 | " 2, 2, 2, 2, 2, 0, 2, 0, 3, 2, 0, 0, 2, 3, 2, 1], device='cuda:0'),\n",
478 | " tensor([2, 1, 2, 2, 2, 1, 0, 2, 2, 3, 2, 2, 2, 0, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2,\n",
479 | " 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 0, 2, 1, 0, 2, 0, 2, 1, 0, 2, 2, 2, 2,\n",
480 | " 2, 2, 2, 2, 3, 2, 0, 2, 2, 0, 2, 1, 1, 0, 0, 2], device='cuda:0'),\n",
481 | " tensor([2, 2, 2, 2, 2, 2, 1, 1, 0, 2, 2, 0, 2, 0, 1, 1, 2, 2, 0, 0, 2, 2, 2, 0,\n",
482 | " 1, 2, 4, 1, 2, 2, 2, 2, 2, 2, 0, 0, 2, 2, 1, 3, 2, 2, 2, 0, 2, 2, 0, 2,\n",
483 | " 2, 2, 2, 2, 4, 0, 2, 2, 4, 2, 2, 2, 2, 2, 2, 2], device='cuda:0'),\n",
484 | " tensor([0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 3, 1, 2,\n",
485 | " 2, 0, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 3, 0, 0, 0, 0, 2, 2,\n",
486 | " 2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2], device='cuda:0'),\n",
487 | " tensor([2, 2, 2, 2, 2, 2, 1, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 3, 0, 2, 2, 1,\n",
488 | " 2, 1, 2, 2, 1, 2, 0, 2, 1, 4, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 3, 0, 2,\n",
489 | " 2, 2, 0, 2, 2, 0, 2, 3, 2, 2, 2, 0, 4, 2, 4, 2], device='cuda:0'),\n",
490 | " tensor([0, 2, 2, 2, 2, 1, 2, 2, 2, 2, 3, 1, 4, 1, 0, 2, 2, 2, 2, 2, 2, 0, 2, 2,\n",
491 | " 0, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 0, 2,\n",
492 | " 2, 2, 0, 0, 0, 2, 2, 0, 2, 4, 2, 1, 2, 2, 2, 2], device='cuda:0'),\n",
493 | " tensor([2, 1, 3, 2, 2, 3, 2, 0, 2, 2, 1, 2, 2, 2, 2, 2, 2, 3, 3, 1, 3, 2, 2, 2,\n",
494 | " 2, 0, 2, 1, 4, 0, 2, 2, 2, 2, 2, 2, 3, 0, 2, 0, 2, 4, 2, 2, 2, 0, 2, 2,\n",
495 | " 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2], device='cuda:0'),\n",
496 | " tensor([2, 0, 0, 2, 2, 1, 2, 2, 0, 3, 2, 2, 0, 2, 2, 3, 0, 2, 2, 0, 0, 2, 2, 0,\n",
497 | " 2, 2, 2, 2, 2, 2, 2, 0, 1, 2, 0, 0, 0, 0, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2,\n",
498 | " 2, 3, 2, 2, 4, 0, 2, 1, 2, 2, 2, 0, 2, 2, 4, 2], device='cuda:0'),\n",
499 | " tensor([2, 2, 0, 4, 3, 2, 2, 2, 0, 2, 2, 2, 4, 2, 3, 2, 0, 2, 2, 2, 2, 2, 3, 2,\n",
500 | " 0, 2, 2, 2, 0, 2, 2, 2, 2, 2, 3, 3, 2, 0, 2, 2, 0, 2, 2, 2, 2, 2, 0, 3,\n",
501 | " 0, 0, 4, 2, 1, 2, 2, 2, 2, 2, 4, 2, 2, 2, 2, 2], device='cuda:0'),\n",
502 | " tensor([1, 2, 2, 2, 0, 2, 4, 2, 2, 4, 2, 2, 2, 2, 0, 2, 1, 1, 2, 2, 2, 0, 3, 2,\n",
503 | " 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 2, 2, 0, 2, 0, 4, 2, 2,\n",
504 | " 2, 2, 2, 2, 3, 2, 3, 0, 2, 2, 0, 0, 0, 2, 1, 0], device='cuda:0'),\n",
505 | " tensor([0, 2, 2, 4, 0, 2, 0, 0, 2, 3, 2, 2, 0, 0, 2, 2, 1, 0, 2, 2, 2, 0, 2, 1,\n",
506 | " 0, 2, 2, 2, 2, 0, 2, 1, 2, 2, 0, 2, 0, 2, 2, 3, 0, 4, 3, 0, 2, 2, 1, 4,\n",
507 | " 2, 2, 0, 0, 2, 2, 4, 2, 1, 2, 2, 1, 2, 0, 2, 0], device='cuda:0'),\n",
508 | " tensor([0, 1, 2, 2, 2, 3, 2, 2, 1, 0, 0, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
509 | " 2, 0, 2, 0, 2, 2, 2, 3, 2, 2, 2, 0, 2, 0, 1, 0, 2, 2, 2, 4, 2, 2, 2, 0,\n",
510 | " 1, 0, 1, 2, 0, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0], device='cuda:0'),\n",
511 | " tensor([2, 2, 3, 2, 2, 2, 4, 4, 2, 2, 2, 0, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
512 | " 2, 2, 0, 2, 0, 2, 2, 0, 0, 0, 2, 2, 2, 4, 0, 0, 0, 2, 2, 2, 2, 0, 2, 2,\n",
513 | " 2, 3, 0, 2, 0, 3, 2, 3, 0, 1, 2, 2, 2, 2, 0, 2], device='cuda:0'),\n",
514 | " tensor([2, 2, 0, 2, 0, 2, 2, 2, 2, 0, 3, 0, 0, 0, 0, 2, 2, 2, 0, 2, 2, 2, 2, 2,\n",
515 | " 2, 2, 2, 4, 2, 0, 3, 4, 2, 2, 2, 3, 1, 4, 2, 0, 0, 2, 2, 1, 2, 2, 1, 0,\n",
516 | " 2, 1, 1, 2, 2, 0, 2, 2, 2, 0, 2, 2, 0, 2, 2, 3], device='cuda:0'),\n",
517 | " tensor([0, 2, 3, 0, 2, 0, 2, 0, 1, 2, 2, 2, 2, 0, 2, 2, 2, 0, 1, 0, 2, 3, 2, 4,\n",
518 | " 0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 0, 2, 0, 2, 1, 0, 2, 0, 2, 3, 2, 2,\n",
519 | " 2, 2, 1, 2, 2, 2, 3, 0, 3, 2, 2, 2, 2, 2, 2, 3], device='cuda:0'),\n",
520 | " tensor([2, 1, 2, 0, 2, 2, 1, 0, 2, 2, 4, 2, 2, 1, 2, 3, 4, 0, 2, 0, 2, 2, 2, 2,\n",
521 | " 2, 4, 1, 2, 2, 2, 2, 3, 1, 2, 2, 2, 2, 0, 2, 0, 2, 2, 0, 2, 2, 4, 2, 3,\n",
522 | " 0, 1, 3, 2, 2, 2, 2, 2, 1, 2, 2, 1, 3, 2, 2, 1], device='cuda:0'),\n",
523 | " tensor([2, 2, 2, 2, 2, 2, 1, 2, 2, 0, 4, 2, 2, 2, 2, 0, 2, 2, 2, 0, 3, 1, 1, 0,\n",
524 | " 2, 1, 2, 0, 2, 2, 0, 0, 2, 2, 3, 2, 4, 2, 0, 0, 0, 2, 2, 1, 2, 0, 2, 2,\n",
525 | " 2, 3, 2, 2, 0, 2, 2, 0, 0, 3, 2, 2, 2, 0, 4, 2], device='cuda:0'),\n",
526 | " tensor([0, 0, 2, 2, 2, 1, 0, 2, 4, 2, 2, 0, 1, 0, 4, 2, 2, 4, 2, 2, 2, 0, 1, 2,\n",
527 | " 0, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 1, 1, 2, 0, 0, 2, 2, 2, 2, 3, 0, 2, 2,\n",
528 | " 2, 2, 0, 2, 0, 0, 2, 0, 2, 2, 0, 2, 0, 0, 2, 2], device='cuda:0'),\n",
529 | " tensor([0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 1, 2,\n",
530 | " 2, 2, 0, 3, 2, 2, 2, 3, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 1, 2, 0, 2, 0, 0,\n",
531 | " 2, 2, 0, 0, 2, 4, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0'),\n",
532 | " tensor([3, 3, 2, 0, 0, 2, 4, 3, 2, 2, 2, 0, 3, 1, 2, 2, 0, 2, 3, 2, 2, 0, 2, 0,\n",
533 | " 2, 2, 0, 4, 2, 3, 2, 2, 2, 2, 0, 0, 2, 2, 2, 0, 2, 2, 3, 2, 4, 4, 2, 2,\n",
534 | " 2, 1, 2, 2, 0, 1, 2, 2, 2, 2, 2, 0, 0, 0, 0, 2], device='cuda:0'),\n",
535 | " tensor([2, 1, 0, 2, 0, 2, 2, 3, 1, 2, 3, 2, 1, 2, 2, 2, 2, 2, 2, 3, 3, 2, 2, 2,\n",
536 | " 4, 0, 2, 2, 3, 2, 2, 2, 0, 0, 0, 2, 2, 2, 4, 2, 2, 2, 0, 0, 0, 2, 2, 2,\n",
537 | " 1, 2, 2, 0, 2, 2, 0, 3, 2, 2, 2, 2, 2, 1, 2, 2], device='cuda:0'),\n",
538 | " tensor([2, 2, 2, 0, 2, 2, 3, 0], device='cuda:0')]"
539 | ]
540 | },
541 | "metadata": {
542 | "tags": []
543 | },
544 | "execution_count": 19
545 | }
546 | ]
547 | },
548 | {
549 | "cell_type": "code",
550 | "metadata": {
551 | "id": "nGcEgjL7yMRE",
552 | "colab_type": "code",
553 | "colab": {}
554 | },
555 | "source": [
556 | "pred_labels = []\n",
557 | "for l in p_labels:\n",
558 | " for l1 in l:\n",
559 | " pred_labels.append(l1.item())"
560 | ],
561 | "execution_count": 0,
562 | "outputs": []
563 | },
564 | {
565 | "cell_type": "code",
566 | "metadata": {
567 | "id": "aeP6MAfUyMSP",
568 | "colab_type": "code",
569 | "outputId": "055465d8-db0e-40d5-fab1-8631dddab286",
570 | "colab": {}
571 | },
572 | "source": [
573 | "pred_labels"
574 | ],
575 | "execution_count": 0,
576 | "outputs": [
577 | {
578 | "output_type": "execute_result",
579 | "data": {
580 | "text/plain": [
581 | "[2,\n",
582 | " 2,\n",
583 | " 2,\n",
584 | " 2,\n",
585 | " 2,\n",
586 | " 2,\n",
587 | " 2,\n",
588 | " 1,\n",
589 | " 3,\n",
590 | " 0,\n",
591 | " 4,\n",
592 | " 2,\n",
593 | " 2,\n",
594 | " 2,\n",
595 | " 2,\n",
596 | " 2,\n",
597 | " 2,\n",
598 | " 2,\n",
599 | " 2,\n",
600 | " 1,\n",
601 | " 2,\n",
602 | " 2,\n",
603 | " 2,\n",
604 | " 2,\n",
605 | " 2,\n",
606 | " 0,\n",
607 | " 2,\n",
608 | " 2,\n",
609 | " 2,\n",
610 | " 2,\n",
611 | " 0,\n",
612 | " 2,\n",
613 | " 2,\n",
614 | " 0,\n",
615 | " 2,\n",
616 | " 2,\n",
617 | " 2,\n",
618 | " 2,\n",
619 | " 0,\n",
620 | " 3,\n",
621 | " 2,\n",
622 | " 2,\n",
623 | " 2,\n",
624 | " 2,\n",
625 | " 1,\n",
626 | " 2,\n",
627 | " 2,\n",
628 | " 3,\n",
629 | " 2,\n",
630 | " 2,\n",
631 | " 3,\n",
632 | " 2,\n",
633 | " 2,\n",
634 | " 0,\n",
635 | " 2,\n",
636 | " 0,\n",
637 | " 2,\n",
638 | " 2,\n",
639 | " 2,\n",
640 | " 2,\n",
641 | " 2,\n",
642 | " 2,\n",
643 | " 3,\n",
644 | " 2,\n",
645 | " 2,\n",
646 | " 2,\n",
647 | " 2,\n",
648 | " 2,\n",
649 | " 2,\n",
650 | " 2,\n",
651 | " 2,\n",
652 | " 2,\n",
653 | " 2,\n",
654 | " 0,\n",
655 | " 2,\n",
656 | " 2,\n",
657 | " 2,\n",
658 | " 2,\n",
659 | " 0,\n",
660 | " 0,\n",
661 | " 0,\n",
662 | " 2,\n",
663 | " 1,\n",
664 | " 1,\n",
665 | " 2,\n",
666 | " 2,\n",
667 | " 2,\n",
668 | " 2,\n",
669 | " 2,\n",
670 | " 1,\n",
671 | " 2,\n",
672 | " 1,\n",
673 | " 2,\n",
674 | " 2,\n",
675 | " 2,\n",
676 | " 2,\n",
677 | " 1,\n",
678 | " 0,\n",
679 | " 2,\n",
680 | " 2,\n",
681 | " 1,\n",
682 | " 0,\n",
683 | " 0,\n",
684 | " 2,\n",
685 | " 2,\n",
686 | " 2,\n",
687 | " 3,\n",
688 | " 2,\n",
689 | " 2,\n",
690 | " 0,\n",
691 | " 2,\n",
692 | " 3,\n",
693 | " 2,\n",
694 | " 1,\n",
695 | " 2,\n",
696 | " 2,\n",
697 | " 2,\n",
698 | " 0,\n",
699 | " 2,\n",
700 | " 2,\n",
701 | " 0,\n",
702 | " 0,\n",
703 | " 3,\n",
704 | " 2,\n",
705 | " 3,\n",
706 | " 2,\n",
707 | " 2,\n",
708 | " 2,\n",
709 | " 0,\n",
710 | " 2,\n",
711 | " 0,\n",
712 | " 2,\n",
713 | " 2,\n",
714 | " 1,\n",
715 | " 2,\n",
716 | " 2,\n",
717 | " 2,\n",
718 | " 2,\n",
719 | " 0,\n",
720 | " 2,\n",
721 | " 2,\n",
722 | " 1,\n",
723 | " 2,\n",
724 | " 2,\n",
725 | " 2,\n",
726 | " 2,\n",
727 | " 0,\n",
728 | " 1,\n",
729 | " 0,\n",
730 | " 2,\n",
731 | " 1,\n",
732 | " 2,\n",
733 | " 2,\n",
734 | " 0,\n",
735 | " 2,\n",
736 | " 1,\n",
737 | " 2,\n",
738 | " 2,\n",
739 | " 2,\n",
740 | " 0,\n",
741 | " 2,\n",
742 | " 2,\n",
743 | " 2,\n",
744 | " 2,\n",
745 | " 2,\n",
746 | " 4,\n",
747 | " 0,\n",
748 | " 2,\n",
749 | " 3,\n",
750 | " 0,\n",
751 | " 0,\n",
752 | " 2,\n",
753 | " 4,\n",
754 | " 2,\n",
755 | " 2,\n",
756 | " 2,\n",
757 | " 2,\n",
758 | " 3,\n",
759 | " 0,\n",
760 | " 0,\n",
761 | " 2,\n",
762 | " 2,\n",
763 | " 2,\n",
764 | " 2,\n",
765 | " 2,\n",
766 | " 2,\n",
767 | " 0,\n",
768 | " 0,\n",
769 | " 2,\n",
770 | " 0,\n",
771 | " 2,\n",
772 | " 2,\n",
773 | " 0,\n",
774 | " 2,\n",
775 | " 2,\n",
776 | " 2,\n",
777 | " 4,\n",
778 | " 1,\n",
779 | " 2,\n",
780 | " 2,\n",
781 | " 2,\n",
782 | " 2,\n",
783 | " 2,\n",
784 | " 2,\n",
785 | " 0,\n",
786 | " 0,\n",
787 | " 2,\n",
788 | " 2,\n",
789 | " 0,\n",
790 | " 2,\n",
791 | " 0,\n",
792 | " 2,\n",
793 | " 2,\n",
794 | " 2,\n",
795 | " 0,\n",
796 | " 2,\n",
797 | " 2,\n",
798 | " 2,\n",
799 | " 2,\n",
800 | " 2,\n",
801 | " 2,\n",
802 | " 0,\n",
803 | " 0,\n",
804 | " 2,\n",
805 | " 4,\n",
806 | " 2,\n",
807 | " 2,\n",
808 | " 2,\n",
809 | " 0,\n",
810 | " 2,\n",
811 | " 2,\n",
812 | " 2,\n",
813 | " 1,\n",
814 | " 1,\n",
815 | " 3,\n",
816 | " 2,\n",
817 | " 2,\n",
818 | " 2,\n",
819 | " 2,\n",
820 | " 0,\n",
821 | " 0,\n",
822 | " 2,\n",
823 | " 2,\n",
824 | " 0,\n",
825 | " 0,\n",
826 | " 2,\n",
827 | " 2,\n",
828 | " 4,\n",
829 | " 0,\n",
830 | " 0,\n",
831 | " 0,\n",
832 | " 1,\n",
833 | " 0,\n",
834 | " 2,\n",
835 | " 2,\n",
836 | " 0,\n",
837 | " 0,\n",
838 | " 1,\n",
839 | " 2,\n",
840 | " 3,\n",
841 | " 2,\n",
842 | " 2,\n",
843 | " 0,\n",
844 | " 0,\n",
845 | " 2,\n",
846 | " 2,\n",
847 | " 3,\n",
848 | " 0,\n",
849 | " 3,\n",
850 | " 2,\n",
851 | " 1,\n",
852 | " 0,\n",
853 | " 1,\n",
854 | " 0,\n",
855 | " 2,\n",
856 | " 2,\n",
857 | " 2,\n",
858 | " 4,\n",
859 | " 2,\n",
860 | " 2,\n",
861 | " 2,\n",
862 | " 3,\n",
863 | " 2,\n",
864 | " 2,\n",
865 | " 2,\n",
866 | " 0,\n",
867 | " 0,\n",
868 | " 0,\n",
869 | " 2,\n",
870 | " 2,\n",
871 | " 0,\n",
872 | " 2,\n",
873 | " 2,\n",
874 | " 1,\n",
875 | " 2,\n",
876 | " 0,\n",
877 | " 2,\n",
878 | " 2,\n",
879 | " 2,\n",
880 | " 2,\n",
881 | " 0,\n",
882 | " 2,\n",
883 | " 2,\n",
884 | " 2,\n",
885 | " 2,\n",
886 | " 2,\n",
887 | " 2,\n",
888 | " 2,\n",
889 | " 0,\n",
890 | " 2,\n",
891 | " 2,\n",
892 | " 2,\n",
893 | " 2,\n",
894 | " 0,\n",
895 | " 2,\n",
896 | " 2,\n",
897 | " 1,\n",
898 | " 2,\n",
899 | " 2,\n",
900 | " 2,\n",
901 | " 0,\n",
902 | " 0,\n",
903 | " 1,\n",
904 | " 0,\n",
905 | " 2,\n",
906 | " 3,\n",
907 | " 2,\n",
908 | " 1,\n",
909 | " 2,\n",
910 | " 2,\n",
911 | " 2,\n",
912 | " 2,\n",
913 | " 2,\n",
914 | " 2,\n",
915 | " 0,\n",
916 | " 2,\n",
917 | " 2,\n",
918 | " 2,\n",
919 | " 4,\n",
920 | " 2,\n",
921 | " 0,\n",
922 | " 2,\n",
923 | " 0,\n",
924 | " 2,\n",
925 | " 2,\n",
926 | " 1,\n",
927 | " 0,\n",
928 | " 2,\n",
929 | " 0,\n",
930 | " 2,\n",
931 | " 2,\n",
932 | " 2,\n",
933 | " 0,\n",
934 | " 2,\n",
935 | " 2,\n",
936 | " 1,\n",
937 | " 2,\n",
938 | " 2,\n",
939 | " 0,\n",
940 | " 1,\n",
941 | " 2,\n",
942 | " 2,\n",
943 | " 2,\n",
944 | " 2,\n",
945 | " 2,\n",
946 | " 0,\n",
947 | " 1,\n",
948 | " 2,\n",
949 | " 2,\n",
950 | " 2,\n",
951 | " 2,\n",
952 | " 2,\n",
953 | " 0,\n",
954 | " 2,\n",
955 | " 2,\n",
956 | " 0,\n",
957 | " 4,\n",
958 | " 2,\n",
959 | " 2,\n",
960 | " 2,\n",
961 | " 2,\n",
962 | " 2,\n",
963 | " 1,\n",
964 | " 2,\n",
965 | " 2,\n",
966 | " 2,\n",
967 | " 1,\n",
968 | " 2,\n",
969 | " 0,\n",
970 | " 0,\n",
971 | " 2,\n",
972 | " 2,\n",
973 | " 0,\n",
974 | " 2,\n",
975 | " 2,\n",
976 | " 0,\n",
977 | " 2,\n",
978 | " 2,\n",
979 | " 2,\n",
980 | " 2,\n",
981 | " 2,\n",
982 | " 2,\n",
983 | " 2,\n",
984 | " 2,\n",
985 | " 2,\n",
986 | " 2,\n",
987 | " 2,\n",
988 | " 2,\n",
989 | " 2,\n",
990 | " 0,\n",
991 | " 0,\n",
992 | " 1,\n",
993 | " 1,\n",
994 | " 3,\n",
995 | " 2,\n",
996 | " 1,\n",
997 | " 0,\n",
998 | " 1,\n",
999 | " 2,\n",
1000 | " 2,\n",
1001 | " 2,\n",
1002 | " 2,\n",
1003 | " 2,\n",
1004 | " 0,\n",
1005 | " 2,\n",
1006 | " 2,\n",
1007 | " 0,\n",
1008 | " 2,\n",
1009 | " 2,\n",
1010 | " 2,\n",
1011 | " 0,\n",
1012 | " 2,\n",
1013 | " 2,\n",
1014 | " 1,\n",
1015 | " 2,\n",
1016 | " 2,\n",
1017 | " 4,\n",
1018 | " 2,\n",
1019 | " 0,\n",
1020 | " 0,\n",
1021 | " 2,\n",
1022 | " 1,\n",
1023 | " 0,\n",
1024 | " 2,\n",
1025 | " 2,\n",
1026 | " 2,\n",
1027 | " 0,\n",
1028 | " 2,\n",
1029 | " 2,\n",
1030 | " 0,\n",
1031 | " 1,\n",
1032 | " 0,\n",
1033 | " 2,\n",
1034 | " 2,\n",
1035 | " 0,\n",
1036 | " 2,\n",
1037 | " 2,\n",
1038 | " 2,\n",
1039 | " 2,\n",
1040 | " 3,\n",
1041 | " 0,\n",
1042 | " 2,\n",
1043 | " 2,\n",
1044 | " 2,\n",
1045 | " 2,\n",
1046 | " 2,\n",
1047 | " 2,\n",
1048 | " 0,\n",
1049 | " 2,\n",
1050 | " 2,\n",
1051 | " 2,\n",
1052 | " 0,\n",
1053 | " 2,\n",
1054 | " 0,\n",
1055 | " 0,\n",
1056 | " 3,\n",
1057 | " 2,\n",
1058 | " 1,\n",
1059 | " 2,\n",
1060 | " 0,\n",
1061 | " 1,\n",
1062 | " 2,\n",
1063 | " 2,\n",
1064 | " 2,\n",
1065 | " 2,\n",
1066 | " 2,\n",
1067 | " 3,\n",
1068 | " 0,\n",
1069 | " 2,\n",
1070 | " 2,\n",
1071 | " 2,\n",
1072 | " 0,\n",
1073 | " 2,\n",
1074 | " 3,\n",
1075 | " 2,\n",
1076 | " 2,\n",
1077 | " 2,\n",
1078 | " 3,\n",
1079 | " 2,\n",
1080 | " 2,\n",
1081 | " 2,\n",
1082 | " 2,\n",
1083 | " 2,\n",
1084 | " 2,\n",
1085 | " 2,\n",
1086 | " 2,\n",
1087 | " 2,\n",
1088 | " 4,\n",
1089 | " 2,\n",
1090 | " 0,\n",
1091 | " 2,\n",
1092 | " 0,\n",
1093 | " 2,\n",
1094 | " 2,\n",
1095 | " 0,\n",
1096 | " 2,\n",
1097 | " 2,\n",
1098 | " 2,\n",
1099 | " 1,\n",
1100 | " 2,\n",
1101 | " 2,\n",
1102 | " 0,\n",
1103 | " 2,\n",
1104 | " 1,\n",
1105 | " 0,\n",
1106 | " 2,\n",
1107 | " 2,\n",
1108 | " 1,\n",
1109 | " 2,\n",
1110 | " 2,\n",
1111 | " 2,\n",
1112 | " 2,\n",
1113 | " 0,\n",
1114 | " 3,\n",
1115 | " 2,\n",
1116 | " 0,\n",
1117 | " 2,\n",
1118 | " 2,\n",
1119 | " 1,\n",
1120 | " 2,\n",
1121 | " 2,\n",
1122 | " 0,\n",
1123 | " 3,\n",
1124 | " 2,\n",
1125 | " 2,\n",
1126 | " 2,\n",
1127 | " 4,\n",
1128 | " 2,\n",
1129 | " 3,\n",
1130 | " 2,\n",
1131 | " 1,\n",
1132 | " 2,\n",
1133 | " 2,\n",
1134 | " 2,\n",
1135 | " 2,\n",
1136 | " 1,\n",
1137 | " 3,\n",
1138 | " 2,\n",
1139 | " 3,\n",
1140 | " 2,\n",
1141 | " 2,\n",
1142 | " 2,\n",
1143 | " 2,\n",
1144 | " 0,\n",
1145 | " 2,\n",
1146 | " 0,\n",
1147 | " 4,\n",
1148 | " 2,\n",
1149 | " 2,\n",
1150 | " 2,\n",
1151 | " 2,\n",
1152 | " 0,\n",
1153 | " 2,\n",
1154 | " 2,\n",
1155 | " 0,\n",
1156 | " 2,\n",
1157 | " 2,\n",
1158 | " 2,\n",
1159 | " 0,\n",
1160 | " 0,\n",
1161 | " 2,\n",
1162 | " 1,\n",
1163 | " 2,\n",
1164 | " 2,\n",
1165 | " 2,\n",
1166 | " 3,\n",
1167 | " 2,\n",
1168 | " 2,\n",
1169 | " 2,\n",
1170 | " 2,\n",
1171 | " 2,\n",
1172 | " 4,\n",
1173 | " 0,\n",
1174 | " 2,\n",
1175 | " 0,\n",
1176 | " 2,\n",
1177 | " 2,\n",
1178 | " 2,\n",
1179 | " 2,\n",
1180 | " 0,\n",
1181 | " 2,\n",
1182 | " 4,\n",
1183 | " 2,\n",
1184 | " 0,\n",
1185 | " 2,\n",
1186 | " 0,\n",
1187 | " 2,\n",
1188 | " 2,\n",
1189 | " 2,\n",
1190 | " 0,\n",
1191 | " 2,\n",
1192 | " 2,\n",
1193 | " 2,\n",
1194 | " 0,\n",
1195 | " 2,\n",
1196 | " 0,\n",
1197 | " 0,\n",
1198 | " 0,\n",
1199 | " 0,\n",
1200 | " 2,\n",
1201 | " 2,\n",
1202 | " 2,\n",
1203 | " 0,\n",
1204 | " 2,\n",
1205 | " 2,\n",
1206 | " 2,\n",
1207 | " 2,\n",
1208 | " 2,\n",
1209 | " 2,\n",
1210 | " 0,\n",
1211 | " 2,\n",
1212 | " 0,\n",
1213 | " 3,\n",
1214 | " 2,\n",
1215 | " 0,\n",
1216 | " 0,\n",
1217 | " 2,\n",
1218 | " 3,\n",
1219 | " 2,\n",
1220 | " 1,\n",
1221 | " 2,\n",
1222 | " 1,\n",
1223 | " 2,\n",
1224 | " 2,\n",
1225 | " 2,\n",
1226 | " 1,\n",
1227 | " 0,\n",
1228 | " 2,\n",
1229 | " 2,\n",
1230 | " 3,\n",
1231 | " 2,\n",
1232 | " 2,\n",
1233 | " 2,\n",
1234 | " 0,\n",
1235 | " 2,\n",
1236 | " 2,\n",
1237 | " 2,\n",
1238 | " 2,\n",
1239 | " 1,\n",
1240 | " 2,\n",
1241 | " 2,\n",
1242 | " 2,\n",
1243 | " 2,\n",
1244 | " 2,\n",
1245 | " 0,\n",
1246 | " 2,\n",
1247 | " 2,\n",
1248 | " 2,\n",
1249 | " 2,\n",
1250 | " 2,\n",
1251 | " 2,\n",
1252 | " 2,\n",
1253 | " 2,\n",
1254 | " 2,\n",
1255 | " 3,\n",
1256 | " 0,\n",
1257 | " 2,\n",
1258 | " 1,\n",
1259 | " 0,\n",
1260 | " 2,\n",
1261 | " 0,\n",
1262 | " 2,\n",
1263 | " 1,\n",
1264 | " 0,\n",
1265 | " 2,\n",
1266 | " 2,\n",
1267 | " 2,\n",
1268 | " 2,\n",
1269 | " 2,\n",
1270 | " 2,\n",
1271 | " 2,\n",
1272 | " 2,\n",
1273 | " 3,\n",
1274 | " 2,\n",
1275 | " 0,\n",
1276 | " 2,\n",
1277 | " 2,\n",
1278 | " 0,\n",
1279 | " 2,\n",
1280 | " 1,\n",
1281 | " 1,\n",
1282 | " 0,\n",
1283 | " 0,\n",
1284 | " 2,\n",
1285 | " 2,\n",
1286 | " 2,\n",
1287 | " 2,\n",
1288 | " 2,\n",
1289 | " 2,\n",
1290 | " 2,\n",
1291 | " 1,\n",
1292 | " 1,\n",
1293 | " 0,\n",
1294 | " 2,\n",
1295 | " 2,\n",
1296 | " 0,\n",
1297 | " 2,\n",
1298 | " 0,\n",
1299 | " 1,\n",
1300 | " 1,\n",
1301 | " 2,\n",
1302 | " 2,\n",
1303 | " 0,\n",
1304 | " 0,\n",
1305 | " 2,\n",
1306 | " 2,\n",
1307 | " 2,\n",
1308 | " 0,\n",
1309 | " 1,\n",
1310 | " 2,\n",
1311 | " 4,\n",
1312 | " 1,\n",
1313 | " 2,\n",
1314 | " 2,\n",
1315 | " 2,\n",
1316 | " 2,\n",
1317 | " 2,\n",
1318 | " 2,\n",
1319 | " 0,\n",
1320 | " 0,\n",
1321 | " 2,\n",
1322 | " 2,\n",
1323 | " 1,\n",
1324 | " 3,\n",
1325 | " 2,\n",
1326 | " 2,\n",
1327 | " 2,\n",
1328 | " 0,\n",
1329 | " 2,\n",
1330 | " 2,\n",
1331 | " 0,\n",
1332 | " 2,\n",
1333 | " 2,\n",
1334 | " 2,\n",
1335 | " 2,\n",
1336 | " 2,\n",
1337 | " 4,\n",
1338 | " 0,\n",
1339 | " 2,\n",
1340 | " 2,\n",
1341 | " 4,\n",
1342 | " 2,\n",
1343 | " 2,\n",
1344 | " 2,\n",
1345 | " 2,\n",
1346 | " 2,\n",
1347 | " 2,\n",
1348 | " 2,\n",
1349 | " 0,\n",
1350 | " 1,\n",
1351 | " 2,\n",
1352 | " 2,\n",
1353 | " 2,\n",
1354 | " 2,\n",
1355 | " 2,\n",
1356 | " 2,\n",
1357 | " 2,\n",
1358 | " 2,\n",
1359 | " 2,\n",
1360 | " 2,\n",
1361 | " 0,\n",
1362 | " 2,\n",
1363 | " 2,\n",
1364 | " 2,\n",
1365 | " 2,\n",
1366 | " 2,\n",
1367 | " 2,\n",
1368 | " 2,\n",
1369 | " 2,\n",
1370 | " 3,\n",
1371 | " 1,\n",
1372 | " 2,\n",
1373 | " 2,\n",
1374 | " 0,\n",
1375 | " 2,\n",
1376 | " 2,\n",
1377 | " 2,\n",
1378 | " 3,\n",
1379 | " 2,\n",
1380 | " 2,\n",
1381 | " 2,\n",
1382 | " 2,\n",
1383 | " 2,\n",
1384 | " 2,\n",
1385 | " 2,\n",
1386 | " 2,\n",
1387 | " 3,\n",
1388 | " 2,\n",
1389 | " 1,\n",
1390 | " 3,\n",
1391 | " 0,\n",
1392 | " 0,\n",
1393 | " 0,\n",
1394 | " 0,\n",
1395 | " 2,\n",
1396 | " 2,\n",
1397 | " 2,\n",
1398 | " 2,\n",
1399 | " 2,\n",
1400 | " 2,\n",
1401 | " 2,\n",
1402 | " 2,\n",
1403 | " 2,\n",
1404 | " 3,\n",
1405 | " 2,\n",
1406 | " 3,\n",
1407 | " 2,\n",
1408 | " 2,\n",
1409 | " 2,\n",
1410 | " 2,\n",
1411 | " 2,\n",
1412 | " 2,\n",
1413 | " 2,\n",
1414 | " 2,\n",
1415 | " 2,\n",
1416 | " 2,\n",
1417 | " 2,\n",
1418 | " 2,\n",
1419 | " 1,\n",
1420 | " 2,\n",
1421 | " 3,\n",
1422 | " 2,\n",
1423 | " 2,\n",
1424 | " 2,\n",
1425 | " 2,\n",
1426 | " 2,\n",
1427 | " 2,\n",
1428 | " 2,\n",
1429 | " 2,\n",
1430 | " 0,\n",
1431 | " 2,\n",
1432 | " 3,\n",
1433 | " 0,\n",
1434 | " 2,\n",
1435 | " 2,\n",
1436 | " 1,\n",
1437 | " 2,\n",
1438 | " 1,\n",
1439 | " 2,\n",
1440 | " 2,\n",
1441 | " 1,\n",
1442 | " 2,\n",
1443 | " 0,\n",
1444 | " 2,\n",
1445 | " 1,\n",
1446 | " 4,\n",
1447 | " 2,\n",
1448 | " 2,\n",
1449 | " 2,\n",
1450 | " 2,\n",
1451 | " 2,\n",
1452 | " 2,\n",
1453 | " 2,\n",
1454 | " 2,\n",
1455 | " 0,\n",
1456 | " 2,\n",
1457 | " 2,\n",
1458 | " 3,\n",
1459 | " 0,\n",
1460 | " 2,\n",
1461 | " 2,\n",
1462 | " 2,\n",
1463 | " 0,\n",
1464 | " 2,\n",
1465 | " 2,\n",
1466 | " 0,\n",
1467 | " 2,\n",
1468 | " 3,\n",
1469 | " 2,\n",
1470 | " 2,\n",
1471 | " 2,\n",
1472 | " 0,\n",
1473 | " 4,\n",
1474 | " 2,\n",
1475 | " 4,\n",
1476 | " 2,\n",
1477 | " 0,\n",
1478 | " 2,\n",
1479 | " 2,\n",
1480 | " 2,\n",
1481 | " 2,\n",
1482 | " 1,\n",
1483 | " 2,\n",
1484 | " 2,\n",
1485 | " 2,\n",
1486 | " 2,\n",
1487 | " 3,\n",
1488 | " 1,\n",
1489 | " 4,\n",
1490 | " 1,\n",
1491 | " 0,\n",
1492 | " 2,\n",
1493 | " 2,\n",
1494 | " 2,\n",
1495 | " 2,\n",
1496 | " 2,\n",
1497 | " 2,\n",
1498 | " 0,\n",
1499 | " 2,\n",
1500 | " 2,\n",
1501 | " 0,\n",
1502 | " 3,\n",
1503 | " 2,\n",
1504 | " 2,\n",
1505 | " 2,\n",
1506 | " 2,\n",
1507 | " 2,\n",
1508 | " 2,\n",
1509 | " 2,\n",
1510 | " 2,\n",
1511 | " 2,\n",
1512 | " 2,\n",
1513 | " 2,\n",
1514 | " 2,\n",
1515 | " 2,\n",
1516 | " 2,\n",
1517 | " 2,\n",
1518 | " 2,\n",
1519 | " 2,\n",
1520 | " 1,\n",
1521 | " 1,\n",
1522 | " 2,\n",
1523 | " 0,\n",
1524 | " 2,\n",
1525 | " 2,\n",
1526 | " 2,\n",
1527 | " 0,\n",
1528 | " 0,\n",
1529 | " 0,\n",
1530 | " 2,\n",
1531 | " 2,\n",
1532 | " 0,\n",
1533 | " 2,\n",
1534 | " 4,\n",
1535 | " 2,\n",
1536 | " 1,\n",
1537 | " 2,\n",
1538 | " 2,\n",
1539 | " 2,\n",
1540 | " 2,\n",
1541 | " 2,\n",
1542 | " 1,\n",
1543 | " 3,\n",
1544 | " 2,\n",
1545 | " 2,\n",
1546 | " 3,\n",
1547 | " 2,\n",
1548 | " 0,\n",
1549 | " 2,\n",
1550 | " 2,\n",
1551 | " 1,\n",
1552 | " 2,\n",
1553 | " 2,\n",
1554 | " 2,\n",
1555 | " 2,\n",
1556 | " 2,\n",
1557 | " 2,\n",
1558 | " 3,\n",
1559 | " 3,\n",
1560 | " 1,\n",
1561 | " 3,\n",
1562 | " 2,\n",
1563 | " 2,\n",
1564 | " 2,\n",
1565 | " 2,\n",
1566 | " 0,\n",
1567 | " 2,\n",
1568 | " 1,\n",
1569 | " 4,\n",
1570 | " 0,\n",
1571 | " 2,\n",
1572 | " 2,\n",
1573 | " 2,\n",
1574 | " 2,\n",
1575 | " 2,\n",
1576 | " 2,\n",
1577 | " 3,\n",
1578 | " 0,\n",
1579 | " 2,\n",
1580 | " 0,\n",
1581 | " ...]"
1582 | ]
1583 | },
1584 | "metadata": {
1585 | "tags": []
1586 | },
1587 | "execution_count": 21
1588 | }
1589 | ]
1590 | },
1591 | {
1592 | "cell_type": "code",
1593 | "metadata": {
1594 | "id": "GTyoqsENyMSj",
1595 | "colab_type": "code",
1596 | "colab": {}
1597 | },
1598 | "source": [
1599 | "sample_sub = pd.read_csv('../input/aptos2019-blindness-detection/sample_submission.csv')\n"
1600 | ],
1601 | "execution_count": 0,
1602 | "outputs": []
1603 | },
1604 | {
1605 | "cell_type": "code",
1606 | "metadata": {
1607 | "id": "z805HlDiyMSp",
1608 | "colab_type": "code",
1609 | "colab": {}
1610 | },
1611 | "source": [
1612 | "sample_sub.diagnosis = pred_labels\n"
1613 | ],
1614 | "execution_count": 0,
1615 | "outputs": []
1616 | },
1617 | {
1618 | "cell_type": "code",
1619 | "metadata": {
1620 | "id": "MQvuj9dCyMSu",
1621 | "colab_type": "code",
1622 | "outputId": "c9245967-8905-49ac-b449-c7fd4f420da7",
1623 | "colab": {}
1624 | },
1625 | "source": [
1626 | "sample_sub.head()"
1627 | ],
1628 | "execution_count": 0,
1629 | "outputs": [
1630 | {
1631 | "output_type": "execute_result",
1632 | "data": {
1633 | "text/html": [
1634 | "
\n", 1652 | " | id_code | \n", 1653 | "diagnosis | \n", 1654 | "
---|---|---|
0 | \n", 1659 | "0005cfc8afb6 | \n", 1660 | "2 | \n", 1661 | "
1 | \n", 1664 | "003f0afdcd15 | \n", 1665 | "2 | \n", 1666 | "
2 | \n", 1669 | "006efc72b638 | \n", 1670 | "2 | \n", 1671 | "
3 | \n", 1674 | "00836aaacf06 | \n", 1675 | "2 | \n", 1676 | "
4 | \n", 1679 | "009245722fa4 | \n", 1680 | "2 | \n", 1681 | "