├── .gitignore ├── LICENSE ├── README.md ├── data └── semi_supervised-vae │ └── test_train_splits │ ├── get_mnist_train_validation_split.ipynb │ ├── labeled_train_indx.npy │ ├── unlabeled_train_indx.npy │ └── validation_indx.npy ├── experiments ├── bit_vector-vae │ ├── archs.py │ ├── data.py │ ├── opts.py │ ├── test.py │ └── train.py ├── semi_supervised-vae │ ├── archs.py │ ├── data.py │ ├── opts.py │ ├── test.py │ └── train.py └── signal-game │ ├── archs.py │ ├── data.py │ ├── opts.py │ ├── test.py │ └── train.py ├── lvmhelpers ├── __init__.py ├── bernoulli.h ├── binary_topk.h ├── budget.h ├── gumbel.py ├── marg.py ├── nvil.py ├── pbernoulli.cpp ├── pbernoulli.pxd ├── pbernoulli.pyx ├── pbinary_topk.cpp ├── pbinary_topk.pyx ├── pbudget.cpp ├── pbudget.pxd ├── pbudget.pyx ├── psequence.cpp ├── psequence.pxd ├── psequence.pyx ├── sequence.h ├── sequence_binary.h ├── sfe.py ├── sparsemap.py ├── structmarg.py ├── sum_and_sample.py └── utils.py ├── requirements.txt ├── scripts ├── bit_vector │ ├── bit_vector_gs.sh │ ├── bit_vector_gs_128.sh │ ├── bit_vector_gs_st.sh │ ├── bit_vector_gs_st_128.sh │ ├── bit_vector_nvil.sh │ ├── bit_vector_nvil_128.sh │ ├── bit_vector_sfe.sh │ ├── bit_vector_sfe_128.sh │ ├── bit_vector_sfe_plus.sh │ ├── bit_vector_sfe_plus_128.sh │ ├── bit_vector_sparsemap.sh │ ├── bit_vector_sparsemap_128.sh │ ├── bit_vector_sparsemap_budget.sh │ ├── bit_vector_sparsemap_budget_128.sh │ ├── bit_vector_topksparse.sh │ └── bit_vector_topksparse_128.sh ├── signal_game │ ├── signal_game_gs_seeds.sh │ ├── signal_game_gs_st_seeds.sh │ ├── signal_game_marg_softmax_seeds.sh │ ├── signal_game_marg_sparsemax_seeds.sh │ ├── signal_game_nvil_seeds.sh │ ├── signal_game_sfe_nll_seeds.sh │ ├── signal_game_sfe_plus_seeds.sh │ └── signal_game_sfe_seeds.sh ├── ssvae │ ├── ssvae_gumbel_seeds.sh │ ├── ssvae_gumbel_st_seeds.sh │ ├── ssvae_marg_softmax_seeds.sh │ ├── ssvae_marg_sparsemax_seeds.sh │ ├── ssvae_nvil_seeds.sh │ ├── ssvae_sfe_plus_seeds.sh │ ├── ssvae_sfe_seeds.sh │ └── ssvae_sumsample_seeds.sh ├── ssvae_warm_start_softmax.sh └── ssvae_warm_start_sparsemax.sh └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # vim swap files 7 | *.swp 8 | *.swo 9 | 10 | data/ 11 | .idea/ 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | 112 | # other 113 | build/ 114 | default/ 115 | logs/ 116 | checkpoints/ 117 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 DeepSPIN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Efficient Marginalization of Discrete and Structured Latent Variables via Sparsity - Official PyTorch implementation of the NeurIPS 2020 paper 2 | 3 | **Gonçalo M. Correia** (Instituto de Telecomunicações), **Vlad Niculae** (IvI, University of Amsterdam), **Wilker Aziz** (ILLC, University of Amsterdam), **André F. T. Martins** (Instituto de Telecomunicações, Unbabel, LUMLIS) 4 | 5 | 6 | **Abstract**: 7 | 8 | _Training neural network models with discrete (categorical or structured) latent variables can be computationally challenging, due to the need for marginalization over large or combinatorial sets. To circumvent this issue, one typically resorts to sampling-based approximations of the true marginal, requiring noisy gradient estimators (e.g., score function estimator) or continuous relaxations with lower-variance reparameterized gradients (e.g., Gumbel-Softmax). In this paper, we propose a new training strategy which replaces these estimators by an exact yet efficient marginalization. To achieve this, we parameterize discrete distributions over latent assignments using differentiable sparse mappings: sparsemax and its structured counterparts. In effect, the support of these distributions is greatly reduced, which enables efficient marginalization. We report successful results in three tasks covering a range of latent variable modeling applications: a semisupervised deep generative model, a latent communication game, and a generative model with a bit-vector latent representation. In all cases, we obtain good performance while still achieving the practicality of sampling-based approximations._ 9 | 10 | ## Resources 11 | 12 | - [Paper](https://arxiv.org/abs/2007.01919) (arXiv) 13 | 14 | All material is made available under the MIT license. You can **use, redistribute, and adapt** the material for **non-commercial purposes**, as long as you give appropriate credit by **citing our paper** and **indicating any changes** that you've made. 15 | 16 | ## Python requirements and installation 17 | 18 | This code was tested on `Python 3.7.1`. To install, follow these steps: 19 | 20 | 1. In a virtual environment, first install Cython: `pip install cython` 21 | 2. Clone the [Eigen](https://gitlab.com/libeigen/eigen) repository to your home: `git clone git@gitlab.com:libeigen/eigen.git` 22 | 3. Clone the [LP-SparseMAP](https://github.com/deep-spin/lp-sparsemap) repository to your home, and follow the installation instructions found there 23 | 4. Install PyTorch: `pip install torch` (we used version 1.6.0) 24 | 5. Install the requirements: `pip install -r requirements.txt` 25 | 6. Install the `lvm-helpers` package: `pip install .` (or in editable mode if you want to make changes: `pip install -e .`) 26 | 27 | ## Datasets 28 | 29 | MNIST and FMNIST should be downloaded automatically by running the training commands for the first time on the semi-supervised VAE and the bit-vector VAE experiments, respectively. To get the dataset for the emergent communication game, plese visit: https://github.com/DianeBouchacourt/SignalingGame. After getting the data, store the `train` and `test` folders under `data/signal-game` of this repository. 30 | 31 | ## Running 32 | 33 | **Training**: 34 | 35 | To get a warm start for the semi-supervised VAE experiment (use `softmax` normalizer for all experiments that do not use sparsemax): 36 | 37 | ``` 38 | python experiments/semi_supervised-vae/train.py \ 39 | --n_epochs 100 \ 40 | --lr 1e-3 \ 41 | --labeled_only \ 42 | --normalizer sparsemax \ 43 | --batch_size 64 44 | ``` 45 | 46 | To train with sparsemax on the semi-supervised VAE experiment (after getting a warm start checkpoint): 47 | 48 | ``` 49 | python experiments/semi_supervised-vae/train.py \ 50 | --mode marg \ 51 | --normalizer sparsemax \ 52 | --random_seed 42 \ 53 | --lr 5e-4 \ 54 | --batch_size 64 \ 55 | --n_epochs 200 \ 56 | --latent_size 10 \ 57 | --warm_start_path /path/to/warm_start/ 58 | ``` 59 | 60 | To train with sparsemax on the emergent communication experiment: 61 | 62 | ``` 63 | python experiments/signal-game/train.py \ 64 | --mode marg \ 65 | --normalizer sparsemax \ 66 | --lr 0.005 \ 67 | --entropy_coeff 0.1 \ 68 | --batch_size 64 \ 69 | --n_epochs 500 \ 70 | --game_size 16 \ 71 | --latent_size 256 \ 72 | --embedding_size 256 \ 73 | --hidden_size 512 \ 74 | --weight_decay 0. \ 75 | --random_seed 42 76 | done 77 | ``` 78 | 79 | To train with SparseMAP, on the bit-vector VAE experiment, on 32 bits: 80 | 81 | ``` 82 | python experiments/bit_vector-vae/train.py \ 83 | --mode sparsemap \ 84 | --lr 0.002 \ 85 | --batch_size 64 \ 86 | --n_epochs 100 \ 87 | --latent_size 32 \ 88 | --weight_decay 0. \ 89 | --random_seed 42 90 | ``` 91 | 92 | To train with top-k sparsemax, on the bit-vector VAE experiment, on 32 bits: 93 | 94 | ``` 95 | python experiments/bit_vector-vae/train.py \ 96 | --mode topksparse \ 97 | --lr 0.002 \ 98 | --batch_size 64 \ 99 | --n_epochs 100 \ 100 | --latent_size 32 \ 101 | --weight_decay 0. \ 102 | --random_seed 42 103 | ``` 104 | 105 | **Evaluating**: 106 | 107 | To evaluate any trained network against one of the test sets, run: 108 | 109 | ``` 110 | python experiments/semi_supervised-vae/test.py /path/to/checkpoint/ /path/to/hparams.yaml 111 | ``` 112 | 113 | Replace `semi_supervised-vae` by `signal-game` or `bit_vector-vae` to get test results in a different experiment. Checkpoints should be found in the appropriate folder inside the automatically generated `checkpoints` directory, and the `yaml` file should be found in the model's automatically generated directory inside `logs`. 114 | 115 | The evaluation results should match the paper. 116 | 117 | ## Citing 118 | 119 | If you use this codebase in your work, please cite: 120 | 121 | ``` 122 | @inproceedings{correia2020efficientmarg, 123 | title = {Efficient {{Marginalization}} of {{Discrete}} and {{Structured Latent Variables}} via {{Sparsity}}}, 124 | booktitle = {Proc. {{NeurIPS}}}, 125 | author = {Correia, Gon{\c c}alo M. and Niculae, Vlad and Aziz, Wilker and Martins, Andr{\'e} F. T.}, 126 | year = {2020}, 127 | url = {http://arxiv.org/abs/2007.01919} 128 | } 129 | ``` 130 | 131 | ## Acknowledgements 132 | 133 | This work was partly funded by the European Research Council (ERC StG DeepSPIN 758969), by the P2020 project MAIA (contract 045909), and by the Fundação para a Ciência e Tecnologia through contract UIDB/50008/2020. This work also received funding from the European Union’s Horizon 2020 research and innovation programme under grant agreement 825299 (GoURMET). 134 | 135 | The code in this repository was largely inspired by the structure and implementations found in [EGG](https://github.com/facebookresearch/EGG) and was built upon it. EGG is copyright (c) Facebook, Inc. and its affiliates. 136 | -------------------------------------------------------------------------------- /data/semi_supervised-vae/test_train_splits/get_mnist_train_validation_split.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import torch\n", 13 | "\n", 14 | "import torchvision.datasets as dset\n", 15 | "import torchvision.transforms as transforms" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 4, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "def load_mnist_data(data_dir='../mnist_data/', train=True):\n", 25 | " if not os.path.exists(data_dir):\n", 26 | " print('creaing folder: ', data_dir)\n", 27 | " os.mkdir(data_dir)\n", 28 | "\n", 29 | " def trans(x):\n", 30 | " return transforms.ToTensor()(x).bernoulli()\n", 31 | "\n", 32 | " data = dset.MNIST(\n", 33 | " root=data_dir, train=train,\n", 34 | " transform=trans, download=True)\n", 35 | "\n", 36 | " return data\n" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "# Load data" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 5, 49 | "metadata": { 50 | "collapsed": false 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "np.random.seed(351245)\n", 55 | "_ = torch.manual_seed(453453)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 6, 61 | "metadata": { 62 | "collapsed": false, 63 | "tags": [] 64 | }, 65 | "outputs": [ 66 | { 67 | "output_type": "stream", 68 | "name": "stderr", 69 | "text": "0it [00:00, ?it/s]creaing folder: ../mnist_data/\nDownloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../mnist_data/MNIST/raw/train-images-idx3-ubyte.gz\n9920512it [00:05, 1736860.32it/s]\nExtracting ../mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ../mnist_data/MNIST/raw\n0it [00:00, ?it/s]Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz\n32768it [00:01, 16895.42it/s]\n0it [00:00, ?it/s]Extracting ../mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ../mnist_data/MNIST/raw\nDownloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz\n1654784it [00:01, 1073456.69it/s]\n0it [00:00, ?it/s]Extracting ../mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../mnist_data/MNIST/raw\nDownloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n8192it [00:00, 15393.64it/s]\nExtracting ../mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../mnist_data/MNIST/raw\nProcessing...\nDone!\n" 70 | } 71 | ], 72 | "source": [ 73 | "mnist_train_data = load_mnist_data(\n", 74 | " data_dir='../mnist_data/', train=True)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 7, 80 | "metadata": { 81 | "collapsed": false 82 | }, 83 | "outputs": [], 84 | "source": [ 85 | "n_train_total = len(mnist_train_data)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 8, 91 | "metadata": { 92 | "collapsed": false, 93 | "tags": [] 94 | }, 95 | "outputs": [ 96 | { 97 | "output_type": "stream", 98 | "name": "stdout", 99 | "text": "60000\n" 100 | } 101 | ], 102 | "source": [ 103 | "all_indx = np.arange(n_train_total)\n", 104 | "print(len(all_indx))" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 9, 110 | "metadata": { 111 | "collapsed": false, 112 | "scrolled": true, 113 | "tags": [] 114 | }, 115 | "outputs": [ 116 | { 117 | "output_type": "stream", 118 | "name": "stdout", 119 | "text": "10000\n" 120 | } 121 | ], 122 | "source": [ 123 | "# get validation index\n", 124 | "validation_indx = np.random.choice(all_indx, 10000, replace = False) \n", 125 | "print(len(validation_indx))" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 10, 131 | "metadata": { 132 | "collapsed": false, 133 | "tags": [] 134 | }, 135 | "outputs": [ 136 | { 137 | "output_type": "stream", 138 | "name": "stdout", 139 | "text": "50000\n" 140 | } 141 | ], 142 | "source": [ 143 | "# remove validation observations \n", 144 | "train_bool = np.ones(n_train_total)\n", 145 | "train_bool[validation_indx] = 0.\n", 146 | "train_indx = all_indx[train_bool == 1]\n", 147 | "print(len(train_indx))" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 11, 153 | "metadata": { 154 | "collapsed": false, 155 | "tags": [] 156 | }, 157 | "outputs": [ 158 | { 159 | "output_type": "stream", 160 | "name": "stdout", 161 | "text": "5000\n" 162 | } 163 | ], 164 | "source": [ 165 | "labeled_train_indx = np.random.choice(train_indx, 5000, replace = False)\n", 166 | "print(len(labeled_train_indx))" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 12, 172 | "metadata": { 173 | "collapsed": false 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "train_bool[labeled_train_indx] = 0." 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 13, 183 | "metadata": { 184 | "collapsed": false, 185 | "tags": [] 186 | }, 187 | "outputs": [ 188 | { 189 | "output_type": "stream", 190 | "name": "stdout", 191 | "text": "45000\n" 192 | } 193 | ], 194 | "source": [ 195 | "unlabeled_train_indx = all_indx[train_bool == 1] \n", 196 | "print(len(unlabeled_train_indx))" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 14, 202 | "metadata": { 203 | "collapsed": false, 204 | "tags": [] 205 | }, 206 | "outputs": [ 207 | { 208 | "output_type": "stream", 209 | "name": "stdout", 210 | "text": "[]\n" 211 | } 212 | ], 213 | "source": [ 214 | "print(np.intersect1d(validation_indx, labeled_train_indx))\n", 215 | "assert len(np.intersect1d(validation_indx, labeled_train_indx)) == 0" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 15, 221 | "metadata": { 222 | "collapsed": false, 223 | "tags": [] 224 | }, 225 | "outputs": [ 226 | { 227 | "output_type": "stream", 228 | "name": "stdout", 229 | "text": "[]\n" 230 | } 231 | ], 232 | "source": [ 233 | "print(np.intersect1d(validation_indx, unlabeled_train_indx))\n", 234 | "assert len(np.intersect1d(validation_indx, unlabeled_train_indx)) == 0" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 16, 240 | "metadata": { 241 | "collapsed": false, 242 | "tags": [] 243 | }, 244 | "outputs": [ 245 | { 246 | "output_type": "stream", 247 | "name": "stdout", 248 | "text": "[]\n" 249 | } 250 | ], 251 | "source": [ 252 | "print(np.intersect1d(labeled_train_indx, unlabeled_train_indx))\n", 253 | "assert len(np.intersect1d(validation_indx, unlabeled_train_indx)) == 0" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 17, 259 | "metadata": { 260 | "collapsed": false 261 | }, 262 | "outputs": [], 263 | "source": [ 264 | "len(np.union1d(np.union1d(labeled_train_indx, unlabeled_train_indx), validation_indx))\n", 265 | "assert len(np.union1d(np.union1d(labeled_train_indx, unlabeled_train_indx), validation_indx)) == 60000" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 19, 271 | "metadata": { 272 | "collapsed": true 273 | }, 274 | "outputs": [], 275 | "source": [ 276 | "np.save('validation_indx.npy', validation_indx)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 20, 282 | "metadata": { 283 | "collapsed": true 284 | }, 285 | "outputs": [], 286 | "source": [ 287 | "np.save('unlabeled_train_indx.npy', unlabeled_train_indx)" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 21, 293 | "metadata": { 294 | "collapsed": true 295 | }, 296 | "outputs": [], 297 | "source": [ 298 | "np.save('labeled_train_indx.npy', labeled_train_indx)" 299 | ] 300 | } 301 | ], 302 | "metadata": { 303 | "kernelspec": { 304 | "display_name": "Python 3.7.5 64-bit ('py371': virtualenv)", 305 | "language": "python", 306 | "name": "python37564bitpy371virtualenv6b21f57c175a4942a2e9f7565c856c98" 307 | }, 308 | "language_info": { 309 | "codemirror_mode": { 310 | "name": "ipython", 311 | "version": 3 312 | }, 313 | "file_extension": ".py", 314 | "mimetype": "text/x-python", 315 | "name": "python", 316 | "nbconvert_exporter": "python", 317 | "pygments_lexer": "ipython3", 318 | "version": "3.7.5-final" 319 | } 320 | }, 321 | "nbformat": 4, 322 | "nbformat_minor": 2 323 | } -------------------------------------------------------------------------------- /data/semi_supervised-vae/test_train_splits/labeled_train_indx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/sparse-marginalization-lvm/9d27ec0fd58ea2eb58d79bca23a0051f1b143ebe/data/semi_supervised-vae/test_train_splits/labeled_train_indx.npy -------------------------------------------------------------------------------- /data/semi_supervised-vae/test_train_splits/unlabeled_train_indx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/sparse-marginalization-lvm/9d27ec0fd58ea2eb58d79bca23a0051f1b143ebe/data/semi_supervised-vae/test_train_splits/unlabeled_train_indx.npy -------------------------------------------------------------------------------- /data/semi_supervised-vae/test_train_splits/validation_indx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/sparse-marginalization-lvm/9d27ec0fd58ea2eb58d79bca23a0051f1b143ebe/data/semi_supervised-vae/test_train_splits/validation_indx.npy -------------------------------------------------------------------------------- /experiments/bit_vector-vae/archs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MLP(torch.nn.Sequential): 5 | def __init__(self, dim_in, dim_hid, dim_out, n_layers): 6 | super().__init__() 7 | Nonlin = torch.nn.ReLU 8 | self.add_module("layer0", torch.nn.Linear(dim_in, dim_hid)) 9 | self.add_module("act0", Nonlin()) 10 | for i in range(1, n_layers + 1): 11 | self.add_module(f"layer{i}", torch.nn.Linear(dim_hid, dim_hid)) 12 | self.add_module(f"act{i}", Nonlin()) 13 | self.add_module( 14 | f"layer{n_layers+1}", torch.nn.Linear(dim_hid, dim_out) 15 | ) 16 | 17 | 18 | class CategoricalGenerator(torch.nn.Module): 19 | def __init__(self, gen, n_features, out_rank, n_classes): 20 | super().__init__() 21 | self.gen = gen 22 | self.n_features = n_features 23 | self.out_rank = out_rank 24 | self.out = torch.nn.Linear(out_rank, n_classes) 25 | 26 | def forward(self, Z, *args): 27 | X = self.gen(Z) 28 | X = X.reshape(-1, self.n_features, self.out_rank) 29 | return self.out(X) 30 | -------------------------------------------------------------------------------- /experiments/bit_vector-vae/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def transform(pic): 5 | return torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 6 | -------------------------------------------------------------------------------- /experiments/bit_vector-vae/opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def populate_experiment_params( 5 | arg_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 6 | return arg_parser 7 | -------------------------------------------------------------------------------- /experiments/bit_vector-vae/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | 4 | from train import VAE 5 | 6 | 7 | def main(checkpoint_path, hparams_path): 8 | 9 | model = VAE.load_from_checkpoint( 10 | checkpoint_path=checkpoint_path, 11 | hparams_file=hparams_path, 12 | map_location=None) 13 | 14 | trainer = pl.Trainer( 15 | progress_bar_refresh_rate=1, 16 | weights_summary='full', 17 | gpus=1 if torch.cuda.is_available() else 0, 18 | deterministic=True) 19 | 20 | # test (pass in the model) 21 | trainer.test(model) 22 | 23 | 24 | if __name__ == '__main__': 25 | import sys 26 | main(sys.argv[1], sys.argv[2]) 27 | -------------------------------------------------------------------------------- /experiments/bit_vector-vae/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.distributions import Bernoulli 7 | from torchvision import datasets 8 | import pytorch_lightning as pl 9 | from pytorch_lightning import loggers as pl_loggers 10 | 11 | from lvmhelpers.structmarg import \ 12 | TopKSparsemaxWrapper, TopKSparsemaxMarg, SparseMAPWrapper, SparseMAPMarg 13 | from lvmhelpers.sfe import \ 14 | BitVectorSFEWrapper, SFEDeterministicWrapper, \ 15 | BitVectorScoreFunctionEstimator 16 | from lvmhelpers.gumbel import \ 17 | BitVectorGumbelSoftmaxWrapper, BitVectorGumbel 18 | from lvmhelpers.nvil import \ 19 | BitVectorNVILWrapper, BitVectorNVIL 20 | from lvmhelpers.utils import DeterministicWrapper, populate_common_params 21 | 22 | from data import transform 23 | from archs import MLP, CategoricalGenerator 24 | from opts import populate_experiment_params 25 | 26 | 27 | class VAE(pl.LightningModule): 28 | def __init__( 29 | self, 30 | n_features, 31 | hidden_size, 32 | out_rank, 33 | out_classes, 34 | budget, 35 | init, 36 | mode, 37 | entropy_coeff, 38 | latent_size, 39 | normalizer, 40 | topksparse, 41 | gs_tau, 42 | temperature_decay, 43 | temperature_update_freq, 44 | straight_through, 45 | baseline_type, 46 | topk, 47 | random_seed, 48 | batch_size, 49 | lr, 50 | weight_decay, 51 | optimizer): 52 | super(VAE, self).__init__() 53 | 54 | self.save_hyperparameters() 55 | 56 | inf = MLP( 57 | dim_in=n_features, 58 | dim_hid=self.hparams.hidden_size, 59 | dim_out=self.hparams.latent_size, 60 | n_layers=0) 61 | gen = MLP( 62 | dim_in=self.hparams.latent_size, 63 | dim_hid=self.hparams.hidden_size, 64 | dim_out=self.hparams.n_features * self.hparams.out_rank, 65 | n_layers=0) 66 | gen = CategoricalGenerator( 67 | gen, 68 | n_features=self.hparams.n_features, 69 | out_rank=self.hparams.out_rank, 70 | n_classes=256) 71 | 72 | loss_fun = reconstruction_loss 73 | 74 | if self.hparams.mode == 'sfe': 75 | inf = BitVectorSFEWrapper( 76 | inf, baseline_type=self.hparams.baseline_type) 77 | gen = SFEDeterministicWrapper(gen) 78 | lvm_method = BitVectorScoreFunctionEstimator 79 | elif self.hparams.mode == 'nvil': 80 | inf = BitVectorNVILWrapper(inf, input_size=n_features) 81 | gen = DeterministicWrapper(gen) 82 | lvm_method = BitVectorNVIL 83 | elif self.hparams.mode == 'gs': 84 | inf = BitVectorGumbelSoftmaxWrapper( 85 | inf, 86 | temperature=self.hparams.gs_tau, 87 | straight_through=self.hparams.straight_through) 88 | gen = DeterministicWrapper(gen) 89 | lvm_method = BitVectorGumbel 90 | elif self.hparams.mode == 'topksparse': 91 | inf = TopKSparsemaxWrapper(inf, k=self.hparams.topksparse) 92 | gen = DeterministicWrapper(gen) 93 | lvm_method = TopKSparsemaxMarg 94 | elif self.hparams.mode == 'sparsemap': 95 | inf = SparseMAPWrapper( 96 | inf, budget=self.hparams.budget, init=self.hparams.init) 97 | gen = DeterministicWrapper(gen) 98 | lvm_method = SparseMAPMarg 99 | else: 100 | raise RuntimeError(f"Unknown training mode: {self.hparams.mode}") 101 | 102 | self.lvm_method = lvm_method( 103 | inf, 104 | gen, 105 | loss_fun, 106 | encoder_entropy_coeff=1.0) 107 | 108 | def forward(self, _inf_input, inf_input): 109 | return self.lvm_method(_inf_input, torch.zeros_like(_inf_input), inf_input) 110 | 111 | def training_step(self, batch, batch_nb): 112 | inf_input, _ = batch 113 | _inf_input = inf_input.to(dtype=torch.float) / 255 114 | training_result = self(_inf_input, inf_input) 115 | loss = training_result['loss'] 116 | 117 | result = pl.TrainResult(minimize=loss) 118 | elbo = \ 119 | - training_result['log']['loss'] + \ 120 | training_result['log']['encoder_entropy'] + \ 121 | self.hparams.latent_size * torch.log(torch.tensor(0.5)) 122 | result.log( 123 | '-train_elbo', 124 | -elbo, 125 | prog_bar=True, logger=True) 126 | 127 | if 'support' in training_result['log'].keys(): 128 | result.log( 129 | 'train_support_median', training_result['log']['support'], 130 | reduce_fx=torch.median, on_epoch=True, on_step=False) 131 | result.log( 132 | 'train_support_mean', torch.mean(training_result['log']['support']), 133 | prog_bar=True, reduce_fx=torch.mean, on_epoch=True, on_step=False) 134 | 135 | # Update temperature if Gumbel 136 | if self.hparams.mode == 'gs': 137 | self.lvm_method.encoder.update_temperature( 138 | self.global_step, 139 | self.hparams.temperature_update_freq, 140 | self.hparams.temperature_decay) 141 | result.log( 142 | 'temperature', self.lvm_method.encoder.temperature) 143 | 144 | return result 145 | 146 | def validation_step(self, batch, batch_nb): 147 | inf_input, _ = batch 148 | _inf_input = inf_input.to(dtype=torch.float) / 255 149 | validation_result = self(_inf_input, inf_input) 150 | 151 | elbo = \ 152 | - validation_result['log']['loss'] + \ 153 | validation_result['log']['encoder_entropy'] + \ 154 | self.hparams.latent_size * torch.log(torch.tensor(0.5)) 155 | 156 | if (self.current_epoch + 1) % 20 == 0: 157 | n_importance_samples = 64 158 | else: 159 | n_importance_samples = 16 160 | 161 | logp_x_bits, logp_x_nats, non_supp_influence = \ 162 | self.compute_importance_sampling( 163 | validation_result['log'], 164 | inf_input, 165 | n_importance_samples) 166 | 167 | result = pl.EvalResult(checkpoint_on=-elbo) 168 | result.log('-val_elbo', -elbo, prog_bar=True) 169 | result.log('val_logp_x_bits', logp_x_bits, prog_bar=True) 170 | result.log('val_non_supp_influence', non_supp_influence) 171 | 172 | if 'support' in validation_result['log'].keys(): 173 | result.log( 174 | 'val_support_median', validation_result['log']['support'], 175 | reduce_fx=torch.median) 176 | result.log( 177 | 'val_support_mean', torch.mean(validation_result['log']['support']), 178 | reduce_fx=torch.mean, on_epoch=True) 179 | 180 | return result 181 | 182 | def test_step(self, batch, batch_nb): 183 | inf_input, _ = batch 184 | _inf_input = inf_input.to(dtype=torch.float) / 255 185 | test_result = self(_inf_input, inf_input) 186 | 187 | elbo = \ 188 | - test_result['log']['loss'] + \ 189 | test_result['log']['encoder_entropy'] + \ 190 | self.hparams.latent_size * torch.log(torch.tensor(0.5)) 191 | 192 | logp_x_bits, logp_x_nats, non_supp_influence = \ 193 | self.compute_importance_sampling( 194 | test_result['log'], 195 | inf_input, 196 | 1024) 197 | 198 | result = pl.EvalResult(checkpoint_on=-elbo) 199 | result.log('-test_elbo', -elbo, prog_bar=True) 200 | result.log('test_logp_x_bits', logp_x_bits, prog_bar=True) 201 | result.log('test_non_supp_influence', non_supp_influence) 202 | result.log('test_distortion', test_result['log']['loss']) 203 | result.log( 204 | 'test_rate', 205 | - self.hparams.latent_size * torch.log(torch.tensor(0.5)) 206 | - test_result['log']['encoder_entropy']) 207 | 208 | if 'support' in test_result['log'].keys(): 209 | result.log( 210 | 'test_support_step', test_result['log']['support'], 211 | reduce_fx=torch.median) 212 | result.log( 213 | 'test_support_mean', torch.mean(test_result['log']['support']), 214 | reduce_fx=torch.mean, on_epoch=True) 215 | 216 | return result 217 | 218 | def configure_optimizers(self): 219 | return torch.optim.Adam( 220 | self.parameters(), 221 | lr=self.hparams.lr, 222 | weight_decay=self.hparams.weight_decay) 223 | 224 | def train_dataloader(self): 225 | return torch.utils.data.DataLoader( 226 | torch.utils.data.Subset( 227 | datasets.FashionMNIST( 228 | 'data/bit_vector-vae/fmnist_data/', 229 | train=True, 230 | download=True, 231 | transform=transform), 232 | indices=range(55000)), 233 | batch_size=self.hparams.batch_size, 234 | shuffle=True, 235 | num_workers=4, 236 | pin_memory=True) 237 | 238 | def val_dataloader(self): 239 | return torch.utils.data.DataLoader( 240 | torch.utils.data.Subset( 241 | datasets.FashionMNIST( 242 | 'data/bit_vector-vae/fmnist_data/', 243 | train=True, 244 | download=True, 245 | transform=transform), 246 | indices=range(55000, 60000)), 247 | batch_size=self.hparams.batch_size, 248 | shuffle=False, 249 | num_workers=4, 250 | pin_memory=True) 251 | 252 | def test_dataloader(self): 253 | return torch.utils.data.DataLoader( 254 | datasets.FashionMNIST( 255 | 'data/bit_vector-vae/fmnist_data/', 256 | train=False, 257 | download=True, 258 | transform=transform), 259 | batch_size=self.hparams.batch_size, 260 | shuffle=False, 261 | num_workers=4, 262 | pin_memory=True) 263 | 264 | def compute_importance_sampling(self, util_dict, inf_input, n_samples): 265 | 266 | distr = util_dict['distr'] 267 | if self.hparams.mode == 'sparsemap' or self.hparams.mode == 'topksparse': 268 | sampling_distr = Bernoulli( 269 | probs=torch.full( 270 | (inf_input.size(0), self.hparams.latent_size), 271 | 0.5).to(inf_input.device)) 272 | else: 273 | sampling_distr = distr 274 | # importance_samples: [n_samples, batch_size, nlatents] 275 | importance_samples = sampling_distr.sample((n_samples,)) 276 | # logq_z_given_x_importance: [n_samples, batch_size] 277 | logq_z_given_x_importance = \ 278 | sampling_distr.log_prob(importance_samples).sum(dim=-1) 279 | 280 | batch_n_samples = 16 281 | logp_x_given_z_importance = [] 282 | for importance_sample_batch in importance_samples.split(batch_n_samples): 283 | # Xhat_importance: [batch_n_samples * batch_size, n_features] 284 | if self.hparams.mode == 'sfe': 285 | Xhat_importance, _, _ = self.lvm_method.decoder(importance_sample_batch) 286 | else: 287 | Xhat_importance = self.lvm_method.decoder(importance_sample_batch) 288 | # inf_input_repeat: [batch_n_samples * batch_size, n_features] 289 | inf_input_repeat = inf_input.repeat( 290 | batch_n_samples, 1, 1).view(-1, inf_input.size(-1)) 291 | # logp_x_given_z_importance: [batch_n_samples, batch_size] 292 | logp_x_given_z_importance.append( 293 | reconstruction_loss( 294 | inf_input_repeat, 295 | importance_samples, 296 | inf_input_repeat, 297 | Xhat_importance, 298 | inf_input_repeat)[0].view( 299 | batch_n_samples, inf_input.size(0))) 300 | # logp_x_given_z_importance: [n_samples, batch_size] 301 | logp_x_given_z_importance = -torch.cat(logp_x_given_z_importance, dim=0) 302 | # logp_z: [] 303 | logp_z = importance_samples.shape[-1] * torch.log(torch.tensor(0.5)) 304 | # aux will be the log of p(x,z)/q(z|x) 305 | # samples are taken from q(z|x) and then we assess this value 306 | # and average over all samples 307 | # aux: [n_samples, batch_size] 308 | aux = ( 309 | logp_x_given_z_importance 310 | + logp_z 311 | - logq_z_given_x_importance 312 | ) 313 | # logp_x_importance: [batch_size] 314 | logp_x_importance = torch.logsumexp(aux, dim=0) - torch.log( 315 | torch.tensor(float(n_samples)) 316 | ) 317 | 318 | non_supp_influence = logp_x_importance.mean(dim=0) 319 | 320 | if self.hparams.mode == 'sparsemap': 321 | # logp_x_deterministic_term: [batch_size] 322 | logp_x_deterministic_term = [] 323 | logp_x_given_z = - util_dict['loss_output'] 324 | idxs = util_dict['idxs'] 325 | for k in range(inf_input.size(0)): 326 | logp_x_deterministic_term.append( 327 | torch.logsumexp( 328 | logp_x_given_z[torch.tensor(idxs) == k] + logp_z, 329 | dim=0)) 330 | logp_x_deterministic_term = torch.stack(logp_x_deterministic_term) 331 | 332 | # logp_x_importance: [batch_size] 333 | logp_x_importance = torch.logsumexp( 334 | torch.stack([logp_x_deterministic_term, logp_x_importance]), dim=0) 335 | elif self.hparams.mode == 'topksparse': 336 | # need to 'reconstruct' logp_x_given_z since 337 | # we were just dealing with nonzeros 338 | # logp_x_deterministic_term: [batch_size] 339 | mask = distr.view(-1) > 0 340 | logp_x_given_z = -torch.ones_like(mask).to( 341 | torch.float32 342 | ) * float("inf") 343 | logp_x_given_z = logp_x_given_z.masked_scatter( 344 | mask, -util_dict['loss_output']).view(distr.shape) 345 | logp_x_deterministic_term = torch.logsumexp( 346 | logp_x_given_z + logp_z, dim=-1 347 | ) 348 | # logp_x_importance: [batch_size] 349 | logp_x_importance = torch.logsumexp( 350 | torch.stack([logp_x_deterministic_term, logp_x_importance]), dim=0) 351 | else: 352 | non_supp_influence = torch.tensor(0.0) 353 | 354 | logp_x_bits = logp_x_importance.mean(dim=0) / torch.log(torch.tensor(2.0)) 355 | logp_x_bits = - logp_x_bits / self.hparams.n_features 356 | logp_x_nats = logp_x_importance.mean(dim=0) 357 | 358 | return logp_x_bits, logp_x_nats, non_supp_influence 359 | 360 | 361 | def reconstruction_loss( 362 | inf_input, 363 | discrete_latent_z, 364 | _gen_input, 365 | gen_output, 366 | true_labels): 367 | Xhat_logits = gen_output.permute(0, 2, 1) 368 | lv = F.cross_entropy( 369 | Xhat_logits, true_labels.to(dtype=torch.long), reduction="none" 370 | ) 371 | return lv.sum(dim=1), {} 372 | 373 | 374 | def get_model(opt): 375 | n_features = 28 * 28 376 | hidden_size = 128 377 | out_rank = 5 378 | out_classes = 256 379 | model = VAE( 380 | n_features=n_features, 381 | hidden_size=hidden_size, 382 | out_rank=out_rank, 383 | out_classes=out_classes, 384 | budget=opt.budget, 385 | init=not opt.noinit, 386 | mode=opt.mode, 387 | entropy_coeff=opt.entropy_coeff, 388 | latent_size=opt.latent_size, 389 | normalizer=opt.normalizer, 390 | topksparse=opt.topksparse, 391 | gs_tau=opt.gs_tau, 392 | temperature_decay=opt.temperature_decay, 393 | temperature_update_freq=opt.temperature_update_freq, 394 | straight_through=opt.straight_through, 395 | baseline_type=opt.baseline_type, 396 | topk=opt.topk, 397 | random_seed=opt.random_seed, 398 | batch_size=opt.batch_size, 399 | lr=opt.lr, 400 | weight_decay=opt.weight_decay, 401 | optimizer=opt.optimizer) 402 | 403 | return model 404 | 405 | 406 | def main(params): 407 | 408 | arg_parser = argparse.ArgumentParser() 409 | arg_parser = populate_experiment_params(arg_parser) 410 | arg_parser = populate_common_params(arg_parser) 411 | opts = arg_parser.parse_args(params) 412 | 413 | # fix seed 414 | pl.seed_everything(opts.random_seed) 415 | 416 | pathlib.Path( 417 | 'data/bit_vector-vae/fmnist_data/').mkdir( 418 | parents=True, exist_ok=True) 419 | 420 | bit_vector_vae = get_model(opts) 421 | 422 | experiment_name = 'bit-vector' 423 | model_name = '%s/%s' % (experiment_name, opts.mode) 424 | other_info = [ 425 | "lr-{}".format(opts.lr), 426 | "latent_size-{}".format(opts.latent_size), 427 | ] 428 | if opts.mode == "sparsemap": 429 | if opts.budget > 0: 430 | other_info.append(f"b{opts.budget}") 431 | if opts.noinit: 432 | other_info.append("noinit") 433 | elif opts.mode == "gs": 434 | if opts.straight_through: 435 | other_info.append("straight_through") 436 | other_info.append("decay-{}".format(opts.temperature_decay)) 437 | other_info.append("updatefreq-{}".format(opts.temperature_update_freq)) 438 | elif opts.mode == 'sfe': 439 | other_info.append("baseline-{}".format(opts.baseline_type)) 440 | elif opts.mode == "topksparse": 441 | other_info.append("k-{}".format(opts.topksparse)) 442 | 443 | model_name = '%s/%s' % (model_name, '_'.join(other_info)) 444 | 445 | tb_logger = pl_loggers.TensorBoardLogger( 446 | 'logs/', 447 | name=model_name) 448 | 449 | tb_logger.log_hyperparams(opts, metrics=None) 450 | 451 | trainer = pl.Trainer( 452 | progress_bar_refresh_rate=20, 453 | logger=tb_logger, 454 | max_epochs=opts.n_epochs, 455 | weights_save_path='checkpoints/', 456 | weights_summary='full', 457 | gpus=1 if torch.cuda.is_available() else 0, 458 | resume_from_checkpoint=opts.load_from_checkpoint, 459 | deterministic=True) 460 | 461 | trainer.fit(bit_vector_vae) 462 | 463 | 464 | if __name__ == '__main__': 465 | import sys 466 | main(sys.argv[1:]) 467 | -------------------------------------------------------------------------------- /experiments/semi_supervised-vae/archs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def get_one_hot_encoding_from_int(z, n_classes): 7 | """ 8 | Convert categorical variable to one-hot enoding 9 | 10 | Parameters 11 | ---------- 12 | z : torch.LongTensor 13 | Tensor with integers corresponding to categories 14 | n_classes : Int 15 | The total number of categories 16 | 17 | Returns 18 | ---------- 19 | z_one_hot : torch.Tensor 20 | One hot encoding of z 21 | """ 22 | 23 | z_one_hot = torch.zeros(len(z), n_classes).to(z.device) 24 | z_one_hot.scatter_(1, z.view(-1, 1), 1) 25 | z_one_hot = z_one_hot.view(len(z), n_classes) 26 | 27 | return z_one_hot 28 | 29 | 30 | class MLPEncoder(nn.Module): 31 | def __init__( 32 | self, 33 | latent_dim=5, 34 | slen=28, 35 | n_classes=10): 36 | # the encoder returns the mean and variance of the latent parameters 37 | # given the image and its class (one hot encoded) 38 | 39 | super(MLPEncoder, self).__init__() 40 | 41 | # image / model parameters 42 | self.n_pixels = slen ** 2 43 | self.latent_dim = latent_dim 44 | self.slen = slen 45 | self.n_classes = n_classes 46 | 47 | # define the linear layers 48 | self.fc1 = nn.Linear(self.n_pixels + self.n_classes, 128) 49 | self.fc2 = nn.Linear(128, 128) 50 | self.fc3 = nn.Linear(128, latent_dim * 2) 51 | 52 | def forward(self, image, one_hot_label): 53 | # label should be one hot encoded 54 | assert one_hot_label.shape[1] == self.n_classes 55 | assert image.shape[0] == one_hot_label.shape[0] 56 | 57 | # feed through neural network 58 | h = image.view(-1, self.n_pixels) 59 | h = torch.cat((h, one_hot_label), dim=1) 60 | 61 | h = F.relu(self.fc1(h)) 62 | h = F.relu(self.fc2(h)) 63 | h = self.fc3(h) 64 | 65 | # get means, std, and class weights 66 | indx1 = self.latent_dim 67 | indx2 = 2 * self.latent_dim 68 | 69 | latent_means = h[:, 0:indx1] 70 | latent_std = torch.exp(h[:, indx1:indx2]) 71 | 72 | return latent_means, latent_std 73 | 74 | 75 | class Classifier(nn.Module): 76 | def __init__( 77 | self, 78 | slen=28, 79 | n_classes=10): 80 | 81 | super(Classifier, self).__init__() 82 | 83 | self.slen = slen 84 | self.n_pixels = slen ** 2 85 | self.n_classes = n_classes 86 | 87 | self.fc1 = nn.Linear(self.n_pixels, 256) 88 | self.fc2 = nn.Linear(256, 256) 89 | self.fc3 = nn.Linear(256, 256) 90 | self.fc4 = nn.Linear(256, n_classes) 91 | 92 | def forward(self, image): 93 | h = image.view(-1, self.n_pixels) 94 | 95 | h = F.relu(self.fc1(h)) 96 | h = F.relu(self.fc2(h)) 97 | h = F.relu(self.fc3(h)) 98 | h = self.fc4(h) 99 | 100 | return h 101 | 102 | 103 | class MLPDecoder(nn.Module): 104 | def __init__( 105 | self, 106 | latent_dim=5, 107 | slen=28, 108 | n_classes=10): 109 | 110 | # This takes the latent parameters and returns the 111 | # mean and variance for the image reconstruction 112 | 113 | super(MLPDecoder, self).__init__() 114 | 115 | # image/model parameters 116 | self.n_pixels = slen ** 2 117 | self.latent_dim = latent_dim 118 | self.n_classes = n_classes 119 | self.slen = slen 120 | 121 | self.fc1 = nn.Linear(latent_dim + n_classes, 128) 122 | self.fc2 = nn.Linear(128, 128) 123 | self.fc3 = nn.Linear(128, self.n_pixels) 124 | 125 | self.sigmoid = nn.Sigmoid() 126 | 127 | def forward(self, latent_params, one_hot_label): 128 | assert latent_params.shape[1] == self.latent_dim 129 | # label should be one hot encoded 130 | assert one_hot_label.shape[1] == self.n_classes 131 | assert latent_params.shape[0] == one_hot_label.shape[0] 132 | 133 | h = torch.cat((latent_params, one_hot_label), dim=1) 134 | 135 | h = F.relu(self.fc1(h)) 136 | h = F.relu(self.fc2(h)) 137 | h = self.fc3(h) 138 | 139 | h = h.view(-1, self.slen, self.slen) 140 | 141 | image_mean = self.sigmoid(h) 142 | 143 | return image_mean 144 | 145 | 146 | class MNISTVAE(nn.Module): 147 | 148 | def __init__(self, encoder, decoder): 149 | super(MNISTVAE, self).__init__() 150 | 151 | self.encoder = encoder 152 | self.decoder = decoder 153 | 154 | assert self.encoder.latent_dim == self.decoder.latent_dim 155 | assert self.encoder.n_classes == self.decoder.n_classes 156 | assert self.encoder.slen == self.decoder.slen 157 | 158 | # save some parameters 159 | self.latent_dim = self.encoder.latent_dim 160 | self.n_classes = self.encoder.n_classes 161 | self.slen = self.encoder.slen 162 | 163 | def get_one_hot_encoding_from_label(self, label): 164 | return get_one_hot_encoding_from_int(label, self.n_classes) 165 | 166 | def forward(self, discrete_latent_z, image): 167 | 168 | if len(discrete_latent_z.size()) != 2: 169 | one_hot_label = torch.zeros( 170 | len(discrete_latent_z), self.n_classes).to(image.device) 171 | one_hot_label.scatter_( 172 | 1, discrete_latent_z.view(-1, 1), 1) 173 | one_hot_label = one_hot_label.view( 174 | len(discrete_latent_z), self.n_classes) 175 | else: 176 | one_hot_label = discrete_latent_z 177 | 178 | assert one_hot_label.shape[0] == image.shape[0] 179 | assert one_hot_label.shape[1] == self.n_classes 180 | 181 | # pass through encoder 182 | latent_means, latent_std = self.encoder(image, one_hot_label) 183 | 184 | # sample latent dimension 185 | latent_samples = torch.randn( 186 | latent_means.shape).to(latent_means.device) * \ 187 | latent_std + latent_means 188 | 189 | assert one_hot_label.shape[0] == latent_samples.shape[0] 190 | assert one_hot_label.shape[1] == self.n_classes 191 | 192 | # pass through decoder 193 | image_mean = self.decoder(latent_samples, one_hot_label) 194 | 195 | output = { 196 | 'latent_means': latent_means, 197 | 'latent_std': latent_std, 198 | 'latent_samples': latent_samples, 199 | 'image_mean': image_mean 200 | } 201 | return output 202 | -------------------------------------------------------------------------------- /experiments/semi_supervised-vae/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from torch.utils.data import Dataset 6 | from torchvision.datasets import MNIST 7 | import torchvision.transforms as transforms 8 | 9 | 10 | def load_mnist_data(data_dir, train=True): 11 | if not os.path.exists(data_dir): 12 | print('creaing folder: ', data_dir) 13 | os.mkdir(data_dir) 14 | 15 | def trans(x): 16 | return transforms.ToTensor()(x).bernoulli() 17 | 18 | data = MNIST( 19 | root=data_dir, train=train, 20 | transform=trans, download=True) 21 | 22 | return data 23 | 24 | 25 | class MNISTDataSet(Dataset): 26 | 27 | def __init__( 28 | self, 29 | data_dir, 30 | propn_sample=1.0, 31 | indices=None, 32 | train_set=True): 33 | 34 | super(MNISTDataSet, self).__init__() 35 | 36 | # Load MNIST dataset 37 | # This is the full dataset 38 | self.mnist_data_set = load_mnist_data( 39 | data_dir=data_dir, 40 | train=train_set) 41 | 42 | n_image_full = len(self.mnist_data_set.targets) 43 | 44 | # we may wish to subset 45 | if indices is None: 46 | self.num_images = round(n_image_full * propn_sample) 47 | self.sample_indx = np.random.choice( 48 | n_image_full, self.num_images, 49 | replace=False) 50 | else: 51 | self.num_images = len(indices) 52 | self.sample_indx = indices 53 | 54 | def __len__(self): 55 | return self.num_images 56 | 57 | def __getitem__(self, idx): 58 | return {'image': self.mnist_data_set[self.sample_indx[idx]][0].squeeze(), 59 | 'label': self.mnist_data_set[self.sample_indx[idx]][1]} 60 | 61 | 62 | class CycleConcatDataset(Dataset): 63 | '''Dataset wrapping multiple train datasets 64 | Parameters 65 | ---------- 66 | *datasets : sequence of torch.utils.data.Dataset 67 | Datasets to be concatenated and cycled 68 | ''' 69 | def __init__(self, *datasets): 70 | self.datasets = datasets 71 | 72 | def __getitem__(self, i): 73 | result = [] 74 | for dataset in self.datasets: 75 | cycled_i = i % len(dataset) 76 | result.append(dataset[cycled_i]) 77 | 78 | return tuple(result) 79 | 80 | def __len__(self): 81 | return max(len(d) for d in self.datasets) 82 | 83 | 84 | def get_mnist_dataset_semisupervised( 85 | data_dir, 86 | train_test_split_folder, 87 | eval_test_set=False, 88 | n_labeled=5000, 89 | one_of_each=False): 90 | 91 | labeled_indx = np.load(train_test_split_folder + 'labeled_train_indx.npy') 92 | unlabeled_indx = np.load(train_test_split_folder + 'unlabeled_train_indx.npy') 93 | 94 | assert (n_labeled <= labeled_indx.shape[0]) 95 | 96 | if one_of_each: 97 | train_set = MNISTDataSet( 98 | data_dir=data_dir, train_set=True) 99 | 100 | one_of_each_indx = [] 101 | for digit in range(10): 102 | digit_mask = (train_set.mnist_data_set.targets[labeled_indx] == digit) 103 | digit_indx = np.where(digit_mask)[0] 104 | indx_choice = np.random.choice(digit_indx) 105 | one_of_each_indx.append(indx_choice) 106 | one_of_each_indx = np.array(one_of_each_indx) 107 | 108 | labeled_new = labeled_indx[one_of_each_indx] 109 | labeled_rest = np.delete(labeled_indx, one_of_each_indx) 110 | 111 | labeled_indx = labeled_new 112 | unlabeled_indx = np.hstack([unlabeled_indx, labeled_rest]) 113 | 114 | elif labeled_indx.shape[0] != n_labeled: 115 | labeled_new_idxs = np.random.choice( 116 | len(labeled_indx), n_labeled, replace=False) 117 | labeled_new = labeled_indx[labeled_new_idxs] 118 | labeled_rest = np.delete(labeled_indx, labeled_new_idxs) 119 | 120 | labeled_indx = labeled_new 121 | unlabeled_indx = np.hstack([unlabeled_indx, labeled_rest]) 122 | 123 | train_set_labeled = MNISTDataSet( 124 | data_dir=data_dir, 125 | indices=labeled_indx, 126 | train_set=True) 127 | train_set_unlabeled = MNISTDataSet( 128 | data_dir=data_dir, 129 | indices=unlabeled_indx, 130 | train_set=True) 131 | 132 | if eval_test_set: 133 | # get test set as usual 134 | test_set = MNISTDataSet( 135 | data_dir=data_dir, 136 | train_set=False) 137 | else: 138 | validation_indx = np.load( 139 | train_test_split_folder + 'validation_indx.npy') 140 | test_set = MNISTDataSet( 141 | data_dir=data_dir, 142 | indices=validation_indx, 143 | train_set=True) 144 | 145 | return train_set_labeled, train_set_unlabeled, test_set 146 | -------------------------------------------------------------------------------- /experiments/semi_supervised-vae/opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def populate_experiment_params( 5 | arg_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 6 | 7 | arg_parser.add_argument("--labeled_only", action="store_true") 8 | arg_parser.add_argument('--warm_start_path', type=str, default='', 9 | help='Path for warm start') 10 | 11 | return arg_parser 12 | -------------------------------------------------------------------------------- /experiments/semi_supervised-vae/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | 4 | from train import SSVAE 5 | 6 | 7 | def main(checkpoint_path, hparams_path): 8 | 9 | model = SSVAE.load_from_checkpoint( 10 | checkpoint_path=checkpoint_path, 11 | hparams_file=hparams_path, 12 | map_location=None) 13 | 14 | trainer = pl.Trainer( 15 | progress_bar_refresh_rate=1, 16 | weights_summary='full', 17 | gpus=1 if torch.cuda.is_available() else 0, 18 | deterministic=True) 19 | 20 | # test (pass in the model) 21 | trainer.test(model) 22 | 23 | 24 | if __name__ == '__main__': 25 | import sys 26 | main(sys.argv[1], sys.argv[2]) 27 | -------------------------------------------------------------------------------- /experiments/semi_supervised-vae/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import pytorch_lightning as pl 4 | from pytorch_lightning import loggers as pl_loggers 5 | 6 | import torch 7 | from torch.nn import CrossEntropyLoss 8 | 9 | from entmax import SparsemaxLoss, Entmax15Loss 10 | 11 | from lvmhelpers.marg import \ 12 | ExplicitWrapper, Marginalizer 13 | from lvmhelpers.sum_and_sample import \ 14 | SumAndSampleWrapper, SumAndSample 15 | from lvmhelpers.sfe import \ 16 | SFEWrapper, SFEDeterministicWrapper, ScoreFunctionEstimator 17 | from lvmhelpers.nvil import \ 18 | NVILWrapper, NVIL 19 | from lvmhelpers.gumbel import \ 20 | GumbelSoftmaxWrapper, Gumbel 21 | from lvmhelpers.utils import DeterministicWrapper, populate_common_params 22 | 23 | from data import get_mnist_dataset_semisupervised, CycleConcatDataset 24 | from archs import MLPEncoder, MLPDecoder, Classifier, MNISTVAE 25 | from opts import populate_experiment_params 26 | 27 | 28 | class SSVAE(pl.LightningModule): 29 | def __init__( 30 | self, 31 | latent_dim, 32 | slen, 33 | n_classes, 34 | labeled_only, 35 | mode, 36 | entropy_coeff, 37 | vocab_size, 38 | normalizer, 39 | gs_tau, 40 | temperature_decay, 41 | temperature_update_freq, 42 | straight_through, 43 | baseline_type, 44 | topk, 45 | random_seed, 46 | batch_size, 47 | lr, 48 | weight_decay, 49 | optimizer): 50 | super(SSVAE, self).__init__() 51 | 52 | self.save_hyperparameters() 53 | 54 | inference_net = MLPEncoder( 55 | latent_dim=self.hparams.latent_dim, 56 | slen=self.hparams.slen, 57 | n_classes=self.hparams.n_classes) 58 | 59 | generative_net = MLPDecoder( 60 | latent_dim=self.hparams.latent_dim, 61 | slen=self.hparams.slen, 62 | n_classes=self.hparams.n_classes) 63 | 64 | gaussian_vae = MNISTVAE( 65 | inference_net, 66 | generative_net) 67 | 68 | classifier_net = Classifier( 69 | slen=self.hparams.slen, 70 | n_classes=self.hparams.n_classes) 71 | 72 | loss_fun = get_unsupervised_loss 73 | 74 | if self.hparams.mode == 'sfe': 75 | classifier_net = SFEWrapper( 76 | classifier_net, 77 | baseline_type=self.hparams.baseline_type) 78 | gaussian_vae = SFEDeterministicWrapper(gaussian_vae) 79 | lvm_method = ScoreFunctionEstimator 80 | elif self.hparams.mode == 'nvil': 81 | classifier_net = NVILWrapper(classifier_net, input_size=slen**2) 82 | gaussian_vae = DeterministicWrapper(gaussian_vae) 83 | lvm_method = NVIL 84 | elif self.hparams.mode == 'gs': 85 | classifier_net = GumbelSoftmaxWrapper( 86 | classifier_net, 87 | temperature=self.hparams.gs_tau, 88 | straight_through=self.hparams.straight_through) 89 | gaussian_vae = DeterministicWrapper(gaussian_vae) 90 | lvm_method = Gumbel 91 | elif self.hparams.mode == 'marg': 92 | classifier_net = ExplicitWrapper( 93 | classifier_net, normalizer=self.hparams.normalizer) 94 | gaussian_vae = DeterministicWrapper(gaussian_vae) 95 | lvm_method = Marginalizer 96 | elif self.hparams.mode == 'sumsample': 97 | classifier_net = SumAndSampleWrapper( 98 | classifier_net, 99 | topk=self.hparams.topk, 100 | baseline_type=self.hparams.baseline_type) 101 | gaussian_vae = DeterministicWrapper(gaussian_vae) 102 | lvm_method = SumAndSample 103 | else: 104 | raise RuntimeError(f"Unknown training mode: {self.hparams.mode}") 105 | self.lvm_method = lvm_method( 106 | classifier_net, 107 | gaussian_vae, 108 | loss_fun, 109 | encoder_entropy_coeff=1.0) 110 | 111 | def forward(self, classifier_input, vae_input, labels): 112 | return self.lvm_method(classifier_input, vae_input, labels) 113 | 114 | def training_step(self, batch, batch_nb): 115 | if not self.hparams.labeled_only: 116 | labeled_batch, unlabeled_batch = batch 117 | labeled_batch_image = labeled_batch['image'] 118 | labeled_batch_labels = labeled_batch['label'] 119 | unlabeled_batch_image = unlabeled_batch['image'] 120 | unlabeled_batch_labels = unlabeled_batch['label'] 121 | else: 122 | labeled_batch_image = batch['image'] 123 | labeled_batch_labels = batch['label'] 124 | 125 | vae = self.lvm_method.decoder 126 | if hasattr(vae, 'agent'): 127 | vae = vae.agent 128 | classifier = self.lvm_method.encoder 129 | 130 | supervised_loss = get_supervised_loss( 131 | vae, 132 | classifier, 133 | labeled_batch_image, 134 | labeled_batch_labels, 135 | self.hparams.normalizer) 136 | 137 | if not self.hparams.labeled_only: 138 | unsupervised_output = self( 139 | unlabeled_batch_image, 140 | unlabeled_batch_image, 141 | unlabeled_batch_labels) 142 | unsupervised_loss = unsupervised_output['loss'] 143 | 144 | loss = \ 145 | supervised_loss + \ 146 | unsupervised_loss * (self.num_unlabeled / self.num_labeled) 147 | else: 148 | loss = supervised_loss 149 | 150 | result = pl.TrainResult(minimize=loss) 151 | if not self.hparams.labeled_only: 152 | result.log('train_elbo', unsupervised_output['log']['loss'], prog_bar=True) 153 | result.log('train_acc', unsupervised_output['log']['acc'], prog_bar=True) 154 | 155 | if 'support' in unsupervised_output['log'].keys(): 156 | result.log( 157 | 'train_support', 158 | unsupervised_output['log']['support'], 159 | prog_bar=True) 160 | 161 | # Update temperature if Gumbel 162 | if self.hparams.mode == 'gs': 163 | self.lvm_method.encoder.update_temperature( 164 | self.global_step, 165 | self.hparams.temperature_update_freq, 166 | self.hparams.temperature_decay) 167 | result.log('temperature', self.lvm_method.encoder.temperature) 168 | 169 | return result 170 | 171 | def validation_step(self, batch, batch_nb): 172 | image = batch['image'] 173 | true_labels = batch['label'] 174 | validation_result = self(image, image, true_labels) 175 | result = pl.EvalResult(checkpoint_on=validation_result['log']['loss']) 176 | result.log('val_elbo', validation_result['log']['loss'], prog_bar=True) 177 | result.log('val_acc', validation_result['log']['acc'], prog_bar=True) 178 | 179 | if 'support' in validation_result['log'].keys(): 180 | result.log( 181 | 'val_support', 182 | validation_result['log']['support'], 183 | prog_bar=True) 184 | return result 185 | 186 | def test_step(self, batch, batch_nb): 187 | image = batch['image'] 188 | true_labels = batch['label'] 189 | test_result = self(image, image, true_labels) 190 | result = pl.EvalResult() 191 | result.log('test_elbo', test_result['log']['loss']) 192 | result.log('test_acc', test_result['log']['acc']) 193 | 194 | if 'support' in test_result['log'].keys(): 195 | result.log( 196 | 'test_support', 197 | test_result['log']['support'], 198 | prog_bar=True) 199 | return result 200 | 201 | def configure_optimizers(self): 202 | return torch.optim.Adam( 203 | self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) 204 | 205 | def train_dataloader(self): 206 | train_labeled, train_unlabeled, _ = get_mnist_dataset_semisupervised( 207 | data_dir='data/semi_supervised-vae/mnist_data/', 208 | train_test_split_folder='data/semi_supervised-vae/test_train_splits/', 209 | n_labeled=5000, 210 | one_of_each=False) 211 | 212 | self.num_labeled = len(train_labeled) 213 | self.num_unlabeled = len(train_unlabeled) 214 | 215 | if not self.hparams.labeled_only: 216 | concat_dataset = CycleConcatDataset( 217 | train_labeled, 218 | train_unlabeled 219 | ) 220 | loader = torch.utils.data.DataLoader( 221 | concat_dataset, 222 | batch_size=self.hparams.batch_size, 223 | shuffle=True, 224 | num_workers=4, 225 | pin_memory=True 226 | ) 227 | else: 228 | self.num_labeled = len(train_labeled) 229 | self.num_unlabeled = len(train_labeled) 230 | loader = torch.utils.data.DataLoader( 231 | train_labeled, 232 | batch_size=self.hparams.batch_size, 233 | shuffle=True, 234 | num_workers=4, 235 | pin_memory=True 236 | ) 237 | 238 | return loader 239 | 240 | def val_dataloader(self): 241 | _, _, valid_set = get_mnist_dataset_semisupervised( 242 | data_dir='data/semi_supervised-vae/mnist_data/', 243 | train_test_split_folder='data/semi_supervised-vae/test_train_splits/', 244 | eval_test_set=False) 245 | return torch.utils.data.DataLoader( 246 | dataset=valid_set, 247 | batch_size=self.hparams.batch_size, 248 | shuffle=False, 249 | num_workers=4, 250 | pin_memory=True) 251 | 252 | def test_dataloader(self): 253 | _, _, test_set = get_mnist_dataset_semisupervised( 254 | data_dir='data/semi_supervised-vae/mnist_data/', 255 | train_test_split_folder='data/semi_supervised-vae/test_train_splits/', 256 | eval_test_set=True) 257 | return torch.utils.data.DataLoader( 258 | dataset=test_set, 259 | batch_size=self.hparams.batch_size, 260 | shuffle=False, 261 | num_workers=4, 262 | pin_memory=True) 263 | 264 | 265 | def get_reconstruction_loss(x_reconstructed, x): 266 | batch_size = x.shape[0] 267 | 268 | bce_loss = -x * torch.log(x_reconstructed + 1e-8) - \ 269 | (1 - x) * torch.log(1 - x_reconstructed + 1e-8) 270 | 271 | return bce_loss.view(batch_size, -1).sum(dim=1) 272 | 273 | 274 | def get_kl_divergence_loss(mean, logvar): 275 | batch_size = mean.shape[0] 276 | return (( 277 | mean**2 + logvar.exp() - 1 - logvar 278 | ) / 2).view(batch_size, -1).sum(dim=1) 279 | 280 | 281 | def get_elbo_loss(image, vae_output): 282 | latent_means = vae_output['latent_means'] 283 | latent_std = vae_output['latent_std'] 284 | image_mean = vae_output['image_mean'] 285 | reconstruction_loss = get_reconstruction_loss( 286 | image_mean, image) 287 | kl_divergence_loss = get_kl_divergence_loss( 288 | latent_means, 2 * torch.log(latent_std)) 289 | return reconstruction_loss + kl_divergence_loss 290 | 291 | 292 | def get_unsupervised_loss( 293 | _classifier_input, 294 | discrete_latent_z, 295 | _vae_input, 296 | vae_output, 297 | true_labels): 298 | vae_loss = get_elbo_loss(_classifier_input, vae_output) 299 | # classifier accuracy (for logging) 300 | if len(discrete_latent_z.size()) == 2: 301 | discrete_latent_z = discrete_latent_z.argmax(dim=-1) 302 | acc = (discrete_latent_z == true_labels).float() 303 | return vae_loss, {'acc': acc} 304 | 305 | 306 | def get_supervised_loss( 307 | vae, 308 | classifier, 309 | labeled_image, 310 | true_labels, 311 | normalizer): 312 | if normalizer == 'softmax': 313 | loss = CrossEntropyLoss(reduction='none') 314 | elif normalizer == 'entmax15': 315 | loss = Entmax15Loss(reduction='none') 316 | elif normalizer == 'sparsemax': 317 | loss = SparsemaxLoss(reduction='none') 318 | else: 319 | raise NameError("%s is not a valid normalizer!" % (normalizer, )) 320 | # get loss on a batch of labeled images 321 | vae_output = vae(true_labels, labeled_image) 322 | labeled_loss = get_elbo_loss(labeled_image, vae_output) 323 | # cross entropy term 324 | logits = classifier.agent.forward(labeled_image) 325 | cross_entropy = loss(logits, true_labels) 326 | 327 | return (labeled_loss + cross_entropy).mean() 328 | 329 | 330 | def get_model(opt): 331 | 332 | model = SSVAE( 333 | latent_dim=8, 334 | slen=28, 335 | n_classes=10, 336 | labeled_only=opt.labeled_only, 337 | mode=opt.mode, 338 | entropy_coeff=opt.entropy_coeff, 339 | vocab_size=opt.latent_size, 340 | normalizer=opt.normalizer, 341 | gs_tau=opt.gs_tau, 342 | temperature_decay=opt.temperature_decay, 343 | temperature_update_freq=opt.temperature_update_freq, 344 | straight_through=opt.straight_through, 345 | baseline_type=opt.baseline_type, 346 | topk=opt.topk, 347 | random_seed=opt.random_seed, 348 | batch_size=opt.batch_size, 349 | lr=opt.lr, 350 | weight_decay=opt.weight_decay, 351 | optimizer=opt.optimizer) 352 | 353 | if len(opt.warm_start_path) != 0: 354 | model = model.load_from_checkpoint( 355 | opt.warm_start_path, 356 | latent_dim=8, 357 | slen=28, 358 | n_classes=10, 359 | labeled_only=opt.labeled_only, 360 | mode=opt.mode, 361 | entropy_coeff=opt.entropy_coeff, 362 | vocab_size=opt.latent_size, 363 | normalizer=opt.normalizer, 364 | gs_tau=opt.gs_tau, 365 | temperature_decay=opt.temperature_decay, 366 | temperature_update_freq=opt.temperature_update_freq, 367 | straight_through=opt.straight_through, 368 | baseline_type=opt.baseline_type, 369 | topk=opt.topk, 370 | random_seed=opt.random_seed, 371 | batch_size=opt.batch_size, 372 | lr=opt.lr, 373 | weight_decay=opt.weight_decay, 374 | optimizer=opt.optimizer, 375 | strict=False) 376 | 377 | return model 378 | 379 | 380 | def main(params): 381 | 382 | arg_parser = argparse.ArgumentParser() 383 | arg_parser = populate_experiment_params(arg_parser) 384 | arg_parser = populate_common_params(arg_parser) 385 | opts = arg_parser.parse_args(params) 386 | 387 | # fix seed 388 | pl.seed_everything(opts.random_seed) 389 | 390 | model = get_model(opts) 391 | 392 | experiment_name = 'ssvae' 393 | if not opts.labeled_only: 394 | model_name = '%s/%s' % (experiment_name, opts.mode) 395 | else: 396 | model_name = '%s/warm_start/%s' % (experiment_name, opts.normalizer) 397 | other_info = [ 398 | "lr-{}".format(opts.lr), 399 | ] 400 | 401 | if opts.mode == "gs": 402 | if opts.straight_through: 403 | other_info.append("straight_through") 404 | other_info.append("decay-{}".format(opts.temperature_decay)) 405 | other_info.append("updatefreq-{}".format(opts.temperature_update_freq)) 406 | elif opts.mode == 'sfe': 407 | other_info.append("baseline-{}".format(opts.baseline_type)) 408 | elif opts.mode == "marg": 409 | other_info.append("norm-{}".format(opts.normalizer)) 410 | elif opts.mode == 'sumsample': 411 | other_info.append("k-{}".format(opts.topk)) 412 | other_info.append("baseline-{}".format(opts.baseline_type)) 413 | 414 | model_name = '%s/%s' % (model_name, '_'.join(other_info)) 415 | 416 | tb_logger = pl_loggers.TensorBoardLogger( 417 | 'logs/', 418 | name=model_name) 419 | 420 | tb_logger.log_hyperparams(opts, metrics=None) 421 | 422 | trainer = pl.Trainer( 423 | progress_bar_refresh_rate=20, 424 | logger=tb_logger, 425 | max_epochs=opts.n_epochs, 426 | weights_save_path='checkpoints/', 427 | weights_summary='full', 428 | gpus=1 if torch.cuda.is_available() else 0, 429 | resume_from_checkpoint=opts.load_from_checkpoint, 430 | deterministic=True) 431 | trainer.fit(model) 432 | 433 | 434 | if __name__ == '__main__': 435 | import sys 436 | main(sys.argv[1:]) 437 | -------------------------------------------------------------------------------- /experiments/signal-game/archs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Sender(nn.Module): 7 | def __init__(self, game_size, feat_size, embedding_size, hidden_size, 8 | vocab_size=100, temp=1.): 9 | super(Sender, self).__init__() 10 | self.game_size = game_size 11 | self.embedding_size = embedding_size 12 | self.hidden_size = hidden_size 13 | self.vocab_size = vocab_size 14 | self.temp = temp 15 | 16 | self.lin1 = nn.Linear(feat_size, embedding_size, bias=False) 17 | self.conv2 = nn.Conv2d(1, hidden_size, 18 | kernel_size=(1, 1), 19 | stride=(1, 1), bias=False) 20 | self.conv3 = nn.Conv2d(1, 1, 21 | kernel_size=(hidden_size, 1), 22 | stride=(hidden_size, 1), bias=False) 23 | self.lin4 = nn.Linear(embedding_size, vocab_size, bias=False) 24 | 25 | def forward(self, x, return_embeddings=False): 26 | emb = self.return_embeddings(x) 27 | 28 | # in: h of size (batch_size, 1, 1, embedding_size) 29 | # out: h of size (batch_size, hidden_size, 1, embedding_size) 30 | h = self.conv2(emb) 31 | h = torch.sigmoid(h) 32 | # in: h of size (batch_size, hidden_size, 1, embedding_size) 33 | # out: h of size (batch_size, 1, hidden_size, embedding_size) 34 | h = h.transpose(1, 2) 35 | h = self.conv3(h) 36 | # h of size (batch_size, 1, 1, embedding_size) 37 | h = torch.sigmoid(h) 38 | h = h.squeeze(dim=1) 39 | h = h.squeeze(dim=1) 40 | # h of size (batch_size, embedding_size) 41 | h = self.lin4(h) 42 | h = h.mul(1./self.temp) 43 | 44 | return h 45 | 46 | def return_embeddings(self, x): 47 | # sender only sees a single image---the one receiver needs to pick 48 | h = x[0] 49 | if len(h.size()) == 3: 50 | h = h.squeeze(dim=-1) 51 | h_i = self.lin1(h) 52 | # h_i are batch_size x embedding_size 53 | h_i = h_i.unsqueeze(dim=1) 54 | h_i = h_i.unsqueeze(dim=1) 55 | # h_i are now batch_size x 1 x 1 x embedding_size 56 | return h_i 57 | 58 | 59 | class Receiver(nn.Module): 60 | def __init__(self, game_size, feat_size, embedding_size, 61 | vocab_size, sfe): 62 | super(Receiver, self).__init__() 63 | self.game_size = game_size 64 | self.embedding_size = embedding_size 65 | 66 | self.lin1 = nn.Linear(feat_size, embedding_size, bias=False) 67 | if sfe: 68 | self.lin2 = nn.Embedding(vocab_size, embedding_size) 69 | else: 70 | self.lin2 = nn.Linear(vocab_size, embedding_size, bias=False) 71 | 72 | def forward(self, signal, x): 73 | # embed each image (left or right) 74 | emb = self.return_embeddings(x) 75 | # embed the signal 76 | if len(signal.size()) == 3: 77 | signal = signal.squeeze(dim=-1) 78 | h_s = self.lin2(signal) 79 | # h_s is of size batch_size x embedding_size 80 | h_s = h_s.unsqueeze(dim=1) 81 | # h_s is of size batch_size x 1 x embedding_size 82 | h_s = h_s.transpose(1, 2) 83 | # h_s is of size batch_size x embedding_size x 1 84 | out = torch.bmm(emb, h_s) 85 | # out is of size batch_size x game_size x 1 86 | out = out.squeeze(dim=-1) 87 | # out is of size batch_size x game_size 88 | log_probs = F.log_softmax(out, dim=1) 89 | return log_probs 90 | 91 | def return_embeddings(self, x): 92 | # embed each image (left or right) 93 | embs = [] 94 | # receiver sees game_size images; only one is the one sender saw 95 | for i in range(self.game_size): 96 | h = x[i] 97 | if len(h.size()) == 3: 98 | h = h.squeeze(dim=-1) 99 | h_i = self.lin1(h) 100 | # h_i are batch_size x embedding_size 101 | h_i = h_i.unsqueeze(dim=1) 102 | # h_i are now batch_size x 1 x embedding_size 103 | embs.append(h_i) 104 | h = torch.cat(embs, dim=1) 105 | return h 106 | -------------------------------------------------------------------------------- /experiments/signal-game/data.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch.utils.data as data 3 | import torch.nn.parallel 4 | import os 5 | import torch 6 | import numpy as np 7 | 8 | 9 | class _BatchIterator: 10 | def __init__(self, loader, seed=None): 11 | self.loader = loader 12 | self.random_state = np.random.RandomState(seed) 13 | 14 | def __iter__(self): 15 | return self 16 | 17 | def __next__(self): 18 | batch_data = self.get_batch() 19 | return batch_data 20 | 21 | def get_batch(self): 22 | loader = self.loader 23 | bsz = loader.bsz 24 | game_size = loader.game_size 25 | same = loader.same 26 | 27 | C = len(self.loader.dataset.obj2id.keys()) # number of concepts 28 | images_indexes_sender = np.zeros((bsz, game_size)) 29 | 30 | for b in range(bsz): 31 | if same: 32 | # randomly sample a concept 33 | concepts = self.random_state.choice(C, 1) 34 | c = concepts[0] 35 | ims = loader.dataset.obj2id[c]["ims"] 36 | idxs_sender = self.random_state.choice( 37 | ims, game_size, replace=False) 38 | images_indexes_sender[b, :] = idxs_sender 39 | else: 40 | idxs_sender = [] 41 | # randomly sample k concepts 42 | concepts = self.random_state.choice( 43 | C, game_size, replace=False) 44 | for i, c in enumerate(concepts): 45 | ims = loader.dataset.obj2id[c]["ims"] 46 | idx = self.random_state.choice(ims, 2, replace=False) 47 | idxs_sender.append(idx[0]) 48 | 49 | images_indexes_sender[b, :] = np.array(idxs_sender) 50 | 51 | images_vectors_sender = [] 52 | 53 | for i in range(game_size): 54 | x, _ = loader.dataset[images_indexes_sender[:, i]] 55 | images_vectors_sender.append(x) 56 | 57 | images_vectors_sender = torch.stack(images_vectors_sender).contiguous() 58 | y = torch.zeros(bsz).long() 59 | 60 | images_vectors_receiver = torch.zeros_like(images_vectors_sender) 61 | for i in range(bsz): 62 | permutation = torch.randperm(game_size) 63 | 64 | images_vectors_receiver[:, i, 65 | :] = images_vectors_sender[permutation, i, :] 66 | y[i] = permutation.argmin() 67 | return images_vectors_sender, images_vectors_receiver, y 68 | 69 | 70 | class ImagenetLoader(torch.utils.data.DataLoader): 71 | def __init__(self, *args, **kwargs): 72 | self.seed = kwargs.pop('seed') 73 | self.bsz = kwargs.pop('batch_size') 74 | self.game_size = kwargs.pop('game_size') 75 | self.same = kwargs.pop('same') 76 | 77 | super(ImagenetLoader, self).__init__(*args, **kwargs) 78 | 79 | def __iter__(self): 80 | if self.seed is None: 81 | seed = np.random.randint(0, 2 ** 32) 82 | else: 83 | seed = self.seed 84 | return _BatchIterator(self, seed=seed) 85 | 86 | 87 | class ImageNetFeat(data.Dataset): 88 | def __init__(self, root, train=True): 89 | import h5py 90 | 91 | self.root = os.path.expanduser(root) 92 | self.train = train # training set or test set 93 | 94 | # FC features 95 | fc_file = os.path.join(root, 'ours_images_single_sm0.h5') 96 | 97 | fc = h5py.File(fc_file, 'r') 98 | # There should be only 1 key 99 | key = list(fc.keys())[0] 100 | # Get the data 101 | data = torch.FloatTensor(list(fc[key])) 102 | 103 | # normalise data 104 | img_norm = torch.norm(data, p=2, dim=1, keepdim=True) 105 | normed_data = data / img_norm 106 | 107 | objects_file = os.path.join(root, 108 | 'ours_images_single_sm0.objects') 109 | with open(objects_file, "rb") as f: 110 | labels = pickle.load(f) 111 | objects_file = os.path.join(root, 112 | 'ours_images_paths_sm0.objects') 113 | with open(objects_file, "rb") as f: 114 | paths = pickle.load(f) 115 | 116 | self.create_obj2id(labels) 117 | self.data_tensor = normed_data 118 | self.labels = labels 119 | self.paths = paths 120 | 121 | def __getitem__(self, index): 122 | return self.data_tensor[index], index 123 | 124 | def __len__(self): 125 | return self.data_tensor.size(0) 126 | 127 | def create_obj2id(self, labels): 128 | self.obj2id = {} 129 | keys = {} 130 | idx_label = -1 131 | for i in range(labels.shape[0]): 132 | if not labels[i] in keys.keys(): 133 | idx_label += 1 134 | keys[labels[i]] = idx_label 135 | self.obj2id[idx_label] = {} 136 | self.obj2id[idx_label]['labels'] = labels[i] 137 | self.obj2id[idx_label]['ims'] = [] 138 | self.obj2id[idx_label]['ims'].append(i) 139 | -------------------------------------------------------------------------------- /experiments/signal-game/opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def populate_experiment_params( 5 | arg_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 6 | 7 | arg_parser.add_argument('--root', default='data/signal-game/', 8 | help='data root folder') 9 | # 2-agents specific parameters 10 | arg_parser.add_argument('--tau_s', type=float, default=10.0, 11 | help='Sender Gibbs temperature') 12 | arg_parser.add_argument('--game_size', type=int, default=2, 13 | help='Number of images seen by an agent') 14 | arg_parser.add_argument('--same', type=int, default=0, 15 | help='Use same concepts') 16 | arg_parser.add_argument('--embedding_size', type=int, default=50, 17 | help='embedding size') 18 | arg_parser.add_argument('--hidden_size', type=int, default=20, 19 | help='hidden size (number of filters informed sender)') 20 | arg_parser.add_argument('--batches_per_epoch', type=int, default=100, 21 | help='Batches in a single training/validation epoch') 22 | 23 | arg_parser.add_argument('--loss_type', type=str, default='nll', 24 | help='acc or nll') 25 | 26 | return arg_parser 27 | -------------------------------------------------------------------------------- /experiments/signal-game/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | 4 | from train import SignalGame 5 | 6 | 7 | def main(checkpoint_path, hparams_path): 8 | 9 | model = SignalGame.load_from_checkpoint( 10 | checkpoint_path=checkpoint_path, 11 | hparams_file=hparams_path, 12 | map_location=None) 13 | 14 | trainer = pl.Trainer( 15 | progress_bar_refresh_rate=1, 16 | weights_summary='full', 17 | limit_test_batches=10000//model.hparams.batch_size, 18 | gpus=1 if torch.cuda.is_available() else 0, 19 | deterministic=True) 20 | 21 | # test (pass in the model) 22 | trainer.test(model) 23 | 24 | 25 | if __name__ == '__main__': 26 | import sys 27 | main(sys.argv[1], sys.argv[2]) 28 | -------------------------------------------------------------------------------- /experiments/signal-game/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import pytorch_lightning as pl 7 | from pytorch_lightning import loggers as pl_loggers 8 | 9 | from lvmhelpers.marg import \ 10 | ExplicitWrapper, Marginalizer 11 | from lvmhelpers.sum_and_sample import \ 12 | SumAndSampleWrapper, SumAndSample 13 | from lvmhelpers.sfe import \ 14 | SFEWrapper, SFEDeterministicWrapper, ScoreFunctionEstimator 15 | from lvmhelpers.nvil import \ 16 | NVILWrapper, NVIL 17 | from lvmhelpers.gumbel import \ 18 | GumbelSoftmaxWrapper, Gumbel 19 | from lvmhelpers.utils import DeterministicWrapper, populate_common_params 20 | 21 | from data import ImageNetFeat, ImagenetLoader 22 | from archs import Sender, Receiver 23 | from opts import populate_experiment_params 24 | 25 | 26 | class CheckpointEveryNSteps(pl.Callback): 27 | """ 28 | Save a checkpoint every N steps, instead of Lightning's default that checkpoints 29 | based on validation loss. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | save_step_frequency, 35 | prefix="N-Step-Checkpoint", 36 | use_modelcheckpoint_filename=False): 37 | """ 38 | Args: 39 | save_step_frequency: how often to save in steps 40 | prefix: add a prefix to the name, only used if 41 | use_modelcheckpoint_filename=False 42 | use_modelcheckpoint_filename: just use the ModelCheckpoint callback's 43 | default filename, don't use ours. 44 | """ 45 | self.save_step_frequency = save_step_frequency 46 | self.prefix = prefix 47 | self.use_modelcheckpoint_filename = use_modelcheckpoint_filename 48 | 49 | def on_batch_end(self, trainer: pl.Trainer, _): 50 | """ Check if we should save a checkpoint after every train batch """ 51 | epoch = trainer.current_epoch 52 | global_step = trainer.global_step 53 | if global_step % self.save_step_frequency == 0: 54 | if self.use_modelcheckpoint_filename: 55 | filename = trainer.checkpoint_callback.filename 56 | else: 57 | filename = f"{self.prefix}_epoch={epoch}_global_step={global_step}.ckpt" 58 | ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename) 59 | trainer.save_checkpoint(ckpt_path) 60 | 61 | 62 | class SignalGame(pl.LightningModule): 63 | def __init__( 64 | self, 65 | feat_size, 66 | embedding_size, 67 | hidden_size, 68 | game_size, 69 | tau_s, 70 | loss_type, 71 | root, 72 | same, 73 | mode, 74 | entropy_coeff, 75 | vocab_size, 76 | normalizer, 77 | gs_tau, 78 | temperature_decay, 79 | temperature_update_freq, 80 | straight_through, 81 | baseline_type, 82 | topk, 83 | random_seed, 84 | batch_size, 85 | lr, 86 | weight_decay, 87 | optimizer): 88 | super(SignalGame, self).__init__() 89 | 90 | self.save_hyperparameters() 91 | 92 | sender = Sender( 93 | self.hparams.game_size, 94 | self.hparams.feat_size, 95 | self.hparams.embedding_size, 96 | self.hparams.hidden_size, 97 | self.hparams.vocab_size, 98 | temp=self.hparams.tau_s) 99 | 100 | receiver = Receiver( 101 | self.hparams.game_size, 102 | self.hparams.feat_size, 103 | self.hparams.embedding_size, 104 | self.hparams.vocab_size, 105 | sfe=( 106 | self.hparams.mode == 'sfe' or 107 | self.hparams.mode == 'marg' or 108 | self.hparams.mode == 'sumsample' or 109 | self.hparams.mode == 'nvil')) 110 | 111 | loss_fun = loss_nll 112 | 113 | if self.hparams.mode == 'sfe': 114 | sender = SFEWrapper(sender, baseline_type=self.hparams.baseline_type) 115 | if self.hparams.loss_type == 'acc': 116 | loss_fun = loss_acc 117 | receiver = SFEWrapper( 118 | receiver, baseline_type=self.hparams.baseline_type) 119 | else: 120 | receiver = SFEDeterministicWrapper(receiver) 121 | lvm_method = ScoreFunctionEstimator 122 | elif self.hparams.mode == 'nvil': 123 | sender = NVILWrapper(sender, input_size=feat_size) 124 | receiver = DeterministicWrapper(receiver) 125 | lvm_method = NVIL 126 | elif self.hparams.mode == 'gs': 127 | sender = GumbelSoftmaxWrapper( 128 | sender, 129 | temperature=self.hparams.gs_tau, 130 | straight_through=self.hparams.straight_through) 131 | receiver = DeterministicWrapper(receiver) 132 | lvm_method = Gumbel 133 | elif self.hparams.mode == 'marg': 134 | sender = ExplicitWrapper(sender, normalizer=self.hparams.normalizer) 135 | receiver = DeterministicWrapper(receiver) 136 | lvm_method = Marginalizer 137 | elif self.hparams.mode == 'sumsample': 138 | sender = SumAndSampleWrapper( 139 | sender, topk=self.hparams.topk, baseline_type=self.hparams.baseline_type) 140 | receiver = DeterministicWrapper(receiver) 141 | lvm_method = SumAndSample 142 | else: 143 | raise RuntimeError(f"Unknown training mode: {self.hparams.mode}") 144 | 145 | self.lvm_method = lvm_method( 146 | sender, 147 | receiver, 148 | loss_fun, 149 | encoder_entropy_coeff=self.hparams.entropy_coeff, 150 | decoder_entropy_coeff=self.hparams.entropy_coeff) 151 | 152 | def forward(self, sender_input, receiver_input, labels): 153 | return self.lvm_method(sender_input, receiver_input, labels) 154 | 155 | def training_step(self, batch, batch_nb): 156 | sender_input, receiver_input, labels = batch 157 | training_result = self(sender_input, receiver_input, labels) 158 | loss = training_result['loss'] 159 | 160 | result = pl.TrainResult(minimize=loss) 161 | result.log('train_loss', training_result['log']['loss'], prog_bar=True) 162 | result.log('train_acc', training_result['log']['acc'], prog_bar=True) 163 | 164 | if 'support' in training_result['log'].keys(): 165 | result.log( 166 | 'train_support', 167 | training_result['log']['support'], 168 | prog_bar=True) 169 | 170 | # Update temperature if Gumbel 171 | if self.hparams.mode == 'gs': 172 | self.lvm_method.encoder.update_temperature( 173 | self.global_step, 174 | self.hparams.temperature_update_freq, 175 | self.hparams.temperature_decay) 176 | result.log('temperature', self.lvm_method.encoder.temperature) 177 | 178 | return result 179 | 180 | def validation_step(self, batch, batch_nb): 181 | sender_input, receiver_input, labels = batch 182 | validation_result = self(sender_input, receiver_input, labels) 183 | result = pl.EvalResult(checkpoint_on=validation_result['log']['loss']) 184 | result.log('val_loss', validation_result['log']['loss'], prog_bar=True) 185 | result.log('val_acc', validation_result['log']['acc'], prog_bar=True) 186 | 187 | if 'support' in validation_result['log'].keys(): 188 | result.log( 189 | 'val_support', 190 | validation_result['log']['support'], 191 | prog_bar=True) 192 | return result 193 | 194 | def test_step(self, batch, batch_nb): 195 | sender_input, receiver_input, labels = batch 196 | test_result = self(sender_input, receiver_input, labels) 197 | result = pl.EvalResult() 198 | result.log('test_loss', test_result['log']['loss']) 199 | result.log('test_acc', test_result['log']['acc']) 200 | 201 | argmax_sample = test_result['log']['distr'].argmax(dim=-1) 202 | z_one_hot = \ 203 | torch.zeros( 204 | len(argmax_sample), 205 | test_result['log']['distr'].size(-1)).to(argmax_sample.device) 206 | z_one_hot.scatter_(1, argmax_sample.view(-1, 1), 1) 207 | z_one_hot = z_one_hot.view( 208 | len(argmax_sample), test_result['log']['distr'].size(-1)) 209 | 210 | if not hasattr(self, 'usage'): 211 | self.usage = z_one_hot.sum(dim=0).cpu().numpy() 212 | else: 213 | self.usage += z_one_hot.sum(dim=0).cpu().numpy() 214 | 215 | if 'support' in test_result['log'].keys(): 216 | result.log( 217 | 'test_support', 218 | test_result['log']['support'], 219 | prog_bar=True) 220 | return result 221 | 222 | def configure_optimizers(self): 223 | return torch.optim.Adam( 224 | self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) 225 | 226 | def train_dataloader(self): 227 | data_folder = os.path.join(self.hparams.root, "train/") 228 | dataset = ImageNetFeat(root=data_folder) 229 | return ImagenetLoader( 230 | dataset, 231 | batch_size=self.hparams.batch_size, 232 | game_size=self.hparams.game_size, 233 | same=self.hparams.same, 234 | shuffle=True, 235 | seed=self.hparams.random_seed, 236 | num_workers=4, 237 | pin_memory=True) 238 | 239 | def val_dataloader(self): 240 | # fixed seed so it's always the same 1024 (32*32) pairs 241 | data_folder = os.path.join(self.hparams.root, "train/") 242 | dataset = ImageNetFeat(root=data_folder, train=False) 243 | return ImagenetLoader( 244 | dataset, 245 | batch_size=self.hparams.batch_size, 246 | game_size=self.hparams.game_size, 247 | same=self.hparams.same, 248 | shuffle=False, 249 | seed=20200724, 250 | num_workers=4, 251 | pin_memory=True) 252 | 253 | def test_dataloader(self): 254 | # fixed seed so it's always the same 1024 (32*32) pairs 255 | data_folder = os.path.join(self.hparams.root, "test/") 256 | dataset = ImageNetFeat(root=data_folder, train=False) 257 | return ImagenetLoader( 258 | dataset, 259 | batch_size=self.hparams.batch_size, 260 | game_size=self.hparams.game_size, 261 | same=self.hparams.same, 262 | shuffle=False, 263 | seed=20200725, 264 | num_workers=4, 265 | pin_memory=True) 266 | 267 | 268 | def loss_acc(_sender_input, _message, _receiver_input, receiver_output, labels): 269 | """ 270 | Accuracy loss - non-differetiable hence cannot be used with GS 271 | """ 272 | # receiver outputs are samples 273 | acc = (labels == receiver_output).float() 274 | return -acc, {'acc': acc} 275 | 276 | 277 | def loss_nll(_sender_input, _message, _receiver_input, receiver_output, labels): 278 | """ 279 | NLL loss - differentiable and can be used with both GS and SFE 280 | """ 281 | nll = F.nll_loss(receiver_output, labels, reduction="none") 282 | # receiver outputs are logits 283 | acc = (labels == receiver_output.argmax(dim=-1)).float() 284 | return nll, {'acc': acc} 285 | 286 | 287 | def get_model(opt): 288 | game = SignalGame( 289 | feat_size=4096, 290 | embedding_size=opt.embedding_size, 291 | hidden_size=opt.hidden_size, 292 | game_size=opt.game_size, 293 | tau_s=opt.tau_s, 294 | loss_type=opt.loss_type, 295 | root=opt.root, 296 | same=opt.same, 297 | mode=opt.mode, 298 | entropy_coeff=opt.entropy_coeff, 299 | vocab_size=opt.latent_size, 300 | normalizer=opt.normalizer, 301 | gs_tau=opt.gs_tau, 302 | temperature_decay=opt.temperature_decay, 303 | temperature_update_freq=opt.temperature_update_freq, 304 | straight_through=opt.straight_through, 305 | baseline_type=opt.baseline_type, 306 | topk=opt.topk, 307 | random_seed=opt.random_seed, 308 | batch_size=opt.batch_size, 309 | lr=opt.lr, 310 | weight_decay=opt.weight_decay, 311 | optimizer=opt.optimizer) 312 | 313 | return game 314 | 315 | 316 | def main(params): 317 | 318 | arg_parser = argparse.ArgumentParser() 319 | arg_parser = populate_experiment_params(arg_parser) 320 | arg_parser = populate_common_params(arg_parser) 321 | opts = arg_parser.parse_args(params) 322 | 323 | # fix seed 324 | pl.seed_everything(opts.random_seed) 325 | 326 | signal_game = get_model(opts) 327 | 328 | experiment_name = 'signal-game' 329 | model_name = '%s/%s' % (experiment_name, opts.mode) 330 | other_info = [ 331 | "lr-{}".format(opts.lr), 332 | ] 333 | 334 | other_info.append("entrcoeff-{}".format(opts.entropy_coeff)) 335 | 336 | if opts.mode == "gs": 337 | if opts.straight_through: 338 | other_info.append("straight_through") 339 | other_info.append("decay-{}".format(opts.temperature_decay)) 340 | other_info.append("updatefreq-{}".format(opts.temperature_update_freq)) 341 | elif opts.mode == 'sfe': 342 | other_info.append("baseline-{}".format(opts.baseline_type)) 343 | elif opts.mode == "marg": 344 | other_info.append("norm-{}".format(opts.normalizer)) 345 | elif opts.mode == 'sumsample': 346 | other_info.append("k-{}".format(opts.topk)) 347 | other_info.append("baseline-{}".format(opts.baseline_type)) 348 | 349 | model_name = '%s/%s' % (model_name, '_'.join(other_info)) 350 | 351 | tb_logger = pl_loggers.TensorBoardLogger( 352 | 'logs/', 353 | name=model_name) 354 | 355 | tb_logger.log_hyperparams(opts, metrics=None) 356 | 357 | trainer = pl.Trainer( 358 | progress_bar_refresh_rate=20, 359 | logger=tb_logger, 360 | callbacks=[CheckpointEveryNSteps(opts.batches_per_epoch)], 361 | max_steps=opts.batches_per_epoch*opts.n_epochs, 362 | limit_val_batches=1024/opts.batch_size, 363 | limit_test_batches=10000//opts.batch_size, 364 | val_check_interval=opts.batches_per_epoch, 365 | weights_save_path='checkpoints/', 366 | weights_summary='full', 367 | gpus=1 if torch.cuda.is_available() else 0, 368 | resume_from_checkpoint=opts.load_from_checkpoint, 369 | deterministic=True) 370 | trainer.fit(signal_game) 371 | 372 | 373 | if __name__ == '__main__': 374 | import sys 375 | main(sys.argv[1:]) 376 | -------------------------------------------------------------------------------- /lvmhelpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/sparse-marginalization-lvm/9d27ec0fd58ea2eb58d79bca23a0051f1b143ebe/lvmhelpers/__init__.py -------------------------------------------------------------------------------- /lvmhelpers/bernoulli.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ad3/GenericFactor.h" 4 | 5 | namespace AD3 { 6 | 7 | class FactorBernoulli : public GenericFactor 8 | { 9 | public: 10 | FactorBernoulli() {} 11 | virtual ~FactorBernoulli() { ClearActiveSet(); } 12 | 13 | // Obtain the best configuration. 14 | void Maximize(const vector& variable_log_potentials, 15 | const vector&, 16 | Configuration& configuration, 17 | double* value) 18 | override 19 | { 20 | *value = 0; 21 | vector* cfg = 22 | static_cast*>(configuration); 23 | for (int i = 0; i < length_; ++i) 24 | { 25 | if (variable_log_potentials[i] > variable_log_potentials[length_ + i]) 26 | { 27 | (*cfg)[i] = 1; 28 | *value += variable_log_potentials[i]; 29 | } else 30 | { 31 | (*cfg)[i] = 0; 32 | *value += variable_log_potentials[length_ + i]; 33 | } 34 | } 35 | } 36 | 37 | // Compute the score of a given assignment. 38 | void Evaluate(const vector& variable_log_potentials, 39 | const vector&, 40 | const Configuration configuration, 41 | double* value) 42 | override 43 | { 44 | const vector* sequence = 45 | static_cast*>(configuration); 46 | *value = 0.0; 47 | for (int i = 0; i < length_; ++i) 48 | { 49 | if ((*sequence)[i] == 1) 50 | *value += variable_log_potentials[i]; 51 | else 52 | *value += variable_log_potentials[length_ + i]; 53 | } 54 | } 55 | 56 | // Given a configuration with a probability (weight), 57 | // increment the vectors of variable and additional posteriors. 58 | void UpdateMarginalsFromConfiguration(const Configuration& configuration, 59 | double weight, 60 | vector* variable_posteriors, 61 | vector*) 62 | override 63 | { 64 | const vector* sequence = 65 | static_cast*>(configuration); 66 | for (int i = 0; i < length_; ++i) 67 | { 68 | if ((*sequence)[i] == 1) 69 | (*variable_posteriors)[i] += weight; 70 | else 71 | (*variable_posteriors)[length_ + i] += weight; 72 | } 73 | } 74 | 75 | // Count how many common values two configurations have. 76 | int CountCommonValues(const Configuration& configuration1, 77 | const Configuration& configuration2) 78 | override 79 | { 80 | const vector* sequence1 = 81 | static_cast*>(configuration1); 82 | const vector* sequence2 = 83 | static_cast*>(configuration2); 84 | assert(sequence1->size() == sequence2->size()); 85 | int count = 0; 86 | for (int i = 0; i < sequence1->size(); ++i) { 87 | if ((*sequence1)[i] == (*sequence2)[i]) 88 | ++count; 89 | } 90 | return count; 91 | } 92 | 93 | // Check if two configurations are the same. 94 | bool SameConfiguration(const Configuration& configuration1, 95 | const Configuration& configuration2) 96 | override 97 | { 98 | const vector* sequence1 = 99 | static_cast*>(configuration1); 100 | const vector* sequence2 = 101 | static_cast*>(configuration2); 102 | 103 | assert(sequence1->size() == sequence2->size()); 104 | for (int i = 0; i < sequence1->size(); ++i) { 105 | if ((*sequence1)[i] != (*sequence2)[i]) 106 | return false; 107 | } 108 | return true; 109 | } 110 | 111 | // Delete configuration. 112 | void DeleteConfiguration(Configuration configuration) 113 | override 114 | { 115 | vector* sequence = static_cast*>(configuration); 116 | delete sequence; 117 | } 118 | 119 | Configuration CreateConfiguration() 120 | override 121 | { 122 | vector* sequence = new vector(length_, -1); 123 | return static_cast(sequence); 124 | } 125 | 126 | public: 127 | void Initialize(int length) 128 | { 129 | length_ = length; 130 | } 131 | 132 | virtual size_t GetNumAdditionals() override { return num_additionals_; } 133 | 134 | protected: 135 | // Number of states for each position. 136 | int length_; 137 | int num_additionals_ = 0; 138 | }; 139 | 140 | } // namespace AD3 141 | -------------------------------------------------------------------------------- /lvmhelpers/binary_topk.h: -------------------------------------------------------------------------------- 1 | // k-best assignments for independent binary variables 2 | // (optimized version of zeroth order viterbi) 3 | // author: vlad niculae 4 | // license: mit 5 | 6 | #pragma once 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | 13 | #define CODE_MAX_SIZE 512 14 | 15 | typedef std::bitset configuration; 16 | typedef float floating; 17 | 18 | struct scored_cfg 19 | { 20 | floating score; 21 | configuration cfg; 22 | 23 | /* 24 | * set the ith variable on, adding v to score 25 | */ 26 | scored_cfg update(unsigned i, floating v) 27 | { 28 | scored_cfg out {this->score + v, this->cfg}; 29 | out.cfg[i] = 1; 30 | return out; 31 | } 32 | 33 | std::vector cfg_vector(int dim) 34 | { 35 | std::vector out; 36 | for (int i = 0; i < dim; ++i) 37 | out.push_back((int) cfg.test(i)); 38 | return out; 39 | } 40 | 41 | void populate_out(floating* out, int dim) 42 | { 43 | for (int i = 0; i < dim; ++i) 44 | out[i] = (float) cfg.test(i); 45 | } 46 | }; 47 | 48 | 49 | /* 50 | * Merge the two k-best lists, depending if the next state is 0 or 1. 51 | * floatinghe k-best for 0 is (a_begin, a_end) 52 | * floatinghe k-best for 1 is [(b.score + val, b.cfg + [1]) for b in (a_begin, a_end)] 53 | * 54 | * Implementation is standard list merge, stopping once we produced k items. 55 | * We also avoid building the b vector. 56 | */ 57 | std::vector::iterator 58 | merge_branch(std::vector::iterator a_begin, 59 | std::vector::iterator a_end, 60 | unsigned i, 61 | floating val, 62 | std::vector::iterator out_begin, 63 | int k) 64 | { 65 | 66 | auto b_begin = a_begin; 67 | auto b_end = a_end; 68 | int inserted = 0; 69 | 70 | while((inserted < k) & (a_begin != a_end) & (b_begin != b_end)) { 71 | auto b_begin_item = b_begin->update(i, val); 72 | if (b_begin_item.score > a_begin->score) { 73 | *out_begin = b_begin_item; 74 | ++b_begin; 75 | } else { 76 | *out_begin = *a_begin; 77 | ++a_begin; 78 | } 79 | ++out_begin; 80 | ++inserted; 81 | } 82 | 83 | while((inserted < k) & (a_begin != a_end)) { 84 | *out_begin = *a_begin; 85 | ++a_begin; 86 | ++out_begin; 87 | ++inserted; 88 | } 89 | 90 | while((inserted < k) & (b_begin != b_end)) { 91 | *out_begin = b_begin->update(i, val); 92 | ++b_begin; 93 | ++out_begin; 94 | ++inserted; 95 | } 96 | 97 | return out_begin; 98 | } 99 | 100 | std::vector topk(const std::vector& x, int k) 101 | { 102 | assert(k > 1); 103 | // partial configuration starting with 0 104 | scored_cfg c0 = {0, 0}; 105 | 106 | // partial configuration starting with 1 107 | scored_cfg c1 = c0.update(0, x[0]); 108 | 109 | std::vector curr(k), next(k); 110 | if (x[0] >= 0) { 111 | curr[0] = c1; 112 | curr[1] = c0; 113 | } else { 114 | curr[0] = c0; 115 | curr[1] = c1; 116 | } 117 | 118 | auto curr_begin = curr.begin(); 119 | auto curr_end = curr_begin + 2; 120 | auto next_begin = next.begin(); 121 | auto next_end = next_begin; 122 | 123 | for (unsigned i = 1; i < x.size(); ++i) { 124 | next_end = merge_branch(curr_begin, curr_end, i, x[i], next_begin, k); 125 | std::swap(curr_begin, next_begin); 126 | std::swap(curr_end, next_end); 127 | } 128 | return std::vector (curr_begin, curr_end); 129 | } 130 | 131 | std::vector topk(floating* x, int size, int k) 132 | { 133 | std::vector xvec(x, x + size); 134 | return topk(xvec, k); 135 | } 136 | 137 | -------------------------------------------------------------------------------- /lvmhelpers/budget.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "ad3/Factor.h" 6 | #include "ad3/GenericFactor.h" 7 | 8 | #include "bernoulli.h" 9 | 10 | namespace AD3 { 11 | class FactorBudget : public FactorBernoulli 12 | { 13 | public: 14 | 15 | FactorBudget() {} 16 | virtual ~FactorBudget() { ClearActiveSet(); } 17 | 18 | void Initialize(int length, int budget) 19 | { 20 | length_ = length; 21 | budget_ = budget; 22 | } 23 | 24 | void Maximize(const vector& variable_log_potentials, 25 | const vector& add_, 26 | Configuration& configuration, 27 | double* value) 28 | override 29 | { 30 | // Create a local copy of the log potentials. 31 | vector eta(length_); 32 | vector* y = static_cast*>(configuration); 33 | *value = 0.0; 34 | 35 | // start with all variables off 36 | for (int i = 0; i < length_; ++i) 37 | { 38 | eta.at(i) = (variable_log_potentials.at(i) - 39 | variable_log_potentials.at(length_ + i)); 40 | *value += variable_log_potentials.at(length_ + i); 41 | y->at(i) = 0; 42 | } 43 | 44 | double valaux; 45 | 46 | // first, try including everything with positive score 47 | size_t num_active = 0; 48 | double sum = 0.0; 49 | for (size_t i = 0; i < length_; ++i) { 50 | if (eta[i] > 0) { 51 | sum += eta[i]; 52 | y->at(i) = 1; 53 | ++num_active; 54 | } 55 | } 56 | 57 | // if we went over budget, we sort, and only include the top 58 | if (num_active > budget_) 59 | { 60 | vector> scores(length_); 61 | for (size_t i = 0; i < length_; ++i) { 62 | scores.at(i).first = -eta.at(i); 63 | scores.at(i).second = i; 64 | } 65 | sort(scores.begin(), scores.end()); 66 | num_active = 0; 67 | sum = 0.0; 68 | for (size_t k = 0; k < budget_; ++k) { 69 | valaux = -scores[k].first; 70 | if (valaux < 0) 71 | break; 72 | int i = scores[k].second; 73 | y->at(i) = 1; 74 | sum += valaux; 75 | ++num_active; 76 | } 77 | 78 | for (size_t k = num_active; k < length_; ++k) { 79 | int i = scores[k].second; 80 | y->at(i) = 0; 81 | } 82 | } 83 | 84 | *value += sum; 85 | } 86 | 87 | protected: 88 | int budget_; 89 | }; 90 | 91 | } 92 | -------------------------------------------------------------------------------- /lvmhelpers/gumbel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import RelaxedOneHotCategorical, RelaxedBernoulli 4 | from torch.distributions import Categorical, Bernoulli 5 | 6 | 7 | def gumbel_softmax_sample( 8 | logits: torch.Tensor, 9 | temperature: float = 1.0, 10 | straight_through: bool = False): 11 | """Samples from a Gumbel-Sotmax/Concrete of a Categorical distribution. 12 | More details in: 13 | - Gumbel-Softmax: https://arxiv.org/abs/1611.01144 14 | - Concrete distribution: https://arxiv.org/abs/1611.00712 15 | 16 | Arguments: 17 | logits {torch.Tensor} -- tensor of logits, the output of an inference network. 18 | Size: [batch_size, n_categories] 19 | 20 | Keyword Arguments: 21 | temperature {float} -- temperature of the softmax relaxation. The lower the 22 | temperature (-->0), the closer the sample is to a discrete sample. 23 | (default: {1.0}) 24 | straight_through {bool} -- Whether to use the straight-through estimator. 25 | (default: {False}) 26 | 27 | Returns: 28 | torch.Tensor -- the relaxed sample. 29 | Size: [batch_size, n_categories] 30 | """ 31 | 32 | sample = RelaxedOneHotCategorical( 33 | logits=logits, temperature=temperature).rsample() 34 | 35 | if straight_through: 36 | size = sample.size() 37 | indexes = sample.argmax(dim=-1) 38 | hard_sample = torch.zeros_like(sample).view(-1, size[-1]) 39 | hard_sample.scatter_(1, indexes.view(-1, 1), 1) 40 | hard_sample = hard_sample.view(*size) 41 | 42 | sample = sample + (hard_sample - sample).detach() 43 | return sample 44 | 45 | 46 | def gumbel_softmax_bit_vector_sample( 47 | logits: torch.Tensor, 48 | temperature: float = 1.0, 49 | straight_through: bool = False): 50 | """Samples from a Gumbel-Sotmax/Concrete of independent Bernoulli distributions. 51 | More details in: 52 | - Gumbel-Softmax: https://arxiv.org/abs/1611.01144 53 | - Concrete distribution: https://arxiv.org/abs/1611.00712 54 | 55 | Arguments: 56 | logits {torch.Tensor} -- tensor of logits, the output of an inference network. 57 | Size: [batch_size, n_bits] 58 | 59 | Keyword Arguments: 60 | temperature {float} -- temperature of the softmax relaxation. The lower the 61 | temperature (-->0), the closer the sample is to discrete samples. 62 | (default: {1.0}) 63 | straight_through {bool} -- Whether to use the straight-through estimator. 64 | (default: {False}) 65 | 66 | Returns: 67 | torch.Tensor -- the relaxed sample. 68 | Size: [batch_size, n_bits] 69 | """ 70 | 71 | sample = RelaxedBernoulli( 72 | logits=logits, temperature=temperature).rsample() 73 | 74 | if straight_through: 75 | hard_sample = (logits > 0).to(torch.float) 76 | sample = sample + (hard_sample - sample).detach() 77 | 78 | return sample 79 | 80 | 81 | class GumbelSoftmaxWrapper(nn.Module): 82 | """ 83 | Gumbel-Softmax Wrapper for a network that parameterizes a Categorical distribution. 84 | Assumes that during the forward pass, 85 | the network returns scores over the potential output categories. 86 | The wrapper transforms them into a sample from the Gumbel-Softmax (GS) distribution. 87 | """ 88 | 89 | def __init__(self, 90 | agent, 91 | temperature=1.0, 92 | trainable_temperature=False, 93 | straight_through=False): 94 | """ 95 | Arguments: 96 | agent -- The agent to be wrapped. agent.forward() has to output 97 | scores over the categories 98 | 99 | Keyword Arguments: 100 | temperature {float} -- The temperature of the Gumbel-Softmax distribution 101 | (default: {1.0}) 102 | trainable_temperature {bool} -- If set to True, the temperature becomes 103 | a trainable parameter of the model (default: {False}) 104 | straight_through {bool} -- Whether straigh-through Gumbel-Softmax is used 105 | (default: {False}) 106 | """ 107 | super(GumbelSoftmaxWrapper, self).__init__() 108 | self.agent = agent 109 | self.straight_through = straight_through 110 | if not trainable_temperature: 111 | self.temperature = temperature 112 | else: 113 | self.temperature = torch.nn.Parameter( 114 | torch.tensor([temperature]), requires_grad=True) 115 | 116 | self.distr_type = Categorical 117 | 118 | def forward(self, *args, **kwargs): 119 | """Forward pass. 120 | 121 | Returns: 122 | sample {torch.Tensor} -- Gumbel-Softmax relaxed sample. 123 | Size: [batch_size, n_categories] 124 | scores {torch.Tensor} -- the output of the network. 125 | Can be useful for logging purposes. 126 | Size: [batch_size, n_categories] 127 | entropy {torch.Tensor} -- the entropy of the distribution. 128 | We assume a Categorical to compute this, which is common-practice, 129 | but may not be ideal. See appendix in https://arxiv.org/abs/1611.00712 130 | Size: [batch_size] 131 | """ 132 | scores = self.agent(*args, **kwargs) 133 | sample = gumbel_softmax_sample( 134 | scores, self.temperature, self.straight_through) 135 | distr = self.distr_type(logits=scores) 136 | entropy = distr.entropy() 137 | return sample, scores, entropy 138 | 139 | def update_temperature( 140 | self, 141 | current_step: int, 142 | temperature_update_freq: int, 143 | temperature_decay: float): 144 | """use this at the end of each training step to anneal the temperature according 145 | to max(0.5, exp(-rt)) with r and t being the decay rate and training step, 146 | respectively. 147 | 148 | Arguments: 149 | current_step {int} -- current global step in the training process 150 | temperature_update_freq {int} -- how often to update the temperature 151 | temperature_decay {float} -- decay rate r 152 | """ 153 | if current_step % temperature_update_freq == 0: 154 | rt = temperature_decay * torch.tensor(current_step) 155 | self.temperature = torch.max( 156 | torch.tensor(0.5), torch.exp(-rt)) 157 | 158 | 159 | class Gumbel(torch.nn.Module): 160 | """ 161 | The training loop for the Gumbel-Softmax method to train discrete latent variables. 162 | Encoder needs to be GumbelSoftmaxWrapper. 163 | Decoder needs to be utils.DeterministicWrapper. 164 | """ 165 | def __init__( 166 | self, 167 | encoder, 168 | decoder, 169 | loss_fun, 170 | encoder_entropy_coeff=0.0, 171 | decoder_entropy_coeff=0.0): 172 | super(Gumbel, self).__init__() 173 | self.encoder = encoder 174 | self.decoder = decoder 175 | self.loss = loss_fun 176 | self.encoder_entropy_coeff = encoder_entropy_coeff 177 | self.decoder_entropy_coeff = decoder_entropy_coeff 178 | 179 | def forward(self, encoder_input, decoder_input, labels): 180 | discrete_latent_z, encoder_scores, encoder_entropy = self.encoder(encoder_input) 181 | decoder_output = self.decoder(discrete_latent_z, decoder_input) 182 | 183 | # entropy component of the final loss, we can 184 | # compute already but will only use it later on 185 | entropy_loss = -(encoder_entropy.mean() * self.encoder_entropy_coeff) 186 | 187 | argmax = encoder_scores.argmax(dim=-1) 188 | 189 | loss, logs = self.loss( 190 | encoder_input, 191 | argmax, 192 | decoder_input, 193 | decoder_output, 194 | labels) 195 | 196 | full_loss = loss.mean() + entropy_loss 197 | 198 | for k, v in logs.items(): 199 | if hasattr(v, 'mean'): 200 | logs[k] = v.mean() 201 | 202 | logs['loss'] = loss.mean() 203 | logs['encoder_entropy'] = encoder_entropy.mean() 204 | logs['distr'] = self.encoder.distr_type(logits=encoder_scores) 205 | return {'loss': full_loss, 'log': logs} 206 | 207 | 208 | class BitVectorGumbelSoftmaxWrapper(GumbelSoftmaxWrapper): 209 | """ 210 | Gumbel-Softmax Wrapper for a network that parameterizes 211 | independent Bernoulli distributions. 212 | Assumes that during the forward pass, 213 | the network returns scores for the Bernoulli parameters. 214 | The wrapper transforms them into a sample from the Gumbel-Softmax (GS) distribution. 215 | """ 216 | def __init__(self, 217 | agent, 218 | temperature=1.0, 219 | trainable_temperature=False, 220 | straight_through=False): 221 | """ 222 | Arguments: 223 | agent -- The agent to be wrapped. agent.forward() has to output 224 | scores for each Bernoulli 225 | 226 | Keyword Arguments: 227 | temperature {float} -- The temperature of the Gumbel-Softmax distribution 228 | (default: {1.0}) 229 | trainable_temperature {bool} -- If set to True, the temperature becomes 230 | a trainable parameter of the model (default: {False}) 231 | straight_through {bool} -- Whether straigh-through Gumbel-Softmax is used 232 | (default: {False}) 233 | """ 234 | super(BitVectorGumbelSoftmaxWrapper, self).__init__( 235 | agent, 236 | temperature, 237 | trainable_temperature, 238 | straight_through) 239 | self.distr_type = Bernoulli 240 | 241 | def forward(self, *args, **kwargs): 242 | """Forward pass. 243 | 244 | Returns: 245 | sample {torch.Tensor} -- Gumbel-Softmax relaxed sample. 246 | Size: [batch_size, n_bits] 247 | scores {torch.Tensor} -- the output of the network. 248 | Can be useful for logging purposes. 249 | Size: [batch_size, n_bits] 250 | entropy {torch.Tensor} -- the entropy of the distribution. 251 | We assume independent Bernoulli to compute this, which is common-practice, 252 | but may not be ideal. See appendix in https://arxiv.org/abs/1611.00712 253 | Size: [batch_size] 254 | """ 255 | scores = self.agent(*args, **kwargs) 256 | sample = gumbel_softmax_bit_vector_sample( 257 | scores, self.temperature, self.straight_through) 258 | distr = self.distr_type(logits=scores) 259 | entropy = distr.entropy().sum(dim=-1) 260 | return sample, scores, entropy 261 | 262 | 263 | class BitVectorGumbel(torch.nn.Module): 264 | """ 265 | The training loop for the Gumbel-Softmax method to train a 266 | bit-vector of independent latent variables. 267 | Encoder needs to be BitVectorGumbelSoftmaxWrapper. 268 | Decoder needs to be utils.DeterministicWrapper. 269 | """ 270 | def __init__( 271 | self, 272 | encoder, 273 | decoder, 274 | loss_fun, 275 | encoder_entropy_coeff=0.0, 276 | decoder_entropy_coeff=0.0): 277 | super(BitVectorGumbel, self).__init__() 278 | self.encoder = encoder 279 | self.decoder = decoder 280 | self.loss = loss_fun 281 | self.encoder_entropy_coeff = encoder_entropy_coeff 282 | self.decoder_entropy_coeff = decoder_entropy_coeff 283 | 284 | def forward(self, encoder_input, decoder_input, labels): 285 | discrete_latent_z, encoder_scores, encoder_entropy = self.encoder(encoder_input) 286 | decoder_output = self.decoder(discrete_latent_z, decoder_input) 287 | 288 | # entropy component of the final loss, we can 289 | # compute already but will only use it later on 290 | entropy_loss = -(encoder_entropy.mean() * self.encoder_entropy_coeff) 291 | 292 | argmax = (encoder_scores > 0).to(torch.float) 293 | 294 | loss, logs = self.loss( 295 | encoder_input, 296 | argmax, 297 | decoder_input, 298 | decoder_output, 299 | labels) 300 | 301 | full_loss = loss.mean() + entropy_loss 302 | 303 | for k, v in logs.items(): 304 | if hasattr(v, 'mean'): 305 | logs[k] = v.mean() 306 | 307 | logs['loss'] = loss.mean() 308 | logs['encoder_entropy'] = encoder_entropy.mean() 309 | logs['distr'] = self.encoder.distr_type(logits=encoder_scores) 310 | return {'loss': full_loss, 'log': logs} 311 | -------------------------------------------------------------------------------- /lvmhelpers/marg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from entmax import entmax15, sparsemax 4 | 5 | 6 | def entropy(p: torch.Tensor): 7 | """Numerically stable computation of Shannon's entropy 8 | for probability distributions with zero-valued elements. 9 | 10 | Arguments: 11 | p {torch.Tensor} -- tensor of probabilities. 12 | Size: [batch_size, n_categories] 13 | 14 | Returns: 15 | {torch.Tensor} -- the entropy of p. 16 | Size: [batch_size] 17 | """ 18 | nz = (p > 0).to(p.device) 19 | 20 | eps = torch.finfo(p.dtype).eps 21 | p_stable = p.clone().clamp(min=eps, max=1 - eps) 22 | 23 | out = torch.where( 24 | nz, 25 | p_stable * torch.log(p_stable), 26 | torch.tensor(0., device=p.device, dtype=torch.float)) 27 | 28 | return -(out).sum(-1) 29 | 30 | 31 | class ExplicitWrapper(nn.Module): 32 | """ 33 | Explicit Marginalization Wrapper for a network. 34 | Assumes that the during the forward pass, 35 | the network returns scores over the potential output categories. 36 | The wrapper transforms them into a tuple of (sample from the Categorical, 37 | log-prob of the sample, entropy for the Categorical). 38 | """ 39 | def __init__(self, agent, normalizer='entmax'): 40 | super(ExplicitWrapper, self).__init__() 41 | self.agent = agent 42 | 43 | normalizer_dict = { 44 | 'softmax': torch.softmax, 45 | 'sparsemax': sparsemax, 46 | 'entmax': entmax15} 47 | self.normalizer = normalizer_dict[normalizer] 48 | 49 | def forward(self, *args, **kwargs): 50 | scores = self.agent(*args, **kwargs) 51 | distr = self.normalizer(scores, dim=-1) 52 | entropy_distr = entropy(distr) 53 | sample = scores.argmax(dim=-1) 54 | return sample, distr, entropy_distr 55 | 56 | 57 | class Marginalizer(torch.nn.Module): 58 | """ 59 | The training loop for the marginalization method to train discrete latent variables. 60 | Encoder needs to be ExplicitWrapper. 61 | Decoder needs to be utils.DeterministicWrapper. 62 | """ 63 | def __init__( 64 | self, 65 | encoder, 66 | decoder, 67 | loss_fun, 68 | encoder_entropy_coeff=0.0, 69 | decoder_entropy_coeff=0.0): 70 | super(Marginalizer, self).__init__() 71 | self.encoder = encoder 72 | self.decoder = decoder 73 | self.loss = loss_fun 74 | self.encoder_entropy_coeff = encoder_entropy_coeff 75 | self.decoder_entropy_coeff = decoder_entropy_coeff 76 | 77 | def forward(self, encoder_input, decoder_input, labels): 78 | discrete_latent_z, encoder_probs, encoder_entropy = self.encoder(encoder_input) 79 | batch_size, latent_size = encoder_probs.shape 80 | 81 | entropy_loss = -(encoder_entropy.mean() * self.encoder_entropy_coeff) 82 | 83 | losses = torch.zeros_like(encoder_probs) 84 | logs_global = None 85 | 86 | for possible_discrete_latent_z in range(latent_size): 87 | if encoder_probs[:, possible_discrete_latent_z].sum().detach() != 0: 88 | # if it's zero, all batch examples 89 | # will be multiplied by zero anyway, 90 | # so skip computations 91 | possible_discrete_latent_z_ = \ 92 | possible_discrete_latent_z + \ 93 | torch.zeros( 94 | batch_size, dtype=torch.long).to(encoder_probs.device) 95 | decoder_output = self.decoder( 96 | possible_discrete_latent_z_, decoder_input) 97 | 98 | loss_sum_term, logs = self.loss( 99 | encoder_input, 100 | discrete_latent_z, 101 | decoder_input, 102 | decoder_output, 103 | labels) 104 | 105 | losses[:, possible_discrete_latent_z] += loss_sum_term 106 | 107 | if not logs_global: 108 | logs_global = {k: 0.0 for k in logs.keys()} 109 | for k, v in logs.items(): 110 | if hasattr(v, 'mean'): 111 | # expectation of accuracy 112 | logs_global[k] += ( 113 | encoder_probs[:, possible_discrete_latent_z] * v).mean() 114 | 115 | for k, v in logs.items(): 116 | if hasattr(v, 'mean'): 117 | logs[k] = logs_global[k] 118 | 119 | # encoder_probs: [batch_size, latent_size] 120 | # losses: [batch_size, latent_size] 121 | # encoder_probs.unsqueeze(1): [batch_size, 1, latent_size] 122 | # losses.unsqueeze(-1): [batch_size, latent_size, 1] 123 | # entropy_loss: [] 124 | # full_loss: [] 125 | loss = encoder_probs.unsqueeze(1).bmm(losses.unsqueeze(-1)).squeeze() 126 | full_loss = loss.mean() + entropy_loss.mean() 127 | 128 | logs['loss'] = loss.mean() 129 | logs['encoder_entropy'] = encoder_entropy.mean() 130 | logs['support'] = (encoder_probs != 0).sum(-1).to(torch.float).mean() 131 | logs['distr'] = encoder_probs 132 | return {'loss': full_loss, 'log': logs} 133 | -------------------------------------------------------------------------------- /lvmhelpers/nvil.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import Categorical, Bernoulli 4 | 5 | 6 | class BaselineNN(nn.Module): 7 | """ 8 | Neural network that outputs the baseline for NVIL. 9 | """ 10 | def __init__(self, input_size): 11 | super(BaselineNN, self).__init__() 12 | 13 | self.input_size = input_size 14 | # define the linear layers 15 | self.fc1 = nn.Linear(self.input_size, 512) 16 | self.fc2 = nn.Linear(512, 512) 17 | self.fc3 = nn.Linear(512, 512) 18 | self.fc4 = nn.Linear(512, 1) 19 | 20 | def forward(self, image): 21 | 22 | # feed through neural network 23 | h = image.view(-1, self.input_size) 24 | 25 | h = torch.relu(self.fc1(h)) 26 | h = torch.relu(self.fc2(h)) 27 | h = torch.relu(self.fc3(h)) 28 | h = self.fc4(h) 29 | # h: [batch_size] 30 | return h 31 | 32 | 33 | class NVILWrapper(nn.Module): 34 | """ 35 | NVIL Wrapper for a network. Assumes that the during the forward pass, 36 | the network returns scores over the potential output categories. 37 | The wrapper transforms them into a tuple of (sample from the Categorical, 38 | log-prob of the sample, entropy for the Categorical). 39 | """ 40 | def __init__(self, agent, input_size): 41 | super(NVILWrapper, self).__init__() 42 | self.agent = agent 43 | self.baseline_nn = BaselineNN(input_size) 44 | 45 | def forward(self, *args, **kwargs): 46 | """Forward pass. 47 | 48 | Returns: 49 | sample {torch.Tensor} -- Categorical sample. 50 | Size: [batch_size] 51 | scores {torch.Tensor} -- the output of the network. 52 | Important to compute the policy component of the loss. 53 | Size: [batch_size, n_categories] 54 | entropy {torch.Tensor} -- the entropy of the Categorical distribution 55 | parameterized by the scores. 56 | Size: [batch_size] 57 | """ 58 | scores = self.agent(*args, **kwargs) 59 | 60 | distr = Categorical(logits=scores) 61 | entropy = distr.entropy() 62 | 63 | sample = distr.sample() 64 | 65 | return sample, scores, entropy 66 | 67 | 68 | class NVIL(torch.nn.Module): 69 | """ 70 | The training loop for the NVIL method to train discrete latent variables. 71 | Encoder needs to be NVILWrapper. 72 | Decoder needs to be utils.DeterministicWrapper. 73 | """ 74 | def __init__( 75 | self, 76 | encoder, 77 | decoder, 78 | loss_fun, 79 | encoder_entropy_coeff=0.0, 80 | decoder_entropy_coeff=0.0): 81 | super(NVIL, self).__init__() 82 | self.encoder = encoder 83 | self.decoder = decoder 84 | self.loss = loss_fun 85 | self.encoder_entropy_coeff = encoder_entropy_coeff 86 | self.decoder_entropy_coeff = decoder_entropy_coeff 87 | 88 | def forward(self, encoder_input, decoder_input, labels): 89 | discrete_latent_z, encoder_scores, encoder_entropy = \ 90 | self.encoder(encoder_input) 91 | decoder_output = \ 92 | self.decoder(discrete_latent_z, decoder_input) 93 | 94 | argmax = encoder_scores.argmax(dim=-1) 95 | 96 | loss, logs = self.loss( 97 | encoder_input, 98 | argmax, 99 | decoder_input, 100 | decoder_output, 101 | labels) 102 | 103 | encoder_categorical_helper = Categorical(logits=encoder_scores) 104 | encoder_sample_log_probs = encoder_categorical_helper.log_prob(discrete_latent_z) 105 | 106 | baseline = self.encoder.baseline_nn(encoder_input).squeeze() 107 | 108 | baseline = baseline.reshape(-1, loss.size(0)).mean(dim=0).squeeze() 109 | 110 | policy_loss = ((loss - baseline).detach() * encoder_sample_log_probs).mean() 111 | entropy_loss = -(encoder_entropy.mean() * self.encoder_entropy_coeff) 112 | mse = ((loss.detach()-baseline)**2).mean() 113 | 114 | full_loss = policy_loss + loss.mean() + mse + entropy_loss 115 | 116 | for k, v in logs.items(): 117 | if hasattr(v, 'mean'): 118 | logs[k] = v.mean() 119 | 120 | logs['baseline'] = baseline 121 | logs['loss'] = loss.mean() 122 | logs['encoder_entropy'] = encoder_entropy.mean() 123 | logs['distr'] = encoder_categorical_helper.probs 124 | 125 | return {'loss': full_loss, 'log': logs} 126 | 127 | 128 | class BitVectorNVILWrapper(nn.Module): 129 | """ 130 | NVIL Wrapper for a network that parameterizes 131 | independent Bernoulli distributions. 132 | Assumes that the during the forward pass, 133 | the network returns scores for the Bernoulli parameters. 134 | The wrapper transforms them into a tuple of (sample from the Bernoulli, 135 | log-prob of the sample, entropy for the independent Bernoulli). 136 | """ 137 | def __init__(self, agent, input_size): 138 | super(BitVectorNVILWrapper, self).__init__() 139 | self.agent = agent 140 | self.baseline_nn = BaselineNN(input_size) 141 | 142 | def forward(self, *args, **kwargs): 143 | scores = self.agent(*args, **kwargs) 144 | 145 | distr = Bernoulli(logits=scores) 146 | entropy = distr.entropy().sum(dim=1) 147 | 148 | sample = distr.sample() 149 | 150 | return sample, scores, entropy 151 | 152 | 153 | class BitVectorNVIL(torch.nn.Module): 154 | """ 155 | The training loop for the NVIL method to train 156 | a bit-vector of independent latent variables. 157 | Encoder needs to be BitVectorNVILWrapper. 158 | Decoder needs to be utils.DeterministicWrapper. 159 | """ 160 | def __init__( 161 | self, 162 | encoder, 163 | decoder, 164 | loss_fun, 165 | encoder_entropy_coeff=0.0, 166 | decoder_entropy_coeff=0.0): 167 | super(BitVectorNVIL, self).__init__() 168 | self.encoder = encoder 169 | self.decoder = decoder 170 | self.loss = loss_fun 171 | self.encoder_entropy_coeff = encoder_entropy_coeff 172 | self.decoder_entropy_coeff = decoder_entropy_coeff 173 | self.mean_baseline = 0.0 174 | self.n_points = 0.0 175 | 176 | def forward(self, encoder_input, decoder_input, labels): 177 | discrete_latent_z, encoder_scores, encoder_entropy = \ 178 | self.encoder(encoder_input) 179 | decoder_output = self.decoder(discrete_latent_z, decoder_input) 180 | 181 | argmax = (encoder_scores > 0).to(torch.float) 182 | 183 | loss, logs = self.loss( 184 | encoder_input, 185 | argmax, 186 | decoder_input, 187 | decoder_output, 188 | labels) 189 | 190 | encoder_bernoull_distr = Bernoulli(logits=encoder_scores) 191 | encoder_sample_log_probs = \ 192 | encoder_bernoull_distr.log_prob(discrete_latent_z).sum(dim=1) 193 | 194 | baseline = self.encoder.baseline_nn(encoder_input).squeeze() 195 | 196 | policy_loss = (loss - baseline).detach() * encoder_sample_log_probs 197 | entropy_loss = - encoder_entropy * self.encoder_entropy_coeff 198 | mse = (loss.detach()-baseline)**2 199 | 200 | full_loss = (policy_loss + loss + mse + entropy_loss).mean() 201 | 202 | for k, v in logs.items(): 203 | if hasattr(v, 'mean'): 204 | logs[k] = v.mean() 205 | 206 | logs['baseline'] = baseline 207 | logs['loss'] = loss.mean() 208 | logs['encoder_entropy'] = encoder_entropy.mean() 209 | logs['distr'] = encoder_bernoull_distr 210 | 211 | return {'loss': full_loss, 'log': logs} 212 | -------------------------------------------------------------------------------- /lvmhelpers/pbernoulli.pxd: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # distutils: language=c++ 3 | 4 | from libcpp cimport bool 5 | from libcpp.vector cimport vector 6 | 7 | from lpsmap.ad3qp.base cimport Factor, GenericFactor, PGenericFactor 8 | 9 | cdef extern from "bernoulli.h" namespace "AD3": 10 | 11 | cdef cppclass FactorBernoulli(GenericFactor): 12 | FactorBernoulli() 13 | void Initialize(int length) 14 | -------------------------------------------------------------------------------- /lvmhelpers/pbernoulli.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # distutils: language=c++ 3 | 4 | from libcpp cimport bool 5 | from libcpp.vector cimport vector 6 | 7 | from lpsmap.ad3qp.base cimport Factor, GenericFactor, PGenericFactor 8 | 9 | 10 | cdef class PFactorBernoulli(PGenericFactor): 11 | 12 | def __cinit__(self, bool allocate=True): 13 | self.allocate = allocate 14 | if allocate: 15 | self.thisptr = new FactorBernoulli() 16 | 17 | def __dealloc__(self): 18 | if self.allocate: 19 | del self.thisptr 20 | 21 | def initialize(self, int length): 22 | (self.thisptr).Initialize(length) 23 | -------------------------------------------------------------------------------- /lvmhelpers/pbinary_topk.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # distutils: language=c++ 3 | 4 | from libcpp.vector cimport vector 5 | 6 | 7 | cdef extern from "binary_topk.h": 8 | 9 | cdef cppclass scored_cfg: 10 | float score 11 | vector[int] cfg_vector(int dim) 12 | void populate_out(float* out, int dim) 13 | 14 | vector[scored_cfg] topk(vector[float] x, int k) 15 | vector[scored_cfg] topk(float* x, int size, int k) 16 | 17 | 18 | cpdef binary_topk(vector[float] scores, int k): 19 | cdef vector[scored_cfg] out = topk(scores, k) 20 | return [(c.score, c.cfg_vector(scores.size())) for c in out] 21 | 22 | 23 | def batched_topk(float[:, :] scores, 24 | float[:, :, :] configs, 25 | int k): 26 | 27 | cdef vector[scored_cfg] out 28 | cdef int size = scores.shape[1] 29 | cdef int i, j 30 | 31 | for i in range(scores.shape[0]): 32 | out = topk(&scores[i, 0], size, k) 33 | for j in range(k): 34 | out[j].populate_out(&configs[i, j, 0], size) 35 | -------------------------------------------------------------------------------- /lvmhelpers/pbudget.pxd: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # distutils: language=c++ 3 | 4 | from libcpp cimport bool 5 | from libcpp.vector cimport vector 6 | 7 | from lpsmap.ad3qp.base cimport Factor, GenericFactor, PGenericFactor 8 | 9 | from .pbernoulli cimport FactorBernoulli 10 | 11 | cdef extern from "budget.h" namespace "AD3": 12 | 13 | cdef cppclass FactorBudget(FactorBernoulli): 14 | FactorBudget() 15 | void Initialize(int length, int budget) 16 | -------------------------------------------------------------------------------- /lvmhelpers/pbudget.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # distutils: language=c++ 3 | 4 | from libcpp cimport bool 5 | from libcpp.vector cimport vector 6 | 7 | from lpsmap.ad3qp.base cimport Factor, GenericFactor, PGenericFactor 8 | 9 | 10 | cdef class PFactorBudget(PGenericFactor): 11 | 12 | def __cinit__(self, bool allocate=True): 13 | self.allocate = allocate 14 | if allocate: 15 | self.thisptr = new FactorBudget() 16 | 17 | def __dealloc__(self): 18 | if self.allocate: 19 | del self.thisptr 20 | 21 | def initialize(self, int length, int budget): 22 | (self.thisptr).Initialize(length, budget) 23 | -------------------------------------------------------------------------------- /lvmhelpers/psequence.pxd: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # distutils: language=c++ 3 | 4 | from libcpp cimport bool 5 | from libcpp.vector cimport vector 6 | 7 | from lpsmap.ad3qp.base cimport Factor, GenericFactor, PGenericFactor 8 | 9 | cdef extern from "sequence.h" namespace "AD3": 10 | 11 | cdef cppclass FactorSequence(GenericFactor): 12 | FactorSequence() 13 | void Initialize(vector[int] num_states) 14 | 15 | 16 | cdef extern from "sequence_binary.h" namespace "AD3": 17 | 18 | cdef cppclass FactorSequenceBinary(FactorSequence): 19 | FactorSequenceBinary() 20 | void Initialize(int length) 21 | -------------------------------------------------------------------------------- /lvmhelpers/psequence.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # distutils: language=c++ 3 | 4 | from libcpp cimport bool 5 | from libcpp.vector cimport vector 6 | 7 | from lpsmap.ad3qp.base cimport PGenericFactor 8 | 9 | 10 | cdef class PFactorSequence(PGenericFactor): 11 | 12 | def __cinit__(self, bool allocate=True): 13 | self.allocate = allocate 14 | if allocate: 15 | self.thisptr = new FactorSequence() 16 | 17 | def __dealloc__(self): 18 | if self.allocate: 19 | del self.thisptr 20 | 21 | def initialize(self, vector[int] num_states): 22 | (self.thisptr).Initialize(num_states) 23 | 24 | 25 | cdef class PFactorSequenceBinary(PGenericFactor): 26 | 27 | def __cinit__(self, bool allocate=True): 28 | self.allocate = allocate 29 | if allocate: 30 | self.thisptr = new FactorSequenceBinary() 31 | 32 | def __dealloc__(self): 33 | if self.allocate: 34 | del self.thisptr 35 | 36 | def initialize(self, int length): 37 | (self.thisptr).Initialize(length) 38 | -------------------------------------------------------------------------------- /lvmhelpers/sequence.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2012 Andre Martins 2 | // All Rights Reserved. 3 | // 4 | // This file is part of AD3 2.1. 5 | // 6 | // AD3 2.1 is free software: you can redistribute it and/or modify 7 | // it under the terms of the GNU Lesser General Public License as published by 8 | // the Free Software Foundation, either version 3 of the License, or 9 | // (at your option) any later version. 10 | // 11 | // AD3 2.1 is distributed in the hope that it will be useful, 12 | // but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | // GNU Lesser General Public License for more details. 15 | // 16 | // You should have received a copy of the GNU Lesser General Public License 17 | // along with AD3 2.1. If not, see . 18 | 19 | #pragma once 20 | 21 | #include "ad3/GenericFactor.h" 22 | #include 23 | 24 | namespace AD3 { 25 | 26 | class FactorSequence : public GenericFactor 27 | { 28 | protected: 29 | virtual double GetNodeScore(int position, 30 | int state, 31 | const vector& variable_log_potentials, 32 | const vector& additional_log_potentials) 33 | { 34 | return variable_log_potentials[offset_states_[position] + state]; 35 | } 36 | 37 | // The edge connects node[position-1] to node[position]. 38 | virtual double GetEdgeScore(int position, 39 | int previous_state, 40 | int state, 41 | const vector& variable_log_potentials, 42 | const vector& additional_log_potentials) 43 | { 44 | int index = index_edges_[position][previous_state][state]; 45 | return additional_log_potentials[index]; 46 | } 47 | 48 | virtual void AddNodePosterior(int position, 49 | int state, 50 | double weight, 51 | vector* variable_posteriors, 52 | vector* additional_posteriors) 53 | { 54 | (*variable_posteriors)[offset_states_[position] + state] += weight; 55 | } 56 | 57 | // The edge connects node[position-1] to node[position]. 58 | virtual void AddEdgePosterior(int position, 59 | int previous_state, 60 | int state, 61 | double weight, 62 | vector* variable_posteriors, 63 | vector* additional_posteriors) 64 | { 65 | int index = index_edges_[position][previous_state][state]; 66 | (*additional_posteriors)[index] += weight; 67 | } 68 | 69 | public: 70 | FactorSequence() {} 71 | virtual ~FactorSequence() { ClearActiveSet(); } 72 | 73 | // Obtain the best configuration. 74 | void Maximize(const vector& variable_log_potentials, 75 | const vector& additional_log_potentials, 76 | Configuration& configuration, 77 | double* value) 78 | { 79 | // Decode using the Viterbi algorithm. 80 | int length = num_states_.size(); 81 | vector> values(length); 82 | vector> path(length); 83 | 84 | // Initialization. 85 | int num_states = num_states_[0]; 86 | values[0].resize(num_states); 87 | path[0].resize(num_states); 88 | for (int l = 0; l < num_states; ++l) { 89 | values[0][l] = 90 | GetNodeScore( 91 | 0, l, variable_log_potentials, additional_log_potentials) + 92 | GetEdgeScore( 93 | 0, 0, l, variable_log_potentials, additional_log_potentials); 94 | path[0][l] = -1; // This won't be used. 95 | } 96 | 97 | // Recursion. 98 | for (int i = 0; i < length - 1; ++i) { 99 | int num_states = num_states_[i + 1]; 100 | values[i + 1].resize(num_states); 101 | path[i + 1].resize(num_states); 102 | for (int k = 0; k < num_states; ++k) { 103 | double best_value = -std::numeric_limits::infinity(); 104 | int best = -1; 105 | for (int l = 0; l < num_states_[i]; ++l) { 106 | double val = 107 | values[i][l] + GetEdgeScore(i + 1, 108 | l, 109 | k, 110 | variable_log_potentials, 111 | additional_log_potentials); 112 | if (best < 0 || val > best_value) { 113 | best_value = val; 114 | best = l; 115 | } 116 | } 117 | values[i + 1][k] = 118 | best_value + GetNodeScore(i + 1, 119 | k, 120 | variable_log_potentials, 121 | additional_log_potentials); 122 | path[i + 1][k] = best; 123 | } 124 | } 125 | 126 | // Termination. 127 | double best_value = -std::numeric_limits::infinity(); 128 | int best = -1; 129 | for (int l = 0; l < num_states_[length - 1]; ++l) { 130 | double val = 131 | values[length - 1][l] + GetEdgeScore(length, 132 | l, 133 | 0, 134 | variable_log_potentials, 135 | additional_log_potentials); 136 | if (best < 0 || val > best_value) { 137 | best_value = val; 138 | best = l; 139 | } 140 | } 141 | 142 | // Path (state sequence) backtracking. 143 | vector* sequence = static_cast*>(configuration); 144 | assert(sequence->size() == length); 145 | (*sequence)[length - 1] = best; 146 | for (int i = length - 1; i > 0; --i) { 147 | (*sequence)[i - 1] = path[i][(*sequence)[i]]; 148 | } 149 | 150 | *value = best_value; 151 | } 152 | 153 | // Compute the score of a given assignment. 154 | void Evaluate(const vector& variable_log_potentials, 155 | const vector& additional_log_potentials, 156 | const Configuration configuration, 157 | double* value) 158 | { 159 | const vector* sequence = 160 | static_cast*>(configuration); 161 | *value = 0.0; 162 | int previous_state = 0; 163 | for (int i = 0; i < sequence->size(); ++i) { 164 | int state = (*sequence)[i]; 165 | *value += GetNodeScore( 166 | i, state, variable_log_potentials, additional_log_potentials); 167 | *value += GetEdgeScore(i, 168 | previous_state, 169 | state, 170 | variable_log_potentials, 171 | additional_log_potentials); 172 | previous_state = state; 173 | } 174 | *value += GetEdgeScore(sequence->size(), 175 | previous_state, 176 | 0, 177 | variable_log_potentials, 178 | additional_log_potentials); 179 | } 180 | 181 | // Given a configuration with a probability (weight), 182 | // increment the vectors of variable and additional posteriors. 183 | void UpdateMarginalsFromConfiguration(const Configuration& configuration, 184 | double weight, 185 | vector* variable_posteriors, 186 | vector* additional_posteriors) 187 | { 188 | const vector* sequence = 189 | static_cast*>(configuration); 190 | int previous_state = 0; 191 | for (int i = 0; i < sequence->size(); ++i) { 192 | int state = (*sequence)[i]; 193 | AddNodePosterior( 194 | i, state, weight, variable_posteriors, additional_posteriors); 195 | AddEdgePosterior(i, 196 | previous_state, 197 | state, 198 | weight, 199 | variable_posteriors, 200 | additional_posteriors); 201 | previous_state = state; 202 | } 203 | AddEdgePosterior(sequence->size(), 204 | previous_state, 205 | 0, 206 | weight, 207 | variable_posteriors, 208 | additional_posteriors); 209 | } 210 | 211 | // Count how many common values two configurations have. 212 | int CountCommonValues(const Configuration& configuration1, 213 | const Configuration& configuration2) 214 | { 215 | const vector* sequence1 = 216 | static_cast*>(configuration1); 217 | const vector* sequence2 = 218 | static_cast*>(configuration2); 219 | assert(sequence1->size() == sequence2->size()); 220 | int count = 0; 221 | for (int i = 0; i < sequence1->size(); ++i) { 222 | if ((*sequence1)[i] == (*sequence2)[i]) 223 | ++count; 224 | } 225 | return count; 226 | } 227 | 228 | // Check if two configurations are the same. 229 | bool SameConfiguration(const Configuration& configuration1, 230 | const Configuration& configuration2) 231 | { 232 | const vector* sequence1 = 233 | static_cast*>(configuration1); 234 | const vector* sequence2 = 235 | static_cast*>(configuration2); 236 | 237 | assert(sequence1->size() == sequence2->size()); 238 | for (int i = 0; i < sequence1->size(); ++i) { 239 | if ((*sequence1)[i] != (*sequence2)[i]) 240 | return false; 241 | } 242 | return true; 243 | } 244 | 245 | // Delete configuration. 246 | void DeleteConfiguration(Configuration configuration) 247 | { 248 | vector* sequence = static_cast*>(configuration); 249 | delete sequence; 250 | } 251 | 252 | Configuration CreateConfiguration() 253 | { 254 | int length = num_states_.size(); 255 | vector* sequence = new vector(length, -1); 256 | return static_cast(sequence); 257 | } 258 | 259 | public: 260 | // num_states contains the number of states at each position 261 | // in the sequence. The start and stop positions are not considered here. 262 | // Note: the variables and the the additional log-potentials must be ordered 263 | // properly. 264 | void Initialize(const vector& num_states) 265 | { 266 | int length = num_states.size(); 267 | num_states_ = num_states; 268 | index_edges_.resize(length + 1); 269 | offset_states_.resize(length); 270 | int offset = 0; 271 | for (int i = 0; i < length; ++i) { 272 | offset_states_[i] = offset; 273 | offset += num_states_[i]; 274 | } 275 | int index = 0; 276 | for (int i = 0; i <= length; ++i) { 277 | // If i == 0, the previous state is the start symbol. 278 | int num_previous_states = (i > 0) ? num_states_[i - 1] : 1; 279 | // If i == length, the previous state is the final symbol. 280 | int num_current_states = (i < length) ? num_states_[i] : 1; 281 | index_edges_[i].resize(num_previous_states); 282 | for (int j = 0; j < num_previous_states; ++j) { 283 | index_edges_[i][j].resize(num_current_states); 284 | for (int k = 0; k < num_current_states; ++k) { 285 | index_edges_[i][j][k] = index; 286 | ++index; 287 | } 288 | } 289 | } 290 | num_additionals_ = index; 291 | } 292 | 293 | virtual size_t GetNumAdditionals() override { return num_additionals_; } 294 | 295 | protected: 296 | // Number of states for each position. 297 | vector num_states_; 298 | // Offset of states for each position. 299 | vector offset_states_; 300 | // At each position, map from edges of states to a global index which 301 | // matches the index of additional_log_potentials_. 302 | vector>> index_edges_; 303 | int num_additionals_; 304 | }; 305 | 306 | } // namespace AD3 307 | -------------------------------------------------------------------------------- /lvmhelpers/sequence_binary.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "sequence.h" 4 | 5 | namespace AD3 { 6 | 7 | class FactorSequenceBinary : public FactorSequence 8 | { 9 | protected: 10 | virtual double GetNodeScore( 11 | int position, 12 | int state, 13 | const vector& variable_log_potentials, 14 | const vector& additional_log_potentials) override 15 | { 16 | if (state == 1) { 17 | return variable_log_potentials[position]; 18 | } else { 19 | return variable_log_potentials[length_ + position]; 20 | } 21 | } 22 | 23 | // The edge connects node[position-1] to node[position]. 24 | virtual double GetEdgeScore( 25 | int position, 26 | int previous_state, 27 | int state, 28 | const vector& variable_log_potentials, 29 | const vector& additional_log_potentials) override 30 | { 31 | // only consider positive-to-positive transitions; 32 | // this automatically rules out the initial and final transitions 33 | if (previous_state == 1 && state == 1) { 34 | return additional_log_potentials[position - 1]; 35 | } else 36 | return 0; 37 | } 38 | 39 | virtual void AddNodePosterior( 40 | int position, 41 | int state, 42 | double weight, 43 | vector* variable_posteriors, 44 | vector* additional_posteriors) override 45 | { 46 | if (state == 1) { 47 | (*variable_posteriors)[position] += weight; 48 | } else { 49 | (*variable_posteriors)[length_ + position] += weight; 50 | } 51 | } 52 | 53 | // The edge connects node[position-1] to node[position]. 54 | virtual void AddEdgePosterior( 55 | int position, 56 | int previous_state, 57 | int state, 58 | double weight, 59 | vector* variable_posteriors, 60 | vector* additional_posteriors) override 61 | { 62 | if (previous_state == 1 && state == 1) { 63 | (*additional_posteriors)[position - 1] += weight; 64 | } 65 | } 66 | 67 | public: 68 | void Initialize(const int length) 69 | { 70 | length_ = length; 71 | num_states_ = vector(length, 2); 72 | num_additionals_ = length - 1; 73 | } 74 | 75 | public: 76 | FactorSequenceBinary() {} 77 | virtual ~FactorSequenceBinary() { ClearActiveSet(); } 78 | 79 | int length_; 80 | }; 81 | 82 | } // namespace AD3 83 | -------------------------------------------------------------------------------- /lvmhelpers/sfe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import Categorical, Bernoulli 4 | 5 | 6 | class SFEWrapper(nn.Module): 7 | """ 8 | SFE Wrapper for a network. Assumes that the during the forward pass, 9 | the network returns scores over the potential output categories. 10 | The wrapper transforms them into a tuple of (sample from the Categorical, 11 | log-prob of the sample, entropy for the Categorical). 12 | """ 13 | def __init__(self, agent, baseline_type): 14 | """ 15 | Arguments: 16 | agent -- The agent to be wrapped. agent.forward() has to output 17 | scores over the categories 18 | baseline_type {str} -- which baseline to use. Either 'runavg' 19 | or 'sample'. 20 | """ 21 | super(SFEWrapper, self).__init__() 22 | self.agent = agent 23 | self.baseline_type = baseline_type 24 | 25 | def forward(self, *args, **kwargs): 26 | """Forward pass. 27 | 28 | Returns: 29 | sample {torch.Tensor} -- SFE sample. 30 | Size: [batch_size] 31 | scores {torch.Tensor} -- the output of the network. 32 | Important to compute the policy component of the SFE loss. 33 | Size: [batch_size, n_categories] 34 | entropy {torch.Tensor} -- the entropy of the Categorical distribution 35 | parameterized by the scores. 36 | Size: [batch_size] 37 | """ 38 | scores = self.agent(*args, **kwargs) 39 | 40 | distr = Categorical(logits=scores) 41 | entropy = distr.entropy() 42 | 43 | sample = distr.sample() 44 | 45 | return sample, scores, entropy 46 | 47 | 48 | class SFEDeterministicWrapper(nn.Module): 49 | """ 50 | Simple wrapper that makes a deterministic agent (without sampling) 51 | compatible with SFE-based training, by 52 | adding zero log-probability and entropy values to the output. 53 | No sampling is run on top of the wrapped agent, 54 | it is passed as is. 55 | """ 56 | def __init__(self, agent): 57 | super(SFEDeterministicWrapper, self).__init__() 58 | self.agent = agent 59 | 60 | def forward(self, *args, **kwargs): 61 | out = self.agent(*args, **kwargs) 62 | device = next(self.parameters()).device 63 | return out, torch.zeros(1).to(device), torch.zeros(1).to(device) 64 | 65 | 66 | class ScoreFunctionEstimator(torch.nn.Module): 67 | """ 68 | The training loop for the SFE method to train discrete latent variables. 69 | Encoder/Decoder needs to be either SFEWrapper or SFEDeterministicWrapper. 70 | """ 71 | def __init__( 72 | self, 73 | encoder, 74 | decoder, 75 | loss_fun, 76 | encoder_entropy_coeff=0.0, 77 | decoder_entropy_coeff=0.0): 78 | super(ScoreFunctionEstimator, self).__init__() 79 | self.encoder = encoder 80 | self.decoder = decoder 81 | self.loss = loss_fun 82 | self.encoder_entropy_coeff = encoder_entropy_coeff 83 | self.decoder_entropy_coeff = decoder_entropy_coeff 84 | self.mean_baseline = 0.0 85 | self.n_points = 0.0 86 | 87 | def forward(self, encoder_input, decoder_input, labels): 88 | discrete_latent_z, encoder_scores, encoder_entropy = \ 89 | self.encoder(encoder_input) 90 | decoder_output, decoder_scores, decoder_entropy = \ 91 | self.decoder(discrete_latent_z, decoder_input) 92 | 93 | argmax = encoder_scores.argmax(dim=-1) 94 | 95 | loss, logs = self.loss( 96 | encoder_input, 97 | argmax, 98 | decoder_input, 99 | decoder_output, 100 | labels) 101 | 102 | encoder_categorical_helper = Categorical(logits=encoder_scores) 103 | encoder_sample_log_probs = encoder_categorical_helper.log_prob(discrete_latent_z) 104 | if len(decoder_scores.size()) != 1: 105 | decoder_categorical_helper = Categorical(logits=decoder_scores) 106 | decoder_sample_log_probs = decoder_categorical_helper.log_prob(decoder_output) 107 | else: 108 | decoder_sample_log_probs = decoder_scores 109 | 110 | if self.encoder.baseline_type == 'runavg': 111 | baseline = self.mean_baseline 112 | elif self.encoder.baseline_type == 'sample': 113 | alt_z_sample = encoder_categorical_helper.sample().detach() 114 | decoder_output, _, _ = self.decoder(alt_z_sample, decoder_input) 115 | baseline, _ = self.loss( 116 | encoder_input, 117 | alt_z_sample, 118 | decoder_input, 119 | decoder_output, 120 | labels) 121 | 122 | policy_loss = ( 123 | (loss.detach() - baseline) * 124 | (encoder_sample_log_probs + decoder_sample_log_probs) 125 | ).mean() 126 | entropy_loss = -( 127 | encoder_entropy.mean() * 128 | self.encoder_entropy_coeff + 129 | decoder_entropy.mean() * 130 | self.decoder_entropy_coeff) 131 | 132 | if self.training and self.encoder.baseline_type == 'runavg': 133 | self.n_points += 1.0 134 | self.mean_baseline += ( 135 | loss.detach().mean() - self.mean_baseline) / self.n_points 136 | 137 | full_loss = policy_loss + entropy_loss + loss.mean() 138 | 139 | for k, v in logs.items(): 140 | if hasattr(v, 'mean'): 141 | logs[k] = v.mean() 142 | 143 | logs['baseline'] = self.mean_baseline 144 | logs['loss'] = loss.mean() 145 | logs['encoder_entropy'] = encoder_entropy.mean() 146 | logs['decoder_entropy'] = decoder_entropy.mean() 147 | logs['distr'] = encoder_categorical_helper.probs 148 | 149 | return {'loss': full_loss, 'log': logs} 150 | 151 | 152 | class BitVectorSFEWrapper(nn.Module): 153 | """ 154 | SFE Wrapper for a network that parameterizes 155 | independent Bernoulli distributions. 156 | Assumes that the during the forward pass, 157 | the network returns scores for the Bernoulli parameters. 158 | The wrapper transforms them into a tuple of (sample from the Bernoulli, 159 | log-prob of the sample, entropy for the independent Bernoulli). 160 | """ 161 | def __init__(self, agent, baseline_type): 162 | """ 163 | Arguments: 164 | agent -- The agent to be wrapped. agent.forward() has to output 165 | scores for each Bernoulli 166 | baseline_type {str} -- which baseline to use. Either 'runavg' 167 | or 'sample'. 168 | """ 169 | super(BitVectorSFEWrapper, self).__init__() 170 | self.agent = agent 171 | self.baseline_type = baseline_type 172 | 173 | def forward(self, *args, **kwargs): 174 | """Forward pass. 175 | 176 | Returns: 177 | sample {torch.Tensor} -- SFE sample. 178 | Size: [batch_size, n_bits] 179 | scores {torch.Tensor} -- the output of the network. 180 | Important to compute the policy component of the SFE loss. 181 | Size: [batch_size, n_bits] 182 | entropy {torch.Tensor} -- the entropy of the independent Bernoulli 183 | parameterized by the scores. 184 | Size: [batch_size] 185 | """ 186 | scores = self.agent(*args, **kwargs) 187 | 188 | distr = Bernoulli(logits=scores) 189 | entropy = distr.entropy().sum(dim=1) 190 | 191 | sample = distr.sample() 192 | 193 | return sample, scores, entropy 194 | 195 | 196 | class BitVectorScoreFunctionEstimator(torch.nn.Module): 197 | """ 198 | The training loop for the SFE method to train 199 | a bit-vector of independent latent variables. 200 | Encoder/Decoder needs to be either BitVectorSFEWrapper or SFEDeterministicWrapper. 201 | """ 202 | def __init__( 203 | self, 204 | encoder, 205 | decoder, 206 | loss_fun, 207 | encoder_entropy_coeff=0.0, 208 | decoder_entropy_coeff=0.0): 209 | super(BitVectorScoreFunctionEstimator, self).__init__() 210 | self.encoder = encoder 211 | self.decoder = decoder 212 | self.loss = loss_fun 213 | self.encoder_entropy_coeff = encoder_entropy_coeff 214 | self.decoder_entropy_coeff = decoder_entropy_coeff 215 | self.mean_baseline = 0.0 216 | self.n_points = 0.0 217 | 218 | def forward(self, encoder_input, decoder_input, labels): 219 | discrete_latent_z, encoder_scores, encoder_entropy = \ 220 | self.encoder(encoder_input) 221 | decoder_output, decoder_scores, decoder_entropy = \ 222 | self.decoder(discrete_latent_z, decoder_input) 223 | 224 | argmax = (encoder_scores > 0).to(torch.float) 225 | 226 | loss, logs = self.loss( 227 | encoder_input, 228 | argmax, 229 | decoder_input, 230 | decoder_output, 231 | labels) 232 | 233 | encoder_bernoull_distr = Bernoulli(logits=encoder_scores) 234 | encoder_sample_log_probs = \ 235 | encoder_bernoull_distr.log_prob(discrete_latent_z).sum(dim=1) 236 | if len(decoder_scores.size()) != 1: 237 | decoder_categorical_helper = Bernoulli(logits=decoder_scores) 238 | decoder_sample_log_probs = \ 239 | decoder_categorical_helper.log_prob(decoder_output).sum(dim=1) 240 | else: 241 | decoder_sample_log_probs = decoder_scores 242 | 243 | if self.encoder.baseline_type == 'runavg': 244 | baseline = self.mean_baseline 245 | elif self.encoder.baseline_type == 'sample': 246 | alt_z_sample = encoder_bernoull_distr.sample().detach() 247 | decoder_output, _, _ = self.decoder(alt_z_sample, decoder_input) 248 | baseline, _ = self.loss( 249 | encoder_input, 250 | alt_z_sample, 251 | decoder_input, 252 | decoder_output, 253 | labels) 254 | 255 | policy_loss = ( 256 | loss.detach() - baseline) * ( 257 | encoder_sample_log_probs + decoder_sample_log_probs) 258 | entropy_loss = -( 259 | encoder_entropy * 260 | self.encoder_entropy_coeff + 261 | decoder_entropy * 262 | self.decoder_entropy_coeff) 263 | 264 | full_loss = (policy_loss + entropy_loss + loss).mean() 265 | 266 | if self.training and self.encoder.baseline_type == 'runavg': 267 | self.n_points += 1.0 268 | self.mean_baseline += ( 269 | loss.detach().mean() - self.mean_baseline) / self.n_points 270 | 271 | for k, v in logs.items(): 272 | if hasattr(v, 'mean'): 273 | logs[k] = v.mean() 274 | 275 | logs['baseline'] = self.mean_baseline 276 | logs['loss'] = loss.mean() 277 | logs['encoder_entropy'] = encoder_entropy.mean() 278 | logs['decoder_entropy'] = decoder_entropy.mean() 279 | logs['distr'] = encoder_bernoull_distr 280 | 281 | return {'loss': full_loss, 'log': logs} 282 | -------------------------------------------------------------------------------- /lvmhelpers/sparsemap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from lpsmap.ad3qp.factor_graph import PFactorGraph 5 | from .pbernoulli import PFactorBernoulli 6 | from .pbudget import PFactorBudget 7 | from .psequence import PFactorSequenceBinary 8 | 9 | 10 | class SparseMAP(torch.autograd.Function): 11 | 12 | @classmethod 13 | def run_sparsemap(cls, ctx, x): 14 | ctx.n = x.shape[0] 15 | ctx.fg = PFactorGraph() 16 | ctx.fg.set_verbosity(1) 17 | ctx.variables = [ctx.fg.create_binary_variable() 18 | for _ in range(2 * ctx.n)] 19 | ctx.f = cls.make_factor(ctx) 20 | ctx.fg.declare_factor(ctx.f, ctx.variables) 21 | x_np = x.detach().cpu().numpy().astype(np.double) 22 | x_np = np.concatenate([x_np, np.zeros_like(x_np)]) 23 | 24 | # initialize better 25 | if ctx.init: 26 | init = torch.rand(x_np.shape[0], dtype=torch.double) 27 | ctx.f.init_active_set_from_scores(init, []) 28 | 29 | _, _ = ctx.f.solve_qp(x_np, [], max_iter=ctx.max_iter) 30 | aset, p = ctx.f.get_sparse_solution() 31 | 32 | p = p[:len(aset)] 33 | aset = torch.tensor(aset, dtype=torch.float32, device=x.device) 34 | p = torch.tensor(p).to(x.device) 35 | ctx.mark_non_differentiable(aset) 36 | return p, aset 37 | 38 | @classmethod 39 | def jv(cls, ctx, dp): 40 | d_eta_u = np.empty(2 * ctx.n, dtype=np.float32) 41 | d_eta_v = np.empty(0, dtype=np.float32) 42 | ctx.f.dist_jacobian_vec(dp.cpu().numpy(), d_eta_u, d_eta_v) 43 | d_eta_u = torch.tensor(d_eta_u[:ctx.n], dtype=torch.float32, device=dp.device) 44 | return d_eta_u 45 | 46 | 47 | class SequenceSparseMAP(torch.autograd.Function): 48 | 49 | @classmethod 50 | def run_sparsemap(cls, ctx, x, t): 51 | ctx.n = x.shape[0] 52 | ctx.fg = PFactorGraph() 53 | ctx.fg.set_verbosity(1) 54 | ctx.variables = [ctx.fg.create_binary_variable() 55 | for _ in range(2 * ctx.n)] 56 | ctx.f = cls.make_factor(ctx) 57 | ctx.fg.declare_factor(ctx.f, ctx.variables) 58 | x_np = x.detach().cpu().numpy().astype(np.double) 59 | x_np = np.concatenate([x_np, np.zeros_like(x_np)]) 60 | 61 | # edge potentials 62 | t_np = t.detach().cpu().numpy().astype(np.double) 63 | ctx.f.set_additional_log_potentials(t_np) 64 | 65 | # initialize better 66 | if ctx.init: 67 | init = torch.rand(x_np.shape[0], dtype=torch.double) 68 | additionals = torch.zeros_like(t_np) 69 | ctx.f.init_active_set_from_scores(init, additionals) 70 | 71 | _, _ = ctx.f.solve_qp(x_np, t_np, max_iter=ctx.max_iter) 72 | aset, p = ctx.f.get_sparse_solution() 73 | 74 | p = p[:len(aset)] 75 | aset = torch.tensor(aset, dtype=torch.float32, device=x.device) 76 | p = torch.tensor(p).to(x.device) 77 | ctx.mark_non_differentiable(aset) 78 | return p, aset 79 | 80 | @classmethod 81 | def jv(cls, ctx, dp): 82 | d_eta_u = np.empty(2 * ctx.n, dtype=np.float32) 83 | d_eta_v = np.empty(ctx.n-1, dtype=np.float32) 84 | ctx.f.dist_jacobian_vec(dp.cpu().numpy(), d_eta_u, d_eta_v) 85 | d_eta_u = torch.tensor(d_eta_u[:ctx.n], dtype=torch.float32, device=dp.device) 86 | d_eta_v = torch.tensor(d_eta_v, dtype=torch.float32, device=dp.device) 87 | return d_eta_u, d_eta_v 88 | 89 | 90 | class BernSparseMAP(SparseMAP): 91 | 92 | @classmethod 93 | def make_factor(cls, ctx): 94 | f = PFactorBernoulli() 95 | f.initialize(ctx.n) 96 | return f 97 | 98 | @classmethod 99 | def forward(cls, ctx, x, max_iter, init): 100 | ctx.max_iter = max_iter 101 | ctx.init = init 102 | return cls.run_sparsemap(ctx, x) 103 | 104 | @classmethod 105 | def backward(cls, ctx, dp, daset): 106 | return cls.jv(ctx, dp), None, None, None 107 | 108 | 109 | class BudgetSparseMAP(SparseMAP): 110 | 111 | @classmethod 112 | def make_factor(cls, ctx): 113 | f = PFactorBudget() 114 | f.initialize(ctx.n, ctx.budget) 115 | return f 116 | 117 | @classmethod 118 | def forward(cls, ctx, x, budget, max_iter, init): 119 | ctx.n = x.shape[0] 120 | ctx.init = init 121 | ctx.max_iter = max_iter 122 | ctx.budget = budget 123 | return cls.run_sparsemap(ctx, x) 124 | 125 | @classmethod 126 | def backward(cls, ctx, dp, daset): 127 | return cls.jv(ctx, dp), None, None, None, None 128 | 129 | 130 | class SequenceBinarySparseMAP(SequenceSparseMAP): 131 | 132 | @classmethod 133 | def make_factor(cls, ctx): 134 | f = PFactorSequenceBinary() 135 | f.initialize(ctx.n) 136 | return f 137 | 138 | @classmethod 139 | def forward(cls, ctx, x, t, max_iter, init): 140 | ctx.n = x.shape[0] 141 | ctx.init = init 142 | ctx.max_iter = max_iter 143 | return cls.run_sparsemap(ctx, x, t) 144 | 145 | @classmethod 146 | def backward(cls, ctx, dp, daset): 147 | d_eta_u, d_eta_v = cls.jv(ctx, dp) 148 | return d_eta_u, d_eta_v, None, None 149 | 150 | 151 | def bernoulli_smap(x, max_iter=100, init=True): 152 | return BernSparseMAP.apply(x, max_iter, init) 153 | 154 | 155 | def budget_smap(x, budget=5, max_iter=100, init=True): 156 | return BudgetSparseMAP.apply(x, budget, max_iter, init) 157 | 158 | 159 | def sequence_smap(x, t, max_iter=100, init=True): 160 | return SequenceBinarySparseMAP.apply(x, t, max_iter, init) 161 | 162 | 163 | def main(): 164 | 165 | torch.manual_seed(42) 166 | 167 | x = torch.randn(5, requires_grad=True) 168 | print(x) 169 | p, aset = bernoulli_smap(x) 170 | print(p) 171 | print(aset) 172 | 173 | p, aset = budget_smap(x, budget=3) 174 | print(p) 175 | print(aset) 176 | 177 | print(torch.autograd.grad(p[0], x)) 178 | print(torch.autograd.grad(p[1], x)) 179 | 180 | 181 | if __name__ == '__main__': 182 | main() 183 | -------------------------------------------------------------------------------- /lvmhelpers/structmarg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from entmax import sparsemax 4 | 5 | from .sparsemap import bernoulli_smap, budget_smap 6 | from .pbinary_topk import batched_topk 7 | 8 | 9 | def entropy(p): 10 | nz = (p > 0).to(p.device) 11 | 12 | eps = torch.finfo(p.dtype).eps 13 | p_stable = p.clone().clamp(min=eps, max=1 - eps) 14 | 15 | out = torch.where( 16 | nz, 17 | p_stable * torch.log(p_stable), 18 | torch.tensor(0., device=p.device, dtype=torch.float)) 19 | 20 | return -(out).sum(-1) 21 | 22 | 23 | class TopKSparsemaxWrapper(nn.Module): 24 | """ 25 | Top-k sparsemax Wrapper for a network that parameterizes 26 | independent Bernoulli distributions. 27 | Assumes that the during the forward pass, 28 | the network returns scores for the Bernoulli parameters. 29 | The wrapper transforms them into a tuple of (sample from the Bernoulli, 30 | log-prob of the sample, entropy for the independent Bernoulli). 31 | """ 32 | def __init__(self, agent, k=10): 33 | super(TopKSparsemaxWrapper, self).__init__() 34 | self.agent = agent 35 | 36 | self.k = k 37 | 38 | def forward(self, *args, **kwargs): 39 | scores = self.agent(*args, **kwargs) 40 | batch_size, latent_size = scores.shape 41 | # get the top-k bit-vectors 42 | bit_vector_z = torch.empty((batch_size, self.k, latent_size), dtype=torch.float32) 43 | batched_topk(scores.detach().cpu().numpy(), bit_vector_z.numpy(), self.k) 44 | bit_vector_z = bit_vector_z.to(scores.device) 45 | 46 | # rank the top-k using sparsemax 47 | scores = torch.einsum("bkj,bj->bk", bit_vector_z, scores) 48 | distr = sparsemax(scores, dim=-1) 49 | 50 | # get the entropy 51 | distr_flat = distr.view(-1) 52 | mask = distr_flat > 0 53 | distr_flat = distr_flat[mask] 54 | entropy_distr = - distr_flat @ torch.log(distr_flat) 55 | 56 | sample = bit_vector_z 57 | 58 | return sample, distr, entropy_distr / batch_size 59 | 60 | 61 | class TopKSparsemaxMarg(torch.nn.Module): 62 | """ 63 | The training loop for the Top-k sparsemax method to train 64 | a bit-vector of independent latent variables. 65 | Encoder needs to be TopKSparsemaxWrapper. 66 | Decoder needs to be utils.DeterministicWrapper. 67 | """ 68 | def __init__( 69 | self, 70 | encoder, 71 | decoder, 72 | loss_fun, 73 | encoder_entropy_coeff=0.0, 74 | decoder_entropy_coeff=0.0): 75 | super(TopKSparsemaxMarg, self).__init__() 76 | self.encoder = encoder 77 | self.decoder = decoder 78 | self.loss = loss_fun 79 | self.encoder_entropy_coeff = encoder_entropy_coeff 80 | self.decoder_entropy_coeff = decoder_entropy_coeff 81 | 82 | def forward(self, encoder_input, decoder_input, labels): 83 | bit_vector_z, encoder_probs, encoder_entropy = self.encoder(encoder_input) 84 | batch_size = bit_vector_z.shape[0] 85 | k = self.encoder.k 86 | latent_size = bit_vector_z.shape[-1] 87 | 88 | entropy_loss = -(encoder_entropy * self.encoder_entropy_coeff) 89 | # bit_vector_z: [batch_size, k, latent_size] 90 | # bit_vector_z_flat: [batch_size * k, latent_size] 91 | bit_vector_z_flat = bit_vector_z.view(-1, latent_size) 92 | # encoder_input: [batch_size, input_size] 93 | # encoder_input_rep: [batch_size, k, input_size] 94 | # encoder_input_rep_flat: [batch_size * k, input_size] 95 | encoder_input_rep = encoder_input.unsqueeze(1).repeat((1, k, 1)) 96 | encoder_input_rep_flat = encoder_input_rep.view(-1, encoder_input.shape[-1]) 97 | 98 | # decoder_input: [batch_size, input_size] 99 | # decoder_input_rep: [batch_size, k, input_size] 100 | # decoder_input_rep_flat: [batch_size * k, input_size] 101 | decoder_input_rep = decoder_input.unsqueeze(1).repeat((1, k, 1)) 102 | decoder_input_rep_flat = decoder_input_rep.view(-1, decoder_input.shape[-1]) 103 | 104 | # TODO: this label format is specific to VAE... 105 | # labels: [batch_size, input_size] 106 | # labels_rep: [batch_size, k, input_size] 107 | # labels_rep_flat: [batch_size * k, input_size] 108 | labels_rep = labels.unsqueeze(1).repeat((1, k, 1)) 109 | labels_rep_flat = labels_rep.view(-1, labels.shape[-1]) 110 | 111 | # encoder_probs: [batch_size, k] 112 | # encoder_probs_flat: [batch_size * k] 113 | encoder_probs_flat = encoder_probs.view(-1) 114 | 115 | # removing components that would end up being zero-ed out 116 | mask = encoder_probs_flat > 0 117 | # encoder_input_rep_flat: [<=batch_size * k, input_size] 118 | encoder_input_rep_flat = encoder_input_rep_flat[mask] 119 | # decoder_input_rep_flat: [<=batch_size * k, input_size] 120 | decoder_input_rep_flat = decoder_input_rep_flat[mask] 121 | # labels_rep_flat: [<=batch_size * k, input_size] 122 | labels_rep_flat = labels_rep_flat[mask] 123 | # encoder_probs_flat: [<=batch_size * k] 124 | encoder_probs_flat = encoder_probs_flat[mask] 125 | # bit_vector_z_flat: [<=batch_size * k, latent_size] 126 | bit_vector_z_flat = bit_vector_z_flat[mask] 127 | 128 | # decoder_output: [<=batch_size * k, input_size, out_classes] 129 | decoder_output = self.decoder(bit_vector_z_flat, decoder_input) 130 | 131 | # loss_components: [<=batch_size * k] 132 | loss_components, logs = self.loss( 133 | encoder_input_rep_flat, 134 | bit_vector_z_flat, 135 | decoder_input_rep_flat, 136 | decoder_output, 137 | labels_rep_flat) 138 | 139 | # loss: [] 140 | loss = (encoder_probs_flat @ loss_components) / batch_size 141 | 142 | full_loss = loss.mean() + entropy_loss 143 | 144 | for k, v in logs.items(): 145 | if hasattr(v, 'mean'): 146 | logs[k] = v.mean() 147 | 148 | logs['loss'] = loss.mean().detach() 149 | logs['encoder_entropy'] = encoder_entropy.detach() 150 | logs['support'] = (encoder_probs > 0).sum(dim=-1).to(torch.float) 151 | logs['distr'] = encoder_probs 152 | logs['loss_output'] = loss_components 153 | return {'loss': full_loss, 'log': logs} 154 | 155 | 156 | class SparseMAPWrapper(nn.Module): 157 | """ 158 | SparseMAP Wrapper for a network that parameterizes 159 | independent Bernoulli distributions. 160 | Assumes that the during the forward pass, 161 | the network returns scores for the Bernoulli parameters. 162 | The wrapper transforms them into a tuple of (sample from the Bernoulli, 163 | log-prob of the sample, entropy for the independent Bernoulli). 164 | """ 165 | def __init__(self, agent, budget=0, init=False, max_iter=300): 166 | super(SparseMAPWrapper, self).__init__() 167 | self.agent = agent 168 | self.budget = budget 169 | self.init = init 170 | self.max_iter = max_iter 171 | 172 | def forward(self, *args, **kwargs): 173 | scores = self.agent(*args, **kwargs) 174 | batch_size, latent_size = scores.shape 175 | 176 | distr = [] 177 | sample = [] 178 | idxs = [] 179 | 180 | support = [] 181 | for k in range(batch_size): 182 | zl = scores[k] 183 | if self.budget > 0: 184 | distri, samplei = budget_smap( 185 | zl, budget=self.budget, init=self.init, max_iter=self.max_iter) 186 | else: 187 | distri, samplei = bernoulli_smap( 188 | zl, init=self.init, max_iter=self.max_iter) 189 | samplei = samplei[distri > 0] 190 | distri = distri[distri > 0] 191 | supp = len(distri) 192 | assert supp > 0 193 | sample.append(samplei) 194 | distr.append(distri) 195 | idxs.extend(supp * [k]) 196 | support.append(supp) 197 | 198 | sample = torch.cat(sample) 199 | distr = torch.cat(distr) 200 | entropy_distr = -distr @ torch.log(distr) 201 | 202 | return sample, distr, entropy_distr / batch_size, idxs, support 203 | 204 | 205 | class SparseMAPMarg(torch.nn.Module): 206 | """ 207 | The training loop for the SparseMAP marginalization method to train 208 | a bit-vector of independent latent variables. 209 | Encoder needs to be SparseMAPWrapper. 210 | Decoder needs to be utils.DeterministicWrapper. 211 | """ 212 | def __init__( 213 | self, 214 | encoder, 215 | decoder, 216 | loss_fun, 217 | encoder_entropy_coeff=0.0, 218 | decoder_entropy_coeff=0.0): 219 | super(SparseMAPMarg, self).__init__() 220 | self.encoder = encoder 221 | self.decoder = decoder 222 | self.loss = loss_fun 223 | self.encoder_entropy_coeff = encoder_entropy_coeff 224 | self.decoder_entropy_coeff = decoder_entropy_coeff 225 | 226 | def forward(self, encoder_input, decoder_input, labels): 227 | bit_vector_z, encoder_probs, encoder_entropy, idxs, support = \ 228 | self.encoder(encoder_input) 229 | batch_size = encoder_input.shape[0] 230 | 231 | entropy_loss = -(encoder_entropy * self.encoder_entropy_coeff) 232 | 233 | decoder_output = self.decoder(bit_vector_z) 234 | 235 | loss_components, logs = self.loss( 236 | encoder_input[idxs], 237 | bit_vector_z, 238 | decoder_input[idxs], 239 | decoder_output, 240 | labels[idxs]) 241 | 242 | loss = (encoder_probs @ loss_components) / batch_size 243 | full_loss = loss + entropy_loss 244 | 245 | for k, v in logs.items(): 246 | if hasattr(v, 'mean'): 247 | logs[k] = v.mean() 248 | 249 | logs['loss'] = loss.detach() 250 | logs['encoder_entropy'] = encoder_entropy.detach() 251 | logs['support'] = torch.tensor(support).to(torch.float) 252 | logs['distr'] = encoder_probs.detach() 253 | logs['loss_output'] = loss_components.detach() 254 | logs['idxs'] = idxs 255 | return {'loss': full_loss, 'log': logs} 256 | -------------------------------------------------------------------------------- /lvmhelpers/sum_and_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Most of the code here is inspired by or copied from 3 | https://github.com/Runjing-Liu120/RaoBlackwellizedSGD/ 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.distributions import Categorical 9 | 10 | 11 | def get_concentrated_mask(class_weights, topk): 12 | """ 13 | Returns a logical mask indicating the categories with the top k largest 14 | probabilities, as well as the catogories corresponding to those with the 15 | top k largest probabilities. 16 | 17 | Parameters 18 | ---------- 19 | class_weights : torch.Tensor 20 | Array of class weights, with each row corresponding to a datapoint, 21 | each column corresponding to the probability of the datapoint 22 | belonging to that category 23 | topk : int 24 | the k in top-k 25 | 26 | Returns 27 | ------- 28 | mask_topk : torch.Tensor 29 | Boolean array, same dimension as class_weights, 30 | with entry 1 if the corresponding class weight is 31 | in the topk for that observation 32 | topk_domain: torch.LongTensor 33 | Array specifying the indices of class_weights that correspond to 34 | the topk observations 35 | """ 36 | 37 | mask_topk = torch.zeros(class_weights.shape).to(class_weights.device) 38 | 39 | seq_tensor = torch.LongTensor([i for i in range(class_weights.shape[0])]) 40 | 41 | if topk > 0: 42 | _, topk_domain = torch.topk(class_weights, topk) 43 | 44 | for i in range(topk): 45 | mask_topk[seq_tensor, topk_domain[:, i]] = 1 46 | else: 47 | topk_domain = None 48 | 49 | return mask_topk, topk_domain, seq_tensor 50 | 51 | 52 | class SumAndSampleWrapper(nn.Module): 53 | """ 54 | Sum&Sample Wrapper for a network. Assumes that the during the forward pass, 55 | the network returns scores over the potential output categories. 56 | The wrapper transforms them into a tuple of (sample from the Categorical, 57 | log-prob of the sample, entropy for the Categorical). 58 | 59 | See: https://arxiv.org/abs/1810.04777 60 | """ 61 | def __init__(self, agent, topk=10, baseline_type=None): 62 | super(SumAndSampleWrapper, self).__init__() 63 | self.agent = agent 64 | self.topk = topk 65 | self.baseline_type = baseline_type 66 | 67 | def forward(self, *args, **kwargs): 68 | scores = self.agent(*args, **kwargs) 69 | 70 | distr = Categorical(logits=scores) 71 | entropy = distr.entropy() 72 | 73 | sample = scores.argmax(dim=-1) 74 | 75 | return sample, scores, entropy 76 | 77 | 78 | class SumAndSample(torch.nn.Module): 79 | """ 80 | The training loop for the Sum&Sample method to train discrete latent variables. 81 | Encoder needs to be SumAndSampleWrapper. 82 | Decoder needs to be utils.DeterministicWrapper. 83 | """ 84 | def __init__( 85 | self, 86 | encoder, 87 | decoder, 88 | loss_fun, 89 | encoder_entropy_coeff=0.0, 90 | decoder_entropy_coeff=0.0): 91 | super(SumAndSample, self).__init__() 92 | self.encoder = encoder 93 | self.decoder = decoder 94 | self.loss = loss_fun 95 | self.encoder_entropy_coeff = encoder_entropy_coeff 96 | self.decoder_entropy_coeff = decoder_entropy_coeff 97 | self.mean_baseline = 0.0 98 | self.n_points = 0.0 99 | 100 | def forward(self, encoder_input, decoder_input, labels): 101 | discrete_latent_z, encoder_scores, encoder_entropy = \ 102 | self.encoder(encoder_input) 103 | batch_size, _ = encoder_scores.shape 104 | 105 | # encoder_log_prob: [batch_size, latent_size] 106 | encoder_log_prob = torch.log_softmax(encoder_scores, dim=-1) 107 | # entropy component of the final loss, we can 108 | # compute already but will only use it later on 109 | entropy_loss = -(encoder_entropy.mean() * self.encoder_entropy_coeff) 110 | 111 | if self.training: 112 | # encoder_prob: [batch_size, latent_size] 113 | encoder_prob = torch.softmax(encoder_scores.detach(), dim=-1) 114 | # this is the indicator C_k 115 | # concentrated_mask: [batch_size, latent_size] 116 | # topk_domain: [batch_size, self.encoder.topk] 117 | # seq_tensor: [batch_size] 118 | concentrated_mask, topk_domain, seq_tensor = \ 119 | get_concentrated_mask(encoder_prob, self.encoder.topk) 120 | concentrated_mask = concentrated_mask.float().detach() 121 | 122 | ############################ 123 | # compute the summed term 124 | summed_term = 0.0 125 | 126 | for ii in range(self.encoder.topk): 127 | # get categories to be summed 128 | possible_z = topk_domain[:, ii] 129 | 130 | decoder_output = self.decoder( 131 | possible_z, decoder_input) 132 | 133 | loss_sum_term, logs = self.loss( 134 | encoder_input, 135 | possible_z, 136 | decoder_input, 137 | decoder_output, 138 | labels) 139 | 140 | if self.encoder.baseline_type == 'runavg': 141 | baseline = self.mean_baseline 142 | elif self.encoder.baseline_type == 'sample': 143 | alt_z_sample = Categorical(logits=encoder_log_prob).sample().detach() 144 | decoder_output = self.decoder(alt_z_sample, decoder_input) 145 | baseline, _ = self.loss( 146 | encoder_input, 147 | alt_z_sample, 148 | decoder_input, 149 | decoder_output, 150 | labels) 151 | else: 152 | baseline = 0. 153 | 154 | # get log class probabilities 155 | encoder_log_prob_i = encoder_log_prob[seq_tensor, possible_z] 156 | # compute gradient estimate 157 | grad_estimate_loss = \ 158 | (loss_sum_term.detach() - baseline) * encoder_log_prob_i + \ 159 | loss_sum_term 160 | # sum 161 | summed_weights = encoder_prob[seq_tensor, possible_z].squeeze() 162 | summed_term = summed_term + (grad_estimate_loss * summed_weights) 163 | 164 | if self.training and self.encoder.baseline_type == 'runavg': 165 | self.n_points += 1.0 166 | self.mean_baseline += ( 167 | loss_sum_term.detach().mean() - self.mean_baseline 168 | ) / self.n_points 169 | 170 | # only compute argmax for training log 171 | if ii == 0: 172 | # save this log in a different variable 173 | train_logs = logs 174 | for k, v in train_logs.items(): 175 | if hasattr(v, 'mean'): 176 | train_logs[k] = v.mean() 177 | 178 | ############################ 179 | # compute sampled term 180 | sampled_weight = torch.sum( 181 | encoder_prob * (1 - concentrated_mask), 182 | dim=1, 183 | keepdim=True) 184 | 185 | if not(self.encoder.topk == encoder_prob.shape[1]): 186 | # if we didn't sum everything 187 | # we sample from the remaining terms 188 | 189 | # class weights conditioned on being in the diffuse set 190 | conditional_encoder_prob = (encoder_prob + 1e-12) * \ 191 | (1 - concentrated_mask) / (sampled_weight + 1e-12) 192 | 193 | # sample from conditional distribution 194 | cat_rv = Categorical(probs=conditional_encoder_prob) 195 | conditional_z_sample = cat_rv.sample().detach() 196 | 197 | decoder_output = self.decoder( 198 | conditional_z_sample, decoder_input) 199 | 200 | loss_sum_term, _ = self.loss( 201 | encoder_input, 202 | conditional_z_sample, 203 | decoder_input, 204 | decoder_output, 205 | labels) 206 | 207 | if self.encoder.baseline_type == 'runavg': 208 | baseline = self.mean_baseline 209 | if self.encoder.baseline_type == 'sample': 210 | alt_z_sample = Categorical(logits=encoder_log_prob).sample().detach() 211 | decoder_output = self.decoder(alt_z_sample, decoder_input) 212 | baseline, _ = self.loss( 213 | encoder_input, 214 | alt_z_sample, 215 | decoder_input, 216 | decoder_output, 217 | labels) 218 | else: 219 | baseline = 0. 220 | 221 | # get log class probabilities 222 | encoder_log_prob_i = encoder_log_prob[seq_tensor, conditional_z_sample] 223 | # compute gradient estimate 224 | grad_estimate_loss_sample = \ 225 | (loss_sum_term.detach() - baseline) * encoder_log_prob_i + \ 226 | loss_sum_term 227 | 228 | if self.training and self.encoder.baseline_type == 'runavg': 229 | self.n_points += 1.0 230 | self.mean_baseline += ( 231 | loss_sum_term.detach().mean() - self.mean_baseline 232 | ) / self.n_points 233 | else: 234 | grad_estimate_loss_sample = 0.0 235 | 236 | loss = grad_estimate_loss_sample * sampled_weight.squeeze() + summed_term 237 | 238 | # restore the log of argmax 239 | logs = train_logs 240 | 241 | with torch.no_grad(): 242 | decoder_output = self.decoder(discrete_latent_z, decoder_input) 243 | map_loss, map_logs = self.loss( 244 | encoder_input, 245 | discrete_latent_z, 246 | decoder_input, 247 | decoder_output, 248 | labels) 249 | 250 | for k, v in map_logs.items(): 251 | if hasattr(v, 'mean'): 252 | map_logs[k] = v.mean() 253 | 254 | if not self.training: 255 | loss, logs = map_loss, map_logs 256 | 257 | full_loss = loss.mean() + entropy_loss 258 | 259 | logs['loss'] = map_loss.mean() 260 | logs['encoder_entropy'] = encoder_entropy.mean() 261 | logs['distr'] = encoder_prob 262 | 263 | return {'loss': full_loss, 'log': logs} 264 | -------------------------------------------------------------------------------- /lvmhelpers/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.nn as nn 3 | 4 | 5 | class DeterministicWrapper(nn.Module): 6 | """ 7 | Simple wrapper that makes a deterministic agent. 8 | No sampling is run on top of the wrapped agent, 9 | it is passed as is. 10 | """ 11 | def __init__(self, agent): 12 | super(DeterministicWrapper, self).__init__() 13 | self.agent = agent 14 | 15 | def forward(self, *args, **kwargs): 16 | out = self.agent(*args, **kwargs) 17 | return out 18 | 19 | 20 | def populate_common_params( 21 | arg_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 22 | 23 | # discrete latent variable training 24 | arg_parser.add_argument("--mode", type=str, default="sfe", 25 | choices=[ 26 | "sfe", "nvil", "gs", "marg", "sumsample", 27 | "topksparse", "sparsemap"], 28 | help="""Method to train the discrete/structured 29 | latent variable model""") 30 | arg_parser.add_argument("--entropy_coeff", type=float, default=1.0, 31 | help="""Entropy loss term coefficient (regularization)""") 32 | arg_parser.add_argument("--latent_size", type=int, default=10, 33 | help="Number of categories (default: 10)") 34 | 35 | # Marginalization 36 | arg_parser.add_argument("--normalizer", type=str, default="softmax", 37 | choices=["softmax", "sparsemax"], 38 | help="""Normalizer to use when parameterizing 39 | a discrete distribution over categories""") 40 | 41 | # Structured Marginalization 42 | arg_parser.add_argument("--topksparse", type=int, default=10, 43 | help="""k in top-k sparsemax. Not to be confused with the --topk option 44 | of the sum&sample estimator""") 45 | arg_parser.add_argument("--noinit", action="store_true") 46 | arg_parser.add_argument("--budget", type=int, default=0) 47 | 48 | # Gumbel-Softmax 49 | arg_parser.add_argument("--gs_tau", type=float, default=1.0, 50 | help="GS temperature") 51 | arg_parser.add_argument("--straight_through", action="store_true") 52 | arg_parser.add_argument("--temperature_decay", type=float, default=1e-5, 53 | help="""temperature decay constant for 54 | Gumbel-Softmax (default: 1e-5)""") 55 | arg_parser.add_argument("--temperature_update_freq", type=int, default=1000, 56 | help="""temperature decay frequency for 57 | Gumbel-Softmax, in steps (default: 1000)""") 58 | 59 | # SFE 60 | arg_parser.add_argument("--baseline_type", type=str, default="runavg", 61 | choices=["runavg", "sample"], 62 | help="""baseline to use in SFE. runavg is the running average 63 | and sample is a self-critic baseline""") 64 | 65 | # sum and sample 66 | arg_parser.add_argument("--topk", type=int, default=1, 67 | help="""number of classes summed over 68 | for sum&sample gradient estimator""") 69 | 70 | # random seed 71 | arg_parser.add_argument("--random_seed", type=int, default=42, 72 | help="Set random seed") 73 | 74 | # trainer params 75 | arg_parser.add_argument("--n_epochs", type=int, default=10, 76 | help="Number of epochs to train (default: 10)") 77 | arg_parser.add_argument("--load_from_checkpoint", type=str, default=None, 78 | help="""If the parameter is set, model, 79 | trainer, and optimizer states are loaded from the 80 | checkpoint (default: None)""") 81 | 82 | # cuda setup 83 | arg_parser.add_argument("--no_cuda", default=False, help="disable cuda", 84 | action="store_true") 85 | 86 | # dataset 87 | arg_parser.add_argument("--batch_size", type=int, default=32, 88 | help="Input batch size for training (default: 32)") 89 | 90 | # optimizer 91 | arg_parser.add_argument("--optimizer", type=str, default="adam", 92 | choices=["adam", "sgd", "adagrad"], 93 | help="Optimizer to use [adam, sgd, adagrad] (default: adam)") 94 | arg_parser.add_argument("--lr", type=float, default=1e-3, 95 | help="Learning rate (default: 1e-3)") 96 | arg_parser.add_argument("--weight_decay", type=float, default=1e-5, 97 | help="L2 regularization constant (default: 1e-5)") 98 | 99 | return arg_parser 100 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython 2 | entmax 3 | h5py 4 | numpy 5 | pytorch-lightning==0.9.0 6 | tensorboard==2.2.0 7 | torch==1.6.0 8 | torchvision==0.7.0 9 | -------------------------------------------------------------------------------- /scripts/bit_vector/bit_vector_gs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/bit_vector-vae/train.py \ 4 | --mode gs \ 5 | --lr 0.001 \ 6 | --batch_size 64 \ 7 | --n_epochs 100 \ 8 | --latent_size 32 \ 9 | --weight_decay 0. \ 10 | --temperature_decay 1e-5 \ 11 | --temperature_update_freq 1000 \ 12 | --random_seed 42 13 | -------------------------------------------------------------------------------- /scripts/bit_vector/bit_vector_gs_128.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/bit_vector-vae/train.py \ 4 | --mode gs \ 5 | --lr 0.001 \ 6 | --batch_size 64 \ 7 | --n_epochs 100 \ 8 | --latent_size 128 \ 9 | --weight_decay 0. \ 10 | --temperature_decay 1e-4 \ 11 | --temperature_update_freq 1000 \ 12 | --random_seed 42 13 | -------------------------------------------------------------------------------- /scripts/bit_vector/bit_vector_gs_st.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/bit_vector-vae/train.py \ 4 | --mode gs \ 5 | --lr 0.001 \ 6 | --batch_size 64 \ 7 | --n_epochs 100 \ 8 | --latent_size 32 \ 9 | --weight_decay 0. \ 10 | --temperature_decay 1e-5 \ 11 | --temperature_update_freq 1000 \ 12 | --straight_through \ 13 | --random_seed 42 14 | -------------------------------------------------------------------------------- /scripts/bit_vector/bit_vector_gs_st_128.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/bit_vector-vae/train.py \ 4 | --mode gs \ 5 | --lr 0.0005 \ 6 | --batch_size 64 \ 7 | --n_epochs 100 \ 8 | --latent_size 128 \ 9 | --weight_decay 0. \ 10 | --temperature_decay 1e-5 \ 11 | --temperature_update_freq 1000 \ 12 | --straight_through \ 13 | --random_seed 42 14 | -------------------------------------------------------------------------------- /scripts/bit_vector/bit_vector_nvil.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/bit_vector-vae/train.py \ 4 | --mode nvil \ 5 | --lr 0.001 \ 6 | --batch_size 64 \ 7 | --n_epochs 100 \ 8 | --latent_size 32 \ 9 | --weight_decay 0. 10 | -------------------------------------------------------------------------------- /scripts/bit_vector/bit_vector_nvil_128.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/bit_vector-vae/train.py \ 4 | --mode nvil \ 5 | --lr 0.0005 \ 6 | --batch_size 64 \ 7 | --n_epochs 100 \ 8 | --latent_size 128 \ 9 | --weight_decay 0. 10 | -------------------------------------------------------------------------------- /scripts/bit_vector/bit_vector_sfe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/bit_vector-vae/train.py \ 4 | --mode sfe \ 5 | --lr 0.001 \ 6 | --batch_size 64 \ 7 | --n_epochs 100 \ 8 | --latent_size 32 \ 9 | --weight_decay 0. \ 10 | --random_seed 42 11 | -------------------------------------------------------------------------------- /scripts/bit_vector/bit_vector_sfe_128.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/bit_vector-vae/train.py \ 4 | --mode sfe \ 5 | --lr 0.001 \ 6 | --batch_size 64 \ 7 | --n_epochs 100 \ 8 | --latent_size 128 \ 9 | --weight_decay 0. \ 10 | --random_seed 42 11 | -------------------------------------------------------------------------------- /scripts/bit_vector/bit_vector_sfe_plus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/bit_vector-vae/train.py \ 4 | --mode sfe \ 5 | --lr 0.001 \ 6 | --batch_size 64 \ 7 | --n_epochs 100 \ 8 | --latent_size 32 \ 9 | --baseline_type sample \ 10 | --weight_decay 0. 11 | -------------------------------------------------------------------------------- /scripts/bit_vector/bit_vector_sfe_plus_128.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/bit_vector-vae/train.py \ 4 | --mode sfe \ 5 | --lr 0.001 \ 6 | --batch_size 64 \ 7 | --n_epochs 100 \ 8 | --latent_size 128 \ 9 | --baseline_type sample \ 10 | --weight_decay 0. 11 | -------------------------------------------------------------------------------- /scripts/bit_vector/bit_vector_sparsemap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/bit_vector-vae/train.py \ 4 | --mode sparsemap \ 5 | --lr 0.002 \ 6 | --batch_size 64 \ 7 | --n_epochs 100 \ 8 | --latent_size 32 \ 9 | --weight_decay 0. \ 10 | --random_seed 42 11 | -------------------------------------------------------------------------------- /scripts/bit_vector/bit_vector_sparsemap_128.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/bit_vector-vae/train.py \ 4 | --mode sparsemap \ 5 | --lr 0.002 \ 6 | --batch_size 16 \ 7 | --n_epochs 100 \ 8 | --latent_size 128 \ 9 | --weight_decay 0. \ 10 | --random_seed 42 11 | -------------------------------------------------------------------------------- /scripts/bit_vector/bit_vector_sparsemap_budget.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/bit_vector-vae/train.py \ 4 | --mode sparsemap \ 5 | --lr 0.002 \ 6 | --batch_size 64 \ 7 | --n_epochs 100 \ 8 | --latent_size 32 \ 9 | --weight_decay 0. \ 10 | --budget 16 \ 11 | --random_seed 42 12 | -------------------------------------------------------------------------------- /scripts/bit_vector/bit_vector_sparsemap_budget_128.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/bit_vector-vae/train.py \ 4 | --mode sparsemap \ 5 | --lr 0.002 \ 6 | --batch_size 16 \ 7 | --n_epochs 100 \ 8 | --latent_size 128 \ 9 | --weight_decay 0. \ 10 | --budget 64 \ 11 | --random_seed 42 12 | -------------------------------------------------------------------------------- /scripts/bit_vector/bit_vector_topksparse.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/bit_vector-vae/train.py \ 4 | --mode topksparse \ 5 | --lr 0.002 \ 6 | --batch_size 64 \ 7 | --n_epochs 100 \ 8 | --latent_size 32 \ 9 | --weight_decay 0. \ 10 | --random_seed 42 11 | -------------------------------------------------------------------------------- /scripts/bit_vector/bit_vector_topksparse_128.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/bit_vector-vae/train.py \ 4 | --mode topksparse \ 5 | --lr 0.001 \ 6 | --batch_size 64 \ 7 | --n_epochs 100 \ 8 | --latent_size 128 \ 9 | --weight_decay 0. \ 10 | --random_seed 42 11 | -------------------------------------------------------------------------------- /scripts/signal_game/signal_game_gs_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..10} 4 | do 5 | ((seed=$i + 321)) 6 | python experiments/signal-game/train.py \ 7 | --mode gs \ 8 | --lr 0.001 \ 9 | --entropy_coeff 0.01 \ 10 | --batch_size 64 \ 11 | --n_epochs 500 \ 12 | --game_size 16 \ 13 | --latent_size 256 \ 14 | --embedding_size 256 \ 15 | --hidden_size 512 \ 16 | --weight_decay 0. \ 17 | --temperature_decay 1e-5 \ 18 | --temperature_update_freq 1000 \ 19 | --random_seed ${seed} 20 | done 21 | -------------------------------------------------------------------------------- /scripts/signal_game/signal_game_gs_st_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..10} 4 | do 5 | ((seed=$i + 321)) 6 | python experiments/signal-game/train.py \ 7 | --mode gs \ 8 | --lr 0.001 \ 9 | --entropy_coeff 0.01 \ 10 | --batch_size 64 \ 11 | --n_epochs 500 \ 12 | --game_size 16 \ 13 | --latent_size 256 \ 14 | --embedding_size 256 \ 15 | --hidden_size 512 \ 16 | --weight_decay 0. \ 17 | --temperature_decay 1e-5 \ 18 | --temperature_update_freq 1000 \ 19 | --straight_through \ 20 | --random_seed ${seed} 21 | done 22 | -------------------------------------------------------------------------------- /scripts/signal_game/signal_game_marg_softmax_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..10} 4 | do 5 | ((seed=$i + 321)) 6 | python experiments/signal-game/train.py \ 7 | --mode marg \ 8 | --lr 0.001 \ 9 | --entropy_coeff 0.1 \ 10 | --batch_size 64 \ 11 | --n_epochs 500 \ 12 | --game_size 16 \ 13 | --latent_size 256 \ 14 | --embedding_size 256 \ 15 | --hidden_size 512 \ 16 | --normalizer softmax \ 17 | --weight_decay 0. \ 18 | --random_seed ${seed} 19 | done 20 | -------------------------------------------------------------------------------- /scripts/signal_game/signal_game_marg_sparsemax_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..10} 4 | do 5 | ((seed=$i + 321)) 6 | python experiments/signal-game/train.py \ 7 | --mode marg \ 8 | --lr 0.005 \ 9 | --entropy_coeff 0.1 \ 10 | --batch_size 64 \ 11 | --n_epochs 500 \ 12 | --game_size 16 \ 13 | --latent_size 256 \ 14 | --embedding_size 256 \ 15 | --hidden_size 512 \ 16 | --normalizer sparsemax \ 17 | --weight_decay 0. \ 18 | --random_seed ${seed} 19 | done 20 | -------------------------------------------------------------------------------- /scripts/signal_game/signal_game_nvil_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..10} 4 | do 5 | ((seed=$i + 321)) 6 | python experiments/signal-game/train.py \ 7 | --mode nvil \ 8 | --lr 0.001 \ 9 | --entropy_coeff 0.1 \ 10 | --batch_size 64 \ 11 | --n_epochs 500 \ 12 | --game_size 16 \ 13 | --latent_size 256 \ 14 | --embedding_size 256 \ 15 | --hidden_size 512 \ 16 | --loss_type nll \ 17 | --weight_decay 0. \ 18 | --random_seed ${seed} 19 | done 20 | -------------------------------------------------------------------------------- /scripts/signal_game/signal_game_sfe_nll_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..10} 4 | do 5 | ((seed=$i + 321)) 6 | python experiments/signal-game/train.py \ 7 | --mode sfe \ 8 | --lr 0.001 \ 9 | --entropy_coeff 0.05 \ 10 | --batch_size 64 \ 11 | --n_epochs 500 \ 12 | --game_size 16 \ 13 | --latent_size 256 \ 14 | --embedding_size 256 \ 15 | --hidden_size 512 \ 16 | --loss_type nll \ 17 | --baseline_type runavg \ 18 | --weight_decay 0. \ 19 | --random_seed ${seed} 20 | done 21 | -------------------------------------------------------------------------------- /scripts/signal_game/signal_game_sfe_plus_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..10} 4 | do 5 | ((seed=$i + 321)) 6 | python experiments/signal-game/train.py \ 7 | --mode sfe \ 8 | --lr 0.001 \ 9 | --entropy_coeff 0.05 \ 10 | --batch_size 64 \ 11 | --n_epochs 500 \ 12 | --game_size 16 \ 13 | --latent_size 256 \ 14 | --embedding_size 256 \ 15 | --hidden_size 512 \ 16 | --loss_type acc \ 17 | --baseline_type sample \ 18 | --weight_decay 0. \ 19 | --random_seed ${seed} 20 | done 21 | -------------------------------------------------------------------------------- /scripts/signal_game/signal_game_sfe_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..10} 4 | do 5 | ((seed=$i + 321)) 6 | python experiments/signal-game/train.py \ 7 | --mode sfe \ 8 | --lr 0.001 \ 9 | --entropy_coeff 0.01 \ 10 | --batch_size 64 \ 11 | --n_epochs 500 \ 12 | --game_size 16 \ 13 | --latent_size 256 \ 14 | --embedding_size 256 \ 15 | --hidden_size 512 \ 16 | --loss_type acc \ 17 | --baseline_type runavg \ 18 | --weight_decay 0. \ 19 | --random_seed ${seed} 20 | done 21 | -------------------------------------------------------------------------------- /scripts/ssvae/ssvae_gumbel_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..10} 4 | do 5 | ((seed=$i + 42)) 6 | python experiments/semi_supervised-vae/train.py \ 7 | --mode gs \ 8 | --random_seed ${seed} \ 9 | --lr 0.001 \ 10 | --batch_size 64 \ 11 | --n_epochs 200 \ 12 | --latent_size 10 \ 13 | --temperature_decay 0.0001 \ 14 | --temperature_update_freq 1000 \ 15 | --warm_start_path checkpoints/ssvae/warm_start/softmax/lr-0.001_baseline-runavg/version_0/checkpoints/epoch\=91.ckpt 16 | done 17 | -------------------------------------------------------------------------------- /scripts/ssvae/ssvae_gumbel_st_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..10} 4 | do 5 | ((seed=$i + 42)) 6 | python experiments/semi_supervised-vae/train.py \ 7 | --mode gs \ 8 | --random_seed ${seed} \ 9 | --lr 0.001 \ 10 | --batch_size 64 \ 11 | --n_epochs 200 \ 12 | --latent_size 10 \ 13 | --temperature_decay 0.0001 \ 14 | --temperature_update_freq 1000 \ 15 | --straight_through \ 16 | --warm_start_path checkpoints/ssvae/warm_start/softmax/lr-0.001_baseline-runavg/version_0/checkpoints/epoch\=91.ckpt 17 | done 18 | -------------------------------------------------------------------------------- /scripts/ssvae/ssvae_marg_softmax_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..10} 4 | do 5 | ((seed=$i + 42)) 6 | python experiments/semi_supervised-vae/train.py \ 7 | --mode marg \ 8 | --normalizer softmax \ 9 | --random_seed ${seed} \ 10 | --lr 0.001 \ 11 | --batch_size 64 \ 12 | --n_epochs 200 \ 13 | --latent_size 10 \ 14 | --warm_start_path checkpoints/ssvae/warm_start/softmax/lr-0.001_baseline-runavg/version_0/checkpoints/epoch\=91.ckpt 15 | done 16 | -------------------------------------------------------------------------------- /scripts/ssvae/ssvae_marg_sparsemax_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..10} 4 | do 5 | ((seed=$i + 42)) 6 | python experiments/semi_supervised-vae/train.py \ 7 | --mode marg \ 8 | --normalizer sparsemax \ 9 | --random_seed ${seed} \ 10 | --lr 0.0005 \ 11 | --batch_size 64 \ 12 | --n_epochs 200 \ 13 | --latent_size 10 \ 14 | --warm_start_path checkpoints/ssvae/warm_start/sparsemax/lr-0.001_baseline-runavg/version_0/checkpoints/epoch\=99.ckpt 15 | done 16 | -------------------------------------------------------------------------------- /scripts/ssvae/ssvae_nvil_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..10} 4 | do 5 | ((seed=$i + 42)) 6 | python experiments/semi_supervised-vae/train.py \ 7 | --mode nvil \ 8 | --random_seed ${seed} \ 9 | --lr 0.001 \ 10 | --batch_size 64 \ 11 | --n_epochs 200 \ 12 | --latent_size 10 \ 13 | --warm_start_path checkpoints/ssvae/warm_start/softmax/lr-0.001_baseline-runavg/version_0/checkpoints/epoch\=91.ckpt 14 | done 15 | -------------------------------------------------------------------------------- /scripts/ssvae/ssvae_sfe_plus_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..10} 4 | do 5 | ((seed=$i + 42)) 6 | python experiments/semi_supervised-vae/train.py \ 7 | --mode sfe \ 8 | --random_seed ${seed} \ 9 | --lr 0.0001 \ 10 | --batch_size 64 \ 11 | --n_epochs 200 \ 12 | --latent_size 10 \ 13 | --baseline_type sample \ 14 | --warm_start_path checkpoints/ssvae/warm_start/softmax/lr-0.001_baseline-runavg/version_0/checkpoints/epoch\=91.ckpt 15 | done 16 | -------------------------------------------------------------------------------- /scripts/ssvae/ssvae_sfe_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..10} 4 | do 5 | ((seed=$i + 42)) 6 | python experiments/semi_supervised-vae/train.py \ 7 | --mode sfe \ 8 | --random_seed ${seed} \ 9 | --lr 0.0001 \ 10 | --batch_size 64 \ 11 | --n_epochs 200 \ 12 | --latent_size 10 \ 13 | --baseline_type runavg \ 14 | --warm_start_path checkpoints/ssvae/warm_start/softmax/lr-0.001_baseline-runavg/version_0/checkpoints/epoch\=91.ckpt 15 | done 16 | -------------------------------------------------------------------------------- /scripts/ssvae/ssvae_sumsample_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {1..10} 4 | do 5 | ((seed=$i + 42)) 6 | python experiments/semi_supervised-vae/train.py \ 7 | --mode sumsample \ 8 | --topk 1 \ 9 | --random_seed ${seed} \ 10 | --lr 0.001 \ 11 | --batch_size 64 \ 12 | --n_epochs 200 \ 13 | --latent_size 10 \ 14 | --warm_start_path checkpoints/ssvae/warm_start/softmax/lr-0.001_baseline-runavg/version_0/checkpoints/epoch\=91.ckpt 15 | done 16 | -------------------------------------------------------------------------------- /scripts/ssvae_warm_start_softmax.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/semi_supervised-vae/train.py \ 4 | --n_epochs 100 \ 5 | --lr 1e-3 \ 6 | --labeled_only \ 7 | --normalizer softmax \ 8 | --batch_size 64 9 | 10 | -------------------------------------------------------------------------------- /scripts/ssvae_warm_start_sparsemax.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments/semi_supervised-vae/train.py \ 4 | --n_epochs 100 \ 5 | --lr 1e-3 \ 6 | --labeled_only \ 7 | --normalizer sparsemax \ 8 | --batch_size 64 9 | 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from setuptools import setup, find_packages 5 | from setuptools.extension import Extension 6 | from setuptools.command.build_ext import build_ext 7 | from distutils.command.clean import clean 8 | 9 | from Cython.Build import cythonize 10 | 11 | from lpsmap import config 12 | 13 | 14 | AD3_FLAGS_UNIX = [ 15 | '-std=c++11', 16 | '-O3', 17 | '-Wall', 18 | '-Wno-sign-compare', 19 | '-Wno-overloaded-virtual', 20 | '-c', 21 | '-fmessage-length=0', 22 | '-fPIC', 23 | '-ffast-math', 24 | '-march=native' 25 | ] 26 | 27 | 28 | AD3_FLAGS_MSVC = [ 29 | '/O2', 30 | '/fp:fast', 31 | '/favor:INTEL64', 32 | '/wd4267' # suppress sign-compare--like warning 33 | ] 34 | 35 | 36 | AD3_CFLAGS = { 37 | 'cygwin': AD3_FLAGS_UNIX, 38 | 'mingw32': AD3_FLAGS_UNIX, 39 | 'unix': AD3_FLAGS_UNIX, 40 | 'msvc': AD3_FLAGS_MSVC 41 | } 42 | 43 | 44 | # support compiler-specific cflags in extensions and libs 45 | class our_build_ext(build_ext): 46 | def build_extensions(self): 47 | 48 | # bug in distutils: flag not valid for c++ 49 | flag = '-Wstrict-prototypes' 50 | if (hasattr(self.compiler, 'compiler_so') 51 | and flag in self.compiler.compiler_so): 52 | self.compiler.compiler_so.remove(flag) 53 | 54 | compiler_type = self.compiler.compiler_type 55 | compile_args = AD3_CFLAGS.get(compiler_type, []) 56 | 57 | for e in self.extensions: 58 | e.extra_compile_args.extend(compile_args) 59 | 60 | build_ext.build_extensions(self) 61 | 62 | 63 | class our_clean(clean): 64 | def run(self): 65 | 66 | if os.path.exists('build'): 67 | shutil.rmtree('build') 68 | 69 | for dirpath, dirnames, filenames in os.walk('.'): 70 | for filename in filenames: 71 | if any(filename.endswith(suffix) for suffix in 72 | (".so", ".pyd", ".dll", ".pyc")): 73 | os.unlink(os.path.join(dirpath, filename)) 74 | continue 75 | extension = os.path.splitext(filename)[1] 76 | if extension in ['.c', '.cpp']: 77 | pyx_file = str.replace(filename, extension, '.pyx') 78 | if os.path.exists(os.path.join(dirpath, pyx_file)): 79 | os.unlink(os.path.join(dirpath, filename)) 80 | for dirname in dirnames: 81 | if dirname == '__pycache__': 82 | shutil.rmtree(os.path.join(dirpath, dirname)) 83 | clean.run(self) 84 | 85 | 86 | # this is a backport of a workaround for a problem in distutils. 87 | 88 | cmdclass = {'build_ext': our_build_ext, 89 | 'clean': our_clean} 90 | 91 | 92 | extensions = [ 93 | Extension('lvmhelpers.pbinary_topk', 94 | ["lvmhelpers/pbinary_topk.pyx"]), 95 | Extension('lvmhelpers.pbernoulli', 96 | ["lvmhelpers/pbernoulli.pyx"], 97 | libraries=['ad3'], 98 | library_dirs=[config.get_libdir()], 99 | include_dirs=[config.get_include()]), 100 | Extension('lvmhelpers.pbudget', 101 | ["lvmhelpers/pbudget.pyx"], 102 | libraries=['ad3'], 103 | library_dirs=[config.get_libdir()], 104 | include_dirs=[config.get_include()]), 105 | Extension('lvmhelpers.psequence', 106 | ["lvmhelpers/psequence.pyx"], 107 | libraries=['ad3'], 108 | library_dirs=[config.get_libdir()], 109 | include_dirs=[config.get_include()]), 110 | ] 111 | 112 | with open('requirements.txt') as f: 113 | requirements = f.read().splitlines() 114 | 115 | setup( 116 | name='lvmhelpers', 117 | version='0.3', 118 | url='https://github.com/goncalomcorreia/explicit-sparse-marginalization', 119 | author='Gonçalo Correia', 120 | author_email='goncalommac@gmail.com', 121 | packages=find_packages(), 122 | install_requires=requirements, 123 | cmdclass=cmdclass, 124 | include_package_data=True, 125 | ext_modules=cythonize(extensions) 126 | ) 127 | --------------------------------------------------------------------------------