├── Example notebook.ipynb ├── Final_Report__EBM___ML.pdf ├── ebm ├── .gitignore ├── README.md ├── __init__.py ├── config.py ├── models.py └── train.py ├── readme.md └── toy_examples ├── Example notebook.ipynb ├── README.md └── ebm_toy ├── .gitignore ├── README.md ├── __init__.py ├── config.py ├── models.py ├── train.py └── utils.py /Example notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "sTwReW7Zcd9H" 7 | }, 8 | "source": [ 9 | "# Setup\n", 10 | "Remember to properly set the global variables in `config.py` " 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from ebm.config import *\n", 20 | "gpu_device = \"cuda:1\"\n", 21 | "\n", 22 | "# For deterministic training\n", 23 | "set_seed(0)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "# Tensorboard\n", 31 | "Doc @ https://pytorch.org/docs/1.7.1/tensorboard.html?highlight=tensorboard \n", 32 | "I don't use it form within notebook. \n", 33 | "To correctly visualize the runs names, open in TB the parent folder of the runs folders!" 34 | ] 35 | }, 36 | { 37 | "cell_type": "raw", 38 | "metadata": {}, 39 | "source": [ 40 | "# Load the TensorBoard notebook extension\n", 41 | "%load_ext tensorboard\n", 42 | "\n", 43 | "# If alerady loaded\n", 44 | "#%reload_ext tensorboard\n", 45 | "\n", 46 | "# Launch from \"terminal\": not working here \n", 47 | "#%tensorboard --logdir /mnt/workspace/EBM_proj/saved_models --host localhost #--port=8889\n", 48 | "\n", 49 | "from tensorboard import notebook\n", 50 | "# View open TensorBoard instances\n", 51 | "notebook.list() \n", 52 | "\n", 53 | "# Display tensorboard in this notebook\n", 54 | "#notebook.display(port=6007, height=1000)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "# Import & install libs" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": { 68 | "colab": { 69 | "base_uri": "https://localhost:8080/" 70 | }, 71 | "id": "d7BD4iiu9iO7", 72 | "outputId": "5384e70b-7397-433b-b61f-c5e1b20cf2ea" 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "%load_ext autoreload\n", 77 | "%autoreload 2\n", 78 | "\n", 79 | "# Standard libraries\n", 80 | "import numpy as np \n", 81 | "from tqdm.notebook import tqdm\n", 82 | "\n", 83 | "## Imports for plotting\n", 84 | "import matplotlib.pyplot as plt\n", 85 | "from matplotlib import cm\n", 86 | "%matplotlib inline \n", 87 | "\n", 88 | "## PyTorch\n", 89 | "import torch\n", 90 | "import torch.nn as nn\n", 91 | "import torch.nn.functional as F\n", 92 | "import torch.utils.data as data\n", 93 | "import torch.optim as optim\n", 94 | "import torch.autograd as autograd\n", 95 | "from torch.utils.tensorboard import SummaryWriter\n", 96 | "\n", 97 | "# Torchvision\n", 98 | "import torchvision\n", 99 | "from torchvision.datasets import MNIST\n", 100 | "from torchvision import transforms\n", 101 | "from torchvision.utils import make_grid\n", 102 | "\n", 103 | "print(\"Torch version: \" + torch.__version__)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "device = torch.device(gpu_device) if torch.cuda.is_available() else torch.device(\"cpu\")\n", 113 | "print(\"Currenly using the device:\", device)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": { 119 | "id": "xu-1QkBe9zVE" 120 | }, 121 | "source": [ 122 | "# Dataset" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": { 129 | "colab": { 130 | "base_uri": "https://localhost:8080/" 131 | }, 132 | "id": "WzlZOO1O9wot", 133 | "outputId": "ebd870b1-2934-403e-ebdb-5140e53c6365" 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "# Create dataset folder if not exists\n", 138 | "if not os.path.exists(DATASET_PATH):\n", 139 | " os.mkdir(DATASET_PATH)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": { 146 | "id": "IP5zRaf296Vl" 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "# Transformations applied on each image => make them a tensor and normalize between -1 and 1\n", 151 | "transform = transforms.Compose(\n", 152 | " [transforms.ToTensor(),\n", 153 | " transforms.Normalize((0.5, ), (0.5, ))])\n", 154 | "\n", 155 | "# Loading the training dataset. We need to split it into a training and validation part\n", 156 | "train_set = MNIST(root=DATASET_PATH,\n", 157 | " train=True,\n", 158 | " transform=transform,\n", 159 | " download=True)\n", 160 | "\n", 161 | "# Loading the test set\n", 162 | "test_set = MNIST(root=DATASET_PATH,\n", 163 | " train=False,\n", 164 | " transform=transform,\n", 165 | " download=True)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": { 171 | "id": "77EPPgBZB8uU" 172 | }, 173 | "source": [ 174 | "# Trainer classes" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "metadata": { 181 | "code_folding": [], 182 | "id": "sKPtGyZOB8Q3" 183 | }, 184 | "outputs": [], 185 | "source": [ 186 | "from ebm.train import EBMLangVanilla, EBMLang2Ord\n", 187 | "from ebm.models import CNNModel, LeNet" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "## Test trainer" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "# Test 1\n", 204 | "MODEL_NAME = \"langVanilla_test\"\n", 205 | "MODEL_DESCRIPTION = \"This is a debug run\"\n", 206 | "MODEL_FAMILY = \"test\"\n", 207 | "\n", 208 | "EBMTrain = EBMLangVanilla(img_shape=(1, 28, 28),\n", 209 | " cnn=LeNet,\n", 210 | " batch_size=256,\n", 211 | " lr=5e-3,\n", 212 | " weight_decay=1e-3,\n", 213 | " mcmc_step_size=5e-6,\n", 214 | " mcmc_steps=2,\n", 215 | " model_name=MODEL_NAME,\n", 216 | " model_description=MODEL_DESCRIPTION,\n", 217 | " model_family=MODEL_FAMILY,\n", 218 | " overwrite=True,\n", 219 | " device=gpu_device)\n", 220 | "EBMTrain.setup()\n", 221 | "EBMTrain.prepare_data(train_set, test_set)\n", 222 | "\n", 223 | "try:\n", 224 | " # Train the model for N epochs\n", 225 | " EBMTrain.fit(2)\n", 226 | "finally:\n", 227 | " # Clear\n", 228 | " EBMTrain.clear()" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "## Reload trained model" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "metadata": {}, 241 | "source": [ 242 | "### Same *name* and *hyperparams*" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "MODEL_NAME = \"langVanilla_test\"\n", 252 | "MODEL_DESCRIPTION = \"This is a debug run\"\n", 253 | "MODEL_FAMILY = \"test\"\n", 254 | "\n", 255 | "EBMTrain = EBMLangVanilla(img_shape=(1, 28, 28),\n", 256 | " cnn=LeNet,\n", 257 | " batch_size=256,\n", 258 | " lr=5e-3,\n", 259 | " weight_decay=1e-3,\n", 260 | " mcmc_step_size=5e-6,\n", 261 | " mcmc_steps=2,\n", 262 | " model_name=MODEL_NAME,\n", 263 | " model_description=MODEL_DESCRIPTION,\n", 264 | " model_family=MODEL_FAMILY,\n", 265 | " overwrite=True,\n", 266 | " reload_model=True,\n", 267 | " device=gpu_device)\n", 268 | "EBMTrain.setup()\n", 269 | "EBMTrain.prepare_data(train_set, test_set)" 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "metadata": {}, 275 | "source": [ 276 | "Generate some samples from pretrained" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "mcmc_iter = 20\n", 286 | "EBMTrain.final_sampled_images = EBMTrain.tb_mcmc_images(\n", 287 | " batch_size=64, mcmc_steps=mcmc_iter, name=\"final_images_sample\", evaluation=True)\n", 288 | "# Plot them\n", 289 | "print(\"Final sample after %d mcmc iterations:\" % mcmc_iter)\n", 290 | "fig, ax = plt.subplots(figsize=(10, 10))\n", 291 | "ax.imshow(EBMTrain.final_sampled_images.permute(1, 2, 0))\n", 292 | "plt.show()" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "# Clear\n", 302 | "EBMTrain.clear()" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "metadata": {}, 308 | "source": [ 309 | "### Reload from given path\n", 310 | "Hyperparams to be explicitely set:\n", 311 | "- mcmc_step_size\n", 312 | "- gpu_device\n", 313 | "- cnn \n", 314 | "\n", 315 | "\n", 316 | "They have to be the same used during training (except fot GPU dev)" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "metadata": { 323 | "scrolled": true 324 | }, 325 | "outputs": [], 326 | "source": [ 327 | "model_root = \"saved_models/MNIST/...\"\n", 328 | "EBMTrain = EBMLangVanilla(mcmc_step_size=1e-3,\n", 329 | " cnn=CNNModel,\n", 330 | " reload_model=model_root,\n", 331 | " device=gpu_device)\n", 332 | "EBMTrain.setup()" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": {}, 338 | "source": [ 339 | "Generate some samples from pretrained" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": null, 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [ 348 | "mcmc_iter = 500\n", 349 | "EBMTrain.final_sampled_images = EBMTrain.tb_mcmc_images(\n", 350 | " batch_size=64, mcmc_steps=mcmc_iter, name=\"final_images_sample\", evaluation=True)\n", 351 | "# Plot them\n", 352 | "print(\"Final sample after %d mcmc iterations:\" % mcmc_iter)\n", 353 | "fig, ax = plt.subplots(figsize=(10, 10))\n", 354 | "ax.imshow(EBMTrain.final_sampled_images.permute(1, 2, 0))\n", 355 | "plt.show()" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": null, 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "# Clear\n", 365 | "EBMTrain.clear()" 366 | ] 367 | } 368 | ], 369 | "metadata": { 370 | "kernelspec": { 371 | "display_name": "torch", 372 | "language": "python", 373 | "name": "torch" 374 | }, 375 | "language_info": { 376 | "codemirror_mode": { 377 | "name": "ipython", 378 | "version": 3 379 | }, 380 | "file_extension": ".py", 381 | "mimetype": "text/x-python", 382 | "name": "python", 383 | "nbconvert_exporter": "python", 384 | "pygments_lexer": "ipython3", 385 | "version": "3.7.10" 386 | }, 387 | "latex_envs": { 388 | "LaTeX_envs_menu_present": true, 389 | "autoclose": false, 390 | "autocomplete": true, 391 | "bibliofile": "biblio.bib", 392 | "cite_by": "apalike", 393 | "current_citInitial": 1, 394 | "eqLabelWithNumbers": true, 395 | "eqNumInitial": 1, 396 | "hotkeys": { 397 | "equation": "Ctrl-E", 398 | "itemize": "Ctrl-I" 399 | }, 400 | "labels_anchors": false, 401 | "latex_user_defs": false, 402 | "report_style_numbering": false, 403 | "user_envs_cfg": false 404 | }, 405 | "toc": { 406 | "base_numbering": 1, 407 | "nav_menu": {}, 408 | "number_sections": true, 409 | "sideBar": true, 410 | "skip_h1_title": false, 411 | "title_cell": "Table of Contents", 412 | "title_sidebar": "Contents", 413 | "toc_cell": false, 414 | "toc_position": {}, 415 | "toc_section_display": true, 416 | "toc_window_display": false 417 | } 418 | }, 419 | "nbformat": 4, 420 | "nbformat_minor": 2 421 | } 422 | -------------------------------------------------------------------------------- /Final_Report__EBM___ML.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matbun/EBM--Generative-Energy-Based-Modeling/4b8a079982e968a9e7d7f2e0b97b33ba70772864/Final_Report__EBM___ML.pdf -------------------------------------------------------------------------------- /ebm/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints 3 | -------------------------------------------------------------------------------- /ebm/README.md: -------------------------------------------------------------------------------- 1 | This folder is a package for EBM training. 2 | 3 | # train.py 4 | Trainer classes for different approaches in EBM training (MCMC sampling from model): 5 | - `EBMLangVanilla` to use first order Langevin dynamics. 6 | - `EBMLang2Ord` to use second order Langevin dynamics. 7 | 8 | # models.py 9 | CNN models for MNIST dataset: 10 | - LeNet 11 | - Custom CNN model: `CNNModel` 12 | - DenseNet 13 | 14 | # config.py 15 | Is a configuration module that contains global variables that are useful in all modules and in the training notebook. 16 | In the training notebook import it as `from ebm.config import *`, before everything else concerning EBM. 17 | -------------------------------------------------------------------------------- /ebm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matbun/EBM--Generative-Energy-Based-Modeling/4b8a079982e968a9e7d7f2e0b97b33ba70772864/ebm/__init__.py -------------------------------------------------------------------------------- /ebm/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | import subprocess 6 | import sys 7 | 8 | def install(package): 9 | subprocess.check_call([sys.executable, "-m", "pip", "install", package]) 10 | 11 | try: 12 | import shutil 13 | except: 14 | install("pytest-shutil") 15 | import shutil 16 | 17 | 18 | HOME_DIR = os.path.expanduser("~") 19 | 20 | PROJECT_ROOT = os.path.join(HOME_DIR, "Projects/EBM_proj") 21 | os.chdir(PROJECT_ROOT) 22 | 23 | # Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10) 24 | DATASET_PATH = PROJECT_ROOT + "/data" 25 | 26 | # Path to the folder where the pretrained models are saved 27 | CHECKPOINT_PATH = PROJECT_ROOT + "/saved_models/MNIST/" 28 | 29 | 30 | # Set rdn seed 31 | def set_seed(seed: int=0, deterministic: bool=True, benchmark: bool=False): 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | 35 | # Ensure that all operations are deterministic on GPU (if used) for reproducibility 36 | torch.backends.cudnn.determinstic = deterministic 37 | torch.backends.cudnn.benchmark = benchmark -------------------------------------------------------------------------------- /ebm/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ## PyTorch 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | def init_layers(m): 9 | if type(m) == nn.Conv2d: 10 | nn.init.kaiming_normal_(m.weight) 11 | nn.init.zeros_(m.bias) 12 | elif type(m) == nn.Linear: 13 | m.weight.data.normal_(0, 0.01) 14 | m.bias.data.normal_(0, 0.01) 15 | 16 | def sq_activation(x): 17 | return 0.5 * torch.pow(x,2) 18 | 19 | 20 | class Swish(nn.Module): 21 | def forward(self, x): 22 | return x * torch.sigmoid(x) 23 | 24 | class LeNet(nn.Module): 25 | """ Adapted LeNet 26 | - Swish activ. func. 27 | - padding=2 in first convo layer (instead of 0) 28 | """ 29 | def __init__(self, out_dim=1, **kwargs): 30 | super().__init__() 31 | self.cnn_layers = nn.Sequential( 32 | nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2), #(28x28) 33 | nn.SiLU(), 34 | nn.AvgPool2d(kernel_size=2, stride=2, padding=0), #(14x14) 35 | nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0), #(10x10) 36 | nn.SiLU(), 37 | nn.AvgPool2d(kernel_size=2, stride=2, padding=0), #(5x5) 38 | nn.Flatten(), 39 | nn.Linear(5*5*16, 64), 40 | nn.SiLU(), 41 | nn.Linear(64, out_dim) 42 | ) 43 | self.cnn_layers.apply(init_layers) 44 | 45 | def forward(self, x): 46 | o = self.cnn_layers(x).squeeze(dim=-1) 47 | return sq_activation(o) 48 | 49 | 50 | class CNNModel(nn.Module): 51 | def __init__(self, hidden_features=32, out_dim=1, beta=0, gamma=0, **kwargs): 52 | """CNNModel 53 | beta: quadratic energies weigth. If not specified, the default is 0. If None, learnable parameter. 54 | gamma: Langevin "weight decay". If not specified, the default is 0. If None, learnable parameter. 55 | """ 56 | super().__init__() 57 | # We increase the hidden dimension over layers. Here pre-calculated for simplicity. 58 | c_hid1 = hidden_features // 2 59 | c_hid2 = hidden_features 60 | c_hid3 = hidden_features * 2 61 | 62 | # Series of convolutions and Swish activation functions 63 | self.cnn_layers = nn.Sequential( 64 | nn.Conv2d(1, c_hid1, kernel_size=5, stride=2, 65 | padding=4), # [16x16] 66 | Swish(), 67 | nn.Conv2d(c_hid1, c_hid2, kernel_size=3, stride=2, 68 | padding=1), # [8x8] 69 | Swish(), 70 | nn.Conv2d(c_hid2, c_hid3, kernel_size=3, stride=2, 71 | padding=1), # [4x4] 72 | Swish(), 73 | nn.Conv2d(c_hid3, c_hid3, kernel_size=3, stride=2, 74 | padding=1), # [2x2] 75 | Swish(), 76 | nn.Flatten(), 77 | nn.Linear(c_hid3 * 4, c_hid3), 78 | 79 | Swish(), 80 | nn.Linear(c_hid3, out_dim) 81 | ) 82 | self.cnn_layers.apply(init_layers) 83 | 84 | def forward(self, x): 85 | o = self.cnn_layers(x).squeeze(dim=-1) 86 | return sq_activation(o) 87 | 88 | 89 | ############################################################## 90 | ### DenseNet ################################################# 91 | ############################################################## 92 | 93 | import torch 94 | 95 | import torch.nn as nn 96 | import torch.optim as optim 97 | 98 | import torch.nn.functional as F 99 | from torch.autograd import Variable 100 | 101 | import torchvision.datasets as dset 102 | import torchvision.transforms as transforms 103 | from torch.utils.data import DataLoader 104 | 105 | import torchvision.models as models 106 | 107 | import sys 108 | import math 109 | 110 | class Bottleneck(nn.Module): 111 | def __init__(self, nChannels, growthRate): 112 | super(Bottleneck, self).__init__() 113 | interChannels = 4*growthRate 114 | # self.bn1 = nn.BatchNorm2d(nChannels, affine=False) 115 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, 116 | bias=False) 117 | # self.bn2 = nn.BatchNorm2d(interChannels, affine=False) 118 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, 119 | padding=1, bias=False) 120 | 121 | def forward(self, x): 122 | # out = self.conv1(F.relu(self.bn1(x))) 123 | # out = self.conv2(F.relu(self.bn2(out))) 124 | out = self.conv1(F.relu(x)) 125 | out = self.conv2(F.relu(out)) 126 | out = torch.cat((x, out), 1) 127 | return out 128 | 129 | class SingleLayer(nn.Module): 130 | def __init__(self, nChannels, growthRate): 131 | super(SingleLayer, self).__init__() 132 | # self.bn1 = nn.BatchNorm2d(nChannels, affine=False) 133 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, 134 | padding=1, bias=False) 135 | 136 | def forward(self, x): 137 | # out = self.conv1(F.relu(self.bn1(x))) 138 | out = self.conv1(F.relu(x)) 139 | out = torch.cat((x, out), 1) 140 | return out 141 | 142 | class Transition(nn.Module): 143 | def __init__(self, nChannels, nOutChannels): 144 | super(Transition, self).__init__() 145 | # self.bn1 = nn.BatchNorm2d(nChannels, affine=False) 146 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, 147 | bias=False) 148 | 149 | def forward(self, x): 150 | out = self.conv1(F.relu(x)) #self.bn1(x) 151 | out = F.avg_pool2d(out, 2) 152 | return out 153 | 154 | class DenseNet(nn.Module): 155 | """For EBM on MNIST: no batchnorm, input is 1 x 28 x 28""" 156 | def __init__(self, growthRate, reduction, nClasses=1, bottleneck=True, depth=11, gamma=0, activ_type=None, **kwargs): 157 | super(DenseNet, self).__init__() 158 | 159 | def erf_activ(x): 160 | return torch.erf(x)*np.sqrt(np.pi)/2 161 | 162 | # Trash 163 | self.cnn_layers = [] 164 | self.beta = torch.tensor(0) 165 | self.gamma = torch.tensor(gamma) 166 | self.activ = erf_activ if activ_type=="erf" else None 167 | if activ_type is not None: 168 | print("CNN: Initial activation:", activ_type) 169 | if gamma > 0: 170 | print("CNN: Using penalty") 171 | 172 | # Parabola params 173 | self.a = nn.Parameter(torch.tensor(1.), requires_grad=True) 174 | self.b = nn.Parameter(torch.tensor(1.), requires_grad=True) 175 | self.c = nn.Parameter(torch.tensor(0.1), requires_grad=True) 176 | 177 | # Good... 178 | nDenseBlocks = 4 179 | if bottleneck: 180 | nDenseBlocks //= 2 181 | 182 | nChannels = 2*growthRate 183 | self.conv1 = nn.Conv2d(1, nChannels, kernel_size=3, padding=1, 184 | bias=False) 185 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 186 | nChannels += nDenseBlocks*growthRate 187 | nOutChannels = int(math.floor(nChannels*reduction)) 188 | self.trans1 = Transition(nChannels, nOutChannels) 189 | 190 | nChannels = nOutChannels 191 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 192 | nChannels += nDenseBlocks*growthRate 193 | nOutChannels = int(math.floor(nChannels*reduction)) 194 | self.trans2 = Transition(nChannels, nOutChannels) 195 | 196 | nChannels = nOutChannels 197 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 198 | nChannels += nDenseBlocks*growthRate 199 | 200 | # self.bn1 = nn.BatchNorm2d(nChannels, affine=False) 201 | self.fc = nn.Linear(nChannels, nClasses) 202 | 203 | for m in self.modules(): 204 | if isinstance(m, nn.Conv2d): 205 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 206 | m.weight.data.normal_(0, math.sqrt(2. / n)) 207 | # elif isinstance(m, nn.BatchNorm2d): 208 | # m.weight.data.fill_(1) 209 | # m.bias.data.zero_() 210 | elif isinstance(m, nn.Linear): 211 | m.bias.data.zero_() 212 | 213 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): 214 | layers = [] 215 | for i in range(int(nDenseBlocks)): 216 | if bottleneck: 217 | layers.append(Bottleneck(nChannels, growthRate)) 218 | else: 219 | layers.append(SingleLayer(nChannels, growthRate)) 220 | nChannels += growthRate 221 | return nn.Sequential(*layers) 222 | 223 | def forward(self, x): 224 | # if self.activ is not None: 225 | # x1 = self.activ(x) 226 | # else: 227 | # x1=x 228 | x1 = x #torch.atan(x) 229 | out = self.conv1(x1) 230 | # print("conv", out.shape) 231 | out = self.trans1(self.dense1(out)) 232 | # print("D1 + T1", out.shape) 233 | out = self.trans2(self.dense2(out)) 234 | # print("D2 + T2", out.shape) 235 | out = self.dense3(out) 236 | # print("D3", out.shape) 237 | out = torch.squeeze(F.avg_pool2d(F.relu(out), 7)) #self.bn1(out) 238 | # print("Final", out.shape) 239 | # out = F.log_softmax(self.fc(out)) 240 | out = self.fc(out) 241 | # print("FC", out.shape) 242 | 243 | # Penalize pixels outside [-1, 1]: 244 | # Attenzione: quando voglio aumentare l'energia il modello potrebbe 245 | # cercare intenzionalmente di buttare i pixel fuori da questo intervallo!! 246 | # Rischio instabilità 247 | #penalty = self.gamma * (torch.exp(torch.pow(x, 4) / 4) - 1).sum() if self.gamma.item() > 0 else torch.tensor(0) 248 | 249 | #parabola = .5*torch.pow(self.a, 2) * out**2 + self.b * out + self.c #(self.b**2 / (4 * self.a)) 250 | 251 | return -0.5*torch.pow(out, 2) #+ out #parabola #+ penalty 252 | -------------------------------------------------------------------------------- /ebm/train.py: -------------------------------------------------------------------------------- 1 | ## Standard libraries 2 | import os 3 | import numpy as np 4 | from tqdm.notebook import tqdm 5 | from IPython.display import clear_output 6 | import shutil 7 | 8 | ## Imports for plotting 9 | import matplotlib.pyplot as plt 10 | from matplotlib import cm 11 | 12 | ## PyTorch 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | import torch.autograd as autograd 17 | import torch.utils.data as data 18 | from torch.utils.tensorboard import SummaryWriter 19 | 20 | # Torchvision 21 | import torchvision 22 | from torchvision.utils import make_grid 23 | 24 | # Custom modules 25 | from ebm.config import * 26 | from ebm.models import CNNModel 27 | 28 | 29 | class DeepEnergyModel: 30 | """ 31 | model_name (str) - Any name to visually recognize the model, like the #run. 32 | model_description (str) - Will be logged by tensorboard as "text" 33 | model_family (str) - When running multiple experiments, it may be useful to divide 34 | the models and their logged results in families (subdirs of checkpoint path). 35 | This param can have the form of a path/to/subfolder. 36 | overwrite (bool) - If the logs folder already exists, if True "overwrite" it (namely, 37 | add also the new logs, without removing the onld ones). 38 | reload_model [bool, str]: if True (bool), reloads the pretrained model with identical 39 | hyperparameters and name (if exists in the CHECKPOINT folder). If a string, reloads 40 | the model at the path indicated by the string. 41 | """ 42 | def __init__(self, 43 | cnn=None, 44 | img_shape=(1, 28, 28), 45 | batch_size=64, 46 | lr=1e-4, 47 | weight_decay=1e-4, 48 | mcmc_step_size=1e-5, 49 | mcmc_steps=250, 50 | model_name="unnamed", 51 | model_description="", 52 | model_family="Langevin_vanilla", 53 | mcmc_init_type="gaussian", 54 | device="cuda:1", 55 | overwrite=False, 56 | start_epoch=0, 57 | reload_model=False, 58 | **CNN_args): 59 | super().__init__() 60 | 61 | # Model 62 | assert cnn is not None, "CNN model has to be specified!" 63 | self.img_shape = img_shape 64 | self.device = torch.device(device) if torch.cuda.is_available() else torch.device("cpu") 65 | print("Running on device:", self.device) 66 | self.cnn = cnn(**CNN_args).to(self.device) 67 | self.reload_model = reload_model 68 | 69 | # Optimizers 70 | self.lr = lr 71 | self.weight_decay = weight_decay 72 | 73 | # Dataset 74 | self.batch_size = batch_size 75 | 76 | # MCMC 77 | self.mcmc_step_size = mcmc_step_size 78 | self.mcmc_steps = mcmc_steps 79 | self.mcmc_persistent_data = (2 * torch.rand((10000,) + img_shape, device=self.device) - 1) 80 | self.mcmc_init_type = mcmc_init_type 81 | 82 | 83 | # Logging 84 | # General purpose: add new element each iteration (batch) 85 | self.log_dict = dict() 86 | # MCMC sampling: add element each MCMC iteration 87 | self.mcmc_evolution_logs = dict() 88 | # Final sample of generated images 89 | self.final_sampled_images = None 90 | 91 | # Tensorboard 92 | self.model_name = model_name 93 | self.model_description = model_description 94 | self.model_family = model_family 95 | self.overwrite = overwrite 96 | 97 | # The following global variables are employed in different functions (in 98 | # different ways) to compute the SummaryWriter global_step. 99 | self.epoch_n = 0 100 | self.tot_batches = 0 101 | self.iter_n = 0 102 | self.start_epoch = start_epoch 103 | 104 | # Setup flag: check the model has been properly set up before starting 105 | self.is_setup = False 106 | 107 | def setup(self): 108 | """Setup the optimizers, setup the Tensorboard SummaryWriter, process hyperparams dict.""" 109 | # Optimizers 110 | self.configure_optimizers() 111 | 112 | # Hyperparams dict 113 | self.hparams_dict = { 114 | 'cnn': self.cnn.__class__.__name__, 115 | 'mcmc_step_size': self.mcmc_step_size, 116 | 'mcmc_steps': self.mcmc_steps, 117 | 'lr': self.lr, 118 | 'weight_decay': self.weight_decay, 119 | 'batch_size': self.batch_size, 120 | 'optimizer': self.optimizer.__class__.__name__ 121 | } 122 | 123 | # Tensorboard logs 124 | hparams_str = "__".join(["=".join( 125 | [str(el) for el in dict_entry]) for dict_entry in self.hparams_dict.items()]) 126 | full_name = self.model_name + "__" + hparams_str 127 | self.ckpt_path = os.path.join(CHECKPOINT_PATH, self.model_family, full_name) 128 | # Reload existing model? 129 | if os.path.exists(self.ckpt_path) and self.reload_model is True: 130 | # Reload model with the same hyperparams and name 131 | path = os.path.join(self.ckpt_path, "model_state_dict.pt") 132 | self.cnn.load_state_dict(torch.load(path)) 133 | print("Loaded pretrained existsing model") 134 | elif isinstance(self.reload_model, str): 135 | # Reload model from a given path 136 | path = os.path.join(self.reload_model, "model_state_dict.pt") 137 | self.cnn.load_state_dict(torch.load(path)) 138 | print("Loaded pretrained existsing model at %s" % self.reload_model) 139 | 140 | ## Ovrewrite existsing model ## 141 | # Unautorized overwrite 142 | elif os.path.exists(self.ckpt_path) and not self.overwrite: 143 | print("Model path: " + self.ckpt_path) 144 | raise NameError("Model already exists! Set self.overwrite=True to overwrite it.") 145 | # Autorized overwrite 146 | elif os.path.exists(self.ckpt_path) and self.overwrite: 147 | # Remove existsing folder 148 | shutil.rmtree(self.ckpt_path) 149 | print("Overwriting existing logs") 150 | 151 | # Create writer 152 | self.tb_writer = SummaryWriter(self.ckpt_path) 153 | # Add docstring to interpret logs 154 | self.tb_writer.add_text('logs_documentation', self.tb_logs_doc()) 155 | descr = "Model description:\n" + self.model_description + "\n\n\n" 156 | self.tb_writer.add_text("model_description", descr, 100) 157 | 158 | # Set is_setup flag to True 159 | self.is_setup = True 160 | 161 | def clear(self): 162 | # Tensorboard writer 163 | self.tb_writer.close() 164 | 165 | 166 | def configure_optimizers(self): 167 | # Optimize only the layers that require grad 168 | self.optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.cnn.parameters()), 169 | lr=self.lr, 170 | weight_decay=self.weight_decay) 171 | 172 | def prepare_data(self, train_set, test_set): 173 | self.train_loader = data.DataLoader(train_set, 174 | batch_size=self.batch_size, 175 | shuffle=True, 176 | drop_last=True, 177 | num_workers=2, 178 | pin_memory=True) 179 | self.test_loader = data.DataLoader(test_set, 180 | batch_size=self.batch_size, 181 | shuffle=False, 182 | drop_last=False, 183 | num_workers=2) 184 | 185 | ###################################################### 186 | ################ Training section #################### 187 | ###################################################### 188 | 189 | def training_step(self, batch): 190 | # Train mode 191 | self.cnn.train() 192 | 193 | real_imgs, _ = batch 194 | real_imgs = real_imgs.to(self.device) 195 | 196 | # Obtain samples 197 | fake_imgs = self.generate_samples() 198 | 199 | # Predict energy score for all images 200 | inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0) 201 | real_out, fake_out = self.cnn(inp_imgs).chunk(2, dim=0) 202 | cdiv_loss = real_out.mean() - fake_out.mean() 203 | 204 | # Free memory 205 | del real_imgs, fake_imgs, inp_imgs 206 | 207 | # Optimize 208 | self.optimizer.zero_grad() 209 | cdiv_loss.backward() 210 | self.optimizer.step() 211 | 212 | # Logging 213 | self.log('loss_cdiv', cdiv_loss) 214 | self.log('energy_avg_real', real_out.mean()) 215 | self.log('energy_avg_fake', fake_out.mean()) 216 | 217 | # Log layers weigth / bias norms 218 | mod = 1 219 | for m in self.cnn.modules(): 220 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 221 | if m.weight is not None: 222 | self.log(f"{m.__class__.__name__} #{str(mod)}: weight norm", 223 | torch.norm( 224 | m.weight).clone().detach().cpu().numpy(), 225 | printable=False) 226 | if m.bias is not None: 227 | self.log(f"{m.__class__.__name__} #{str(mod)}: bias norm", 228 | torch.norm( 229 | m.bias).clone().detach().cpu().numpy(), 230 | printable=False) 231 | 232 | mod +=1 233 | 234 | 235 | def fit(self, n_epochs=None): 236 | 237 | assert self.is_setup, "Model is not properly setup. Call .setup() before running!" 238 | 239 | if self.train_loader is None: 240 | print("Train data not loaded") 241 | return 242 | 243 | # Epochs 244 | self.tot_batches = len(self.train_loader) 245 | self.tot_epochs = n_epochs - self.start_epoch 246 | epochs_bar = tqdm(range(self.start_epoch, n_epochs), total=self.tot_epochs, leave=True) 247 | epochs_bar.set_description("Epochs") 248 | for self.epoch_n in epochs_bar: 249 | # Iterations 250 | self.log_active = True 251 | iters_bar = tqdm(enumerate(self.train_loader), 252 | total=self.tot_batches, 253 | position=0, 254 | leave=False) 255 | iters_bar.set_description("Batches (iterations)") 256 | for self.iter_n, batch in iters_bar: 257 | 258 | self.training_step(batch) 259 | 260 | 261 | ############## Tensorboard ############### 262 | # Evolution thoughout a mcmc simulation 263 | self.tb_mcmc_simulation() 264 | 265 | # Log a generaed sample of images 266 | self.tb_mcmc_images(batch_size=25, evaluation=True) 267 | 268 | # Force tensorboard to write to disk (to be sure) 269 | self.tb_writer.flush() 270 | 271 | ############# Other logs ################ 272 | # Print logged measures 273 | #self.flush_logs() 274 | 275 | # Save model state dict (params) 276 | self.save_model() 277 | 278 | # TB: Log a final batch of images sampled form the model 279 | mcmc_iter = 1000 280 | self.final_sampled_images = self.tb_mcmc_images( 281 | batch_size=64, mcmc_steps=mcmc_iter, name="final_images_sample", evaluation=True) 282 | # Plot them 283 | print("Final sample after %d mcmc iterations:" % mcmc_iter) 284 | fig, ax = plt.subplots(figsize=(10, 10)) 285 | ax.imshow(self.final_sampled_images.permute(1, 2, 0)) 286 | plt.show() 287 | 288 | ###################################################### 289 | ################ Langevin dynamics ################### 290 | ###################################################### 291 | 292 | 293 | def generate_samples(self, 294 | evaluation=False, 295 | batch_size=None, 296 | mcmc_steps=None): 297 | """ 298 | Draw samples using Langevin dynamics 299 | evaluation: if True, avoids logging mcmc stats. It means we're sampling 300 | from the model with arbitrary batchsize/mcmc_steps and it isn't related to training. 301 | noise_scale: Optional. float. If None, set to np.sqrt(step_size * 2) 302 | """ 303 | batch_size = self.batch_size if batch_size is None else batch_size 304 | mcmc_steps = self.mcmc_steps if mcmc_steps is None else mcmc_steps 305 | 306 | is_training = self.cnn.training 307 | self.cnn.eval() 308 | 309 | # Initial batch of noise / images: starting point of mcmc chain 310 | def sample_s_t_0(): 311 | if self.mcmc_init_type == 'persistent' and not evaluation: 312 | rand_inds = torch.randperm(self.mcmc_persistent_data.shape[0])[0:batch_size] 313 | return self.mcmc_persistent_data[rand_inds], rand_inds 314 | elif self.mcmc_init_type == 'data' and not evaluation: 315 | raise RuntimeError("EBM train: Not implmented error") 316 | #return torch.Tensor(self.train_set.sample_toy_data(batch_size)), None 317 | elif self.mcmc_init_type == 'uniform' or evaluation: 318 | return (2 * torch.rand((batch_size,) + self.img_shape, device=self.device) - 1) , None 319 | elif self.mcmc_init_type == 'gaussian' and not evaluation: 320 | return torch.randn((batch_size,) + self.img_shape, device=self.device) , None 321 | else: 322 | raise RuntimeError('Invalid method for "init_type" (use "persistent", "data", "uniform", or "gaussian")') 323 | 324 | x, rand_inds = sample_s_t_0() 325 | x = torch.autograd.Variable(x.clone(), requires_grad=True).to(self.device) 326 | original_x = x.clone().detach() 327 | 328 | noise_scale = np.sqrt(self.mcmc_step_size * 2) 329 | 330 | # Pre-allocate additive noise (for Langevin step) 331 | noise = torch.randn_like(x, device=self.device) 332 | 333 | 334 | def append_norm(in_tensor, array): 335 | return np.append( 336 | array, 337 | torch.norm(in_tensor, 338 | dim=[2, 3]).mean().clone().detach().cpu().numpy()) 339 | 340 | grad_norms = np.array([]) 341 | data_norms = np.array([]) 342 | 343 | # To study the evolution within an mcmc simulation 344 | distances = np.array([]) 345 | prev_distances = np.array([]) 346 | time_window = 50 347 | 348 | for _ in range(mcmc_steps): 349 | 350 | if self.iter_n < time_window: 351 | #Used to compute prev_distances items 352 | old_x = x.clone().detach() 353 | 354 | # Re-init noise tensor 355 | noise.normal_(mean=0.0, std=noise_scale) 356 | out = self.cnn(x) 357 | grad = autograd.grad(out.sum(), x, only_inputs=True)[0] 358 | # grad is in "device" by default 359 | 360 | # Avoid NaN gradients 361 | # if torch.any(torch.isnan(grad)): 362 | # self.tb_writer.flush() 363 | # raise RuntimeError("Langevin grad has some NaN values!") 364 | 365 | x = x - self.mcmc_step_size * grad + noise 366 | 367 | # Save stats 368 | grad_norms = append_norm(grad, grad_norms) 369 | data_norms = append_norm(x, data_norms) 370 | 371 | if self.iter_n < time_window: 372 | prev_distances = append_norm(x - old_x, prev_distances) 373 | distances = append_norm(x - original_x, distances) 374 | 375 | self.cnn.train(is_training) 376 | 377 | ####### Evolution within Langevin dynamics ###### 378 | # If at the beginning of an epoch, save the evolution of 379 | # grad and img norms, for a time window of width K. 380 | # These quantities will be logged within fit() function 381 | 382 | def append_mcmc_logs(prop_name, prop_array): 383 | full_name = "%s_epoch_%d" % (prop_name, self.epoch_n + 1) 384 | entry = self.mcmc_evolution_logs.get(full_name, None) 385 | if entry is None: 386 | self.mcmc_evolution_logs[full_name] = prop_array 387 | else: 388 | self.mcmc_evolution_logs[full_name] = np.vstack( 389 | (entry, prop_array)) 390 | return 391 | 392 | if not evaluation: 393 | # Beginning of epoch e 394 | # 'langevin_evolution_' metrics describe the evolution 395 | # within a mcmc sampling process. Computed over a time_window of iterations. 396 | if self.iter_n < time_window: 397 | # Gradient norm 398 | append_mcmc_logs("langevin_evolution_grad_norm", grad_norms) 399 | 400 | # Data norm 401 | append_mcmc_logs("langevin_evolution_img_norm", data_norms) 402 | 403 | # Distance from previous point 404 | append_mcmc_logs("langevin_evolution_distance2prevstep", prev_distances) 405 | 406 | # Distance from starting point 407 | append_mcmc_logs("langevin_evolution_distance2start", distances) 408 | 409 | 410 | # Always log the avg 411 | # 'langevin_avg_' metrics describe the avg value of a measure 412 | # within a mcmc sampling process. Computed at each iteration. 413 | self.log('langevin_avg_grad_norm', np.mean(grad_norms)) 414 | self.log('langevin_avg_img_norm', np.mean(data_norms)) 415 | e2e_distances = torch.norm( 416 | x - original_x, dim=[2, 3]).mean().clone().detach().cpu().numpy() 417 | self.log('langevin_avg_distance_start2end', e2e_distances) 418 | 419 | return x.detach() 420 | 421 | ###################################################### 422 | #################### Utilities ####################### 423 | ###################################################### 424 | 425 | def save_model(self): 426 | """Saves the state dict of the model""" 427 | torch.save(self.cnn.state_dict(), self.ckpt_path + "/model_state_dict.pt") 428 | 429 | def tb_mcmc_simulation(self): 430 | """ 431 | This function writes to tensorboard the evolution of a 432 | measure duing MCMC simulation. We have an array of misurations, 433 | each one obtained at an iteration of the mcmc method. 434 | K measurments are collected and the resulting arrays are vertically 435 | stacked, to obtain a matrix. For this reason, the mean is obtained by 436 | averaging on the 0 axis. 437 | """ 438 | # In this dict there are only 2D arrays! 439 | for name, array in self.mcmc_evolution_logs.items(): 440 | if array.ndim != 2: 441 | raise NameError("expected 2-dimensional array here!") 442 | array = array.mean(axis=0) 443 | for i in range(array.shape[0]): 444 | self.tb_writer.add_scalar(name, array[i], i) 445 | # Free 446 | del self.mcmc_evolution_logs 447 | self.mcmc_evolution_logs = dict() 448 | 449 | def tb_mcmc_images(self, name=None, batch_size=None, **MCMC_args): 450 | """ 451 | Generate B images from the currently learned model and add them as 452 | images grid to tensorboard. 453 | """ 454 | img_name = "sample_images_epoch_%d" % (self.epoch_n + 455 | 1) if name is None else name 456 | batch_size = self.batch_size if batch_size is None else batch_size 457 | fake_imgs = self.generate_samples(batch_size=batch_size, **MCMC_args) 458 | grid_img = make_grid(fake_imgs.clone().detach().cpu(), 459 | nrow=int(np.sqrt(batch_size)), 460 | normalize=True, 461 | range=(0, 1)) 462 | g_step = self.epoch_n * self.tot_batches + self.iter_n 463 | self.tb_writer.add_image(img_name, grid_img, g_step) 464 | return grid_img 465 | 466 | 467 | def log(self, name, val, printable=True): 468 | """ 469 | name: string name of the property to log 470 | val: value 471 | print: whether to print this quantity or not. If false the quantity is just for "intermediate" use by another function. 472 | """ 473 | if not self.log_active: 474 | return 475 | 476 | # Parse the value to log 477 | if isinstance(val, torch.Tensor): 478 | if val.dim() == 0: 479 | # Single element tensor (e.g. loss) 480 | payload = val.item() 481 | else: 482 | # Mupliple dimensions tensor (e.g. vector) 483 | payload = val.numpy( 484 | ) # Fine also for 1 element tensors, instead of .item() 485 | else: 486 | payload = val 487 | 488 | # Add the value to the logs list 489 | if self.log_dict.get(name, None) is None: 490 | self.log_dict[name] = ([payload], printable) 491 | else: 492 | self.log_dict[name][0].append(payload) 493 | 494 | # Add to tensorboard 495 | global_step = self.epoch_n * self.tot_batches + self.iter_n 496 | self.tb_writer.add_scalar(name, payload, global_step=global_step) 497 | 498 | def flush_logs(self): 499 | """ 500 | Called each epoch. 501 | Print the average of the current logged measure and remove it from the dict 502 | """ 503 | for name, (measures_list, printable) in self.log_dict.items(): 504 | if printable: 505 | print(f"{name}: {np.mean(np.array(measures_list)):.3f}") 506 | print() 507 | 508 | # Clean the active logs dictionary 509 | del self.log_dict 510 | self.log_dict = dict() 511 | 512 | def tb_logs_doc(self): 513 | return """ 514 | Documentation of Tensorboard logs 515 | 516 | 'langevin_evolution_' metrics describe the evolution within 517 | a mcmc sampling process. 518 | E.g. the norm of the generated images at each mcmc step: it's 519 | an array. 520 | Computed over a `time_window` of first K iterations of an epoch. 521 | 522 | 'langevin_avg_' metrics describe the avg value of a measure 523 | within a mcmc sampling process. 524 | E.g. the *avg* norm of the generated images at each mcmc step: 525 | it's a scalar. 526 | Computed at each iteration. 527 | 528 | 'energy_avg_': avg energy of real/fakes images at current iteration. 529 | 530 | 'loss': can be `loss`, `loss_cdiv`, `loss_reg` (regularization loss, weigthed 531 | by alpha hparam). 532 | 533 | 'layer_': norm of weights/biases of a given layer 534 | """ 535 | 536 | 537 | 538 | 539 | 540 | class EBMLangVanilla(DeepEnergyModel): 541 | """"Vanilla Langevin Dynamics""" 542 | def __init__(self, **kwargs): 543 | super().__init__(**kwargs) 544 | 545 | 546 | class EBMLang2Ord(DeepEnergyModel): 547 | """SGHMC: Second order Langevin Dynamics, with leapfrog""" 548 | def __init__(self, C=2, mass=1, **kwargs): 549 | super().__init__(**kwargs) 550 | self.C = C 551 | self.hparams_dict['C'] = C 552 | self.mass = mass 553 | self.hparams_dict['mass'] = mass 554 | 555 | def generate_samples(self, 556 | evaluation=False, 557 | batch_size=None, 558 | mcmc_steps=None): 559 | """ 560 | Draw samples using Langevin dynamics 561 | evaluation: if True, avoids logging mcmc stats. It means we're sampling 562 | from the model with arbitrary batchsize/mcmc_steps and it isn't related to training. 563 | noise_scale: Optional. float. If None, set to np.sqrt(step_size * 2) 564 | """ 565 | batch_size = self.batch_size if batch_size is None else batch_size 566 | mcmc_steps = self.mcmc_steps if mcmc_steps is None else mcmc_steps 567 | 568 | is_training = self.cnn.training 569 | self.cnn.eval() 570 | 571 | # Init images with RND normal noise: x_i ~ N(0,1) 572 | x = torch.randn((batch_size, ) + self.img_shape, device=self.device) 573 | original_x = x.clone().detach() 574 | x.requires_grad = True 575 | 576 | # Init momentum 577 | #momentum = torch.randn((batch_size, ) + self.img_shape, device=self.device) 578 | momentum = torch.zeros_like(x, device=self.device) 579 | noise_scale = np.sqrt(self.mcmc_step_size * 2 * self.C) 580 | 581 | # Pre-allocate additive noise (for Langevin step) 582 | noise = torch.randn_like(x, device=self.device) 583 | 584 | 585 | def append_norm(in_tensor, array): 586 | return np.append( 587 | array, 588 | torch.norm(in_tensor, 589 | dim=[2, 3]).mean().clone().detach().cpu().numpy()) 590 | 591 | grad_norms = np.array([]) 592 | data_norms = np.array([]) 593 | momentum_norms = np.array([]) 594 | 595 | # To study the evolution within an mcmc simulation 596 | distances = np.array([]) 597 | prev_distances = np.array([]) 598 | time_window = 50 599 | 600 | for _ in range(mcmc_steps): 601 | 602 | if self.iter_n < time_window: 603 | #Used to compute prev_distances items 604 | old_x = x.clone().detach() 605 | 606 | # Re-init noise tensor 607 | noise.normal_(mean=0.0, std=noise_scale) 608 | out = self.cnn(x) 609 | grad = autograd.grad(out.sum(), x, only_inputs=True)[0] 610 | 611 | x = x + self.mcmc_step_size * self.mass * momentum 612 | momentum = momentum - self.mass * momentum * self.mcmc_step_size * self.C - self.mcmc_step_size * grad + noise 613 | 614 | 615 | # Save stats 616 | grad_norms = append_norm(grad, grad_norms) 617 | data_norms = append_norm(x, data_norms) 618 | momentum_norms = append_norm(momentum, momentum_norms) 619 | 620 | if self.iter_n < time_window: 621 | prev_distances = append_norm(x - old_x, prev_distances) 622 | distances = append_norm(x - original_x, distances) 623 | 624 | self.cnn.train(is_training) 625 | 626 | ####### Evolution within Langevin dynamics ###### 627 | # If at the beginning of an epoch, save the evolution of 628 | # grad and img norms, for a time window of width K. 629 | # These quantities will be logged within fit() function 630 | 631 | def append_mcmc_logs(prop_name, prop_array): 632 | full_name = "%s_epoch_%d" % (prop_name, self.epoch_n + 1) 633 | entry = self.mcmc_evolution_logs.get(full_name, None) 634 | if entry is None: 635 | self.mcmc_evolution_logs[full_name] = prop_array 636 | else: 637 | self.mcmc_evolution_logs[full_name] = np.vstack( 638 | (entry, prop_array)) 639 | return 640 | 641 | if not evaluation: 642 | # Beginning of epoch e 643 | # 'langevin_evolution_' metrics describe the evolution 644 | # within a mcmc sampling process. Computed over a time_window of iterations. 645 | if self.iter_n < time_window: 646 | # Gradient norm 647 | append_mcmc_logs("langevin_evolution_grad_norm", grad_norms) 648 | 649 | # Data norm 650 | append_mcmc_logs("langevin_evolution_img_norm", data_norms) 651 | 652 | # Momentum norm 653 | append_mcmc_logs("langevin_evolution_momentum_norm", momentum_norms) 654 | 655 | # Distance from previous point 656 | append_mcmc_logs("langevin_evolution_distance2prevstep", prev_distances) 657 | 658 | # Distance from starting point 659 | append_mcmc_logs("langevin_evolution_distance2start", distances) 660 | 661 | 662 | # Always log the avg 663 | # 'langevin_avg_' metrics describe the avg value of a measure 664 | # within a mcmc sampling process. Computed at each iteration. 665 | self.log('langevin_avg_grad_norm', np.mean(grad_norms)) 666 | self.log('langevin_avg_img_norm', np.mean(data_norms)) 667 | self.log('langevin_avg_momentum_norm', np.mean(momentum_norms)) 668 | e2e_distances = torch.norm( 669 | x - original_x, dim=[2, 3]).mean().clone().detach().cpu().numpy() 670 | self.log('langevin_avg_distance_start2end', e2e_distances) 671 | 672 | return x.detach() -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Energy Based Models 2 | EBMs are a family of models currently under research. Their remarkable advantage with respect to VAEs is that they do not make any assumption on the form of the probability density they fit. These models are also a potential competitor of GANs. In this work (project at EURECOM University) I implement them as generative models with *Maximum Likelihood* estimation, aimed at generating MNIST images. 3 | Here you can also find the final report of my project: `Final_Report__EBM___ML.pdf` 4 | For more info: https://arxiv.org/abs/2101.03288 5 | 6 | # EBM PyTorch training packages 7 | These packages offer key utilities to train an Energy Based Model with ML estimation. MCMC sampling from the model can be carried out with Langevin dynamics (SGLD) or Stochastic Gradient Hamiltonian Monte Carlo (SGHMC). 8 | 9 | Two python packages: 10 | - `ebm`: train on MNIST dataset 11 | - `ebm_toy` in `toy_examples` folder: train on **gmm** (gaussian mixture model) or **circles** 2D datasets, where the ground truth distribution is known and supervised metrics (e.g. Kolmogorov-Smirnov distance) can be computed. 12 | 13 | Here I provide a sample notebook to show how to use the `ebm` package. 14 | 15 | -------------------------------------------------------------------------------- /toy_examples/README.md: -------------------------------------------------------------------------------- 1 | `ebm_toy` is a package for EBM training on toy 2D datasets taken from https://github.com/point0bar1/ebm-anatomy. 2 | I also provide a sample notebook to show how to use the package. 3 | -------------------------------------------------------------------------------- /toy_examples/ebm_toy/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints 3 | -------------------------------------------------------------------------------- /toy_examples/ebm_toy/README.md: -------------------------------------------------------------------------------- 1 | This folder is a package for EBM training on toy datasets taken from https://github.com/point0bar1/ebm-anatomy. 2 | 3 | # train.py 4 | Trainer classes for different approaches in EBM training (MCMC sampling from model): 5 | - `EBMLangVanilla` to use first order Langevin dynamics (SGLD). 6 | - `EBMLang2Ord` to use second order Langevin dynamics (SGHMC). 7 | - 8 | # models.py 9 | CNN model for 2D examples of shape: `(C x H X W) = (2 x 1 x 1)`. 10 | 11 | # utils.py 12 | - ToyDataset class 13 | - `ksDist` and `ks2d2s`: 2D Kolmogorov-Sminorv test to compute the distance among two samples' distributions. 14 | - Discrete KL divergence on a 2D grid 15 | 16 | # config.py 17 | Is a configuration module that contains global variables that are useful in all modules and in the training notebook. 18 | In the training notebook import it as `from ebm_toy.config import *`, before everything else concerning EBM. 19 | -------------------------------------------------------------------------------- /toy_examples/ebm_toy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matbun/EBM--Generative-Energy-Based-Modeling/4b8a079982e968a9e7d7f2e0b97b33ba70772864/toy_examples/ebm_toy/__init__.py -------------------------------------------------------------------------------- /toy_examples/ebm_toy/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | import subprocess 6 | import sys 7 | 8 | def install(package): 9 | subprocess.check_call([sys.executable, "-m", "pip", "install", package]) 10 | 11 | try: 12 | import shutil 13 | except: 14 | install("pytest-shutil") 15 | import shutil 16 | 17 | 18 | HOME_DIR = os.path.expanduser("~") 19 | 20 | PROJECT_ROOT = os.path.join(HOME_DIR, "Projects/EBM_proj/toy_examples") 21 | os.chdir(PROJECT_ROOT) 22 | 23 | # Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10) 24 | DATASET_PATH = PROJECT_ROOT + "/data" 25 | 26 | # Path to the folder where the pretrained models are saved 27 | CHECKPOINT_PATH = PROJECT_ROOT + "/saved_models/MNIST/" 28 | 29 | # Set rdn seed 30 | def set_seed(seed: int=0, deterministic: bool=True, benchmark: bool=False): 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | 34 | # Ensure that all operations are deterministic on GPU (if used) for reproducibility 35 | torch.backends.cudnn.determinstic = deterministic 36 | torch.backends.cudnn.benchmark = benchmark -------------------------------------------------------------------------------- /toy_examples/ebm_toy/models.py: -------------------------------------------------------------------------------- 1 | ## PyTorch 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def init_layers(m): 7 | if type(m) == nn.Conv2d: 8 | nn.init.kaiming_normal_(m.weight) 9 | nn.init.zeros_(m.bias) 10 | elif type(m) == nn.Linear: 11 | m.weight.data.normal_(0, 0.01) 12 | m.bias.data.normal_(0, 0.01) 13 | 14 | 15 | class Swish(nn.Module): 16 | def forward(self, x): 17 | return x * torch.sigmoid(x) 18 | 19 | 20 | 21 | ################################# 22 | # ## TOY NETWORK FOR 2D DATA ## # 23 | ################################# 24 | 25 | class ToyNet(nn.Module): 26 | def __init__(self, dim=2, n_f=32, leak=0.05): 27 | super().__init__() 28 | self.cnn_layers = nn.Sequential( 29 | nn.Conv2d(dim, n_f, 1, 1, 0), 30 | nn.LeakyReLU(leak), 31 | nn.Conv2d(n_f, n_f * 2, 1, 1, 0), 32 | nn.LeakyReLU(leak), 33 | nn.Conv2d(n_f * 2, n_f * 2, 1, 1, 0), 34 | nn.LeakyReLU(leak), 35 | nn.Conv2d(n_f * 2, n_f * 2, 1, 1, 0), 36 | nn.LeakyReLU(leak), 37 | nn.Conv2d(n_f * 2, 1, 1, 1, 0)) 38 | 39 | #self.cnn_layers.apply(init_layers) 40 | 41 | def forward(self, x): 42 | e = self.cnn_layers(x).squeeze() 43 | return 0.5 * e**2 -------------------------------------------------------------------------------- /toy_examples/ebm_toy/train.py: -------------------------------------------------------------------------------- 1 | ## Standard libraries 2 | import os 3 | import numpy as np 4 | from tqdm.notebook import tqdm 5 | from IPython.display import clear_output 6 | import shutil 7 | 8 | ## Imports for plotting 9 | import matplotlib.pyplot as plt 10 | from matplotlib import cm 11 | 12 | ## PyTorch 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | import torch.autograd as autograd 17 | import torch.utils.data as data 18 | from torch.utils.tensorboard import SummaryWriter 19 | 20 | # Torchvision 21 | import torchvision 22 | from torchvision.utils import make_grid 23 | 24 | # Custom modules 25 | from ebm_toy.config import * 26 | from ebm_toy.models import ToyNet 27 | 28 | 29 | 30 | class DeepEnergyModel: 31 | """ 32 | model_name (str) - Any name to visually recognize the model, like the #run. 33 | model_description (str) - Will be logged by tensorboard as "text" 34 | model_family (str) - When running multiple experiments, it may be useful to divide 35 | the models and their logged results in families (subdirs of checkpoint path). 36 | This param can have the form of a path/to/subfolder. 37 | overwrite (bool) - If the logs folder already exists, if True "overwrite" it (namely, 38 | add also the new logs, without removing the onld ones). 39 | mcmc_init_type: persistent, gaussian, uniform, data. 40 | """ 41 | def __init__(self, 42 | batch_size=64, 43 | img_shape=(1,28,28), 44 | lr=1e-4, 45 | weight_decay=1e-4, 46 | mcmc_step_size=1e-5, 47 | mcmc_steps=250, 48 | optim='sgd', 49 | mcmc_init_type="persistent", 50 | model_name="unnamed", 51 | model_description="", 52 | model_family="Langevin_vanilla", 53 | device="cuda:1", 54 | overwrite=False, 55 | reload_model=False, 56 | log_every_n_epochs=1, 57 | **CNN_args): 58 | super().__init__() 59 | 60 | # Model 61 | self.img_shape = img_shape 62 | self.device = torch.device(device) if torch.cuda.is_available() else torch.device("cpu") 63 | print("Running on device:", self.device) 64 | self.cnn = ToyNet().to(self.device) 65 | self.reload_model = reload_model 66 | 67 | # Optimizers 68 | self.lr = lr 69 | self.weight_decay = weight_decay 70 | self.optim_name = optim 71 | 72 | # Dataset 73 | self.batch_size = batch_size 74 | 75 | # MCMC 76 | self.mcmc_step_size = mcmc_step_size 77 | self.mcmc_steps = mcmc_steps 78 | self.mcmc_persistent_data = 2 * torch.rand([10000, 2, 1, 1]).to(self.device) - 1 79 | self.mcmc_init_type = mcmc_init_type 80 | 81 | # Logging 82 | # General purpose: add new element each iteration (batch) 83 | self.log_dict = dict() 84 | # MCMC sampling: add element each MCMC iteration 85 | self.mcmc_evolution_logs = dict() 86 | # Final sample of generated images 87 | self.final_sampled_images = None 88 | self.log_every_n_epochs = log_every_n_epochs 89 | 90 | # Tensorboard 91 | self.model_name = model_name 92 | self.model_description = model_description 93 | self.model_family = model_family 94 | self.overwrite = overwrite 95 | 96 | # The following global variables are employed in different functions (in 97 | # different ways) to compute the SummaryWriter global_step. 98 | self.epoch_n = 0 99 | self.tot_batches = 0 100 | self.iter_n = 0 101 | 102 | # Convert mcmc_steps to string if it's a list of tuples (veriable mcmc steps) 103 | if not isinstance(mcmc_steps, int): 104 | # Write in the hparams dict a brief string. 105 | conv_mcmc_steps = "sched" 106 | else: 107 | conv_mcmc_steps = mcmc_steps 108 | # Hyperparams dict 109 | self.hparams_dict = { 110 | 'mc_lr': self.mcmc_step_size, 111 | 'mc_steps': conv_mcmc_steps, 112 | 'lr': self.lr, 113 | 'w_dec': self.weight_decay, 114 | 'b_size': self.batch_size 115 | } 116 | 117 | # Setup flag: check the model has been properly set up before starting 118 | self.is_setup = False 119 | 120 | def setup(self): 121 | """Setup the optimizers, setup the Tensorboard SummaryWriter, process hyperparams dict.""" 122 | # Optimizers 123 | self.configure_optimizers() 124 | self.hparams_dict['opt'] = self.optimizer.__class__.__name__ 125 | 126 | # Tensorboard logs 127 | hparams_str = "__".join(["=".join( 128 | [str(el) for el in dict_entry]) for dict_entry in self.hparams_dict.items()]) 129 | full_name = self.model_name + "__" + hparams_str 130 | self.ckpt_path = os.path.join(CHECKPOINT_PATH, self.model_family, full_name) 131 | # Reload existing model? 132 | if os.path.exists(self.ckpt_path) and self.reload_model: 133 | path = os.path.join(self.ckpt_path, "model_state_dict.pt") 134 | self.cnn.load_state_dict(torch.load(path)) 135 | print("Loaded pretrained existsing model") 136 | ## Ovrewrite existsing model ## 137 | # Unautorized overwrite 138 | elif os.path.exists(self.ckpt_path) and not self.overwrite: 139 | print("Model path: " + self.ckpt_path) 140 | raise NameError("Model already exists! Set self.overwrite=True to overwrite it.") 141 | # Autorized overwrite 142 | elif os.path.exists(self.ckpt_path) and self.overwrite: 143 | # Remove existsing folder 144 | shutil.rmtree(self.ckpt_path) 145 | print("Overwriting existing logs") 146 | # Create writer 147 | self.tb_writer = SummaryWriter(self.ckpt_path) 148 | # Add some textual notes 149 | # 1. Add docstring to interpret logs 150 | self.tb_writer.add_text('logs_documentation', self.tb_logs_doc()) 151 | # 2. Add mcmc steps scheduler, if present 152 | mcmc_steps_schedule = "" 153 | if not isinstance(self.mcmc_steps, int): 154 | mcmc_steps_schedule = "Schedule of MCMC steps:\n\n" 155 | mcmc_steps_schedule += "-".join([f"{it}:{st}" for it,st in self.mcmc_steps]) 156 | descr = "Model description:\n" + self.model_description + "\n\n\n" 157 | self.tb_writer.add_text("model_description", descr + mcmc_steps_schedule, 100) 158 | 159 | # Set is_setup flag to True 160 | self.is_setup = True 161 | 162 | def clear(self): 163 | # Tensorboard writer 164 | self.tb_writer.close() 165 | 166 | 167 | def configure_optimizers(self): 168 | # Energy models can have issues with momentum as the loss surfaces changes with its parameters. 169 | # Hence, we set it to 0 by default. 170 | # Optimize only the layers that require grad 171 | if self.optim_name.lower() == "sgd": 172 | chosen_optim = optim.SGD 173 | elif self.optim_name.lower() == "adam": 174 | chosen_optim = optim.Adam 175 | else: 176 | print("Optimizer name:", self.optim_name) 177 | raise RuntimeError("Optimizer name not understood!") 178 | 179 | self.optimizer = chosen_optim(filter(lambda p: p.requires_grad, self.cnn.parameters()), 180 | lr=self.lr, 181 | weight_decay=self.weight_decay) 182 | 183 | # scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.97) # Exponential decay over epochs 184 | pass 185 | 186 | def prepare_data(self, train_set, test_set): 187 | self.train_set = train_set 188 | # Prepare data 189 | self.train_loader = data.DataLoader(train_set, 190 | batch_size=self.batch_size, 191 | shuffle=True, 192 | drop_last=True, 193 | num_workers=2, 194 | pin_memory=True) 195 | self.test_loader = data.DataLoader(test_set, 196 | batch_size=self.batch_size, 197 | shuffle=False, 198 | drop_last=False, 199 | num_workers=2) 200 | pass 201 | 202 | ###################################################### 203 | ################ Training section #################### 204 | ###################################################### 205 | 206 | def training_step(self, batch): 207 | 208 | # Train mode 209 | self.cnn.train() 210 | 211 | # We add minimal noise to the original images to prevent the model from focusing on purely "clean" inputs 212 | real_imgs, _ = batch 213 | real_imgs = real_imgs.to(self.device) 214 | 215 | # Obtain samples 216 | fake_imgs = self.generate_samples() 217 | 218 | # Predict energy score for all images 219 | inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0) 220 | real_out, fake_out = self.cnn(inp_imgs.float()).chunk(2, dim=0) 221 | 222 | # Calculate losses 223 | loss = real_out.mean() - fake_out.mean() 224 | 225 | # Optimize 226 | self.optimizer.zero_grad() 227 | loss.backward() 228 | self.optimizer.step() 229 | 230 | # Logging 231 | self.log('cdiv_loss', loss) 232 | self.log('energy_avg_real', real_out.mean()) 233 | self.log('energy_avg_fake', fake_out.mean()) 234 | 235 | # Log CNN beta and gamma (may be learnable) 236 | if hasattr(self.cnn, 'beta'): 237 | self.log('beta', self.cnn.beta.clone().detach().cpu()) 238 | if hasattr(self.cnn, 'gamma'): 239 | self.log('gamma', self.cnn.gamma.clone().detach().cpu()) 240 | 241 | # Log layers weigth / bias norms 242 | for layer_id, layer in enumerate(self.cnn.cnn_layers): 243 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 244 | self.log('%s_layer%d_weight_norm' % (layer.__class__.__name__, layer_id), 245 | torch.norm( 246 | layer.weight).clone().detach().cpu().numpy()) 247 | self.log('%s_layer%d_bias_norm' % (layer.__class__.__name__, layer_id), 248 | torch.norm(layer.bias).clone().detach().cpu().numpy()) 249 | 250 | # Free memory 251 | del real_imgs, fake_imgs, inp_imgs 252 | 253 | def fit(self, n_epochs=None): 254 | 255 | assert self.is_setup, "Model is not properly setup. Call .setup() before running!" 256 | 257 | if self.train_loader is None: 258 | print("Train data not loaded") 259 | return 260 | 261 | # Epochs 262 | self.tot_batches = len(self.train_loader) 263 | epochs_bar = tqdm(range(n_epochs), total=n_epochs, leave=True) 264 | epochs_bar.set_description("Epochs") 265 | for self.epoch_n in epochs_bar: 266 | 267 | # Iterations 268 | self.log_active = True 269 | iters_bar = tqdm(enumerate(self.train_loader), 270 | total=self.tot_batches, 271 | position=0, 272 | leave=False) 273 | iters_bar.set_description("Batches (iterations)") 274 | for self.iter_n, batch in iters_bar: 275 | 276 | self.training_step(batch) 277 | 278 | if (self.epoch_n + 1) % self.log_every_n_epochs == 0: 279 | ############## Tensorboard ############### 280 | # Evolution thoughout a mcmc simulation 281 | self.tb_mcmc_simulation() 282 | 283 | # Log a generated sample of images 284 | #self.tb_mcmc_images(batch_size=25, evaluation=True) 285 | 286 | # Log currently learned density 287 | learned_desity = self.train_set.ebm_learned_density(f=self.cnn, 288 | epsilon=0.0 #np.sqrt(2*self.mcmc_step_size) 289 | ) 290 | self.tb_add_figure(learned_desity, "Epoch %d: learned density" % (self.epoch_n + 1)) 291 | self.tb_add_figure(np.log(learned_desity + 1e-10), "Epoch %d: learned log-density" % (self.epoch_n + 1)) 292 | 293 | # Force tensorboard to write to disk (to be sure) 294 | self.tb_writer.flush() 295 | 296 | # Save model state dict (params) 297 | self.save_model() 298 | 299 | ###################################################### 300 | ################ Langevin dynamics ################### 301 | ###################################################### 302 | 303 | 304 | def generate_samples(self, 305 | evaluation=False, 306 | batch_size=None, 307 | mcmc_steps=None): 308 | """ 309 | Draw samples using Langevin dynamics 310 | evaluation: if True, avoids logging mcmc stats. It means we're sampling 311 | from the model with arbitrary batchsize/mcmc_steps and it isn't related to training. 312 | noise_scale: Optional. float. If None, set to np.sqrt(step_size * 2) 313 | """ 314 | batch_size = self.batch_size if batch_size is None else batch_size 315 | mcmc_steps = self.mcmc_steps if mcmc_steps is None else mcmc_steps 316 | if not isinstance(mcmc_steps, int): 317 | # mcmc steps not fixed, but using a scheduler of the form 318 | # [(up_to_iter_i, mcmc_steps), (up_to_iter_j, mcmc_steps), ...] 319 | global_iter = self.epoch_n * self.tot_batches + self.iter_n 320 | curr_steps = None 321 | for max_iter,steps in mcmc_steps: 322 | if max_iter > global_iter: 323 | break 324 | mcmc_steps = steps 325 | 326 | is_training = self.cnn.training 327 | self.cnn.eval() 328 | 329 | # Initial batch of noise / images: starting point of mcmc chain 330 | def sample_s_t_0(): 331 | if self.mcmc_init_type == 'persistent': 332 | rand_inds = torch.randperm(self.mcmc_persistent_data.shape[0])[0:batch_size] 333 | return self.mcmc_persistent_data[rand_inds], rand_inds 334 | elif self.mcmc_init_type == 'data': 335 | return torch.Tensor(self.train_set.sample_toy_data(batch_size)), None 336 | elif self.mcmc_init_type == 'uniform': 337 | noise_init_factor = 2 338 | return noise_init_factor * (2 * torch.rand([batch_size, 2, 1, 1]) - 1), None 339 | elif self.mcmc_init_type == 'gaussian': 340 | noise_init_factor = 2 341 | return noise_init_factor * torch.randn([batch_size, 2, 1, 1]), None 342 | else: 343 | raise RuntimeError('Invalid method for "init_type" (use "persistent", "data", "uniform", or "gaussian")') 344 | 345 | x, rand_inds = sample_s_t_0() 346 | x = torch.autograd.Variable(x.clone(), requires_grad=True).to(self.device) 347 | original_x = x.clone().detach() 348 | #x.requires_grad = True 349 | 350 | noise_scale = np.sqrt(self.mcmc_step_size * 2) 351 | 352 | # Pre-allocate additive noise (for Langevin step) 353 | noise = torch.randn_like(x, device=self.device) 354 | 355 | r_s_t = torch.zeros(1).to(self.device) # variable r_s_t (Section 3.2) to record average gradient magnitude 356 | 357 | def append_norm(in_tensor, array): 358 | return np.append( 359 | array, 360 | torch.norm(in_tensor, 361 | dim=[2, 3]).mean().clone().detach().cpu().numpy()) 362 | 363 | grad_norms = np.array([]) 364 | #grad_avg_norm = np.array([]) 365 | data_norms = np.array([]) 366 | 367 | # To study the evolution within an mcmc simulation 368 | distances = np.array([]) 369 | prev_distances = np.array([]) 370 | time_window = 50 371 | 372 | for _ in range(mcmc_steps): 373 | 374 | if self.iter_n < time_window: 375 | #Used to compute prev_distances items 376 | old_x = x.clone().detach() 377 | 378 | # Re-init noise tensor 379 | noise.normal_(mean=0.0, std=noise_scale) 380 | out = self.cnn(x) 381 | grad = autograd.grad(out.sum(), x, only_inputs=True)[0] 382 | # grad is in "device" by default 383 | 384 | x = x - self.mcmc_step_size * grad + noise 385 | 386 | # avg grad norm 387 | r_s_t += grad.view(grad.shape[0], -1).norm(dim=1).mean() 388 | 389 | # Save stats 390 | grad_norms = append_norm(grad, grad_norms) 391 | data_norms = append_norm(x, data_norms) 392 | #grad_avg_norm = np.append(grad_avg_norm, r_s_t) 393 | 394 | if self.iter_n < time_window: 395 | prev_distances = append_norm(x - old_x, prev_distances) 396 | distances = append_norm(x - original_x, distances) 397 | 398 | self.cnn.train(is_training) 399 | 400 | ####### Evolution within Langevin dynamics ###### 401 | # If at the beginning of an epoch, save the evolution of 402 | # grad and img norms, for a time window of width K. 403 | # These quantities will be logged within fit() function 404 | 405 | def append_mcmc_logs(prop_name, prop_array): 406 | full_name = "%s_epoch_%d" % (prop_name, self.epoch_n + 1) 407 | entry = self.mcmc_evolution_logs.get(full_name, None) 408 | if entry is None: 409 | self.mcmc_evolution_logs[full_name] = prop_array 410 | else: 411 | self.mcmc_evolution_logs[full_name] = np.vstack( 412 | (entry, prop_array)) 413 | return 414 | 415 | if not evaluation: 416 | # Beginning of epoch e 417 | # 'langevin_evolution_' metrics describe the evolution 418 | # within a mcmc sampling process. Computed over a time_window of iterations. 419 | if self.iter_n < time_window: 420 | # Gradient norm 421 | append_mcmc_logs("langevin_evolution_grad_norm", grad_norms) 422 | 423 | # Data norm 424 | append_mcmc_logs("langevin_evolution_img_norm", data_norms) 425 | 426 | # Distance from previous point 427 | append_mcmc_logs("langevin_evolution_distance2prevstep", prev_distances) 428 | 429 | # Distance from starting point 430 | append_mcmc_logs("langevin_evolution_distance2start", distances) 431 | 432 | # Avg gradient norm 433 | #append_mcmc_logs("langevin_evolution_avgGradNorm", grad_avg_norm) 434 | 435 | 436 | # Always log the avg 437 | # 'langevin_avg_' metrics describe the avg value of a measure 438 | # within a mcmc sampling process. Computed at each iteration. 439 | self.log('langevin_avg_grad_norm', np.mean(grad_norms)) 440 | self.log('langevin_movingAvg_grad_norm', r_s_t.detach().cpu() / self.mcmc_steps) 441 | self.log('langevin_avg_img_norm', np.mean(data_norms)) 442 | e2e_distances = torch.norm( 443 | x - original_x, dim=[2, 3]).mean().clone().detach().cpu().numpy() 444 | self.log('langevin_avg_distance_start2end', e2e_distances) 445 | 446 | 447 | if self.mcmc_init_type == 'persistent' and not evaluation: 448 | # update persistent state bank 449 | self.mcmc_persistent_data.data[rand_inds] = x.detach().data.clone() 450 | 451 | return x.detach().float() 452 | 453 | ###################################################### 454 | #################### Utilities ####################### 455 | ###################################################### 456 | 457 | def save_model(self): 458 | """Saves the state dict of the model""" 459 | torch.save(self.cnn.state_dict(), self.ckpt_path + "/model_state_dict.pt") 460 | 461 | def tb_add_figure(self, X, title): 462 | """Add a figure to tensorboard""" 463 | fig, ax = plt.subplots(figsize=plt.figaspect(X)) 464 | fig.subplots_adjust(0,0,1,1) 465 | ax.imshow(X, cmap='viridis') 466 | ax.axis('off') 467 | 468 | g_step = self.epoch_n * self.tot_batches + self.iter_n 469 | self.tb_writer.add_figure(title, fig, global_step=g_step) 470 | 471 | def tb_mcmc_simulation(self): 472 | """ 473 | This function writes to tensorboard the evolution of a 474 | measure duing MCMC simulation. We have an array of misurations, 475 | each one obtained at an iteration of the mcmc method. 476 | K measurments are collected and the resulting arrays are vertically 477 | stacked, to obtain a matrix. For this reason, the mean is obtained by 478 | averaging on the 0 axis. 479 | """ 480 | # In this dict there are only 2D arrays! 481 | for name, array in self.mcmc_evolution_logs.items(): 482 | if array.ndim != 2: 483 | raise NameError("expected 2-dimensional array here!") 484 | array = array.mean(axis=0) 485 | for i in range(array.shape[0]): 486 | self.tb_writer.add_scalar(name, array[i], i) 487 | # Free 488 | del self.mcmc_evolution_logs 489 | self.mcmc_evolution_logs = dict() 490 | 491 | def tb_mcmc_images(self, name=None, batch_size=None, **MCMC_args): 492 | """ 493 | Generate B images from the currently learned model and add them as 494 | images grid to tensorboard. 495 | """ 496 | img_name = "sample_images_epoch_%d" % (self.epoch_n + 497 | 1) if name is None else name 498 | batch_size = self.batch_size if batch_size is None else batch_size 499 | fake_imgs = self.generate_samples(batch_size=batch_size, **MCMC_args) 500 | grid_img = make_grid(fake_imgs.clone().detach().cpu(), 501 | nrow=int(np.sqrt(batch_size)), 502 | normalize=True, 503 | range=(0, 1)) 504 | g_step = self.epoch_n * self.tot_batches + self.iter_n 505 | self.tb_writer.add_image(img_name, grid_img, g_step) 506 | return grid_img 507 | 508 | 509 | 510 | def log(self, name, val): 511 | """ 512 | name: string name of the property to log 513 | val: value 514 | print: whether to print this quantity or not. If false the quantity is just for "intermediate" use by another function. 515 | """ 516 | if not self.log_active: 517 | return 518 | 519 | # Parse the value to log 520 | if isinstance(val, torch.Tensor): 521 | if val.dim() == 0: 522 | # Single element tensor (e.g. loss) 523 | payload = val.item() 524 | else: 525 | # Mupliple dimensions tensor (e.g. vector) 526 | payload = val.numpy( 527 | ) # Fine also for 1 element tensors, instead of .item() 528 | else: 529 | payload = val 530 | 531 | # Add to tensorboard 532 | global_step = self.epoch_n * self.tot_batches + self.iter_n 533 | self.tb_writer.add_scalar(name, payload, global_step=global_step) 534 | 535 | 536 | def tb_logs_doc(self): 537 | return """ 538 | Documentation of Tensorboard logs 539 | 540 | 'langevin_evolution_' metrics describe the evolution within 541 | a mcmc sampling process. 542 | E.g. the norm of the generated images at each mcmc step: it's 543 | an array. 544 | Computed over a `time_window` of first K iterations of an epoch. 545 | 546 | 'langevin_avg_' metrics describe the avg value of a measure 547 | within a mcmc sampling process. 548 | E.g. the *avg* norm of the generated images at each mcmc step: 549 | it's a scalar. 550 | Computed at each iteration. 551 | 552 | 'energy_avg_': avg energy of real/fakes images at current iteration. 553 | 554 | 'loss': can be `loss`, `loss_cdiv`, `loss_reg` (regularization loss, weigthed 555 | by alpha hparam). 556 | 557 | 'layer_': norm of weights/biases of a given layer 558 | """ 559 | 560 | 561 | 562 | class EBMLangVanilla(DeepEnergyModel): 563 | """"Vanilla Langevin Dynamics""" 564 | def __init__(self, **kwargs): 565 | super().__init__(**kwargs) 566 | 567 | 568 | class EBMLangMomentum(DeepEnergyModel): 569 | """Second order Langevin Dynamics, with leapfrog""" 570 | def __init__(self, C=2, mass=1, **kwargs): 571 | super().__init__(**kwargs) 572 | self.C = C 573 | self.hparams_dict['C'] = C 574 | self.mass = mass 575 | self.hparams_dict['m'] = mass 576 | 577 | def generate_samples(self, 578 | evaluation=False, 579 | batch_size=None, 580 | mcmc_steps=None): 581 | """ 582 | Draw samples using Langevin dynamics 583 | evaluation: if True, avoids logging mcmc stats. It means we're sampling 584 | from the model with arbitrary batchsize/mcmc_steps and it isn't related to training. 585 | noise_scale: Optional. float. If None, set to np.sqrt(step_size * 2) 586 | """ 587 | batch_size = self.batch_size if batch_size is None else batch_size 588 | mcmc_steps = self.mcmc_steps if mcmc_steps is None else mcmc_steps 589 | if not isinstance(mcmc_steps, int): 590 | # mcmc steps not fixed, but using a scheduler of the form 591 | # [(up_to_iter_i, mcmc_steps), (up_to_iter_j, mcmc_steps), ...] 592 | global_iter = self.epoch_n * self.tot_batches + self.iter_n 593 | curr_steps = None 594 | for max_iter,steps in mcmc_steps: 595 | if max_iter > global_iter: 596 | break 597 | mcmc_steps = steps 598 | 599 | is_training = self.cnn.training 600 | self.cnn.eval() 601 | 602 | # Initial batch of noise / images: starting point of mcmc chain 603 | def sample_s_t_0(): 604 | if self.mcmc_init_type == 'persistent': 605 | rand_inds = torch.randperm(self.mcmc_persistent_data.shape[0])[0:batch_size] 606 | return self.mcmc_persistent_data[rand_inds], rand_inds 607 | elif self.mcmc_init_type == 'data': 608 | return torch.Tensor(self.train_set.sample_toy_data(batch_size)), None 609 | elif self.mcmc_init_type == 'uniform': 610 | noise_init_factor = 2 611 | return noise_init_factor * (2 * torch.rand([batch_size, 2, 1, 1]) - 1), None 612 | elif self.mcmc_init_type == 'gaussian': 613 | noise_init_factor = 2 614 | return noise_init_factor * torch.randn([batch_size, 2, 1, 1]), None 615 | else: 616 | raise RuntimeError('Invalid method for "init_type" (use "persistent", "data", "uniform", or "gaussian")') 617 | 618 | x, rand_inds = sample_s_t_0() 619 | x = torch.autograd.Variable(x.clone(), requires_grad=True).to(self.device) 620 | original_x = x.clone().detach() 621 | #x.requires_grad = True 622 | 623 | # Momentum 624 | momentum = torch.randn_like(x, device=self.device) 625 | 626 | noise_scale = np.sqrt(self.mcmc_step_size * 2) 627 | 628 | # Pre-allocate additive noise (for Langevin step) 629 | noise = torch.randn_like(x, device=self.device) 630 | 631 | r_s_t = torch.zeros(1).to(self.device) # variable r_s_t (Section 3.2) to record average gradient magnitude 632 | 633 | def append_norm(in_tensor, array): 634 | return np.append( 635 | array, 636 | torch.norm(in_tensor, 637 | dim=[2, 3]).mean().clone().detach().cpu().numpy()) 638 | 639 | grad_norms = np.array([]) 640 | data_norms = np.array([]) 641 | momentum_norms = np.array([]) 642 | 643 | # To study the evolution within an mcmc simulation 644 | distances = np.array([]) 645 | prev_distances = np.array([]) 646 | time_window = 50 647 | 648 | for _ in range(mcmc_steps): 649 | 650 | if self.iter_n < time_window: 651 | #Used to compute prev_distances items 652 | old_x = x.clone().detach() 653 | 654 | 655 | # Re-init noise tensor 656 | noise.normal_(mean=0.0, std=noise_scale) 657 | out = self.cnn(x) 658 | grad = autograd.grad(out.sum(), x, only_inputs=True)[0] 659 | # grad is in "device" by default 660 | 661 | momentum = momentum - self.mass * momentum * self.mcmc_step_size * self.C - self.mcmc_step_size * grad + noise 662 | x = x + self.mcmc_step_size * self.mass * momentum 663 | 664 | # avg grad norm 665 | r_s_t += grad.view(grad.shape[0], -1).norm(dim=1).mean() 666 | 667 | # Save stats 668 | grad_norms = append_norm(grad, grad_norms) 669 | data_norms = append_norm(x, data_norms) 670 | momentum_norms = append_norm(momentum, momentum_norms) 671 | 672 | if self.iter_n < time_window: 673 | prev_distances = append_norm(x - old_x, prev_distances) 674 | distances = append_norm(x - original_x, distances) 675 | 676 | self.cnn.train(is_training) 677 | 678 | ####### Evolution within Langevin dynamics ###### 679 | # If at the beginning of an epoch, save the evolution of 680 | # grad and img norms, for a time window of width K. 681 | # These quantities will be logged within fit() function 682 | 683 | def append_mcmc_logs(prop_name, prop_array): 684 | full_name = "%s_epoch_%d" % (prop_name, self.epoch_n + 1) 685 | entry = self.mcmc_evolution_logs.get(full_name, None) 686 | if entry is None: 687 | self.mcmc_evolution_logs[full_name] = prop_array 688 | else: 689 | self.mcmc_evolution_logs[full_name] = np.vstack( 690 | (entry, prop_array)) 691 | return 692 | 693 | if not evaluation: 694 | # Beginning of epoch e 695 | # 'langevin_evolution_' metrics describe the evolution 696 | # within a mcmc sampling process. Computed over a time_window of iterations. 697 | if self.iter_n < time_window: 698 | # Gradient norm 699 | append_mcmc_logs("langevin_evolution_grad_norm", grad_norms) 700 | 701 | # Data norm 702 | append_mcmc_logs("langevin_evolution_img_norm", data_norms) 703 | 704 | # Distance from previous point 705 | append_mcmc_logs("langevin_evolution_distance2prevstep", prev_distances) 706 | 707 | # Distance from starting point 708 | append_mcmc_logs("langevin_evolution_distance2start", distances) 709 | 710 | # Momentum norm 711 | append_mcmc_logs("langevin_evolution_momentum_norm", momentum_norms) 712 | 713 | 714 | # Always log the avg 715 | # 'langevin_avg_' metrics describe the avg value of a measure 716 | # within a mcmc sampling process. Computed at each iteration. 717 | self.log('langevin_avg_grad_norm', np.mean(grad_norms)) 718 | self.log('langevin_movingAvg_grad_norm', r_s_t.detach().cpu() / self.mcmc_steps) 719 | self.log('langevin_avg_img_norm', np.mean(data_norms)) 720 | self.log('langevin_avg_momentum_norm', np.mean(momentum_norms)) 721 | e2e_distances = torch.norm( 722 | x - original_x, dim=[2, 3]).mean().clone().detach().cpu().numpy() 723 | self.log('langevin_avg_distance_start2end', e2e_distances) 724 | 725 | 726 | if self.mcmc_init_type == 'persistent' and not evaluation: 727 | # update persistent state bank 728 | self.mcmc_persistent_data.data[rand_inds] = x.detach().data.clone() 729 | 730 | return x.detach().float() 731 | -------------------------------------------------------------------------------- /toy_examples/ebm_toy/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from numpy import random 5 | from scipy.stats import kstwobign, pearsonr 6 | 7 | import torch as t 8 | import torchvision as tv 9 | from torch.utils.data import Dataset 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | from mpl_toolkits.mplot3d import Axes3D 13 | from matplotlib import cm 14 | from matplotlib.ticker import LinearLocator, FormatStrFormatter 15 | 16 | 17 | 18 | ################## 19 | # ## PLOTTING ## # 20 | ################## 21 | 22 | # visualize negative samples synthesized from energy 23 | def plot_ims(p, x): tv.utils.save_image(t.clamp(x, -1., 1.), p, normalize=True, nrow=int(x.shape[0] ** 0.5)) 24 | 25 | # plot diagnostics for learning 26 | def plot_diagnostics(batch, en_diffs, grad_mags, exp_dir, fontsize=10): 27 | # axis tick size 28 | matplotlib.rc('xtick', labelsize=6) 29 | matplotlib.rc('ytick', labelsize=6) 30 | fig = plt.figure() 31 | 32 | def plot_en_diff_and_grad_mag(): 33 | # energy difference 34 | ax = fig.add_subplot(221) 35 | ax.plot(en_diffs[0:(batch+1)].data.cpu().numpy()) 36 | ax.axhline(y=0, ls='--', c='k') 37 | ax.set_title('Energy Difference', fontsize=fontsize) 38 | ax.set_xlabel('batch', fontsize=fontsize) 39 | ax.set_ylabel('$d_{s_t}$', fontsize=fontsize) 40 | # mean langevin gradient 41 | ax = fig.add_subplot(222) 42 | ax.plot(grad_mags[0:(batch+1)].data.cpu().numpy()) 43 | ax.set_title('Average Langevin Gradient Magnitude', fontsize=fontsize) 44 | ax.set_xlabel('batch', fontsize=fontsize) 45 | ax.set_ylabel('$r_{s_t}$', fontsize=fontsize) 46 | 47 | def plot_crosscorr_and_autocorr(t_gap_max=2000, max_lag=15, b_w=0.35): 48 | t_init = max(0, batch + 1 - t_gap_max) 49 | t_end = batch + 1 50 | t_gap = t_end - t_init 51 | max_lag = min(max_lag, t_gap - 1) 52 | # rescale energy diffs to unit mean square but leave uncentered 53 | en_rescale = en_diffs[t_init:t_end] / t.sqrt(t.sum(en_diffs[t_init:t_end] * en_diffs[t_init:t_end])/(t_gap-1)) 54 | # normalize gradient magnitudes 55 | grad_rescale = (grad_mags[t_init:t_end]-t.mean(grad_mags[t_init:t_end]))/t.std(grad_mags[t_init:t_end]) 56 | # cross-correlation and auto-correlations 57 | cross_corr = np.correlate(en_rescale.cpu().numpy(), grad_rescale.cpu().numpy(), 'full') / (t_gap - 1) 58 | en_acorr = np.correlate(en_rescale.cpu().numpy(), en_rescale.cpu().numpy(), 'full') / (t_gap - 1) 59 | grad_acorr = np.correlate(grad_rescale.cpu().numpy(), grad_rescale.cpu().numpy(), 'full') / (t_gap - 1) 60 | # x values and indices for plotting 61 | x_corr = np.linspace(-max_lag, max_lag, 2 * max_lag + 1) 62 | x_acorr = np.linspace(0, max_lag, max_lag + 1) 63 | t_0_corr = int((len(cross_corr) - 1) / 2 - max_lag) 64 | t_0_acorr = int((len(cross_corr) - 1) / 2) 65 | 66 | # plot cross-correlation 67 | ax = fig.add_subplot(223) 68 | ax.bar(x_corr, cross_corr[t_0_corr:(t_0_corr + 2 * max_lag + 1)]) 69 | ax.axhline(y=0, ls='--', c='k') 70 | ax.set_title('Cross Correlation of Energy Difference\nand Gradient Magnitude', fontsize=fontsize) 71 | ax.set_xlabel('lag', fontsize=fontsize) 72 | ax.set_ylabel('correlation', fontsize=fontsize) 73 | # plot auto-correlation 74 | ax = fig.add_subplot(224) 75 | ax.bar(x_acorr-b_w/2, en_acorr[t_0_acorr:(t_0_acorr + max_lag + 1)], b_w, label='en. diff. $d_{s_t}$') 76 | ax.bar(x_acorr+b_w/2, grad_acorr[t_0_acorr:(t_0_acorr + max_lag + 1)], b_w, label='grad. mag. $r_{s_t}}$') 77 | ax.axhline(y=0, ls='--', c='k') 78 | ax.set_title('Auto-Correlation of Energy Difference\nand Gradient Magnitude', fontsize=fontsize) 79 | ax.set_xlabel('lag', fontsize=fontsize) 80 | ax.set_ylabel('correlation', fontsize=fontsize) 81 | ax.legend(loc='upper right', fontsize=fontsize-4) 82 | 83 | # make diagnostic plots 84 | plot_en_diff_and_grad_mag() 85 | plot_crosscorr_and_autocorr() 86 | # save figure 87 | plt.subplots_adjust(hspace=0.6, wspace=0.6) 88 | plt.savefig(os.path.join(exp_dir, 'diagnosis_plot.pdf'), format='pdf') 89 | plt.close() 90 | 91 | 92 | ##################### 93 | # ## TOY DATASET ## # 94 | ##################### 95 | 96 | class ToyDataset(Dataset): 97 | def __init__(self, dataset_len=60000, toy_type='gmm', toy_groups=8, toy_sd=0.15, toy_radius=1, viz_res=500, kde_bw=0.05): 98 | # import helper functions 99 | from scipy.stats import gaussian_kde 100 | from scipy.stats import multivariate_normal 101 | self.gaussian_kde = gaussian_kde 102 | self.mvn = multivariate_normal 103 | 104 | # dataset class property 105 | self.dataset_len = dataset_len 106 | 107 | # toy dataset parameters 108 | self.toy_type = toy_type 109 | self.toy_groups = toy_groups 110 | self.toy_sd = toy_sd 111 | self.toy_radius = toy_radius 112 | self.weights = np.ones(toy_groups) / toy_groups 113 | if toy_type == 'gmm': 114 | means_x = np.cos(2*np.pi*np.linspace(0, (toy_groups-1)/toy_groups, toy_groups)).reshape(toy_groups, 1, 1, 1) 115 | means_y = np.sin(2*np.pi*np.linspace(0, (toy_groups-1)/toy_groups, toy_groups)).reshape(toy_groups, 1, 1, 1) 116 | self.means = toy_radius * np.concatenate((means_x, means_y), axis=1) 117 | else: 118 | self.means = None 119 | 120 | # ground truth density 121 | if self.toy_type == 'gmm': 122 | def true_density(x): 123 | density = 0 124 | for k in range(toy_groups): 125 | density += self.weights[k]*self.mvn.pdf(np.array([x[1], x[0]]), mean=self.means[k].squeeze(), 126 | cov=(self.toy_sd**2)*np.eye(2)) 127 | return density 128 | elif self.toy_type == 'rings': 129 | def true_density(x): 130 | radius = np.sqrt((x[1] ** 2) + (x[0] ** 2)) 131 | density = 0 132 | for k in range(toy_groups): 133 | density += self.weights[k] * self.mvn.pdf(radius, mean=self.toy_radius * (k + 1), 134 | cov=(self.toy_sd**2))/(2*np.pi*self.toy_radius*(k+1)) 135 | return density 136 | else: 137 | raise RuntimeError('Invalid option for toy_type (use "gmm" or "rings")') 138 | self.true_density = true_density 139 | 140 | # viz parameters 141 | self.viz_res = viz_res 142 | self.kde_bw = kde_bw 143 | if toy_type == 'rings': 144 | self.plot_val_max = toy_groups * toy_radius + 4 * toy_sd 145 | else: 146 | self.plot_val_max = toy_radius + 4 * toy_sd 147 | 148 | # save values for plotting groundtruth landscape 149 | self.xy_plot = np.linspace(-self.plot_val_max, self.plot_val_max, self.viz_res) 150 | self.z_true_density = np.zeros(self.viz_res**2).reshape(self.viz_res, self.viz_res) 151 | for x_ind in range(len(self.xy_plot)): 152 | for y_ind in range(len(self.xy_plot)): 153 | self.z_true_density[x_ind, y_ind] = self.true_density([self.xy_plot[x_ind], self.xy_plot[y_ind]]) 154 | 155 | @property 156 | def tile_side(self): 157 | return self.xy_plot[1] - self.xy_plot[0] 158 | @property 159 | def plot_side(self): 160 | return np.abs(2*self.plot_val_max) 161 | 162 | def __len__(self): 163 | return self.dataset_len 164 | 165 | def __getitem__(self, idx): 166 | return (self.sample_toy_data(1).squeeze(axis=0), 0) # (example, label) 167 | 168 | def sample_toy_data(self, num_samples): 169 | toy_sample = np.zeros(0).reshape(0, 2, 1, 1) 170 | sample_group_sz = np.random.multinomial(num_samples, self.weights) 171 | if self.toy_type == 'gmm': 172 | for i in range(self.toy_groups): 173 | sample_group = self.means[i] + self.toy_sd * np.random.randn(2*sample_group_sz[i]).reshape(-1, 2, 1, 1) 174 | toy_sample = np.concatenate((toy_sample, sample_group), axis=0) 175 | elif self.toy_type == 'rings': 176 | for i in range(self.toy_groups): 177 | sample_radii = self.toy_radius*(i+1) + self.toy_sd * np.random.randn(sample_group_sz[i]) 178 | sample_thetas = 2 * np.pi * np.random.random(sample_group_sz[i]) 179 | sample_x = sample_radii.reshape(-1, 1) * np.cos(sample_thetas).reshape(-1, 1) 180 | sample_y = sample_radii.reshape(-1, 1) * np.sin(sample_thetas).reshape(-1, 1) 181 | sample_group = np.concatenate((sample_x, sample_y), axis=1) 182 | toy_sample = np.concatenate((toy_sample, sample_group.reshape(-1, 2, 1, 1)), axis=0) 183 | else: 184 | raise RuntimeError('Invalid option for toy_type ("gmm" or "rings")') 185 | 186 | return toy_sample 187 | 188 | def ebm_learned_energy(self, f): 189 | xy_plot_torch = t.Tensor(self.xy_plot).view(-1, 1, 1, 1).to(next(f.parameters()).device) 190 | # y values for learned energy landscape of descriptor network 191 | z_learned_energy = np.zeros([self.viz_res, self.viz_res]) 192 | for i in range(len(self.xy_plot)): 193 | y_vals = float(self.xy_plot[i]) * t.ones_like(xy_plot_torch) 194 | vals = t.cat((xy_plot_torch, y_vals), 1) 195 | z_learned_energy[i] = f(vals).data.cpu().numpy() 196 | 197 | return z_learned_energy 198 | 199 | def plot_learned_energy_surf(self, f, mcmc_lr): 200 | # Learned energy 201 | z_learned_energy = self.ebm_learned_energy(f) 202 | 203 | fig = plt.figure(figsize=(10,10)) 204 | ax = fig.gca(projection='3d') 205 | 206 | # Make data. 207 | X = self.xy_plot 208 | Y = self.xy_plot 209 | X, Y = np.meshgrid(X, Y) 210 | Z = z_learned_energy 211 | 212 | # Plot the surface. 213 | surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm, 214 | linewidth=0, antialiased=False, alpha=0.7) 215 | 216 | # Customize the z axis. 217 | ax.zaxis.set_major_locator(LinearLocator(10)) 218 | ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f')) 219 | 220 | # Rotate plot 221 | ax.view_init(30, 30) # Rotation of the 3d plot 222 | ax.set_title(f"Energy landscape. $\eta={mcmc_lr:.0e}$") 223 | 224 | 225 | # Add a color bar which maps values to colors. 226 | fig.colorbar(surf, shrink=0.5, aspect=5) 227 | 228 | plt.show() 229 | 230 | 231 | def ebm_learned_density(self, f, epsilon=0.0): 232 | z_learned_energy = self.ebm_learned_energy(f) 233 | 234 | # transform learned energy into learned density 235 | z_learned_density_unnormalized = np.exp(- (z_learned_energy - np.min(z_learned_energy))) 236 | bin_area = (self.xy_plot[1] - self.xy_plot[0]) ** 2 237 | z_learned_density = z_learned_density_unnormalized / (bin_area * np.sum(z_learned_density_unnormalized)) 238 | 239 | return z_learned_density 240 | 241 | def ebm_kl_divergence(self, f): 242 | """Compute KL[p || q]""" 243 | p = self.z_true_density 244 | q = self.ebm_learned_density(f) 245 | bin_area = (self.xy_plot[1] - self.xy_plot[0]) ** 2 246 | return bin_area * np.sum(np.where(p != 0, p * np.log(p / q), 0)) 247 | 248 | 249 | def plot_toy_density(self, plot_truth=False, f=None, epsilon=0.0, x_s_t=None, save_path='toy.pdf'): 250 | num_plots = 0 251 | if plot_truth: 252 | num_plots += 1 253 | 254 | # density of learned EBM 255 | if f is not None: 256 | num_plots += 1 257 | xy_plot_torch = t.Tensor(self.xy_plot).view(-1, 1, 1, 1).to(next(f.parameters()).device) 258 | # y values for learned energy landscape of descriptor network 259 | z_learned_energy = np.zeros([self.viz_res, self.viz_res]) 260 | for i in range(len(self.xy_plot)): 261 | y_vals = float(self.xy_plot[i]) * t.ones_like(xy_plot_torch) 262 | vals = t.cat((xy_plot_torch, y_vals), 1) 263 | z_learned_energy[i] = f(vals).data.cpu().numpy() 264 | 265 | # transform learned energy into learned density 266 | z_learned_density_unnormalized = np.exp(- (z_learned_energy - np.min(z_learned_energy))) 267 | bin_area = (self.xy_plot[1] - self.xy_plot[0]) ** 2 268 | z_learned_density = z_learned_density_unnormalized / (bin_area * np.sum(z_learned_density_unnormalized)) 269 | 270 | # kernel density estimate of shortrun samples 271 | if x_s_t is not None: 272 | num_plots += 1 273 | density_estimate = self.gaussian_kde(x_s_t.squeeze().cpu().numpy().transpose(), bw_method=self.kde_bw) 274 | z_kde_density = np.zeros([self.viz_res, self.viz_res]) 275 | for i in range(len(self.xy_plot)): 276 | for j in range(len(self.xy_plot)): 277 | z_kde_density[i, j] = density_estimate((self.xy_plot[j], self.xy_plot[i])) 278 | 279 | # plot results 280 | plot_ind = 0 281 | fig = plt.figure() 282 | 283 | # true density 284 | if plot_truth: 285 | plot_ind += 1 286 | ax = fig.add_subplot(2, num_plots, plot_ind) 287 | ax.set_title('True density') 288 | plt.imshow(self.z_true_density, cmap='viridis') 289 | plt.axis('off') 290 | ax = fig.add_subplot(2, num_plots, plot_ind + num_plots) 291 | ax.set_title('True log-density') 292 | plt.imshow(np.log(self.z_true_density + 1e-10), cmap='viridis') 293 | plt.axis('off') 294 | # learned ebm 295 | if f is not None: 296 | plot_ind += 1 297 | ax = fig.add_subplot(2, num_plots, plot_ind) 298 | ax.set_title('EBM density') 299 | plt.imshow(z_learned_density, cmap='viridis') 300 | plt.axis('off') 301 | ax = fig.add_subplot(2, num_plots, plot_ind + num_plots) 302 | ax.set_title('EBM log-density') 303 | plt.imshow(np.log(z_learned_density + 1e-10), cmap='viridis') 304 | plt.axis('off') 305 | # shortrun kde 306 | if x_s_t is not None: 307 | plot_ind += 1 308 | ax = fig.add_subplot(2, num_plots, plot_ind) 309 | ax.set_title('Short-run KDE') 310 | plt.imshow(z_kde_density, cmap='viridis') 311 | plt.axis('off') 312 | ax = fig.add_subplot(2, num_plots, plot_ind + num_plots) 313 | ax.set_title('Short-run log-KDE') 314 | plt.imshow(np.log(z_kde_density + 1e-10), cmap='viridis') 315 | plt.axis('off') 316 | 317 | plt.tight_layout() 318 | if save_path is None: 319 | plt.show() 320 | else: 321 | plt.savefig(save_path, bbox_inches='tight', format='pdf') 322 | 323 | plt.close() 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | ################################### 332 | ### Kolmogorov Smirnov distance ### 333 | ################################### 334 | 335 | def ksDist(ebmModel, trainSet, n_samples=1000, benchmark=False): 336 | """Computes Kolmogorov-Smirnov distance. Wrapper of ks2d2s 337 | If benchmark = True, computes p-value and D from two samples of 338 | the true dist. 339 | Returns: 340 | - Two-tailed (approximated) p-value. 341 | - KS statistic (dist). 342 | - pval, KS dist of two samples of the true dist: optional. 343 | 344 | Small p-values means that the two samples are significantly different. 345 | Note that the p-value is only an approximation as the analytic distribution is unkonwn. 346 | The approximation is accurate enough when N > ~20 and p-value < ~0.20 or so. 347 | When p-value > 0.20, the value may not be accurate, but it certainly implies that the two 348 | samples are not significantly different. (cf. Press 2007) 349 | """ 350 | # Sample from fitted density 351 | negative_samples = ebmModel.generate_samples(evaluation=True, batch_size=n_samples) 352 | neg_samples = negative_samples.cpu().numpy().squeeze(-1).squeeze(-1) 353 | 354 | # Sample from ground truth density 355 | positive_samples = trainSet.sample_toy_data(n_samples).squeeze(-1).squeeze(-1) 356 | 357 | 358 | pval, d = ks2d2s(positive_samples[:, 0], 359 | positive_samples[:, 1], 360 | neg_samples[:, 0], 361 | neg_samples[:, 1], 362 | extra=True) 363 | 364 | if benchmark: 365 | # Recycle the name "neg_samples" 366 | neg_samples = trainSet.sample_toy_data(n_samples).squeeze(-1).squeeze(-1) 367 | pval_b, d_b = ks2d2s(positive_samples[:, 0], 368 | positive_samples[:, 1], 369 | neg_samples[:, 0], 370 | neg_samples[:, 1], 371 | extra=True) 372 | return pval, d, pval_b, d_b 373 | # if not benchmark 374 | return pval, d 375 | 376 | 377 | def ks2d2s(x1, y1, x2, y2, extra=False): 378 | '''Two-dimensional Kolmogorov-Smirnov test on two samples. 379 | Parameters 380 | ---------- 381 | x1, y1 : ndarray, shape (n1, ) 382 | Data of sample 1. 383 | x2, y2 : ndarray, shape (n2, ) 384 | Data of sample 2. Size of two samples can be different. 385 | extra: bool, optional 386 | If True, KS statistic is also returned. Default is False. 387 | Returns 388 | ------- 389 | p : float 390 | Two-tailed p-value. 391 | D : float, optional 392 | KS statistic. Returned if keyword `extra` is True. 393 | Notes 394 | ----- 395 | This is the two-sided K-S test. Small p-values means that the two samples are significantly different. 396 | Note that the p-value is only an approximation as the analytic distribution is unkonwn. 397 | The approximation is accurate enough when N > ~20 and p-value < ~0.20 or so. 398 | When p-value > 0.20, the value may not be accurate, but it certainly implies that the two 399 | samples are not significantly different. (cf. Press 2007) 400 | References 401 | ---------- 402 | https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=&ved=2ahUKEwj8j8nm7NfwAhWJ2hQKHcdSAkoQFjAAegQIAxAD&url=https%3A%2F%2Faip.scitation.org%2Fdoi%2Fpdf%2F10.1063%2F1.4822753&usg=AOvVaw0MJ3m8vCKG1h3RzVmqOuKT 403 | Peacock, J.A. 1983, Two-Dimensional Goodness-of-Fit Testing in Astronomy, 404 | Monthly Notices of the Royal Astronomical Society, vol. 202, pp. 615-627 405 | Fasano, G. and Franceschini, A. 1987, A Multidimensional Version of the Kolmogorov-Smirnov Test, 406 | Monthly Notices of the Royal Astronomical Society, vol. 225, pp. 155-170 407 | Press, W.H. et al. 2007, Numerical Recipes, section 14.8 408 | ''' 409 | assert (len(x1) == len(y1)) and (len(x2) == len(y2)) 410 | n1, n2 = len(x1), len(x2) 411 | D = avgmaxdist(x1, y1, x2, y2) 412 | 413 | sqen = np.sqrt(n1 * n2 / (n1 + n2)) 414 | r1 = pearsonr(x1, y1)[0] 415 | r2 = pearsonr(x2, y2)[0] 416 | r = np.sqrt(1. - 0.5 * (r1**2 + r2**2)) 417 | d = D * sqen / (1. + r * (0.25 - 0.75 / sqen)) 418 | p = kstwobign.sf(d) 419 | 420 | if extra: 421 | return p, D 422 | else: 423 | return p 424 | 425 | 426 | def avgmaxdist(x1, y1, x2, y2): 427 | D1 = maxdist(x1, y1, x2, y2) 428 | D2 = maxdist(x2, y2, x1, y1) 429 | return (D1 + D2) / 2 430 | 431 | 432 | def maxdist(x1, y1, x2, y2): 433 | n1 = len(x1) 434 | D1 = 0.0 435 | for i in range(n1): 436 | a1, b1, c1, d1 = quadct(x1[i], y1[i], x1, y1) 437 | a2, b2, c2, d2 = quadct(x1[i], y1[i], x2, y2) 438 | D1 = np.max([D1, np.abs(a1-a2), np.abs(b1-b2), np.abs(c1-c2), np.abs(d1-d2)]) 439 | return D1 440 | 441 | def quadct(x, y, xx, yy): 442 | n = len(xx) 443 | ix1, ix2 = yy >= y, xx >= x 444 | a = np.sum(ix1 & ix2) / n 445 | b = np.sum(ix1 & ~ix2) / n 446 | c = np.sum(~ix1 & ix2) / n 447 | d = 1 - a - b - c 448 | return a, b, c, d 449 | 450 | 451 | 452 | 453 | ####################################################### 454 | def scheduler_stats(model, mcmc_steps_schedule, train_set_len, target_iters): 455 | batches_per_epoch = int(train_set_len / model.batch_size) 456 | print(f"batches_per_epoch: {batches_per_epoch}") 457 | target_iterations = target_iters 458 | epochs_2_target = int(np.ceil(target_iterations / batches_per_epoch)) 459 | print(f"epochs_2_target: <= {epochs_2_target}") 460 | effective_tot_iters = epochs_2_target * batches_per_epoch 461 | print(f"effective_tot_iters: {effective_tot_iters}") 462 | area = 0 463 | for i in range(len(mcmc_steps_schedule)): 464 | if i == 0: 465 | prev = 0 466 | else: 467 | prev = mcmc_steps_schedule[i-1][0] 468 | area += (mcmc_steps_schedule[i][0] - prev) * mcmc_steps_schedule[i][1] 469 | area += (effective_tot_iters - mcmc_steps_schedule[i][0]) * mcmc_steps_schedule[i][1] 470 | avg_mcmc_steps_per_iter = area / effective_tot_iters 471 | print(f"avg_mcmc_steps_per_iter: {avg_mcmc_steps_per_iter:.1f}") 472 | return epochs_2_target 473 | --------------------------------------------------------------------------------