├── .gitignore ├── README.md ├── pseudo_label-DL.ipynb └── pseudo_label-Logistic_reg.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pseudo Labeling to deal with small datasets 2 | 3 | Accompanying blog : 4 | 5 | Dataset: https://www.kaggle.com/oddrationale/mnist-in-csv 6 | (Download to the 'data' folder) 7 | 8 | ## Credits and References: 9 | 10 | 1. Dong-Hyun Lee. "Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks" ICML 2013 Workshop : Challenges in Representation Learning (WREPL), Atlanta, Georgia, USA, 2013 (http://deeplearning.net/wp-content/uploads/2013/03/pseudo_label_final.pdf) 11 | 2. https://github.com/peimengsui/semi_supervised_mnist 12 | 3. https://www.analyticsvidhya.com/blog/2017/09/pseudo-labelling-semi-supervised-learning-technique/ 13 | 14 | -------------------------------------------------------------------------------- /pseudo_label-DL.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "Collapsed": "false" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "%matplotlib inline\n", 12 | "from MulticoreTSNE import MulticoreTSNE as TSNE\n", 13 | "from matplotlib import pyplot as plt\n", 14 | "import torch\n", 15 | "from torchvision import datasets, transforms\n", 16 | "from torch import nn\n", 17 | "import torch.nn.functional as F\n", 18 | "import numpy as np\n", 19 | "\n", 20 | "torch.manual_seed(42)\n", 21 | "np.random.seed(42)\n", 22 | "torch.backends.cudnn.deterministic = True\n", 23 | "torch.backends.cudnn.benchmark = False" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 3, 29 | "metadata": { 30 | "Collapsed": "false" 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "import pandas as pd \n", 35 | "\n", 36 | "UNLABELED_BS = 256\n", 37 | "TRAIN_BS = 32\n", 38 | "TEST_BS = 1024\n", 39 | "\n", 40 | "num_train_samples = 1000\n", 41 | "samples_per_class = int(num_train_samples/9)\n", 42 | "\n", 43 | "x = pd.read_csv('data/mnist_train.csv')\n", 44 | "y = x['label']\n", 45 | "x.drop(['label'], inplace = True, axis = 1)\n", 46 | "\n", 47 | "x_test = pd.read_csv('data/mnist_test.csv')\n", 48 | "y_test = x_test['label']\n", 49 | "x_test.drop(['label'], inplace = True, axis = 1)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": { 55 | "Collapsed": "false" 56 | }, 57 | "source": [ 58 | "Now, lets divide the dataset into train and unlabeled sets. For the train set we'll make sure that we have equal samples for all the 10 classes. (class-balancing)\n", 59 | "\n", 60 | "We wont use the labels for the unlabeled set." 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 4, 66 | "metadata": { 67 | "Collapsed": "false" 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "x_train, x_unlabeled = x[y.values == 0].values[:samples_per_class], x[y.values == 0].values[samples_per_class:]\n", 72 | "y_train = y[y.values == 0].values[:samples_per_class]\n", 73 | "\n", 74 | "for i in range(1,10):\n", 75 | " x_train = np.concatenate([x_train, x[y.values == i].values[:samples_per_class]], axis = 0)\n", 76 | " y_train = np.concatenate([y_train, y[y.values == i].values[:samples_per_class]], axis = 0)\n", 77 | " \n", 78 | " x_unlabeled = np.concatenate([x_unlabeled, x[y.values == i].values[samples_per_class:]], axis = 0)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": { 84 | "Collapsed": "false" 85 | }, 86 | "source": [ 87 | "Next, we'll normalize the data, convert it into tensors and create the dataloaders for train, unlabeled and test sets. " 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 5, 93 | "metadata": { 94 | "Collapsed": "false" 95 | }, 96 | "outputs": [], 97 | "source": [ 98 | "from sklearn.preprocessing import Normalizer\n", 99 | "\n", 100 | "normalizer = Normalizer()\n", 101 | "x_train = normalizer.fit_transform(x_train)\n", 102 | "x_unlabeled = normalizer.transform(x_unlabeled)\n", 103 | "x_test = normalizer.transform(x_test.values)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 6, 109 | "metadata": { 110 | "Collapsed": "false" 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "x_train = torch.from_numpy(x_train).type(torch.FloatTensor)\n", 115 | "y_train = torch.from_numpy(y_train).type(torch.LongTensor) \n", 116 | "\n", 117 | "x_test = torch.from_numpy(x_test).type(torch.FloatTensor)\n", 118 | "y_test = torch.from_numpy(y_test.values).type(torch.LongTensor) " 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 7, 124 | "metadata": { 125 | "Collapsed": "false" 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "train = torch.utils.data.TensorDataset(x_train, y_train)\n", 130 | "test = torch.utils.data.TensorDataset(x_test, y_test)\n", 131 | "\n", 132 | "train_loader = torch.utils.data.DataLoader(train, batch_size = TRAIN_BS, shuffle = True, num_workers = 8)\n", 133 | "\n", 134 | "unlabeled_train = torch.from_numpy(x_unlabeled).type(torch.FloatTensor)\n", 135 | "\n", 136 | "unlabeled = torch.utils.data.TensorDataset(unlabeled_train)\n", 137 | "unlabeled_loader = torch.utils.data.DataLoader(unlabeled, batch_size = UNLABELED_BS, shuffle = True, num_workers = 8)\n", 138 | "\n", 139 | "test_loader = torch.utils.data.DataLoader(test, batch_size = TEST_BS, shuffle = True, num_workers = 8)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": { 145 | "Collapsed": "false" 146 | }, 147 | "source": [ 148 | "### Network Architecture\n", 149 | "\n", 150 | "We'll use a simple 2 layer Conv + 2 FC layer network with dropouts." 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 8, 156 | "metadata": { 157 | "Collapsed": "false" 158 | }, 159 | "outputs": [], 160 | "source": [ 161 | "# Architecture from : https://github.com/peimengsui/semi_supervised_mnist\n", 162 | "class Net(nn.Module):\n", 163 | " def __init__(self):\n", 164 | " super(Net, self).__init__()\n", 165 | " self.conv1 = nn.Conv2d(1, 20, kernel_size=5)\n", 166 | " self.conv2 = nn.Conv2d(20, 40, kernel_size=5)\n", 167 | " self.conv2_drop = nn.Dropout2d()\n", 168 | " self.fc1 = nn.Linear(640, 150)\n", 169 | " self.fc2 = nn.Linear(150, 10)\n", 170 | " self.log_softmax = nn.LogSoftmax(dim = 1)\n", 171 | "\n", 172 | " def forward(self, x):\n", 173 | " x = x.view(-1,1,28,28)\n", 174 | " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", 175 | " x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", 176 | " x = x.view(-1, 640)\n", 177 | " x = F.relu(self.fc1(x))\n", 178 | " x = F.dropout(x, training=self.training)\n", 179 | " x = F.relu(self.fc2(x))\n", 180 | " x = self.log_softmax(x)\n", 181 | " return x\n", 182 | " \n", 183 | "net = Net().cuda()" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": { 189 | "Collapsed": "false" 190 | }, 191 | "source": [ 192 | "Now let's define a function to evaluate the network and get loss and accuracy values. " 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 9, 198 | "metadata": { 199 | "Collapsed": "false" 200 | }, 201 | "outputs": [], 202 | "source": [ 203 | "def evaluate(model, test_loader):\n", 204 | " model.eval()\n", 205 | " correct = 0 \n", 206 | " loss = 0\n", 207 | " with torch.no_grad():\n", 208 | " for data, labels in test_loader:\n", 209 | " data = data.cuda()\n", 210 | " output = model(data)\n", 211 | " predicted = torch.max(output,1)[1]\n", 212 | " correct += (predicted == labels.cuda()).sum()\n", 213 | " loss += F.nll_loss(output, labels.cuda()).item()\n", 214 | "\n", 215 | " return (float(correct)/len(test)) *100, (loss/len(test_loader))" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": { 221 | "Collapsed": "false" 222 | }, 223 | "source": [ 224 | "First, let's train the model on the labeled set for 300 epochs" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 10, 230 | "metadata": { 231 | "Collapsed": "false" 232 | }, 233 | "outputs": [], 234 | "source": [ 235 | "from tqdm import tqdm_notebook\n", 236 | "def train_supervised(model, train_loader, test_loader):\n", 237 | " optimizer = torch.optim.SGD( model.parameters(), lr = 0.1)\n", 238 | " EPOCHS = 100\n", 239 | " model.train()\n", 240 | " for epoch in tqdm_notebook(range(EPOCHS)):\n", 241 | " correct = 0\n", 242 | " running_loss = 0\n", 243 | " for batch_idx, (X_batch, y_batch) in enumerate(train_loader):\n", 244 | " X_batch, y_batch = X_batch.cuda(), y_batch.cuda()\n", 245 | " \n", 246 | " output = model(X_batch)\n", 247 | " labeled_loss = F.nll_loss(output, y_batch)\n", 248 | " \n", 249 | " optimizer.zero_grad()\n", 250 | " labeled_loss.backward()\n", 251 | " optimizer.step()\n", 252 | " running_loss += labeled_loss.item()\n", 253 | " \n", 254 | " if epoch %10 == 0:\n", 255 | " test_acc, test_loss = evaluate(model, test_loader)\n", 256 | " print('Epoch: {} : Train Loss : {:.5f} | Test Acc : {:.5f} | Test Loss : {:.3f} '.format(epoch, running_loss/(10 * len(train)), test_acc, test_loss))\n", 257 | " model.train()\n", 258 | " " 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 11, 264 | "metadata": { 265 | "Collapsed": "false" 266 | }, 267 | "outputs": [ 268 | { 269 | "data": { 270 | "application/vnd.jupyter.widget-view+json": { 271 | "model_id": "ca16c55be32944fb873de7476cd7877f", 272 | "version_major": 2, 273 | "version_minor": 0 274 | }, 275 | "text/plain": [ 276 | "HBox(children=(IntProgress(value=0), HTML(value='')))" 277 | ] 278 | }, 279 | "metadata": {}, 280 | "output_type": "display_data" 281 | }, 282 | { 283 | "name": "stdout", 284 | "output_type": "stream", 285 | "text": [ 286 | "Epoch: 0 : Train Loss : 0.00726 | Test Acc : 12.74000 | Test Loss : 2.302 \n", 287 | "Epoch: 10 : Train Loss : 0.00725 | Test Acc : 29.54000 | Test Loss : 2.297 \n", 288 | "Epoch: 20 : Train Loss : 0.00528 | Test Acc : 55.26000 | Test Loss : 1.436 \n", 289 | "Epoch: 30 : Train Loss : 0.00154 | Test Acc : 87.87000 | Test Loss : 0.410 \n", 290 | "Epoch: 40 : Train Loss : 0.00102 | Test Acc : 92.41000 | Test Loss : 0.246 \n", 291 | "Epoch: 50 : Train Loss : 0.00076 | Test Acc : 93.89000 | Test Loss : 0.207 \n", 292 | "Epoch: 60 : Train Loss : 0.00060 | Test Acc : 93.95000 | Test Loss : 0.194 \n", 293 | "Epoch: 70 : Train Loss : 0.00048 | Test Acc : 94.25000 | Test Loss : 0.189 \n", 294 | "Epoch: 80 : Train Loss : 0.00038 | Test Acc : 94.42000 | Test Loss : 0.188 \n", 295 | "Epoch: 90 : Train Loss : 0.00034 | Test Acc : 94.61000 | Test Loss : 0.190 \n", 296 | "\n" 297 | ] 298 | } 299 | ], 300 | "source": [ 301 | "train_supervised(net, train_loader, test_loader)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 13, 307 | "metadata": { 308 | "Collapsed": "false" 309 | }, 310 | "outputs": [ 311 | { 312 | "name": "stdout", 313 | "output_type": "stream", 314 | "text": [ 315 | "Test Acc : 94.91000 | Test Loss : 0.188 \n" 316 | ] 317 | } 318 | ], 319 | "source": [ 320 | "test_acc, test_loss = evaluate(net, test_loader)\n", 321 | "print('Test Acc : {:.5f} | Test Loss : {:.3f} '.format(test_acc, test_loss))\n", 322 | "torch.save(net.state_dict(), 'saved_models/supervised_weights')" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 11, 328 | "metadata": { 329 | "Collapsed": "false" 330 | }, 331 | "outputs": [ 332 | { 333 | "data": { 334 | "text/plain": [ 335 | "" 336 | ] 337 | }, 338 | "execution_count": 11, 339 | "metadata": {}, 340 | "output_type": "execute_result" 341 | } 342 | ], 343 | "source": [ 344 | "net.load_state_dict(torch.load('saved_models/supervised_weights'))" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 14, 350 | "metadata": { 351 | "Collapsed": "false" 352 | }, 353 | "outputs": [], 354 | "source": [ 355 | "T1 = 100\n", 356 | "T2 = 700\n", 357 | "af = 3\n", 358 | "\n", 359 | "def alpha_weight(epoch):\n", 360 | " if epoch < T1:\n", 361 | " return 0.0\n", 362 | " elif epoch > T2:\n", 363 | " return af\n", 364 | " else:\n", 365 | " return ((epoch-T1) / (T2-T1))*af" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 15, 371 | "metadata": { 372 | "Collapsed": "false" 373 | }, 374 | "outputs": [], 375 | "source": [ 376 | "# Concept from : https://github.com/peimengsui/semi_supervised_mnist\n", 377 | "\n", 378 | "from tqdm import tqdm_notebook\n", 379 | "\n", 380 | "acc_scores = []\n", 381 | "unlabel = []\n", 382 | "pseudo_label = []\n", 383 | "\n", 384 | "alpha_log = []\n", 385 | "test_acc_log = []\n", 386 | "test_loss_log = []\n", 387 | "def semisup_train(model, train_loader, unlabeled_loader, test_loader):\n", 388 | " optimizer = torch.optim.SGD(model.parameters(), lr = 0.1)\n", 389 | " EPOCHS = 150\n", 390 | " \n", 391 | " # Instead of using current epoch we use a \"step\" variable to calculate alpha_weight\n", 392 | " # This helps the model converge faster\n", 393 | " step = 100 \n", 394 | " \n", 395 | " model.train()\n", 396 | " for epoch in tqdm_notebook(range(EPOCHS)):\n", 397 | " for batch_idx, x_unlabeled in enumerate(unlabeled_loader):\n", 398 | " \n", 399 | " \n", 400 | " # Forward Pass to get the pseudo labels\n", 401 | " x_unlabeled = x_unlabeled[0].cuda()\n", 402 | " model.eval()\n", 403 | " output_unlabeled = model(x_unlabeled)\n", 404 | " _, pseudo_labeled = torch.max(output_unlabeled, 1)\n", 405 | " model.train()\n", 406 | " \n", 407 | " \n", 408 | " \"\"\" ONLY FOR VISUALIZATION\"\"\"\n", 409 | " if (batch_idx < 3) and (epoch % 10 == 0):\n", 410 | " unlabel.append(x_unlabeled.cpu())\n", 411 | " pseudo_label.append(pseudo_labeled.cpu())\n", 412 | " \"\"\" ********************** \"\"\"\n", 413 | " \n", 414 | " # Now calculate the unlabeled loss using the pseudo label\n", 415 | " output = model(x_unlabeled)\n", 416 | " unlabeled_loss = alpha_weight(step) * F.nll_loss(output, pseudo_labeled) \n", 417 | " \n", 418 | " # Backpropogate\n", 419 | " optimizer.zero_grad()\n", 420 | " unlabeled_loss.backward()\n", 421 | " optimizer.step()\n", 422 | " \n", 423 | " \n", 424 | " # For every 50 batches train one epoch on labeled data \n", 425 | " if batch_idx % 50 == 0:\n", 426 | " \n", 427 | " # Normal training procedure\n", 428 | " for batch_idx, (X_batch, y_batch) in enumerate(train_loader):\n", 429 | " X_batch = X_batch.cuda()\n", 430 | " y_batch = y_batch.cuda()\n", 431 | " output = model(X_batch)\n", 432 | " labeled_loss = F.nll_loss(output, y_batch)\n", 433 | "\n", 434 | " optimizer.zero_grad()\n", 435 | " labeled_loss.backward()\n", 436 | " optimizer.step()\n", 437 | " \n", 438 | " # Now we increment step by 1\n", 439 | " step += 1\n", 440 | " \n", 441 | "\n", 442 | " test_acc, test_loss =evaluate(model, test_loader)\n", 443 | " print('Epoch: {} : Alpha Weight : {:.5f} | Test Acc : {:.5f} | Test Loss : {:.3f} '.format(epoch, alpha_weight(step), test_acc, test_loss))\n", 444 | " \n", 445 | " \"\"\" LOGGING VALUES \"\"\"\n", 446 | " alpha_log.append(alpha_weight(step))\n", 447 | " test_acc_log.append(test_acc/100)\n", 448 | " test_loss_log.append(test_loss)\n", 449 | " \"\"\" ************** \"\"\"\n", 450 | " model.train()\n", 451 | " " 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 16, 457 | "metadata": { 458 | "Collapsed": "false" 459 | }, 460 | "outputs": [ 461 | { 462 | "data": { 463 | "application/vnd.jupyter.widget-view+json": { 464 | "model_id": "7a0e527451814223a49a8601ebc009d2", 465 | "version_major": 2, 466 | "version_minor": 0 467 | }, 468 | "text/plain": [ 469 | "HBox(children=(IntProgress(value=0, max=150), HTML(value='')))" 470 | ] 471 | }, 472 | "metadata": {}, 473 | "output_type": "display_data" 474 | }, 475 | { 476 | "name": "stdout", 477 | "output_type": "stream", 478 | "text": [ 479 | "Epoch: 0 : Alpha Weight : 0.02500 | Test Acc : 95.05000 | Test Loss : 0.192 \n", 480 | "Epoch: 1 : Alpha Weight : 0.05000 | Test Acc : 95.07000 | Test Loss : 0.189 \n", 481 | "Epoch: 2 : Alpha Weight : 0.07500 | Test Acc : 94.99000 | Test Loss : 0.197 \n", 482 | "Epoch: 3 : Alpha Weight : 0.10000 | Test Acc : 95.08000 | Test Loss : 0.192 \n", 483 | "Epoch: 4 : Alpha Weight : 0.12500 | Test Acc : 95.22000 | Test Loss : 0.192 \n", 484 | "Epoch: 5 : Alpha Weight : 0.15000 | Test Acc : 95.24000 | Test Loss : 0.183 \n", 485 | "Epoch: 6 : Alpha Weight : 0.17500 | Test Acc : 95.35000 | Test Loss : 0.187 \n", 486 | "Epoch: 7 : Alpha Weight : 0.20000 | Test Acc : 95.25000 | Test Loss : 0.181 \n", 487 | "Epoch: 8 : Alpha Weight : 0.22500 | Test Acc : 95.39000 | Test Loss : 0.180 \n", 488 | "Epoch: 9 : Alpha Weight : 0.25000 | Test Acc : 95.42000 | Test Loss : 0.175 \n", 489 | "Epoch: 10 : Alpha Weight : 0.27500 | Test Acc : 95.63000 | Test Loss : 0.171 \n", 490 | "Epoch: 11 : Alpha Weight : 0.30000 | Test Acc : 95.71000 | Test Loss : 0.171 \n", 491 | "Epoch: 12 : Alpha Weight : 0.32500 | Test Acc : 95.98000 | Test Loss : 0.158 \n", 492 | "Epoch: 13 : Alpha Weight : 0.35000 | Test Acc : 95.95000 | Test Loss : 0.159 \n", 493 | "Epoch: 14 : Alpha Weight : 0.37500 | Test Acc : 95.84000 | Test Loss : 0.163 \n", 494 | "Epoch: 15 : Alpha Weight : 0.40000 | Test Acc : 95.97000 | Test Loss : 0.154 \n", 495 | "Epoch: 16 : Alpha Weight : 0.42500 | Test Acc : 96.05000 | Test Loss : 0.158 \n", 496 | "Epoch: 17 : Alpha Weight : 0.45000 | Test Acc : 96.12000 | Test Loss : 0.155 \n", 497 | "Epoch: 18 : Alpha Weight : 0.47500 | Test Acc : 96.22000 | Test Loss : 0.154 \n", 498 | "Epoch: 19 : Alpha Weight : 0.50000 | Test Acc : 96.29000 | Test Loss : 0.147 \n", 499 | "Epoch: 20 : Alpha Weight : 0.52500 | Test Acc : 96.15000 | Test Loss : 0.150 \n", 500 | "Epoch: 21 : Alpha Weight : 0.55000 | Test Acc : 96.38000 | Test Loss : 0.145 \n", 501 | "Epoch: 22 : Alpha Weight : 0.57500 | Test Acc : 96.31000 | Test Loss : 0.145 \n", 502 | "Epoch: 23 : Alpha Weight : 0.60000 | Test Acc : 96.35000 | Test Loss : 0.141 \n", 503 | "Epoch: 24 : Alpha Weight : 0.62500 | Test Acc : 96.55000 | Test Loss : 0.137 \n", 504 | "Epoch: 25 : Alpha Weight : 0.65000 | Test Acc : 96.59000 | Test Loss : 0.134 \n", 505 | "Epoch: 26 : Alpha Weight : 0.67500 | Test Acc : 96.72000 | Test Loss : 0.135 \n", 506 | "Epoch: 27 : Alpha Weight : 0.70000 | Test Acc : 96.67000 | Test Loss : 0.132 \n", 507 | "Epoch: 28 : Alpha Weight : 0.72500 | Test Acc : 96.56000 | Test Loss : 0.135 \n", 508 | "Epoch: 29 : Alpha Weight : 0.75000 | Test Acc : 96.76000 | Test Loss : 0.129 \n", 509 | "Epoch: 30 : Alpha Weight : 0.77500 | Test Acc : 96.73000 | Test Loss : 0.133 \n", 510 | "Epoch: 31 : Alpha Weight : 0.80000 | Test Acc : 96.73000 | Test Loss : 0.134 \n", 511 | "Epoch: 32 : Alpha Weight : 0.82500 | Test Acc : 96.71000 | Test Loss : 0.128 \n", 512 | "Epoch: 33 : Alpha Weight : 0.85000 | Test Acc : 96.70000 | Test Loss : 0.135 \n", 513 | "Epoch: 34 : Alpha Weight : 0.87500 | Test Acc : 96.86000 | Test Loss : 0.127 \n", 514 | "Epoch: 35 : Alpha Weight : 0.90000 | Test Acc : 96.55000 | Test Loss : 0.132 \n", 515 | "Epoch: 36 : Alpha Weight : 0.92500 | Test Acc : 96.59000 | Test Loss : 0.132 \n", 516 | "Epoch: 37 : Alpha Weight : 0.95000 | Test Acc : 96.82000 | Test Loss : 0.126 \n", 517 | "Epoch: 38 : Alpha Weight : 0.97500 | Test Acc : 97.22000 | Test Loss : 0.116 \n", 518 | "Epoch: 39 : Alpha Weight : 1.00000 | Test Acc : 97.10000 | Test Loss : 0.118 \n", 519 | "Epoch: 40 : Alpha Weight : 1.02500 | Test Acc : 96.48000 | Test Loss : 0.144 \n", 520 | "Epoch: 41 : Alpha Weight : 1.05000 | Test Acc : 96.91000 | Test Loss : 0.131 \n", 521 | "Epoch: 42 : Alpha Weight : 1.07500 | Test Acc : 97.08000 | Test Loss : 0.115 \n", 522 | "Epoch: 43 : Alpha Weight : 1.10000 | Test Acc : 97.15000 | Test Loss : 0.117 \n", 523 | "Epoch: 44 : Alpha Weight : 1.12500 | Test Acc : 97.18000 | Test Loss : 0.112 \n", 524 | "Epoch: 45 : Alpha Weight : 1.15000 | Test Acc : 97.15000 | Test Loss : 0.115 \n", 525 | "Epoch: 46 : Alpha Weight : 1.17500 | Test Acc : 97.26000 | Test Loss : 0.110 \n", 526 | "Epoch: 47 : Alpha Weight : 1.20000 | Test Acc : 97.14000 | Test Loss : 0.110 \n", 527 | "Epoch: 48 : Alpha Weight : 1.22500 | Test Acc : 97.27000 | Test Loss : 0.110 \n", 528 | "Epoch: 49 : Alpha Weight : 1.25000 | Test Acc : 97.47000 | Test Loss : 0.108 \n", 529 | "Epoch: 50 : Alpha Weight : 1.27500 | Test Acc : 97.10000 | Test Loss : 0.118 \n", 530 | "Epoch: 51 : Alpha Weight : 1.30000 | Test Acc : 97.23000 | Test Loss : 0.115 \n", 531 | "Epoch: 52 : Alpha Weight : 1.32500 | Test Acc : 97.32000 | Test Loss : 0.105 \n", 532 | "Epoch: 53 : Alpha Weight : 1.35000 | Test Acc : 97.17000 | Test Loss : 0.112 \n", 533 | "Epoch: 54 : Alpha Weight : 1.37500 | Test Acc : 97.47000 | Test Loss : 0.107 \n", 534 | "Epoch: 55 : Alpha Weight : 1.40000 | Test Acc : 97.44000 | Test Loss : 0.110 \n", 535 | "Epoch: 56 : Alpha Weight : 1.42500 | Test Acc : 97.36000 | Test Loss : 0.106 \n", 536 | "Epoch: 57 : Alpha Weight : 1.45000 | Test Acc : 97.54000 | Test Loss : 0.107 \n", 537 | "Epoch: 58 : Alpha Weight : 1.47500 | Test Acc : 97.52000 | Test Loss : 0.106 \n", 538 | "Epoch: 59 : Alpha Weight : 1.50000 | Test Acc : 97.52000 | Test Loss : 0.102 \n", 539 | "Epoch: 60 : Alpha Weight : 1.52500 | Test Acc : 97.63000 | Test Loss : 0.100 \n", 540 | "Epoch: 61 : Alpha Weight : 1.55000 | Test Acc : 97.54000 | Test Loss : 0.106 \n", 541 | "Epoch: 62 : Alpha Weight : 1.57500 | Test Acc : 97.58000 | Test Loss : 0.101 \n", 542 | "Epoch: 63 : Alpha Weight : 1.60000 | Test Acc : 97.25000 | Test Loss : 0.111 \n", 543 | "Epoch: 64 : Alpha Weight : 1.62500 | Test Acc : 97.56000 | Test Loss : 0.100 \n", 544 | "Epoch: 65 : Alpha Weight : 1.65000 | Test Acc : 97.29000 | Test Loss : 0.102 \n", 545 | "Epoch: 66 : Alpha Weight : 1.67500 | Test Acc : 97.69000 | Test Loss : 0.097 \n", 546 | "Epoch: 67 : Alpha Weight : 1.70000 | Test Acc : 97.72000 | Test Loss : 0.097 \n", 547 | "Epoch: 68 : Alpha Weight : 1.72500 | Test Acc : 97.72000 | Test Loss : 0.097 \n", 548 | "Epoch: 69 : Alpha Weight : 1.75000 | Test Acc : 97.63000 | Test Loss : 0.094 \n", 549 | "Epoch: 70 : Alpha Weight : 1.77500 | Test Acc : 97.74000 | Test Loss : 0.092 \n", 550 | "Epoch: 71 : Alpha Weight : 1.80000 | Test Acc : 97.73000 | Test Loss : 0.096 \n", 551 | "Epoch: 72 : Alpha Weight : 1.82500 | Test Acc : 96.21000 | Test Loss : 0.138 \n", 552 | "Epoch: 73 : Alpha Weight : 1.85000 | Test Acc : 97.75000 | Test Loss : 0.090 \n", 553 | "Epoch: 74 : Alpha Weight : 1.87500 | Test Acc : 95.25000 | Test Loss : 0.157 \n", 554 | "Epoch: 75 : Alpha Weight : 1.90000 | Test Acc : 97.92000 | Test Loss : 0.089 \n", 555 | "Epoch: 76 : Alpha Weight : 1.92500 | Test Acc : 98.01000 | Test Loss : 0.088 \n", 556 | "Epoch: 77 : Alpha Weight : 1.95000 | Test Acc : 97.84000 | Test Loss : 0.089 \n", 557 | "Epoch: 78 : Alpha Weight : 1.97500 | Test Acc : 97.66000 | Test Loss : 0.096 \n", 558 | "Epoch: 79 : Alpha Weight : 2.00000 | Test Acc : 97.91000 | Test Loss : 0.088 \n", 559 | "Epoch: 80 : Alpha Weight : 2.02500 | Test Acc : 97.95000 | Test Loss : 0.082 \n", 560 | "Epoch: 81 : Alpha Weight : 2.05000 | Test Acc : 97.73000 | Test Loss : 0.090 \n", 561 | "Epoch: 82 : Alpha Weight : 2.07500 | Test Acc : 97.83000 | Test Loss : 0.091 \n", 562 | "Epoch: 83 : Alpha Weight : 2.10000 | Test Acc : 97.98000 | Test Loss : 0.083 \n", 563 | "Epoch: 84 : Alpha Weight : 2.12500 | Test Acc : 97.85000 | Test Loss : 0.088 \n", 564 | "Epoch: 85 : Alpha Weight : 2.15000 | Test Acc : 97.94000 | Test Loss : 0.085 \n", 565 | "Epoch: 86 : Alpha Weight : 2.17500 | Test Acc : 97.54000 | Test Loss : 0.102 \n", 566 | "Epoch: 87 : Alpha Weight : 2.20000 | Test Acc : 98.01000 | Test Loss : 0.083 \n", 567 | "Epoch: 88 : Alpha Weight : 2.22500 | Test Acc : 97.99000 | Test Loss : 0.087 \n", 568 | "Epoch: 89 : Alpha Weight : 2.25000 | Test Acc : 97.89000 | Test Loss : 0.089 \n", 569 | "Epoch: 90 : Alpha Weight : 2.27500 | Test Acc : 97.97000 | Test Loss : 0.083 \n", 570 | "Epoch: 91 : Alpha Weight : 2.30000 | Test Acc : 97.89000 | Test Loss : 0.091 \n", 571 | "Epoch: 92 : Alpha Weight : 2.32500 | Test Acc : 98.02000 | Test Loss : 0.085 \n", 572 | "Epoch: 93 : Alpha Weight : 2.35000 | Test Acc : 98.05000 | Test Loss : 0.082 \n", 573 | "Epoch: 94 : Alpha Weight : 2.37500 | Test Acc : 96.84000 | Test Loss : 0.133 \n", 574 | "Epoch: 95 : Alpha Weight : 2.40000 | Test Acc : 97.94000 | Test Loss : 0.087 \n", 575 | "Epoch: 96 : Alpha Weight : 2.42500 | Test Acc : 97.99000 | Test Loss : 0.086 \n", 576 | "Epoch: 97 : Alpha Weight : 2.45000 | Test Acc : 98.00000 | Test Loss : 0.083 \n", 577 | "Epoch: 98 : Alpha Weight : 2.47500 | Test Acc : 98.12000 | Test Loss : 0.080 \n", 578 | "Epoch: 99 : Alpha Weight : 2.50000 | Test Acc : 97.95000 | Test Loss : 0.091 \n", 579 | "Epoch: 100 : Alpha Weight : 2.52500 | Test Acc : 98.08000 | Test Loss : 0.083 \n", 580 | "Epoch: 101 : Alpha Weight : 2.55000 | Test Acc : 96.73000 | Test Loss : 0.132 \n", 581 | "Epoch: 102 : Alpha Weight : 2.57500 | Test Acc : 98.05000 | Test Loss : 0.080 \n", 582 | "Epoch: 103 : Alpha Weight : 2.60000 | Test Acc : 97.83000 | Test Loss : 0.087 \n", 583 | "Epoch: 104 : Alpha Weight : 2.62500 | Test Acc : 98.13000 | Test Loss : 0.082 \n", 584 | "Epoch: 105 : Alpha Weight : 2.65000 | Test Acc : 98.12000 | Test Loss : 0.078 \n", 585 | "Epoch: 106 : Alpha Weight : 2.67500 | Test Acc : 98.08000 | Test Loss : 0.082 \n", 586 | "Epoch: 107 : Alpha Weight : 2.70000 | Test Acc : 97.73000 | Test Loss : 0.094 \n", 587 | "Epoch: 108 : Alpha Weight : 2.72500 | Test Acc : 97.92000 | Test Loss : 0.086 \n", 588 | "Epoch: 109 : Alpha Weight : 2.75000 | Test Acc : 98.17000 | Test Loss : 0.082 \n", 589 | "Epoch: 110 : Alpha Weight : 2.77500 | Test Acc : 98.21000 | Test Loss : 0.081 \n", 590 | "Epoch: 111 : Alpha Weight : 2.80000 | Test Acc : 98.21000 | Test Loss : 0.075 \n", 591 | "Epoch: 112 : Alpha Weight : 2.82500 | Test Acc : 98.08000 | Test Loss : 0.081 \n", 592 | "Epoch: 113 : Alpha Weight : 2.85000 | Test Acc : 74.14000 | Test Loss : 0.813 \n", 593 | "Epoch: 114 : Alpha Weight : 2.87500 | Test Acc : 98.17000 | Test Loss : 0.078 \n", 594 | "Epoch: 115 : Alpha Weight : 2.90000 | Test Acc : 98.08000 | Test Loss : 0.075 \n", 595 | "Epoch: 116 : Alpha Weight : 2.92500 | Test Acc : 98.05000 | Test Loss : 0.083 \n", 596 | "Epoch: 117 : Alpha Weight : 2.95000 | Test Acc : 97.94000 | Test Loss : 0.088 \n", 597 | "Epoch: 118 : Alpha Weight : 2.97500 | Test Acc : 98.19000 | Test Loss : 0.075 \n", 598 | "Epoch: 119 : Alpha Weight : 3.00000 | Test Acc : 98.17000 | Test Loss : 0.079 \n", 599 | "Epoch: 120 : Alpha Weight : 3.00000 | Test Acc : 98.02000 | Test Loss : 0.080 \n", 600 | "Epoch: 121 : Alpha Weight : 3.00000 | Test Acc : 98.16000 | Test Loss : 0.077 \n", 601 | "Epoch: 122 : Alpha Weight : 3.00000 | Test Acc : 98.10000 | Test Loss : 0.080 \n", 602 | "Epoch: 123 : Alpha Weight : 3.00000 | Test Acc : 97.52000 | Test Loss : 0.112 \n", 603 | "Epoch: 124 : Alpha Weight : 3.00000 | Test Acc : 98.26000 | Test Loss : 0.077 \n", 604 | "Epoch: 125 : Alpha Weight : 3.00000 | Test Acc : 98.04000 | Test Loss : 0.083 \n", 605 | "Epoch: 126 : Alpha Weight : 3.00000 | Test Acc : 98.19000 | Test Loss : 0.081 \n", 606 | "Epoch: 127 : Alpha Weight : 3.00000 | Test Acc : 98.20000 | Test Loss : 0.074 \n", 607 | "Epoch: 128 : Alpha Weight : 3.00000 | Test Acc : 98.14000 | Test Loss : 0.076 \n", 608 | "Epoch: 129 : Alpha Weight : 3.00000 | Test Acc : 98.24000 | Test Loss : 0.075 \n", 609 | "Epoch: 130 : Alpha Weight : 3.00000 | Test Acc : 97.48000 | Test Loss : 0.110 \n", 610 | "Epoch: 131 : Alpha Weight : 3.00000 | Test Acc : 98.32000 | Test Loss : 0.075 \n", 611 | "Epoch: 132 : Alpha Weight : 3.00000 | Test Acc : 98.33000 | Test Loss : 0.074 \n", 612 | "Epoch: 133 : Alpha Weight : 3.00000 | Test Acc : 98.38000 | Test Loss : 0.075 \n", 613 | "Epoch: 134 : Alpha Weight : 3.00000 | Test Acc : 98.24000 | Test Loss : 0.080 \n", 614 | "Epoch: 135 : Alpha Weight : 3.00000 | Test Acc : 98.30000 | Test Loss : 0.074 \n", 615 | "Epoch: 136 : Alpha Weight : 3.00000 | Test Acc : 98.27000 | Test Loss : 0.079 \n", 616 | "Epoch: 137 : Alpha Weight : 3.00000 | Test Acc : 98.30000 | Test Loss : 0.078 \n", 617 | "Epoch: 138 : Alpha Weight : 3.00000 | Test Acc : 98.35000 | Test Loss : 0.073 \n", 618 | "Epoch: 139 : Alpha Weight : 3.00000 | Test Acc : 98.29000 | Test Loss : 0.078 \n", 619 | "Epoch: 140 : Alpha Weight : 3.00000 | Test Acc : 98.31000 | Test Loss : 0.078 \n", 620 | "Epoch: 141 : Alpha Weight : 3.00000 | Test Acc : 98.29000 | Test Loss : 0.075 \n", 621 | "Epoch: 142 : Alpha Weight : 3.00000 | Test Acc : 98.35000 | Test Loss : 0.075 \n", 622 | "Epoch: 143 : Alpha Weight : 3.00000 | Test Acc : 98.32000 | Test Loss : 0.075 \n", 623 | "Epoch: 144 : Alpha Weight : 3.00000 | Test Acc : 98.07000 | Test Loss : 0.083 \n", 624 | "Epoch: 145 : Alpha Weight : 3.00000 | Test Acc : 98.21000 | Test Loss : 0.080 \n", 625 | "Epoch: 146 : Alpha Weight : 3.00000 | Test Acc : 98.39000 | Test Loss : 0.074 \n", 626 | "Epoch: 147 : Alpha Weight : 3.00000 | Test Acc : 98.34000 | Test Loss : 0.072 \n", 627 | "Epoch: 148 : Alpha Weight : 3.00000 | Test Acc : 98.40000 | Test Loss : 0.072 \n", 628 | "Epoch: 149 : Alpha Weight : 3.00000 | Test Acc : 98.02000 | Test Loss : 0.077 \n", 629 | "\n" 630 | ] 631 | } 632 | ], 633 | "source": [ 634 | "semisup_train(net, train_loader, unlabeled_loader, test_loader)" 635 | ] 636 | }, 637 | { 638 | "cell_type": "code", 639 | "execution_count": 17, 640 | "metadata": { 641 | "Collapsed": "false" 642 | }, 643 | "outputs": [ 644 | { 645 | "name": "stdout", 646 | "output_type": "stream", 647 | "text": [ 648 | "Test Acc : 98.02000 | Test Loss : 0.077 \n" 649 | ] 650 | } 651 | ], 652 | "source": [ 653 | "test_acc, test_loss = evaluate(net, test_loader)\n", 654 | "print('Test Acc : {:.5f} | Test Loss : {:.3f} '.format(test_acc, test_loss))\n", 655 | "torch.save(net.state_dict(), 'saved_models/semi_supervised_weights')" 656 | ] 657 | }, 658 | { 659 | "cell_type": "markdown", 660 | "metadata": { 661 | "Collapsed": "false" 662 | }, 663 | "source": [ 664 | "## Visualizations" 665 | ] 666 | }, 667 | { 668 | "cell_type": "code", 669 | "execution_count": 15, 670 | "metadata": { 671 | "Collapsed": "false" 672 | }, 673 | "outputs": [], 674 | "source": [ 675 | "unlabel = np.concatenate([u.cpu().numpy() for u in unlabel])\n", 676 | "pseudo_label = np.concatenate([u.cpu().numpy() for u in pseudo_label])" 677 | ] 678 | }, 679 | { 680 | "cell_type": "code", 681 | "execution_count": 17, 682 | "metadata": { 683 | "Collapsed": "false" 684 | }, 685 | "outputs": [], 686 | "source": [ 687 | "x = pd.read_csv('data/mnist_train.csv')\n", 688 | "y = x['label']\n", 689 | "x.drop(['label'], inplace = True, axis = 1)\n", 690 | "\n", 691 | "x = normalizer.transform(x.values)\n", 692 | "\n", 693 | "tsne_x = np.concatenate([x, x_train, unlabel])\n", 694 | "tsne_y = np.concatenate([y.values, y_train, pseudo_label])\n", 695 | "\n", 696 | "embeddings = TSNE(perplexity = 30, n_jobs=-1, verbose = 1, n_iter = 500).fit_transform(tsne_x)" 697 | ] 698 | }, 699 | { 700 | "cell_type": "code", 701 | "execution_count": 21, 702 | "metadata": { 703 | "Collapsed": "false" 704 | }, 705 | "outputs": [ 706 | { 707 | "name": "stdout", 708 | "output_type": "stream", 709 | "text": [ 710 | "Using matplotlib backend: GTK3Agg\n" 711 | ] 712 | }, 713 | { 714 | "data": { 715 | "application/vnd.jupyter.widget-view+json": { 716 | "model_id": "7bf931bf279b493388e547b56aae4b26", 717 | "version_major": 2, 718 | "version_minor": 0 719 | }, 720 | "text/plain": [ 721 | "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" 722 | ] 723 | }, 724 | "metadata": {}, 725 | "output_type": "display_data" 726 | }, 727 | { 728 | "name": "stdout", 729 | "output_type": "stream", 730 | "text": [ 731 | "\n" 732 | ] 733 | } 734 | ], 735 | "source": [ 736 | "from tqdm import tqdm_notebook\n", 737 | "%matplotlib\n", 738 | "plt.figure(figsize=(15,10))\n", 739 | "\n", 740 | "step_size = UNLABELED_BS * 3\n", 741 | "base_index = x.shape[0]\n", 742 | "epoch = 0\n", 743 | "for i in tqdm_notebook(range(0,unlabel.shape[0], step_size)):\n", 744 | " plt.scatter(embeddings[:base_index, 0], embeddings[:base_index, 1], c=tsne_y[:base_index], cmap=plt.cm.get_cmap(\"jet\", 10), marker='s', alpha = 0.002, s = 14**2)\n", 745 | " a = base_index\n", 746 | " b = base_index + num_train_samples\n", 747 | " plt.scatter(embeddings[a:b, 0], embeddings[a:b, 1], c=tsne_y[a:b], cmap=plt.cm.get_cmap(\"jet\", 10), marker='o', alpha = 0.3, s = 90**1)\n", 748 | " a = base_index + num_train_samples + i\n", 749 | " b = base_index + num_train_samples + i + step_size\n", 750 | " plt.scatter(embeddings[a:b, 0], embeddings[a:b, 1], c=tsne_y[a:b], cmap=plt.cm.get_cmap(\"jet\", 10), marker='*', s = 150**1)\n", 751 | " plt.colorbar(ticks=range(10))\n", 752 | " plt.clim(-0.5, 9.5)\n", 753 | " plt.title('Epoch : ' + str(epoch) +' Test Acc : {:.2f}%'.format(test_acc_log[epoch]*100), fontsize = 20)\n", 754 | " plt.savefig('imgs/tsne' + str(i) + '.png')\n", 755 | " plt.draw()\n", 756 | " plt.pause(5)\n", 757 | " plt.clf()\n", 758 | " epoch += 10\n" 759 | ] 760 | } 761 | ], 762 | "metadata": { 763 | "kernelspec": { 764 | "display_name": "Python 3", 765 | "language": "python", 766 | "name": "python3" 767 | }, 768 | "language_info": { 769 | "codemirror_mode": { 770 | "name": "ipython", 771 | "version": 3 772 | }, 773 | "file_extension": ".py", 774 | "mimetype": "text/x-python", 775 | "name": "python", 776 | "nbconvert_exporter": "python", 777 | "pygments_lexer": "ipython3", 778 | "version": "3.7.3" 779 | } 780 | }, 781 | "nbformat": 4, 782 | "nbformat_minor": 4 783 | } 784 | -------------------------------------------------------------------------------- /pseudo_label-Logistic_reg.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "Collapsed": "false" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import pandas as pd \n", 12 | "import numpy as np\n", 13 | "\n", 14 | "num_train_samples = 1000\n", 15 | "samples_per_class = int(num_train_samples/10)\n", 16 | "\n", 17 | "x = pd.read_csv('data/mnist_train.csv').sample(frac = 1)\n", 18 | "y = x['label']\n", 19 | "x.drop(['label'], inplace = True, axis = 1)\n", 20 | "\n", 21 | "x_test = pd.read_csv('data/mnist_test.csv')\n", 22 | "y_test = x_test['label']\n", 23 | "x_test.drop(['label'], inplace = True, axis = 1)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": { 30 | "Collapsed": "false" 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "x_train, x_unlabeled = x[y.values == 0].values[:samples_per_class], x[y.values == 0].values[samples_per_class: ]\n", 35 | "y_train, y_unlabeled = y[y.values == 0].values[:samples_per_class], y[y.values == 0].values[samples_per_class: ]\n", 36 | "\n", 37 | "\n", 38 | "for i in range(1,10):\n", 39 | " x_train = np.concatenate([x_train, x[y.values == i].values[:samples_per_class]], axis = 0)\n", 40 | " y_train = np.concatenate([y_train, y[y.values == i].values[:samples_per_class]], axis = 0)\n", 41 | " \n", 42 | " x_unlabeled = np.concatenate([x_unlabeled, x[y.values == i].values[samples_per_class: ]], axis = 0)\n", 43 | " y_unlabeled = np.concatenate([y_unlabeled, y[y.values == i].values[samples_per_class: ]], axis = 0)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": { 49 | "Collapsed": "false" 50 | }, 51 | "source": [ 52 | "Shuffle the data" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "metadata": { 59 | "Collapsed": "false" 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "p = np.random.permutation(x_train.shape[0])\n", 64 | "x_train, y_train = x_train[p], y_train[p]\n", 65 | "\n", 66 | "p = np.random.permutation(x_unlabeled.shape[0])\n", 67 | "x_unlabeled, y_unlabeled = x_unlabeled[p], y_unlabeled[p]" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": { 73 | "Collapsed": "false" 74 | }, 75 | "source": [ 76 | "## Feature Engineering" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 4, 82 | "metadata": { 83 | "Collapsed": "false" 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "from sklearn.preprocessing import Normalizer\n", 88 | "scaler = Normalizer()\n", 89 | "x_train = scaler.fit_transform(x_train)\n", 90 | "x_test = scaler.transform(x_test)\n", 91 | "x_unlabeled = scaler.transform(x_unlabeled)\n", 92 | "\n" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 5, 98 | "metadata": { 99 | "Collapsed": "false" 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "from sklearn.decomposition import PCA\n", 104 | "\n", 105 | "pca = PCA(n_components = 50)\n", 106 | "x_train_pca = pca.fit_transform(x_train)\n", 107 | "x_test_pca = pca.transform(x_test)\n", 108 | "x_unlabeled_pca = pca.transform(x_unlabeled)\n", 109 | "\n" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 6, 115 | "metadata": { 116 | "Collapsed": "false" 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "from sklearn.preprocessing import PolynomialFeatures\n", 121 | "\n", 122 | "poly = PolynomialFeatures()\n", 123 | "x_train_poly = poly.fit_transform(x_train_pca)\n", 124 | "x_test_poly = poly.transform(x_test_pca)\n", 125 | "x_unlabeled_poly = poly.transform(x_unlabeled_pca)\n", 126 | "\n" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": { 132 | "Collapsed": "false" 133 | }, 134 | "source": [ 135 | "## Effect of Increasing Data" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 29, 141 | "metadata": { 142 | "Collapsed": "false" 143 | }, 144 | "outputs": [ 145 | { 146 | "data": { 147 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nO3deXwddb3/8dcnJ1vbpGvSfaUrpaUU0gKCskNVuK0I2iKyiCJeweXihj9ELnpVrl64XuGqWAEFAZGrbUWkogRQBLrQFtqSdIW26ZJ0SZM2bbbz+f0xk3JIT5LTNCcny/v5eJzHmX0+M5nM58z3O/Mdc3dEREQaS0t1ACIi0jEpQYiISFxKECIiEpcShIiIxKUEISIicSlBiIhIXEoQcoSZfdfMdpvZzrD/I2a21cwOmNn0FMbVIeJIBTO708weTXUcqWRm3zSz+amOoztSguhGzOxtMzsUnmgbPveF40YCtwKT3X1wOMuPgJvdPcfdVxzHet3Mxh1H6M3G0QbL75TM7Fwzi8b8LbeZ2ZNmNuMYltEuCeh41uPu33P3T7d1TNIyJYju57LwRNvwuTkcPhLY4+6lMdOOAta0f4hHafc4zCy9Pdd3HLa7ew6QC5wBFAF/N7MLUhuWdAVKEIKZXQg8BwwNf4k+bmYHgAiwysw2htMNNbP/M7MyM9tsZl+IWUYkLArYaGaVZrbczEaY2UvhJKvCZX88zvrTzOx2M3vHzErN7Ndm1sfMsuLF0cK23Bn+iv51GMcaMyuIGT/CzH4fbsOemCuo68zsZTO718z2AHeGwz9lZm+Z2T4zW2xmo2KW9eOw6Ksi3N73x4ybaWbLwnG7zOyemHFnmNk/zazczFaZ2bkx48aY2Yth7M8BeS1tM4AHtrn7HcB84O6W4jSzWcA3gY+Hf5tV4fDrw22uNLNNZvbZmGXlmdnTYex7zezvZpYWjot7fDS1njh/u6+bWUm43uKGJBd79WFm9zW6Aq4zszubW78cB3fXp5t8gLeBC5sYdy6wrdEwB8aF3WnAcuAOIBM4AdgEXBKO/yrwJjARMGAaMKDxcppY96eADeEyc4DfA4/Ei6OJ+WPjvBM4DHyIILF8H3g1HBcBVgH3Ar2AbODscNx1QB1wC5AO9ABmh3GdGA67HfhnzHqvBgaE424FdgLZ4bhXgE+G3TnAGWH3MGBPGF8acFHYnx8z3z1AFvABoBJ4NNG/WTj8fCAK9EogzjsbLx/4MDA2/DueA1QBp4bjvg/8DMgIP+8Pp2vp+DhqPY3WORHYCgwN+0cDY5ubFzgFKAOmt7R+fVp5zkh1APq04x87SBAHgPKYz2fCcUedbHjvifd0YEuj8bcBD4XdxcDsJtbb0gn+b8C/xvRPBGqB9ATnb5wg/hozbjJwKOw+MzyhpMdZxnVxtu/PwA0x/WnhyXJUE3HsA6aF3S8B/w7kNZrm68Qkv3DYYuBagmK+OsITezjusaZOrPH+ZuHwSeE+GZZAnM2euMNpFgBfDLvvAhY2/nskcHw0ux5gHFAKXAhkNBp31LxAfng8z01k/fq07qMipu5njrv3jfn8IsH5RhEUQZU3fAiKDQaF40cALRYBNWEo8E5M/zsEv3YHxZ+8RTtjuquA7LBOYQTwjrvXNTHf1kb9o4Afx2zvXoJfy8MAzOwrYVHM/nB8H94tEroBmAAUmdlSM7s0ZplXNtqPZwNDCPbDPnc/GBND7H5J1DCCBFGeQJxHMbMPmtmrYRFSOcHVTsP0PyS4qvpLWPz0jZjtau74aJa7bwC+RJAMSs3sCTMb2kR8GcBTwGPu/kRbrF/i6ywVcZJ6W4HN7j6+mfFjgdWtWPZ2gn/wBg2/pHe1YlnN2QqMNLP0JpJE46aNtwL/4e6/aTxhWI7/NeACYI27R81sH0ECwd3XA/PC8vnLgafMbEC4zEfc/TNxljkK6GdmvWKSxMg4cbXkI8Dr7n6wpTgbL9vMsoD/A64BFrp7rZktiNmuSoJiqlvNbArwvJktpeXjo8VtcPfHgMfMrDfwc4J6lE/GmfQnQAVBkV+DltYvraArCEnUEqAyrEjsYUGl9BR795bK+cB3zGy8BU4OT4gQnOhPaGbZjwNfDitoc4DvAb9t5pf+8WzDDuAHZtbLzLLN7Kxmpv8ZcJuZnQRgQcX5leG4XIIkVgakm9kdQO+GGc3sajPLd/co4S95gnqBR4HLzOyScB9mW3C76nB3fwdYBvy7mWWa2dnAZYlsWLjPh5nZt4FPE/x6bjFOgr/N6IaKZoLy+6xw+joz+yBwccx6LjWzcWZmwH6gPtyulo6PxutpHP9EMzs/TFCHgUPhchtP91mCepFPhPu2QUvrl1ZQguh+/tjoLpA/JDKTu9cDlxJUDG4GdhMkhT7hJPcATwJ/Ifh190uCil4Iig1+FV76fyzO4h8EHiEot99McIK4pRXblsg2XEZQ3r0F2AYcdVdVzPR/IPgV+4SZVRBcHX0wHL0YeBZYR1AMdJj3FlHNAtZYcBfWjwnKyg+5+1aCyu9vEpyEtxJU8Df8L15FUJ6+F/g28OsWNmtouI4DwFJgKnCuu/8lwTh/F37vMbPXwyuELxD8LfeF8SyKmX488Ndwfa8A/+vuhQkcH+9ZT5ztyAJ+EM63ExhIUIfQ2DyCHxvbY47hbyawfmkFCytzRERE3kNXECIiEpcShIiIxKUEISIicSlBiIhIXF3mOYi8vDwfPXp0qsMQEelUli9fvtvd8+ON6zIJYvTo0SxbtizVYYiIdCpm1uTT+ipiEhGRuJQgREQkLiUIERGJSwlCRETiUoIQEZG4lCBERCQuJQgREYmryzwHISLSVdVHnYpDtew/VEt5+N3wqThUS7+emVx1+sg2X68ShIhIO3N3NpQeoGhn5XtO9OVV7z35NwyvrG7+3VnTR/ZVghAR6ayqaur454Y9vLCulMKiMkrKD71nfFZ6Gn16ZBz5DOmTzaTBufQO+/v2zHjP+IZP7x4ZZGdEkhKzEoSISBK4O5t3H6SwuIwXikt5bdNeauqj9MyMcNa4PP71vLGcNqof/Xpm0ieJJ/njoQQhItJGDtfW88qmPbxYXEZhcSnv7KkCYGx+L645cxTnThzIjDH9yErveMkgHiUIEZHjsGVPVVhsVMo/N+6hui5KdkYa7xubx6fPHsO5Ewcyon/PVIfZKklNEGY2i+CF7RFgvrv/oNH4UQQvrM8neEn71e6+LRx3LXB7OOl33f1XyYxVRDqGaNQpKT9E8c5KindVUrSzkm37quiZGaF3dga52en0zg7K3ntnp5Mb0907LJPPzU4nJzOdtDQ7rljqo05tfZTquig1dVFq64PvrfuqeCG8SthUdhCAUQN6Mm/mSM6dmM8ZJwzokEVGxyppCcLMIsD9wEXANmCpmS1y97Uxk/0I+LW7/8rMzge+D3zSzPoD3wYKAAeWh/PuS1a8ItL+9h6soWhnBetiksG6nZUcrKk/Ms2wvj0YndeTw7VRSisOUHG4lsrDdVTFTBOPGeRmpYcJI0ggPTMj1EX9qBN+w3dNzHdtvVMf9SaXn5mexhknDODq00dx3qSBjMnr1Wb7paNI5hXETGCDu28CMLMngNlAbIKYDPxb2F0ILAi7LwGec/e94bzPAbOAx5MYr4gkyaGaetaXBgmgeGcl68JkUFZZfWSafj0zmDg4lysLRjBhUC4TB+cyYVAOudkZcZdZWx+l8nAdFYdqjySNhu6KQ3Vxh5UdqCYjkkZmJI3c7HQyI2lBf3rwyYikkZWeRkbEjvRnpgfTN3xnRNIYkJPJzDH96ZnZtUvpk7l1w4CtMf3bgNMbTbMKuJygGOojQK6ZDWhi3mGNV2BmNwI3Aowc2fb3AIvIsTlcW8+msoOsL61kQ+kB1u0KEsI7e6vw8Md4dkYa4wfmcs6EfCYNDhLBxEG55OdmYZZ4kVBGJI3+vTLp3yszSVsjqU5/XwHuM7PrgJeAEqD568YY7v4A8ABAQUFB09eCItKmDtXUs7EsSADrSw+wftcBNpRWsmVvFQ2lMpE0Y9SAnkwe2ps504eFyaA3I/v3JHKcdQPSPpKZIEqAETH9w8NhR7j7doIrCMwsB/iou5ebWQlwbqN5X0hirCISx4HqOjaUHmD9ruCKYH3pAdaXVrJt36EjVwQZEWNMXi9OGtqH2acMY/ygHMYPzGVMXi8y09XcW2eWzASxFBhvZmMIEsNc4KrYCcwsD9jr7lHgNoI7mgAWA98zs35h/8XheBFJUE1dlKqaOg7W1FNV/e53VU09B2vC79j+6nqqauqpqgnK79/eXfWep30zI2mckN+LU0b048rTRjB+YA7jB+UwakAvMiJKBF1R0hKEu9eZ2c0EJ/sI8KC7rzGzu4Bl7r6I4Crh+2bmBEVMnw/n3Wtm3yFIMgB3NVRYi8i7yiqrWbN9P2u2V7C6ZD9v7ahgX1UtVTV11NYnXuqamZ5Gr8wIPTPT6ZUVfM8Y3Y+rBo1k3MAcxg/MYWT/nqQrEXQr5t41iu4LCgp82bJlqQ5DJCncnZ0Vh1ldEiSCNdv3s7qkgp0Vh49MMzos7x+Ym03PzAi9stLpkRE5csI/8p2ZTs+sCL0y0+mRGaFnZkRXAN2YmS1394J441JdSS0ijbg7W/ceYvX2/awu2c/q7RWsKdnPnoM1AKQZjM3P4cyxAzhpaG+mDOvD5KG96d3E7aAiraUEIXIc9h6s4e09B4lGg4eqog5Rb+j2sDsYFg3H1x/pfne6uqjz9u6DwRXC9v1UHg6ad05PMyYMyuWCEwcyZVgfThramxOH9O7y999Lx6CjTKSV3tpRwdwHXmX/odo2WV5mehonDs7lsmlDmTK0D1OH9WH8oJwu0WSDdE5KECKtsKH0AFfPf40eGRF+eMXJ9MiMkGYWfoJnANLSgv6IGRYOi6QF49OsobthOsjLyVJdgHQoShAix2jLnio+Mf9VzOA3nzmdsfk5qQ5JJCmUIESOwfbyQ1w1/1Wq66I8ceMZSg7Spel6ViRBpZWH+cT819hfVcuvPzWTSYN7pzokkaTSFYRIAvYdrOGT85ewc/9hHrlhJicP75vqkESSTglCpAUVh2u55sElbN5zkIeum0HB6P6pDkmkXaiISaQZB6vruP6hpRTtrOBnV5/KWePyUh2SSLtRghBpwuHaej7z62Ws2LKPH8+dzvmTBqU6JJF2pSImkThq6qJ87tHlvLJpD/d8bBofmjok1SGJtDtdQYg0Ulcf5YtPrKCwuIz/mDOVj0wfnuqQRFJCCUIkRjTqfPWpN/jz6p1869LJXHW6XmUr3ZcShEjI3fl/C1bzhxUlfOXiCdxw9phUhySSUkoQIgTJ4a6n1/L4ki3867ljufn88akOSSTllCBEgB/9pZiHXn6b688azVcvmZjqcEQ6BCUI6fbuL9zA/YUbmTdzBHdcOhkzS3VIIh2CEoR0a7/8x2Z+uLiYOacM5btzpio5iMRQgpBu6/ElW/jO02uZddJgfnTlNCJpSg4isfSgnHQ70ajz2JItfGvhas6bmM//zJtOul7UI3IUJQjpVl7fso+7/riWlVvLOXtcHj+9+jQy05UcROJRgpBuYXv5Ie5+toiFK7eTn5vFD684mY+eOpw0FSuJNEkJQrq0qpo6fv7iJn7+0kaiDjefN47PnTuWXlk69EVaov8S6ZKiUWfhqhLu/nMxOysOc+nJQ/j6rEmM6N8z1aGJdBpKENLlxNYzTB3Wh59cNZ0ZesmPyDFTgpAuI7aeYWBuFj+6chqXTx+megaRVlKCkE6vqqaOn724iQde2og73HL+OG46R/UMIsdL/0HSaUWjzoKVJdz9bBG7Kqq5bNpQvj5rIsP7qZ5BpC0oQUintPydfdz19FpWbS1n2vA+3H/VqRSonkGkTSlBSLt4+OXNPPraFjIiaWSlp5GZHnxnpUfIyojpTk8L+8Pu9DSyMiJkRYLhmZE0/rx6J4tWbWdQ7yzu+dg05pyiegaRZFCCkKR79NV3uPOPa5k+si/5OVlU10WprqvnQHUdew7UUF1XHw6LUl37bndTstLT+ML54/is6hlEkkr/XZJUC1eW8K2Fq7nwxIH89OrTyEiwzSN3p7be4yaPgblZDMjJSnLkIqIEIUnzfNEubn1yFTNH9+e+q05NODkAmBmZ6UZmehq5SYxRRJqmVsokKV7dtIfPPfo6k4f2Zv61BWRnRFIdkogcIyUIaXNvbCvn079axsj+PXn4+pnkZmekOiQRaQUlCGlT63dVcu2DS+jbM4NHbjid/r0yUx2SiLRSUhOEmc0ys2Iz22Bm34gzfqSZFZrZCjN7w8w+FA4fbWaHzGxl+PlZMuOUtrF1bxWf/OUS0iNpPHrD6Qzuk53qkETkOCStktrMIsD9wEXANmCpmS1y97Uxk90OPOnuPzWzycAzwOhw3EZ3PyVZ8UnbKq08zNW/fI1DtfX89rNnMDqvV6pDEpHjlMwriJnABnff5O41wBPA7EbTONA77O4DbE9iPJIk+6tqueaXSyirrOah62cwaXDvlmcSkQ4vmQliGLA1pn9bOCzWncDVZraN4OrhlphxY8KipxfN7P3xVmBmN5rZMjNbVlZW1oahS6IOVtdx3cNL2FR2kAc+WcCpI/ulOiQRaSOprqSeBzzs7sOBDwGPmFkasAMY6e7TgX8DHjOzo36WuvsD7l7g7gX5+fntGrhAdV09Nz26nFVby/mfedM5e3xeqkMSkTaUzARRAoyI6R8eDot1A/AkgLu/AmQDee5e7e57wuHLgY3AhCTGKseorj7KFx9fyd/X7+Y/r5jGrCmDUx2SiLSxZCaIpcB4MxtjZpnAXGBRo2m2ABcAmNmJBAmizMzyw0puzOwEYDywKYmxyjGIRp1v/P5Nnl2zkzsuncwVpw1PdUgikgRJu4vJ3evM7GZgMRABHnT3NWZ2F7DM3RcBtwK/MLMvE1RYX+fubmYfAO4ys1ogCtzk7nuTFaskzt357p/e4qnl2/jSheP51NljUh2SiCSJuXuqY2gTBQUFvmzZslSH0eX9+K/rufev67j+rNHccelkzNTMtkhnZmbL3b0g3rhUV1JLJ/LgPzZz71/XccVpw/nWh5UcRLo6JQhJyFPLt3HX02uZddJgfnD5VL2gR6QbUIKQFj27eidfe2oVZ4/L48fzTiH9GJrtFpHOS++DkCbV1Uf53fJtfHvhGqaN6MvPP3kaWelqtluku1CCkKO4O4vX7OSHi4vZWHaQGaP7Mf+aGXq9p0g3o/94eY9XNu7h7meLWLm1nLH5vfjZ1adxyUmDVCEt0g0pQQgAa7bv5z+fLebFdWUM7p3N3R+dykdPHa76BpFuTAmim9uyp4r/eq6YhSu306dHBrd9cBLXvm+0XhEqIkoQ3VVZZTX3Pb+ex5ZsIZJmfO7csdx0zlj69NDrQUUkoATRzVQeruUXf9/M/L9vorouyscKRvClC8czqLfe/iYi76UE0U1U19Xz6KtbuL9wA3sP1vDhqUP4t4snMDY/J9WhiUgHpQTRxdVHnQUrSrjnuXWUlB/irHED+Nolk5g2om+qQxORDk4Jootyd54vKuU/ny2meFclU4b15gcfncr7x+vFSiKSGCWILmj3gWpu+/2bPLd2F6MH9OQn86bz4alD1H6SiBwTJYgu5tnVO/jmH1ZzoLqOb35oEtefNYYMPcsgIq2gBNFF7D9Uy52L1vCHFSVMGdabez52ChMG5aY6LBHpxJQguoCX1pXxtafeoOxANV+8YDw3nz9OVw0ictyUIDqxqpo6vvfMWzz66hbGDczhgWtO4+ThujtJRNqGEkQnteztvdz6u1Vs2VvFp88ew1cumajmMUSkTbWYIMzsFuBRd9/XDvFIC6rr6rn3ufU88NJGhvbtweOfOYMzThiQ6rBEpAtK5ApiELDUzF4HHgQWu7snNyyJZ832/fzbb1dRvKuSeTNH8P8+PJkcvaNBRJKkxZpMd78dGA/8ErgOWG9m3zOzsUmOTUJ19VHue349s+97mX1VNTx03Qy+f/nJSg4iklQJnWHc3c1sJ7ATqAP6AU+Z2XPu/rVkBtjdbSw7wK1PrmLl1nIuPXkI35k9hX69MlMdloh0A4nUQXwRuAbYDcwHvurutWaWBqwHlCCSIBp1fvXK29z9bBHZGRF+Mm86l00bmuqwRKQbSeQKoj9wubu/EzvQ3aNmdmlywuretu2r4qu/e4NXNu3hvIn53P3Rkxmo5rhFpJ0lkiD+DOxt6DGz3sCJ7v6au7+VtMi6qY1lB5hz38tE3fnB5VP5+IwReh+0iKREIgnip8CpMf0H4gyTNvLEki0crqvnuS+fw+i8XqkOR0S6sUTaY7DY21rdPYoesEuK+qizcOV2zpkwUMlBRFIukQSxycy+YGYZ4eeLwKZkB9YdvbppD6WV1cyZrspoEUm9RBLETcD7gBJgG3A6cGMyg+quFqwoIScrnQtPHJTqUEREWi4qcvdSYG47xNKtHa6t59nVO5k1ZbDaVBKRDiGR5yCygRuAk4Aj91q6+6eSGFe383xRKZXVdcw5ZViqQxERARIrYnoEGAxcArwIDAcqkxlUd7RgRQkDc7M4c6wa3hORjiGRBDHO3b8FHHT3XwEfJqiHkDZSXlVDYXEpl00bSkTvjRaRDiKRBFEbfpeb2RSgDzAweSF1P8+8uZPaelfxkoh0KIk8z/CAmfUDbgcWATnAt5IaVTezYGUJY/N7MWVY71SHIiJyRLMJImyQryJ8WdBLwAntElU3UlJ+iCWb93LrRRPUpIaIdCjNFjGFT023urVWM5tlZsVmtsHMvhFn/EgzKzSzFWb2hpl9KGbcbeF8xWZ2SWtj6OgWrdwOwGwVL4lIB5NIHcRfzewrZjbCzPo3fFqaycwiwP3AB4HJwDwzm9xostuBJ919OsGzFv8bzjs57D8JmAX8b7i8LmfhyhJOHdmXkQN6pjoUEZH3SCRBfBz4PEER0/LwsyyB+WYCG9x9k7vXAE8AsxtN40BDwXsfYHvYPRt4wt2r3X0zsCFcXpfy1o4KinZWMme6rh5EpONJ5EnqMa1c9jBga0x/QzMdse4E/mJmtwC9gAtj5n210bxHnUXN7EbCZj9GjhzZyjBTZ8HKEtLTjA9PHZLqUEREjpLIk9TXxBvu7r9ug/XPAx529/8yszOBR8JbaRPi7g8ADwAUFBR4C5N3KNGo88eV2/nAhHwG5GSlOhwRkaMkcpvrjJjubOAC4HWgpQRRAoyI6R8eDot1A0EdA+7+StisR16C83ZqS97ey/b9h/n6ByelOhQRkbgSKWK6JbbfzPoS1Ce0ZCkw3szGEJzc5wJXNZpmC0HCedjMTiRIQGUEz1s8Zmb3AEOB8cCSBNbZaSxcWULPzAgXTVbLrSLSMbXmxT8HgRbrJdy9zsxuBhYDEeBBd19jZncBy9x9EXAr8Asz+zJBhfV14cuJ1pjZk8BaoA74vLvXtyLWDqm6rp4/vbGDS04aTM9MvXtJRDqmROog/khw8obgrqfJwJOJLNzdnwGeaTTsjpjutcBZTcz7H8B/JLKezqawqIyKw3XMPkUvBhKRjiuRn68/iumuA95x921JiqdbWLiyhLycTM4el5fqUEREmpRIgtgC7HD3wwBm1sPMRrv720mNrIuqOFzL34pKuWrmSNIjiTyGIiKSGomcoX4HRGP668Nh0grPvrmTmrqoHo4TkQ4vkQSRHj4JDUDYnZm8kLq2BStLGD2gJ9OG90l1KCIizUokQZSZ2b809JjZbGB38kLqunbuP8wrm/Yw+5RharlVRDq8ROogbgJ+Y2b3hf3bgLhPV0vzFq0qwR0VL4lIp5DIg3IbgTPMLCfsP5D0qLqoBSu2M21EX8bk9Up1KCIiLWqxiMnMvmdmfd39gLsfMLN+Zvbd9giuK1m/q5K1OyqYo2cfRKSTSKQO4oPuXt7QE75d7kPNTC9xLFhZQiTNuPRkJQgR6RwSSRARMzvS3KiZ9QDU/OgxiEadhSu3c9a4PPJztetEpHNIpJL6N8DfzOwhwIDrgF8lM6iuZvmWfWzbd4hbL56Q6lBERBKWSCX13Wa2iuBlPk7Q+N6oZAfWlSxYUUKPjAgXTx6c6lBERBKWaFsPuwiSw5XA+cBbSYuoi6mpi/KnN3dw0eRB9MpSy60i0nk0ecYyswkEb3ybR/Bg3G8Bc/fz2im2LuGldWWUV9UyZ7oqp0Wkc2nuJ20R8HfgUnffABC+t0GOwYKVJfTvlcn7x+enOhQRkWPSXBHT5cAOoNDMfmFmFxBUUkuCKg/X8tzaXXx46hAy1HKriHQyTZ613H2Bu88FJgGFwJeAgWb2UzO7uL0C7MwWr9lFtVpuFZFOqsWfte5+0N0fc/fLgOHACuDrSY+sC1i4soSR/Xty6si+qQ5FROSYHVO5h7vvc/cH3P2CZAXUVZRWHublDbuZfcpQtdwqIp2SCsaT5I+rdhB1mH2KipdEpHNSgkiShStLmDKsN+MG5qQ6FBGRVlGCSIKNZQd4Y9t+5ujqQUQ6MSWIJFi4ooQ0g3+ZpofjRKTzUoJoY+7OgpXbed/YPAb2zk51OCIiraYE0cZWbC1ny94qZuvFQCLSySlBtLGFK0rISk9j1hS13CoinZsSRBuqrY/y9Bs7uPDEQeRmZ6Q6HBGR46IE0Yb+sX43ew7WqGkNEekSlCDa0IKVJfTtmcE5E9Ryq4h0fkoQbaSqpo6/rNnFh6YOITNdu1VEOj+dydrIS+t2c6i2nktPHpLqUERE2oQSRBt5obiU3Ox0Zozun+pQRETahBJEG3B3CotL+cD4fL0YSES6DJ3N2sBbOyrZVVHNuRNVOS0iXYcSRBsoLC4F4BwlCBHpQpQg2kBhUSlTh/VhYK7aXhKRrkMJ4jiVV9Xw+pZ9nKerBxHpYpKaIMxslpkVm9kGM/tGnPH3mtnK8LPOzMpjxtXHjFuUzDiPx0vrdxN1OG/SwFSHIiLSptKTtWAziwD3AxcB24ClZrbI3dc2TOPuX46Z/hZgeswiDrn7KcmKr628UFRK/16ZnDy8b6pDERFpU8m8gpgJbHD3Te5eAzwBzG5m+nnA40mMp81Fo84L68o4Z0I+kTRLdTgiIm0qmQliGLA1pn9bOOwoZjYKGAM8HzM427ZisekAAA12SURBVMyWmdmrZjanifluDKdZVlZW1lZxJ2zVtnL2HqzR7a0i0iV1lErqucBT7l4fM2yUuxcAVwH/bWZjG8/k7g+4e4G7F+Tnt/9JurC4jDRDjfOJSJeUzARRAoyI6R8eDotnLo2Kl9y9JPzeBLzAe+snOoQXiks5dWQ/+vbMTHUoIiJtLpkJYikw3szGmFkmQRI46m4kM5sE9ANeiRnWz8yywu484CxgbeN5U6msspo3tu3X3Usi0mUl7S4md68zs5uBxUAEeNDd15jZXcAyd29IFnOBJ9zdY2Y/Efi5mUUJktgPYu9+6gheXBfUeaj+QUS6qqQlCAB3fwZ4ptGwOxr13xlnvn8CU5MZ2/EqLCplYG4Wk4f0TnUoIiJJ0VEqqTuV2vooL60v47yJAzHT7a0i0jUpQbTC6+/so/JwHedNUvGSiHRdShCtUFhcRkbEOGtcXqpDERFJGiWIVnihuJQZo/uTm52R6lBERJJGCeIYlZQfomhnJedN1O2tItK1KUEcoxfClwOp/kFEujoliGNUWFTG8H49GJufk+pQRESSSgniGFTX1fPyht2cP0m3t4pI16cEcQyWbN7Lodp61T+ISLegBHEMni8qJSs9jTNOGJDqUEREkk4J4hi8UFzGmWMH0CMzkupQRESSTgkiQZt3H2Tz7oMqXhKRbkMJIkFHbm9VghCRbkIJIkGFxWWMze/FyAE9Ux2KiEi7UIJIQFVNHa9u2qOrBxHpVpQgEvDPDXuoqYvq7XEi0q0oQSSgsLiUXpkRCkb3S3UoIiLtRgmiBe7OC8VlnDUuj6x03d4qIt2HEkQL1pceoKT8EOereElEuhkliBY8XxTc3nquKqhFpJtRgmhBYVEpJw7pzeA+2akORUSkXSlBNKPicC3L3tnHeRP17gcR6X6UIJrxj/W7qY+6bm8VkW5JCaIZhUWl9OmRwfQRfVMdiohIu1OCaEI06rywrowPTMgnPaLdJCLdj858TVizvYKyymrVP4hIt6UE0YTC4lLM4AMTlCBEpHtSgmhCYXEpJw/vS15OVqpDERFJCSWIOPYerGHl1nLO18NxItKNKUHE8dK6MtzhvEkqXhKR7ksJIo7ni0rJy8lkytA+qQ5FRCRllCAaqY86L64r45wJA0lLs1SHIyKSMkoQjazcuo/9h2pVvCQi3Z4SRCOFRWVE0oz3j1eCEJHuTQmikcLiUk4b1Y8+PTJSHYqISEopQcTYVXGYNdsrOE+3t4qIKEHEeqE4eDmQ6h9ERJKcIMxslpkVm9kGM/tGnPH3mtnK8LPOzMpjxl1rZuvDz7XJjLNBYVEZQ/pkM3FQbnusTkSkQ0tP1oLNLALcD1wEbAOWmtkid1/bMI27fzlm+luA6WF3f+DbQAHgwPJw3n3JiremLso/NuzmsmlDMdPtrSIiybyCmAlscPdN7l4DPAHMbmb6ecDjYfclwHPuvjdMCs8Bs5IYK8ve2cuB6jrO18uBRESA5CaIYcDWmP5t4bCjmNkoYAzw/LHMa2Y3mtkyM1tWVlZ2XMEWFpWSGUnjfWMHHNdyRES6io5SST0XeMrd649lJnd/wN0L3L0gP//4KpYLi8s4/YT+9MpKWqmbiEinkswEUQKMiOkfHg6LZy7vFi8d67zHbeveKjaUHuBc3d4qInJEMhPEUmC8mY0xs0yCJLCo8URmNgnoB7wSM3gxcLGZ9TOzfsDF4bCkOHJ7q94eJyJyRNLKU9y9zsxuJjixR4AH3X2Nmd0FLHP3hmQxF3jC3T1m3r1m9h2CJANwl7vvTVashcVljB7QkxPyc5K1ChGRTiepBe7u/gzwTKNhdzTqv7OJeR8EHkxacKHDtfX8c+Nu5s4YmexViYh0Kh2lkjplKg7VcvHkwVxy0uBUhyIi0qF0+1t2BvbO5n/mTU91GCIiHU63v4IQEZH4lCBERCQuJQgREYlLCUJEROJSghARkbiUIEREJC4lCBERiUsJQkRE4rKYJpA6NTMrA95JdRxtJA/YneogOjDtn+Zp/zRN++Zoo9w9bkulXSZBdCVmtszdC1IdR0el/dM87Z+mad8cGxUxiYhIXEoQIiISlxJEx/RAqgPo4LR/mqf90zTtm2OgOggREYlLVxAiIhKXEoSIiMSlBJECZjbCzArNbK2ZrTGzL4bD+5vZc2a2PvzuFw43M/sfM9tgZm+Y2amp3YLkM7OIma0ws6fD/jFm9lq4D35rZpnh8Kywf0M4fnQq424PZtbXzJ4ysyIze8vMztSx8y4z+3L4f7XazB43s2wdP62jBJEadcCt7j4ZOAP4vJlNBr4B/M3dxwN/C/sBPgiMDz83Aj9t/5Db3ReBt2L67wbudfdxwD7ghnD4DcC+cPi94XRd3Y+BZ919EjCNYD/p2AHMbBjwBaDA3acAEWAuOn5ax931SfEHWAhcBBQDQ8JhQ4DisPvnwLyY6Y9M1xU/wHCCk9z5wNOAETz9mh6OPxNYHHYvBs4Mu9PD6SzV25DEfdMH2Nx4G3XsHNm+YcBWoH94PDwNXKLjp3UfXUGkWHhJOx14DRjk7jvCUTuBQWF3w0HfYFs4rKv6b+BrQDTsHwCUu3td2B+7/Uf2TTh+fzh9VzUGKAMeCovg5ptZL3TsAODuJcCPgC3ADoLjYTk6flpFCSKFzCwH+D/gS+5eETvOg5803e4eZDO7FCh19+WpjqWDSgdOBX7q7tOBg7xbnAR032MHIKx7mU2QSIcCvYBZKQ2qE1OCSBEzyyBIDr9x99+Hg3eZ2ZBw/BCgNBxeAoyImX14OKwrOgv4FzN7G3iCoJjpx0BfM0sPp4nd/iP7JhzfB9jTngG3s23ANnd/Lex/iiBh6NgJXAhsdvcyd68Ffk9wTOn4aQUliBQwMwN+Cbzl7vfEjFoEXBt2X0tQN9Ew/JrwjpQzgP0xxQldirvf5u7D3X00QeXi8+7+CaAQuCKcrPG+adhnV4TTd9lfz+6+E9hqZhPDQRcAa9Gx02ALcIaZ9Qz/zxr2j46fVtCT1ClgZmcDfwfe5N1y9m8S1EM8CYwkaLr8Y+6+NzzQ7yO4VK4Crnf3Ze0eeDszs3OBr7j7pWZ2AsEVRX9gBXC1u1ebWTbwCEE9zl5grrtvSlXM7cHMTgHmA5nAJuB6gh97OnYAM/t34OMEdwuuAD5NUNeg4+cYKUGIiEhcKmISEZG4lCBERCQuJQgREYlLCUJEROJSghARkbiUIKRTMDM3s/+K6f+Kmd3ZRst+2MyuaHnK417PlWHrq4WNhqeFLa6uNrM3zWypmY1Jcixvm1leMtchnZ8ShHQW1cDlHe2kFvN0biJuAD7j7uc1Gv5xgmYhTnb3qcBHgPI2ClGk1ZQgpLOoI3if8Jcbj2h8BWBmB8Lvc83sRTNbaGabzOwHZvYJM1sS/lIfG7OYC81smZmtC9uDangnxQ/DX/RvmNlnY5b7dzNbRPCUbuN45oXLX21md4fD7gDOBn5pZj9sNMsQYIe7RwHcfZu77wvn+2kY15rwAbCGdbxtZt83s5Xh+FPNbLGZbTSzm2LifMnM/mRmxWb2MzM76n/ezK4O98lKM/t5uN2RcL82XNUctd+l6zuWXz8iqXY/8IaZ/ecxzDMNOJHgKdlNwHx3n2nBS5puAb4UTjcamAmMBQrNbBxwDUHTFDPMLAt42cz+Ek5/KjDF3TfHrszMhhK8U+A0gvcO/MXM5rj7XWZ2PsGT4Y2fZH4S+IeZvZ+gmfNH3X1FOO7/hU9ER4C/mdnJ7v5GOG6Lu59iZvcCDxO0OZQNrAZ+Fk4zE5hM8HT1s8DlBO03NcR7IsEVzFnuXmtm/wt8AlgDDPPgnQqYWd8E9rV0MbqCkE4jbPH21wQvhEnUUnff4e7VwEag4QT/JkFSaPCku0fdfT1BIpkEXEzQjtFKgmZQBhC8eAdgSePkEJoBvBA2FlcH/Ab4QAvbtQ2YCNxG0PTK38zsgnD0x8zsdYLmIU4iONk3WBSzLa+5e6W7lwHVMSf0Je6+yd3rgccJrmJiXUCQzJaG23kBcEK4D04ws5+Y2SygAul2dAUhnc1/A68DD8UMqyP8sRMWoWTGjKuO6Y7G9Ed57/HfuM0ZJ3hR0S3uvjh2RNhG1MHWhR9fmMD+DPzZzHYBc8xsE/AVYIa77zOzhwmuEBrEbkvj7WzYtnjbFcuAX7n7bY1jMrNpBC/buQn4GPCpY90u6dx0BSGdirvvJSiSuSFm8NsEv4IB/gXIaMWirwzvJhpL8Au6mOBtY5+zoGl2zGyCBS/nac4S4BwzywuLheYBLzY3Q1h/MDTsTgNOJigS6k2QiPab2SCC14ceq5kWvI85jaAo6R+Nxv8NuMLMBobr729mo8KbAdLc/f+A2wmK1KSb0RWEdEb/Bdwc0/8LYKGZrSIoZ2/Nr/stBCf33sBN7n7YzOYTFEO9bmZG8Ca3Oc0txN13mNk3CJqXNuBP7r6wuXmAgcAvwnoOwjjuC2NYARQRvPXs5VZs11KC1lzHhTH9oVG8a83sdoK6kjSgFvg8cIjgrXUNPyKPusKQrk+tuYp0URbTXHqqY5HOSUVMIiISl64gREQkLl1BiIhIXEoQIiISlxKEiIjEpQQhIiJxKUGIiEhc/x8UrSbZqNoF5AAAAABJRU5ErkJggg==\n", 148 | "text/plain": [ 149 | "
" 150 | ] 151 | }, 152 | "metadata": { 153 | "needs_background": "light" 154 | }, 155 | "output_type": "display_data" 156 | } 157 | ], 158 | "source": [ 159 | "%matplotlib inline \n", 160 | "import matplotlib.pyplot as plt\n", 161 | "\n", 162 | "num_samples = list(range(100, x_train_poly.shape[0], 50))\n", 163 | "plt.plot(num_samples, accuracy_log)\n", 164 | "plt.xlabel('Number of Samples')\n", 165 | "plt.ylabel('Accuracy')\n", 166 | "plt.title('Effect of Increased Dataset size')\n", 167 | "plt.show()" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": { 173 | "Collapsed": "false" 174 | }, 175 | "source": [ 176 | "## Baseline Estimates" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 8, 182 | "metadata": { 183 | "Collapsed": "false" 184 | }, 185 | "outputs": [ 186 | { 187 | "name": "stdout", 188 | "output_type": "stream", 189 | "text": [ 190 | "Test Accuracy: 90.86%\n" 191 | ] 192 | } 193 | ], 194 | "source": [ 195 | "from sklearn.linear_model import SGDClassifier\n", 196 | "from sklearn.metrics import accuracy_score\n", 197 | "\n", 198 | "\n", 199 | "test_acc = []\n", 200 | "for _ in range(10):\n", 201 | " log_reg = SGDClassifier(loss = 'log', n_jobs = -1, alpha = 1e-5)\n", 202 | " log_reg.fit(x_train_poly, y_train)\n", 203 | " y_test_pred = log_reg.predict(x_test_poly)\n", 204 | " test_acc.append(accuracy_score(y_test_pred, y_test))\n", 205 | " \n", 206 | " \n", 207 | "print('Test Accuracy: {:.2f}%'.format(np.array(test_acc).mean()*100))" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": { 213 | "Collapsed": "false" 214 | }, 215 | "source": [ 216 | "## Semi-Supervised Training" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 21, 222 | "metadata": { 223 | "Collapsed": "false" 224 | }, 225 | "outputs": [], 226 | "source": [ 227 | "# Concept similar to : https://www.analyticsvidhya.com/blog/2017/09/pseudo-labelling-semi-supervised-learning-technique/\n", 228 | "\n", 229 | "class pseudo_labeling():\n", 230 | "\n", 231 | " \n", 232 | " def __init__(self, model, unlabelled_data, sample_rate=0.01, upper_threshold = 0.6, lower_threshold = 0.4, verbose = False):\n", 233 | " \n", 234 | " self.sample_rate = sample_rate\n", 235 | " self.model = model\n", 236 | " self.unlabelled_data = unlabelled_data\n", 237 | " self.verbose = verbose\n", 238 | " self.upper_threshold = upper_threshold\n", 239 | " self.lower_threshold = lower_threshold\n", 240 | " \n", 241 | " # create a list of all the indices \n", 242 | " self.unlabelled_indices = list(range(unlabelled_data.shape[0])) \n", 243 | " \n", 244 | " # Number of rows to sample in each iteration\n", 245 | " self.sample_size = int(unlabelled_data.shape[0] * self.sample_rate)\n", 246 | " \n", 247 | " # Shuffle the indices\n", 248 | " np.random.shuffle(self.unlabelled_indices)\n", 249 | "\n", 250 | " \n", 251 | " \n", 252 | " def __pop_rows(self):\n", 253 | " \"\"\"\n", 254 | " Function to sample indices without replacement\n", 255 | " \"\"\"\n", 256 | " chosen_rows = self.unlabelled_indices[:self.sample_size]\n", 257 | " \n", 258 | " # Remove the chosen rows from the list of indicies (We are sampling w/o replacement)\n", 259 | " self.unlabelled_indices = self.unlabelled_indices[self.sample_size:]\n", 260 | " return chosen_rows\n", 261 | " \n", 262 | " \n", 263 | " def fit(self, X, y):\n", 264 | " \n", 265 | " \"\"\"\n", 266 | " Perform pseudo labelling\n", 267 | " \n", 268 | " X: train features\n", 269 | " y: train targets\n", 270 | " \n", 271 | " \"\"\"\n", 272 | " \n", 273 | " num_iters = int(len(self.unlabelled_indices)/self.sample_size)\n", 274 | "\n", 275 | " for _ in (tqdm_notebook(range(num_iters)) if self.verbose else range(num_iters)):\n", 276 | " \n", 277 | " # Get the samples\n", 278 | " chosen_rows = self.__pop_rows()\n", 279 | "\n", 280 | " # Fit to data\n", 281 | " self.model.fit(X, y.ravel())\n", 282 | " \n", 283 | " chosen_unlabelled_rows = self.unlabelled_data[chosen_rows,:]\n", 284 | " pseudo_labels_prob = self.model.predict_proba(chosen_unlabelled_rows)\n", 285 | " \n", 286 | " \n", 287 | " # We have 10 classes this means `predict_proba` returns an array of 10 probabilities per datapoint\n", 288 | " # We will first find the maximum probability and then find the rows which are within our threshold values\n", 289 | " label_probability = np.max(pseudo_labels_prob, axis = 1)\n", 290 | " labels_within_threshold = np.where((label_probability < self.lower_threshold) | (label_probability > self.upper_threshold))[0]\n", 291 | " \n", 292 | " \n", 293 | " # Use argmax to find the class with the highest probability\n", 294 | " pseudo_labels = np.argmax(pseudo_labels_prob[labels_within_threshold], axis = 1)\n", 295 | " chosen_unlabelled_rows = chosen_unlabelled_rows[labels_within_threshold]\n", 296 | "\n", 297 | " # Combine data\n", 298 | " X = np.vstack((chosen_unlabelled_rows, X))\n", 299 | " y = np.vstack((pseudo_labels.reshape(-1,1), np.array(y).reshape(-1,1)))\n", 300 | "\n", 301 | " # Shuffle \n", 302 | " indices = list(range(X.shape[0]))\n", 303 | " np.random.shuffle(indices)\n", 304 | "\n", 305 | " X = X[indices]\n", 306 | " y = y[indices] \n", 307 | " \n", 308 | " def predict(self, X):\n", 309 | " return self.model.predict(X)\n", 310 | " \n", 311 | " def predict_proba(self, X):\n", 312 | " return self.model.predict_proba(X)\n", 313 | " \n", 314 | " \n", 315 | " def decision_function(self, X):\n", 316 | " return self.model.decision_function(X)\n", 317 | "\n", 318 | " " 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": 22, 324 | "metadata": { 325 | "Collapsed": "false" 326 | }, 327 | "outputs": [], 328 | "source": [ 329 | "from sklearn.linear_model import SGDClassifier \n", 330 | "from tqdm import tqdm_notebook\n", 331 | "\n", 332 | "log_reg = SGDClassifier(loss = 'log', n_jobs = -1, alpha = 1e-5)\n", 333 | "\n", 334 | "pseudo_labeller = pseudo_labeling(\n", 335 | " log_reg,\n", 336 | " x_unlabeled_poly,\n", 337 | " sample_rate = 0.04,\n", 338 | " verbose = True\n", 339 | " )" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 23, 345 | "metadata": { 346 | "Collapsed": "false" 347 | }, 348 | "outputs": [ 349 | { 350 | "data": { 351 | "application/vnd.jupyter.widget-view+json": { 352 | "model_id": "365653575c194ba381a169219a33f956", 353 | "version_major": 2, 354 | "version_minor": 0 355 | }, 356 | "text/plain": [ 357 | "HBox(children=(IntProgress(value=0, max=25), HTML(value='')))" 358 | ] 359 | }, 360 | "metadata": {}, 361 | "output_type": "display_data" 362 | }, 363 | { 364 | "name": "stdout", 365 | "output_type": "stream", 366 | "text": [ 367 | "\n" 368 | ] 369 | } 370 | ], 371 | "source": [ 372 | "pseudo_labeller.fit(x_train_poly, y_train)" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 24, 378 | "metadata": { 379 | "Collapsed": "false" 380 | }, 381 | "outputs": [ 382 | { 383 | "name": "stdout", 384 | "output_type": "stream", 385 | "text": [ 386 | "Test Accuracy: 92.42%\n" 387 | ] 388 | } 389 | ], 390 | "source": [ 391 | "from sklearn.metrics import accuracy_score\n", 392 | "y_test_pred = pseudo_labeller.predict(x_test_poly)\n", 393 | "print('Test Accuracy: {:.2f}%'.format(accuracy_score(y_test_pred, y_test)*100))" 394 | ] 395 | } 396 | ], 397 | "metadata": { 398 | "kernelspec": { 399 | "display_name": "Python 3", 400 | "language": "python", 401 | "name": "python3" 402 | }, 403 | "language_info": { 404 | "codemirror_mode": { 405 | "name": "ipython", 406 | "version": 3 407 | }, 408 | "file_extension": ".py", 409 | "mimetype": "text/x-python", 410 | "name": "python", 411 | "nbconvert_exporter": "python", 412 | "pygments_lexer": "ipython3", 413 | "version": "3.7.3" 414 | } 415 | }, 416 | "nbformat": 4, 417 | "nbformat_minor": 4 418 | } 419 | --------------------------------------------------------------------------------