├── README.md ├── __init__.py ├── arguments.py ├── data ├── __init__.py ├── cifar_contrastive_loader.py ├── idx_set_0.csv ├── idx_set_1.csv ├── idx_set_2.csv ├── idx_set_3.csv ├── idx_set_4.csv ├── randomcrops_2000_0.25_1.0.csv └── train_idxes.csv ├── experiment_scripts ├── cifar experiment │ ├── cifar10Hhard_MCInfoNCE.sh │ ├── cifar10Hsoft_ELK.sh │ ├── cifar10Hsoft_HedgedInstance.sh │ ├── cifar10Hsoft_MCInfoNCE.sh │ └── cifar10hard_MCInfoNCE.sh └── controlled experiment │ ├── ELK_ambiguous_seed_4.sh │ ├── ELK_ambiguous_seed_5.sh │ ├── ELK_ambiguous_seed_6.sh │ ├── ELK_ambiguous_seed_7.sh │ ├── ELK_ambiguous_seed_8.sh │ ├── HedgedInstance_ambiguous_seed_4.sh │ ├── HedgedInstance_ambiguous_seed_5.sh │ ├── HedgedInstance_ambiguous_seed_6.sh │ ├── HedgedInstance_ambiguous_seed_7.sh │ ├── HedgedInstance_ambiguous_seed_8.sh │ ├── MCInfoNCE_2D_seed_4.sh │ ├── MCInfoNCE_2D_seed_5.sh │ ├── MCInfoNCE_2D_seed_6.sh │ ├── MCInfoNCE_2D_seed_7.sh │ ├── MCInfoNCE_2D_seed_8.sh │ ├── MCInfoNCE_ambiguous_seed_4.sh │ ├── MCInfoNCE_ambiguous_seed_5.sh │ ├── MCInfoNCE_ambiguous_seed_6.sh │ ├── MCInfoNCE_ambiguous_seed_7.sh │ ├── MCInfoNCE_ambiguous_seed_8.sh │ ├── MCInfoNCE_clear_seed_4.sh │ ├── MCInfoNCE_clear_seed_5.sh │ ├── MCInfoNCE_clear_seed_6.sh │ ├── MCInfoNCE_clear_seed_7.sh │ ├── MCInfoNCE_clear_seed_8.sh │ ├── MCInfoNCE_encoderdim_128_seed_4.sh │ ├── MCInfoNCE_encoderdim_128_seed_5.sh │ ├── MCInfoNCE_encoderdim_128_seed_6.sh │ ├── MCInfoNCE_encoderdim_128_seed_7.sh │ ├── MCInfoNCE_encoderdim_128_seed_8.sh │ ├── MCInfoNCE_encoderdim_16_seed_4.sh │ ├── MCInfoNCE_encoderdim_16_seed_5.sh │ ├── MCInfoNCE_encoderdim_16_seed_6.sh │ ├── MCInfoNCE_encoderdim_16_seed_7.sh │ ├── MCInfoNCE_encoderdim_16_seed_8.sh │ ├── MCInfoNCE_encoderdim_32_seed_4.sh │ ├── MCInfoNCE_encoderdim_32_seed_5.sh │ ├── MCInfoNCE_encoderdim_32_seed_6.sh │ ├── MCInfoNCE_encoderdim_32_seed_7.sh │ ├── MCInfoNCE_encoderdim_32_seed_8.sh │ ├── MCInfoNCE_encoderdim_4_seed_4.sh │ ├── MCInfoNCE_encoderdim_4_seed_5.sh │ ├── MCInfoNCE_encoderdim_4_seed_6.sh │ ├── MCInfoNCE_encoderdim_4_seed_7.sh │ ├── MCInfoNCE_encoderdim_4_seed_8.sh │ ├── MCInfoNCE_encoderdim_64_seed_4.sh │ ├── MCInfoNCE_encoderdim_64_seed_5.sh │ ├── MCInfoNCE_encoderdim_64_seed_6.sh │ ├── MCInfoNCE_encoderdim_64_seed_7.sh │ ├── MCInfoNCE_encoderdim_64_seed_8.sh │ ├── MCInfoNCE_encoderdim_8_seed_4.sh │ ├── MCInfoNCE_encoderdim_8_seed_5.sh │ ├── MCInfoNCE_encoderdim_8_seed_6.sh │ ├── MCInfoNCE_encoderdim_8_seed_7.sh │ ├── MCInfoNCE_encoderdim_8_seed_8.sh │ ├── MCInfoNCE_gaussiandistr_seed_4.sh │ ├── MCInfoNCE_gaussiandistr_seed_5.sh │ ├── MCInfoNCE_gaussiandistr_seed_6.sh │ ├── MCInfoNCE_gaussiandistr_seed_7.sh │ ├── MCInfoNCE_gaussiandistr_seed_8.sh │ ├── MCInfoNCE_highdim_10_seed_4.sh │ ├── MCInfoNCE_highdim_10_seed_5.sh │ ├── MCInfoNCE_highdim_10_seed_6.sh │ ├── MCInfoNCE_highdim_128_seed_4.sh │ ├── MCInfoNCE_highdim_128_seed_5.sh │ ├── MCInfoNCE_highdim_128_seed_6.sh │ ├── MCInfoNCE_highdim_16_seed_4.sh │ ├── MCInfoNCE_highdim_16_seed_5.sh │ ├── MCInfoNCE_highdim_16_seed_6.sh │ ├── MCInfoNCE_highdim_32_seed_4.sh │ ├── MCInfoNCE_highdim_32_seed_5.sh │ ├── MCInfoNCE_highdim_32_seed_6.sh │ ├── MCInfoNCE_highdim_40_seed_4.sh │ ├── MCInfoNCE_highdim_40_seed_5.sh │ ├── MCInfoNCE_highdim_40_seed_6.sh │ ├── MCInfoNCE_highdim_48_seed_4.sh │ ├── MCInfoNCE_highdim_48_seed_5.sh │ ├── MCInfoNCE_highdim_48_seed_6.sh │ ├── MCInfoNCE_highdim_56_seed_4.sh │ ├── MCInfoNCE_highdim_56_seed_5.sh │ ├── MCInfoNCE_highdim_56_seed_6.sh │ ├── MCInfoNCE_highdim_64_seed_4.sh │ ├── MCInfoNCE_highdim_64_seed_5.sh │ ├── MCInfoNCE_highdim_64_seed_6.sh │ ├── MCInfoNCE_injective_seed_4.sh │ ├── MCInfoNCE_injective_seed_5.sh │ ├── MCInfoNCE_injective_seed_6.sh │ ├── MCInfoNCE_injective_seed_7.sh │ ├── MCInfoNCE_injective_seed_8.sh │ ├── MCInfoNCE_laplacedistr_seed_4.sh │ ├── MCInfoNCE_laplacedistr_seed_5.sh │ ├── MCInfoNCE_laplacedistr_seed_6.sh │ ├── MCInfoNCE_laplacedistr_seed_7.sh │ ├── MCInfoNCE_laplacedistr_seed_8.sh │ ├── MCInfoNCE_nmcsamples_16_seed_4.sh │ ├── MCInfoNCE_nmcsamples_16_seed_5.sh │ ├── MCInfoNCE_nmcsamples_16_seed_6.sh │ ├── MCInfoNCE_nmcsamples_16_seed_7.sh │ ├── MCInfoNCE_nmcsamples_16_seed_8.sh │ ├── MCInfoNCE_nmcsamples_1_seed_4.sh │ ├── MCInfoNCE_nmcsamples_1_seed_5.sh │ ├── MCInfoNCE_nmcsamples_1_seed_6.sh │ ├── MCInfoNCE_nmcsamples_1_seed_7.sh │ ├── MCInfoNCE_nmcsamples_1_seed_8.sh │ ├── MCInfoNCE_nmcsamples_256_seed_4.sh │ ├── MCInfoNCE_nmcsamples_256_seed_5.sh │ ├── MCInfoNCE_nmcsamples_256_seed_6.sh │ ├── MCInfoNCE_nmcsamples_256_seed_7.sh │ ├── MCInfoNCE_nmcsamples_256_seed_8.sh │ ├── MCInfoNCE_nmcsamples_4_seed_4.sh │ ├── MCInfoNCE_nmcsamples_4_seed_5.sh │ ├── MCInfoNCE_nmcsamples_4_seed_6.sh │ ├── MCInfoNCE_nmcsamples_4_seed_7.sh │ ├── MCInfoNCE_nmcsamples_4_seed_8.sh │ ├── MCInfoNCE_nmcsamples_64_seed_4.sh │ ├── MCInfoNCE_nmcsamples_64_seed_5.sh │ ├── MCInfoNCE_nmcsamples_64_seed_6.sh │ ├── MCInfoNCE_nmcsamples_64_seed_7.sh │ └── MCInfoNCE_nmcsamples_64_seed_8.sh ├── main.py ├── main_cifar.py ├── models ├── __init__.py ├── encoder.py ├── encoder_resnet.py ├── generator.py └── state_dicts │ └── .gitignore ├── results └── .gitignore ├── thumbnail.png └── utils ├── __init__.py ├── approx_vmf_norm_const.R ├── losses.py ├── metrics.py ├── scheduler.py ├── utils.py └── vmf_sampler.py /README.md: -------------------------------------------------------------------------------- 1 | # Probabilistic Contrastive Learning Recovers the Correct Aleatoric Uncertainty of Ambiguous Inputs 2 | 3 | Michael Kirchhof, Enkelejda Kasneci, Seong Joon Oh 4 | 5 | --- 6 | 7 | ![Deterministic ](thumbnail.png) 8 | 9 | _Contrastively trained encoders have recently been proven to invert the data-generating process: they encode each input, e.g., an image, into the true latent vector that generated the image (Zimmermann et al., 2021). However, real-world observations often have inherent ambiguities. For instance, images may be blurred or only show a 2D view of a 3D object, so multiple latents could have generated them. This makes the true posterior for the latent vector probabilistic with heteroscedastic uncertainty. In this setup, we extend the common InfoNCE objective and encoders to predict latent distributions instead of points. We prove that these distributions recover the correct posteriors of the data-generating process, including its level of aleatoric uncertainty, up to a rotation of the latent space. In addition to providing calibrated uncertainty estimates, these posteriors allow the computation of credible intervals in image retrieval. They comprise images with the same latent as a given query, subject to its uncertainty._ 10 | 11 | **Link**: https://arxiv.org/abs/2302.02865 12 | 13 | --- 14 | ### Installation 15 | This code was tested on Python 3.8. Use the code below to create a fitting conda environment. 16 | 17 | ```commandline 18 | conda create --name probcontrlearning python=3.8 19 | conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia 20 | conda install tqdm scipy matplotlib argparse 21 | pip install wandb tueplots 22 | ``` 23 | 24 | If you want to do experiments on CIFAR-10H, you need to download the pretrained ResNet18 and the CIFAR-10H labels. [Download the ResNet weights](https://drive.google.com/file/d/17fmN8eQdLpq2jIMQ_X0IXDPXfI9oVWgq/view?usp=sharing), unzip and copy them into `models/state_dicts/resnet18.pt`. Then, [download the CIFAR-10H labels](https://github.com/jcpeterson/cifar-10h/blob/master/data/cifar10h-probs.npy) and copy them into `data/cifar10h-probs.npy`. The CIFAR-10 data itself is downloaded automatically. 25 | 26 | --- 27 | ### Reproducing Paper Results 28 | 29 | The `experiment_scripts` folder contains all shell files to reproduce our results. Here's an example: 30 | ``` 31 | python main.py --loss MCInfoNCE --g_dim_z 10 --g_dim_x 10 --e_dim_z 10 \ 32 | --g_pos_kappa 20 --g_post_kappa_min 16 --g_post_kappa_max 32 \ 33 | --n_phases 1 --n_batches_per_half_phase 50000 --bs 512 \ 34 | --l_n_samples 512 --n_neg 32 --use_wandb False --seed 4 35 | ``` 36 | These flags mean the following (`parameters.py` contains descriptions of all parameters): 37 | 38 | * `main.py` is used to train encoders in the controlled experiments. `main_cifar.py` trains and tests encoders on the CIFAR experiment. 39 | * `--loss` is by default `MCInfoNCE`. We also provide implementations for Expected Likelihood Kernels `ELK` and `HedgedInstance` embeddings. 40 | * `--g_dim_z`, `--g_dim_x`, and `--e_dim_z` control the dimensions of the latent space of the generative process, the image space, and the latent space of the encoder, respectively. 41 | * `--g_pos_kappa` controls $\kappa_\text{pos}$, i.e., how close latents have to be in the generative latent space to be considered positive samples to one another. 42 | * `--g_post_kappa_min` and `--g_post_kappa_max` control the in which range the true posterior kappas should lie (higher=less ambiguous; `"Inf"` to remove the uncertainty). 43 | * `--n_phases` controls how many phases we want in our training. Each phase first trains $\hat{\mu}(x)$ for `--n_batches_per_half_phase` batches of size `--bs` and then $\hat{\kappa}(x)$. If you want to train $\hat{\mu}(x)$ and $\hat{\kappa}(x)$ simultaneously, set `--n_phases 0`. 44 | * `--l_n_samples` is the number of MC samples to calculate the MCInfoNCE loss. 45 | * `--n_neg` is the number of negative contrastive samples per positive pair. Use `--n_neg 0` to use rolled samples from the own batch as negative samples. 46 | * `--use_wandb` is a boolean flag on whether to use wandb or store results in a local `/results` folder. If you want to use wandb, enter your API key with `--wandb_key`. 47 | * `--seed` The random seed that controls the intitialization of the generative process. We used seeds `1, 2, 3` for development and hyperparameter tuning and `4, 5, 6, 7, 8` for the paper results. 48 | 49 | --- 50 | 51 | ### Applying MCInfoNCE to Your Own Problem 52 | 53 | If you want obtain probabilistic embeddings for your own contrastive learning problem, you need two things: 54 | 55 | 1) Copy-paste the `MCInfoNCE()` loss from `utils/losses.py` into your project. The most important hyperparameters to tune are `kappa_init` and `n_neg`. We found `16` to be a solid starting value for both. 56 | 2) Make your encoder output both 57 | 1) mean (your typical penultimate-layer embedding, normalized to an $L_2$ norm of 1) and 58 | 2) kappa (scalar value indicating the certainty). 59 | 60 | You can use an explicit network to predict kappa, as for example in `models/encoder.py`, but you can also implicitly parameterize it via the norm of your embedding, as in `models/encoder_resnet.py`. The latter has been confirmed to work plug-and-play with ResNet and VGG architectures. We'd be happy to learn whether it also works on yours. 61 | 62 | --- 63 | 64 | ### How to Cite: 65 | ``` 66 | @article{kirchhof2023probabilistic, 67 | author={Kirchhof, Michael and Kasneci, Enkelejda and Oh, Seong Joon}, 68 | title={Probabilistic Contrastive Learning Recovers the Correct Aleatoric Uncertainty of Ambiguous Inputs}, 69 | journal={arXiv preprint arXiv:2302.02865}, 70 | year={2023} 71 | } 72 | ``` -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkirchhof/Probabilistic_Contrastive_Learning/b0f70c07e2bcf85e8eb13bf1e62fbd521fb6dd7d/__init__.py -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def str2bool(v): 4 | """ 5 | Thank to stackoverflow user: Maxim 6 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse/43357954#43357954 7 | :param v: A command line argument with values [yes, true, t, y, 1, True, no, false, f, n, 0, False] 8 | :return: Boolean version of the command line argument 9 | """ 10 | 11 | if v.lower() in ('yes', 'true', 't', 'y', '1', 'True'): 12 | return True 13 | elif v.lower() in ('no', 'false', 'f', 'n', '0', 'False'): 14 | return False 15 | else: 16 | raise argparse.ArgumentTypeError('Boolean value expected.') 17 | 18 | ####################################### 19 | def get_args(): 20 | parser = argparse.ArgumentParser() 21 | 22 | ##### Generator Parameters 23 | parser.add_argument('--g_dim_z', default=10, type=int, help='Dimensionality of latent space.') 24 | parser.add_argument('--g_dim_x', default=10, type=int, help='Dimensionality of x space.') 25 | parser.add_argument('--g_dim_hidden', default=10, type=int, help='Dimensionality of hidden layers.') 26 | parser.add_argument('--g_n_hidden', default=1, type=int, help='Number of hidden layers. (The network will additionally have one dim_x->dim_hidden and one dim_hidden->dim_z layer)') 27 | parser.add_argument('--g_pos_kappa', default=20, type=int, help='kappa of the implicit positive distribution in the z space.') 28 | # For controlled experiment: 29 | parser.add_argument('--g_post_family', default="vmf", type=str, help="How GT posteriors should be distributed (vmf, Gaussian, Laplace). Note: Predicted posteriors are *always* vmf") 30 | parser.add_argument('--g_post_kappa_min', default=16, type=float, help='How concentrated GT posteriors should be at least. (Use float("Inf") for the crisp case)') 31 | parser.add_argument('--g_post_kappa_max', default=32, type=float, help='How concentrated GT posteriors should be at most. (Use float("Inf") for the crisp case)') 32 | parser.add_argument('--g_min_spread', default=0.5, type=float, help='How much area of the sphere the generator needs to span to be accepted (measured by maximum cosine distance between means). Default 1 accepts any generator.') 33 | parser.add_argument('--has_joint_backbone', default=False, type=str2bool, help="Whether the kappa and mu functions should share the same backbone or be independent.") 34 | 35 | ##### Encoder Parameters 36 | parser.add_argument('--e_dim_z', default=10, type=int, help='Dimensionality of latent space.') 37 | parser.add_argument('--e_dim_hidden', default=0, type=int, help='Dimensionality of hidden layers. Use 0 to use the standard Zimmermann setting (10*e_dim_z for first and last, 50*e_dim_z for others)') 38 | parser.add_argument('--e_n_hidden', default=6, type=int, help='Number of hidden layers.') 39 | parser.add_argument('--e_post_kappa_min', default=16, type=float, help='To which range the encoder posteriors kappas should be initialized. (Use float("Inf") for the crisp case)') 40 | parser.add_argument('--e_post_kappa_max', default=32, type=float, help='To which range the encoder posteriors kappas should be initialized. (Use float("Inf") for the crisp case)') 41 | 42 | ##### Training parameters 43 | parser.add_argument('--train', default=True, type=str2bool, help="Whether to train. If false, uses an already trained checkpoint.") 44 | parser.add_argument('--lr', default=1e-4, type=float, help='Learning Rate for network parameters.') 45 | parser.add_argument('--n_phases', default=1, type=int, help='If 0, train kappa and mu together for 2 * n_batces_per_half_phase batches. Otherwise, do n_phases of split training, each training first mu and then kappa.') 46 | parser.add_argument('--n_batches_per_half_phase', default=50000, type=int, help='Number of training epochs per half phase (i.e. per mu and per kappa)') 47 | parser.add_argument('--lr_decrease_after_phase', default=0.5, type=float, help="Factor to multiply lr by after each half phase.") 48 | parser.add_argument('--bs', default=128, type=int, help='Mini-Batchsize to use.') 49 | parser.add_argument('--seed', default=1, type=int, help='Random seed for reproducibility.') 50 | parser.add_argument('--n_neg', default=16, type=int, help='Number of negative samples per image. If 0, use rolled batch as negative samples.') 51 | parser.add_argument('--oversampling_factor', default=10, type=int, help="How many candidates to generate per wanted rejection sample (higher=faster, but more RAM)") 52 | # For CIFAR experiment 53 | parser.add_argument('--traindata', default="test_softlabels", type=str, help="Which training data to use. train_hardlabels uses the normal CIFAR-10 train set. test_hardlabels a subset of the CIFAR-10 test set. test_softlabels a subset of the CIFAR-10H soft label test set.") 54 | parser.add_argument('--pretrained', default=True, type=str2bool, help="Whether to use ResNet-18 weights pretrained on CIFAR-10 for faster convergence.") 55 | 56 | ##### Loss parameters 57 | parser.add_argument('--loss', default="MCInfoNCE", type=str, help="Which loss (MCInfoNCE, ELK, HedgedInstance)") 58 | parser.add_argument('--l_n_samples', default=128, type=int, help='Number of MC samples to calculate the loss.') 59 | parser.add_argument('--l_learnable_params', default=False, type=str2bool, help="Whether loss parameters (e.g., temperature) should be learnable.") 60 | parser.add_argument('--l_hib_a', default=1, type=float, help="The multiplication constant in the HIB loss.") 61 | parser.add_argument('--l_hib_b', default=0, type=float, help="The addition constant in the HIB loss.") 62 | 63 | ##### Test parameters 64 | parser.add_argument('--eval_every', default=500, type=int, help="After how many batches to evaluate during training.") 65 | parser.add_argument('--n_numerical_eval', default=10000, type=int, help="On how many x-samples should posteriors be numerically evaluated.") 66 | parser.add_argument('--eval_std_instead_of_param', default=True, type=str2bool, help="Turn this to true if you want to compare non-vMF generators to vMF encoders") 67 | parser.add_argument('--n_graphical_eval', default=400, type=int, help="On how many x-samples should posteriors be graphically evaluated. (this might take a long time for high numbers) 0 to turn off.") 68 | parser.add_argument('--savefolder', default="test", type=str, help="Where to save the results") 69 | # for CIFAR experiment 70 | parser.add_argument('--test', default=True, type=str2bool, help="Whether to eval on the test set.") 71 | 72 | ##### Weights and Biases parameters 73 | parser.add_argument('--use_wandb', default=False, type=str2bool, help="Turns on or off wandb tracking") 74 | parser.add_argument('--wandb_key', default="ADD YOUR WANDB API KEY HERE", type=str, help="Your Wandb API key") 75 | parser.add_argument('--wandb_project', default="prob_contr_learning", type=str, help="Project name") 76 | 77 | return parser.parse_args() -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkirchhof/Probabilistic_Contrastive_Learning/b0f70c07e2bcf85e8eb13bf1e62fbd521fb6dd7d/data/__init__.py -------------------------------------------------------------------------------- /data/cifar_contrastive_loader.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets 2 | from torchvision import transforms 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torchvision.transforms import functional as functional_transforms 7 | from PIL import Image 8 | import os 9 | import json 10 | 11 | class ContrastiveCifar(): 12 | def __init__(self, mode="train", seed=1, batch_size=64, device=torch.device("cuda:0")): 13 | super().__init__() 14 | self.device = device 15 | 16 | # Load data 17 | self.transform = transforms.Compose( 18 | [transforms.ToTensor(), 19 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2471, 0.2435, 0.2616])]) 20 | self.data = datasets.CIFAR10(root='data/data_CIFAR10_test', train=False, download=True, transform=self.transform) 21 | plabels_path = "data/cifar10h-probs.npy" 22 | if os.path.isfile(plabels_path): 23 | self.plabels = torch.from_numpy(np.load(plabels_path)) 24 | else: 25 | raise FileNotFoundError("Could not find CIFAR-10H labels under " + plabels_path + ". Please download them (see README -> Installation).") 26 | 27 | # Limit to train/val/test 28 | # Each idx_set_i.csv contains 2000 image ids as crossvalidation splits of the original 10000 idxes. 29 | # We use i = {seed, seed + 1, seed + 2} MOD 5 for train, 30 | # i = seed + 3 MOD 5 for val 31 | # i = seed + 4 MOD 4 for test 32 | if mode == "train": 33 | train_1 = np.loadtxt(f"data/idx_set_{seed % 5}.csv", delimiter=",").astype("int") 34 | train_2 = np.loadtxt(f"data/idx_set_{(seed + 1) % 5}.csv", delimiter=",").astype("int") 35 | train_3 = np.loadtxt(f"data/idx_set_{(seed + 2) % 5}.csv", delimiter=",").astype("int") 36 | idxes = np.concatenate((train_1, train_2, train_3)) 37 | elif mode == "val": 38 | idxes = np.loadtxt(f"data/idx_set_{(seed + 3) % 5}.csv", delimiter=",").astype("int") 39 | elif mode == "test": 40 | idxes = np.loadtxt(f"data/idx_set_{(seed + 4) % 5}.csv", delimiter=",").astype("int") 41 | self.data.data = self.data.data[idxes] 42 | self.plabels = self.plabels[idxes] 43 | self.len = self.plabels.shape[0] 44 | 45 | # Create dataloader 46 | self.data.targets = self.plabels 47 | self.dl = DataLoader(self.data, batch_size=batch_size, shuffle=(mode == "train"), num_workers=2) 48 | 49 | # Create tensorized versions 50 | self.t_data = torch.from_numpy(self.data.data).to(device) 51 | self.plabels = self.plabels.to(device) 52 | 53 | # Prepare negative sampling 54 | self.p_different_class = None 55 | 56 | def get_dataloader(self): 57 | # If we just want to loop over our data, e.g., for validation and test 58 | return self.dl 59 | 60 | def sample_x(self, n=64): 61 | # Return some random images, without labels 62 | ids = torch.multinomial(torch.ones(self.len, device=self.device), n, replacement=False) 63 | x = torch.stack([self.data.__getitem__(i)[0] for i in ids], dim=0).to(self.device) 64 | 65 | return x 66 | 67 | def sample(self, n=64, same_ref=False, n_repeat=1, n_neg=1): 68 | # For generating contrastive samples, similar to models/generator.py 69 | ref_ids = torch.multinomial(torch.ones(self.len, device=self.device), n, replacement=False) 70 | if same_ref: 71 | ref_ids[:,:] = ref_ids[0,:] 72 | if n_repeat > 1: 73 | ref_ids.repeat(n_repeat, 1) 74 | 75 | # generate pos and neg samples 76 | pos_ids = self._sample_pos_by_candidates(ref_ids) 77 | if n_neg > 0: 78 | neg_ids = self._sample_neg(ref_ids, n_neg) 79 | 80 | # cast ids to images 81 | x_ref = torch.stack([self.data.__getitem__(i)[0] for i in ref_ids], dim=0).to(self.device) 82 | x_pos = torch.stack([self.data.__getitem__(i)[0] for i in pos_ids], dim=0).to(self.device).unsqueeze(1) 83 | if n_neg > 0: 84 | x_neg = torch.stack([self.data.__getitem__(i)[0] for i in torch.flatten(neg_ids)], dim=0).to(self.device) 85 | x_neg = torch.reshape(x_neg, [*neg_ids.shape, *x_neg.shape[1:]]) 86 | else: 87 | x_neg = None 88 | 89 | return x_ref, x_pos, x_neg 90 | 91 | def _sample_neg(self, ref_ids, n_neg=1): 92 | if self.p_different_class is None: 93 | # First time, calculate it: 94 | self.p_different_class = 1 - torch.matmul(self.plabels, self.plabels.t()) 95 | 96 | batchsize = ref_ids.shape[0] 97 | 98 | # Generate candidates until each z_ref has a sample 99 | partner_ids = torch.zeros((batchsize, n_neg), device=self.device) 100 | needs_partner = torch.ones((batchsize, n_neg), dtype=torch.uint8, device=self.device) 101 | while torch.any(needs_partner): 102 | # Limit ourselves to those samples that need partners (for efficiency) 103 | requires_partner = torch.any(needs_partner, dim=1) 104 | 105 | # Sample whether other samples are neg to the ref 106 | is_ref_and_cand_wanted = torch.bernoulli(self.p_different_class[ref_ids[requires_partner]]) 107 | is_ref_and_cand_wanted = is_ref_and_cand_wanted.type(torch.uint8) 108 | # Choose samples 109 | # in is_ref_and_cand_wanted we might have rows with full 0. This crashes torch.multinomial. 110 | # In case we have no 1, give everything a one and then filter out everything again afterwards 111 | p_select_bigger0 = is_ref_and_cand_wanted.float() + (torch.sum(is_ref_and_cand_wanted, dim=1) == 0).unsqueeze(1) 112 | chosen_idxes = torch.multinomial(p_select_bigger0, n_neg, replacement=False) 113 | 114 | # Choose the actual matches for each ref sample: 115 | for sub_idx, overall_idx in enumerate(requires_partner.nonzero()[:, 0]): 116 | # sub_idx is the index with respect to those that require a partner (the first that requires a partner, the second, ...) 117 | # overall_idx is the general idx of those samples (e.g., 8, 17, 52, ...) 118 | # The chosen_idx will probably contain samples with probability 0, because we forced it to sample n things, 119 | # even if there were less than n possible 1s in the array. 120 | n_matches = torch.sum(is_ref_and_cand_wanted[sub_idx]) 121 | n_needed = torch.sum(needs_partner[overall_idx, :]) 122 | n_new_samples = torch.min(n_matches, n_needed).type(torch.int) 123 | if n_new_samples > 0: 124 | # One trick we can use is that the prob-0 choices are always at the end 125 | chosen_idx = chosen_idxes[sub_idx, :n_new_samples] 126 | partner_ids[overall_idx, n_neg - n_needed:(n_neg - n_needed + n_new_samples)] = chosen_idx 127 | needs_partner[overall_idx, n_neg - n_needed:(n_neg - n_needed + n_new_samples)] = False 128 | 129 | # The dataloader expects int on cpu 130 | partner_ids = partner_ids.cpu().type(torch.uint8) 131 | return partner_ids 132 | 133 | def _sample_pos_by_candidates(self, ref_ids): 134 | batchsize = ref_ids.shape[0] 135 | id_partner = torch.zeros(batchsize, device=self.device).long() 136 | needs_partner = torch.ones(batchsize, dtype=torch.uint8, device=self.device) 137 | while torch.any(needs_partner): 138 | # Draw a class that we assume the reference belongs to 139 | ref_class = torch.multinomial(self.plabels[ref_ids], num_samples=1).squeeze(1) 140 | 141 | # See if we can find positive matches in that class 142 | is_ref_and_cand_pos = torch.bernoulli(self.plabels[:, ref_class].t()) 143 | p_select_bigger0 = is_ref_and_cand_pos + (torch.sum(is_ref_and_cand_pos, dim=1) == 0).unsqueeze(1) 144 | chosen_idxes = torch.multinomial(p_select_bigger0, num_samples=1, replacement=False) 145 | 146 | n_matches = torch.sum(is_ref_and_cand_pos, dim=1) 147 | id_partner[torch.logical_and(needs_partner, n_matches > 0)] = chosen_idxes[torch.logical_and(needs_partner, n_matches > 0), 0] 148 | needs_partner[torch.logical_and(needs_partner, n_matches > 0)] = False 149 | 150 | return id_partner 151 | 152 | def get_data(self): 153 | return self.data 154 | 155 | 156 | class ContrastiveCifarHard(ContrastiveCifar): 157 | def __init__(self, mode="train", seed=1, batch_size=64, device=torch.device("cuda:0")): 158 | super().__init__(mode=mode, seed=seed, batch_size=batch_size, device=device) 159 | 160 | # Make softlabels hard 161 | for i in torch.arange(self.plabels.shape[0]): 162 | hard_labels = torch.zeros(self.plabels.shape[1], device=self.device) 163 | hard_labels[torch.argmax(self.plabels[i,:])] = 1. 164 | self.plabels[i,:] = hard_labels 165 | 166 | # Create dataloader 167 | self.data.targets = self.plabels.cpu().numpy() 168 | self.dl = DataLoader(self.data, batch_size=batch_size, shuffle=(mode == "train"), num_workers=2) 169 | 170 | 171 | class ContrastiveCifarHardTrain(ContrastiveCifar): 172 | def __init__(self, mode="train", batch_size=64, device=torch.device("cuda:0"), random_augs=False): 173 | super().__init__(mode=mode, batch_size=batch_size, device=device) 174 | 175 | # Load data 176 | if random_augs: 177 | self.transform = transforms.Compose( 178 | [transforms.RandomCrop(32, padding=4), 179 | transforms.RandomHorizontalFlip(), 180 | transforms.ToTensor(), 181 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2471, 0.2435, 0.2616])]) 182 | self.data = datasets.CIFAR10(root='data/data_CIFAR10_test', train=True, download=True, transform=self.transform) 183 | self.plabels = torch.zeros((len(self.data.targets), 10)) 184 | self.plabels.scatter_(dim=1, index=torch.Tensor(self.data.targets).type(torch.long).unsqueeze(1), value=1.) 185 | self.len = self.plabels.shape[0] 186 | 187 | # Create dataloader 188 | self.dl = DataLoader(self.data, batch_size=batch_size, shuffle=(mode == "train"), num_workers=2) 189 | 190 | # Create tensorized versions 191 | self.t_data = torch.from_numpy(self.data.data).to(device) 192 | self.plabels = self.plabels.to(device) 193 | 194 | 195 | def make_lossy_dataloader(dataset, batchsize=64, shuffle=False): 196 | lossy_dataset = LossyCifar(dataset) 197 | return DataLoader(lossy_dataset, batch_size=batchsize, shuffle=shuffle, num_workers=2) 198 | 199 | ######################################################################################################################## 200 | # The following code is modified from https://github.com/tinkoff-ai/probabilistic-embeddings, Apache License 2.0 # 201 | ######################################################################################################################## 202 | 203 | 204 | class DatasetWrapper(torch.utils.data.Dataset): 205 | """Base class for dataset extension.""" 206 | 207 | def __init__(self, dataset): 208 | self._dataset = dataset 209 | 210 | @property 211 | def dataset(self): 212 | """Get base dataset.""" 213 | return self._dataset 214 | 215 | @property 216 | def classification(self): 217 | """Whether dataset is classification or verification.""" 218 | return self.dataset.classification 219 | 220 | @property 221 | def openset(self): 222 | return self.dataset.openset 223 | 224 | @property 225 | def labels(self): 226 | """Get dataset labels array. 227 | Labels are integers in the range [0, N-1]. 228 | """ 229 | return self.dataset.labels 230 | 231 | @property 232 | def has_quality(self): 233 | """Whether dataset assigns quality score to each sample or not.""" 234 | return self.dataset.has_quality 235 | 236 | def __len__(self): 237 | """Get dataset length.""" 238 | return len(self.dataset) 239 | 240 | def __getitem__(self, index): 241 | """Get element of the dataset. 242 | Classification dataset returns tuple (image, label). 243 | Verification dataset returns ((image1, image2), label). 244 | Datasets with quality assigned to each sample return tuples like 245 | (image, label, quality) or ((image1, image2), label, (quality1, quality2)). 246 | """ 247 | return self.dataset[index] 248 | 249 | 250 | class LossyCifar(DatasetWrapper): 251 | """Add lossy transformations to input data.""" 252 | def __init__(self, dataset): 253 | super().__init__(dataset) 254 | 255 | crop_min = 0.25 256 | crop_max = 1.0 257 | if crop_min > crop_max: 258 | raise AssertionError("Crop min size is greater than max.") 259 | # See if we already stored random crops for this configuration (to make them the same across runs) 260 | filepath = f'./data/randomcrops_{len(dataset)}_{crop_min}_{crop_max}.csv' 261 | if os.path.exists(filepath): 262 | self._center_crop = np.loadtxt(filepath, delimiter=",") 263 | else: 264 | self._center_crop = np.random.random(len(dataset)) * (crop_max - crop_min) + crop_min 265 | np.savetxt(filepath, self._center_crop, delimiter=",") 266 | 267 | @property 268 | def has_quality(self): 269 | """Whether dataset assigns quality score to each sample or not.""" 270 | return True 271 | 272 | def __getitem__(self, index): 273 | """Get element of the dataset. 274 | Classification dataset returns tuple (image, soft_labels, quality). 275 | """ 276 | image, label = self.dataset[index] 277 | 278 | if isinstance(image, Image.Image): 279 | image = np.asarray(image) 280 | 281 | center_crop = self._center_crop[index] 282 | if abs(center_crop - 1) > 1e-6: 283 | if isinstance(image, np.ndarray): 284 | # Image in HWC format. 285 | size = int(round(min(image.shape[0], image.shape[1]) * center_crop)) 286 | y_offset = (image.shape[0] - size) // 2 287 | x_offset = (image.shape[1] - size) // 2 288 | image = image[y_offset:y_offset + size, x_offset:x_offset + size] 289 | elif isinstance(image, torch.Tensor): 290 | # Image in CHW format. 291 | size = int(round(min(image.shape[1], image.shape[2]) * center_crop)) 292 | old_size = [image.shape[1], image.shape[2]] 293 | image = functional_transforms.center_crop(image, size) 294 | image = functional_transforms.resize(image, old_size) 295 | else: 296 | raise ValueError("Expected Numpy or torch Tensor.") 297 | if isinstance(image, np.ndarray): 298 | image = Image.fromarray(image) 299 | quality = center_crop 300 | return image, label, quality 301 | -------------------------------------------------------------------------------- /experiment_scripts/cifar experiment/cifar10Hhard_MCInfoNCE.sh: -------------------------------------------------------------------------------- 1 | python main_cifar.py \ 2 | --e_dim_z 16 \ 3 | --g_pos_kappa 16 \ 4 | --e_post_kappa_min 16 \ 5 | --e_post_kappa_max 16 \ 6 | --l_n_samples 128 \ 7 | --bs 128 \ 8 | --traindata test_hardlabels \ 9 | --n_neg 0 \ 10 | --loss MCInfoNCE \ 11 | --n_phases 0 \ 12 | --seed 1 13 | -------------------------------------------------------------------------------- /experiment_scripts/cifar experiment/cifar10Hsoft_ELK.sh: -------------------------------------------------------------------------------- 1 | python main_cifar.py \ 2 | --e_dim_z 8 \ 3 | --g_pos_kappa 32 \ 4 | --e_post_kappa_min 32 \ 5 | --e_post_kappa_max 32 \ 6 | --l_n_samples 128 \ 7 | --bs 128 \ 8 | --n_neg 1 \ 9 | --loss ELK \ 10 | --n_phases 0 \ 11 | --seed 1 12 | -------------------------------------------------------------------------------- /experiment_scripts/cifar experiment/cifar10Hsoft_HedgedInstance.sh: -------------------------------------------------------------------------------- 1 | python main_cifar.py \ 2 | --e_dim_z 8 \ 3 | --g_pos_kappa 32 \ 4 | --e_post_kappa_min 32 \ 5 | --e_post_kappa_max 32 \ 6 | --l_n_samples 128 \ 7 | --bs 128 \ 8 | --n_neg 0 \ 9 | --loss HedgedInstance \ 10 | --n_phases 1 \ 11 | --l_hib_a 2 \ 12 | --l_hib_b 1 \ 13 | --seed 1 14 | -------------------------------------------------------------------------------- /experiment_scripts/cifar experiment/cifar10Hsoft_MCInfoNCE.sh: -------------------------------------------------------------------------------- 1 | python main_cifar.py \ 2 | --e_dim_z 8 \ 3 | --g_pos_kappa 16 \ 4 | --e_post_kappa_min 16 \ 5 | --e_post_kappa_max 16 \ 6 | --l_n_samples 128 \ 7 | --bs 128 \ 8 | --n_neg 32 \ 9 | --loss MCInfoNCE \ 10 | --n_phases 1 \ 11 | --seed 1 12 | -------------------------------------------------------------------------------- /experiment_scripts/cifar experiment/cifar10hard_MCInfoNCE.sh: -------------------------------------------------------------------------------- 1 | python main_cifar.py \ 2 | --e_dim_z 8 \ 3 | --g_pos_kappa 64 \ 4 | --e_post_kappa_min 64 \ 5 | --e_post_kappa_max 64 \ 6 | --l_n_samples 128 \ 7 | --bs 128 \ 8 | --traindata train_hardlabels \ 9 | --n_neg 0 \ 10 | --loss MCInfoNCE \ 11 | --n_phases 1 \ 12 | --seed 1 13 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/ELK_ambiguous_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 1 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss ELK \ 14 | --l_learnable_params False \ 15 | --n_phases 0 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/ELK_ambiguous_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 1 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss ELK \ 14 | --l_learnable_params False \ 15 | --n_phases 0 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/ELK_ambiguous_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 1 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss ELK \ 14 | --l_learnable_params False \ 15 | --n_phases 0 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/ELK_ambiguous_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 1 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss ELK \ 14 | --l_learnable_params False \ 15 | --n_phases 0 \ 16 | --seed 7 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/ELK_ambiguous_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 1 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss ELK \ 14 | --l_learnable_params False \ 15 | --n_phases 0 \ 16 | --seed 8 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/HedgedInstance_ambiguous_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 0 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss HedgedInstance \ 14 | --l_learnable_params False \ 15 | --n_phases 0 \ 16 | --l_hib_a 1 \ 17 | --l_hib_b 0 \ 18 | --seed 4 19 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/HedgedInstance_ambiguous_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 0 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss HedgedInstance \ 14 | --l_learnable_params False \ 15 | --n_phases 0 \ 16 | --l_hib_a 1 \ 17 | --l_hib_b 0 \ 18 | --seed 5 19 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/HedgedInstance_ambiguous_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 0 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss HedgedInstance \ 14 | --l_learnable_params False \ 15 | --n_phases 0 \ 16 | --l_hib_a 1 \ 17 | --l_hib_b 0 \ 18 | --seed 6 19 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/HedgedInstance_ambiguous_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 0 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss HedgedInstance \ 14 | --l_learnable_params False \ 15 | --n_phases 0 \ 16 | --l_hib_a 1 \ 17 | --l_hib_b 0 \ 18 | --seed 7 19 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/HedgedInstance_ambiguous_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 0 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss HedgedInstance \ 14 | --l_learnable_params False \ 15 | --n_phases 0 \ 16 | --l_hib_a 1 \ 17 | --l_hib_b 0 \ 18 | --seed 8 19 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_2D_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 2 \ 3 | --g_dim_x 2 \ 4 | --e_dim_z 2 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 1 \ 11 | --n_batches_per_half_phase 4096 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_2D_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 2 \ 3 | --g_dim_x 2 \ 4 | --e_dim_z 2 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 1 \ 11 | --n_batches_per_half_phase 4096 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_2D_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 2 \ 3 | --g_dim_x 2 \ 4 | --e_dim_z 2 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 1 \ 11 | --n_batches_per_half_phase 4096 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_2D_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 2 \ 3 | --g_dim_x 2 \ 4 | --e_dim_z 2 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 1 \ 11 | --n_batches_per_half_phase 4096 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 7 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_2D_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 2 \ 3 | --g_dim_x 2 \ 4 | --e_dim_z 2 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 1 \ 11 | --n_batches_per_half_phase 4096 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 8 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_ambiguous_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_ambiguous_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_ambiguous_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_ambiguous_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 7 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_ambiguous_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 8 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_clear_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 64 \ 7 | --g_post_kappa_max 128 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_clear_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 64 \ 7 | --g_post_kappa_max 128 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_clear_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 64 \ 7 | --g_post_kappa_max 128 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_clear_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 64 \ 7 | --g_post_kappa_max 128 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 7 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_clear_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 64 \ 7 | --g_post_kappa_max 128 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 8 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_128_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 128 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_128_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 128 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_128_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 128 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_128_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 128 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 7 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_128_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 128 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 8 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_16_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 16 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_16_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 16 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_16_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 16 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_16_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 16 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 7 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_16_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 16 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 8 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_32_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 32 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_32_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 32 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_32_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 32 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_32_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 32 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 7 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_32_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 32 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 8 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_4_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 4 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_4_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 4 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_4_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 4 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_4_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 4 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 7 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_4_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 4 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 8 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_64_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 64 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 512 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_64_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 64 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 512 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_64_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 64 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 512 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_64_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 64 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 512 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 7 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_64_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 64 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 512 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 8 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_8_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 8 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_8_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 8 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_8_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 8 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_8_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 8 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 7 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_encoderdim_8_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 8 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 8 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_gaussiandistr_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --g_post_family Gaussian \ 9 | --l_n_samples 512 \ 10 | --bs 512 \ 11 | --n_neg 32 \ 12 | --n_batches_per_half_phase 50000 \ 13 | --use_wandb False \ 14 | --loss MCInfoNCE \ 15 | --l_learnable_params False \ 16 | --n_phases 1 \ 17 | --seed 4 18 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_gaussiandistr_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --g_post_family Gaussian \ 9 | --l_n_samples 512 \ 10 | --bs 512 \ 11 | --n_neg 32 \ 12 | --n_batches_per_half_phase 50000 \ 13 | --use_wandb False \ 14 | --loss MCInfoNCE \ 15 | --l_learnable_params False \ 16 | --n_phases 1 \ 17 | --seed 5 18 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_gaussiandistr_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --g_post_family Gaussian \ 9 | --l_n_samples 512 \ 10 | --bs 512 \ 11 | --n_neg 32 \ 12 | --n_batches_per_half_phase 50000 \ 13 | --use_wandb False \ 14 | --loss MCInfoNCE \ 15 | --l_learnable_params False \ 16 | --n_phases 1 \ 17 | --seed 6 18 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_gaussiandistr_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --g_post_family Gaussian \ 9 | --l_n_samples 512 \ 10 | --bs 512 \ 11 | --n_neg 32 \ 12 | --n_batches_per_half_phase 50000 \ 13 | --use_wandb False \ 14 | --loss MCInfoNCE \ 15 | --l_learnable_params False \ 16 | --n_phases 1 \ 17 | --seed 7 18 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_gaussiandistr_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --g_post_family Gaussian \ 9 | --l_n_samples 512 \ 10 | --bs 512 \ 11 | --n_neg 32 \ 12 | --n_batches_per_half_phase 50000 \ 13 | --use_wandb False \ 14 | --loss MCInfoNCE \ 15 | --l_learnable_params False \ 16 | --n_phases 1 \ 17 | --seed 8 18 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_10_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_10_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_10_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_128_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 128 \ 3 | --g_dim_x 128 \ 4 | --e_dim_z 128 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_128_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 128 \ 3 | --g_dim_x 128 \ 4 | --e_dim_z 128 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_128_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 128 \ 3 | --g_dim_x 128 \ 4 | --e_dim_z 128 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_16_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 16 \ 3 | --g_dim_x 16 \ 4 | --e_dim_z 16 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_16_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 16 \ 3 | --g_dim_x 16 \ 4 | --e_dim_z 16 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_16_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 16 \ 3 | --g_dim_x 16 \ 4 | --e_dim_z 16 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_32_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 32 \ 3 | --g_dim_x 32 \ 4 | --e_dim_z 32 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_32_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 32 \ 3 | --g_dim_x 32 \ 4 | --e_dim_z 32 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_32_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 32 \ 3 | --g_dim_x 32 \ 4 | --e_dim_z 32 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_40_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 40 \ 3 | --g_dim_x 40 \ 4 | --e_dim_z 40 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_40_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 40 \ 3 | --g_dim_x 40 \ 4 | --e_dim_z 40 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_40_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 40 \ 3 | --g_dim_x 40 \ 4 | --e_dim_z 40 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_48_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 48 \ 3 | --g_dim_x 48 \ 4 | --e_dim_z 48 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_48_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 48 \ 3 | --g_dim_x 48 \ 4 | --e_dim_z 48 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_48_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 48 \ 3 | --g_dim_x 48 \ 4 | --e_dim_z 48 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_56_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 56 \ 3 | --g_dim_x 56 \ 4 | --e_dim_z 56 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_56_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 56 \ 3 | --g_dim_x 56 \ 4 | --e_dim_z 56 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_56_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 56 \ 3 | --g_dim_x 56 \ 4 | --e_dim_z 56 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_64_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 64 \ 3 | --g_dim_x 64 \ 4 | --e_dim_z 64 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_64_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 64 \ 3 | --g_dim_x 64 \ 4 | --e_dim_z 64 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_highdim_64_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 64 \ 3 | --g_dim_x 64 \ 4 | --e_dim_z 64 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 256 \ 10 | --n_neg 8 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_injective_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min "inf" \ 7 | --g_post_kappa_max "inf" \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_injective_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min "inf" \ 7 | --g_post_kappa_max "inf" \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_injective_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min "inf" \ 7 | --g_post_kappa_max "inf" \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_injective_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min "inf" \ 7 | --g_post_kappa_max "inf" \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 7 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_injective_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min "inf" \ 7 | --g_post_kappa_max "inf" \ 8 | --l_n_samples 512 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 8 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_laplacedistr_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --g_post_family Laplace \ 9 | --l_n_samples 512 \ 10 | --bs 512 \ 11 | --n_neg 32 \ 12 | --n_batches_per_half_phase 50000 \ 13 | --use_wandb False \ 14 | --loss MCInfoNCE \ 15 | --l_learnable_params False \ 16 | --n_phases 1 \ 17 | --seed 4 18 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_laplacedistr_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --g_post_family Laplace \ 9 | --l_n_samples 512 \ 10 | --bs 512 \ 11 | --n_neg 32 \ 12 | --n_batches_per_half_phase 50000 \ 13 | --use_wandb False \ 14 | --loss MCInfoNCE \ 15 | --l_learnable_params False \ 16 | --n_phases 1 \ 17 | --seed 5 18 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_laplacedistr_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --g_post_family Laplace \ 9 | --l_n_samples 512 \ 10 | --bs 512 \ 11 | --n_neg 32 \ 12 | --n_batches_per_half_phase 50000 \ 13 | --use_wandb False \ 14 | --loss MCInfoNCE \ 15 | --l_learnable_params False \ 16 | --n_phases 1 \ 17 | --seed 6 18 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_laplacedistr_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --g_post_family Laplace \ 9 | --l_n_samples 512 \ 10 | --bs 512 \ 11 | --n_neg 32 \ 12 | --n_batches_per_half_phase 50000 \ 13 | --use_wandb False \ 14 | --loss MCInfoNCE \ 15 | --l_learnable_params False \ 16 | --n_phases 1 \ 17 | --seed 7 18 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_laplacedistr_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --g_post_family Laplace \ 9 | --l_n_samples 512 \ 10 | --bs 512 \ 11 | --n_neg 32 \ 12 | --n_batches_per_half_phase 50000 \ 13 | --use_wandb False \ 14 | --loss MCInfoNCE \ 15 | --l_learnable_params False \ 16 | --n_phases 1 \ 17 | --seed 8 18 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_16_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 16 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_16_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 16 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_16_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 16 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_16_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 16 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 7 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_16_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 16 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 8 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_1_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 1 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_1_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 1 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_1_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 1 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_1_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 1 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 7 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_1_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 1 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 8 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_256_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_256_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_256_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_256_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 7 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_256_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 256 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 8 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_4_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 4 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_4_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 4 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_4_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 4 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_4_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 4 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 7 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_4_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 4 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 8 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_64_seed_4.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 64 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 4 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_64_seed_5.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 64 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 5 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_64_seed_6.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 64 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 6 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_64_seed_7.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 64 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 7 17 | -------------------------------------------------------------------------------- /experiment_scripts/controlled experiment/MCInfoNCE_nmcsamples_64_seed_8.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --g_dim_z 10 \ 3 | --g_dim_x 10 \ 4 | --e_dim_z 10 \ 5 | --g_pos_kappa 20 \ 6 | --g_post_kappa_min 16 \ 7 | --g_post_kappa_max 32 \ 8 | --l_n_samples 64 \ 9 | --bs 512 \ 10 | --n_neg 32 \ 11 | --n_batches_per_half_phase 50000 \ 12 | --use_wandb False \ 13 | --loss MCInfoNCE \ 14 | --l_learnable_params False \ 15 | --n_phases 1 \ 16 | --seed 8 17 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from models.generator import Generator 3 | from models.encoder import Encoder 4 | from utils.losses import MCInfoNCE, ELK, HedgedInstance 5 | from tqdm import tqdm 6 | from arguments import get_args 7 | from utils.utils import init_seeds, pairwise_cos_sims, init_wandb 8 | from utils.metrics import numerical_eval, graphical_eval 9 | import torch 10 | import os 11 | import json 12 | from shutil import rmtree 13 | import wandb 14 | from torch.optim.lr_scheduler import StepLR 15 | 16 | def train_loop(args, gen, enc, loss): 17 | with torch.no_grad(): 18 | eval_set = gen._sample_x(args.n_numerical_eval) 19 | 20 | # train 21 | enc.train() 22 | loss.train() 23 | running_loss = 0 24 | # If n_phases == 0, means we want to do one long phase with 2 * args.n_batches_per_half_phase batches 25 | n_total_batches = max(args.n_phases, 1) * 2 * args.n_batches_per_half_phase 26 | for b in tqdm(range(n_total_batches), position=0, leave=True): 27 | # Choose which parameter training phase we are in 28 | if b % args.n_batches_per_half_phase == 0: 29 | if args.n_phases == 0: 30 | # Train parameters jointly 31 | # To avoid restarting the optimizer, do this only in the beginning 32 | if b == 0: 33 | params = list(enc.mu_net.parameters()) + list(enc.kappa_net.parameters()) 34 | if args.l_learnable_params: 35 | params += list(loss.parameters()) 36 | optim = torch.optim.Adam(params, lr=args.lr) 37 | scheduler = StepLR(optim, step_size=args.n_batches_per_half_phase, gamma=args.lr_decrease_after_phase) 38 | n_neg = args.n_neg 39 | else: 40 | # First train mu then kappa (then mu, then kappa, then mu...) 41 | lr = args.lr * args.lr_decrease_after_phase**(b / (2 * args.n_batches_per_half_phase)) 42 | if (b / args.n_batches_per_half_phase) % 2 == 0: 43 | # time to train mu: 44 | params = list(enc.mu_net.parameters()) 45 | if args.l_learnable_params: 46 | params += list(loss.parameters()) 47 | optim = torch.optim.Adam(params, lr=lr) 48 | scheduler = StepLR(optim, step_size=args.n_batches_per_half_phase / 2, gamma=args.lr_decrease_after_phase) 49 | # Use rolled negatives in the same batch when training mu in high dimensions (faster and no empirical difference) 50 | n_neg = args.n_neg if args.g_dim_z == 2 else 0 51 | else: 52 | # time to train kappa 53 | params = list(enc.kappa_net.parameters()) 54 | if args.l_learnable_params: 55 | params += list(loss.parameters()) 56 | optim = torch.optim.Adam(params, lr=lr) 57 | scheduler = StepLR(optim, step_size=args.n_batches_per_half_phase / 2, gamma=args.lr_decrease_after_phase) 58 | # Use repeated mu when training kappa 59 | n_neg = args.n_neg 60 | 61 | # Train 62 | optim.zero_grad() 63 | # Use the generator to create a batch 64 | x_ref, x_pos, x_neg = gen.sample(n=args.bs, n_neg=n_neg, oversampling_factor=args.oversampling_factor) 65 | mu_ref, kappa_ref = enc(x_ref) 66 | mu_pos, kappa_pos = enc(x_pos) 67 | # If we have n_neg = 0, then do not try to forward it 68 | if x_neg is not None: 69 | mu_neg, kappa_neg = enc(x_neg) 70 | else: 71 | mu_neg = None 72 | kappa_neg = None 73 | # Calculate loss 74 | cur_loss = loss(mu_ref, kappa_ref, mu_pos, kappa_pos, mu_neg, kappa_neg) 75 | cur_loss.backward() 76 | running_loss += cur_loss.detach().cpu().item() 77 | optim.step() 78 | scheduler.step() 79 | 80 | # Val 81 | if b == 0 or (b+1) % args.eval_every == 0 or b == n_total_batches - 1: 82 | if b == 0: 83 | avg_loss = running_loss 84 | else: 85 | avg_loss = running_loss / args.eval_every 86 | print(f'Loss: {avg_loss}') 87 | enc.eval() 88 | cosdist_mse, cosdist_corr, cosdist_rankcorr, kappa_mse, kappa_corr, kappa_rankcorr = numerical_eval(args, gen, enc, eval_set, eval_std_instead_of_param=args.eval_std_instead_of_param) 89 | if args.n_graphical_eval > 0: 90 | graphical_eval(args, gen, enc, eval_set[:args.n_graphical_eval], print_examples=b == n_total_batches - 1) 91 | plt.suptitle(f'After {b} batches: Loss={avg_loss:.4f}. Cosdist MSE={cosdist_mse:.3f} (cor={cosdist_corr:.3f}). Kappa MSE={kappa_mse:.3f} (cor={kappa_corr:.3f}).') 92 | plt.savefig(f'results/{args.savefolder}/embeds_after_{b:06d}_batches.png') 93 | plt.close() 94 | results_dict = {"cosdist_mse": cosdist_mse.detach().cpu().item(), 95 | "cosdist_corr": cosdist_corr, 96 | "cosdist_rankcorr": cosdist_rankcorr, 97 | "kappa_mse": kappa_mse.detach().cpu().item(), 98 | "kappa_corr": kappa_corr, 99 | "kappa_rankcorr":kappa_rankcorr, 100 | "temperature":loss.kappa.detach().cpu().item(), 101 | "loss": avg_loss, 102 | "batches": b} 103 | if args.use_wandb: 104 | wandb.log(results_dict) 105 | with open(f'results/{args.savefolder}/latest_results.json', 'w') as f2: 106 | json.dump(results_dict, f2) 107 | enc.train() 108 | running_loss = 0 109 | 110 | enc.eval() 111 | return enc 112 | 113 | def create_well_conditioned_generator(args): 114 | gen = None 115 | spread = 1 116 | while spread > args.g_min_spread or gen is None: 117 | # Create a new generator candidate 118 | gen = Generator(dim_x=args.g_dim_x, dim_hidden=args.g_dim_hidden, dim_z=args.g_dim_z, n_hidden=args.g_n_hidden, 119 | pos_kappa=args.g_pos_kappa, post_kappa_min=args.g_post_kappa_min, post_kappa_max=args.g_post_kappa_max, 120 | family=args.g_post_family, has_joint_backbone=args.has_joint_backbone) 121 | 122 | # Measure how much space in the latent space it fills 123 | samples = gen._sample_x(1000) 124 | mus, _ = gen(samples) 125 | spread = torch.min(pairwise_cos_sims(mus)) 126 | 127 | return gen 128 | 129 | def get_loss(args): 130 | if args.loss == "MCInfoNCE": 131 | loss = MCInfoNCE(kappa_init=args.g_pos_kappa, n_samples=args.l_n_samples) 132 | elif args.loss == "ELK": 133 | loss = ELK(kappa_init=args.g_pos_kappa) 134 | elif args.loss == "HedgedInstance": 135 | loss = HedgedInstance(kappa_init=args.l_hib_a, n_samples=args.l_n_samples, b_init=args.l_hib_b) 136 | else: 137 | raise NotImplementedError(f"loss {args.loss} is not implemented.") 138 | 139 | return loss 140 | 141 | if __name__=="__main__": 142 | args = get_args() 143 | 144 | ################### SETUP ################### 145 | init_seeds(args.seed) 146 | loss = get_loss(args) 147 | gen = create_well_conditioned_generator(args) 148 | enc = Encoder(dim_x=args.g_dim_x, dim_hidden=args.e_dim_hidden, dim_z=args.e_dim_z, n_hidden=args.e_n_hidden, 149 | post_kappa_min=args.e_post_kappa_min, post_kappa_max=args.e_post_kappa_max, x_samples=gen._sample_x(1000), 150 | has_joint_backbone=args.has_joint_backbone) 151 | 152 | # Clean up the output folder 153 | os.makedirs("results", exist_ok=True) 154 | rmtree(f'results/{args.savefolder}', ignore_errors=True) 155 | os.makedirs(f'results/{args.savefolder}', exist_ok=True) 156 | with open(f'results/{args.savefolder}/parameters.json', 'w') as f: 157 | json.dump(args.__dict__, f, indent=2) 158 | 159 | init_wandb(args) 160 | 161 | ################### TRAIN ################### 162 | train_loop(args, gen, enc, loss) 163 | 164 | print("Fin.") 165 | -------------------------------------------------------------------------------- /main_cifar.py: -------------------------------------------------------------------------------- 1 | from models.encoder_resnet import ResnetProbEncoder 2 | from tqdm import tqdm 3 | from arguments import get_args 4 | from utils.utils import init_seeds, init_wandb 5 | from utils.metrics import eval_cifar 6 | from utils.scheduler import WarmupCosineLR 7 | import torch 8 | import os 9 | import json 10 | from shutil import rmtree 11 | import wandb 12 | from main import get_loss 13 | from data.cifar_contrastive_loader import ContrastiveCifar, make_lossy_dataloader, ContrastiveCifarHard, ContrastiveCifarHardTrain 14 | from torch.optim.lr_scheduler import StepLR 15 | 16 | 17 | def train_loop(args, gen, enc, loss, gen_val): 18 | # train 19 | enc.train() 20 | loss.train() 21 | running_loss = 0 22 | best_rcorr_corrupt = -2 23 | best_rcorr_entropy = -2 24 | n_total_batches = max(args.n_phases, 1) * 2 * args.n_batches_per_half_phase 25 | for b in tqdm(range(n_total_batches), position=0, leave=True): 26 | if args.n_phases == 0: 27 | # Train parameters jointly 28 | # To avoid restarting the optimizer, do this only in the beginning 29 | if b == 0: 30 | params = list(enc.parameters()) 31 | if args.l_learnable_params: 32 | params += list(loss.parameters()) 33 | optim = torch.optim.AdamW(params, lr=args.lr) 34 | if args.pretrained: 35 | scheduler = StepLR(optim, step_size=args.n_batches_per_half_phase, gamma=args.lr_decrease_after_phase) 36 | else: 37 | # Use an lr scheduler better suited for starting from scratch 38 | scheduler = WarmupCosineLR(optim, warmup_epochs = 0.2 * n_total_batches, max_epochs=n_total_batches) 39 | # Use repeated mu when training kappa 40 | n_neg = args.n_neg 41 | training_phase = "joint" 42 | else: 43 | # Choose which parameter training phase we are in 44 | if b % args.n_batches_per_half_phase == 0: 45 | lr = args.lr * args.lr_decrease_after_phase**(b / args.n_batches_per_half_phase) 46 | if (b / args.n_batches_per_half_phase) % 2 == 0: 47 | # time to train mu: 48 | params = list(enc.parameters()) 49 | if args.l_learnable_params: 50 | params += list(loss.parameters()) 51 | optim = torch.optim.AdamW(params, lr=lr) 52 | if args.pretrained: 53 | scheduler = StepLR(optim, step_size=args.n_batches_per_half_phase, gamma=args.lr_decrease_after_phase) 54 | else: 55 | # Use an lr scheduler better suited for starting from scratch 56 | scheduler = WarmupCosineLR(optim, warmup_epochs = 0.2 * n_total_batches / 2, max_epochs=n_total_batches / 2) 57 | # Use single n when training mu 58 | n_neg = 0 59 | training_phase = "mu" 60 | else: 61 | # time to train kappa 62 | params = list(enc.parameters()) 63 | if args.l_learnable_params: 64 | params += list(loss.parameters()) 65 | optim = torch.optim.AdamW(params, lr=lr) 66 | if args.pretrained: 67 | scheduler = StepLR(optim, step_size=args.n_batches_per_half_phase, gamma=args.lr_decrease_after_phase) 68 | else: 69 | # Use an lr scheduler better suited for starting from scratch 70 | scheduler = WarmupCosineLR(optim, warmup_epochs = 0.2 * n_total_batches / 2, max_epochs=n_total_batches / 2) 71 | # Use repeated mu when training kappa 72 | n_neg = args.n_neg 73 | training_phase = "kappa" 74 | 75 | # Train 76 | optim.zero_grad() 77 | x_ref, x_pos, x_neg = gen.sample(n=args.bs, n_neg=n_neg) 78 | mu_ref, kappa_ref = enc(x_ref) 79 | # Need to do some reshaping since we have two batch size dimensions (batch and n_pos), but enc expects one 80 | pos_bs = x_pos.shape[:2] 81 | x_pos = torch.reshape(x_pos, [pos_bs[0] * pos_bs[1], *x_pos.shape[2:]]) 82 | mu_pos, kappa_pos = enc(x_pos) 83 | mu_pos = torch.reshape(mu_pos, [*pos_bs, *mu_pos.shape[1:]]) 84 | kappa_pos = torch.reshape(kappa_pos, [*pos_bs, *kappa_pos.shape[1:]]) 85 | if n_neg > 0: 86 | # Need to do some reshaping since we have two batch size dimensions (batch and n_neg), but enc expects one 87 | neg_bs = x_neg.shape[:2] 88 | x_neg = torch.reshape(x_neg, [neg_bs[0] * neg_bs[1], *x_neg.shape[2:]]) 89 | mu_neg, kappa_neg = enc(x_neg) 90 | mu_neg = torch.reshape(mu_neg, [*neg_bs, *mu_neg.shape[1:]]) 91 | kappa_neg = torch.reshape(kappa_neg, [*neg_bs, *kappa_neg.shape[1:]]) 92 | else: 93 | mu_neg = None 94 | kappa_neg = None 95 | # mu and kappa are not stored as individual networks, but as direction and norm of the same parameter 96 | # Hence, we need to turn off their gradients here instead of in the optimizers 97 | if training_phase == "mu": 98 | kappa_ref = kappa_ref.detach() 99 | kappa_pos = kappa_pos.detach() 100 | if n_neg > 0: 101 | kappa_neg = kappa_neg.detach() 102 | elif training_phase == "kappa": 103 | pass 104 | elif training_phase == "joint": 105 | pass 106 | cur_loss = loss(mu_ref, kappa_ref, mu_pos, kappa_pos, mu_neg, kappa_neg) 107 | cur_loss.backward() 108 | running_loss += cur_loss.detach().cpu().item() 109 | optim.step() 110 | scheduler.step() 111 | 112 | # Val 113 | if b == 0 or (b+1) % args.eval_every == 0: 114 | if b == 0: 115 | avg_loss = running_loss 116 | else: 117 | avg_loss = running_loss / args.eval_every 118 | print(f'Loss: {avg_loss}') 119 | enc.eval() 120 | r1, mapr, rcorr_entropy, r1_corrupt, mapr_corrupt, rcorr_corrupt = \ 121 | eval_cifar(args, enc, gen_val.get_dataloader(), make_lossy_dataloader(gen_val.data)) 122 | results_dict = {"r1": r1, 123 | "mapr": mapr, 124 | "rcorr_entropy": rcorr_entropy, 125 | "r1_corrupt": r1_corrupt, 126 | "mapr_corrupt": mapr_corrupt, 127 | "rcorr_corrupt": rcorr_corrupt, 128 | "temperature": loss.kappa.detach().cpu().item(), 129 | "loss": avg_loss, 130 | "batches": b} 131 | if args.use_wandb: 132 | wandb.log(results_dict) 133 | with open(f'results/{args.savefolder}/results_val_after_{(b+1):06d}_batches.json', 'w') as f2: 134 | json.dump(results_dict, f2) 135 | 136 | # Save the model if it achieved a new best 137 | if rcorr_corrupt > best_rcorr_corrupt: 138 | best_rcorr_corrupt = rcorr_corrupt 139 | best_rcorr_entropy = rcorr_entropy 140 | torch.save(enc.state_dict(), f"results/{args.savefolder}/encoder_params.pth") 141 | 142 | enc.train() 143 | running_loss = 0 144 | 145 | enc.eval() 146 | print(f"Validation score in best epoch: rcorr_entropy: {best_rcorr_entropy}, rcorr_corrupt: {best_rcorr_corrupt}") 147 | 148 | return enc 149 | 150 | 151 | def get_traindata(args): 152 | if args.traindata == "test_softlabels": 153 | gen = ContrastiveCifar(mode="train", seed=args.seed, batch_size=args.bs) 154 | elif args.traindata == "test_hardlabels": 155 | gen = ContrastiveCifarHard(mode="train", seed=args.seed, batch_size=args.bs) 156 | elif args.traindata == "train_hardlabels": 157 | gen = ContrastiveCifarHardTrain(mode="train", batch_size=args.bs, random_augs=not args.pretrained) 158 | else: 159 | raise NotImplementedError("traindata is not implemented.") 160 | 161 | return gen 162 | 163 | 164 | if __name__ == "__main__": 165 | args = get_args() 166 | 167 | ################### SETUP ################### 168 | init_seeds(args.seed) 169 | loss = get_loss(args) 170 | gen = get_traindata(args) 171 | gen_val = ContrastiveCifar(mode="val", seed=args.seed, batch_size=args.bs) 172 | enc = ResnetProbEncoder(dim_z=args.e_dim_z, post_kappa_min=args.e_post_kappa_min, post_kappa_max=args.e_post_kappa_max, pretrained=args.pretrained) 173 | 174 | # Clean up the output folder 175 | if args.train: 176 | os.makedirs("results", exist_ok=True) 177 | # Delete possible old results if we are (re-)training 178 | rmtree(f'results/{args.savefolder}', ignore_errors=True) 179 | os.makedirs(f'results/{args.savefolder}', exist_ok=True) 180 | with open(f'results/{args.savefolder}/parameters.json', 'w') as f: 181 | json.dump(args.__dict__, f, indent=2) 182 | 183 | init_wandb(args) 184 | 185 | ################### TRAIN ################### 186 | if args.train: 187 | train_loop(args, gen, enc, loss, gen_val) 188 | 189 | ################### TEST ################### 190 | if args.test: 191 | # Load best model 192 | enc.load_state_dict(torch.load(f"results/{args.savefolder}/encoder_params.pth")) 193 | 194 | gen_test = ContrastiveCifar(mode="test", seed=args.seed, batch_size=args.bs) 195 | r1, mapr, rcorr_entropy, r1_corrupt, mapr_corrupt, rcorr_corrupt = \ 196 | eval_cifar(args, enc, gen_test.get_dataloader(), make_lossy_dataloader(gen_test.data), "testset", False) 197 | results_dict = {"r1": r1, 198 | "mapr": mapr, 199 | "rcorr_entropy": rcorr_entropy, 200 | "r1_corrupt": r1_corrupt, 201 | "mapr_corrupt": mapr_corrupt, 202 | "rcorr_corrupt": rcorr_corrupt} 203 | if args.use_wandb: 204 | wandb.log(results_dict) 205 | with open(f'results/{args.savefolder}/results_testset.json', 'w') as f2: 206 | json.dump(results_dict, f2) 207 | 208 | print("Fin.") 209 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkirchhof/Probabilistic_Contrastive_Learning/b0f70c07e2bcf85e8eb13bf1e62fbd521fb6dd7d/models/__init__.py -------------------------------------------------------------------------------- /models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from utils.utils import construct_mlp 4 | 5 | class Encoder(nn.Module): 6 | def __init__(self, n_hidden=2, dim_x=10, dim_z=2, dim_hidden=32, 7 | post_kappa_min=20, post_kappa_max=80, x_samples=None, 8 | device=torch.device('cuda:0'), has_joint_backbone=False): 9 | super().__init__() 10 | 11 | # Save parameters 12 | self.device = device 13 | self.post_kappa_min = torch.tensor(post_kappa_min, device=device) 14 | self.post_kappa_max = torch.tensor(post_kappa_max, device=device) 15 | self.dim_x = dim_x 16 | self.dim_z = dim_z 17 | 18 | # Create networks 19 | self.has_joint_backbone = has_joint_backbone 20 | self.mu_net = construct_mlp(n_hidden=n_hidden, dim_x=dim_x, dim_z=dim_z, dim_hidden=dim_hidden) 21 | self.kappa_net = construct_mlp(n_hidden=n_hidden, dim_x=dim_x if not has_joint_backbone else dim_z, dim_z=1, dim_hidden=dim_hidden) 22 | self.mu_net = self.mu_net.to(device) 23 | self.kappa_net = self.kappa_net.to(device) 24 | 25 | # Bring the kappa network to the correct range 26 | self.kappa_upscale = 1. 27 | self.kappa_add = 0. 28 | with torch.no_grad(): 29 | self._rescale_kappa(x_samples) 30 | 31 | # Turn on gradients 32 | for p in self.mu_net.parameters(): 33 | p.requires_grad = True 34 | for p in self.kappa_net.parameters(): 35 | p.requires_grad = self.kappa_upscale.item() < float("Inf") # if we use infinite kappas, gradients break. So, turn off. 36 | 37 | def forward(self, x): 38 | # Return posterior (z-space) means and kappas for a batch of x 39 | mu = self.mu_net(x) 40 | mu = mu / torch.norm(mu, dim=-1).unsqueeze(-1) 41 | kappa = torch.exp(self.kappa_upscale * torch.log(1 + torch.exp(self.kappa_net(x if not self.has_joint_backbone else mu))) + self.kappa_add) 42 | return mu, kappa 43 | 44 | def _rescale_kappa(self, x_samples=None): 45 | # Goal: Find scale and shift parameters to bring the kappas to the desired range 46 | # indicated by self.post_kappa_min and self.post_kappa_max 47 | if torch.isinf(self.post_kappa_min) or torch.isinf(self.post_kappa_max): 48 | self.kappa_upscale = torch.ones(1, device=self.device) * float("inf") 49 | self.kappa_add = torch.ones(1, device=self.device) * float("inf") 50 | else: 51 | if self.post_kappa_max <= self.post_kappa_min: 52 | raise("post_kappa_max has to be > post_kappa_min.") 53 | if x_samples is None: 54 | raise("Please provide x_samples to the encoder to know which region of x we're dealing with.") 55 | kappa_samples = torch.log(1 + torch.exp(self.kappa_net(x_samples))) 56 | sample_min = torch.min(kappa_samples) 57 | sample_max = torch.max(kappa_samples) 58 | 59 | self.kappa_upscale = (torch.log(self.post_kappa_max) - torch.log(self.post_kappa_min)) / ( 60 | sample_max - sample_min) 61 | self.kappa_add = torch.log(self.post_kappa_max) - self.kappa_upscale * sample_max 62 | -------------------------------------------------------------------------------- /models/encoder_resnet.py: -------------------------------------------------------------------------------- 1 | # Code modified from https://github.com/huyvnphan/PyTorch_CIFAR10 under MIT license 2 | 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | from utils.utils import construct_mlp 7 | 8 | __all__ = [ 9 | "ResNet", 10 | "resnet18", 11 | "resnet34", 12 | "resnet50", 13 | ] 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 17 | """3x3 convolution with padding""" 18 | return nn.Conv2d( 19 | in_planes, 20 | out_planes, 21 | kernel_size=3, 22 | stride=stride, 23 | padding=dilation, 24 | groups=groups, 25 | bias=False, 26 | dilation=dilation, 27 | ) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__( 39 | self, 40 | inplanes, 41 | planes, 42 | stride=1, 43 | downsample=None, 44 | groups=1, 45 | base_width=64, 46 | dilation=1, 47 | norm_layer=None, 48 | ): 49 | super(BasicBlock, self).__init__() 50 | if norm_layer is None: 51 | norm_layer = nn.BatchNorm2d 52 | if groups != 1 or base_width != 64: 53 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 54 | if dilation > 1: 55 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 56 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 57 | self.conv1 = conv3x3(inplanes, planes, stride) 58 | self.bn1 = norm_layer(planes) 59 | self.relu = nn.ReLU(inplace=True) 60 | self.conv2 = conv3x3(planes, planes) 61 | self.bn2 = norm_layer(planes) 62 | self.downsample = downsample 63 | self.stride = stride 64 | 65 | def forward(self, x): 66 | identity = x 67 | 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | 75 | if self.downsample is not None: 76 | identity = self.downsample(x) 77 | 78 | out += identity 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class Bottleneck(nn.Module): 85 | expansion = 4 86 | 87 | def __init__( 88 | self, 89 | inplanes, 90 | planes, 91 | stride=1, 92 | downsample=None, 93 | groups=1, 94 | base_width=64, 95 | dilation=1, 96 | norm_layer=None, 97 | ): 98 | super(Bottleneck, self).__init__() 99 | if norm_layer is None: 100 | norm_layer = nn.BatchNorm2d 101 | width = int(planes * (base_width / 64.0)) * groups 102 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 103 | self.conv1 = conv1x1(inplanes, width) 104 | self.bn1 = norm_layer(width) 105 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 106 | self.bn2 = norm_layer(width) 107 | self.conv3 = conv1x1(width, planes * self.expansion) 108 | self.bn3 = norm_layer(planes * self.expansion) 109 | self.relu = nn.ReLU(inplace=True) 110 | self.downsample = downsample 111 | self.stride = stride 112 | 113 | def forward(self, x): 114 | identity = x 115 | 116 | out = self.conv1(x) 117 | out = self.bn1(out) 118 | out = self.relu(out) 119 | 120 | out = self.conv2(out) 121 | out = self.bn2(out) 122 | out = self.relu(out) 123 | 124 | out = self.conv3(out) 125 | out = self.bn3(out) 126 | 127 | if self.downsample is not None: 128 | identity = self.downsample(x) 129 | 130 | out += identity 131 | out = self.relu(out) 132 | 133 | return out 134 | 135 | 136 | class ResNet(nn.Module): 137 | def __init__( 138 | self, 139 | block, 140 | layers, 141 | num_classes=10, 142 | zero_init_residual=False, 143 | groups=1, 144 | width_per_group=64, 145 | replace_stride_with_dilation=None, 146 | norm_layer=None, 147 | ): 148 | super(ResNet, self).__init__() 149 | if norm_layer is None: 150 | norm_layer = nn.BatchNorm2d 151 | self._norm_layer = norm_layer 152 | 153 | self.inplanes = 64 154 | self.dilation = 1 155 | if replace_stride_with_dilation is None: 156 | # each element in the tuple indicates if we should replace 157 | # the 2x2 stride with a dilated convolution instead 158 | replace_stride_with_dilation = [False, False, False] 159 | if len(replace_stride_with_dilation) != 3: 160 | raise ValueError( 161 | "replace_stride_with_dilation should be None " 162 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation) 163 | ) 164 | self.groups = groups 165 | self.base_width = width_per_group 166 | 167 | # CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1 168 | self.conv1 = nn.Conv2d( 169 | 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False 170 | ) 171 | # END 172 | 173 | self.bn1 = norm_layer(self.inplanes) 174 | self.relu = nn.ReLU(inplace=True) 175 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 176 | self.layer1 = self._make_layer(block, 64, layers[0]) 177 | self.layer2 = self._make_layer( 178 | block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] 179 | ) 180 | self.layer3 = self._make_layer( 181 | block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] 182 | ) 183 | self.layer4 = self._make_layer( 184 | block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] 185 | ) 186 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 187 | self.fc = nn.Linear(512 * block.expansion, num_classes) 188 | 189 | for m in self.modules(): 190 | if isinstance(m, nn.Conv2d): 191 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 192 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 193 | nn.init.constant_(m.weight, 1) 194 | nn.init.constant_(m.bias, 0) 195 | 196 | # Zero-initialize the last BN in each residual branch, 197 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 198 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 199 | if zero_init_residual: 200 | for m in self.modules(): 201 | if isinstance(m, Bottleneck): 202 | nn.init.constant_(m.bn3.weight, 0) 203 | elif isinstance(m, BasicBlock): 204 | nn.init.constant_(m.bn2.weight, 0) 205 | 206 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 207 | norm_layer = self._norm_layer 208 | downsample = None 209 | previous_dilation = self.dilation 210 | if dilate: 211 | self.dilation *= stride 212 | stride = 1 213 | if stride != 1 or self.inplanes != planes * block.expansion: 214 | downsample = nn.Sequential( 215 | conv1x1(self.inplanes, planes * block.expansion, stride), 216 | norm_layer(planes * block.expansion), 217 | ) 218 | 219 | layers = [] 220 | layers.append( 221 | block( 222 | self.inplanes, 223 | planes, 224 | stride, 225 | downsample, 226 | self.groups, 227 | self.base_width, 228 | previous_dilation, 229 | norm_layer, 230 | ) 231 | ) 232 | self.inplanes = planes * block.expansion 233 | for _ in range(1, blocks): 234 | layers.append( 235 | block( 236 | self.inplanes, 237 | planes, 238 | groups=self.groups, 239 | base_width=self.base_width, 240 | dilation=self.dilation, 241 | norm_layer=norm_layer, 242 | ) 243 | ) 244 | 245 | return nn.Sequential(*layers) 246 | 247 | def forward(self, x): 248 | x = self.conv1(x) 249 | x = self.bn1(x) 250 | x = self.relu(x) 251 | x = self.maxpool(x) 252 | 253 | x = self.layer1(x) 254 | x = self.layer2(x) 255 | x = self.layer3(x) 256 | x = self.layer4(x) 257 | 258 | x = self.avgpool(x) 259 | x = x.reshape(x.size(0), -1) 260 | #x = self.fc(x) # Commented this out because we do not want class logits but embeddings as outputs 261 | 262 | return x 263 | 264 | 265 | def _resnet(arch, block, layers, pretrained, progress, device, **kwargs): 266 | model = ResNet(block, layers, **kwargs) 267 | if pretrained: 268 | script_dir = os.path.dirname(__file__) 269 | file_path = script_dir + "/state_dicts/" + arch + ".pt" 270 | if os.path.isfile(file_path): 271 | state_dict = torch.load(file_path, map_location=device) 272 | model.load_state_dict(state_dict) 273 | else: 274 | raise FileNotFoundError("Could not find pretrained weights under " + file_path + ". Please download them (see README -> Installation).") 275 | return model 276 | 277 | 278 | def resnet18(pretrained=False, progress=True, device="cpu", **kwargs): 279 | """Constructs a ResNet-18 model. 280 | Args: 281 | pretrained (bool): If True, returns a model pre-trained on ImageNet 282 | progress (bool): If True, displays a progress bar of the download to stderr 283 | """ 284 | return _resnet( 285 | "resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, device, **kwargs 286 | ) 287 | 288 | 289 | def resnet34(pretrained=False, progress=True, device="cpu", **kwargs): 290 | """Constructs a ResNet-34 model. 291 | Args: 292 | pretrained (bool): If True, returns a model pre-trained on ImageNet 293 | progress (bool): If True, displays a progress bar of the download to stderr 294 | """ 295 | return _resnet( 296 | "resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, device, **kwargs 297 | ) 298 | 299 | 300 | def resnet50(pretrained=False, progress=True, device="cpu", **kwargs): 301 | """Constructs a ResNet-50 model. 302 | Args: 303 | pretrained (bool): If True, returns a model pre-trained on ImageNet 304 | progress (bool): If True, displays a progress bar of the download to stderr 305 | """ 306 | return _resnet( 307 | "resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, device, **kwargs 308 | ) 309 | 310 | 311 | class ResnetProbEncoder(nn.Module): 312 | def __init__(self, dim_z=10, post_kappa_min=20, post_kappa_max=80, device=torch.device("cuda:0"), pretrained=True): 313 | super().__init__() 314 | self.device = device 315 | 316 | # Load a pretrained ResNet18 backend 317 | self.backend = resnet18(pretrained=pretrained) 318 | self.backend = self.backend.to(self.device) 319 | 320 | # Add a layer that casts down the resnet embeddings to dim_z dimensions 321 | self.post_kappa_min = post_kappa_min 322 | self.lin = torch.nn.Linear(in_features=512, out_features=dim_z, bias=False, device=self.device) 323 | nn.init.xavier_normal_(self.lin.weight, gain=post_kappa_min if post_kappa_min < float("inf") else 1) 324 | self.lin = nn.utils.weight_norm(self.lin, dim=0, name="weight") 325 | 326 | def forward(self, x): 327 | # Return posterior (z-space) means and kappas for a batch of x 328 | x = self.backend(x) 329 | embed = self.lin(x) 330 | 331 | mu = embed / torch.norm(embed, dim=-1).unsqueeze(-1) 332 | 333 | if self.post_kappa_min < float("inf"): 334 | kappa = torch.norm(embed, dim=-1).unsqueeze(-1) 335 | else: 336 | kappa = torch.ones(embed.shape[:-1], device=self.device).unsqueeze(-1) * float("inf") 337 | return mu, kappa 338 | -------------------------------------------------------------------------------- /models/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from utils.utils import construct_mlp, vmf_norm_ratio 4 | from utils.vmf_sampler import VonMisesFisher 5 | from torch.distributions.normal import Normal 6 | from torch.distributions.laplace import Laplace 7 | from utils.losses import smoothness_loss 8 | 9 | 10 | def smoothen_via_training(gen, print_progress=False): 11 | # Turn on gradients 12 | for p in gen.mu_net.parameters(): 13 | p.requires_grad = True 14 | gen.train() 15 | 16 | # train 17 | optim = torch.optim.Adam(gen.parameters(), lr=0.01) 18 | running_loss = 0 19 | for b in range(5000): 20 | optim.zero_grad() 21 | with torch.no_grad(): 22 | x = gen._sample_x(64) 23 | x.requires_grad = True 24 | mu, _ = gen(x) 25 | loss = 0 26 | loss += smoothness_loss(x, mu) 27 | running_loss += loss.detach() 28 | loss.backward() 29 | optim.step() 30 | 31 | if print_progress and b % 500 == 0: 32 | if b == 0: 33 | avg_loss = running_loss 34 | else: 35 | avg_loss = running_loss / 500 36 | print(f'Loss: {avg_loss}') 37 | running_loss = 0 38 | 39 | # Turn off gradients 40 | gen.eval() 41 | for p in gen.mu_net.parameters(): 42 | p.requires_grad = False 43 | 44 | return gen 45 | 46 | class Generator(nn.Module): 47 | def __init__(self, n_hidden=2, dim_x=10, dim_z=2, dim_hidden=32, pos_kappa=10, n_samples=10, 48 | post_kappa_min=20, post_kappa_max=80, family="vmf", device=torch.device('cuda:0'), 49 | has_joint_backbone=False): 50 | super().__init__() 51 | 52 | # Save parameters 53 | self.device = device 54 | self.post_kappa_min = torch.tensor(post_kappa_min, device=device) 55 | self.post_kappa_max = torch.tensor(post_kappa_max, device=device) 56 | self.dim_x = dim_x 57 | self.dim_z = dim_z 58 | self.pos_kappa = pos_kappa 59 | self.family = family 60 | 61 | # For sampling 62 | self.n_samples = n_samples 63 | self.denom_const = None # will be calculated on demand below 64 | 65 | # Create networks 66 | self.has_joint_backbone = has_joint_backbone 67 | self.mu_net = construct_mlp(n_hidden=n_hidden, dim_x=dim_x, dim_z=dim_z, dim_hidden=dim_hidden) 68 | self.kappa_net = construct_mlp(n_hidden=n_hidden - 1, dim_x=dim_x if not has_joint_backbone else dim_z, dim_z=1, dim_hidden=dim_hidden) 69 | self.mu_net = self.mu_net.to(device) 70 | self.kappa_net = self.kappa_net.to(device) 71 | 72 | # Turn off gradients 73 | for p in self.mu_net.parameters(): 74 | p.requires_grad = False 75 | for p in self.kappa_net.parameters(): 76 | p.requires_grad = False 77 | 78 | # Bring the kappa network to the correct range 79 | self.kappa_upscale = 1. 80 | self.kappa_add = 0. 81 | self._rescale_kappa() 82 | 83 | smoothen_via_training(self) 84 | 85 | def forward(self, x): 86 | # Return posterior (z-space) means and kappas for a batch of x 87 | mu = self.mu_net(x) 88 | mu = mu / torch.norm(mu, dim=-1).unsqueeze(-1) 89 | kappa = torch.exp(self.kappa_upscale * torch.log(1 + torch.exp(self.kappa_net(x if not self.has_joint_backbone else mu))) + self.kappa_add) 90 | return mu, kappa 91 | 92 | def _rescale_kappa(self): 93 | # Goal: Find scale and shift parameters to bring the kappas to the desired range 94 | # indicated by self.post_kappa_min and self.post_kappa_max 95 | if torch.isinf(self.post_kappa_min) or torch.isinf(self.post_kappa_max): 96 | self.kappa_upscale = torch.ones(1, device=self.device) * float("inf") 97 | self.kappa_add = torch.ones(1, device=self.device) * float("inf") 98 | else: 99 | if self.post_kappa_max <= self.post_kappa_min: 100 | raise("post_kappa_max has to be > post_kappa_min.") 101 | x_samples = self._sample_x(1000) 102 | kappa_samples = torch.log(1 + torch.exp(self.kappa_net(x_samples))) 103 | sample_min = torch.min(kappa_samples) 104 | sample_max = torch.max(kappa_samples) 105 | 106 | self.kappa_upscale = (torch.log(self.post_kappa_max) - torch.log(self.post_kappa_min)) / ( 107 | sample_max - sample_min) 108 | self.kappa_add = torch.log(self.post_kappa_max) - self.kappa_upscale * sample_max 109 | 110 | def sample(self,n=64, n_neg=1, oversampling_factor=1, same_ref=False): 111 | # Generates (x_ref, x_pos, x_neg) triplets. 112 | # Input: 113 | # n - integer, batchsize (number of x_ref) 114 | # n_neg - integer, number of negatives (0 to return None) 115 | # oversampling_factor - integer, how many candidates to generate to select x_pos and x_neg from. 116 | # Use a value as high as possible, otherwise need to resample 117 | # same_ref - boolean, whether to use the same x_ref for the whole batch (for debugging) 118 | 119 | # Generate random samples from x-space 120 | x_ref = self._sample_x(n) 121 | if same_ref: 122 | x_ref[:,:] = x_ref[0,:] 123 | z_ref = self._sample_z_from_x(x_ref) 124 | 125 | # generate pos and neg samples to the above samples 126 | x_pos, x_neg, _, _ = self._sample_pos_neg_by_candidates(z_ref, n_neg, oversampling_factor) 127 | 128 | return x_ref, x_pos, x_neg 129 | 130 | def _sample_x(self, n): 131 | return torch.rand((n, self.dim_x), device=self.device) 132 | 133 | def _sample_z_from_x(self, x): 134 | # Takes a batch of x, encodes their posteriors and draws from them 135 | mu, kappa = self.forward(x) 136 | if self.family == "vmf": 137 | z_distrs = VonMisesFisher(mu, kappa) 138 | elif self.family == "Gaussian": 139 | z_distrs = Normal(mu, 1/torch.sqrt(kappa)) 140 | elif self.family == "Laplace": 141 | z_distrs = Laplace(mu, 1/kappa) 142 | z_samples = z_distrs.sample() 143 | z_samples = torch.nn.functional.normalize(z_samples, dim=-1) 144 | return z_samples 145 | 146 | def _sample_pos_neg_by_candidates(self, z_ref, n_neg=1, oversampling_factor=1): 147 | # Sample x-candidates, encode them into z and try to find pos/neg matches to the reference points 148 | # Works if the area that z_pos covers inside the whole z space is relatively high. 149 | # z_ref - [batchsize, x_dim] batch of reference points 150 | # oversampling_factor - integer, how many candidates to generate to select x_pos and x_neg from. 151 | # Use a value as high as possible, otherwise need to resample 152 | 153 | x_pos, z_pos = self._sample_candidates(z_ref, n=1, want_pos=True, oversampling_factor=oversampling_factor) 154 | if n_neg > 0: 155 | x_neg, z_neg = self._sample_candidates(z_ref, n=n_neg, want_pos=False, oversampling_factor=oversampling_factor) 156 | else: 157 | x_neg = None 158 | z_neg = None 159 | 160 | return x_pos, x_neg, z_pos, z_neg 161 | 162 | def _sample_candidates(self, z_ref, n=1, want_pos=True, oversampling_factor=1): 163 | batchsize = z_ref.shape[0] 164 | 165 | # Generate candidates until each z_ref has a sample 166 | x_partner = torch.zeros((batchsize, n, self.dim_x), device=self.device) 167 | z_partner = torch.zeros((batchsize, n, self.dim_z), device=self.device) 168 | needs_partner = torch.ones((batchsize, n), dtype=torch.uint8, device=self.device) 169 | while torch.any(needs_partner): 170 | requires_partner = torch.any(needs_partner, dim=1) 171 | n_require_partner = torch.sum(requires_partner) 172 | x_cand = self._sample_x(n_require_partner * n * oversampling_factor) 173 | z_cand = self._sample_z_from_x(x_cand) 174 | 175 | # sample whether the candidates are pos/neg to the ref 176 | # Each x_ref has its own candidates 177 | cand_per_ref = z_cand.reshape(n_require_partner, n*oversampling_factor, z_cand.shape[-1]) 178 | prob_ref_and_cand_pos = self._pos_prob(z_ref[requires_partner].unsqueeze(1), cand_per_ref) 179 | is_ref_and_cand_pos = torch.bernoulli(prob_ref_and_cand_pos) 180 | is_ref_and_cand_pos = is_ref_and_cand_pos.type(torch.uint8) 181 | is_ref_and_cand_wanted = is_ref_and_cand_pos == want_pos 182 | 183 | # Choose samples 184 | # in is_ref_and_cand_wanted we might have rows with full 0. This crashes torch.multinomial. 185 | # In case we have no 1, give everything a one and then filter out everything again afterwards 186 | p_select_bigger0 = is_ref_and_cand_wanted.float() + (torch.sum(is_ref_and_cand_wanted, dim=1) == 0).unsqueeze(1) 187 | chosen_idxes = torch.multinomial(p_select_bigger0, n, replacement=False) 188 | # Currently, chosen_idxes indices the columns per row. 189 | # We want to get back to the original indexing of the flattened x_cand and z_cand tensors: 190 | chosen_idxes = chosen_idxes + torch.arange(n_require_partner, device=chosen_idxes.device).unsqueeze(1) * n * oversampling_factor 191 | 192 | if n > 1: 193 | # If we need several samples, we need to fill in the tensor sample by sample, because we might have 194 | # a different amount of valid candidates per sample and cannot tensorize this indexing 195 | for sub_idx, overall_idx in enumerate(requires_partner.nonzero()[:,0]): 196 | # sub_idx is the index with respect to those that require a partner (the first that requires a partner, the second, ...) 197 | # overall_idx is the general idx of those samples (e.g., 8, 17, 52, ...) 198 | # The chosen_idx will probably contain samples with probability 0, because we forced it to sample n things, 199 | # even if there were less than n possible 1s in the array. 200 | n_matches = torch.sum(is_ref_and_cand_wanted[sub_idx]) 201 | n_needed = torch.sum(needs_partner[overall_idx, :]) 202 | n_new_samples = torch.min(n_matches, n_needed).type(torch.int) 203 | if n_new_samples > 0: 204 | # One trick we can use is that the prob-0 choices are always at the end 205 | chosen_idx = chosen_idxes[sub_idx,:n_new_samples] 206 | x_partner[overall_idx, n - n_needed:(n - n_needed + n_new_samples)] = x_cand[chosen_idx, :] 207 | z_partner[overall_idx, n - n_needed:(n - n_needed + n_new_samples)] = z_cand[chosen_idx, :] 208 | needs_partner[overall_idx, n - n_needed:(n - n_needed + n_new_samples)] = False 209 | elif n == 1: 210 | # We can speed up the indexing by tensorizing it 211 | n_matches = torch.sum(is_ref_and_cand_wanted, dim=1) 212 | x_partner[requires_partner.nonzero()[n_matches > 0, 0], 0] = x_cand[chosen_idxes[n_matches > 0, 0], :] 213 | z_partner[requires_partner.nonzero()[n_matches > 0, 0], 0] = z_cand[chosen_idxes[n_matches > 0, 0], :] 214 | needs_partner[requires_partner.nonzero()[n_matches > 0, 0], 0] = False 215 | 216 | return x_partner, z_partner 217 | 218 | def _pos_prob(self, z1, z2): 219 | # Returns P(Y = 1|z_1, z_2) based on the P(z_2|Y=1, z_1) pos-vMF distribution 220 | # and the uniform distribution for negative samples 221 | # Input: 222 | # z_1 - [batchsize_1, z_dim] tensor containing rowwise normalized zs 223 | # z_2 - [batchsize_2, z_dim] tensor containing rowwise normalized zs 224 | # Output: 225 | # [batchsize_1, batchsize_2] tensor containing probabilities P(Y=1) in [0, 1] 226 | 227 | # Calculate these constants here and not in the class init, because not all strategies need them 228 | if self.denom_const is None: 229 | self.denom_const = torch.tensor(vmf_norm_ratio(self.pos_kappa, self.dim_z), device=self.device) 230 | 231 | cos = torch.sum(z1 * z2, dim=-1) 232 | log_pos_dens = self.pos_kappa * cos 233 | log_neg_dens = self.denom_const 234 | 235 | return torch.exp(log_pos_dens - torch.logsumexp(torch.stack((log_pos_dens, log_neg_dens * torch.ones(log_pos_dens.shape, device=self.device)), dim=0), dim=0)) 236 | -------------------------------------------------------------------------------- /models/state_dicts/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /results/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkirchhof/Probabilistic_Contrastive_Learning/b0f70c07e2bcf85e8eb13bf1e62fbd521fb6dd7d/thumbnail.png -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkirchhof/Probabilistic_Contrastive_Learning/b0f70c07e2bcf85e8eb13bf1e62fbd521fb6dd7d/utils/__init__.py -------------------------------------------------------------------------------- /utils/approx_vmf_norm_const.R: -------------------------------------------------------------------------------- 1 | # This script compares the true log(cp(kappa)) normalizing factor to approximations from the literature 2 | 3 | # True Cp(k) 4 | logcp_ratio = function(kappa, p = 512){ 5 | # Calculate for the big kappa 6 | res = (p/2 - 1) * log(kappa) - p/2 * log(2 * pi) - log(besselI(kappa, p / 2 - 1, expon.scaled = TRUE)) - kappa 7 | res[res == Inf | res == -Inf] = NA 8 | # Calculate for kappa = 0 9 | res0 = -log(2) - p/2 * log(pi) + lgamma(p / 2) 10 | res0[res0 == Inf | res0 == -Inf] = NA 11 | return(res0 - res) 12 | } 13 | 14 | # Taylor approximation 15 | # Values generated for kappa: 10:500 16 | p = 128 17 | kappa = 10:500 18 | gt = logcp_ratio(kappa, p) 19 | approx = lm(y ~ x + I(x^1.55), data=data.frame(y=gt, x=kappa)) 20 | summary(approx) 21 | taylor = predict(approx, data.frame(x=kappa)) 22 | plot(kappa[!is.na(taylor)], predict(approx, newdata=data.frame(x=kappa)), xlim=range(kappa), type = "l", las=1, #main=expression("Taylor Approximation of"~log(C[p](kappa))), 23 | xlab="", ylab="", col="red") 24 | points(x = kappa[!is.na(gt)], y = gt[!is.na(gt)], col="black", type="l") 25 | title(sub=paste0(c("p =", p), collapse=" ")) 26 | print(summary(approx)) 27 | 28 | 29 | # Just the normalizing constant (for ELK) 30 | logcp = function(kappa, p=10){ 31 | res = (p/2 - 1) * log(kappa) - p/2 * log(2 * pi) - log(besselI(kappa, p / 2 - 1, expon.scaled = TRUE)) - kappa 32 | res[res == Inf | res == -Inf] = NA 33 | return(res) 34 | } 35 | 36 | p = 10 37 | kappa = 10:50 38 | gt = logcp(kappa, p=p) 39 | approx = lm(y ~ x + I(x^1.1), data=data.frame(y=gt, x=kappa)) 40 | summary(approx) 41 | taylor = predict(approx, data.frame(x=kappa)) 42 | plot(kappa[!is.na(taylor)], predict(approx, newdata=data.frame(x=kappa)), xlim=range(kappa), type = "l", las=1, #main=expression("Taylor Approximation of"~log(C[p](kappa))), 43 | xlab="", ylab="", col="red") 44 | points(x = kappa[!is.na(gt)], y = gt[!is.na(gt)], col="black", type="l") 45 | title(sub=paste0(c("p =", p), collapse=" ")) 46 | print(summary(approx)) -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from utils.vmf_sampler import VonMisesFisher 4 | from utils.utils import pairwise_cos_sims, pairwise_l2_dists, log_vmf_norm_const 5 | 6 | 7 | class MCInfoNCE(nn.Module): 8 | def __init__(self, kappa_init=20, n_samples=16, device=torch.device('cuda:0')): 9 | super().__init__() 10 | 11 | self.n_samples = n_samples 12 | self.kappa = torch.nn.Parameter(torch.ones(1, device=device) * kappa_init, requires_grad=True) 13 | 14 | def forward(self, mu_ref, kappa_ref, mu_pos, kappa_pos, mu_neg, kappa_neg): 15 | # mu_neg and mu_pos is of dimension [batch, n_neg, dim] 16 | # mu_ref is dimension [batch, dim] 17 | mu_ref = mu_ref.unsqueeze(1) 18 | kappa_ref = kappa_ref.unsqueeze(1) 19 | 20 | # Draw samples (new dimension 0 contains the samples) 21 | samples_ref = VonMisesFisher(mu_ref, kappa_ref).rsample(self.n_samples) # [n_MC, batch, n_pos, dim] 22 | samples_pos = VonMisesFisher(mu_pos, kappa_pos).rsample(self.n_samples) 23 | if mu_neg is not None: 24 | samples_neg = VonMisesFisher(mu_neg, kappa_neg).rsample(self.n_samples) 25 | else: 26 | # If we don't get negative samples, treat the next batch sample as negative sample 27 | samples_neg = torch.roll(samples_pos, 1, 1) 28 | 29 | # calculate the standard log contrastive loss for each vmf sample 30 | negs = torch.logsumexp(torch.sum(samples_ref * samples_neg, dim=3) * self.kappa - torch.log(torch.ones(1).cuda() * samples_neg.shape[2]), dim=2) 31 | log_denominator_pos = torch.logsumexp(torch.stack((torch.sum(samples_ref * samples_pos, dim=3).squeeze(2) * self.kappa, negs), dim=0), dim=0) 32 | log_numerator_pos = torch.sum(samples_ref * samples_pos, dim=3) * self.kappa 33 | log_py1_pos = log_numerator_pos - log_denominator_pos.unsqueeze(2) 34 | 35 | # Average over the samples (we actually want a logmeanexp, that's why we substract log(n_samples)) 36 | log_py1_pos = torch.logsumexp(log_py1_pos, dim=0) - torch.log(torch.ones(1, device=self.kappa.device) * self.n_samples) 37 | 38 | # Calculate loss 39 | loss = torch.mean(log_py1_pos) 40 | return -loss 41 | 42 | 43 | class ELK(nn.Module): 44 | def __init__(self, kappa_init=20, device=torch.device('cuda:0')): 45 | super().__init__() 46 | 47 | self.kappa = torch.nn.Parameter(torch.ones(1, device=device) * kappa_init, requires_grad=True) 48 | 49 | def log_ppk_vmf_vec(self, mu1, kappa1, mu2, kappa2): 50 | p = mu1.shape[-1] 51 | 52 | kappa3 = torch.linalg.norm(kappa1 * mu1 + kappa2 * mu2, dim=-1).unsqueeze(-1) 53 | ppk = log_vmf_norm_const(kappa1, p) + log_vmf_norm_const(kappa2, p) - log_vmf_norm_const(kappa3, p) 54 | ppk = ppk * self.kappa 55 | 56 | return ppk.squeeze(-1) 57 | 58 | def forward(self, mu_ref, kappa_ref, mu_pos, kappa_pos, mu_neg, kappa_neg): 59 | # mu_neg and mu_pos is of dimension [batch, n_neg, dim] 60 | # mu_ref is dimension [batch, dim] 61 | mu_ref = mu_ref.unsqueeze(1) 62 | kappa_ref = kappa_ref.unsqueeze(1) 63 | 64 | # Calculate similarities 65 | sim_pos = self.log_ppk_vmf_vec(mu_ref, kappa_ref, mu_pos, kappa_pos) 66 | if mu_neg is not None: 67 | sim_neg = self.log_ppk_vmf_vec(mu_ref, kappa_ref, mu_neg, kappa_neg) 68 | else: 69 | # If we don't get negative samples, treat the next batch sample as negative sample 70 | sim_neg = torch.roll(sim_pos, 1, 0) 71 | 72 | # Calculate loss 73 | loss = torch.mean(sim_pos, dim=1) - torch.logsumexp(torch.cat((sim_pos, sim_neg), dim=1), dim=1) 74 | loss = -torch.mean(loss) 75 | return loss 76 | 77 | 78 | class HedgedInstance(nn.Module): 79 | def __init__(self, kappa_init=1, b_init=0, n_samples=16, device=torch.device('cuda:0')): 80 | super().__init__() 81 | 82 | self.n_samples = n_samples 83 | self.kappa = torch.nn.Parameter(torch.ones(1, device=device) * kappa_init, requires_grad=True) # kappa is "a" in the notation of their paper 84 | self.b = torch.nn.Parameter(torch.ones(1, device=device) * b_init, requires_grad=True) 85 | 86 | def forward(self, mu_ref, kappa_ref, mu_pos, kappa_pos, mu_neg, kappa_neg): 87 | # mu_neg and mu_pos is of dimension [batch, n_neg, dim] 88 | # mu_ref is dimension [batch, dim] 89 | mu_ref = mu_ref.unsqueeze(1) 90 | kappa_ref = kappa_ref.unsqueeze(1) 91 | 92 | # Draw samples (new dimension 0 contains the samples) 93 | samples_ref = VonMisesFisher(mu_ref, kappa_ref).rsample(self.n_samples) # [n_MC, batch, n_pos, dim] 94 | samples_pos = VonMisesFisher(mu_pos, kappa_pos).rsample(self.n_samples) 95 | if mu_neg is not None: 96 | samples_neg = VonMisesFisher(mu_neg, kappa_neg).rsample(self.n_samples) 97 | else: 98 | # If we don't get negative samples, treat the next batch sample as negative sample 99 | samples_neg = torch.roll(samples_pos, 1, 1) 100 | 101 | # calculate the standard log contrastive loss for each vmf sample 102 | py1_pos = torch.sigmoid(self.kappa * torch.sum(samples_ref * samples_pos, dim=-1) + self.b) 103 | py1_neg = torch.sigmoid(self.kappa * torch.sum(samples_ref * samples_neg, dim=-1) + self.b) 104 | 105 | # Average over the samples 106 | log_py1_pos = torch.mean(torch.log(py1_pos), dim=0) 107 | log_py0_neg = torch.mean(torch.log(1 - py1_neg), dim=0) 108 | 109 | # Calculate loss 110 | loss = torch.mean(log_py1_pos) + torch.mean(log_py0_neg) / log_py0_neg.shape[-1] 111 | return -loss 112 | 113 | 114 | def smoothness_loss(x, z): 115 | x_dist = pairwise_l2_dists(x) 116 | z_dist = 1 - pairwise_cos_sims(z)/ 2 117 | 118 | loss = torch.mean((x_dist - z_dist)**2 * (torch.sqrt(torch.ones(1, device=x.device) * 2) - z_dist.detach())**4) 119 | 120 | return loss 121 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | from utils.vmf_sampler import VonMisesFisher 5 | from torch.distributions.normal import Normal 6 | from torch.distributions.laplace import Laplace 7 | from utils.utils import pairwise_cos_sims, pairwise_l2_dists 8 | from scipy.stats import spearmanr 9 | from abc import abstractmethod, ABC 10 | from collections import OrderedDict 11 | import faiss 12 | from mpl_toolkits.axes_grid1 import ImageGrid 13 | from tueplots import bundles 14 | bundles.icml2022(family="sans-serif", usetex=False, column="half", nrows=1) 15 | plt.rcParams.update(bundles.icml2022()) 16 | 17 | 18 | def vis_2d_sphere(point_list, ax=None): 19 | # Input: 20 | # point_list: a list of 2D numpy arrays. Each array will be plotted with its own color 21 | 22 | if ax is None: 23 | ax = plt.gca() 24 | 25 | for points in point_list: 26 | ax.scatter(points[:, 0], points[:, 1]) 27 | 28 | ax.set_xlim(-1.2, 1.2) 29 | ax.set_ylim(-1.2, 1.2) 30 | ax.set_aspect('equal') 31 | 32 | 33 | def eval_generator_smoothness(gen, ax=None): 34 | 35 | if ax is None: 36 | ax = plt.gca() 37 | 38 | n = 100 39 | x = gen._sample_x(n) 40 | mu, _ = gen(x) 41 | x = x.detach().cpu() 42 | mu = mu.detach().cpu() 43 | 44 | cos_dist = (1 - pairwise_cos_sims(mu)) / 2 45 | l2_dist = pairwise_l2_dists(x) 46 | corr = np.corrcoef(cos_dist.numpy(), l2_dist.numpy())[1, 0] 47 | 48 | ax.scatter(l2_dist, cos_dist) 49 | ax.set_xlabel("l2 Distance in input space") 50 | ax.set_ylabel("cos dist in embedding space") 51 | ax.set_title(f'corr={corr}') 52 | 53 | 54 | def numerical_eval(args, gen, enc, x_eval=None, print_results=True, eval_std_instead_of_param=False): 55 | def mse(a, b): 56 | return torch.mean(torch.sqrt((a - b)**2)) 57 | 58 | with torch.no_grad(): 59 | if x_eval is None: 60 | x_eval = gen._sample_x(args.n_eval) 61 | mu_enc, kappa_enc = enc(x_eval) 62 | mu_gen, kappa_gen = gen(x_eval) 63 | if eval_std_instead_of_param: 64 | if args.g_post_family == "vmf": 65 | gen_samples = VonMisesFisher(mu_gen, kappa_gen).sample(100) 66 | elif args.g_post_family == "Gaussian": 67 | gen_samples = Normal(mu_gen, 1 / torch.sqrt(kappa_gen)).sample([100]) 68 | elif args.g_post_family == "Laplace": 69 | gen_samples = Laplace(mu_gen, 1 / kappa_gen).sample([100]) 70 | enc_samples = VonMisesFisher(mu_enc, kappa_enc).sample(100) 71 | gen_samples = torch.nn.functional.normalize(gen_samples, dim=-1) 72 | kappa_enc = (enc_samples * mu_enc).sum(-1).abs().mean(0) 73 | kappa_gen = (gen_samples * mu_gen).sum(-1).abs().mean(0) 74 | 75 | # Evaluate means. 76 | # We want them to be equal up to a rotation of the sphere. 77 | # Zimmermann do this by regressing the mu_enc and mu_gen, saying they should be r=1 if one is only a rotation 78 | # of the other. Not sure if this holds. Rotation matrices can have imaginary parts. 79 | # We will do it by comparing the pairwise cosine distances between the mu_s 80 | cosdist_enc = pairwise_cos_sims(mu_enc) 81 | cosdist_gen = pairwise_cos_sims(mu_gen) 82 | cosdist_mse = mse(cosdist_gen, cosdist_enc) 83 | cosdist_corr = np.corrcoef(cosdist_gen.cpu().numpy(), cosdist_enc.cpu().numpy())[1, 0] 84 | cosdist_rankcorr = spearmanr(cosdist_gen.cpu().numpy(), cosdist_enc.cpu().numpy())[0] 85 | 86 | # Evaluate kappas 87 | # For kappas, we want exact match. But possibly if we are matching up to a scale, that might indicate something. 88 | kappa_mse = mse(kappa_enc, kappa_gen) 89 | kappa_corr = np.corrcoef(kappa_enc.cpu().flatten().numpy(), kappa_gen.cpu().flatten().numpy())[1, 0] 90 | kappa_rankcorr = spearmanr(kappa_enc.cpu().flatten().numpy(), kappa_gen.cpu().flatten().numpy())[0] 91 | 92 | if print_results: 93 | print(f'Cosdist MSE = {cosdist_mse:.3f}, corr = {cosdist_corr:.3f}, rankcorr = {cosdist_rankcorr:.3f}. Kappa MSE = {kappa_mse:.3f}, corr = {kappa_corr:.3f}, rankcorr = {kappa_rankcorr:.3f}.') 94 | return cosdist_mse, cosdist_corr, cosdist_rankcorr, kappa_mse, kappa_corr, kappa_rankcorr 95 | 96 | 97 | def graphical_eval(args, gen, enc, x_eval=None, print_examples=False): 98 | with torch.no_grad(): 99 | if x_eval is None: 100 | x_eval = gen._sample_x(args.n_eval) 101 | mu_enc, kappa_enc = enc(x_eval) 102 | mu_gen, kappa_gen = gen(x_eval) 103 | n_examples = torch.arange(0, x_eval.shape[0] - 1, torch.floor(torch.ones(1) * (x_eval.shape[0] - 1) / 30).type(torch.long).item()) 104 | examples_ids = torch.argsort(kappa_gen.squeeze())[n_examples] 105 | samples_vmf_enc = VonMisesFisher(mu_enc[examples_ids], kappa_enc[examples_ids]).sample(10) 106 | samples_vmf_gen = VonMisesFisher(mu_gen[examples_ids], kappa_gen[examples_ids]).sample(10) 107 | 108 | # The 2D sphere of enc and gen should be the same, up to a rotation 109 | f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(9.5, 9.5)) 110 | vis_2d_sphere([samples_vmf_gen[:, i, :].cpu().numpy() for i in range(samples_vmf_gen.shape[1])], ax1) 111 | vis_2d_sphere([samples_vmf_enc[:, i, :].cpu().numpy() for i in range(samples_vmf_enc.shape[1])], ax2) 112 | ax1.set_title("Generator space $\mathcal{Z}$") 113 | ax2.set_title("Encoder space $\mathcal{E}$") 114 | 115 | # Show the correlation between the cosine dists 116 | cosdist_enc = pairwise_cos_sims(mu_enc) 117 | cosdist_gen = pairwise_cos_sims(mu_gen) 118 | ax3.scatter(cosdist_gen.cpu().numpy(), cosdist_enc.cpu().numpy()) 119 | ax3.set_xlabel("Generator pairwise cos-sim") 120 | ax3.set_ylabel("Encoder pairwise cos-sim") 121 | 122 | # Show the correlation between the kappas 123 | ax4.scatter(kappa_gen.cpu().flatten().numpy(), kappa_enc.cpu().flatten().numpy()) 124 | ax4.set_xlabel("Generator $\kappa$") 125 | ax4.set_ylabel("Encoder $\kappa$") 126 | 127 | # Print some exemplary predictions 128 | if print_examples: 129 | print("Exemplary true and predicted mu and kappas:") 130 | print("GT mu(x):") 131 | print(mu_gen[examples_ids]) 132 | print("GT kappa(x):") 133 | print(kappa_gen[examples_ids]) 134 | print("Pred mu(x):") 135 | print(mu_enc[examples_ids]) 136 | print("Pred kappa(x):") 137 | print(kappa_enc[examples_ids]) 138 | 139 | def eval_cifar(args, enc, dataloader, corrupted_dataloader, filename_suffix=None, want_plot=False): 140 | with torch.no_grad(): 141 | # Run through the dataset: 142 | embeddings = list() 143 | confidences = list() 144 | soft_labels = list() 145 | for img, soft_label in dataloader: 146 | embed, conf = enc(img.to(enc.device)) 147 | embeddings.append(embed.detach().cpu()) 148 | confidences.append(conf.detach().cpu().squeeze()) 149 | soft_labels.append(soft_label) 150 | embeddings = torch.cat(embeddings, dim=0) 151 | confidences = torch.cat(confidences, dim=0) 152 | soft_labels = torch.cat(soft_labels, dim=0) 153 | _, hard_labels = torch.max(soft_labels, dim=1) 154 | gt_entropy = soft_labels * torch.log(soft_labels) 155 | gt_entropy[torch.isnan(gt_entropy)] = 0 156 | gt_entropy = torch.sum(gt_entropy, dim=1) 157 | 158 | # predicted entropy vs class entropy 159 | rcorr_entropy = spearmanr(confidences.numpy(), gt_entropy.numpy())[0] 160 | 161 | # MAP@R vs confidence filter-out rates 162 | mapr = NearestNeighboursMetrics()(embeddings, hard_labels, confidences) 163 | 164 | # predicted entropy vs cifar10s entropy 165 | # rcorr_cifar10s = spearmanr(confidences.numpy()[np.logical_not(np.isnan(entropies_cifar10s))], 166 | # entropies_cifar10s[np.logical_not(np.isnan(entropies_cifar10s))])[0] 167 | # print(rcorr_cifar10s) 168 | 169 | # Plots 170 | if want_plot: 171 | if filename_suffix is not None: 172 | path = f"results/{args.savefolder}/uncertain_images_{filename_suffix}.png" 173 | path_erc = f"results/{args.savefolder}/erc_plot_{filename_suffix}.pdf" 174 | path_conf_vs_entropy = f"results/{args.savefolder}/confidence_vs_entropy_{filename_suffix}.png" 175 | path_conf_vs_corruption = f"results/{args.savefolder}/confidence_vs_corruption_{filename_suffix}.png" 176 | path_uncertain_retrieval = f"results/{args.savefolder}/uncertain_retrieval_{filename_suffix}.png" 177 | # path_conf_vs_cifar10s = f"results/{args.savefolder}/confidence_vs_cifar10s_{filename_suffix}.png" 178 | else: 179 | path = f"results/{args.savefolder}/uncertain_images.png" 180 | path_erc = f"results/{args.savefolder}/erc_plot.pdf" 181 | path_conf_vs_entropy = f"results/{args.savefolder}/confidence_vs_entropy.png" 182 | path_conf_vs_corruption = f"results/{args.savefolder}/confidence_vs_corruption.png" 183 | path_uncertain_retrieval = f"results/{args.savefolder}/uncertain_retrieval.png" 184 | # path_conf_vs_cifar10s = f"results/{args.savefolder}/confidence_vs_cifar10s.png" 185 | # path_cifar_10h_vs_10s = "cifar_10h_vs_10s.png" 186 | #query_ids = torch.multinomial(torch.ones(embeddings.shape[0]), 5, replacement=False) 187 | with plt.rc_context(bundles.icml2022()): 188 | erc_plot(mapr["erc-recall@1"], path_erc) 189 | scatter_plot(confidences.numpy(), gt_entropy.numpy(), path_conf_vs_entropy, 190 | xlabel="Predicted Confidence $\kappa$", ylabel="Negative Entropy of CIFAR-10H") 191 | query_ids = torch.argsort(confidences, descending=False)[torch.Tensor([30, 500, 1000, 1302, 1800]).type(torch.long)] 192 | print("Kappa values of the example images:") 193 | print(confidences[query_ids]) 194 | #query_ids = torch.arange(0, embeddings.shape[0] - 1, torch.floor(torch.ones(1) * embeddings.shape[0]).item() / 5).type(torch.long) 195 | uncertain_retrieval(embeddings[query_ids], confidences[query_ids], query_ids, 196 | embeddings, confidences, dataloader.dataset, path=path_uncertain_retrieval) 197 | uncertain_images(confidences, hard_labels, dataloader, path) 198 | # scatter_plot(confidences.numpy()[np.logical_not(np.isnan(entropies_cifar10s))], 199 | # entropies_cifar10s[np.logical_not(np.isnan(entropies_cifar10s))], path_conf_vs_cifar10s, 200 | # xlabel="Predicted Confidence $\kappa$", ylabel="Negative Entropy of CIFAR-10S") 201 | # scatter_plot(gt_entropy.numpy()[np.logical_not(np.isnan(entropies_cifar10s))], 202 | # entropies_cifar10s[np.logical_not(np.isnan(entropies_cifar10s))], path_cifar_10h_vs_10s, 203 | # xlabel="Negative Entropy of CIFAR-10H", ylabel="Negative Entropy of CIFAR-10S") 204 | 205 | # Run through the corrupted data 206 | embeddings = list() 207 | confidences = list() 208 | soft_labels = list() 209 | quality = list() 210 | for img, soft_label, qual in corrupted_dataloader: 211 | embed, conf = enc(img.to(enc.device)) 212 | embeddings.append(embed.detach().cpu()) 213 | confidences.append(conf.detach().cpu().squeeze()) 214 | soft_labels.append(soft_label) 215 | quality.append(qual) 216 | embeddings = torch.cat(embeddings, dim=0) 217 | confidences = torch.cat(confidences, dim=0) 218 | soft_labels = torch.cat(soft_labels, dim=0) 219 | _, hard_labels = torch.max(soft_labels, dim=1) 220 | quality = torch.cat(quality, dim=0) 221 | 222 | # predicted entropy vs corruption level 223 | rcorr_corrupt = spearmanr(confidences.numpy(), quality.numpy())[0] 224 | 225 | # MAP@R vs confidence filter-out rates on corrupted data 226 | mapr_corrupt = NearestNeighboursMetrics()(embeddings, hard_labels, confidences) 227 | 228 | # Another plot 229 | if want_plot: 230 | scatter_plot(confidences.numpy(), quality.numpy(), path_conf_vs_corruption, 231 | xlabel="Predicted Confidence $\kappa$", ylabel="Proportion of Image Shown") 232 | 233 | return mapr["recall@1"], mapr["mapr"].detach().cpu().item(), rcorr_entropy, \ 234 | mapr_corrupt["recall@1"], mapr_corrupt["mapr"].detach().cpu().item(), rcorr_corrupt 235 | 236 | def erc_plot(erc, path="erc_plot.png"): 237 | # erc - a tensor with the cumsums of average errors for the least confident 1, 2, 3, 4, ... samples 238 | plt.figure(figsize=(3.25, 2.)) 239 | plt.plot(np.arange(1, len(erc) + 1) / len(erc), erc, color="#4878d0") 240 | plt.xlabel("Percentage of Excluded Lowest-certain Samples") 241 | plt.ylabel("Recall@1") 242 | plt.grid(zorder=-1, color="lightgrey", lw=0.5) 243 | plt.savefig(path) 244 | plt.close() 245 | 246 | return None 247 | 248 | def scatter_plot(x, y, path="scatterplot.png", xlabel="", ylabel=""): 249 | # erc - a tensor with the cumsums of average errors for the least confident 1, 2, 3, 4, ... samples 250 | plt.scatter(x, y) 251 | plt.xlabel(xlabel) 252 | plt.ylabel(ylabel) 253 | plt.savefig(path) 254 | plt.close() 255 | 256 | return None 257 | 258 | 259 | def plot_images(dataset, ids, fig=plt.figure(figsize=(20., 20.))): 260 | # Provided a 2D tensor of image ids, picks them from the dataset and plots them in the 261 | # same matrix structure they have in ids 262 | # If ids includes negative ids, they are skipped and plotted as a white picture 263 | 264 | # Setup image grid 265 | grid = ImageGrid(fig, 111, # similar to subplot(111) 266 | nrows_ncols=(ids.shape[0], ids.shape[1]), # creates 2x2 grid of axes 267 | axes_pad=0.1, # pad between axes in inch. 268 | ) 269 | 270 | # Plot images 271 | # Iterating over the grid returns the Axes. 272 | for ax, id in zip(grid, ids.flatten()): 273 | ax.set_axis_off() 274 | if id >= 0: 275 | im, _ = dataset.__getitem__(id) 276 | ax.imshow(torch.minimum(torch.ones(1), torch.maximum(torch.zeros(1), im.permute(1, 2, 0) * 277 | torch.tensor([0.2471, 0.2435, 0.2616]) + torch.tensor( 278 | [0.4914, 0.4822, 0.4465])))) 279 | 280 | 281 | def uncertain_images(confidence, labels, dataloader, path="uncertain_images.png"): 282 | # Find most certain/uncertain images per class 283 | chosenclasses = np.array(np.arange(10)) 284 | chosen_ids = [] 285 | for lab in chosenclasses: 286 | ids = np.array([i for i in np.arange(len(labels)) if labels[i] == lab]) 287 | order = torch.argsort(confidence[ids]) 288 | first = ids[order[:5]] 289 | last = ids[order[-5:]] 290 | chosen_ids.append(np.concatenate((first, last)).tolist()) 291 | chosen_ids = np.array(chosen_ids) 292 | chosen_ids[1, 1] = -1 293 | 294 | fig = plt.figure(figsize=(20., 20.)) 295 | plot_images(dataloader.dataset, chosen_ids, fig) 296 | fig.savefig(path) 297 | plt.close() 298 | 299 | return None 300 | 301 | def uncertain_retrieval(q_embed, q_conf, q_id, r_embed, r_conf, dataset, alpha=0.05, path="uncertain_retrieval.png"): 302 | ''' 303 | Retrieves and plots the top images for each query image from a retrieval dataset 304 | :param q_embed: tensor of shape [n_query, dim], mean embeds of images to be searched 305 | :param q_conf: tensor of shape [n_query], kappa confidence values of images to be searched 306 | :param r_embed: tensor of shape [retrieval_dataset_size, dim], mean embeds of all images in the desired dataset 307 | :param r_conf: tensor of shape [retrieval_dataset_size], kappa confidence values of all images in the desired dataset 308 | (currently unused) 309 | :param dataset: a dataset object where we can retrieve images via a __getitem__ method 310 | :param alpha: The "confidence level" for the maximum-a-posteriori interval 311 | :param path: String, where to save the image 312 | :return: Nothing, but a plot is plotted 313 | ''' 314 | # order queries by their confidence 315 | order = torch.argsort(q_conf, descending=True) 316 | q_embed = q_embed[order] 317 | q_id = q_id[order] 318 | q_conf = q_conf[order] 319 | q_conf = q_conf.unsqueeze(1) 320 | 321 | # Calculate confidence intervals 322 | samples = VonMisesFisher(q_embed, q_conf).sample(10000) 323 | dot_prods = torch.sum(q_embed.unsqueeze(0) * samples, dim=-1) 324 | thresholds = torch.quantile(dot_prods, alpha, dim=0) 325 | 326 | # See which region of the embedding space this covers to decide how many samples we should show 327 | unif_samples = torch.zeros((1000000, r_embed.shape[1])).normal_() 328 | unif_samples = unif_samples / unif_samples.norm(dim=1).unsqueeze(1) 329 | dot_prods = torch.sum(q_embed[0].unsqueeze(0) * unif_samples, dim=-1) 330 | covered_sphere_pct = torch.sum(dot_prods.unsqueeze(0) >= thresholds.unsqueeze(1), dim=1) / 1000000 331 | n_select = torch.round(covered_sphere_pct / torch.max(covered_sphere_pct) * 21).type(torch.long) # we want 14 for the maximum one 332 | 333 | # collect query image ids that fall into these confidence intervals 334 | dot_prod_q_r = torch.sum(q_embed.unsqueeze(1) * r_embed.unsqueeze(0), dim=-1) 335 | is_in_interval = dot_prod_q_r >= thresholds.unsqueeze(1) 336 | # Prevent retrieving the query image itself 337 | for i, id in enumerate(q_id): 338 | is_in_interval[i, id] = 0 339 | # Edge case: No similarities found. Then just remove the thing 340 | n_found = is_in_interval.sum(dim=1) 341 | q_id = q_id[n_found > 0] 342 | is_in_interval = is_in_interval[n_found > 0,:] 343 | # The space is very clustered. Incentivize to pick not all images from the same cluster by 344 | # weighing the samples with the inverse density of the embedding space at its position, estimated by NN 345 | #dens = torch.sum(torch.exp(80 * torch.sum(r_embed.unsqueeze(0) * r_embed.unsqueeze(1), dim=-1)), dim=1) 346 | selected_ids = torch.multinomial(is_in_interval.type(torch.float), torch.max(n_select), replacement=False)# / dens.unsqueeze(0), torch.max(n_select), replacement=False) 347 | # Remember to only use the the n_select first ones of each row. 348 | # We only had to choose the max for all, because a tensor cannot have different number of entries in each row. 349 | 350 | # plot 351 | # Transform the selected_ids into an array that will look good in uncertain_images 352 | ncol = 7 # Columns of retrieved images 353 | dist_between_query_and_retrieval = 4 354 | dist_between_retrieval_rows = 6 355 | retrieval_tensors = [] 356 | for i in range(selected_ids.shape[0]): 357 | nrows = ((n_select[i] - 1) // (ncol) + 1) + dist_between_retrieval_rows 358 | id_array = -torch.ones(nrows * (ncol)) 359 | selected = selected_ids[i, :n_select[i]] 360 | # Order them by similarity 361 | similarities = dot_prod_q_r[i, selected] 362 | selected = selected[torch.argsort(similarities, descending=True)] 363 | id_array[:n_select[i]] = selected 364 | id_array = torch.reshape(id_array, (nrows, ncol)) 365 | query_array = -torch.ones((nrows, dist_between_query_and_retrieval)) 366 | query_array[0, 0] = q_id[i] 367 | id_array = torch.cat((query_array, id_array), dim=1) 368 | retrieval_tensors.append(id_array) 369 | id_matrix = torch.cat(retrieval_tensors, dim=0).type(torch.long).numpy() 370 | 371 | fig = plt.figure(figsize=(20., 20.)) 372 | plot_images(dataset, id_matrix, fig) 373 | fig.savefig(path) 374 | plt.close() 375 | 376 | ######################################################################################################################## 377 | # The following code is modified from https://github.com/tinkoff-ai/probabilistic-embeddings, Apache License 2.0 # 378 | ######################################################################################################################## 379 | def asarray(x): 380 | if isinstance(x, torch.Tensor): 381 | x = x.cpu() 382 | return np.ascontiguousarray(x) 383 | 384 | 385 | class NearestNeighboursBase(ABC): 386 | """Base class for all nearest neighbour metrics.""" 387 | 388 | @property 389 | @abstractmethod 390 | def match_self(self): 391 | """Whether to compare each sample with self or not.""" 392 | pass 393 | 394 | @property 395 | @abstractmethod 396 | def need_positives(self): 397 | """Whether metric requires positive scores or not.""" 398 | pass 399 | 400 | @property 401 | @abstractmethod 402 | def need_confidences(self): 403 | """Whether metric requires confidences or not.""" 404 | pass 405 | 406 | @abstractmethod 407 | def num_nearest(self, labels): 408 | """Get the number of required neighbours. 409 | Args: 410 | labels: Dataset labels. 411 | """ 412 | pass 413 | 414 | @abstractmethod 415 | def __call__(self, nearest_same, nearest_scores, class_sizes, positive_scores=None, confidences=None): 416 | """Compute metric value. 417 | Args: 418 | nearset_same: Binary labels of nearest neighbours equal to 1 iff class is equal to the query. 419 | nearest_scores: Similarity scores of nearest neighbours. 420 | class_sizes: Class size for each element. 421 | positive_scores (optional): Similarity scores of elements with the same class (depends on match_self). 422 | confidences (optional): Confidence for each element of the batch with shape (B). 423 | Returns: 424 | Metric value. 425 | """ 426 | pass 427 | 428 | 429 | class RecallK(NearestNeighboursBase): 430 | """Recall@K metric.""" 431 | def __init__(self, k): 432 | self._k = k 433 | 434 | @property 435 | def match_self(self): 436 | """Whether to compare each sample with self or not.""" 437 | return False 438 | 439 | @property 440 | def need_positives(self): 441 | """Whether metric requires positive scores or not.""" 442 | return False 443 | 444 | @property 445 | def need_confidences(self): 446 | """Whether metric requires confidences or not.""" 447 | return False 448 | 449 | def num_nearest(self, labels): 450 | """Get the number of required neighbours. 451 | Args: 452 | labels: Dataset labels. 453 | """ 454 | return self._k 455 | 456 | def __call__(self, nearest_same, nearest_scores, class_sizes, positive_scores=None, confidences=None): 457 | """Compute metric value. 458 | Args: 459 | nearset_same: Binary labels of nearest neighbours equal to 1 iff class is equal to the query. 460 | nearest_scores: Similarity scores of nearest neighbours. 461 | class_sizes: Class size for each element. 462 | positive_scores: Similarity scores of elements with the same class. 463 | confidences (optional): Confidence for each element of the batch with shape (B). 464 | Returns: 465 | Metric value. 466 | """ 467 | mask = class_sizes > 1 468 | if mask.sum().item() == 0: 469 | return np.nan 470 | has_same, _ = nearest_same[mask, :self._k].max(1) 471 | return has_same.float().mean().item() 472 | 473 | 474 | class ERCRecallK(NearestNeighboursBase): 475 | """Error-versus-Reject-Curve based on Recall@K metric.""" 476 | def __init__(self, k): 477 | self._k = k 478 | 479 | @property 480 | def match_self(self): 481 | """Whether to compare each sample with self or not.""" 482 | return False 483 | 484 | @property 485 | def need_positives(self): 486 | """Whether metric requires positive scores or not.""" 487 | return False 488 | 489 | @property 490 | def need_confidences(self): 491 | """Whether metric requires confidences or not.""" 492 | return True 493 | 494 | def num_nearest(self, labels): 495 | """Get the number of required neighbours. 496 | Args: 497 | labels: Dataset labels. 498 | """ 499 | return self._k 500 | 501 | def __call__(self, nearest_same, nearest_scores, class_sizes, positive_scores=None, confidences=None): 502 | """Compute metric value. 503 | Args: 504 | nearset_same: Binary labels of nearest neighbours equal to 1 iff class is equal to the query. 505 | nearest_scores: Similarity scores of nearest neighbours. 506 | class_sizes: Class size for each element. 507 | positive_scores: Similarity scores of elements with the same class. 508 | confidences (optional): Confidence for each element of the batch with shape (B). 509 | Returns: 510 | Metric value. 511 | """ 512 | if confidences is None: 513 | raise ValueError("Can't compute ERC without confidences.") 514 | mask = class_sizes > 1 515 | if mask.sum().item() == 0: 516 | return np.nan 517 | recalls, _ = nearest_same[mask, :self._k].max(1) 518 | errors = 1 - recalls.float() 519 | confidences = confidences[mask] 520 | 521 | b = len(errors) 522 | order = torch.argsort(confidences, descending=True) 523 | errors = errors[order] # High confidence first. 524 | mean_errors = errors.cumsum(0) / torch.arange(1, b + 1, device=errors.device) 525 | # We want to plot the R@1, not the error, so return the correct predictions 526 | correct = 1 - mean_errors 527 | correct = torch.flip(correct, dims=(0, )) # In the latter plot, we want the highest confident samples last 528 | return correct.cpu().numpy() 529 | 530 | class ATRBase(NearestNeighboursBase): 531 | """Base class for @R metrics. 532 | All @R metrics search for the number of neighbours equal to class size. 533 | Args: 534 | match_self: Whether to compare each sample with self or not. 535 | Inputs: 536 | - parameters: Embeddings distributions tensor with shape (B, P). 537 | - labels: Label for each embedding with shape (B). 538 | Outputs: 539 | - Metric value. 540 | """ 541 | 542 | def __init__(self, match_self=False): 543 | super().__init__() 544 | self._match_self = match_self 545 | 546 | @property 547 | @abstractmethod 548 | def oversample(self): 549 | """Sample times more nearest neighbours.""" 550 | pass 551 | 552 | @abstractmethod 553 | def _aggregate(self, nearest_same, nearest_scores, num_nearest, class_sizes, positive_scores, confidences=None): 554 | """Compute metric value. 555 | Args: 556 | nearest_same: Matching labels for nearest neighbours with shape (B, R). 557 | Matches are coded with 1 and mismatches with 0. 558 | nearest_scores: Score for each neighbour with shape (B, R). 559 | num_nearest: Number of nearest neighbours for each element of the batch with shape (B). 560 | class_sizes: Number of elements in the class for each element of the batch. 561 | positive_scores: Similarity scores of elements with the same class. 562 | confidences (optional): Confidence for each element of the batch with shape (B). 563 | """ 564 | pass 565 | 566 | @property 567 | def match_self(self): 568 | """Whether to compare each sample with self or not.""" 569 | return self._match_self 570 | 571 | @property 572 | def need_positives(self): 573 | """Whether metric requires positive scores or not.""" 574 | return True 575 | 576 | @property 577 | def need_confidences(self): 578 | """Whether metric requires confidences or not.""" 579 | return False 580 | 581 | def num_nearest(self, labels): 582 | """Get maximum number of required neighbours. 583 | Args: 584 | labels: Dataset labels. 585 | """ 586 | max_r = torch.bincount(labels).max().item() 587 | max_r *= self.oversample 588 | return max_r 589 | 590 | def __call__(self, nearest_same, nearest_scores, class_sizes, positive_scores, confidences=None): 591 | """Compute metric value. 592 | Args: 593 | nearset_same: Binary labels of nearest neighbours equal to 1 iff class is equal to the query. 594 | nearest_scores: Similarity scores of nearest neighbours. 595 | class_sizes: Number of elements in the class for each element of the batch. 596 | positive_scores: Similarity scores of elements with the same class. 597 | confidences (optional): Confidence for each element of the batch with shape (B). 598 | Returns: 599 | Metric value. 600 | """ 601 | num_positives = class_sizes if self.match_self else class_sizes - 1 602 | num_nearest = torch.clip(num_positives * self.oversample, max=nearest_same.shape[1]) 603 | return self._aggregate(nearest_same, nearest_scores, num_nearest, class_sizes, positive_scores, 604 | confidences=confidences) 605 | 606 | 607 | class MAPR(ATRBase): 608 | """MAP@R metric. 609 | See "A Metric Learning Reality Check" (2020) for details. 610 | """ 611 | 612 | @property 613 | def oversample(self): 614 | """Sample times more nearest neighbours.""" 615 | return 1 616 | 617 | def _aggregate(self, nearest_same, nearest_scores, num_nearest, class_sizes, positive_scores, confidences=None): 618 | """Compute MAP@R. 619 | Args: 620 | nearest_same: Matching labels for nearest neighbours with shape (B, R). 621 | Matches are coded with 1 and mismatches with 0. 622 | nearest_scores: (unused) Score for each neighbour with shape (B, R). 623 | num_nearest: Number of nearest neighbours for each element of the batch with shape (B). 624 | class_sizes: (unused) Number of elements in the class for each element of the batch. 625 | positive_scores: Similarity scores of elements with the same class. 626 | confidences (optional): Confidence for each element of the batch with shape (B). 627 | """ 628 | b, r = nearest_same.shape 629 | device = nearest_same.device 630 | range = torch.arange(1, r + 1, device=device) # (R). 631 | count_mask = range[None].tile(b, 1) <= num_nearest[:, None] # (B, R). 632 | precisions = count_mask * nearest_same * torch.cumsum(nearest_same, dim=1) / range[None] # (B, R). 633 | maprs = precisions.sum(-1) / torch.clip(num_nearest, min=1) # (B). 634 | return maprs.mean() 635 | 636 | 637 | class ERCMAPR(ATRBase): 638 | """ERC curve for MAP@R metric.""" 639 | 640 | @property 641 | def need_confidences(self): 642 | """Whether metric requires confidences or not.""" 643 | return True 644 | 645 | @property 646 | def oversample(self): 647 | """Sample times more nearest neighbours.""" 648 | return 1 649 | 650 | def _aggregate(self, nearest_same, nearest_scores, num_nearest, class_sizes, positive_scores, confidences=None): 651 | """Compute MAP@R ERC. 652 | Args: 653 | nearest_same: Matching labels for nearest neighbours with shape (B, R). 654 | Matches are coded with 1 and mismatches with 0. 655 | nearest_scores: (unused) Score for each neighbour with shape (B, R). 656 | num_nearest: Number of nearest neighbours for each element of the batch with shape (B). 657 | class_sizes: (unused) Number of elements in the class for each element of the batch. 658 | positive_scores: Similarity scores of elements with the same class. 659 | confidences (optional): Confidence for each element of the batch with shape (B). 660 | """ 661 | if confidences is None: 662 | raise ValueError("Can't compute ERC without confidences.") 663 | b, r = nearest_same.shape 664 | device = nearest_same.device 665 | range = torch.arange(1, r + 1, device=device) # (R). 666 | count_mask = range[None].tile(b, 1) <= num_nearest[:, None] # (B, R). 667 | precisions = count_mask * nearest_same * torch.cumsum(nearest_same, dim=1) / range[None] # (B, R). 668 | maprs = precisions.sum(-1) / torch.clip(num_nearest, min=1) # (B). 669 | errors = 1 - maprs.float() 670 | 671 | b = len(errors) 672 | order = torch.argsort(confidences, descending=True) 673 | errors = errors[order] # High confidence first. 674 | mean_errors = errors.cumsum(0) / torch.arange(1, b + 1, device=errors.device) 675 | return mean_errors.mean().cpu().item() 676 | 677 | 678 | class KNNIndex: 679 | BACKENDS = { 680 | "faiss": faiss.IndexFlatL2 681 | } 682 | 683 | def __init__(self, dim, backend="torch"): 684 | self._index = self.BACKENDS[backend](dim) 685 | 686 | def __enter__(self): 687 | if self._index is None: 688 | raise RuntimeError("Can't create context multiple times.") 689 | return self._index 690 | 691 | def __exit__(self, exc_type, exc_value, traceback): 692 | self._index.reset() 693 | self._index = None 694 | 695 | class NearestNeighboursMetrics: 696 | """Metrics based on nearest neighbours search. 697 | Args: 698 | distribution: Distribution object. 699 | scorer: Scorer object. 700 | Inputs: 701 | - parameters: Embeddings distributions tensor with shape (B, P). 702 | - labels: Label for each embedding with shape (B). 703 | Outputs: 704 | - Metrics values. 705 | """ 706 | 707 | METRICS = { 708 | "recall": RecallK, 709 | "erc-recall@1": lambda: ERCRecallK(1), 710 | "mapr": MAPR, 711 | "erc-mapr": ERCMAPR 712 | } 713 | 714 | @staticmethod 715 | def get_default_config(backend="faiss", broadcast_backend="torch", metrics=None, prefetch_factor=2, recall_k_values=(1,5)): 716 | """Get metrics parameters. 717 | Args: 718 | backend: KNN search engine ("faiss"). 719 | broadcast_backend: Torch doesn't support broadcast for gather method. 720 | We can emulate this behaviour with Numpy ("numpy") or tiling ("torch"). 721 | metrics: List of metric names to compute ("recall", "mapr", "mapr-nms"). 722 | By default compute all available metrics. 723 | prefetch_factor: Nearest neighbours number scaler for presampling. 724 | recall_k_values: List of K values to compute recall at. 725 | """ 726 | return OrderedDict([ 727 | ("backend", backend), 728 | ("broadcast_backend", broadcast_backend), 729 | ("metrics", metrics), 730 | ("prefetch_factor", prefetch_factor), 731 | ("recall_k_values", recall_k_values) 732 | ]) 733 | 734 | def __init__(self, *, match_via_cosine=True): 735 | self._config = self.get_default_config() 736 | 737 | self._metrics = OrderedDict() 738 | metric_names = self._config["metrics"] if self._config["metrics"] is not None else list(self.METRICS) 739 | for name in metric_names: 740 | if name == "recall": 741 | for k in self._config["recall_k_values"]: 742 | k = int(k) 743 | self._metrics["{}@{}".format(name, k)] = self.METRICS[name](k) 744 | else: 745 | metric = self.METRICS[name]() 746 | self._metrics[name] = metric 747 | 748 | self.match_via_cosine = match_via_cosine 749 | 750 | def __call__(self, embeddings, labels, confidences=None): 751 | """ 752 | This computes the metrics of given embeddings 753 | :param embeddings: (Batchsize, dim) tensor of embeddings 754 | :param labels: (Batchsize) tensor of class labels 755 | :param confidences: (Batchsize) tensor of confidences (higher = more confident). 756 | If None, will use the distance to the closest neighbor 757 | :return: 758 | """ 759 | if len(labels) != len(embeddings): 760 | raise ValueError("Batch size mismatch between labels and embeddings.") 761 | embeddings = embeddings.detach() # (B, P). 762 | labels = labels.detach() # (B). 763 | 764 | if self.match_via_cosine: 765 | embeddings = torch.nn.functional.normalize(embeddings, dim=-1) 766 | 767 | # Find desired nearest neighbours number for each sample and total. 768 | label_counts = torch.bincount(labels) # (L). 769 | class_sizes = label_counts[labels] # (B). 770 | num_nearest = max(metric.num_nearest(labels) + int(not metric.match_self) for metric in self._metrics.values()) 771 | num_nearest = min(num_nearest, len(labels)) 772 | 773 | # Gather nearest neighbours (sorted in score descending order). 774 | nearest, scores = self._find_nearest(embeddings, num_nearest) # (B, R), (B, R). 775 | num_nearest = torch.full((len(nearest),), num_nearest, device=labels.device) 776 | nearest_labels = self._gather_broadcast(labels[None], 1, nearest, backend=self._config["broadcast_backend"]) # (B, R). 777 | nearest_same = nearest_labels == labels[:, None] # (B, R). 778 | 779 | # If necessary, compute confidence as the similarity to the closest neighbor 780 | need_confidences = any([metric.need_confidences for metric in self._metrics.values()]) 781 | if need_confidences and confidences is None: 782 | confidences = scores[:,0] 783 | 784 | need_positives = any(metric.need_positives for metric in self._metrics.values()) 785 | if need_positives: 786 | positive_scores, _, positive_same_mask = self._get_positives(embeddings, labels) 787 | else: 788 | positive_scores, positive_same_mask = None, None 789 | 790 | need_nms = any(not metric.match_self for metric in self._metrics.values()) 791 | if need_nms: 792 | no_self_mask = torch.arange(len(labels), device=embeddings.device)[:, None] != nearest 793 | nearest_same_nms, _ = self._gather_mask(nearest_same, num_nearest, no_self_mask) 794 | scores_nms, num_nearest = self._gather_mask(scores, num_nearest, no_self_mask) 795 | if need_positives: 796 | positive_scores_nms, _ = self._gather_mask(positive_scores, class_sizes, ~positive_same_mask) 797 | else: 798 | positive_scores_nms = None 799 | 800 | metrics = OrderedDict() 801 | for name, metric in self._metrics.items(): 802 | if metric.match_self: 803 | metrics[name] = metric(nearest_same, scores, class_sizes, positive_scores, confidences=confidences) 804 | else: 805 | metrics[name] = metric(nearest_same_nms, scores_nms, class_sizes, positive_scores_nms, confidences=confidences) 806 | return metrics 807 | 808 | def _find_nearest(self, embeddings, max_nearest): 809 | """Find nearest neighbours for each element of the batch. 810 | """ 811 | embeddings = embeddings.unsqueeze(1) # We only have one "mode" 812 | b, c, d = embeddings.shape 813 | # Find neighbors using simple L2/dot scoring. 814 | prefetch = min(max_nearest * self._config["prefetch_factor"], b) 815 | candidates_indices, sim = self._multimodal_knn(embeddings, prefetch) # (B, C * R). 816 | return candidates_indices.reshape((b, -1)), sim.reshape((b, -1)) 817 | 818 | def _get_positives(self, embeddings, labels): 819 | label_counts = torch.bincount(labels) 820 | num_labels = len(label_counts) 821 | max_label_count = label_counts.max().item() 822 | by_label = torch.full((num_labels, max_label_count), -1, dtype=torch.long) 823 | counts = np.zeros(num_labels, dtype=np.int64) 824 | for i, label in enumerate(labels.cpu().numpy()): 825 | by_label[label][counts[label]] = i 826 | counts[label] += 1 827 | by_label = by_label.to(labels.device) # (L, C). 828 | indices = by_label[labels] # (B, C). 829 | num_positives = torch.from_numpy(counts).long().to(labels.device)[labels] 830 | positive_parameters = self._gather_broadcast(embeddings[None], 1, indices[..., None], 831 | backend=self._config["broadcast_backend"]) # (B, C, P). 832 | with torch.no_grad(): 833 | positive_scores = torch.sum(embeddings[:, None, :] * positive_parameters, dim = -1) # (B, C). 834 | same_mask = indices == torch.arange(len(labels), device=indices.device)[:, None] 835 | # Sort first elements in each row according to counts. 836 | no_sort_mask = torch.arange(positive_scores.shape[1], device=embeddings.device)[None] >= num_positives[:, None] 837 | positive_scores[no_sort_mask] = positive_scores.min() - 1 838 | positive_scores, order = torch.sort(positive_scores, dim=1, descending=True) 839 | same_mask = torch.gather(same_mask, 1, order) 840 | return positive_scores, num_positives, same_mask 841 | 842 | def _multimodal_knn(self, x, k): 843 | """Find nearest neighbours for multimodal queries. 844 | Args: 845 | x: Embeddings with shape (B, C, D) where C is the number of modalities. 846 | k: Number of nearest neighbours. 847 | Returns: 848 | Nearest neighbours indices with shape (B, C, K). Indices are in the range [0, B - 1]. 849 | """ 850 | b, c, d = x.shape 851 | if k > b: 852 | raise ValueError("Number of nearest neighbours is too large: {} for batch size {}.".format(k, b)) 853 | x_flat = asarray(x).reshape((b * c, d)) 854 | with KNNIndex(d, backend=self._config["backend"]) as index: 855 | index.add(x_flat) 856 | sim, indices = index.search(x_flat, k) # (B * C, K), indices are in [0, B * C - 1]. 857 | indices //= c # (B * C, K), indices are in [0, B - 1]. 858 | return torch.from_numpy(indices.reshape((b, c, k))).long().to(x.device), torch.from_numpy(sim.reshape((b, c, k))).to(x.device) 859 | 860 | @staticmethod 861 | def _remove_duplicates(indices, num_unique): 862 | """Take first n unique values from each row. 863 | Args: 864 | indices: Input indices with shape (B, K). 865 | num_unique: Number of unique indices in each row. 866 | Returns: 867 | Unique indices with shape (B, num_unique) and new scores if scores are provided. 868 | """ 869 | b, k = indices.shape 870 | if k == 1: 871 | return indices 872 | sorted_indices, order = torch.sort(indices, dim=1, stable=True) 873 | mask = sorted_indices[:, 1:] != sorted_indices[:, :-1] # (B, K - 1). 874 | mask = torch.cat([torch.ones_like(mask[:, :1]), mask], dim=1) # (B, K). 875 | mask = torch.gather(mask, 1, torch.argsort(order, dim=1)) 876 | counts = torch.cumsum(mask, 1) # (B, K). 877 | mask &= counts <= num_unique # (B, K). 878 | 879 | # Some FAISS indices allow duplicates. In this case total number of unique elements is less than min_unique. 880 | # Add tail samples to get exact min_unique number. 881 | num_extra_zeros = torch.clip(num_unique - counts[:, -1], 0) 882 | counts = torch.cumsum(~mask, 1) 883 | sums = counts[:, -1].unsqueeze(-1) # (B, 1). 884 | counts = torch.cat((sums, sums - counts[:, :-1]), dim=-1) # (B, K). 885 | mask |= counts <= num_extra_zeros[:, None] 886 | 887 | unique = indices[mask].reshape(b, num_unique) # (B, R), all indices are unique. 888 | return unique 889 | 890 | @staticmethod 891 | def _gather_mask(matrix, lengths, mask): 892 | b, n = matrix.shape 893 | device = matrix.device 894 | length_mask = torch.arange(n, device=device)[None].tile(b, 1) < lengths[:, None] # (B, N). 895 | mask = mask & length_mask 896 | counts = mask.sum(1) # (B). 897 | max_count = counts.max() 898 | padding = max_count - counts.min() 899 | if padding > 0: 900 | matrix = torch.cat((matrix, torch.zeros(b, padding, dtype=matrix.dtype, device=device)), dim=1) 901 | mask = torch.cat((mask, torch.ones(b, padding, dtype=torch.bool, device=device)), dim=1) 902 | mask &= torch.cumsum(mask, 1) <= max_count 903 | return matrix[mask].reshape(b, max_count), counts 904 | 905 | @staticmethod 906 | def _gather_broadcast(input, dim, index, backend="torch"): 907 | if backend == "torch": 908 | shape = np.maximum(np.array(input.shape), np.array(index.shape)).tolist() 909 | index[index < 0] += shape[dim] 910 | shape[dim] = input.shape[dim] 911 | input = input.broadcast_to(shape) 912 | shape[dim] = index.shape[dim] 913 | index = index.broadcast_to(shape) 914 | return input.gather(dim, index) 915 | elif backend == "numpy": 916 | result_array = np.take_along_axis(asarray(input), 917 | asarray(index), 918 | dim) 919 | result = torch.from_numpy(result_array).to(dtype=input.dtype, device=input.device) 920 | return result 921 | else: 922 | raise ValueError("Unknown broadcast backend: {}.".format(backend)) -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from typing import List 4 | 5 | from torch.optim import Optimizer 6 | from torch.optim.lr_scheduler import _LRScheduler 7 | 8 | 9 | class WarmupCosineLR(_LRScheduler): 10 | """ 11 | Sets the learning rate of each parameter group to follow a linear warmup schedule 12 | between warmup_start_lr and base_lr followed by a cosine annealing schedule between 13 | base_lr and eta_min. 14 | .. warning:: 15 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` 16 | after each iteration as calling it after each epoch will keep the starting lr at 17 | warmup_start_lr for the first epoch which is 0 in most cases. 18 | .. warning:: 19 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. 20 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of 21 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing 22 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling 23 | train and validation methods. 24 | Args: 25 | optimizer (Optimizer): Wrapped optimizer. 26 | warmup_epochs (int): Maximum number of iterations for linear warmup 27 | max_epochs (int): Maximum number of iterations 28 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 29 | eta_min (float): Minimum learning rate. Default: 0. 30 | last_epoch (int): The index of last epoch. Default: -1. 31 | Example: 32 | >>> layer = nn.Linear(10, 1) 33 | >>> optimizer = Adam(layer.parameters(), lr=0.02) 34 | >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40) 35 | >>> # 36 | >>> # the default case 37 | >>> for epoch in range(40): 38 | ... # train(...) 39 | ... # validate(...) 40 | ... scheduler.step() 41 | >>> # 42 | >>> # passing epoch param case 43 | >>> for epoch in range(40): 44 | ... scheduler.step(epoch) 45 | ... # train(...) 46 | ... # validate(...) 47 | """ 48 | 49 | def __init__( 50 | self, 51 | optimizer: Optimizer, 52 | warmup_epochs: int, 53 | max_epochs: int, 54 | warmup_start_lr: float = 1e-8, 55 | eta_min: float = 1e-8, 56 | last_epoch: int = -1, 57 | ) -> None: 58 | 59 | self.warmup_epochs = warmup_epochs 60 | self.max_epochs = max_epochs 61 | self.warmup_start_lr = warmup_start_lr 62 | self.eta_min = eta_min 63 | 64 | super(WarmupCosineLR, self).__init__(optimizer, last_epoch) 65 | 66 | def get_lr(self) -> List[float]: 67 | """ 68 | Compute learning rate using chainable form of the scheduler 69 | """ 70 | if not self._get_lr_called_within_step: 71 | warnings.warn( 72 | "To get the last learning rate computed by the scheduler, " 73 | "please use `get_last_lr()`.", 74 | UserWarning, 75 | ) 76 | 77 | if self.last_epoch == 0: 78 | return [self.warmup_start_lr] * len(self.base_lrs) 79 | elif self.last_epoch < self.warmup_epochs: 80 | return [ 81 | group["lr"] 82 | + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 83 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 84 | ] 85 | elif self.last_epoch == self.warmup_epochs: 86 | return self.base_lrs 87 | elif (self.last_epoch - 1 - self.max_epochs) % ( 88 | 2 * (self.max_epochs - self.warmup_epochs) 89 | ) == 0: 90 | return [ 91 | group["lr"] 92 | + (base_lr - self.eta_min) 93 | * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) 94 | / 2 95 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 96 | ] 97 | 98 | return [ 99 | ( 100 | 1 101 | + math.cos( 102 | math.pi 103 | * (self.last_epoch - self.warmup_epochs) 104 | / (self.max_epochs - self.warmup_epochs) 105 | ) 106 | ) 107 | / ( 108 | 1 109 | + math.cos( 110 | math.pi 111 | * (self.last_epoch - self.warmup_epochs - 1) 112 | / (self.max_epochs - self.warmup_epochs) 113 | ) 114 | ) 115 | * (group["lr"] - self.eta_min) 116 | + self.eta_min 117 | for group in self.optimizer.param_groups 118 | ] 119 | 120 | def _get_closed_form_lr(self) -> List[float]: 121 | """ 122 | Called when epoch is passed as a param to the `step` function of the scheduler. 123 | """ 124 | if self.last_epoch < self.warmup_epochs: 125 | return [ 126 | self.warmup_start_lr 127 | + self.last_epoch 128 | * (base_lr - self.warmup_start_lr) 129 | / (self.warmup_epochs - 1) 130 | for base_lr in self.base_lrs 131 | ] 132 | 133 | return [ 134 | self.eta_min 135 | + 0.5 136 | * (base_lr - self.eta_min) 137 | * ( 138 | 1 139 | + math.cos( 140 | math.pi 141 | * (self.last_epoch - self.warmup_epochs) 142 | / (self.max_epochs - self.warmup_epochs) 143 | ) 144 | ) 145 | for base_lr in self.base_lrs 146 | ] -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | from numpy import random 5 | from numpy import i0 # modified Bessel function of first kind order 0, I_0 6 | from scipy.special import ive # exponential modified Bessel function of first kind, I_v * exp(-abs(kappa)) 7 | import wandb 8 | import os 9 | 10 | def init_seeds(seed=123): 11 | torch.backends.cudnn.deterministic = True; 12 | np.random.seed(seed); 13 | random.seed(seed) 14 | torch.manual_seed(seed); 15 | torch.cuda.manual_seed(seed); 16 | torch.cuda.manual_seed_all(seed) 17 | 18 | def construct_mlp(n_hidden=2, dim_x=10, dim_z=2, dim_hidden=32): 19 | if dim_hidden == 0: 20 | # Zimmermann setup 21 | dim_first = 10 * dim_z 22 | dim_middle = 50 * dim_z 23 | dim_last = 10 * dim_z 24 | else: 25 | dim_first = dim_hidden 26 | dim_middle = dim_hidden 27 | dim_last = dim_hidden 28 | 29 | layers = [] 30 | layers.append(nn.Linear(dim_x, dim_first)) 31 | prev_dim = dim_first 32 | for i in range(n_hidden - 1): 33 | layers.append(nn.LeakyReLU()) 34 | layers.append(nn.Linear(prev_dim, dim_middle)) 35 | prev_dim = dim_middle 36 | if n_hidden - 1 >= 0: 37 | layers.append(nn.LeakyReLU()) 38 | layers.append(nn.Linear(prev_dim, dim_last)) 39 | prev_dim = dim_last 40 | layers.append(nn.LeakyReLU()) 41 | layers.append(nn.Linear(prev_dim, dim_z)) 42 | 43 | return nn.Sequential(*layers) 44 | 45 | def _vmf_normalize(kappa, dim): 46 | """Compute normalization constant using built-in numpy/scipy Bessel 47 | approximations. 48 | Works well on small kappa and mu. 49 | Imported from https://github.com/jasonlaska/spherecluster/blob/develop/spherecluster/von_mises_fisher_mixture.py 50 | """ 51 | if kappa < 1e-15: 52 | kappa = 1e-15 53 | 54 | num = (dim / 2.0 - 1.0) * np.log(kappa) 55 | 56 | if dim / 2.0 - 1.0 < 1e-15: 57 | denom = (dim / 2.0) * np.log(2.0 * np.pi) + np.log(i0(kappa)) 58 | else: 59 | denom = (dim / 2.0) * np.log(2.0 * np.pi) + np.log(ive(dim / 2.0 - 1.0, kappa)) + kappa 60 | 61 | if np.isinf(num): 62 | raise ValueError("VMF scaling numerator was inf.") 63 | 64 | if np.isinf(denom): 65 | raise ValueError("VMF scaling denominator was inf.") 66 | 67 | const = np.exp(num - denom) 68 | 69 | if const == 0: 70 | raise ValueError("VMF norm const was 0.") 71 | 72 | return const 73 | 74 | def vmf_norm_ratio(kappa, dim): 75 | # Approximates log(norm_const(0) / norm_const(kappa)) of a vMF distribution 76 | # See approx_vmf_norm_const.R to see how it was approximated 77 | 78 | if dim==2: 79 | return -2.439 + 0.9904 * kappa + 2.185e-4 * kappa**1.55 80 | elif dim==4: 81 | return -4.817 + 0.9713 * kappa + 6.479e-4 * kappa**1.55 82 | elif dim==8: 83 | return -7.908 + 0.9344 * kappa + 1.477e-3 * kappa**1.55 84 | elif dim==10: 85 | return -9.024 + 0.9165 * kappa + 1.877e-3 * kappa**1.55 86 | elif dim==12: 87 | return -9.958 + 0.8990 * kappa + 2.267e-3 * kappa**1.55 88 | elif dim==16: 89 | return -11.43 + 0.8649 * kappa + 3.020e-3 * kappa**1.55 90 | elif dim==32: 91 | return -14.38 + 0.7416 * kappa + 5.686e-3 * kappa**1.55 92 | elif dim==40: 93 | return -14.92 + 0.6868 * kappa + 6.837e-3 * kappa**1.55 94 | elif dim==48: 95 | return -15.13 + 0.6360 * kappa + 7.882e-3 * kappa**1.55 96 | elif dim==56: 97 | return -15.12 + 0.5889 * kappa + 8.833e-3 * kappa**1.55 98 | elif dim==64: 99 | return -14.94 + 0.5450 * kappa + 0.009698 * kappa**1.55 100 | elif dim==96: 101 | return -13.42 + 0.3973 * kappa + 1.246e-2 * kappa**1.55 102 | elif dim==128: 103 | return -11.44 + 0.2839 * kappa + 1.425e-2 * kappa**1.55 104 | elif dim==256: 105 | return -4.7340339 + 0.0289469 * kappa + 0.0173026 * kappa**1.55 106 | elif dim==512: 107 | return 0.8674 - 0.1124 * kappa + 0.01589 * kappa**1.55 108 | else: 109 | return np.log(_vmf_normalize(0, dim)) - np.log(_vmf_normalize(kappa, dim)) 110 | 111 | def log_vmf_norm_const(kappa, dim=10): 112 | # Approximates the log vMF normalization constant (for the ELK loss) 113 | # See approx_vmf_norm_const.R to see how it was approximated 114 | 115 | if dim==4: 116 | return -0.826604 - 0.354357 * kappa - 0.383723 * kappa**1.1 117 | if dim==8: 118 | return -1.29737 + 0.36841 * kappa - 0.80936 * kappa**1.1 119 | elif dim==10: 120 | return -1.27184 + 0.67365 * kappa - 0.98726 * kappa**1.1 121 | elif dim==16: 122 | return -0.23773 + 1.39146 * kappa - 1.39819 * kappa**1.1 123 | elif dim==32: 124 | return 8.07579 + 2.28954 * kappa - 1.86925 * kappa**1.1 125 | elif dim==64: 126 | return 38.82967 + 2.34269 * kappa - 1.77425 * kappa**1.1 127 | else: 128 | return np.log(_vmf_normalize(kappa, dim)) 129 | 130 | def pairwise_cos_sims(z): 131 | # Calculate pairwise cosine distances between rows in z return as flat vector 132 | cos_dists = torch.matmul(z, z.t()) 133 | cos_dists = cos_dists[torch.tril_indices(cos_dists.shape[0], cos_dists.shape[1]).unbind()] 134 | return cos_dists 135 | 136 | def pairwise_l2_dists(z): 137 | l2_dists = torch.cdist(z, z) 138 | l2_dists = l2_dists[torch.tril_indices(l2_dists.shape[0], l2_dists.shape[1]).unbind()] 139 | return l2_dists 140 | 141 | def init_wandb(args): 142 | ### If wandb-logging is turned on, initialize the wandb-run here: 143 | if args.use_wandb: 144 | import re 145 | if args.wandb_key != "": 146 | _ = os.system('wandb login {}'.format(args.wandb_key)) 147 | os.environ['WANDB_API_KEY'] = args.wandb_key 148 | else: 149 | print("No wandb key provided. Hoping that one was specified as environment variable.") 150 | # For the groupname, remove the seed, so that we can group per seed. 151 | group = re.sub("_seed_[^_.]*", "", args.savefolder) 152 | wandb.init(project=args.wandb_project, group=group, name=args.savefolder, dir=f"results/{args.savefolder}/") 153 | wandb.init(settings=wandb.Settings(start_method='fork')) 154 | wandb.config.update(args) 155 | -------------------------------------------------------------------------------- /utils/vmf_sampler.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The vMF Embeddings Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """torch.distributions.Distribution implementation of a von Mises-Fisher. 17 | 18 | Code was adapted from: 19 | https://github.com/nicola-decao/s-vae-pytorch/blob/master/hyperspherical_vae/distributions/von_mises_fisher.py 20 | """ 21 | 22 | import math 23 | import torch 24 | 25 | EPS = 1e-14 26 | 27 | 28 | class VonMisesFisher(torch.distributions.Distribution): 29 | """torch.distributions.Distribution implementation of a von Mises-Fisher.""" 30 | 31 | arg_constraints = { 32 | "loc": torch.distributions.constraints.real, 33 | "scale": torch.distributions.constraints.positive, 34 | } 35 | support = torch.distributions.constraints.real 36 | has_rsample = True 37 | _mean_carrier_measure = 0 38 | 39 | def __init__(self, loc, scale, validate_args=None, k=1): 40 | self.dtype = loc.dtype 41 | self.loc = loc 42 | self.scale = scale 43 | self.device = loc.device 44 | self.__m = loc.shape[-1] 45 | self.__e1 = (torch.zeros(loc.shape[-1], device=self.device)) 46 | self.__e1[0] = 1.0 47 | self.k = k 48 | self.log_norm_const = None # for log_prob (if ever used) 49 | 50 | super().__init__(self.loc.size(), validate_args=validate_args) 51 | 52 | def sample(self, shape=torch.Size()): 53 | with torch.no_grad(): 54 | return self.rsample(shape) 55 | 56 | def rsample(self, shape=torch.Size()): 57 | shape = shape if isinstance(shape, torch.Size) else torch.Size([shape]) 58 | 59 | # only use __sample_w3 for 3D vMFs, otherwise __sample_w_rej 60 | # This samples the 1 dimensional "mixture variable" that indicates how far we are from the mode (1, 0, ..., 0) 61 | w = ( 62 | self.__sample_w3(shape=shape) if self.__m == 3 else self.__sample_w_rej( 63 | shape=shape)) 64 | 65 | # Draw uniform points on the unit sphere for the m-1 "other" variables 66 | v = (torch.distributions.Normal( 67 | torch.tensor(0, dtype=self.dtype, device=self.device), 68 | torch.tensor(1, dtype=self.dtype, device=self.device), 69 | ).sample(shape + torch.Size(self.loc.shape)).transpose( 70 | 0, -1)[1:]).transpose(0, -1) 71 | v = v / v.norm(dim=-1, keepdim=True) 72 | 73 | # Build together the vector from the 1-dim rejection sample and the other m-1 dims 74 | w_ = torch.sqrt(torch.clamp(1 - (w**2), EPS)) 75 | x = torch.cat((w, w_ * v), -1) 76 | 77 | # rotate to get the modal value from (1, 0, ..., 0) to the intended mu 78 | z = self.__householder_rotation(x) 79 | z = z.type(self.dtype) 80 | 81 | # One last sanity check because this sometimes returns NaN 82 | if torch.any(torch.isnan(z)): 83 | return self.rsample(shape) 84 | else: 85 | return z 86 | 87 | def __sample_w3(self, shape): 88 | shape = shape + torch.Size(self.scale.shape) 89 | u = ( 90 | torch.distributions.Uniform( 91 | torch.tensor(0, dtype=self.dtype, device=self.device), 92 | torch.tensor(1, dtype=self.dtype, device=self.device), 93 | ).sample(shape)) 94 | self.__w = (1 + torch.stack( 95 | [torch.log(u), torch.log(1 - u) - 2 * self.scale], dim=0).logsumexp(0) / 96 | self.scale) 97 | return self.__w 98 | 99 | def __sample_w_rej(self, shape): 100 | c = torch.sqrt((4 * (self.scale**2)) + (self.__m - 1)**2) 101 | b_true = (-2 * self.scale + c) / (self.__m - 1) 102 | 103 | # Using Taylor approximation with a smooth swift from 10 < scale < 11 to 104 | # avoid numerical errors for large scale. 105 | b_app = (self.__m - 1) / (4 * self.scale) 106 | s = torch.min( 107 | torch.max( 108 | torch.tensor([0.0], dtype=self.dtype, device=self.device), 109 | self.scale - 10, 110 | ), 111 | torch.tensor([1.0], dtype=self.dtype, device=self.device), 112 | ) 113 | b = b_app * s + b_true * (1 - s) 114 | 115 | a = (self.__m - 1 + 2 * self.scale + c) / 4 116 | d = (4 * a * b) / (1 + b) - (self.__m - 1) * math.log(self.__m - 1) 117 | 118 | self.__b, (self.__e, self.__w) = b, self.__while_loop( 119 | b, a, d, shape, k=self.k) 120 | return self.__w 121 | 122 | @staticmethod 123 | def first_nonzero(x, dim, invalid_val=-1): 124 | mask = x > 0 125 | idx = torch.where( 126 | mask.any(dim=dim), 127 | mask.float().max(dim=1)[1].squeeze(), 128 | torch.tensor(invalid_val, device=x.device), 129 | ) 130 | return idx 131 | 132 | def __while_loop(self, b, a, d, shape, k=20, eps=1e-20): 133 | # Matrix while loop: samples a matrix of [A, k] samples, to avoid looping 134 | # all together. 135 | is_inf = self.scale == float("Inf") 136 | b, a, d, is_inf = [ 137 | e.repeat(*shape, *([1] * len(self.scale.shape))).reshape(-1, 1) 138 | for e in (b, a, d, is_inf) 139 | ] 140 | w, e, bool_mask = ( 141 | torch.zeros_like(b, device=self.device), 142 | torch.zeros_like(b, device=self.device), 143 | (torch.ones_like(b, device=self.device) == 1), 144 | ) 145 | 146 | sample_shape = torch.Size([b.shape[0], k]) 147 | shape = shape + torch.Size(self.scale.shape) 148 | 149 | while bool_mask.sum() != 0: 150 | con1 = torch.tensor( 151 | (self.__m - 1) / 2, dtype=torch.float64, device=self.device) 152 | con2 = torch.tensor( 153 | (self.__m - 1) / 2, dtype=torch.float64, device=self.device) 154 | e_ = ( 155 | torch.distributions.Beta(con1, con2).sample(sample_shape).type(self.dtype)) 156 | 157 | u = ( 158 | torch.distributions.Uniform( 159 | torch.tensor(0 + eps, dtype=self.dtype, device=self.device), 160 | torch.tensor(1 - eps, dtype=self.dtype, device=self.device), 161 | ).sample(sample_shape).type(self.dtype)) 162 | 163 | w_ = (1 - (1 + b) * e_) / (1 - (1 - b) * e_) 164 | t = (2 * a * b) / (1 - (1 - b) * e_) 165 | 166 | accept = ((self.__m - 1.0) * t.log() - t + d) > torch.log(u) 167 | 168 | # For samples with infinite kappa, return the mean (by just returning a w_ = 1) 169 | w_[is_inf] = 1 170 | accept[is_inf] = True 171 | 172 | accept_idx = self.first_nonzero( 173 | accept, dim=-1, invalid_val=-1).unsqueeze(1) 174 | accept_idx_clamped = accept_idx.clamp(0) 175 | w_ = w_.gather(1, accept_idx_clamped.view(-1, 1)) 176 | e_ = e_.gather(1, accept_idx_clamped.view(-1, 1)) 177 | 178 | reject = accept_idx < 0 179 | accept = ~reject if torch.__version__ >= "1.2.0" else 1 - reject 180 | 181 | w[bool_mask * accept] = w_[bool_mask * accept] 182 | e[bool_mask * accept] = e_[bool_mask * accept] 183 | 184 | bool_mask[bool_mask * accept] = reject[bool_mask * accept] 185 | 186 | return e.reshape(shape), w.reshape(shape) 187 | 188 | def __householder_rotation(self, x): 189 | u = self.__e1 - self.loc 190 | u = u / u.norm(dim=-1, keepdim=True).clamp_min(EPS) 191 | z = x - 2 * (x * u).sum(-1, keepdim=True) * u 192 | return z 193 | 194 | def log_prob(self, value): 195 | # Input: value - [batch, vMF, ?,dim] 196 | # Output: density of value under vMF WITHOUT NORMALIZING CONSTANTS, because their derivative is not implemented 197 | 198 | log_p = self.scale.unsqueeze(0) * torch.sum(value * self.loc.unsqueeze(0), dim=-1, keepdim=True) 199 | log_p = log_p.squeeze(-1) 200 | 201 | return log_p 202 | --------------------------------------------------------------------------------