├── .gitignore ├── CONTRIBUTING.md ├── notebooks ├── README.md ├── pytorch_mine.ipynb ├── comparison.ipynb ├── pytorch_mine_vs_nce.ipynb └── pytorch_mine_vs_nwj.ipynb └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | .DS_Store 3 | __pycache__ 4 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## CONTRIBUTING 2 | You can contribute to this repository in the following ways: 3 | 4 | 1. Add relevant research papers 5 | 2. Code implementations in any framework 6 | 3. Links to other blog posts or videos 7 | 4. New ideas and approaches in the field 8 | 9 | Just open a PR with required changes for review. 10 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | ## Implementation Details 2 | 3 | The notebooks in this folder implement estimation of Mutual Information using different lower-bound objective functions described in the papers mentioned in the repository. 4 | 5 | Some of these techniques include the Mutual Information Neural Estimation (MINE), Nguyen Wainwright Jordan (NWJ) loss, Jensen Shannon Divergence (JS) loss, Noise Contrastive Estimation(NCE) loss, etc. The neural network is trained using PyTorch and the calculated MI estimates are plotted against the MI obtained using the traditional method (the basic definition of MI). 6 | 7 | These lower bounds have been compared in the analysis done by Poole et al. in [On variational lower bounds of mutual information](http://bayesiandeeplearning.org/2018/papers/136.pdf) and this paper was recently presented at NeurIPS, 2018. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Information Theory in Deep Learning 2 | This repository contains implementations (mostly in PyTorch), relevant resources and lessons related to information theory of deep learning. The aim is to have a single source for all the information on the topic. 3 | 4 | ## Resources 5 | 6 | ### Videos and Talks 7 | 8 | - [Stanford Seminar - Information Theory of Deep Learning, Dr. Naftly Tishby, 2018](https://www.youtube.com/watch?v=XL07WEc2TRI) 9 | - [Information Theory of Deep Learning (Yandex) - Naftali Tishby, 2017](https://www.youtube.com/watch?v=dPhsU0bu4LY) 10 | 11 | ### Research Papers 12 | 13 | [1] [Naftali Tishby and Noga Zaslavsky. “Deep learning and the information bottleneck principle” IEEE Information Theory Workshop (ITW), 2015](https://arxiv.org/pdf/1503.02406.pdf) 14 | 15 | [2] [Ravid Schwartz-Ziv and Naftali Tishby. “Opening the Black Box of Deep Neural Networks via Information” ICRI-CI, 2017](https://arxiv.org/pdf/1703.00810.pdf) 16 | 17 | [3] [Naftali Tishby, Fernando C. Pereira, and William Bialek. "The information bottleneck method"](https://arxiv.org/pdf/physics/0004057.pdf) 18 | 19 | [4] [Mohamed Ishmael Belghazi, Aristide Baratin, Sai Rajeswar, Sherjil Ozair, Yoshua Bengio, Aaron Courville, R Devon Hjelm, "Mutual Information Neural Estimation" ICML, 2018](https://arxiv.org/abs/1801.04062) 20 | 21 | [5] [Ben Poole, Sherjil Ozair, Aaron van den Oord, Alexander A. Alemi, George Tucker1, "On variational lower bounds of mutual information", NeurIPS, 2018](http://bayesiandeeplearning.org/2018/papers/136.pdf) 22 | 23 | [6] [R Devon Hjelm, Alex Fedorov, Samuel Lavoie-Marchildon, Karan Grewal, Phil Bachman, Adam Trischler, Yoshua Bengio, "Learning deep representations by mutual information estimation and maximization", [stat.ML] 3 Oct 2018](https://arxiv.org/abs/1808.06670) 24 | 25 | ### Blog posts and Articles 26 | 27 | - [Lilian Weng’s blog post: Anatomize Deep Learning using Information Theory, Sep 2017](https://lilianweng.github.io/lil-log/2017/09/28/anatomize-deep-learning-with-information-theory.html) 28 | - [My own blog post on Information Theory of Deep Learning, Dec 2018](https://adityashrm21.github.io/Information-Theory-In-Deep-Learning/) 29 | -------------------------------------------------------------------------------- /notebooks/pytorch_mine.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "The code is adapted from [Masanori Yamada's repository](https://github.com/MasanoriYamada/Mine_pytorch) on MINE implementation." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": { 14 | "colab": {}, 15 | "colab_type": "code", 16 | "id": "BS2h2kglwPa_" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import torch\n", 21 | "from torch.autograd import Variable\n", 22 | "import torch.nn as nn\n", 23 | "import torch.nn.functional as F\n", 24 | "from torch.utils.data import Dataset, DataLoader\n", 25 | "import numpy as np\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "from tqdm import tqdm" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": { 34 | "colab": {}, 35 | "colab_type": "code", 36 | "id": "Ze6Kk6SsQKiE" 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "# data\n", 41 | "np.random.seed(1234)\n", 42 | "var = 0.2\n", 43 | "def func(x):\n", 44 | " return x\n", 45 | "\n", 46 | "def gen_x():\n", 47 | " return np.sign(np.random.normal(0.,1.,[data_size,1]))\n", 48 | "\n", 49 | "def gen_y(x):\n", 50 | " return func(x)+np.random.normal(0.,np.sqrt(var),[data_size,1])" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 3, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "np.random.seed(1234)\n", 60 | "data_size = 1000000\n", 61 | "x = gen_x()\n", 62 | "y = gen_y(x)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 4, 68 | "metadata": { 69 | "colab": { 70 | "base_uri": "https://localhost:8080/", 71 | "height": 34 72 | }, 73 | "colab_type": "code", 74 | "id": "C_wM4ty1Qldx", 75 | "outputId": "2892779e-b4dc-4f29-9858-28afb707d11d" 76 | }, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "Mutual information calculated through traditional method is: 0.6584537102332939\n" 83 | ] 84 | } 85 | ], 86 | "source": [ 87 | "x = gen_x()\n", 88 | "y = gen_y(x)\n", 89 | "p_y_x = np.exp(-(y - x)**2 / (2 * var))\n", 90 | "p_y_x_minus = np.exp(-(y + 1)**2 / (2 * var))\n", 91 | "p_y_x_plus = np.exp(-(y - 1)**2 / (2 * var))\n", 92 | "mi = np.average(np.log(p_y_x / (0.5 * p_y_x_minus + 0.5 * p_y_x_plus)))\n", 93 | "#mi = mutual_information(x, y)\n", 94 | "print(\"Mutual information calculated through traditional method is:\", mi)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 5, 100 | "metadata": { 101 | "colab": {}, 102 | "colab_type": "code", 103 | "id": "Y9qfiZYqS-0R" 104 | }, 105 | "outputs": [ 106 | { 107 | "name": "stderr", 108 | "output_type": "stream", 109 | "text": [ 110 | "100%|██████████| 500/500 [00:05<00:00, 84.49it/s]\n" 111 | ] 112 | } 113 | ], 114 | "source": [ 115 | "np.random.seed(1234)\n", 116 | "H=10\n", 117 | "n_epoch = 500\n", 118 | "data_size = 20000\n", 119 | "\n", 120 | "class Net(nn.Module):\n", 121 | " def __init__(self):\n", 122 | " super(Net, self).__init__()\n", 123 | " self.fc1 = nn.Linear(1, H)\n", 124 | " self.fc2 = nn.Linear(1, H)\n", 125 | " self.fc3 = nn.Linear(H, 1)\n", 126 | "\n", 127 | " def forward(self, x, y):\n", 128 | " h1 = F.relu(self.fc1(x)+self.fc2(y))\n", 129 | " h2 = self.fc3(h1)\n", 130 | " return h2 \n", 131 | "\n", 132 | "model = Net()\n", 133 | "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", 134 | "plot_loss = []\n", 135 | "for epoch in tqdm(range(n_epoch)):\n", 136 | " x_sample=gen_x()\n", 137 | " y_sample=gen_y(x_sample)\n", 138 | " y_shuffle=np.random.permutation(y_sample)\n", 139 | " \n", 140 | " x_sample = Variable(torch.from_numpy(x_sample).type(torch.FloatTensor), requires_grad = True)\n", 141 | " y_sample = Variable(torch.from_numpy(y_sample).type(torch.FloatTensor), requires_grad = True)\n", 142 | " y_shuffle = Variable(torch.from_numpy(y_shuffle).type(torch.FloatTensor), requires_grad = True) \n", 143 | " \n", 144 | " pred_xy = model(x_sample, y_sample)\n", 145 | " pred_x_y = model(x_sample, y_shuffle)\n", 146 | "\n", 147 | " ret = torch.mean(pred_xy) - torch.log(torch.mean(torch.exp(pred_x_y)))\n", 148 | " loss = - ret # maximize\n", 149 | " plot_loss.append(loss.data.numpy())\n", 150 | " model.zero_grad()\n", 151 | " loss.backward()\n", 152 | " optimizer.step()" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 6, 158 | "metadata": {}, 159 | "outputs": [ 160 | { 161 | "data": { 162 | "image/png": "\n", 163 | "text/plain": [ 164 | "
" 165 | ] 166 | }, 167 | "metadata": {}, 168 | "output_type": "display_data" 169 | } 170 | ], 171 | "source": [ 172 | "plot_x = np.arange(len(plot_loss))\n", 173 | "plot_y = np.array(plot_loss).reshape(-1,)\n", 174 | "\n", 175 | "plt.plot(plot_x, -plot_y)\n", 176 | "plt.plot(plot_x, mi*np.ones(len(plot_loss)))\n", 177 | "plt.show()" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [] 186 | } 187 | ], 188 | "metadata": { 189 | "accelerator": "GPU", 190 | "colab": { 191 | "name": "pytorch_mine.ipynb", 192 | "provenance": [], 193 | "version": "0.3.2" 194 | }, 195 | "kernelspec": { 196 | "display_name": "Python 3", 197 | "language": "python", 198 | "name": "python3" 199 | }, 200 | "language_info": { 201 | "codemirror_mode": { 202 | "name": "ipython", 203 | "version": 3 204 | }, 205 | "file_extension": ".py", 206 | "mimetype": "text/x-python", 207 | "name": "python", 208 | "nbconvert_exporter": "python", 209 | "pygments_lexer": "ipython3", 210 | "version": "3.6.5" 211 | } 212 | }, 213 | "nbformat": 4, 214 | "nbformat_minor": 1 215 | } 216 | -------------------------------------------------------------------------------- /notebooks/comparison.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This notebook compares most of the objective functions mentioned in the recent paper at NEurIPS 2018 [On variational lower bounds of mutual information](http://bayesiandeeplearning.org/2018/papers/136.pdf) by [Ben Poole](https://cs.stanford.edu/~poole/) et al. to calculate Mutual Information using an implementation in PyTorch." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import torch\n", 17 | "from torch.autograd import Variable\n", 18 | "import torch.nn as nn\n", 19 | "import torch.nn.functional as F\n", 20 | "from torch.utils.data import Dataset, DataLoader\n", 21 | "import numpy as np\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "from tqdm import tqdm" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# data\n", 33 | "np.random.seed(1234)\n", 34 | "var = 0.2\n", 35 | "def func(x):\n", 36 | " return x\n", 37 | "\n", 38 | "def gen_x():\n", 39 | " return np.sign(np.random.normal(0.,1.,[data_size,1]))\n", 40 | "\n", 41 | "def gen_y(x):\n", 42 | " return func(x)+np.random.normal(0.,np.sqrt(var),[data_size,1])" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "np.random.seed(1234)\n", 52 | "data_size = 1000000\n", 53 | "x = gen_x()\n", 54 | "y = gen_y(x)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "Mutual information calculated through traditional method is: 0.6584537102332939\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "x = gen_x()\n", 72 | "y = gen_y(x)\n", 73 | "p_y_x = np.exp(-(y - x)**2 / (2 * var))\n", 74 | "p_y_x_minus = np.exp(-(y + 1)**2 / (2 * var))\n", 75 | "p_y_x_plus = np.exp(-(y - 1)**2 / (2 * var))\n", 76 | "mi = np.average(np.log(p_y_x / (0.5 * p_y_x_minus + 0.5 * p_y_x_plus)))\n", 77 | "#mi = mutual_information(x, y)\n", 78 | "print(\"Mutual information calculated through traditional method is:\", mi)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 5, 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "name": "stderr", 88 | "output_type": "stream", 89 | "text": [ 90 | "100%|██████████| 500/500 [00:11<00:00, 43.71it/s]\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "np.random.seed(1234)\n", 96 | "H=10\n", 97 | "n_epoch = 500\n", 98 | "data_size = 20000\n", 99 | "e = 2.71828\n", 100 | "\n", 101 | "class Net(nn.Module):\n", 102 | " def __init__(self):\n", 103 | " super(Net, self).__init__()\n", 104 | " self.fc1 = nn.Linear(1, H)\n", 105 | " self.fc2 = nn.Linear(1, H)\n", 106 | " self.fc3 = nn.Linear(H, 1)\n", 107 | "\n", 108 | " def forward(self, x, y):\n", 109 | " h1 = F.relu(self.fc1(x)+self.fc2(y))\n", 110 | " h2 = self.fc3(h1)\n", 111 | " return h2 \n", 112 | "\n", 113 | "model = Net()\n", 114 | "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", 115 | "\n", 116 | "plot_loss_nwj = []\n", 117 | "plot_loss_jsd = []\n", 118 | "plot_loss_mine = []\n", 119 | "plot_loss_nce = []\n", 120 | "\n", 121 | "\n", 122 | "for epoch in tqdm(range(n_epoch)):\n", 123 | " x_sample=gen_x()\n", 124 | " y_sample=gen_y(x_sample)\n", 125 | " y_shuffle=np.random.permutation(y_sample)\n", 126 | " \n", 127 | " x_sample = Variable(torch.from_numpy(x_sample).type(torch.FloatTensor), requires_grad = True)\n", 128 | " y_sample = Variable(torch.from_numpy(y_sample).type(torch.FloatTensor), requires_grad = True)\n", 129 | " y_shuffle = Variable(torch.from_numpy(y_shuffle).type(torch.FloatTensor), requires_grad = True) \n", 130 | " \n", 131 | " pred_xy = model(x_sample, y_sample)\n", 132 | " pred_x_y = model(x_sample, y_shuffle)\n", 133 | "\n", 134 | " # Nguyen Wainwright Jordan (NWJ)\n", 135 | " ret_nwj = torch.mean(pred_xy) - (1/e) * torch.mean(torch.exp(pred_x_y))\n", 136 | " # Jensen Shannon Divergence (JS) loss\n", 137 | " ret_jsd = torch.mean(-torch.log(1 + torch.exp(-pred_xy))) - torch.mean(torch.log(1 + torch.exp(pred_x_y))) + e - 1\n", 138 | " #ret_jsd = torch.mean(torch.log(1 + pred_xy)) - torch.mean(torch.log(pred_x_y))\n", 139 | " # Mutual Information Neural Estimation (MINE)/ Donsker Varadhan (DV) loss\n", 140 | " ret_mine = torch.mean(pred_xy) - torch.log(torch.mean(torch.exp(pred_x_y)))\n", 141 | " # Noise Contrastive Estimation(NCE) loss\n", 142 | " #ret_nce = torch.mean(pred_xy - torch.log(1 + pred_xy + torch.sum(torch.exp(pred_x_y))))\n", 143 | " ret_nce = torch.mean(pred_xy - torch.mean(torch.log(torch.sum(torch.exp(pred_x_y)))))\n", 144 | " \n", 145 | " loss_nwj = - ret_nwj # maximize\n", 146 | " loss_jsd = - ret_jsd\n", 147 | " loss_mine = - ret_mine # maximize\n", 148 | " loss_nce = - ret_nce # maximize\n", 149 | " \n", 150 | " plot_loss_mine.append(loss_mine.data.numpy())\n", 151 | " plot_loss_jsd.append(loss_jsd.data.numpy())\n", 152 | " plot_loss_nwj.append(loss_nwj.data.numpy())\n", 153 | " plot_loss_nce.append(loss_nce.data.numpy())\n", 154 | " \n", 155 | " model.zero_grad()\n", 156 | " \n", 157 | " loss_nwj.backward(retain_graph=True)\n", 158 | " loss_jsd.backward(retain_graph=True)\n", 159 | " loss_mine.backward(retain_graph=True)\n", 160 | " loss_nce.backward(retain_graph=True)\n", 161 | " \n", 162 | " optimizer.step()" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 6, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "plot_x = np.arange(len(plot_loss_mine))\n", 172 | "plot_y_nwj = np.array(plot_loss_nwj).reshape(-1,)\n", 173 | "plot_y_jsd = np.array(plot_loss_jsd).reshape(-1,)\n", 174 | "plot_y_mine = np.array(plot_loss_mine).reshape(-1,)\n", 175 | "plot_y_nce = np.array(plot_loss_nce).reshape(-1,) - np.log(data_size)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 7, 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "data": { 185 | "image/png": "\n", 186 | "text/plain": [ 187 | "
" 188 | ] 189 | }, 190 | "metadata": {}, 191 | "output_type": "display_data" 192 | } 193 | ], 194 | "source": [ 195 | "# plotting all together\n", 196 | "plt.plot(plot_x, -plot_y_nwj)\n", 197 | "plt.plot(plot_x, -plot_y_jsd, c = \"green\")\n", 198 | "plt.plot(plot_x, -plot_y_mine, c = \"yellow\")\n", 199 | "plt.plot(plot_x, -plot_y_nce, c = \"red\")\n", 200 | "plt.plot(plot_x, mi*np.ones(len(plot_loss_mine)))\n", 201 | "plt.legend([\"MI_NWJ\", \"MI_JSD\", \"MI_MINE\", \"MI_NCE\", \"MI_STD\"])\n", 202 | "plt.show()" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "MINE and NCE almost coincide and hence, we don't see the yellow line." 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [] 218 | } 219 | ], 220 | "metadata": { 221 | "kernelspec": { 222 | "display_name": "Python 3", 223 | "language": "python", 224 | "name": "python3" 225 | }, 226 | "language_info": { 227 | "codemirror_mode": { 228 | "name": "ipython", 229 | "version": 3 230 | }, 231 | "file_extension": ".py", 232 | "mimetype": "text/x-python", 233 | "name": "python", 234 | "nbconvert_exporter": "python", 235 | "pygments_lexer": "ipython3", 236 | "version": "3.6.5" 237 | } 238 | }, 239 | "nbformat": 4, 240 | "nbformat_minor": 2 241 | } 242 | -------------------------------------------------------------------------------- /notebooks/pytorch_mine_vs_nce.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "The code is adapted from [Masanori Yamada's repository](https://github.com/MasanoriYamada/Mine_pytorch) on MINE implementation." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 2, 13 | "metadata": { 14 | "colab": {}, 15 | "colab_type": "code", 16 | "id": "BS2h2kglwPa_" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import torch\n", 21 | "from torch.autograd import Variable\n", 22 | "import torch.nn as nn\n", 23 | "import torch.nn.functional as F\n", 24 | "from torch.utils.data import Dataset, DataLoader\n", 25 | "import numpy as np\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "from tqdm import tqdm" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 3, 33 | "metadata": { 34 | "colab": {}, 35 | "colab_type": "code", 36 | "id": "Ze6Kk6SsQKiE" 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "# data\n", 41 | "np.random.seed(1234)\n", 42 | "var = 0.2\n", 43 | "def func(x):\n", 44 | " return x\n", 45 | "\n", 46 | "def gen_x():\n", 47 | " return np.sign(np.random.normal(0.,1.,[data_size,1]))\n", 48 | "\n", 49 | "def gen_y(x):\n", 50 | " return func(x)+np.random.normal(0.,np.sqrt(var),[data_size,1])" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "np.random.seed(1234)\n", 60 | "data_size = 1000000\n", 61 | "x = gen_x()\n", 62 | "y = gen_y(x)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 5, 68 | "metadata": { 69 | "colab": { 70 | "base_uri": "https://localhost:8080/", 71 | "height": 34 72 | }, 73 | "colab_type": "code", 74 | "id": "C_wM4ty1Qldx", 75 | "outputId": "2892779e-b4dc-4f29-9858-28afb707d11d" 76 | }, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "Mutual information calculated through traditional method is: 0.6584537102332939\n" 83 | ] 84 | } 85 | ], 86 | "source": [ 87 | "x = gen_x()\n", 88 | "y = gen_y(x)\n", 89 | "p_y_x = np.exp(-(y - x)**2 / (2 * var))\n", 90 | "p_y_x_minus = np.exp(-(y + 1)**2 / (2 * var))\n", 91 | "p_y_x_plus = np.exp(-(y - 1)**2 / (2 * var))\n", 92 | "mi = np.average(np.log(p_y_x / (0.5 * p_y_x_minus + 0.5 * p_y_x_plus)))\n", 93 | "#mi = mutual_information(x, y)\n", 94 | "print(\"Mutual information calculated through traditional method is:\", mi)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 6, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "name": "stderr", 104 | "output_type": "stream", 105 | "text": [ 106 | "100%|██████████| 500/500 [00:07<00:00, 66.77it/s]\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "np.random.seed(1234)\n", 112 | "H=10\n", 113 | "n_epoch = 500\n", 114 | "data_size = 20000\n", 115 | "\n", 116 | "class Net(nn.Module):\n", 117 | " def __init__(self):\n", 118 | " super(Net, self).__init__()\n", 119 | " self.fc1 = nn.Linear(1, H)\n", 120 | " self.fc2 = nn.Linear(1, H)\n", 121 | " self.fc3 = nn.Linear(H, 1)\n", 122 | "\n", 123 | " def forward(self, x, y):\n", 124 | " h1 = F.relu(self.fc1(x)+self.fc2(y))\n", 125 | " h2 = self.fc3(h1)\n", 126 | " return h2 \n", 127 | "\n", 128 | "model = Net()\n", 129 | "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", 130 | "plot_loss_mine = []\n", 131 | "plot_loss_nce = []\n", 132 | "for epoch in tqdm(range(n_epoch)):\n", 133 | " x_sample=gen_x()\n", 134 | " y_sample=gen_y(x_sample)\n", 135 | " y_shuffle=np.random.permutation(y_sample)\n", 136 | " \n", 137 | " x_sample = Variable(torch.from_numpy(x_sample).type(torch.FloatTensor), requires_grad = True)\n", 138 | " y_sample = Variable(torch.from_numpy(y_sample).type(torch.FloatTensor), requires_grad = True)\n", 139 | " y_shuffle = Variable(torch.from_numpy(y_shuffle).type(torch.FloatTensor), requires_grad = True) \n", 140 | " \n", 141 | " pred_xy = model(x_sample, y_sample)\n", 142 | " pred_x_y = model(x_sample, y_shuffle)\n", 143 | "\n", 144 | " # Noise Contrastive Estimation(NCE) loss\n", 145 | " #ret_nce = torch.mean(pred_xy - torch.log(1 + pred_xy + torch.sum(torch.exp(pred_x_y))))\n", 146 | " ret_nce = torch.mean(pred_xy - torch.mean(torch.log(torch.sum(torch.exp(pred_x_y)))))\n", 147 | " # Mutual Information Neural Estimation (MINE) loss\n", 148 | " ret_mine = torch.mean(pred_xy) - torch.log(torch.mean(torch.exp(pred_x_y)))\n", 149 | " loss_mine = - ret_mine # maximize\n", 150 | " loss_nce = - ret_nce # maximize\n", 151 | " plot_loss_mine.append(loss_mine.data.numpy())\n", 152 | " plot_loss_nce.append(loss_nce.data.numpy())\n", 153 | " model.zero_grad()\n", 154 | " loss_mine.backward(retain_graph=True)\n", 155 | " loss_nce.backward(retain_graph=True)\n", 156 | " optimizer.step()" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 7, 162 | "metadata": { 163 | "colab": {}, 164 | "colab_type": "code", 165 | "id": "0HJFmVYJQ4uf" 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "plot_x = np.arange(len(plot_loss_mine))\n", 170 | "plot_y_mine = np.array(plot_loss_mine).reshape(-1,)\n", 171 | "plot_y_nce = np.array(plot_loss_nce).reshape(-1,) - np.log(data_size)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": {}, 177 | "source": [ 178 | "Plotting one after the other as the overlap is too sharp!" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 8, 184 | "metadata": { 185 | "colab": { 186 | "base_uri": "https://localhost:8080/", 187 | "height": 349 188 | }, 189 | "colab_type": "code", 190 | "id": "1-ezg88MRVEH", 191 | "outputId": "94ee42a7-9e19-4296-eff7-57e5077860c9" 192 | }, 193 | "outputs": [ 194 | { 195 | "data": { 196 | "image/png": "\n", 197 | "text/plain": [ 198 | "
" 199 | ] 200 | }, 201 | "metadata": {}, 202 | "output_type": "display_data" 203 | } 204 | ], 205 | "source": [ 206 | "# plotting MI using MINE\n", 207 | "plt.plot(plot_x, -plot_y_mine)\n", 208 | "plt.plot(plot_x, mi*np.ones(len(plot_loss_mine)))\n", 209 | "plt.show()" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 9, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "data": { 219 | "image/png": "\n", 220 | "text/plain": [ 221 | "
" 222 | ] 223 | }, 224 | "metadata": {}, 225 | "output_type": "display_data" 226 | } 227 | ], 228 | "source": [ 229 | "# plotting MI using NCE\n", 230 | "plt.plot(plot_x, -plot_y_nce, c = \"red\")\n", 231 | "plt.plot(plot_x, mi*np.ones(len(plot_loss_mine)))\n", 232 | "plt.show()" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 10, 238 | "metadata": {}, 239 | "outputs": [ 240 | { 241 | "data": { 242 | "text/plain": [ 243 | "True" 244 | ] 245 | }, 246 | "execution_count": 10, 247 | "metadata": {}, 248 | "output_type": "execute_result" 249 | } 250 | ], 251 | "source": [ 252 | "np.allclose(plot_y_mine,plot_y_nce, atol=1e-04)" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 11, 258 | "metadata": {}, 259 | "outputs": [ 260 | { 261 | "data": { 262 | "text/plain": [ 263 | "False" 264 | ] 265 | }, 266 | "execution_count": 11, 267 | "metadata": {}, 268 | "output_type": "execute_result" 269 | } 270 | ], 271 | "source": [ 272 | "np.allclose(plot_y_mine,plot_y_nce, atol=1e-05)" 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": {}, 278 | "source": [ 279 | "We see that the mutual information obtained from both the methods is almost equal and close to 1e-4 tolerance! Therefore, both the methods are consistent." 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [] 288 | } 289 | ], 290 | "metadata": { 291 | "accelerator": "GPU", 292 | "colab": { 293 | "name": "pytorch_mine.ipynb", 294 | "provenance": [], 295 | "version": "0.3.2" 296 | }, 297 | "kernelspec": { 298 | "display_name": "Python 3", 299 | "language": "python", 300 | "name": "python3" 301 | }, 302 | "language_info": { 303 | "codemirror_mode": { 304 | "name": "ipython", 305 | "version": 3 306 | }, 307 | "file_extension": ".py", 308 | "mimetype": "text/x-python", 309 | "name": "python", 310 | "nbconvert_exporter": "python", 311 | "pygments_lexer": "ipython3", 312 | "version": "3.6.5" 313 | } 314 | }, 315 | "nbformat": 4, 316 | "nbformat_minor": 1 317 | } 318 | -------------------------------------------------------------------------------- /notebooks/pytorch_mine_vs_nwj.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Using the fixed constant `e`, we get $I_{NWJ}$ from Nguyen et al. (2010) (also known as MINE-f from Belghazi et al. (2018))." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import torch\n", 17 | "from torch.autograd import Variable\n", 18 | "import torch.nn as nn\n", 19 | "import torch.nn.functional as F\n", 20 | "from torch.utils.data import Dataset, DataLoader\n", 21 | "import numpy as np\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "from tqdm import tqdm" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# data\n", 33 | "np.random.seed(1234)\n", 34 | "var = 0.2\n", 35 | "def func(x):\n", 36 | " return x\n", 37 | "\n", 38 | "def gen_x():\n", 39 | " return np.sign(np.random.normal(0.,1.,[data_size,1]))\n", 40 | "\n", 41 | "def gen_y(x):\n", 42 | " return func(x)+np.random.normal(0.,np.sqrt(var),[data_size,1])" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "np.random.seed(1234)\n", 52 | "data_size = 1000000\n", 53 | "x = gen_x()\n", 54 | "y = gen_y(x)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "Mutual information calculated through traditional method is: 0.6584537102332939\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "x = gen_x()\n", 72 | "y = gen_y(x)\n", 73 | "p_y_x = np.exp(-(y - x)**2 / (2 * var))\n", 74 | "p_y_x_minus = np.exp(-(y + 1)**2 / (2 * var))\n", 75 | "p_y_x_plus = np.exp(-(y - 1)**2 / (2 * var))\n", 76 | "mi = np.average(np.log(p_y_x / (0.5 * p_y_x_minus + 0.5 * p_y_x_plus)))\n", 77 | "#mi = mutual_information(x, y)\n", 78 | "print(\"Mutual information calculated through traditional method is:\", mi)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 5, 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "name": "stderr", 88 | "output_type": "stream", 89 | "text": [ 90 | "100%|██████████| 500/500 [00:07<00:00, 67.17it/s]\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "np.random.seed(1234)\n", 96 | "H=10\n", 97 | "n_epoch = 500\n", 98 | "data_size = 20000\n", 99 | "\n", 100 | "class Net(nn.Module):\n", 101 | " def __init__(self):\n", 102 | " super(Net, self).__init__()\n", 103 | " self.fc1 = nn.Linear(1, H)\n", 104 | " self.fc2 = nn.Linear(1, H)\n", 105 | " self.fc3 = nn.Linear(H, 1)\n", 106 | "\n", 107 | " def forward(self, x, y):\n", 108 | " h1 = F.relu(self.fc1(x)+self.fc2(y))\n", 109 | " h2 = self.fc3(h1)\n", 110 | " return h2 \n", 111 | "\n", 112 | "model = Net()\n", 113 | "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", 114 | "plot_loss_mine = []\n", 115 | "plot_loss_nwj = []\n", 116 | "for epoch in tqdm(range(n_epoch)):\n", 117 | " x_sample=gen_x()\n", 118 | " y_sample=gen_y(x_sample)\n", 119 | " y_shuffle=np.random.permutation(y_sample)\n", 120 | " \n", 121 | " x_sample = Variable(torch.from_numpy(x_sample).type(torch.FloatTensor), requires_grad = True)\n", 122 | " y_sample = Variable(torch.from_numpy(y_sample).type(torch.FloatTensor), requires_grad = True)\n", 123 | " y_shuffle = Variable(torch.from_numpy(y_shuffle).type(torch.FloatTensor), requires_grad = True) \n", 124 | " \n", 125 | " pred_xy = model(x_sample, y_sample)\n", 126 | " pred_x_y = model(x_sample, y_shuffle)\n", 127 | "\n", 128 | " # Nguyen Wainwright Jordan (NWJ) loss\n", 129 | " ret_nwj = torch.mean(pred_xy) - (1/2.71828) * torch.mean(torch.exp(pred_x_y))\n", 130 | " # Mutual Information Neural Estimation (MINE) loss\n", 131 | " ret_mine = torch.mean(pred_xy) - torch.log(torch.mean(torch.exp(pred_x_y)))\n", 132 | " loss_mine = - ret_mine # maximize\n", 133 | " loss_nwj = - ret_nwj # maximize\n", 134 | " plot_loss_mine.append(loss_mine.data.numpy())\n", 135 | " plot_loss_nwj.append(loss_nwj.data.numpy())\n", 136 | " model.zero_grad()\n", 137 | " loss_mine.backward(retain_graph=True)\n", 138 | " loss_nwj.backward(retain_graph=True)\n", 139 | " optimizer.step()" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 6, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "plot_x = np.arange(len(plot_loss_mine))\n", 149 | "plot_y_mine = np.array(plot_loss_mine).reshape(-1,)\n", 150 | "plot_y_nwj = np.array(plot_loss_nwj).reshape(-1,)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 7, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "data": { 160 | "image/png": "\n", 161 | "text/plain": [ 162 | "
" 163 | ] 164 | }, 165 | "metadata": {}, 166 | "output_type": "display_data" 167 | } 168 | ], 169 | "source": [ 170 | "# plotting MI using MINE\n", 171 | "plt.plot(plot_x, -plot_y_mine)\n", 172 | "plt.plot(plot_x, mi*np.ones(len(plot_loss_mine)))\n", 173 | "plt.show()" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 8, 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "data": { 183 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD8CAYAAACfF6SlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3XmUVOWd//H3F1AWRZBFAVkFXIgLiT2CSVxHozjGZcYY/GXRqNHM6GTRJGNmMsYk5py4TGL0EBNHMyaahEQTIxqNGqNxAwRkEVCWBhSkhabZGlmb/v7++FZZ1d3V0E1X9+2u+3mdU+duT9373FtVn3rqubeqzN0REZF06ZR0BUREpO0p/EVEUkjhLyKSQgp/EZEUUviLiKSQwl9EJIUU/iIiKaTwFxFJIYW/iEgKdUm6Ao3p16+fDx8+POlqiIh0KLNmzVrn7v33Vq7dhv/w4cOZOXNm0tUQEelQzOztppRTt4+ISAop/EVEUkjhLyKSQgp/EZEUUviLiKSQwl9EJIUU/iIiKaTwF5HS4A6/+hVUVyddkw5B4S8ipWHWLLjsMrj22qRr0iEo/EVkz2pqYMaMaFm3J48/Du+9l5uurIzhW29FnbNqa1uvDhs3wl/+su/3370bduwoXn2aQeEv0p6tXg0VFRFgkybBH/7QsExlJcyfH+O33AKHHw7LlsHSpXXLbd8Ou3bF+Nq18MQTcPfdMT17dgRRIRddBCeeCA8/3PR6z58P778frfFJkwqXWbEC1q+P8QcfhKee2vt6N22CZ56J4fnnw0c/GvOnT4d582J8xgzYbz94/XX4+9+hc2eYM6fxdVZXQ//+cSwaOwYAq1ZFULvDbbfBCy/AeefBhAlxPLM2boQ33qh735UrYcmS3PTOnTGcOBG6dYt1Tp9e902rtbl7u7ydcMIJLh3c7Nnuq1Y1/37vvut++ukxfPFF91mzmr+O73/f/Y47mn8/d/c1a9xfeCE3XVvrfu+97vfc475unfumTTF/82b3q65yX7687v1/+cvYvrv73/8e93V3X7/e/Y9/dP/KV9wrK91/8Qv3116re99169wfeii26e4O7mbuDz8c4+D+/vvuO3fG8kmTcvOz5bM3s7rr7tfP/YwzYnz48Fy5p56K4Y9+5H7zze7/9m9RZscO97KyXLkLLih8vG6/3f2552K8vNz9iiui/H/8h/u4cTF+yilR16zKypg/ZEjdeu/YUXgb7u4vv5wrd/XVdfe10O2b33T/+Mdj/Oyz3T/9afdt22Jdu3e7//zn8VjOnJm7z1e/Go/RWWe5f+ELuW3Pnh3LTz01Hl+oewz/+tdc2QkTYt4dd8R2du9279o15s2d6z5wYIxn6wbuP/1pDG+4IZ4X2cd/HwAzvQkZa97ePspllJWV+b7+sNt3H1/AwtWbi1wjKWj79rj17t1wWbbV9fGPN37/nTvj6d+1a0xv2ACVa6HiPRgxApYvj/mnnhrD97fAxk1w2GExXVsLnQyw3DpXroyWb/799qaiIlqARxwBL78cLcCRI6MldvDBdVuOBx4AJ5TB229H63Xo0Khr/n4DnHwyvPRSjH/sozBnbrSGAXr2jO317QvHHBPzKith4cLcerp1i2MLMHhwtDyzDu4No0ZHKzfrQx+CBQvq7leP7nD4SOjVC155Jeb16ZNrcec76CDYvDm3vS5dYv+yOlm0kLdlHu9hQ6F6S+7YjBsXre9t22K6d+94fLduza1j3LjYz+zjA3Uf506doH8/6H0wHHpoHKOdO6POs1+HLe83rHdjunRp2JIeMyb2Yf36aJ2PGBHbLC8vvI6xx8c28z9F9egOW7fVLTdiBFRVxfPyzTdz84cMhve35o53587x3Bo0KD7VZZnV6VYbc8oJfOdLZzV9X/OY2Sx3L9tbuXb7q57SDLt3Aw6dW/nhrKqKF+kJHwHL9BhOnx7Dk0+OFxHEC66T5dWtEflhd+qpERrZj+5Qt495x44I+oUL44V3QA/oeRBMnRohMXJkBHh5ed37bamOEN20Gbp3hx49YM2a2Jcjj4wX44IFsG5dlO/XL1fnbCC8Xe9HEre8D/PmwoaNueMyYkS8mFcsz5XLBj/AK6/GsHu3OE7ZK1KqqmDLltj3/ODv3DkX/BDdHPlhtmFj1MsMRo+GxYsbBj/EsZo/P/Y7KxtEw4bl9s0sF/xQ940GYMCh8N4aWJPp3ti8Gd55J45tVva5cMQRsHEDrK1sWJ9smXzL845ZbW3cb9MmWLsmd4xHjYo3nYN6Qq/e8QYPuTepQYNg1kyo9VzXVvZYde2a61dfuDDe5Lp3b7jt/OOR9cb83PMh+6aZH/yHHRZdPtn1ZI/hscfGvJV5x/GQQ3LdQ6NHx/O0oiK3PP/NviLvjaG1NOXjQRI3dfs0UXW1e58+8ZHxttvct25tWGbnzlh27rnuP/tZw+V//nN0reza5b5sWcPlW7bEssMOi+0880zMnz4997H1gAOim+HLX47pn/0st2zevJi/a1fcb+pU9098ou5H9H/9V/d///e9f5TP3nr2dP/sZ3PTH/pQ42WPOio3/sgjufEvfcl9wwb37t1z84YObXw9P/2p+3HH1Z13ww1Nq+9++7lfdlns/+TJMS9/u/m3AQPca2rcL7oo10WQPUb1y55ySnQlZKc/85kY9upVeN0VFe4/+IF7jx7RXZWd/9pr0d1z993xfHj7bfdDDsktf/75GB56aHTHNbafBxwQXVdf/3pu3v/+r/utt8b44Ye7v/KK+zXXxPT117v/8z/nyp53XnQ/Zaevv77u+u+7L47hm2+6P/ts3edpTU10G0Ecu7PPjvHvfz93/2HDGq/7ihV1p08/PYYjRrgfeWS8RsaPzx3n3/wmXlv1H5cxY6Kr57nn6s5/770YDhoU9d2wIbqZBg6M12dlZTwORxzhfv75zcuBPDSx22evBZK6KfybYMsW906d6j7BJk5sWO7OO+uW+eEP3T/6UfeXXor+7ez8/BfsoEERAjt3Rt/mDTe4H310LoTco++6sRdSt26F5598svvgwY3fD9wvvLDw/IsuiuHXvhbhVX95ti92b7ehQ90vvbTuvAcfzL3YBw0qfL+sV191f+yx6HvfvTsXUN26RSBceGGE73nnxRve7t11+3ArKiIkJ092//Wv3X/1q3gsHnss1tO7d67srFm57f/ud9EPDXH+AeINacuWXJmamnjMXnwxpsvKciF07rm59WbfiP/7v6MvvZDqavebbor7bt4c50EqK2PZb37jfvnldY/PHXdE8LvH8P/+LwJ6587Y/z/8IdbjHtP5/fvZxsKVV+b64K+4Ipbl962vWLHHl4S7u1dVuW/fHg2h11+P8y4Q5zx27XI/8MCYvuaaus8D93g8stMrV8a68h+77JvY5Zfn5q1dGw2Xu+6KN/Rp03LLFiyI1+QDD8T088+7L1q05/r/539GY2kfKfzT4I47CofUtde6P/54nGSaNs39tNMKl+vUKddSLHQ77bR4wRZalj0R+MlPxhvIb3/rftBBTQtfiBbcunXxRvTyyxGYn/pUvPDd65b9/e/jhFptbYTajh0RPuPHx/R110WrurY2Po3Uf2MYONB91Kjcibh77okWVn6ZRYuiJXbLLfGml2313nJLDP/lX/b8WKxb5756ddMfu2z45tu6Nbb13e/WLZet49q1Mf3ee7Fs06ZcMD31VLypZG3b5n7xxRE+7nH8qqubXr+sPZ14nDo16vXFL8ZJ7UL71FTV1fHGlt23p5/OnZz9p3/KhfG+WLs27j9uXEy/804EdU1NTM+YEc+jrI0b4zlQyMaN8an19dcLL8+uM0EK/1JSWRkfJR96KDcve7XDoYfmwqF374Yhe+650UVSf/7DD0c3TaFgfvXVCMuePePjbv6yr341QiU7feutuTq9+mrMGzs2Wm2XX+7+l7/kyi5cGK3rXr3iDWNPPvKR3P3eead5x2vbtgiPBx+Mq09mzYqwr66Oll32SpkrrnAfOTKORX133RXbnjMnyu/e3bw67Kvt2xsG7ty5uSuG2ps//Wnf3lSao7LS/W9/a9k6fvvb6K5KAYV/RzR7doTU1KkRkP37R4suv3U+a1b0oeZPP/54jI8albvMrrHbeeflLmHL9pOfdVZu+be/HcuyH2/zb5dcEuHkHt0WEB9j8z36qPuSJbnpnTujTze/a6Epl7FVV8f5h+z5hdZQW7vnutS/hFOkA2hq+JfkpZ4djjt8/vPw0EMxPWFCwy+81L80rFu3+CbjsGHxRZqysrh6ZeHC+NZj9lLI+qqq4qoFgBdfjMv0rrkmrtwYODCuCunUKbZ1881xaVxtLdxwAzz5ZNQNYvzGG2HatLpXkohIopp6qWfiLfzGbiXZ8l+zpm6r+OWX4yTqn/7UsJX9jW9EN8r550crvKoq19Xzuc/lWuDu0deb32p3z62nqsr9/vvjBC/s25dHsn3pItLu0ZYtfzM7B/gJ0Bm4z91/WKDMJcDNgANz3f3/7WmdJdfyd89dB5895iefHF8oKuTRR+HCCxvO37AhvjxjVnf+6tVxvXvnzjE9bVpczzxmTExXV8c1xiNHtnxfRKTdarMveZlZZ2AScBawCphhZlPcfWFemdHAt4CPufsGMzukpdvtMHbsiC+Z/PWvuXl33w0PPBBdMI0ZN67w/IMPLjx/0KC60+PH153u2TNuIiIU5xu+JwJL3X0ZgJlNBi4A8r6uyBeBSe6+AcDd1zZYS6l56CG48sr4avqcOfDHP+aWffnLDcuXl8M3vhE/9LRqVfS/i4i0kmKE/2HAyrzpVUD9ZusRAGb2CtE1dLO7t+B3UNu5bdvgc5/LTZ9xRnwt/IQT4uQswOmnw/PPw4ABMGVK/BJjoV9sFBFpBcUIfyswr/6JhC7AaOA0YDDwkpkd4+4b66zI7GrgaoChQ4cWoWoJmTat7nT2t1S+850I/4oKuOsu+Pa34ZJL4B/+oe3rKCKpVozwXwUMyZseDNT/VaJVwDR33wUsN7NFxJvBjPxC7n4vcC/ECd8i1C0ZM2bUnf7CF+JyyQkT4JOfzM2//fa2rZeISEYxwn8GMNrMRgDvAhOB+lfy/Am4FHjAzPoR3UDLKDU1NfEHD/fdF904550XXT4XXJB0zURE6mhx+Lt7jZldBzxN9Of/wt0XmNn3iOtNp2SWfcLMFgK7gW+4+x4udemA5s6FsWNz07feCt/8ZnL1ERHZg6L8ALy7Pwk8WW/eTXnjDlyfuZWmO+7IjT/+eLT6RUTaKf2ZS0u5wz33xKWdPXrAF7+o4BeRdk/h31JvvQXXXhvjP/kJXHVVsvUREWmCTklXoMObMiU3ftJJydVDRKQZFP4t9dhj8Qfe99+f+x0dEZF2TuG/r3btgv/5n/gD8auugiuuaPhjayIi7ZTCf1/deSd8/evxg2oTJyZdGxGRZtEJ3331zDNw1FHx5ylq8YtIB6OW/77YtQtefRXOPFPBLyIdksJ/X8yeDVu3wimnJF0TEZF9ovBvrkWLcn+0cvLJydZFRGQfKfybK/+P1QcMSK4eIiItoPBvrjfeiOHy5cnWQ0SkBRT+zTVnDpx1FgwfnnRNRET2mcK/OVativDXzziISAen8G+OyZPjH7ny/59XRKQDUvg3x+uvw7BhMGpU0jUREWkRhX9zvPUWHH100rUQEWkxhX9T1dZG+B91VNI1ERFpMYV/Uz3zDGzbBscck3RNRERaTOHfVD/4QfT16xc8RaQEKPybwh3mzYOzz4YDDki6NiIiLabwb4p33oHNm+HYY5OuiYhIUSj8m2LevBgq/EWkRCj8m2L6dOjcGY4/PumaiIgUhcK/KaZOhbFj1d8vIiVD4b837jBjBowfn3RNRESKRuG/N+vXQ3U1jB6ddE1ERIpG4b8377wTwyFDkq2HiEgRKfz3Jhv+Q4cmWw8RkSJS+O+Nwl9ESpDCf09qa+GRR6B7d+jfP+naiIgUjcJ/T2bNghdfhK99DcySro2ISNEUJfzN7BwzW2RmS83sxj2Uu9jM3MzKirHdVldeHsNLL022HiIiRdbi8DezzsAkYAIwBrjUzMYUKNcT+DIwvaXbbDPLl8dQf9YuIiWmGC3/E4Gl7r7M3XcCk4ELCpT7PnAbsL0I22wby5dDv35w4IFJ10REpKiKEf6HASvzpldl5n3AzD4MDHH3J/a0IjO72sxmmtnMysrKIlSthZYvhxEjkq6FiEjRFSP8C50J9Q8WmnUCfgzcsLcVufu97l7m7mX9k766prY2/rBd/9wlIiWoGOG/Csj/+utgYHXedE/gGOAFM1sBjAemtPuTvvPmxU87nH560jURESm6YoT/DGC0mY0ws/2BicCU7EJ33+Tu/dx9uLsPB6YB57v7zCJsu/W8/HIMTz012XqIiLSCFoe/u9cA1wFPA28Cv3f3BWb2PTM7v6XrT8y8edCnj37TR0RKUpdirMTdnwSerDfvpkbKnlaMbba6N96A447Tl7tEpCTpG76F1NZG+OtvG0WkRCn8C6mogPffh6OOSromIiKtQuFfSPZnHUaNSrYeIiKtROFfyNKlMRw5Mtl6iIi0EoV/IeXl0KULDBuWdE1ERFqFwr+Q+fOj1d+lKBdDiYi0Owr/+txh6lQYPz7pmoiItBqFf33LlkFlJZx0UtI1ERFpNQr/+ubMieEJJyRbDxGRVqTwr2/+/PhW75gG/0cjIlIyFP71ZU/29uiRdE1ERFqNwr++hQvV6heRkqfwz+ce/96lL3eJSIlT+OerrIRt2/SH7SJS8hT++VasiKHCX0RKnMI/n8JfRFJC4Z9v0aIYjhiRbD1ERFqZwj/fSy/FH7j07Jl0TUREWpXCP6umBl59FU45JemaiIi0OoV/1uzZ8e9dCn8RSQGFf9aLL8ZQ4S8iKaDwz5o2Lb7cNWBA0jUREWl1Cv+sZcvgiCOSroWISJtQ+GetWKHr+0UkNRT+AJs3w/r1Cn8RSQ2FP8Dbb8dQ4S8iKaHwB1iyJIaHH55sPURE2ojCH2DePOjUSb/jLyKpofAHmDsXRo/Wv3eJSGoo/CFa/scdl3QtRETajMK/ujqu8T/++KRrIiLSZooS/mZ2jpktMrOlZnZjgeXXm9lCM5tnZs+Z2bBibLco3ngjhmr5i0iKtDj8zawzMAmYAIwBLjWz+mdOZwNl7n4c8AhwW0u3WzTz5sVQLX8RSZFitPxPBJa6+zJ33wlMBi7IL+Duz7v71szkNGBwEbZbHDNmQN++MGRI0jUREWkzxQj/w4CVedOrMvMacyXwVBG2WxzTpsH48WCWdE1ERNpMMcK/UGp6wYJmnwXKgNsbWX61mc00s5mVlZVFqNpebNoEb74J48a1/rZERNqRYoT/KiC/z2QwsLp+ITM7E/gv4Hx331FoRe5+r7uXuXtZ//79i1C1vXjrLXBXf7+IpE4xwn8GMNrMRpjZ/sBEYEp+ATP7MPBzIvjXFmGbxbF4cQyPPDLZeoiItLEWh7+71wDXAU8DbwK/d/cFZvY9Mzs/U+x24EDgYTObY2ZTGlld21q0CDp3hhEjkq6JiEib6lKMlbj7k8CT9ebdlDd+ZjG2U3SLF0fw779/0jUREWlT6f6G7+LF6vIRkVRKb/jX1kb4668bRSSF0hv+774L27Yp/EUkldIb/tkrfRT+IpJC6Q3/BQtiePTRydZDRCQB6Q3/efOgf38YMCDpmoiItLl0h/9xx+k3fUQkldIZ/rW10e1zzDFJ10REJBHpDP/Vq2HrVl3jLyKplc7wX7IkhrrSR0RSKp3hn73Mc/ToZOshIpKQdIb/kiXQrRsMbj9/KCYi0pbSGf6LF8OoUdApnbsvIpLO9FuyRP39IpJq6Qv/mhooL1f4i0iqpS/8334bdu3SyV4RSbX0hb8u8xQRSWH46zJPEZEUhv+SJXDQQXDIIUnXREQkMekL/+y/d+kH3UQkxdIZ/uryEZGUS1f479gRV/voZK+IpFy6wr+8HNzV8heR1EtX+OsyTxERIG3hv3RpDEeNSrYeIiIJS1f4L1sGBx8cNxGRFEtX+JeXw+GHJ10LEZHEpSv8ly1T+IuIkKbw370bVqxQ+IuIkKbwf/fd+DXPkSOTromISOLSE/7l5TFUy19EJEXhv2xZDBX+IiLFCX8zO8fMFpnZUjO7scDyrmb2u8zy6WY2vBjbbZZly6BLFxgypM03LSLS3rQ4/M2sMzAJmACMAS41szH1il0JbHD3UcCPgVtbut1mKy+HYcPiDUBEJOWK0fI/EVjq7svcfScwGbigXpkLgF9mxh8B/tGsjX9TWZd5ioh8oBjhfxiwMm96VWZewTLuXgNsAvoWYdtNp/AXEflAMcK/UAve96EMZna1mc00s5mVlZVFqFrGpk1QVaXLPEVEMooR/quA/LOog4HVjZUxsy5AL2B9/RW5+73uXubuZf379y9C1TKWL4/hiBHFW6eISAdWjPCfAYw2sxFmtj8wEZhSr8wU4LLM+MXA39y9Qcu/1bz7bgwHD26zTYqItGctvvTF3WvM7DrgaaAz8At3X2Bm3wNmuvsU4H7gQTNbSrT4J7Z0u81SURHDgQPbdLMiIu1VUa57dPcngSfrzbspb3w78KlibGufrM70Qg0YkFgVRETak3R8w7eiAvr2ha5dk66JiEi7kJ7wV5ePiMgHFP4iIimk8BcRSaHSD393eO89GDQo6ZqIiLQbpR/+VVXxJy5q+YuIfKD0wz97mafCX0TkA6Uf/vqCl4hIAwp/EZEUUviLiKRQOsK/Vy/o0SPpmoiItBvpCH+1+kVE6ij98F+9WuEvIlJP6Ye/Wv4iIg2Udvi7R/jr270iInWUdvhv3gzbt+t3/EVE6int8F+7NoaHHppsPURE2pnSDv81a2J4yCHJ1kNEpJ0p7fDPtvwV/iIidaQj/NXtIyJSRzrCv1+/ZOshItLOlHb4r1kDffrAfvslXRMRkXaltMN/7Vr194uIFFD64a/+fhGRBko7/NesUctfRKSA0g5/dfuIiBRUuuG/cyds2KBuHxGRAko3/CsrY6iWv4hIA6Ub/vp2r4hIo0o//Pv3T7YeIiLtUOmGf1VVDPv2TbYeIiLtkMJfRCSFWhT+ZtbHzJ41syWZ4cEFyow1s6lmtsDM5pnZp1uyzSbLhn+fPm2yORGRjqSlLf8bgefcfTTwXGa6vq3A5939Q8A5wJ1m1ruF2927qiro1Qu6dGn1TYmIdDQtDf8LgF9mxn8JXFi/gLsvdvclmfHVwFqg9c/CVlXp1zxFRBrR0vA/1N0rADLDPV5XaWYnAvsD5S3c7t5VVam/X0SkEXvtEzGzvwKF/gH9v5qzITMbCDwIXObutY2UuRq4GmDo0KHNWX1DVVW6xl9EpBF7DX93P7OxZWa2xswGuntFJtzXNlLuIODPwLfdfdoetnUvcC9AWVmZ761ue1RVBUcf3aJViIiUqpZ2+0wBLsuMXwY8Vr+Ame0PPAr8yt0fbuH2mk7dPiIijWpp+P8QOMvMlgBnZaYxszIzuy9T5hLgFOByM5uTuY1t4Xb3bOdOqK5W+IuINKJF10G6exXwjwXmzwSuyow/BDzUku002/r1MVT4i4gUVJrf8NW3e0VE9kjhLyKSQgp/EZEUUviLiKSQwl9EJIVKN/y7doUePZKuiYhIu1S64d+3L5glXRMRkXaptMNfREQKUviLiKSQwl9EJIUU/iIiKVR64e8ev+2j8BcRaVTphf/mzVBTo/AXEdmD0gv/mhr49Kfh2GOTromISLvVop90bpf69oXJk5OuhYhIu1Z6LX8REdkrhb+ISAop/EVEUkjhLyKSQgp/EZEUUviLiKSQwl9EJIUU/iIiKWTunnQdCjKzSuDtFqyiH7CuSNXpKLTP6aB9Tod93edh7t5/b4Xabfi3lJnNdPeypOvRlrTP6aB9TofW3md1+4iIpJDCX0QkhUo5/O9NugIJ0D6ng/Y5HVp1n0u2z19ERBpXyi1/ERFpRMmFv5mdY2aLzGypmd2YdH2Kxcx+YWZrzWx+3rw+ZvasmS3JDA/OzDczuytzDOaZ2UeSq/m+M7MhZva8mb1pZgvM7CuZ+SW732bWzcxeM7O5mX3+bmb+CDObntnn35nZ/pn5XTPTSzPLhydZ/5Yws85mNtvMnshMl/Q+m9kKM3vDzOaY2czMvDZ7bpdU+JtZZ2ASMAEYA1xqZmOSrVXRPACcU2/ejcBz7j4aeC4zDbH/ozO3q4F72qiOxVYD3ODuRwPjgWszj2cp7/cO4Ax3Px4YC5xjZuOBW4EfZ/Z5A3BlpvyVwAZ3HwX8OFOuo/oK8GbedBr2+XR3H5t3SWfbPbfdvWRuwEnA03nT3wK+lXS9irh/w4H5edOLgIGZ8YHAosz4z4FLC5XryDfgMeCstOw30AN4HRhHfNmnS2b+B89z4GngpMx4l0w5S7ru+7CvgzNhdwbwBGAp2OcVQL9689rsuV1SLX/gMGBl3vSqzLxSdai7VwBkhodk5pfccch8tP8wMJ0S3+9M98ccYC3wLFAObHT3mkyR/P36YJ8zyzcBfdu2xkVxJ/BNoDYz3ZfS32cHnjGzWWZ2dWZemz23S+0/fK3AvDRezlRSx8HMDgT+AHzV3TebFdq9KFpgXofbb3ffDYw1s97Ao8DRhYplhh1+n83sPGCtu88ys9OyswsULZl9zviYu682s0OAZ83srT2ULfo+l1rLfxUwJG96MLA6obq0hTVmNhAgM1ybmV8yx8HM9iOC/9fu/sfM7JLfbwB33wi8QJzv6G1m2cZa/n59sM+Z5b2A9W1b0xb7GHC+ma0AJhNdP3dS2vuMu6/ODNcSb/In0obP7VIL/xnA6MxVAvsDE4EpCdepNU0BLsuMX0b0iWfnfz5zhcB4YFP2o2RHYtHEvx94091/lLeoZPfbzPpnWvyYWXfgTOIk6PPAxZli9fc5eywuBv7mmU7hjsLdv+Xug919OPGa/Zu7f4YS3mczO8DMembHgU8A82nL53bSJz1a4STKucBiop/0v5KuTxH367dABbCLaAVcSfRzPgcsyQz7ZMoacdVTOfAGUJZ0/fdxnz9OfLSdB8zJ3M4t5f0GjgNmZ/Z5PnBTZv7hwGvAUuBhoGtmfrfM9NLM8sOT3ocW7v9pwBOlvs+3SUcdAAAARklEQVSZfZubuS3IZlVbPrf1DV8RkRQqtW4fERFpAoW/iEgKKfxFRFJI4S8ikkIKfxGRFFL4i4ikkMJfRCSFFP4iIin0/wFhp/zkX+rIAwAAAABJRU5ErkJggg==\n", 184 | "text/plain": [ 185 | "
" 186 | ] 187 | }, 188 | "metadata": {}, 189 | "output_type": "display_data" 190 | } 191 | ], 192 | "source": [ 193 | "# plotting MI using NWJ\n", 194 | "plt.plot(plot_x, -plot_y_nwj, c = \"red\")\n", 195 | "plt.plot(plot_x, mi*np.ones(len(plot_loss_mine)))\n", 196 | "plt.show()" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 9, 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "data": { 206 | "image/png": "\n", 207 | "text/plain": [ 208 | "
" 209 | ] 210 | }, 211 | "metadata": {}, 212 | "output_type": "display_data" 213 | } 214 | ], 215 | "source": [ 216 | "# plotting both together\n", 217 | "plt.plot(plot_x, -plot_y_mine)\n", 218 | "plt.plot(plot_x, -plot_y_nwj, c = \"red\")\n", 219 | "plt.plot(plot_x, mi*np.ones(len(plot_loss_mine)))\n", 220 | "plt.show()" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "Again we see the consistency in the measurement of mutual information using these two methods." 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [] 236 | } 237 | ], 238 | "metadata": { 239 | "kernelspec": { 240 | "display_name": "Python 3", 241 | "language": "python", 242 | "name": "python3" 243 | }, 244 | "language_info": { 245 | "codemirror_mode": { 246 | "name": "ipython", 247 | "version": 3 248 | }, 249 | "file_extension": ".py", 250 | "mimetype": "text/x-python", 251 | "name": "python", 252 | "nbconvert_exporter": "python", 253 | "pygments_lexer": "ipython3", 254 | "version": "3.6.5" 255 | } 256 | }, 257 | "nbformat": 4, 258 | "nbformat_minor": 2 259 | } 260 | --------------------------------------------------------------------------------