├── .gitignore ├── LICENSE ├── README.md ├── environment.yml └── mine ├── __init__.py ├── ib.py ├── mine.py ├── models.py └── utils ├── __init__.py ├── data_loader.py ├── log.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # temp 2 | *__pycache__ 3 | *.ipynb_checkpoints/ 4 | # datasets 5 | data 6 | # log outputs 7 | logs 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Mohith Sakthivel 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mutual Information Neural Estimation 2 | 3 | This repository contains a pytorch implementation of the Information Bottleneck (IB) using Mutual Information Neural Estimation (MINE). [[Belghazi et al., 2018]](#references) 4 | 5 | A standard baseline MLP (as described in Deep VIB paper [[Alemi et al., 2017]](#references)) has been used for comparison. 6 | 7 | ## Setup 8 | 9 | ``` 10 | git clone https://github.com/mohith-sakthivel/mine-pytorch.git mine 11 | cd mine 12 | 13 | conda env create -f environment.yml 14 | conda activate mine 15 | ``` 16 | 17 | ## Run 18 | * To run the baseline model with default parameters 19 | ``` 20 | python3 -m mine.ib --deter 21 | ``` 22 | 23 | The baseline model is a standard MLP with 3 hidden layers and ReLU non-linearity. During training, an exponential weighted average of the parameters is maintained and these averaged parameters are used at test time. 24 | 25 | 26 | * To run MINE+IB model 27 | ``` 28 | python3 -m mine.ib --mine 29 | ``` 30 | 31 | ## Note 32 | This repo contains an implementation of MINE for information minimization only. For information maximization you should also incorporate adaptive gradient clipping as mentioned in [Belghazi et al.](#references). This is because MI is unbounded for typical high-dimensional use cases and hence gradients from the MI estimate can overwhelm gradients from the primary loss. 33 | 34 | 35 | 36 | ## References 37 | 1. M I Belghazi, A Baratin, S Rajeswar, S Ozair, Y Bengio, A Courville, R D Hjelm - MINE: Mutual Information Neural Estimation, ICML, 2018. ([paper](https://arxiv.org/abs/1801.04062)) 38 | 39 | 2. A A Alemi, I Fischer, J V Dillon, K Murphy - Deep Variational Information Bottleneck, ICLR, 2017. ([paper](https://arxiv.org/abs/1612.00410)) 40 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mine 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - anyio=2.2.0=py38h578d9bd_0 9 | - argon2-cffi=20.1.0=py38h25fe258_2 10 | - async_generator=1.10=py_0 11 | - attrs=20.3.0=pyhd3deb0d_0 12 | - babel=2.9.0=pyhd3deb0d_0 13 | - backcall=0.2.0=pyh9f0ad1d_0 14 | - backports=1.0=py_2 15 | - backports.functools_lru_cache=1.6.1=py_0 16 | - blas=1.0=mkl 17 | - bleach=3.3.0=pyh44b312d_0 18 | - brotlipy=0.7.0=py38h8df0ef7_1001 19 | - bzip2=1.0.8=h7b6447c_0 20 | - ca-certificates=2020.12.5=ha878542_0 21 | - certifi=2020.12.5=py38h578d9bd_1 22 | - cffi=1.14.5=py38h261ae71_0 23 | - chardet=4.0.0=py38h578d9bd_1 24 | - cryptography=2.9.2=py38h766eaa4_0 25 | - cudatoolkit=10.2.89=hfd86e86_1 26 | - decorator=4.4.2=py_0 27 | - defusedxml=0.7.1=pyhd8ed1ab_0 28 | - entrypoints=0.3=pyhd8ed1ab_1003 29 | - ffmpeg=4.3=hf484d3e_0 30 | - freetype=2.10.4=h5ab3b9f_0 31 | - gmp=6.2.1=h2531618_2 32 | - gnutls=3.6.5=h71b1129_1002 33 | - idna=2.10=pyh9f0ad1d_0 34 | - importlib-metadata=3.7.2=py38h578d9bd_0 35 | - intel-openmp=2020.2=254 36 | - ipykernel=5.5.0=py38h81c977d_1 37 | - ipython=7.21.0=py38h81c977d_0 38 | - ipython_genutils=0.2.0=py_1 39 | - jedi=0.18.0=py38h578d9bd_2 40 | - jinja2=2.11.3=pyh44b312d_0 41 | - jpeg=9b=h024ee3a_2 42 | - json5=0.9.5=pyh9f0ad1d_0 43 | - jsonschema=3.2.0=pyhd8ed1ab_3 44 | - jupyter-packaging=0.7.12=pyhd8ed1ab_0 45 | - jupyter_client=6.1.11=pyhd8ed1ab_1 46 | - jupyter_core=4.7.1=py38h578d9bd_0 47 | - jupyter_server=1.4.1=py38h578d9bd_0 48 | - jupyterlab=3.0.10=pyhd8ed1ab_0 49 | - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0 50 | - jupyterlab_server=2.3.0=pyhd8ed1ab_0 51 | - lame=3.100=h7b6447c_0 52 | - lcms2=2.11=h396b838_0 53 | - ld_impl_linux-64=2.33.1=h53a641e_7 54 | - libedit=3.1.20191231=h14c3975_1 55 | - libffi=3.3=he6710b0_2 56 | - libgcc-ng=9.1.0=hdf63c60_0 57 | - libiconv=1.15=h63c8f33_5 58 | - libpng=1.6.37=hbc83047_0 59 | - libsodium=1.0.18=h36c2ea0_1 60 | - libstdcxx-ng=9.1.0=hdf63c60_0 61 | - libtiff=4.1.0=h2733197_1 62 | - libuv=1.40.0=h7b6447c_0 63 | - lz4-c=1.9.3=h2531618_0 64 | - markupsafe=1.1.1=py38h8df0ef7_2 65 | - mistune=0.8.4=py38h25fe258_1002 66 | - mkl=2020.2=256 67 | - mkl-service=2.3.0=py38he904b0f_0 68 | - mkl_fft=1.3.0=py38h54f3939_0 69 | - mkl_random=1.1.1=py38h0573a6f_0 70 | - nbclassic=0.2.6=pyhd8ed1ab_0 71 | - nbclient=0.5.3=pyhd8ed1ab_0 72 | - nbconvert=6.0.7=py38h578d9bd_3 73 | - nbformat=5.1.2=pyhd8ed1ab_1 74 | - ncurses=6.2=he6710b0_1 75 | - nest-asyncio=1.4.3=pyhd8ed1ab_0 76 | - nettle=3.4.1=hbb512f6_0 77 | - ninja=1.10.2=py38hff7bd54_0 78 | - notebook=6.2.0=py38h578d9bd_0 79 | - numpy=1.19.2=py38h54aff64_0 80 | - numpy-base=1.19.2=py38hfa32c7d_0 81 | - olefile=0.46=py_0 82 | - openh264=2.1.0=hd408876_0 83 | - openssl=1.1.1j=h27cfd23_0 84 | - packaging=20.9=pyh44b312d_0 85 | - pandoc=2.12=h7f98852_0 86 | - pandocfilters=1.4.2=py_1 87 | - parso=0.8.1=pyhd8ed1ab_0 88 | - pexpect=4.8.0=pyh9f0ad1d_2 89 | - pickleshare=0.7.5=py_1003 90 | - pillow=8.1.2=py38he98fc37_0 91 | - pip=21.0.1=py38h06a4308_0 92 | - prometheus_client=0.9.0=pyhd3deb0d_0 93 | - prompt-toolkit=3.0.16=pyha770c72_0 94 | - ptyprocess=0.7.0=pyhd3deb0d_0 95 | - pycparser=2.20=pyh9f0ad1d_2 96 | - pygments=2.8.1=pyhd8ed1ab_0 97 | - pyopenssl=19.1.0=py38_0 98 | - pyparsing=2.4.7=pyh9f0ad1d_0 99 | - pyrsistent=0.17.3=py38h25fe258_1 100 | - pysocks=1.7.1=py38h578d9bd_3 101 | - python=3.8.5=h7579374_1 102 | - python-dateutil=2.8.1=py_0 103 | - python_abi=3.8=1_cp38 104 | - pytorch=1.8.0=py3.8_cuda10.2_cudnn7.6.5_0 105 | - pytz=2021.1=pyhd8ed1ab_0 106 | - pyzmq=19.0.2=py38ha71036d_2 107 | - readline=8.1=h27cfd23_0 108 | - requests=2.25.1=pyhd3deb0d_0 109 | - send2trash=1.5.0=py_0 110 | - setuptools=52.0.0=py38h06a4308_0 111 | - six=1.15.0=py38h06a4308_0 112 | - sniffio=1.2.0=py38h578d9bd_1 113 | - sqlite=3.33.0=h62c20be_0 114 | - terminado=0.9.2=py38h578d9bd_0 115 | - testpath=0.4.4=py_0 116 | - tk=8.6.10=hbc83047_0 117 | - torchvision=0.9.0=py38_cu102 118 | - tornado=6.1=py38h25fe258_0 119 | - traitlets=5.0.5=py_0 120 | - typing_extensions=3.7.4.3=pyha847dfd_0 121 | - urllib3=1.26.3=pyhd8ed1ab_0 122 | - wcwidth=0.2.5=pyh9f0ad1d_2 123 | - webencodings=0.5.1=py_1 124 | - wheel=0.36.2=pyhd3eb1b0_0 125 | - xz=5.2.5=h7b6447c_0 126 | - zeromq=4.3.3=he6710b0_3 127 | - zipp=3.4.1=pyhd8ed1ab_0 128 | - zlib=1.2.11=h7b6447c_3 129 | - zstd=1.4.5=h9ceee32_0 130 | - pip: 131 | - absl-py==0.11.0 132 | - aiohttp==3.7.4.post0 133 | - appdirs==1.4.4 134 | - async-timeout==3.0.1 135 | - autopep8==1.5.5 136 | - cachetools==4.2.1 137 | - distlib==0.3.1 138 | - filelock==3.0.12 139 | - fsspec==0.8.7 140 | - future==0.18.2 141 | - google-auth==1.27.1 142 | - google-auth-oauthlib==0.4.3 143 | - grpcio==1.36.1 144 | - jupyter-http-over-ws==0.0.8 145 | - markdown==3.3.4 146 | - multidict==5.1.0 147 | - oauthlib==3.1.0 148 | - protobuf==3.15.5 149 | - pyasn1==0.4.8 150 | - pyasn1-modules==0.2.8 151 | - pycodestyle==2.6.0 152 | - pyyaml==5.3.1 153 | - requests-oauthlib==1.3.0 154 | - rsa==4.7.2 155 | - tensorboard==2.4.1 156 | - tensorboard-plugin-wit==1.8.0 157 | - toml==0.10.2 158 | - tqdm==4.59.0 159 | - werkzeug==1.0.1 160 | - yarl==1.6.3 161 | prefix: /home/mohith/miniconda3/envs/mine 162 | -------------------------------------------------------------------------------- /mine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mohith-sakthivel/mine-pytorch/eedf0103a05837836f1e97747d6e3b2502edc5f1/mine/__init__.py -------------------------------------------------------------------------------- /mine/ib.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import tqdm 4 | import yaml 5 | import pathlib 6 | import argparse 7 | import itertools 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torch.distributions as dist 13 | import torch.nn.functional as F 14 | 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | 18 | 19 | from mine.mine import get_estimator 20 | from mine.models import MLP 21 | from mine.utils.log import Logger 22 | from mine.utils.train import PolyakAveraging, BetaScheduler 23 | from mine.utils.data_loader import CustomSampler, CustomBatchSampler 24 | 25 | 26 | class Classifier(nn.Module): 27 | """ 28 | Multi-label classifier with cross entropy loss 29 | """ 30 | 31 | def __init__(self, base_net, K, lr=1e-4, base_net_args={}, use_polyak=True, logdir='.'): 32 | super().__init__() 33 | self._K = K 34 | self._lr = lr 35 | self._use_polyak = use_polyak 36 | self.logdir = logdir 37 | 38 | self._base_net = base_net( 39 | input_dim=28*28, output_dim=K, **base_net_args) 40 | self._logits = nn.Linear(K, 10) 41 | self._model_list = [self._base_net, self._logits] 42 | self._current_epoch = 0 43 | self._initialize_weights() 44 | self._configure_optimizers() 45 | self._configure_callbacks() 46 | 47 | def _initialize_weights(self): 48 | for (name, param) in self.named_parameters(): 49 | if 'weight' in name: 50 | nn.init.xavier_uniform_(param) 51 | elif 'bias' in name: 52 | nn.init.zeros_(param) 53 | else: 54 | raise ValueError 55 | 56 | def _configure_optimizers(self): 57 | optimizer = optim.Adam( 58 | self.get_model_parameters(), lr=self._lr, betas=(0.5, 0.999)) 59 | scheduler = { 60 | 'scheduler': optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97), 61 | 'frequency': 2 62 | } 63 | self.optimizers, self.schedulers = ([optimizer], [scheduler]) 64 | 65 | @property 66 | def device(self): 67 | return self._logits.weight.device 68 | 69 | def _configure_callbacks(self): 70 | self._callbacks = [] 71 | if self._use_polyak: 72 | self._callbacks.append(PolyakAveraging()) 73 | 74 | def invoke_callback(self, hook): 75 | for callback in self._callbacks: 76 | if hasattr(callback, hook): 77 | func = getattr(callback, hook) 78 | func(self) 79 | 80 | def get_model_parameters(self): 81 | return itertools.chain(*[module.parameters() for module in self._model_list]) 82 | 83 | def step_epoch(self): 84 | self._current_epoch += 1 85 | 86 | @staticmethod 87 | def _get_grad_norm(params, device): 88 | total_grad = torch.zeros([], device=device) 89 | for param in params: 90 | total_grad += param.grad.data.norm().square() 91 | return total_grad.sqrt() 92 | 93 | def _unpack_batch(self, batch): 94 | batch = [item.to(self.device) for item in batch] 95 | x, y = batch 96 | x = x.view(x.shape[0], -1) 97 | return (x, y) 98 | 99 | def _get_embedding(self, x, mc_samples=1): 100 | z = self._base_net(x) 101 | if self._base_net.is_stochastic(): 102 | mean, std = z 103 | z = dist.Independent(dist.Normal(mean, std), 104 | 1).rsample([mc_samples]) 105 | else: 106 | z = z.unsqueeze(dim=0) 107 | return z 108 | 109 | def forward(self, x, mc_samples=1): 110 | x = self._get_embedding(x, mc_samples=mc_samples) 111 | return self._logits(x).mean(dim=0) 112 | 113 | def _get_eval_stats(self, batch, batch_idx, mc_samples=1): 114 | stats = {} 115 | x, y = self._unpack_batch(batch) 116 | y_pred = self(x, mc_samples) 117 | y_pred = torch.argmax(y_pred, dim=1) 118 | stats['error'] = torch.sum(y != y_pred)/len(y)*100 119 | return stats 120 | 121 | def training_step(self, batch, batch_idx, logger): 122 | opt = self.optimizers[0] 123 | opt.zero_grad() 124 | x, y = self._unpack_batch(batch) 125 | z = self._get_embedding(x).mean(dim=0) 126 | logits = self._logits(z) 127 | loss = F.cross_entropy(logits, y) 128 | stats = {'loss': loss.detach().cpu().numpy()} 129 | logger.scalar(stats['loss'], 'cross_ent', 130 | accumulator='train', progbar=True) 131 | loss.backward() 132 | opt.step() 133 | grad_norm = self._get_grad_norm( 134 | self.get_model_parameters(), self.device) 135 | logger.scalar(grad_norm, 'model_grad_norm', accumulator='train') 136 | return stats 137 | 138 | def evaluation_step(self, batch, batch_idx, logger, mc_samples=1, tag='error', accumulator='test'): 139 | stats = self._get_eval_stats(batch, batch_idx, mc_samples) 140 | logger.scalar(stats['error'], tag, accumulator=accumulator) 141 | 142 | 143 | class MINE_Classifier(Classifier): 144 | """ 145 | Classifier that uses Information Bottleneck (IB) regularization using MINE 146 | """ 147 | 148 | def __init__(self, base_net, K, beta=1e-3, mine_args={}, **kwargs): 149 | super().__init__(base_net, K, **kwargs) 150 | self._mine = get_estimator(28*28, K, mine_args) 151 | self._beta = BetaScheduler( 152 | 0, beta, 0) if isinstance(beta, (int, float)) else beta 153 | self._configure_mine_optimizers() 154 | 155 | def _configure_mine_optimizers(self): 156 | opt, sch = self._mine._configure_optimizers() 157 | self.optimizers.append(opt) 158 | if sch is not None: 159 | self.schedulers.append(sch) 160 | 161 | def step_epoch(self): 162 | super().step_epoch() 163 | self._mine.step_epoch() 164 | 165 | def _unpack_margin_batch(self, batch): 166 | batch = [item.to(self.device) for item in batch] 167 | x, y = batch 168 | x = x.view(x.shape[0], -1) 169 | x, x_margin = torch.chunk(x, chunks=2, dim=0) 170 | y, y_margin = torch.chunk(y, chunks=2, dim=0) 171 | return x, y, x_margin, y_margin 172 | 173 | def _get_train_embedding(self, x): 174 | z = self._base_net(x) 175 | z_dist = None 176 | if self._base_net.is_stochastic(): 177 | mean, std = z 178 | z_dist = dist.Independent(dist.Normal(mean, std), 1) 179 | z = z_dist.rsample() 180 | return z, z_dist 181 | 182 | def model_train(self, x, y, x_margin, opt, logger): 183 | """ Train classifier """ 184 | opt.zero_grad() 185 | # calculate loss 186 | z, z_dist = self._get_train_embedding(x) 187 | self._cache = {'z': z.detach()} # cache z for MINE loss calculation 188 | if x_margin is not None: 189 | z_margin, _ = self._get_train_embedding(x_margin) 190 | self._cache['z_margin'] = z_margin.detach() 191 | else: 192 | self._cache['z_margin'] = z_margin = None 193 | logits = self._logits(z) 194 | cross_entropy = F.cross_entropy(logits, y) 195 | mi_xz = self._mine.get_mi_bound(x, z, z_margin, update_ema=True) 196 | loss = cross_entropy + self._beta.get(self._current_epoch) * mi_xz 197 | loss /= math.log(2) 198 | # log train stats 199 | if z_dist is not None: 200 | logger.scalar(z_dist.entropy().mean(), 'z_post_ent', 201 | accumulator='train', progbar=True) 202 | logger.scalar(z_dist.stddev.mean(), 'z_post_std_dev', 203 | accumulator='train', progbar=False) 204 | logger.scalar(cross_entropy, 'cross_ent', 205 | accumulator='train', progbar=True) 206 | logger.scalar(mi_xz, 'mi_xz', accumulator='train', progbar=True) 207 | logger.scalar(loss, 'total_loss', 208 | accumulator='train', progbar=False) 209 | # step optimizer 210 | loss.backward() 211 | opt.step() 212 | grad_norm = self._get_grad_norm( 213 | self.get_model_parameters(), self.device) 214 | logger.scalar(grad_norm, 'model_grad_norm', accumulator='train') 215 | 216 | def mine_train(self, x, z, z_margin, opt, logger): 217 | opt.zero_grad() 218 | # calculate loss 219 | loss = -self._mine.get_mi_bound(x, z, z_margin, update_ema=True) 220 | # log stats 221 | logger.scalar(loss, 'estimator_loss', 222 | accumulator='train', progbar=True) 223 | # step optimizer 224 | loss.backward() 225 | opt.step() 226 | grad_norm = self._get_grad_norm(self._mine.parameters(), self.device) 227 | logger.scalar(grad_norm, 'mine_grad_norm', accumulator='train') 228 | 229 | def training_step(self, batch, batch_idx, logger, train_mine=False): 230 | # unpack data and optimizers 231 | model_opt, mine_opt = self.optimizers 232 | if self._mine.variant == 'marginal': 233 | x, y, x_margin, _ = self._unpack_margin_batch(batch) 234 | else: 235 | x, y, x_margin = *self._unpack_batch(batch), None 236 | # train model 237 | self.model_train(x, y, x_margin, model_opt, logger) 238 | # train statistics network 239 | if train_mine: 240 | self.mine_train(x, self._cache['z'], self._cache['z_margin'], 241 | mine_opt, logger) 242 | self._cache = {} 243 | 244 | def mine_training_step(self, batch, batch_idx, logger): 245 | # unpack data and optimizers 246 | _, mine_opt = self.optimizers 247 | if self._mine.variant == 'marginal': 248 | x, _, x_margin, _ = self._unpack_margin_batch(batch) 249 | else: 250 | x, _, x_margin = *self._unpack_batch(batch), None 251 | # get z embeddings from encoder 252 | with torch.no_grad(): 253 | z, _ = self._get_train_embedding(x) 254 | z_margin = None 255 | if self._mine.variant == 'marginal': 256 | z_margin, _ = self._get_train_embedding(x_margin) 257 | # train statistics network 258 | self.mine_train(x, z, z_margin, mine_opt, logger) 259 | 260 | 261 | MODELS = { 262 | 'deter': Classifier, 263 | 'mine': MINE_Classifier 264 | } 265 | 266 | 267 | def run(args): 268 | 269 | if args['seed'] is not None: 270 | torch.manual_seed(args['seed']) 271 | 272 | Model = MODELS.get(args['model_id']) 273 | 274 | # setup datasets and dataloaders 275 | data_transforms = transforms.Compose([transforms.ToTensor(), 276 | transforms.Normalize(0.5, 0.5)]) 277 | 278 | train_dataset = datasets.MNIST( 279 | './data', train=True, download=True, transform=data_transforms) 280 | # use a custom dataloader if z marginals are to be calculated over the true marginal 281 | if args['model_id'] == 'mine' and args['model_args']['mine_args']['variant'] == 'marginal': 282 | # use a custom dataloader if z marginals are to be calculated over the true marginal. 283 | # each batch contains two 2*batchsize samples [batchsize (for joint) + batchsize (for marginals)] 284 | sampler = CustomSampler(train_dataset, secondary_replacement=False) 285 | batch_sampler = CustomBatchSampler(sampler, 286 | batch_size=args['batch_size'], 287 | drop_last=False) 288 | train_loader = torch.utils.data.DataLoader(train_dataset, 289 | batch_sampler=batch_sampler, 290 | num_workers=args['workers']) 291 | else: 292 | train_loader = torch.utils.data.DataLoader(train_dataset, 293 | batch_size=args['batch_size'], 294 | shuffle=True, 295 | num_workers=args['workers']) 296 | 297 | test_dataset = datasets.MNIST( 298 | './data', train=False, download=True, transform=data_transforms) 299 | test_loader = torch.utils.data.DataLoader(test_dataset, 300 | batch_size=args['batch_size'], 301 | shuffle=False, 302 | num_workers=args['workers']) 303 | 304 | # setup logging 305 | logdir = pathlib.Path(args['logdir']) 306 | time_stamp = time.strftime("%d-%m-%Y_%H:%M:%S") 307 | logdir = logdir.joinpath(args['model_id'], '_'.join( 308 | [args['exp_name'], 's{}'.format(args['seed']), time_stamp])) 309 | logger = Logger(log_dir=logdir) 310 | # save experimetn parameters 311 | with open(logdir.joinpath('hparams.yaml'), 'w') as out: 312 | yaml.dump(args, out) 313 | args['model_args']['logdir'] = logdir 314 | 315 | model = Model(MLP, **args['model_args']) 316 | print('Using {}...'.format(args['device'])) 317 | model.to(args['device']) 318 | 319 | # Training loop 320 | model.invoke_callback('on_train_start') 321 | for epoch in tqdm.trange(1, args['epochs']+1, disable=True): 322 | model.step_epoch() 323 | model.train(True) 324 | 325 | for batch_idx, batch in enumerate(tqdm.tqdm(train_loader, 326 | desc='Model | {}/{} Epochs'.format( 327 | epoch-1, args['epochs']), 328 | unit=' batches', 329 | postfix=logger.get_progbar_desc(), 330 | leave=False)): 331 | # Train MINE 332 | if args['model_id'] == 'mine': 333 | for _ in tqdm.trange(args['mine_train_steps'], disable=True): 334 | _ = model.mine_training_step(batch, batch_idx, logger) 335 | # Train Model 336 | _ = model.training_step(batch, batch_idx, logger) 337 | model.invoke_callback('on_train_batch_end') 338 | # Post epoch processing 339 | _ = logger.scalar_queue_flush('train', epoch) 340 | 341 | for sch in model.schedulers: 342 | if epoch % sch['frequency'] == 0: 343 | sch['scheduler'].step() 344 | 345 | # Run validation step 346 | if (args['validation_freq'] is not None and 347 | epoch % args['validation_freq'] == 0): 348 | model.eval() 349 | # testset used in validation step for observation/study purpose 350 | for batch_idx, batch in enumerate(test_loader): 351 | model.evaluation_step(batch, batch_idx, logger, mc_samples=1, 352 | tag='error', accumulator='validation') 353 | if args['mc_samples'] > 1: 354 | model.evaluation_step(batch, batch_idx, logger, mc_samples=args['mc_samples'], 355 | tag='error_mc', accumulator='validation') 356 | _ = logger.scalar_queue_flush('validation', epoch) 357 | # invoke post training callbacks 358 | model.invoke_callback('on_train_end') 359 | 360 | # Test model 361 | model.eval() 362 | for batch_idx, batch in enumerate(test_loader): 363 | model.evaluation_step(batch, batch_idx, logger, 364 | mc_samples=1, tag='error') 365 | if args['mc_samples'] > 1: 366 | model.evaluation_step(batch, batch_idx, logger, 367 | mc_samples=args['mc_samples'], tag='error_mc') 368 | test_out = logger.scalar_queue_flush('test', epoch) 369 | 370 | print('***************************************************') 371 | print('Model Test Error: {:.4f}%'.format(test_out['error'])) 372 | if args['mc_samples'] > 1: 373 | print('Model Test Error ({} sample avg): {:.4f}%'.format( 374 | args['mc_samples'], test_out['error_mc'])) 375 | print('***************************************************') 376 | logger.close() 377 | 378 | 379 | def get_default_args(model_id): 380 | """ 381 | Returns default experiment arguments 382 | """ 383 | 384 | args = { 385 | 'exp_name': 'mine_ib', 386 | 'seed': 0, 387 | # Trainer args 388 | 'device': 'cuda' if torch.cuda.is_available() else 'cpu', 389 | 'epochs': 200, 390 | 'logdir': './logs', 391 | 'validation_freq': 1, 392 | # Dataset args 393 | 'batch_size': 100, 394 | 'workers': 4, 395 | # Model args 396 | 'model_args': { 397 | 'lr': 1e-4, 398 | 'use_polyak': True, 399 | } 400 | } 401 | 402 | if model_id == 'deter': 403 | args['model_args']['K'] = 1024 404 | args['model_args']['base_net_args'] = { 405 | 'layers': [784, 1024, 1024], 'stochastic': False} 406 | args['mc_samples'] = 1 407 | 408 | elif model_id == 'mine': 409 | args['model_args']['K'] = 256 410 | args['model_args']['base_net_args'] = { 411 | 'layers': [784, 1024, 1024], 'stochastic': True} 412 | args['mc_samples'] = 12 413 | args['model_args']['beta'] = 1e-3 414 | args['model_args']['mine_args'] = {} 415 | args['model_args']['mine_args']['estimator'] = 'dv' 416 | args['model_args']['mine_args']['est_lr'] = 2e-4 417 | args['model_args']['mine_args']['variant'] = 'unbiased' 418 | args['mine_train_steps'] = 1 419 | return args 420 | 421 | 422 | if __name__ == "__main__": 423 | parser = argparse.ArgumentParser(prog='Information BottleNeck with MINE') 424 | 425 | parser.add_argument('--exp_name', action='store', type=str, 426 | help='Experiment Name') 427 | parser.add_argument('--seed', action='store', type=int) 428 | parser.add_argument('--logdir', action='store', type=str, 429 | help='Directory to log results') 430 | 431 | group = parser.add_mutually_exclusive_group(required=True) 432 | group.add_argument('--deter', action='store_const', dest='model_id', const='deter', 433 | help='Run baseline') 434 | group.add_argument('--mine', action='store_const', dest='model_id', const='mine', 435 | help='Run MINE + IB model') 436 | 437 | estimator = parser.add_mutually_exclusive_group() 438 | estimator.add_argument('--dv', dest='estimator', action='store_const', 439 | const='dv', help='Use Donsker-Varadhan estimator') 440 | estimator.add_argument('--fdiv', dest='estimator', action='store_const', 441 | const='fdiv', help='Use f-divergence estimator') 442 | estimator.add_argument('--nwj', dest='estimator', action='store_const', 443 | const='nwj', help='Use NWJ estimator') 444 | 445 | variant = parser.add_mutually_exclusive_group() 446 | variant.add_argument('--unbiased', dest='variant', action='store_const', 447 | const='unbiased', help='Use unbiased MI estimator') 448 | variant.add_argument('--biased', dest='variant', action='store_const', 449 | const='biased', help='Use biased MI estimator') 450 | variant.add_argument('--marginal', dest='variant', action='store_const', 451 | const='marginal', help='Use samples from true marginal for MI estimation') 452 | 453 | parser.add_argument('--epochs', action='store', type=int) 454 | parser.add_argument('--beta', action='store', type=float, 455 | help='information bottleneck coefficient') 456 | args = parser.parse_args() 457 | 458 | model_args = ['K', 'lr', 'use_polyak', 'beta'] 459 | mine_args = ['estimator', 'est_lr', 'variant'] 460 | 461 | exp_args = get_default_args(args.model_id) 462 | for key, value in args.__dict__.items(): 463 | if value is not None: 464 | if key in model_args: 465 | exp_args['model_args'][key] = value 466 | elif key in mine_args: 467 | exp_args['model_args']['mine_args'][key] = value 468 | else: 469 | exp_args[key] = value 470 | 471 | run(exp_args) 472 | -------------------------------------------------------------------------------- /mine/mine.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.autograd import Function 9 | 10 | from mine.models import StatisticsNet 11 | 12 | 13 | class UnbiasedLogMeanExp(Function): 14 | """ 15 | Calculates and uses gradients with reduced bias 16 | """ 17 | 18 | epsilon = 1e-6 19 | 20 | @staticmethod 21 | def forward(ctx, i, ema): 22 | ctx.save_for_backward(i, ema) 23 | mean_numel = torch.tensor(i.shape[0], dtype=torch.float) 24 | output = i.logsumexp(dim=0).subtract(torch.log(mean_numel)) 25 | return output 26 | 27 | @staticmethod 28 | def backward(ctx, grad_output): 29 | i, ema = ctx.saved_tensors 30 | grad_i = grad_ema = None 31 | mean_numel = torch.tensor(i.shape[0], dtype=torch.float) 32 | grad_i = grad_output * \ 33 | (i.exp() / ((ema+UnbiasedLogMeanExp.epsilon)*mean_numel)) 34 | return grad_i, grad_ema 35 | 36 | 37 | class MINE_Base(nn.Module): 38 | 39 | def __init__(self, input_dim, K, est_lr=2e-4, variant='unbiased'): 40 | super().__init__() 41 | self._T = StatisticsNet(input_dim, K) 42 | self._est_lr = est_lr 43 | self.variant = variant 44 | self._current_epoch = 0 45 | 46 | def _configure_optimizers(self): 47 | opt = optim.Adam(self._T.parameters(), 48 | lr=self._est_lr, betas=(0.5, 0.999)) 49 | sch = None 50 | return opt, sch 51 | 52 | def step_epoch(self): 53 | self._current_epoch += 1 54 | 55 | 56 | class MINE_DV(MINE_Base): 57 | _ANNEAL_PERIOD = 0 58 | _EMA_ANNEAL_PERIOD = 0 59 | 60 | def __init__(self, *args, **kwargs): 61 | super().__init__(*args, **kwargs) 62 | self._decay = 0.994 # decay for ema (not tuned) 63 | self._ema = None 64 | 65 | def _update_ema(self, t_margin): 66 | with torch.no_grad(): 67 | exp_t = t_margin.exp().mean(dim=0) 68 | if self._ema is not None: 69 | self._ema = self._decay * self._ema + (1-self._decay) * exp_t 70 | else: 71 | self._ema = exp_t 72 | 73 | def get_mi_bound(self, x, z, z_margin=None, update_ema=False): 74 | t_joint = self._T(x, z).mean(dim=0) 75 | if z_margin is not None: 76 | t_margin = self._T(x, z_margin) 77 | else: 78 | t_margin = self._T(x, z[torch.randperm(x.shape[0])]) 79 | # maintain an exponential moving average of exp_t under the marginal distribution 80 | # done to reduce bias in the estimator 81 | if ((self.variant == 'unbiased' and update_ema) and 82 | self._current_epoch > self._EMA_ANNEAL_PERIOD): 83 | self._update_ema(t_margin) 84 | # Calculate biased or unbiased estimate 85 | if self.variant == 'unbiased' and self._current_epoch > self._ANNEAL_PERIOD: 86 | log_exp_t = UnbiasedLogMeanExp.apply(t_margin, self._ema) 87 | else: 88 | log_exp_t = t_margin.logsumexp( 89 | dim=0).subtract(math.log(x.shape[0])) 90 | # mi lower bound 91 | return t_joint - log_exp_t 92 | 93 | 94 | class MINE_f_Div(MINE_Base): 95 | 96 | def get_mi_bound(self, x, z, z_margin=None, update_ema=False): 97 | t_joint = self._T(x, z).mean(dim=0) 98 | if z_margin is not None: 99 | t_margin = self._T(x, z_margin) 100 | else: 101 | t_margin = self._T(x, z[torch.randperm(x.shape[0])]) 102 | 103 | exp_t = torch.mean(torch.exp(t_margin-1), dim=0) 104 | # mi lower bound 105 | return t_joint - exp_t 106 | 107 | 108 | class NWJ(nn.Module): 109 | """ 110 | NWJ (Nguyen, Wainwright, and Jordan) estimator 111 | """ 112 | 113 | def __init__(self, input_dim, K, est_lr=2e-4, variant='unbiased'): 114 | super().__init__() 115 | self._critic = StatisticsNet(input_dim, K) 116 | # from mine.models import ConcatCritic, BiLinearCritic 117 | # self._critic = ConcatCritic(input_dim, K) 118 | # self._critic = BiLinearCritic(input_dim, K) 119 | self._est_lr = est_lr 120 | self._current_epoch = 0 121 | self.variant = variant 122 | 123 | def _configure_optimizers(self): 124 | opt = optim.Adam(self._critic.parameters(), 125 | lr=self._est_lr, betas=(0.5, 0.999)) 126 | sch = None 127 | return opt, sch 128 | 129 | def step_epoch(self): 130 | self._current_epoch += 1 131 | 132 | @staticmethod 133 | def logmeanexp_nondiag(tensor): 134 | batch_size = tensor.shape[0] 135 | device = tensor.device 136 | dim = (0, 1) 137 | numel = batch_size * (batch_size-1) 138 | logsumexp = torch.logsumexp(tensor - torch.diag(np.inf * torch.ones(batch_size, device=device)), dim=dim) 139 | return logsumexp - np.math.log(numel) 140 | 141 | def get_mi_bound(self, x, z, z_margin=None, update_ema=None): 142 | joint = self._critic(x, z).mean(dim=0) 143 | if z_margin is not None: 144 | margin = self._critic(x, z_margin) 145 | else: 146 | margin = self._critic(x, z[torch.randperm(x.shape[0])]) 147 | margin = torch.logsumexp(margin, dim=0) - np.math.log(z.shape[0]) 148 | margin = torch.exp(margin-1) 149 | return joint - margin 150 | 151 | 152 | def get_estimator(input_dim, K, args_dict): 153 | args_dict = args_dict.copy() 154 | estimator = args_dict.pop('estimator') 155 | if estimator == 'dv': 156 | return MINE_DV(input_dim, K, **args_dict) 157 | elif estimator == 'fdiv': 158 | return MINE_f_Div(input_dim, K, **args_dict) 159 | elif estimator == 'nwj': 160 | return NWJ(input_dim, K, **args_dict) 161 | else: 162 | raise ValueError 163 | -------------------------------------------------------------------------------- /mine/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FeatureExtractor(nn.Module): 7 | def __init__(self, input_dim, output_dim, stochastic=False, 8 | init_std_bias=-5.0, min_std=1e-8): 9 | super().__init__() 10 | self._stochastic = stochastic 11 | if stochastic: 12 | self._init_std_bias = init_std_bias 13 | self._min_std = min_std 14 | self.latent_out = 2 * output_dim 15 | else: 16 | self.latent_out = output_dim 17 | self._net = None 18 | 19 | def is_stochastic(self): 20 | return self._stochastic 21 | 22 | def forward(self, x): 23 | x = self._net(x) 24 | if self._stochastic: 25 | # parameterize outputs as a sample from a gaussian 26 | mean, std = torch.chunk(x, chunks=2, dim=1) 27 | std = nn.functional.softplus( 28 | std+self._init_std_bias) + self._min_std 29 | return mean, std 30 | return x 31 | 32 | 33 | class MLP(FeatureExtractor): 34 | """ 35 | MLP with gaussian or deterministic outputs 36 | """ 37 | 38 | def __init__(self, input_dim, output_dim, layers, act=nn.ReLU, **kwargs): 39 | super().__init__(input_dim, output_dim, **kwargs) 40 | 41 | net_layers = nn.ModuleList() 42 | inp = input_dim 43 | for layer_dim in layers: 44 | net_layers.append(nn.Linear(inp, layer_dim)) 45 | net_layers.append(act()) 46 | inp = layer_dim 47 | 48 | net_layers.append(nn.Linear(inp, self.latent_out)) 49 | self._net = nn.Sequential(*net_layers) 50 | 51 | 52 | def make_layers(cfg, batch_norm=False): 53 | layers = [] 54 | in_channels = 3 55 | for v in cfg: 56 | if v == 'M': 57 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 58 | else: 59 | v = int(v) 60 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 61 | if batch_norm: 62 | layers += [conv2d, 63 | nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 64 | else: 65 | layers += [conv2d, nn.ReLU(inplace=True)] 66 | in_channels = v 67 | return nn.Sequential(*layers) 68 | 69 | 70 | class VGGNet(FeatureExtractor): 71 | 72 | cfgs = { 73 | 'vgg_9': [32, 'M', 64, 'M', 128, 128, 'M', 256, 256, 'M'], 74 | 'vgg_11': [32, 'M', 64, 'M', 128, 128, 'M', 256, 256, 'M', 256, 256, 'M'] 75 | } 76 | 77 | def __init__(self, input_dim, output_dim, bn=False, arch='vgg_11', **kwargs): 78 | super().__init__(input_dim, output_dim, **kwargs) 79 | 80 | self._arch = arch 81 | self._base_layers = make_layers(self.cfgs[arch], bn) 82 | 83 | with torch.no_grad(): 84 | dum_input = torch.zeros((1,) + input_dim, dtype=torch.float32) 85 | flat_shape = self._base_layers(dum_input).shape.numel() 86 | 87 | self._net = nn.Sequential( 88 | self._base_layers, 89 | nn.Flatten(start_dim=1), 90 | nn.Linear(flat_shape, 256), 91 | nn.ReLU(inplace=True), 92 | nn.Linear(256, self.latent_out) 93 | ) 94 | self._initialize_weights() 95 | 96 | def _initialize_weights(self) -> None: 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | nn.init.kaiming_normal_( 100 | m.weight, mode='fan_out', nonlinearity='relu') 101 | if m.bias is not None: 102 | nn.init.constant_(m.bias, 0) 103 | elif isinstance(m, nn.BatchNorm2d): 104 | nn.init.constant_(m.weight, 1) 105 | nn.init.constant_(m.bias, 0) 106 | elif isinstance(m, nn.Linear): 107 | nn.init.normal_(m.weight, 0, 0.01) 108 | nn.init.constant_(m.bias, 0) 109 | 110 | 111 | class StatisticsNet(nn.Module): 112 | """ 113 | Network for estimating mutual information between two random variables 114 | """ 115 | 116 | def __init__(self, x_dim, z_dim): 117 | super().__init__() 118 | self._layers = nn.ModuleList() 119 | self._layers.append(nn.Linear(x_dim+z_dim, 512)) 120 | self._layers.append(nn.Linear(512, 512)) 121 | self._out_layer = nn.Linear(512, 1) 122 | self._initialize_weights() 123 | 124 | def _initialize_weights(self): 125 | for (name, param) in self._layers.named_parameters(): 126 | if 'weight' in name: 127 | nn.init.kaiming_normal_(param, nonlinearity='relu') 128 | elif 'bias' in name: 129 | nn.init.zeros_(param) 130 | 131 | for (name, param) in self._out_layer.named_parameters(): 132 | if 'weight' in name: 133 | nn.init.kaiming_normal_(param, nonlinearity='linear') 134 | elif 'bias' in name: 135 | nn.init.zeros_(param) 136 | 137 | def forward(self, x, z): 138 | x = torch.cat([x, z], dim=1) 139 | x = x + 0.3 * torch.randn_like(x) 140 | for hid_layer in self._layers: 141 | x = F.elu(hid_layer(x)) 142 | x = x + 0.5 * torch.randn_like(x) 143 | return self._out_layer(x) 144 | -------------------------------------------------------------------------------- /mine/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mohith-sakthivel/mine-pytorch/eedf0103a05837836f1e97747d6e3b2502edc5f1/mine/utils/__init__.py -------------------------------------------------------------------------------- /mine/utils/data_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Sampler, RandomSampler, BatchSampler 2 | 3 | 4 | class CustomSampler(Sampler): 5 | def __init__(self, data_source, secondary_replacement=False): 6 | self.data_source = data_source 7 | self.secondary_replacement = secondary_replacement 8 | self._batch_sampler = RandomSampler(data_source) 9 | self._marginal_sampler = RandomSampler(data_source, 10 | replacement=secondary_replacement) 11 | 12 | def __len__(self) -> int: 13 | return len(self.data_source) 14 | 15 | def __iter__(self): 16 | return zip(self._batch_sampler.__iter__(), self._marginal_sampler.__iter__()) 17 | 18 | 19 | class CustomBatchSampler(BatchSampler): 20 | def __iter__(self): 21 | batch = [] 22 | secondary_batch = [] 23 | for idx, secondary_idx in self.sampler: 24 | batch.append(idx) 25 | secondary_batch.append(secondary_idx) 26 | if len(batch) == self.batch_size: 27 | yield batch + secondary_batch 28 | batch = [] 29 | secondary_batch = [] 30 | if len(batch) > 0 and not self.drop_last: 31 | yield batch + secondary_batch -------------------------------------------------------------------------------- /mine/utils/log.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.tensorboard.writer import SummaryWriter 5 | 6 | 7 | class Logger(): 8 | 9 | def __init__(self, log_dir): 10 | self.log_dir = log_dir 11 | self._tb_logger = SummaryWriter(log_dir=self.log_dir) 12 | self._data = {} 13 | self._progbar = StatsDescriptor() 14 | 15 | def scalar(self, value, tag, step=None, accumulator=None, progbar=False): 16 | assert step is None or accumulator is None 17 | if isinstance(value, torch.Tensor): 18 | value = value.detach().cpu().numpy() 19 | if accumulator is None: 20 | self._tb_logger.add_scalar(tag, value, global_step=step) 21 | _ = self._progbar.add(tag, value) if progbar else None 22 | else: 23 | if not accumulator in self._data.keys(): 24 | self._data[accumulator] = {} 25 | if tag in self._data[accumulator].keys(): 26 | self._data[accumulator][tag].append(value) 27 | else: 28 | self._data[accumulator][tag] = [value] 29 | if progbar and not self._progbar.contains('_'.join([accumulator, tag])): 30 | self._progbar.add('_'.join([accumulator, tag]), None) 31 | 32 | def scalar_queue_flush(self, accumulator, step=None): 33 | assert accumulator in self._data.keys() 34 | out = {} 35 | if len(self._data[accumulator]) > 0: 36 | for key, values in self._data[accumulator].items(): 37 | out[key] = np.mean(values) 38 | self._tb_logger.add_scalar('/'.join([accumulator, key]), 39 | out[key], 40 | global_step=step) 41 | if self._progbar.contains('_'.join([accumulator, key])): 42 | self._progbar.add(key, out[key]) 43 | 44 | _ = self._data.pop(accumulator) 45 | return out 46 | 47 | def scalar_queue_group_flush(self, accumulator, step=None): 48 | assert accumulator in self._data.keys() 49 | out = {} 50 | if len(self._data[accumulator]) > 0: 51 | for key, values in self._data[accumulator].items(): 52 | out[key] = np.mean(values) 53 | if self._progbar.contains('_'.join([accumulator, key])): 54 | self._progbar.add(key, out[key]) 55 | self._tb_logger.add_scalars(accumulator, out, global_step=step) 56 | _ = self._data.pop(accumulator) 57 | return out 58 | 59 | def get_progbar_desc(self): 60 | return self._progbar.get_descriptor() 61 | 62 | def close(self): 63 | self._tb_logger.flush() 64 | self._tb_logger.close() 65 | 66 | 67 | class StatsDescriptor: 68 | def __init__(self): 69 | self._stats = {} 70 | 71 | def get_descriptor(self): 72 | desc = [] 73 | for key, val in self._stats.items(): 74 | if val is not None: 75 | desc.append(key + '={:.4f}'.format(val)) 76 | return ' '.join(desc) 77 | 78 | def add(self, key, value): 79 | self._stats[key] = value 80 | 81 | def contains(self, tag): 82 | return tag in self._stats -------------------------------------------------------------------------------- /mine/utils/train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import warnings 3 | import torch 4 | 5 | 6 | class PolyakAveraging(): 7 | """ 8 | Calculates exponential moving average of parameter weights during training 9 | """ 10 | 11 | def __init__(self, alpha=0.999): 12 | self._alpha = alpha 13 | self._avg_module = None 14 | 15 | def on_train_start(self, module): 16 | if self._avg_module is None: 17 | self._avg_module = copy.deepcopy(module) 18 | else: 19 | warnings.warn( 20 | "Existing EMA(Exponential Moving Average) values of the model is being used for new 'Trainer.fit' sequence", RuntimeWarning) 21 | 22 | def on_train_batch_end(self, module): 23 | device = self._avg_module.device 24 | for src, dst in zip(module.get_model_parameters(), self._avg_module.get_model_parameters()): 25 | dst.detach().copy_(self._alpha * dst + (1-self._alpha) * src.detach().to(device)) 26 | 27 | def on_train_end(self, module): 28 | torch.save(module.state_dict, module.logdir.joinpath('train_end.pth')) 29 | device = module.device 30 | for src, dst in zip(self._avg_module.get_model_parameters(), module.get_model_parameters()): 31 | dst.detach().copy_(src.detach().to(device)) 32 | torch.save(module.state_dict, module.logdir.joinpath('ema_weights.pth')) 33 | self._avg_module = None 34 | 35 | 36 | class BetaScheduler: 37 | """ 38 | Schedules beta after an intial warmup period. Value is annealed from 39 | the intial value to the steady state value linearly. 40 | """ 41 | def __init__(self, warm_up, value, anneal_period, init_value=0): 42 | self._warm_up = warm_up 43 | self._value = value 44 | self._anneal_period = anneal_period 45 | self._init_value = init_value 46 | 47 | def get(self, epoch): 48 | if epoch <= self._warm_up: 49 | return self._init_value 50 | elif epoch > (self._warm_up + self._anneal_period): 51 | return self._value 52 | else: 53 | ratio = (epoch-self._warm_up)/(self._anneal_period) 54 | return ratio * (self._value - self._init_value) + self._init_value --------------------------------------------------------------------------------