├── README.md └── avo-poisson.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # AVO Pytorch 2 | 3 | Implementation of [Adversarial Variational Optimization](https://arxiv.org/pdf/1707.07113.pdf) to solve likelihood-free inference problems 4 | -------------------------------------------------------------------------------- /avo-poisson.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# AVORIM for a Poisson simulator" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": { 14 | "collapsed": true 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "import numpy as np\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "%matplotlib inline\n", 21 | "import copy\n", 22 | "\n", 23 | "import torch\n", 24 | "import torch.nn as nn\n", 25 | "import torch.nn.functional as F\n", 26 | "from torch.nn.parameter import Parameter\n", 27 | "import torch.autograd as autograd\n", 28 | "\n", 29 | "from IPython.display import clear_output\n", 30 | "from tqdm import tqdm" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "## Constants" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "### Simulator parameters" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": { 51 | "collapsed": true 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "lambda_poisson = 7\n", 56 | "nb_samples = 20000" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "### Network parameters" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "metadata": { 70 | "collapsed": true 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "batch_size = 64\n", 75 | "lambda_gp = 1\n", 76 | "gamma_entropy = 100" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 4, 82 | "metadata": { 83 | "collapsed": true 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "critic_params = {\"lr\": 1e-3,\n", 88 | " \"nb_steps\": 1,\n", 89 | " \"nb_hiddens\": 10\n", 90 | " }" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 5, 96 | "metadata": { 97 | "collapsed": true 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "proposal_params = {\"lr\": 1e-3,\n", 102 | " \"beta1\": 0.5,\n", 103 | " \"beta2\": 0.9\n", 104 | " }" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 6, 110 | "metadata": { 111 | "collapsed": true 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "gen_params = {\"nb_steps\": 1}" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "## Generation of the data" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "### Build the simulator" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 7, 135 | "metadata": { 136 | "collapsed": true 137 | }, 138 | "outputs": [], 139 | "source": [ 140 | "def simulator(theta):\n", 141 | " # lambda = exp(theta) (for positivity)\n", 142 | " return np.random.poisson(np.exp(theta))" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "### Generate the data" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 8, 155 | "metadata": { 156 | "collapsed": true 157 | }, 158 | "outputs": [], 159 | "source": [ 160 | "X_obs = simulator([np.log(lambda_poisson) for i in range(nb_samples)])\n", 161 | "X_obs = X_obs[:,np.newaxis]\n", 162 | "X_obs = torch.Tensor(X_obs)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "## Definition of the optimizer" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": {}, 175 | "source": [ 176 | "### RMSprop" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 24, 182 | "metadata": { 183 | "collapsed": true 184 | }, 185 | "outputs": [], 186 | "source": [ 187 | "class RmsPropOptimizer:\n", 188 | " def __init__(self, size_input, lr=0.01, gamma=1., eps=10**-8):\n", 189 | " self.size_input = size_input\n", 190 | " self.lr = lr\n", 191 | " self.gamma = gamma\n", 192 | " self.eps = eps\n", 193 | "\n", 194 | " self.reset()\n", 195 | "\n", 196 | " def reset(self):\n", 197 | " self.avg_sq_grad = torch.ones(self.size_input)\n", 198 | "\n", 199 | " def step(self, grad_approx, num_iters=1):\n", 200 | " v = 0\n", 201 | " for i in range(num_iters):\n", 202 | " self.avg_sq_grad = self.avg_sq_grad * self.gamma + grad_approx**2 * (1 - self.gamma)\n", 203 | " v = v + self.lr * grad_approx / torch.sqrt(self.avg_sq_grad + self.eps)\n", 204 | " v /= num_iters\n", 205 | " print(\"step: \", v)\n", 206 | " return v" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 25, 212 | "metadata": { 213 | "collapsed": true 214 | }, 215 | "outputs": [], 216 | "source": [ 217 | "rms = RmsPropOptimizer(1, lr=0.1)\n", 218 | "x = torch.tensor([1.], requires_grad=True)\n", 219 | "y = x**2" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 58, 225 | "metadata": {}, 226 | "outputs": [ 227 | { 228 | "name": "stdout", 229 | "output_type": "stream", 230 | "text": [ 231 | "step: tensor(1.00000e-03 *\n", 232 | " [ 3.6029])\n" 233 | ] 234 | }, 235 | { 236 | "data": { 237 | "text/plain": [ 238 | "tensor(1.00000e-02 *\n", 239 | " [ 1.4412])" 240 | ] 241 | }, 242 | "execution_count": 58, 243 | "metadata": {}, 244 | "output_type": "execute_result" 245 | } 246 | ], 247 | "source": [ 248 | "x.grad = torch.tensor([0.])\n", 249 | "y.backward(retain_graph=True)\n", 250 | "x.data.sub_(rms.step(x.grad))" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": {}, 256 | "source": [ 257 | "## Definition of the proposal distribution $q$" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 27, 263 | "metadata": { 264 | "collapsed": true 265 | }, 266 | "outputs": [], 267 | "source": [ 268 | "class Proposal(nn.Module):\n", 269 | " def __init__(self, lr=0.01):\n", 270 | " super(Proposal, self).__init__()\n", 271 | " \n", 272 | " self.psi = Parameter(torch.tensor([0., -1.]), requires_grad=True) # mean and log(std)\n", 273 | "\n", 274 | " self.optimizer = RmsPropOptimizer(self.psi.shape[0], lr=lr)\n", 275 | " \n", 276 | " def forward(self, size):\n", 277 | " return self.psi[0] + torch.exp(self.psi[1]) * torch.normal(torch.zeros((size, 1)), torch.ones((size, 1)))\n", 278 | " \n", 279 | " def backprop(self, grad_approx):\n", 280 | " step = self.optimizer.step(grad_approx)\n", 281 | " self.psi[0].data.sub_(step[0])\n", 282 | " \n", 283 | " def pdf(self, theta):\n", 284 | " return 1. / torch.sqrt(2 * np.pi * torch.exp(self.psi[1])**2) * torch.exp(-(torch.exp(theta) - self.psi[0])**2 / (2 * torch.exp(self.psi[1])))" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 28, 290 | "metadata": { 291 | "collapsed": true 292 | }, 293 | "outputs": [], 294 | "source": [ 295 | "p = Proposal()" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 29, 301 | "metadata": { 302 | "collapsed": true 303 | }, 304 | "outputs": [], 305 | "source": [ 306 | "p.zero_grad()\n", 307 | "l = torch.log(p.pdf(torch.tensor(0.)))\n", 308 | "l.backward()" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": 30, 314 | "metadata": {}, 315 | "outputs": [ 316 | { 317 | "data": { 318 | "text/plain": [ 319 | "tensor([ 2.7183, 0.3591])" 320 | ] 321 | }, 322 | "execution_count": 30, 323 | "metadata": {}, 324 | "output_type": "execute_result" 325 | } 326 | ], 327 | "source": [ 328 | "p.psi.grad" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 31, 334 | "metadata": {}, 335 | "outputs": [ 336 | { 337 | "data": { 338 | "text/plain": [ 339 | "tensor(0.2062)" 340 | ] 341 | }, 342 | "execution_count": 31, 343 | "metadata": {}, 344 | "output_type": "execute_result" 345 | } 346 | ], 347 | "source": [ 348 | "p.pdf(torch.tensor(0.1))" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "metadata": {}, 354 | "source": [ 355 | "## Building the networks" 356 | ] 357 | }, 358 | { 359 | "cell_type": "markdown", 360 | "metadata": {}, 361 | "source": [ 362 | "### Losses" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 32, 368 | "metadata": { 369 | "collapsed": true 370 | }, 371 | "outputs": [], 372 | "source": [ 373 | "def wgan_critic_loss(C_gen, C_real):\n", 374 | " return torch.mean(C_gen) - torch.mean(C_real)\n", 375 | "\n", 376 | "def wgan_generator_loss(C_gen):\n", 377 | " return -torch.mean(C_gen)" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 33, 383 | "metadata": { 384 | "collapsed": true 385 | }, 386 | "outputs": [], 387 | "source": [ 388 | "def ogan_critic_loss(C_gen, C_real, eps=1e-8):\n", 389 | " return torch.mean(- torch.log(C_real+eps) - torch.log(1 - C_gen+eps))\n", 390 | "\n", 391 | "def ogan_generator_loss(C_gen, eps=1e-8):\n", 392 | " return -torch.mean(torch.log(C_gen+eps))" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": 34, 398 | "metadata": { 399 | "collapsed": true 400 | }, 401 | "outputs": [], 402 | "source": [ 403 | "def gradient_penalty(critic, X_real):\n", 404 | " X_real.requires_grad_(True)\n", 405 | " C_real = critic.forward(X_real)\n", 406 | " out_grad = torch.ones(X_real.shape[0], 1) # used to define the size of the grad (grad_C(C) = [1,1,...])\n", 407 | " gradients = autograd.grad(outputs=C_real, inputs=X_real,\n", 408 | " grad_outputs=out_grad, create_graph=True, retain_graph=True,\n", 409 | " only_inputs=True)[0]\n", 410 | " return ((gradients.norm(2, dim=1))**2).mean()" 411 | ] 412 | }, 413 | { 414 | "cell_type": "markdown", 415 | "metadata": {}, 416 | "source": [ 417 | "### Critic" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": 35, 423 | "metadata": { 424 | "collapsed": true 425 | }, 426 | "outputs": [], 427 | "source": [ 428 | "class Critic(nn.Module):\n", 429 | " def __init__(self):\n", 430 | " super(Critic, self).__init__()\n", 431 | " self.fc1 = nn.Linear(1,critic_params[\"nb_hiddens\"])\n", 432 | " self.fc2 = nn.Linear(critic_params[\"nb_hiddens\"],critic_params[\"nb_hiddens\"])\n", 433 | " self.fc3 = nn.Linear(critic_params[\"nb_hiddens\"],1)\n", 434 | " \n", 435 | " self.optimizer = torch.optim.RMSprop(self.parameters(), lr=critic_params[\"lr\"])\n", 436 | " \n", 437 | " def forward(self, x):\n", 438 | " out = F.relu(self.fc1(x))\n", 439 | " out = F.relu(self.fc2(out))\n", 440 | " out = F.sigmoid(self.fc3(out))\n", 441 | " \n", 442 | " return out\n", 443 | " \n", 444 | " def loss(self, X_gen, X_real):\n", 445 | " C_gen = self.forward(X_gen)\n", 446 | " C_real = self.forward(X_real)\n", 447 | "# print(\"C_gen\", C_gen)\n", 448 | "# print(\"C_real\", C_real)\n", 449 | " \n", 450 | " return ogan_critic_loss(C_gen, C_real) #+ lambda_gp * gradient_penalty(self, X_real)" 451 | ] 452 | }, 453 | { 454 | "cell_type": "markdown", 455 | "metadata": {}, 456 | "source": [ 457 | "### Generator" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": 60, 463 | "metadata": { 464 | "collapsed": true 465 | }, 466 | "outputs": [], 467 | "source": [ 468 | "class Generator(nn.Module):\n", 469 | " def __init__(self):\n", 470 | " super(Generator, self).__init__()\n", 471 | " \n", 472 | " def forward(self, theta):\n", 473 | " theta_np = theta.detach().numpy()\n", 474 | " x_np = simulator(theta_np)\n", 475 | " return torch.FloatTensor(x_np)\n", 476 | " \n", 477 | " def loss(self, C_gen):\n", 478 | " return ogan_generator_loss(C_gen)\n", 479 | " \n", 480 | " def grad_approx(self, proposal, theta, C_gen, eps=1e-8):\n", 481 | "# log_q_batch = torch.log(proposal.pdf(theta))\n", 482 | " \n", 483 | "# proposal.zero_grad()\n", 484 | "# log_q_batch = torch.log(proposal.pdf(theta.detach()))\n", 485 | "# log_q_batch.backward(-C_gen)\n", 486 | "# grad_batch = proposal.psi.grad\n", 487 | "\n", 488 | " grad_u = 0\n", 489 | " for i in range(len(theta)):\n", 490 | " proposal.zero_grad()\n", 491 | " log_q = torch.log(proposal.pdf(theta[i]) + eps)\n", 492 | " log_q.backward()\n", 493 | " grad_u += - C_gen[i] * proposal.psi.grad\n", 494 | " print(\"pdf\", proposal.pdf(theta[i].detach()))\n", 495 | "\n", 496 | " grad_u /= theta.shape[0]\n", 497 | " \n", 498 | " return grad_u\n", 499 | " \n", 500 | " def entropy(self, proposal, theta, eps=1e-8):\n", 501 | " grad_h = 0\n", 502 | " for i in range(len(theta)):\n", 503 | " proposal.zero_grad()\n", 504 | " q = proposal.pdf(theta[i].detach())\n", 505 | " H = q * torch.log(q + eps)\n", 506 | " H.backward()\n", 507 | " grad_h += proposal.psi.grad\n", 508 | "\n", 509 | " grad_h /= theta.shape[0]\n", 510 | " \n", 511 | " return grad_h\n", 512 | " \n", 513 | " def backprop(self, proposal, theta, C_gen):\n", 514 | " grad_u = self.grad_approx(proposal, theta.detach(), C_gen)\n", 515 | " grad_h = self.entropy(proposal, theta)\n", 516 | " print(\"grad_u\", grad_u)\n", 517 | " print(\"grad_h\", gamma_entropy * grad_h)\n", 518 | " proposal.backprop(grad_u + 0 * grad_h)" 519 | ] 520 | }, 521 | { 522 | "cell_type": "markdown", 523 | "metadata": {}, 524 | "source": [ 525 | "## Main loop" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": 61, 531 | "metadata": { 532 | "collapsed": true 533 | }, 534 | "outputs": [], 535 | "source": [ 536 | "i = 0\n", 537 | "C_loss_list = []\n", 538 | "G_loss_list = []\n", 539 | "\n", 540 | "proposal = Proposal(lr=0.01)\n", 541 | "critic = Critic()\n", 542 | "generator = Generator()" 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": 62, 548 | "metadata": {}, 549 | "outputs": [ 550 | { 551 | "name": "stdout", 552 | "output_type": "stream", 553 | "text": [ 554 | "pdf tensor(1.00000e-03 *\n", 555 | " [ 1.2718])\n", 556 | "grad_u tensor([-0.4120, -0.5302])\n", 557 | "grad_h tensor([-0.8154, -0.5042])\n", 558 | "step: tensor(1.00000e-03 *\n", 559 | " [-4.1201, -5.3018])\n", 560 | "Psi: Parameter containing:\n", 561 | "tensor([ 2.1869, -1.0000])\n" 562 | ] 563 | }, 564 | { 565 | "ename": "KeyboardInterrupt", 566 | "evalue": "", 567 | "output_type": "error", 568 | "traceback": [ 569 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 570 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 571 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mC_gen\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcritic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtheta_sample\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0mG_loss_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgenerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mC_gen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0msleep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0mclear_output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0mgenerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackprop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mproposal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtheta_sample\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mC_gen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 572 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 573 | ] 574 | } 575 | ], 576 | "source": [ 577 | "from time import sleep\n", 578 | "# print(proposal.psi)\n", 579 | "critic_params[\"nb_steps\"] = 5\n", 580 | "gen_params[\"nb_steps\"] = 2\n", 581 | "\n", 582 | "nb_iter = 100\n", 583 | "for i in range(i,i+nb_iter):\n", 584 | " critic.optimizer.zero_grad()\n", 585 | " \n", 586 | " for i_critic in range(critic_params[\"nb_steps\"]):\n", 587 | " X_real_sample = X_obs[np.random.choice(len(X_obs), batch_size)]\n", 588 | " theta_sample = proposal.forward(batch_size)\n", 589 | " X_gen_sample = torch.FloatTensor(simulator(theta_sample.detach().numpy()))\n", 590 | " \n", 591 | " C_loss = critic.loss(X_gen_sample, X_real_sample)\n", 592 | " C_loss.backward()\n", 593 | " critic.optimizer.step()\n", 594 | " \n", 595 | " # validation\n", 596 | " X_real_sample = X_obs[np.random.choice(len(X_obs), batch_size)]\n", 597 | " theta_sample = proposal.forward(batch_size)\n", 598 | " X_gen_sample = torch.FloatTensor(simulator(theta_sample.detach().numpy()))\n", 599 | " C_loss = critic.loss(X_gen_sample, X_real_sample) \n", 600 | " C_loss_list.append(C_loss)\n", 601 | " \n", 602 | " for i_gen in range(gen_params[\"nb_steps\"]):\n", 603 | " theta_sample = proposal.forward(batch_size)\n", 604 | " C_gen = critic(generator.forward(theta_sample.detach()))\n", 605 | " G_loss_list.append(generator.loss(C_gen))\n", 606 | " sleep(0.1)\n", 607 | " clear_output(wait=True)\n", 608 | " generator.backprop(proposal, theta_sample, C_gen)\n", 609 | " print(\"Psi: \", proposal.psi)\n", 610 | "\n", 611 | "print(\"Critic Loss: \", C_loss)" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": 59, 617 | "metadata": {}, 618 | "outputs": [ 619 | { 620 | "data": { 621 | "text/plain": [ 622 | "Text(0,0.5,'Generator loss')" 623 | ] 624 | }, 625 | "execution_count": 59, 626 | "metadata": {}, 627 | "output_type": "execute_result" 628 | }, 629 | { 630 | "data": { 631 | "image/png": "\n", 632 | "text/plain": [ 633 | "
" 634 | ] 635 | }, 636 | "metadata": {}, 637 | "output_type": "display_data" 638 | } 639 | ], 640 | "source": [ 641 | "plt.rcParams[\"figure.figsize\"] = (20,7)\n", 642 | "\n", 643 | "plt.subplot(1,2,1)\n", 644 | "\n", 645 | "plt.plot(C_loss_list)\n", 646 | "plt.xlabel(\"Number of iterations\", size=15)\n", 647 | "plt.ylabel(\"Critic loss\", size=15)\n", 648 | "\n", 649 | "plt.subplot(1,2,2)\n", 650 | "\n", 651 | "plt.plot(G_loss_list)\n", 652 | "plt.xlabel(\"Number of iterations\", size=15)\n", 653 | "plt.ylabel(\"Generator loss\", size=15)" 654 | ] 655 | }, 656 | { 657 | "cell_type": "markdown", 658 | "metadata": {}, 659 | "source": [ 660 | "## Visualization" 661 | ] 662 | }, 663 | { 664 | "cell_type": "code", 665 | "execution_count": 459, 666 | "metadata": {}, 667 | "outputs": [ 668 | { 669 | "data": { 670 | "text/plain": [ 671 | "" 672 | ] 673 | }, 674 | "execution_count": 459, 675 | "metadata": {}, 676 | "output_type": "execute_result" 677 | }, 678 | { 679 | "data": { 680 | "image/png": "\n", 681 | "text/plain": [ 682 | "
" 683 | ] 684 | }, 685 | "metadata": {}, 686 | "output_type": "display_data" 687 | } 688 | ], 689 | "source": [ 690 | "theta_sample = proposal.forward(20000)\n", 691 | "X_gen_sample = torch.FloatTensor(simulator(theta_sample.detach().numpy()))\n", 692 | "\n", 693 | "hist1 = plt.hist(X_gen_sample[:,0], label=\"Generate distribution\", density=True, alpha=0.8)\n", 694 | "hist2 = plt.hist(X_obs[:,0], label=\"Real distribution\", density=True, alpha=0.8)\n", 695 | "\n", 696 | "cm = plt.cm.coolwarm\n", 697 | "\n", 698 | "x_min, x_max = 0, 25\n", 699 | "y_min, y_max = 0, 0.6\n", 700 | "h = .1\n", 701 | "xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))\n", 702 | "levels = np.linspace(0,1,100)\n", 703 | "eps = 1\n", 704 | "Z = critic.forward(torch.FloatTensor(xx.reshape(xx.shape[0]*xx.shape[1], 1))).detach()\n", 705 | "Z = Z.reshape(xx.shape[0], xx.shape[1])\n", 706 | "plt.contourf(xx, yy, Z, levels, cmap=cm, alpha=0.8)\n", 707 | "\n", 708 | "plt.xlim(0,25)\n", 709 | "plt.ylim(0, 0.5)\n", 710 | "plt.colorbar()\n", 711 | "plt.legend()" 712 | ] 713 | }, 714 | { 715 | "cell_type": "code", 716 | "execution_count": 118, 717 | "metadata": {}, 718 | "outputs": [ 719 | { 720 | "data": { 721 | "text/plain": [ 722 | "array([4505])" 723 | ] 724 | }, 725 | "execution_count": 118, 726 | "metadata": {}, 727 | "output_type": "execute_result" 728 | } 729 | ], 730 | "source": [ 731 | "simulator(max(theta_sample.detach()))" 732 | ] 733 | }, 734 | { 735 | "cell_type": "code", 736 | "execution_count": 119, 737 | "metadata": {}, 738 | "outputs": [ 739 | { 740 | "data": { 741 | "text/plain": [ 742 | "tensor([ 8.3894])" 743 | ] 744 | }, 745 | "execution_count": 119, 746 | "metadata": {}, 747 | "output_type": "execute_result" 748 | } 749 | ], 750 | "source": [ 751 | "max(theta_sample.detach())" 752 | ] 753 | }, 754 | { 755 | "cell_type": "code", 756 | "execution_count": null, 757 | "metadata": { 758 | "collapsed": true 759 | }, 760 | "outputs": [], 761 | "source": [ 762 | "xx.reshape(xx.shape[0]*xx.shape[1], 1)" 763 | ] 764 | }, 765 | { 766 | "cell_type": "code", 767 | "execution_count": null, 768 | "metadata": { 769 | "collapsed": true 770 | }, 771 | "outputs": [], 772 | "source": [ 773 | "np.log(xx.reshape(xx.shape[0]*xx.shape[1], 1) + eps)" 774 | ] 775 | }, 776 | { 777 | "cell_type": "code", 778 | "execution_count": null, 779 | "metadata": { 780 | "collapsed": true 781 | }, 782 | "outputs": [], 783 | "source": [ 784 | "np.log(xx.reshape(xx.shape[0]*xx.shape[1], 1) + eps)" 785 | ] 786 | }, 787 | { 788 | "cell_type": "code", 789 | "execution_count": 104, 790 | "metadata": {}, 791 | "outputs": [ 792 | { 793 | "data": { 794 | "text/plain": [ 795 | "403.4287934927351" 796 | ] 797 | }, 798 | "execution_count": 104, 799 | "metadata": {}, 800 | "output_type": "execute_result" 801 | } 802 | ], 803 | "source": [ 804 | "np.exp(6)" 805 | ] 806 | }, 807 | { 808 | "cell_type": "code", 809 | "execution_count": null, 810 | "metadata": { 811 | "collapsed": true 812 | }, 813 | "outputs": [], 814 | "source": [] 815 | } 816 | ], 817 | "metadata": { 818 | "kernelspec": { 819 | "display_name": "Python [conda env:rim-avo]", 820 | "language": "python", 821 | "name": "conda-env-rim-avo-py" 822 | }, 823 | "language_info": { 824 | "codemirror_mode": { 825 | "name": "ipython", 826 | "version": 3 827 | }, 828 | "file_extension": ".py", 829 | "mimetype": "text/x-python", 830 | "name": "python", 831 | "nbconvert_exporter": "python", 832 | "pygments_lexer": "ipython3", 833 | "version": "3.6.5" 834 | } 835 | }, 836 | "nbformat": 4, 837 | "nbformat_minor": 2 838 | } 839 | --------------------------------------------------------------------------------