├── README.md └── ResNet ├── CIFAR10_ResNet50.ipynb └── ResNet.py /README.md: -------------------------------------------------------------------------------- 1 | # ResNet-PyTorch 2 | Implementation of ResNet 50, 101, 152 in PyTorch based on paper "Deep Residual Learning for Image Recognition" by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 3 | 4 | Currently working on implementing the ResNet 18 and 34 architectures as well which do not include the Bottleneck in the residual block. 5 | 6 | A baseline run of ResNet50 on the CIFAR-10 dataset is given as well, with the standard setup proposed by the paper it already achieves around 85.6% accuracy. However, this can definitely be brought up to at least 92% accuracy via some more slight optimization. 7 | 8 | References: 9 | 1. Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 10 | Deep Residual Learning for Image Recognition 11 | https://arxiv.org/pdf/1512.03385.pdf 12 | 13 | 2. https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 14 | 15 | 3. https://github.com/AladdinPerzon/Machine-Learning-Collection/blob/master/ML/Pytorch/CNN_architectures/pytorch_resnet.py 16 | -------------------------------------------------------------------------------- /ResNet/CIFAR10_ResNet50.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 32, 6 | "metadata": { 7 | "colab": {}, 8 | "colab_type": "code", 9 | "id": "oPSy3VPHL5Uq" 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "import torch\n", 14 | "import torchvision\n", 15 | "import torchvision.transforms as transforms\n", 16 | "import torch.nn as nn\n", 17 | "import torch.nn.functional as F\n", 18 | "import torch.optim as optim\n", 19 | "\n", 20 | "from ResNet import Bottleneck, ResNet, ResNet50" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 33, 26 | "metadata": { 27 | "colab": {}, 28 | "colab_type": "code", 29 | "id": "yFFakdIfNEEZ" 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "transform_train = transforms.Compose([\n", 34 | " transforms.RandomHorizontalFlip(),\n", 35 | " transforms.RandomCrop(32, padding=4),\n", 36 | " transforms.ToTensor(),\n", 37 | " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", 38 | "])\n", 39 | "\n", 40 | "transform_test = transforms.Compose([\n", 41 | " transforms.ToTensor(),\n", 42 | " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", 43 | "])" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 34, 49 | "metadata": { 50 | "colab": { 51 | "base_uri": "https://localhost:8080/", 52 | "height": 51 53 | }, 54 | "colab_type": "code", 55 | "id": "V2SCe0hDNeaV", 56 | "outputId": "02ea31cc-a3e8-4b9e-da02-1f2821ca4e52" 57 | }, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "Files already downloaded and verified\n", 64 | "Files already downloaded and verified\n" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)\n", 70 | "\n", 71 | "trainloader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True, num_workers=2)\n", 72 | "\n", 73 | "test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)\n", 74 | "\n", 75 | "testloader = torch.utils.data.DataLoader(test, batch_size=128,shuffle=False, num_workers=2)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 35, 81 | "metadata": { 82 | "colab": {}, 83 | "colab_type": "code", 84 | "id": "NVrCGU9jNkJb" 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 37, 94 | "metadata": { 95 | "colab": {}, 96 | "colab_type": "code", 97 | "id": "obf7QfWYOBFT" 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "net = ResNet50(10).to('cuda')\n", 102 | "\n", 103 | "criterion = nn.CrossEntropyLoss()\n", 104 | "optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)\n", 105 | "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.1, patience=5)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 38, 111 | "metadata": { 112 | "colab": { 113 | "base_uri": "https://localhost:8080/", 114 | "height": 1000 115 | }, 116 | "colab_type": "code", 117 | "id": "Sp3I3vApPK2w", 118 | "outputId": "60d8be96-76cf-4ca8-8451-9ef862d1ef19" 119 | }, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "Loss [1, 100](epoch, minibatch): 8.839240016937255\n", 126 | "Loss [1, 200](epoch, minibatch): 2.8543614387512206\n", 127 | "Loss [1, 300](epoch, minibatch): 2.5090506088733675\n", 128 | "Loss [2, 100](epoch, minibatch): 2.1302514731884004\n", 129 | "Loss [2, 200](epoch, minibatch): 1.9953344476222992\n", 130 | "Loss [2, 300](epoch, minibatch): 1.9380652523040771\n", 131 | "Loss [3, 100](epoch, minibatch): 1.855226961374283\n", 132 | "Loss [3, 200](epoch, minibatch): 1.7927296078205108\n", 133 | "Loss [3, 300](epoch, minibatch): 1.743425933122635\n", 134 | "Loss [4, 100](epoch, minibatch): 1.7187504696846008\n", 135 | "Loss [4, 200](epoch, minibatch): 1.687288955450058\n", 136 | "Loss [4, 300](epoch, minibatch): 1.665401849746704\n", 137 | "Loss [5, 100](epoch, minibatch): 1.6413379180431367\n", 138 | "Loss [5, 200](epoch, minibatch): 1.6188793969154358\n", 139 | "Loss [5, 300](epoch, minibatch): 1.591441843509674\n", 140 | "Loss [6, 100](epoch, minibatch): 1.582551372051239\n", 141 | "Loss [6, 200](epoch, minibatch): 1.5571859669685364\n", 142 | "Loss [6, 300](epoch, minibatch): 1.5410805189609527\n", 143 | "Loss [7, 100](epoch, minibatch): 1.5283722984790802\n", 144 | "Loss [7, 200](epoch, minibatch): 1.4911496341228485\n", 145 | "Loss [7, 300](epoch, minibatch): 1.4876199412345885\n", 146 | "Loss [8, 100](epoch, minibatch): 1.4896726572513581\n", 147 | "Loss [8, 200](epoch, minibatch): 1.444840190410614\n", 148 | "Loss [8, 300](epoch, minibatch): 1.4478385579586028\n", 149 | "Loss [9, 100](epoch, minibatch): 1.4383381915092468\n", 150 | "Loss [9, 200](epoch, minibatch): 1.3851233530044555\n", 151 | "Loss [9, 300](epoch, minibatch): 1.3953999078273773\n", 152 | "Loss [10, 100](epoch, minibatch): 1.3797350168228149\n", 153 | "Loss [10, 200](epoch, minibatch): 1.3668984162807465\n", 154 | "Loss [10, 300](epoch, minibatch): 1.3529738533496856\n", 155 | "Loss [11, 100](epoch, minibatch): 1.3527108430862427\n", 156 | "Loss [11, 200](epoch, minibatch): 1.3190882837772369\n", 157 | "Loss [11, 300](epoch, minibatch): 1.289210125207901\n", 158 | "Loss [12, 100](epoch, minibatch): 1.3086059510707855\n", 159 | "Loss [12, 200](epoch, minibatch): 1.274652873277664\n", 160 | "Loss [12, 300](epoch, minibatch): 1.2602075088024138\n", 161 | "Loss [13, 100](epoch, minibatch): 1.2712362825870513\n", 162 | "Loss [13, 200](epoch, minibatch): 1.2355617022514342\n", 163 | "Loss [13, 300](epoch, minibatch): 1.2457362633943558\n", 164 | "Loss [14, 100](epoch, minibatch): 1.2231375360488892\n", 165 | "Loss [14, 200](epoch, minibatch): 1.2079989618062974\n", 166 | "Loss [14, 300](epoch, minibatch): 1.2034609174728395\n", 167 | "Loss [15, 100](epoch, minibatch): 1.1854187709093094\n", 168 | "Loss [15, 200](epoch, minibatch): 1.1755084204673767\n", 169 | "Loss [15, 300](epoch, minibatch): 1.157737416625023\n", 170 | "Loss [16, 100](epoch, minibatch): 1.192123634815216\n", 171 | "Loss [16, 200](epoch, minibatch): 1.1370069837570191\n", 172 | "Loss [16, 300](epoch, minibatch): 1.1442907297611236\n", 173 | "Loss [17, 100](epoch, minibatch): 1.107102950811386\n", 174 | "Loss [17, 200](epoch, minibatch): 1.1141330301761627\n", 175 | "Loss [17, 300](epoch, minibatch): 1.0901018267869949\n", 176 | "Loss [18, 100](epoch, minibatch): 1.0941310381889344\n", 177 | "Loss [18, 200](epoch, minibatch): 1.0507519060373307\n", 178 | "Loss [18, 300](epoch, minibatch): 1.071660562157631\n", 179 | "Loss [19, 100](epoch, minibatch): 1.0489300626516342\n", 180 | "Loss [19, 200](epoch, minibatch): 1.0265987980365754\n", 181 | "Loss [19, 300](epoch, minibatch): 1.0444538742303848\n", 182 | "Loss [20, 100](epoch, minibatch): 1.0053205507993699\n", 183 | "Loss [20, 200](epoch, minibatch): 1.0068520921468735\n", 184 | "Loss [20, 300](epoch, minibatch): 0.9833028954267502\n", 185 | "Loss [21, 100](epoch, minibatch): 0.9830259472131729\n", 186 | "Loss [21, 200](epoch, minibatch): 0.9682340639829635\n", 187 | "Loss [21, 300](epoch, minibatch): 0.9769683623313904\n", 188 | "Loss [22, 100](epoch, minibatch): 0.9624786764383316\n", 189 | "Loss [22, 200](epoch, minibatch): 0.9280662304162979\n", 190 | "Loss [22, 300](epoch, minibatch): 0.9312695723772049\n", 191 | "Loss [23, 100](epoch, minibatch): 0.9080375480651856\n", 192 | "Loss [23, 200](epoch, minibatch): 0.9153354549407959\n", 193 | "Loss [23, 300](epoch, minibatch): 0.9088288271427154\n", 194 | "Loss [24, 100](epoch, minibatch): 0.8874812722206116\n", 195 | "Loss [24, 200](epoch, minibatch): 0.8856264269351959\n", 196 | "Loss [24, 300](epoch, minibatch): 0.8913706356287002\n", 197 | "Loss [25, 100](epoch, minibatch): 0.8780943483114243\n", 198 | "Loss [25, 200](epoch, minibatch): 0.8746073889732361\n", 199 | "Loss [25, 300](epoch, minibatch): 0.8545712912082672\n", 200 | "Loss [26, 100](epoch, minibatch): 0.833527010679245\n", 201 | "Loss [26, 200](epoch, minibatch): 0.862846450805664\n", 202 | "Loss [26, 300](epoch, minibatch): 0.8561604201793671\n", 203 | "Loss [27, 100](epoch, minibatch): 0.8465841692686081\n", 204 | "Loss [27, 200](epoch, minibatch): 0.8506513500213623\n", 205 | "Loss [27, 300](epoch, minibatch): 0.8320120567083359\n", 206 | "Loss [28, 100](epoch, minibatch): 0.8219020199775696\n", 207 | "Loss [28, 200](epoch, minibatch): 0.8109362137317657\n", 208 | "Loss [28, 300](epoch, minibatch): 0.7991736209392548\n", 209 | "Loss [29, 100](epoch, minibatch): 0.7946899974346161\n", 210 | "Loss [29, 200](epoch, minibatch): 0.8021908444166184\n", 211 | "Loss [29, 300](epoch, minibatch): 0.7956095051765442\n", 212 | "Loss [30, 100](epoch, minibatch): 0.7780683839321136\n", 213 | "Loss [30, 200](epoch, minibatch): 0.7675377279520035\n", 214 | "Loss [30, 300](epoch, minibatch): 0.7864366567134857\n", 215 | "Loss [31, 100](epoch, minibatch): 0.7492583286762238\n", 216 | "Loss [31, 200](epoch, minibatch): 0.7654818654060364\n", 217 | "Loss [31, 300](epoch, minibatch): 0.7475564414262772\n", 218 | "Loss [32, 100](epoch, minibatch): 0.7464428848028183\n", 219 | "Loss [32, 200](epoch, minibatch): 0.74605344414711\n", 220 | "Loss [32, 300](epoch, minibatch): 0.7464539271593094\n", 221 | "Loss [33, 100](epoch, minibatch): 0.7113471245765686\n", 222 | "Loss [33, 200](epoch, minibatch): 0.7323727625608444\n", 223 | "Loss [33, 300](epoch, minibatch): 0.7354452228546142\n", 224 | "Loss [34, 100](epoch, minibatch): 0.7045288586616516\n", 225 | "Loss [34, 200](epoch, minibatch): 0.7239412397146225\n", 226 | "Loss [34, 300](epoch, minibatch): 0.702357049882412\n", 227 | "Loss [35, 100](epoch, minibatch): 0.6966931942105293\n", 228 | "Loss [35, 200](epoch, minibatch): 0.6861495634913445\n", 229 | "Loss [35, 300](epoch, minibatch): 0.6926117312908172\n", 230 | "Loss [36, 100](epoch, minibatch): 0.6820511701703071\n", 231 | "Loss [36, 200](epoch, minibatch): 0.6807230353355408\n", 232 | "Loss [36, 300](epoch, minibatch): 0.6995807924866676\n", 233 | "Loss [37, 100](epoch, minibatch): 0.6681192058324814\n", 234 | "Loss [37, 200](epoch, minibatch): 0.6784629571437836\n", 235 | "Loss [37, 300](epoch, minibatch): 0.6783250313997269\n", 236 | "Loss [38, 100](epoch, minibatch): 0.6557750165462494\n", 237 | "Loss [38, 200](epoch, minibatch): 0.6511869436502457\n", 238 | "Loss [38, 300](epoch, minibatch): 0.6703853836655617\n", 239 | "Loss [39, 100](epoch, minibatch): 0.6443643879890442\n", 240 | "Loss [39, 200](epoch, minibatch): 0.6630928012728691\n", 241 | "Loss [39, 300](epoch, minibatch): 0.646601778268814\n", 242 | "Loss [40, 100](epoch, minibatch): 0.6537199735641479\n", 243 | "Loss [40, 200](epoch, minibatch): 0.6422764873504638\n", 244 | "Loss [40, 300](epoch, minibatch): 0.6430440771579743\n", 245 | "Loss [41, 100](epoch, minibatch): 0.6208404710888863\n", 246 | "Loss [41, 200](epoch, minibatch): 0.6468847548961639\n", 247 | "Loss [41, 300](epoch, minibatch): 0.6267846083641052\n", 248 | "Loss [42, 100](epoch, minibatch): 0.6086450031399727\n", 249 | "Loss [42, 200](epoch, minibatch): 0.6202251797914505\n", 250 | "Loss [42, 300](epoch, minibatch): 0.6238853070139885\n", 251 | "Loss [43, 100](epoch, minibatch): 0.6185173097252846\n", 252 | "Loss [43, 200](epoch, minibatch): 0.6032044097781182\n", 253 | "Loss [43, 300](epoch, minibatch): 0.6149302759766578\n", 254 | "Loss [44, 100](epoch, minibatch): 0.6147406467795372\n", 255 | "Loss [44, 200](epoch, minibatch): 0.6051844453811646\n", 256 | "Loss [44, 300](epoch, minibatch): 0.6186407184600831\n", 257 | "Loss [45, 100](epoch, minibatch): 0.5977083477377891\n", 258 | "Loss [45, 200](epoch, minibatch): 0.5842293477058411\n", 259 | "Loss [45, 300](epoch, minibatch): 0.5979399380087852\n", 260 | "Loss [46, 100](epoch, minibatch): 0.5913382676243782\n", 261 | "Loss [46, 200](epoch, minibatch): 0.5902311527729034\n", 262 | "Loss [46, 300](epoch, minibatch): 0.5711671343445778\n", 263 | "Loss [47, 100](epoch, minibatch): 0.5773072516918183\n", 264 | "Loss [47, 200](epoch, minibatch): 0.5879658138751984\n", 265 | "Loss [47, 300](epoch, minibatch): 0.5736844995617867\n", 266 | "Loss [48, 100](epoch, minibatch): 0.5691047406196594\n", 267 | "Loss [48, 200](epoch, minibatch): 0.5797811582684517\n", 268 | "Loss [48, 300](epoch, minibatch): 0.5650003811717034\n", 269 | "Loss [49, 100](epoch, minibatch): 0.5642372065782547\n", 270 | "Loss [49, 200](epoch, minibatch): 0.5789714482426643\n", 271 | "Loss [49, 300](epoch, minibatch): 0.5606344413757324\n", 272 | "Loss [50, 100](epoch, minibatch): 0.5656562167406082\n", 273 | "Loss [50, 200](epoch, minibatch): 0.5659125879406929\n", 274 | "Loss [50, 300](epoch, minibatch): 0.5555062511563301\n", 275 | "Loss [51, 100](epoch, minibatch): 0.5589963096380234\n", 276 | "Loss [51, 200](epoch, minibatch): 0.5562647691369057\n", 277 | "Loss [51, 300](epoch, minibatch): 0.5620793667435646\n", 278 | "Loss [52, 100](epoch, minibatch): 0.5478817278146744\n", 279 | "Loss [52, 200](epoch, minibatch): 0.5548448315262795\n", 280 | "Loss [52, 300](epoch, minibatch): 0.551018553674221\n", 281 | "Loss [53, 100](epoch, minibatch): 0.5372573563456535\n", 282 | "Loss [53, 200](epoch, minibatch): 0.532047936618328\n", 283 | "Loss [53, 300](epoch, minibatch): 0.5470021218061447\n", 284 | "Loss [54, 100](epoch, minibatch): 0.5318462827801704\n", 285 | "Loss [54, 200](epoch, minibatch): 0.531562380194664\n", 286 | "Loss [54, 300](epoch, minibatch): 0.5409415599703788\n", 287 | "Loss [55, 100](epoch, minibatch): 0.5360717245936394\n", 288 | "Loss [55, 200](epoch, minibatch): 0.5359487247467041\n", 289 | "Loss [55, 300](epoch, minibatch): 0.5337685203552246\n", 290 | "Loss [56, 100](epoch, minibatch): 0.5278202265501022\n", 291 | "Loss [56, 200](epoch, minibatch): 0.5313336753845215\n", 292 | "Loss [56, 300](epoch, minibatch): 0.5186288148164749\n", 293 | "Loss [57, 100](epoch, minibatch): 0.5166278117895127\n", 294 | "Loss [57, 200](epoch, minibatch): 0.5004483360052109\n", 295 | "Loss [57, 300](epoch, minibatch): 0.5322226944565773\n", 296 | "Loss [58, 100](epoch, minibatch): 0.5155257326364517\n", 297 | "Loss [58, 200](epoch, minibatch): 0.5150638368725776\n", 298 | "Loss [58, 300](epoch, minibatch): 0.527855603992939\n", 299 | "Loss [59, 100](epoch, minibatch): 0.5135599303245545\n", 300 | "Loss [59, 200](epoch, minibatch): 0.5028944450616837\n", 301 | "Loss [59, 300](epoch, minibatch): 0.5169031661748886\n", 302 | "Loss [60, 100](epoch, minibatch): 0.49445574909448625\n", 303 | "Loss [60, 200](epoch, minibatch): 0.5134569630026817\n", 304 | "Loss [60, 300](epoch, minibatch): 0.5036155226826667\n", 305 | "Loss [61, 100](epoch, minibatch): 0.49519878506660464\n", 306 | "Loss [61, 200](epoch, minibatch): 0.5081535249948501\n", 307 | "Loss [61, 300](epoch, minibatch): 0.5032998052239418\n", 308 | "Loss [62, 100](epoch, minibatch): 0.497781480550766\n", 309 | "Loss [62, 200](epoch, minibatch): 0.4825281220674515\n", 310 | "Loss [62, 300](epoch, minibatch): 0.5051576554775238\n", 311 | "Loss [63, 100](epoch, minibatch): 0.4705116418004036\n", 312 | "Loss [63, 200](epoch, minibatch): 0.4794814148545265\n", 313 | "Loss [63, 300](epoch, minibatch): 0.5205192571878433\n", 314 | "Loss [64, 100](epoch, minibatch): 0.4849816679954529\n", 315 | "Loss [64, 200](epoch, minibatch): 0.4939527302980423\n", 316 | "Loss [64, 300](epoch, minibatch): 0.48096194803714754\n", 317 | "Loss [65, 100](epoch, minibatch): 0.47678621411323546\n", 318 | "Loss [65, 200](epoch, minibatch): 0.49045946389436723\n", 319 | "Loss [65, 300](epoch, minibatch): 0.47963920176029207\n", 320 | "Loss [66, 100](epoch, minibatch): 0.48344717651605607\n", 321 | "Loss [66, 200](epoch, minibatch): 0.4736376956105232\n", 322 | "Loss [66, 300](epoch, minibatch): 0.48270602971315385\n", 323 | "Loss [67, 100](epoch, minibatch): 0.46687799513339995\n", 324 | "Loss [67, 200](epoch, minibatch): 0.4763159981369972\n", 325 | "Loss [67, 300](epoch, minibatch): 0.4797471961379051\n", 326 | "Loss [68, 100](epoch, minibatch): 0.4562823221087456\n", 327 | "Loss [68, 200](epoch, minibatch): 0.4788463106751442\n", 328 | "Loss [68, 300](epoch, minibatch): 0.4773494863510132\n", 329 | "Loss [69, 100](epoch, minibatch): 0.47677563488483427\n", 330 | "Loss [69, 200](epoch, minibatch): 0.4595793172717094\n", 331 | "Loss [69, 300](epoch, minibatch): 0.4738587388396263\n", 332 | "Loss [70, 100](epoch, minibatch): 0.459033143222332\n", 333 | "Loss [70, 200](epoch, minibatch): 0.4538933762907982\n", 334 | "Loss [70, 300](epoch, minibatch): 0.454161247164011\n", 335 | "Loss [71, 100](epoch, minibatch): 0.45955162674188615\n", 336 | "Loss [71, 200](epoch, minibatch): 0.4669695857167244\n", 337 | "Loss [71, 300](epoch, minibatch): 0.46416772991418837\n", 338 | "Loss [72, 100](epoch, minibatch): 0.4398690718412399\n", 339 | "Loss [72, 200](epoch, minibatch): 0.4655371907353401\n", 340 | "Loss [72, 300](epoch, minibatch): 0.44217260509729384\n", 341 | "Loss [73, 100](epoch, minibatch): 0.4422738966345787\n", 342 | "Loss [73, 200](epoch, minibatch): 0.4478651860356331\n", 343 | "Loss [73, 300](epoch, minibatch): 0.4469786891341209\n", 344 | "Loss [74, 100](epoch, minibatch): 0.452652185857296\n", 345 | "Loss [74, 200](epoch, minibatch): 0.44992151618003845\n", 346 | "Loss [74, 300](epoch, minibatch): 0.4409087851643562\n", 347 | "Loss [75, 100](epoch, minibatch): 0.45018809378147123\n", 348 | "Loss [75, 200](epoch, minibatch): 0.4415216711163521\n", 349 | "Loss [75, 300](epoch, minibatch): 0.45369074791669844\n", 350 | "Loss [76, 100](epoch, minibatch): 0.4358271759748459\n", 351 | "Loss [76, 200](epoch, minibatch): 0.44987971723079684\n", 352 | "Loss [76, 300](epoch, minibatch): 0.46197822034358976\n", 353 | "Loss [77, 100](epoch, minibatch): 0.4495241305232048\n", 354 | "Loss [77, 200](epoch, minibatch): 0.4328723394870758\n", 355 | "Loss [77, 300](epoch, minibatch): 0.44710842669010165\n", 356 | "Loss [78, 100](epoch, minibatch): 0.430138566493988\n", 357 | "Loss [78, 200](epoch, minibatch): 0.45061751931905747\n", 358 | "Loss [78, 300](epoch, minibatch): 0.44067456185817716\n", 359 | "Loss [79, 100](epoch, minibatch): 0.4371441785991192\n", 360 | "Loss [79, 200](epoch, minibatch): 0.42857502430677413\n", 361 | "Loss [79, 300](epoch, minibatch): 0.4354091404378414\n", 362 | "Loss [80, 100](epoch, minibatch): 0.4266621361672878\n", 363 | "Loss [80, 200](epoch, minibatch): 0.44032411873340604\n", 364 | "Loss [80, 300](epoch, minibatch): 0.43038433641195295\n", 365 | "Loss [81, 100](epoch, minibatch): 0.4152362994849682\n", 366 | "Loss [81, 200](epoch, minibatch): 0.4299368330836296\n", 367 | "Loss [81, 300](epoch, minibatch): 0.43515898123383523\n", 368 | "Loss [82, 100](epoch, minibatch): 0.4057521885633469\n", 369 | "Loss [82, 200](epoch, minibatch): 0.4278752975165844\n", 370 | "Loss [82, 300](epoch, minibatch): 0.4446417504549027\n", 371 | "Loss [83, 100](epoch, minibatch): 0.4321592256426811\n", 372 | "Loss [83, 200](epoch, minibatch): 0.4217809358239174\n", 373 | "Loss [83, 300](epoch, minibatch): 0.41628495454788206\n", 374 | "Loss [84, 100](epoch, minibatch): 0.418860921561718\n", 375 | "Loss [84, 200](epoch, minibatch): 0.4078247061371803\n", 376 | "Loss [84, 300](epoch, minibatch): 0.4230827637016773\n", 377 | "Loss [85, 100](epoch, minibatch): 0.4184363022446632\n", 378 | "Loss [85, 200](epoch, minibatch): 0.42853747457265856\n", 379 | "Loss [85, 300](epoch, minibatch): 0.421194339543581\n", 380 | "Loss [86, 100](epoch, minibatch): 0.3945658066868782\n", 381 | "Loss [86, 200](epoch, minibatch): 0.4248090210556984\n", 382 | "Loss [86, 300](epoch, minibatch): 0.4014836198091507\n", 383 | "Loss [87, 100](epoch, minibatch): 0.3953758604824543\n", 384 | "Loss [87, 200](epoch, minibatch): 0.40270566433668137\n", 385 | "Loss [87, 300](epoch, minibatch): 0.4346407979726791\n", 386 | "Loss [88, 100](epoch, minibatch): 0.39131177216768265\n", 387 | "Loss [88, 200](epoch, minibatch): 0.41342190772294996\n", 388 | "Loss [88, 300](epoch, minibatch): 0.42055966943502426\n", 389 | "Loss [89, 100](epoch, minibatch): 0.3993154127895832\n", 390 | "Loss [89, 200](epoch, minibatch): 0.40027943134307864\n", 391 | "Loss [89, 300](epoch, minibatch): 0.4079010629653931\n", 392 | "Loss [90, 100](epoch, minibatch): 0.41024216786026957\n", 393 | "Loss [90, 200](epoch, minibatch): 0.39797315165400504\n", 394 | "Loss [90, 300](epoch, minibatch): 0.41257811307907105\n", 395 | "Loss [91, 100](epoch, minibatch): 0.39576814725995063\n", 396 | "Loss [91, 200](epoch, minibatch): 0.4014496797323227\n", 397 | "Loss [91, 300](epoch, minibatch): 0.42784968852996824\n", 398 | "Loss [92, 100](epoch, minibatch): 0.39781673535704615\n", 399 | "Loss [92, 200](epoch, minibatch): 0.3933528487384319\n", 400 | "Loss [92, 300](epoch, minibatch): 0.417411085665226\n", 401 | "Loss [93, 100](epoch, minibatch): 0.39207191064953806\n", 402 | "Loss [93, 200](epoch, minibatch): 0.3904254674911499\n", 403 | "Loss [93, 300](epoch, minibatch): 0.4065232607722282\n", 404 | "Loss [94, 100](epoch, minibatch): 0.38698487982153895\n", 405 | "Loss [94, 200](epoch, minibatch): 0.40696744948625563\n", 406 | "Loss [94, 300](epoch, minibatch): 0.4168149581551552\n", 407 | "Loss [95, 100](epoch, minibatch): 0.3864050799608231\n", 408 | "Loss [95, 200](epoch, minibatch): 0.38968835815787317\n", 409 | "Loss [95, 300](epoch, minibatch): 0.40488335996866226\n", 410 | "Loss [96, 100](epoch, minibatch): 0.3776258008182049\n", 411 | "Loss [96, 200](epoch, minibatch): 0.399998320043087\n", 412 | "Loss [96, 300](epoch, minibatch): 0.41234890550374986\n", 413 | "Loss [97, 100](epoch, minibatch): 0.3748104391992092\n", 414 | "Loss [97, 200](epoch, minibatch): 0.3953130042552948\n", 415 | "Loss [97, 300](epoch, minibatch): 0.38361811742186547\n", 416 | "Loss [98, 100](epoch, minibatch): 0.3799448338150978\n", 417 | "Loss [98, 200](epoch, minibatch): 0.40081761837005614\n", 418 | "Loss [98, 300](epoch, minibatch): 0.39133093401789665\n", 419 | "Loss [99, 100](epoch, minibatch): 0.38404966324567796\n", 420 | "Loss [99, 200](epoch, minibatch): 0.387141085267067\n", 421 | "Loss [99, 300](epoch, minibatch): 0.3873719447851181\n", 422 | "Loss [100, 100](epoch, minibatch): 0.3959707449376583\n", 423 | "Loss [100, 200](epoch, minibatch): 0.39313841179013254\n", 424 | "Loss [100, 300](epoch, minibatch): 0.39064713671803475\n", 425 | "Loss [101, 100](epoch, minibatch): 0.38259925991296767\n", 426 | "Loss [101, 200](epoch, minibatch): 0.37727161556482314\n", 427 | "Loss [101, 300](epoch, minibatch): 0.39863600984215736\n", 428 | "Loss [102, 100](epoch, minibatch): 0.36940001249313353\n", 429 | "Loss [102, 200](epoch, minibatch): 0.3683442968130112\n", 430 | "Loss [102, 300](epoch, minibatch): 0.3892579409480095\n", 431 | "Loss [103, 100](epoch, minibatch): 0.379208252876997\n", 432 | "Loss [103, 200](epoch, minibatch): 0.3744164763391018\n", 433 | "Loss [103, 300](epoch, minibatch): 0.39919144541025164\n", 434 | "Loss [104, 100](epoch, minibatch): 0.3775774529576302\n", 435 | "Loss [104, 200](epoch, minibatch): 0.3805718170106411\n", 436 | "Loss [104, 300](epoch, minibatch): 0.3764234222471714\n", 437 | "Loss [105, 100](epoch, minibatch): 0.36263732209801675\n", 438 | "Loss [105, 200](epoch, minibatch): 0.3814633898437023\n", 439 | "Loss [105, 300](epoch, minibatch): 0.38640366077423094\n", 440 | "Loss [106, 100](epoch, minibatch): 0.361251565515995\n", 441 | "Loss [106, 200](epoch, minibatch): 0.37794685557484625\n", 442 | "Loss [106, 300](epoch, minibatch): 0.38354177072644235\n", 443 | "Loss [107, 100](epoch, minibatch): 0.35719128370285036\n", 444 | "Loss [107, 200](epoch, minibatch): 0.3681615675985813\n", 445 | "Loss [107, 300](epoch, minibatch): 0.3835552254319191\n", 446 | "Loss [108, 100](epoch, minibatch): 0.3459990732371807\n", 447 | "Loss [108, 200](epoch, minibatch): 0.37363224148750307\n", 448 | "Loss [108, 300](epoch, minibatch): 0.3908988712728024\n", 449 | "Loss [109, 100](epoch, minibatch): 0.3808039338886738\n", 450 | "Loss [109, 200](epoch, minibatch): 0.37834601536393164\n", 451 | "Loss [109, 300](epoch, minibatch): 0.3620477768778801\n", 452 | "Loss [110, 100](epoch, minibatch): 0.35704113021492956\n", 453 | "Loss [110, 200](epoch, minibatch): 0.37036731615662577\n", 454 | "Loss [110, 300](epoch, minibatch): 0.37806847020983697\n", 455 | "Loss [111, 100](epoch, minibatch): 0.3648793002963066\n", 456 | "Loss [111, 200](epoch, minibatch): 0.3695460321009159\n", 457 | "Loss [111, 300](epoch, minibatch): 0.36696819990873336\n", 458 | "Loss [112, 100](epoch, minibatch): 0.3497805692255497\n", 459 | "Loss [112, 200](epoch, minibatch): 0.3715791854262352\n", 460 | "Loss [112, 300](epoch, minibatch): 0.3761456596851349\n", 461 | "Loss [113, 100](epoch, minibatch): 0.36030389592051504\n", 462 | "Loss [113, 200](epoch, minibatch): 0.3580723369121552\n", 463 | "Loss [113, 300](epoch, minibatch): 0.37366754561662674\n", 464 | "Loss [114, 100](epoch, minibatch): 0.355857448130846\n", 465 | "Loss [114, 200](epoch, minibatch): 0.3717909619212151\n", 466 | "Loss [114, 300](epoch, minibatch): 0.3745942084491253\n", 467 | "Loss [115, 100](epoch, minibatch): 0.35502728521823884\n", 468 | "Loss [115, 200](epoch, minibatch): 0.3801086536049843\n", 469 | "Loss [115, 300](epoch, minibatch): 0.3543571825325489\n", 470 | "Loss [116, 100](epoch, minibatch): 0.3592769005894661\n", 471 | "Loss [116, 200](epoch, minibatch): 0.36883870139718056\n", 472 | "Loss [116, 300](epoch, minibatch): 0.3670287993550301\n", 473 | "Loss [117, 100](epoch, minibatch): 0.35530189946293833\n", 474 | "Loss [117, 200](epoch, minibatch): 0.3661247684061527\n", 475 | "Loss [117, 300](epoch, minibatch): 0.35167143180966376\n", 476 | "Loss [118, 100](epoch, minibatch): 0.3455443613231182\n", 477 | "Loss [118, 200](epoch, minibatch): 0.3425072169303894\n", 478 | "Loss [118, 300](epoch, minibatch): 0.3530725271999836\n", 479 | "Loss [119, 100](epoch, minibatch): 0.35234869197010993\n", 480 | "Loss [119, 200](epoch, minibatch): 0.3528152620792389\n", 481 | "Loss [119, 300](epoch, minibatch): 0.3658104398846626\n", 482 | "Loss [120, 100](epoch, minibatch): 0.35267854034900664\n", 483 | "Loss [120, 200](epoch, minibatch): 0.36587018370628355\n", 484 | "Loss [120, 300](epoch, minibatch): 0.34983618319034576\n", 485 | "Loss [121, 100](epoch, minibatch): 0.3462674055993557\n", 486 | "Loss [121, 200](epoch, minibatch): 0.3647799214720726\n", 487 | "Loss [121, 300](epoch, minibatch): 0.362434226423502\n", 488 | "Loss [122, 100](epoch, minibatch): 0.3419904267787933\n", 489 | "Loss [122, 200](epoch, minibatch): 0.34956473514437675\n", 490 | "Loss [122, 300](epoch, minibatch): 0.36206165090203285\n", 491 | "Loss [123, 100](epoch, minibatch): 0.34519577994942663\n", 492 | "Loss [123, 200](epoch, minibatch): 0.34915505900979044\n", 493 | "Loss [123, 300](epoch, minibatch): 0.3660018076002598\n", 494 | "Loss [124, 100](epoch, minibatch): 0.3389536565542221\n", 495 | "Loss [124, 200](epoch, minibatch): 0.3529126699268818\n", 496 | "Loss [124, 300](epoch, minibatch): 0.3584320928156376\n", 497 | "Loss [125, 100](epoch, minibatch): 0.2745757547020912\n", 498 | "Loss [125, 200](epoch, minibatch): 0.22790059849619865\n", 499 | "Loss [125, 300](epoch, minibatch): 0.20956130720674992\n", 500 | "Loss [126, 100](epoch, minibatch): 0.18237497799098493\n", 501 | "Loss [126, 200](epoch, minibatch): 0.1933388452231884\n", 502 | "Loss [126, 300](epoch, minibatch): 0.18526836052536966\n", 503 | "Loss [127, 100](epoch, minibatch): 0.17400821894407273\n", 504 | "Loss [127, 200](epoch, minibatch): 0.16326623141765595\n", 505 | "Loss [127, 300](epoch, minibatch): 0.16480547986924649\n", 506 | "Loss [128, 100](epoch, minibatch): 0.1624736401066184\n", 507 | "Loss [128, 200](epoch, minibatch): 0.1507112606242299\n", 508 | "Loss [128, 300](epoch, minibatch): 0.1602012763172388\n", 509 | "Loss [129, 100](epoch, minibatch): 0.14371952280402184\n", 510 | "Loss [129, 200](epoch, minibatch): 0.14584138613194228\n", 511 | "Loss [129, 300](epoch, minibatch): 0.15714668460190295\n", 512 | "Loss [130, 100](epoch, minibatch): 0.14310971003025771\n", 513 | "Loss [130, 200](epoch, minibatch): 0.15047285255044698\n", 514 | "Loss [130, 300](epoch, minibatch): 0.14869120560586452\n", 515 | "Loss [131, 100](epoch, minibatch): 0.13591380145400764\n", 516 | "Loss [131, 200](epoch, minibatch): 0.12753482565283775\n", 517 | "Loss [131, 300](epoch, minibatch): 0.13426522023975848\n", 518 | "Loss [132, 100](epoch, minibatch): 0.1255357664451003\n", 519 | "Loss [132, 200](epoch, minibatch): 0.12333543673157692\n", 520 | "Loss [132, 300](epoch, minibatch): 0.12769552011042834\n", 521 | "Loss [133, 100](epoch, minibatch): 0.1199782094731927\n", 522 | "Loss [133, 200](epoch, minibatch): 0.11861515365540981\n", 523 | "Loss [133, 300](epoch, minibatch): 0.12944098114967345\n", 524 | "Loss [134, 100](epoch, minibatch): 0.11577606782317161\n", 525 | "Loss [134, 200](epoch, minibatch): 0.1202187130972743\n", 526 | "Loss [134, 300](epoch, minibatch): 0.1212599566206336\n", 527 | "Loss [135, 100](epoch, minibatch): 0.1121910022571683\n", 528 | "Loss [135, 200](epoch, minibatch): 0.10954646151512862\n", 529 | "Loss [135, 300](epoch, minibatch): 0.11485624104738236\n", 530 | "Loss [136, 100](epoch, minibatch): 0.1138050353527069\n", 531 | "Loss [136, 200](epoch, minibatch): 0.11469013143330813\n", 532 | "Loss [136, 300](epoch, minibatch): 0.1149770862981677\n", 533 | "Loss [137, 100](epoch, minibatch): 0.10496257197111845\n", 534 | "Loss [137, 200](epoch, minibatch): 0.11100531533360482\n", 535 | "Loss [137, 300](epoch, minibatch): 0.10968413643538952\n", 536 | "Loss [138, 100](epoch, minibatch): 0.10927507575601339\n", 537 | "Loss [138, 200](epoch, minibatch): 0.10147064968943596\n", 538 | "Loss [138, 300](epoch, minibatch): 0.09852944605052472\n", 539 | "Loss [139, 100](epoch, minibatch): 0.10008480526506901\n", 540 | "Loss [139, 200](epoch, minibatch): 0.09973813395947217\n", 541 | "Loss [139, 300](epoch, minibatch): 0.10994329303503036\n", 542 | "Loss [140, 100](epoch, minibatch): 0.09728149063885212\n", 543 | "Loss [140, 200](epoch, minibatch): 0.10047981210052967\n", 544 | "Loss [140, 300](epoch, minibatch): 0.09900152161717415\n", 545 | "Loss [141, 100](epoch, minibatch): 0.09882153782993555\n", 546 | "Loss [141, 200](epoch, minibatch): 0.09600843213498593\n", 547 | "Loss [141, 300](epoch, minibatch): 0.08927489314228296\n", 548 | "Loss [142, 100](epoch, minibatch): 0.08919163968414068\n", 549 | "Loss [142, 200](epoch, minibatch): 0.0945032512396574\n", 550 | "Loss [142, 300](epoch, minibatch): 0.0904152081720531\n", 551 | "Loss [143, 100](epoch, minibatch): 0.08809526942670345\n", 552 | "Loss [143, 200](epoch, minibatch): 0.09453827252611519\n", 553 | "Loss [143, 300](epoch, minibatch): 0.08901693891733885\n", 554 | "Loss [144, 100](epoch, minibatch): 0.08728643577545882\n", 555 | "Loss [144, 200](epoch, minibatch): 0.07892006158828735\n", 556 | "Loss [144, 300](epoch, minibatch): 0.08468541529029608\n", 557 | "Loss [145, 100](epoch, minibatch): 0.08363243386149406\n", 558 | "Loss [145, 200](epoch, minibatch): 0.08623464576900006\n", 559 | "Loss [145, 300](epoch, minibatch): 0.0881262375600636\n", 560 | "Loss [146, 100](epoch, minibatch): 0.08025822123512626\n", 561 | "Loss [146, 200](epoch, minibatch): 0.08137236028909683\n", 562 | "Loss [146, 300](epoch, minibatch): 0.08294845636934042\n", 563 | "Loss [147, 100](epoch, minibatch): 0.07447089921683073\n", 564 | "Loss [147, 200](epoch, minibatch): 0.08077028729021549\n", 565 | "Loss [147, 300](epoch, minibatch): 0.08693202156573535\n", 566 | "Loss [148, 100](epoch, minibatch): 0.07026688635349274\n", 567 | "Loss [148, 200](epoch, minibatch): 0.08575814723968506\n", 568 | "Loss [148, 300](epoch, minibatch): 0.07528316034004093\n", 569 | "Loss [149, 100](epoch, minibatch): 0.07796191483736038\n", 570 | "Loss [149, 200](epoch, minibatch): 0.07794814396649599\n", 571 | "Loss [149, 300](epoch, minibatch): 0.07503192286938429\n", 572 | "Loss [150, 100](epoch, minibatch): 0.07355786092579365\n", 573 | "Loss [150, 200](epoch, minibatch): 0.07198565602302551\n", 574 | "Loss [150, 300](epoch, minibatch): 0.07614338904619217\n", 575 | "Loss [151, 100](epoch, minibatch): 0.0696952274441719\n", 576 | "Loss [151, 200](epoch, minibatch): 0.07664596512913704\n", 577 | "Loss [151, 300](epoch, minibatch): 0.07621198937296868\n", 578 | "Loss [152, 100](epoch, minibatch): 0.07049904264509678\n", 579 | "Loss [152, 200](epoch, minibatch): 0.07448908409103751\n", 580 | "Loss [152, 300](epoch, minibatch): 0.0722444773092866\n", 581 | "Loss [153, 100](epoch, minibatch): 0.07408497285097837\n", 582 | "Loss [153, 200](epoch, minibatch): 0.07429932378232479\n", 583 | "Loss [153, 300](epoch, minibatch): 0.07627773340791463\n", 584 | "Loss [154, 100](epoch, minibatch): 0.07251217653974891\n", 585 | "Loss [154, 200](epoch, minibatch): 0.06886578345671296\n", 586 | "Loss [154, 300](epoch, minibatch): 0.07596122696995736\n", 587 | "Loss [155, 100](epoch, minibatch): 0.06036315029487014\n", 588 | "Loss [155, 200](epoch, minibatch): 0.06643758732825518\n", 589 | "Loss [155, 300](epoch, minibatch): 0.0678966423124075\n", 590 | "Loss [156, 100](epoch, minibatch): 0.06482726924121379\n", 591 | "Loss [156, 200](epoch, minibatch): 0.06206686893478036\n", 592 | "Loss [156, 300](epoch, minibatch): 0.06706166461110115\n", 593 | "Loss [157, 100](epoch, minibatch): 0.061307161506265404\n", 594 | "Loss [157, 200](epoch, minibatch): 0.05830227831378579\n", 595 | "Loss [157, 300](epoch, minibatch): 0.06689968667924404\n", 596 | "Loss [158, 100](epoch, minibatch): 0.058500857651233674\n", 597 | "Loss [158, 200](epoch, minibatch): 0.0662197027541697\n", 598 | "Loss [158, 300](epoch, minibatch): 0.06511821905151009\n", 599 | "Loss [159, 100](epoch, minibatch): 0.056493834406137464\n", 600 | "Loss [159, 200](epoch, minibatch): 0.05936595628038049\n", 601 | "Loss [159, 300](epoch, minibatch): 0.06405909527093172\n", 602 | "Loss [160, 100](epoch, minibatch): 0.0620882386341691\n", 603 | "Loss [160, 200](epoch, minibatch): 0.0671468859165907\n", 604 | "Loss [160, 300](epoch, minibatch): 0.06453811233863234\n", 605 | "Loss [161, 100](epoch, minibatch): 0.05737098846584558\n", 606 | "Loss [161, 200](epoch, minibatch): 0.05518507042899728\n", 607 | "Loss [161, 300](epoch, minibatch): 0.059649915415793654\n", 608 | "Loss [162, 100](epoch, minibatch): 0.055831551291048526\n", 609 | "Loss [162, 200](epoch, minibatch): 0.05499050464481115\n", 610 | "Loss [162, 300](epoch, minibatch): 0.0628994096070528\n", 611 | "Loss [163, 100](epoch, minibatch): 0.05985814977437258\n", 612 | "Loss [163, 200](epoch, minibatch): 0.060354581326246264\n", 613 | "Loss [163, 300](epoch, minibatch): 0.0627272029966116\n", 614 | "Loss [164, 100](epoch, minibatch): 0.05714571725577116\n", 615 | "Loss [164, 200](epoch, minibatch): 0.054981582798063755\n", 616 | "Loss [164, 300](epoch, minibatch): 0.05704518230631948\n", 617 | "Loss [165, 100](epoch, minibatch): 0.05408035064116121\n", 618 | "Loss [165, 200](epoch, minibatch): 0.05964760657399893\n", 619 | "Loss [165, 300](epoch, minibatch): 0.053540941476821896\n", 620 | "Loss [166, 100](epoch, minibatch): 0.056790421810001136\n", 621 | "Loss [166, 200](epoch, minibatch): 0.050547503493726255\n", 622 | "Loss [166, 300](epoch, minibatch): 0.05958049297332764\n", 623 | "Loss [167, 100](epoch, minibatch): 0.05472408337518573\n", 624 | "Loss [167, 200](epoch, minibatch): 0.05718009475618601\n", 625 | "Loss [167, 300](epoch, minibatch): 0.05491965893656015\n", 626 | "Loss [168, 100](epoch, minibatch): 0.051263665817677974\n", 627 | "Loss [168, 200](epoch, minibatch): 0.05783386977389455\n", 628 | "Loss [168, 300](epoch, minibatch): 0.04919152554124594\n", 629 | "Loss [169, 100](epoch, minibatch): 0.052700967714190486\n", 630 | "Loss [169, 200](epoch, minibatch): 0.05077950274571776\n", 631 | "Loss [169, 300](epoch, minibatch): 0.05193569347262383\n", 632 | "Loss [170, 100](epoch, minibatch): 0.05093814089894295\n", 633 | "Loss [170, 200](epoch, minibatch): 0.05232907233759761\n", 634 | "Loss [170, 300](epoch, minibatch): 0.058492430597543714\n", 635 | "Loss [171, 100](epoch, minibatch): 0.04809644397348165\n", 636 | "Loss [171, 200](epoch, minibatch): 0.04822698749601841\n", 637 | "Loss [171, 300](epoch, minibatch): 0.04737004151567817\n", 638 | "Loss [172, 100](epoch, minibatch): 0.04959821948781609\n", 639 | "Loss [172, 200](epoch, minibatch): 0.048613924365490675\n", 640 | "Loss [172, 300](epoch, minibatch): 0.04811608199030161\n", 641 | "Loss [173, 100](epoch, minibatch): 0.05222764510661364\n", 642 | "Loss [173, 200](epoch, minibatch): 0.05422562401741743\n", 643 | "Loss [173, 300](epoch, minibatch): 0.05661549609154463\n", 644 | "Loss [174, 100](epoch, minibatch): 0.05212319299578667\n", 645 | "Loss [174, 200](epoch, minibatch): 0.04867717148736119\n", 646 | "Loss [174, 300](epoch, minibatch): 0.051791660711169245\n", 647 | "Loss [175, 100](epoch, minibatch): 0.04901155410334468\n", 648 | "Loss [175, 200](epoch, minibatch): 0.05254359154030681\n", 649 | "Loss [175, 300](epoch, minibatch): 0.051361130569130185\n", 650 | "Loss [176, 100](epoch, minibatch): 0.056696265991777184\n", 651 | "Loss [176, 200](epoch, minibatch): 0.057765098493546246\n", 652 | "Loss [176, 300](epoch, minibatch): 0.05033251633867621\n", 653 | "Loss [177, 100](epoch, minibatch): 0.047747475747019055\n", 654 | "Loss [177, 200](epoch, minibatch): 0.05219324320554733\n", 655 | "Loss [177, 300](epoch, minibatch): 0.0493961656652391\n", 656 | "Loss [178, 100](epoch, minibatch): 0.04632460195571184\n", 657 | "Loss [178, 200](epoch, minibatch): 0.04858155118301511\n", 658 | "Loss [178, 300](epoch, minibatch): 0.053830772526562216\n", 659 | "Loss [179, 100](epoch, minibatch): 0.04543822238221765\n", 660 | "Loss [179, 200](epoch, minibatch): 0.047149957194924354\n", 661 | "Loss [179, 300](epoch, minibatch): 0.050803377628326415\n", 662 | "Loss [180, 100](epoch, minibatch): 0.049232794865965844\n", 663 | "Loss [180, 200](epoch, minibatch): 0.04929339682683349\n", 664 | "Loss [180, 300](epoch, minibatch): 0.04897596085444093\n", 665 | "Loss [181, 100](epoch, minibatch): 0.046268546655774116\n", 666 | "Loss [181, 200](epoch, minibatch): 0.042781891915947196\n", 667 | "Loss [181, 300](epoch, minibatch): 0.05060765234753489\n", 668 | "Loss [182, 100](epoch, minibatch): 0.04476797264069319\n", 669 | "Loss [182, 200](epoch, minibatch): 0.05307487161830068\n", 670 | "Loss [182, 300](epoch, minibatch): 0.04850761391222477\n", 671 | "Loss [183, 100](epoch, minibatch): 0.043931541200727224\n", 672 | "Loss [183, 200](epoch, minibatch): 0.04511226551607251\n", 673 | "Loss [183, 300](epoch, minibatch): 0.04752522405236959\n", 674 | "Loss [184, 100](epoch, minibatch): 0.04832927925512195\n", 675 | "Loss [184, 200](epoch, minibatch): 0.05141777416691184\n", 676 | "Loss [184, 300](epoch, minibatch): 0.04611195098608732\n", 677 | "Loss [185, 100](epoch, minibatch): 0.04489896455779672\n", 678 | "Loss [185, 200](epoch, minibatch): 0.05029143312945962\n", 679 | "Loss [185, 300](epoch, minibatch): 0.04505758104845881\n", 680 | "Loss [186, 100](epoch, minibatch): 0.043485564403235914\n", 681 | "Loss [186, 200](epoch, minibatch): 0.04173792021349072\n", 682 | "Loss [186, 300](epoch, minibatch): 0.05178328309208155\n", 683 | "Loss [187, 100](epoch, minibatch): 0.04555150557309389\n", 684 | "Loss [187, 200](epoch, minibatch): 0.045824325457215306\n", 685 | "Loss [187, 300](epoch, minibatch): 0.04687584307044745\n", 686 | "Loss [188, 100](epoch, minibatch): 0.04189886808395386\n", 687 | "Loss [188, 200](epoch, minibatch): 0.044915573969483376\n", 688 | "Loss [188, 300](epoch, minibatch): 0.046888482999056576\n", 689 | "Loss [189, 100](epoch, minibatch): 0.042703730929642914\n", 690 | "Loss [189, 200](epoch, minibatch): 0.04844313383102417\n", 691 | "Loss [189, 300](epoch, minibatch): 0.04827386787161231\n", 692 | "Loss [190, 100](epoch, minibatch): 0.03814472258090973\n", 693 | "Loss [190, 200](epoch, minibatch): 0.04559313138946891\n", 694 | "Loss [190, 300](epoch, minibatch): 0.04017418723553419\n", 695 | "Loss [191, 100](epoch, minibatch): 0.04222352247685194\n", 696 | "Loss [191, 200](epoch, minibatch): 0.0527523024007678\n", 697 | "Loss [191, 300](epoch, minibatch): 0.04720606388524175\n", 698 | "Loss [192, 100](epoch, minibatch): 0.04322879239916801\n", 699 | "Loss [192, 200](epoch, minibatch): 0.04275385139510036\n", 700 | "Loss [192, 300](epoch, minibatch): 0.0427305543795228\n", 701 | "Loss [193, 100](epoch, minibatch): 0.043951349519193175\n", 702 | "Loss [193, 200](epoch, minibatch): 0.037608045563101766\n", 703 | "Loss [193, 300](epoch, minibatch): 0.04689977327361703\n", 704 | "Loss [194, 100](epoch, minibatch): 0.04425131002441049\n", 705 | "Loss [194, 200](epoch, minibatch): 0.043790823165327314\n", 706 | "Loss [194, 300](epoch, minibatch): 0.04787272963672876\n", 707 | "Loss [195, 100](epoch, minibatch): 0.045910639762878416\n", 708 | "Loss [195, 200](epoch, minibatch): 0.04503726653754711\n", 709 | "Loss [195, 300](epoch, minibatch): 0.044547636974602935\n", 710 | "Loss [196, 100](epoch, minibatch): 0.038648189175873995\n", 711 | "Loss [196, 200](epoch, minibatch): 0.04546283291652799\n", 712 | "Loss [196, 300](epoch, minibatch): 0.046378096230328084\n", 713 | "Loss [197, 100](epoch, minibatch): 0.03882626427337527\n", 714 | "Loss [197, 200](epoch, minibatch): 0.03689012542366982\n", 715 | "Loss [197, 300](epoch, minibatch): 0.030634157378226518\n", 716 | "Loss [198, 100](epoch, minibatch): 0.030805828757584097\n", 717 | "Loss [198, 200](epoch, minibatch): 0.029564176592975853\n", 718 | "Loss [198, 300](epoch, minibatch): 0.027018439136445523\n", 719 | "Loss [199, 100](epoch, minibatch): 0.023948805034160615\n", 720 | "Loss [199, 200](epoch, minibatch): 0.030192875675857066\n", 721 | "Loss [199, 300](epoch, minibatch): 0.025430371947586537\n", 722 | "Loss [200, 100](epoch, minibatch): 0.023097231648862362\n", 723 | "Loss [200, 200](epoch, minibatch): 0.022299093399196862\n", 724 | "Loss [200, 300](epoch, minibatch): 0.023658006470650434\n", 725 | "Training Done\n" 726 | ] 727 | } 728 | ], 729 | "source": [ 730 | "EPOCHS = 200\n", 731 | "for epoch in range(EPOCHS):\n", 732 | " losses = []\n", 733 | " running_loss = 0\n", 734 | " for i, inp in enumerate(trainloader):\n", 735 | " inputs, labels = inp\n", 736 | " inputs, labels = inputs.to('cuda'), labels.to('cuda')\n", 737 | " optimizer.zero_grad()\n", 738 | " \n", 739 | " outputs = net(inputs)\n", 740 | " loss = criterion(outputs, labels)\n", 741 | " losses.append(loss.item())\n", 742 | "\n", 743 | " loss.backward()\n", 744 | " optimizer.step()\n", 745 | " \n", 746 | " running_loss += loss.item()\n", 747 | " \n", 748 | " if i%100 == 0 and i > 0:\n", 749 | " print(f'Loss [{epoch+1}, {i}](epoch, minibatch): ', running_loss / 100)\n", 750 | " running_loss = 0.0\n", 751 | "\n", 752 | " avg_loss = sum(losses)/len(losses)\n", 753 | " scheduler.step(avg_loss)\n", 754 | " \n", 755 | "print('Training Done')" 756 | ] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": 39, 761 | "metadata": { 762 | "colab": { 763 | "base_uri": "https://localhost:8080/", 764 | "height": 34 765 | }, 766 | "colab_type": "code", 767 | "id": "eY3rPw7wPVEe", 768 | "outputId": "213b5a1c-d9f4-4659-f6dc-b6ff9d0d1a44" 769 | }, 770 | "outputs": [ 771 | { 772 | "name": "stdout", 773 | "output_type": "stream", 774 | "text": [ 775 | "Accuracy on 10,000 test images: 85.66 %\n" 776 | ] 777 | } 778 | ], 779 | "source": [ 780 | "correct = 0\n", 781 | "total = 0\n", 782 | "\n", 783 | "with torch.no_grad():\n", 784 | " for data in testloader:\n", 785 | " images, labels = data\n", 786 | " images, labels = images.to('cuda'), labels.to('cuda')\n", 787 | " outputs = net(images)\n", 788 | " \n", 789 | " _, predicted = torch.max(outputs.data, 1)\n", 790 | " total += labels.size(0)\n", 791 | " correct += (predicted == labels).sum().item()\n", 792 | "print('Accuracy on 10,000 test images: ', 100*(correct/total), '%')" 793 | ] 794 | }, 795 | { 796 | "cell_type": "code", 797 | "execution_count": 39, 798 | "metadata": { 799 | "colab": {}, 800 | "colab_type": "code", 801 | "id": "oeVRGN56hkNZ" 802 | }, 803 | "outputs": [], 804 | "source": [] 805 | } 806 | ], 807 | "metadata": { 808 | "accelerator": "GPU", 809 | "colab": { 810 | "name": "CIFAR10-ResNet50_85%.ipynb", 811 | "provenance": [] 812 | }, 813 | "kernelspec": { 814 | "display_name": "Python 3", 815 | "language": "python", 816 | "name": "python3" 817 | }, 818 | "language_info": { 819 | "codemirror_mode": { 820 | "name": "ipython", 821 | "version": 3 822 | }, 823 | "file_extension": ".py", 824 | "mimetype": "text/x-python", 825 | "name": "python", 826 | "nbconvert_exporter": "python", 827 | "pygments_lexer": "ipython3", 828 | "version": "3.7.4" 829 | } 830 | }, 831 | "nbformat": 4, 832 | "nbformat_minor": 1 833 | } 834 | -------------------------------------------------------------------------------- /ResNet/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | def __init__(self, in_channels, out_channels, i_downsample=None, stride=1): 9 | super(Bottleneck, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 12 | self.batch_norm1 = nn.BatchNorm2d(out_channels) 13 | 14 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1) 15 | self.batch_norm2 = nn.BatchNorm2d(out_channels) 16 | 17 | self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0) 18 | self.batch_norm3 = nn.BatchNorm2d(out_channels*self.expansion) 19 | 20 | self.i_downsample = i_downsample 21 | self.stride = stride 22 | self.relu = nn.ReLU() 23 | 24 | def forward(self, x): 25 | identity = x.clone() 26 | x = self.relu(self.batch_norm1(self.conv1(x))) 27 | 28 | x = self.relu(self.batch_norm2(self.conv2(x))) 29 | 30 | x = self.conv3(x) 31 | x = self.batch_norm3(x) 32 | 33 | #downsample if needed 34 | if self.i_downsample is not None: 35 | identity = self.i_downsample(identity) 36 | #add identity 37 | x+=identity 38 | x=self.relu(x) 39 | 40 | return x 41 | 42 | class Block(nn.Module): 43 | expansion = 1 44 | def __init__(self, in_channels, out_channels, i_downsample=None, stride=1): 45 | super(Block, self).__init__() 46 | 47 | 48 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False) 49 | self.batch_norm1 = nn.BatchNorm2d(out_channels) 50 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False) 51 | self.batch_norm2 = nn.BatchNorm2d(out_channels) 52 | 53 | self.i_downsample = i_downsample 54 | self.stride = stride 55 | self.relu = nn.ReLU() 56 | 57 | def forward(self, x): 58 | identity = x.clone() 59 | 60 | x = self.relu(self.batch_norm2(self.conv1(x))) 61 | x = self.batch_norm2(self.conv2(x)) 62 | 63 | if self.i_downsample is not None: 64 | identity = self.i_downsample(identity) 65 | print(x.shape) 66 | print(identity.shape) 67 | x += identity 68 | x = self.relu(x) 69 | return x 70 | 71 | 72 | 73 | 74 | class ResNet(nn.Module): 75 | def __init__(self, ResBlock, layer_list, num_classes, num_channels=3): 76 | super(ResNet, self).__init__() 77 | self.in_channels = 64 78 | 79 | self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 80 | self.batch_norm1 = nn.BatchNorm2d(64) 81 | self.relu = nn.ReLU() 82 | self.max_pool = nn.MaxPool2d(kernel_size = 3, stride=2, padding=1) 83 | 84 | self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64) 85 | self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2) 86 | self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2) 87 | self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2) 88 | 89 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 90 | self.fc = nn.Linear(512*ResBlock.expansion, num_classes) 91 | 92 | def forward(self, x): 93 | x = self.relu(self.batch_norm1(self.conv1(x))) 94 | x = self.max_pool(x) 95 | 96 | x = self.layer1(x) 97 | x = self.layer2(x) 98 | x = self.layer3(x) 99 | x = self.layer4(x) 100 | 101 | x = self.avgpool(x) 102 | x = x.reshape(x.shape[0], -1) 103 | x = self.fc(x) 104 | 105 | return x 106 | 107 | def _make_layer(self, ResBlock, blocks, planes, stride=1): 108 | ii_downsample = None 109 | layers = [] 110 | 111 | if stride != 1 or self.in_channels != planes*ResBlock.expansion: 112 | ii_downsample = nn.Sequential( 113 | nn.Conv2d(self.in_channels, planes*ResBlock.expansion, kernel_size=1, stride=stride), 114 | nn.BatchNorm2d(planes*ResBlock.expansion) 115 | ) 116 | 117 | layers.append(ResBlock(self.in_channels, planes, i_downsample=ii_downsample, stride=stride)) 118 | self.in_channels = planes*ResBlock.expansion 119 | 120 | for i in range(blocks-1): 121 | layers.append(ResBlock(self.in_channels, planes)) 122 | 123 | return nn.Sequential(*layers) 124 | 125 | 126 | 127 | def ResNet50(num_classes, channels=3): 128 | return ResNet(Bottleneck, [3,4,6,3], num_classes, channels) 129 | 130 | def ResNet101(num_classes, channels=3): 131 | return ResNet(Bottleneck, [3,4,23,3], num_classes, channels) 132 | 133 | def ResNet152(num_classes, channels=3): 134 | return ResNet(Bottleneck, [3,8,36,3], num_classes, channels) 135 | 136 | --------------------------------------------------------------------------------