├── LICENSE ├── README.md └── pytorch-resnet50-starter.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Transfer Learning image Classification 2 | This is my sample kernel for the kaggle competition iMet Collection 2019 - FGVC6 (Recognize artwork attributes from The Metropolitan Museum of Art) 3 | ### Highlights of this project: 4 | * Pytorch 5 | * Pytorch custom data class 6 | * Transfer learning (Resnet50) 7 | * Multi label classification 8 | * 1103 categories of labels 9 | -------------------------------------------------------------------------------- /pytorch-resnet50-starter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", 8 | "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5" 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "%matplotlib inline\n", 13 | "\n", 14 | "import os\n", 15 | "import time\n", 16 | "import copy\n", 17 | "import pandas as pd\n", 18 | "import numpy as np\n", 19 | "\n", 20 | "from random import seed\n", 21 | "from random import randint\n", 22 | "import random\n", 23 | "\n", 24 | "import torch\n", 25 | "from torch import nn\n", 26 | "from torch import optim\n", 27 | "import torch.nn.functional as F\n", 28 | "from torchvision import datasets, transforms, models\n", 29 | "\n", 30 | "from PIL import Image\n", 31 | "from matplotlib import pyplot as plt\n", 32 | "\n", 33 | "import warnings\n", 34 | "warnings.filterwarnings(\"ignore\")\n", 35 | "\n", 36 | "from tqdm import tqdm_notebook as tqdm\n", 37 | "\n", 38 | "\n", 39 | "input_dir = os.path.join('..','input','imet-2019-fgvc6')\n", 40 | "train_dir = os.path.join(input_dir,'train')\n", 41 | "test_dir = os.path.join(input_dir,'test')\n", 42 | "labels_csv= os.path.join(input_dir,'labels.csv')\n", 43 | "train_csv = os.path.join(input_dir,'train.csv')\n", 44 | "resnet_weights_path = os.path.join('..','input','resnet50','resnet50.pth')" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "def seed_all(seed=27):\n", 54 | " \"\"\"https://pytorch.org/docs/stable/notes/randomness.html\"\"\"\n", 55 | " random.seed(seed)\n", 56 | " os.environ['PYTHONHASHSEED'] = str(seed)\n", 57 | " np.random.seed(seed)\n", 58 | " torch.manual_seed(seed)\n", 59 | " torch.cuda.manual_seed(seed)\n", 60 | " torch.backends.cudnn.deterministic = True\n", 61 | "seed_all(27)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "BATCH_SIZE = 128\n", 71 | "NUM_EPOCHS = 20\n", 72 | "PERCENTILE = 99.7\n", 73 | "LEARNING_RATE = 0.0001\n", 74 | "DISABLE_TQDM = True" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "df = pd.read_csv(labels_csv)\n", 93 | "attribute_dict = dict(zip(df.attribute_id,df.attribute_name))\n", 94 | "del df,labels_csv" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 6, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "total_categories: 1103\n", 107 | "tag_categories: 705 \n", 108 | "culture_categories: 398 \n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "tag_count = 0 \n", 114 | "culture_count = 0\n", 115 | "for idx,data in attribute_dict.items():\n", 116 | " if data.split(\"::\")[0] == 'tag':\n", 117 | " tag_count+=1\n", 118 | " if data.split(\"::\")[0] == 'culture':\n", 119 | " culture_count+=1\n", 120 | "print('total_categories: {0}\\ntag_categories: {1} \\nculture_categories: {2} ' \\\n", 121 | " .format(len(attribute_dict),tag_count,culture_count))\n", 122 | "#cross check your results\n", 123 | "assert tag_count+culture_count == len(attribute_dict)\n", 124 | "output_dim = len(attribute_dict) " 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 7, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "df = pd.read_csv(train_csv)\n", 134 | "labels_dict = dict(zip(df.id,df.attribute_ids))" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 8, 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "name": "stdout", 144 | "output_type": "stream", 145 | "text": [ 146 | "culture::assyrian\n", 147 | "tag::decorative elements\n" 148 | ] 149 | }, 150 | { 151 | "data": { 152 | "image/png": "\n", 153 | "text/plain": [ 154 | "
" 155 | ] 156 | }, 157 | "metadata": {}, 158 | "output_type": "display_data" 159 | } 160 | ], 161 | "source": [ 162 | "idx = len(os.listdir(train_dir))\n", 163 | "number = randint(0,idx)\n", 164 | "image_name = os.listdir(train_dir)[number]\n", 165 | "def imshow(image):\n", 166 | " plt.figure(figsize=(6, 6))\n", 167 | " plt.imshow(image)\n", 168 | " plt.show()\n", 169 | "# Example image\n", 170 | "x = Image.open(os.path.join(train_dir,image_name))\n", 171 | "for i in labels_dict[os.listdir(train_dir)[number].split('.')[0]].split():\n", 172 | " print(attribute_dict[int(i)])\n", 173 | "np.array(x).shape\n", 174 | "imshow(x)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 9, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "# need to add more transforms here\n", 184 | "data_transforms = transforms.Compose([\n", 185 | " transforms.Resize((224,224)),\n", 186 | " transforms.ToTensor(),\n", 187 | " ])" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "# Custom Dataset class" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 10, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "from torch.utils import data\n", 204 | "class ImageData(data.Dataset):\n", 205 | " def __init__(self,df,dirpath,transform,test = False):\n", 206 | " self.df = df\n", 207 | " self.test = test\n", 208 | " self.dirpath = dirpath\n", 209 | " self.conv_to_tensor = transform\n", 210 | " #image data \n", 211 | " if not self.test:\n", 212 | " self.image_arr = np.asarray(str(self.dirpath)+'/'+self.df.iloc[:, 0]+'.png')\n", 213 | " else:\n", 214 | " self.image_arr = np.asarray(str(self.dirpath)+'/'+self.df.iloc[:, 0])\n", 215 | " \n", 216 | " #labels data\n", 217 | " if not self.test:\n", 218 | " self.label_df = self.df.iloc[:,1]\n", 219 | " \n", 220 | " # Calculate length of df\n", 221 | " self.data_len = len(self.df.index)\n", 222 | "\n", 223 | " def __len__(self):\n", 224 | " return self.data_len\n", 225 | " \n", 226 | " def __getitem__(self, idx):\n", 227 | " image_name = self.image_arr[idx]\n", 228 | " img = Image.open(image_name)\n", 229 | " img_tensor = self.conv_to_tensor(img)\n", 230 | " if not self.test:\n", 231 | " image_labels = self.label_df[idx]\n", 232 | " label_tensor = torch.zeros((1, output_dim))\n", 233 | " for label in image_labels.split():\n", 234 | " label_tensor[0, int(label)] = 1\n", 235 | " image_label = torch.tensor(label_tensor,dtype= torch.float32)\n", 236 | " return (img_tensor,image_label.squeeze())\n", 237 | " return (img_tensor)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 11, 243 | "metadata": {}, 244 | "outputs": [ 245 | { 246 | "name": "stdout", 247 | "output_type": "stream", 248 | "text": [ 249 | "Validation_Data Length: 21848\n", 250 | " Train_Data Length: 87389\n" 251 | ] 252 | } 253 | ], 254 | "source": [ 255 | "df = pd.read_csv(train_csv)\n", 256 | "# if you want to run on less data to quickly check\n", 257 | "#df = pd.read_csv(train_csv).head(5000)\n", 258 | "from sklearn.model_selection import train_test_split\n", 259 | "train_df,val_df = train_test_split(df, test_size=0.20)\n", 260 | "train_df = train_df.reset_index(drop=True)\n", 261 | "val_df = val_df.reset_index(drop=True)\n", 262 | "print(f\"Validation_Data Length: {len(val_df)}\\n Train_Data Length: {len(train_df)}\")" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 12, 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [ 271 | "# Train dataset\n", 272 | "train_dataset = ImageData(train_df,train_dir,data_transforms)\n", 273 | "train_loader = data.DataLoader(dataset=train_dataset,batch_size=BATCH_SIZE,shuffle=False)\n", 274 | "\n", 275 | "# validation dataset\n", 276 | "val_dataset = ImageData(val_df,train_dir,data_transforms)\n", 277 | "val_loader = data.DataLoader(dataset=val_dataset,batch_size=BATCH_SIZE,shuffle=False)\n", 278 | "\n", 279 | "# test dataset\n", 280 | "test_df = pd.DataFrame(os.listdir(test_dir))\n", 281 | "test_dataset = ImageData(test_df,test_dir,data_transforms,test = True)\n", 282 | "test_loader = data.DataLoader(dataset=test_dataset,batch_size=BATCH_SIZE,shuffle=False)\n", 283 | "\n", 284 | "dataloaders_dict = {'train':train_loader, 'val':val_loader}" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 13, 290 | "metadata": {}, 291 | "outputs": [ 292 | { 293 | "name": "stdout", 294 | "output_type": "stream", 295 | "text": [ 296 | "Train Features: torch.Size([128, 3, 224, 224])\n", 297 | "Train Labels: torch.Size([128, 1103])\n", 298 | "\n", 299 | "Validation Features: torch.Size([128, 3, 224, 224])\n", 300 | "Validation Labels: torch.Size([128, 1103])\n", 301 | "\n", 302 | "Test Features: torch.Size([128, 3, 224, 224])\n", 303 | "\n" 304 | ] 305 | } 306 | ], 307 | "source": [ 308 | "features, labels = next(iter(train_loader))\n", 309 | "print(f'Train Features: {features.shape}\\nTrain Labels: {labels.shape}')\n", 310 | "print()\n", 311 | "features, labels = next(iter(val_loader))\n", 312 | "print(f'Validation Features: {features.shape}\\nValidation Labels: {labels.shape}')\n", 313 | "print()\n", 314 | "features = next(iter(test_loader))\n", 315 | "print(f'Test Features: {features.shape}\\n')" 316 | ] 317 | }, 318 | { 319 | "cell_type": "markdown", 320 | "metadata": {}, 321 | "source": [ 322 | "# Model Using Resnet50" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 14, 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [ 331 | "resnet_cls = models.resnet50()\n", 332 | "resnet_cls.load_state_dict(torch.load(resnet_weights_path))\n", 333 | "\n", 334 | "class AvgPool(nn.Module):\n", 335 | " def forward(self, x):\n", 336 | " return F.avg_pool2d(x, x.shape[2:])\n", 337 | " \n", 338 | "class ResNet50(nn.Module):\n", 339 | " def __init__(self,num_outputs):\n", 340 | " super(ResNet50,self).__init__()\n", 341 | " self.resnet = resnet_cls\n", 342 | " layer4 = self.resnet.layer4\n", 343 | " self.resnet.layer4 = nn.Sequential(\n", 344 | " nn.Dropout(0.5),\n", 345 | " layer4\n", 346 | " )\n", 347 | " self.resnet.avgpool = AvgPool()\n", 348 | " self.resnet.fc = nn.Linear(2048, num_outputs)\n", 349 | " for param in self.resnet.parameters():\n", 350 | " param.requires_grad = False\n", 351 | "\n", 352 | " for param in self.resnet.layer4.parameters():\n", 353 | " param.requires_grad = True\n", 354 | "\n", 355 | " for param in self.resnet.fc.parameters():\n", 356 | " param.requires_grad = True\n", 357 | " \n", 358 | " def forward(self,x):\n", 359 | " out = self.resnet(x)\n", 360 | " return out\n", 361 | " \n", 362 | "NeuralNet = ResNet50(num_outputs = output_dim) " 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 15, 368 | "metadata": {}, 369 | "outputs": [ 370 | { 371 | "data": { 372 | "text/plain": [ 373 | "ResNet50(\n", 374 | " (resnet): ResNet(\n", 375 | " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", 376 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 377 | " (relu): ReLU(inplace)\n", 378 | " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", 379 | " (layer1): Sequential(\n", 380 | " (0): Bottleneck(\n", 381 | " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 382 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 383 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 384 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 385 | " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 386 | " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 387 | " (relu): ReLU(inplace)\n", 388 | " (downsample): Sequential(\n", 389 | " (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 390 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 391 | " )\n", 392 | " )\n", 393 | " (1): Bottleneck(\n", 394 | " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 395 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 396 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 397 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 398 | " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 399 | " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 400 | " (relu): ReLU(inplace)\n", 401 | " )\n", 402 | " (2): Bottleneck(\n", 403 | " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 404 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 405 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 406 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 407 | " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 408 | " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 409 | " (relu): ReLU(inplace)\n", 410 | " )\n", 411 | " )\n", 412 | " (layer2): Sequential(\n", 413 | " (0): Bottleneck(\n", 414 | " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 415 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 416 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 417 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 418 | " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 419 | " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 420 | " (relu): ReLU(inplace)\n", 421 | " (downsample): Sequential(\n", 422 | " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 423 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 424 | " )\n", 425 | " )\n", 426 | " (1): Bottleneck(\n", 427 | " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 428 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 429 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 430 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 431 | " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 432 | " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 433 | " (relu): ReLU(inplace)\n", 434 | " )\n", 435 | " (2): Bottleneck(\n", 436 | " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 437 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 438 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 439 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 440 | " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 441 | " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 442 | " (relu): ReLU(inplace)\n", 443 | " )\n", 444 | " (3): Bottleneck(\n", 445 | " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 446 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 447 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 448 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 449 | " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 450 | " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 451 | " (relu): ReLU(inplace)\n", 452 | " )\n", 453 | " )\n", 454 | " (layer3): Sequential(\n", 455 | " (0): Bottleneck(\n", 456 | " (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 457 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 458 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 459 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 460 | " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 461 | " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 462 | " (relu): ReLU(inplace)\n", 463 | " (downsample): Sequential(\n", 464 | " (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 465 | " (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 466 | " )\n", 467 | " )\n", 468 | " (1): Bottleneck(\n", 469 | " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 470 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 471 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 472 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 473 | " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 474 | " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 475 | " (relu): ReLU(inplace)\n", 476 | " )\n", 477 | " (2): Bottleneck(\n", 478 | " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 479 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 480 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 481 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 482 | " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 483 | " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 484 | " (relu): ReLU(inplace)\n", 485 | " )\n", 486 | " (3): Bottleneck(\n", 487 | " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 488 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 489 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 490 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 491 | " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 492 | " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 493 | " (relu): ReLU(inplace)\n", 494 | " )\n", 495 | " (4): Bottleneck(\n", 496 | " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 497 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 498 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 499 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 500 | " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 501 | " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 502 | " (relu): ReLU(inplace)\n", 503 | " )\n", 504 | " (5): Bottleneck(\n", 505 | " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 506 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 507 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 508 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 509 | " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 510 | " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 511 | " (relu): ReLU(inplace)\n", 512 | " )\n", 513 | " )\n", 514 | " (layer4): Sequential(\n", 515 | " (0): Dropout(p=0.5)\n", 516 | " (1): Sequential(\n", 517 | " (0): Bottleneck(\n", 518 | " (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 519 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 520 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 521 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 522 | " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 523 | " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 524 | " (relu): ReLU(inplace)\n", 525 | " (downsample): Sequential(\n", 526 | " (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 527 | " (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 528 | " )\n", 529 | " )\n", 530 | " (1): Bottleneck(\n", 531 | " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 532 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 533 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 534 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 535 | " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 536 | " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 537 | " (relu): ReLU(inplace)\n", 538 | " )\n", 539 | " (2): Bottleneck(\n", 540 | " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 541 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 542 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 543 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 544 | " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 545 | " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 546 | " (relu): ReLU(inplace)\n", 547 | " )\n", 548 | " )\n", 549 | " )\n", 550 | " (avgpool): AvgPool()\n", 551 | " (fc): Linear(in_features=2048, out_features=1103, bias=True)\n", 552 | " )\n", 553 | ")" 554 | ] 555 | }, 556 | "execution_count": 15, 557 | "metadata": {}, 558 | "output_type": "execute_result" 559 | } 560 | ], 561 | "source": [ 562 | "NeuralNet" 563 | ] 564 | }, 565 | { 566 | "cell_type": "code", 567 | "execution_count": 16, 568 | "metadata": {}, 569 | "outputs": [ 570 | { 571 | "name": "stdout", 572 | "output_type": "stream", 573 | "text": [ 574 | "25,768,079 total parameters.\n", 575 | "17,224,783 training parameters.\n" 576 | ] 577 | } 578 | ], 579 | "source": [ 580 | "total_params = sum(p.numel() for p in NeuralNet.parameters())\n", 581 | "print(f'{total_params:,} total parameters.')\n", 582 | "total_trainable_params = sum(p.numel() for p in NeuralNet.parameters() if p.requires_grad)\n", 583 | "print(f'{total_trainable_params:,} training parameters.')" 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": 17, 589 | "metadata": {}, 590 | "outputs": [ 591 | { 592 | "name": "stdout", 593 | "output_type": "stream", 594 | "text": [ 595 | "TRAINING\n", 596 | "training examples: 87389\n", 597 | "batch size: 128\n", 598 | "batches available: 683\n", 599 | "\n", 600 | "TESTING\n", 601 | "validation examples: 21848\n", 602 | "batch size: 128\n", 603 | "batches available: 171\n", 604 | "\n", 605 | "VALIDATION\n", 606 | "testing examples: 7443\n", 607 | "batch size: 128\n", 608 | "batches available: 59\n" 609 | ] 610 | } 611 | ], 612 | "source": [ 613 | "print(\"TRAINING\")\n", 614 | "print(\"training examples: \",len(train_dataset))\n", 615 | "print(\"batch size: \",BATCH_SIZE)\n", 616 | "print(\"batches available: \",len(train_loader))\n", 617 | "print()\n", 618 | "print(\"TESTING\")\n", 619 | "print(\"validation examples: \",len(val_dataset))\n", 620 | "print(\"batch size: \",BATCH_SIZE)\n", 621 | "print(\"batches available: \",len(val_loader))\n", 622 | "print()\n", 623 | "print(\"VALIDATION\")\n", 624 | "print(\"testing examples: \",len(test_dataset))\n", 625 | "print(\"batch size: \",BATCH_SIZE)\n", 626 | "print(\"batches available: \",len(test_loader))" 627 | ] 628 | }, 629 | { 630 | "cell_type": "markdown", 631 | "metadata": {}, 632 | "source": [ 633 | "# Train the Model" 634 | ] 635 | }, 636 | { 637 | "cell_type": "code", 638 | "execution_count": 18, 639 | "metadata": {}, 640 | "outputs": [ 641 | { 642 | "name": "stdout", 643 | "output_type": "stream", 644 | "text": [ 645 | "Phase: train | Epoch: 1/20 | train_loss:0.02759083 | Time: 1020.3523s\n", 646 | "model val_loss Improved from inf to 0.01205008\n", 647 | "Phase: val | Epoch: 1/20 | val_loss:0.01205008 | Time: 242.5033s\n", 648 | "Phase: train | Epoch: 2/20 | train_loss:0.01108181 | Time: 1019.1168s\n", 649 | "model val_loss Improved from 0.01205008 to 0.01037972\n", 650 | "Phase: val | Epoch: 2/20 | val_loss:0.01037972 | Time: 235.0539s\n", 651 | "Phase: train | Epoch: 3/20 | train_loss:0.00987623 | Time: 974.0747s\n", 652 | "model val_loss Improved from 0.01037972 to 0.00965139\n", 653 | "Phase: val | Epoch: 3/20 | val_loss:0.00965139 | Time: 229.2483s\n", 654 | "Phase: train | Epoch: 4/20 | train_loss:0.00913306 | Time: 985.1822s\n", 655 | "model val_loss Improved from 0.00965139 to 0.00927821\n", 656 | "Phase: val | Epoch: 4/20 | val_loss:0.00927821 | Time: 230.8872s\n", 657 | "Phase: train | Epoch: 5/20 | train_loss:0.00857142 | Time: 981.3410s\n", 658 | "model val_loss Improved from 0.00927821 to 0.00903705\n", 659 | "Phase: val | Epoch: 5/20 | val_loss:0.00903705 | Time: 228.4492s\n", 660 | "Phase: train | Epoch: 6/20 | train_loss:0.00809647 | Time: 977.3055s\n", 661 | "model val_loss Improved from 0.00903705 to 0.00891332\n", 662 | "Phase: val | Epoch: 6/20 | val_loss:0.00891332 | Time: 232.7572s\n", 663 | "Phase: train | Epoch: 7/20 | train_loss:0.00767191 | Time: 978.3559s\n", 664 | "model val_loss Improved from 0.00891332 to 0.00883641\n", 665 | "Phase: val | Epoch: 7/20 | val_loss:0.00883641 | Time: 230.9421s\n", 666 | "Phase: train | Epoch: 8/20 | train_loss:0.00728203 | Time: 970.8857s\n", 667 | "model val_loss Improved from 0.00883641 to 0.00880080\n", 668 | "Phase: val | Epoch: 8/20 | val_loss:0.00880080 | Time: 228.9912s\n", 669 | "Phase: train | Epoch: 9/20 | train_loss:0.00691778 | Time: 960.9457s\n", 670 | "Phase: val | Epoch: 9/20 | val_loss:0.00885041 | Time: 228.8472s\n", 671 | "Phase: train | Epoch: 10/20 | train_loss:0.00656370 | Time: 961.6056s\n", 672 | "Phase: val | Epoch: 10/20 | val_loss:0.00894096 | Time: 226.4677s\n", 673 | "Phase: train | Epoch: 11/20 | train_loss:0.00623539 | Time: 978.5004s\n", 674 | "Phase: val | Epoch: 11/20 | val_loss:0.00897319 | Time: 225.8588s\n", 675 | "Phase: train | Epoch: 12/20 | train_loss:0.00572551 | Time: 966.0341s\n", 676 | "model val_loss Improved from 0.00880080 to 0.00869350\n", 677 | "Phase: val | Epoch: 12/20 | val_loss:0.00869350 | Time: 224.9203s\n", 678 | "Phase: train | Epoch: 13/20 | train_loss:0.00553604 | Time: 960.7631s\n", 679 | "Phase: val | Epoch: 13/20 | val_loss:0.00869544 | Time: 224.7910s\n", 680 | "Phase: train | Epoch: 14/20 | train_loss:0.00544243 | Time: 966.4869s\n", 681 | "Phase: val | Epoch: 14/20 | val_loss:0.00870903 | Time: 227.0596s\n", 682 | "Phase: train | Epoch: 15/20 | train_loss:0.00536984 | Time: 965.9445s\n", 683 | "Phase: val | Epoch: 15/20 | val_loss:0.00872658 | Time: 228.9651s\n", 684 | "Phase: train | Epoch: 16/20 | train_loss:0.00528060 | Time: 972.3011s\n", 685 | "model val_loss Improved from 0.00869350 to 0.00868408\n", 686 | "Phase: val | Epoch: 16/20 | val_loss:0.00868408 | Time: 226.7299s\n", 687 | "Phase: train | Epoch: 17/20 | train_loss:0.00525970 | Time: 961.6806s\n", 688 | "model val_loss Improved from 0.00868408 to 0.00868126\n", 689 | "Phase: val | Epoch: 17/20 | val_loss:0.00868126 | Time: 223.1173s\n", 690 | "Phase: train | Epoch: 18/20 | train_loss:0.00524516 | Time: 959.7015s\n", 691 | "model val_loss Improved from 0.00868126 to 0.00868042\n", 692 | "Phase: val | Epoch: 18/20 | val_loss:0.00868042 | Time: 224.9494s\n", 693 | "Phase: train | Epoch: 19/20 | train_loss:0.00523335 | Time: 962.3714s\n", 694 | "model val_loss Improved from 0.00868042 to 0.00867484\n", 695 | "Phase: val | Epoch: 19/20 | val_loss:0.00867484 | Time: 227.1843s\n", 696 | "Phase: train | Epoch: 20/20 | train_loss:0.00522770 | Time: 981.1657s\n", 697 | "Phase: val | Epoch: 20/20 | val_loss:0.00868109 | Time: 228.8864s\n" 698 | ] 699 | } 700 | ], 701 | "source": [ 702 | "NeuralNet = NeuralNet.to(device)\n", 703 | "optimizer = optim.Adam(NeuralNet.parameters(),lr = LEARNING_RATE)\n", 704 | "loss_func = torch.nn.BCEWithLogitsLoss()\n", 705 | "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience = 2)\n", 706 | "best_loss = np.inf\n", 707 | "for epoch in range(NUM_EPOCHS):\n", 708 | " for phase in ['train', 'val']:\n", 709 | " start_time = time.time()\n", 710 | " if phase == 'train':\n", 711 | " NeuralNet.train()\n", 712 | " else:\n", 713 | " NeuralNet.eval()\n", 714 | " \n", 715 | " running_loss = 0.0\n", 716 | " for images_batch, labels_batch in tqdm(dataloaders_dict[phase],disable = DISABLE_TQDM):\n", 717 | " images_batch = images_batch.to(device)\n", 718 | " labels_batch = labels_batch.to(device)\n", 719 | " \n", 720 | " optimizer.zero_grad()\n", 721 | " \n", 722 | " with torch.set_grad_enabled(phase == 'train'):\n", 723 | " pred_batch = NeuralNet(images_batch)\n", 724 | " loss = loss_func(pred_batch,labels_batch)\n", 725 | " \n", 726 | " if phase == 'train':\n", 727 | " loss.backward()\n", 728 | " optimizer.step()\n", 729 | " \n", 730 | " running_loss += loss.item() * images_batch.size(0) \n", 731 | " epoch_loss = running_loss / len(dataloaders_dict[phase].dataset) \n", 732 | "\n", 733 | " if phase == 'val' and epoch_loss < best_loss: \n", 734 | " print(\"model val_loss Improved from {:.8f} to {:.8f}\".format(best_loss,epoch_loss))\n", 735 | " best_loss = epoch_loss\n", 736 | " best_model_wts = copy.deepcopy(NeuralNet.state_dict())\n", 737 | " \n", 738 | " if phase == 'val':\n", 739 | " scheduler.step(epoch_loss)\n", 740 | " \n", 741 | " elapsed_time = time.time()-start_time\n", 742 | " print(\"Phase: {} | Epoch: {}/{} | {}_loss:{:.8f} | Time: {:.4f}s\".format(phase,\n", 743 | " epoch+1,\n", 744 | " NUM_EPOCHS,\n", 745 | " phase,\n", 746 | " epoch_loss,\n", 747 | " elapsed_time))\n", 748 | "NeuralNet.load_state_dict(best_model_wts)" 749 | ] 750 | }, 751 | { 752 | "cell_type": "markdown", 753 | "metadata": {}, 754 | "source": [ 755 | "# Predictions from the model" 756 | ] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": 19, 761 | "metadata": {}, 762 | "outputs": [], 763 | "source": [ 764 | "NeuralNet.eval()\n", 765 | "predictions = np.zeros((len(test_dataset), output_dim))\n", 766 | "i = 0\n", 767 | "for test_batch in tqdm(test_loader,disable = DISABLE_TQDM):\n", 768 | " test_batch = test_batch.to(device)\n", 769 | " batch_prediction = NeuralNet(test_batch).detach().cpu().numpy()\n", 770 | " predictions[i * BATCH_SIZE:(i+1) * BATCH_SIZE, :] = batch_prediction\n", 771 | " i+=1" 772 | ] 773 | }, 774 | { 775 | "cell_type": "markdown", 776 | "metadata": {}, 777 | "source": [ 778 | "# Generating submission " 779 | ] 780 | }, 781 | { 782 | "cell_type": "code", 783 | "execution_count": 20, 784 | "metadata": {}, 785 | "outputs": [], 786 | "source": [ 787 | "predicted_class_idx = []\n", 788 | "for i in range(len(predictions)): \n", 789 | " idx_list = np.where(predictions[i] > np.percentile(predictions[i],PERCENTILE)) \n", 790 | " predicted_class_idx.append(idx_list[0])" 791 | ] 792 | }, 793 | { 794 | "cell_type": "code", 795 | "execution_count": 21, 796 | "metadata": {}, 797 | "outputs": [ 798 | { 799 | "data": { 800 | "text/html": [ 801 | "
\n", 802 | "\n", 815 | "\n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | "
idattribute_ids
0b7bb1254bf81c71e79 304 487 1062
1a6689c35fce7a4ea13 79 121 1039
2d0801eed43ebf742121 161 800 1092
38cbcbba676d26bec554 612 671 780
4b899b628ae6db63413 156 813 1092
\n", 851 | "
" 852 | ], 853 | "text/plain": [ 854 | " id attribute_ids\n", 855 | "0 b7bb1254bf81c71e 79 304 487 1062\n", 856 | "1 a6689c35fce7a4ea 13 79 121 1039\n", 857 | "2 d0801eed43ebf742 121 161 800 1092\n", 858 | "3 8cbcbba676d26bec 554 612 671 780\n", 859 | "4 b899b628ae6db634 13 156 813 1092" 860 | ] 861 | }, 862 | "execution_count": 21, 863 | "metadata": {}, 864 | "output_type": "execute_result" 865 | } 866 | ], 867 | "source": [ 868 | "test_df['attribute_ids'] = predicted_class_idx\n", 869 | "test_df['attribute_ids'] = test_df['attribute_ids'].apply(lambda x : ' '.join(map(str,list(x))))\n", 870 | "test_df = test_df.rename(columns={0: 'id'})\n", 871 | "test_df['id'] = test_df['id'].apply(lambda x : x.split('.')[0])\n", 872 | "test_df.head()" 873 | ] 874 | }, 875 | { 876 | "cell_type": "code", 877 | "execution_count": 22, 878 | "metadata": {}, 879 | "outputs": [], 880 | "source": [ 881 | "test_df.to_csv('submission.csv',index = False)" 882 | ] 883 | } 884 | ], 885 | "metadata": { 886 | "kernelspec": { 887 | "display_name": "Python 3", 888 | "language": "python", 889 | "name": "python3" 890 | }, 891 | "language_info": { 892 | "codemirror_mode": { 893 | "name": "ipython", 894 | "version": 3 895 | }, 896 | "file_extension": ".py", 897 | "mimetype": "text/x-python", 898 | "name": "python", 899 | "nbconvert_exporter": "python", 900 | "pygments_lexer": "ipython3", 901 | "version": "3.6.4" 902 | } 903 | }, 904 | "nbformat": 4, 905 | "nbformat_minor": 1 906 | } 907 | --------------------------------------------------------------------------------