├── .gitignore ├── GAIN-pretty-tqdm.ipynb ├── GAIN.ipynb ├── LICENSE ├── Letter.csv ├── README.md └── Spam.csv /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | -------------------------------------------------------------------------------- /GAIN-pretty-tqdm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Creator: Dhanajit Brahma\n", 8 | "\n", 9 | "Adapted from the original implementation in tensorflow from here: https://github.com/jsyoon0823/GAIN\n", 10 | "\n", 11 | "Generative Adversarial Imputation Networks (GAIN) Implementation on Letter and Spam Dataset\n", 12 | "\n", 13 | "Reference: J. Yoon, J. Jordon, M. van der Schaar, \"GAIN: Missing Data Imputation using Generative Adversarial Nets,\" ICML, 2018." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "#%% Packages\n", 23 | "import torch\n", 24 | "import numpy as np\n", 25 | "# from tqdm import tqdm\n", 26 | "from tqdm.notebook import tqdm_notebook as tqdm\n", 27 | "import torch.nn.functional as F" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "dataset_file = 'Spam.csv' # 'Letter.csv' for Letter dataset an 'Spam.csv' for Spam dataset\n", 37 | "use_gpu = False # set it to True to use GPU and False to use CPU\n", 38 | "\n", 39 | "if use_gpu:\n", 40 | " torch.cuda.set_device(0)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "#%% System Parameters\n", 50 | "# 1. Mini batch size\n", 51 | "mb_size = 128\n", 52 | "# 2. Missing rate\n", 53 | "p_miss = 0.2\n", 54 | "# 3. Hint rate\n", 55 | "p_hint = 0.9\n", 56 | "# 4. Loss Hyperparameters\n", 57 | "alpha = 10\n", 58 | "# 5. Train Rate\n", 59 | "train_rate = 0.8\n", 60 | "\n", 61 | "#%% Data\n", 62 | "\n", 63 | "# Data generation\n", 64 | "Data = np.loadtxt(dataset_file, delimiter=\",\",skiprows=1)\n", 65 | "\n", 66 | "# Parameters\n", 67 | "No = len(Data)\n", 68 | "Dim = len(Data[0,:])\n", 69 | "\n", 70 | "# Hidden state dimensions\n", 71 | "H_Dim1 = Dim\n", 72 | "H_Dim2 = Dim\n", 73 | "\n", 74 | "# Normalization (0 to 1)\n", 75 | "Min_Val = np.zeros(Dim)\n", 76 | "Max_Val = np.zeros(Dim)\n", 77 | "\n", 78 | "for i in range(Dim):\n", 79 | " Min_Val[i] = np.min(Data[:,i])\n", 80 | " Data[:,i] = Data[:,i] - np.min(Data[:,i])\n", 81 | " Max_Val[i] = np.max(Data[:,i])\n", 82 | " Data[:,i] = Data[:,i] / (np.max(Data[:,i]) + 1e-6) \n", 83 | "\n", 84 | "#%% Missing introducing\n", 85 | "p_miss_vec = p_miss * np.ones((Dim,1)) \n", 86 | " \n", 87 | "Missing = np.zeros((No,Dim))\n", 88 | "\n", 89 | "for i in range(Dim):\n", 90 | " A = np.random.uniform(0., 1., size = [len(Data),])\n", 91 | " B = A > p_miss_vec[i]\n", 92 | " Missing[:,i] = 1.*B\n", 93 | "\n", 94 | " \n", 95 | "#%% Train Test Division \n", 96 | " \n", 97 | "idx = np.random.permutation(No)\n", 98 | "\n", 99 | "Train_No = int(No * train_rate)\n", 100 | "Test_No = No - Train_No\n", 101 | " \n", 102 | "# Train / Test Features\n", 103 | "trainX = Data[idx[:Train_No],:]\n", 104 | "testX = Data[idx[Train_No:],:]\n", 105 | "\n", 106 | "# Train / Test Missing Indicators\n", 107 | "trainM = Missing[idx[:Train_No],:]\n", 108 | "testM = Missing[idx[Train_No:],:]\n", 109 | "\n", 110 | "#%% Necessary Functions\n", 111 | "\n", 112 | "# 1. Xavier Initialization Definition\n", 113 | "# def xavier_init(size):\n", 114 | "# in_dim = size[0]\n", 115 | "# xavier_stddev = 1. / tf.sqrt(in_dim / 2.)\n", 116 | "# return tf.random_normal(shape = size, stddev = xavier_stddev)\n", 117 | "def xavier_init(size):\n", 118 | " in_dim = size[0]\n", 119 | " xavier_stddev = 1. / np.sqrt(in_dim / 2.)\n", 120 | " return np.random.normal(size = size, scale = xavier_stddev)\n", 121 | " \n", 122 | "# Hint Vector Generation\n", 123 | "def sample_M(m, n, p):\n", 124 | " A = np.random.uniform(0., 1., size = [m, n])\n", 125 | " B = A > p\n", 126 | " C = 1.*B\n", 127 | " return C\n", 128 | " " 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "### GAIN Architecture \n", 136 | "GAIN Consists of 3 Components\n", 137 | "- Generator\n", 138 | "- Discriminator\n", 139 | "- Hint Mechanism" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 4, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "#%% 1. Discriminator\n", 149 | "if use_gpu is True:\n", 150 | " D_W1 = torch.tensor(xavier_init([Dim*2, H_Dim1]),requires_grad=True, device=\"cuda\") # Data + Hint as inputs\n", 151 | " D_b1 = torch.tensor(np.zeros(shape = [H_Dim1]),requires_grad=True, device=\"cuda\")\n", 152 | "\n", 153 | " D_W2 = torch.tensor(xavier_init([H_Dim1, H_Dim2]),requires_grad=True, device=\"cuda\")\n", 154 | " D_b2 = torch.tensor(np.zeros(shape = [H_Dim2]),requires_grad=True, device=\"cuda\")\n", 155 | "\n", 156 | " D_W3 = torch.tensor(xavier_init([H_Dim2, Dim]),requires_grad=True, device=\"cuda\")\n", 157 | " D_b3 = torch.tensor(np.zeros(shape = [Dim]),requires_grad=True, device=\"cuda\") # Output is multi-variate\n", 158 | "else:\n", 159 | " D_W1 = torch.tensor(xavier_init([Dim*2, H_Dim1]),requires_grad=True) # Data + Hint as inputs\n", 160 | " D_b1 = torch.tensor(np.zeros(shape = [H_Dim1]),requires_grad=True)\n", 161 | "\n", 162 | " D_W2 = torch.tensor(xavier_init([H_Dim1, H_Dim2]),requires_grad=True)\n", 163 | " D_b2 = torch.tensor(np.zeros(shape = [H_Dim2]),requires_grad=True)\n", 164 | "\n", 165 | " D_W3 = torch.tensor(xavier_init([H_Dim2, Dim]),requires_grad=True)\n", 166 | " D_b3 = torch.tensor(np.zeros(shape = [Dim]),requires_grad=True) # Output is multi-variate\n", 167 | "\n", 168 | "theta_D = [D_W1, D_W2, D_W3, D_b1, D_b2, D_b3]\n", 169 | "\n", 170 | "#%% 2. Generator\n", 171 | "if use_gpu is True:\n", 172 | " G_W1 = torch.tensor(xavier_init([Dim*2, H_Dim1]),requires_grad=True, device=\"cuda\") # Data + Mask as inputs (Random Noises are in Missing Components)\n", 173 | " G_b1 = torch.tensor(np.zeros(shape = [H_Dim1]),requires_grad=True, device=\"cuda\")\n", 174 | "\n", 175 | " G_W2 = torch.tensor(xavier_init([H_Dim1, H_Dim2]),requires_grad=True, device=\"cuda\")\n", 176 | " G_b2 = torch.tensor(np.zeros(shape = [H_Dim2]),requires_grad=True, device=\"cuda\")\n", 177 | "\n", 178 | " G_W3 = torch.tensor(xavier_init([H_Dim2, Dim]),requires_grad=True, device=\"cuda\")\n", 179 | " G_b3 = torch.tensor(np.zeros(shape = [Dim]),requires_grad=True, device=\"cuda\")\n", 180 | "else:\n", 181 | " G_W1 = torch.tensor(xavier_init([Dim*2, H_Dim1]),requires_grad=True) # Data + Mask as inputs (Random Noises are in Missing Components)\n", 182 | " G_b1 = torch.tensor(np.zeros(shape = [H_Dim1]),requires_grad=True)\n", 183 | "\n", 184 | " G_W2 = torch.tensor(xavier_init([H_Dim1, H_Dim2]),requires_grad=True)\n", 185 | " G_b2 = torch.tensor(np.zeros(shape = [H_Dim2]),requires_grad=True)\n", 186 | "\n", 187 | " G_W3 = torch.tensor(xavier_init([H_Dim2, Dim]),requires_grad=True)\n", 188 | " G_b3 = torch.tensor(np.zeros(shape = [Dim]),requires_grad=True)\n", 189 | "\n", 190 | "theta_G = [G_W1, G_W2, G_W3, G_b1, G_b2, G_b3]" 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "metadata": {}, 196 | "source": [ 197 | "## GAIN Functions" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 5, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "#%% 1. Generator\n", 207 | "def generator(new_x,m):\n", 208 | " inputs = torch.cat(dim = 1, tensors = [new_x,m]) # Mask + Data Concatenate\n", 209 | " G_h1 = F.relu(torch.matmul(inputs, G_W1) + G_b1)\n", 210 | " G_h2 = F.relu(torch.matmul(G_h1, G_W2) + G_b2) \n", 211 | " G_prob = torch.sigmoid(torch.matmul(G_h2, G_W3) + G_b3) # [0,1] normalized Output\n", 212 | " \n", 213 | " return G_prob\n", 214 | "\n", 215 | "#%% 2. Discriminator\n", 216 | "def discriminator(new_x, h):\n", 217 | " inputs = torch.cat(dim = 1, tensors = [new_x,h]) # Hint + Data Concatenate\n", 218 | " D_h1 = F.relu(torch.matmul(inputs, D_W1) + D_b1) \n", 219 | " D_h2 = F.relu(torch.matmul(D_h1, D_W2) + D_b2)\n", 220 | " D_logit = torch.matmul(D_h2, D_W3) + D_b3\n", 221 | " D_prob = torch.sigmoid(D_logit) # [0,1] Probability Output\n", 222 | " \n", 223 | " return D_prob\n", 224 | "\n", 225 | "#%% 3. Other functions\n", 226 | "# Random sample generator for Z\n", 227 | "def sample_Z(m, n):\n", 228 | " return np.random.uniform(0., 0.01, size = [m, n]) \n", 229 | "\n", 230 | "# Mini-batch generation\n", 231 | "def sample_idx(m, n):\n", 232 | " A = np.random.permutation(m)\n", 233 | " idx = A[:n]\n", 234 | " return idx" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": {}, 240 | "source": [ 241 | "## GAIN Losses" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 6, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "def discriminator_loss(M, New_X, H):\n", 251 | " # Generator\n", 252 | " G_sample = generator(New_X,M)\n", 253 | " # Combine with original data\n", 254 | " Hat_New_X = New_X * M + G_sample * (1-M)\n", 255 | "\n", 256 | " # Discriminator\n", 257 | " D_prob = discriminator(Hat_New_X, H)\n", 258 | "\n", 259 | " #%% Loss\n", 260 | " D_loss = -torch.mean(M * torch.log(D_prob + 1e-8) + (1-M) * torch.log(1. - D_prob + 1e-8))\n", 261 | " return D_loss\n", 262 | "\n", 263 | "def generator_loss(X, M, New_X, H):\n", 264 | " #%% Structure\n", 265 | " # Generator\n", 266 | " G_sample = generator(New_X,M)\n", 267 | "\n", 268 | " # Combine with original data\n", 269 | " Hat_New_X = New_X * M + G_sample * (1-M)\n", 270 | "\n", 271 | " # Discriminator\n", 272 | " D_prob = discriminator(Hat_New_X, H)\n", 273 | "\n", 274 | " #%% Loss\n", 275 | " G_loss1 = -torch.mean((1-M) * torch.log(D_prob + 1e-8))\n", 276 | " MSE_train_loss = torch.mean((M * New_X - M * G_sample)**2) / torch.mean(M)\n", 277 | "\n", 278 | " G_loss = G_loss1 + alpha * MSE_train_loss \n", 279 | "\n", 280 | " #%% MSE Performance metric\n", 281 | " MSE_test_loss = torch.mean(((1-M) * X - (1-M)*G_sample)**2) / torch.mean(1-M)\n", 282 | " return G_loss, MSE_train_loss, MSE_test_loss\n", 283 | " \n", 284 | "def test_loss(X, M, New_X):\n", 285 | " #%% Structure\n", 286 | " # Generator\n", 287 | " G_sample = generator(New_X,M)\n", 288 | "\n", 289 | " #%% MSE Performance metric\n", 290 | " MSE_test_loss = torch.mean(((1-M) * X - (1-M)*G_sample)**2) / torch.mean(1-M)\n", 291 | " return MSE_test_loss, G_sample" 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "metadata": {}, 297 | "source": [ 298 | "## Optimizers" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 7, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "optimizer_D = torch.optim.Adam(params=theta_D)\n", 308 | "optimizer_G = torch.optim.Adam(params=theta_G)" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": {}, 314 | "source": [ 315 | "## Training" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 8, 321 | "metadata": { 322 | "scrolled": false 323 | }, 324 | "outputs": [ 325 | { 326 | "data": { 327 | "application/vnd.jupyter.widget-view+json": { 328 | "model_id": "1fd9356d77d842f0afd2cd32c64e5e7f", 329 | "version_major": 2, 330 | "version_minor": 0 331 | }, 332 | "text/plain": [ 333 | "HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))" 334 | ] 335 | }, 336 | "metadata": {}, 337 | "output_type": "display_data" 338 | }, 339 | { 340 | "name": "stdout", 341 | "output_type": "stream", 342 | "text": [ 343 | "Iter: 0\tTrain_loss: 0.5377\tTest_loss: 0.534\n", 344 | "Iter: 100\tTrain_loss: 0.05745\tTest_loss: 0.06301\n", 345 | "Iter: 200\tTrain_loss: 0.05156\tTest_loss: 0.05043\n", 346 | "Iter: 300\tTrain_loss: 0.06761\tTest_loss: 0.07604\n", 347 | "Iter: 400\tTrain_loss: 0.05879\tTest_loss: 0.05206\n", 348 | "Iter: 500\tTrain_loss: 0.05116\tTest_loss: 0.06352\n", 349 | "Iter: 600\tTrain_loss: 0.05932\tTest_loss: 0.05469\n", 350 | "Iter: 700\tTrain_loss: 0.04748\tTest_loss: 0.05652\n", 351 | "Iter: 800\tTrain_loss: 0.05378\tTest_loss: 0.06295\n", 352 | "Iter: 900\tTrain_loss: 0.04738\tTest_loss: 0.05832\n", 353 | "Iter: 1000\tTrain_loss: 0.05263\tTest_loss: 0.05414\n", 354 | "Iter: 1100\tTrain_loss: 0.04671\tTest_loss: 0.06805\n", 355 | "Iter: 1200\tTrain_loss: 0.04553\tTest_loss: 0.04829\n", 356 | "Iter: 1300\tTrain_loss: 0.04856\tTest_loss: 0.04031\n", 357 | "Iter: 1400\tTrain_loss: 0.04177\tTest_loss: 0.04624\n", 358 | "Iter: 1500\tTrain_loss: 0.03908\tTest_loss: 0.05898\n", 359 | "Iter: 1600\tTrain_loss: 0.03954\tTest_loss: 0.05035\n", 360 | "Iter: 1700\tTrain_loss: 0.04188\tTest_loss: 0.06045\n", 361 | "Iter: 1800\tTrain_loss: 0.03277\tTest_loss: 0.05487\n", 362 | "Iter: 1900\tTrain_loss: 0.0325\tTest_loss: 0.0554\n", 363 | "Iter: 2000\tTrain_loss: 0.03252\tTest_loss: 0.0577\n", 364 | "Iter: 2100\tTrain_loss: 0.03432\tTest_loss: 0.06288\n", 365 | "Iter: 2200\tTrain_loss: 0.04173\tTest_loss: 0.05097\n", 366 | "Iter: 2300\tTrain_loss: 0.03515\tTest_loss: 0.0504\n", 367 | "Iter: 2400\tTrain_loss: 0.03183\tTest_loss: 0.05901\n", 368 | "Iter: 2500\tTrain_loss: 0.03237\tTest_loss: 0.04286\n", 369 | "Iter: 2600\tTrain_loss: 0.03039\tTest_loss: 0.06045\n", 370 | "Iter: 2700\tTrain_loss: 0.03477\tTest_loss: 0.04825\n", 371 | "Iter: 2800\tTrain_loss: 0.03078\tTest_loss: 0.05256\n", 372 | "Iter: 2900\tTrain_loss: 0.03695\tTest_loss: 0.0402\n", 373 | "Iter: 3000\tTrain_loss: 0.02848\tTest_loss: 0.04715\n", 374 | "Iter: 3100\tTrain_loss: 0.02706\tTest_loss: 0.04177\n", 375 | "Iter: 3200\tTrain_loss: 0.0361\tTest_loss: 0.05011\n", 376 | "Iter: 3300\tTrain_loss: 0.03275\tTest_loss: 0.05108\n", 377 | "Iter: 3400\tTrain_loss: 0.02734\tTest_loss: 0.05991\n", 378 | "Iter: 3500\tTrain_loss: 0.02801\tTest_loss: 0.03673\n", 379 | "Iter: 3600\tTrain_loss: 0.03044\tTest_loss: 0.05335\n", 380 | "Iter: 3700\tTrain_loss: 0.0308\tTest_loss: 0.03785\n", 381 | "Iter: 3800\tTrain_loss: 0.0245\tTest_loss: 0.05691\n", 382 | "Iter: 3900\tTrain_loss: 0.02657\tTest_loss: 0.05048\n", 383 | "Iter: 4000\tTrain_loss: 0.02559\tTest_loss: 0.05374\n", 384 | "Iter: 4100\tTrain_loss: 0.0234\tTest_loss: 0.05291\n", 385 | "Iter: 4200\tTrain_loss: 0.02384\tTest_loss: 0.04727\n", 386 | "Iter: 4300\tTrain_loss: 0.02709\tTest_loss: 0.05202\n", 387 | "Iter: 4400\tTrain_loss: 0.03058\tTest_loss: 0.05541\n", 388 | "Iter: 4500\tTrain_loss: 0.02499\tTest_loss: 0.05595\n", 389 | "Iter: 4600\tTrain_loss: 0.02364\tTest_loss: 0.05992\n", 390 | "Iter: 4700\tTrain_loss: 0.02053\tTest_loss: 0.05415\n", 391 | "Iter: 4800\tTrain_loss: 0.02321\tTest_loss: 0.04843\n", 392 | "Iter: 4900\tTrain_loss: 0.02571\tTest_loss: 0.05179\n", 393 | "\n" 394 | ] 395 | } 396 | ], 397 | "source": [ 398 | "#%% Start Iterations\n", 399 | "for it in tqdm(range(5000)): \n", 400 | " \n", 401 | " #%% Inputs\n", 402 | " mb_idx = sample_idx(Train_No, mb_size)\n", 403 | " X_mb = trainX[mb_idx,:] \n", 404 | " \n", 405 | " Z_mb = sample_Z(mb_size, Dim) \n", 406 | " M_mb = trainM[mb_idx,:] \n", 407 | " H_mb1 = sample_M(mb_size, Dim, 1-p_hint)\n", 408 | " H_mb = M_mb * H_mb1\n", 409 | " \n", 410 | " New_X_mb = M_mb * X_mb + (1-M_mb) * Z_mb # Missing Data Introduce\n", 411 | " \n", 412 | " if use_gpu is True:\n", 413 | " X_mb = torch.tensor(X_mb, device=\"cuda\")\n", 414 | " M_mb = torch.tensor(M_mb, device=\"cuda\")\n", 415 | " H_mb = torch.tensor(H_mb, device=\"cuda\")\n", 416 | " New_X_mb = torch.tensor(New_X_mb, device=\"cuda\")\n", 417 | " else:\n", 418 | " X_mb = torch.tensor(X_mb)\n", 419 | " M_mb = torch.tensor(M_mb)\n", 420 | " H_mb = torch.tensor(H_mb)\n", 421 | " New_X_mb = torch.tensor(New_X_mb)\n", 422 | " \n", 423 | " optimizer_D.zero_grad()\n", 424 | " D_loss_curr = discriminator_loss(M=M_mb, New_X=New_X_mb, H=H_mb)\n", 425 | " D_loss_curr.backward()\n", 426 | " optimizer_D.step()\n", 427 | " \n", 428 | " optimizer_G.zero_grad()\n", 429 | " G_loss_curr, MSE_train_loss_curr, MSE_test_loss_curr = generator_loss(X=X_mb, M=M_mb, New_X=New_X_mb, H=H_mb)\n", 430 | " G_loss_curr.backward()\n", 431 | " optimizer_G.step() \n", 432 | " \n", 433 | " #%% Intermediate Losses\n", 434 | " if it % 100 == 0:\n", 435 | " print('Iter: {}'.format(it),end='\\t')\n", 436 | " print('Train_loss: {:.4}'.format(np.sqrt(MSE_train_loss_curr.item())),end='\\t')\n", 437 | " print('Test_loss: {:.4}'.format(np.sqrt(MSE_test_loss_curr.item())))" 438 | ] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "metadata": {}, 443 | "source": [ 444 | "## Testing" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": 9, 450 | "metadata": {}, 451 | "outputs": [ 452 | { 453 | "name": "stdout", 454 | "output_type": "stream", 455 | "text": [ 456 | "Final Test RMSE: 0.058903747078809826\n" 457 | ] 458 | } 459 | ], 460 | "source": [ 461 | "Z_mb = sample_Z(Test_No, Dim) \n", 462 | "M_mb = testM\n", 463 | "X_mb = testX\n", 464 | " \n", 465 | "New_X_mb = M_mb * X_mb + (1-M_mb) * Z_mb # Missing Data Introduce\n", 466 | "\n", 467 | "if use_gpu is True:\n", 468 | " X_mb = torch.tensor(X_mb, device='cuda')\n", 469 | " M_mb = torch.tensor(M_mb, device='cuda')\n", 470 | " New_X_mb = torch.tensor(New_X_mb, device='cuda')\n", 471 | "else:\n", 472 | " X_mb = torch.tensor(X_mb)\n", 473 | " M_mb = torch.tensor(M_mb)\n", 474 | " New_X_mb = torch.tensor(New_X_mb)\n", 475 | " \n", 476 | "MSE_final, Sample = test_loss(X=X_mb, M=M_mb, New_X=New_X_mb)\n", 477 | " \n", 478 | "print('Final Test RMSE: ' + str(np.sqrt(MSE_final.item())))" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 10, 484 | "metadata": {}, 485 | "outputs": [ 486 | { 487 | "name": "stdout", 488 | "output_type": "stream", 489 | "text": [ 490 | "Imputed test data:\n", 491 | "[[0. 0.00617567 0.04509803 ... 0.0011094 0.00220264 0.0302399 ]\n", 492 | " [0. 0. 0.03044493 ... 0.00039673 0.00254678 0.00138889]\n", 493 | " [0.01101321 0.0210084 0.07843136 ... 0.0025525 0.00951141 0.03705808]\n", 494 | " ...\n", 495 | " [0.03524228 0. 0.02368228 ... 0.00707947 0.00180216 0.01508838]\n", 496 | " [0. 0.01680672 0. ... 0.0080581 0.02342811 0.07039141]\n", 497 | " [0. 0. 0. ... 0.0005547 0.00151935 0.00176768]]\n" 498 | ] 499 | } 500 | ], 501 | "source": [ 502 | "imputed_data = M_mb * X_mb + (1-M_mb) * Sample\n", 503 | "print(\"Imputed test data:\")\n", 504 | "np.set_printoptions(formatter={'float': lambda x: \"{0:0.8f}\".format(x)})\n", 505 | "\n", 506 | "if use_gpu is True:\n", 507 | " print(imputed_data.cpu().detach().numpy())\n", 508 | "else:\n", 509 | " print(imputed_data.detach().numpy())" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": null, 515 | "metadata": {}, 516 | "outputs": [], 517 | "source": [] 518 | } 519 | ], 520 | "metadata": { 521 | "kernelspec": { 522 | "display_name": "Python [conda env:pt]", 523 | "language": "python", 524 | "name": "conda-env-pt-py" 525 | }, 526 | "language_info": { 527 | "codemirror_mode": { 528 | "name": "ipython", 529 | "version": 3 530 | }, 531 | "file_extension": ".py", 532 | "mimetype": "text/x-python", 533 | "name": "python", 534 | "nbconvert_exporter": "python", 535 | "pygments_lexer": "ipython3", 536 | "version": "3.6.10" 537 | } 538 | }, 539 | "nbformat": 4, 540 | "nbformat_minor": 2 541 | } 542 | -------------------------------------------------------------------------------- /GAIN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Creator: Dhanajit Brahma\n", 8 | "\n", 9 | "Adapted from the original implementation in tensorflow from here: https://github.com/jsyoon0823/GAIN\n", 10 | "\n", 11 | "Generative Adversarial Imputation Networks (GAIN) Implementation on Letter and Spam Dataset\n", 12 | "\n", 13 | "Reference: J. Yoon, J. Jordon, M. van der Schaar, \"GAIN: Missing Data Imputation using Generative Adversarial Nets,\" ICML, 2018." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "#%% Packages\n", 23 | "import torch\n", 24 | "import numpy as np\n", 25 | "from tqdm import tqdm\n", 26 | "import torch.nn.functional as F" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "dataset_file = 'Spam.csv' # 'Letter.csv' for Letter dataset an 'Spam.csv' for Spam dataset\n", 36 | "use_gpu = False # set it to True to use GPU and False to use CPU\n", 37 | "if use_gpu:\n", 38 | " torch.cuda.set_device(0)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "#%% System Parameters\n", 48 | "# 1. Mini batch size\n", 49 | "mb_size = 128\n", 50 | "# 2. Missing rate\n", 51 | "p_miss = 0.2\n", 52 | "# 3. Hint rate\n", 53 | "p_hint = 0.9\n", 54 | "# 4. Loss Hyperparameters\n", 55 | "alpha = 10\n", 56 | "# 5. Train Rate\n", 57 | "train_rate = 0.8\n", 58 | "\n", 59 | "#%% Data\n", 60 | "\n", 61 | "# Data generation\n", 62 | "Data = np.loadtxt(dataset_file, delimiter=\",\",skiprows=1)\n", 63 | "\n", 64 | "# Parameters\n", 65 | "No = len(Data)\n", 66 | "Dim = len(Data[0,:])\n", 67 | "\n", 68 | "# Hidden state dimensions\n", 69 | "H_Dim1 = Dim\n", 70 | "H_Dim2 = Dim\n", 71 | "\n", 72 | "# Normalization (0 to 1)\n", 73 | "Min_Val = np.zeros(Dim)\n", 74 | "Max_Val = np.zeros(Dim)\n", 75 | "\n", 76 | "for i in range(Dim):\n", 77 | " Min_Val[i] = np.min(Data[:,i])\n", 78 | " Data[:,i] = Data[:,i] - np.min(Data[:,i])\n", 79 | " Max_Val[i] = np.max(Data[:,i])\n", 80 | " Data[:,i] = Data[:,i] / (np.max(Data[:,i]) + 1e-6) \n", 81 | "\n", 82 | "#%% Missing introducing\n", 83 | "p_miss_vec = p_miss * np.ones((Dim,1)) \n", 84 | " \n", 85 | "Missing = np.zeros((No,Dim))\n", 86 | "\n", 87 | "for i in range(Dim):\n", 88 | " A = np.random.uniform(0., 1., size = [len(Data),])\n", 89 | " B = A > p_miss_vec[i]\n", 90 | " Missing[:,i] = 1.*B\n", 91 | "\n", 92 | " \n", 93 | "#%% Train Test Division \n", 94 | " \n", 95 | "idx = np.random.permutation(No)\n", 96 | "\n", 97 | "Train_No = int(No * train_rate)\n", 98 | "Test_No = No - Train_No\n", 99 | " \n", 100 | "# Train / Test Features\n", 101 | "trainX = Data[idx[:Train_No],:]\n", 102 | "testX = Data[idx[Train_No:],:]\n", 103 | "\n", 104 | "# Train / Test Missing Indicators\n", 105 | "trainM = Missing[idx[:Train_No],:]\n", 106 | "testM = Missing[idx[Train_No:],:]\n", 107 | "\n", 108 | "#%% Necessary Functions\n", 109 | "\n", 110 | "# 1. Xavier Initialization Definition\n", 111 | "# def xavier_init(size):\n", 112 | "# in_dim = size[0]\n", 113 | "# xavier_stddev = 1. / tf.sqrt(in_dim / 2.)\n", 114 | "# return tf.random_normal(shape = size, stddev = xavier_stddev)\n", 115 | "def xavier_init(size):\n", 116 | " in_dim = size[0]\n", 117 | " xavier_stddev = 1. / np.sqrt(in_dim / 2.)\n", 118 | " return np.random.normal(size = size, scale = xavier_stddev)\n", 119 | " \n", 120 | "# Hint Vector Generation\n", 121 | "def sample_M(m, n, p):\n", 122 | " A = np.random.uniform(0., 1., size = [m, n])\n", 123 | " B = A > p\n", 124 | " C = 1.*B\n", 125 | " return C\n", 126 | " " 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "### GAIN Architecture \n", 134 | "GAIN Consists of 3 Components\n", 135 | "- Generator\n", 136 | "- Discriminator\n", 137 | "- Hint Mechanism" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 4, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "#%% 1. Discriminator\n", 147 | "if use_gpu is True:\n", 148 | " D_W1 = torch.tensor(xavier_init([Dim*2, H_Dim1]),requires_grad=True, device=\"cuda\") # Data + Hint as inputs\n", 149 | " D_b1 = torch.tensor(np.zeros(shape = [H_Dim1]),requires_grad=True, device=\"cuda\")\n", 150 | "\n", 151 | " D_W2 = torch.tensor(xavier_init([H_Dim1, H_Dim2]),requires_grad=True, device=\"cuda\")\n", 152 | " D_b2 = torch.tensor(np.zeros(shape = [H_Dim2]),requires_grad=True, device=\"cuda\")\n", 153 | "\n", 154 | " D_W3 = torch.tensor(xavier_init([H_Dim2, Dim]),requires_grad=True, device=\"cuda\")\n", 155 | " D_b3 = torch.tensor(np.zeros(shape = [Dim]),requires_grad=True, device=\"cuda\") # Output is multi-variate\n", 156 | "else:\n", 157 | " D_W1 = torch.tensor(xavier_init([Dim*2, H_Dim1]),requires_grad=True) # Data + Hint as inputs\n", 158 | " D_b1 = torch.tensor(np.zeros(shape = [H_Dim1]),requires_grad=True)\n", 159 | "\n", 160 | " D_W2 = torch.tensor(xavier_init([H_Dim1, H_Dim2]),requires_grad=True)\n", 161 | " D_b2 = torch.tensor(np.zeros(shape = [H_Dim2]),requires_grad=True)\n", 162 | "\n", 163 | " D_W3 = torch.tensor(xavier_init([H_Dim2, Dim]),requires_grad=True)\n", 164 | " D_b3 = torch.tensor(np.zeros(shape = [Dim]),requires_grad=True) # Output is multi-variate\n", 165 | "\n", 166 | "theta_D = [D_W1, D_W2, D_W3, D_b1, D_b2, D_b3]\n", 167 | "\n", 168 | "#%% 2. Generator\n", 169 | "if use_gpu is True:\n", 170 | " G_W1 = torch.tensor(xavier_init([Dim*2, H_Dim1]),requires_grad=True, device=\"cuda\") # Data + Mask as inputs (Random Noises are in Missing Components)\n", 171 | " G_b1 = torch.tensor(np.zeros(shape = [H_Dim1]),requires_grad=True, device=\"cuda\")\n", 172 | "\n", 173 | " G_W2 = torch.tensor(xavier_init([H_Dim1, H_Dim2]),requires_grad=True, device=\"cuda\")\n", 174 | " G_b2 = torch.tensor(np.zeros(shape = [H_Dim2]),requires_grad=True, device=\"cuda\")\n", 175 | "\n", 176 | " G_W3 = torch.tensor(xavier_init([H_Dim2, Dim]),requires_grad=True, device=\"cuda\")\n", 177 | " G_b3 = torch.tensor(np.zeros(shape = [Dim]),requires_grad=True, device=\"cuda\")\n", 178 | "else:\n", 179 | " G_W1 = torch.tensor(xavier_init([Dim*2, H_Dim1]),requires_grad=True) # Data + Mask as inputs (Random Noises are in Missing Components)\n", 180 | " G_b1 = torch.tensor(np.zeros(shape = [H_Dim1]),requires_grad=True)\n", 181 | "\n", 182 | " G_W2 = torch.tensor(xavier_init([H_Dim1, H_Dim2]),requires_grad=True)\n", 183 | " G_b2 = torch.tensor(np.zeros(shape = [H_Dim2]),requires_grad=True)\n", 184 | "\n", 185 | " G_W3 = torch.tensor(xavier_init([H_Dim2, Dim]),requires_grad=True)\n", 186 | " G_b3 = torch.tensor(np.zeros(shape = [Dim]),requires_grad=True)\n", 187 | "\n", 188 | "theta_G = [G_W1, G_W2, G_W3, G_b1, G_b2, G_b3]" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "## GAIN Functions" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 5, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "#%% 1. Generator\n", 205 | "def generator(new_x,m):\n", 206 | " inputs = torch.cat(dim = 1, tensors = [new_x,m]) # Mask + Data Concatenate\n", 207 | " G_h1 = F.relu(torch.matmul(inputs, G_W1) + G_b1)\n", 208 | " G_h2 = F.relu(torch.matmul(G_h1, G_W2) + G_b2) \n", 209 | " G_prob = torch.sigmoid(torch.matmul(G_h2, G_W3) + G_b3) # [0,1] normalized Output\n", 210 | " \n", 211 | " return G_prob\n", 212 | "\n", 213 | "#%% 2. Discriminator\n", 214 | "def discriminator(new_x, h):\n", 215 | " inputs = torch.cat(dim = 1, tensors = [new_x,h]) # Hint + Data Concatenate\n", 216 | " D_h1 = F.relu(torch.matmul(inputs, D_W1) + D_b1) \n", 217 | " D_h2 = F.relu(torch.matmul(D_h1, D_W2) + D_b2)\n", 218 | " D_logit = torch.matmul(D_h2, D_W3) + D_b3\n", 219 | " D_prob = torch.sigmoid(D_logit) # [0,1] Probability Output\n", 220 | " \n", 221 | " return D_prob\n", 222 | "\n", 223 | "#%% 3. Other functions\n", 224 | "# Random sample generator for Z\n", 225 | "def sample_Z(m, n):\n", 226 | " return np.random.uniform(0., 0.01, size = [m, n]) \n", 227 | "\n", 228 | "# Mini-batch generation\n", 229 | "def sample_idx(m, n):\n", 230 | " A = np.random.permutation(m)\n", 231 | " idx = A[:n]\n", 232 | " return idx" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": {}, 238 | "source": [ 239 | "## GAIN Losses" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 6, 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "def discriminator_loss(M, New_X, H):\n", 249 | " # Generator\n", 250 | " G_sample = generator(New_X,M)\n", 251 | " # Combine with original data\n", 252 | " Hat_New_X = New_X * M + G_sample * (1-M)\n", 253 | "\n", 254 | " # Discriminator\n", 255 | " D_prob = discriminator(Hat_New_X, H)\n", 256 | "\n", 257 | " #%% Loss\n", 258 | " D_loss = -torch.mean(M * torch.log(D_prob + 1e-8) + (1-M) * torch.log(1. - D_prob + 1e-8))\n", 259 | " return D_loss\n", 260 | "\n", 261 | "def generator_loss(X, M, New_X, H):\n", 262 | " #%% Structure\n", 263 | " # Generator\n", 264 | " G_sample = generator(New_X,M)\n", 265 | "\n", 266 | " # Combine with original data\n", 267 | " Hat_New_X = New_X * M + G_sample * (1-M)\n", 268 | "\n", 269 | " # Discriminator\n", 270 | " D_prob = discriminator(Hat_New_X, H)\n", 271 | "\n", 272 | " #%% Loss\n", 273 | " G_loss1 = -torch.mean((1-M) * torch.log(D_prob + 1e-8))\n", 274 | " MSE_train_loss = torch.mean((M * New_X - M * G_sample)**2) / torch.mean(M)\n", 275 | "\n", 276 | " G_loss = G_loss1 + alpha * MSE_train_loss \n", 277 | "\n", 278 | " #%% MSE Performance metric\n", 279 | " MSE_test_loss = torch.mean(((1-M) * X - (1-M)*G_sample)**2) / torch.mean(1-M)\n", 280 | " return G_loss, MSE_train_loss, MSE_test_loss\n", 281 | " \n", 282 | "def test_loss(X, M, New_X):\n", 283 | " #%% Structure\n", 284 | " # Generator\n", 285 | " G_sample = generator(New_X,M)\n", 286 | "\n", 287 | " #%% MSE Performance metric\n", 288 | " MSE_test_loss = torch.mean(((1-M) * X - (1-M)*G_sample)**2) / torch.mean(1-M)\n", 289 | " return MSE_test_loss, G_sample" 290 | ] 291 | }, 292 | { 293 | "cell_type": "markdown", 294 | "metadata": {}, 295 | "source": [ 296 | "## Optimizers" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 7, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "optimizer_D = torch.optim.Adam(params=theta_D)\n", 306 | "optimizer_G = torch.optim.Adam(params=theta_G)" 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "## Training" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": { 320 | "scrolled": false 321 | }, 322 | "outputs": [ 323 | { 324 | "name": "stderr", 325 | "output_type": "stream", 326 | "text": [ 327 | " 0%| | 17/5000 [00:00<01:11, 69.32it/s]" 328 | ] 329 | }, 330 | { 331 | "name": "stdout", 332 | "output_type": "stream", 333 | "text": [ 334 | "Iter: 0\n", 335 | "Train_loss: 0.5529\n", 336 | "Test_loss: 0.5574\n", 337 | "\n" 338 | ] 339 | }, 340 | { 341 | "name": "stderr", 342 | "output_type": "stream", 343 | "text": [ 344 | " 2%|▏ | 112/5000 [00:01<00:46, 104.41it/s]" 345 | ] 346 | }, 347 | { 348 | "name": "stdout", 349 | "output_type": "stream", 350 | "text": [ 351 | "Iter: 100\n", 352 | "Train_loss: 0.07032\n", 353 | "Test_loss: 0.05963\n", 354 | "\n" 355 | ] 356 | }, 357 | { 358 | "name": "stderr", 359 | "output_type": "stream", 360 | "text": [ 361 | " 4%|▍ | 211/5000 [00:02<00:45, 105.64it/s]" 362 | ] 363 | }, 364 | { 365 | "name": "stdout", 366 | "output_type": "stream", 367 | "text": [ 368 | "Iter: 200\n", 369 | "Train_loss: 0.06312\n", 370 | "Test_loss: 0.06026\n", 371 | "\n" 372 | ] 373 | }, 374 | { 375 | "name": "stderr", 376 | "output_type": "stream", 377 | "text": [ 378 | " 6%|▌ | 310/5000 [00:02<00:44, 104.49it/s]" 379 | ] 380 | }, 381 | { 382 | "name": "stdout", 383 | "output_type": "stream", 384 | "text": [ 385 | "Iter: 300\n", 386 | "Train_loss: 0.06422\n", 387 | "Test_loss: 0.0617\n", 388 | "\n" 389 | ] 390 | }, 391 | { 392 | "name": "stderr", 393 | "output_type": "stream", 394 | "text": [ 395 | " 8%|▊ | 415/5000 [00:03<00:40, 112.52it/s]" 396 | ] 397 | }, 398 | { 399 | "name": "stdout", 400 | "output_type": "stream", 401 | "text": [ 402 | "Iter: 400\n", 403 | "Train_loss: 0.06123\n", 404 | "Test_loss: 0.05874\n", 405 | "\n" 406 | ] 407 | }, 408 | { 409 | "name": "stderr", 410 | "output_type": "stream", 411 | "text": [ 412 | " 10%|█ | 518/5000 [00:04<00:36, 121.83it/s]" 413 | ] 414 | }, 415 | { 416 | "name": "stdout", 417 | "output_type": "stream", 418 | "text": [ 419 | "Iter: 500\n", 420 | "Train_loss: 0.05244\n", 421 | "Test_loss: 0.07433\n", 422 | "\n" 423 | ] 424 | }, 425 | { 426 | "name": "stderr", 427 | "output_type": "stream", 428 | "text": [ 429 | " 12%|█▏ | 622/5000 [00:05<00:37, 118.12it/s]" 430 | ] 431 | }, 432 | { 433 | "name": "stdout", 434 | "output_type": "stream", 435 | "text": [ 436 | "Iter: 600\n", 437 | "Train_loss: 0.05774\n", 438 | "Test_loss: 0.04573\n", 439 | "\n" 440 | ] 441 | }, 442 | { 443 | "name": "stderr", 444 | "output_type": "stream", 445 | "text": [ 446 | " 14%|█▍ | 722/5000 [00:06<00:37, 114.42it/s]" 447 | ] 448 | }, 449 | { 450 | "name": "stdout", 451 | "output_type": "stream", 452 | "text": [ 453 | "Iter: 700\n", 454 | "Train_loss: 0.04833\n", 455 | "Test_loss: 0.04788\n", 456 | "\n" 457 | ] 458 | }, 459 | { 460 | "name": "stderr", 461 | "output_type": "stream", 462 | "text": [ 463 | " 16%|█▋ | 815/5000 [00:07<00:38, 109.62it/s]" 464 | ] 465 | }, 466 | { 467 | "name": "stdout", 468 | "output_type": "stream", 469 | "text": [ 470 | "Iter: 800\n", 471 | "Train_loss: 0.04508\n", 472 | "Test_loss: 0.06754\n", 473 | "\n" 474 | ] 475 | }, 476 | { 477 | "name": "stderr", 478 | "output_type": "stream", 479 | "text": [ 480 | " 18%|█▊ | 917/5000 [00:08<00:41, 98.30it/s] " 481 | ] 482 | }, 483 | { 484 | "name": "stdout", 485 | "output_type": "stream", 486 | "text": [ 487 | "Iter: 900\n", 488 | "Train_loss: 0.05552\n", 489 | "Test_loss: 0.06502\n", 490 | "\n" 491 | ] 492 | }, 493 | { 494 | "name": "stderr", 495 | "output_type": "stream", 496 | "text": [ 497 | " 20%|██ | 1019/5000 [00:09<00:39, 100.36it/s]" 498 | ] 499 | }, 500 | { 501 | "name": "stdout", 502 | "output_type": "stream", 503 | "text": [ 504 | "Iter: 1000\n", 505 | "Train_loss: 0.05365\n", 506 | "Test_loss: 0.07008\n", 507 | "\n" 508 | ] 509 | }, 510 | { 511 | "name": "stderr", 512 | "output_type": "stream", 513 | "text": [ 514 | " 22%|██▏ | 1113/5000 [00:10<00:35, 109.98it/s]" 515 | ] 516 | }, 517 | { 518 | "name": "stdout", 519 | "output_type": "stream", 520 | "text": [ 521 | "Iter: 1100\n", 522 | "Train_loss: 0.04365\n", 523 | "Test_loss: 0.05524\n", 524 | "\n" 525 | ] 526 | }, 527 | { 528 | "name": "stderr", 529 | "output_type": "stream", 530 | "text": [ 531 | " 24%|██▍ | 1220/5000 [00:11<00:35, 106.01it/s]" 532 | ] 533 | }, 534 | { 535 | "name": "stdout", 536 | "output_type": "stream", 537 | "text": [ 538 | "Iter: 1200\n", 539 | "Train_loss: 0.04645\n", 540 | "Test_loss: 0.04994\n", 541 | "\n" 542 | ] 543 | }, 544 | { 545 | "name": "stderr", 546 | "output_type": "stream", 547 | "text": [ 548 | " 26%|██▌ | 1311/5000 [00:12<00:35, 103.89it/s]" 549 | ] 550 | }, 551 | { 552 | "name": "stdout", 553 | "output_type": "stream", 554 | "text": [ 555 | "Iter: 1300\n", 556 | "Train_loss: 0.04116\n", 557 | "Test_loss: 0.05477\n", 558 | "\n" 559 | ] 560 | }, 561 | { 562 | "name": "stderr", 563 | "output_type": "stream", 564 | "text": [ 565 | " 28%|██▊ | 1412/5000 [00:13<00:36, 98.95it/s] " 566 | ] 567 | }, 568 | { 569 | "name": "stdout", 570 | "output_type": "stream", 571 | "text": [ 572 | "Iter: 1400\n", 573 | "Train_loss: 0.04603\n", 574 | "Test_loss: 0.05895\n", 575 | "\n" 576 | ] 577 | }, 578 | { 579 | "name": "stderr", 580 | "output_type": "stream", 581 | "text": [ 582 | " 30%|███ | 1512/5000 [00:13<00:32, 106.19it/s]" 583 | ] 584 | }, 585 | { 586 | "name": "stdout", 587 | "output_type": "stream", 588 | "text": [ 589 | "Iter: 1500\n", 590 | "Train_loss: 0.04387\n", 591 | "Test_loss: 0.05071\n", 592 | "\n" 593 | ] 594 | }, 595 | { 596 | "name": "stderr", 597 | "output_type": "stream", 598 | "text": [ 599 | " 32%|███▏ | 1614/5000 [00:14<00:33, 101.73it/s]" 600 | ] 601 | }, 602 | { 603 | "name": "stdout", 604 | "output_type": "stream", 605 | "text": [ 606 | "Iter: 1600\n", 607 | "Train_loss: 0.03953\n", 608 | "Test_loss: 0.06814\n", 609 | "\n" 610 | ] 611 | }, 612 | { 613 | "name": "stderr", 614 | "output_type": "stream", 615 | "text": [ 616 | " 34%|███▍ | 1713/5000 [00:15<00:30, 108.21it/s]" 617 | ] 618 | }, 619 | { 620 | "name": "stdout", 621 | "output_type": "stream", 622 | "text": [ 623 | "Iter: 1700\n", 624 | "Train_loss: 0.04072\n", 625 | "Test_loss: 0.04313\n", 626 | "\n" 627 | ] 628 | }, 629 | { 630 | "name": "stderr", 631 | "output_type": "stream", 632 | "text": [ 633 | " 36%|███▋ | 1813/5000 [00:16<00:29, 106.27it/s]" 634 | ] 635 | }, 636 | { 637 | "name": "stdout", 638 | "output_type": "stream", 639 | "text": [ 640 | "Iter: 1800\n", 641 | "Train_loss: 0.03748\n", 642 | "Test_loss: 0.04535\n", 643 | "\n" 644 | ] 645 | }, 646 | { 647 | "name": "stderr", 648 | "output_type": "stream", 649 | "text": [ 650 | " 38%|███▊ | 1915/5000 [00:17<00:28, 109.63it/s]" 651 | ] 652 | }, 653 | { 654 | "name": "stdout", 655 | "output_type": "stream", 656 | "text": [ 657 | "Iter: 1900\n", 658 | "Train_loss: 0.03597\n", 659 | "Test_loss: 0.04755\n", 660 | "\n" 661 | ] 662 | }, 663 | { 664 | "name": "stderr", 665 | "output_type": "stream", 666 | "text": [ 667 | " 40%|████ | 2016/5000 [00:18<00:30, 97.41it/s] " 668 | ] 669 | }, 670 | { 671 | "name": "stdout", 672 | "output_type": "stream", 673 | "text": [ 674 | "Iter: 2000\n", 675 | "Train_loss: 0.03811\n", 676 | "Test_loss: 0.06289\n", 677 | "\n" 678 | ] 679 | }, 680 | { 681 | "name": "stderr", 682 | "output_type": "stream", 683 | "text": [ 684 | " 42%|████▏ | 2112/5000 [00:19<00:28, 103.07it/s]" 685 | ] 686 | }, 687 | { 688 | "name": "stdout", 689 | "output_type": "stream", 690 | "text": [ 691 | "Iter: 2100\n", 692 | "Train_loss: 0.03472\n", 693 | "Test_loss: 0.04887\n", 694 | "\n" 695 | ] 696 | }, 697 | { 698 | "name": "stderr", 699 | "output_type": "stream", 700 | "text": [ 701 | " 44%|████▍ | 2213/5000 [00:20<00:28, 99.27it/s] " 702 | ] 703 | }, 704 | { 705 | "name": "stdout", 706 | "output_type": "stream", 707 | "text": [ 708 | "Iter: 2200\n", 709 | "Train_loss: 0.03489\n", 710 | "Test_loss: 0.05582\n", 711 | "\n" 712 | ] 713 | }, 714 | { 715 | "name": "stderr", 716 | "output_type": "stream", 717 | "text": [ 718 | " 46%|████▋ | 2315/5000 [00:21<00:25, 103.86it/s]" 719 | ] 720 | }, 721 | { 722 | "name": "stdout", 723 | "output_type": "stream", 724 | "text": [ 725 | "Iter: 2300\n", 726 | "Train_loss: 0.03552\n", 727 | "Test_loss: 0.05143\n", 728 | "\n" 729 | ] 730 | }, 731 | { 732 | "name": "stderr", 733 | "output_type": "stream", 734 | "text": [ 735 | " 48%|████▊ | 2417/5000 [00:23<00:36, 69.82it/s] " 736 | ] 737 | }, 738 | { 739 | "name": "stdout", 740 | "output_type": "stream", 741 | "text": [ 742 | "Iter: 2400\n", 743 | "Train_loss: 0.0326\n", 744 | "Test_loss: 0.05533\n", 745 | "\n" 746 | ] 747 | }, 748 | { 749 | "name": "stderr", 750 | "output_type": "stream", 751 | "text": [ 752 | " 50%|█████ | 2506/5000 [00:24<00:50, 49.05it/s]" 753 | ] 754 | }, 755 | { 756 | "name": "stdout", 757 | "output_type": "stream", 758 | "text": [ 759 | "Iter: 2500\n", 760 | "Train_loss: 0.02871\n", 761 | "Test_loss: 0.05063\n", 762 | "\n" 763 | ] 764 | }, 765 | { 766 | "name": "stderr", 767 | "output_type": "stream", 768 | "text": [ 769 | " 52%|█████▏ | 2608/5000 [00:26<00:49, 48.11it/s]" 770 | ] 771 | }, 772 | { 773 | "name": "stdout", 774 | "output_type": "stream", 775 | "text": [ 776 | "Iter: 2600\n", 777 | "Train_loss: 0.02966\n", 778 | "Test_loss: 0.04797\n", 779 | "\n" 780 | ] 781 | }, 782 | { 783 | "name": "stderr", 784 | "output_type": "stream", 785 | "text": [ 786 | " 54%|█████▍ | 2706/5000 [00:28<00:46, 49.48it/s]" 787 | ] 788 | }, 789 | { 790 | "name": "stdout", 791 | "output_type": "stream", 792 | "text": [ 793 | "Iter: 2700\n", 794 | "Train_loss: 0.03089\n", 795 | "Test_loss: 0.05014\n", 796 | "\n" 797 | ] 798 | }, 799 | { 800 | "name": "stderr", 801 | "output_type": "stream", 802 | "text": [ 803 | " 56%|█████▌ | 2809/5000 [00:30<00:42, 51.80it/s]" 804 | ] 805 | }, 806 | { 807 | "name": "stdout", 808 | "output_type": "stream", 809 | "text": [ 810 | "Iter: 2800\n", 811 | "Train_loss: 0.0266\n", 812 | "Test_loss: 0.05604\n", 813 | "\n" 814 | ] 815 | }, 816 | { 817 | "name": "stderr", 818 | "output_type": "stream", 819 | "text": [ 820 | " 58%|█████▊ | 2906/5000 [00:32<00:37, 55.97it/s]" 821 | ] 822 | }, 823 | { 824 | "name": "stdout", 825 | "output_type": "stream", 826 | "text": [ 827 | "Iter: 2900\n", 828 | "Train_loss: 0.02701\n", 829 | "Test_loss: 0.04171\n", 830 | "\n" 831 | ] 832 | }, 833 | { 834 | "name": "stderr", 835 | "output_type": "stream", 836 | "text": [ 837 | " 60%|██████ | 3014/5000 [00:34<00:27, 72.81it/s]" 838 | ] 839 | }, 840 | { 841 | "name": "stdout", 842 | "output_type": "stream", 843 | "text": [ 844 | "Iter: 3000\n", 845 | "Train_loss: 0.02792\n", 846 | "Test_loss: 0.0569\n", 847 | "\n" 848 | ] 849 | }, 850 | { 851 | "name": "stderr", 852 | "output_type": "stream", 853 | "text": [ 854 | " 62%|██████▏ | 3114/5000 [00:35<00:25, 72.75it/s]" 855 | ] 856 | }, 857 | { 858 | "name": "stdout", 859 | "output_type": "stream", 860 | "text": [ 861 | "Iter: 3100\n", 862 | "Train_loss: 0.03158\n", 863 | "Test_loss: 0.05757\n", 864 | "\n" 865 | ] 866 | }, 867 | { 868 | "name": "stderr", 869 | "output_type": "stream", 870 | "text": [ 871 | " 64%|██████▍ | 3211/5000 [00:37<00:25, 70.25it/s]" 872 | ] 873 | }, 874 | { 875 | "name": "stdout", 876 | "output_type": "stream", 877 | "text": [ 878 | "Iter: 3200\n", 879 | "Train_loss: 0.02715\n", 880 | "Test_loss: 0.04837\n", 881 | "\n" 882 | ] 883 | }, 884 | { 885 | "name": "stderr", 886 | "output_type": "stream", 887 | "text": [ 888 | " 66%|██████▌ | 3309/5000 [00:38<00:22, 74.20it/s]" 889 | ] 890 | }, 891 | { 892 | "name": "stdout", 893 | "output_type": "stream", 894 | "text": [ 895 | "Iter: 3300\n", 896 | "Train_loss: 0.02615\n", 897 | "Test_loss: 0.06774\n", 898 | "\n" 899 | ] 900 | }, 901 | { 902 | "name": "stderr", 903 | "output_type": "stream", 904 | "text": [ 905 | " 68%|██████▊ | 3405/5000 [00:39<00:30, 52.71it/s]" 906 | ] 907 | }, 908 | { 909 | "name": "stdout", 910 | "output_type": "stream", 911 | "text": [ 912 | "Iter: 3400\n", 913 | "Train_loss: 0.02437\n", 914 | "Test_loss: 0.05418\n", 915 | "\n" 916 | ] 917 | }, 918 | { 919 | "name": "stderr", 920 | "output_type": "stream", 921 | "text": [ 922 | " 70%|███████ | 3509/5000 [00:41<00:21, 68.08it/s]" 923 | ] 924 | }, 925 | { 926 | "name": "stdout", 927 | "output_type": "stream", 928 | "text": [ 929 | "Iter: 3500\n", 930 | "Train_loss: 0.02578\n", 931 | "Test_loss: 0.06015\n", 932 | "\n" 933 | ] 934 | }, 935 | { 936 | "name": "stderr", 937 | "output_type": "stream", 938 | "text": [ 939 | " 72%|███████▏ | 3610/5000 [00:43<00:25, 54.86it/s]" 940 | ] 941 | }, 942 | { 943 | "name": "stdout", 944 | "output_type": "stream", 945 | "text": [ 946 | "Iter: 3600\n", 947 | "Train_loss: 0.0286\n", 948 | "Test_loss: 0.04191\n", 949 | "\n" 950 | ] 951 | }, 952 | { 953 | "name": "stderr", 954 | "output_type": "stream", 955 | "text": [ 956 | " 74%|███████▍ | 3711/5000 [00:45<00:20, 62.87it/s]" 957 | ] 958 | }, 959 | { 960 | "name": "stdout", 961 | "output_type": "stream", 962 | "text": [ 963 | "Iter: 3700\n", 964 | "Train_loss: 0.02791\n", 965 | "Test_loss: 0.06541\n", 966 | "\n" 967 | ] 968 | }, 969 | { 970 | "name": "stderr", 971 | "output_type": "stream", 972 | "text": [ 973 | " 76%|███████▌ | 3807/5000 [00:46<00:22, 53.56it/s]" 974 | ] 975 | }, 976 | { 977 | "name": "stdout", 978 | "output_type": "stream", 979 | "text": [ 980 | "Iter: 3800\n", 981 | "Train_loss: 0.02706\n", 982 | "Test_loss: 0.06418\n", 983 | "\n" 984 | ] 985 | }, 986 | { 987 | "name": "stderr", 988 | "output_type": "stream", 989 | "text": [ 990 | " 78%|███████▊ | 3906/5000 [00:48<00:19, 55.09it/s]" 991 | ] 992 | }, 993 | { 994 | "name": "stdout", 995 | "output_type": "stream", 996 | "text": [ 997 | "Iter: 3900\n", 998 | "Train_loss: 0.02297\n", 999 | "Test_loss: 0.06778\n", 1000 | "\n" 1001 | ] 1002 | }, 1003 | { 1004 | "name": "stderr", 1005 | "output_type": "stream", 1006 | "text": [ 1007 | " 79%|███████▊ | 3932/5000 [00:48<00:18, 57.62it/s]" 1008 | ] 1009 | } 1010 | ], 1011 | "source": [ 1012 | "#%% Start Iterations\n", 1013 | "for it in tqdm(range(5000)): \n", 1014 | " \n", 1015 | " #%% Inputs\n", 1016 | " mb_idx = sample_idx(Train_No, mb_size)\n", 1017 | " X_mb = trainX[mb_idx,:] \n", 1018 | " \n", 1019 | " Z_mb = sample_Z(mb_size, Dim) \n", 1020 | " M_mb = trainM[mb_idx,:] \n", 1021 | " H_mb1 = sample_M(mb_size, Dim, 1-p_hint)\n", 1022 | " H_mb = M_mb * H_mb1\n", 1023 | " \n", 1024 | " New_X_mb = M_mb * X_mb + (1-M_mb) * Z_mb # Missing Data Introduce\n", 1025 | " \n", 1026 | " if use_gpu is True:\n", 1027 | " X_mb = torch.tensor(X_mb, device=\"cuda\")\n", 1028 | " M_mb = torch.tensor(M_mb, device=\"cuda\")\n", 1029 | " H_mb = torch.tensor(H_mb, device=\"cuda\")\n", 1030 | " New_X_mb = torch.tensor(New_X_mb, device=\"cuda\")\n", 1031 | " else:\n", 1032 | " X_mb = torch.tensor(X_mb)\n", 1033 | " M_mb = torch.tensor(M_mb)\n", 1034 | " H_mb = torch.tensor(H_mb)\n", 1035 | " New_X_mb = torch.tensor(New_X_mb)\n", 1036 | " \n", 1037 | " optimizer_D.zero_grad()\n", 1038 | " D_loss_curr = discriminator_loss(M=M_mb, New_X=New_X_mb, H=H_mb)\n", 1039 | " D_loss_curr.backward()\n", 1040 | " optimizer_D.step()\n", 1041 | " \n", 1042 | " optimizer_G.zero_grad()\n", 1043 | " G_loss_curr, MSE_train_loss_curr, MSE_test_loss_curr = generator_loss(X=X_mb, M=M_mb, New_X=New_X_mb, H=H_mb)\n", 1044 | " G_loss_curr.backward()\n", 1045 | " optimizer_G.step() \n", 1046 | " \n", 1047 | " #%% Intermediate Losses\n", 1048 | " if it % 100 == 0:\n", 1049 | " print('Iter: {}'.format(it))\n", 1050 | " print('Train_loss: {:.4}'.format(np.sqrt(MSE_train_loss_curr.item())))\n", 1051 | " print('Test_loss: {:.4}'.format(np.sqrt(MSE_test_loss_curr.item())))\n", 1052 | " print()" 1053 | ] 1054 | }, 1055 | { 1056 | "cell_type": "markdown", 1057 | "metadata": {}, 1058 | "source": [ 1059 | "## Testing" 1060 | ] 1061 | }, 1062 | { 1063 | "cell_type": "code", 1064 | "execution_count": null, 1065 | "metadata": {}, 1066 | "outputs": [], 1067 | "source": [ 1068 | "Z_mb = sample_Z(Test_No, Dim) \n", 1069 | "M_mb = testM\n", 1070 | "X_mb = testX\n", 1071 | " \n", 1072 | "New_X_mb = M_mb * X_mb + (1-M_mb) * Z_mb # Missing Data Introduce\n", 1073 | "\n", 1074 | "if use_gpu is True:\n", 1075 | " X_mb = torch.tensor(X_mb, device='cuda')\n", 1076 | " M_mb = torch.tensor(M_mb, device='cuda')\n", 1077 | " New_X_mb = torch.tensor(New_X_mb, device='cuda')\n", 1078 | "else:\n", 1079 | " X_mb = torch.tensor(X_mb)\n", 1080 | " M_mb = torch.tensor(M_mb)\n", 1081 | " New_X_mb = torch.tensor(New_X_mb)\n", 1082 | " \n", 1083 | "MSE_final, Sample = test_loss(X=X_mb, M=M_mb, New_X=New_X_mb)\n", 1084 | " \n", 1085 | "print('Final Test RMSE: ' + str(np.sqrt(MSE_final.item())))" 1086 | ] 1087 | }, 1088 | { 1089 | "cell_type": "code", 1090 | "execution_count": null, 1091 | "metadata": {}, 1092 | "outputs": [], 1093 | "source": [ 1094 | "imputed_data = M_mb * X_mb + (1-M_mb) * Sample\n", 1095 | "print(\"Imputed test data:\")\n", 1096 | "# np.set_printoptions(formatter={'float': lambda x: \"{0:0.8f}\".format(x)})\n", 1097 | "\n", 1098 | "if use_gpu is True:\n", 1099 | " print(imputed_data.cpu().detach().numpy())\n", 1100 | "else:\n", 1101 | " print(imputed_data.detach().numpy())" 1102 | ] 1103 | } 1104 | ], 1105 | "metadata": { 1106 | "kernelspec": { 1107 | "display_name": "Python [conda env:pt]", 1108 | "language": "python", 1109 | "name": "conda-env-pt-py" 1110 | }, 1111 | "language_info": { 1112 | "codemirror_mode": { 1113 | "name": "ipython", 1114 | "version": 3 1115 | }, 1116 | "file_extension": ".py", 1117 | "mimetype": "text/x-python", 1118 | "name": "python", 1119 | "nbconvert_exporter": "python", 1120 | "pygments_lexer": "ipython3", 1121 | "version": "3.6.10" 1122 | } 1123 | }, 1124 | "nbformat": 4, 1125 | "nbformat_minor": 2 1126 | } 1127 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Dhanajit Brahma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Adversarial Imputation Networks (GAIN) Pytorch Implementation 2 | Pytorch implementation of the paper *GAIN: Missing Data Imputation using Generative Adversarial Nets* by Jinsung Yoon, James Jordon, Mihaela van der Schaar 3 | 4 | Reference: J. Yoon, J. Jordon, M. van der Schaar, "GAIN: Missing Data Imputation using Generative Adversarial Nets," International Conference on Machine Learning (ICML), 2018. 5 | 6 | This notebook is a Pytorch adaptation of the original Tensorflow code available here: https://github.com/jsyoon0823/GAIN 7 | 8 | This repo is tested on Python 3.6 and PyTorch 1.4.0. 9 | 10 | The code can be run in either GPU or CPU (using use_gpu flag). 11 | 12 | The results are same as the original implementation. 13 | 14 | Datasets are taken from the original implementation. 15 | --------------------------------------------------------------------------------