├── .gitignore ├── README.md └── src ├── architectures ├── __init__.py ├── feedforward.py ├── helper.py ├── lstm_latent.py └── stochastic.py ├── commons ├── __init__.py ├── callbacks.py ├── func.py └── types_.py ├── datasets ├── __init__.py ├── data_modules.py ├── dsprites.py ├── helper.py └── mpi3d.py ├── evaluation ├── __init__.py ├── comp_gen.py ├── disentangle_metric_evaluator.py ├── extra_metrics_configs │ ├── beta_vae_sklearn.gin │ ├── dci.gin │ ├── factor_vae_metric.gin │ ├── irs.gin │ ├── mcc.gin │ ├── mig.gin │ ├── modularity_explicitness.gin │ └── sap_score.gin ├── group_metric.py ├── scikit_learn_evaluator.py ├── topo_sim.py └── utils.py ├── models ├── __init__.py ├── ae.py ├── beta_tcvae.py ├── optimizer.py ├── rec_el.py └── vae.py └── scripts ├── configs ├── rec_el.yaml ├── scikitlearn_eval.yaml └── vae.yaml ├── eval_gt_rep.py ├── experiments.py ├── run_el.py ├── run_tcvae.py └── run_vae.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea/ 6 | *.DS_Store 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CompGen 2 | This repository contains the official code for the paper: "Compositional Generalization in Unsupervised Compositional Representation Learning: A Study on Disentanglement and Emergent Language" (NeurIPS2022). 3 | 4 | This work propose a new protocol of evaluating compositional generalization of learned representations. Our protocol focus on whether or not it is easy to train a simple 5 | model for downstream tasks on top of the learned representation that generalizes to new combinations of compositional factors. We systematically studied $\beta$-VAE, $\beta$-TCVAE and emergent language autoencoders. 6 | 7 | 8 | 9 | ## Dependencies 10 | ``` 11 | torch == 1.8.1 12 | torchvision == 0.9.1 13 | pytorch-lightning == 1.5.8 14 | wandb == 0.12.10 15 | scikit-learn == 0.22 16 | disentanglement-lib == 1.5 17 | tensorflow == 1.15.0 18 | tensorflow == 1.15.0 19 | tensorflow-datasets == 4.2.0 20 | tensorflow-estimator ==1.15.1 21 | tensorflow-hub == 0.4.0 22 | tensorflow-metadata == 0.30.0 23 | tensorflow-probability == 0.6.0 24 | ``` 25 | ## Data 26 | Two public available datasets [dSprites](https://github.com/deepmind/dsprites-dataset) and [MPI3D](https://github.com/rr-learning/disentanglement_dataset) are used in our work. 27 | 28 | ## Run Experiments 29 | - Set configuations in ```.yaml``` files under ```scripts/configs``` or directly overload arguments in experimental scripts e.g. ```run_{MODEL_NAME}.py```. 30 | 31 | - Run Pretrain and finetune by 32 | ``` 33 | python run_vae.py -g 0 -ft 34 | python run_tcvae.py -g 0 -ft 35 | python run_el.py -g 0 -ft 36 | ``` 37 | - By default, linear readout models are used. Add `-gbt` to use GBT read models for evaluation. 38 | 39 | - If the a pretraining model with the same config exists, it will skip the pretraining use the previous saved model unless adding ```--overwrite``` tag. 40 | 41 | - Evaluate the disentanglement/compositionality metric of pretrained models 42 | ``` 43 | python run_{MODEL_NAME}.py -g 0 --compmetric 44 | ``` 45 | 46 | Add `--nowb` to disable wandb logger. 47 | -------------------------------------------------------------------------------- /src/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | -------------------------------------------------------------------------------- /src/architectures/feedforward.py: -------------------------------------------------------------------------------- 1 | """ 2 | A copy of https://github.com/mmrl/disent-and-gen/blob/9b34e902471b452b50855b1b84422a14d16beeeb/src/models/feedforward.py 3 | 4 | FeedForward module 5 | 6 | This module initializes a list of layer configurations into a feed-forward 7 | PyTorch module. This is class inherits from nn.Sequential and adds a few more 8 | methods to build a sequential layer from a set of definitions. 9 | 10 | Parameters in the config for each layer follow the order in Pytorch's 11 | documentation. Excluding any of them will use the default ones. We can also 12 | pass kwargs in a dict: 13 | 14 | ('layer_name', , ) 15 | 16 | This is a list of the configuration values supported: 17 | 18 | Layer Paramaeters 19 | ============================================================================== 20 | Convolution: n-channels, size, stride, padding 21 | Transposed Convolution: same, output_padding when stride > 1! (use kwargs) 22 | Pooling: size, stride, padding, type 23 | Linear: output size, fit bias 24 | Flatten: start dim, (optional, defaults=-1) end dim 25 | Unflatten: unflatten shape (have to pass the full shape) 26 | Batch-norm: dimensionality (1-2-3d) 27 | Upsample: upsample_shape (hard to infer automatically). Bilinear 28 | Non-linearity: pass whatever arguments that non-linearity supports. 29 | 30 | There is a method called transpose_layer_defs which allows for automatically 31 | transposing the layer definitions for a decoder in a generative model. This 32 | will automatically convert convolutions into transposed convolutions and 33 | flattening to unflattening. However it will produce weird (but functionally 34 | equivalent) orders of layers for ReLU before flattening, which means 35 | unflattening in the corresponding decoder will be done before the ReLU. 36 | """ 37 | 38 | 39 | import math 40 | import numpy as np 41 | import torch 42 | import torch.nn as nn 43 | 44 | 45 | def _pair(s): 46 | if not isinstance(s, tuple): 47 | return s, s 48 | return s 49 | 50 | 51 | def preprocess_defs(layer_defs): 52 | def preprocess(definition): 53 | if len(definition) == 1: 54 | return definition[0], [], {} 55 | elif len(definition) == 2 and isinstance(definition[1], (tuple, list)): 56 | return (*definition, {}) 57 | elif len(definition) == 2 and isinstance(definition[1], dict): 58 | return definition[0], [], definition[1] 59 | elif len(definition) == 3: 60 | return definition 61 | raise ValueError('Invalid layer definition') 62 | 63 | return list(map(preprocess, layer_defs)) 64 | 65 | 66 | def get_nonlinearity(nonlinearity): 67 | if nonlinearity == 'relu': 68 | return nn.ReLU 69 | elif nonlinearity == 'sigmoid': 70 | return nn.Sigmoid 71 | elif nonlinearity == 'tanh': 72 | return nn.Tanh 73 | elif nonlinearity == 'lrelu': 74 | return nn.LeakyReLU 75 | elif nonlinearity == 'elu': 76 | return nn.ELU 77 | raise ValueError('Unrecognized non linearity: {}'.format(nonlinearity)) 78 | 79 | 80 | def create_linear(input_size, args, kwargs, transposed=False): 81 | if isinstance(input_size, (list, tuple)): 82 | in_features = input_size[-1] 83 | else: 84 | in_features = input_size 85 | 86 | if transposed: 87 | layer = nn.Linear(args[0], in_features, *args[1:], **kwargs) 88 | else: 89 | layer = nn.Linear(in_features, *args, **kwargs) 90 | 91 | if isinstance(input_size, (list, tuple)): 92 | input_size[-1] = args[0] 93 | else: 94 | input_size = args[0] 95 | 96 | return layer, input_size 97 | 98 | 99 | def creat_batch_norm(ndims, input_size, args, kwargs): 100 | if ndims == 1: 101 | return nn.BatchNorm1d(input_size, *args, **kwargs) 102 | elif ndims == 2: 103 | return nn.BatchNorm2d(input_size[0], *args, **kwargs) 104 | elif ndims == 3: 105 | return nn.BatchNorm3d(input_size[0], *args, **kwargs) 106 | 107 | 108 | def maxpool2d_out_shape(in_shape, pool_shape, stride, padding): 109 | in_channels, hout, wout = in_shape 110 | pool_shape = _pair(pool_shape) 111 | stride = _pair(stride) 112 | padding = _pair(padding) 113 | 114 | hval, wval = zip(pool_shape, stride, padding) 115 | 116 | hout = math.floor((hout - hval[0] + 2 * hval[2]) / hval[1]) + 1 117 | wout = math.floor((wout - wval[0] + 2 * wval[2]) / wval[1]) + 1 118 | 119 | return in_channels, hout, wout 120 | 121 | 122 | def create_pool(kernel_size, stride, padding, mode, kwargs): 123 | if mode == 'avg': 124 | pooling = nn.AvgPool2d(kernel_size, stride, padding, **kwargs) 125 | elif mode == 'max': 126 | pooling = nn.MaxPool2d(kernel_size, stride, **kwargs) 127 | elif mode == 'adapt': 128 | pooling = nn.AdaptiveAvgPool2d(kernel_size, **kwargs) 129 | else: 130 | raise ValueError('Unrecognised pooling mode {}'.format(mode)) 131 | 132 | return pooling 133 | 134 | 135 | def conv2d_out_shape(in_shape, out_channels, kernel_shape, stride, padding): 136 | in_shape = in_shape[1:] 137 | kernel_shape = _pair(kernel_shape) 138 | stride = _pair(stride) 139 | padding = _pair(padding) 140 | 141 | hval, wval = zip(in_shape, kernel_shape, stride, padding) 142 | 143 | hout = math.floor((hval[0] - hval[1] + 2 * hval[3]) / hval[2]) + 1 144 | wout = math.floor((wval[0] - wval[1] + 2 * wval[3]) / wval[2]) + 1 145 | 146 | return out_channels, hout, wout 147 | 148 | 149 | def transp_conv2d_out_shape(in_shape, out_channels, kernel_shape, 150 | stride, padding): 151 | in_shape = in_shape[1:] 152 | kernel_shape = _pair(kernel_shape) 153 | stride = _pair(stride) 154 | padding = _pair(padding) 155 | 156 | hval, wval = zip(in_shape, kernel_shape, stride, padding) 157 | 158 | hout = (hval[0] - 1) * hval[2] - 2 * hval[3] + hval[1] 159 | wout = (wval[0] - 1) * wval[2] - 2 * wval[3] + wval[1] 160 | 161 | return out_channels, hout, wout 162 | 163 | 164 | def compute_flattened_size(input_size, start_dim=1, end_dim=-1): 165 | start_dim -= 1 166 | if start_dim < 0: 167 | raise ValueError('Cannot flatten batch dimension') 168 | 169 | if end_dim < 0: 170 | end_dim = len(input_size) + 1 171 | 172 | output_size = list(input_size[:start_dim]) 173 | output_size.append(np.prod(input_size[start_dim:end_dim])) 174 | output_size.extend(input_size[end_dim:]) 175 | 176 | if len(output_size) == 1: 177 | return output_size[0] 178 | 179 | return output_size 180 | 181 | 182 | class Unflatten(nn.Module): 183 | def __init__(self, unflatten_shape): 184 | super().__init__() 185 | self.unflatten_shape = unflatten_shape 186 | 187 | def forward(self, inputs): 188 | return inputs.view(-1, *self.unflatten_shape) 189 | 190 | def extra_repr(self): 191 | dims = [str(d) for d in self.unflatten_shape] 192 | return 'batch_size, {}'.format(', '.join(dims)) 193 | 194 | 195 | class FeedForward(nn.Sequential): 196 | def __init__(self, input_size, layer_defs, flatten=True): 197 | if isinstance(layer_defs, dict): 198 | layer_defs = dict.items() 199 | layer_defs = preprocess_defs(layer_defs) 200 | 201 | cnn_layers, output_size = [], input_size 202 | 203 | for layer_type, args, kwargs in layer_defs: 204 | if layer_type == 'linear': 205 | layer, output_size = create_linear(output_size, args, kwargs) 206 | elif layer_type == 'conv': 207 | layer = nn.Conv2d(output_size[0], *args, **kwargs) 208 | output_size = conv2d_out_shape(output_size, *args) 209 | elif layer_type == 'tconv': 210 | layer = nn.ConvTranspose2d(output_size[0], *args, **kwargs) 211 | output_size = transp_conv2d_out_shape(output_size, *args) 212 | elif layer_type == 'batch_norm': 213 | layer = creat_batch_norm(args[0], output_size, 214 | args[1:], kwargs) 215 | elif layer_type == 'pool': 216 | layer = create_pool(*args, kwargs) 217 | output_size = maxpool2d_out_shape(output_size, *args[:-1]) 218 | elif layer_type == 'dropout': 219 | layer = nn.Dropout2d(*args, **kwargs) 220 | elif layer_type == 'flatten': 221 | layer = nn.Flatten(*args) 222 | output_size = compute_flattened_size(output_size) 223 | elif layer_type == 'unflatten': 224 | layer = Unflatten(args) 225 | output_size = args 226 | elif layer_type == 'upsample': 227 | layer = nn.UpsamplingBilinear2d(size=args) 228 | output_size = output_size[0], *args 229 | else: 230 | layer = get_nonlinearity(layer_type)(*args, **kwargs) 231 | 232 | cnn_layers.append(layer) 233 | 234 | super().__init__(*cnn_layers) 235 | 236 | self.input_size = input_size 237 | self.output_size = output_size 238 | self.flatten = flatten 239 | 240 | def forward(self, inputs): 241 | if isinstance(self.input_size, (list, tuple)): 242 | inputs = inputs.view(-1, *self.input_size) 243 | 244 | outputs = super().forward(inputs) 245 | 246 | if self.flatten: 247 | outputs = torch.flatten(outputs, start_dim=1) 248 | 249 | return outputs 250 | 251 | 252 | def transpose_layer_defs(layer_defs, input_size): 253 | if isinstance(layer_defs, dict): 254 | layer_defs = layer_defs.items() 255 | 256 | layer_defs = preprocess_defs(layer_defs) 257 | 258 | transposed_layer_defs = [] 259 | 260 | for layer_type, args, kwargs in layer_defs: 261 | if layer_type == 'linear': 262 | if isinstance(input_size, (tuple, list)): 263 | linear_size = *input_size[:-1], args[0] 264 | args = input_size[-1], *args[1:] 265 | input_size = linear_size 266 | else: 267 | args, input_size = [input_size] + args[1:], args[0] 268 | elif layer_type == 'conv': 269 | layer_type = 'tconv' 270 | conv_size = conv2d_out_shape(input_size, *args) 271 | args, input_size = (input_size[0], *args[1:]), conv_size 272 | elif layer_type == 'pool': 273 | layer_type = 'upsample' 274 | pooled_size = maxpool2d_out_shape(input_size, *args[:-1]) 275 | args, input_size = input_size[1:], pooled_size 276 | elif layer_type == 'flatten': 277 | layer_type = 'unflatten' 278 | flattened_size = compute_flattened_size(input_size, *args) 279 | args, input_size = input_size, flattened_size 280 | 281 | layer = layer_type, args, kwargs 282 | transposed_layer_defs.append(layer) 283 | 284 | return list(reversed(transposed_layer_defs)) 285 | -------------------------------------------------------------------------------- /src/architectures/helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | from torch import nn 5 | from architectures.feedforward import FeedForward, transpose_layer_defs 6 | 7 | arch_configs= { 8 | 'burgess':{ 9 | 'encoder_cnn': 10 | [ 11 | ('conv', (32, 4, 2, 1)), 12 | ('relu',), 13 | 14 | ('conv', (32, 4, 2, 1)), 15 | ('relu',), 16 | 17 | ('conv', (64, 4, 2, 1)), 18 | ('relu',), 19 | 20 | ('conv', (64, 4, 2, 1)), 21 | ('relu',), 22 | 23 | ('flatten', [1]), 24 | ], 25 | 'encoder_latent': 26 | [ 27 | ('linear', [256]), 28 | ('relu',), 29 | 30 | ('linear', [256]), 31 | ('relu',) 32 | ], 33 | 'lstm_latent': 34 | { 35 | 'hidden_size': 256 36 | } 37 | }, 38 | 39 | 'base': { 40 | 'encoder_cnn': 41 | [ 42 | ('conv', (64, 4, 2, 1)), 43 | ('relu',), 44 | 45 | ('conv', (64, 4, 2, 1)), 46 | ('relu',), 47 | 48 | ('conv', (128, 4, 2, 1)), 49 | ('relu',), 50 | 51 | ('conv', (128, 4, 2, 1)), 52 | ('relu',), 53 | 54 | ('flatten', [1]), 55 | ], 56 | 'encoder_latent': 57 | [ 58 | ('linear', [512]), 59 | ('relu',), 60 | 61 | ('linear', [1024]), 62 | ('relu',), 63 | 64 | ('linear', [1024]), 65 | ('relu',), 66 | 67 | ('linear', [512]), 68 | ('relu',) 69 | ], 70 | 'lstm_latent': 71 | { 72 | 'hidden_size': 512 73 | } 74 | }, 75 | 76 | 'large': { 77 | 'encoder_cnn': 78 | [ 79 | ('conv', (128, 4, 2, 1)), 80 | ('relu',), 81 | 82 | ('conv', (128, 4, 2, 1)), 83 | ('relu',), 84 | 85 | ('conv', (256, 4, 2, 1)), 86 | ('relu',), 87 | 88 | ('conv', (256, 4, 2, 1)), 89 | ('relu',), 90 | 91 | ('flatten', [1]), 92 | ], 93 | 'encoder_latent': 94 | [ 95 | ('linear', [1024]), 96 | ('relu',), 97 | ('linear', [2048]), 98 | ('relu',), 99 | ('linear', [2048]), 100 | ('relu',), 101 | ('linear', [1024]), 102 | ('relu',) 103 | ], 104 | 'lstm_latent': 105 | { 106 | 'hidden_size': 1024 107 | } 108 | }, 109 | } 110 | 111 | def build_architectures(input_size, name, latent_size, model, **kwargs): 112 | config = arch_configs[name] 113 | # build conv layers 114 | encoder_cnn_config = config['encoder_cnn'][:] 115 | encoder_cnn = FeedForward(input_size, encoder_cnn_config, flatten=False) 116 | if 'decoder_cnn' in config: 117 | decoder_cnn_config = config['decoder_config'][:] 118 | else: 119 | decoder_cnn_config = encoder_cnn_config[:] 120 | decoder_cnn_config = transpose_layer_defs(decoder_cnn_config, input_size) 121 | decoder_cnn = FeedForward(encoder_cnn.output_size, decoder_cnn_config, flatten=False) 122 | conv_layers = (encoder_cnn, decoder_cnn) 123 | 124 | # build latent layers 125 | if 'Recurrent' in model: 126 | # if it is recurrent models, the latent layers are built seperately with given config 127 | latent_layers = config['lstm_latent'] 128 | else: 129 | encoder_latent_config = config['encoder_latent'][:] 130 | if model in ['VAE', 'BetaTCVAE']: 131 | # mu and log_variance 132 | encoder_latent_config += [('linear', [2 * latent_size])] 133 | elif model == 'AutoEncoder': 134 | # latent size = number of discrete code * n_classes 135 | encoder_latent_config += [('linear', [latent_size])] 136 | else: 137 | raise NotImplementedError('Not implemented for {}'.format(model)) 138 | 139 | encoder_latent = FeedForward(encoder_cnn.output_size, encoder_latent_config, flatten=False) 140 | 141 | if 'decoder_latent' in config: 142 | decoder_latent_config = config['encoder_config'][:] 143 | else: 144 | decoder_latent_config = encoder_latent_config[:-1] 145 | if model in ['VAE', 'BetaTCVAE', 'AutoEncoder']: 146 | decoder_latent_config.append(('linear', [latent_size])) 147 | else: 148 | raise NotImplementedError(f'Not implemented for {model}') 149 | decoder_latent_config = transpose_layer_defs(decoder_latent_config, encoder_latent.input_size) 150 | decoder_latent = FeedForward(latent_size, decoder_latent_config, flatten=False) 151 | latent_layers = (encoder_latent, decoder_latent) 152 | 153 | return conv_layers, latent_layers 154 | 155 | -------------------------------------------------------------------------------- /src/architectures/lstm_latent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | import torch.nn as nn 5 | from .stochastic import gumble_softmax, straight_through_discretize 6 | 7 | class LatentModuleLSTM(nn.Module): 8 | """ 9 | A module that encode a feature vector as a sequence of discrete tokens autoregressively 10 | with LSTM and decode the sequence into a feature vector with another LSTM. 11 | """ 12 | def __init__(self, 13 | input_size, 14 | output_size, 15 | hidden_size, 16 | latent_size, 17 | dictionary_size, 18 | fix_length=False, 19 | temperature=1.0, 20 | **kwargs): 21 | super(LatentModuleLSTM, self).__init__(**kwargs) 22 | self.input_size = input_size 23 | self.output_size = output_size 24 | self.hidden_size = hidden_size 25 | self.latent_size = latent_size 26 | self.dictionary_size = dictionary_size 27 | self.temperature = temperature 28 | self.fix_length = fix_length 29 | 30 | self.input_layer = nn.Linear(self.input_size, self.hidden_size) 31 | self.output_layer = nn.Linear(self.hidden_size, self.output_size) 32 | self.encoder_lstm = nn.LSTMCell(self.hidden_size, self.hidden_size) 33 | self.decoder_lstm = nn.LSTMCell(self.hidden_size, self.hidden_size) 34 | self.hidden_to_token = nn.Linear(self.hidden_size, 35 | self.dictionary_size) 36 | self.token_to_hidden = nn.Linear(self.dictionary_size, self.hidden_size) 37 | 38 | if not fix_length: 39 | # if using variable length, token #0 is the eos token and indicate the end of sequence 40 | self.register_buffer('eos_token', 41 | torch.zeros(1, self.dictionary_size).scatter_(-1, torch.tensor([[0, ]]), 1)) 42 | 43 | def forward(self, input, sampling=True, **kwargs): 44 | res = self.encode(input, sampling) 45 | res['output'] = self.decode(**res) 46 | return res 47 | 48 | def encode(self, x, sampling): 49 | """ 50 | Encode a batch of feature vector into a batch of sequences of discrete tokens 51 | :param x: (Tensor) input feature [B x input_size] 52 | :return: 53 | z: (Tensor) sequence of discrete tokens in one-hot shapes [B, latent_size, dictionary_size] 54 | logits: (Tensor) sequence of logits from which tokens were sampled from [B, latent_size, dictionary_size] 55 | """ 56 | x = self.input_layer(x) 57 | tokens_one_hot, logits, eos_id = self.encode_variable_length(x, sampling=sampling) 58 | res = { 59 | 'z': tokens_one_hot, 60 | 'logits': logits, 61 | 'eos_id': eos_id, 62 | } 63 | return res 64 | 65 | def decode(self, z, eos_id=None, **kwargs): 66 | output = self.decode_variable_length(z, eos_id) 67 | output = self.output_layer(output) 68 | return output 69 | 70 | 71 | def encode_variable_length(self, x, sampling=True): 72 | """ 73 | transform an feature into a sequence of discrate tokens (as one-hot vectors) 74 | :param x: image feature with size of self.hidden_size 75 | :param sampling: if True, using Gumble-softmax to sample tokens from distributions, 76 | otherwise use the token with the highest probability. 77 | :return: 78 | """ 79 | _device = x.device 80 | samples = [] 81 | logits = [] 82 | batch_size = x.shape[0] 83 | hx = torch.zeros(batch_size, self.hidden_size, 84 | device=_device) 85 | cx = x 86 | lstm_input = torch.zeros(batch_size, self.hidden_size, 87 | device=_device) 88 | 89 | if not self.fix_length: 90 | is_finished = torch.zeros(batch_size, device=_device).bool() 91 | eos_ind = torch.zeros(batch_size, device=_device, dtype=torch.long) # the index where the first eos appears 92 | eos_batch = self.eos_token.to(_device).repeat(batch_size, 1) 93 | else: 94 | eos_ind = None 95 | 96 | for num in range(self.latent_size): 97 | hx, cx = self.encoder_lstm(lstm_input, (hx, cx)) 98 | pre_logits = self.hidden_to_token(hx) # embedding to catogory logits 99 | logits.append(pre_logits) 100 | 101 | if sampling and self.training: 102 | # sample discrete code with gumble softmax 103 | z_sampled_soft = gumble_softmax(pre_logits, self.temperature) 104 | else: 105 | z_sampled_soft = torch.softmax(pre_logits, dim=-1) 106 | 107 | z_sampled_onehot, z_argmax = straight_through_discretize(z_sampled_soft) 108 | 109 | if not self.fix_length: 110 | # record ending state of this step 111 | z_sampled_onehot[is_finished] = eos_batch[is_finished] 112 | is_finished = torch.logical_or(is_finished, z_argmax == 0) 113 | not_finished = torch.logical_not(is_finished) 114 | eos_ind += not_finished.long() 115 | 116 | samples.append(z_sampled_onehot) 117 | 118 | # the projected embedding of the sampled discrete code is the input for the next step 119 | lstm_input = self.token_to_hidden(z_sampled_onehot) 120 | 121 | logits = torch.stack(logits).permute(1, 0, 2) 122 | samples = torch.stack(samples).permute(1, 0, 2) 123 | 124 | return samples, logits, eos_ind 125 | 126 | 127 | def decode_variable_length(self, z, eos_ind): 128 | """ 129 | 130 | :param z: onehot discrete representions (BatchSize x LatentCodeSize x VacabularySize ) 131 | :param eos_ind: index of EOS token for latent z (BatchSize) e.g. if eos_ind == 2, z[0:2] are meaningfull tokens 132 | :return: 133 | """ 134 | batch_size = z.shape[0] 135 | _device = z.device 136 | 137 | z_embeddings = self.token_to_hidden(z.contiguous().view(-1, z.shape[-1])).view(batch_size, self.latent_size, -1) # project one-hot codes into continueious embeddings 138 | hx = torch.zeros(batch_size, self.hidden_size, 139 | device=_device) 140 | cx = torch.zeros(batch_size, self.hidden_size, 141 | device=_device) 142 | outputs = [] 143 | for n in range(self.latent_size): 144 | inputs = z_embeddings[:,n] 145 | hx, cx = self.decoder_lstm(inputs, (hx, cx)) 146 | outputs.append(hx) 147 | 148 | if self.fix_length: 149 | return hx 150 | else: 151 | # we also feed EOS embeddings to decoder LSTM 152 | eos_embeddings = self.token_to_hidden(self.eos_token.to(_device).repeat(batch_size, 1)) 153 | hx, cx = self.decoder_lstm(eos_embeddings, (hx, cx)) 154 | outputs.append(hx) 155 | outputs = torch.stack(outputs).permute(1, 0, 2) 156 | # select right output according to the EOS position in the latent code sequence. 157 | # mask_ind = eos_ind 158 | eos_ind_mask = torch.zeros(batch_size, self.latent_size + 1, 1, device=_device).scatter_(1, eos_ind.view(-1, 1, 1), 1) 159 | selected_output = outputs.masked_select(eos_ind_mask.bool()).view(batch_size, -1) 160 | return selected_output -------------------------------------------------------------------------------- /src/architectures/stochastic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from commons.types_ import * 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | def gumble_softmax(logits: Tensor, temperature: float) -> Tensor: 8 | """ 9 | Reparameterization trick to sample from a discrete distribution 10 | :param logits: (Tensor) [B x C X D] C is channels for features and D is for different discrete categories 11 | :param temperature: (Float) temperature for Gumble softmax 12 | :return: (Tensor) [B x C X D] 13 | """ 14 | # sample from standard gumbel distribution 15 | g = torch.distributions.gumbel.Gumbel(torch.zeros_like(logits), torch.ones_like(logits)) 16 | G = g.sample() 17 | return F.softmax((logits + G) / temperature, -1) 18 | 19 | 20 | def straight_through_discretize(z_sampled_soft): 21 | """ 22 | get argmax z (one-hot) from a sampled distribution with straight through gradient estimation 23 | :param p_sampled: distribution or logits of dicrete variables [B x C_d x D] 24 | :return: z_sampled_onehot: [B x C_d X D], z_sampled [B x C_d] 25 | """ 26 | 27 | z_argmax = torch.argmax(z_sampled_soft, dim=-1, keepdim=True) 28 | z_argmax_one_hot = torch.zeros_like(z_sampled_soft).scatter_(-1, z_argmax, 1) 29 | 30 | # straight through gradient estimator 31 | z_sampled_onehot_with_grad = z_sampled_soft + ( 32 | z_argmax_one_hot - z_sampled_soft).detach() 33 | 34 | return z_sampled_onehot_with_grad, z_argmax.squeeze(-1) -------------------------------------------------------------------------------- /src/commons/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wildphoton/Compositional-Generalization/1e0eebe153a79c17102986af58fe0277b079a742/src/commons/__init__.py -------------------------------------------------------------------------------- /src/commons/callbacks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from pytorch_lightning.callbacks import ModelCheckpoint 4 | 5 | class MyModelCheckpoint(ModelCheckpoint): 6 | # def _get_metric_interpolated_filepath_name( 7 | # self, 8 | # ckpt_name_metrics: Dict[str, Any], 9 | # epoch: int, 10 | # step: int, 11 | # trainer, 12 | # del_filepath: Optional[str] = None, 13 | # ) -> str: 14 | # filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics) 15 | # return filepath 16 | 17 | def _get_metric_interpolated_filepath_name(self, monitor_candidates, trainer, del_filepath=None) -> str: 18 | # this change allows for overwriting previous checkpoints 19 | filepath = self.format_checkpoint_name(monitor_candidates) 20 | 21 | # version_cnt = self.STARTING_VERSION 22 | # while self.file_exists(filepath, trainer) and filepath != del_filepath: 23 | # filepath = self.format_checkpoint_name(monitor_candidates, ver=version_cnt) 24 | # version_cnt += 1 25 | 26 | return filepath -------------------------------------------------------------------------------- /src/commons/func.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | def merge_params(dictionary): 4 | """ 5 | Function to merge a two-level parameter dictionary 6 | :param dictionary: 7 | :return: 8 | """ 9 | res = {} 10 | for meta_key in dictionary: 11 | for key in dictionary[meta_key]: 12 | res[f'{meta_key}_{key}'] = dictionary[meta_key][key] 13 | 14 | return res -------------------------------------------------------------------------------- /src/commons/types_.py: -------------------------------------------------------------------------------- 1 | from typing import List, Callable, Union, Any, TypeVar, Tuple, Dict 2 | Tensor = TypeVar('torch.tensor') 3 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from .dsprites import DSprites 4 | from .mpi3d import MPI3D 5 | from .helper import get_datamodule 6 | -------------------------------------------------------------------------------- /src/datasets/data_modules.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import numpy as np 5 | from torch.utils.data import Dataset, DataLoader 6 | from torchvision import transforms 7 | import pytorch_lightning as pl 8 | 9 | from .dsprites import DSprites 10 | from .mpi3d import MPI3D 11 | 12 | DATASETS_DICT = { 13 | "dsprites90d": DSprites, 14 | "mpi3d": MPI3D, 15 | } 16 | 17 | 18 | class MetaDataModule(pl.LightningDataModule): 19 | def __init__(self, name, data_dir, batch_size: int = 128, num_workers=4, n_train=None, n_fold=1, 20 | random_seed=None, virtual_n_samples=None, in_distribution_test=False, **dataset_config): 21 | super(MetaDataModule, self).__init__() 22 | self.name = name 23 | self.data_dir = data_dir 24 | self.batch_size = batch_size 25 | self.transform = None 26 | self.num_workers = num_workers 27 | self.n_train = n_train # use only partial of training set 28 | self.n_fold = n_fold # the dataset will be n_fold x n_train samples 29 | self.virtual_n_samples = virtual_n_samples # change the number of samples to be total training steps for prefetching purpose 30 | self.in_distribution_test = in_distribution_test # use the full train_set (ignore n_train) for testing 31 | self.dataset_config = dataset_config 32 | 33 | self.class_name = self.name.split('_')[0] 34 | self.dataset_class = DATASETS_DICT[self.class_name] 35 | self.num_classes = self.dataset_class.NUM_CLASSES # unique classes for each attribute 36 | 37 | self.random_seed = random_seed 38 | 39 | self.train_ind = None 40 | self.test_ind = None 41 | self.train_dataset = None 42 | self.test_dataset = None 43 | 44 | def prepare_data(self) -> None: 45 | raise NotImplementedError() 46 | 47 | def train_dataloader(self, shuffle=True): 48 | return DataLoader(self.train_dataset, 49 | batch_size=self.batch_size, 50 | shuffle=shuffle, 51 | drop_last=False, 52 | num_workers=self.num_workers, 53 | pin_memory=True, 54 | persistent_workers=self.num_workers>0, 55 | ) 56 | 57 | def val_dataloader(self): 58 | return self.test_dataloader() 59 | 60 | def test_dataloader(self): 61 | return DataLoader(self.test_dataset, 62 | batch_size=self.batch_size, 63 | shuffle=False, 64 | drop_last=False, 65 | num_workers=self.num_workers) 66 | 67 | class DSpritesDataModule(MetaDataModule): 68 | def prepare_data(self): 69 | # the name is formated as CLASSNAME_MODE_VERSION 70 | self.mode = self.name.split('_')[1] 71 | self.version = self.name.split('_')[2] 72 | if self.class_name == 'dsprites90d': 73 | range_all = [np.arange(3), np.arange(6), np.arange(10), np.arange(32), np.arange(32)] 74 | 75 | if self.mode == 'random': 76 | # 184320 total images 77 | test_sizes = { 78 | 'v1': 30000, # 5: 1 79 | 'v2': 60000, # 2: 1 80 | 'v3': 90000, # 1: 1 81 | 'v4': 129024, # 3: 7 82 | 'v5': 165888, # 1: 9 83 | 'v6': 175104, # 5: 95 n_train = 9216 84 | } 85 | test_size = test_sizes[self.version] 86 | # total 184K 87 | 88 | all_ind = DSprites.get_partition(range_all) 89 | # shuffled_ids_cache_path = os.path.join(self.data_dir, f"{self.name}_shuffled_ids.npy") 90 | shuffled_ids_cache_path = os.path.join(self.data_dir, f"{self.class_name}_{self.mode}_seed{self.random_seed}_shuffled_ids.npy") 91 | 92 | if os.path.isfile(shuffled_ids_cache_path): 93 | print(f"Load shuffled ids at {shuffled_ids_cache_path}") 94 | shuffled_ids = np.load(shuffled_ids_cache_path) 95 | else: 96 | print(f"Save shuffled ids at {shuffled_ids_cache_path}") 97 | shuffled_ids = np.random.permutation(all_ind) 98 | np.save(shuffled_ids_cache_path, shuffled_ids) 99 | 100 | self.train_ind = shuffled_ids[test_size:] 101 | if self.in_distribution_test: 102 | self.test_ind = shuffled_ids[test_size:] 103 | else: 104 | self.test_ind = shuffled_ids[:test_size] 105 | else: 106 | raise ValueError('Undefined splitting') 107 | 108 | if self.n_train is not None: 109 | self.train_ind = self.train_ind[:int(self.n_train*self.n_fold)] 110 | else: 111 | raise ValueError('Undefined dataset type') 112 | 113 | 114 | def setup(self, *args, **kwargs): 115 | self.train_dataset = self.dataset_class(root=self.data_dir, 116 | range=self.train_ind, 117 | transform=self.transform, 118 | n_samples=self.virtual_n_samples, 119 | **self.dataset_config 120 | ) 121 | self.test_dataset = self.dataset_class(root=self.data_dir, 122 | range=self.test_ind, 123 | transform=self.transform, 124 | **self.dataset_config 125 | ) 126 | 127 | 128 | class MPI3DDataModule(MetaDataModule): 129 | """ 130 | Supported name format : "mpi3d_{subset}_{split_mode}_v{version}" 131 | subset = {toy, realistic, real} 132 | split_mode = {random} 133 | version = {} 134 | """ 135 | def prepare_data(self) -> None: 136 | self.subset_name = self.name.split('_')[1].lower() 137 | self.mode = self.name.split('_')[2].lower() 138 | self.version = self.name.split('_')[3].lower() 139 | 140 | if self.mode == 'random': 141 | # total size 1036800 142 | test_ratio = { 143 | 'v1': 1/6, # 5: 1 144 | 'v2': 1/3, # 2: 1 145 | 'v3': 1/2, # 1: 1 146 | 'v4': 7/10, # 3: 7 147 | 'v5': 9/10, # 1: 9 148 | 'v6': 95/100, # 5: 95 149 | 'v7': 99/100, # 1: 99 150 | } 151 | test_size = int(self.dataset_class.total_sample_size * test_ratio[self.version]) 152 | 153 | all_ind = list(np.arange(self.dataset_class.total_sample_size)) 154 | shuffled_ids_cache_path = os.path.join(self.data_dir, f"{self.class_name}_{self.subset_name}_{self.mode}_shuffled_ids_seed{self.random_seed}.npy") 155 | if os.path.isfile(shuffled_ids_cache_path): 156 | print(f"Load shuffled ids at {shuffled_ids_cache_path}") 157 | shuffled_ids = np.load(shuffled_ids_cache_path) 158 | else: 159 | print(f"Save shuffled ids at {shuffled_ids_cache_path}") 160 | shuffled_ids = np.random.permutation(all_ind) 161 | np.save(shuffled_ids_cache_path, shuffled_ids) 162 | 163 | self.train_ind = shuffled_ids[test_size:] 164 | if self.in_distribution_test: 165 | self.test_ind = shuffled_ids[test_size:] 166 | else: 167 | self.test_ind = shuffled_ids[:test_size] 168 | 169 | if self.n_train is not None: 170 | self.train_ind = self.train_ind[:int(self.n_train*self.n_fold)] 171 | 172 | def setup(self, *args, **kwargs): 173 | self.train_dataset = self.dataset_class(root=self.data_dir, 174 | range=self.train_ind, 175 | subset=self.subset_name, 176 | n_samples=self.virtual_n_samples, 177 | ) 178 | self.test_dataset = self.dataset_class(root=self.data_dir, 179 | range=self.test_ind, 180 | subset=self.subset_name, 181 | ) -------------------------------------------------------------------------------- /src/datasets/dsprites.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Copied from https://github.com/YannDubs/disentangling-vae/blob/master/utils/datasets.py 4 | """ 5 | import abc 6 | import logging 7 | import os 8 | import subprocess 9 | 10 | # from skimage.io import imread 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import Dataset, DataLoader 14 | 15 | DIR = os.path.abspath(os.path.dirname(__file__)) 16 | COLOUR_BLACK = 0 17 | COLOUR_WHITE = 1 18 | DATASETS_DICT = { 19 | # "mnist": "MNIST", 20 | # "fashion": "FashionMNIST", 21 | "dsprites": "DSprites", 22 | # "celeba": "CelebA", 23 | # "chairs": "Chairs" 24 | } 25 | DATASETS = list(DATASETS_DICT.keys()) 26 | 27 | class DisentangledDataset(Dataset, abc.ABC): 28 | """Base Class for disentangled VAE datasets. 29 | 30 | Parameters 31 | ---------- 32 | root : string 33 | Root directory of dataset. 34 | 35 | transforms_list : list 36 | List of `torch.vision.transforms` to apply to the data when loading it. 37 | """ 38 | 39 | def __init__(self, root, logger=logging.getLogger(__name__)): 40 | self.root = root 41 | self.train_data = os.path.join(root, type(self).files["train"]) 42 | # self.transforms = transforms.Compose(transforms_list) 43 | self.logger = logger 44 | 45 | if not os.path.isdir(root): 46 | self.logger.info("Downloading {} ...".format(str(type(self)))) 47 | self.download() 48 | self.logger.info("Finished Downloading.") 49 | 50 | def __len__(self): 51 | return len(self.imgs) 52 | 53 | @abc.abstractmethod 54 | def __getitem__(self, idx): 55 | """Get the image of `idx`. 56 | 57 | Return 58 | ------ 59 | sample : torch.Tensor 60 | Tensor in [0.,1.] of shape `img_size`. 61 | """ 62 | pass 63 | 64 | @abc.abstractmethod 65 | def download(self): 66 | """Download the dataset. """ 67 | pass 68 | 69 | def sample_factors(self, num, random_state): 70 | """Sample a batch of factors Y.""" 71 | raise NotImplementedError() 72 | 73 | def sample_observations_from_factors(self, factors, random_state): 74 | """Sample a batch of observations X given a batch of factors Y.""" 75 | raise NotImplementedError() 76 | 77 | def sample(self, num, random_state): 78 | """Sample a batch of factors Y and observations X.""" 79 | raise NotImplementedError() 80 | 81 | def sample_observations(self, num, random_state): 82 | """Sample a batch of observations X.""" 83 | return self.sample(num, random_state)[1] 84 | 85 | 86 | class DSprites(DisentangledDataset): 87 | """DSprites Dataset from [1]. 88 | 89 | Disentanglement test Sprites dataset.Procedurally generated 2D shapes, from 6 90 | disentangled latent factors. This dataset uses 6 latents, controlling the color, 91 | shape, scale, rotation and position of a sprite. All possible variations of 92 | the latents are present. Ordering along dimension 1 is fixed and can be mapped 93 | back to the exact latent values that generated that image. Pixel outputs are 94 | different. No noise added. 95 | 96 | Notes 97 | ----- 98 | - Link : https://github.com/deepmind/dsprites-dataset/ 99 | - hard coded metadata because issue with python 3 loading of python 2 100 | 101 | Parameters 102 | ---------- 103 | root : string 104 | Root directory of dataset. 105 | 106 | References 107 | ---------- 108 | [1] Higgins, I., Matthey, L., Pal, A., Burgess, C., Glorot, X., Botvinick, 109 | M., ... & Lerchner, A. (2017). beta-vae: Learning basic visual concepts 110 | with a constrained variational framework. In International Conference 111 | on Learning Representations. 112 | 113 | """ 114 | urls = { 115 | "train": "https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true"} 116 | files = {"train": "dsprite_train.npz"} 117 | lat_names = ('shape', 'scale', 'orientation', 'posX', 'posY') 118 | lat_sizes = np.array([3, 6, 40, 32, 32]) 119 | task_types = np.array(['cls', 'reg', 'reg', 'reg', 'reg']) 120 | NUM_CLASSES = list(lat_sizes) 121 | img_size = (1, 64, 64) 122 | num_factors = 5 123 | total_sample_size = 737280 124 | background_color = COLOUR_BLACK 125 | latents_values = {'color': np.array([1.]), 126 | 'shape': np.array([1., 2., 3.]), # square, ellipse, heart 127 | 'scale': np.array([0.5, 0.6, 0.7, 0.8, 0.9, 1.]), 128 | 'orientation': np.array([0., 0.16110732, 0.32221463, 0.48332195, 129 | 0.64442926, 0.80553658, 0.96664389, 1.12775121, 130 | 1.28885852, 1.44996584, 1.61107316, 1.77218047, 131 | 1.93328779, 2.0943951, 2.25550242, 2.41660973, 132 | 2.57771705, 2.73882436, 2.89993168, 3.061039, 133 | 3.22214631, 3.38325363, 3.54436094, 3.70546826, 134 | 3.86657557, 4.02768289, 4.1887902, 4.34989752, 135 | 4.51100484, 4.67211215, 4.83321947, 4.99432678, 136 | 5.1554341, 5.31654141, 5.47764873, 5.63875604, 137 | 5.79986336, 5.96097068, 6.12207799, 6.28318531]), # [0, 2 pi] 138 | 'posX': np.array([0., 0.03225806, 0.06451613, 0.09677419, 0.12903226, 139 | 0.16129032, 0.19354839, 0.22580645, 0.25806452, 140 | 0.29032258, 0.32258065, 0.35483871, 0.38709677, 141 | 0.41935484, 0.4516129, 0.48387097, 0.51612903, 142 | 0.5483871, 0.58064516, 0.61290323, 0.64516129, 143 | 0.67741935, 0.70967742, 0.74193548, 0.77419355, 144 | 0.80645161, 0.83870968, 0.87096774, 0.90322581, 145 | 0.93548387, 0.96774194, 1.]), 146 | 'posY': np.array([0., 0.03225806, 0.06451613, 0.09677419, 0.12903226, 147 | 0.16129032, 0.19354839, 0.22580645, 0.25806452, 148 | 0.29032258, 0.32258065, 0.35483871, 0.38709677, 149 | 0.41935484, 0.4516129, 0.48387097, 0.51612903, 150 | 0.5483871, 0.58064516, 0.61290323, 0.64516129, 151 | 0.67741935, 0.70967742, 0.74193548, 0.77419355, 152 | 0.80645161, 0.83870968, 0.87096774, 0.90322581, 153 | 0.93548387, 0.96774194, 1.]), 154 | } 155 | 156 | def __init__(self, root, range, use_latent_class=True, transform=None, n_samples=None, **kwargs): 157 | """ 158 | 159 | :param root: 160 | :param range: an array of indices for a subset of the data 161 | :param laten_class: return latent variables as classes instead of values 162 | :param transform: 163 | :param kwargs: 164 | """ 165 | super().__init__(root, **kwargs) 166 | 167 | dataset_zip = np.load(self.train_data, allow_pickle=True, encoding='latin1') 168 | self.meta_data = dataset_zip['metadata'][()] 169 | self.use_latent_class = use_latent_class 170 | 171 | self.imgs = dataset_zip['imgs'] 172 | self.latents_values = dataset_zip['latents_values'][:, 1:] 173 | self.latents_classes = dataset_zip['latents_classes'][:, 1:] 174 | 175 | self.imgs = self.imgs[range] 176 | self.latents_values = self.latents_values[range] 177 | self.latents_classes = self.latents_classes[range] 178 | 179 | self.imgs = torch.from_numpy(self.imgs).unsqueeze(1).float() 180 | self.latents_values = torch.from_numpy(self.latents_values).float() 181 | self.latents_classes = torch.from_numpy(self.latents_classes).long() 182 | 183 | self.transform = transform 184 | self.n_samples = n_samples if n_samples is not None else len(self.imgs) 185 | self.raw_num_samples = len(self.imgs) 186 | # self.imgs = (self.imgs - 0.5)/0.5 187 | 188 | def download(self): 189 | """Download the dataset.""" 190 | os.makedirs(self.root) 191 | subprocess.check_call(["curl", "-L", type(self).urls["train"], 192 | "--output", self.train_data]) 193 | 194 | def __getitem__(self, idx): 195 | """Get the image of `idx` 196 | Return 197 | ------ 198 | sample : torch.Tensor 199 | Tensor in [0.,1.] of shape `img_size`. 200 | 201 | lat_value : np.array 202 | Array of length 6, that gives the value of each factor of variation. 203 | """ 204 | # map the recursive id to real id 205 | idx = idx % self.raw_num_samples 206 | 207 | # stored image have binary and shape (H x W) so multiply by 255 to get pixel 208 | # values + add dimension 209 | sample = self.imgs[idx] 210 | 211 | latent = self.latents_classes[idx] if self.use_latent_class else self.latents_values[idx] 212 | # lat_cls = self.latents_classes[idx] 213 | # lat_value = self.latents_values[idx] 214 | if self.transform is not None: 215 | sample = self.transform(sample) 216 | # lat_value = self.transform(lat_value) 217 | return sample, latent 218 | 219 | def __len__(self): 220 | return self.n_samples 221 | 222 | # def map_cls_to_val(self, cls): 223 | # """ 224 | # map class labels to actually values 225 | # :param cls: batch of class labels for 5 channels 226 | # :return: latent values 227 | # """ 228 | 229 | @staticmethod 230 | def get_partition(range_all: list, range_test=None): 231 | """ 232 | Get a data partition given the range of each factor. 233 | :param range: a list of 5 arrays (color, shape, scale, rotation, pos_x, pos_y) 234 | each array is a range of indices for the corresponding properties 235 | e.g. [[0, ], [1,], [0, 1], np.arange(13,26), np.arange(21, 31), np.arange(21, 31)]. 236 | ** range_test could also be a tuple of such lists for merge multiple possible ranges 237 | :return: the corresponding indices of the partition data. 238 | when range_test is None, return indice for range_all; 239 | when range_test is given, return indices for range_all - range_test, and for range_test 240 | """ 241 | 242 | all_indices = DSprites.range_to_index(range_all) 243 | if range_test is None: 244 | return all_indices 245 | else: 246 | test_indices = DSprites.range_to_index(range_test) 247 | assert len(np.setdiff1d(test_indices, all_indices)) == 0 # check if test indices is a subset of all indices 248 | train_indices = np.setdiff1d(all_indices, test_indices) 249 | return train_indices, test_indices 250 | 251 | @staticmethod 252 | def range_to_index(latent_range): 253 | if latent_range is None: 254 | return np.arange(0, DSprites.total_sample_size) 255 | else: 256 | if type(latent_range) is list: 257 | latents_sampled = np.array(np.meshgrid(*latent_range)).T.reshape(-1, len(latent_range)) 258 | indices_sampled = DSprites.latent_to_index(latents_sampled) 259 | elif type(latent_range) is tuple: 260 | multiple_indices_sampled = [DSprites.latent_to_index( 261 | np.array(np.meshgrid(*range)).T.reshape(-1, len(range))) for range in latent_range] 262 | indices_sampled = np.concatenate(multiple_indices_sampled, axis=0) 263 | else: 264 | raise ValueError() 265 | return indices_sampled 266 | 267 | @staticmethod 268 | def latent_to_index(latents): 269 | latents_sizes = DSprites.lat_sizes 270 | latents_bases = np.concatenate((latents_sizes[::-1].cumprod()[::-1][1:], 271 | np.array([1, ]))) 272 | 273 | return np.dot(latents, latents_bases).astype(int) 274 | 275 | def sample(self, num, random_state): 276 | indices = random_state.choice(self.raw_num_samples, 277 | num, 278 | replace=False if self.raw_num_samples > num else True) 279 | factors = self.latents_classes[indices].numpy().astype(np.int32) 280 | samples = self.imgs[indices].numpy() 281 | if len(samples.shape) == 3: # set channel dim to 1 282 | samples = samples[:, None] 283 | if np.issubdtype(samples.dtype, np.uint8): 284 | samples = samples.astype(np.float32) / 255. 285 | return factors, samples 286 | 287 | # def __len__(self): 288 | # return 1000 289 | 290 | # if __name__ == '__main__': 291 | # transform = transforms.Normalize((0.5,), (0.5,)) 292 | # range_all = [np.arange(3), np.arange(6), np.arange(10), np.arange(32), np.arange(32)] 293 | # range_test = [[1, ], [0, 1], np.arange(6, 10), np.arange(21, 32), np.arange(21, 32)] 294 | # dataset = DSprites(root=os.path.join(DIR, '../data/dsprites/'), 295 | # range_all=None, 296 | # range_test=None, 297 | # train=True, 298 | # transform=None 299 | # ) 300 | # dm = DSpritesDataModule(name='dsprites90d_random_v4', data_dir='../data/dsprites', 301 | # batch_size = 128, num_workers=0, n_train=None) 302 | # dm.prepare_data() 303 | # dm.setup() 304 | # pass 305 | -------------------------------------------------------------------------------- /src/datasets/helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import numpy as np 5 | from torchvision import transforms 6 | 7 | from .data_modules import DSpritesDataModule, MPI3DDataModule 8 | 9 | def get_datamodule(name, data_dir, **configs): 10 | if 'dsprites' in name: 11 | return DSpritesDataModule(name=name, data_dir=data_dir, **configs) 12 | elif 'mpi3d' in name: 13 | return MPI3DDataModule(name=name, data_dir=data_dir, **configs) 14 | else: 15 | raise ValueError() 16 | -------------------------------------------------------------------------------- /src/datasets/mpi3d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import numpy as np 5 | from urllib import request 6 | from itertools import product 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | import torchvision.transforms as trans 11 | 12 | class MPI3D(Dataset): 13 | """ 14 | #========================================================================== 15 | # Latent Dimension, Latent values N vals 16 | #========================================================================== 17 | 18 | # object color: white=0, green=1, red=2, blue=3, 6 19 | # brown=4, olive=5 20 | # object shape: cone=0, cube=1, cylinder=2, 6 21 | # hexagonal=3, pyramid=4, sphere=5 22 | # object size: small=0, large=1 2 23 | # camera height: top=0, center=1, bottom=2 3 24 | # background color: purple=0, sea green=1, salmon=2 3 25 | # horizontal axis: 40 values liearly spaced [0, 39] 40 26 | # vertical axis: 40 values liearly spaced [0, 39] 40 27 | """ 28 | files = {"toy": "mpi3d_toy.npz", 29 | "realistic": "mpi3d_realistic.npz", 30 | "real": "mpi3d_real.npz"} 31 | urls = { 32 | "toy": 'https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_toy.npz', 33 | "realistic": 'https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_realistic.npz', 34 | "real": 'https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_real.npz' 35 | } 36 | task_types = np.array(['cls', 'cls', 'cls', 'reg', 'cls', 'reg', 'reg',]) 37 | num_factors = 7 38 | # 'object_color', 'object_shape', 'object_size', 'camera_height', 'background_color', 'horizontal_axis', 'vertical_axis' 39 | lat_names = ('color', 'shape', 'size', 'height', 'bg_color', 'x-axis', 'y-axis') 40 | lat_sizes = np.array([6, 6, 2, 3, 3, 40, 40]) 41 | img_size = (3, 64, 64) 42 | total_sample_size = 1036800 43 | NUM_CLASSES = list(lat_sizes) 44 | 45 | lat_values = {'color': np.arange(6), 46 | 'shape': np.arange(6), 47 | 'size': np.arange(2), 48 | 'height': np.arange(3), 49 | 'bg_color': np.arange(3), 50 | 'x-axis': np.arange(40), 51 | 'y-axis': np.arange(40)} 52 | 53 | def __init__(self, root, subset, range=None, n_samples=None): 54 | self.root = root 55 | self.subset = subset 56 | self.file_path = os.path.join(root, self.files[subset]) 57 | if not os.path.exists(self.file_path): 58 | self.download(self.file_path) 59 | self.imgs = np.load(self.file_path)['images'] 60 | latent_values = np.asarray(list(product(*self.lat_values.values())), dtype=np.int8) 61 | self.latent_values = torch.from_numpy(latent_values) 62 | 63 | if range is not None: 64 | self.imgs = self.imgs[range] 65 | self.latent_values = self.latent_values[range] 66 | 67 | # self.imgs = torch.from_numpy(self.imgs) 68 | image_transforms = [ 69 | trans.ToTensor(), 70 | trans.ConvertImageDtype(torch.float32), 71 | ] 72 | # if color_mode == 'hsv': 73 | # image_transforms.insert(0, trans.Lambda(rgb2hsv)) 74 | 75 | self.transform = trans.Compose(image_transforms) 76 | self.n_samples = n_samples if n_samples is not None else len(self.imgs) 77 | self.raw_num_samples = len(self.imgs) 78 | 79 | 80 | def __getitem__(self, idx): 81 | # map the recursive id to real id 82 | idx = idx % self.raw_num_samples 83 | 84 | img, label = self.imgs[idx], self.latent_values[idx] 85 | if self.transform: 86 | img = self.transform(img) 87 | return img, label 88 | 89 | def __len__(self): 90 | return self.n_samples 91 | 92 | def download(self, file_path): 93 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 94 | print('downloading MPI3D {}'.format(self.subset)) 95 | request.urlretrieve(self.urls[self.subset], file_path) 96 | print('download complete') 97 | 98 | def sample(self, num, random_state): 99 | indices = random_state.choice(self.raw_num_samples, 100 | num, 101 | replace=False if self.raw_num_samples > num else True) 102 | factors = self.latent_values[indices].numpy().astype(np.int32) 103 | samples = self.imgs[indices] 104 | if np.issubdtype(samples.dtype, np.uint8): 105 | samples = samples.astype(np.float32) / 255. 106 | return factors, samples 107 | -------------------------------------------------------------------------------- /src/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from .comp_gen import CompGenalizationMetrics 4 | from .scikit_learn_evaluator import ScikitLearnEvaluator 5 | from .disentangle_metric_evaluator import DisentangleMetricEvaluator 6 | from .topo_sim import TopoSimEval -------------------------------------------------------------------------------- /src/evaluation/comp_gen.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import pandas as pd 7 | from dataclasses import dataclass 8 | from torch.utils.data import DataLoader 9 | from sklearn.linear_model import LassoCV, MultiTaskLassoCV, RidgeCV, LogisticRegressionCV, LinearRegression 10 | from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, GradientBoostingClassifier 11 | from .utils import infer 12 | 13 | EPS = 1e-12 14 | 15 | class CompGenalizationMetrics: 16 | def __init__(self, train_data, test_data, 17 | n_factors, 18 | n_fold=1, 19 | factor_selector = None, 20 | model='lasso', 21 | regressoR_coeffkwargs=None, 22 | **kwargs): 23 | 24 | if regressoR_coeffkwargs is not None: 25 | kwargs.update(regressoR_coeffkwargs) 26 | 27 | if model == 'lasso': 28 | if 'alphas' not in kwargs: 29 | kwargs['alphas'] = [0.00005, 0.0001, 0.001,] 30 | if 'selection' not in kwargs: 31 | kwargs['selection'] ='random' 32 | if 'cv' not in kwargs: 33 | kwargs['cv'] = 5 34 | # kwargs.update({'cv': 5, }) 35 | self.model_class = LassoCV 36 | elif model == 'ridge': 37 | if 'alphas' not in kwargs: 38 | kwargs['alphas'] = [0, 0.01, 0.1, 1.0, 10] 39 | if 'cv' not in kwargs: 40 | kwargs['cv'] = 5 41 | # if 'normalize' not in kwargs: 42 | # kwargs['normalize'] = True 43 | self.model_class = RidgeCV 44 | elif model == 'logistic': 45 | if 'cv' not in kwargs: 46 | kwargs['cv'] = 5 47 | self.model_class = LogisticRegressionCV 48 | elif model == 'linear': 49 | if 'cv' in kwargs: 50 | del kwargs['cv'] 51 | self.model_class = LinearRegression 52 | elif model == 'random-forest': 53 | self.model_class = RandomForestRegressor 54 | elif model == 'GBTC': 55 | self.model_class = GradientBoostingClassifier 56 | elif model == 'GBTR': 57 | self.model_class = GradientBoostingRegressor 58 | else: 59 | raise ValueError() 60 | 61 | self.train_data = train_data 62 | self.test_data = test_data 63 | self.n_fold = n_fold 64 | self.factor_indices = factor_selector if factor_selector is not None else list(range(n_factors)) 65 | self.kwargs = kwargs 66 | self.reset_model() 67 | 68 | def reset_model(self): 69 | self.model = self.model_class(**self.kwargs) 70 | 71 | def compute_score(self, rep_model, model_zs=None, mode='latent'): 72 | if model_zs is None: 73 | train_X, train_y = infer(rep_model, self.train_data, mode) 74 | train_X, train_y = train_X.numpy(), train_y.numpy() 75 | test_X, test_y = infer(rep_model, self.test_data, mode) 76 | test_X, test_y = test_X.numpy(), test_y.numpy() 77 | else: 78 | (train_X, train_y), (test_X, test_y) = model_zs 79 | 80 | if self.model_class == LogisticRegressionCV: 81 | train_y = train_y.astype(int) 82 | test_y = test_y.astype(int) 83 | 84 | score = [] 85 | n_samples_per_fold = train_X.shape[0] // self.n_fold 86 | for k in self.factor_indices: 87 | train_y_k = train_y[:, k] 88 | score_k = [] 89 | for j in range(self.n_fold): 90 | train_X_fold = train_X[n_samples_per_fold*j:n_samples_per_fold*(j+1)] 91 | train_y_k_one_fold = train_y_k[n_samples_per_fold*j:n_samples_per_fold*(j+1)] 92 | if len(np.unique(train_y_k)) > 1: 93 | try: 94 | self.reset_model() 95 | self.model.fit(train_X_fold, train_y_k_one_fold) 96 | score_k.append(self.model.score(test_X, test_y[:, k])) 97 | except Exception as e: 98 | print("Error message {}".format(str(e))) 99 | score.append(np.nanmean(score_k)) 100 | return np.array(score) 101 | 102 | def __call__(self, model=None, model_zs=None, mode='latent'): 103 | return self.compute_score(model, model_zs, mode) 104 | 105 | def compute_compgen_metric(models, train_data, test_data): 106 | """ 107 | Convenience function to compute the DCI metrics for a set of models 108 | in a given dataset. 109 | """ 110 | n_factors = train_data.num_factors 111 | 112 | train_loader = DataLoader(train_data, batch_size=64, num_workers=4, pin_memory=True) 113 | test_loader = DataLoader(test_data, batch_size=64, num_workers=4, pin_memory=True) 114 | 115 | com_gen = CompGenalizationMetrics(train_loader, test_loader, n_factors=n_factors) 116 | 117 | results = [com_gen(model) for model in models] 118 | 119 | return results -------------------------------------------------------------------------------- /src/evaluation/disentangle_metric_evaluator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import os 5 | import gin 6 | import gin.tf 7 | import numpy as np 8 | import torch 9 | 10 | gin.enter_interactive_mode() 11 | 12 | # needed later: 13 | from disentanglement_lib.evaluation.metrics import beta_vae 14 | from disentanglement_lib.evaluation.metrics import dci 15 | from disentanglement_lib.evaluation.metrics import downstream_task 16 | from disentanglement_lib.evaluation.metrics import factor_vae 17 | from disentanglement_lib.evaluation.metrics import fairness 18 | from disentanglement_lib.evaluation.metrics import irs 19 | from disentanglement_lib.evaluation.metrics import mig 20 | from disentanglement_lib.evaluation.metrics import modularity_explicitness 21 | from disentanglement_lib.evaluation.metrics import reduced_downstream_task 22 | from disentanglement_lib.evaluation.metrics import sap_score 23 | from disentanglement_lib.evaluation.metrics import unsupervised_metrics 24 | 25 | # IRS, DCI, Factor-VAE, MIG, SAP-Score 26 | config_root = '../evaluation/extra_metrics_configs' 27 | metric_score_name = { 28 | 'dci': 'disentanglement', 29 | 'mig': 'discrete_mig', 30 | 'sap_score': 'SAP_score', 31 | 'irs': 'IRS', 32 | } 33 | 34 | class DisentangleMetricEvaluator(): 35 | def __init__(self, model, datamodule, metrics=('dci', 'sap_score', 'mig', 'irs')): 36 | self.model = model 37 | datamodule.prepare_data() 38 | datamodule.setup() 39 | self.train_dataset = datamodule.train_dataset 40 | self.test_dataset = datamodule.test_dataset 41 | self.metrics = metrics 42 | 43 | def eval(self): 44 | res_train = eval_disentangle_metrics(self.model, self.train_dataset, self.metrics) 45 | res_test = eval_disentangle_metrics(self.model, self.test_dataset, self.metrics) 46 | res = {} 47 | for metric_name, score in res_train.items(): 48 | res[f'train_{metric_name}'] = score 49 | for metric_name, score in res_test.items(): 50 | res[f'test_{metric_name}'] = score 51 | return res 52 | 53 | def eval_disentangle_metrics(model, dataset, metrics): 54 | device = next(model.parameters()).device 55 | random_seed = np.random.randint(2**16) 56 | def representation_function(x): 57 | if x.shape[-1] == 3 or x.shape[-1] == 1: 58 | x = np.transpose(x, (0, 3, 1, 2)) 59 | representation = model.embed(torch.from_numpy(x).float().to(device), mode='latent') 60 | return np.array(representation.detach().cpu()) 61 | 62 | @gin.configurable("evaluation") 63 | def evaluate_model(evaluation_fn=gin.REQUIRED, random_seed=gin.REQUIRED): 64 | return evaluation_fn( 65 | dataset, 66 | representation_function, 67 | random_state=np.random.RandomState(random_seed)) 68 | 69 | results = {} 70 | for metric in metrics: 71 | metric_config = f"{metric}.gin" 72 | eval_bindings = [ 73 | f'evaluation.random_seed = {random_seed}'] 74 | gin.parse_config_files_and_bindings( 75 | [os.path.join(config_root, metric_config)], eval_bindings) 76 | # print(f'Eval metric {metric}') 77 | with torch.no_grad(): 78 | out = evaluate_model() 79 | gin.clear_config() 80 | results[metric] = out[metric_score_name[metric]] 81 | print(f'Metric {metric}: {results[metric]}') 82 | return results 83 | -------------------------------------------------------------------------------- /src/evaluation/extra_metrics_configs/beta_vae_sklearn.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | evaluation.evaluation_fn = @beta_vae_sklearn 17 | evaluation.random_seed = 0 18 | beta_vae_sklearn.batch_size=1000 19 | beta_vae_sklearn.num_train=10000 20 | beta_vae_sklearn.num_eval=5000 21 | -------------------------------------------------------------------------------- /src/evaluation/extra_metrics_configs/dci.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | evaluation.evaluation_fn = @dci 17 | evaluation.random_seed = 0 18 | dci.num_train=10000 19 | dci.num_test=5000 20 | dci.batch_size=1000 21 | -------------------------------------------------------------------------------- /src/evaluation/extra_metrics_configs/factor_vae_metric.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | evaluation.evaluation_fn = @factor_vae_score 17 | evaluation.random_seed = 0 18 | factor_vae_score.num_variance_estimate=10000 19 | factor_vae_score.num_train=10000 20 | factor_vae_score.num_eval=5000 21 | factor_vae_score.batch_size=64 22 | prune_dims.threshold = 0.05 23 | 24 | -------------------------------------------------------------------------------- /src/evaluation/extra_metrics_configs/irs.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | evaluation.evaluation_fn = @irs 17 | evaluation.random_seed = 0 18 | irs.num_train=10000 19 | irs.batch_size=16 20 | discretizer.num_bins=20 21 | discretizer.discretizer_fn=@histogram_discretizer 22 | -------------------------------------------------------------------------------- /src/evaluation/extra_metrics_configs/mcc.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | evaluation.evaluation_fn = @mcc 17 | evaluation.random_seed = 0 18 | mcc.num_train=100000 19 | mcc.correlation_fn = "Spearman" 20 | mcc.batch_size=1000 21 | -------------------------------------------------------------------------------- /src/evaluation/extra_metrics_configs/mig.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | evaluation.evaluation_fn = @mig 17 | evaluation.random_seed = 0 18 | mig.num_train=100000 19 | discretizer.discretizer_fn = @histogram_discretizer 20 | discretizer.num_bins = 20 21 | mig.batch_size=1000 -------------------------------------------------------------------------------- /src/evaluation/extra_metrics_configs/modularity_explicitness.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | evaluation.evaluation_fn = @modularity_explicitness 17 | evaluation.random_seed = 0 18 | modularity_explicitness.num_train=100000 19 | modularity_explicitness.num_test=5000 20 | modularity_explicitness.batch_size=1000 21 | discretizer.discretizer_fn = @histogram_discretizer 22 | discretizer.num_bins = 20 23 | -------------------------------------------------------------------------------- /src/evaluation/extra_metrics_configs/sap_score.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | evaluation.evaluation_fn = @sap_score 17 | evaluation.random_seed = 0 18 | sap_score.num_train=10000 19 | sap_score.num_test=5000 20 | sap_score.continuous_factors=False 21 | sap_score.batch_size=1000 22 | -------------------------------------------------------------------------------- /src/evaluation/group_metric.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from torchmetrics import Metric 4 | import torch.nn as nn 5 | import torch 6 | 7 | class GroupMetric(nn.Module): 8 | """ 9 | A group metric to handle multiple same-type metrics e.g. multiple Accuracy 10 | """ 11 | 12 | def __init__(self, metric_class, group_size, names=None, **kwargs): 13 | super(GroupMetric, self).__init__() 14 | self.names = names 15 | if self.names is not None: 16 | assert len(names) == group_size 17 | self.group_metrics = nn.ModuleList([metric_class(**kwargs) for i in range(group_size)]) 18 | 19 | def update(self, preds, targets): 20 | for i, metric in enumerate(self.group_metrics): 21 | metric.update(preds[i], targets[i]) 22 | 23 | def compute(self): 24 | if self.names is None: 25 | return [metric.compute() for i, metric in enumerate(self.group_metrics)] 26 | else: 27 | return {self.names[i]: metric.compute() for i, metric in enumerate(self.group_metrics)} 28 | 29 | def forward(self, preds, targets, **kwargs): 30 | """ 31 | preds and targets are iterators over pairs of predictions and targets 32 | :param preds: 33 | :param targets: 34 | :param kwargs: 35 | :return: 36 | """ 37 | if self.names is None: 38 | return [metric(preds[i], targets[i], **kwargs) for i, metric in enumerate(self.group_metrics)] 39 | else: 40 | return {self.names[i]: metric(preds[i], targets[i], **kwargs) for i, metric in enumerate(self.group_metrics)} 41 | 42 | def mean(self): 43 | if self.names is None: 44 | return torch.mean(torch.FloatTensor(self.compute())) 45 | else: 46 | return torch.mean(torch.FloatTensor(list(self.compute().values()))) 47 | 48 | def reset(self): 49 | for metric in self.group_metrics: 50 | metric.reset() -------------------------------------------------------------------------------- /src/evaluation/scikit_learn_evaluator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from evaluation import CompGenalizationMetrics 4 | from .utils import infer 5 | 6 | class ScikitLearnEvaluator: 7 | def __init__(self, backbone, datamodule, mode, 8 | n_train=None, n_fold=1, reg_model='ridge', cls_model='logistic', reverse_task_type=False, 9 | ckpoint='', testOnTrain=False, inDist=False, **kwargs): 10 | """ 11 | 12 | :param backbone: 13 | :param datamodule: 14 | :param mode: 15 | :param n_train: number of labeled samples for training a supervised model 16 | :param n_fold: each fold has different n_train samples and results in a model, 17 | average performance of n_fold models are reported. 18 | :param reg_model: 19 | :param cls_model: 20 | :param reverse_task_type: add metrics that use reg models for cls factors and cls models for reg factors 21 | :param kwargs: 22 | """ 23 | self.backbone = backbone 24 | self.mode = mode 25 | self.n_train = n_train 26 | self.n_fold = n_fold 27 | self.reverse_task_type = reverse_task_type 28 | datamodule.prepare_data() 29 | datamodule.setup() 30 | task_types = datamodule.train_dataset.task_types 31 | if not reverse_task_type: 32 | self.factor_ids = {'reg': (task_types == 'reg').nonzero()[0], 33 | 'cls': (task_types == 'cls').nonzero()[0]} 34 | else: 35 | self.factor_ids = {'reg': (task_types == 'cls').nonzero()[0], 36 | 'cls': (task_types == 'reg').nonzero()[0]} 37 | self.metric_names = {'reg': 'R2', 'cls': 'acc'} 38 | self.model_names = {'reg': reg_model, 'cls': cls_model} 39 | 40 | self.regression_metric = CompGenalizationMetrics(train_data=None, 41 | test_data=None, 42 | n_factors=datamodule.train_dataset.num_factors, 43 | model=reg_model, 44 | factor_selector=self.factor_ids['reg'], 45 | n_fold = n_fold 46 | ) 47 | self.classification_metric = CompGenalizationMetrics(train_data=None, 48 | test_data=None, 49 | n_factors=datamodule.train_dataset.num_factors, 50 | model=cls_model, 51 | factor_selector=self.factor_ids['cls'], 52 | n_fold=n_fold 53 | ) 54 | 55 | self.train_data = datamodule.train_dataloader() 56 | self.test_data = datamodule.test_dataloader() 57 | 58 | self.factor_names = datamodule.train_dataset.lat_names 59 | self.checkpoint_name = ckpoint 60 | self.testOnTrain = testOnTrain 61 | self.testInDist = inDist and not testOnTrain 62 | 63 | def eval(self): 64 | train_X, train_y = infer(self.backbone, self.train_data, self.mode) 65 | train_X, train_y = train_X.numpy(), train_y.numpy() 66 | test_X, test_y = infer(self.backbone, self.test_data, self.mode) 67 | test_X, test_y = test_X.numpy(), test_y.numpy() 68 | 69 | if self.testOnTrain: 70 | test_X, test_y = train_X, train_y 71 | 72 | results = {'reg': self.regression_metric(model_zs=((train_X, train_y), (test_X, test_y)), mode=self.mode), 73 | 'cls': self.classification_metric(model_zs=((train_X, train_y), (test_X, test_y)), mode=self.mode)} 74 | log_dict = {} 75 | for task, scores in results.items(): 76 | metric_name = self.metric_names[task] 77 | for i, score in enumerate(scores): 78 | factor_name = self.factor_names[self.factor_ids[task][i]][:6] 79 | log_dict[f'{metric_name}_{factor_name}'] = score 80 | log_dict[f'{metric_name}_mean'] = scores.mean() 81 | return log_dict 82 | 83 | @property 84 | def name(self): 85 | return '{}_{}{}_reg-{}_cls-{}_{}fold{}{}{}_ckpt{}'.format( 86 | self.__class__.__name__, 87 | self.mode, 88 | f'_{self.n_train}train' if self.n_train is not None else '', 89 | self.model_names['reg'], 90 | self.model_names['cls'], 91 | self.n_fold, 92 | '_TestonTrain' if self.testOnTrain else '', 93 | '_TestInDist' if self.testInDist and not self.testOnTrain else '', 94 | '_rev' if self.reverse_task_type else '', 95 | self.checkpoint_name, 96 | ) -------------------------------------------------------------------------------- /src/evaluation/topo_sim.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Measuring the topographical similarity 4 | 5 | Referred the implementation at 6 | https://github.com/facebookresearch/EGG/blob/main/egg/zoo/compo_vs_generalization/intervention.py""" 7 | 8 | import editdistance 9 | import torch 10 | from scipy.stats import spearmanr 11 | from scipy import spatial 12 | from copy import deepcopy 13 | 14 | class TopoSimEval: 15 | def __init__(self, model, datamodule, n_sample=10000): 16 | super(TopoSimEval, self).__init__() 17 | self.model = model 18 | datamodule.prepare_data() 19 | datamodule.setup() 20 | self.datamodule = datamodule 21 | self.n_sample = n_sample 22 | 23 | def eval(self): 24 | train_messages, train_attributes, train_msg_len = infer_message(self.model, self.datamodule.train_dataloader(), self.n_sample) 25 | test_messages, test_attributes, test_msg_len = infer_message(self.model, self.datamodule.test_dataloader(), self.n_sample) 26 | 27 | res_train = topographic_similarity(train_messages, train_attributes, train_msg_len) 28 | res_test = topographic_similarity(test_messages, test_attributes, test_msg_len) 29 | return { 30 | f'topsim_train_{self.n_sample}': res_train, 31 | f'topsim_test_{self.n_sample}': res_test, 32 | } 33 | 34 | 35 | def infer_message(model, dataloader, sampling=False, n_sample=1000): 36 | with torch.no_grad(): 37 | model.eval() 38 | device = next(model.parameters()).device 39 | 40 | latents, targets, eos_ids = [], [], [] 41 | n = 0 42 | for x, t in dataloader: 43 | x = x.to(device=device) 44 | res = model.encode(x, sampling=sampling) 45 | z = res['z'] 46 | eos_id = res['eos_id'] 47 | latents.append(z.cpu()) 48 | targets.append(deepcopy(t)) 49 | eos_ids.append(eos_id.cpu()) 50 | n += x.shape[0] 51 | if n >= n_sample: 52 | break 53 | 54 | latents = torch.cat(latents)[:n_sample] 55 | targets = torch.cat(targets)[:n_sample] 56 | eos_ids = torch.cat(eos_ids)[:n_sample] 57 | return latents.argmax(-1).numpy(), targets.numpy(), eos_ids.numpy() 58 | 59 | 60 | def topographic_similarity(messages, attributes, msg_len=None): 61 | """ 62 | 63 | :param messages: discrete messages generated by models. numpy arrays of shape N_samples X max_len 64 | :param attributes: ground truth attributes. numpy arrays of shape N_samples X N_attributes 65 | :param msg_len: arrays of length for each message in messages 66 | :return: 67 | """ 68 | if msg_len is not None: 69 | messages_string = [] 70 | for i, s in enumerate(messages): 71 | messages_string.append([x.item() for x in s[:msg_len[i]]]) 72 | else: 73 | messages_string = messages 74 | distance_messages = edit_dist(messages_string) 75 | distance_inputs = cosine_dist(attributes.astype(float)) 76 | 77 | corr = spearmanr(distance_messages, distance_inputs).correlation 78 | return corr 79 | 80 | def edit_dist(_list): 81 | distances = [] 82 | count = 0 83 | for i, el1 in enumerate(_list[:-1]): 84 | for j, el2 in enumerate(_list[i + 1 :]): 85 | count += 1 86 | # Normalized edit distance (same in our case as length is fixed) 87 | distances.append(editdistance.eval(el1, el2) / max(len(el1), len(el2))) 88 | return distances 89 | 90 | def cosine_dist(_list): 91 | """ 92 | 93 | :param _list: 94 | :return: 95 | """ 96 | distances = [] 97 | for i, el1 in enumerate(_list[:-1]): 98 | for j, el2 in enumerate(_list[i + 1 :]): 99 | distances.append(spatial.distance.cosine(el1, el2)) 100 | return distances 101 | 102 | 103 | -------------------------------------------------------------------------------- /src/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | from copy import deepcopy 5 | 6 | def infer(model, dataloader, mode): 7 | with torch.no_grad(): 8 | model.eval() 9 | device = next(model.parameters()).device 10 | 11 | latents, targets = [], [] 12 | for x, t in dataloader: 13 | x = x.to(device=device) 14 | z = model.embed(x, mode) 15 | latents.append(z.cpu()) 16 | targets.append(deepcopy(t)) 17 | 18 | latents = torch.cat(latents) 19 | targets = torch.cat(targets) 20 | 21 | return latents, targets -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from .vae import VAE 4 | from .rec_el import RecurrentEmergentLanguage 5 | from .beta_tcvae import BetaTCVAE 6 | from .ae import AutoEncoder 7 | 8 | vae_models = { 9 | 'VAE': VAE, 10 | 'RecurrentEL': RecurrentEmergentLanguage, 11 | 'BetaTCVAE': BetaTCVAE, 12 | 'AE': AutoEncoder, 13 | } -------------------------------------------------------------------------------- /src/models/ae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from math import sqrt 4 | import torch 5 | from torch.nn import functional as F 6 | import torchvision.utils as vutils 7 | import pytorch_lightning as pl 8 | 9 | from architectures.helper import build_architectures 10 | from .optimizer import init_optimizer 11 | from commons.types_ import * 12 | 13 | 14 | class AutoEncoder(pl.LightningModule): 15 | def __init__(self, 16 | input_size: List, 17 | architecture: str, 18 | latent_size: int, 19 | recon_loss: str = 'mse', 20 | lr: float = 0.001, 21 | optim: str = 'adam', 22 | weight_decay: float = 0, 23 | **kwargs) -> None: 24 | """ 25 | 26 | :param input_size: (n_channels, image_height, image_width) 27 | :param architecture: 28 | :param latent_size: 29 | :param img_size: 30 | :param recon_loss: 'mse' for Gaussian decoder and 'bce' for the bernoulli decoder 31 | :param lr: 32 | :param optim: 33 | :param weight_decay: 34 | :param kwargs: 35 | """ 36 | 37 | super(AutoEncoder, self).__init__() 38 | 39 | self.save_hyperparameters() 40 | self.latent_size = latent_size 41 | self.architecture = architecture 42 | self.input_size = input_size 43 | self.recon_loss = recon_loss 44 | self.optim = optim 45 | self.lr = lr 46 | self.weight_decay = weight_decay 47 | self.setup_models() 48 | pass 49 | 50 | def setup_models(self): 51 | (self.encoder_conv, self.decoder_conv), (self.encoder_latent, self.decoder_latent) = build_architectures( 52 | self.input_size, self.architecture, self.latent_size, model=self.__class__.__name__) 53 | 54 | def encode_latent(self, feat): 55 | """ 56 | Encode the feature from backbone to latent variables and reparameterize it 57 | :param emb: 58 | :return: 59 | """ 60 | return self.encoder_latent(feat) 61 | 62 | def decode_latent(self, z): 63 | """ 64 | Decode latent variable z into a feature 65 | :param z: 66 | :return: 67 | """ 68 | return self.decoder_latent(z) 69 | 70 | def encode(self, input: Tensor) -> List[Tensor]: 71 | """ 72 | Encodes the input by passing through the encoder network 73 | and returns the latent codes. 74 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 75 | :param sampling: (Bool) if sample from latent distribution 76 | :return: (Tensor) mean and variation logits of latent distribution 77 | """ 78 | feat = self.encoder_conv(input) 79 | z = self.encode_latent(feat) 80 | return z 81 | 82 | def decode(self, z: Tensor) -> Tensor: 83 | """ 84 | Maps the given latent codes 85 | onto the image space. 86 | :param z: latent variables 87 | :return: (Tensor) [B x C x H x W] 88 | """ 89 | return self.decoder_conv(self.decode_latent(z)) 90 | 91 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 92 | results = {} 93 | z = self.encode(input) 94 | results['recon'] = self.decode(z) 95 | return results 96 | 97 | def embed(self, x, mode, **kwargs): 98 | """ 99 | Function to call to use VAE as a backbone model for downstream tasks 100 | :param x: 101 | :param mode: 102 | :param sampling: 103 | :param kwargs: 104 | :return: 105 | """ 106 | if mode == 'pre': 107 | return self.encoder_conv(x) 108 | else: 109 | z = self.encode(x) 110 | if mode == 'latent': 111 | return z 112 | elif mode == 'post': 113 | return self.decoder_latent(z) 114 | else: 115 | raise ValueError() 116 | 117 | def compute_loss(self, inputs, results, labels=None): 118 | recon_loss = self.compute_recontruct_loss(inputs, results) 119 | loss_dict = {'loss': recon_loss, 'recon_loss': recon_loss} 120 | return loss_dict 121 | 122 | def compute_recontruct_loss(self, inputs, results): 123 | if self.recon_loss == 'mse': 124 | recon_loss = F.mse_loss(results['recon'], inputs, reduction='sum') / inputs.size(0) 125 | elif self.recon_loss == 'bce': 126 | recon_loss = F.binary_cross_entropy_with_logits(results['recon'], inputs, reduction='sum') / inputs.size(0) 127 | return recon_loss 128 | 129 | def step(self, batch, batch_idx, stage='train') -> dict: 130 | """ 131 | Computes the VAE loss function. 132 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 133 | :param args: 134 | :param kwargs: 135 | :return: 136 | """ 137 | x, y = batch 138 | results = self.forward(x) 139 | loss_dict = self.compute_loss(x, results) 140 | log = {f"{stage}_{k}": v.detach() for k, v in loss_dict.items()} 141 | 142 | if batch_idx == 0 and self.logger and stage != 'test': 143 | self.sample_images(batch, stage=stage) 144 | 145 | return loss_dict['loss'], log 146 | 147 | def training_step(self, batch, batch_idx, optimizer_idx = 0): 148 | loss, logs = self.step(batch, batch_idx) 149 | self.log_dict(logs, prog_bar=True) 150 | return loss 151 | 152 | def validation_step(self, batch, batch_idx): 153 | loss, logs = self.step(batch, batch_idx, 'val') 154 | self.log_dict(logs, prog_bar=True, on_step=False, on_epoch=True) 155 | return loss 156 | 157 | def test_step(self, batch, batch_idx): 158 | loss, logs = self.step(batch, batch_idx, 'test') 159 | return logs 160 | 161 | def test_epoch_end(self, outputs): 162 | metrics = {} 163 | for key in outputs[0]: 164 | if 'loss' in key: 165 | metrics['{}'.format(key)] = torch.stack([x[key] for x in outputs]).mean() 166 | self.log_dict({key: val.item() for key, val in metrics.items()}, prog_bar=False) 167 | 168 | try: 169 | hparams_log = {} 170 | for key, val in self.hparams.items(): 171 | if type(val) == list: 172 | hparams_log[key] = torch.tensor(val) 173 | self.logger.experiment.add_hparams(hparams_log, metrics) 174 | except: 175 | print("Failed to add hparams") 176 | 177 | def configure_optimizers(self): 178 | optimizer = init_optimizer(self.optim, self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 179 | return {'optimizer': optimizer} 180 | 181 | def sample_images(self, batch, num=25, stage='train'): 182 | # Get sample reconstruction image 183 | inputs, labels = batch 184 | if inputs.size(0)>num: 185 | inputs, labels = inputs[:num], labels[:num] 186 | recons = self.forward(inputs, labels=labels)['recon'] 187 | if self.recon_loss == 'bce': 188 | recons = torch.sigmoid(recons) 189 | 190 | inputs_grids = vutils.make_grid(inputs, normalize=True, nrow=int(sqrt(num)), pad_value=1) 191 | recon_grids = vutils.make_grid(recons, normalize=True, nrow=int(sqrt(num)), pad_value=1) 192 | self.logger.log_image(key=f'input_{stage}', images=[inputs_grids], 193 | caption=[f'epoch_{self.current_epoch}']) 194 | self.logger.log_image(key=f'recon_{stage}', images=[recon_grids], 195 | caption=[f'epoch_{self.current_epoch}']) 196 | del inputs, recons 197 | 198 | @property 199 | def name(self) -> str: 200 | return self.make_name() 201 | 202 | @property 203 | def backbone_name(self) -> str: 204 | return self.make_backbone_name() 205 | 206 | def make_name(self) -> str: 207 | """ 208 | Get the name of the model according its parameters 209 | """ 210 | return "{}_{}_{}_lr{}_{}_wd{}".format( 211 | self.__class__.__name__, 212 | self.make_backbone_name(), 213 | self.recon_loss, 214 | self.lr, 215 | self.optim, 216 | self.weight_decay, 217 | ) 218 | 219 | def make_backbone_name(self) -> str: 220 | """ 221 | Get the name of the backbone according its parameters 222 | """ 223 | return "{}_z{}".format( 224 | self.architecture, 225 | self.latent_size, 226 | ) 227 | 228 | def get_rep_size(self, mode): 229 | if mode == 'latent': 230 | return self.latent_size 231 | elif mode == 'pre': 232 | return self.encoder_conv.output_size 233 | elif mode == 'post': 234 | return self.decoder_latent.output_size 235 | else: 236 | raise ValueError() -------------------------------------------------------------------------------- /src/models/beta_tcvae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Beta Total Correlation VAE 4 | """ 5 | import os 6 | import sys 7 | sys.path.append(os.path.realpath('..')) 8 | import torch 9 | import math 10 | from models.vae import VAE 11 | 12 | class BetaTCVAE(VAE): 13 | 14 | def __init__(self, 15 | alpha=1, 16 | gamma=1, 17 | mss=False, 18 | **kwargs): 19 | self.alpha = alpha 20 | self.gamma = gamma 21 | self.mss = mss 22 | super(BetaTCVAE, self).__init__(**kwargs) 23 | 24 | def compute_loss(self, inputs, results, labels=None): 25 | recon_loss = self.compute_recontruct_loss(inputs, results) 26 | mi, tc, dim_KL = self.compute_KLD_loss(results) 27 | loss = recon_loss + self.alpha * mi + self.beta * tc + self.gamma * dim_KL 28 | loss_dict = {'loss': loss, 'recon_loss': recon_loss, 29 | 'mi': mi, 'tc': tc, 'dim_KL': dim_KL} 30 | return loss_dict 31 | 32 | def compute_KLD_loss(self, results): 33 | """Compute decomposed KL loss""" 34 | mu = results['mu'] 35 | log_var = results['log_var'] 36 | z = results['z'] 37 | batch_size, dim = z.shape 38 | try: 39 | dataset_size = len(self.trainer.datamodule.train_dataset) 40 | except: 41 | dataset_size = 10000 42 | 43 | log_pz = gaussian_log_density(z, torch.zeros_like(z), torch.zeros_like(z)).sum(1) 44 | log_qz_cond_x = gaussian_log_density(z, mu, log_var).sum(1) 45 | 46 | # compute log q(z) ~= log 1/(NM) sum_m=1^M q(z|x_m) = - log(MN) + logsumexp_m(q(z|x_m)) 47 | matrix_log_qz = gaussian_log_density(z.view(batch_size, 1, dim), 48 | mu.view(1, batch_size, dim), 49 | log_var.view(1, batch_size, dim)) 50 | 51 | if not self.mss: 52 | # minibatch weighted sampling 53 | log_qz_prod_marginals = (torch.logsumexp(matrix_log_qz, dim=1, keepdim=False) - math.log(batch_size * dataset_size)).sum(1) 54 | log_qz = (torch.logsumexp(matrix_log_qz.sum(2), dim=1, keepdim=False) - math.log(batch_size * dataset_size)) 55 | else: 56 | # minibatch stratified sampling 57 | logiw_matrix = log_importance_weight_matrix(batch_size, dataset_size).type_as(matrix_log_qz.data) 58 | log_qz = torch.logsumexp(logiw_matrix + matrix_log_qz.sum(2), dim=1, keepdim=False) 59 | log_qz_prod_marginals = torch.logsumexp( 60 | logiw_matrix.view(batch_size, batch_size, 1) + matrix_log_qz, dim=1, keepdim=False).sum(1) 61 | 62 | return (log_qz_cond_x - log_qz).mean(), (log_qz - log_qz_prod_marginals).mean(), (log_qz_prod_marginals - log_pz).mean() 63 | 64 | def make_name(self) -> str: 65 | """ 66 | Get the name of the model according its parameters 67 | """ 68 | return "{}_{}_alpha{}beta{}gamma{}{}_{}_lr{}_{}_wd{}".format( 69 | self.__class__.__name__, 70 | self.make_backbone_name(), 71 | self.alpha, 72 | self.beta, 73 | self.gamma, 74 | '_mss' if self.mss else '', 75 | self.recon_loss, 76 | self.lr, 77 | self.optim, 78 | self.weight_decay, 79 | ) 80 | 81 | def gaussian_log_density(samples, mean, log_var): 82 | """ Estimate the log density of a Gaussian distribution 83 | Borrowed from https://github.com/google-research/disentanglement_lib/ 84 | :param samples: batched samples of the Gaussian densities with mean=mean and log of variance = log_var 85 | :param mean: batched means of Gaussian densities 86 | :param log_var: batches means of log_vars 87 | :return: 88 | """ 89 | import math 90 | pi = torch.tensor(math.pi, requires_grad=False) 91 | normalization = torch.log(2. * pi) 92 | inv_var = torch.exp(-log_var) 93 | tmp = samples - mean 94 | return -0.5 * (tmp * tmp * inv_var + log_var + normalization) 95 | 96 | 97 | def matrix_log_density_gaussian(x, mu, logvar): 98 | """Calculates log density of a Gaussian for all combination of bacth pairs of 99 | `x` and `mu`. I.e. return tensor of shape `(batch_size, batch_size, dim)` 100 | instead of (batch_size, dim) in the usual log density. 101 | Parameters 102 | ---------- 103 | x: torch.Tensor 104 | Value at which to compute the density. Shape: (batch_size, dim). 105 | mu: torch.Tensor 106 | Mean. Shape: (batch_size, dim). 107 | logvar: torch.Tensor 108 | Log variance. Shape: (batch_size, dim). 109 | batch_size: int 110 | number of training images in the batch 111 | """ 112 | batch_size, dim = x.shape 113 | x = x.view(batch_size, 1, dim) 114 | mu = mu.view(1, batch_size, dim) 115 | logvar = logvar.view(1, batch_size, dim) 116 | return gaussian_log_density(x, mu, logvar) 117 | 118 | def log_importance_weight_matrix(batch_size, dataset_size): 119 | N = dataset_size 120 | M = batch_size - 1 121 | strat_weight = (N - M) / (N * M) 122 | W = torch.Tensor(batch_size, batch_size).fill_(1 / M) 123 | W.view(-1)[::M + 1] = 1 / N 124 | W.view(-1)[1::M + 1] = strat_weight 125 | W[M - 1, 0] = strat_weight 126 | return W.log() 127 | 128 | -------------------------------------------------------------------------------- /src/models/optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | From 3 | 4 | Helper functions to initialize an optimizaton algorithm. 5 | """ 6 | 7 | 8 | import torch.optim as optim 9 | import torch.optim.lr_scheduler as lr_scheduler 10 | 11 | 12 | def init_optimizer(optimizer, params, lr=0.01, weight_decay=0.0, **kwargs): 13 | 14 | if optimizer == 'adam': 15 | optimizer = optim.Adam(params, lr=lr, weight_decay=weight_decay, **kwargs) 16 | elif optimizer == 'sparseadam': 17 | optimizer = optim.SparseAdam(params, lr=lr, **kwargs) 18 | elif optimizer == 'adamax': 19 | optimizer = optim.Adamax(params, lr=lr, weight_decay=weight_decay, **kwargs) 20 | elif optimizer == 'rmsprop': 21 | optimizer = optim.RMSprop(params, lr=lr, 22 | weight_decay=weight_decay, **kwargs) 23 | elif optimizer == 'sgd': 24 | optimizer = optim.SGD(params, lr=lr, 25 | weight_decay=weight_decay, **kwargs) # 0.01 26 | elif optimizer == 'nesterov': 27 | optimizer = optim.SGD(params, lr=lr, weight_decay=weight_decay, 28 | nesterov=True, **kwargs) 29 | elif optimizer == 'adagrad': 30 | optimizer = optim.Adagrad(params, lr=lr, 31 | weight_decay=weight_decay, **kwargs) 32 | elif optimizer == 'adadelta': 33 | optimizer = optim.Adadelta(params, lr=lr, 34 | weight_decay=weight_decay, **kwargs) 35 | else: 36 | raise ValueError(r'Optimizer {0} not recognized'.format(optimizer)) 37 | 38 | return optimizer 39 | 40 | 41 | def init_lr_scheduler(optimizer, scheduler, lr_decay, 42 | patience, threshold=1e-4, min_lr=1e-9): 43 | 44 | if scheduler == 'reduce-on-plateau': 45 | scheduler = lr_scheduler.ReduceLROnPlateau( 46 | optimizer, 47 | factor=lr_decay, 48 | patience=patience, 49 | threshold=threshold, 50 | min_lr=min_lr 51 | ) 52 | else: 53 | raise ValueError(r'Scheduler {0} not recognized'.format(scheduler)) 54 | 55 | return scheduler 56 | -------------------------------------------------------------------------------- /src/models/rec_el.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | from models.vae import VAE 5 | from architectures.helper import build_architectures 6 | from architectures.lstm_latent import LatentModuleLSTM 7 | from commons.types_ import * 8 | 9 | 10 | class RecurrentEmergentLanguage(VAE): 11 | def __init__(self, 12 | input_size: List, 13 | architecture: str, 14 | latent_size: int, 15 | dictionary_size: int, 16 | beta: float = 1.0, 17 | recon_loss: str = 'mse', 18 | gsm_temperature: int = 1, 19 | soft_discrete=False, 20 | lr: float = 0.001, 21 | optim: str = 'adam', 22 | weight_decay: float = 0, 23 | fix_length=False, 24 | deterministic=False, 25 | **kwargs): 26 | self.fix_length = fix_length 27 | self.deterministic = deterministic 28 | self.dictionary_size = dictionary_size 29 | self.gsm_temperature = gsm_temperature 30 | self.soft_discrete = soft_discrete 31 | 32 | super(RecurrentEmergentLanguage, self).__init__( 33 | input_size=input_size, architecture=architecture, latent_size=latent_size, beta=beta, recon_loss=recon_loss, 34 | lr=lr, optim=optim, weight_decay=weight_decay, **kwargs) 35 | 36 | # prior distributions (logP) 37 | prior = torch.log( 38 | torch.tensor([1 / self.dictionary_size] * self.dictionary_size, dtype=torch.float).repeat( 39 | 1, self.latent_size, 1)) 40 | self.register_buffer('prior', prior) 41 | 42 | 43 | def setup_models(self): 44 | (self.encoder_conv, self.decoder_conv), latent_config = build_architectures( 45 | self.input_size, self.architecture, self.latent_size, model=self.__class__.__name__) 46 | self.latent_layers = LatentModuleLSTM( 47 | input_size=self.encoder_conv.output_size, 48 | output_size=self.decoder_conv.input_size, 49 | hidden_size=latent_config['hidden_size'], 50 | latent_size=self.latent_size, 51 | dictionary_size=self.dictionary_size, 52 | fix_length=self.fix_length, 53 | temperature=self.gsm_temperature, 54 | ) 55 | 56 | def encode(self, x: Tensor, sampling) -> Dict[str, Tensor]: 57 | feat = self.encoder_conv(x) 58 | res = self.latent_layers.encode(feat, sampling=sampling and not self.deterministic) 59 | return res 60 | 61 | def decode(self, inputs): 62 | return self.decoder_conv(self.latent_layers.decode(**inputs)) 63 | 64 | def embed(self, x, mode, sampling=False, **kwargs): 65 | """ 66 | Function to call to use VAE as a backbone model for downstream tasks 67 | :param x: 68 | :param mode: 69 | :param sampling: 70 | :param kwargs: 71 | :return: 72 | """ 73 | if mode == 'pre': 74 | return self.encoder_conv(x) 75 | else: 76 | enc_res = self.encode(x, sampling=sampling) 77 | if mode == 'latent': 78 | return enc_res['z'].argmax(-1).float() 79 | elif mode == 'post': 80 | return self.latent_layers.decode(**enc_res) 81 | else: 82 | raise ValueError() 83 | 84 | def compute_KLD_loss(self, results): 85 | # Calculate KL divergence 86 | logits = results['logits'] 87 | logits_dist = torch.distributions.OneHotCategorical(logits=logits) 88 | prior_batch = self.prior.expand(logits.shape) 89 | prior_dist = torch.distributions.OneHotCategorical(logits=prior_batch) 90 | kl = torch.distributions.kl_divergence(logits_dist, prior_dist) 91 | if not self.fix_length: 92 | eos_ind = results['eos_id'] 93 | kl_loss_mask = torch.arange(0, self.latent_size).to(logits.device).repeat( 94 | logits.shape[0], 1) < eos_ind.unsqueeze(1) 95 | kl *= kl_loss_mask 96 | return kl.sum(1).mean(0) 97 | 98 | def make_backbone_name(self) -> str: 99 | """ 100 | Get the name of the backbone according its parameters 101 | """ 102 | return "{}_z{}{}{}_D{}_gsmT{}".format( 103 | self.architecture, 104 | self.latent_size, 105 | '_fix' if self.fix_length else '', 106 | '_determ' if self.deterministic else '', 107 | self.dictionary_size, 108 | self.gsm_temperature 109 | ) 110 | 111 | def get_rep_size(self, mode): 112 | if mode == 'latent': 113 | return self.latent_size 114 | elif mode == 'pre': 115 | return self.encoder_conv.output_size 116 | elif mode == 'post': 117 | return self.latent_layers.output_size 118 | else: 119 | raise ValueError() -------------------------------------------------------------------------------- /src/models/vae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # sys.path.append(os.path.realpath('..')) 4 | from math import sqrt 5 | import torch 6 | from torch.nn import functional as F 7 | import torchvision.utils as vutils 8 | 9 | import pytorch_lightning as pl 10 | 11 | from architectures.helper import build_architectures 12 | from .optimizer import init_optimizer 13 | from commons.types_ import * 14 | 15 | 16 | class VAE(pl.LightningModule): 17 | def __init__(self, 18 | input_size: List, 19 | architecture: str, 20 | latent_size: int, 21 | beta: float = 1.0, 22 | icc: bool = False, 23 | icc_max: float = 20, 24 | icc_min: float = 0, 25 | icc_steps: float = 100000, 26 | recon_loss: str = 'mse', 27 | lr: float = 0.001, 28 | optim: str = 'adam', 29 | weight_decay: float = 0, 30 | **kwargs) -> None: 31 | """ 32 | 33 | :param input_size: (n_channels, image_height, image_width) 34 | :param architecture: 35 | :param latent_size: 36 | :param img_size: 37 | :param beta: 38 | :param icc: |KL - C|*beta with a increasing schedule on C based on: Understanding disentangling in β-VAE https://arxiv.org/pdf/1804.03599.pdf 39 | :param recon_loss: 'mse' for Gaussian decoder and 'bce' for the bernoulli decoder 40 | :param lr: 41 | :param optim: 42 | :param weight_decay: 43 | :param kwargs: 44 | """ 45 | 46 | super(VAE, self).__init__() 47 | 48 | self.save_hyperparameters() 49 | self.latent_size = latent_size 50 | self.architecture = architecture 51 | self.input_size = input_size 52 | self.beta = beta 53 | self.recon_loss = recon_loss 54 | 55 | # information capacity control hyperparameters 56 | # in original paper, icc is from 0 to 25 nats, icc_step = 100000 iters, beta=1000 57 | self.icc = icc 58 | self.icc_max = icc_max 59 | self.icc_min = icc_min 60 | self.icc_steps = icc_steps 61 | 62 | self.optim = optim 63 | self.lr = lr 64 | self.weight_decay = weight_decay 65 | self.setup_models() 66 | pass 67 | 68 | def setup_models(self): 69 | (self.encoder_conv, self.decoder_conv), (self.encoder_latent, self.decoder_latent) = build_architectures( 70 | self.input_size, self.architecture, self.latent_size, model=self.__class__.__name__) 71 | 72 | def encode_latent(self, feat, sampling): 73 | """ 74 | Encode the feature from backbone to latent variables and reparameterize it 75 | :param emb: 76 | :param sampling: (Bool) if sample from latent distribution 77 | :return: 78 | """ 79 | latent = self.encoder_latent(feat) 80 | results = self.reparameterize(latent, sampling=sampling) 81 | return results 82 | 83 | def decode_latent(self, z): 84 | """ 85 | Decode latent variable z into a feature 86 | :param z: 87 | :return: 88 | """ 89 | return self.decoder_latent(z) 90 | 91 | def encode(self, input: Tensor, sampling) -> List[Tensor]: 92 | """ 93 | Encodes the input by passing through the encoder network 94 | and returns the latent codes. 95 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 96 | :param sampling: (Bool) if sample from latent distribution 97 | :return: (Tensor) mean and variation logits of latent distribution 98 | """ 99 | feat = self.encoder_conv(input) 100 | res = self.encode_latent(feat, sampling=sampling) 101 | return res 102 | 103 | def decode(self, inputs: Tensor) -> Tensor: 104 | """ 105 | Maps the given latent codes 106 | onto the image space. 107 | :param inputs: A dictionary of encoder outputs 108 | :return: (Tensor) [B x C x H x W] 109 | """ 110 | return self.decoder_conv(self.decode_latent(inputs['z'])) 111 | 112 | def reparameterize(self, latent: dict, sampling) -> dict: 113 | """ 114 | Reparameterization trick to sample from N(mu, var) from 115 | N(0,1). 116 | :param Input dictionary with two keys: 117 | mu: (Tensor) Mean of the latent Gaussian [B x D] 118 | logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 119 | :param sampling: (Bool) if sample from latent distribution 120 | :return: output dictionary 121 | """ 122 | mu, log_var = torch.split(latent, self.latent_size, dim=1) 123 | # std = torch.exp(0.5 * log_var) 124 | std = log_var.mul(0.5).exp_() + torch.finfo(torch.float32).eps 125 | 126 | # # prior 127 | # p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std)) 128 | # 129 | # try: 130 | # q = torch.distributions.Normal(mu, std) # posterior 131 | # z = q.rsample() 132 | # except BaseException as e: 133 | # # print(str(e)) 134 | # # print(mu, std) 135 | # # print(std[0<=std]) 136 | 137 | eps = torch.randn_like(std) 138 | z = mu.addcmul(std, eps) 139 | 140 | return { 141 | 'mu': mu, 142 | 'log_var': log_var, 143 | # 'prior': p, 144 | # 'posterior': q, 145 | 'z': z if sampling else mu, 146 | } 147 | 148 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 149 | results = self.encode(input, sampling=self.training) 150 | results['recon'] = self.decode(results) 151 | return results 152 | 153 | def embed(self, x, mode, sampling=False, **kwargs): 154 | """ 155 | Function to call to use VAE as a backbone model for downstream tasks 156 | :param x: 157 | :param mode: 158 | :param sampling: 159 | :param kwargs: 160 | :return: 161 | """ 162 | if mode == 'pre': 163 | return self.encoder_conv(x) 164 | else: 165 | enc_res = self.encode(x, sampling=sampling) 166 | if mode == 'latent': 167 | return enc_res['z'] 168 | elif mode == 'post': 169 | return self.decoder_latent(enc_res['z']) 170 | else: 171 | raise ValueError() 172 | 173 | def compute_loss(self, inputs, results, labels=None): 174 | recon_loss = self.compute_recontruct_loss(inputs, results) 175 | 176 | kld_loss = self.compute_KLD_loss(results) 177 | 178 | loss = recon_loss 179 | if self.beta > 0: 180 | if self.icc: 181 | capacity = min(self.icc_max, 182 | self.icc_min + (self.icc_max - self.icc_min) * self.global_step / self.icc_steps) 183 | loss += (kld_loss - capacity).abs() * self.beta 184 | else: 185 | loss += kld_loss * self.beta 186 | loss_dict = {'loss': loss, 'recon_loss': recon_loss, 'kl_loss': kld_loss} 187 | return loss_dict 188 | 189 | def compute_recontruct_loss(self, inputs, results): 190 | if self.recon_loss == 'mse': 191 | recon_loss = F.mse_loss(results['recon'], inputs, reduction='sum') / inputs.size(0) 192 | elif self.recon_loss == 'bce': 193 | recon_loss = F.binary_cross_entropy_with_logits(results['recon'], inputs, reduction='sum') / inputs.size(0) 194 | return recon_loss 195 | 196 | def compute_KLD_loss(self, results): 197 | mu = results['mu'] 198 | log_var = results['log_var'] 199 | kl = -0.5 * ((1 + log_var - mu ** 2 - log_var.exp()).sum(1).mean(0)) 200 | return kl 201 | 202 | def step(self, batch, batch_idx, stage='train') -> dict: 203 | """ 204 | Computes the VAE loss function. 205 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 206 | :param args: 207 | :param kwargs: 208 | :return: 209 | """ 210 | x, y = batch 211 | results = self.forward(x) 212 | loss_dict = self.compute_loss(x, results) 213 | log = {f"{stage}_{k}": v.detach() for k, v in loss_dict.items()} 214 | 215 | if batch_idx == 0 and self.logger and stage != 'test': 216 | self.sample_images(batch, stage=stage) 217 | 218 | return loss_dict['loss'], log 219 | 220 | def training_step(self, batch, batch_idx, optimizer_idx = 0): 221 | loss, logs = self.step(batch, batch_idx) 222 | self.log_dict(logs, prog_bar=True) 223 | return loss 224 | 225 | def validation_step(self, batch, batch_idx): 226 | loss, logs = self.step(batch, batch_idx, 'val') 227 | self.log_dict(logs, prog_bar=True, on_step=False, on_epoch=True) 228 | return loss 229 | 230 | def test_step(self, batch, batch_idx): 231 | loss, logs = self.step(batch, batch_idx, 'test') 232 | return logs 233 | 234 | def test_epoch_end(self, outputs): 235 | metrics = {} 236 | for key in outputs[0]: 237 | if 'loss' in key: 238 | metrics['{}'.format(key)] = torch.stack([x[key] for x in outputs]).mean() 239 | self.log_dict({key: val.item() for key, val in metrics.items()}, prog_bar=False) 240 | 241 | try: 242 | hparams_log = {} 243 | for key, val in self.hparams.items(): 244 | if type(val) == list: 245 | hparams_log[key] = torch.tensor(val) 246 | self.logger.experiment.add_hparams(hparams_log, metrics) 247 | except: 248 | print("Failed to add hparams") 249 | 250 | def configure_optimizers(self): 251 | optimizer = init_optimizer(self.optim, self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 252 | return {'optimizer': optimizer} 253 | 254 | def sample_images(self, batch, num=25, stage='train'): 255 | # Get sample reconstruction image 256 | inputs, labels = batch 257 | if inputs.size(0)>num: 258 | inputs, labels = inputs[:num], labels[:num] 259 | recons = self.forward(inputs, labels=labels)['recon'] 260 | if self.recon_loss == 'bce': 261 | recons = torch.sigmoid(recons) 262 | 263 | inputs_grids = vutils.make_grid(inputs, normalize=True, nrow=int(sqrt(num)), pad_value=1) 264 | recon_grids = vutils.make_grid(recons, normalize=True, nrow=int(sqrt(num)), pad_value=1) 265 | self.logger.log_image(key=f'input_{stage}', images=[inputs_grids], 266 | caption=[f'epoch_{self.current_epoch}']) 267 | self.logger.log_image(key=f'recon_{stage}', images=[recon_grids], 268 | caption=[f'epoch_{self.current_epoch}']) 269 | del inputs, recons 270 | 271 | @property 272 | def name(self) -> str: 273 | return self.make_name() 274 | 275 | @property 276 | def backbone_name(self) -> str: 277 | return self.make_backbone_name() 278 | 279 | def make_name(self) -> str: 280 | """ 281 | Get the name of the model according its parameters 282 | """ 283 | return "{}_{}_beta{}{}_{}_lr{}_{}_wd{}".format( 284 | self.__class__.__name__, 285 | self.make_backbone_name(), 286 | self.beta, 287 | f'_icc{self.icc_min}-{self.icc_max}-{self.icc_steps}steps' if self.icc else '', 288 | self.recon_loss, 289 | self.lr, 290 | self.optim, 291 | self.weight_decay, 292 | ) 293 | 294 | def make_backbone_name(self) -> str: 295 | """ 296 | Get the name of the backbone according its parameters 297 | """ 298 | return "{}_z{}".format( 299 | self.architecture, 300 | self.latent_size, 301 | ) 302 | 303 | def get_rep_size(self, mode): 304 | if mode == 'latent': 305 | return self.latent_size 306 | elif mode == 'pre': 307 | return self.encoder_conv.output_size 308 | elif mode == 'post': 309 | return self.decoder_latent.output_size 310 | else: 311 | raise ValueError() -------------------------------------------------------------------------------- /src/scripts/configs/rec_el.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'RecurrentEL' 3 | input_size: [1, 64, 64] 4 | architecture: 'base' 5 | latent_size: 10 6 | dictionary_size: 256 7 | beta: 0.0 8 | lr: 0.0001 9 | recon_loss: 'bce' 10 | 11 | exp_params: 12 | dataset: dsprites90d_random_v5 13 | data_path: "YourPathToData" 14 | train_workers: 4 15 | val_workers: 1 16 | random_seed: 2001 17 | batch_size: 64 # Better to have a square number 18 | max_epochs: 100 19 | 20 | trainer_params: 21 | gpus: [0, ] 22 | 23 | logging_params: 24 | save_dir: "logs/" 25 | -------------------------------------------------------------------------------- /src/scripts/configs/scikitlearn_eval.yaml: -------------------------------------------------------------------------------- 1 | mode: 'latent' 2 | n_train: 1000 3 | ckpoint: 'last' 4 | n_fold: 1 5 | reg_model: 'ridge' 6 | cls_model: 'logistic' -------------------------------------------------------------------------------- /src/scripts/configs/vae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'VAE' 3 | input_size: [1, 64, 64] 4 | architecture: 'base' 5 | latent_size: 10 6 | beta: 0.0 7 | lr: 0.0001 8 | recon_loss: 'bce' 9 | 10 | exp_params: 11 | dataset: dsprites90d_random_v5 12 | data_path: "YourPathToData" 13 | train_workers: 4 14 | val_workers: 1 15 | random_seed: 2001 16 | batch_size: 64 # Better to have a square number 17 | max_epochs: 100 18 | 19 | trainer_params: 20 | gpus: [0, ] 21 | 22 | logging_params: 23 | save_dir: "logs/" 24 | -------------------------------------------------------------------------------- /src/scripts/eval_gt_rep.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Evaluate ground truth representation 4 | 5 | Created by zhenlinx on 07/31/2022 6 | """ 7 | import os 8 | import sys 9 | sys.path.append(os.path.realpath('..')) 10 | import argparse 11 | import yaml 12 | from itertools import product 13 | import torch 14 | from pytorch_lightning import seed_everything 15 | from pytorch_lightning.loggers import WandbLogger 16 | from copy import deepcopy 17 | 18 | from evaluation import ScikitLearnEvaluator 19 | from datasets import get_datamodule 20 | from scripts.experiments import add_vae_argument, scikitlearn_eval, setup_experiment 21 | 22 | 23 | def linear_map(x): 24 | x = x.float() 25 | latent_size = x.shape[1] 26 | scale = (torch.rand(latent_size)+1e-6).unsqueeze(0) 27 | return x*scale 28 | 29 | def affine_map(x): 30 | x = x.float() 31 | latent_size = x.shape[1] 32 | scale = (torch.rand(latent_size)+1e-6).unsqueeze(0) 33 | res = torch.rand(latent_size).unsqueeze(0) 34 | return x * scale + res 35 | 36 | def polynomial_map(x): 37 | x = x.float() 38 | latent_size = x.shape[1] 39 | # scale_2 = (torch.rand(latent_size)+1e-6).unsqueeze(0) 40 | # scale_1 = (torch.rand(latent_size)+1e-6).unsqueeze(0) 41 | # scale_0 = (torch.rand(latent_size)+1e-6).unsqueeze(0) 42 | # # return x ** 2 * scale_2 + x * scale_1 + scale_0 43 | return x ** 2 44 | 45 | rep_mapping_functions = { 46 | 'same': lambda x: x.float(), 47 | 'linear': linear_map, 48 | 'affine': affine_map, 49 | 'polynomial': polynomial_map, 50 | } 51 | 52 | class ScikitLearnGTRepEvaluator(ScikitLearnEvaluator): 53 | def __init__(self, latent_size=None, **kwargs): 54 | super(ScikitLearnGTRepEvaluator, self).__init__(**kwargs) 55 | self.mapping = kwargs["mapping"] # how we map the linear 56 | self.latent_sizes = self.train_data.dataset.lat_sizes 57 | 58 | def eval(self): 59 | # get gt rep 60 | train_X, train_y = self.get_gt_rep(self.train_data) 61 | test_X, test_y = self.get_gt_rep(self.test_data) 62 | 63 | train_X, test_X = self.make_rep(train_X, test_X) 64 | train_X, train_y = train_X.numpy(), train_y.numpy() 65 | test_X, test_y = test_X.numpy(), test_y.numpy() 66 | 67 | results = {'reg': self.regression_metric(model_zs=((train_X, train_y), (test_X, test_y)), mode=self.mode), 68 | 'cls': self.classification_metric(model_zs=((train_X, train_y), (test_X, test_y)), mode=self.mode)} 69 | log_dict = {} 70 | for task, scores in results.items(): 71 | metric_name = self.metric_names[task] 72 | for i, score in enumerate(scores): 73 | factor_name = self.factor_names[self.factor_ids[task][i]][:6] 74 | log_dict[f'{metric_name}_{factor_name}'] = score 75 | log_dict[f'{metric_name}_mean'] = scores.mean() 76 | return log_dict 77 | 78 | def get_gt_rep(self, dataloader): 79 | latents, targets = [], [] 80 | for _, t in dataloader: 81 | latents.append(deepcopy(t)) 82 | targets.append(deepcopy(t)) 83 | latents = torch.cat(latents) 84 | targets = torch.cat(targets) 85 | return latents, targets 86 | 87 | def make_rep(self, reps_train, reps_test): 88 | reps = torch.cat([reps_train, reps_test]) 89 | reps = (reps + 1) / self.latent_sizes 90 | n_train, n_test = reps_train.shape[0], reps_test.shape[0] 91 | f_mapping = rep_mapping_functions[self.mapping] 92 | reps_mapped = f_mapping(reps) 93 | reps_train, reps_test = torch.split(reps_mapped, (n_train, n_test)) 94 | return reps_train, reps_test 95 | 96 | @property 97 | def name(self): 98 | return '{}_{}{}_reg-{}_cls-{}_{}fold{}'.format( 99 | self.__class__.__name__, 100 | self.mapping, 101 | f'_{self.n_train}train' if self.n_train is not None else '', 102 | self.model_names['reg'], 103 | self.model_names['cls'], 104 | self.n_fold, 105 | '_rev' if self.reverse_task_type else '', 106 | ) 107 | 108 | def eval_gt_rep(config, args): 109 | seed_everything(config['exp_params']['random_seed']) 110 | 111 | exp_name = 'GTRep_{}'.format( 112 | config['exp_params']['dataset'], 113 | ) 114 | 115 | exp_root = os.path.join(config['logging_params']['save_dir'], 116 | config['exp_params']['dataset'], 117 | exp_name, f"version_{config['exp_params']['random_seed']}") 118 | 119 | dm = get_datamodule(config['exp_params']['dataset'], 120 | data_dir=config['exp_params']['data_path'], 121 | batch_size=config['exp_params']['batch_size'], 122 | num_workers=0, 123 | n_train=config['eval_params']['n_train'], 124 | n_fold=config['eval_params']['n_fold'], 125 | random_seed=config['exp_params']['random_seed'], 126 | ) 127 | 128 | evaluator = ScikitLearnGTRepEvaluator(backbone=None, datamodule=dm, **config['eval_params']) 129 | 130 | print(f"======= {evaluator.name} with {exp_name} =======") 131 | ft_root = os.path.join(exp_root, evaluator.name, 132 | f"version_{config['exp_params']['random_seed']}") 133 | 134 | if not os.path.isdir(ft_root): 135 | os.makedirs(ft_root, exist_ok=True) 136 | eval_res = evaluator.eval() 137 | 138 | ft_logger = None if args.nowb else WandbLogger(project=args.project, 139 | name=f"{evaluator.name}_{exp_name}", 140 | save_dir=ft_root, 141 | tags=['GT_Rep', 'scikit_eval_v2', ] + args.tags, 142 | config=config, 143 | reinit=True 144 | ) 145 | 146 | if ft_logger: 147 | ft_logger.log_hyperparams(config) 148 | ft_logger.log_metrics(eval_res) 149 | ft_logger.experiment.finish() 150 | else: 151 | print(eval_res) 152 | 153 | 154 | 155 | def main(): 156 | parser = argparse.ArgumentParser(description='Generic runner for VAE models') 157 | add_vae_argument(parser) 158 | args = parser.parse_args() 159 | config, sklearn_eval_cfg = setup_experiment(args) 160 | 161 | # setting hyperparameters 162 | # for data in ('dsprites90d_random_v5', 'mpi3d_real_random_v5',): 163 | for data in ('mpi3d_real_random_v5',): 164 | for mapping in ('polynomial', ): 165 | for seed in (2001, 2002, 2003): 166 | # for seed in (2003, ): 167 | # sklearn eval 168 | for n_train in (500, ): 169 | config['eval_params'] = sklearn_eval_cfg 170 | config['eval_params']['n_train'] = n_train 171 | config['exp_params'][ 172 | 'data_path'] = 'YOUR_PATH_TO_DATA' 173 | config['exp_params']['random_seed'] = seed 174 | config['exp_params']['dataset'] = data 175 | 176 | if 'mpi3d' in data: 177 | config['exp_params'][ 178 | 'data_path'] = 'YOUR_PATH_TO_DATA' 179 | config['model_params']['input_size'] = [3, 64, 64] 180 | 181 | if args.gbt: 182 | config['eval_params']['reg_model'] = 'GBTR' 183 | config['eval_params']['cls_model'] = 'GBTC' 184 | args.tags = ['GBT', ] 185 | config['eval_params']['mapping'] = mapping 186 | eval_gt_rep(config, args) 187 | 188 | 189 | if __name__ == '__main__': 190 | main() 191 | -------------------------------------------------------------------------------- /src/scripts/experiments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | import shutil 6 | import yaml 7 | 8 | import torch 9 | from pytorch_lightning import Trainer, seed_everything 10 | from pytorch_lightning.loggers import WandbLogger 11 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 12 | 13 | from models import vae_models 14 | from datasets import get_datamodule 15 | from evaluation import ScikitLearnEvaluator, DisentangleMetricEvaluator, TopoSimEval 16 | from commons.callbacks import MyModelCheckpoint 17 | 18 | sys.path.append(os.path.realpath('..')) 19 | 20 | 21 | def load_yaml_file(file_name): 22 | with open(file_name, 'r') as file: 23 | try: 24 | config = yaml.safe_load(file) 25 | except yaml.YAMLError as exc: 26 | print(exc) 27 | return config 28 | 29 | 30 | def train_vae(config, args): 31 | seed_everything(config['exp_params']['random_seed']) 32 | 33 | # set experiments 34 | vae = vae_models[config['model_params']['name']]( 35 | **config['model_params'], 36 | ) 37 | 38 | exp_name = vae.name + '_{}_batch{}_{}epochs'.format( 39 | config['exp_params']['dataset'], 40 | config['exp_params']['batch_size'], 41 | config['exp_params']['max_epochs'], 42 | ) 43 | 44 | exp_root = os.path.join(config['logging_params']['save_dir'], 45 | config['exp_params']['dataset'], 46 | exp_name, f"version_{config['exp_params']['random_seed']}", 'debug' if args.debug else '') 47 | 48 | ckpoint_path = os.path.join(exp_root, 'checkpoints', 'last.ckpt') 49 | 50 | if os.path.isfile(ckpoint_path) and config['exp_params']['max_epochs'] == 0 and\ 51 | not (args.overwrite or args.test or args.compmetric): 52 | print(f"Exp {exp_name} exsited") 53 | return 54 | 55 | dm = get_datamodule(config['exp_params']['dataset'], 56 | data_dir=config['exp_params']['data_path'], 57 | batch_size=config['exp_params']['batch_size'], 58 | num_workers=config['exp_params']['train_workers'], 59 | random_seed=config['exp_params']['random_seed'], 60 | virtual_n_samples=config['exp_params']['val_steps'] * config['exp_params']['batch_size'], 61 | ) 62 | 63 | if not os.path.isdir(exp_root): 64 | os.makedirs(exp_root, exist_ok=True) 65 | 66 | logger = None if args.nowb else WandbLogger(project=args.project, 67 | name=exp_name, 68 | save_dir=exp_root, 69 | tags=[config['model_params']['name'], 'pretrain'] + args.tags, 70 | config=config, 71 | reinit=True 72 | ) 73 | if logger is not None: 74 | logger.watch(vae, log="all") 75 | # Init ModelCheckpoint callback, monitoring 'val_loss' 76 | best_checkpoint_callback = MyModelCheckpoint( 77 | dirpath=os.path.join(exp_root, 'checkpoints'), 78 | monitor='val_loss', 79 | filename="best", 80 | every_n_epochs=1, 81 | verbose=False, 82 | save_last=False, 83 | ) 84 | checkpoint_callback = MyModelCheckpoint( 85 | dirpath=os.path.join(exp_root, 'checkpoints'), 86 | save_top_k=-1, 87 | filename="{epoch:02d}", 88 | every_n_epochs=20, 89 | verbose=False, 90 | save_last=True, 91 | ) 92 | 93 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 94 | callbacks = [checkpoint_callback, best_checkpoint_callback] 95 | if not args.nowb: 96 | callbacks.append(lr_monitor) 97 | 98 | # set runner 99 | runner = Trainer(min_epochs=1, 100 | logger=logger, 101 | log_every_n_steps=100, 102 | num_sanity_val_steps=5, 103 | deterministic=True, 104 | benchmark=False, 105 | max_steps=20 if args.debug else config['exp_params']['train_steps'], 106 | val_check_interval=20 if args.debug else config['exp_params']['val_steps'], 107 | limit_train_batches=1.0, 108 | callbacks=callbacks, 109 | # deterministic=True, 110 | **config['trainer_params'], 111 | ) 112 | 113 | if os.path.isfile(ckpoint_path) and config['exp_params']['max_epochs'] == 0 and not args.notrain: 114 | print(f"======= Training {exp_name} =======") 115 | runner.fit(vae, dm) 116 | print(f"Best model at {best_checkpoint_callback.best_model_path}") 117 | else: 118 | print("No pre-training") 119 | 120 | if config['exp_params']['max_epochs'] > 0: 121 | ckpt = torch.load(ckpoint_path, map_location=torch.device('cpu')) 122 | vae.load_state_dict(ckpt['state_dict']) 123 | 124 | if args.test: 125 | print(f"======= Testing {exp_name} =======") 126 | runner.test(vae, datamodule=dm) 127 | 128 | if args.compmetric: 129 | if 'Recurrent' in config['model_params']['name']: 130 | evaluator = TopoSimEval(vae, dm) 131 | else: 132 | evaluator = DisentangleMetricEvaluator(vae, dm) 133 | res = evaluator.eval() 134 | if logger is not None: 135 | logger.log_metrics(res) 136 | else: 137 | print(res) 138 | 139 | if logger is not None: 140 | logger.experiment.finish() 141 | 142 | 143 | def scikitlearn_eval(config, args): 144 | if args.finetune: 145 | # https://github.com/pytorch/pytorch/issues/11201 146 | # import torch.multiprocessing 147 | # torch.multiprocessing.set_sharing_strategy('file_system') 148 | 149 | seed_everything(config['exp_params']['random_seed']) 150 | 151 | """learn a task module on learned encoder""" 152 | vae = vae_models[config['model_params']['name']]( 153 | **config['model_params'], 154 | ) 155 | if args.gpu is not None: 156 | vae = vae.to(f'cuda:{args.gpu[0]}') 157 | 158 | exp_name = vae.name + '_{}_batch{}_{}epochs'.format( 159 | config['exp_params']['dataset'], 160 | config['exp_params']['batch_size'], 161 | config['exp_params']['max_epochs'], 162 | ) 163 | 164 | exp_root = os.path.join(config['logging_params']['save_dir'], 165 | config['exp_params']['dataset'], 166 | exp_name, f"version_{config['exp_params']['random_seed']}") 167 | ckpoint_path = os.path.join(exp_root, 'checkpoints', f"{config['eval_params']['ckpoint']}.ckpt") 168 | 169 | dm = get_datamodule(config['exp_params']['dataset'], 170 | data_dir=config['exp_params']['data_path'], 171 | batch_size=config['exp_params']['batch_size'], 172 | num_workers=0, 173 | n_train=config['eval_params']['n_train'], 174 | n_fold=config['eval_params']['n_fold'], 175 | random_seed=config['exp_params']['random_seed'], 176 | in_distribution_test=args.TestInDist, 177 | ) 178 | 179 | config['eval_params']['inDist'] = args.TestInDist 180 | 181 | print("Loading checkpoint at {}".format(ckpoint_path)) 182 | ckpt = torch.load(ckpoint_path, map_location=torch.device('cpu')) 183 | vae.load_state_dict(ckpt['state_dict']) 184 | print("Checkpoint loaded!") 185 | 186 | evaluator = ScikitLearnEvaluator( 187 | vae, 188 | dm, 189 | **config['eval_params'], 190 | ) 191 | print(f"======= {evaluator.name} with {exp_name} =======") 192 | 193 | ft_root = os.path.join(exp_root, evaluator.name, 194 | f"version_{config['exp_params']['random_seed']}") 195 | 196 | if not os.path.isdir(ft_root): 197 | os.makedirs(ft_root, exist_ok=True) 198 | eval_res = evaluator.eval() 199 | 200 | ft_logger = None if args.nowb else WandbLogger(project=args.project, 201 | name=f"{evaluator.name}_{exp_name}", 202 | save_dir=ft_root, 203 | tags=[config['model_params']['name'], 'scikit_eval_v2', ] + args.tags, 204 | config=config, 205 | reinit=True 206 | ) 207 | 208 | if ft_logger: 209 | ft_logger.log_hyperparams(config) 210 | ft_logger.log_metrics(eval_res) 211 | ft_logger.experiment.finish() 212 | else: 213 | print(eval_res) 214 | 215 | 216 | def setup_experiment(args): 217 | config = load_yaml_file(args.filename) 218 | 219 | sklearn_eval_config_file = 'configs/scikitlearn_eval.yaml' 220 | sklearn_eval_cfg = load_yaml_file(sklearn_eval_config_file) 221 | 222 | if args.debug: 223 | config['exp_params']['train_workers'], config['exp_params']['val_workers'] = 0, 0 224 | 225 | config['trainer_params']['gpus'] = args.gpu 226 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # Deterministic behavior of torch.addmm. Please refer to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility 227 | print(config['trainer_params']['gpus']) 228 | return config, sklearn_eval_cfg 229 | 230 | def add_vae_argument(parser): 231 | parser.add_argument('--config', '-c', 232 | dest="filename", 233 | metavar='FILE', 234 | help='path to the config file', 235 | default='configs/vae.yaml') 236 | parser.add_argument('--debug', '-d', action='store_true', 237 | help='debug mode') 238 | parser.add_argument('--test', '-t', action='store_true', 239 | help='test vae models on the pretraining task') 240 | # parser.add_argument('--notrain', '-ntr', action='store_true', 241 | # help='do not run pretraining') 242 | parser.add_argument('--finetune', '-ft', action='store_true', 243 | help='finetune mode: only train a linear head or GBT on trained vae') 244 | # parser.add_argument('--nofinetune', '-nft', action='store_true', 245 | # help='do not finetune, only train vae') 246 | parser.add_argument('--ckpt', '-cp', action='store_true', 247 | help='path to checkpoints') 248 | parser.add_argument('--overwrite', '-ow', action='store_true', 249 | help='overwrite existing checkpoint otherwise skip training') 250 | parser.add_argument('--gpu', '-g', type=int, nargs='+', 251 | help='gpu ids') 252 | parser.add_argument('--tags', '-tg', type=str, nargs='+', default=[], 253 | help='tags add to experiments') 254 | parser.add_argument('--project', '-pj', type=str, default='comp_gen', 255 | help='the name of project for W&B logger ') 256 | parser.add_argument('--nowb', '-nw', action='store_true', 257 | help='do not run log on weight and bias') 258 | parser.add_argument('--gbt', '-gbt', action='store_true', 259 | help='use gbt readout models, otherwise use linear models') 260 | parser.add_argument('--compmetric', '-cm', action='store_true', 261 | help='test the disentanglement score with dis-lib metrics') 262 | parser.add_argument('--TestInDist', '-id', action='store_true', 263 | help='test the performance in training data') -------------------------------------------------------------------------------- /src/scripts/run_el.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | sys.path.append(os.path.realpath('..')) 6 | import argparse 7 | import yaml 8 | from itertools import product 9 | from scripts.experiments import add_vae_argument, train_vae, scikitlearn_eval, setup_experiment 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser(description='Generic runner for VAE models') 13 | add_vae_argument(parser) 14 | args = parser.parse_args() 15 | args.filename = 'configs/rec_el.yaml' 16 | config, sklearn_eval_cfg = setup_experiment(args) 17 | 18 | for data in ('dsprites90d_random_v5', ): 19 | # for data in ('mpi3d_real_random_v5', ): 20 | for recon_loss, beta, latent_size, dict_size, arch in product(('bce', ), 21 | # (0, ), (8, 10, 12), (128, 256, 512), 22 | (0, ), (10, ), (512, ), 23 | # ('burgess', 'burgess_wide') 24 | ('base', ) 25 | # ('large', ) 26 | ): 27 | for seed in (2001, 2002, 2003): 28 | config['model_params']['name'] = 'RecurrentEL' 29 | config['model_params']['beta'] = beta 30 | config['model_params']['latent_size'] = latent_size 31 | config['model_params']['dictionary_size'] = dict_size 32 | config['model_params']['recon_loss'] = recon_loss 33 | config['model_params']['architecture'] = arch 34 | 35 | config['model_params']['fix_length'] = False # using fix length message 36 | config['model_params']['deterministic'] = False # using greedy sampling 37 | 38 | config['exp_params']['random_seed'] = seed 39 | config['exp_params']['train_steps'] = 500000 40 | config['exp_params']['val_steps'] = 5000 41 | config['exp_params']['dataset'] = data 42 | config['exp_params'][ 43 | 'data_path'] = 'YOUR_PATH_TO_DATA' 44 | 45 | if 'mpi3d' in data: 46 | config['exp_params'][ 47 | 'data_path'] = 'YOUR_PATH_TO_DATA' 48 | config['model_params']['input_size'] = [3, 64, 64] 49 | config['exp_params']['train_steps'] = 1000000 50 | config['exp_params']['val_steps'] = 10000 51 | 52 | train_vae(config, args) 53 | 54 | if args.sklearn: 55 | # sklearn eval 56 | # for mode, n_train in product(('pre', 'post', 'latent'), (1000, 500, 100), ): 57 | for mode, n_train in product(('pre', 'post', 'latent'), (500, ), ): 58 | config['eval_params'] = sklearn_eval_cfg 59 | config['eval_params']['mode'] = mode 60 | config['eval_params']['n_train'] = n_train 61 | # config['eval_params']['testOnTrain'] = True 62 | 63 | if args.gbt: 64 | config['eval_params']['reg_model'] = 'GBTR' 65 | config['eval_params']['cls_model'] = 'GBTC' 66 | args.tags = ['GBT', ] 67 | scikitlearn_eval(config, args) 68 | 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /src/scripts/run_tcvae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | sys.path.append(os.path.realpath('..')) 6 | import argparse 7 | import yaml 8 | from itertools import product 9 | from scripts.experiments import add_vae_argument, train_vae, scikitlearn_eval, setup_experiment 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser(description='Generic runner for VAE models') 13 | add_vae_argument(parser) 14 | args = parser.parse_args() 15 | config, sklearn_eval_cfg = setup_experiment(args) 16 | 17 | # setting hyperparameters 18 | for data in ('dsprites90d_random_v5', ): 19 | # for data in ('mpi3d_real_random_v6', ): 20 | for recon_loss, beta, arch in product(('bce', ), 21 | (0, 0.1, 0.5, 1, 4, 8), 22 | ('base', )): 23 | seeds = (2002, 2004, 2005) if beta == 4 and data == 'mpi3d_real_random_v5' else (2001, 2002, 2003) 24 | for seed in seeds: 25 | config['model_params']['name'] = 'BetaTCVAE' 26 | config['model_params']['beta'] = beta # alpha=1 and gamma=1 by default 27 | config['model_params']['latent_size'] = 10 28 | config['model_params']['recon_loss'] = recon_loss 29 | config['model_params']['architecture'] = arch 30 | 31 | config['exp_params']['random_seed'] = seed 32 | config['exp_params']['train_steps'] = 500000 33 | config['exp_params']['val_steps'] = 5000 34 | config['exp_params']['dataset'] = data 35 | 36 | if 'mpi3d' in data: 37 | config['exp_params'][ 38 | 'data_path'] = 'YourPathToData' 39 | config['model_params']['input_size'] = [3, 64, 64] 40 | config['exp_params']['train_steps'] = 1000000 41 | config['exp_params']['val_steps'] = 10000 42 | 43 | train_vae(config, args) 44 | 45 | if args.sklearn: 46 | # sklearn eval 47 | for mode, n_train in product(('pre', 'post', 'latent'), (1000, 500, 100), ): 48 | config['eval_params'] = sklearn_eval_cfg 49 | config['eval_params']['mode'] = mode 50 | config['eval_params']['n_train'] = n_train 51 | config['eval_params']['testOnTrain'] = True 52 | 53 | if args.gbt: 54 | config['eval_params']['reg_model'] = 'GBTR' 55 | config['eval_params']['cls_model'] = 'GBTC' 56 | args.tags = ['GBT', ] 57 | scikitlearn_eval(config, args) 58 | 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /src/scripts/run_vae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | sys.path.append(os.path.realpath('..')) 6 | import argparse 7 | import yaml 8 | from itertools import product 9 | from scripts.experiments import add_vae_argument, train_vae, scikitlearn_eval, setup_experiment 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser(description='Generic runner for VAE models') 13 | add_vae_argument(parser) 14 | args = parser.parse_args() 15 | config, sklearn_eval_cfg = setup_experiment(args) 16 | 17 | # setting hyperparameters 18 | for data in ('dsprites90d_random_v5', ): 19 | # for data in ('mpi3d_real_random_v5', ): 20 | for recon_loss, beta, arch in product(('bce', ), 21 | (0, 1), 22 | ('base', ) 23 | ): 24 | for seed in (2001, 2002, 2003): 25 | config['model_params']['beta'] = beta 26 | config['model_params']['latent_size'] = 10 27 | config['model_params']['recon_loss'] = recon_loss 28 | config['model_params']['architecture'] = arch 29 | 30 | config['exp_params']['train_steps'] = 500000 31 | config['exp_params']['val_steps'] = 5000 32 | config['exp_params']['random_seed'] = seed 33 | config['exp_params']['dataset'] = data 34 | config['exp_params'][ 35 | 'data_path'] = 'YOUR_PATH_TO_DATA' 36 | 37 | if 'mpi3d' in data: 38 | config['exp_params'][ 39 | 'data_path'] = 'YOUR_PATH_TO_DATA' 40 | config['model_params']['input_size'] = [3, 64, 64] 41 | config['exp_params']['train_steps'] = 1000000 42 | config['exp_params']['val_steps'] = 10000 43 | 44 | train_vae(config, args) 45 | 46 | if args.finetune: 47 | # for mode, n_train in product(('pre', 'post', 'latent'), (1000, 500, 100), ): 48 | for mode, n_train in product(('post', 'pre', 'latent'), (500, ), ): 49 | config['eval_params'] = sklearn_eval_cfg 50 | config['eval_params']['mode'] = mode 51 | config['eval_params']['n_train'] = n_train 52 | config['eval_params']['testOnTrain'] = True 53 | 54 | if args.gbt: 55 | config['eval_params']['reg_model'] = 'GBTR' 56 | config['eval_params']['cls_model'] = 'GBTC' 57 | args.tags = ['GBT', ] 58 | scikitlearn_eval(config, args) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | --------------------------------------------------------------------------------