├── README.md ├── VAE_for_imbalanced_data.ipynb ├── Variational_Autoencoder_data_augmentation.ipynb └── Create_Autoencoder_Model_Basemodel_3Embeddings.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Autoencoder 2 | 3 | I use the famous iris-dataset to create an Autoencoder with PyTorch. I then show the difference between a PCA and an embedding space build by the Autoencoder. 4 | -------------------------------------------------------------------------------- /VAE_for_imbalanced_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "language_info": { 4 | "codemirror_mode": { 5 | "name": "ipython", 6 | "version": 3 7 | }, 8 | "file_extension": ".py", 9 | "mimetype": "text/x-python", 10 | "name": "python", 11 | "nbconvert_exporter": "python", 12 | "pygments_lexer": "ipython3", 13 | "version": "3.9.2-final" 14 | }, 15 | "orig_nbformat": 2, 16 | "kernelspec": { 17 | "name": "python3", 18 | "display_name": "Python 3.9.2 64-bit", 19 | "metadata": { 20 | "interpreter": { 21 | "hash": "9139ca13fc640d8623238ac4ed44beace8a76f86a07bab6efe75c2506e18783d" 22 | } 23 | } 24 | } 25 | }, 26 | "nbformat": 4, 27 | "nbformat_minor": 2, 28 | "cells": [ 29 | { 30 | "cell_type": "code", 31 | "execution_count": 374, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "import torch\n", 36 | "import torch.nn as nn\n", 37 | "import torch.nn.functional as F\n", 38 | "from torch import nn, optim\n", 39 | "from torch.autograd import Variable\n", 40 | "\n", 41 | "import pandas as pd\n", 42 | "import numpy as np\n", 43 | "from sklearn import preprocessing\n", 44 | "from sklearn.model_selection import train_test_split\n", 45 | "import mlprepare as mlp \n", 46 | "from sklearn.ensemble import RandomForestClassifier\n", 47 | "from sklearn.metrics import confusion_matrix" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 2, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "output_type": "execute_result", 57 | "data": { 58 | "text/plain": [ 59 | "device(type='cpu')" 60 | ] 61 | }, 62 | "metadata": {}, 63 | "execution_count": 2 64 | } 65 | ], 66 | "source": [ 67 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 68 | "device" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "DATA_PATH = 'data/creditcard.csv'" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 327, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "output_type": "execute_result", 87 | "data": { 88 | "text/plain": [ 89 | " Time V1 V2 V3 V4 V5 V6 V7 \\\n", 90 | "0 0.0 -1.359807 -0.072781 2.536347 1.378155 -0.338321 0.462388 0.239599 \n", 91 | "1 0.0 1.191857 0.266151 0.166480 0.448154 0.060018 -0.082361 -0.078803 \n", 92 | "2 1.0 -1.358354 -1.340163 1.773209 0.379780 -0.503198 1.800499 0.791461 \n", 93 | "3 1.0 -0.966272 -0.185226 1.792993 -0.863291 -0.010309 1.247203 0.237609 \n", 94 | "4 2.0 -1.158233 0.877737 1.548718 0.403034 -0.407193 0.095921 0.592941 \n", 95 | "\n", 96 | " V8 V9 ... V21 V22 V23 V24 V25 \\\n", 97 | "0 0.098698 0.363787 ... -0.018307 0.277838 -0.110474 0.066928 0.128539 \n", 98 | "1 0.085102 -0.255425 ... -0.225775 -0.638672 0.101288 -0.339846 0.167170 \n", 99 | "2 0.247676 -1.514654 ... 0.247998 0.771679 0.909412 -0.689281 -0.327642 \n", 100 | "3 0.377436 -1.387024 ... -0.108300 0.005274 -0.190321 -1.175575 0.647376 \n", 101 | "4 -0.270533 0.817739 ... -0.009431 0.798278 -0.137458 0.141267 -0.206010 \n", 102 | "\n", 103 | " V26 V27 V28 Amount Class \n", 104 | "0 -0.189115 0.133558 -0.021053 149.62 0 \n", 105 | "1 0.125895 -0.008983 0.014724 2.69 0 \n", 106 | "2 -0.139097 -0.055353 -0.059752 378.66 0 \n", 107 | "3 -0.221929 0.062723 0.061458 123.50 0 \n", 108 | "4 0.502292 0.219422 0.215153 69.99 0 \n", 109 | "\n", 110 | "[5 rows x 31 columns]" 111 | ], 112 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
TimeV1V2V3V4V5V6V7V8V9...V21V22V23V24V25V26V27V28AmountClass
00.0-1.359807-0.0727812.5363471.378155-0.3383210.4623880.2395990.0986980.363787...-0.0183070.277838-0.1104740.0669280.128539-0.1891150.133558-0.021053149.620
10.01.1918570.2661510.1664800.4481540.060018-0.082361-0.0788030.085102-0.255425...-0.225775-0.6386720.101288-0.3398460.1671700.125895-0.0089830.0147242.690
21.0-1.358354-1.3401631.7732090.379780-0.5031981.8004990.7914610.247676-1.514654...0.2479980.7716790.909412-0.689281-0.327642-0.139097-0.055353-0.059752378.660
31.0-0.966272-0.1852261.792993-0.863291-0.0103091.2472030.2376090.377436-1.387024...-0.1083000.005274-0.190321-1.1755750.647376-0.2219290.0627230.061458123.500
42.0-1.1582330.8777371.5487180.403034-0.4071930.0959210.592941-0.2705330.817739...-0.0094310.798278-0.1374580.141267-0.2060100.5022920.2194220.21515369.990
\n

5 rows × 31 columns

\n
" 113 | }, 114 | "metadata": {}, 115 | "execution_count": 327 116 | } 117 | ], 118 | "source": [ 119 | "df = pd.read_csv(DATA_PATH, sep=',')\n", 120 | "df.head()" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "df_base = df.copy()" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 282, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "cols = df_base.columns" 139 | ] 140 | }, 141 | { 142 | "source": [ 143 | "We need to normalize Time and Amount" 144 | ], 145 | "cell_type": "markdown", 146 | "metadata": {} 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 184, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "mean_time=df_base['Time'].mean()\n", 155 | "mean_amount=df_base['Amount'].mean()\n", 156 | "std_time=df_base['Time'].std()\n", 157 | "std_amount=df_base['Amount'].std()\n", 158 | "\n", 159 | "df_base['Time']=(df_base['Time']-mean_time)/std_time\n", 160 | "df_base['Amount']=(df_base['Amount']-mean_amount)/std_amount" 161 | ] 162 | }, 163 | { 164 | "source": [ 165 | "Class=1 means that this was indeed a fraud case, class=0 means no fraud. This dataset is highly imbalanced:" 166 | ], 167 | "cell_type": "markdown", 168 | "metadata": {} 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 185, 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "output_type": "execute_result", 177 | "data": { 178 | "text/plain": [ 179 | "0 284315\n", 180 | "1 492\n", 181 | "Name: Class, dtype: int64" 182 | ] 183 | }, 184 | "metadata": {}, 185 | "execution_count": 185 186 | } 187 | ], 188 | "source": [ 189 | "df_base['Class'].value_counts()" 190 | ] 191 | }, 192 | { 193 | "source": [ 194 | "I want to create fake data based on the 492 cases, which I will then use to improve the model. Let's first train a simple RandomForest." 195 | ], 196 | "cell_type": "markdown", 197 | "metadata": {} 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 186, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "X_train, X_test, y_train, y_test = mlp.split_df(df_base, dep_var='Class', test_size=0.3, split_mode='random')\n" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 368, 211 | "metadata": {}, 212 | "outputs": [ 213 | { 214 | "output_type": "execute_result", 215 | "data": { 216 | "text/plain": [ 217 | "0 85286\n", 218 | "1 157\n", 219 | "Name: Class, dtype: int64" 220 | ] 221 | }, 222 | "metadata": {}, 223 | "execution_count": 368 224 | } 225 | ], 226 | "source": [ 227 | "y_test.value_counts()" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 369, 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "output_type": "execute_result", 237 | "data": { 238 | "text/plain": [ 239 | "543.2229299363057" 240 | ] 241 | }, 242 | "metadata": {}, 243 | "execution_count": 369 244 | } 245 | ], 246 | "source": [ 247 | "#Ratio of the two classes:\n", 248 | "y_test.value_counts()[0]/y_test.value_counts()[1]" 249 | ] 250 | }, 251 | { 252 | "source": [ 253 | "RandomForest with Oversampling" 254 | ], 255 | "cell_type": "markdown", 256 | "metadata": {} 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 406, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "def rf(xs, y, n_estimators=40, max_samples=500,\n", 265 | " max_features=0.5, min_samples_leaf=5, **kwargs):\n", 266 | " return RandomForestClassifier(n_jobs=-1, n_estimators=n_estimators,\n", 267 | " max_samples=max_samples, max_features=max_features,\n", 268 | " min_samples_leaf=min_samples_leaf, oob_score=True, class_weight={0:1,1:543}).fit(xs, y)" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 407, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "m = rf(X_train, y_train)" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 408, 283 | "metadata": {}, 284 | "outputs": [ 285 | { 286 | "output_type": "execute_result", 287 | "data": { 288 | "text/plain": [ 289 | "array([[85278, 8],\n", 290 | " [ 118, 39]], dtype=int64)" 291 | ] 292 | }, 293 | "metadata": {}, 294 | "execution_count": 408 295 | } 296 | ], 297 | "source": [ 298 | "confusion_matrix(y_test, np.round(m.predict(X_test)))" 299 | ] 300 | }, 301 | { 302 | "source": [ 303 | "With this technique we get about 39 out of 157 Fraud cases, although the results vary quite a lot!" 304 | ], 305 | "cell_type": "markdown", 306 | "metadata": {} 307 | }, 308 | { 309 | "source": [ 310 | "# Fake Data with VAE" 311 | ], 312 | "cell_type": "markdown", 313 | "metadata": {} 314 | }, 315 | { 316 | "source": [ 317 | "We want only where y_train/test_train =1" 318 | ], 319 | "cell_type": "markdown", 320 | "metadata": {} 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 264, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "X_train_fraud = X_train.iloc[np.where(y_train==1)[0]]\n", 329 | "X_test_fraud = X_test.iloc[np.where(y_test==1)[0]]" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 265, 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "from torch.utils.data import Dataset, DataLoader\n", 339 | "class DataBuilder(Dataset):\n", 340 | " def __init__(self, dataset):\n", 341 | " self.x = dataset.values\n", 342 | " self.x = torch.from_numpy(self.x).to(torch.float)\n", 343 | " self.len=self.x.shape[0]\n", 344 | " def __getitem__(self,index): \n", 345 | " return self.x[index]\n", 346 | " def __len__(self):\n", 347 | " return self.len\n" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": 266, 353 | "metadata": {}, 354 | "outputs": [], 355 | "source": [ 356 | "traindata_set=DataBuilder(X_train_fraud)\n", 357 | "testdata_set=DataBuilder(X_test_fraud)\n", 358 | "\n", 359 | "trainloader=DataLoader(dataset=traindata_set,batch_size=1024)\n", 360 | "testloader=DataLoader(dataset=testdata_set,batch_size=1024)" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 267, 366 | "metadata": {}, 367 | "outputs": [], 368 | "source": [ 369 | "class Autoencoder(nn.Module):\n", 370 | " def __init__(self,D_in,H=50,H2=12,latent_dim=3):\n", 371 | " \n", 372 | " #Encoder\n", 373 | " super(Autoencoder,self).__init__()\n", 374 | " self.linear1=nn.Linear(D_in,H)\n", 375 | " self.lin_bn1 = nn.BatchNorm1d(num_features=H)\n", 376 | " self.linear2=nn.Linear(H,H2)\n", 377 | " self.lin_bn2 = nn.BatchNorm1d(num_features=H2)\n", 378 | " self.linear3=nn.Linear(H2,H2)\n", 379 | " self.lin_bn3 = nn.BatchNorm1d(num_features=H2)\n", 380 | " \n", 381 | " # Latent vectors mu and sigma\n", 382 | " self.fc1 = nn.Linear(H2, latent_dim)\n", 383 | " self.bn1 = nn.BatchNorm1d(num_features=latent_dim)\n", 384 | " self.fc21 = nn.Linear(latent_dim, latent_dim)\n", 385 | " self.fc22 = nn.Linear(latent_dim, latent_dim)\n", 386 | "\n", 387 | " # Sampling vector\n", 388 | " self.fc3 = nn.Linear(latent_dim, latent_dim)\n", 389 | " self.fc_bn3 = nn.BatchNorm1d(latent_dim)\n", 390 | " self.fc4 = nn.Linear(latent_dim, H2)\n", 391 | " self.fc_bn4 = nn.BatchNorm1d(H2)\n", 392 | " \n", 393 | " # Decoder\n", 394 | " self.linear4=nn.Linear(H2,H2)\n", 395 | " self.lin_bn4 = nn.BatchNorm1d(num_features=H2)\n", 396 | " self.linear5=nn.Linear(H2,H)\n", 397 | " self.lin_bn5 = nn.BatchNorm1d(num_features=H)\n", 398 | " self.linear6=nn.Linear(H,D_in)\n", 399 | " self.lin_bn6 = nn.BatchNorm1d(num_features=D_in)\n", 400 | " \n", 401 | " self.relu = nn.ReLU()\n", 402 | " \n", 403 | " def encode(self, x):\n", 404 | " lin1 = self.relu(self.lin_bn1(self.linear1(x)))\n", 405 | " lin2 = self.relu(self.lin_bn2(self.linear2(lin1)))\n", 406 | " lin3 = self.relu(self.lin_bn3(self.linear3(lin2)))\n", 407 | "\n", 408 | " fc1 = F.relu(self.bn1(self.fc1(lin3)))\n", 409 | "\n", 410 | " r1 = self.fc21(fc1)\n", 411 | " r2 = self.fc22(fc1)\n", 412 | " \n", 413 | " return r1, r2\n", 414 | " \n", 415 | " def reparameterize(self, mu, logvar):\n", 416 | " if self.training:\n", 417 | " std = logvar.mul(0.5).exp_()\n", 418 | " eps = Variable(std.data.new(std.size()).normal_())\n", 419 | " return eps.mul(std).add_(mu)\n", 420 | " else:\n", 421 | " return mu\n", 422 | " \n", 423 | " def decode(self, z):\n", 424 | " fc3 = self.relu(self.fc_bn3(self.fc3(z)))\n", 425 | " fc4 = self.relu(self.fc_bn4(self.fc4(fc3)))\n", 426 | "\n", 427 | " lin4 = self.relu(self.lin_bn4(self.linear4(fc4)))\n", 428 | " lin5 = self.relu(self.lin_bn5(self.linear5(lin4)))\n", 429 | " return self.lin_bn6(self.linear6(lin5))\n", 430 | "\n", 431 | "\n", 432 | " \n", 433 | " def forward(self, x):\n", 434 | " mu, logvar = self.encode(x)\n", 435 | " z = self.reparameterize(mu, logvar)\n", 436 | " return self.decode(z), mu, logvar" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": 268, 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "class customLoss(nn.Module):\n", 446 | " def __init__(self):\n", 447 | " super(customLoss, self).__init__()\n", 448 | " self.mse_loss = nn.MSELoss(reduction=\"sum\")\n", 449 | " \n", 450 | " # x_recon ist der im forward im Model erstellte recon_batch, x ist der originale x Batch, mu ist mu und logvar ist logvar \n", 451 | " def forward(self, x_recon, x, mu, logvar):\n", 452 | " loss_MSE = self.mse_loss(x_recon, x)\n", 453 | " loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n", 454 | "\n", 455 | " return loss_MSE + loss_KLD" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 269, 461 | "metadata": {}, 462 | "outputs": [], 463 | "source": [ 464 | "D_in = traindata_set.x.shape[1]\n", 465 | "H = 50\n", 466 | "H2 = 12\n", 467 | "model = Autoencoder(D_in, H, H2).to(device)\n", 468 | "optimizer = optim.Adam(model.parameters(), lr=1e-3)" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": 270, 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [ 477 | "loss_mse = customLoss()" 478 | ] 479 | }, 480 | { 481 | "source": [ 482 | "## Train Model" 483 | ], 484 | "cell_type": "markdown", 485 | "metadata": {} 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": 271, 490 | "metadata": {}, 491 | "outputs": [], 492 | "source": [ 493 | "log_interval = 50\n", 494 | "val_losses = []\n", 495 | "train_losses = []\n", 496 | "test_losses = []" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": 272, 502 | "metadata": {}, 503 | "outputs": [], 504 | "source": [ 505 | "def train(epoch):\n", 506 | " model.train()\n", 507 | " train_loss = 0\n", 508 | " for batch_idx, data in enumerate(trainloader):\n", 509 | " data = data.to(device)\n", 510 | " optimizer.zero_grad()\n", 511 | " recon_batch, mu, logvar = model(data)\n", 512 | " loss = loss_mse(recon_batch, data, mu, logvar)\n", 513 | " loss.backward()\n", 514 | " train_loss += loss.item()\n", 515 | " optimizer.step()\n", 516 | " if epoch % 200 == 0: \n", 517 | " print('====> Epoch: {} Average training loss: {:.4f}'.format(\n", 518 | " epoch, train_loss / len(trainloader.dataset)))\n", 519 | " train_losses.append(train_loss / len(trainloader.dataset))" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": 273, 525 | "metadata": {}, 526 | "outputs": [], 527 | "source": [ 528 | "def test(epoch):\n", 529 | " with torch.no_grad():\n", 530 | " test_loss = 0\n", 531 | " for batch_idx, data in enumerate(testloader):\n", 532 | " data = data.to(device)\n", 533 | " optimizer.zero_grad()\n", 534 | " recon_batch, mu, logvar = model(data)\n", 535 | " loss = loss_mse(recon_batch, data, mu, logvar)\n", 536 | " test_loss += loss.item()\n", 537 | " if epoch % 200 == 0: \n", 538 | " print('====> Epoch: {} Average test loss: {:.4f}'.format(\n", 539 | " epoch, test_loss / len(testloader.dataset)))\n", 540 | " test_losses.append(test_loss / len(testloader.dataset))" 541 | ] 542 | }, 543 | { 544 | "cell_type": "code", 545 | "execution_count": 274, 546 | "metadata": {}, 547 | "outputs": [ 548 | { 549 | "output_type": "stream", 550 | "name": "stdout", 551 | "text": [ 552 | "====> Epoch: 200 Average training loss: 706.2121\n", 553 | "====> Epoch: 200 Average test loss: 590.0016\n", 554 | "====> Epoch: 400 Average training loss: 620.5279\n", 555 | "====> Epoch: 400 Average test loss: 521.3142\n", 556 | "====> Epoch: 600 Average training loss: 566.4392\n", 557 | "====> Epoch: 600 Average test loss: 477.5008\n", 558 | "====> Epoch: 800 Average training loss: 521.7474\n", 559 | "====> Epoch: 800 Average test loss: 440.3243\n", 560 | "====> Epoch: 1000 Average training loss: 481.2092\n", 561 | "====> Epoch: 1000 Average test loss: 407.7625\n", 562 | "====> Epoch: 1200 Average training loss: 434.3898\n", 563 | "====> Epoch: 1200 Average test loss: 362.2760\n", 564 | "====> Epoch: 1400 Average training loss: 396.9551\n", 565 | "====> Epoch: 1400 Average test loss: 343.7408\n" 566 | ] 567 | } 568 | ], 569 | "source": [ 570 | "epochs = 1500\n", 571 | "for epoch in range(1, epochs + 1):\n", 572 | " train(epoch)\n", 573 | " test(epoch)" 574 | ] 575 | }, 576 | { 577 | "source": [ 578 | "We're still improving so keep going " 579 | ], 580 | "cell_type": "markdown", 581 | "metadata": {} 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": 275, 586 | "metadata": {}, 587 | "outputs": [ 588 | { 589 | "output_type": "stream", 590 | "name": "stdout", 591 | "text": [ 592 | "====> Epoch: 200 Average training loss: 343.3472\n", 593 | "====> Epoch: 200 Average test loss: 300.3575\n", 594 | "====> Epoch: 400 Average training loss: 310.5800\n", 595 | "====> Epoch: 400 Average test loss: 285.6697\n", 596 | "====> Epoch: 600 Average training loss: 281.8408\n", 597 | "====> Epoch: 600 Average test loss: 263.7150\n", 598 | "====> Epoch: 800 Average training loss: 256.1950\n", 599 | "====> Epoch: 800 Average test loss: 244.9427\n", 600 | "====> Epoch: 1000 Average training loss: 232.6077\n", 601 | "====> Epoch: 1000 Average test loss: 236.3014\n", 602 | "====> Epoch: 1200 Average training loss: 211.2899\n", 603 | "====> Epoch: 1200 Average test loss: 217.6404\n", 604 | "====> Epoch: 1400 Average training loss: 191.3525\n", 605 | "====> Epoch: 1400 Average test loss: 205.8287\n", 606 | "====> Epoch: 1600 Average training loss: 174.0826\n", 607 | "====> Epoch: 1600 Average test loss: 189.0589\n", 608 | "====> Epoch: 1800 Average training loss: 157.4292\n", 609 | "====> Epoch: 1800 Average test loss: 175.6006\n", 610 | "====> Epoch: 2000 Average training loss: 143.2475\n", 611 | "====> Epoch: 2000 Average test loss: 177.1668\n", 612 | "====> Epoch: 2200 Average training loss: 129.9684\n", 613 | "====> Epoch: 2200 Average test loss: 160.4641\n", 614 | "====> Epoch: 2400 Average training loss: 117.6745\n", 615 | "====> Epoch: 2400 Average test loss: 150.9483\n" 616 | ] 617 | } 618 | ], 619 | "source": [ 620 | "epochs = 2500\n", 621 | "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n", 622 | "for epoch in range(1, epochs + 1):\n", 623 | " train(epoch)\n", 624 | " test(epoch)" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": 278, 630 | "metadata": {}, 631 | "outputs": [ 632 | { 633 | "output_type": "stream", 634 | "name": "stdout", 635 | "text": [ 636 | "====> Epoch: 200 Average training loss: 54.6816\n", 637 | "====> Epoch: 200 Average test loss: 129.6853\n", 638 | "====> Epoch: 400 Average training loss: 48.5159\n", 639 | "====> Epoch: 400 Average test loss: 134.4429\n" 640 | ] 641 | } 642 | ], 643 | "source": [ 644 | "epochs = 500\n", 645 | "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n", 646 | "for epoch in range(1, epochs + 1):\n", 647 | " train(epoch)\n", 648 | " test(epoch)" 649 | ] 650 | }, 651 | { 652 | "source": [ 653 | "Let's look at the results:" 654 | ], 655 | "cell_type": "markdown", 656 | "metadata": {} 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": 279, 661 | "metadata": {}, 662 | "outputs": [], 663 | "source": [ 664 | "with torch.no_grad():\n", 665 | " for batch_idx, data in enumerate(testloader):\n", 666 | " data = data.to(device)\n", 667 | " optimizer.zero_grad()\n", 668 | " recon_batch, mu, logvar = model(data)" 669 | ] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "execution_count": 288, 674 | "metadata": {}, 675 | "outputs": [], 676 | "source": [ 677 | "recon_row = recon_batch[0].cpu().numpy()\n", 678 | "recon_row = np.append(recon_row, [1])\n", 679 | "real_row = testloader.dataset.x[0].cpu().numpy()\n", 680 | "real_row = np.append(real_row, [1])" 681 | ] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "execution_count": 290, 686 | "metadata": {}, 687 | "outputs": [ 688 | { 689 | "output_type": "execute_result", 690 | "data": { 691 | "text/plain": [ 692 | " Time V1 V2 V3 V4 V5 V6 \\\n", 693 | "0 -0.196971 -7.667089 5.699276 -10.15090 10.077229 -7.307253 -2.589641 \n", 694 | "1 0.910404 -5.839191 7.151532 -12.81676 7.031115 -9.651272 -2.938427 \n", 695 | "\n", 696 | " V7 V8 V9 ... V21 V22 V23 V24 \\\n", 697 | "0 -9.824335 3.019747 -7.658296 ... 1.073921 0.034662 0.247951 0.00464 \n", 698 | "1 -11.543207 4.843626 -3.494276 ... 2.462056 1.054865 0.530481 0.47267 \n", 699 | "\n", 700 | " V25 V26 V27 V28 Amount Class \n", 701 | "0 -0.037674 0.597619 0.763070 -0.609457 -0.377716 1.0 \n", 702 | "1 -0.275998 0.282435 0.104886 0.254417 0.910404 1.0 \n", 703 | "\n", 704 | "[2 rows x 31 columns]" 705 | ], 706 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
TimeV1V2V3V4V5V6V7V8V9...V21V22V23V24V25V26V27V28AmountClass
0-0.196971-7.6670895.699276-10.1509010.077229-7.307253-2.589641-9.8243353.019747-7.658296...1.0739210.0346620.2479510.00464-0.0376740.5976190.763070-0.609457-0.3777161.0
10.910404-5.8391917.151532-12.816767.031115-9.651272-2.938427-11.5432074.843626-3.494276...2.4620561.0548650.5304810.47267-0.2759980.2824350.1048860.2544170.9104041.0
\n

2 rows × 31 columns

\n
" 707 | }, 708 | "metadata": {}, 709 | "execution_count": 290 710 | } 711 | ], 712 | "source": [ 713 | "df = pd.DataFrame(np.stack((recon_row, real_row)), columns = cols)\n", 714 | "df" 715 | ] 716 | }, 717 | { 718 | "cell_type": "code", 719 | "execution_count": 293, 720 | "metadata": {}, 721 | "outputs": [], 722 | "source": [ 723 | "sigma = torch.exp(logvar/2)" 724 | ] 725 | }, 726 | { 727 | "cell_type": "code", 728 | "execution_count": 294, 729 | "metadata": {}, 730 | "outputs": [ 731 | { 732 | "output_type": "execute_result", 733 | "data": { 734 | "text/plain": [ 735 | "(tensor([0.0001, 0.0163, 0.0400]), tensor([0.9976, 0.0370, 0.0381]))" 736 | ] 737 | }, 738 | "metadata": {}, 739 | "execution_count": 294 740 | } 741 | ], 742 | "source": [ 743 | "mu.mean(axis=0), sigma.mean(axis=0)" 744 | ] 745 | }, 746 | { 747 | "cell_type": "code", 748 | "execution_count": 295, 749 | "metadata": {}, 750 | "outputs": [], 751 | "source": [ 752 | "# sample z from q\n", 753 | "no_samples = 20\n", 754 | "q = torch.distributions.Normal(mu.mean(axis=0), sigma.mean(axis=0))\n", 755 | "z = q.rsample(sample_shape=torch.Size([no_samples]))" 756 | ] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": 318, 761 | "metadata": {}, 762 | "outputs": [], 763 | "source": [ 764 | "with torch.no_grad():\n", 765 | " pred = model.decode(z).cpu().numpy()" 766 | ] 767 | }, 768 | { 769 | "cell_type": "code", 770 | "execution_count": 324, 771 | "metadata": {}, 772 | "outputs": [ 773 | { 774 | "output_type": "execute_result", 775 | "data": { 776 | "text/plain": [ 777 | " Time V1 V2 V3 V4 V5 V6 \\\n", 778 | "0 -1.014143 1.505616 -4.616234 7.718655 -0.977422 8.594662 -3.198405 \n", 779 | "1 -1.810440 -13.005595 1.212420 5.370727 2.069537 -1.141557 -3.816671 \n", 780 | "2 -1.152523 12.006341 -3.014931 4.485871 -1.155190 10.059814 -3.355832 \n", 781 | "3 0.228914 -5.935965 -1.644437 -6.354884 7.788726 -0.055751 -1.726003 \n", 782 | "4 0.180823 -3.444491 4.722339 -4.571048 4.998073 -4.543203 -0.816252 \n", 783 | "\n", 784 | " V7 V8 V9 ... V21 V22 V23 V24 \\\n", 785 | "0 -6.944025 -5.043085 2.561653 ... 1.094700 0.510489 -1.254657 -0.085117 \n", 786 | "1 -6.958980 4.140651 -1.208175 ... 0.902933 -0.573067 1.209823 0.543091 \n", 787 | "2 -8.342437 -8.336978 2.741910 ... -0.101801 1.417866 -2.335097 0.034988 \n", 788 | "3 0.577209 1.638260 -5.880371 ... -5.350942 2.994604 -0.079382 -1.020990 \n", 789 | "4 -5.482205 3.643872 -4.685173 ... -1.748235 1.525022 0.258438 -0.465014 \n", 790 | "\n", 791 | " V25 V26 V27 V28 Amount Class \n", 792 | "0 0.283567 -0.268765 3.025049 0.929408 -79.125496 1 \n", 793 | "1 0.666637 -0.524895 0.204588 -0.074243 -380.632935 1 \n", 794 | "2 -0.466923 -0.012957 2.653872 1.081970 -163.960175 1 \n", 795 | "3 -0.090167 0.395981 -1.590370 -1.090804 9.417862 1 \n", 796 | "4 0.064509 0.277528 1.127516 0.161839 171.483337 1 \n", 797 | "\n", 798 | "[5 rows x 31 columns]" 799 | ], 800 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
TimeV1V2V3V4V5V6V7V8V9...V21V22V23V24V25V26V27V28AmountClass
0-1.0141431.505616-4.6162347.718655-0.9774228.594662-3.198405-6.944025-5.0430852.561653...1.0947000.510489-1.254657-0.0851170.283567-0.2687653.0250490.929408-79.1254961
1-1.810440-13.0055951.2124205.3707272.069537-1.141557-3.816671-6.9589804.140651-1.208175...0.902933-0.5730671.2098230.5430910.666637-0.5248950.204588-0.074243-380.6329351
2-1.15252312.006341-3.0149314.485871-1.15519010.059814-3.355832-8.342437-8.3369782.741910...-0.1018011.417866-2.3350970.034988-0.466923-0.0129572.6538721.081970-163.9601751
30.228914-5.935965-1.644437-6.3548847.788726-0.055751-1.7260030.5772091.638260-5.880371...-5.3509422.994604-0.079382-1.020990-0.0901670.395981-1.590370-1.0908049.4178621
40.180823-3.4444914.722339-4.5710484.998073-4.543203-0.816252-5.4822053.643872-4.685173...-1.7482351.5250220.258438-0.4650140.0645090.2775281.1275160.161839171.4833371
\n

5 rows × 31 columns

\n
" 801 | }, 802 | "metadata": {}, 803 | "execution_count": 324 804 | } 805 | ], 806 | "source": [ 807 | "df_fake = pd.DataFrame(pred)\n", 808 | "df_fake['Class']=1\n", 809 | "df_fake.columns = cols\n", 810 | "df_fake['Class'] = np.round(df_fake['Class']).astype(int)\n", 811 | "df_fake['Time'] = (df_fake['Time']*std_time)+mean_time\n", 812 | "df_fake['Amount'] = (df_fake['Amount']*std_amount)+mean_amount\n", 813 | "df_fake.head()" 814 | ] 815 | }, 816 | { 817 | "cell_type": "code", 818 | "execution_count": 325, 819 | "metadata": {}, 820 | "outputs": [ 821 | { 822 | "output_type": "execute_result", 823 | "data": { 824 | "text/plain": [ 825 | "121.77293" 826 | ] 827 | }, 828 | "metadata": {}, 829 | "execution_count": 325 830 | } 831 | ], 832 | "source": [ 833 | "df_fake['Amount'].mean()" 834 | ] 835 | }, 836 | { 837 | "cell_type": "code", 838 | "execution_count": 338, 839 | "metadata": {}, 840 | "outputs": [ 841 | { 842 | "output_type": "execute_result", 843 | "data": { 844 | "text/plain": [ 845 | "Class\n", 846 | "0 88.291022\n", 847 | "1 122.211321\n", 848 | "Name: Amount, dtype: float64" 849 | ] 850 | }, 851 | "metadata": {}, 852 | "execution_count": 338 853 | } 854 | ], 855 | "source": [ 856 | "df.groupby('Class').mean()['Amount']" 857 | ] 858 | }, 859 | { 860 | "source": [ 861 | "Use fake data for oversampling in RandomForest" 862 | ], 863 | "cell_type": "markdown", 864 | "metadata": {} 865 | }, 866 | { 867 | "cell_type": "code", 868 | "execution_count": 344, 869 | "metadata": {}, 870 | "outputs": [ 871 | { 872 | "output_type": "execute_result", 873 | "data": { 874 | "text/plain": [ 875 | "0 199029\n", 876 | "1 335\n", 877 | "Name: Class, dtype: int64" 878 | ] 879 | }, 880 | "metadata": {}, 881 | "execution_count": 344 882 | } 883 | ], 884 | "source": [ 885 | "y_train.value_counts()" 886 | ] 887 | }, 888 | { 889 | "source": [ 890 | "So let's build about 190.000 fake fraud detection cases:" 891 | ], 892 | "cell_type": "markdown", 893 | "metadata": {} 894 | }, 895 | { 896 | "cell_type": "code", 897 | "execution_count": 346, 898 | "metadata": {}, 899 | "outputs": [], 900 | "source": [ 901 | "no_samples = 190_000\n", 902 | "q = torch.distributions.Normal(mu.mean(axis=0), sigma.mean(axis=0))\n", 903 | "z = q.rsample(sample_shape=torch.Size([no_samples]))" 904 | ] 905 | }, 906 | { 907 | "cell_type": "code", 908 | "execution_count": 347, 909 | "metadata": {}, 910 | "outputs": [], 911 | "source": [ 912 | "with torch.no_grad():\n", 913 | " pred = model.decode(z).cpu().numpy()" 914 | ] 915 | }, 916 | { 917 | "source": [ 918 | "Concat to our X_train:" 919 | ], 920 | "cell_type": "markdown", 921 | "metadata": {} 922 | }, 923 | { 924 | "cell_type": "code", 925 | "execution_count": 365, 926 | "metadata": {}, 927 | "outputs": [ 928 | { 929 | "output_type": "execute_result", 930 | "data": { 931 | "text/plain": [ 932 | "(389364, 30)" 933 | ] 934 | }, 935 | "metadata": {}, 936 | "execution_count": 365 937 | } 938 | ], 939 | "source": [ 940 | "X_train_augmented = np.vstack((X_train.values, pred))\n", 941 | "y_train_augmented = np.append(y_train.values, np.repeat(1,no_samples))\n", 942 | "X_train_augmented.shape" 943 | ] 944 | }, 945 | { 946 | "source": [ 947 | "We now have roughly as many fraud cases as we have non-fraud cases. " 948 | ], 949 | "cell_type": "markdown", 950 | "metadata": {} 951 | }, 952 | { 953 | "source": [ 954 | "## Train RandomForest" 955 | ], 956 | "cell_type": "markdown", 957 | "metadata": {} 958 | }, 959 | { 960 | "cell_type": "code", 961 | "execution_count": 409, 962 | "metadata": {}, 963 | "outputs": [], 964 | "source": [ 965 | "def rf_aug(xs, y, n_estimators=40, max_samples=500,\n", 966 | " max_features=0.5, min_samples_leaf=5, **kwargs):\n", 967 | " return RandomForestClassifier(n_jobs=-1, n_estimators=n_estimators,\n", 968 | " max_samples=max_samples, max_features=max_features,\n", 969 | " min_samples_leaf=min_samples_leaf, oob_score=True).fit(xs, y)" 970 | ] 971 | }, 972 | { 973 | "cell_type": "code", 974 | "execution_count": 412, 975 | "metadata": {}, 976 | "outputs": [ 977 | { 978 | "output_type": "execute_result", 979 | "data": { 980 | "text/plain": [ 981 | "array([[84963, 323],\n", 982 | " [ 30, 127]], dtype=int64)" 983 | ] 984 | }, 985 | "metadata": {}, 986 | "execution_count": 412 987 | } 988 | ], 989 | "source": [ 990 | "m_aug = rf_aug(X_train_augmented, y_train_augmented)\n", 991 | "confusion_matrix(y_test, np.round(m_aug.predict(X_test)))" 992 | ] 993 | }, 994 | { 995 | "cell_type": "code", 996 | "execution_count": 413, 997 | "metadata": {}, 998 | "outputs": [ 999 | { 1000 | "output_type": "execute_result", 1001 | "data": { 1002 | "text/plain": [ 1003 | "array([[85278, 8],\n", 1004 | " [ 118, 39]], dtype=int64)" 1005 | ] 1006 | }, 1007 | "metadata": {}, 1008 | "execution_count": 413 1009 | } 1010 | ], 1011 | "source": [ 1012 | "confusion_matrix(y_test, np.round(m.predict(X_test)))" 1013 | ] 1014 | }, 1015 | { 1016 | "source": [ 1017 | "Look at that! We managed to find 127 out of 157! " 1018 | ], 1019 | "cell_type": "markdown", 1020 | "metadata": {} 1021 | }, 1022 | { 1023 | "cell_type": "code", 1024 | "execution_count": null, 1025 | "metadata": {}, 1026 | "outputs": [], 1027 | "source": [] 1028 | } 1029 | ] 1030 | } -------------------------------------------------------------------------------- /Variational_Autoencoder_data_augmentation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "language_info": { 4 | "codemirror_mode": { 5 | "name": "ipython", 6 | "version": 3 7 | }, 8 | "file_extension": ".py", 9 | "mimetype": "text/x-python", 10 | "name": "python", 11 | "nbconvert_exporter": "python", 12 | "pygments_lexer": "ipython3", 13 | "version": "3.9.2-final" 14 | }, 15 | "orig_nbformat": 2, 16 | "kernelspec": { 17 | "name": "python3", 18 | "display_name": "Python 3.9.2 64-bit ('deeplearning_venv': venv)", 19 | "metadata": { 20 | "interpreter": { 21 | "hash": "9139ca13fc640d8623238ac4ed44beace8a76f86a07bab6efe75c2506e18783d" 22 | } 23 | } 24 | } 25 | }, 26 | "nbformat": 4, 27 | "nbformat_minor": 2, 28 | "cells": [ 29 | { 30 | "source": [ 31 | "# Variational Autoencoder" 32 | ], 33 | "cell_type": "markdown", 34 | "metadata": {} 35 | }, 36 | { 37 | "source": [ 38 | "## How to create fake tabular data to enhance machine learning algorithms" 39 | ], 40 | "cell_type": "markdown", 41 | "metadata": {} 42 | }, 43 | { 44 | "source": [ 45 | "To train deeplearning models the more data the better. When we're thinking of image data, the deeplearnig community thought about a lot of tricks how to enhance the model given a dataset of images: image enhancement. Meaning that by rotating, flipping, blurring etc the image we can create more input data and also improve our model. \n", 46 | "\n", 47 | "Hoever, when thinking about tabular data, only few of these techniques exist. In this notebook I want to show you how to create a variational autoencoder to make use of data enhancement. I will create fake data, which is sampled from the learned distribution of the underlying data. " 48 | ], 49 | "cell_type": "markdown", 50 | "metadata": {} 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 215, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "import torch\n", 59 | "import torch.nn as nn\n", 60 | "import torch.nn.functional as F\n", 61 | "from torch import nn, optim\n", 62 | "from torch.autograd import Variable\n", 63 | "from sklearn.decomposition import PCA\n", 64 | "\n", 65 | "import pandas as pd\n", 66 | "import numpy as np\n", 67 | "from sklearn import preprocessing\n", 68 | "from sklearn.model_selection import train_test_split" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 216, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "output_type": "execute_result", 78 | "data": { 79 | "text/plain": [ 80 | "device(type='cpu')" 81 | ] 82 | }, 83 | "metadata": {}, 84 | "execution_count": 216 85 | } 86 | ], 87 | "source": [ 88 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 89 | "device" 90 | ] 91 | }, 92 | { 93 | "source": [ 94 | "### Define path to dataset" 95 | ], 96 | "cell_type": "markdown", 97 | "metadata": {} 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 217, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "DATA_PATH = 'data/wine.csv'" 106 | ] 107 | }, 108 | { 109 | "source": [ 110 | "## Dataset Overview" 111 | ], 112 | "cell_type": "markdown", 113 | "metadata": {} 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 440, 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "output_type": "execute_result", 122 | "data": { 123 | "text/plain": [ 124 | " Wine Alcohol Malic.acid Ash Acl Mg Phenols Flavanoids \\\n", 125 | "0 1 14.23 1.71 2.43 15.6 127 2.80 3.06 \n", 126 | "1 1 13.20 1.78 2.14 11.2 100 2.65 2.76 \n", 127 | "2 1 13.16 2.36 2.67 18.6 101 2.80 3.24 \n", 128 | "3 1 14.37 1.95 2.50 16.8 113 3.85 3.49 \n", 129 | "4 1 13.24 2.59 2.87 21.0 118 2.80 2.69 \n", 130 | "\n", 131 | " Nonflavanoid.phenols Proanth Color.int Hue OD Proline \n", 132 | "0 0.28 2.29 5.64 1.04 3.92 1065 \n", 133 | "1 0.26 1.28 4.38 1.05 3.40 1050 \n", 134 | "2 0.30 2.81 5.68 1.03 3.17 1185 \n", 135 | "3 0.24 2.18 7.80 0.86 3.45 1480 \n", 136 | "4 0.39 1.82 4.32 1.04 2.93 735 " 137 | ], 138 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
WineAlcoholMalic.acidAshAclMgPhenolsFlavanoidsNonflavanoid.phenolsProanthColor.intHueODProline
0114.231.712.4315.61272.803.060.282.295.641.043.921065
1113.201.782.1411.21002.652.760.261.284.381.053.401050
2113.162.362.6718.61012.803.240.302.815.681.033.171185
3114.371.952.5016.81133.853.490.242.187.800.863.451480
4113.242.592.8721.01182.802.690.391.824.321.042.93735
\n
" 139 | }, 140 | "metadata": {}, 141 | "execution_count": 440 142 | } 143 | ], 144 | "source": [ 145 | "df_base = pd.read_csv(DATA_PATH, sep=',')\n", 146 | "df_base.head()" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 441, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "cols = df_base.columns" 156 | ] 157 | }, 158 | { 159 | "source": [ 160 | "## Build Data Loader" 161 | ], 162 | "cell_type": "markdown", 163 | "metadata": {} 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 222, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "def load_and_standardize_data(path):\n", 172 | " # read in from csv\n", 173 | " df = pd.read_csv(path, sep=',')\n", 174 | " # replace nan with -99\n", 175 | " df = df.fillna(-99)\n", 176 | " df = df.values.reshape(-1, df.shape[1]).astype('float32')\n", 177 | " # randomly split\n", 178 | " X_train, X_test = train_test_split(df, test_size=0.3, random_state=42)\n", 179 | " # standardize values\n", 180 | " scaler = preprocessing.StandardScaler()\n", 181 | " X_train = scaler.fit_transform(X_train)\n", 182 | " X_test = scaler.transform(X_test) \n", 183 | " return X_train, X_test, scaler" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 223, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "from torch.utils.data import Dataset, DataLoader\n", 193 | "class DataBuilder(Dataset):\n", 194 | " def __init__(self, path, train=True):\n", 195 | " self.X_train, self.X_test, self.standardizer = load_and_standardize_data(DATA_PATH)\n", 196 | " if train:\n", 197 | " self.x = torch.from_numpy(self.X_train)\n", 198 | " self.len=self.x.shape[0]\n", 199 | " else:\n", 200 | " self.x = torch.from_numpy(self.X_test)\n", 201 | " self.len=self.x.shape[0]\n", 202 | " del self.X_train\n", 203 | " del self.X_test \n", 204 | " def __getitem__(self,index): \n", 205 | " return self.x[index]\n", 206 | " def __len__(self):\n", 207 | " return self.len" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 224, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "traindata_set=DataBuilder(DATA_PATH, train=True)\n", 217 | "testdata_set=DataBuilder(DATA_PATH, train=False)\n", 218 | "\n", 219 | "trainloader=DataLoader(dataset=traindata_set,batch_size=1024)\n", 220 | "testloader=DataLoader(dataset=testdata_set,batch_size=1024)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 225, 226 | "metadata": {}, 227 | "outputs": [ 228 | { 229 | "output_type": "execute_result", 230 | "data": { 231 | "text/plain": [ 232 | "(torch.Tensor, torch.Tensor)" 233 | ] 234 | }, 235 | "metadata": {}, 236 | "execution_count": 225 237 | } 238 | ], 239 | "source": [ 240 | "type(trainloader.dataset.x), type(testloader.dataset.x)" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 226, 246 | "metadata": {}, 247 | "outputs": [ 248 | { 249 | "output_type": "execute_result", 250 | "data": { 251 | "text/plain": [ 252 | "(torch.Size([124, 14]), torch.Size([54, 14]))" 253 | ] 254 | }, 255 | "metadata": {}, 256 | "execution_count": 226 257 | } 258 | ], 259 | "source": [ 260 | "trainloader.dataset.x.shape, testloader.dataset.x.shape" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 227, 266 | "metadata": {}, 267 | "outputs": [ 268 | { 269 | "output_type": "execute_result", 270 | "data": { 271 | "text/plain": [ 272 | "tensor([[ 1.3598, 0.6284, 1.0812, ..., -0.6414, -1.0709, -0.5182],\n", 273 | " [ 0.0628, -0.5409, -0.6130, ..., 0.3465, 1.3308, -0.2151],\n", 274 | " [ 0.0628, -0.7557, -1.2870, ..., 0.4324, -0.3984, 0.0420],\n", 275 | " ...,\n", 276 | " [-1.2343, 1.6904, -0.4855, ..., 1.0338, 0.5485, 2.6682],\n", 277 | " [ 0.0628, -0.3261, -0.7952, ..., 0.0029, -0.7415, -0.7983],\n", 278 | " [ 0.0628, -0.7437, 0.0428, ..., -0.6843, 1.0700, -0.9861]])" 279 | ] 280 | }, 281 | "metadata": {}, 282 | "execution_count": 227 283 | } 284 | ], 285 | "source": [ 286 | "trainloader.dataset.x" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 228, 292 | "metadata": {}, 293 | "outputs": [ 294 | { 295 | "output_type": "execute_result", 296 | "data": { 297 | "text/plain": [ 298 | "" 299 | ] 300 | }, 301 | "metadata": {}, 302 | "execution_count": 228 303 | } 304 | ], 305 | "source": [ 306 | "trainloader.dataset.standardizer.inverse_transform" 307 | ] 308 | }, 309 | { 310 | "source": [ 311 | "## Build model" 312 | ], 313 | "cell_type": "markdown", 314 | "metadata": {} 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 229, 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "class Autoencoder(nn.Module):\n", 323 | " def __init__(self,D_in,H=50,H2=12,latent_dim=3):\n", 324 | " \n", 325 | " #Encoder\n", 326 | " super(Autoencoder,self).__init__()\n", 327 | " self.linear1=nn.Linear(D_in,H)\n", 328 | " self.lin_bn1 = nn.BatchNorm1d(num_features=H)\n", 329 | " self.linear2=nn.Linear(H,H2)\n", 330 | " self.lin_bn2 = nn.BatchNorm1d(num_features=H2)\n", 331 | " self.linear3=nn.Linear(H2,H2)\n", 332 | " self.lin_bn3 = nn.BatchNorm1d(num_features=H2)\n", 333 | " \n", 334 | "# # Latent vectors mu and sigma\n", 335 | " self.fc1 = nn.Linear(H2, latent_dim)\n", 336 | " self.bn1 = nn.BatchNorm1d(num_features=latent_dim)\n", 337 | " self.fc21 = nn.Linear(latent_dim, latent_dim)\n", 338 | " self.fc22 = nn.Linear(latent_dim, latent_dim)\n", 339 | "\n", 340 | "# # Sampling vector\n", 341 | " self.fc3 = nn.Linear(latent_dim, latent_dim)\n", 342 | " self.fc_bn3 = nn.BatchNorm1d(latent_dim)\n", 343 | " self.fc4 = nn.Linear(latent_dim, H2)\n", 344 | " self.fc_bn4 = nn.BatchNorm1d(H2)\n", 345 | " \n", 346 | "# # Decoder\n", 347 | " self.linear4=nn.Linear(H2,H2)\n", 348 | " self.lin_bn4 = nn.BatchNorm1d(num_features=H2)\n", 349 | " self.linear5=nn.Linear(H2,H)\n", 350 | " self.lin_bn5 = nn.BatchNorm1d(num_features=H)\n", 351 | " self.linear6=nn.Linear(H,D_in)\n", 352 | " self.lin_bn6 = nn.BatchNorm1d(num_features=D_in)\n", 353 | " \n", 354 | " self.relu = nn.ReLU()\n", 355 | " \n", 356 | " def encode(self, x):\n", 357 | " lin1 = self.relu(self.lin_bn1(self.linear1(x)))\n", 358 | " lin2 = self.relu(self.lin_bn2(self.linear2(lin1)))\n", 359 | " lin3 = self.relu(self.lin_bn3(self.linear3(lin2)))\n", 360 | "\n", 361 | " fc1 = F.relu(self.bn1(self.fc1(lin3)))\n", 362 | "\n", 363 | " r1 = self.fc21(fc1)\n", 364 | " r2 = self.fc22(fc1)\n", 365 | " \n", 366 | " return r1, r2\n", 367 | " \n", 368 | " def reparameterize(self, mu, logvar):\n", 369 | " if self.training:\n", 370 | " std = logvar.mul(0.5).exp_()\n", 371 | " eps = Variable(std.data.new(std.size()).normal_())\n", 372 | " return eps.mul(std).add_(mu)\n", 373 | " else:\n", 374 | " return mu\n", 375 | " \n", 376 | " def decode(self, z):\n", 377 | " fc3 = self.relu(self.fc_bn3(self.fc3(z)))\n", 378 | " fc4 = self.relu(self.fc_bn4(self.fc4(fc3)))\n", 379 | "\n", 380 | " lin4 = self.relu(self.lin_bn4(self.linear4(fc4)))\n", 381 | " lin5 = self.relu(self.lin_bn5(self.linear5(lin4)))\n", 382 | " return self.lin_bn6(self.linear6(lin5))\n", 383 | "\n", 384 | "\n", 385 | " \n", 386 | " def forward(self, x):\n", 387 | " mu, logvar = self.encode(x)\n", 388 | " z = self.reparameterize(mu, logvar)\n", 389 | " # self.decode(z) ist später recon_batch, mu ist mu und logvar ist logvar\n", 390 | " return self.decode(z), mu, logvar" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": 230, 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "class customLoss(nn.Module):\n", 400 | " def __init__(self):\n", 401 | " super(customLoss, self).__init__()\n", 402 | " self.mse_loss = nn.MSELoss(reduction=\"sum\")\n", 403 | " \n", 404 | " # x_recon ist der im forward im Model erstellte recon_batch, x ist der originale x Batch, mu ist mu und logvar ist logvar \n", 405 | " def forward(self, x_recon, x, mu, logvar):\n", 406 | " loss_MSE = self.mse_loss(x_recon, x)\n", 407 | " loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n", 408 | "\n", 409 | " return loss_MSE + loss_KLD" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 231, 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "# takes in a module and applies the specified weight initialization\n", 419 | "def weights_init_uniform_rule(m):\n", 420 | " classname = m.__class__.__name__\n", 421 | " # for every Linear layer in a model..\n", 422 | " if classname.find('Linear') != -1:\n", 423 | " # get the number of the inputs\n", 424 | " n = m.in_features\n", 425 | " y = 1.0/np.sqrt(n)\n", 426 | " m.weight.data.uniform_(-y, y)\n", 427 | " m.bias.data.fill_(0)" 428 | ] 429 | }, 430 | { 431 | "source": [ 432 | "If you want to better understand the variational autoencoder technique, look [here](https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73).\n", 433 | "\n", 434 | "For better understanding this AutoencoderClass, let me go briefly through it. This is a variational autoencoder (VAE) with two hidden layers, which (by default, but you can change this) 50 and then 12 activations. The latent factors are set to 3 (you can change that, too). So we're first exploding our initially 14 variables to 50 activations, then condensing it to 12, then to 3. From these 3 latent factors we then sample to recreate the original 14 values. We do that by inflating the 3 latent factors back to 12, then 50 and finally 14 activations (we decode the latent factors so to speak). With this reconstructed batch (recon_batch) we compare it with the original batch, computate our loss and adjust the weights and biases via our gradient (our optimizer here will be Adam). " 435 | ], 436 | "cell_type": "markdown", 437 | "metadata": {} 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": 232, 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "D_in = data_set.x.shape[1]\n", 446 | "H = 50\n", 447 | "H2 = 12\n", 448 | "model = Autoencoder(D_in, H, H2).to(device)\n", 449 | "optimizer = optim.Adam(model.parameters(), lr=1e-3)" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": 233, 455 | "metadata": {}, 456 | "outputs": [], 457 | "source": [ 458 | "loss_mse = customLoss()" 459 | ] 460 | }, 461 | { 462 | "source": [ 463 | "## Train Model" 464 | ], 465 | "cell_type": "markdown", 466 | "metadata": {} 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 234, 471 | "metadata": {}, 472 | "outputs": [], 473 | "source": [ 474 | "epochs = 1500\n", 475 | "log_interval = 50\n", 476 | "val_losses = []\n", 477 | "train_losses = []\n", 478 | "test_losses = []" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 235, 484 | "metadata": {}, 485 | "outputs": [], 486 | "source": [ 487 | "def train(epoch):\n", 488 | " model.train()\n", 489 | " train_loss = 0\n", 490 | " for batch_idx, data in enumerate(trainloader):\n", 491 | " data = data.to(device)\n", 492 | " optimizer.zero_grad()\n", 493 | " recon_batch, mu, logvar = model(data)\n", 494 | " loss = loss_mse(recon_batch, data, mu, logvar)\n", 495 | " loss.backward()\n", 496 | " train_loss += loss.item()\n", 497 | " optimizer.step()\n", 498 | " if epoch % 200 == 0: \n", 499 | " print('====> Epoch: {} Average training loss: {:.4f}'.format(\n", 500 | " epoch, train_loss / len(trainloader.dataset)))\n", 501 | " train_losses.append(train_loss / len(trainloader.dataset))" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": 236, 507 | "metadata": {}, 508 | "outputs": [], 509 | "source": [ 510 | "def test(epoch):\n", 511 | " with torch.no_grad():\n", 512 | " test_loss = 0\n", 513 | " for batch_idx, data in enumerate(testloader):\n", 514 | " data = data.to(device)\n", 515 | " optimizer.zero_grad()\n", 516 | " recon_batch, mu, logvar = model(data)\n", 517 | " loss = loss_mse(recon_batch, data, mu, logvar)\n", 518 | " test_loss += loss.item()\n", 519 | " if epoch % 200 == 0: \n", 520 | " print('====> Epoch: {} Average test loss: {:.4f}'.format(\n", 521 | " epoch, test_loss / len(testloader.dataset)))\n", 522 | " test_losses.append(test_loss / len(testloader.dataset))" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": 237, 528 | "metadata": {}, 529 | "outputs": [ 530 | { 531 | "output_type": "stream", 532 | "name": "stdout", 533 | "text": [ 534 | "====> Epoch: 200 Average training loss: 12.3501\n", 535 | "====> Epoch: 200 Average test loss: 11.7777\n", 536 | "====> Epoch: 400 Average training loss: 10.1168\n", 537 | "====> Epoch: 400 Average test loss: 8.9987\n", 538 | "====> Epoch: 600 Average training loss: 9.2956\n", 539 | "====> Epoch: 600 Average test loss: 9.3548\n", 540 | "====> Epoch: 800 Average training loss: 8.9570\n", 541 | "====> Epoch: 800 Average test loss: 8.9647\n", 542 | "====> Epoch: 1000 Average training loss: 8.6688\n", 543 | "====> Epoch: 1000 Average test loss: 8.5866\n", 544 | "====> Epoch: 1200 Average training loss: 8.3341\n", 545 | "====> Epoch: 1200 Average test loss: 8.8371\n", 546 | "====> Epoch: 1400 Average training loss: 8.4063\n", 547 | "====> Epoch: 1400 Average test loss: 8.7891\n" 548 | ] 549 | } 550 | ], 551 | "source": [ 552 | "for epoch in range(1, epochs + 1):\n", 553 | " train(epoch)\n", 554 | " test(epoch)" 555 | ] 556 | }, 557 | { 558 | "source": [ 559 | "We we're able to reduce the training and test loss but quite a bit, let's have a look at how the fake results actually look like vs the real results:" 560 | ], 561 | "cell_type": "markdown", 562 | "metadata": {} 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": 238, 567 | "metadata": {}, 568 | "outputs": [], 569 | "source": [ 570 | "with torch.no_grad():\n", 571 | " for batch_idx, data in enumerate(testloader):\n", 572 | " data = data.to(device)\n", 573 | " optimizer.zero_grad()\n", 574 | " recon_batch, mu, logvar = model(data)" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": 243, 580 | "metadata": {}, 581 | "outputs": [], 582 | "source": [ 583 | "scaler = trainloader.dataset.standardizer\n", 584 | "recon_row = scaler.inverse_transform(recon_batch[0].cpu().numpy())\n", 585 | "real_row = scaler.inverse_transform(testloader.dataset.x[0].cpu().numpy())" 586 | ] 587 | }, 588 | { 589 | "cell_type": "code", 590 | "execution_count": 246, 591 | "metadata": {}, 592 | "outputs": [ 593 | { 594 | "output_type": "execute_result", 595 | "data": { 596 | "text/plain": [ 597 | " Wine Alcohol Malic.acid Ash Acl Mg Phenols \\\n", 598 | "0 1.002792 13.535107 2.010303 2.557292 18.198132 112.606842 2.737524 \n", 599 | "1 1.000000 13.640000 3.100000 2.560000 15.200000 116.000000 2.700000 \n", 600 | "\n", 601 | " Flavanoids Nonflavanoid.phenols Proanth Color.int Hue OD \\\n", 602 | "0 2.807587 0.320866 1.738254 4.899318 1.078039 3.187276 \n", 603 | "1 3.030000 0.170000 1.660000 5.100000 0.960000 3.360000 \n", 604 | "\n", 605 | " Proline \n", 606 | "0 1013.391479 \n", 607 | "1 845.000000 " 608 | ], 609 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
WineAlcoholMalic.acidAshAclMgPhenolsFlavanoidsNonflavanoid.phenolsProanthColor.intHueODProline
01.00279213.5351072.0103032.55729218.198132112.6068422.7375242.8075870.3208661.7382544.8993181.0780393.1872761013.391479
11.00000013.6400003.1000002.56000015.200000116.0000002.7000003.0300000.1700001.6600005.1000000.9600003.360000845.000000
\n
" 610 | }, 611 | "metadata": {}, 612 | "execution_count": 246 613 | } 614 | ], 615 | "source": [ 616 | "df = pd.DataFrame(np.stack((recon_row, real_row)), columns = cols)\n", 617 | "df" 618 | ] 619 | }, 620 | { 621 | "source": [ 622 | "Not to bad right (the first row is the reconstructed row, the second one the real row from the data)? However, what we want is to built this row not with the real input so to speak, since right now we were giving the model the complete rows with their 14 columns, condensed it to 3 input parameters, just to blow it up again to the corresponding 14 columns. What I want to do is to create these 14 rows by giving the model 3 latent factors as input. Let's have a look at these latent variables. " 623 | ], 624 | "cell_type": "markdown", 625 | "metadata": {} 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": 255, 630 | "metadata": {}, 631 | "outputs": [], 632 | "source": [ 633 | "sigma = torch.exp(logvar/2)" 634 | ] 635 | }, 636 | { 637 | "cell_type": "code", 638 | "execution_count": 256, 639 | "metadata": {}, 640 | "outputs": [ 641 | { 642 | "output_type": "execute_result", 643 | "data": { 644 | "text/plain": [ 645 | "(tensor([-0.9960, -0.8502, -0.0043]), tensor([0.2555, 0.4801, 0.9888]))" 646 | ] 647 | }, 648 | "metadata": {}, 649 | "execution_count": 256 650 | } 651 | ], 652 | "source": [ 653 | "mu[1], sigma[1]" 654 | ] 655 | }, 656 | { 657 | "source": [ 658 | "Mu represents the mean for each of our latent factor values, logvar the log of the standard deviation. Each of these have a distribution by itself. We have 54 cases in our test data, so we have 3x54 different mu and logvar. We can have a look at the distribution of each of the 3 latent variables: " 659 | ], 660 | "cell_type": "markdown", 661 | "metadata": {} 662 | }, 663 | { 664 | "cell_type": "code", 665 | "execution_count": 257, 666 | "metadata": {}, 667 | "outputs": [ 668 | { 669 | "output_type": "execute_result", 670 | "data": { 671 | "text/plain": [ 672 | "(tensor([-0.0088, 0.0051, 0.0044]), tensor([0.4514, 0.3897, 0.9986]))" 673 | ] 674 | }, 675 | "metadata": {}, 676 | "execution_count": 257 677 | } 678 | ], 679 | "source": [ 680 | "mu.mean(axis=0), sigma.mean(axis=0)" 681 | ] 682 | }, 683 | { 684 | "source": [ 685 | "All of the latent variables have a mean around zero, but the last latent factor has a wider standard deviation. So when we sample values from each of these latent variables, the last value will vary much more then the other two. I assume a normal distribution for all the latent factors." 686 | ], 687 | "cell_type": "markdown", 688 | "metadata": {} 689 | }, 690 | { 691 | "cell_type": "code", 692 | "execution_count": 405, 693 | "metadata": {}, 694 | "outputs": [], 695 | "source": [ 696 | "# sample z from q\n", 697 | "no_samples = 20\n", 698 | "q = torch.distributions.Normal(mu.mean(axis=0), sigma.mean(axis=0))\n", 699 | "z = q.rsample(sample_shape=torch.Size([no_samples]))" 700 | ] 701 | }, 702 | { 703 | "cell_type": "code", 704 | "execution_count": 406, 705 | "metadata": {}, 706 | "outputs": [ 707 | { 708 | "output_type": "execute_result", 709 | "data": { 710 | "text/plain": [ 711 | "torch.Size([20, 3])" 712 | ] 713 | }, 714 | "metadata": {}, 715 | "execution_count": 406 716 | } 717 | ], 718 | "source": [ 719 | "z.shape" 720 | ] 721 | }, 722 | { 723 | "cell_type": "code", 724 | "execution_count": 446, 725 | "metadata": {}, 726 | "outputs": [ 727 | { 728 | "output_type": "execute_result", 729 | "data": { 730 | "text/plain": [ 731 | "tensor([[ 0.5283, 0.4519, 0.6792],\n", 732 | " [ 0.3664, -0.5569, -0.1531],\n", 733 | " [-0.5802, 0.4394, 1.8406],\n", 734 | " [-1.0136, -0.4239, 0.4524],\n", 735 | " [-0.0605, 0.3913, 0.8030]])" 736 | ] 737 | }, 738 | "metadata": {}, 739 | "execution_count": 446 740 | } 741 | ], 742 | "source": [ 743 | "z[:5]" 744 | ] 745 | }, 746 | { 747 | "source": [ 748 | "With these three latent factors we can now start and create fake data for our dataset and see how it looks like:" 749 | ], 750 | "cell_type": "markdown", 751 | "metadata": {} 752 | }, 753 | { 754 | "cell_type": "code", 755 | "execution_count": 408, 756 | "metadata": {}, 757 | "outputs": [], 758 | "source": [ 759 | "with torch.no_grad():\n", 760 | " pred = model.decode(z).cpu().numpy()" 761 | ] 762 | }, 763 | { 764 | "cell_type": "code", 765 | "execution_count": 409, 766 | "metadata": {}, 767 | "outputs": [ 768 | { 769 | "output_type": "execute_result", 770 | "data": { 771 | "text/plain": [ 772 | "array([-0.24290268, -0.6087041 , -0.44325534, -0.7158908 , -0.15065292,\n", 773 | " -0.47845733, 0.26319185, 0.23732403, -0.22809544, 0.12187037,\n", 774 | " -0.8295655 , 0.44908378, 0.6173717 , -0.55648965], dtype=float32)" 775 | ] 776 | }, 777 | "metadata": {}, 778 | "execution_count": 409 779 | } 780 | ], 781 | "source": [ 782 | "pred[1]" 783 | ] 784 | }, 785 | { 786 | "source": [ 787 | "## Create fake data from Autoencoder" 788 | ], 789 | "cell_type": "markdown", 790 | "metadata": {} 791 | }, 792 | { 793 | "cell_type": "code", 794 | "execution_count": 420, 795 | "metadata": {}, 796 | "outputs": [ 797 | { 798 | "output_type": "execute_result", 799 | "data": { 800 | "text/plain": [ 801 | "(20, 14)" 802 | ] 803 | }, 804 | "metadata": {}, 805 | "execution_count": 420 806 | } 807 | ], 808 | "source": [ 809 | "fake_data = scaler.inverse_transform(pred)\n", 810 | "fake_data.shape" 811 | ] 812 | }, 813 | { 814 | "cell_type": "code", 815 | "execution_count": 439, 816 | "metadata": {}, 817 | "outputs": [ 818 | { 819 | "output_type": "execute_result", 820 | "data": { 821 | "text/plain": [ 822 | " Wine Alcohol Malic.acid Ash Acl Mg Phenols \\\n", 823 | "0 3 13.350755 3.817283 2.425754 21.229387 98.816788 1.682916 \n", 824 | "1 2 12.453159 1.916350 2.172731 18.977226 93.556114 2.444676 \n", 825 | "2 2 12.735057 2.404566 2.447556 20.400013 105.475235 1.937112 \n", 826 | "3 1 14.664644 1.517465 2.269279 12.428186 88.851791 3.354010 \n", 827 | "4 3 13.160161 3.359397 2.415784 21.050211 99.859154 1.662516 \n", 828 | "5 2 12.453159 1.916350 2.172731 18.977226 93.556114 2.444676 \n", 829 | "6 2 12.520310 2.522696 2.375254 20.435560 92.619812 1.838333 \n", 830 | "7 3 12.877177 2.746192 2.395865 20.154610 97.263092 1.744550 \n", 831 | "8 2 12.679532 2.344776 2.331834 19.901327 97.031586 1.857117 \n", 832 | "9 2 13.062141 2.719065 2.461590 19.947014 103.352890 2.070540 \n", 833 | "\n", 834 | " Flavanoids Nonflavanoid.phenols Proanth Color.int Hue OD \\\n", 835 | "0 0.910786 0.450081 1.245882 8.242197 0.667928 1.705379 \n", 836 | "1 2.246270 0.335432 1.663583 3.166457 1.063876 3.050176 \n", 837 | "2 1.657119 0.385740 1.452577 4.242754 0.928397 2.467263 \n", 838 | "3 3.997237 0.265253 2.586414 7.366968 1.275564 3.170231 \n", 839 | "4 0.929189 0.427978 1.135361 7.101127 0.708510 1.732820 \n", 840 | "5 2.246270 0.335432 1.663583 3.166457 1.063876 3.050176 \n", 841 | "6 1.361269 0.470815 1.221076 4.518130 0.906680 2.146883 \n", 842 | "7 1.187050 0.464942 1.160733 5.619783 0.836708 1.871472 \n", 843 | "8 1.495742 0.461352 1.239715 4.668478 0.934352 2.094139 \n", 844 | "9 1.566055 0.380154 1.293219 5.675068 0.852832 2.128047 \n", 845 | "\n", 846 | " Proline \n", 847 | "0 636.650818 \n", 848 | "1 568.385925 \n", 849 | "2 680.271545 \n", 850 | "3 1516.662720 \n", 851 | "4 640.412231 \n", 852 | "5 568.385925 \n", 853 | "6 583.079102 \n", 854 | "7 665.485718 \n", 855 | "8 680.778809 \n", 856 | "9 778.582825 " 857 | ], 858 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
WineAlcoholMalic.acidAshAclMgPhenolsFlavanoidsNonflavanoid.phenolsProanthColor.intHueODProline
0313.3507553.8172832.42575421.22938798.8167881.6829160.9107860.4500811.2458828.2421970.6679281.705379636.650818
1212.4531591.9163502.17273118.97722693.5561142.4446762.2462700.3354321.6635833.1664571.0638763.050176568.385925
2212.7350572.4045662.44755620.400013105.4752351.9371121.6571190.3857401.4525774.2427540.9283972.467263680.271545
3114.6646441.5174652.26927912.42818688.8517913.3540103.9972370.2652532.5864147.3669681.2755643.1702311516.662720
4313.1601613.3593972.41578421.05021199.8591541.6625160.9291890.4279781.1353617.1011270.7085101.732820640.412231
5212.4531591.9163502.17273118.97722693.5561142.4446762.2462700.3354321.6635833.1664571.0638763.050176568.385925
6212.5203102.5226962.37525420.43556092.6198121.8383331.3612690.4708151.2210764.5181300.9066802.146883583.079102
7312.8771772.7461922.39586520.15461097.2630921.7445501.1870500.4649421.1607335.6197830.8367081.871472665.485718
8212.6795322.3447762.33183419.90132797.0315861.8571171.4957420.4613521.2397154.6684780.9343522.094139680.778809
9213.0621412.7190652.46159019.947014103.3528902.0705401.5660550.3801541.2932195.6750680.8528322.128047778.582825
\n
" 859 | }, 860 | "metadata": {}, 861 | "execution_count": 439 862 | } 863 | ], 864 | "source": [ 865 | "df_fake = pd.DataFrame(fake_data, columns = cols)\n", 866 | "df_fake['Wine'] = np.round(df_fake['Wine']).astype(int)\n", 867 | "df_fake['Wine'] = np.where(df_fake['Wine']<1, 1, df_fake['Wine'])\n", 868 | "df_fake.head(10)" 869 | ] 870 | }, 871 | { 872 | "source": [ 873 | "For comparison the real data:" 874 | ], 875 | "cell_type": "markdown", 876 | "metadata": {} 877 | }, 878 | { 879 | "cell_type": "code", 880 | "execution_count": 444, 881 | "metadata": {}, 882 | "outputs": [ 883 | { 884 | "output_type": "execute_result", 885 | "data": { 886 | "text/plain": [ 887 | " Wine Alcohol Malic.acid Ash Acl Mg Phenols Flavanoids \\\n", 888 | "1 1 13.20 1.78 2.14 11.2 100 2.65 2.76 \n", 889 | "35 1 13.48 1.81 2.41 20.5 100 2.70 2.98 \n", 890 | "114 2 12.08 1.39 2.50 22.5 84 2.56 2.29 \n", 891 | "149 3 13.08 3.90 2.36 21.5 113 1.41 1.39 \n", 892 | "158 3 14.34 1.68 2.70 25.0 98 2.80 1.31 \n", 893 | "9 1 13.86 1.35 2.27 16.0 98 2.98 3.15 \n", 894 | "90 2 12.08 1.83 2.32 18.5 81 1.60 1.50 \n", 895 | "47 1 13.90 1.68 2.12 16.0 101 3.10 3.39 \n", 896 | "10 1 14.10 2.16 2.30 18.0 105 2.95 3.32 \n", 897 | "31 1 13.58 1.66 2.36 19.1 106 2.86 3.19 \n", 898 | "\n", 899 | " Nonflavanoid.phenols Proanth Color.int Hue OD Proline \n", 900 | "1 0.26 1.28 4.38 1.05 3.40 1050 \n", 901 | "35 0.26 1.86 5.10 1.04 3.47 920 \n", 902 | "114 0.43 1.04 2.90 0.93 3.19 385 \n", 903 | "149 0.34 1.14 9.40 0.57 1.33 550 \n", 904 | "158 0.53 2.70 13.00 0.57 1.96 660 \n", 905 | "9 0.22 1.85 7.22 1.01 3.55 1045 \n", 906 | "90 0.52 1.64 2.40 1.08 2.27 480 \n", 907 | "47 0.21 2.14 6.10 0.91 3.33 985 \n", 908 | "10 0.22 2.38 5.75 1.25 3.17 1510 \n", 909 | "31 0.22 1.95 6.90 1.09 2.88 1515 " 910 | ], 911 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
WineAlcoholMalic.acidAshAclMgPhenolsFlavanoidsNonflavanoid.phenolsProanthColor.intHueODProline
1113.201.782.1411.21002.652.760.261.284.381.053.401050
35113.481.812.4120.51002.702.980.261.865.101.043.47920
114212.081.392.5022.5842.562.290.431.042.900.933.19385
149313.083.902.3621.51131.411.390.341.149.400.571.33550
158314.341.682.7025.0982.801.310.532.7013.000.571.96660
9113.861.352.2716.0982.983.150.221.857.221.013.551045
90212.081.832.3218.5811.601.500.521.642.401.082.27480
47113.901.682.1216.01013.103.390.212.146.100.913.33985
10114.102.162.3018.01052.953.320.222.385.751.253.171510
31113.581.662.3619.11062.863.190.221.956.901.092.881515
\n
" 912 | }, 913 | "metadata": {}, 914 | "execution_count": 444 915 | } 916 | ], 917 | "source": [ 918 | "df_base.sample(10)" 919 | ] 920 | }, 921 | { 922 | "source": [ 923 | "## Compare variables grouped by Wine" 924 | ], 925 | "cell_type": "markdown", 926 | "metadata": {} 927 | }, 928 | { 929 | "cell_type": "code", 930 | "execution_count": 443, 931 | "metadata": {}, 932 | "outputs": [ 933 | { 934 | "output_type": "execute_result", 935 | "data": { 936 | "text/plain": [ 937 | " Alcohol Malic.acid Ash Acl Mg Phenols \\\n", 938 | "Wine \n", 939 | "1 13.744746 2.010678 2.455593 17.037288 106.338983 2.840169 \n", 940 | "2 12.278732 1.932676 2.244789 20.238028 94.549296 2.258873 \n", 941 | "3 13.153750 3.333750 2.437083 21.416667 99.312500 1.678750 \n", 942 | "\n", 943 | " Flavanoids Nonflavanoid.phenols Proanth Color.int Hue \\\n", 944 | "Wine \n", 945 | "1 2.982373 0.290000 1.899322 5.528305 1.062034 \n", 946 | "2 2.080845 0.363662 1.630282 3.086620 1.056282 \n", 947 | "3 0.781458 0.447500 1.153542 7.396250 0.682708 \n", 948 | "\n", 949 | " OD Proline \n", 950 | "Wine \n", 951 | "1 3.157797 1115.711864 \n", 952 | "2 2.785352 519.507042 \n", 953 | "3 1.683542 629.895833 " 954 | ], 955 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
AlcoholMalic.acidAshAclMgPhenolsFlavanoidsNonflavanoid.phenolsProanthColor.intHueODProline
Wine
113.7447462.0106782.45559317.037288106.3389832.8401692.9823730.2900001.8993225.5283051.0620343.1577971115.711864
212.2787321.9326762.24478920.23802894.5492962.2588732.0808450.3636621.6302823.0866201.0562822.785352519.507042
313.1537503.3337502.43708321.41666799.3125001.6787500.7814580.4475001.1535427.3962500.6827081.683542629.895833
\n
" 956 | }, 957 | "metadata": {}, 958 | "execution_count": 443 959 | } 960 | ], 961 | "source": [ 962 | "df_base.groupby('Wine').mean()" 963 | ] 964 | }, 965 | { 966 | "cell_type": "code", 967 | "execution_count": 445, 968 | "metadata": {}, 969 | "outputs": [ 970 | { 971 | "output_type": "execute_result", 972 | "data": { 973 | "text/plain": [ 974 | " Alcohol Malic.acid Ash Acl Mg Phenols \\\n", 975 | "Wine \n", 976 | "1 13.812141 1.814212 2.482638 17.172688 107.468864 3.062387 \n", 977 | "2 12.560544 2.157595 2.301805 19.696327 99.324005 2.254415 \n", 978 | "3 13.170316 3.413856 2.416369 20.929930 99.028229 1.683604 \n", 979 | "\n", 980 | " Flavanoids Nonflavanoid.phenols Proanth Color.int Hue \\\n", 981 | "Wine \n", 982 | "1 3.344664 0.259955 2.162966 5.331643 1.147217 \n", 983 | "2 1.995140 0.366076 1.575015 3.791955 1.000527 \n", 984 | "3 0.964315 0.443444 1.176529 7.288512 0.718357 \n", 985 | "\n", 986 | " OD Proline \n", 987 | "Wine \n", 988 | "1 3.280716 1148.031372 \n", 989 | "2 2.741598 629.895203 \n", 990 | "3 1.745200 644.870056 " 991 | ], 992 | "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
AlcoholMalic.acidAshAclMgPhenolsFlavanoidsNonflavanoid.phenolsProanthColor.intHueODProline
Wine
113.8121411.8142122.48263817.172688107.4688643.0623873.3446640.2599552.1629665.3316431.1472173.2807161148.031372
212.5605442.1575952.30180519.69632799.3240052.2544151.9951400.3660761.5750153.7919551.0005272.741598629.895203
313.1703163.4138562.41636920.92993099.0282291.6836040.9643150.4434441.1765297.2885120.7183571.745200644.870056
\n
" 993 | }, 994 | "metadata": {}, 995 | "execution_count": 445 996 | } 997 | ], 998 | "source": [ 999 | "df_fake.groupby('Wine').mean()" 1000 | ] 1001 | }, 1002 | { 1003 | "source": [ 1004 | "That looks pretty convincing if you ask me. \n", 1005 | "\n", 1006 | "To sum up, we've built a variational autoencoder, which we trained on our trainingset. We checked whether our loss kept on improving based on the testset, which the autoencoder never saw for generating fake data. We then calculated the mean and standard deviation from our latent factors given the test data. We've then sampled from this distribution to feed it back into our decoder to create some fake data. With this approach I am now able to create as much fake data derived from the underlying distribution as a want. And I think the results look promising. \n", 1007 | "\n", 1008 | "You can take this approach to for example create data from under-represented in highly skewed datasets instead of just weighting them higher. The re-weighting approach might cause the algorithm to find relations where there are none, only because a few then overrepresented data points share this relation by random. With the shown approach, the learned distribution would take into account the high variance these features have and therefore will hopefully help the algorithm to not draw these false conclusions.\n", 1009 | "\n", 1010 | "Stay tuned for the next blogpost, where I will show the shown approach in exactly this use case." 1011 | ], 1012 | "cell_type": "markdown", 1013 | "metadata": {} 1014 | }, 1015 | { 1016 | "cell_type": "code", 1017 | "execution_count": null, 1018 | "metadata": {}, 1019 | "outputs": [], 1020 | "source": [] 1021 | } 1022 | ] 1023 | } -------------------------------------------------------------------------------- /Create_Autoencoder_Model_Basemodel_3Embeddings.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Variational Autoencoder on Tabular Data" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "I use [wine dataset](https://archive.ics.uci.edu/ml/datasets/wine) to show how Variational Autoencoder (VAE) with PyTorch on tabular data works. I use the VAE to reduce the dimensionality of dataset, in this case don to 3 Variables (embeddings). I then plot the embeddings in a 3D graph to show how VAE is similar to a PCA but works in a non-linear way. " 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "# Imports" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import torch\n", 31 | "import torch.nn as nn\n", 32 | "import torch.nn.functional as F\n", 33 | "from torch import nn, optim\n", 34 | "from torch.autograd import Variable\n", 35 | "\n", 36 | "import pandas as pd\n", 37 | "import numpy as np\n", 38 | "from sklearn import preprocessing" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "data": { 48 | "text/plain": [ 49 | "device(type='cpu')" 50 | ] 51 | }, 52 | "execution_count": 2, 53 | "metadata": {}, 54 | "output_type": "execute_result" 55 | } 56 | ], 57 | "source": [ 58 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 59 | "device" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "# Define Path to Dataset" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 3, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "DATA_PATH = 'Data/wine.csv'" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "# Define Functions" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 4, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "def load_data(path):\n", 92 | " # read in from csv\n", 93 | " df = pd.read_csv(path, sep=',')\n", 94 | " # replace nan with -99\n", 95 | " df = df.fillna(-99)\n", 96 | " df_base = df.iloc[:, 1:]\n", 97 | " df_wine = df.iloc[:,0].values\n", 98 | " x = df_base.values.reshape(-1, df_base.shape[1]).astype('float32')\n", 99 | " # stadardize values\n", 100 | " standardizer = preprocessing.StandardScaler()\n", 101 | " x_train = standardizer.fit_transform(x)\n", 102 | " x_train = torch.from_numpy(x_train).to(device)\n", 103 | " return x_train, standardizer, df_wine" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "# Build DataLoader" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 5, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "from torch.utils.data import Dataset, DataLoader\n", 120 | "class DataBuilder(Dataset):\n", 121 | " def __init__(self, path):\n", 122 | " self.x, self.standardizer, self.wine = load_data(DATA_PATH)\n", 123 | " self.len=self.x.shape[0]\n", 124 | " def __getitem__(self,index): \n", 125 | " return self.x[index]\n", 126 | " def __len__(self):\n", 127 | " return self.len" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 6, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "data_set=DataBuilder(DATA_PATH)\n", 137 | "trainloader=DataLoader(dataset=data_set,batch_size=1024)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 7, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "data": { 147 | "text/plain": [ 148 | "torch.Tensor" 149 | ] 150 | }, 151 | "execution_count": 7, 152 | "metadata": {}, 153 | "output_type": "execute_result" 154 | } 155 | ], 156 | "source": [ 157 | "type(trainloader.dataset.x)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 8, 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "data": { 167 | "text/plain": [ 168 | "tensor([[ 1.5186, -0.5622, 0.2321, ..., 0.3622, 1.8479, 1.0130],\n", 169 | " [ 0.2463, -0.4994, -0.8280, ..., 0.4061, 1.1134, 0.9652],\n", 170 | " [ 0.1969, 0.0212, 1.1093, ..., 0.3183, 0.7886, 1.3951],\n", 171 | " ...,\n", 172 | " [ 0.3328, 1.7447, -0.3894, ..., -1.6121, -1.4854, 0.2806],\n", 173 | " [ 0.2092, 0.2277, 0.0127, ..., -1.5683, -1.4007, 0.2965],\n", 174 | " [ 1.3951, 1.5832, 1.3652, ..., -1.5244, -1.4289, -0.5952]])" 175 | ] 176 | }, 177 | "execution_count": 8, 178 | "metadata": {}, 179 | "output_type": "execute_result" 180 | } 181 | ], 182 | "source": [ 183 | "data_set.x" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "# Build Model and train it" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 9, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "class Autoencoder(nn.Module):\n", 200 | " def __init__(self,D_in,H=50,H2=12,latent_dim=3):\n", 201 | " \n", 202 | " #Encoder\n", 203 | " super(Autoencoder,self).__init__()\n", 204 | " self.linear1=nn.Linear(D_in,H)\n", 205 | " self.lin_bn1 = nn.BatchNorm1d(num_features=H)\n", 206 | " self.linear2=nn.Linear(H,H2)\n", 207 | " self.lin_bn2 = nn.BatchNorm1d(num_features=H2)\n", 208 | " self.linear3=nn.Linear(H2,H2)\n", 209 | " self.lin_bn3 = nn.BatchNorm1d(num_features=H2)\n", 210 | " \n", 211 | "# # Latent vectors mu and sigma\n", 212 | " self.fc1 = nn.Linear(H2, latent_dim)\n", 213 | "# self.bn1 = nn.BatchNorm1d(num_features=latent_dim)\n", 214 | " self.fc21 = nn.Linear(latent_dim, latent_dim)\n", 215 | " self.fc22 = nn.Linear(latent_dim, latent_dim)\n", 216 | "\n", 217 | "# # Sampling vector\n", 218 | " self.fc3 = nn.Linear(latent_dim, latent_dim)\n", 219 | "# self.fc_bn3 = nn.BatchNorm1d(latent_dim)\n", 220 | " self.fc4 = nn.Linear(latent_dim, H2)\n", 221 | "# self.fc_bn4 = nn.BatchNorm1d(H2)\n", 222 | " \n", 223 | "# # Decoder\n", 224 | " self.linear4=nn.Linear(H2,H2)\n", 225 | " self.lin_bn4 = nn.BatchNorm1d(num_features=H2)\n", 226 | " self.linear5=nn.Linear(H2,H)\n", 227 | " self.lin_bn5 = nn.BatchNorm1d(num_features=H)\n", 228 | " self.linear6=nn.Linear(H,D_in)\n", 229 | " self.lin_bn6 = nn.BatchNorm1d(num_features=D_in)\n", 230 | " \n", 231 | " self.relu = nn.ReLU()\n", 232 | " \n", 233 | " def encode(self, x):\n", 234 | " lin1 = self.relu(self.lin_bn1(self.linear1(x)))\n", 235 | " lin2 = self.relu(self.lin_bn2(self.linear2(lin1)))\n", 236 | " lin3 = self.relu(self.lin_bn3(self.linear3(lin2)))\n", 237 | "\n", 238 | " fc1 = F.relu(self.fc1(lin3))\n", 239 | "\n", 240 | " r1 = self.fc21(fc1)\n", 241 | " r2 = self.fc22(fc1)\n", 242 | " \n", 243 | " return r1, r2\n", 244 | " \n", 245 | " def reparameterize(self, mu, logvar):\n", 246 | " if self.training:\n", 247 | " std = logvar.mul(0.5).exp_()\n", 248 | " eps = Variable(std.data.new(std.size()).normal_())\n", 249 | " return eps.mul(std).add_(mu)\n", 250 | " else:\n", 251 | " return mu\n", 252 | " \n", 253 | " def decode(self, z):\n", 254 | " fc3 = self.relu(self.fc3(z))\n", 255 | " fc4 = self.relu(self.fc4(fc3))#.view(128, -1)\n", 256 | "\n", 257 | " lin4 = self.relu(self.lin_bn4(self.linear4(fc4)))\n", 258 | " lin5 = self.relu(self.lin_bn5(self.linear5(lin4)))\n", 259 | " return self.lin_bn6(self.linear6(lin5))\n", 260 | "\n", 261 | "\n", 262 | " \n", 263 | " def forward(self, x):\n", 264 | " mu, logvar = self.encode(x)\n", 265 | " z = self.reparameterize(mu, logvar)\n", 266 | " # self.decode(z) ist später recon_batch, mu ist mu und logvar ist logvar\n", 267 | " return self.decode(z), mu, logvar" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 10, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "class customLoss(nn.Module):\n", 277 | " def __init__(self):\n", 278 | " super(customLoss, self).__init__()\n", 279 | " self.mse_loss = nn.MSELoss(reduction=\"sum\")\n", 280 | " \n", 281 | " # x_recon ist der im forward im Model erstellte recon_batch, x ist der originale x Batch, mu ist mu und logvar ist logvar \n", 282 | " def forward(self, x_recon, x, mu, logvar):\n", 283 | " loss_MSE = self.mse_loss(x_recon, x)\n", 284 | " loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n", 285 | "\n", 286 | " return loss_MSE + loss_KLD" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 11, 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "# takes in a module and applies the specified weight initialization\n", 296 | "def weights_init_uniform_rule(m):\n", 297 | " classname = m.__class__.__name__\n", 298 | " # for every Linear layer in a model..\n", 299 | " if classname.find('Linear') != -1:\n", 300 | " # get the number of the inputs\n", 301 | " n = m.in_features\n", 302 | " y = 1.0/np.sqrt(n)\n", 303 | " m.weight.data.uniform_(-y, y)\n", 304 | " m.bias.data.fill_(0)" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 12, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "D_in = data_set.x.shape[1]\n", 314 | "H = 50\n", 315 | "H2 = 12\n", 316 | "model = Autoencoder(D_in, H, H2).to(device)\n", 317 | "#model.apply(weights_init_uniform_rule)\n", 318 | "#sae.fc4.register_forward_hook(get_activation('fc4'))\n", 319 | "optimizer = optim.Adam(model.parameters(), lr=1e-3)" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 13, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "loss_mse = customLoss()" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "metadata": {}, 334 | "source": [ 335 | "# Train" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 14, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "epochs = 2000\n", 345 | "log_interval = 50\n", 346 | "val_losses = []\n", 347 | "train_losses = []" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": 15, 353 | "metadata": {}, 354 | "outputs": [], 355 | "source": [ 356 | "def train(epoch):\n", 357 | " model.train()\n", 358 | " train_loss = 0\n", 359 | " for batch_idx, data in enumerate(trainloader):\n", 360 | " data = data.to(device)\n", 361 | " optimizer.zero_grad()\n", 362 | " recon_batch, mu, logvar = model(data)\n", 363 | " loss = loss_mse(recon_batch, data, mu, logvar)\n", 364 | " loss.backward()\n", 365 | " train_loss += loss.item()\n", 366 | " optimizer.step()\n", 367 | "# if batch_idx % log_interval == 0:\n", 368 | "# print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", 369 | "# epoch, batch_idx * len(data), len(trainloader.dataset),\n", 370 | "# 100. * batch_idx / len(trainloader),\n", 371 | "# loss.item() / len(data)))\n", 372 | " if epoch % 200 == 0: \n", 373 | " print('====> Epoch: {} Average loss: {:.4f}'.format(\n", 374 | " epoch, train_loss / len(trainloader.dataset)))\n", 375 | " train_losses.append(train_loss / len(trainloader.dataset))" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 16, 381 | "metadata": { 382 | "scrolled": true 383 | }, 384 | "outputs": [ 385 | { 386 | "name": "stdout", 387 | "output_type": "stream", 388 | "text": [ 389 | "====> Epoch: 200 Average loss: 11.2040\n", 390 | "====> Epoch: 400 Average loss: 9.4026\n", 391 | "====> Epoch: 600 Average loss: 8.1786\n", 392 | "====> Epoch: 800 Average loss: 7.7224\n", 393 | "====> Epoch: 1000 Average loss: 7.6587\n", 394 | "====> Epoch: 1200 Average loss: 7.4626\n", 395 | "====> Epoch: 1400 Average loss: 7.4643\n", 396 | "====> Epoch: 1600 Average loss: 7.3207\n", 397 | "====> Epoch: 1800 Average loss: 7.0685\n", 398 | "====> Epoch: 2000 Average loss: 7.2222\n" 399 | ] 400 | } 401 | ], 402 | "source": [ 403 | "for epoch in range(1, epochs + 1):\n", 404 | " train(epoch)" 405 | ] 406 | }, 407 | { 408 | "cell_type": "markdown", 409 | "metadata": {}, 410 | "source": [ 411 | "# Evaluate" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 17, 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "standardizer = trainloader.dataset.standardizer" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 18, 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [ 429 | "model.eval()\n", 430 | "test_loss = 0\n", 431 | "# no_grad() bedeutet wir nehmen die vorher berechneten Gewichte und erneuern sie nicht\n", 432 | "with torch.no_grad():\n", 433 | " for i, data in enumerate(trainloader):\n", 434 | " data = data.to(device)\n", 435 | " recon_batch, mu, logvar = model(data)" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": 19, 441 | "metadata": {}, 442 | "outputs": [ 443 | { 444 | "data": { 445 | "text/plain": [ 446 | "array([1.34402313e+01, 1.96345377e+00, 2.62067842e+00, 1.87440109e+01,\n", 447 | " 1.07427719e+02, 2.63568044e+00, 2.69742918e+00, 3.28377843e-01,\n", 448 | " 1.67999256e+00, 4.58271646e+00, 1.10575414e+00, 3.02566600e+00,\n", 449 | " 1.01869727e+03], dtype=float32)" 450 | ] 451 | }, 452 | "execution_count": 19, 453 | "metadata": {}, 454 | "output_type": "execute_result" 455 | } 456 | ], 457 | "source": [ 458 | "standardizer.inverse_transform(recon_batch[65].cpu().numpy())" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": 20, 464 | "metadata": {}, 465 | "outputs": [ 466 | { 467 | "data": { 468 | "text/plain": [ 469 | "array([1.237e+01, 1.210e+00, 2.560e+00, 1.810e+01, 9.800e+01, 2.420e+00,\n", 470 | " 2.650e+00, 3.700e-01, 2.080e+00, 4.600e+00, 1.190e+00, 2.300e+00,\n", 471 | " 6.780e+02], dtype=float32)" 472 | ] 473 | }, 474 | "execution_count": 20, 475 | "metadata": {}, 476 | "output_type": "execute_result" 477 | } 478 | ], 479 | "source": [ 480 | "standardizer.inverse_transform(data[65].cpu().numpy())" 481 | ] 482 | }, 483 | { 484 | "cell_type": "markdown", 485 | "metadata": {}, 486 | "source": [ 487 | "# Get Embeddings" 488 | ] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "execution_count": 21, 493 | "metadata": {}, 494 | "outputs": [], 495 | "source": [ 496 | "mu_output = []\n", 497 | "logvar_output = []\n", 498 | "\n", 499 | "with torch.no_grad():\n", 500 | " for i, (data) in enumerate(trainloader):\n", 501 | " data = data.to(device)\n", 502 | " optimizer.zero_grad()\n", 503 | " recon_batch, mu, logvar = model(data)\n", 504 | "\n", 505 | " \n", 506 | " mu_tensor = mu \n", 507 | " mu_output.append(mu_tensor)\n", 508 | " mu_result = torch.cat(mu_output, dim=0)\n", 509 | "\n", 510 | " logvar_tensor = logvar \n", 511 | " logvar_output.append(logvar_tensor)\n", 512 | " logvar_result = torch.cat(logvar_output, dim=0)" 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": 22, 518 | "metadata": {}, 519 | "outputs": [ 520 | { 521 | "data": { 522 | "text/plain": [ 523 | "torch.Size([178, 3])" 524 | ] 525 | }, 526 | "execution_count": 22, 527 | "metadata": {}, 528 | "output_type": "execute_result" 529 | } 530 | ], 531 | "source": [ 532 | "mu_result.shape" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": 23, 538 | "metadata": {}, 539 | "outputs": [ 540 | { 541 | "data": { 542 | "text/plain": [ 543 | "tensor([[-0.0331, 1.1830, -0.0299],\n", 544 | " [-0.0395, 1.1291, -0.0380],\n", 545 | " [ 0.0049, 1.6027, 0.0214],\n", 546 | " [ 0.0227, 0.6376, 0.0077]])" 547 | ] 548 | }, 549 | "execution_count": 23, 550 | "metadata": {}, 551 | "output_type": "execute_result" 552 | } 553 | ], 554 | "source": [ 555 | "mu_result[1:5,:]" 556 | ] 557 | }, 558 | { 559 | "cell_type": "markdown", 560 | "metadata": {}, 561 | "source": [ 562 | "# Plot Embeddings" 563 | ] 564 | }, 565 | { 566 | "cell_type": "code", 567 | "execution_count": 24, 568 | "metadata": {}, 569 | "outputs": [], 570 | "source": [ 571 | "from mpl_toolkits import mplot3d\n", 572 | "\n", 573 | "%matplotlib inline\n", 574 | "import numpy as np\n", 575 | "import matplotlib.pyplot as plt" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": 25, 581 | "metadata": {}, 582 | "outputs": [ 583 | { 584 | "data": { 585 | "image/png": "\n", 586 | "text/plain": [ 587 | "
" 588 | ] 589 | }, 590 | "metadata": {}, 591 | "output_type": "display_data" 592 | } 593 | ], 594 | "source": [ 595 | "ax = plt.axes(projection='3d')\n", 596 | "\n", 597 | "\n", 598 | "# Data for three-dimensional scattered points\n", 599 | "winetype = data_set.wine\n", 600 | "zdata = mu_result[:,0].numpy()\n", 601 | "xdata = mu_result[:,1].numpy()\n", 602 | "ydata = mu_result[:,2].numpy()\n", 603 | "ax.scatter3D(xdata, ydata, zdata, c=winetype);" 604 | ] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "execution_count": null, 609 | "metadata": {}, 610 | "outputs": [], 611 | "source": [] 612 | } 613 | ], 614 | "metadata": { 615 | "kernelspec": { 616 | "display_name": "Python 3", 617 | "language": "python", 618 | "name": "python3" 619 | }, 620 | "language_info": { 621 | "codemirror_mode": { 622 | "name": "ipython", 623 | "version": 3 624 | }, 625 | "file_extension": ".py", 626 | "mimetype": "text/x-python", 627 | "name": "python", 628 | "nbconvert_exporter": "python", 629 | "pygments_lexer": "ipython3", 630 | "version": "3.6.5" 631 | } 632 | }, 633 | "nbformat": 4, 634 | "nbformat_minor": 2 635 | } 636 | --------------------------------------------------------------------------------