├── .gitignore ├── README.md ├── configs ├── celeba.yml ├── cifar10.yml ├── cifar100.yml ├── fmnist.yml ├── mnist.yml ├── news20.yml └── omniglot.yml ├── main.py ├── minimal_requirements.txt ├── models ├── losses.py ├── model.py ├── model_smalltree.py └── networks.py ├── train ├── train.py ├── train_tree.py └── validate_tree.py ├── tree_exploration.ipynb ├── treevae.png ├── treevae.yml └── utils ├── data_utils.py ├── model_utils.py ├── plotting_utils.py ├── training_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | omniglot/* 7 | out/* 8 | plots/* 9 | omniglot_plot_ckpts/* 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | .idea/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # models and logs 138 | models/experiments/ 139 | models/logs/ 140 | results_* 141 | bin/* 142 | pyvenv.cfg 143 | Untitled.ipynb 144 | wandb/* 145 | run/* 146 | old* 147 | dataset/* 148 | *.png 149 | data/* 150 | logs/* 151 | datasets/* 152 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tree Variational Autoencoders 2 | This is the PyTorch repository for the NeurIPS 2023 Spotlight Publication (https://neurips.cc/virtual/2023/poster/71188). 3 | 4 | TreeVAE is a new generative method that learns the optimal tree-based posterior distribution of latent variables to capture the hierarchical structures present in the data. It adapts the architecture to discover the optimal tree for encoding dependencies between latent variables. TreeVAE optimizes the balance between shared and specialized architecture, enhancing the learning and adaptation capabilities of generative models. 5 | An example of a tree learned by TreeVAE is depicted in the figure below. Each edge and each split are encoded by neural networks, while the circles depict latent variables. Each sample is associated with a probability distribution over different paths of the discovered tree. The resulting tree thus organizes the data into an interpretable hierarchical structure in an unsupervised fashion, optimizing the amount of shared information between samples. In CIFAR-10, for example, the method divides the vehicles and animals into two different subtrees and similar groups (such as planes and ships) share common ancestors. 6 | 7 | ![Alt text](https://github.com/lauramanduchi/treevae/blob/main/treevae.png?raw=true) 8 | For running TreeVAE: 9 | 10 | 1. Create a new environment with the ```treevae.yml``` or ```minimal_requirements.txt``` file. 11 | 2. Select the dataset you wish to use by changing the default config_name in the main.py parser. 12 | 3. Potentially adapt default configuration in the config of the selected dataset (config/data_name.yml), the full set of config parameters with their explanations can be found in ```config/mnist.yml```. 13 | 4. For Weights & Biases support, set project & entity in ```train/train.py``` and change the value of ```wandb_logging``` to ```online``` in the config file. 14 | 5. Run ```main.py```. 15 | 16 | For exploring TreeVAE results (including the discovered tree, the generation of new images, the clustering performances and much more) we created a jupyter notebook (```tree_exploration.ipynb```): 17 | 1. Run the steps above by setting ```save_model=True```. 18 | 2. Copy the experiment path where the model is saved (it will be printed out). 19 | 3. Open ```tree_exploration.ipynb```, replace the experiment path with yours, and have fun exploring the model! 20 | 21 | DISCLAIMER: This PyTorch repository was thoroughly debugged and tested, however, please note that the experiments of the submission were performed using the repository with the Tensorflow code (https://github.com/lauramanduchi/treevae-tensorflow). 22 | 23 | ## Citing 24 | To cite TreeVAE please use the following BibTEX entries: 25 | 26 | ``` 27 | @inproceedings{ 28 | manduchi2023tree, 29 | title={Tree Variational Autoencoders}, 30 | author={Laura Manduchi and Moritz Vandenhirtz and Alain Ryser and Julia E Vogt}, 31 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 32 | year={2023}, 33 | url={https://openreview.net/forum?id=adq0oXb9KM} 34 | } 35 | ``` 36 | -------------------------------------------------------------------------------- /configs/celeba.yml: -------------------------------------------------------------------------------- 1 | run_name: 'celeba' 2 | 3 | data: 4 | data_name: 'celeba' 5 | num_clusters_data: 1 6 | 7 | training: 8 | num_epochs: 150 9 | num_epochs_smalltree: 150 10 | num_epochs_intermediate_fulltrain: 0 11 | num_epochs_finetuning: 0 12 | batch_size: 256 13 | lr: 0.001 14 | weight_decay: 0.00001 15 | decay_lr: 0.1 16 | decay_stepsize: 100 17 | decay_kl: 0.01 18 | kl_start: 0.01 19 | 20 | inp_shape: 12288 21 | latent_dim: [64,64,64,64,64,64] 22 | mlp_layers: [4096, 512, 512, 512, 512, 512] 23 | initial_depth: 1 24 | activation: 'mse' 25 | encoder: 'cnn2' 26 | grow: True 27 | prune: True 28 | num_clusters_tree: 10 29 | augment: True 30 | augmentation_method: 'InfoNCE,instancewise_full' 31 | aug_decisions_weight: 100 32 | compute_ll: False 33 | 34 | globals: 35 | wandb_logging: 'disabled' 36 | eager_mode: False 37 | seed: 42 38 | save_model: True 39 | config_name: 'celeba' -------------------------------------------------------------------------------- /configs/cifar10.yml: -------------------------------------------------------------------------------- 1 | run_name: 'cifar10' 2 | 3 | data: 4 | data_name: 'cifar10' 5 | num_clusters_data: 10 6 | 7 | training: 8 | num_epochs: 150 9 | num_epochs_smalltree: 150 10 | num_epochs_intermediate_fulltrain: 0 11 | num_epochs_finetuning: 0 12 | batch_size: 256 13 | lr: 0.001 14 | weight_decay: 0.00001 15 | decay_lr: 0.1 16 | decay_stepsize: 100 17 | decay_kl: 0.01 18 | kl_start: 0.01 19 | 20 | inp_shape: 3072 21 | latent_dim: [64,64,64,64,64,64] 22 | mlp_layers: [4096, 512, 512, 512, 512, 512] 23 | initial_depth: 1 24 | activation: 'mse' 25 | encoder: 'cnn2' 26 | grow: True 27 | prune: True 28 | num_clusters_tree: 10 29 | augment: True 30 | augmentation_method: 'InfoNCE,instancewise_full' 31 | aug_decisions_weight: 100 32 | compute_ll: False 33 | 34 | globals: 35 | wandb_logging: 'disabled' 36 | eager_mode: False 37 | seed: 42 38 | save_model: False 39 | config_name: 'cifar10' 40 | -------------------------------------------------------------------------------- /configs/cifar100.yml: -------------------------------------------------------------------------------- 1 | run_name: 'cifar100' 2 | 3 | data: 4 | data_name: 'cifar100' 5 | num_clusters_data: 20 6 | 7 | training: 8 | num_epochs: 150 9 | num_epochs_smalltree: 150 10 | num_epochs_intermediate_fulltrain: 0 11 | num_epochs_finetuning: 0 12 | batch_size: 256 13 | lr: 0.001 14 | weight_decay: 0.00001 15 | decay_lr: 0.1 16 | decay_stepsize: 100 17 | decay_kl: 0.01 18 | kl_start: 0.01 19 | 20 | inp_shape: 3072 21 | latent_dim: [64,64,64,64,64,64] 22 | mlp_layers: [4096, 512, 512, 512, 512, 512] 23 | initial_depth: 1 24 | activation: 'mse' 25 | encoder: 'cnn2' 26 | grow: True 27 | prune: True 28 | num_clusters_tree: 20 29 | augment: True 30 | augmentation_method: 'InfoNCE,instancewise_full' 31 | aug_decisions_weight: 100 32 | compute_ll: False 33 | 34 | globals: 35 | wandb_logging: 'disabled' 36 | eager_mode: False 37 | seed: 42 38 | save_model: False 39 | config_name: 'cifar100' 40 | -------------------------------------------------------------------------------- /configs/fmnist.yml: -------------------------------------------------------------------------------- 1 | run_name: 'fmnist' 2 | 3 | data: 4 | data_name: 'fmnist' 5 | num_clusters_data: 10 6 | 7 | training: 8 | num_epochs: 150 9 | num_epochs_smalltree: 150 10 | num_epochs_intermediate_fulltrain: 80 11 | num_epochs_finetuning: 200 12 | batch_size: 256 13 | lr: 0.001 14 | weight_decay: 0.00001 15 | decay_lr: 0.1 16 | decay_stepsize: 100 17 | decay_kl: 0.001 18 | kl_start: 0.0 19 | 20 | inp_shape: 784 21 | latent_dim: [8, 8, 8, 8, 8, 8] 22 | mlp_layers: [128, 128, 128, 128, 128, 128] 23 | initial_depth: 1 24 | activation: "sigmoid" 25 | encoder: 'cnn1' 26 | grow: True 27 | prune: True 28 | num_clusters_tree: 10 29 | compute_ll: False 30 | augment: False 31 | augmentation_method: 'simple' 32 | aug_decisions_weight: 1 33 | 34 | globals: 35 | wandb_logging: 'disabled' 36 | eager_mode: True 37 | seed: 42 38 | save_model: False 39 | config_name: 'fmnist' -------------------------------------------------------------------------------- /configs/mnist.yml: -------------------------------------------------------------------------------- 1 | run_name: 'mnist' # name of the run 2 | 3 | data: 4 | data_name: 'mnist' # name of the dataset 5 | num_clusters_data: 10 # number of true clusters in the data (if known), this is used only for evaluation purposes 6 | 7 | training: 8 | num_epochs: 150 # number of epochs to train the initial tree 9 | num_epochs_smalltree: 150 # number of epochs to train the sub-tree during growing 10 | num_epochs_intermediate_fulltrain: 80 # number of epochs to train the full tree during growing 11 | num_epochs_finetuning: 200 # number of epochs to train the final tree 12 | batch_size: 256 # batch size 13 | lr: 0.001 # learning rate 14 | weight_decay: 0.00001 # optimizer weight decay 15 | decay_lr: 0.1 # learning rate decay 16 | decay_stepsize: 100 # number of epochs after which learning rate decays 17 | decay_kl: 0.001 # KL-annealing weight increase per epoch (capped at 1) 18 | kl_start: 0.0 # KL-annealing weight initialization 19 | 20 | inp_shape: 784 # The total dimensions of the input data (if rgb images of 32x32 then 32x32x3) 21 | latent_dim: [8, 8, 8, 8, 8, 8] # A list of latent dimensions for each depth of the tree from the bottom to the root, last value is the dimensionality of the root node 22 | mlp_layers: [128, 128, 128, 128, 128, 128] # A list of hidden units number for the MLP transformations for each depth of the tree from bottom to root 23 | initial_depth: 1 # The initial depth of the tree (root has depth 0 and a root with two leaves has depth 1) 24 | activation: "sigmoid" # The name of the activation function for the reconstruction loss [sigmoid, mse] 25 | encoder: 'cnn1' # Type of encoder/decoder used 26 | grow: True # Whether to grow the tree 27 | prune: True # Whether to prune the tree of empty leaves 28 | num_clusters_tree: 10 # The maximum number of leaves of the final tree 29 | compute_ll: False # Whether to compute the log-likelihood estimation at the end of the training (it might take some time) 30 | augment: False # Whether to use contrastive learning through augmentation 31 | augmentation_method: 'simple' # The type of augmentation method used if augment is True 32 | aug_decisions_weight: 1 # The weight of the contrastive losses 33 | 34 | globals: 35 | wandb_logging: 'disabled' # Whether to log to wandb [online, offline, disabled] 36 | eager_mode: True # Whether to run in eager or graph mode 37 | seed: 42 # Random seed 38 | save_model: False # Whether to save the model. Set to True for inspecting models in notebook 39 | config_name: 'mnist' 40 | -------------------------------------------------------------------------------- /configs/news20.yml: -------------------------------------------------------------------------------- 1 | run_name: 'news20' 2 | 3 | data: 4 | data_name: 'news20' 5 | num_clusters_data: 20 6 | 7 | training: 8 | num_epochs: 150 9 | num_epochs_smalltree: 150 10 | num_epochs_intermediate_fulltrain: 80 11 | num_epochs_finetuning: 200 12 | batch_size: 256 13 | lr: 0.001 14 | weight_decay: 0.00001 15 | decay_lr: 0.1 16 | decay_stepsize: 100 17 | decay_kl: 0.001 18 | kl_start: 0.0 19 | 20 | inp_shape: 2000 21 | latent_dim: [4, 4, 4, 4, 4, 4, 4] 22 | mlp_layers: [128, 128, 128, 128, 128, 128, 128] 23 | initial_depth: 1 24 | activation: "sigmoid" 25 | encoder: 'mlp' 26 | grow: True 27 | prune: True 28 | num_clusters_tree: 20 29 | compute_ll: False 30 | augment: False 31 | augmentation_method: 'simple' 32 | aug_decisions_weight: 1 33 | 34 | globals: 35 | wandb_logging: 'disabled' 36 | eager_mode: True 37 | seed: 42 38 | save_model: False 39 | config_name: 'news20' -------------------------------------------------------------------------------- /configs/omniglot.yml: -------------------------------------------------------------------------------- 1 | run_name: 'omniglot' 2 | 3 | data: 4 | data_name: 'omniglot' 5 | num_clusters_data: 5 6 | path: 'datasets/omniglot' 7 | 8 | training: 9 | num_epochs: 150 10 | num_epochs_smalltree: 150 11 | num_epochs_intermediate_fulltrain: 80 12 | num_epochs_finetuning: 200 13 | batch_size: 256 14 | lr: 0.001 15 | weight_decay: 0.00001 16 | decay_lr: 0.1 17 | decay_stepsize: 100 18 | decay_kl: 0.001 19 | kl_start: 0.001 20 | 21 | inp_shape: 784 22 | latent_dim: [8, 8, 8, 8, 8, 8] 23 | mlp_layers: [128, 128, 128, 128, 128, 128] 24 | initial_depth: 1 25 | activation: "sigmoid" 26 | encoder: 'cnn_omni' 27 | grow: True 28 | prune: True 29 | num_clusters_tree: 5 30 | compute_ll: False 31 | augment: True 32 | augmentation_method: 'simple' 33 | aug_decisions_weight: 1 34 | 35 | 36 | globals: 37 | wandb_logging: 'online' 38 | eager_mode: True 39 | seed: 42 40 | save_model: False 41 | config_name: 'omniglot' -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Runs the treeVAE model. 3 | """ 4 | import argparse 5 | from pathlib import Path 6 | import distutils 7 | 8 | from train.train import run_experiment 9 | from utils.utils import prepare_config 10 | 11 | 12 | def main(): 13 | project_dir = Path(__file__).absolute().parent 14 | print("Project directory:", project_dir) 15 | 16 | parser = argparse.ArgumentParser() 17 | 18 | # Model parameters 19 | parser.add_argument('--data_name', type=str, help='the dataset') 20 | parser.add_argument('--num_epochs', type=int, help='the number of training epochs') 21 | parser.add_argument('--num_epochs_finetuning', type=int, help='the number of finetuning epochs') 22 | parser.add_argument('--num_epochs_intermediate_fulltrain', type=int, help='the number of finetuning epochs during training') 23 | parser.add_argument('--num_epochs_smalltree', type=int, help='the number of sub-tree training epochs') 24 | 25 | parser.add_argument('--num_clusters_data', type=int, help='the number of clusters in the data') 26 | parser.add_argument('--num_clusters_tree', type=int, help='the max number of leaves of the tree') 27 | 28 | parser.add_argument('--kl_start', type=float, nargs='?', const=0., 29 | help='initial KL divergence from where annealing starts') 30 | parser.add_argument('--decay_kl', type=float, help='KL divergence annealing') 31 | parser.add_argument('--latent_dim', type=str, help='specifies the latent dimensions of the tree') 32 | parser.add_argument('--mlp_layers', type=str, help='specifies how many layers should the MLPs have') 33 | 34 | parser.add_argument('--grow', type=lambda x: bool(distutils.util.strtobool(x)), help='whether to grow the tree') 35 | parser.add_argument('--augment', type=lambda x: bool(distutils.util.strtobool(x)), help='augment images or not') 36 | parser.add_argument('--augmentation_method', type=str, help='none vs simple augmentation vs contrastive approaches') 37 | parser.add_argument('--aug_decisions_weight', type=float, 38 | help='weight of similarity regularizer for augmented images') 39 | parser.add_argument('--compute_ll', type=lambda x: bool(distutils.util.strtobool(x)), 40 | help='whether to compute the log-likelihood') 41 | 42 | # Other parameters 43 | parser.add_argument('--save_model', type=lambda x: bool(distutils.util.strtobool(x)), 44 | help='specifies if the model should be saved') 45 | parser.add_argument('--eager_mode', type=lambda x: bool(distutils.util.strtobool(x)), 46 | help='specifies if the model should be run in graph or eager mode') 47 | parser.add_argument('--num_workers', type=int, help='number of workers in dataloader') 48 | parser.add_argument('--seed', type=int, help='random number generator seed') 49 | parser.add_argument('--wandb_logging', type=str, help='online, disabled, offline enables logging in wandb') 50 | 51 | # Specify config name 52 | parser.add_argument('--config_name', default='mnist', type=str, 53 | choices=['mnist', 'fmnist', 'news20', 'omniglot', 'cifar10', 'cifar100', 'celeba'], 54 | help='the override file name for config.yml') 55 | 56 | args = parser.parse_args() 57 | configs = prepare_config(args, project_dir) 58 | run_experiment(configs) 59 | 60 | 61 | if __name__ == "__main__": 62 | main() 63 | -------------------------------------------------------------------------------- /minimal_requirements.txt: -------------------------------------------------------------------------------- 1 | attrs==22.1.0 2 | cudatoolkit==11.7 # conda install cudatoolkit=11.7 -c pytorch -c nvidia 3 | cudnn==8.9.2.26 4 | numpy==1.23.4 5 | Pillow==9.3.0 6 | python==3.9.15 # Do this first 7 | PyYAML==6.0 8 | scikit_learn==1.1.3 9 | scipy==1.10.0 10 | torch==2.0.1 # conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia 11 | torchmetrics==1.0.3 12 | torchvision==0.15.2 13 | tqdm==4.65.0 14 | wandb==0.13.5 -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss functions for the reconstruction term of the ELBO. 3 | """ 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def loss_reconstruction_binary(x, x_decoded_mean, weights): 9 | x = torch.flatten(x, start_dim=1) 10 | x_decoded_mean = [torch.flatten(decoded_leaf, start_dim=1) for decoded_leaf in x_decoded_mean] 11 | loss = torch.sum( 12 | torch.stack([weights[i] * 13 | F.binary_cross_entropy(input = x_decoded_mean[i], target = x, reduction='none').sum(dim=-1) 14 | for i in range(len(x_decoded_mean))], dim=-1), dim=-1) 15 | return loss 16 | 17 | def loss_reconstruction_mse(x, x_decoded_mean, weights): 18 | x = torch.flatten(x, start_dim=1) 19 | x_decoded_mean = [torch.flatten(decoded_leaf, start_dim=1) for decoded_leaf in x_decoded_mean] 20 | loss = torch.sum( 21 | torch.stack([weights[i] * 22 | F.mse_loss(input = x_decoded_mean[i], target = x, reduction='none').sum(dim=-1) 23 | for i in range(len(x_decoded_mean))], dim=-1), dim=-1) 24 | return loss 25 | 26 | def loss_reconstruction_cov_mse_eval(x, x_decoded_mean, weights): 27 | # NOTE Only use for evaluation purposes, as the clamping stops gradient flow 28 | # NOTE WE ASSUME IDENTITY MATRIX BECAUSE WE ASSUME THIS IMPLICITLY WHEN ONLY OPTIMIZING MSE 29 | scale = torch.diag(torch.ones_like(x_decoded_mean[0])) 30 | logpx = torch.zeros_like(weights[0]) 31 | for i in range(len(x_decoded_mean)): 32 | x_dist = torch.distributions.multivariate_normal.MultivariateNormal(loc=torch.clamp(x_decoded_mean[i],0,1), covariance_matrix=scale) 33 | logpx = logpx + weights[i] * x_dist.log_prob(x) 34 | return logpx 35 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | TreeVAE model. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.distributions as td 7 | from utils.model_utils import construct_tree, compute_posterior 8 | from models.networks import get_encoder, get_decoder, MLP, Router, Dense 9 | from models.losses import loss_reconstruction_binary, loss_reconstruction_mse 10 | from utils.model_utils import return_list_tree 11 | from utils.training_utils import calc_aug_loss 12 | 13 | class TreeVAE(nn.Module): 14 | """ 15 | A class used to represent a tree-based VAE. 16 | 17 | TreeVAE specifies a Variational Autoencoder with a tree structure posterior distribution of latent variables. 18 | It is defined by a bottom-up chain of deterministic transformations that from the input x compute the root 19 | representation of the data, and a probabilistic top-down architecture which takes the form of a tree. The 20 | top down tree is described by the probability distribution of its node (which depends on their parents) and the 21 | probability distribution of the decisions (what is the probability of following a certain path in the tree). 22 | Each node of the tree is described by the class Node in utils.model_utils. 23 | 24 | Attributes 25 | ---------- 26 | activation : str 27 | The name of the activation function for the reconstruction loss [sigmoid, mse] 28 | loss : models.losses 29 | The loss function used by the decoder to reconstruct the input 30 | alpha : float 31 | KL-annealing weight initialization 32 | encoded_sizes : list 33 | A list of latent dimensions for each depth of the tree from the bottom to the root 34 | hidden_layers : float 35 | A list of hidden units number for the MLP transformations for each depth of the tree from bottom to root 36 | depth : int 37 | The depth of the tree (root has depth 0 and a root with two leaves has depth 1) 38 | inp_shape : int 39 | The total dimensions of the input data (if images of 32x32 then 32x32x3) 40 | augment : bool 41 | Whether to use contrastive learning through augmentation, if False no augmentation is used 42 | augmentation_method : str 43 | The type of augmentation method used 44 | aug_decisions_weight : str 45 | The weight of the contrastive loss used in the decisions 46 | return_x : float 47 | Whether to return the input in the return dictionary of the forward method 48 | return_elbo : float 49 | Whether to return the sample-specific elbo in the return dictionary of the forward method 50 | return_bottomup : float 51 | Whether to return the list of bottom-up transformations (including encoder) 52 | bottom_up : str 53 | The list of bottom-up transformations [encoder, MLP, MLP, ...] up to the root 54 | contrastive_mlp : list 55 | The list of transformations from the bottom-up embeddings to the latent spaces 56 | in which the contrastive losses are applied 57 | transformations : list 58 | List of transformations (MLPs) associated with each node of the tree from root to bottom (left to right) 59 | denses : list 60 | List of dense layers for the sharing of top-down and bottom-up (MLPs) associated with each node of the tree 61 | from root to bottom (left to right). 62 | decisions : list 63 | List of decisions associated with each node of the tree from root to bottom (left to right) 64 | decoders : list 65 | List of decoders one for each leaf 66 | decisions_q : str 67 | List of decisions of the bottom-up associated with each node of the tree from root to bottom (left to right) 68 | tree : utils.model_utils.Node 69 | The root node of the tree 70 | 71 | Methods 72 | ------- 73 | forward(x) 74 | Compute the forward pass of the treeVAE model and return a dictionary of losses and optional outputs 75 | (like input, bottom-up and sample-specific elbo) when needed. 76 | compute_leaves() 77 | Return a list of leaf-nodes from left to right of the current tree (self.tree). 78 | compute_depth() 79 | Calculate the depth of the given tree (self.tree). 80 | attach_smalltree(node, small_model) 81 | Attach a sub tree (small_model) to the given node of the current tree. 82 | compute_reconstruction(x) 83 | Given the input x, it computes the reconstructions. 84 | generate_images(n_samples, device) 85 | Generate n_samples new images by sampling from the root and propagating through the entire tree. 86 | """ 87 | 88 | def __init__(self, **kwargs): 89 | """ 90 | Parameters 91 | ---------- 92 | kwargs : dict 93 | A dictionary of attributes (see config file). 94 | """ 95 | super(TreeVAE, self).__init__() 96 | self.kwargs = kwargs 97 | 98 | self.activation = self.kwargs['activation'] 99 | if self.activation == "sigmoid": 100 | self.loss = loss_reconstruction_binary 101 | elif self.activation == "mse": 102 | self.loss = loss_reconstruction_mse 103 | else: 104 | raise NotImplementedError 105 | # KL-annealing weight initialization 106 | self.alpha = torch.tensor(self.kwargs['kl_start']) 107 | 108 | # saving important variables to initialize the tree 109 | self.encoded_sizes = self.kwargs['latent_dim'] 110 | self.hidden_layers = self.kwargs['mlp_layers'] 111 | # check that the number of layers for bottom up is equal to top down 112 | if len(self.encoded_sizes) != len(self.hidden_layers): 113 | raise ValueError('Model is mispecified!!') 114 | self.depth = self.kwargs['initial_depth'] 115 | self.inp_shape = self.kwargs['inp_shape'] 116 | self.augment = self.kwargs['augment'] 117 | self.augmentation_method = self.kwargs['augmentation_method'] 118 | self.aug_decisions_weight = self.kwargs['aug_decisions_weight'] 119 | self.return_x = torch.tensor([False]) 120 | self.return_bottomup = torch.tensor([False]) 121 | self.return_elbo = torch.tensor([False]) 122 | 123 | # bottom up: the inference chain that from input computes the d units till the root 124 | if self.activation == "mse": 125 | size = int((self.inp_shape / 3)**0.5) 126 | encoder = get_encoder(architecture=self.kwargs['encoder'], encoded_size=self.hidden_layers[0], 127 | x_shape=self.inp_shape, size=size) 128 | else: 129 | encoder = get_encoder(architecture=self.kwargs['encoder'], encoded_size=self.hidden_layers[0], 130 | x_shape=self.inp_shape) 131 | 132 | self.bottom_up = nn.ModuleList([encoder]) 133 | for i in range(1, len(self.hidden_layers)): 134 | self.bottom_up.append(MLP(self.hidden_layers[i-1], self.encoded_sizes[i], self.hidden_layers[i])) 135 | 136 | # MLP's if we use contrastive loss on d's 137 | if len([i for i in self.augmentation_method if i in ['instancewise_first', 'instancewise_full']]) > 0: 138 | self.contrastive_mlp = nn.ModuleList([]) 139 | for i in range(0, len(self.hidden_layers)): 140 | self.contrastive_mlp.append(MLP(input_size=self.hidden_layers[i], encoded_size=self.encoded_sizes[i], hidden_unit=min(self.hidden_layers))) 141 | 142 | # top down: the generative model that from x computes the prior prob of all nodes from root till leaves 143 | # it has a tree structure which is constructed by passing a list of transformations and routers from root to 144 | # leaves visiting nodes layer-wise from left to right 145 | # N.B. root has None as transformation and leaves have None as routers 146 | # the encoded sizes and layers are reversed from bottom up 147 | # e.g. for bottom up [MLP(256, 32), MLP(128, 16), MLP(64, 8)] the list of top-down transformations are 148 | # [None, MLP(16, 64), MLP(16, 64), MLP(32, 128), MLP(32, 128), MLP(32, 128), MLP(32, 128)] 149 | 150 | # select the top down generative networks 151 | encoded_size_gen = self.encoded_sizes[-(self.depth+1):] # e.g. encoded_sizes 32,16,8, depth 1 152 | encoded_size_gen = encoded_size_gen[::-1] # encoded_size_gen = 16,8 => 8,16 153 | layers_gen = self.hidden_layers[-(self.depth+1):] # e.g. encoded_sizes 256,128,64, depth 1 154 | layers_gen = layers_gen[::-1] # encoded_size_gen = 128,64 => 64,128 155 | 156 | # add root transformation and dense layer, the dense layer is layer that connects the bottom-up with the nodes 157 | self.transformations = nn.ModuleList([None]) 158 | self.denses = nn.ModuleList([Dense(layers_gen[0], encoded_size_gen[0])]) 159 | # attach the rest of transformations and dense layers for each node 160 | for i in range(self.depth): 161 | for j in range(2 ** (i + 1)): 162 | self.transformations.append(MLP(encoded_size_gen[i], encoded_size_gen[i+1], layers_gen[i])) # MLP from depth i to i+1 163 | self.denses.append(Dense(layers_gen[i+1], encoded_size_gen[i+1])) # Dense at depth i+1 from bottom-up to top-down 164 | 165 | # compute the list of decisions for both bottom-up (decisions_q) and top-down (decisions) 166 | # for each node of the tree 167 | self.decisions = nn.ModuleList([]) 168 | self.decisions_q = nn.ModuleList([]) 169 | for i in range(self.depth): 170 | for _ in range(2 ** i): 171 | self.decisions.append(Router(encoded_size_gen[i], hidden_units=layers_gen[i])) # Router at node of depth i 172 | self.decisions_q.append(Router(layers_gen[i], hidden_units=layers_gen[i])) 173 | # the leaves do not have decisions (we set it to None) 174 | for _ in range(2 ** (self.depth)): 175 | self.decisions.append(None) 176 | self.decisions_q.append(None) 177 | 178 | # compute the list of decoders to attach to each node, note that internal nodes do not have a decoder 179 | # e.g. for a tree with depth 2: decoders = [None, None, None, Dec, Dec, Dec, Dec] 180 | self.decoders = nn.ModuleList([None for i in range(self.depth) for j in range(2 ** i)]) 181 | for _ in range(2 ** (self.depth)): 182 | self.decoders.append(get_decoder(architecture=self.kwargs['encoder'], input_shape=encoded_size_gen[-1], 183 | output_shape=self.inp_shape, activation=self.activation)) 184 | 185 | # construct the tree 186 | self.tree = construct_tree(transformations=self.transformations, routers=self.decisions, 187 | routers_q=self.decisions_q, denses=self.denses, decoders=self.decoders) 188 | 189 | def forward(self, x): 190 | """ 191 | Forward pass of the treeVAE model. 192 | 193 | Parameters 194 | ---------- 195 | x : tensor 196 | Input data (batch-size, input-size) 197 | 198 | Returns 199 | ------- 200 | dict 201 | a dictionary 202 | {'rec_loss': reconstruction loss, 203 | 'kl_root': the KL loss of the root, 204 | 'kl_decisions': the KL loss of the decisions, 205 | 'kl_nodes': the KL loss of the nodes, 206 | 'aug_decisions': the weighted contrastive loss, 207 | 'p_c_z': the probability of each sample to be assigned to each leaf with size: #samples x #leaves, 208 | 'node_leaves': a list of leaf nodes, each one described by a dictionary 209 | {'prob': sample-wise probability of reaching the node, 'z_sample': sampled leaf embedding} 210 | } 211 | """ 212 | # Small constant to prevent numerical instability 213 | epsilon = 1e-7 214 | device = x.device 215 | 216 | # compute deterministic bottom up 217 | d = x 218 | encoders = [] 219 | emb_contr = [] 220 | 221 | for i in range(0, len(self.hidden_layers)): 222 | d, _, _ = self.bottom_up[i](d) 223 | # store bottom-up embeddings for top-down 224 | encoders.append(d) 225 | 226 | # pass through contrastive MLP's if contrastive learning is selected 227 | if 'instancewise_full' in self.augmentation_method: 228 | _, emb_c, _ = self.contrastive_mlp[i](d) 229 | emb_contr.append(emb_c) 230 | elif 'instancewise_first' in self.augmentation_method: 231 | if i == 0: 232 | _, emb_c, _ = self.contrastive_mlp[i](d) 233 | emb_contr.append(emb_c) 234 | 235 | # create a list of nodes of the tree that need to be processed, self.tree is the root of the tree 236 | list_nodes = [{'node': self.tree, 'depth': 0, 'prob': torch.ones(x.size(0), device=device), 'z_parent_sample': None}] 237 | # initializate KL losses 238 | kl_nodes_tot = torch.zeros(len(x), device=device) 239 | kl_decisions_tot = torch.zeros(len(x), device=device) 240 | aug_decisions_loss = torch.zeros(1, device=device) 241 | leaves_prob = [] 242 | reconstructions = [] 243 | node_leaves = [] 244 | 245 | # iterates over all nodes in the tree 246 | while len(list_nodes) != 0: 247 | # store info regarding the current node 248 | current_node = list_nodes.pop(0) 249 | node, depth_level, prob = current_node['node'], current_node['depth'], current_node['prob'] 250 | z_parent_sample = current_node['z_parent_sample'] 251 | # access deterministic bottom up mu and sigma hat (computed above) 252 | d = encoders[-(1+depth_level)] 253 | z_mu_q_hat, z_sigma_q_hat = node.dense(d) 254 | 255 | # here we are in the root 256 | if depth_level == 0: 257 | # the root has a standard gaussian prior 258 | z_mu_p, z_sigma_p = torch.zeros_like(z_mu_q_hat, device=device), torch.ones_like(z_sigma_q_hat, device=device) 259 | z_p = td.Independent(td.Normal(z_mu_p, torch.sqrt(z_sigma_p + epsilon)), 1) 260 | # the samples z (from q(z|x)) is the top layer of deterministic bottom-up 261 | z_mu_q, z_sigma_q = z_mu_q_hat, z_sigma_q_hat 262 | 263 | # otherwise we are in the rest of the nodes of the tree 264 | else: 265 | # the generative probability distribution of internal nodes is a gaussian with mu and sigma that are 266 | # the outputs of the top-down network conditioned on the sampled parent 267 | _, z_mu_p, z_sigma_p = node.transformation(z_parent_sample) 268 | z_p = td.Independent(td.Normal(z_mu_p, torch.sqrt(z_sigma_p + epsilon)), 1) 269 | # to avoid posterior collapse there is a share of information between the bottom-up and top-down 270 | z_mu_q, z_sigma_q = compute_posterior(z_mu_q_hat, z_mu_p, z_sigma_q_hat, z_sigma_p) 271 | 272 | # compute sample z using mu_q and sigma_q 273 | z = td.Independent(td.Normal(z_mu_q, torch.sqrt(z_sigma_q + epsilon)), 1) 274 | z_sample = z.rsample() 275 | 276 | # compute KL node 277 | kl_node = prob * td.kl_divergence(z, z_p) 278 | kl_node = torch.clamp(kl_node, min=-1, max=1000) 279 | 280 | if depth_level == 0: 281 | kl_root = kl_node 282 | else: 283 | kl_nodes_tot += kl_node 284 | 285 | # if there is a router (i.e. decision probability) then we are in the internal nodes (not leaves) 286 | if node.router is not None: 287 | # compute the probability of the sample to go to the left child 288 | prob_child_left = node.router(z_sample).squeeze() 289 | prob_child_left_q = node.routers_q(d).squeeze() 290 | 291 | # compute the KL of the decisions 292 | kl_decisions = prob_child_left_q * (epsilon + prob_child_left_q / (prob_child_left + epsilon)).log() + \ 293 | (1 - prob_child_left_q) * (epsilon + (1 - prob_child_left_q) / (1 - prob_child_left + epsilon)).log() 294 | kl_decisions = prob * kl_decisions 295 | kl_decisions_tot += kl_decisions 296 | 297 | # compute the contrastive loss of the embeddings and the decisions 298 | if self.training is True and self.augment is True and 'simple' not in self.augmentation_method: 299 | if depth_level == 0: 300 | # compute the contrastive loss for all the bottom-up representations 301 | aug_decisions_loss += calc_aug_loss(prob_parent=prob, prob_router=prob_child_left_q, augmentation_methods=self.augmentation_method, emb_contr=emb_contr) 302 | else: 303 | # compute the contrastive loss for the decisions 304 | aug_decisions_loss += calc_aug_loss(prob_parent=prob, prob_router=prob_child_left_q, augmentation_methods=self.augmentation_method, emb_contr=[]) 305 | 306 | # we are not in a leaf, so we have to add the left and right child to the list 307 | prob_node_left, prob_node_right = prob * prob_child_left_q, prob * (1 - prob_child_left_q) 308 | node_left, node_right = node.left, node.right 309 | list_nodes.append( 310 | {'node': node_left, 'depth': depth_level + 1, 'prob': prob_node_left, 'z_parent_sample': z_sample}) 311 | list_nodes.append({'node': node_right, 'depth': depth_level + 1, 'prob': prob_node_right, 312 | 'z_parent_sample': z_sample}) 313 | 314 | # if there is a decoder then we are in one of the leaf 315 | elif node.decoder is not None: 316 | # if we are in a leaf we need to store the prob of reaching that leaf and compute reconstructions 317 | # as the nodes are explored left to right, these probabilities will be also ordered left to right 318 | leaves_prob.append(prob) 319 | dec = node.decoder 320 | reconstructions.append(dec(z_sample)) 321 | node_leaves.append({'prob': prob, 'z_sample': z_sample}) 322 | 323 | # here we are in an internal node with pruned leaves and thus only have one child 324 | elif node.router is None and node.decoder is None: 325 | node_left, node_right = node.left, node.right 326 | child = node_left if node_left is not None else node_right 327 | list_nodes.append( 328 | {'node': child, 'depth': depth_level + 1, 'prob': prob, 'z_parent_sample': z_sample}) 329 | 330 | kl_nodes_loss = torch.clamp(torch.mean(kl_nodes_tot), min=-10, max=1e10) 331 | kl_decisions_loss = torch.mean(kl_decisions_tot) 332 | kl_root_loss = torch.mean(kl_root) 333 | 334 | # p_c_z is the probability of reaching a leaf and is of shape [batch_size, num_clusters] 335 | p_c_z = torch.cat([prob.unsqueeze(-1) for prob in leaves_prob], dim=-1) 336 | 337 | rec_losses = self.loss(x, reconstructions, leaves_prob) 338 | rec_loss = torch.mean(rec_losses, dim=0) 339 | 340 | return_dict = { 341 | 'rec_loss': rec_loss, 342 | 'kl_root': kl_root_loss, 343 | 'kl_decisions': kl_decisions_loss, 344 | 'kl_nodes': kl_nodes_loss, 345 | 'aug_decisions': self.aug_decisions_weight * aug_decisions_loss, 346 | 'p_c_z': p_c_z, 347 | 'node_leaves': node_leaves, 348 | } 349 | 350 | if self.return_elbo: 351 | return_dict['elbo_samples'] = kl_nodes_tot + kl_decisions_tot + kl_root + rec_losses 352 | 353 | if self.return_bottomup: 354 | return_dict['bottom_up'] = encoders 355 | 356 | if self.return_x: 357 | return_dict['input'] = x 358 | 359 | return return_dict 360 | 361 | 362 | def compute_leaves(self): 363 | """ 364 | Computes the leaves of the tree 365 | 366 | Returns 367 | ------- 368 | list 369 | A list of the leaves from left to right. 370 | A leaf is defined by a dictionary: {'node': leaf node, 'depth': depth of the leaf node}. 371 | A leaf node is defined by the class Node in utils.model_utils. 372 | """ 373 | # iterate over all nodes in the tree to find the leaves 374 | list_nodes = [{'node': self.tree, 'depth': 0}] 375 | nodes_leaves = [] 376 | while len(list_nodes) != 0: 377 | current_node = list_nodes.pop(0) 378 | node, depth_level = current_node['node'], current_node['depth'] 379 | if node.router is not None: 380 | node_left, node_right = node.left, node.right 381 | list_nodes.append( 382 | {'node': node_left, 'depth': depth_level + 1}) 383 | list_nodes.append({'node': node_right, 'depth': depth_level + 1}) 384 | elif node.router is None and node.decoder is None: 385 | # we are in an internal node with pruned leaves and thus only have one child 386 | node_left, node_right = node.left, node.right 387 | child = node_left if node_left is not None else node_right 388 | list_nodes.append({'node': child, 'depth': depth_level + 1}) 389 | else: 390 | nodes_leaves.append(current_node) 391 | return nodes_leaves 392 | 393 | 394 | def compute_depth(self): 395 | """ 396 | Computes the depth of the tree 397 | 398 | Returns 399 | ------- 400 | int 401 | The depth of the tree (the root has depth 0 and a root with two leaves had depth 1). 402 | """ 403 | # computes depth of the tree 404 | nodes_leaves = self.compute_leaves() 405 | d = [] 406 | for i in range(len(nodes_leaves)): 407 | d.append(nodes_leaves[i]['depth']) 408 | return max(d) 409 | 410 | def attach_smalltree(self, node, small_model): 411 | """ 412 | Attach a trained small tree of the class SmallTreeVAE (models.model_smalltree) to the given node of the full 413 | TreeVAE. The small tree has one root and two leaves. It does not return anything but changes self.tree 414 | 415 | Parameters 416 | ---------- 417 | node : utils.model_utils.Node 418 | The selected node of TreeVAE where to attach the sub-tree, which was trained separately. 419 | small_model: models.model_smalltree.SmallTreeVAE 420 | The sub-tree with one root and two leaves that needs to be attached to TreeVAE. 421 | """ 422 | assert node.left is None and node.right is None 423 | node.router = small_model.decision 424 | node.routers_q = small_model.decision_q 425 | node.decoder = None 426 | for j in range(2): 427 | dense = small_model.denses[j] 428 | transformation = small_model.transformations[j] 429 | decoder = small_model.decoders[j] 430 | # insert each leaf of the small tree as child of the node of TreeVAE 431 | node.insert(transformation, None, None, dense, decoder) 432 | 433 | # once the small tree is attached we re-compute the list of transformations, routers etc 434 | transformations, routers, denses, decoders, routers_q = return_list_tree(self.tree) 435 | 436 | # we then need to re-initialize the parameters of TreeVAE 437 | self.decisions_q = routers_q 438 | self.transformations = transformations 439 | self.decisions = routers 440 | self.denses = denses 441 | self.decoders = decoders 442 | self.depth = self.compute_depth() 443 | return 444 | 445 | 446 | def compute_reconstruction(self, x): 447 | """ 448 | Given the input x, it computes the reconstructions. 449 | 450 | Parameters 451 | ---------- 452 | x: Tensor 453 | Input data. 454 | 455 | Returns 456 | ------- 457 | Tensor 458 | The reconstructions of the input data by computing a forward pass of the model. 459 | List 460 | A list of leaf nodes, each one described by a dictionary 461 | {'prob': sample-wise probability of reaching the node, 'z_sample': sampled leaf embedding} 462 | """ 463 | assert self.training is False 464 | epsilon = 1e-7 465 | device = x.device 466 | 467 | # compute deterministic bottom up 468 | d = x 469 | encoders = [] 470 | 471 | for i in range(0, len(self.hidden_layers)): 472 | d, _, _ = self.bottom_up[i](d) 473 | # store the bottom-up layers for the top down computation 474 | encoders.append(d) 475 | 476 | # create a list of nodes of the tree that need to be processed 477 | list_nodes = [{'node': self.tree, 'depth': 0, 'prob': torch.ones(x.size(0), device=device), 'z_parent_sample': None}] 478 | 479 | # initializate KL losses 480 | leaves_prob = [] 481 | reconstructions = [] 482 | node_leaves = [] 483 | 484 | # iterate over the nodes 485 | while len(list_nodes) != 0: 486 | 487 | # store info regarding the current node 488 | current_node = list_nodes.pop(0) 489 | node, depth_level, prob = current_node['node'], current_node['depth'], current_node['prob'] 490 | z_parent_sample = current_node['z_parent_sample'] 491 | # access deterministic bottom up mu and sigma hat (computed above) 492 | d = encoders[-(1+depth_level)] 493 | z_mu_q_hat, z_sigma_q_hat = node.dense(d) 494 | 495 | if depth_level == 0: 496 | z_mu_q, z_sigma_q = z_mu_q_hat, z_sigma_q_hat 497 | else: 498 | # the generative mu and sigma is the output of the top-down network given the sampled parent 499 | _, z_mu_p, z_sigma_p = node.transformation(z_parent_sample) 500 | z_mu_q, z_sigma_q = compute_posterior(z_mu_q_hat, z_mu_p, z_sigma_q_hat, z_sigma_p) 501 | 502 | # compute sample z using mu_q and sigma_q 503 | z = td.Independent(td.Normal(z_mu_q, torch.sqrt(z_sigma_q + epsilon)), 1) 504 | z_sample = z.rsample() 505 | 506 | # if we are in the internal nodes (not leaves) 507 | if node.router is not None: 508 | 509 | prob_child_left_q = node.routers_q(d).squeeze() 510 | 511 | # we are not in a leaf, so we have to add the left and right child to the list 512 | prob_node_left, prob_node_right = prob * prob_child_left_q, prob * (1 - prob_child_left_q) 513 | 514 | node_left, node_right = node.left, node.right 515 | list_nodes.append( 516 | {'node': node_left, 'depth': depth_level + 1, 'prob': prob_node_left, 'z_parent_sample': z_sample}) 517 | list_nodes.append({'node': node_right, 'depth': depth_level + 1, 'prob': prob_node_right, 518 | 'z_parent_sample': z_sample}) 519 | 520 | elif node.decoder is not None: 521 | # if we are in a leaf we need to store the prob of reaching that leaf and compute reconstructions 522 | # as the nodes are explored left to right, these probabilities will be also ordered left to right 523 | leaves_prob.append(prob) 524 | dec = node.decoder 525 | reconstructions.append(dec(z_sample)) 526 | node_leaves.append({'prob': prob, 'z_sample': z_sample}) 527 | 528 | elif node.router is None and node.decoder is None: 529 | # We are in an internal node with pruned leaves and thus only have one child 530 | node_left, node_right = node.left, node.right 531 | child = node_left if node_left is not None else node_right 532 | list_nodes.append( 533 | {'node': child, 'depth': depth_level + 1, 'prob': prob, 'z_parent_sample': z_sample}) 534 | 535 | return reconstructions, node_leaves 536 | 537 | def generate_images(self, n_samples, device): 538 | """ 539 | Generate K x n_samples new images by sampling from the root and propagating through the entire tree. 540 | For each sample the method generates K images, where K is the number of leaves. 541 | 542 | Parameters 543 | ---------- 544 | n_samples: int 545 | Number of generated samples the function should output. 546 | device: torch.device 547 | Either cpu or gpu 548 | 549 | Returns 550 | ------- 551 | list 552 | A list of K tensors containing the leaf-specific generations obtained by sampling from the root and 553 | propagating through the entire tree, where K is the number of leaves. 554 | Tensor 555 | The probability of each generated sample to be assigned to each leaf with size: #samples x #leaves, 556 | """ 557 | assert self.training is False 558 | epsilon = 1e-7 559 | sizes = self.encoded_sizes 560 | list_nodes = [{'node': self.tree, 'depth': 0, 'prob': torch.ones(n_samples, device=device), 'z_parent_sample': None}] 561 | leaves_prob = [] 562 | reconstructions = [] 563 | while len(list_nodes) != 0: 564 | current_node = list_nodes.pop(0) 565 | node, depth_level, prob = current_node['node'], current_node['depth'], current_node['prob'] 566 | z_parent_sample = current_node['z_parent_sample'] 567 | 568 | if depth_level == 0: 569 | z_mu_p, z_sigma_p = torch.zeros([n_samples, sizes[-1]], device=device), torch.ones([n_samples, sizes[-1]], device=device) 570 | z_p = td.Independent(td.Normal(z_mu_p, torch.sqrt(z_sigma_p+epsilon)), 1) 571 | z_sample = z_p.rsample() 572 | 573 | else: 574 | _, z_mu_p, z_sigma_p = node.transformation(z_parent_sample) 575 | z_p = td.Independent(td.Normal(z_mu_p, torch.sqrt(z_sigma_p+epsilon)), 1) 576 | z_sample = z_p.rsample() 577 | 578 | if node.router is not None: 579 | prob_child_left = node.router(z_sample).squeeze() 580 | prob_node_left, prob_node_right = prob * prob_child_left, prob * ( 581 | 1 - prob_child_left) 582 | node_left, node_right = node.left, node.right 583 | list_nodes.append( 584 | {'node': node_left, 'depth': depth_level + 1, 'prob': prob_node_left, 'z_parent_sample': z_sample}) 585 | list_nodes.append({'node': node_right, 'depth': depth_level + 1, 'prob': prob_node_right, 586 | 'z_parent_sample': z_sample}) 587 | 588 | elif node.decoder is not None: 589 | # here we are in a leaf node and we attach the corresponding generations 590 | leaves_prob.append(prob) 591 | dec = node.decoder 592 | reconstructions.append(dec(z_sample)) 593 | 594 | elif node.router is None and node.decoder is None: 595 | # We are in an internal node with pruned leaves and thus only have one child 596 | node_left, node_right = node.left, node.right 597 | child = node_left if node_left is not None else node_right 598 | list_nodes.append( 599 | {'node': child, 'depth': depth_level + 1, 'prob': prob, 'z_parent_sample': z_sample}) 600 | p_c_z = torch.cat([prob.unsqueeze(-1) for prob in leaves_prob], dim=-1) 601 | 602 | return reconstructions, p_c_z 603 | -------------------------------------------------------------------------------- /models/model_smalltree.py: -------------------------------------------------------------------------------- 1 | """ 2 | SmallTreeVAE model (used for the growing procedure of TreeVAE). 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.distributions as td 7 | from models.networks import get_decoder, MLP, Router, Dense 8 | from utils.model_utils import compute_posterior 9 | from models.losses import loss_reconstruction_binary, loss_reconstruction_mse 10 | from utils.training_utils import calc_aug_loss 11 | 12 | class SmallTreeVAE(nn.Module): 13 | """ 14 | A class used to represent a sub-tree VAE with one root and two children. 15 | 16 | SmallTreeVAE specifies a sub-tree of TreeVAE with one root and two children. It is used in the 17 | growing procedure of TreeVAE. At each growing step a new SmallTreeVAE is attached to a leaf of TreeVAE and 18 | trained separately to reduce computational time. 19 | 20 | Attributes 21 | ---------- 22 | activation : str 23 | The name of the activation function for the reconstruction loss [sigmoid, mse] 24 | loss : models.losses 25 | The loss function used by the decoder to reconstruct the input 26 | alpha : float 27 | KL-annealing weight initialization 28 | depth : int 29 | The depth at which the sub-tree will be attached (root has depth 0 and a root with two leaves has depth 1) 30 | inp_shape : int 31 | The total dimensions of the input data (if images of 32x32 then 32x32x3) 32 | augment : bool 33 | Whether to use contrastive learning through augmentation, if False no augmentation is used 34 | augmentation_method : str 35 | The type of augmentation method used 36 | aug_decisions_weight : str 37 | The weight of the contrastive loss used in the decisions 38 | denses : nn.ModuleList 39 | List of dense layers for the sharing of top-down and bottom-up (MLPs) associated with each of the two leaf 40 | node of the tree from left to right. 41 | transformations : nn.ModuleList 42 | List of transformations (MLPs) associated with each of the two leaf node of the sub-tree from left to right 43 | decision : Router 44 | The decision associated with the root of the sub-tree. 45 | decoders : nn.ModuleList 46 | List of two decoders one for each leaf of the sub-tree 47 | decision_q : str 48 | The decision of the bottom-up associated with the root of the sub-tree 49 | 50 | Methods 51 | ------- 52 | forward(x) 53 | Compute the forward pass of the SmallTreeVAE model and return a dictionary of losses. 54 | """ 55 | def __init__(self, depth, **kwargs): 56 | """ 57 | Parameters 58 | ---------- 59 | depth: int 60 | The depth at which the sub-tree will be attached to TreeVAE 61 | kwargs : dict 62 | A dictionary of attributes (see config file). 63 | """ 64 | super(SmallTreeVAE, self).__init__() 65 | self.kwargs = kwargs 66 | 67 | self.activation = self.kwargs['activation'] 68 | if self.activation == "sigmoid": 69 | self.loss = loss_reconstruction_binary 70 | elif self.activation == "mse": 71 | self.loss = loss_reconstruction_mse 72 | else: 73 | raise NotImplementedError 74 | # KL-annealing weight initialization 75 | self.alpha=self.kwargs['kl_start'] 76 | 77 | encoded_sizes = self.kwargs['latent_dim'] 78 | hidden_layers = self.kwargs['mlp_layers'] 79 | self.depth = depth 80 | encoded_size_gen = encoded_sizes[-(self.depth+1):-(self.depth-1)] # e.g. encoded_size_gen = 32,16, depth 2 81 | self.encoded_size = encoded_size_gen[::-1] # self.encoded_size = 32,16 => 16,32 82 | layers_gen = hidden_layers[-(self.depth+1):-(self.depth-1)] # e.g. encoded_sizes 256,128,64, depth 2 83 | self.hidden_layer = layers_gen[::-1] # encoded_size_gen = 256,128 => 128,256 84 | 85 | self.inp_shape = self.kwargs['inp_shape'] 86 | self.augment = self.kwargs['augment'] 87 | self.augmentation_method = self.kwargs['augmentation_method'] 88 | self.aug_decisions_weight = self.kwargs['aug_decisions_weight'] 89 | 90 | self.denses = nn.ModuleList([Dense(self.hidden_layer[1], self.encoded_size[1]) for _ in range(2)]) 91 | self.transformations = nn.ModuleList([MLP(self.encoded_size[0], self.encoded_size[1], self.hidden_layer[0]) for _ in range(2)]) 92 | self.decision = Router(self.encoded_size[0], hidden_units=self.hidden_layer[0]) 93 | self.decision_q = Router(self.hidden_layer[0], hidden_units=self.hidden_layer[0]) 94 | self.decoders = nn.ModuleList([get_decoder(architecture=self.kwargs['encoder'], input_shape=self.encoded_size[1], 95 | output_shape=self.inp_shape, activation=self.activation) for _ in range(2)]) 96 | 97 | def forward(self, x, z_parent, p, bottom_up): 98 | """ 99 | Forward pass of the SmallTreeVAE model. 100 | 101 | Parameters 102 | ---------- 103 | x : tensor 104 | Input data (batch-size, input-size) 105 | z_parent: tensor 106 | The embeddings of the parent of the two children of SmallTreeVAE (which are the embeddings of the TreeVAE 107 | leaf where the SmallTreeVAE will be attached) 108 | p: list 109 | Probabilities of falling into the selected TreeVAE leaf where the SmallTreeVAE will be attached 110 | bottom_up: list 111 | The list of bottom-up transformations [encoder, MLP, MLP, ...] up to the root 112 | 113 | Returns 114 | ------- 115 | dict 116 | a dictionary 117 | {'rec_loss': reconstruction loss, 118 | 'kl_decisions': the KL loss of the decisions, 119 | 'kl_nodes': the KL loss of the nodes, 120 | 'aug_decisions': the weighted contrastive loss, 121 | 'p_c_z': the probability of each sample to be assigned to each leaf with size: #samples x #leaves, 122 | } 123 | """ 124 | epsilon = 1e-7 # Small constant to prevent numerical instability 125 | device = x.device 126 | 127 | # Extract relevant bottom-up 128 | d_q = bottom_up[-self.depth] 129 | d = bottom_up[-self.depth - 1] 130 | 131 | prob_child_left = self.decision(z_parent).squeeze() 132 | prob_child_left_q = self.decision_q(d_q).squeeze() 133 | leaves_prob = [p * prob_child_left_q, p * (1 - prob_child_left_q)] 134 | 135 | kl_decisions = prob_child_left_q * torch.log(epsilon + prob_child_left_q / (prob_child_left + epsilon)) +\ 136 | (1 - prob_child_left_q) * torch.log(epsilon + (1 - prob_child_left_q) / 137 | (1 - prob_child_left + epsilon)) 138 | kl_decisions = torch.mean(p * kl_decisions) 139 | 140 | # Contrastive loss 141 | aug_decisions_loss = torch.zeros(1, device=device) 142 | if self.training is True and self.augment is True and 'simple' not in self.augmentation_method: 143 | aug_decisions_loss += calc_aug_loss(prob_parent=p, prob_router=prob_child_left_q, 144 | augmentation_methods=self.augmentation_method) 145 | 146 | reconstructions = [] 147 | kl_nodes = torch.zeros(1, device=device) 148 | for i in range(2): 149 | # Compute posterior parameters 150 | z_mu_q_hat, z_sigma_q_hat = self.denses[i](d) 151 | _, z_mu_p, z_sigma_p = self.transformations[i](z_parent) 152 | z_p = td.Independent(td.Normal(z_mu_p, torch.sqrt(z_sigma_p+epsilon)), 1) 153 | z_mu_q, z_sigma_q = compute_posterior(z_mu_q_hat, z_mu_p, z_sigma_q_hat, z_sigma_p) 154 | 155 | # Compute sample z using mu_q and sigma_q 156 | z_q = td.Independent(td.Normal(z_mu_q, torch.sqrt(z_sigma_q + epsilon)), 1) 157 | z_sample = z_q.rsample() 158 | 159 | # Compute KL node 160 | kl_node = torch.mean(leaves_prob[i] * td.kl_divergence(z_q, z_p)) 161 | kl_nodes += kl_node 162 | 163 | reconstructions.append(self.decoders[i](z_sample)) 164 | 165 | kl_nodes_loss = torch.clamp(kl_nodes, min=-10, max=1e10) 166 | 167 | # Probability of falling in each leaf 168 | p_c_z = torch.cat([prob.unsqueeze(-1) for prob in leaves_prob], dim=-1) 169 | 170 | rec_losses = self.loss(x, reconstructions, leaves_prob) 171 | rec_loss = torch.mean(rec_losses, dim=0) 172 | 173 | return { 174 | 'rec_loss': rec_loss, 175 | 'kl_decisions': kl_decisions, 176 | 'kl_nodes': kl_nodes_loss, 177 | 'aug_decisions': self.aug_decisions_weight * aug_decisions_loss, 178 | 'p_c_z': p_c_z, 179 | } 180 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Encoder, decoder, transformation, router, and dense layer architectures. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def actvn(x): 10 | return F.leaky_relu(x, negative_slope=0.3) 11 | 12 | class EncoderSmall(nn.Module): 13 | def __init__(self, input_shape, output_shape): 14 | super(EncoderSmall, self).__init__() 15 | 16 | self.dense1 = nn.Linear(in_features=input_shape, out_features=4*output_shape, bias=False) 17 | self.bn1 = nn.BatchNorm1d(4*output_shape) 18 | self.dense2 = nn.Linear(in_features=4*output_shape, out_features=4*output_shape, bias=False) 19 | self.bn2 = nn.BatchNorm1d(4*output_shape) 20 | self.dense3 = nn.Linear(in_features=4*output_shape, out_features=2*output_shape, bias=False) 21 | self.bn3 = nn.BatchNorm1d(2*output_shape) 22 | self.dense4 = nn.Linear(in_features=2*output_shape, out_features=output_shape, bias=False) 23 | self.bn4 = nn.BatchNorm1d(output_shape) 24 | 25 | def forward(self, inputs): 26 | x = self.dense1(inputs) 27 | x = self.bn1(x) 28 | x = actvn(x) 29 | x = self.dense2(x) 30 | x = self.bn2(x) 31 | x = actvn(x) 32 | x = self.dense3(x) 33 | x = self.bn3(x) 34 | x = actvn(x) 35 | x = self.dense4(x) 36 | x = self.bn4(x) 37 | x = actvn(x) 38 | return x, None, None 39 | 40 | class DecoderSmall(nn.Module): 41 | def __init__(self, input_shape, output_shape, activation): 42 | super(DecoderSmall, self).__init__() 43 | self.activation = activation 44 | self.dense1 = nn.Linear(in_features=input_shape, out_features=128, bias=False) 45 | self.bn1 = nn.BatchNorm1d(128) 46 | self.dense2 = nn.Linear(in_features=128, out_features=256, bias=False) 47 | self.bn2 = nn.BatchNorm1d(256) 48 | self.dense3 = nn.Linear(in_features=256, out_features=512, bias=False) 49 | self.bn3 = nn.BatchNorm1d(512) 50 | self.dense4 = nn.Linear(in_features=512, out_features=512, bias=False) 51 | self.bn4 = nn.BatchNorm1d(512) 52 | self.dense5 = nn.Linear(in_features=512, out_features=output_shape, bias=True) 53 | 54 | def forward(self, inputs): 55 | x = self.dense1(inputs) 56 | x = self.bn1(x) 57 | x = actvn(x) 58 | x = self.dense2(x) 59 | x = self.bn2(x) 60 | x = actvn(x) 61 | x = self.dense3(x) 62 | x = self.bn3(x) 63 | x = actvn(x) 64 | x = self.dense4(x) 65 | x = self.bn4(x) 66 | x = actvn(x) 67 | x = self.dense5(x) 68 | if self.activation == "sigmoid": 69 | x = torch.sigmoid(x) 70 | return x 71 | 72 | 73 | class EncoderSmallCnn(nn.Module): 74 | def __init__(self, encoded_size): 75 | super(EncoderSmallCnn, self).__init__() 76 | n_maps_output = encoded_size//4 77 | self.cnn0 = nn.Conv2d(in_channels=1, out_channels=n_maps_output//4, kernel_size=3, stride=2, padding=0, bias=False) 78 | self.cnn1 = nn.Conv2d(in_channels=n_maps_output//4, out_channels=n_maps_output//2, kernel_size=3, stride=2, padding=0, bias=False) 79 | self.cnn2 = nn.Conv2d(in_channels=n_maps_output//2, out_channels=n_maps_output, kernel_size=3, stride=2, padding=0, bias=False) 80 | self.bn0 = nn.BatchNorm2d(n_maps_output//4) 81 | self.bn1 = nn.BatchNorm2d(n_maps_output//2) 82 | self.bn2 = nn.BatchNorm2d(n_maps_output) 83 | 84 | def forward(self, x): 85 | x = self.cnn0(x) 86 | x = self.bn0(x) 87 | x = actvn(x) 88 | x = self.cnn1(x) 89 | x = self.bn1(x) 90 | x = actvn(x) 91 | x = self.cnn2(x) 92 | x = self.bn2(x) 93 | x = actvn(x) 94 | x = x.view(x.size(0), -1) 95 | return x, None, None 96 | 97 | class DecoderSmallCnn(nn.Module): 98 | def __init__(self, input_shape, activation): 99 | super(DecoderSmallCnn, self).__init__() 100 | self.activation = activation 101 | self.dense = nn.Linear(in_features=input_shape, out_features=3 * 3 * 32, bias=False) 102 | self.bn = nn.BatchNorm1d(3 * 3 * 32) 103 | self.bn1 = nn.BatchNorm2d(16) 104 | self.bn2 = nn.BatchNorm2d(8) 105 | self.cnn1 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, bias=False) 106 | self.cnn2 = nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False) 107 | self.cnn3 = nn.ConvTranspose2d(in_channels=8, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True) 108 | 109 | def forward(self, inputs): 110 | x = self.dense(inputs) 111 | x = self.bn(x) 112 | x = actvn(x) 113 | x = x.view(-1, 32, 3, 3) 114 | x = self.cnn1(x) 115 | x = self.bn1(x) 116 | x = actvn(x) 117 | x = self.cnn2(x) 118 | x = self.bn2(x) 119 | x = actvn(x) 120 | x = self.cnn3(x) 121 | if self.activation == 'sigmoid': 122 | x = torch.sigmoid(x) 123 | return x 124 | 125 | 126 | class EncoderOmniglot(nn.Module): 127 | def __init__(self, encoded_size): 128 | super(EncoderOmniglot, self).__init__() 129 | self.cnns = nn.ModuleList([ 130 | nn.Conv2d(in_channels=1, out_channels=encoded_size//4, kernel_size=4, stride=1, padding=1, bias=False), 131 | nn.Conv2d(in_channels=encoded_size//4, out_channels=encoded_size//4, kernel_size=4, stride=2, padding=1, bias=False), 132 | nn.Conv2d(in_channels=encoded_size//4, out_channels=encoded_size//2, kernel_size=4, stride=1, padding=1, bias=False), 133 | nn.Conv2d(in_channels=encoded_size//2, out_channels=encoded_size//2, kernel_size=4, stride=2, padding=1, bias=False), 134 | nn.Conv2d(in_channels=encoded_size//2, out_channels=encoded_size, kernel_size=4, stride=1, padding=1, bias=False), 135 | nn.Conv2d(in_channels=encoded_size, out_channels=encoded_size, kernel_size=5, bias=False) 136 | ]) 137 | self.bns = nn.ModuleList([ 138 | nn.BatchNorm2d(encoded_size//4), 139 | nn.BatchNorm2d(encoded_size//4), 140 | nn.BatchNorm2d(encoded_size//2), 141 | nn.BatchNorm2d(encoded_size//2), 142 | nn.BatchNorm2d(encoded_size), 143 | nn.BatchNorm2d(encoded_size) 144 | ]) 145 | 146 | def forward(self, x): 147 | for i in range(len(self.cnns)): 148 | x = self.cnns[i](x) 149 | x = self.bns[i](x) 150 | x = actvn(x) 151 | x = x.view(x.size(0), -1) 152 | return x, None, None 153 | 154 | class DecoderOmniglot(nn.Module): 155 | def __init__(self, input_shape, activation): 156 | super(DecoderOmniglot, self).__init__() 157 | self.activation = activation 158 | self.dense = nn.Linear(in_features=input_shape, out_features=2 * 2 * 128, bias=False) 159 | self.cnns = nn.ModuleList([ 160 | nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=5, stride=2, bias=False), 161 | nn.Conv2d(in_channels=128, out_channels=64, kernel_size=4, stride=1, padding=1, bias=False), 162 | nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=0, bias=False), 163 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=4, stride=1, padding=1, bias=False), 164 | nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=0, output_padding=1, bias=False) 165 | ]) 166 | self.cnns.append(nn.Conv2d(in_channels=32, out_channels=1, kernel_size=4, stride=1, padding=1, bias=True)) 167 | self.bn = nn.BatchNorm1d(2 * 2 * 128) 168 | self.bns = nn.ModuleList([ 169 | nn.BatchNorm2d(128), 170 | nn.BatchNorm2d(64), 171 | nn.BatchNorm2d(64), 172 | nn.BatchNorm2d(32), 173 | nn.BatchNorm2d(32) 174 | ]) 175 | 176 | def forward(self, inputs): 177 | x = self.dense(inputs) 178 | x = self.bn(x) 179 | x = actvn(x) 180 | x = x.view(-1, 128, 2, 2) 181 | for i in range(len(self.bns)): 182 | x = self.cnns[i](x) 183 | x = self.bns[i](x) 184 | x = actvn(x) 185 | x = self.cnns[-1](x) 186 | if self.activation == "sigmoid": 187 | x = torch.sigmoid(x) 188 | return x 189 | 190 | 191 | class ResnetBlock(nn.Module): 192 | def __init__(self, fin, fout, fhidden=None, is_bias=True): 193 | super(ResnetBlock, self).__init__() 194 | 195 | self.learned_shortcut = (fin != fout) 196 | self.fin = fin 197 | self.fout = fout 198 | if fhidden is None: 199 | self.fhidden = min(fin, fout) 200 | else: 201 | self.fhidden = fhidden 202 | 203 | # Submodules 204 | self.conv_0 = nn.Conv2d(in_channels=fin, out_channels=self.fhidden, kernel_size=3, stride=1, padding=1) 205 | self.conv_1 = nn.Conv2d(in_channels=self.fhidden, out_channels=self.fout, kernel_size=3, stride=1, padding=1, bias=is_bias) 206 | if self.learned_shortcut: 207 | self.conv_s = nn.Conv2d(in_channels=fin, out_channels=self.fout, kernel_size=1, stride=1, padding=0, bias=False) 208 | self.bn0 = nn.BatchNorm2d(self.fin) 209 | self.bn1 = nn.BatchNorm2d(self.fhidden) 210 | 211 | def forward(self, x): 212 | x_s = self._shortcut(x) 213 | dx = self.conv_0(actvn(self.bn0(x))) 214 | dx = self.conv_1(actvn(self.bn1(dx))) 215 | out = x_s + 0.1 * dx 216 | return out 217 | 218 | def _shortcut(self, x): 219 | if self.learned_shortcut: 220 | x_s = self.conv_s(x) 221 | else: 222 | x_s = x 223 | return x_s 224 | 225 | class Resnet_Encoder(nn.Module): 226 | def __init__(self, s0=2, nf=8, nf_max=256, size=32): 227 | super(Resnet_Encoder, self).__init__() 228 | 229 | self.s0 = s0 230 | self.nf = nf 231 | self.nf_max = nf_max 232 | self.size = size 233 | 234 | # Submodules 235 | nlayers = int(torch.log2(torch.tensor(size / s0).float())) 236 | self.nf0 = min(nf_max, nf * 2 ** nlayers) 237 | 238 | blocks = [ 239 | ResnetBlock(nf, nf) 240 | ] 241 | 242 | for i in range(nlayers): 243 | nf0 = min(nf * 2 ** i, nf_max) 244 | nf1 = min(nf * 2 ** (i + 1), nf_max) 245 | blocks += [ 246 | nn.AvgPool2d(kernel_size=3, stride=2, padding=1), 247 | ResnetBlock(nf0, nf1), 248 | ] 249 | 250 | self.conv_img = nn.Conv2d(3, 1 * nf, kernel_size=3, padding=1) 251 | 252 | self.resnet = nn.Sequential(*blocks) 253 | 254 | self.bn0 = nn.BatchNorm2d(self.nf0) 255 | 256 | 257 | def forward(self, x): 258 | out = self.conv_img(x) 259 | out = self.resnet(out) 260 | out = actvn(self.bn0(out)) 261 | out = out.view(out.size(0), -1) 262 | return out, None, None 263 | 264 | class Resnet_Decoder(nn.Module): 265 | def __init__(self, s0=2, nf=8, nf_max=256, ndim=64, activation='sigmoid', size=32): 266 | super(Resnet_Decoder, self).__init__() 267 | 268 | self.s0 = s0 269 | self.nf = nf 270 | self.nf_max = nf_max 271 | self.activation = activation 272 | 273 | # Submodules 274 | nlayers = int(torch.log2(torch.tensor(size / s0).float())) 275 | self.nf0 = min(nf_max, nf * 2 ** nlayers) 276 | 277 | self.fc = nn.Linear(ndim, self.nf0 * s0 * s0) 278 | 279 | blocks = [] 280 | for i in range(nlayers): 281 | nf0 = min(nf * 2 ** (nlayers - i), nf_max) 282 | nf1 = min(nf * 2 ** (nlayers - i - 1), nf_max) 283 | blocks += [ 284 | ResnetBlock(nf0, nf1), 285 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 286 | ] 287 | blocks += [ 288 | ResnetBlock(nf, nf), 289 | ] 290 | self.resnet = nn.Sequential(*blocks) 291 | 292 | self.bn0 = nn.BatchNorm2d(nf) 293 | self.conv_img = nn.ConvTranspose2d(nf, 3, kernel_size=3, padding=1) 294 | 295 | 296 | def forward(self, z): 297 | out = self.fc(z) 298 | out = out.view(-1, self.nf0, self.s0, self.s0) 299 | out = self.resnet(out) 300 | out = self.conv_img(actvn(self.bn0(out))) 301 | if self.activation == 'sigmoid': 302 | out = torch.sigmoid(out) 303 | return out 304 | 305 | 306 | # Small branch transformation 307 | class MLP(nn.Module): 308 | def __init__(self, input_size, encoded_size, hidden_unit): 309 | super(MLP, self).__init__() 310 | self.dense1 = nn.Linear(input_size, hidden_unit, bias=False) 311 | self.bn1 = nn.BatchNorm1d(hidden_unit) 312 | self.mu = nn.Linear(hidden_unit, encoded_size) 313 | self.sigma = nn.Linear(hidden_unit, encoded_size) 314 | 315 | def forward(self, inputs): 316 | x = self.dense1(inputs) 317 | x = self.bn1(x) 318 | x = actvn(x) 319 | mu = self.mu(x) 320 | sigma = F.softplus(self.sigma(x)) 321 | return x, mu, sigma 322 | 323 | 324 | class Dense(nn.Module): 325 | def __init__(self, input_size, encoded_size): 326 | super(Dense, self).__init__() 327 | self.mu = nn.Linear(input_size, encoded_size) 328 | self.sigma = nn.Linear(input_size, encoded_size) 329 | 330 | def forward(self, inputs): 331 | x = inputs 332 | mu = self.mu(x) 333 | sigma = F.softplus(self.sigma(x)) 334 | return mu, sigma 335 | 336 | 337 | class Router(nn.Module): 338 | def __init__(self, input_size, hidden_units=128): 339 | super(Router, self).__init__() 340 | self.dense1 = nn.Linear(input_size, hidden_units, bias=False) 341 | self.dense2 = nn.Linear(hidden_units, hidden_units, bias=False) 342 | self.bn1 = nn.BatchNorm1d(hidden_units) 343 | self.bn2 = nn.BatchNorm1d(hidden_units) 344 | self.dense3 = nn.Linear(hidden_units, 1) 345 | 346 | def forward(self, inputs, return_last_layer=False): 347 | x = self.dense1(inputs) 348 | x = self.bn1(x) 349 | x = actvn(x) 350 | x = self.dense2(x) 351 | x = self.bn2(x) 352 | x = actvn(x) 353 | d = F.sigmoid(self.dense3(x)) 354 | if return_last_layer: 355 | return d, x 356 | else: 357 | return d 358 | 359 | 360 | def get_encoder(architecture, encoded_size, x_shape, size=None): 361 | if architecture == 'mlp': 362 | encoder = EncoderSmall(input_shape=x_shape, output_shape=encoded_size) 363 | elif architecture == 'cnn1': 364 | encoder = EncoderSmallCnn(encoded_size) 365 | elif architecture == 'cnn2': 366 | encoder = Resnet_Encoder(s0=4, nf=32, nf_max=256, size=size) 367 | elif architecture == 'cnn_omni': 368 | encoder = EncoderOmniglot(encoded_size) 369 | else: 370 | raise ValueError('The encoder architecture is mispecified.') 371 | return encoder 372 | 373 | 374 | def get_decoder(architecture, input_shape, output_shape, activation): 375 | if architecture == 'mlp': 376 | decoder = DecoderSmall(input_shape, output_shape, activation) 377 | elif architecture == 'cnn1': 378 | decoder = DecoderSmallCnn(input_shape, activation) 379 | elif architecture == 'cnn2': 380 | size = int((output_shape/3)**0.5) 381 | decoder = Resnet_Decoder(s0=4, nf=32, nf_max=256, ndim = input_shape, activation=activation, size=size) 382 | elif architecture == 'cnn_omni': 383 | decoder = DecoderOmniglot(input_shape, activation) 384 | else: 385 | raise ValueError('The decoder architecture is mispecified.') 386 | return decoder 387 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run training and validation functions of TreeVAE. 3 | """ 4 | import time 5 | from pathlib import Path 6 | import wandb 7 | import uuid 8 | import os 9 | import torch 10 | 11 | from utils.data_utils import get_data 12 | from utils.utils import reset_random_seeds 13 | from train.train_tree import run_tree 14 | from train.validate_tree import val_tree 15 | 16 | 17 | def run_experiment(configs): 18 | """ 19 | Run the experiments for TreeVAE as defined in the config setting. This method will set up the device, the correct 20 | experimental paths, initialize Wandb for tracking, generate the dataset, train and grow the TreeVAE model, and 21 | finally it will validate the result. All final results and validations will be stored in Wandb, while the most 22 | important ones will be also printed out in the terminal. If specified, the model will also be saved for further 23 | exploration using the Jupyter Notebook: tree_exploration.ipynb. 24 | 25 | Parameters 26 | ---------- 27 | configs: dict 28 | The config setting for training and validating TreeVAE defined in configs or in the command line. 29 | """ 30 | # Setting device on GPU if available, else CPU 31 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 32 | 33 | # Additional info when using cuda 34 | if device.type == 'cuda': 35 | print("Using", torch.cuda.get_device_name(0)) 36 | else: 37 | print("No GPU available") 38 | 39 | # Set paths 40 | project_dir = Path(__file__).absolute().parent 41 | timestr = time.strftime("%Y%m%d-%H%M%S") 42 | ex_name = "{}_{}".format(str(timestr), uuid.uuid4().hex[:5]) 43 | experiment_path = configs['globals']['results_dir'] / configs['data']['data_name'] / ex_name 44 | experiment_path.mkdir(parents=True) 45 | os.makedirs(os.path.join(project_dir, '../models/logs', ex_name)) 46 | print("Experiment path: ", experiment_path) 47 | 48 | # Wandb 49 | os.environ['WANDB_CACHE_DIR'] = os.path.join(project_dir, '../wandb', '.cache', 'wandb') 50 | os.environ["WANDB_SILENT"] = "true" 51 | 52 | # ADD YOUR WANDB ENTITY 53 | wandb.init( 54 | project="treevae", 55 | entity="test", 56 | config=configs, 57 | mode=configs['globals']['wandb_logging'] 58 | ) 59 | 60 | if configs['globals']['wandb_logging'] in ['online', 'disabled']: 61 | wandb.run.name = wandb.run.name.split("-")[-1] + "-"+ configs['run_name'] 62 | elif configs['globals']['wandb_logging'] == 'offline': 63 | wandb.run.name = configs['run_name'] 64 | else: 65 | raise ValueError('wandb needs to be set to online, offline or disabled.') 66 | 67 | # Reproducibility 68 | reset_random_seeds(configs['globals']['seed']) 69 | 70 | # Generate a new dataset each run 71 | trainset, trainset_eval, testset = get_data(configs) 72 | 73 | # Run the full training of treeVAE model, including the growing of the tree 74 | model = run_tree(trainset, trainset_eval, testset, device, configs) 75 | 76 | # Save model 77 | if configs['globals']['save_model']: 78 | print("\nSaving weights at ", experiment_path) 79 | torch.save(model.state_dict(), experiment_path / 'model_weights.pt') 80 | 81 | # Evaluation of TreeVAE 82 | print("\n" * 2) 83 | print("Evaluation") 84 | print("\n" * 2) 85 | val_tree(trainset_eval, testset, model, device, experiment_path, configs) 86 | wandb.finish(quiet=True) 87 | return 88 | -------------------------------------------------------------------------------- /train/train_tree.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training function of TreeVAE and SmallTreeVAE. 3 | """ 4 | import wandb 5 | import numpy as np 6 | import gc 7 | import torch 8 | import torch.optim as optim 9 | 10 | from utils.training_utils import train_one_epoch, validate_one_epoch, AnnealKLCallback, Custom_Metrics, \ 11 | get_ind_small_tree, compute_growing_leaf, compute_pruning_leaf, get_optimizer, predict 12 | from utils.data_utils import get_gen 13 | from utils.model_utils import return_list_tree, construct_data_tree 14 | from models.model import TreeVAE 15 | from models.model_smalltree import SmallTreeVAE 16 | 17 | 18 | def run_tree(trainset, trainset_eval, testset, device, configs): 19 | """ 20 | Run the TreeVAE model as defined in the config setting. The method will first train a TreeVAE model with initial 21 | depth defined in config (initial_depth). After training TreeVAE for epochs=num_epochs, if grow=True then it will 22 | start the iterative growing schedule. At each step, a SmallTreeVAE will be trained for num_epochs_smalltree and 23 | attached to the selected leaf of TreeVAE. The resulting TreeVAE will then grow at each step and will be finetuned 24 | throughout the growing procedure for num_epochs_intermediate_fulltrain and at the end of the growing procedure for 25 | num_epochs_finetuning. 26 | 27 | Parameters 28 | ---------- 29 | trainset: torch.utils.data.Dataset 30 | The train dataset 31 | trainset_eval: torch.utils.data.Dataset 32 | The validation dataset 33 | testset: torch.utils.data.Dataset 34 | The test dataset 35 | device: torch.device 36 | The device in which to validate the model 37 | configs: dict 38 | The config setting for training and validating TreeVAE defined in configs or in the command line 39 | 40 | Returns 41 | ------ 42 | models.model.TreeVAE 43 | The trained TreeVAE model 44 | """ 45 | 46 | graph_mode = not configs['globals']['eager_mode'] 47 | gen_train = get_gen(trainset, configs, validation=False, shuffle=True) 48 | gen_train_eval = get_gen(trainset_eval, configs, validation=True, shuffle=False) 49 | gen_test = get_gen(testset, configs, validation=True, shuffle=False) 50 | _ = gc.collect() 51 | 52 | # Define model & optimizer 53 | model = TreeVAE(**configs['training']) 54 | model.to(device) 55 | 56 | if graph_mode: 57 | model = torch.compile(model) 58 | 59 | optimizer = get_optimizer(model, configs) 60 | 61 | # Initialize schedulers 62 | lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=configs['training']['decay_stepsize'], 63 | gamma=configs['training']['decay_lr']) 64 | alpha_scheduler = AnnealKLCallback(model, decay=configs['training']['decay_kl'], 65 | start=configs['training']['kl_start']) 66 | 67 | # Initialize Metrics 68 | metrics_calc_train = Custom_Metrics(device).to(device) 69 | metrics_calc_val = Custom_Metrics(device).to(device) 70 | 71 | ################################# TRAINING TREEVAE with depth defined in config ################################# 72 | 73 | # Training the initial tree 74 | for epoch in range(configs['training']['num_epochs']): # loop over the dataset multiple times 75 | train_one_epoch(gen_train, model, optimizer, metrics_calc_train, epoch, device) 76 | validate_one_epoch(gen_test, model, metrics_calc_val, epoch, device) 77 | lr_scheduler.step() 78 | alpha_scheduler.on_epoch_end(epoch) 79 | _ = gc.collect() 80 | 81 | ################################# GROWING THE TREE ################################# 82 | 83 | # Start the growing loop of the tree 84 | # Compute metrics and set node.expand False for the nodes that should not grow 85 | # This loop goes layer-wise 86 | grow = configs['training']['grow'] 87 | initial_depth = configs['training']['initial_depth'] 88 | max_depth = len(configs['training']['mlp_layers']) - 1 89 | if initial_depth >= max_depth: 90 | grow = False 91 | growing_iterations = 0 92 | while grow and growing_iterations < 150: 93 | 94 | # full model finetuning during growing after every 3 splits 95 | if configs['training']['num_epochs_intermediate_fulltrain']>0: 96 | if growing_iterations != 0 and growing_iterations % 3 == 0: 97 | # Initialize optimizer and schedulers 98 | optimizer = get_optimizer(model, configs) 99 | lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=configs['training']['decay_stepsize'], 100 | gamma=configs['training']['decay_lr']) 101 | alpha_scheduler = AnnealKLCallback(model, decay=configs['training']['decay_kl'], 102 | start=configs['training']['kl_start']) 103 | 104 | # Training the initial split 105 | print('\nTree intermediate finetuning\n') 106 | for epoch in range(configs['training']['num_epochs_intermediate_fulltrain']): 107 | train_one_epoch(gen_train, model, optimizer, metrics_calc_train, epoch, device) 108 | validate_one_epoch(gen_test, model, metrics_calc_val, epoch, device) 109 | lr_scheduler.step() 110 | alpha_scheduler.on_epoch_end(epoch) 111 | _ = gc.collect() 112 | 113 | # extract information of leaves 114 | node_leaves_train = predict(gen_train_eval, model, device, 'node_leaves') 115 | node_leaves_test = predict(gen_test, model, device, 'node_leaves') 116 | 117 | # compute which leaf to grow and split 118 | ind_leaf, leaf, n_effective_leaves = compute_growing_leaf(gen_train_eval, model, node_leaves_train, max_depth, 119 | configs['training']['batch_size'], 120 | max_leaves=configs['training']['num_clusters_tree']) 121 | if ind_leaf == None: 122 | break 123 | else: 124 | print('\nGrowing tree: Leaf %d at depth %d\n' % (ind_leaf, leaf['depth'])) 125 | depth, node = leaf['depth'], leaf['node'] 126 | 127 | # get subset of data that has high prob. of falling in subtree 128 | ind_train = get_ind_small_tree(node_leaves_train[ind_leaf], n_effective_leaves) 129 | ind_test = get_ind_small_tree(node_leaves_test[ind_leaf], n_effective_leaves) 130 | gen_train_small = get_gen(trainset, configs, shuffle=True, smalltree=True, smalltree_ind=ind_train) 131 | gen_test_small = get_gen(testset, configs, shuffle=False, validation=True, smalltree=True, 132 | smalltree_ind=ind_test) 133 | 134 | # preparation for the smalltree training 135 | # initialize the smalltree 136 | small_model = SmallTreeVAE(depth=depth+1, **configs['training']) 137 | small_model.to(device) 138 | if graph_mode: 139 | small_model = torch.compile(small_model) 140 | 141 | # Optimizer for smalltree 142 | optimizer = get_optimizer(small_model, configs) 143 | 144 | # Initialize schedulers 145 | lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=configs['training']['decay_stepsize'], 146 | gamma=configs['training']['decay_lr']) 147 | alpha_scheduler = AnnealKLCallback(small_model, decay=configs['training']['decay_kl'], 148 | start=configs['training']['kl_start']) 149 | 150 | # Training the smalltree subsplit 151 | for epoch in range(configs['training']['num_epochs_smalltree']): 152 | train_one_epoch(gen_train_small, model, optimizer, metrics_calc_train, epoch, device, train_small_tree=True, 153 | small_model=small_model, ind_leaf=ind_leaf) 154 | validate_one_epoch(gen_test_small, model, metrics_calc_val, epoch, device, train_small_tree=True, 155 | small_model=small_model, ind_leaf=ind_leaf) 156 | lr_scheduler.step() 157 | alpha_scheduler.on_epoch_end(epoch) 158 | _ = gc.collect() 159 | 160 | # attach smalltree to full tree by assigning decisions and adding new children nodes to full tree 161 | model.attach_smalltree(node, small_model) 162 | 163 | # Check if reached the max number of effective leaves before finetuning unnecessarily 164 | if n_effective_leaves + 1 == configs['training']['num_clusters_tree']: 165 | node_leaves_train = predict(gen_train_eval, model, device, 'node_leaves') 166 | _, _, max_growth = compute_growing_leaf(gen_train_eval, model, node_leaves_train, max_depth, 167 | configs['training']['batch_size'], 168 | max_leaves=configs['training']['num_clusters_tree'], check_max=True) 169 | if max_growth is True: 170 | break 171 | 172 | growing_iterations += 1 173 | 174 | # The growing loop of the tree is concluded! 175 | # check whether we need to prune the final tree and log pre-pruning dendrogram 176 | prune = configs['training']['prune'] 177 | if prune: 178 | node_leaves_test, prob_leaves_test = predict(gen_test, model, device, 'node_leaves', 'prob_leaves') 179 | if len(node_leaves_test)<2: 180 | prune = False 181 | else: 182 | print('\nStarting pruning!\n') 183 | yy = np.squeeze(np.argmax(prob_leaves_test, axis=-1)) 184 | y_test = testset.dataset.targets[testset.indices] 185 | data_tree = construct_data_tree(model, y_predicted=yy, y_true=y_test, n_leaves=len(node_leaves_test), 186 | data_name=configs['data']['data_name']) 187 | 188 | table = wandb.Table(columns=["node_id", "node_name", "parent", "size"], data=data_tree) 189 | fields = {"node_name": "node_name", "node_id": "node_id", "parent": "parent", "size": "size"} 190 | dendro = wandb.plot_table(vega_spec_name="stacey/flat_tree", data_table=table, fields=fields) 191 | wandb.log({"dendogram_pre_pruned": dendro}) 192 | 193 | # prune the tree 194 | while prune: 195 | # check pruning conditions 196 | node_leaves_train = predict(gen_train_eval, model, device, 'node_leaves') 197 | ind_leaf, leaf = compute_pruning_leaf(model, node_leaves_train) 198 | 199 | if ind_leaf == None: 200 | print('\nPruning finished!\n') 201 | break 202 | else: 203 | # prune leaves and internal nodes without children 204 | print(f'\nPruning leaf {ind_leaf}!\n') 205 | current_node = leaf['node'] 206 | while all(child is None for child in [current_node.left, current_node.right]): 207 | if current_node.parent is not None: 208 | parent = current_node.parent 209 | # root does not get pruned 210 | else: 211 | break 212 | parent.prune_child(current_node) 213 | current_node = parent 214 | 215 | 216 | # reinitialize model 217 | transformations, routers, denses, decoders, routers_q = return_list_tree(model.tree) 218 | model.decisions_q = routers_q 219 | model.transformations = transformations 220 | model.decisions = routers 221 | model.denses = denses 222 | model.decoders = decoders 223 | model.depth = model.compute_depth() 224 | _ = gc.collect() 225 | 226 | ################################# FULL MODEL FINETUNING ################################# 227 | 228 | 229 | print('\n*****************model depth %d******************\n' % (model.depth)) 230 | print('\n*****************model finetuning******************\n') 231 | 232 | # Initialize optimizer and schedulers 233 | optimizer = get_optimizer(model, configs) 234 | lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=configs['training']['decay_stepsize'], gamma=configs['training']['decay_lr']) 235 | alpha_scheduler = AnnealKLCallback(model, decay=max(0.01,1/max(1,configs['training']['num_epochs_finetuning']-1)), start=configs['training']['kl_start']) 236 | # finetune the full tree 237 | print('\nTree final finetuning\n') 238 | for epoch in range(configs['training']['num_epochs_finetuning']): # loop over the dataset multiple times 239 | train_one_epoch(gen_train, model, optimizer, metrics_calc_train, epoch, device) 240 | validate_one_epoch(gen_test, model, metrics_calc_val, epoch, device) 241 | lr_scheduler.step() 242 | alpha_scheduler.on_epoch_end(epoch) 243 | _ = gc.collect() 244 | 245 | return model 246 | 247 | 248 | -------------------------------------------------------------------------------- /train/validate_tree.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import numpy as np 3 | from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score 4 | import gc 5 | import yaml 6 | import torch 7 | import scipy 8 | from tqdm import tqdm 9 | 10 | from utils.data_utils import get_gen 11 | from utils.utils import cluster_acc, dendrogram_purity, leaf_purity 12 | from utils.training_utils import compute_leaves, validate_one_epoch, Custom_Metrics, predict 13 | from utils.model_utils import construct_data_tree 14 | from models.losses import loss_reconstruction_cov_mse_eval 15 | 16 | 17 | def val_tree(trainset, testset, model, device, experiment_path, configs): 18 | """ 19 | Run the validation of a trained instance of TreeVAE on both the train and test datasets. All final results and 20 | validations will be stored in Wandb, while the most important ones will be also printed out in the terminal. 21 | 22 | Parameters 23 | ---------- 24 | trainset: torch.utils.data.Dataset 25 | The train dataset 26 | testset: torch.utils.data.Dataset 27 | The test dataset 28 | model: models.model.TreeVAE 29 | The trained TreeVAE model 30 | device: torch.device 31 | The device in which to validate the model 32 | experiment_path: str 33 | The experimental path where to store the tree 34 | configs: dict 35 | The config setting for training and validating TreeVAE defined in configs or in the command line 36 | """ 37 | 38 | ############ Training set performance ############ 39 | 40 | # get the data loader 41 | gen_train_eval = get_gen(trainset, configs, validation=True, shuffle=False) 42 | y_train = trainset.dataset.targets[trainset.indices].numpy() 43 | # compute the leaf probabilities 44 | prob_leaves_train = predict(gen_train_eval, model, device, 'prob_leaves') 45 | _ = gc.collect() 46 | # compute the predicted cluster 47 | y_train_pred = np.squeeze(np.argmax(prob_leaves_train, axis=-1)).numpy() 48 | # compute clustering metrics 49 | acc, idx = cluster_acc(y_train, y_train_pred, return_index=True) 50 | nmi = normalized_mutual_info_score(y_train, y_train_pred) 51 | ari = adjusted_rand_score(y_train, y_train_pred) 52 | wandb.log({"Train Accuracy": acc, "Train Normalized Mutual Information": nmi, "Train Adjusted Rand Index": ari}) 53 | # compute confusion matrix 54 | swap = dict(zip(range(len(idx)), idx)) 55 | y_wandb = np.array([swap[i] for i in y_train_pred], dtype=np.uint8) 56 | wandb.log({"Train_confusion_matrix": 57 | wandb.plot.confusion_matrix(probs=None, y_true=y_train, preds=y_wandb, class_names=range(len(idx)))}) 58 | 59 | ############ Test set performance ############ 60 | 61 | # get the data loader 62 | gen_test = get_gen(testset, configs, validation=True, shuffle=False) 63 | y_test = testset.dataset.targets[testset.indices].numpy() 64 | # compute one validation pass through the test set to log losses 65 | metrics_calc_test = Custom_Metrics(device) 66 | validate_one_epoch(gen_test, model, metrics_calc_test, 0, device, test=True) 67 | _ = gc.collect() 68 | # predict the leaf probabilities and the leaves 69 | node_leaves_test, prob_leaves_test = predict(gen_test, model, device, 'node_leaves', 'prob_leaves') 70 | _ = gc.collect() 71 | # compute the predicted cluster 72 | y_test_pred = np.squeeze(np.argmax(prob_leaves_test, axis=-1)).numpy() 73 | # Calculate clustering metrics 74 | acc, idx = cluster_acc(y_test, y_test_pred, return_index=True) 75 | nmi = normalized_mutual_info_score(y_test, y_test_pred) 76 | ari = adjusted_rand_score(y_test, y_test_pred) 77 | wandb.log({"Test Accuracy": acc, "Test Normalized Mutual Information": nmi, "Test Adjusted Rand Index": ari}) 78 | # Calculate confusion matrix 79 | swap = dict(zip(range(len(idx)), idx)) 80 | y_wandb = np.array([swap[i] for i in y_test_pred], dtype=np.uint8) 81 | wandb.log({"Test_confusion_matrix": wandb.plot.confusion_matrix(probs=None, 82 | y_true=y_test, preds=y_wandb, 83 | class_names=range(len(idx)))}) 84 | 85 | # Determine indices of samples that fall into each leaf for Dendogram Purity & Leaf Purity 86 | leaves = compute_leaves(model.tree) 87 | ind_samples_of_leaves = [] 88 | for i in range(len(leaves)): 89 | ind_samples_of_leaves.append([leaves[i]['node'], np.where(y_test_pred == i)[0]]) 90 | # Calculate leaf and dedrogram purity 91 | dp = dendrogram_purity(model.tree, y_test, ind_samples_of_leaves) 92 | lp = leaf_purity(model.tree, y_test, ind_samples_of_leaves) 93 | # Note: Only comparable DP & LP values wrt baselines if they have the same n_leaves for all methods 94 | wandb.log({"Test Dendrogram Purity": dp, "Test Leaf Purity": lp}) 95 | 96 | # Save the tree structure of TreeVAE and log it 97 | data_tree = construct_data_tree(model, y_predicted=y_test_pred, y_true=y_test, n_leaves=len(node_leaves_test), 98 | data_name=configs['data']['data_name']) 99 | 100 | if configs['globals']['save_model']: 101 | with open(experiment_path / 'data_tree.npy', 'wb') as save_file: 102 | np.save(save_file, data_tree) 103 | with open(experiment_path / 'config.yaml', 'w', encoding='utf8') as outfile: 104 | yaml.dump(configs, outfile, default_flow_style=False, allow_unicode=True) 105 | 106 | table = wandb.Table(columns=["node_id", "node_name", "parent", "size"], data=data_tree) 107 | fields = {"node_name": "node_name", "node_id": "node_id", "parent": "parent", "size": "size"} 108 | dendro = wandb.plot_table(vega_spec_name="stacey/flat_tree", data_table=table, fields=fields) 109 | wandb.log({"dendogram_final": dendro}) 110 | 111 | # Printing important results 112 | print(np.unique(y_test_pred, return_counts=True)) 113 | print("Accuracy:", acc) 114 | print("Normalized Mutual Information:", nmi) 115 | print("Adjusted Rand Index:", ari) 116 | print("Dendrogram Purity:", dp) 117 | print("Leaf Purity:", lp) 118 | print("Digits", np.unique(y_test)) 119 | 120 | # Compute the log-likehood of the test data 121 | # ATTENTION it might take a while! If not interested disable the setting in configs 122 | if configs['training']['compute_ll']: 123 | compute_likelihood(testset, model, device, configs) 124 | return 125 | 126 | 127 | def compute_likelihood(testset, model, device, configs): 128 | """ 129 | Compute the approximated log-likelihood calculated using 1000 importance-weighted samples. 130 | 131 | Parameters 132 | ---------- 133 | testset: torch.utils.data.Dataset 134 | The test dataset 135 | model: models.model.TreeVAE 136 | The trained TreeVAE model 137 | device: torch.device 138 | The device in which to validate the model 139 | configs: dict 140 | The config setting for training and validating TreeVAE defined in configs or in the command line 141 | """ 142 | ESTIMATION_SAMPLES = 1000 143 | gen_test = get_gen(testset, configs, validation=True, shuffle=False) 144 | print('\nComputing the log likelihood.... it might take a while.') 145 | if configs['training']['activation'] == 'sigmoid': 146 | elbo = np.zeros((len(testset), ESTIMATION_SAMPLES)) 147 | for j in tqdm(range(ESTIMATION_SAMPLES)): 148 | elbo[:, j] = predict(gen_test, model, device, 'elbo') 149 | _ = gc.collect() 150 | elbo_new = elbo[:, :ESTIMATION_SAMPLES] 151 | log_likel = np.log(1 / ESTIMATION_SAMPLES) + scipy.special.logsumexp(-elbo_new, axis=1) 152 | marginal_log_likelihood = np.sum(log_likel) / len(testset) 153 | wandb.log({"test log-likelihood": marginal_log_likelihood}) 154 | print("Test log-likelihood", marginal_log_likelihood) 155 | output_elbo, output_rec_loss = predict(gen_test, model, device, 'elbo', 'rec_loss') 156 | print('Test ELBO:', -torch.mean(output_elbo)) 157 | print('Test Reconstruction Loss:', torch.mean(output_rec_loss)) 158 | 159 | elif configs['training']['activation'] == 'mse': 160 | # Correct calculation of ELBO and Loglikelihood for 3channel images without assuming diagonal gaussian for 161 | # reconstruction 162 | old_loss = model.loss 163 | model.loss = loss_reconstruction_cov_mse_eval 164 | # Note that for comparability to other papers, one might want to add Uniform(0,1) noise to the input images 165 | # (in 0,255), to go from the discrete to the assumed continuous inputs 166 | # x_test_elbo = x_test * 255 167 | # x_test_elbo = (x_test_elbo + tfd.Uniform().sample(x_test_elbo.shape)) / 256 168 | output_elbo, output_rec_loss = predict(gen_test, model, device, 'elbo', 'rec_loss') 169 | nelbo = torch.mean(output_elbo) 170 | nelbo_bpd = nelbo / (torch.log(torch.tensor(2)) * configs['training']['inp_shape']) + 8 # Add 8 to account normalizing of inputs 171 | model.loss = old_loss 172 | elbo = np.zeros((len(testset), ESTIMATION_SAMPLES)) 173 | for j in range(ESTIMATION_SAMPLES): 174 | # x_test_elbo = x_test * 255 175 | # x_test_elbo = (x_test_elbo + tfd.Uniform().sample(x_test_elbo.shape)) / 256 176 | output_elbo = predict(gen_test, model, device, 'elbo') 177 | elbo[:, j] = output_elbo 178 | # Change to bpd 179 | elbo_new = elbo[:, :ESTIMATION_SAMPLES] 180 | log_likel = np.log(1 / ESTIMATION_SAMPLES) + scipy.special.logsumexp(-elbo_new, axis=1) 181 | marginal_log_likelihood = np.sum(log_likel) / len(testset) 182 | marginal_log_likelihood = marginal_log_likelihood / ( 183 | torch.log(torch.tensor(2)) * configs['training']['inp_shape']) - 8 184 | wandb.log({"test log-likelihood": marginal_log_likelihood}) 185 | print('Test Log-Likelihood Bound:', marginal_log_likelihood) 186 | print('Test ELBO:', -nelbo_bpd) 187 | print('Test Reconstruction Loss:', 188 | torch.mean(output_rec_loss) / (torch.log(torch.tensor(2)) * configs['training']['inp_shape']) + 8) 189 | model.loss = old_loss 190 | else: 191 | raise NotImplementedError 192 | return 193 | -------------------------------------------------------------------------------- /treevae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lauramanduchi/treevae/698bbe9d618c7355289ab10f9bee43e488db159d/treevae.png -------------------------------------------------------------------------------- /treevae.yml: -------------------------------------------------------------------------------- 1 | name: treevae 2 | channels: 3 | - anaconda 4 | - defaults 5 | - conda-forge 6 | - bioconda 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_kmp_llvm 10 | - anyio=3.5.0=py311h06a4308_0 11 | - appdirs=1.4.4=pyh9f0ad1d_0 12 | - argon2-cffi=21.3.0=pyhd3eb1b0_0 13 | - argon2-cffi-bindings=21.2.0=py311h5eee18b_0 14 | - asttokens=2.0.5=pyhd3eb1b0_0 15 | - attrs=23.1.0=py311h06a4308_0 16 | - backcall=0.2.0=pyhd3eb1b0_0 17 | - bleach=4.1.0=pyhd3eb1b0_0 18 | - brotlipy=0.7.0=py311h5eee18b_1002 19 | - bzip2=1.0.8=h7f98852_4 20 | - ca-certificates=2023.08.22=h06a4308_0 21 | - certifi=2023.7.22=py311h06a4308_0 22 | - cffi=1.15.1=py311h5eee18b_3 23 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 24 | - click=8.1.7=unix_pyh707e725_0 25 | - comm=0.1.2=py311h06a4308_0 26 | - cryptography=41.0.3=py311hdda0065_0 27 | - cuda-version=11.8=h70ddcb2_2 28 | - cudatoolkit=11.8.0=h4ba93d1_12 29 | - cudnn=8.8.0.121=h838ba91_3 30 | - debugpy=1.6.7=py311h6a678d5_0 31 | - decorator=5.1.1=pyhd3eb1b0_0 32 | - defusedxml=0.7.1=pyhd3eb1b0_0 33 | - docker-pycreds=0.4.0=py_0 34 | - entrypoints=0.4=py311h06a4308_0 35 | - executing=0.8.3=pyhd3eb1b0_0 36 | - filelock=3.12.4=pyhd8ed1ab_0 37 | - freetype=2.12.1=h4a9f257_0 38 | - gitdb=4.0.10=pyhd8ed1ab_0 39 | - gitpython=3.1.37=pyhd8ed1ab_0 40 | - gmp=6.2.1=h58526e2_0 41 | - gmpy2=2.1.2=py311h6a5fa03_1 42 | - icu=73.2=h59595ed_0 43 | - idna=3.4=py311h06a4308_0 44 | - ipykernel=6.25.0=py311h92b7b1e_0 45 | - ipython=8.15.0=py311h06a4308_0 46 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 47 | - ipywidgets=8.0.4=py311h06a4308_0 48 | - jedi=0.18.1=py311h06a4308_1 49 | - jinja2=3.1.2=pyhd8ed1ab_1 50 | - joblib=1.2.0=py311h06a4308_0 51 | - jsonschema=4.17.3=py311h06a4308_0 52 | - jupyter=1.0.0=pyhd8ed1ab_10 53 | - jupyter_client=8.1.0=py311h06a4308_0 54 | - jupyter_console=6.6.3=py311h06a4308_0 55 | - jupyter_core=5.3.0=py311h06a4308_0 56 | - jupyter_server=1.13.5=pyhd3eb1b0_0 57 | - jupyterlab_widgets=3.0.5=py311h06a4308_0 58 | - lcms2=2.15=hb7c19ff_3 59 | - ld_impl_linux-64=2.40=h41732ed_0 60 | - lerc=4.0.0=h27087fc_0 61 | - libabseil=20230802.1=cxx17_h59595ed_0 62 | - libblas=3.9.0=19_linux64_openblas 63 | - libcblas=3.9.0=19_linux64_openblas 64 | - libdeflate=1.19=hd590300_0 65 | - libexpat=2.5.0=hcb278e6_1 66 | - libffi=3.4.2=h7f98852_5 67 | - libgcc-ng=13.2.0=h807b86a_2 68 | - libgfortran-ng=13.2.0=h69a702a_2 69 | - libgfortran5=13.2.0=ha4646dd_2 70 | - libhwloc=2.9.3=default_h554bfaf_1009 71 | - libiconv=1.17=h166bdaf_0 72 | - libjpeg-turbo=3.0.0=hd590300_1 73 | - liblapack=3.9.0=19_linux64_openblas 74 | - libmagma=2.7.1=hc72dce7_6 75 | - libmagma_sparse=2.7.1=h8354cda_6 76 | - libnsl=2.0.1=hd590300_0 77 | - libopenblas=0.3.24=pthreads_h413a1c8_0 78 | - libpng=1.6.39=h5eee18b_0 79 | - libprotobuf=4.24.3=hf27288f_1 80 | - libsodium=1.0.18=h7b6447c_0 81 | - libsqlite=3.43.2=h2797004_0 82 | - libstdcxx-ng=13.2.0=h7e041cc_2 83 | - libtiff=4.6.0=ha9c0a0a_2 84 | - libuuid=2.38.1=h0b41bf4_0 85 | - libuv=1.46.0=hd590300_0 86 | - libwebp-base=1.3.2=h5eee18b_0 87 | - libxcb=1.15=h7f8727e_0 88 | - libxml2=2.11.5=h232c23b_1 89 | - libzlib=1.2.13=hd590300_5 90 | - llvm-openmp=17.0.2=h4dfa4b3_0 91 | - magma=2.7.1=ha770c72_6 92 | - markupsafe=2.1.3=py311h459d7ec_1 93 | - matplotlib-inline=0.1.6=py311h06a4308_0 94 | - mkl=2022.2.1=h84fe81f_16997 95 | - mpc=1.3.1=hfe3b2da_0 96 | - mpfr=4.2.0=hb012696_0 97 | - mpmath=1.3.0=pyhd8ed1ab_0 98 | - nbclassic=0.5.5=py311h06a4308_0 99 | - nbformat=5.9.2=py311h06a4308_0 100 | - nccl=2.19.3.1=h6103f9b_0 101 | - ncurses=6.4=hcb278e6_0 102 | - nest-asyncio=1.5.6=py311h06a4308_0 103 | - networkx=3.1=pyhd8ed1ab_0 104 | - notebook=6.5.4=py311h06a4308_0 105 | - notebook-shim=0.2.2=py311h06a4308_0 106 | - numpy=1.26.0=py311h64a7726_0 107 | - openjpeg=2.5.0=h488ebb8_3 108 | - openssl=3.1.3=hd590300_0 109 | - packaging=23.1=py311h06a4308_0 110 | - pandoc=2.12=h06a4308_3 111 | - pandocfilters=1.5.0=pyhd3eb1b0_0 112 | - parso=0.8.3=pyhd3eb1b0_0 113 | - pathtools=0.1.2=py_1 114 | - pexpect=4.8.0=pyhd3eb1b0_3 115 | - pickleshare=0.7.5=pyhd3eb1b0_1003 116 | - pillow=10.1.0=py311ha6c5da5_0 117 | - pip=23.3=pyhd8ed1ab_0 118 | - platformdirs=3.10.0=py311h06a4308_0 119 | - prometheus_client=0.14.1=py311h06a4308_0 120 | - prompt-toolkit=3.0.36=py311h06a4308_0 121 | - prompt_toolkit=3.0.36=hd3eb1b0_0 122 | - protobuf=4.24.3=py311h46cbc50_1 123 | - psutil=5.9.5=py311h459d7ec_1 124 | - ptyprocess=0.7.0=pyhd3eb1b0_2 125 | - pure_eval=0.2.2=pyhd3eb1b0_0 126 | - pycparser=2.21=pyhd3eb1b0_0 127 | - pygments=2.15.1=py311h06a4308_1 128 | - pyopenssl=23.2.0=py311h06a4308_0 129 | - pyrsistent=0.18.0=py311h5eee18b_0 130 | - pysocks=1.7.1=py311h06a4308_0 131 | - python=3.11.6=hab00c5b_0_cpython 132 | - python-dateutil=2.8.2=pyhd3eb1b0_0 133 | - python-fastjsonschema=2.16.2=py311h06a4308_0 134 | - python_abi=3.11=4_cp311 135 | - pytorch=2.0.0=cuda112py311hf4e4fe6_303 136 | - pytorch-gpu=2.0.0=cuda112py311h398211c_303 137 | - pyyaml=6.0.1=py311h459d7ec_1 138 | - pyzmq=25.1.0=py311h6a678d5_0 139 | - qtconsole-base=5.4.4=pyha770c72_0 140 | - qtpy=2.4.0=pyhd8ed1ab_0 141 | - readline=8.2=h8228510_1 142 | - requests=2.31.0=py311h06a4308_0 143 | - scikit-learn=1.2.2=py311h6a678d5_1 144 | - scipy=1.11.3=py311h64a7726_1 145 | - send2trash=1.8.0=pyhd3eb1b0_1 146 | - sentry-sdk=1.32.0=pyhd8ed1ab_0 147 | - setproctitle=1.3.3=py311h459d7ec_0 148 | - setuptools=68.2.2=pyhd8ed1ab_0 149 | - six=1.16.0=pyh6c4a22f_0 150 | - sleef=3.5.1=h9b69904_2 151 | - smmap=3.0.5=pyh44b312d_0 152 | - sniffio=1.2.0=py311h06a4308_1 153 | - stack_data=0.2.0=pyhd3eb1b0_0 154 | - sympy=1.12=pypyh9d50eac_103 155 | - tbb=2021.10.0=h00ab1b0_1 156 | - terminado=0.17.1=py311h06a4308_0 157 | - testpath=0.6.0=py311h06a4308_0 158 | - threadpoolctl=2.2.0=pyh0d69192_0 159 | - tk=8.6.13=h2797004_0 160 | - torchmetrics=0.11.4=py311h92b7b1e_1 161 | - torchvision=0.15.2=cuda112py311h3f38234_2 162 | - tornado=6.3.3=py311h5eee18b_0 163 | - tqdm=4.65.0=py311h92b7b1e_0 164 | - traitlets=5.7.1=py311h06a4308_0 165 | - typing-extensions=4.8.0=hd8ed1ab_0 166 | - typing_extensions=4.8.0=pyha770c72_0 167 | - tzdata=2023c=h71feb2d_0 168 | - urllib3=1.26.16=py311h06a4308_0 169 | - wandb=0.15.12=pyhd8ed1ab_0 170 | - wcwidth=0.2.5=pyhd3eb1b0_0 171 | - webencodings=0.5.1=py311h06a4308_1 172 | - websocket-client=0.58.0=py311h06a4308_4 173 | - wheel=0.41.2=pyhd8ed1ab_0 174 | - widgetsnbextension=4.0.5=py311h06a4308_0 175 | - xz=5.2.6=h166bdaf_0 176 | - yaml=0.2.5=h7b6447c_0 177 | - zeromq=4.3.4=h2531618_0 178 | - zlib=1.2.13=hd590300_5 179 | - zstd=1.5.5=hfc55251_0 180 | - pip: 181 | - beautifulsoup4==4.12.2 182 | - jupyterlab-pygments==0.2.2 183 | - mistune==3.0.2 184 | - nbclient==0.8.0 185 | - nbconvert==7.9.2 186 | - soupsieve==2.5 187 | - tinycss2==1.2.1 188 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for data loading. 3 | """ 4 | import os 5 | import torch 6 | import torchvision 7 | import torchvision.transforms as T 8 | import numpy as np 9 | from torch.utils.data import TensorDataset, DataLoader, Subset, ConcatDataset 10 | from PIL import Image 11 | from sklearn.datasets import fetch_20newsgroups 12 | from sklearn.feature_extraction.text import TfidfVectorizer 13 | from sklearn.model_selection import train_test_split 14 | from utils.utils import reset_random_seeds 15 | 16 | def get_data(configs): 17 | """Compute and process the data specified in the configs file. 18 | 19 | Parameters 20 | ---------- 21 | configs : dict 22 | A dictionary of config settings, where the data_name, the number of clusters in the data and augmentation 23 | details are specified. 24 | 25 | Returns 26 | ------ 27 | list 28 | A list of three tensor datasets: trainset, trainset_eval, testset 29 | """ 30 | data_name = configs['data']['data_name'] 31 | augment = configs['training']['augment'] 32 | augmentation_method = configs['training']['augmentation_method'] 33 | n_classes = configs['data']['num_clusters_data'] 34 | 35 | data_path = './data/' 36 | 37 | if data_name == 'mnist': 38 | reset_random_seeds(configs['globals']['seed']) 39 | full_trainset = torchvision.datasets.MNIST(root=data_path, train=True, download=True, transform=T.ToTensor()) 40 | full_testset = torchvision.datasets.MNIST(root=data_path, train=False, download=True, transform=T.ToTensor()) 41 | 42 | # get only num_clusters digits 43 | indx_train, indx_test = select_subset(full_trainset.targets, full_testset.targets, n_classes) 44 | trainset = Subset(full_trainset, indx_train) 45 | trainset_eval = Subset(full_trainset, indx_train) 46 | testset = Subset(full_testset, indx_test) 47 | 48 | 49 | elif data_name == 'fmnist': 50 | reset_random_seeds(configs['globals']['seed']) 51 | full_trainset = torchvision.datasets.FashionMNIST(root=data_path, train=True, download=True, transform=T.ToTensor()) 52 | full_testset = torchvision.datasets.FashionMNIST(root=data_path, train=False, download=True, transform=T.ToTensor()) 53 | 54 | # get only num_clusters digits 55 | indx_train, indx_test = select_subset(full_trainset.targets, full_testset.targets, n_classes) 56 | trainset = Subset(full_trainset, indx_train) 57 | trainset_eval = Subset(full_trainset, indx_train) 58 | testset = Subset(full_testset, indx_test) 59 | 60 | 61 | elif data_name == 'news20': 62 | reset_random_seeds(configs['globals']['seed']) 63 | newsgroups_train = fetch_20newsgroups(subset='train') 64 | newsgroups_test = fetch_20newsgroups(subset='test') 65 | vectorizer = TfidfVectorizer(max_features=2000, dtype=np.float32) 66 | x_train = torch.from_numpy(vectorizer.fit_transform(newsgroups_train.data).toarray()) 67 | x_test = torch.from_numpy(vectorizer.transform(newsgroups_test.data).toarray()) 68 | y_train = torch.from_numpy(newsgroups_train.target) 69 | y_test = torch.from_numpy(newsgroups_test.target) 70 | 71 | # get only num_clusters digits 72 | indx_train, indx_test = select_subset(y_train, y_test, n_classes) 73 | trainset = Subset(TensorDataset(x_train, y_train), indx_train) 74 | trainset_eval = Subset(TensorDataset(x_train, y_train), indx_train) 75 | testset = Subset(TensorDataset(x_test, y_test), indx_test) 76 | trainset.dataset.targets = torch.tensor(trainset.dataset.tensors[1]) 77 | trainset_eval.dataset.targets = torch.tensor(trainset_eval.dataset.tensors[1]) 78 | testset.dataset.targets = torch.tensor(testset.dataset.tensors[1]) 79 | 80 | 81 | elif data_name == 'omniglot': 82 | reset_random_seeds(configs['globals']['seed']) 83 | 84 | transform_eval = T.Compose([ 85 | T.ToTensor(), 86 | T.Resize([28,28], antialias=True), 87 | ]) 88 | 89 | if augment and augmentation_method == ['simple']: 90 | transform = T.Compose([ 91 | T.ToTensor(), 92 | T.Resize([28,28], antialias=True), 93 | T.RandomAffine(degrees=10, translate=(1/28, 1/28), scale=(0.9, 1.1), shear=0.01, fill=1), 94 | ]) 95 | elif augment is False: 96 | transform = transform_eval 97 | else: 98 | raise NotImplementedError 99 | 100 | # Download the datasets and apply transformations 101 | trainset_premerge = torchvision.datasets.Omniglot(root=data_path, background=True, download=True, transform=transform) 102 | testset_premerge = torchvision.datasets.Omniglot(root=data_path, background=False, download=True, transform=transform) 103 | trainset_premerge_eval = torchvision.datasets.Omniglot(root=data_path, background=True, download=True, transform=transform_eval) 104 | testset_premerge_eval = torchvision.datasets.Omniglot(root=data_path, background=False, download=True, transform=transform_eval) 105 | 106 | # Get the corresponding labels y_train and y_test 107 | y_train_ind = torch.tensor([sample[1] for sample in trainset_premerge]) 108 | y_test_ind = torch.tensor([sample[1] for sample in testset_premerge]) 109 | 110 | # Create a list of all alphabet labels from both datasets 111 | alphabets = trainset_premerge._alphabets + testset_premerge._alphabets 112 | 113 | # Replace character labels by alphabet labels 114 | y_train_pre = [] 115 | y_test_pre = [] 116 | for value in y_train_ind: 117 | alphabet = trainset_premerge._characters[value].split("/")[0] 118 | alphabet_ind = alphabets.index(alphabet) 119 | y_train_pre.append(alphabet_ind) 120 | for value in y_test_ind: 121 | alphabet = testset_premerge._characters[value].split("/")[0] 122 | alphabet_ind = alphabets.index(alphabet) 123 | y_test_pre.append(alphabet_ind) 124 | 125 | y = np.array(y_train_pre + y_test_pre) 126 | 127 | # Select alphabets 128 | num_clusters = n_classes 129 | if num_clusters !=50: 130 | alphabets_selected = get_selected_omniglot_alphabets()[:num_clusters] 131 | alphabets_ind = [] 132 | for i in alphabets_selected: 133 | alphabets_ind.append(alphabets.index(i)) 134 | else: 135 | alphabets_ind = np.arange(50) 136 | 137 | indx = np.array([], dtype=int) 138 | for i in range(num_clusters): 139 | indx = np.append(indx, np.where(y == alphabets_ind[i])[0]) 140 | indx = np.sort(indx) 141 | 142 | # Split and stratify by digits 143 | digits_label = torch.concatenate([y_train_ind, y_test_ind+len(torch.unique(y_train_ind))]) 144 | indx_train, indx_test = train_test_split(indx, test_size=0.2, random_state=configs['globals']['seed'], stratify=digits_label[indx]) 145 | indx_train = np.sort(indx_train) 146 | indx_test = np.sort(indx_test) 147 | 148 | # Define alphabets as labels 149 | y = y+50 150 | for idx, alphabet in enumerate(alphabets_ind): 151 | y[y==alphabet+50] = idx 152 | 153 | # Define mapping from digit to label 154 | mapping_train = [] 155 | for value in torch.unique(y_train_ind): 156 | alphabet = trainset_premerge._characters[value].split("/")[0] 157 | alphabet_ind = alphabets.index(alphabet) 158 | mapping_train.append(alphabet_ind) 159 | mapping_test = [] 160 | for value in torch.unique(y_test_ind): 161 | alphabet = testset_premerge._characters[value].split("/")[0] 162 | alphabet_ind = alphabets.index(alphabet) 163 | mapping_test.append(alphabet_ind) 164 | 165 | custom_target_transform_train = T.Lambda(lambda y: mapping_train[y]) 166 | custom_target_transform_test = T.Lambda(lambda y: mapping_test[y]) 167 | 168 | trainset_premerge.target_transform = custom_target_transform_train 169 | trainset_premerge_eval.target_transform = custom_target_transform_train 170 | testset_premerge.target_transform = custom_target_transform_test 171 | testset_premerge_eval.target_transform = custom_target_transform_test 172 | 173 | # Define datasets 174 | fullset = ConcatDataset([trainset_premerge, testset_premerge]) 175 | fullset_eval = ConcatDataset([trainset_premerge_eval, testset_premerge_eval]) 176 | fullset.targets = torch.from_numpy(y) 177 | fullset_eval.targets = torch.from_numpy(y) 178 | trainset = Subset(fullset, indx_train) 179 | trainset_eval = Subset(fullset_eval, indx_train) 180 | testset = Subset(fullset_eval, indx_test) 181 | 182 | 183 | 184 | elif data_name in ['cifar10', 'cifar100', 'cifar10_vehicles', 'cifar10_animals']: 185 | reset_random_seeds(configs['globals']['seed']) 186 | aug_strength = 0.5 187 | 188 | 189 | transform_eval = T.Compose([ 190 | T.ToTensor(), 191 | ]) 192 | 193 | if augment is True: 194 | aug_transforms = T.Compose([ 195 | T.RandomResizedCrop(32, interpolation=Image.BICUBIC, scale=(0.2, 1.0)), 196 | T.RandomHorizontalFlip(), 197 | T.RandomApply([T.ColorJitter(0.8 * aug_strength, 0.8 * aug_strength, 0.8 * aug_strength, 0.2 * aug_strength)], p=0.8), 198 | T.RandomGrayscale(p=0.2), 199 | T.ToTensor(), 200 | ]) 201 | if augmentation_method == ['simple']: 202 | transform = aug_transforms 203 | else: 204 | transform = ContrastiveTransformations(aug_transforms, n_views=2) 205 | else: 206 | transform = transform_eval 207 | 208 | if data_name == 'cifar100': 209 | if n_classes==20: 210 | dataset = CIFAR100Coarse 211 | else: 212 | dataset = torchvision.datasets.CIFAR100 213 | else: 214 | dataset = torchvision.datasets.CIFAR10 215 | 216 | full_trainset = dataset(root=data_path, train=True, download=True, transform=transform) 217 | full_trainset_eval = dataset(root=data_path, train=True, download=True, transform=transform_eval) 218 | full_testset = dataset(root=data_path, train=False, download=True, transform=transform_eval) 219 | 220 | if data_name == 'cifar10_vehicles': 221 | indx_train = [index for index, value in enumerate(full_trainset.targets) if value in (0, 1, 8, 9)] 222 | indx_test = [index for index, value in enumerate(full_testset.targets) if value in (0, 1, 8, 9)] 223 | elif data_name == 'cifar10_animals': 224 | indx_train = [index for index, value in enumerate(full_trainset.targets) if value not in (0, 1, 8, 9)] 225 | indx_test = [index for index, value in enumerate(full_testset.targets) if value not in (0, 1, 8, 9)] 226 | else: 227 | indx_train, indx_test = select_subset(full_trainset.targets, full_testset.targets, n_classes) 228 | 229 | trainset = Subset(full_trainset, indx_train) 230 | trainset_eval = Subset(full_trainset_eval, indx_train) 231 | testset = Subset(full_testset, indx_test) 232 | 233 | trainset.dataset.targets = torch.tensor(trainset.dataset.targets) 234 | trainset_eval.dataset.targets = torch.tensor(trainset_eval.dataset.targets) 235 | testset.dataset.targets = torch.tensor(testset.dataset.targets) 236 | 237 | elif data_name == 'celeba': 238 | reset_random_seeds(configs['globals']['seed']) 239 | aug_strength = 0.25 240 | 241 | # Slightly different reshaping from TF implementation to be inline with WAE 242 | transform_eval = T.Compose([ 243 | T.Lambda(lambda x: T.functional.crop(x, left=15, top=40, width=148, height=148)), 244 | T.Resize([64,64], antialias=True), 245 | T.ToTensor(), 246 | ]) 247 | if augment is True: 248 | aug_transforms = T.Compose([ 249 | T.Lambda(lambda x: T.functional.crop(x, left=15, top=40, width=148, height=148)), 250 | T.Resize([64,64], antialias=True), 251 | T.RandomResizedCrop(64, interpolation=Image.BICUBIC, scale = (3/4,1), ratio=(4/5,5/4)), 252 | T.RandomHorizontalFlip(), 253 | T.RandomApply([T.ColorJitter(0.8 * aug_strength, 0.8 * aug_strength, 0.8 * aug_strength)], p=0.8), 254 | T.ToTensor(), 255 | ]) 256 | if augmentation_method == ['simple']: 257 | transform = aug_transforms 258 | else: 259 | transform = ContrastiveTransformations(aug_transforms, n_views=2) 260 | else: 261 | transform = transform_eval 262 | 263 | 264 | full_trainset = torchvision.datasets.CelebA(root=data_path, split='train', target_type='attr', target_transform=lambda y: 0, download=True, transform=transform) 265 | full_trainset_eval = torchvision.datasets.CelebA(root=data_path, split='train', target_type='attr', target_transform=lambda y: 0, download=True, transform=transform_eval) 266 | full_testset = torchvision.datasets.CelebA(root=data_path, split='test', target_type='attr', target_transform=lambda y: 0, download=True, transform=transform_eval) 267 | 268 | indx_train = np.arange(len(full_trainset)) 269 | indx_test = np.arange(len(full_testset)) 270 | 271 | trainset = Subset(full_trainset, indx_train) 272 | trainset_eval = Subset(full_trainset_eval, indx_train) 273 | testset = Subset(full_testset, indx_test) 274 | trainset.dataset.targets = torch.zeros(trainset.dataset.attr.shape[0], dtype=torch.int8) 275 | trainset_eval.dataset.targets = torch.zeros(trainset.dataset.attr.shape[0], dtype=torch.int8) 276 | testset.dataset.targets = torch.zeros(trainset.dataset.attr.shape[0], dtype=torch.int8) 277 | 278 | else: 279 | raise NotImplementedError('This dataset is not supported!') 280 | 281 | assert trainset.__class__ == testset.__class__ == trainset_eval.__class__ == Subset 282 | return trainset, trainset_eval, testset 283 | 284 | 285 | def get_gen(dataset, configs, validation=False, shuffle=True, smalltree=False, smalltree_ind=None): 286 | """Given the dataset and a config file, it will output the DataLoader for training. 287 | 288 | Parameters 289 | ---------- 290 | dataset : torch.dataset 291 | A tensor dataset. 292 | configs : dict 293 | A dictionary of config settings. 294 | validation : bool, optional 295 | If set to True it will not drop the last batch, during training it is preferrable to drop the last batch if it 296 | has a different shape to avoid changing the batch normalization statistics. 297 | shuffle : bool, optional 298 | Whether to shuffle the dataset at every epoch. 299 | smalltree : bool, optional 300 | Whether the method should output the DataLoader for the small tree training, where a subset of training inputs 301 | are used. 302 | smalltree_ind : list 303 | For training the small tree during the growing strategy of TreeVAE, only a subset of training inputs will be 304 | used for efficiency. 305 | 306 | Returns 307 | ------ 308 | DataLoader 309 | The dataloader of the provided dataset. 310 | """ 311 | batch_size = configs['training']['batch_size'] 312 | drop_last = not validation 313 | try: 314 | num_workers = configs['parser']['num_workers'] 315 | except: 316 | num_workers = 6 317 | 318 | if smalltree: 319 | dataset = Subset(dataset, smalltree_ind) 320 | 321 | # Call the DataLoader when contrastive learning is used 322 | if configs['training']['augment'] and configs['training']['augmentation_method'] != ['simple'] and not validation: 323 | # As one datapoint leads to two samples, we have to half the batch size to retain same number of samples per batch 324 | assert batch_size % 2 == 0 325 | batch_size = batch_size // 2 326 | data_gen = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, 327 | persistent_workers=True, collate_fn=custom_collate_fn, drop_last=drop_last) 328 | 329 | # Call the DataLoader without contrastive learning 330 | else: 331 | data_gen = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, 332 | persistent_workers=True, drop_last=drop_last) 333 | return data_gen 334 | 335 | 336 | def select_subset(y_train, y_test, num_classes): 337 | # Select a random subset of labels where the number of different labels equal num_classes. 338 | digits = np.random.choice([i for i in range(len(np.unique(y_train)))], size=num_classes, replace=False) 339 | indx_train = np.array([], dtype=int) 340 | indx_test = np.array([], dtype=int) 341 | for i in range(num_classes): 342 | indx_train = np.append(indx_train, np.where(y_train == digits[i])[0]) 343 | indx_test = np.append(indx_test, np.where(y_test == digits[i])[0]) 344 | return np.sort(indx_train), np.sort(indx_test) 345 | 346 | 347 | def custom_collate_fn(batch): 348 | # Concatenate the augmented versions 349 | batch = torch.utils.data.default_collate(batch) 350 | batch[0] = batch[0].transpose(1, 0).reshape(-1,*batch[0].shape[2:]) 351 | batch[1] = batch[1].repeat(2) 352 | return batch 353 | 354 | 355 | class ContrastiveTransformations(object): 356 | 357 | def __init__(self, base_transforms, n_views=2): 358 | self.base_transforms = base_transforms 359 | self.n_views = n_views 360 | 361 | def __call__(self, x): 362 | return torch.stack([self.base_transforms(x) for i in range(self.n_views)],dim=0) 363 | 364 | 365 | def get_selected_omniglot_alphabets(): 366 | return ['Braille', 'Glagolitic', 'Old_Church_Slavonic_(Cyrillic)', 'Oriya', 'Bengali'] 367 | 368 | 369 | class CIFAR100Coarse(torchvision.datasets.CIFAR100): 370 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False): 371 | super(CIFAR100Coarse, self).__init__(root, train, transform, target_transform, download) 372 | 373 | # update labels 374 | coarse_labels = np.array([ 4, 1, 14, 8, 0, 6, 7, 7, 18, 3, 375 | 3, 14, 9, 18, 7, 11, 3, 9, 7, 11, 376 | 6, 11, 5, 10, 7, 6, 13, 15, 3, 15, 377 | 0, 11, 1, 10, 12, 14, 16, 9, 11, 5, 378 | 5, 19, 8, 8, 15, 13, 14, 17, 18, 10, 379 | 16, 4, 17, 4, 2, 0, 17, 4, 18, 17, 380 | 10, 3, 2, 12, 12, 16, 12, 1, 9, 19, 381 | 2, 10, 0, 1, 16, 12, 9, 13, 15, 13, 382 | 16, 19, 2, 4, 6, 19, 5, 5, 8, 19, 383 | 18, 1, 2, 15, 6, 0, 17, 8, 14, 13]) 384 | self.targets = coarse_labels[self.targets] 385 | 386 | # update classes 387 | self.classes = [['beaver', 'dolphin', 'otter', 'seal', 'whale'], 388 | ['aquarium_fish', 'flatfish', 'ray', 'shark', 'trout'], 389 | ['orchid', 'poppy', 'rose', 'sunflower', 'tulip'], 390 | ['bottle', 'bowl', 'can', 'cup', 'plate'], 391 | ['apple', 'mushroom', 'orange', 'pear', 'sweet_pepper'], 392 | ['clock', 'keyboard', 'lamp', 'telephone', 'television'], 393 | ['bed', 'chair', 'couch', 'table', 'wardrobe'], 394 | ['bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach'], 395 | ['bear', 'leopard', 'lion', 'tiger', 'wolf'], 396 | ['bridge', 'castle', 'house', 'road', 'skyscraper'], 397 | ['cloud', 'forest', 'mountain', 'plain', 'sea'], 398 | ['camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo'], 399 | ['fox', 'porcupine', 'possum', 'raccoon', 'skunk'], 400 | ['crab', 'lobster', 'snail', 'spider', 'worm'], 401 | ['baby', 'boy', 'girl', 'man', 'woman'], 402 | ['crocodile', 'dinosaur', 'lizard', 'snake', 'turtle'], 403 | ['hamster', 'mouse', 'rabbit', 'shrew', 'squirrel'], 404 | ['maple_tree', 'oak_tree', 'palm_tree', 'pine_tree', 'willow_tree'], 405 | ['bicycle', 'bus', 'motorcycle', 'pickup_truck', 'train'], 406 | ['lawn_mower', 'rocket', 'streetcar', 'tank', 'tractor']] -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for model. 3 | """ 4 | import numpy as np 5 | import torch.nn as nn 6 | 7 | def compute_posterior(mu_q, mu_p, sigma_q, sigma_p): 8 | epsilon = 1e-7 9 | z_sigma_q = 1 / (1 / (sigma_q + epsilon) + 1 / (sigma_p + epsilon)) 10 | z_mu_q = (mu_q / (sigma_q + epsilon) + 11 | mu_p / (sigma_p + epsilon)) * z_sigma_q 12 | return z_mu_q, z_sigma_q 13 | 14 | 15 | def construct_tree(transformations, routers, routers_q, denses, decoders): 16 | """ 17 | Construct the tree by passing a list of transformations and routers from root to leaves visiting nodes 18 | layer-wise from left to right 19 | 20 | :param transformations: list of transformations to attach to the nodes of the tree 21 | :param routers: list of decisions to attach to the nodes of the tree 22 | :param denses: list of dense network that from d of the bottom up compute node-specific q 23 | :param decoders: list of decoders to attach to the nodes, they should be set to None except the leaves 24 | :return: the root of the tree 25 | """ 26 | if len(transformations) != len(routers) and len(transformations) != len(denses) \ 27 | and len(transformations) != len(decoders): 28 | raise ValueError('Len transformation is different than len routers in constructing the tree.') 29 | root = Node(transformation=transformations[0], router=routers[0], routers_q=routers_q[0], dense=denses[0], decoder=decoders[0]) 30 | for i in range(1, len(transformations)): 31 | root.insert(transformation=transformations[i], router=routers[i], routers_q=routers_q[i], dense=denses[i], decoder=decoders[i]) 32 | return root 33 | 34 | 35 | class Node: 36 | def __init__(self, transformation, router, routers_q, dense, decoder=None, expand=True): 37 | self.left = None 38 | self.right = None 39 | self.parent = None 40 | self.transformation = transformation 41 | self.dense = dense 42 | self.router = router 43 | self.routers_q = routers_q 44 | self.decoder = decoder 45 | self.expand = expand 46 | 47 | def insert(self, transformation=None, router=None, routers_q=None, dense=None, decoder=None): 48 | queue = [] 49 | node = self 50 | queue.append(node) 51 | while len(queue) > 0: 52 | node = queue.pop(0) 53 | if node.expand: 54 | if node.left is None: 55 | node.left = Node(transformation, router, routers_q, dense, decoder) 56 | node.left.parent = node 57 | return 58 | elif node.right is None: 59 | node.right = Node(transformation, router, routers_q, dense, decoder) 60 | node.right.parent = node 61 | return 62 | else: 63 | queue.append(node.left) 64 | queue.append(node.right) 65 | print('\nAttention node has not been inserted!\n') 66 | return 67 | 68 | def prune_child(self, child): 69 | if child is self.left: 70 | self.left = None 71 | self.router = None 72 | 73 | elif child is self.right: 74 | self.right = None 75 | self.router = None 76 | 77 | else: 78 | raise ValueError("This is not my child! (Node is not a child of this parent.)") 79 | 80 | def return_list_tree(root): 81 | list_nodes = [root] 82 | denses = [] 83 | transformations = [] 84 | routers = [] 85 | routers_q = [] 86 | decoders = [] 87 | while len(list_nodes) != 0: 88 | current_node = list_nodes.pop(0) 89 | denses.append(current_node.dense) 90 | transformations.append(current_node.transformation) 91 | routers.append(current_node.router) 92 | routers_q.append(current_node.routers_q) 93 | decoders.append(current_node.decoder) 94 | if current_node.router is not None: 95 | node_left, node_right = current_node.left, current_node.right 96 | list_nodes.append(node_left) 97 | list_nodes.append(node_right) 98 | elif current_node.router is None and current_node.decoder is None: 99 | # We are in an internal node with pruned leaves and thus only have one child 100 | node_left, node_right = current_node.left, current_node.right 101 | child = node_left if node_left is not None else node_right 102 | list_nodes.append(child) 103 | return nn.ModuleList(transformations), nn.ModuleList(routers), nn.ModuleList(denses), nn.ModuleList(decoders), nn.ModuleList(routers_q) 104 | 105 | 106 | def construct_tree_fromnpy(model, data_tree, configs): 107 | from models.model_smalltree import SmallTreeVAE 108 | nodes = {0: {'node': model.tree, 'depth': 0}} 109 | 110 | for i in range(1, len(data_tree)-1): 111 | node_left = data_tree[i] 112 | node_right = data_tree[i + 1] 113 | id_node_left = node_left[0] 114 | id_node_right = node_right[0] 115 | 116 | if node_left[2] == node_right[2]: 117 | id_parent = node_left[2] 118 | 119 | parent = nodes[id_parent] 120 | node = parent['node'] 121 | depth = parent['depth'] 122 | 123 | new_depth = depth + 1 124 | 125 | small_model = SmallTreeVAE(new_depth+1, **configs['training']) 126 | 127 | node.router = small_model.decision 128 | node.routers_q = small_model.decision_q 129 | 130 | node.decoder = None 131 | n = [] 132 | for j in range(2): 133 | dense = small_model.denses[j] 134 | transformation = small_model.transformations[j] 135 | decoder = small_model.decoders[j] 136 | n.append(Node(transformation, None, None, dense, decoder)) 137 | 138 | node.left = n[0] 139 | node.right = n[1] 140 | 141 | nodes[id_node_left] = {'node': node.left, 'depth': new_depth} 142 | nodes[id_node_right] = {'node': node.right, 'depth': new_depth} 143 | elif data_tree[i][2] != data_tree[i - 1][2]: # Internal node w/ 1 child only 144 | id_parent = node_left[2] 145 | 146 | parent = nodes[id_parent] 147 | node = parent['node'] 148 | depth = parent['depth'] 149 | 150 | new_depth = depth + 1 151 | 152 | small_model = SmallTreeVAE(new_depth+1, **configs['training']) 153 | 154 | node.router = None 155 | node.routers_q = None 156 | 157 | node.decoder = None 158 | n = [] 159 | for j in range(1): 160 | dense = small_model.denses[j] 161 | transformation = small_model.transformations[j] 162 | decoder = small_model.decoders[j] 163 | n.append(Node(transformation, None, None, dense, decoder)) 164 | 165 | node.left = n[0] 166 | nodes[id_node_left] = {'node': node.left, 'depth': new_depth} 167 | 168 | transformations, routers, denses, decoders, routers_q = return_list_tree(model.tree) 169 | model.decisions_q = routers_q 170 | model.transformations = transformations 171 | model.decisions = routers 172 | model.denses = denses 173 | model.decoders = decoders 174 | model.depth = model.compute_depth() 175 | return model 176 | 177 | 178 | def construct_data_tree(model, y_predicted, y_true, n_leaves, data_name): 179 | list_nodes = [{'node':model.tree, 'id': 0, 'parent':None}] 180 | data = [] 181 | i = 0 182 | labels = [i for i in range(n_leaves)] 183 | while len(list_nodes) != 0: 184 | current_node = list_nodes.pop(0) 185 | if current_node['node'].router is not None: 186 | data.append([current_node['id'], str(current_node['id']), current_node['parent'], 10]) 187 | node_left, node_right = current_node['node'].left, current_node['node'].right 188 | i += 1 189 | list_nodes.append({'node':node_left, 'id': i, 'parent': current_node['id']}) 190 | i += 1 191 | list_nodes.append({'node':node_right, 'id': i, 'parent': current_node['id']}) 192 | elif current_node['node'].router is None and current_node['node'].decoder is None: 193 | # We are in an internal node with pruned leaves and will only add the non-pruned leaves 194 | data.append([current_node['id'], str(current_node['id']), current_node['parent'], 10]) 195 | node_left, node_right = current_node['node'].left, current_node['node'].right 196 | child = node_left if node_left is not None else node_right 197 | i += 1 198 | list_nodes.append({'node': child, 'id': i, 'parent': current_node['id']}) 199 | else: 200 | y_leaf = labels.pop(0) 201 | ind = np.where(y_predicted == y_leaf)[0] 202 | digits, counts = np.unique(y_true[ind], return_counts=True) 203 | tot = len(ind) 204 | if tot == 0: 205 | name = 'no digits' 206 | else: 207 | counts = np.round(counts / np.sum(counts), 2) 208 | ind = np.where(counts > 0.1)[0] 209 | name = ' ' 210 | for j in ind: 211 | if data_name == 'fmnist': 212 | items = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 213 | 'Bag', 'Boot'] 214 | name = name + str(items[digits[j]]) + ': ' + str(counts[j]) + ' ' 215 | elif data_name == 'cifar10': 216 | items = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 217 | 'truck'] 218 | name = name + str(items[digits[j]]) + ': ' + str(counts[j]) + ' ' 219 | elif data_name == 'news20': 220 | items = ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 221 | 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale','rec.autos', 222 | 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 223 | 'sci.electronics', 'sci.med', 'sci.space', 'soc.religion.christian', 224 | 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 225 | 'talk.religion.misc'] 226 | name = name + str(items[digits[j]]) + ': ' + str(counts[j]) + ' ' 227 | elif data_name == 'omniglot': 228 | from utils.data_utils import get_selected_omniglot_alphabets 229 | items = get_selected_omniglot_alphabets() 230 | if np.unique(y_true).shape[0]>len(items): 231 | items=np.arange(50) 232 | 233 | name = name + items[digits[j]] + ': ' + str(counts[j]) + ' ' 234 | else: 235 | name = name + str(digits[j]) + ': ' + str(counts[j]) + ' ' 236 | name = name + 'tot ' + str(tot) 237 | data.append([current_node['id'], name, current_node['parent'], 1]) 238 | return data 239 | -------------------------------------------------------------------------------- /utils/plotting_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.distributions as td 4 | from matplotlib import pyplot as plt 5 | from utils.model_utils import construct_tree, compute_posterior 6 | import re 7 | import networkx as nx 8 | from sklearn.decomposition import PCA 9 | 10 | 11 | 12 | def hierarchy_pos(G, root, levels=None, width=1., height=1.): 13 | ''' 14 | Encodes the hierarchy for the tree layout in a graph. 15 | From https://stackoverflow.com/questions/29586520/can-one-get-hierarchical-graphs-from-networkx-with-python-3 16 | If there is a cycle that is reachable from root, then this will see infinite recursion. 17 | G: the graph 18 | root: the root node 19 | levels: a dictionary 20 | key: level number (starting from 0) 21 | value: number of nodes in this level 22 | width: horizontal space allocated for drawing 23 | height: vertical space allocated for drawing''' 24 | TOTAL = "total" 25 | CURRENT = "current" 26 | def make_levels(levels, node=root, currentLevel=0, parent=None): 27 | """Compute the number of nodes for each level 28 | """ 29 | if not currentLevel in levels: 30 | levels[currentLevel] = {TOTAL : 0, CURRENT : 0} 31 | levels[currentLevel][TOTAL] += 1 32 | neighbors = G.neighbors(node) 33 | for neighbor in neighbors: 34 | if not neighbor == parent: 35 | levels = make_levels(levels, neighbor, currentLevel + 1, node) 36 | return levels 37 | 38 | def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0): 39 | dx = 1/levels[currentLevel][TOTAL] 40 | left = dx/2 41 | pos[node] = ((left + dx*levels[currentLevel][CURRENT])*width, vert_loc) 42 | levels[currentLevel][CURRENT] += 1 43 | neighbors = G.neighbors(node) 44 | for neighbor in neighbors: 45 | if not neighbor == parent: 46 | pos = make_pos(pos, neighbor, currentLevel + 1, node, vert_loc-vert_gap) 47 | return pos 48 | if levels is None: 49 | levels = make_levels({}) 50 | else: 51 | levels = {l:{TOTAL: levels[l], CURRENT:0} for l in levels} 52 | vert_gap = height / (max([l for l in levels])+1) 53 | return make_pos({}) 54 | 55 | 56 | def plot_tree_graph(data): 57 | 58 | # get a '/n' before every 'tot' in each second entry of data 59 | data = data.copy() 60 | for d in data: 61 | if d[3] == 1: 62 | #d[1] = d[1].replace('tot', '\ntot') 63 | pattern = r'(\w+:\s\d+\.\d+|\d+:\s\d+\.\d+|\w+\s\d+|\d+\s\d+|\w+:\s\d+|\d+:\s\d+|\w+:\s\d+\s\w+|\d+:\s\d+\s\w+|\w+\s\d+\s\w+|\d+\s\d+\s\w+|\w+:\s\d+\.\d+\s\w+|\d+:\s\d+\.\d+\s\w+)' 64 | 65 | # Split the string using the regular expression pattern 66 | result = re.findall(pattern, d[1]) 67 | 68 | # Join the resulting list to format it as desired 69 | d[1] = '\n'.join(result) 70 | 71 | # Create a directed graph 72 | G = nx.DiGraph() 73 | 74 | # Add nodes and edges to the graph 75 | for node in data: 76 | node_id, label, parent_id, node_type = node 77 | G.add_node(node_id, label=label, node_type=node_type) 78 | if parent_id is not None: 79 | G.add_edge(parent_id, node_id) 80 | 81 | # Get positions of graph nodes 82 | pos = hierarchy_pos(G, 0, levels=None, width=1, height=1) 83 | 84 | # get the labels of the nodes 85 | labels = nx.get_node_attributes(G, 'label') 86 | 87 | # Initialize node color and size lists 88 | node_colors = [] 89 | node_sizes = [] 90 | 91 | # Iterate through nodes to set colors and sizes 92 | for node_id, node_data in G.nodes(data=True): 93 | if G.out_degree(node_id) == 0: # Leaf nodes have out-degree 0 94 | node_colors.append('lightgreen') 95 | node_sizes.append(4000) 96 | 97 | else: 98 | node_colors.append('lightblue') 99 | node_sizes.append(1000) 100 | 101 | # Draw the graph with different node properties 102 | tree = plt.figure(figsize=(10, 5)) 103 | nx.draw(G, pos=pos, labels=labels, with_labels=True, node_size=node_sizes, node_color=node_colors, font_size=7) 104 | 105 | plt.show() 106 | 107 | 108 | 109 | def get_node_embeddings(model, x): 110 | assert model.training == False 111 | epsilon = 1e-7 112 | device = x.device 113 | 114 | # compute deterministic bottom up 115 | d = x 116 | encoders = [] 117 | 118 | for i in range(0, len(model.hidden_layers)): 119 | d, _, _ = model.bottom_up[i](d) 120 | # store the bottom-up layers for the top-down computation 121 | encoders.append(d) 122 | 123 | # Create a list to store node information 124 | node_info_list = [] 125 | 126 | # Create a list of nodes of the tree that need to be processed 127 | list_nodes = [{'node': model.tree, 'depth': 0, 'prob': torch.ones(x.size(0), device=device), 'z_parent_sample': None}] 128 | 129 | while len(list_nodes) != 0: 130 | # Store info regarding the current node 131 | current_node = list_nodes.pop(0) 132 | node, depth_level, prob = current_node['node'], current_node['depth'], current_node['prob'] 133 | z_parent_sample = current_node['z_parent_sample'] 134 | 135 | # Access deterministic bottom-up mu and sigma hat (computed above) 136 | d = encoders[-(1 + depth_level)] 137 | z_mu_q_hat, z_sigma_q_hat = node.dense(d) 138 | 139 | if depth_level == 0: 140 | z_mu_q, z_sigma_q = z_mu_q_hat, z_sigma_q_hat 141 | else: 142 | # The generative mu and sigma are the output of the top-down network given the sampled parent 143 | _, z_mu_p, z_sigma_p = node.transformation(z_parent_sample) 144 | z_mu_q, z_sigma_q = compute_posterior(z_mu_q_hat, z_mu_p, z_sigma_q_hat, z_sigma_p) 145 | 146 | # Compute sample z using mu_q and sigma_q 147 | z = td.Independent(td.Normal(z_mu_q, torch.sqrt(z_sigma_q + epsilon)), 1) 148 | z_sample = z.rsample() 149 | 150 | # Store information in the list 151 | node_info = {'prob': prob, 'z_sample': z_sample} 152 | node_info_list.append(node_info) 153 | 154 | if node.router is not None: 155 | prob_child_left_q = node.routers_q(d).squeeze() 156 | 157 | # We are not in a leaf, so we have to add the left and right child to the list 158 | prob_node_left, prob_node_right = prob * prob_child_left_q, prob * (1 - prob_child_left_q) 159 | 160 | node_left, node_right = node.left, node.right 161 | list_nodes.append( 162 | {'node': node_left, 'depth': depth_level + 1, 'prob': prob_node_left, 'z_parent_sample': z_sample}) 163 | list_nodes.append({'node': node_right, 'depth': depth_level + 1, 'prob': prob_node_right, 164 | 'z_parent_sample': z_sample}) 165 | 166 | elif node.decoder is None and (node.left is not None or node.right is not None): 167 | # We are in an internal node with pruned leaves and thus only have one child 168 | node_left, node_right = node.left, node.right 169 | child = node_left if node_left is not None else node_right 170 | list_nodes.append( 171 | {'node': child, 'depth': depth_level + 1, 'prob': prob, 'z_parent_sample': z_sample}) 172 | 173 | return node_info_list 174 | 175 | 176 | 177 | # Create a function to draw scatter plots as nodes 178 | def draw_scatter_node(node_id, node_embeddings, colors, ax, pca = True): 179 | 180 | # if list is empty --> node has been pruned 181 | if node_embeddings[node_id]['z_sample'] == []: 182 | # return empty plot 183 | ax.set_title(f"Node {node_id}") 184 | ax.set_xticks([]) 185 | ax.set_yticks([]) 186 | return 187 | 188 | z_sample = node_embeddings[node_id]['z_sample'] 189 | weights = node_embeddings[node_id]['prob'] 190 | 191 | if pca: 192 | pca_fit = PCA(n_components=2) 193 | z_sample = pca_fit.fit_transform(z_sample) 194 | 195 | 196 | ax.scatter(z_sample[:, 0], z_sample[:, 1], c=colors, cmap='tab10', alpha=weights, s = 0.25) 197 | ax.set_title(f"Node {node_id}") 198 | # no ticks 199 | ax.set_xticks([]) 200 | ax.set_yticks([]) 201 | 202 | 203 | def splits_to_right_and_left(node_id, data): 204 | # Initialize splits to right and left to 0 205 | splits_to_right = 0 206 | splits_to_left = 0 207 | 208 | # root node 209 | 210 | while True: 211 | # root node 212 | if node_id == 0: 213 | return splits_to_left, splits_to_right 214 | 215 | # previous node has same parent 216 | elif data[node_id-1][2] == data[node_id][2]: 217 | splits_to_right += 1 218 | node_id = data[node_id][2] 219 | 220 | else: 221 | splits_to_left += 1 222 | node_id = data[node_id][2] 223 | 224 | 225 | def get_depth(node_id, data): 226 | # Initialize the depth to 0 227 | depth = 0 228 | 229 | # Find the node in the data list 230 | node = next(node for node in data if node[0] == node_id) 231 | 232 | # Recursively calculate the depth 233 | if node[2] is not None: 234 | depth = 1 + get_depth(node[2], data) 235 | 236 | return depth 237 | 238 | 239 | # Create the tree graph with scatter plots as nodes 240 | def draw_tree_with_scatter_plots(data, node_embeddings, label_list, pca = True): 241 | 242 | # Create a directed graph 243 | G = nx.DiGraph() 244 | 245 | # Add nodes and edges to the graph 246 | for node in data: 247 | node_id, label, parent_id, node_type = node 248 | G.add_node(node_id, label=label, node_type=node_type) 249 | if parent_id is not None: 250 | G.add_edge(parent_id, node_id) 251 | 252 | # Get positions of graph nodes 253 | pos = hierarchy_pos(G, 0, levels=None, width=1, height=1) 254 | 255 | # get the labels of the nodes 256 | labels = nx.get_node_attributes(G, 'label') 257 | 258 | 259 | fig, ax = plt.subplots(figsize=(20, 10)) 260 | 261 | for node_id, node_data in G.nodes(data=True): 262 | x, y = pos[node_id] 263 | 264 | # Create a subplot for each node, centered on the node 265 | sub_ax = fig.add_axes([x, y+0.9, 0.1, 0.1]) 266 | draw_scatter_node(node_id, node_embeddings, label_list, sub_ax, pca) 267 | 268 | # Draw the lines between above nodes, need to consider the position of the subplots 269 | 270 | # first need a list of edges in the order of the nodes and the positions of the nodes 271 | # Calculate the positions of the connection lines 272 | # offset by -0.05 for each left split and by +0.05 for each right split 273 | 274 | node_positions = {} 275 | 276 | for node in data: 277 | node_id, label, parent_id, node_type = node 278 | x, y = pos[node_id] 279 | depth = get_depth(node_id, data) 280 | splits_to_left, splits_to_right = splits_to_right_and_left(node_id, data) 281 | 282 | # calculate the position of the node 283 | x = x - splits_to_left * 0.05 + splits_to_right * 0.05 + 0.05 284 | y = y + 1.1 - depth * 0.05 285 | 286 | node_positions[node_id] = (x, y) 287 | 288 | # draw the connection lines 289 | if parent_id is not None: 290 | x_parent, y_parent = node_positions[parent_id] 291 | ax.plot([x_parent, x], [y_parent, y], color='black', alpha=0.5) 292 | 293 | 294 | # Set the limits of the plot 295 | ax.set_ylim(0, 1) 296 | ax.set_xlim(0, 1) 297 | ax.axis('off') 298 | 299 | plt.show() 300 | 301 | 302 | -------------------------------------------------------------------------------- /utils/training_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for training. 3 | """ 4 | import torch 5 | import math 6 | import numpy as np 7 | import gc 8 | import wandb 9 | from tqdm import tqdm 10 | import torch.optim as optim 11 | from torchmetrics import Metric 12 | from sklearn.metrics.cluster import normalized_mutual_info_score 13 | from utils.utils import cluster_acc 14 | from torch.utils.data import TensorDataset 15 | 16 | 17 | def train_one_epoch(train_loader, model, optimizer, metrics_calc, epoch_idx, device, train_small_tree=False, 18 | small_model=None, ind_leaf=None): 19 | """ 20 | Train TreeVAE or SmallTreeVAE model for one epoch. 21 | 22 | Parameters 23 | ---------- 24 | train_loader: DataLoader 25 | The train data loader 26 | model: models.model.TreeVAE 27 | The TreeVAE model 28 | optimizer: optim 29 | The optimizer for training the model 30 | metrics_calc: Metric 31 | The metrics to keep track while training 32 | epoch_idx: int 33 | The current epoch 34 | device: torch.device 35 | The device in which to validate the model 36 | train_small_tree: bool 37 | If set to True, then the subtree (small_model) will be trained (and afterwords attached to model) 38 | small_model: models.model.SmallTreeVAE 39 | The SmallTreeVAE model (which is then attached to a selected leaf of TreeVAE) 40 | ind_leaf: int 41 | The index of the TreeVAE leaf where the small_model will be attached 42 | """ 43 | if train_small_tree: 44 | # if we train the small tree, then the full tree is frozen 45 | model.eval() 46 | small_model.train() 47 | model.return_bottomup[0] = True 48 | model.return_x[0] = True 49 | alpha = small_model.alpha 50 | else: 51 | # otherwise we are training the full tree 52 | model.train() 53 | alpha = model.alpha 54 | 55 | metrics_calc.reset() 56 | 57 | for batch_idx, batch in enumerate(tqdm(train_loader)): 58 | inputs, labels = batch 59 | inputs, labels = inputs.to(device), labels.to(device) 60 | # Zero your gradients for every batch 61 | optimizer.zero_grad() 62 | 63 | # Make predictions for this batch 64 | if train_small_tree: 65 | # Gradient-free pass of full tree 66 | with torch.no_grad(): 67 | outputs_full = model(inputs) 68 | x, node_leaves, bottom_up = outputs_full['input'], outputs_full['node_leaves'], outputs_full['bottom_up'] 69 | # Passing through subtree for updating its parameters 70 | outputs = small_model(x, node_leaves[ind_leaf]['z_sample'], node_leaves[ind_leaf]['prob'], bottom_up) 71 | outputs['kl_root'] = torch.tensor(0., device=device) 72 | else: 73 | outputs = model(inputs) 74 | 75 | # Compute the loss and its gradients 76 | rec_loss = outputs['rec_loss'] 77 | kl_losses = outputs['kl_root'] + outputs['kl_decisions'] + outputs['kl_nodes'] 78 | loss_value = rec_loss + alpha * kl_losses + outputs['aug_decisions'] 79 | loss_value.backward() 80 | 81 | # Adjust learning weights 82 | optimizer.step() 83 | 84 | # Store metrics 85 | # Note that y_pred is used for computing nmi. 86 | # During subtree training, the nmi is calculated relative to only the subtree. 87 | y_pred = outputs['p_c_z'].argmax(dim=-1) 88 | metrics_calc.update(loss_value, outputs['rec_loss'], outputs['kl_decisions'], outputs['kl_nodes'], 89 | outputs['kl_root'], outputs['aug_decisions'], 90 | (1 - torch.mean(y_pred.float()) if outputs['p_c_z'].shape[1] <= 2 else torch.tensor(0., 91 | device=device)), 92 | labels, y_pred) 93 | 94 | if train_small_tree: 95 | model.return_bottomup[0] = False 96 | model.return_x[0] = False 97 | 98 | # Calculate and log metrics 99 | metrics = metrics_calc.compute() 100 | metrics['alpha'] = alpha 101 | wandb.log({'train': metrics}) 102 | prints = f"Epoch {epoch_idx}, Train : " 103 | for key, value in metrics.items(): 104 | prints += f"{key}: {value:.3f} " 105 | print(prints) 106 | metrics_calc.reset() 107 | _ = gc.collect() 108 | return 109 | 110 | 111 | def validate_one_epoch(test_loader, model, metrics_calc, epoch_idx, device, test=False, train_small_tree=False, 112 | small_model=None, ind_leaf=None): 113 | model.eval() 114 | if train_small_tree: 115 | small_model.eval() 116 | model.return_bottomup[0] = True 117 | model.return_x[0] = True 118 | alpha = small_model.alpha 119 | else: 120 | alpha = model.alpha 121 | 122 | metrics_calc.reset() 123 | 124 | with torch.no_grad(): 125 | for batch_idx, batch in enumerate(tqdm(test_loader)): 126 | inputs, labels = batch 127 | inputs, labels = inputs.to(device), labels.to(device) 128 | # Make predictions for this batch 129 | if train_small_tree: 130 | # Sass of full tree 131 | outputs_full = model(inputs) 132 | x, node_leaves, bottom_up = outputs_full['input'], outputs_full['node_leaves'], outputs_full[ 133 | 'bottom_up'] 134 | # Passing through subtree 135 | outputs = small_model(x, node_leaves[ind_leaf]['z_sample'], node_leaves[ind_leaf]['prob'], bottom_up) 136 | outputs['kl_root'] = torch.tensor(0., device=device) 137 | else: 138 | outputs = model(inputs) 139 | 140 | # Compute the loss and its gradients 141 | rec_loss = outputs['rec_loss'] 142 | kl_losses = outputs['kl_root'] + outputs['kl_decisions'] + outputs['kl_nodes'] 143 | loss_value = rec_loss + alpha * kl_losses + outputs['aug_decisions'] 144 | 145 | # Store metrics 146 | y_pred = outputs['p_c_z'].argmax(dim=-1) 147 | metrics_calc.update(loss_value, outputs['rec_loss'], outputs['kl_decisions'], outputs['kl_nodes'], 148 | outputs['kl_root'], 149 | outputs['aug_decisions'], ( 150 | 1 - torch.mean(outputs['p_c_z'].argmax(dim=-1).float()) if outputs['p_c_z'].shape[ 151 | 1] <= 2 else torch.tensor( 152 | 0., device=device)), labels, y_pred) 153 | 154 | if train_small_tree: 155 | model.return_bottomup[0] = False 156 | model.return_x[0] = False 157 | 158 | # Calculate and log metrics 159 | metrics = metrics_calc.compute() 160 | if not test: 161 | wandb.log({'validation': metrics}) 162 | prints = f"Epoch {epoch_idx}, Validation: " 163 | else: 164 | wandb.log({'test': metrics}) 165 | prints = f"Test: " 166 | for key, value in metrics.items(): 167 | prints += f"{key}: {value:.3f} " 168 | print(prints) 169 | metrics_calc.reset() 170 | _ = gc.collect() 171 | return 172 | 173 | 174 | def predict(loader, model, device, *return_flags): 175 | model.eval() 176 | 177 | if 'bottom_up' in return_flags: 178 | model.return_bottomup[0] = True 179 | if 'X_aug' in return_flags: 180 | model.return_x[0] = True 181 | if 'elbo' in return_flags: 182 | model.return_elbo[0] = True 183 | 184 | results = {name: [] for name in return_flags} 185 | # Create a dictionary to map return flags to corresponding functions 186 | return_functions = { 187 | 'node_leaves': lambda: move_to(outputs['node_leaves'], 'cpu'), 188 | 'bottom_up': lambda: move_to(outputs['bottom_up'], 'cpu'), 189 | 'prob_leaves': lambda: move_to(outputs['p_c_z'], 'cpu'), 190 | 'X_aug': lambda: move_to(outputs['input'], 'cpu'), 191 | 'y': lambda: labels, 192 | 'elbo': lambda: move_to(outputs['elbo_samples'], 'cpu'), 193 | 'rec_loss': lambda: move_to(outputs['rec_loss'], 'cpu') 194 | } 195 | 196 | with torch.no_grad(): 197 | for batch_idx, (inputs, labels) in enumerate(tqdm(loader)): 198 | inputs = inputs.to(device) 199 | # Make predictions for this batch 200 | outputs = model(inputs) 201 | 202 | for return_flag in return_flags: 203 | results[return_flag].append(return_functions[return_flag]()) 204 | 205 | for return_flag in return_flags: 206 | if return_flag == 'bottom_up': 207 | bottom_up = results[return_flag] 208 | results[return_flag] = [torch.cat([sublist[i] for sublist in bottom_up], dim=0) for i in 209 | range(len(bottom_up[0]))] 210 | elif return_flag == 'node_leaves': 211 | node_leaves_combined = [] 212 | node_leaves = results[return_flag] 213 | for i in range(len(node_leaves[0])): 214 | node_leaves_combined.append(dict()) 215 | for key in node_leaves[0][i].keys(): 216 | node_leaves_combined[i][key] = torch.cat([sublist[i][key] for sublist in node_leaves], dim=0) 217 | results[return_flag] = node_leaves_combined 218 | elif return_flag == 'rec_loss': 219 | results[return_flag] = torch.stack(results[return_flag], dim=0) 220 | else: 221 | results[return_flag] = torch.cat(results[return_flag], dim=0) 222 | 223 | if 'bottom_up' in return_flags: 224 | model.return_bottomup[0] = False 225 | if 'X_aug' in return_flags: 226 | model.return_x[0] = False 227 | if 'elbo' in return_flags: 228 | model.return_elbo[0] = False 229 | 230 | if len(return_flags) == 1: 231 | return list(results.values())[0] 232 | else: 233 | return tuple(results.values()) 234 | 235 | 236 | def move_to(obj, device): 237 | if torch.is_tensor(obj): 238 | return obj.to(device) 239 | elif isinstance(obj, dict): 240 | res = {} 241 | for k, v in obj.items(): 242 | res[k] = move_to(v, device) 243 | return res 244 | elif isinstance(obj, list): 245 | res = [] 246 | for v in obj: 247 | res.append(move_to(v, device)) 248 | return res 249 | elif isinstance(obj, tuple): 250 | res = tuple(tensor.to(device) for tensor in obj) 251 | return res 252 | else: 253 | raise TypeError("Invalid type for move_to") 254 | 255 | 256 | class AnnealKLCallback: 257 | def __init__(self, model, decay=0.01, start=0.): 258 | self.decay = decay 259 | self.start = start 260 | self.model = model 261 | self.model.alpha = torch.tensor(min(1, start)) 262 | 263 | def on_epoch_end(self, epoch, logs=None): 264 | value = self.start + (epoch + 1) * self.decay 265 | self.model.alpha = torch.tensor(min(1, value)) 266 | 267 | 268 | class Decay(): 269 | def __init__(self, lr=0.001, drop=0.1, epochs_drop=50): 270 | self.lr = lr 271 | self.drop = drop 272 | self.epochs_drop = epochs_drop 273 | 274 | def learning_rate_scheduler(self, epoch): 275 | initial_lrate = self.lr 276 | drop = self.drop 277 | epochs_drop = self.epochs_drop 278 | lrate = initial_lrate * math.pow(drop, math.floor((1 + epoch) / epochs_drop)) 279 | return lrate 280 | 281 | 282 | def calc_aug_loss(prob_parent, prob_router, augmentation_methods, emb_contr=[]): 283 | aug_decisions_loss = torch.zeros(1, device=prob_parent.device) 284 | prob_parent = prob_parent.detach() 285 | 286 | # Get router probabilities of X' and X'' 287 | p1, p2 = prob_router[:len(prob_router) // 2], prob_router[len(prob_router) // 2:] 288 | # Perform invariance regularization 289 | for aug_method in augmentation_methods: 290 | # Perform invariance regularization in the decisions 291 | if aug_method == 'InfoNCE': 292 | p1_normed = torch.nn.functional.normalize(torch.stack([p1, 1 - p1], 1), dim=1) 293 | p2_normed = torch.nn.functional.normalize(torch.stack([p2, 1 - p2], 1), dim=1) 294 | pair_sim = torch.exp(torch.sum(p1_normed * p2_normed, dim=1)) 295 | p_normed = torch.cat([p1_normed, p2_normed], dim=0) 296 | matrix_sim = torch.exp(torch.matmul(p_normed, p_normed.t())) 297 | norm_factor = torch.sum(matrix_sim, dim=1) - torch.diag(matrix_sim) 298 | pair_sim = pair_sim.repeat(2) # storing sim for X' and X'' 299 | info_nce_sample = -torch.log(pair_sim / norm_factor) 300 | info_nce = torch.sum(prob_parent * info_nce_sample) / torch.sum(prob_parent) 301 | aug_decisions_loss += info_nce 302 | # Perform invariance regularization in the bottom-up embeddings 303 | elif aug_method == 'instancewise_full': 304 | looplen = len(emb_contr) 305 | for i in range(looplen): 306 | temp_instance = 0.5 307 | emb = emb_contr[i] 308 | emb1, emb2 = emb[:len(emb) // 2], emb[len(emb) // 2:] 309 | emb1_normed = torch.nn.functional.normalize(emb1, dim=1) 310 | emb2_normed = torch.nn.functional.normalize(emb2, dim=1) 311 | pair_sim = torch.exp(torch.sum(emb1_normed * emb2_normed, dim=1) / temp_instance) 312 | emb_normed = torch.cat([emb1_normed, emb2_normed], dim=0) 313 | matrix_sim = torch.exp(torch.matmul(emb_normed, emb_normed.t()) / temp_instance) 314 | norm_factor = torch.sum(matrix_sim, dim=1) - torch.diag(matrix_sim) 315 | pair_sim = pair_sim.repeat(2) # storing sim for X' and X'' 316 | info_nce_sample = -torch.log(pair_sim / norm_factor) 317 | info_nce = torch.mean(info_nce_sample) 318 | info_nce = info_nce / looplen 319 | aug_decisions_loss += info_nce 320 | else: 321 | raise NotImplementedError 322 | 323 | return aug_decisions_loss 324 | 325 | 326 | def get_ind_small_tree(node_leaves, n_effective_leaves): 327 | prob = node_leaves['prob'] 328 | ind = np.where(prob >= min(1 / n_effective_leaves, 0.5))[0] # To circumvent problems with n_effective_leaves==1 329 | return ind 330 | 331 | 332 | def compute_leaves(tree): 333 | list_nodes = [{'node': tree, 'depth': 0}] 334 | nodes_leaves = [] 335 | while len(list_nodes) != 0: 336 | current_node = list_nodes.pop(0) 337 | node, depth_level = current_node['node'], current_node['depth'] 338 | if node.router is not None: 339 | node_left, node_right = node.left, node.right 340 | list_nodes.append( 341 | {'node': node_left, 'depth': depth_level + 1}) 342 | list_nodes.append({'node': node_right, 'depth': depth_level + 1}) 343 | elif node.router is None and node.decoder is None: 344 | # We are in an internal node with pruned leaves and thus only have one child 345 | node_left, node_right = node.left, node.right 346 | child = node_left if node_left is not None else node_right 347 | list_nodes.append( 348 | {'node': child, 'depth': depth_level + 1}) 349 | else: 350 | nodes_leaves.append(current_node) 351 | return nodes_leaves 352 | 353 | 354 | def compute_growing_leaf(loader, model, node_leaves, max_depth, batch_size, max_leaves, check_max=False): 355 | """ 356 | Compute the leaf of the TreeVAE model that should be further split. 357 | 358 | Parameters 359 | ---------- 360 | loader: DataLoader 361 | The data loader used to compute the leaf 362 | model: models.model.TreeVAE 363 | The TreeVAE model 364 | node_leaves: list 365 | A list of leaf nodes, each one described by a dictionary 366 | {'prob': sample-wise probability of reaching the node, 'z_sample': sampled leaf embedding} 367 | max_depth: int 368 | The maximum depth of the tree 369 | batch_size: int 370 | The batch size 371 | max_leaves: int 372 | The maximum number of leaves of the tree 373 | check_max: bool 374 | Whether to check that we reached the maximum number of leaves 375 | Returns 376 | ------ 377 | list: List containing: 378 | ind_leaf: index of the selected leaf 379 | leaf: the selected leaf 380 | n_effective_leaves: number of leaves that are not empty 381 | """ 382 | 383 | # count effective number of leaves (non empty leaves) 384 | weights = [node_leaves[i]['prob'] for i in range(len(node_leaves))] 385 | weights_summed = [weights[i].sum() for i in range(len(weights))] 386 | n_effective_leaves = len(np.where(weights_summed / np.sum(weights_summed) >= 0.01)[0]) 387 | print("\nNumber of effective leaves: ", n_effective_leaves) 388 | 389 | # grow until reaching required n_effective_leaves 390 | if n_effective_leaves >= max_leaves: 391 | print('\nReached maximum number of leaves\n') 392 | return None, None, True 393 | 394 | elif check_max: 395 | return None, None, False 396 | 397 | else: 398 | leaves = compute_leaves(model.tree) 399 | n_samples = [] 400 | if loader.dataset.dataset.__class__ is TensorDataset: 401 | y_train = loader.dataset.dataset.tensors[1][loader.dataset.indices] 402 | else: 403 | y_train = loader.dataset.dataset.targets[loader.dataset.indices] 404 | # Calculating ground-truth nodes-to-split for logging and model development 405 | # NOTE: labels are used to evaluate leaf metrics, they are not used to select the leaf 406 | for i in range(len(node_leaves)): 407 | depth, node = leaves[i]['depth'], leaves[i]['node'] 408 | if not node.expand: 409 | continue 410 | ind = get_ind_small_tree(node_leaves[i], n_effective_leaves) 411 | y_train_small = y_train[ind] 412 | # printing distribution of ground-truth classes in leaves 413 | print(f"Leaf {i}: ", np.unique(y_train_small, return_counts=True)) 414 | n_samples.append(len(y_train_small)) 415 | 416 | # Highest number of samples indicates splitting 417 | split_values = n_samples 418 | ind_leaves = np.argsort(np.array(split_values)) 419 | ind_leaves = ind_leaves[::-1] 420 | 421 | print("Ranking of leaves to split: ", ind_leaves) 422 | for i in ind_leaves: 423 | if n_samples[i] < batch_size: 424 | wandb.log({'Skipped Split': 1}) 425 | print("We don't split leaves with fewer samples than batch size") 426 | continue 427 | elif leaves[i]['depth'] == max_depth or not leaves[i]['node'].expand: 428 | leaves[i]['node'].expand = False 429 | print('\nReached maximum architecture\n') 430 | print('\n!!ATTENTION!! architecture is not deep enough\n') 431 | break 432 | else: 433 | ind_leaf = i 434 | leaf = leaves[ind_leaf] 435 | print(f'\nSplitting leaf {ind_leaf}\n') 436 | return ind_leaf, leaf, n_effective_leaves 437 | 438 | return None, None, n_effective_leaves 439 | 440 | 441 | def compute_pruning_leaf(model, node_leaves_train): 442 | leaves = compute_leaves(model.tree) 443 | n_leaves = len(node_leaves_train) 444 | weights = [node_leaves_train[i]['prob'] for i in range(n_leaves)] 445 | 446 | # Assign each sample to a leaf by argmax(weights) 447 | max_indeces = np.array([np.argmax(col) for col in zip(*weights)]) 448 | 449 | n_samples = [] 450 | for i in range(n_leaves): 451 | print(f"Leaf {i}: ", sum(max_indeces == i), "samples") 452 | n_samples.append(sum(max_indeces == i)) 453 | 454 | # Prune leaves with less than 1% of all samples 455 | ind_leaf = np.argmin(n_samples) 456 | if n_samples[ind_leaf] < 0.01 * sum(n_samples): 457 | leaf = leaves[ind_leaf] 458 | return ind_leaf, leaf 459 | else: 460 | return None, None 461 | 462 | 463 | def get_optimizer(model, configs): 464 | optimizer = optim.Adam(params=model.parameters(), lr=configs['training']['lr'], 465 | weight_decay=configs['training']['weight_decay']) 466 | return optimizer 467 | 468 | 469 | class Custom_Metrics(Metric): 470 | def __init__(self, device): 471 | super().__init__() 472 | self.add_state("loss_value", default=torch.tensor(0., device=device)) 473 | self.add_state("rec_loss", default=torch.tensor(0., device=device)) 474 | self.add_state("kl_root", default=torch.tensor(0., device=device)) 475 | self.add_state("kl_decisions", default=torch.tensor(0., device=device)) 476 | self.add_state("kl_nodes", default=torch.tensor(0., device=device)) 477 | self.add_state("aug_decisions", default=torch.tensor(0., device=device)) 478 | self.add_state("perc_samples", default=torch.tensor(0., device=device)) 479 | self.add_state("y_true", default=[]) 480 | self.add_state("y_pred", default=[]) 481 | self.add_state("n_samples", default=torch.tensor(0, dtype=torch.int, device=device)) 482 | 483 | def update(self, loss_value: torch.Tensor, rec_loss: torch.Tensor, kl_decisions: torch.Tensor, 484 | kl_nodes: torch.Tensor, kl_root: torch.Tensor, aug_decisions: torch.Tensor, perc_samples: torch.Tensor, 485 | y_true: torch.Tensor, y_pred: torch.Tensor): 486 | assert y_true.shape == y_pred.shape 487 | 488 | n_samples = y_true.numel() 489 | self.n_samples += n_samples 490 | self.loss_value += loss_value.item() * n_samples 491 | self.rec_loss += rec_loss.item() * n_samples 492 | self.kl_root += kl_root.item() * n_samples 493 | self.kl_decisions += kl_decisions.item() * n_samples 494 | self.kl_nodes += kl_nodes.item() * n_samples 495 | self.aug_decisions += aug_decisions.item() * n_samples 496 | self.perc_samples += perc_samples.item() * n_samples 497 | self.y_true.append(y_true) 498 | self.y_pred.append(y_pred) 499 | 500 | def compute(self): 501 | self.y_true = torch.cat(self.y_true, dim=0) 502 | self.y_pred = torch.cat(self.y_pred, dim=0) 503 | nmi = normalized_mutual_info_score(self.y_true.cpu().numpy(), self.y_pred.cpu().numpy()) 504 | acc = cluster_acc(self.y_true.cpu().numpy(), self.y_pred.cpu().numpy(), return_index=False) 505 | 506 | metrics = dict({'loss_value': self.loss_value / self.n_samples, 'rec_loss': self.rec_loss / self.n_samples, 507 | 'kl_decisions': self.kl_decisions / self.n_samples, 'kl_root': self.kl_root / self.n_samples, 508 | 'kl_nodes': self.kl_nodes / self.n_samples, 509 | 'aug_decisions': self.aug_decisions / self.n_samples, 510 | 'perc_samples': self.perc_samples / self.n_samples, 'nmi': nmi, 'accuracy': acc}) 511 | 512 | return metrics 513 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | General utility functions. 3 | """ 4 | 5 | import numpy as np 6 | from scipy.optimize import linear_sum_assignment as linear_assignment 7 | from scipy.special import comb 8 | import torch 9 | import os 10 | import random 11 | from pathlib import Path 12 | import yaml 13 | 14 | 15 | def cluster_acc(y_true, y_pred, return_index=False): 16 | """ 17 | Calculate clustering accuracy. 18 | # Arguments 19 | y: true labels, numpy.array with shape `(n_samples,)` 20 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 21 | # Return 22 | accuracy, in [0,1] 23 | """ 24 | y_true = y_true.astype(np.int64) 25 | assert y_pred.size == y_true.size 26 | D = max(y_pred.astype(int).max(), y_true.astype(int).max()) + 1 27 | w = np.zeros((int(D), (D)), dtype=np.int64) 28 | for i in range(y_pred.size): 29 | w[int(y_pred[i]), int(y_true[i])] += 1 30 | ind = np.array(linear_assignment(w.max() - w)) 31 | if return_index: 32 | assert all(ind[0] == range(len(ind[0]))) # Assert rows don't change order 33 | cluster_acc = sum(w[ind[0], ind[1]]) * 1.0 / y_pred.size 34 | return cluster_acc, ind[1] 35 | else: 36 | return sum([w[ind[0,i], ind[1,i]] for i in range(len(ind[0]))]) * 1.0 / y_pred.size 37 | 38 | 39 | def reset_random_seeds(seed): 40 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 41 | np.random.seed(seed) 42 | random.seed(seed) 43 | torch.backends.cudnn.deterministic = True 44 | # No determinism as nn.Upsample has no deterministic implementation 45 | #torch.use_deterministic_algorithms(True) 46 | torch.backends.cudnn.benchmark = False 47 | torch.manual_seed(seed) 48 | torch.cuda.manual_seed(seed) 49 | os.environ['PYTHONHASHSEED'] = str(seed) 50 | 51 | 52 | def merge_yaml_args(configs, args): 53 | arg_dict = args.__dict__ 54 | configs['parser'] = dict() 55 | for key, value in arg_dict.items(): 56 | flag = True 57 | # Replace/Create values in config if they are defined by arg in parser. 58 | if arg_dict[key] is not None: 59 | for key_config in configs.keys(): 60 | # If value of config is dict itself, then search key-value pairs inside this dict for matching the arg 61 | if type(configs[key_config]) is dict: 62 | for key2, value2 in configs[key_config].items(): 63 | if key == key2: 64 | configs[key_config][key2] = value 65 | flag = False 66 | # If value of config is not a dict, check whether key matches to the arg 67 | else: 68 | if key == key_config: 69 | configs[key_config] = value 70 | flag = False 71 | # Break out of loop if key got replaced 72 | if flag == False: 73 | break 74 | # If arg does not match any keys of config, define a new key 75 | else: 76 | print("Could not find this key in config, therefore adding it:", key) 77 | configs['parser'][key] = arg_dict[key] 78 | return configs 79 | 80 | 81 | def prepare_config(args, project_dir): 82 | # Load config 83 | data_name = args.config_name +'.yml' 84 | config_path = project_dir / 'configs' / data_name 85 | 86 | with config_path.open(mode='r') as yamlfile: 87 | configs = yaml.safe_load(yamlfile) 88 | 89 | # Override config if args in parser 90 | configs = merge_yaml_args(configs, args) 91 | if isinstance(configs['training']['latent_dim'], str): 92 | a = configs['training']['latent_dim'].split(",") 93 | configs['training']['latent_dim'] = [int(i) for i in a] 94 | if isinstance(configs['training']['mlp_layers'], str): 95 | a = configs['training']['mlp_layers'].split(",") 96 | configs['training']['mlp_layers'] = [int(i) for i in a] 97 | 98 | a = configs['training']['augmentation_method'].split(",") 99 | configs['training']['augmentation_method'] = [str(i) for i in a] 100 | 101 | 102 | 103 | configs['globals']['results_dir'] = os.path.join(project_dir, 'models/experiments') 104 | configs['globals']['results_dir'] = Path(configs['globals']['results_dir']).absolute() 105 | 106 | # Prepare for passing x' and x'' through model by setting batch size to an even number 107 | if configs['training']['augment'] is True and configs['training']['augmentation_method'] != ['simple'] and configs['training']['batch_size'] % 2 != 0: 108 | configs['training']['batch_size'] += 1 109 | 110 | 111 | return configs 112 | 113 | def count_values_in_sequence(seq): 114 | from collections import defaultdict 115 | res = defaultdict(lambda : 0) 116 | for key in seq: 117 | res[key] += 1 118 | return dict(res) 119 | 120 | 121 | def dendrogram_purity(tree_root, ground_truth, ind_samples_of_leaves): 122 | total_per_label_frequencies = count_values_in_sequence(ground_truth) 123 | total_per_label_pairs_count = {k: comb(v, 2, True) for k, v in total_per_label_frequencies.items()} 124 | total_n_of_pairs = sum(total_per_label_pairs_count.values()) 125 | one_div_total_n_of_pairs = 1. / total_n_of_pairs 126 | purity = 0. 127 | 128 | def calculate_purity(node, level): 129 | nonlocal purity 130 | if node.decoder: 131 | # Match node to leaf samples 132 | ind_leaf = np.where([node == ind_samples_of_leaves[ind_leaf][0] for ind_leaf in range(len(ind_samples_of_leaves))])[0].item() 133 | ind_samples_of_leaf = ind_samples_of_leaves[ind_leaf][1] 134 | node_total_dp_count = len(ind_samples_of_leaf) 135 | # Count how many samples of given leaf fall into which ground-truth class (-> For treevae make use of ground_truth(to which class a sample belongs)&yy(into which leaf a sample falls)) 136 | node_per_label_frequencies = count_values_in_sequence( 137 | [ground_truth[id] for id in ind_samples_of_leaf]) 138 | # From above, deduct how many pairs will fall into same leaf 139 | node_per_label_pairs_count = {k: comb(v, 2, True) for k, v in node_per_label_frequencies.items()} 140 | 141 | elif node.router is None and node.decoder is None: 142 | # We are in an internal node with pruned leaves and thus only have one child. Therefore no prunity calculation here! 143 | node_left, node_right = node.left, node.right 144 | child = node_left if node_left is not None else node_right 145 | node_per_label_frequencies, node_total_dp_count = calculate_purity(child, level + 1) 146 | return node_per_label_frequencies, node_total_dp_count 147 | 148 | else: 149 | # it is an inner splitting node 150 | left_child_per_label_freq, left_child_total_dp_count = calculate_purity(node.left, level + 1) 151 | right_child_per_label_freq, right_child_total_dp_count = calculate_purity(node.right, level + 1) 152 | node_total_dp_count = left_child_total_dp_count + right_child_total_dp_count 153 | # Count how many samples of given internal node fall into which ground-truth class (=sum of their children's values) 154 | node_per_label_frequencies = {k: left_child_per_label_freq.get(k, 0) + right_child_per_label_freq.get(k, 0) \ 155 | for k in set(left_child_per_label_freq) | set(right_child_per_label_freq)} 156 | 157 | # Class-wisedly count how many pairs of samples of a class will have this node as least common ancestor (=mult. of their children's values, bcs this is all possible pairs coming from different sides) 158 | node_per_label_pairs_count = {k: left_child_per_label_freq.get(k) * right_child_per_label_freq.get(k) \ 159 | for k in set(left_child_per_label_freq) & set(right_child_per_label_freq)} 160 | 161 | # Given the class-wise number of pairs with given node as least common ancestor node, calculate their purity 162 | for label, pair_count in node_per_label_pairs_count.items(): 163 | label_freq = node_per_label_frequencies[label] 164 | label_pairs = node_per_label_pairs_count[label] 165 | purity += one_div_total_n_of_pairs * label_freq / node_total_dp_count * label_pairs # (1/n_all_pairs) * purity(=n_samples_of_this_class_in_node/n_samples) * n_class_pairs_with_this_node_being_least_common_ancestor(this last term represents sum over pairs with this node being least common ancestor) 166 | return node_per_label_frequencies, node_total_dp_count 167 | 168 | calculate_purity(tree_root, 0) 169 | return purity 170 | 171 | 172 | def leaf_purity(tree_root, ground_truth, ind_samples_of_leaves): 173 | values = [] # purity rate per leaf 174 | weights = [] # n_samples per leaf 175 | # For each leaf calculate the maximum over classes for in-leaf purity (i.e. majority class / n_samples_in_leaf) 176 | def get_leaf_purities(node): 177 | nonlocal values 178 | nonlocal weights 179 | if node.decoder: 180 | ind_leaf = np.where([node == ind_samples_of_leaves[ind_leaf][0] for ind_leaf in range(len(ind_samples_of_leaves))])[0].item() 181 | ind_samples_of_leaf = ind_samples_of_leaves[ind_leaf][1] 182 | node_total_dp_count = len(ind_samples_of_leaf) 183 | node_per_label_counts = count_values_in_sequence( 184 | [ground_truth[id] for id in ind_samples_of_leaf]) 185 | if node_total_dp_count > 0: 186 | purity_rate = max(node_per_label_counts.values()) / node_total_dp_count 187 | else: 188 | purity_rate = 1.0 189 | values.append(purity_rate) 190 | weights.append(node_total_dp_count) 191 | elif node.router is None and node.decoder is None: 192 | # We are in an internal node with pruned leaves and thus only have one child. 193 | node_left, node_right = node.left, node.right 194 | child = node_left if node_left is not None else node_right 195 | get_leaf_purities(child) 196 | else: 197 | get_leaf_purities(node.left) 198 | get_leaf_purities(node.right) 199 | 200 | get_leaf_purities(tree_root) 201 | assert len(values) == len(ind_samples_of_leaves), "Didn't iterate through all leaves" 202 | # Return mean leaf_purity 203 | return np.average(values, weights=weights) 204 | 205 | def display_image(image): 206 | assert image.dim() == 3 207 | if image.size()[0] == 1: 208 | return torch.clamp(image.squeeze(0),0,1) 209 | elif image.size()[0] == 3: 210 | return torch.clamp(image.permute(1, 2, 0),0,1) 211 | elif image.size()[-1] == 3: 212 | return torch.clamp(image,0,1) 213 | else: 214 | raise NotImplementedError --------------------------------------------------------------------------------