├── requirements.txt ├── tokenizers └── punkt │ ├── PY3 │ ├── english.pickle │ └── README │ └── NOTICE ├── utils ├── infinite_iterator.py ├── invertible_network_utils.py ├── losses.py └── latent_spaces.py ├── LICENSE ├── .gitignore ├── README.md ├── encoders.py ├── datasets.py ├── main_imgtxt.py └── main_mlp.py /requirements.txt: -------------------------------------------------------------------------------- 1 | nltk==3.8.1 2 | numpy==1.24.2 3 | pandas==1.5.3 4 | scikit-learn==1.2.2 5 | scipy==1.10.1 6 | torch==1.13.1 7 | torchvision==0.2.0 8 | -------------------------------------------------------------------------------- /tokenizers/punkt/PY3/english.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imantdaunhawer/multimodal-contrastive-learning/HEAD/tokenizers/punkt/PY3/english.pickle -------------------------------------------------------------------------------- /tokenizers/punkt/NOTICE: -------------------------------------------------------------------------------- 1 | NLTK tokenizer, see https://github.com/nltk/nltk 2 | 3 | We provide the nltk English tokenizer here for convenience. It can also be 4 | downloaded using a python shell as follows: 5 | ``` 6 | import nltk 7 | nltk.download('punkt') 8 | ``` 9 | -------------------------------------------------------------------------------- /utils/infinite_iterator.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code originates from the following projects: 3 | - https://github.com/brendel-group/cl-ica 4 | - https://github.com/ysharma1126/ssl_identifiability 5 | """ 6 | 7 | from typing import Iterable 8 | 9 | 10 | class InfiniteIterator: 11 | """Infinitely repeat the iterable.""" 12 | def __init__(self, iterable: Iterable): 13 | self._iterable = iterable 14 | self.iterator = iter(self._iterable) 15 | 16 | def __iter__(self): 17 | return self 18 | 19 | def __next__(self): 20 | for _ in range(2): 21 | try: 22 | return next(self.iterator) 23 | except StopIteration: 24 | # reset iterator 25 | del self.iterator 26 | self.iterator = iter(self._iterable) 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Imant Daunhawer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | --- 24 | 25 | Note that individual files are adapted from the following projects, which might 26 | be subject to a different license: 27 | - https://github.com/ysharma1126/ssl_identifiability 28 | - https://github.com/brendel-group/cl-ica 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # MAC files 2 | *.DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | .vscode 10 | .idea 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # Project-specific 138 | data/ 139 | models/ 140 | *.backup 141 | *.log 142 | lsf.o* 143 | tags 144 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Identifiability Results for Multimodal Contrastive Learning 2 | 3 | Official code for the paper [Identifiability Results for Multimodal 4 | Contrastive Learning](https://arxiv.org/abs/2303.09166) presented at 5 | [ICLR 2023](https://iclr.cc/Conferences/2023). This repository contains 6 | pytorch code to reproduce the numerical simulation and the image/text 7 | experiment. The code to generate the image/text data from scratch is provided 8 | in a separate repo: [Multimodal3DIdent](https://github.com/imantdaunhawer/Multimodal3DIdent). 9 | 10 | ## Installation 11 | 12 | This project was developed with Python 3.10 and PyTorch 1.13.1. The 13 | dependencies can be installed as follows: 14 | 15 | ```bash 16 | # install dependencies (preferably, inside your conda/virtual environment) 17 | $ pip install -r requirements.txt 18 | 19 | # test if pytorch was installed with cuda support; should not raise an error 20 | $ python -c "import torch; assert torch.cuda.device_count() > 0, 'No cuda support'" 21 | ``` 22 | 23 | ## Numerical Simulation 24 | 25 | ```bash 26 | # train a model on data with style-change probability 0.75 and with statistical and 27 | # causal dependencies, and save it to the directory "models/mlp_example" 28 | $ python main_mlp.py --model-id "mlp_example" \ 29 | --style-change-prob 0.75 \ 30 | --statistical-dependence \ 31 | --content-dependent-style 32 | 33 | # evaluate the trained model 34 | $ python main_mlp.py --model-id "mlp_example" --load-args --evaluate 35 | ``` 36 | 37 | ## Image/Text Experiment 38 | 39 | ```bash 40 | # download and extract the dataset 41 | $ wget https://zenodo.org/record/7678231/files/m3di.tar.gz 42 | $ tar -xzf m3di.tar.gz 43 | 44 | # train a model with encoding size 4 and save it to the directory "models/imgtxt_example" 45 | $ python main_imgtxt.py --datapath "m3di" --model-id "imgtxt_example" --encoding-size 4 46 | 47 | # evaluate the trained model 48 | $ python main_imgtxt.py --datapath "m3di" --model-id "imgtxt_example" --load-args --evaluate 49 | ``` 50 | 51 | ## BibTeX 52 | If you find this project useful, please cite our paper: 53 | 54 | ```bibtex 55 | @article{daunhawer2023multimodal, 56 | author = { 57 | Daunhawer, Imant and 58 | Bizeul, Alice and 59 | Palumbo, Emanuele and 60 | Marx, Alexander and 61 | Vogt, Julia E. 62 | }, 63 | title = {Identifiability Results for Multimodal Contrastive Learning}, 64 | booktitle = {International Conference on Learning Representations}, 65 | year = {2023} 66 | } 67 | ``` 68 | 69 | ## Acknowledgements 70 | 71 | This project builds on the following resources. Please cite them appropriately. 72 | - https://github.com/ysharma1126/ssl_identifiability <3 73 | - https://github.com/brendel-group/cl-ica <3 74 | -------------------------------------------------------------------------------- /encoders.py: -------------------------------------------------------------------------------- 1 | """ 2 | Definition of encoder architectures. 3 | """ 4 | 5 | from torch import nn 6 | from typing import List, Union 7 | from typing_extensions import Literal 8 | 9 | 10 | def get_mlp(n_in: int, n_out: int, 11 | layers: List[int], 12 | layer_normalization: Union[None, Literal["bn"], Literal["gn"]] = None, 13 | act_inf_param=0.01): 14 | """ 15 | Creates an MLP. 16 | 17 | This code originates from the following projects: 18 | - https://github.com/brendel-group/cl-ica 19 | - https://github.com/ysharma1126/ssl_identifiability 20 | 21 | Args: 22 | n_in: Dimensionality of the input data 23 | n_out: Dimensionality of the output data 24 | layers: Number of neurons for each hidden layer 25 | layer_normalization: Normalization for each hidden layer. 26 | Possible values: bn (batch norm), gn (group norm), None 27 | """ 28 | modules: List[nn.Module] = [] 29 | 30 | def add_module(n_layer_in: int, n_layer_out: int, last_layer: bool = False): 31 | modules.append(nn.Linear(n_layer_in, n_layer_out)) 32 | # perform normalization & activation not in last layer 33 | if not last_layer: 34 | if layer_normalization == "bn": 35 | modules.append(nn.BatchNorm1d(n_layer_out)) 36 | elif layer_normalization == "gn": 37 | modules.append(nn.GroupNorm(1, n_layer_out)) 38 | modules.append(nn.LeakyReLU(negative_slope=act_inf_param)) 39 | 40 | return n_layer_out 41 | 42 | if len(layers) > 0: 43 | n_out_last_layer = n_in 44 | else: 45 | assert n_in == n_out, "Network with no layers must have matching n_in and n_out" 46 | modules.append(layers.Lambda(lambda x: x)) 47 | 48 | layers.append(n_out) 49 | 50 | for i, l in enumerate(layers): 51 | n_out_last_layer = add_module(n_out_last_layer, l, i == len(layers)-1) 52 | 53 | return nn.Sequential(*modules) 54 | 55 | 56 | class TextEncoder2D(nn.Module): 57 | """2D-ConvNet to encode text data.""" 58 | 59 | def __init__(self, input_size, output_size, sequence_length, 60 | embedding_dim=128, fbase=25): 61 | super(TextEncoder2D, self).__init__() 62 | if sequence_length < 24 or sequence_length > 31: 63 | raise ValueError( 64 | "TextEncoder2D expects sequence_length between 24 and 31") 65 | self.fbase = fbase 66 | self.embedding = nn.Linear(input_size, embedding_dim) 67 | self.convnet = nn.Sequential( 68 | # input size: 1 x sequence_length x embedding_dim 69 | nn.Conv2d(1, fbase, 4, 2, 1, bias=True), 70 | nn.BatchNorm2d(fbase), 71 | nn.ReLU(True), 72 | nn.Conv2d(fbase, fbase * 2, 4, 2, 1, bias=True), 73 | nn.BatchNorm2d(fbase * 2), 74 | nn.ReLU(True), 75 | nn.Conv2d(fbase * 2, fbase * 4, 4, 2, 1, bias=True), 76 | nn.BatchNorm2d(fbase * 4), 77 | nn.ReLU(True), 78 | # size: (fbase * 4) x 3 x 16 79 | ) 80 | self.ldim = fbase * 4 * 3 * 16 81 | self.linear = nn.Linear(self.ldim, output_size) 82 | 83 | def forward(self, x): 84 | x = self.embedding(x).unsqueeze(1) 85 | x = self.convnet(x) 86 | x = x.view(-1, self.ldim) 87 | x = self.linear(x) 88 | return x 89 | -------------------------------------------------------------------------------- /utils/invertible_network_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Create invertible mixing networks. 3 | 4 | This code originates from the following projects: 5 | - https://github.com/brendel-group/cl-ica 6 | - https://github.com/ysharma1126/ssl_identifiability 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | from scipy.stats import ortho_group 13 | from typing import Union 14 | from typing_extensions import Literal 15 | 16 | 17 | def construct_invertible_mlp(n: int = 20, n_layers: int = 2, n_iter_cond_thresh: int = 10000, 18 | cond_thresh_ratio: float = 0.25, 19 | weight_matrix_init: Union[Literal["pcl"], Literal["rvs"]] = 'pcl', 20 | act_fct: Union[Literal["relu"], Literal["leaky_relu"], Literal["elu"], 21 | Literal["smooth_leaky_relu"], Literal["softplus"]] = 'leaky_relu', 22 | verbose=False, 23 | ): 24 | """ 25 | Create an (approximately) invertible mixing network based on an MLP. 26 | Based on the mixing code by Hyvarinen et al. 27 | 28 | Args: 29 | n: Dimensionality of the input and output data 30 | n_layers: Number of layers in the MLP. 31 | n_iter_cond_thresh: How many random matrices to use as a pool to find weights. 32 | cond_thresh_ratio: Relative threshold how much the invertibility 33 | (based on the condition number) can be violated in each layer. 34 | weight_matrix_init: How to initialize the weight matrices. 35 | act_fct: Activation function for hidden layers. 36 | """ 37 | 38 | class SmoothLeakyReLU(nn.Module): 39 | def __init__(self, alpha=0.2): 40 | super().__init__() 41 | self.alpha = alpha 42 | 43 | def forward(self, x): 44 | return self.alpha * x + (1 - self.alpha) * torch.log(1 + torch.exp(x)) 45 | 46 | def get_act_fct(act_fct): 47 | if act_fct == 'relu': 48 | return torch.nn.ReLU, {}, 1 49 | if act_fct == 'leaky_relu': 50 | return torch.nn.LeakyReLU, {'negative_slope': 0.2}, 1 51 | elif act_fct == 'elu': 52 | return torch.nn.ELU, {'alpha': 1.0}, 1 53 | elif act_fct == 'max_out': 54 | raise NotImplementedError() 55 | elif act_fct == 'smooth_leaky_relu': 56 | return SmoothLeakyReLU, {'alpha': 0.2}, 1 57 | elif act_fct == 'softplus': 58 | return torch.nn.Softplus, {'beta': 1}, 1 59 | else: 60 | raise Exception(f'activation function {act_fct} not defined.') 61 | 62 | layers = [] 63 | act_fct, act_kwargs, _ = get_act_fct(act_fct) 64 | 65 | # Subfuction to normalize mixing matrix 66 | def l2_normalize(Amat, axis=0): 67 | # axis: 0=column-normalization, 1=row-normalization 68 | l2norm = np.sqrt(np.sum(Amat * Amat, axis)) 69 | Amat = Amat / l2norm 70 | return Amat 71 | 72 | condList = np.zeros([n_iter_cond_thresh]) 73 | if weight_matrix_init == 'pcl': 74 | for i in range(n_iter_cond_thresh): 75 | A = np.random.uniform(-1, 1, [n, n]) 76 | A = l2_normalize(A, axis=0) 77 | condList[i] = np.linalg.cond(A) 78 | condList.sort() # Ascending order 79 | condThresh = condList[int(n_iter_cond_thresh * cond_thresh_ratio)] 80 | if verbose: 81 | print("condition number threshold: {0:f}".format(condThresh)) 82 | 83 | for i in range(n_layers): 84 | 85 | lin_layer = nn.Linear(n, n, bias=False) 86 | 87 | if weight_matrix_init == 'pcl': 88 | condA = condThresh + 1 89 | while condA > condThresh: 90 | weight_matrix = np.random.uniform(-1, 1, (n, n)) 91 | weight_matrix = l2_normalize(weight_matrix, axis=0) 92 | 93 | condA = np.linalg.cond(weight_matrix) 94 | # print(" L{0:d}: cond={1:f}".format(i, condA)) 95 | if verbose: 96 | print(f"layer {i+1}/{n_layers}, condition number: {np.linalg.cond(weight_matrix)}") 97 | lin_layer.weight.data = torch.tensor(weight_matrix, dtype=torch.float32) 98 | 99 | elif weight_matrix_init == 'rvs': 100 | weight_matrix = ortho_group.rvs(n) 101 | lin_layer.weight.data = torch.tensor(weight_matrix, dtype=torch.float32) 102 | elif weight_matrix_init == 'expand': 103 | pass 104 | else: 105 | raise Exception(f'weight matrix {weight_matrix_init} not implemented') 106 | 107 | layers.append(lin_layer) 108 | 109 | if i < n_layers - 1: 110 | layers.append(act_fct(**act_kwargs)) 111 | 112 | mixing_net = nn.Sequential(*layers) 113 | 114 | # fix parameters 115 | for p in mixing_net.parameters(): 116 | p.requires_grad = False 117 | 118 | return mixing_net 119 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Definition of loss functions. 3 | 4 | This code originates from the following projects: 5 | - https://github.com/brendel-group/cl-ica 6 | - https://github.com/ysharma1126/ssl_identifiability 7 | """ 8 | 9 | 10 | from abc import ABC, abstractmethod 11 | import torch 12 | 13 | 14 | class CLLoss(ABC): 15 | """Abstract class to define losses in the CL framework that use one 16 | positive pair and one negative pair""" 17 | 18 | @abstractmethod 19 | def loss(self, z1, z2_con_z1, z3, z1_rec, z2_con_z1_rec, z3_rec): 20 | """ 21 | z1_t = h(z1) 22 | z2_t = h(z2) 23 | z3_t = h(z3) 24 | and z1 ~ p(z1), z3 ~ p(z3) 25 | and z2 ~ p(z2 | z1) 26 | 27 | returns the total loss and componentwise contributions 28 | """ 29 | pass 30 | 31 | def __call__(self, z1, z2_con_z1, z3, z1_rec, z2_con_z1_rec, z3_rec): 32 | return self.loss(z1, z2_con_z1, z3, z1_rec, z2_con_z1_rec, z3_rec) 33 | 34 | 35 | class LpSimCLRLoss(CLLoss): 36 | """Extended InfoNCE objective for non-normalized representations based on an Lp norm. 37 | 38 | Args: 39 | p: Exponent of the norm to use. 40 | tau: Rescaling parameter of exponent. 41 | alpha: Weighting factor between the two summands. 42 | simclr_compatibility_mode: Use logsumexp (as used in SimCLR loss) instead of logmeanexp 43 | pow: Use p-th power of Lp norm instead of Lp norm. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | p: int = 2, 49 | tau: float = 1.0, 50 | alpha: float = 0.5, 51 | simclr_compatibility_mode: bool = False, 52 | simclr_denominator: bool = True, 53 | pow: bool = True, 54 | ): 55 | self.p = p 56 | self.tau = tau 57 | self.alpha = alpha 58 | self.simclr_compatibility_mode = simclr_compatibility_mode 59 | self.simclr_denominator = simclr_denominator 60 | self.pow = pow 61 | 62 | def loss(self, z1, z2_con_z1, z3, z1_rec, z2_con_z1_rec, z3_rec): 63 | del z1, z2_con_z1, z3 64 | 65 | if self.p < 1.0: 66 | # add small epsilon to make calculation of norm numerically more stable 67 | neg = torch.norm( 68 | torch.abs(z1_rec.unsqueeze(0) - z3_rec.unsqueeze(1) + 1e-12), 69 | p=self.p, 70 | dim=-1, 71 | ) 72 | pos = torch.norm( 73 | torch.abs(z1_rec - z2_con_z1_rec) + 1e-12, p=self.p, dim=-1 74 | ) 75 | else: 76 | neg = torch.pow(z1_rec.unsqueeze(1) - z3_rec.unsqueeze(0), float(self.p)).sum(dim=-1) 77 | pos = torch.pow(z1_rec - z2_con_z1_rec, float(self.p)).sum(dim=-1) 78 | 79 | if not self.pow: 80 | neg = neg.pow(1.0 / self.p) 81 | pos = pos.pow(1.0 / self.p) 82 | 83 | # all = torch.cat((neg, pos.unsqueeze(1)), dim=1) 84 | 85 | if self.simclr_compatibility_mode: 86 | neg_and_pos = torch.cat((neg, pos.unsqueeze(1)), dim=1) 87 | 88 | loss_pos = pos / self.tau 89 | loss_neg = torch.logsumexp(-neg_and_pos / self.tau, dim=1) 90 | else: 91 | if self.simclr_denominator: 92 | neg_and_pos = torch.cat((neg, pos.unsqueeze(1)), dim=1) 93 | else: 94 | neg_and_pos = neg 95 | 96 | loss_pos = pos / self.tau 97 | loss_neg = _logmeanexp(-neg_and_pos / self.tau, dim=1) 98 | 99 | loss = 2 * (self.alpha * loss_pos + (1.0 - self.alpha) * loss_neg) 100 | 101 | loss_mean = torch.mean(loss) 102 | # loss_std = torch.std(loss) 103 | 104 | loss_pos_mean = torch.mean(loss_pos) 105 | loss_neg_mean = torch.mean(loss_neg) 106 | 107 | return loss_mean, loss, [loss_pos_mean, loss_neg_mean] 108 | 109 | 110 | def _logmeanexp(x, dim): 111 | # do the -log thing to use logsumexp to calculate the mean and not the sum 112 | # as log sum_j exp(x_j - log N) = log sim_j exp(x_j)/N = log mean(exp(x_j) 113 | N = torch.tensor(x.shape[dim], dtype=x.dtype, device=x.device) 114 | return torch.logsumexp(x, dim=dim) - torch.log(N) 115 | 116 | 117 | def infonce_loss(z1, z2, sim_metric, criterion, tau=1.0): 118 | """ 119 | This code originates from the following project: 120 | - https://github.com/ysharma1126/ssl_identifiability 121 | """ 122 | sim11 = sim_metric(z1.unsqueeze(-2), z1.unsqueeze(-3)) / tau 123 | sim22 = sim_metric(z2.unsqueeze(-2), z2.unsqueeze(-3)) / tau 124 | sim12 = sim_metric(z1.unsqueeze(-2), z2.unsqueeze(-3)) / tau 125 | d = sim12.shape[-1] 126 | sim11[..., range(d), range(d)] = float('-inf') 127 | sim22[..., range(d), range(d)] = float('-inf') 128 | raw_scores1 = torch.cat([sim12, sim11], dim=-1) 129 | raw_scores2 = torch.cat([sim22, sim12.transpose(-1, -2)], dim=-1) 130 | raw_scores = torch.cat([raw_scores1, raw_scores2], dim=-2) 131 | targets = torch.arange(2 * d, dtype=torch.long, device=raw_scores.device) 132 | loss_value = criterion(raw_scores, targets) 133 | return loss_value 134 | -------------------------------------------------------------------------------- /utils/latent_spaces.py: -------------------------------------------------------------------------------- 1 | """ 2 | Definition of latent spaces for the multimodal setup. 3 | 4 | Parts of this code originate from the files spaces.py and latent_spaces.py 5 | from the following projects: 6 | - https://github.com/brendel-group/cl-ica 7 | - https://github.com/ysharma1126/ssl_identifiability 8 | """ 9 | 10 | from typing import Callable, List 11 | from abc import ABC, abstractmethod 12 | import numpy as np 13 | import torch 14 | 15 | 16 | class Space(ABC): 17 | @abstractmethod 18 | def uniform(self, size, device): 19 | pass 20 | 21 | @abstractmethod 22 | def normal(self, mean, std, size, device): 23 | pass 24 | 25 | @property 26 | @abstractmethod 27 | def dim(self): 28 | pass 29 | 30 | 31 | class NRealSpace(Space): 32 | def __init__(self, n): 33 | self.n = n 34 | 35 | @property 36 | def dim(self): 37 | return self.n 38 | 39 | def uniform(self, size, device="cpu"): 40 | raise NotImplementedError("Not defined on R^n") 41 | 42 | def normal(self, mean, std, size, device="cpu", change_prob=1., Sigma=None): 43 | """Sample from a Normal distribution in R^N. 44 | Args: 45 | mean: Value(s) to sample around. 46 | std: Concentration parameter of the distribution (=standard deviation). 47 | size: Number of samples to draw. 48 | device: torch device identifier 49 | """ 50 | if mean is None: 51 | mean = torch.zeros(self.n) 52 | if len(mean.shape) == 1 and mean.shape[0] == self.n: 53 | mean = mean.unsqueeze(0) 54 | if not torch.is_tensor(std): 55 | std = torch.ones(self.n) * std 56 | if len(std.shape) == 1 and std.shape[0] == self.n: 57 | std = std.unsqueeze(0) 58 | assert len(mean.shape) == 2 59 | assert len(std.shape) == 2 60 | 61 | if torch.is_tensor(mean): 62 | mean = mean.to(device) 63 | if torch.is_tensor(std): 64 | std = std.to(device) 65 | change_indices = torch.distributions.binomial.Binomial(probs=change_prob).sample((size, self.n)).to(device) 66 | if Sigma is not None: 67 | changes = np.random.multivariate_normal(np.zeros(self.n), Sigma, size) 68 | changes = torch.FloatTensor(changes).to(device) 69 | else: 70 | changes = torch.randn((size, self.n), device=device) * std 71 | return mean + change_indices * changes 72 | 73 | 74 | class LatentSpace: 75 | """Combines a topological space with a marginal and conditional density to sample from.""" 76 | 77 | def __init__( 78 | self, space: Space, sample_marginal: Callable, sample_conditional: Callable 79 | ): 80 | self.space = space 81 | self._sample_marginal = sample_marginal 82 | self._sample_conditional = sample_conditional 83 | 84 | @property 85 | def sample_conditional(self): 86 | if self._sample_conditional is None: 87 | raise RuntimeError("sample_conditional was not set") 88 | return lambda *args, **kwargs: self._sample_conditional( 89 | self.space, *args, **kwargs 90 | ) 91 | 92 | @sample_conditional.setter 93 | def sample_conditional(self, value: Callable): 94 | assert callable(value) 95 | self._sample_conditional = value 96 | 97 | @property 98 | def sample_marginal(self): 99 | if self._sample_marginal is None: 100 | raise RuntimeError("sample_marginal was not set") 101 | return lambda *args, **kwargs: self._sample_marginal( 102 | self.space, *args, **kwargs 103 | ) 104 | 105 | @sample_marginal.setter 106 | def sample_marginal(self, value: Callable): 107 | assert callable(value) 108 | self._sample_marginal = value 109 | 110 | @property 111 | def dim(self): 112 | return self.space.dim 113 | 114 | 115 | class ProductLatentSpace(LatentSpace): 116 | """A latent space which is the cartesian product of other latent spaces.""" 117 | 118 | def __init__(self, spaces: List[LatentSpace], a=None, B=None): 119 | """Assumes that the list of spaces is [c, s] or [c, s, m1, m2].""" 120 | self.spaces = spaces 121 | self.a = a 122 | self.B = B 123 | 124 | # determine dimensions, assuming the ordering [c, s, m1, m2] 125 | assert len(spaces) in (2, 4) # either [c, s] or [c, s, m1, m2] 126 | self.content_n = spaces[0].dim 127 | self.style_n = spaces[1].dim 128 | self.modality_n = 0 129 | if len(spaces) > 2: 130 | assert spaces[2].dim == spaces[3].dim # can be relaxed 131 | self.modality_n = spaces[2].dim 132 | 133 | def sample_conditional(self, z, size, **kwargs): 134 | z_new = [] 135 | n = 0 136 | for s in self.spaces: 137 | if len(z.shape) == 1: 138 | z_s = z[n : n + s.space.n] 139 | else: 140 | z_s = z[:, n : n + s.space.n] 141 | n += s.space.n 142 | z_new.append(s.sample_conditional(z=z_s, size=size, **kwargs)) 143 | 144 | return torch.cat(z_new, -1) 145 | 146 | def sample_marginal(self, size, **kwargs): 147 | z = [s.sample_marginal(size=size, **kwargs) for s in self.spaces] 148 | if self.a is not None and self.B is not None: 149 | content_dependent_style = torch.einsum("ij,kj -> ki", self.B, z[0]) + self.a 150 | z[1] += content_dependent_style # index 1 is style 151 | return torch.cat(z, -1) 152 | 153 | def sample_z1_and_z2(self, size, device): 154 | z = self.sample_marginal(size=size, device=device) # z = (c, s, m1, m2) 155 | z_tilde = self.sample_conditional(z, size=size, device=device) # s -> s_tilde 156 | z1 = self.z_to_zi(z, modality=1) # z1 = (c, s, m1) 157 | z2 = self.z_to_zi(z_tilde, modality=2) # z2 = (c, s_tilde, m2) 158 | return z1, z2 159 | 160 | def z_to_csm(self, z): 161 | nc, ns, nm = self.content_n, self.style_n, self.modality_n 162 | ix_c = torch.tensor(range(0, nc), dtype=int) 163 | ix_s = torch.tensor(range(nc, nc + ns), dtype=int) 164 | ix_m1 = torch.tensor(range(nc + ns, nc + ns + nm), dtype=int) 165 | ix_m2 = torch.tensor(range(nc + ns + nm, nc + ns + nm*2), dtype=int) 166 | c = z[:, ix_c] 167 | s = z[:, ix_s] 168 | m1 = z[:, ix_m1] 169 | m2 = z[:, ix_m2] 170 | return c, s, m1, m2 171 | 172 | def zi_to_csmi(self, zi): 173 | nc, ns, nm = self.content_n, self.style_n, self.modality_n 174 | ix_c = torch.tensor(range(0, nc), dtype=int) 175 | ix_s = torch.tensor(range(nc, nc + ns), dtype=int) 176 | ix_mi = torch.tensor(range(nc + ns, nc + ns + nm), dtype=int) 177 | c = zi[:, ix_c] 178 | s = zi[:, ix_s] 179 | mi = zi[:, ix_mi] 180 | return c, s, mi 181 | 182 | def z_to_zi(self, z, modality): 183 | assert modality in [1, 2] 184 | c, s, m1, m2 = self.z_to_csm(z) 185 | if modality == 1: 186 | zi = torch.cat((c, s, m1), dim=-1) 187 | elif modality == 2: 188 | zi = torch.cat((c, s, m2), dim=-1) 189 | return zi 190 | 191 | @property 192 | def dim(self): 193 | return sum([s.dim for s in self.spaces]) 194 | -------------------------------------------------------------------------------- /tokenizers/punkt/PY3/README: -------------------------------------------------------------------------------- 1 | Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected) 2 | 3 | Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have 4 | been contributed by various people using NLTK for sentence boundary detection. 5 | 6 | For information about how to use these models, please confer the tokenization HOWTO: 7 | http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html 8 | and chapter 3.8 of the NLTK book: 9 | http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation 10 | 11 | There are pretrained tokenizers for the following languages: 12 | 13 | File Language Source Contents Size of training corpus(in tokens) Model contributed by 14 | ======================================================================================================================================================================= 15 | czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss 16 | Literarni Noviny 17 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 18 | danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss 19 | (Berlingske Avisdata, Copenhagen) Weekend Avisen 20 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 21 | dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss 22 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 23 | english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss 24 | (American) 25 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 26 | estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss 27 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 28 | finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss 29 | Text Bank (Suomen Kielen newspapers 30 | Tekstipankki) 31 | Finnish Center for IT Science 32 | (CSC) 33 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 34 | french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss 35 | (European) 36 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 37 | german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss 38 | (Switzerland) CD-ROM 39 | (Uses "ss" 40 | instead of "ß") 41 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 42 | greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss 43 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 44 | italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss 45 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 46 | norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss 47 | (Bokmål and Information Technologies, 48 | Nynorsk) Bergen 49 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 50 | polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner 51 | (http://www.nkjp.pl/) 52 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 53 | portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss 54 | (Brazilian) (Linguateca) 55 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 56 | slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss 57 | Slovene Academy for Arts 58 | and Sciences 59 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 60 | spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss 61 | (European) 62 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 63 | swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss 64 | (and some other texts) 65 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 66 | turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss 67 | (Türkçe Derlem Projesi) 68 | University of Ankara 69 | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- 70 | 71 | The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to 72 | Unicode using the codecs module. 73 | 74 | Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection. 75 | Computational Linguistics 32: 485-525. 76 | 77 | ---- Training Code ---- 78 | 79 | # import punkt 80 | import nltk.tokenize.punkt 81 | 82 | # Make a new Tokenizer 83 | tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer() 84 | 85 | # Read in training corpus (one example: Slovene) 86 | import codecs 87 | text = codecs.open("slovene.plain","Ur","iso-8859-2").read() 88 | 89 | # Train tokenizer 90 | tokenizer.train(text) 91 | 92 | # Dump pickled tokenizer 93 | import pickle 94 | out = open("slovene.pickle","wb") 95 | pickle.dump(tokenizer, out) 96 | out.close() 97 | 98 | --------- 99 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Definition of datasets. 3 | """ 4 | 5 | import io 6 | import json 7 | import os 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | from collections import Counter, OrderedDict 12 | from nltk.tokenize import sent_tokenize, word_tokenize 13 | from torchvision.datasets.folder import pil_loader 14 | 15 | 16 | class OrderedCounter(Counter, OrderedDict): 17 | """Counter that remembers the order of elements encountered.""" 18 | 19 | def __repr__(self): 20 | return "%s(%r)" % (self.__class__.__name__, OrderedDict(self)) 21 | 22 | def __reduce__(self): 23 | return self.__class__, (OrderedDict(self),) 24 | 25 | 26 | class Multimodal3DIdent(torch.utils.data.Dataset): 27 | """Multimodal3DIdent Dataset. 28 | 29 | Attributes: 30 | FACTORS (dict): names of factors for image and text modalities. 31 | DISCRETE_FACTORS (dict): names of discrete factors, respectively. 32 | """ 33 | 34 | FACTORS = { 35 | "image": { 36 | 0: "object_shape", 37 | 1: "object_ypos", 38 | 2: "object_xpos", 39 | # 3: "object_zpos", # is constant 40 | 4: "object_alpharot", 41 | 5: "object_betarot", 42 | 6: "object_gammarot", 43 | 7: "spotlight_pos", 44 | 8: "object_color", 45 | 9: "spotlight_color", 46 | 10: "background_color" 47 | }, 48 | "text": { 49 | 0: "object_shape", 50 | 1: "object_ypos", 51 | 2: "object_xpos", 52 | # 3: "object_zpos", # is constant 53 | 4: "object_color_index", 54 | 5: "text_phrasing" 55 | } 56 | } 57 | 58 | DISCRETE_FACTORS = { 59 | "image": { 60 | 0: "object_shape", 61 | 1: "object_ypos", 62 | 2: "object_xpos", 63 | # 3: "object_zpos", # is constant 64 | }, 65 | "text": { 66 | 0: "object_shape", 67 | 1: "object_ypos", 68 | 2: "object_xpos", 69 | # 3: "object_zpos", # is constant 70 | 4: "object_color_index", 71 | 5: "text_phrasing" 72 | } 73 | } 74 | 75 | def __init__(self, data_dir, mode="train", transform=None, 76 | has_labels=True, vocab_filepath=None): 77 | """ 78 | Args: 79 | data_dir (string): path to directory. 80 | mode (string): name of data split, 'train', 'val', or 'test'. 81 | transform (callable): Optional transform to be applied. 82 | has_labels (bool): Indicates if the data has ground-truth labels. 83 | vocab_filepath (str): Optional path to a saved vocabulary. If None, 84 | the vocabulary will be (re-)created. 85 | """ 86 | self.mode = mode 87 | self.transform = transform 88 | self.has_labels = has_labels 89 | self.data_dir = data_dir 90 | self.data_dir_mode = os.path.join(data_dir, mode) 91 | self.latents_text_filepath = \ 92 | os.path.join(self.data_dir_mode, "latents_text.csv") 93 | self.latents_image_filepath = \ 94 | os.path.join(self.data_dir_mode, "latents_image.csv") 95 | self.text_filepath = \ 96 | os.path.join(self.data_dir_mode, "text", "text_raw.txt") 97 | self.image_dir = os.path.join(self.data_dir_mode, "images") 98 | 99 | # load text 100 | text_in_sentences, text_in_words = self._load_text() 101 | self.text_in_sentences = text_in_sentences # sentence-tokenized text 102 | self.text_in_words = text_in_words # word-tokenized text 103 | 104 | # determine num_samples and max_sequence_length 105 | self.num_samples = len(self.text_in_sentences) 106 | self.max_sequence_length = \ 107 | max([len(sent) for sent in self.text_in_words]) + 1 # +1 for "eos" 108 | 109 | # load or create the vocabulary (i.e., word <-> index maps) 110 | self.w2i, self.i2w = self._load_vocab(vocab_filepath) 111 | self.vocab_size = len(self.w2i) 112 | if vocab_filepath: 113 | self.vocab_filepath = vocab_filepath 114 | else: 115 | self.vocab_filepath = os.path.join(self.data_dir, "vocab.json") 116 | 117 | # optionally, load ground-truth labels 118 | if has_labels: 119 | self.labels = self._load_labels() 120 | 121 | # create list of image filepaths 122 | image_paths = [] 123 | width = int(np.ceil(np.log10(self.num_samples))) 124 | for i in range(self.num_samples): 125 | fp = os.path.join(self.image_dir, str(i).zfill(width) + ".png") 126 | image_paths.append(fp) 127 | self.image_paths = image_paths 128 | 129 | def get_w2i(self, word): 130 | try: 131 | return self.w2i[word] 132 | except KeyError: 133 | return "{unk}" # special token for unknown words 134 | 135 | def _load_text(self): 136 | print(f"Tokenization of {self.mode} data...") 137 | 138 | # load raw text 139 | with open(self.text_filepath, "r") as f: 140 | text_raw = f.read() 141 | 142 | # create sentence-tokenized text 143 | text_in_sentences = sent_tokenize(text_raw) 144 | 145 | # create word-tokenized text 146 | text_in_words = [word_tokenize(sent) for sent in text_in_sentences] 147 | 148 | return text_in_sentences, text_in_words 149 | 150 | def _load_labels(self): 151 | 152 | # load image labels 153 | z_image = pd.read_csv(self.latents_image_filepath) 154 | 155 | # load text labels 156 | z_text = pd.read_csv(self.latents_text_filepath) 157 | 158 | # check if all factors are present 159 | for v in self.FACTORS["image"].values(): 160 | assert v in z_image.keys() 161 | for v in self.FACTORS["text"].values(): 162 | assert v in z_text.keys() 163 | 164 | # create label dict 165 | labels = {"z_image": z_image, "z_text": z_text} 166 | 167 | return labels 168 | 169 | def _create_vocab(self, vocab_filepath): 170 | print(f"Creating vocabulary as '{vocab_filepath}'...") 171 | 172 | if self.mode != "train": 173 | raise ValueError("Vocabulary should be created from training data") 174 | 175 | # initialize counter and word <-> index maps 176 | ordered_counter = OrderedCounter() # counts occurrence of each word 177 | w2i = dict() # word-to-index map 178 | i2w = dict() # index-to-word map 179 | unique_words = [] 180 | 181 | # add special tokens for padding, end-of-string, and unknown words 182 | special_tokens = ["{pad}", "{eos}", "{unk}"] 183 | for st in special_tokens: 184 | i2w[len(w2i)] = st 185 | w2i[st] = len(w2i) 186 | 187 | for i, words in enumerate(self.text_in_words): 188 | ordered_counter.update(words) 189 | 190 | for w, _ in ordered_counter.items(): 191 | if w not in special_tokens: 192 | i2w[len(w2i)] = w 193 | w2i[w] = len(w2i) 194 | else: 195 | unique_words.append(w) 196 | if len(w2i) != len(i2w): 197 | print(unique_words) 198 | raise ValueError("Mismatch between w2i and i2w mapping") 199 | 200 | # save vocabulary to disk 201 | vocab = dict(w2i=w2i, i2w=i2w) 202 | with io.open(vocab_filepath, "wb") as vocab_file: 203 | jd = json.dumps(vocab, ensure_ascii=False) 204 | vocab_file.write(jd.encode("utf8", "replace")) 205 | 206 | return vocab 207 | 208 | def _load_vocab(self, vocab_filepath=None): 209 | if vocab_filepath is not None: 210 | with open(vocab_filepath, "r") as vocab_file: 211 | vocab = json.load(vocab_file) 212 | else: 213 | new_filepath = os.path.join(self.data_dir, "vocab.json") 214 | vocab = self._create_vocab(vocab_filepath=new_filepath) 215 | return (vocab["w2i"], vocab["i2w"]) 216 | 217 | def __getitem__(self, idx): 218 | if torch.is_tensor(idx): 219 | idx = idx.tolist() 220 | 221 | # load image 222 | img_name = self.image_paths[idx] 223 | image = pil_loader(img_name) 224 | if self.transform is not None: 225 | image = self.transform(image) 226 | 227 | # load text 228 | words = self.text_in_words[idx] 229 | words = words + ["{eos}"] 230 | words = words + ["{pad}" for c in range(self.max_sequence_length-len(words))] 231 | indices = [self.get_w2i(word) for word in words] 232 | indices_onehot = torch.nn.functional.one_hot( 233 | torch.Tensor(indices).long(), self.vocab_size).float() 234 | 235 | # load labels 236 | if self.has_labels: 237 | z_image = {k: v[idx] for k, v in self.labels["z_image"].items()} 238 | z_text = {k: v[idx] for k, v in self.labels["z_text"].items()} 239 | else: 240 | z_image, z_text = None, None 241 | 242 | sample = { 243 | "image": image, 244 | "text": indices_onehot, 245 | "z_image": z_image, 246 | "z_text": z_text} 247 | return sample 248 | 249 | def __len__(self): 250 | return self.num_samples 251 | -------------------------------------------------------------------------------- /main_imgtxt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Experiment with image/text pairs. 3 | """ 4 | 5 | import argparse 6 | import json 7 | import os 8 | import random 9 | import uuid 10 | import warnings 11 | 12 | import numpy as np 13 | import pandas as pd 14 | import torch 15 | from sklearn.kernel_ridge import KernelRidge 16 | from sklearn.linear_model import LinearRegression, LogisticRegression 17 | from sklearn.metrics import accuracy_score, r2_score 18 | from sklearn.model_selection import GridSearchCV 19 | from sklearn.neural_network import MLPClassifier, MLPRegressor 20 | from sklearn.preprocessing import StandardScaler 21 | from torch.nn.utils import clip_grad_norm_ 22 | from torch.utils.data import DataLoader 23 | from torchvision import transforms 24 | from torchvision.models import resnet18 25 | 26 | from datasets import Multimodal3DIdent 27 | from encoders import TextEncoder2D 28 | from utils.infinite_iterator import InfiniteIterator 29 | from utils.losses import infonce_loss 30 | 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--datapath", type=str, required=True) 35 | parser.add_argument("--model-dir", type=str, default="models") 36 | parser.add_argument("--model-id", type=str, default=None) 37 | parser.add_argument("--encoding-size", type=int, default=4) 38 | parser.add_argument("--hidden-size", type=int, default=100) 39 | parser.add_argument("--tau", type=float, default=1.0) 40 | parser.add_argument("--lr", type=float, default=1e-5) 41 | parser.add_argument("--batch-size", type=int, default=256) 42 | parser.add_argument("--train-steps", type=int, default=100001) 43 | parser.add_argument("--log-steps", type=int, default=1000) 44 | parser.add_argument("--checkpoint-steps", type=int, default=10000) 45 | parser.add_argument("--evaluate", action='store_true') 46 | parser.add_argument("--seed", type=int, default=np.random.randint(32**2-1)) 47 | parser.add_argument("--workers", type=int, default=4) 48 | parser.add_argument("--no-cuda", action="store_true") 49 | parser.add_argument("--save-all-checkpoints", action="store_true") 50 | parser.add_argument("--load-args", action="store_true") 51 | args = parser.parse_args() 52 | return args, parser 53 | 54 | 55 | def train_step(data, f1, f2, loss_func, optimizer, params): 56 | 57 | # reset grad 58 | if optimizer is not None: 59 | optimizer.zero_grad() 60 | 61 | # compute loss 62 | x1 = data['image'] 63 | x2 = data['text'] 64 | hz1 = f1(x1) 65 | hz2 = f2(x2) 66 | loss_value1 = loss_func(hz1, hz2) 67 | loss_value2 = loss_func(hz2, hz1) 68 | loss_value = 0.5 * (loss_value1 + loss_value2) # symmetrized infonce loss 69 | 70 | # backprop 71 | if optimizer is not None: 72 | loss_value.backward() 73 | clip_grad_norm_(params, max_norm=2.0, norm_type=2) # stabilizes training 74 | optimizer.step() 75 | 76 | return loss_value.item() 77 | 78 | 79 | def val_step(data, f1, f2, loss_func): 80 | return train_step(data, f1, f2, loss_func, optimizer=None, params=None) 81 | 82 | 83 | def get_data(dataset, f1, f2, loss_func, dataloader_kwargs): 84 | loader = DataLoader(dataset, **dataloader_kwargs) 85 | iterator = InfiniteIterator(loader) 86 | labels_image = {v: [] for v in Multimodal3DIdent.FACTORS["image"].values()} 87 | labels_text = {v: [] for v in Multimodal3DIdent.FACTORS["text"].values()} 88 | rdict = {"hz_image": [], "hz_text": [], "loss_values": [], 89 | "labels_image": labels_image, "labels_text": labels_text} 90 | i = 0 91 | with torch.no_grad(): 92 | while (i < len(dataset)): # NOTE: can yield slightly too many samples 93 | 94 | # load batch 95 | i += loader.batch_size 96 | data = next(iterator) # contains images, texts, and labels 97 | 98 | # compute loss 99 | loss_value = val_step(data, f1, f2, loss_func) 100 | rdict["loss_values"].append([loss_value]) 101 | 102 | # collect representations 103 | hz_image = f1(data["image"]) 104 | hz_text = f2(data["text"]) 105 | rdict["hz_image"].append(hz_image.detach().cpu().numpy()) 106 | rdict["hz_text"].append(hz_text.detach().cpu().numpy()) 107 | 108 | # collect image labels 109 | for k in rdict["labels_image"]: 110 | labels_k = data["z_image"][k] 111 | rdict["labels_image"][k].append(labels_k) 112 | 113 | # collect text labels 114 | for k in rdict["labels_text"]: 115 | labels_k = data["z_text"][k] 116 | rdict["labels_text"][k].append(labels_k) 117 | 118 | # concatenate each list of values along the batch dimension 119 | for k, v in rdict.items(): 120 | if type(v) == list: 121 | rdict[k] = np.concatenate(v, axis=0) 122 | elif type(v) == dict: 123 | for k2, v2 in v.items(): 124 | rdict[k][k2] = np.concatenate(v2, axis=0) 125 | return rdict 126 | 127 | 128 | def evaluate_prediction(model, metric, X_train, y_train, X_test, y_test): 129 | model.fit(X_train, y_train) 130 | y_pred = model.predict(X_test) 131 | return metric(y_test, y_pred) 132 | 133 | 134 | def main(): 135 | 136 | # parse args 137 | args, parser = parse_args() 138 | 139 | # create save_dir, where the model/results are or will be saved 140 | if args.model_id is None: 141 | setattr(args, "model_id", uuid.uuid4()) 142 | args.save_dir = os.path.join(args.model_dir, args.model_id) 143 | if not os.path.exists(args.save_dir): 144 | os.makedirs(args.save_dir) 145 | 146 | # optionally, reuse existing arguments from args.json (only for evaluation) 147 | if args.evaluate and args.load_args: 148 | with open(os.path.join(args.save_dir, 'args.json'), 'r') as fp: 149 | loaded_args = json.load(fp) 150 | arguments_to_load = ["encoding_size", "hidden_size"] 151 | for arg in arguments_to_load: 152 | setattr(args, arg, loaded_args[arg]) 153 | # NOTE: Any new arguments that shall be automatically loaded for the 154 | # evaluation of a trained model must be added to 'arguments_to_load'. 155 | 156 | # print args 157 | print("Arguments:") 158 | for k, v in vars(args).items(): 159 | print(f"\t{k}: {v}") 160 | 161 | # save args to disk (only for training) 162 | if not args.evaluate: 163 | with open(os.path.join(args.save_dir, 'args.json'), 'w') as fp: 164 | json.dump(args.__dict__, fp) 165 | 166 | # set all seeds 167 | if args.seed is not None: 168 | np.random.seed(args.seed) 169 | random.seed(args.seed) 170 | torch.manual_seed(args.seed) 171 | 172 | # set device 173 | if torch.cuda.is_available() and not args.no_cuda: 174 | device = "cuda" 175 | else: 176 | device = "cpu" 177 | warnings.warn("cuda is not available or --no-cuda was set.") 178 | 179 | # define similarity metric and loss function 180 | sim_metric = torch.nn.CosineSimilarity(dim=-1) 181 | criterion = torch.nn.CrossEntropyLoss() 182 | loss_func = lambda z1, z2: infonce_loss( 183 | z1, z2, sim_metric=sim_metric, criterion=criterion, tau=args.tau) 184 | 185 | # define augmentations (only normalization of the input images) 186 | mean_per_channel = [0.4327, 0.2689, 0.2839] # values from 3DIdent 187 | std_per_channel = [0.1201, 0.1457, 0.1082] # values from 3DIdent 188 | transform = transforms.Compose([ 189 | transforms.ToTensor(), 190 | transforms.Normalize(mean_per_channel, std_per_channel)]) 191 | 192 | # define kwargs 193 | dataset_kwargs = {"transform": transform} 194 | dataloader_kwargs = { 195 | "batch_size": args.batch_size, "shuffle": True, "drop_last": True, 196 | "num_workers": args.workers, "pin_memory": True} 197 | 198 | # define dataloaders 199 | train_dataset = Multimodal3DIdent(args.datapath, mode="train", **dataset_kwargs) 200 | vocab_filepath = train_dataset.vocab_filepath 201 | if args.evaluate: 202 | val_dataset = Multimodal3DIdent(args.datapath, mode="val", 203 | vocab_filepath=vocab_filepath, 204 | **dataset_kwargs) 205 | test_dataset = Multimodal3DIdent(args.datapath, mode="test", 206 | vocab_filepath=vocab_filepath, 207 | **dataset_kwargs) 208 | else: 209 | train_loader = DataLoader(train_dataset, **dataloader_kwargs) 210 | train_iterator = InfiniteIterator(train_loader) 211 | 212 | # define image encoder 213 | encoder_img = torch.nn.Sequential( 214 | resnet18(num_classes=args.hidden_size), 215 | torch.nn.LeakyReLU(), 216 | torch.nn.Linear(args.hidden_size, args.encoding_size)) 217 | encoder_img = torch.nn.DataParallel(encoder_img) 218 | encoder_img.to(device) 219 | 220 | # define text encoder 221 | sequence_length = train_dataset.max_sequence_length 222 | encoder_txt = TextEncoder2D( 223 | input_size=train_dataset.vocab_size, 224 | output_size=args.encoding_size, 225 | sequence_length=sequence_length) 226 | encoder_txt = torch.nn.DataParallel(encoder_txt) 227 | encoder_txt.to(device) 228 | 229 | # for evaluation, always load saved encoders 230 | if args.evaluate: 231 | path_img = os.path.join(args.save_dir, "encoder_img.pt") 232 | path_txt = os.path.join(args.save_dir, "encoder_txt.pt") 233 | encoder_img.load_state_dict(torch.load(path_img, map_location=device)) 234 | encoder_txt.load_state_dict(torch.load(path_txt, map_location=device)) 235 | 236 | # define the optimizer 237 | params = list(encoder_img.parameters())+list(encoder_txt.parameters()) 238 | optimizer = torch.optim.Adam(params, lr=args.lr) 239 | 240 | # training 241 | # -------- 242 | if not args.evaluate: 243 | 244 | # training loop 245 | step = 1 246 | loss_values = [] # list to keep track of loss values 247 | while (step <= args.train_steps): 248 | 249 | # training step 250 | data = next(train_iterator) # contains images, texts, and labels 251 | loss_value = train_step(data, encoder_img, encoder_txt, loss_func, optimizer, params) 252 | loss_values.append(loss_value) 253 | 254 | # print average loss value 255 | if step % args.log_steps == 1 or step == args.train_steps: 256 | print(f"Step: {step} \t", 257 | f"Loss: {loss_value:.4f} \t", 258 | f": {np.mean(loss_values[-args.log_steps:]):.4f} \t") 259 | 260 | # save models and intermediate checkpoints 261 | if step % args.checkpoint_steps == 1 or step == args.train_steps: 262 | torch.save(encoder_img.state_dict(), os.path.join(args.save_dir, "encoder_img.pt")) 263 | torch.save(encoder_txt.state_dict(), os.path.join(args.save_dir, "encoder_txt.pt")) 264 | if args.save_all_checkpoints: 265 | torch.save(encoder_img.state_dict(), os.path.join(args.save_dir, "encoder_img_%d.pt" % step)) 266 | torch.save(encoder_txt.state_dict(), os.path.join(args.save_dir, "encoder_txt_%d.pt" % step)) 267 | step += 1 268 | 269 | # evaluation 270 | # ---------- 271 | if args.evaluate: 272 | 273 | # collect encodings and labels from the validation and test data 274 | val_dict = get_data(val_dataset, encoder_img, encoder_txt, loss_func, dataloader_kwargs) 275 | test_dict = get_data(test_dataset, encoder_img, encoder_txt, loss_func, dataloader_kwargs) 276 | 277 | # print average loss values 278 | print(f": {np.mean(val_dict['loss_values']):.4f}") 279 | print(f": {np.mean(test_dict['loss_values']):.4f}") 280 | 281 | # handle edge case when the encodings are 1-dimensional 282 | if args.encoding_size == 1: 283 | for m in ["image", "text"]: 284 | val_dict[f"hz_{m}"] = val_dict[f"hz_{m}"].reshape(-1, 1) 285 | test_dict[f"hz_{m}"] = test_dict[f"hz_{m}"].reshape(-1, 1) 286 | 287 | # standardize the encodings 288 | for m in ["image", "text"]: 289 | scaler = StandardScaler() 290 | val_dict[f"hz_{m}"] = scaler.fit_transform(val_dict[f"hz_{m}"]) 291 | test_dict[f"hz_{m}"] = scaler.transform(test_dict[f"hz_{m}"]) 292 | 293 | # evaluate how well each factor can be predicted from the encodings 294 | results = [] 295 | for m in ["image", "text"]: 296 | factors_m = Multimodal3DIdent.FACTORS[m] 297 | discrete_factors_m = Multimodal3DIdent.DISCRETE_FACTORS[m] 298 | for ix, factor_name in factors_m.items(): 299 | 300 | # select data 301 | train_inputs = val_dict[f"hz_{m}"] 302 | test_inputs = test_dict[f"hz_{m}"] 303 | train_labels = val_dict[f"labels_{m}"][factor_name] 304 | test_labels = test_dict[f"labels_{m}"][factor_name] 305 | data = [train_inputs, train_labels, test_inputs, test_labels] 306 | r2_linreg, r2_krreg, acc_logreg, acc_mlp = [np.nan] * 4 307 | 308 | # check if factor ix is discrete for modality m 309 | if ix in discrete_factors_m: 310 | factor_type = "discrete" 311 | else: 312 | factor_type = "continuous" 313 | 314 | # for continuous factors, do regression and compute R2 score 315 | if factor_type == "continuous": 316 | # linear regression 317 | linreg = LinearRegression(n_jobs=-1) 318 | r2_linreg = evaluate_prediction(linreg, r2_score, *data) 319 | # nonlinear regression 320 | gskrreg = GridSearchCV( 321 | KernelRidge(kernel='rbf', gamma=0.1), 322 | param_grid={"alpha": [1e0, 0.1, 1e-2, 1e-3], 323 | "gamma": np.logspace(-2, 2, 4)}, 324 | cv=3, n_jobs=-1) 325 | r2_krreg = evaluate_prediction(gskrreg, r2_score, *data) 326 | # NOTE: MLP is a lightweight alternative 327 | # r2_krreg = evaluate_prediction( 328 | # MLPRegressor(max_iter=1000), r2_score, *data) 329 | 330 | # for discrete factors, do classification and compute accuracy 331 | if factor_type == "discrete": 332 | # logistic classification 333 | logreg = LogisticRegression(n_jobs=-1, max_iter=1000) 334 | acc_logreg = evaluate_prediction(logreg, accuracy_score, *data) 335 | # nonlinear classification 336 | mlpreg = MLPClassifier(max_iter=1000) 337 | acc_mlp = evaluate_prediction(mlpreg, accuracy_score, *data) 338 | 339 | # append results 340 | results.append([ix, m, factor_name, factor_type, 341 | r2_linreg, r2_krreg, acc_logreg, acc_mlp]) 342 | 343 | # convert evaluation results into tabular form 344 | columns = ["ix", "modality", "factor_name", "factor_type", 345 | "r2_linreg", "r2_krreg", "acc_logreg", "acc_mlp"] 346 | df_results = pd.DataFrame(results, columns=columns) 347 | df_results.to_csv(os.path.join(args.save_dir, "results.csv")) 348 | print(df_results.to_string()) 349 | 350 | 351 | if __name__ == "__main__": 352 | main() 353 | -------------------------------------------------------------------------------- /main_mlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Numerical simulation for the multimodal setup. 3 | 4 | This code builds on the following projects: 5 | - https://github.com/brendel-group/cl-ica 6 | - https://github.com/ysharma1126/ssl_identifiability 7 | """ 8 | 9 | import argparse 10 | import json 11 | import os 12 | import random 13 | import uuid 14 | import warnings 15 | 16 | import numpy as np 17 | import pandas as pd 18 | import torch 19 | from scipy.stats import wishart 20 | from sklearn import kernel_ridge, linear_model 21 | from sklearn.metrics import r2_score 22 | from sklearn.model_selection import GridSearchCV, train_test_split 23 | from sklearn.neural_network import MLPRegressor 24 | from sklearn.preprocessing import StandardScaler 25 | from torch.nn.utils import clip_grad_norm_ 26 | 27 | import encoders 28 | from utils.invertible_network_utils import construct_invertible_mlp 29 | from utils.latent_spaces import LatentSpace, NRealSpace, ProductLatentSpace 30 | from utils.losses import LpSimCLRLoss 31 | 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--model-dir", type=str, default="models") 36 | parser.add_argument("--model-id", type=str, default=None) 37 | parser.add_argument("--encoding-size", type=int, default=5) 38 | parser.add_argument("--content-n", type=int, default=5) 39 | parser.add_argument("--style-n", type=int, default=5) 40 | parser.add_argument("--modality-n", type=int, default=5) 41 | parser.add_argument("--style-change-prob", type=float, default=1.0) 42 | parser.add_argument("--statistical-dependence", action='store_true') 43 | parser.add_argument("--content-dependent-style", action='store_true') 44 | parser.add_argument("--c-param", type=float, default=1.0) 45 | parser.add_argument("--m-param", type=float, default=1.0) 46 | parser.add_argument("--n-mixing-layer", type=int, default=3) 47 | parser.add_argument("--shared-mixing", action='store_true') 48 | parser.add_argument("--shared-encoder", action='store_true') 49 | parser.add_argument("--seed", type=int, default=np.random.randint(32**2-1)) 50 | parser.add_argument("--batch-size", type=int, default=6144) 51 | parser.add_argument("--lr", type=float, default=1e-4) 52 | parser.add_argument("--train-steps", type=int, default=300001) 53 | parser.add_argument("--log-steps", type=int, default=1000) 54 | parser.add_argument("--evaluate", action='store_true') 55 | parser.add_argument("--num-eval-batches", type=int, default=5) 56 | parser.add_argument("--permuted-content", action="store_true") 57 | parser.add_argument("--mlp-eval", action="store_true") 58 | parser.add_argument("--no-cuda", action="store_true") 59 | parser.add_argument("--load-args", action="store_true") 60 | args = parser.parse_args() 61 | 62 | return args, parser 63 | 64 | 65 | def train_step(data, h1, h2, loss_func, optimizer, params): 66 | 67 | # reset grad 68 | if optimizer is not None: 69 | optimizer.zero_grad() 70 | 71 | # compute symmetrized loss 72 | z1, z2, z1_, z2_ = data 73 | hz1 = h1(z1) 74 | hz1_ = h1(z1_) 75 | hz2 = h2(z2) 76 | hz2_ = h2(z2_) 77 | total_loss_value1, _, _ = loss_func(z1, z2, z1_, hz1, hz2, hz1_) 78 | total_loss_value2, _, _ = loss_func(z2, z1, z2_, hz2, hz1, hz2_) 79 | total_loss_value = 0.5 * (total_loss_value1 + total_loss_value2) 80 | 81 | # backprop 82 | if optimizer is not None: 83 | total_loss_value.backward() 84 | clip_grad_norm_(params, max_norm=2.0, norm_type=2) # stabilizes training 85 | optimizer.step() 86 | 87 | return total_loss_value.item() 88 | 89 | 90 | def val_step(data, h1, h2, loss_func): 91 | return train_step(data, h1, h2, loss_func, optimizer=None, params=None) 92 | 93 | 94 | def generate_data(latent_space, h1, h2, device, num_batches=1, batch_size=4096, 95 | loss_func=None, permuted_content=False): 96 | 97 | rdict = {k: [] for k in 98 | ["c", "s", "s~", "m1", "m2", "c'", "hz1", "hz2", "loss_values"]} 99 | with torch.no_grad(): 100 | for _ in range(num_batches): 101 | 102 | # sample batch of latents 103 | z1, z2 = latent_space.sample_z1_and_z2(batch_size, device) 104 | 105 | # compute representations 106 | hz1 = h1(z1) 107 | hz2 = h2(z2) 108 | if permuted_content: 109 | nc = latent_space.content_n 110 | z1_intervened = z1.clone() 111 | z2_intervened = z2.clone() 112 | perm = torch.randperm(len(z1)) 113 | z1_intervened[:, :nc] = z1_intervened[perm, :nc] 114 | z2_intervened[:, :nc] = z2_intervened[perm, :nc] 115 | hz1 = h1(z1_intervened) 116 | hz2 = h2(z2_intervened) 117 | c_perm = z1_intervened[:, 0:nc] 118 | 119 | # compute loss 120 | if loss_func is not None: 121 | z1_, z2_ = latent_space.sample_z1_and_z2(batch_size, device) 122 | data = [z1, z2, z1_, z2_] 123 | loss_value = val_step(data, h1, h2, loss_func) 124 | rdict["loss_values"].append([loss_value]) 125 | 126 | # partition latents into content, style, modality-specific factors 127 | c, s1, m1 = latent_space.zi_to_csmi(z1) 128 | _, s2, m2 = latent_space.zi_to_csmi(z2) # NOTE: same content c 129 | 130 | # collect labels and representations 131 | rdict["c"].append(c.detach().cpu().numpy()) 132 | rdict["s"].append(s1.detach().cpu().numpy()) 133 | rdict["s~"].append(s2.detach().cpu().numpy()) 134 | rdict["m1"].append(m1.detach().cpu().numpy()) 135 | rdict["m2"].append(m2.detach().cpu().numpy()) 136 | rdict["hz1"].append(hz1.detach().cpu().numpy()) 137 | rdict["hz2"].append(hz2.detach().cpu().numpy()) 138 | if permuted_content: 139 | rdict["c'"].append(c_perm.detach().cpu().numpy()) 140 | 141 | # concatenate each list of values along the batch dimension 142 | for k, v in rdict.items(): 143 | if len(v) > 0: 144 | v = np.concatenate(v, axis=0) 145 | rdict[k] = np.array(v) 146 | return rdict 147 | 148 | 149 | def evaluate_prediction(model, metric, X_train, y_train, X_test, y_test): 150 | # handle edge cases when inputs or labels are zero-dimensional 151 | if any([0 in x.shape for x in [X_train, y_train, X_test, y_test]]): 152 | return np.nan 153 | assert X_train.shape[1] == X_test.shape[1] 154 | assert y_train.shape[1] == y_test.shape[1] 155 | # handle edge cases when the inputs are one-dimensional 156 | if X_train.shape[1] == 1: 157 | X_train = X_train.reshape(-1, 1) 158 | model.fit(X_train, y_train) 159 | y_pred = model.predict(X_test) 160 | return metric(y_test, y_pred) 161 | 162 | 163 | def main(): 164 | 165 | # parse args 166 | args, _ = parse_args() 167 | 168 | # create save_dir, where the model/results are or will be saved 169 | if args.model_id is None: 170 | setattr(args, "model_id", uuid.uuid4()) 171 | args.save_dir = os.path.join(args.model_dir, args.model_id) 172 | if not os.path.exists(args.save_dir): 173 | os.makedirs(args.save_dir) 174 | 175 | # optionally, reuse existing arguments from args.json (only for evaluation) 176 | if args.evaluate and args.load_args: 177 | with open(os.path.join(args.save_dir, 'args.json'), 'r') as fp: 178 | loaded_args = json.load(fp) 179 | arguments_to_load = [ 180 | "style_change_prob", "statistical_dependence", 181 | "content_dependent_style", "modality_n", "style_n", "content_n", 182 | "encoding_size", "shared_mixing", "n_mixing_layer", "shared_encoder", 183 | "c_param", "m_param"] 184 | for arg in arguments_to_load: 185 | setattr(args, arg, loaded_args[arg]) 186 | # NOTE: Any new arguments that shall be automatically loaded for the 187 | # evaluation of a trained model must be added to 'arguments_to_load'. 188 | 189 | # print args 190 | print("Arguments:") 191 | for k, v in vars(args).items(): 192 | print(f"\t{k}: {v}") 193 | 194 | # save args to disk (only for training) 195 | if not args.evaluate: 196 | with open(os.path.join(args.save_dir, 'args.json'), 'w') as fp: 197 | json.dump(args.__dict__, fp) 198 | 199 | # set all seeds 200 | np.random.seed(args.seed) 201 | random.seed(args.seed) 202 | torch.manual_seed(args.seed) 203 | 204 | # load training seed, which ensures consistent latent spaces for evaluation 205 | if args.evaluate: 206 | with open(os.path.join(args.save_dir, 'args.json'), 'r') as fp: 207 | train_seed = json.load(fp)["seed"] 208 | assert args.seed != train_seed 209 | else: 210 | train_seed = args.seed 211 | 212 | # set device 213 | if torch.cuda.is_available() and not args.no_cuda: 214 | device = "cuda" 215 | else: 216 | device = "cpu" 217 | warnings.warn("cuda is not available or --no-cuda was set.") 218 | 219 | # define loss function 220 | loss_func = LpSimCLRLoss() 221 | 222 | # shorthand notation for dimensionality 223 | nc, ns, nm = args.content_n, args.style_n, args.modality_n 224 | 225 | # define latents 226 | latent_spaces_list = [] 227 | Sigma_c, Sigma_s, Sigma_a, Sigma_m1, Sigma_m2, a, B = [None] * 7 228 | rgen = torch.Generator(device=device) 229 | rgen.manual_seed(train_seed) # ensures same latents for train and eval 230 | if args.statistical_dependence: 231 | Sigma_c = wishart.rvs(nc, np.eye(nc), size=1, random_state=train_seed) 232 | Sigma_s = wishart.rvs(ns, np.eye(ns), size=1, random_state=train_seed) 233 | Sigma_a = wishart.rvs(ns, np.eye(ns), size=1, random_state=train_seed) 234 | Sigma_m1 = wishart.rvs(nm, np.eye(nm), size=1, random_state=train_seed) 235 | Sigma_m2 = wishart.rvs(nm, np.eye(nm), size=1, random_state=train_seed) 236 | if args.content_dependent_style: 237 | B = torch.randn(ns, nc, device=device, generator=rgen) 238 | a = torch.randn(ns, device=device, generator=rgen) 239 | # content c 240 | space_c = NRealSpace(nc) 241 | sample_marginal_c = lambda space, size, device=device: \ 242 | space.normal(None, args.m_param, size, device, Sigma=Sigma_c) 243 | sample_conditional_c = lambda space, z, size, device=device: z 244 | latent_spaces_list.append(LatentSpace( 245 | space=space_c, 246 | sample_marginal=sample_marginal_c, 247 | sample_conditional=sample_conditional_c)) 248 | # style s 249 | space_s = NRealSpace(ns) 250 | sample_marginal_s = lambda space, size, device=device: \ 251 | space.normal(None, args.m_param, size, device, Sigma=Sigma_s) 252 | sample_conditional_s = lambda space, z, size, device=device: \ 253 | space.normal(z, args.c_param, size, device, 254 | change_prob=args.style_change_prob, Sigma=Sigma_a) 255 | latent_spaces_list.append(LatentSpace( 256 | space=space_s, 257 | sample_marginal=sample_marginal_s, 258 | sample_conditional=sample_conditional_s)) 259 | # modality-specific m1 and m2 260 | if nm > 0: 261 | space_m1 = NRealSpace(nm) 262 | sample_marginal_m1 = lambda space, size, device=device: \ 263 | space.normal(None, args.m_param, size, device, Sigma=Sigma_m1) 264 | sample_conditional_m1 = lambda space, z, size, device=device: z 265 | latent_spaces_list.append(LatentSpace( 266 | space=space_m1, 267 | sample_marginal=sample_marginal_m1, 268 | sample_conditional=sample_conditional_m1)) 269 | space_m2 = NRealSpace(nm) 270 | sample_marginal_m2 = lambda space, size, device=device: \ 271 | space.normal(None, args.m_param, size, device, Sigma=Sigma_m2) 272 | sample_conditional_m2 = lambda space, z, size, device=device: z 273 | latent_spaces_list.append(LatentSpace( 274 | space=space_m2, 275 | sample_marginal=sample_marginal_m2, 276 | sample_conditional=sample_conditional_m2)) 277 | # combine latents 278 | latent_space = ProductLatentSpace(spaces=latent_spaces_list, a=a, B=B) 279 | 280 | # define mixing functions 281 | f1 = construct_invertible_mlp( 282 | n=nc + ns + nm, 283 | n_layers=args.n_mixing_layer, 284 | cond_thresh_ratio=0.001, 285 | n_iter_cond_thresh=25000) 286 | f1 = f1.to(device) 287 | f2 = construct_invertible_mlp( 288 | n=nc + ns + nm, 289 | n_layers=args.n_mixing_layer, 290 | cond_thresh_ratio=0.001, 291 | n_iter_cond_thresh=25000) 292 | f2 = f2.to(device) 293 | # for evaluation, always load saved mixing functions 294 | if args.evaluate: 295 | f1_path = os.path.join(args.save_dir, 'f1.pt') 296 | f1.load_state_dict(torch.load(f1_path, map_location=device)) 297 | f2_path = os.path.join(args.save_dir, 'f2.pt') 298 | f2.load_state_dict(torch.load(f2_path, map_location=device)) 299 | # freeze parameters 300 | for p in f1.parameters(): 301 | p.requires_grad = False 302 | for p in f2.parameters(): 303 | p.requires_grad = False 304 | # save mixing functions to disk 305 | if args.save_dir and not args.evaluate: 306 | if not os.path.exists(args.save_dir): 307 | os.makedirs(args.save_dir) 308 | torch.save(f1.state_dict(), os.path.join(args.save_dir, "f1.pt")) 309 | torch.save(f2.state_dict(), os.path.join(args.save_dir, "f2.pt")) 310 | 311 | # define encoders 312 | g1 = encoders.get_mlp( 313 | n_in=nc + ns + nm, 314 | n_out=args.encoding_size, 315 | layers=[(nc + ns + nm) * 10, 316 | (nc + ns + nm) * 50, 317 | (nc + ns + nm) * 50, 318 | (nc + ns + nm) * 50, 319 | (nc + ns + nm) * 50, 320 | (nc + ns + nm) * 10]) 321 | g1 = g1.to(device) 322 | g2 = encoders.get_mlp( 323 | n_in=nc + ns + nm, 324 | n_out=args.encoding_size, 325 | layers=[(nc + ns + nm) * 10, 326 | (nc + ns + nm) * 50, 327 | (nc + ns + nm) * 50, 328 | (nc + ns + nm) * 50, 329 | (nc + ns + nm) * 50, 330 | (nc + ns + nm) * 10]) 331 | g2 = g2.to(device) 332 | # for evaluation, always load saved encoders 333 | if args.evaluate: 334 | g1_path = os.path.join(args.save_dir, 'g1.pt') 335 | g1.load_state_dict(torch.load(g1_path, map_location=device)) 336 | g2_path = os.path.join(args.save_dir, 'g2.pt') 337 | g2.load_state_dict(torch.load(g2_path, map_location=device)) 338 | 339 | # for convenience, define h as a composition of mixing function and encoder 340 | if args.shared_mixing: 341 | f2 = f1 # overwrites the second mixing function 342 | if args.shared_encoder: 343 | g2 = g1 # overwrites the second encoder 344 | h1 = lambda z: g1(f1(z)) 345 | h2 = lambda z: g2(f2(z)) 346 | 347 | # define optimizer 348 | if not args.evaluate: 349 | if args.shared_encoder: 350 | params = list(g1.parameters()) 351 | else: 352 | params = list(g1.parameters()) + list(g2.parameters()) 353 | optimizer = torch.optim.Adam(params, lr=args.lr) 354 | 355 | # training 356 | # -------- 357 | step = 1 358 | while step <= args.train_steps and not args.evaluate: 359 | 360 | # training step 361 | z1, z2 = latent_space.sample_z1_and_z2(args.batch_size, device) 362 | z1_, z2_ = latent_space.sample_z1_and_z2(args.batch_size, device) 363 | data = [z1, z2, z1_, z2_] 364 | train_step(data, h1, h2, loss_func, optimizer, params) 365 | 366 | # every log_steps, we have a checkpoint and small evaluation 367 | if step % args.log_steps == 1 or step == args.train_steps: 368 | 369 | # save encoders to disk 370 | if args.save_dir and not args.evaluate: 371 | torch.save(g1.state_dict(), os.path.join(args.save_dir, "g1.pt")) 372 | torch.save(g2.state_dict(), os.path.join(args.save_dir, "g2.pt")) 373 | 374 | # lightweight evaluation with linear classifiers 375 | print(f"\nStep: {step} \t") 376 | data_dict = generate_data(latent_space, h1, h2, device, loss_func=loss_func) 377 | print(f": {np.mean(data_dict['loss_values']):.4f} \t") 378 | data_dict["hz1"] = StandardScaler().fit_transform(data_dict["hz1"]) 379 | for k in ["c", "s", "s~", "m1", "m2"]: 380 | inputs, labels = data_dict["hz1"], data_dict[k] 381 | train_inputs, test_inputs, train_labels, test_labels = \ 382 | train_test_split(inputs, labels) 383 | data = [train_inputs, train_labels, test_inputs, test_labels] 384 | r2_linear = evaluate_prediction( 385 | linear_model.LinearRegression(n_jobs=-1), r2_score, *data) 386 | print(f"{k} r2_linear: {r2_linear}") 387 | step += 1 388 | 389 | # evaluation 390 | # ---------- 391 | if args.evaluate: 392 | 393 | # generate encodings and labels for the validation and test data 394 | val_dict = generate_data( 395 | latent_space, h1, h2, device, 396 | num_batches=args.num_eval_batches, 397 | loss_func=loss_func, 398 | permuted_content=args.permuted_content) 399 | test_dict = generate_data( 400 | latent_space, h1, h2, device, 401 | num_batches=args.num_eval_batches, 402 | loss_func=loss_func, 403 | permuted_content=args.permuted_content) 404 | 405 | # print average loss value 406 | print(f": {np.mean(val_dict['loss_values']):.4f} \t") 407 | print(f": {np.mean(test_dict['loss_values']):.4f} \t") 408 | 409 | # standardize the encodings 410 | for m in [1, 2]: 411 | scaler = StandardScaler() 412 | val_dict[f"hz{m}"] = scaler.fit_transform(val_dict[f"hz{m}"]) 413 | test_dict[f"hz{m}"] = scaler.transform(test_dict[f"hz{m}"]) 414 | 415 | # train predictors on data from val_dict and evaluate on test_dict 416 | results = [] 417 | for m in [1, 2]: 418 | for k in ["c", "s", "s~", "m1", "m2", "c'"]: 419 | 420 | # select data 421 | train_inputs, test_inputs = val_dict[f"hz{m}"], test_dict[f"hz{m}"] 422 | train_labels, test_labels = val_dict[k], test_dict[k] 423 | data = [train_inputs, train_labels, test_inputs, test_labels] 424 | 425 | # linear regression 426 | r2_linear = evaluate_prediction( 427 | linear_model.LinearRegression(n_jobs=-1), r2_score, *data) 428 | 429 | # nonlinear regression 430 | if args.mlp_eval: 431 | model = MLPRegressor(max_iter=1000) # lightweight option 432 | else: 433 | # grid search is time- and memory-intensive 434 | model = GridSearchCV( 435 | kernel_ridge.KernelRidge(kernel='rbf', gamma=0.1), 436 | param_grid={"alpha": [1e0, 0.1, 1e-2, 1e-3], 437 | "gamma": np.logspace(-2, 2, 4)}, 438 | cv=3, n_jobs=-1) 439 | r2_nonlinear = evaluate_prediction(model, r2_score, *data) 440 | 441 | # append results 442 | results.append((f"hz{m}", k, r2_linear, r2_nonlinear)) 443 | 444 | # convert evaluation results into tabular form 445 | cols = ["encoding", "predicted_factors", "r2_linear", "r2_nonlinear"] 446 | df_results = pd.DataFrame(results, columns=cols) 447 | df_results.to_csv(os.path.join(args.save_dir, "results.csv")) 448 | print("Regression results:") 449 | print(df_results.to_string()) 450 | 451 | 452 | if __name__ == "__main__": 453 | main() 454 | --------------------------------------------------------------------------------