├── .gitignore ├── LICENSE ├── README.md ├── data └── processed │ ├── .gitignore │ ├── blobs_overlap.npz │ ├── blobs_overlap_5.npz │ ├── rgbd.npz │ └── voc.npz ├── environment.yml ├── models └── voc │ ├── CoMVC │ ├── best.pt │ └── config.pkl │ └── SiMVC │ ├── best.pt │ └── config.pkl ├── requirements.txt └── src ├── config ├── __init__.py ├── config.py ├── constants.py ├── defaults.py ├── eamc │ ├── __init__.py │ ├── defaults.py │ └── experiments.py └── experiments │ ├── __init__.py │ ├── ccv.py │ ├── coil.py │ ├── fmnist.py │ ├── mnist.py │ ├── rgbd.py │ ├── test.py │ └── voc.py ├── data ├── load.py └── make_dataset.py ├── eamc ├── loss.py └── model.py ├── helpers.py ├── lib ├── backbones.py ├── fusion.py ├── kernel.py ├── loss.py └── optimizer.py ├── models ├── build_model.py ├── callback.py ├── clustering_module.py ├── contrastive_mvc.py ├── ddc.py ├── evaluate.py ├── model_base.py ├── simple_mvc.py └── train.py └── scripts ├── comvc_ablation.sh ├── ddc_loss_ablation.sh └── mnist_noise.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | #lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | *.DS_Store 132 | *.idea/ 133 | 134 | data/raw/* 135 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 DanielTrosten 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SiMVC & CoMVC 2 | 3 | This repository provides the implementations of SiMVC and CoMVC, presented in the paper: 4 | 5 | "Reconsidering Representation Alignment for Multi-view Clustering" by 6 | Daniel J. Trosten, Sigurd Løkse, Robert Jenssen and Michael Kampffmeyer, in _CVPR 2021_. 7 | 8 | BibTeX: 9 | ```text 10 | @inproceedings{trostenMVC, 11 | title = {Reconsidering Representation Alignment for Multi-view Clustering}, 12 | author = {Daniel J. Trosten and Sigurd Løkse and Robert Jenssen and Michael Kampffmeyer}, 13 | year = 2021, 14 | booktitle = {2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)} 15 | } 16 | ``` 17 | 18 | 19 | ## Installation 20 | Requires Python >= 3.7 (tested on 3.8) 21 | 22 | To install the required packages, run: 23 | ``` 24 | pip install -r requirements.txt 25 | ``` 26 | from the root directory of the repository. Anaconda (or similar) is recommended. 27 | 28 | ## Datasets 29 | ### Included dataset 30 | The following datasets are included as files in this project: 31 | 32 | - `voc` (VOC) 33 | - `rgbd` (RGB-D) 34 | - `blobs_overlap_5` (Toy dataset with 5 clusters) 35 | - `blobs_overlap` (Toy dataset with 3 clusters) 36 | 37 | ### Generating datasets 38 | To generate training-ready datasets, run: 39 | ``` 40 | python -m data.make_dataset <...> 41 | ``` 42 | This will export the training-ready datasets to `data/processed/.npz`. 43 | 44 | Currently, the following datasets can be generated without downloading additional files: 45 | 46 | - `mnist_mv` (E-MNIST) 47 | - `fmnist` (E-FMNIST) 48 | 49 | ### Datasets that require additional downloads 50 | 51 | - `ccv` (CCV): Download the files from [here](https://www.ee.columbia.edu/ln/dvmm/CCV/), and place them in 52 | `data/raw/CCV`. 53 | - `coil` (COIL-20). Download the processed files from 54 | [here](https://www.cs.columbia.edu/CAVE/software/softlib/coil-20.php), and place them in `data/raw/COIL`. 55 | 56 | After downloading and extracting the files, run 57 | ``` Bash 58 | python -m data.make_dataset ccv coil 59 | ``` 60 | to generate training-ready versions of CCV and COIL-20. 61 | 62 | ### Preparing a custom dataset for training 63 | Create `.npz` in `data/processed/` with the following keys: 64 | ``` 65 | n_views: The number of views, V 66 | labels: One-dimensional array of labels. Shape (n,) 67 | view_0: Data for first view. Shape (n, ...) 68 | . 69 | . 70 | . 71 | view_V: Data for view V. Shape (n, ...) 72 | ``` 73 | Alternatively, call 74 | ```Python 75 | data.make_dataset.export_dataset( 76 | "", # Name of the dataset 77 | views, # List of view-arrays 78 | labels # Label array 79 | ) 80 | ``` 81 | This will automatically export the dataset to an `.npz` file at the correct location. 82 | 83 | Then, in the Experiment-config, set 84 | ```Python 85 | dataset_config=Dataset("") 86 | ``` 87 | 88 | ## Experiment configuration 89 | Experiment configs are nested configuration objects, where the top-level config is an instance of 90 | `config.defaults.Experiment`. 91 | 92 | The configuration object for the contrastive model on E-MNIST, for instance, looks like this: 93 | ```Python 94 | from config.defaults import Experiment, CNN, DDC, Fusion, Loss, Dataset, CoMVC, Optimizer 95 | 96 | 97 | mnist_contrast = Experiment( 98 | dataset_config=Dataset(name="mnist_mv"), 99 | model_config=CoMVC( 100 | backbone_configs=( 101 | CNN(input_size=(1, 28, 28)), 102 | CNN(input_size=(1, 28, 28)), 103 | ), 104 | fusion_config=Fusion(method="weighted_mean", n_views=2), 105 | projector_config=None, 106 | cm_config=DDC(n_clusters=10), 107 | loss_config=Loss( 108 | funcs="ddc_1|ddc_2|ddc_3|contrast", 109 | # Additional loss parameters go here 110 | ), 111 | optimizer_config=Optimizer( 112 | learning_rate=1e-3, 113 | # Additional optimizer parameters go here 114 | ) 115 | ), 116 | n_epochs=100, 117 | n_runs=20, 118 | ) 119 | ``` 120 | 121 | ## Running an experiment 122 | In the `src` directory, run: 123 | ``` 124 | python -m models.train -c 125 | ``` 126 | where `` is the name of an experiment config from one of the files in `src/config/experiments/` or from 127 | 'src/config/eamc/experiments.py' (for EAMC experiments). 128 | 129 | ### Overriding config parameters at the command-line 130 | Parameters set in the config object can be overridden at the command line. For instance, if we want to change the 131 | learning rate for the E-MNIST experiment below from 0.001 to 0.0001, and the number of epochs from 100 to 200, 132 | we can run: 133 | ``` 134 | python -m models.train -c mnist_contrast \ 135 | --n_epochs 200 \ 136 | --model_config__optimizer_config__learning_rate 0.0001 137 | ``` 138 | Note the double underscores to traverse the hierarchy of the config-object. 139 | 140 | ## Evaluating an experiment 141 | Run the evaluation script: 142 | ```Bash 143 | python -m models.evaluate -c \ # Name of the experiment config 144 | -t \ # The unique 8-character ID assigned to the experiment when calling models.train 145 | --plot # Optional flag to plot the representations before and after fusion. 146 | ``` 147 | 148 | ## Ablation studies and noise experiment 149 | To run one of these experiments, execute the corresponding script in the `src/scripts` directory. 150 | -------------------------------------------------------------------------------- /data/processed/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !rgbd.npz 4 | !voc.npz 5 | !blobs_overlap.npz 6 | !blobs_overlap_5.npz -------------------------------------------------------------------------------- /data/processed/blobs_overlap.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielTrosten/mvc/6e50f6a7bf7338ccb02935a7e273e4e2bbf43370/data/processed/blobs_overlap.npz -------------------------------------------------------------------------------- /data/processed/blobs_overlap_5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielTrosten/mvc/6e50f6a7bf7338ccb02935a7e273e4e2bbf43370/data/processed/blobs_overlap_5.npz -------------------------------------------------------------------------------- /data/processed/rgbd.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielTrosten/mvc/6e50f6a7bf7338ccb02935a7e273e4e2bbf43370/data/processed/rgbd.npz -------------------------------------------------------------------------------- /data/processed/voc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielTrosten/mvc/6e50f6a7bf7338ccb02935a7e273e4e2bbf43370/data/processed/voc.npz -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mvc 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python=3.8 7 | - pip -------------------------------------------------------------------------------- /models/voc/CoMVC/best.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielTrosten/mvc/6e50f6a7bf7338ccb02935a7e273e4e2bbf43370/models/voc/CoMVC/best.pt -------------------------------------------------------------------------------- /models/voc/CoMVC/config.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielTrosten/mvc/6e50f6a7bf7338ccb02935a7e273e4e2bbf43370/models/voc/CoMVC/config.pkl -------------------------------------------------------------------------------- /models/voc/SiMVC/best.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielTrosten/mvc/6e50f6a7bf7338ccb02935a7e273e4e2bbf43370/models/voc/SiMVC/best.pt -------------------------------------------------------------------------------- /models/voc/SiMVC/config.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielTrosten/mvc/6e50f6a7bf7338ccb02935a7e273e4e2bbf43370/models/voc/SiMVC/config.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | tqdm 4 | numpy 5 | scipy 6 | scikit-learn 7 | pandas 8 | pydantic 9 | opencv-python 10 | pyyaml 11 | tabulate 12 | tensorboard 13 | numba 14 | typing_extensions 15 | wandb 16 | plotly 17 | h5py -------------------------------------------------------------------------------- /src/config/__init__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | from .constants import * 5 | from .config import Config 6 | from . import defaults, experiments 7 | from .eamc import experiments as eamc_experiments 8 | from .eamc import defaults as eamc_defaults 9 | 10 | 11 | def parse_config_name_arg(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("-c", "--config", dest="config_name", required=True) 14 | return parser.parse_known_args()[0].config_name 15 | 16 | 17 | def set_cfg_value(cfg, key_list, value): 18 | sub_cfg = cfg 19 | for key in key_list[:-1]: 20 | sub_cfg = getattr(sub_cfg, key) 21 | setattr(sub_cfg, key_list[-1], value) 22 | 23 | 24 | def update_cfg(cfg): 25 | sep = "__" 26 | parser = argparse.ArgumentParser() 27 | cfg_dict = hparams_dict(cfg, sep=sep) 28 | 29 | parser.add_argument("-c", "--config", dest="config_name") 30 | 31 | for key, value in cfg_dict.items(): 32 | value_type = type(value) if isinstance(value, (int, float, bool)) else None 33 | parser.add_argument("--" + key, dest=key, default=value, type=value_type) 34 | 35 | args = parser.parse_args() 36 | 37 | for key in cfg_dict.keys(): 38 | key_list = key.split(sep) 39 | value = getattr(args, key) 40 | set_cfg_value(cfg, key_list, value) 41 | 42 | 43 | def get_config_by_name(name): 44 | try: 45 | if name.startswith("eamc"): 46 | cfg = getattr(eamc_experiments, name) 47 | else: 48 | cfg = getattr(experiments, name) 49 | except Exception as err: 50 | raise RuntimeError(f"Config not found: {name}") from err 51 | cfg.model_config.loss_config.n_clusters = cfg.model_config.cm_config.n_clusters 52 | return cfg 53 | 54 | 55 | def get_config_from_file(name=None, tag=None, file_path=None, run=0): 56 | if file_path is None: 57 | file_path = MODELS_DIR / f"{name}-{tag}" / f"run-{run}" / "config.pkl" 58 | with open(file_path, "rb") as f: 59 | cfg = pickle.load(f) 60 | return cfg 61 | 62 | 63 | def get_experiment_config(): 64 | name = parse_config_name_arg() 65 | cfg = get_config_by_name(name) 66 | update_cfg(cfg) 67 | return name, cfg 68 | 69 | 70 | def _insert_hparams(cfg_dict, hp_dict, key_prefix, skip_keys, sep="/"): 71 | hparam_types = (str, int, float, bool) 72 | for key, value in cfg_dict.items(): 73 | if key in skip_keys: 74 | continue 75 | _key = f"{key_prefix}{sep}{key}" if key_prefix else key 76 | if isinstance(value, hparam_types) or value is None: 77 | hp_dict[_key] = value 78 | elif isinstance(value, dict): 79 | _insert_hparams(value, hp_dict, _key, skip_keys, sep=sep) 80 | 81 | 82 | def hparams_dict(cfg, sep="/"): 83 | skip_keys = [] 84 | hp_dict = {} 85 | _insert_hparams(cfg.dict(), hp_dict, "", skip_keys, sep=sep) 86 | return hp_dict 87 | -------------------------------------------------------------------------------- /src/config/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class Config(BaseModel): 5 | @property 6 | def class_name(self): 7 | return self.__class__.__name__ 8 | 9 | -------------------------------------------------------------------------------- /src/config/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch as th 3 | from pathlib import Path 4 | 5 | 6 | CUDA_AVALABLE = th.cuda.is_available() 7 | DEVICE = th.device("cuda" if CUDA_AVALABLE else "cpu") 8 | 9 | PROJECT_ROOT = Path(os.path.abspath(__file__)).parents[2] 10 | DATA_DIR = PROJECT_ROOT / "data" 11 | MODELS_DIR = PROJECT_ROOT / "models" 12 | 13 | DATETIME_FMT = "%Y-%m-%d_%H-%M-%S" 14 | -------------------------------------------------------------------------------- /src/config/defaults.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List, Union, Optional 2 | from typing_extensions import Literal 3 | 4 | from config import Config 5 | 6 | 7 | class Dataset(Config): 8 | # Name of the dataset. Must correspond to a filename in data/processed/ 9 | name: str 10 | # Number of samples to load. Set to None to load all samples 11 | n_samples: int = None 12 | # Subset of views to load. Set to None to load all views 13 | select_views: Tuple[int, ...] = None 14 | # Subset of labels (classes) to load. Set to None to load all classes 15 | select_labels: Tuple[int, ...] = None 16 | # Number of samples to load for each class. Set to None to load all samples 17 | label_counts: Tuple[int, ...] = None 18 | # Standard deviation of noise added to the views `noise_views`. 19 | noise_sd: float = None 20 | # Subset of views to add noise to 21 | noise_views: Tuple[int, ...] = None 22 | 23 | 24 | class Loss(Config): 25 | # Number of clusters 26 | n_clusters: int = None 27 | # Terms to use in the loss, separated by '|'. E.g. "ddc_1|ddc_2|ddc_3|" for the DDC clustering loss 28 | funcs: str 29 | # Optional weights for the loss terms. Set to None to have all weights equal to 1. 30 | weights: Tuple[Union[float, int], ...] = None 31 | # Multiplication factor for the sigma hyperparameter 32 | rel_sigma = 0.15 33 | # Tau hyperparameter 34 | tau = 0.1 35 | # Delta hyperparameter 36 | delta = 0.1 37 | # Fraction of batch size to use as the number of negative samples in the contrastive loss. Set to -1 to use all 38 | # pairs (except the positive) as negative pairs. 39 | negative_samples_ratio: float = 0.25 40 | # Similarity function for the contrastive loss. "cos" (default) and "gauss" are supported. 41 | contrastive_similarity: Literal["cos", "gauss"] = "cos" 42 | # Enable the adaptive contrastive weighting? 43 | adaptive_contrastive_weight = True 44 | 45 | 46 | class Optimizer(Config): 47 | # Base learning rate 48 | learning_rate: float = 0.001 49 | # Max gradient norm for gradient clipping. 50 | clip_norm: float = 5.0 51 | # Step size for the learning rate scheduler. None disables the scheduler. 52 | scheduler_step_size: int = None 53 | # Multiplication factor for the learning rate scheduler 54 | scheduler_gamma: float = 0.1 55 | 56 | 57 | class DDC(Config): 58 | # Number of clusters 59 | n_clusters: int = None 60 | # Number of units in the first fully connected layer 61 | n_hidden = 100 62 | # Use batch norm after the first fully connected layer? 63 | use_bn = True 64 | 65 | 66 | class CNN(Config): 67 | # Shape of the input image. Format: CHW 68 | input_size: Tuple[int, ...] = None 69 | # Network layers 70 | layers = ( 71 | ("conv", 5, 5, 32, "relu"), 72 | ("conv", 5, 5, 32, None), 73 | ("bn",), 74 | ("relu",), 75 | ("pool", 2, 2), 76 | ("conv", 3, 3, 32, "relu"), 77 | ("conv", 3, 3, 32, None), 78 | ("bn",), 79 | ("relu",), 80 | ("pool", 2, 2), 81 | ) 82 | 83 | 84 | class MLP(Config): 85 | # Shape of the input 86 | input_size: Tuple[int, ...] = None 87 | # Units in the network layers 88 | layers: Tuple[Union[int, str], ...] = (512, 512, 256) 89 | # Activation function. Can be a single string specifying the activation function for all layers, or a list/tuple of 90 | # string specifying the activation function for each layer. 91 | activation: Union[str, None, List[Union[None, str]], Tuple[Union[None, str], ...]] = "relu" 92 | # Include bias parameters? A single bool for all layers, or a list/tuple of booleans for individual layers. 93 | use_bias: Union[bool, Tuple[bool, ...]] = True 94 | # Include batch norm after layers? A single bool for all layers, or a list/tuple of booleans for individual layers. 95 | use_bn: Union[bool, Tuple[bool, ...]] = False 96 | 97 | 98 | class Fusion(Config): 99 | # Fusion method. "mean" constant weights = 1/V. "weighted_mean": Weighted average with learned weights. 100 | method: Literal["mean", "weighted_mean"] 101 | # Number of views in the dataset 102 | n_views: int 103 | 104 | 105 | class DDCModel(Config): 106 | # Encoder network config 107 | backbone_config: Union[MLP, CNN] 108 | # Clustering module config 109 | cm_config: Union[DDC] 110 | # Loss function config 111 | loss_config: Loss 112 | # Optimizer config 113 | optimizer_config = Optimizer() 114 | 115 | 116 | class SiMVC(Config): 117 | # Tuple of encoder configs. One for each modality. 118 | backbone_configs: Tuple[Union[MLP, CNN], ...] 119 | # Fusion module config. 120 | fusion_config: Fusion 121 | # Clustering module config. 122 | cm_config: Union[DDC] 123 | # Loss function config 124 | loss_config: Loss 125 | # Optimizer config 126 | optimizer_config = Optimizer() 127 | 128 | 129 | class CoMVC(Config): 130 | # Tuple of encoder configs. One for each modality. 131 | backbone_configs: Tuple[Union[MLP, CNN], ...] 132 | # Projection head config. Set to None to remove the projection head. 133 | projector_config: Optional[MLP] 134 | # Fusion module config. 135 | fusion_config: Fusion 136 | # Clustering module config. 137 | cm_config: Union[DDC] 138 | # Loss function config 139 | loss_config: Loss 140 | # Optimizer config 141 | optimizer_config = Optimizer() 142 | 143 | 144 | class Experiment(Config): 145 | # Dataset config 146 | dataset_config: Dataset 147 | # Model config 148 | model_config: Union[CoMVC, SiMVC, DDC] 149 | # Number of training runs 150 | n_runs = 20 151 | # Number of training epochs 152 | n_epochs = 100 153 | # Batch size 154 | batch_size = 100 155 | # Number of epochs between model evaluation. 156 | eval_interval: int = 4 157 | # Number of epochs between model checkpoints. 158 | checkpoint_interval = 20 159 | # Patience for early stopping. 160 | patience = 50000 161 | # Number of samples to use for evaluation. Set to None to use all samples in the dataset. 162 | n_eval_samples: int = None 163 | # Term in loss function to use for model selection. Set to "tot" to use the sum of all terms. 164 | best_loss_term = "ddc_1" 165 | -------------------------------------------------------------------------------- /src/config/eamc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielTrosten/mvc/6e50f6a7bf7338ccb02935a7e273e4e2bbf43370/src/config/eamc/__init__.py -------------------------------------------------------------------------------- /src/config/eamc/defaults.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom implementation of End-to-End Adversarial-Attention Network for Multi-Modal Clustering (EAMC). 3 | https://openaccess.thecvf.com/content_CVPR_2020/papers/Zhou_End-to-End_Adversarial-Attention_Network_for_Multi-Modal_Clustering_CVPR_2020_paper.pdf 4 | Based on code sent to us by the original authors. 5 | """ 6 | 7 | from typing import Tuple, Union, Optional 8 | 9 | from config.config import Config 10 | from config.defaults import MLP, DDC, CNN, Dataset, Fusion 11 | 12 | 13 | class Loss(Config): 14 | # Multiplication factor for the sigma hyperparameter 15 | rel_sigma: float = 0.15 16 | # Weight of adversarial losses 17 | gamma: float = 10 18 | # Number of clusters 19 | n_clusters: int = None 20 | # Optional weights for the loss terms. Set to None to have all weights equal to 1. 21 | weights: Tuple[Union[float, int], ...] = None 22 | # Terms to use in the loss, separated by '|'. E.g. "ddc_1|ddc_2|ddc_3|" for the DDC clustering loss 23 | funcs = "ddc_1|ddc_2_flipped|ddc_3|att|gen|disc" 24 | 25 | 26 | class AttentionLayer(Config): 27 | # Softmax temperature 28 | tau: float = 10.0 29 | # Config for the attention net. Final layer will be added automatically 30 | mlp_config: MLP = MLP( 31 | layers=(100, 50), 32 | activation=None 33 | ) 34 | # Number of input views 35 | n_views: int = 2 36 | 37 | 38 | class Discriminator(Config): 39 | # Config for the discriminator 40 | mlp_config: MLP = MLP( 41 | layers=(256, 256, 128), 42 | activation="leaky_relu:0.2" 43 | ) 44 | 45 | 46 | class Optimizer(Config): 47 | # Discriminator learning rate 48 | lr_disc: float = 1e-3 49 | # Encoder learning rate 50 | lr_backbones: float = 1e-5 51 | # Attention learning rate 52 | lr_att: float = 1e-4 53 | # Clustering module learning rate 54 | lr_clustering_module: float = 1e-5 55 | # Beta parameters for the discriminator 56 | betas_disc = (0.5, 0.999) 57 | # Beta parameters for the encoders 58 | betas_backbones = (0.95, 0.999) 59 | # Beta parameters for the attention net 60 | betas_att = (0.95, 0.999) 61 | # Beta parameters for the clustering module 62 | betas_clustering_module = (0.95, 0.999) 63 | 64 | 65 | class EAMC(Config): 66 | # Encoder configs 67 | backbone_configs: Tuple[Union[MLP, CNN], ...] 68 | # Attention net config. Set to None to remove attention net 69 | attention_config: Optional[AttentionLayer] = AttentionLayer() 70 | # Optional fusion config to use instead of attention net. 71 | fusion_config: Fusion = None 72 | # Discriminator config 73 | discriminator_config: Optional[Discriminator] = Discriminator() 74 | # Clustering module config 75 | cm_config: DDC 76 | # Loss config 77 | loss_config: Loss = Loss() 78 | # Optimizer config 79 | optimizer_config: Optimizer = Optimizer() 80 | # Max norm for gradient cliping 81 | clip_norm = 0.5 82 | # Number of consecutive batches to train the encoders, attention net and clustering module. 83 | t: int = 1 84 | # Number of consecutive batches to train the discriminator. 85 | t_disc: int = 1 86 | 87 | 88 | class EAMCExperiment(Config): 89 | # Dataset config 90 | dataset_config: Dataset 91 | # Model config 92 | model_config: EAMC 93 | # Number of training runs 94 | n_runs = 20 95 | # Number of epochs per run 96 | n_epochs = 500 97 | # Batch size 98 | batch_size = 100 99 | # Number of epochs between model evaluation. 100 | eval_interval: int = 5 101 | # Number of epochs between model checkpoints. 102 | checkpoint_interval = 50 103 | # Number of samples to use for evaluation. Set to None to use all samples in the dataset. 104 | n_eval_samples: int = None 105 | # Patience for early stopping. 106 | patience: int = 1e9 107 | # Term in loss function to use for model selection. Set to "tot" to use the sum of all terms. 108 | best_loss_term = "tot" 109 | -------------------------------------------------------------------------------- /src/config/eamc/experiments.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom implementation of End-to-End Adversarial-Attention Network for Multi-Modal Clustering (EAMC). 3 | https://openaccess.thecvf.com/content_CVPR_2020/papers/Zhou_End-to-End_Adversarial-Attention_Network_for_Multi-Modal_Clustering_CVPR_2020_paper.pdf 4 | Based on code sent to us by the original authors. 5 | """ 6 | 7 | from config.defaults import MLP, CNN, DDC, Dataset 8 | from config.eamc.defaults import EAMCExperiment, EAMC, AttentionLayer, Discriminator, Loss, Optimizer 9 | 10 | BACKBONE_MLP_LAYERS = (200, 200, 500) 11 | CNN_LAYERS = ( 12 | ("conv", 5, 5, 32, "relu"), 13 | ("pool", 2, 2), 14 | ("conv", 5, 5, 64, "relu"), 15 | ("pool", 2, 2), 16 | ("fc", 500), 17 | ("bn",), 18 | ("relu",) 19 | ) 20 | CNN_BACKBONES = ( 21 | CNN(layers=CNN_LAYERS, input_size=(1, 28, 28)), 22 | CNN(layers=CNN_LAYERS, input_size=(1, 28, 28)), 23 | ) 24 | 25 | 26 | eamc_blobs_overlap = EAMCExperiment( 27 | dataset_config=Dataset(name="blobs_overlap"), 28 | model_config=EAMC( 29 | backbone_configs=( 30 | MLP(layers=[32, 32, 32], input_size=(2,)), 31 | MLP(layers=[32, 32, 32], input_size=(2,)), 32 | ), 33 | discriminator_config=Discriminator( 34 | mlp_config=MLP(layers=(32, 32, 32)) 35 | ), 36 | loss_config=Loss(), 37 | cm_config=DDC(n_clusters=3), 38 | optimizer_config=Optimizer(lr_backbones=2e-4, lr_disc=1e-5) 39 | ), 40 | ) 41 | 42 | eamc_blobs_overlap_5 = EAMCExperiment( 43 | dataset_config=Dataset(name="blobs_overlap_5"), 44 | model_config=EAMC( 45 | backbone_configs=( 46 | MLP(layers=[32, 32, 32], input_size=(2,)), 47 | MLP(layers=[32, 32, 32], input_size=(2,)), 48 | ), 49 | discriminator_config=Discriminator( 50 | mlp_config=MLP(layers=(32, 32, 32)) 51 | ), 52 | loss_config=Loss(), 53 | cm_config=DDC(n_clusters=5), 54 | optimizer_config=Optimizer(lr_backbones=2e-4, lr_disc=1e-5) 55 | ), 56 | ) 57 | 58 | eamc_mnist = EAMCExperiment( 59 | dataset_config=Dataset(name="mnist_mv"), 60 | model_config=EAMC( 61 | backbone_configs=CNN_BACKBONES, 62 | cm_config=DDC(n_clusters=10), 63 | ), 64 | ) 65 | 66 | eamc_mnist_var_noise = EAMCExperiment( 67 | dataset_config=Dataset(name="mnist_mv", noise_sd=1.0, noise_views=(1,)), 68 | model_config=EAMC( 69 | backbone_configs=CNN_BACKBONES, 70 | cm_config=DDC(n_clusters=10), 71 | ), 72 | ) 73 | 74 | 75 | eamc_fmnist = EAMCExperiment( 76 | dataset_config=Dataset(name="fmnist"), 77 | model_config=EAMC( 78 | backbone_configs=CNN_BACKBONES, 79 | cm_config=DDC(n_clusters=10), 80 | ), 81 | ) 82 | 83 | eamc_coil = EAMCExperiment( 84 | dataset_config=Dataset(name="coil"), 85 | model_config=EAMC( 86 | backbone_configs=( 87 | CNN(input_size=(1, 128, 128)), 88 | CNN(input_size=(1, 128, 128)), 89 | CNN(input_size=(1, 128, 128)), 90 | ), 91 | cm_config=DDC(n_clusters=20), 92 | attention_config=AttentionLayer(n_views=3) 93 | ), 94 | ) 95 | 96 | eamc_rgbd = EAMCExperiment( 97 | dataset_config=Dataset(name="rgbd"), 98 | model_config=EAMC( 99 | backbone_configs=( 100 | MLP(layers=BACKBONE_MLP_LAYERS, input_size=(2048,)), 101 | MLP(layers=BACKBONE_MLP_LAYERS, input_size=(300,)), 102 | ), 103 | cm_config=DDC(n_clusters=13), 104 | optimizer_config=Optimizer(lr_backbones=6e-5, lr_disc=2e-5) 105 | ), 106 | ) 107 | -------------------------------------------------------------------------------- /src/config/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | from .test import * 2 | from .mnist import * 3 | from .ccv import * 4 | from .voc import * 5 | from .rgbd import * 6 | from .fmnist import * 7 | from .coil import * 8 | -------------------------------------------------------------------------------- /src/config/experiments/ccv.py: -------------------------------------------------------------------------------- 1 | from config.defaults import Experiment, SiMVC, DDC, Fusion, MLP, Loss, Dataset, CoMVC, Optimizer 2 | 3 | ccv = Experiment( 4 | dataset_config=Dataset(name="ccv"), 5 | model_config=SiMVC( 6 | backbone_configs=( 7 | MLP(input_size=(5000,)), 8 | MLP(input_size=(5000,)), 9 | MLP(input_size=(4000,)), 10 | ), 11 | fusion_config=Fusion(method="weighted_mean", n_views=3), 12 | cm_config=DDC(n_clusters=20), 13 | loss_config=Loss( 14 | funcs="ddc_1|ddc_2|ddc_3", 15 | ), 16 | optimizer_config=Optimizer() 17 | ), 18 | ) 19 | 20 | ccv_contrast = Experiment( 21 | dataset_config=Dataset(name="ccv"), 22 | model_config=CoMVC( 23 | backbone_configs=( 24 | MLP(input_size=(5000,)), 25 | MLP(input_size=(5000,)), 26 | MLP(input_size=(4000,)), 27 | ), 28 | fusion_config=Fusion(method="weighted_mean", n_views=3), 29 | projector_config=None, 30 | cm_config=DDC(n_clusters=20), 31 | loss_config=Loss( 32 | funcs="ddc_1|ddc_2|ddc_3|contrast", 33 | delta=20.0 34 | ), 35 | optimizer_config=Optimizer(scheduler_step_size=50, scheduler_gamma=0.1) 36 | ), 37 | n_epochs=100 38 | ) 39 | -------------------------------------------------------------------------------- /src/config/experiments/coil.py: -------------------------------------------------------------------------------- 1 | from config.defaults import Experiment, SiMVC, CNN, DDC, Fusion, Loss, Dataset, CoMVC, Optimizer 2 | 3 | 4 | coil = Experiment( 5 | dataset_config=Dataset(name="coil"), 6 | model_config=SiMVC( 7 | backbone_configs=( 8 | CNN(input_size=(1, 128, 128)), 9 | CNN(input_size=(1, 128, 128)), 10 | CNN(input_size=(1, 128, 128)), 11 | ), 12 | fusion_config=Fusion(method="weighted_mean", n_views=3), 13 | cm_config=DDC(n_clusters=20), 14 | loss_config=Loss( 15 | funcs="ddc_1|ddc_2|ddc_3", 16 | ), 17 | optimizer_config=Optimizer() 18 | ), 19 | n_epochs=100, 20 | ) 21 | 22 | coil_contrast = Experiment( 23 | dataset_config=Dataset(name="coil"), 24 | model_config=CoMVC( 25 | backbone_configs=( 26 | CNN(input_size=(1, 128, 128)), 27 | CNN(input_size=(1, 128, 128)), 28 | CNN(input_size=(1, 128, 128)), 29 | ), 30 | fusion_config=Fusion(method="weighted_mean", n_views=3), 31 | projector_config=None, 32 | cm_config=DDC(n_clusters=20), 33 | loss_config=Loss( 34 | funcs="ddc_1|ddc_2|ddc_3|contrast", 35 | delta=20.0 36 | ), 37 | optimizer_config=Optimizer() 38 | ), 39 | ) -------------------------------------------------------------------------------- /src/config/experiments/fmnist.py: -------------------------------------------------------------------------------- 1 | from config.defaults import Experiment, SiMVC, CNN, DDC, Fusion, Loss, Dataset, CoMVC, Optimizer 2 | 3 | fmnist = Experiment( 4 | dataset_config=Dataset(name="fmnist"), 5 | model_config=SiMVC( 6 | backbone_configs=( 7 | CNN(input_size=(1, 28, 28)), 8 | CNN(input_size=(1, 28, 28)), 9 | ), 10 | fusion_config=Fusion(method="weighted_mean", n_views=2), 11 | cm_config=DDC(n_clusters=10), 12 | loss_config=Loss( 13 | funcs="ddc_1|ddc_2|ddc_3", 14 | ), 15 | optimizer_config=Optimizer() 16 | ), 17 | ) 18 | 19 | 20 | fmnist_contrast = Experiment( 21 | dataset_config=Dataset(name="fmnist"), 22 | model_config=CoMVC( 23 | backbone_configs=( 24 | CNN(input_size=(1, 28, 28)), 25 | CNN(input_size=(1, 28, 28)), 26 | ), 27 | fusion_config=Fusion(method="weighted_mean", n_views=2), 28 | projector_config=None, 29 | cm_config=DDC(n_clusters=10), 30 | loss_config=Loss( 31 | funcs="ddc_1|ddc_2|ddc_3|contrast", 32 | ), 33 | optimizer_config=Optimizer() 34 | ), 35 | ) 36 | -------------------------------------------------------------------------------- /src/config/experiments/mnist.py: -------------------------------------------------------------------------------- 1 | from config.defaults import Experiment, SiMVC, CNN, DDC, Fusion, Loss, Dataset, CoMVC, Optimizer 2 | 3 | 4 | mnist = Experiment( 5 | dataset_config=Dataset(name="mnist_mv"), 6 | model_config=SiMVC( 7 | backbone_configs=( 8 | CNN(input_size=(1, 28, 28)), 9 | CNN(input_size=(1, 28, 28)), 10 | ), 11 | fusion_config=Fusion(method="weighted_mean", n_views=2), 12 | cm_config=DDC(n_clusters=10), 13 | loss_config=Loss( 14 | funcs="ddc_1|ddc_2|ddc_3", 15 | ), 16 | optimizer_config=Optimizer() 17 | ), 18 | ) 19 | 20 | 21 | mnist_contrast = Experiment( 22 | dataset_config=Dataset(name="mnist_mv"), 23 | model_config=CoMVC( 24 | backbone_configs=( 25 | CNN(input_size=(1, 28, 28)), 26 | CNN(input_size=(1, 28, 28)), 27 | ), 28 | fusion_config=Fusion(method="weighted_mean", n_views=2), 29 | projector_config=None, 30 | cm_config=DDC(n_clusters=10), 31 | loss_config=Loss( 32 | funcs="ddc_1|ddc_2|ddc_3|contrast", 33 | ), 34 | optimizer_config=Optimizer() 35 | ), 36 | ) 37 | -------------------------------------------------------------------------------- /src/config/experiments/rgbd.py: -------------------------------------------------------------------------------- 1 | from config.defaults import Experiment, Dataset, SiMVC, DDC, Fusion, MLP, Loss, CoMVC, Optimizer 2 | 3 | 4 | rgbd = Experiment( 5 | dataset_config=Dataset(name="rgbd"), 6 | model_config=SiMVC( 7 | backbone_configs=( 8 | MLP(input_size=(2048,)), 9 | MLP(input_size=(300,)), 10 | ), 11 | fusion_config=Fusion(method="weighted_mean", n_views=2), 12 | cm_config=DDC(n_clusters=13), 13 | loss_config=Loss( 14 | funcs="ddc_1|ddc_2|ddc_3", 15 | ) 16 | ), 17 | ) 18 | 19 | rgbd_contrast = Experiment( 20 | dataset_config=Dataset(name="rgbd"), 21 | model_config=CoMVC( 22 | backbone_configs=( 23 | MLP(input_size=(2048,)), 24 | MLP(input_size=(300,)), 25 | ), 26 | fusion_config=Fusion(method="weighted_mean", n_views=2), 27 | projector_config=None, 28 | cm_config=DDC(n_clusters=13), 29 | loss_config=Loss( 30 | funcs="ddc_1|ddc_2|ddc_3|contrast", 31 | ), 32 | optimizer_config=Optimizer(scheduler_step_size=50, scheduler_gamma=0.5) 33 | ), 34 | ) 35 | -------------------------------------------------------------------------------- /src/config/experiments/test.py: -------------------------------------------------------------------------------- 1 | from config.defaults import Experiment, Dataset, SiMVC, MLP, DDC, Fusion, Loss, CoMVC 2 | 3 | 4 | blobs_overlap = Experiment( 5 | dataset_config=Dataset(name="blobs_overlap"), 6 | model_config=SiMVC( 7 | backbone_configs=( 8 | MLP(layers=[32, 32, 32], input_size=(2,)), 9 | MLP(layers=[32, 32, 32], input_size=(2,)), 10 | ), 11 | fusion_config=Fusion(method="weighted_mean", n_views=2), 12 | cm_config=DDC(n_clusters=3), 13 | loss_config=Loss( 14 | funcs="ddc_1|ddc_2|ddc_3", 15 | ), 16 | ), 17 | n_runs=1, 18 | n_epochs=10, 19 | ) 20 | 21 | 22 | blobs_overlap_contrast = Experiment( 23 | dataset_config=Dataset(name="blobs_overlap"), 24 | model_config=CoMVC( 25 | backbone_configs=( 26 | MLP(layers=[32, 32, 32], input_size=(2,)), 27 | MLP(layers=[32, 32, 32], input_size=(2,)), 28 | ), 29 | fusion_config=Fusion(method="weighted_mean", n_views=2), 30 | projector_config=None, 31 | cm_config=DDC(n_clusters=3), 32 | loss_config=Loss( 33 | funcs="ddc_1|ddc_2|ddc_3|contrast", 34 | ) 35 | ), 36 | n_runs=1, 37 | ) 38 | 39 | blobs_overlap_5 = Experiment( 40 | dataset_config=Dataset(name="blobs_overlap_5"), 41 | model_config=SiMVC( 42 | backbone_configs=( 43 | MLP(layers=[32, 32, 32], input_size=(2,)), 44 | MLP(layers=[32, 32, 32], input_size=(2,)), 45 | ), 46 | fusion_config=Fusion(method="weighted_mean", n_views=2), 47 | cm_config=DDC(n_clusters=5), 48 | loss_config=Loss( 49 | funcs="ddc_1|ddc_2|ddc_3", 50 | ), 51 | ), 52 | n_runs=1, 53 | ) 54 | 55 | blobs_overlap_5_contrast = Experiment( 56 | dataset_config=Dataset(name="blobs_overlap_5"), 57 | model_config=CoMVC( 58 | backbone_configs=( 59 | MLP(layers=[32, 32, 32], input_size=(2,)), 60 | MLP(layers=[32, 32, 32], input_size=(2,)), 61 | ), 62 | fusion_config=Fusion(method="weighted_mean", n_views=2), 63 | projector_config=None, 64 | cm_config=DDC(n_clusters=5), 65 | loss_config=Loss( 66 | funcs="ddc_1|ddc_2|ddc_3|contrast", 67 | ) 68 | ), 69 | n_runs=1, 70 | ) 71 | -------------------------------------------------------------------------------- /src/config/experiments/voc.py: -------------------------------------------------------------------------------- 1 | from config.defaults import Experiment, Dataset, SiMVC, DDC, Fusion, MLP, Loss, CoMVC, Optimizer 2 | 3 | voc = Experiment( 4 | dataset_config=Dataset(name="voc"), 5 | model_config=SiMVC( 6 | backbone_configs=( 7 | MLP(input_size=(512,)), 8 | MLP(input_size=(399,)), 9 | ), 10 | fusion_config=Fusion(method="weighted_mean", n_views=2), 11 | cm_config=DDC(n_clusters=20), 12 | loss_config=Loss( 13 | funcs="ddc_1|ddc_2|ddc_3", 14 | ), 15 | optimizer_config=Optimizer(learning_rate=1e-3, scheduler_step_size=50, scheduler_gamma=0.1) 16 | ), 17 | ) 18 | 19 | voc_contrast = Experiment( 20 | dataset_config=Dataset(name="voc"), 21 | model_config=CoMVC( 22 | backbone_configs=( 23 | MLP(input_size=(512,)), 24 | MLP(input_size=(399,)), 25 | ), 26 | projector_config=None, 27 | fusion_config=Fusion(method="weighted_mean", n_views=2), 28 | cm_config=DDC(n_clusters=20), 29 | loss_config=Loss( 30 | funcs="ddc_1|ddc_2|ddc_3|contrast", 31 | ), 32 | optimizer_config=Optimizer(learning_rate=1e-3) 33 | ), 34 | ) 35 | -------------------------------------------------------------------------------- /src/data/load.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | import config 5 | 6 | 7 | def _load_npz(name): 8 | return np.load(config.DATA_DIR / "processed" / f"{name}.npz") 9 | 10 | 11 | def _fix_labels(l): 12 | uniq = np.unique(l)[None, :] 13 | new = (l[:, None] == uniq).argmax(axis=1) 14 | return new 15 | 16 | 17 | def load_dataset(name, n_samples=None, select_views=None, select_labels=None, label_counts=None, noise_sd=None, 18 | noise_views=None, to_dataset=True, **kwargs): 19 | npz = _load_npz(name) 20 | labels = npz["labels"] 21 | views = [npz[f"view_{i}"] for i in range(npz["n_views"])] 22 | 23 | if select_labels is not None: 24 | mask = np.isin(labels, select_labels) 25 | labels = labels[mask] 26 | views = [v[mask] for v in views] 27 | labels = _fix_labels(labels) 28 | 29 | if label_counts is not None: 30 | idx = [] 31 | unique_labels = np.unique(labels) 32 | assert len(unique_labels) == len(label_counts) 33 | for l, n in zip(unique_labels, label_counts): 34 | _idx = np.random.choice(np.where(labels == l)[0], size=n, replace=False) 35 | idx.append(_idx) 36 | 37 | idx = np.concatenate(idx, axis=0) 38 | labels = labels[idx] 39 | views = [v[idx] for v in views] 40 | 41 | if n_samples is not None: 42 | idx = np.random.choice(labels.shape[0], size=min(labels.shape[0], int(n_samples)), replace=False) 43 | labels = labels[idx] 44 | views = [v[idx] for v in views] 45 | 46 | if select_views is not None: 47 | if not isinstance(select_views, (list, tuple)): 48 | select_views = [select_views] 49 | views = [views[i] for i in select_views] 50 | 51 | if noise_sd is not None: 52 | assert noise_views is not None, "'noise_views' has to be specified when 'noise_sd' is not None." 53 | if not isinstance(noise_views, (list, tuple)): 54 | noise_views = [int(noise_views)] 55 | for v in noise_views: 56 | views[v] += np.random.normal(loc=0, scale=float(noise_sd), size=views[v].shape) 57 | 58 | views = [v.astype(np.float32) for v in views] 59 | if to_dataset: 60 | dataset = th.utils.data.TensorDataset(*[th.Tensor(v).to(config.DEVICE, non_blocking=True) for v in views], 61 | th.Tensor(labels).to(config.DEVICE, non_blocking=True)) 62 | else: 63 | dataset = (views, labels) 64 | return dataset 65 | -------------------------------------------------------------------------------- /src/data/make_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import numpy as np 5 | import torch as th 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | from sklearn.datasets import make_blobs 9 | 10 | import config 11 | 12 | 13 | def export_dataset(name, views, labels): 14 | processed_dir = config.DATA_DIR / "processed" 15 | os.makedirs(processed_dir, exist_ok=True) 16 | file_path = processed_dir / f"{name}.npz" 17 | npz_dict = {"labels": labels, "n_views": len(views)} 18 | for i, v in enumerate(views): 19 | npz_dict[f"view_{i}"] = v 20 | np.savez(file_path, **npz_dict) 21 | 22 | 23 | def _concat_edge_image(img): 24 | img = np.array(img) 25 | dilation = cv2.dilate(img, np.ones((3, 3), np.uint8), iterations=1) 26 | edge = dilation - img 27 | return np.stack((img, edge), axis=-1) 28 | 29 | 30 | def _mnist(add_edge_img, dataset_class=torchvision.datasets.MNIST): 31 | img_transforms = [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] 32 | if add_edge_img: 33 | img_transforms.insert(0, _concat_edge_image) 34 | transform = transforms.Compose(img_transforms) 35 | dataset = dataset_class(root=config.DATA_DIR / "raw", train=True, download=True, transform=transform) 36 | 37 | loader = th.utils.data.DataLoader(dataset, batch_size=len(dataset)) 38 | data, labels = list(loader)[0] 39 | return data, labels 40 | 41 | 42 | def mnist_mv(): 43 | data, labels = _mnist(add_edge_img=True) 44 | views = np.split(data, data.shape[1], axis=1) 45 | export_dataset("mnist_mv", views=views, labels=labels) 46 | 47 | 48 | def fmnist(): 49 | data, labels = _mnist(add_edge_img=True, dataset_class=torchvision.datasets.FashionMNIST) 50 | views = np.split(data, data.shape[1], axis=1) 51 | export_dataset("fmnist", views=views, labels=labels) 52 | 53 | 54 | def ccv(): 55 | ccv_dir = config.DATA_DIR / "raw" / "CCV" 56 | 57 | def _load_train_test(typ, suffix="Feature"): 58 | if typ: 59 | typ += "-" 60 | train = np.loadtxt(ccv_dir / f"{typ}train{suffix}.txt") 61 | test = np.loadtxt(ccv_dir / f"{typ}test{suffix}.txt") 62 | return np.concatenate((train, test), axis=0) 63 | 64 | views = [_load_train_test(typ) for typ in ["STIP", "SIFT", "MFCC"]] 65 | labels = _load_train_test("", suffix="Label") 66 | 67 | # Only include videos with exactly one label 68 | row_mask = (labels.sum(axis=1) == 1) 69 | labels = labels[row_mask].argmax(axis=1) 70 | views = [v[row_mask] for v in views] 71 | export_dataset("ccv", views=views, labels=labels) 72 | 73 | 74 | def coil(): 75 | from skimage.io import imread 76 | 77 | data_dir = config.DATA_DIR / "raw" / "COIL" 78 | img_size = (1, 128, 128) 79 | n_objs = 20 80 | n_imgs = 72 81 | n_views = 3 82 | assert n_imgs % n_views == 0 83 | 84 | n = (n_objs * n_imgs) // n_views 85 | 86 | imgs = np.empty((n_views, n, *img_size)) 87 | labels = [] 88 | 89 | img_idx = np.arange(n_imgs) 90 | 91 | for obj in range(n_objs): 92 | obj_img_idx = np.random.permutation(img_idx).reshape(n_views, n_imgs // n_views) 93 | labels += (n_imgs // n_views) * [obj] 94 | 95 | for view, indices in enumerate(obj_img_idx): 96 | for i, idx in enumerate(indices): 97 | fname = data_dir / f"obj{obj + 1}__{idx}.png" 98 | img = imread(fname)[None, ...] 99 | imgs[view, ((obj * (n_imgs // n_views)) + i)] = img 100 | 101 | assert not np.isnan(imgs).any() 102 | views = [imgs[v] for v in range(n_views)] 103 | labels = np.array(labels) 104 | export_dataset("coil", views=views, labels=labels) 105 | 106 | 107 | def blobs_overlap(): 108 | nc = 1000 109 | ndim = 2 110 | view_1, l1 = make_blobs(n_samples=[nc, 2 * nc], n_features=ndim, cluster_std=1.0, shuffle=False) 111 | view_2, l2 = make_blobs(n_samples=[2 * nc, nc], n_features=ndim, cluster_std=1.0, shuffle=False) 112 | labels = l1 + l2 113 | export_dataset("blobs_overlap", views=[view_1, view_2], labels=labels) 114 | 115 | 116 | def blobs_overlap_5(): 117 | nc = 500 118 | ndim = 2 119 | view_1, _ = make_blobs(n_samples=[3 * nc, 2 * nc], n_features=ndim, cluster_std=1.0, shuffle=False) 120 | view_2, _ = make_blobs(n_samples=[1 * nc, 2 * nc, 2 * nc], n_features=ndim, cluster_std=1.0, shuffle=False) 121 | view_2[(2 * nc): (4 * nc)] = view_2[(2 * nc): (4 * nc)][::-1] 122 | labels = np.concatenate(([nc * [i] for i in range(5)])) 123 | export_dataset("blobs_overlap_5", views=[view_1, view_2], labels=labels) 124 | 125 | 126 | LOADERS = { 127 | "mnist_mv": mnist_mv, 128 | "ccv": ccv, 129 | "blobs_overlap": blobs_overlap, 130 | "blobs_overlap_5": blobs_overlap_5, 131 | "fmnist": fmnist, 132 | "coil": coil 133 | } 134 | 135 | 136 | def main(): 137 | export_sets = sys.argv[1:] if len(sys.argv) > 1 else LOADERS.keys() 138 | for name in export_sets: 139 | print(f"Exporting dataset '{name}'") 140 | LOADERS[name]() 141 | 142 | 143 | if __name__ == '__main__': 144 | main() 145 | -------------------------------------------------------------------------------- /src/eamc/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom implementation of End-to-End Adversarial-Attention Network for Multi-Modal Clustering (EAMC). 3 | https://openaccess.thecvf.com/content_CVPR_2020/papers/Zhou_End-to-End_Adversarial-Attention_Network_for_Multi-Modal_Clustering_CVPR_2020_paper.pdf 4 | Based on code sent to us by the original authors. 5 | """ 6 | 7 | import torch as th 8 | from torch.nn.functional import binary_cross_entropy 9 | 10 | import config 11 | from lib import loss, kernel 12 | 13 | 14 | class AttLoss(loss.LossTerm): 15 | """ 16 | Attention loss 17 | """ 18 | required_tensors = ["backbone_kernels", "fusion_kernel"] 19 | 20 | def __call__(self, net, cfg, extra): 21 | kc = th.sum(net.weights[None, None, :] * th.stack(extra["backbone_kernels"], dim=-1), dim=-1) 22 | dif = (extra["fusion_kernel"] - kc) 23 | return th.trace(dif @ th.t(dif)) 24 | 25 | 26 | class GenLoss(loss.LossTerm): 27 | """ 28 | Generator loss 29 | """ 30 | def __call__(self, net, cfg, extra): 31 | tot = th.tensor(0., device=config.DEVICE) 32 | target = th.ones(net.output.size(0), device=config.DEVICE) 33 | for _, dv in net.discriminator_outputs: 34 | tot += binary_cross_entropy(dv.squeeze(), target) 35 | return cfg.gamma * tot 36 | 37 | 38 | class DiscLoss(loss.LossTerm): 39 | """ 40 | Discriminator loss 41 | """ 42 | def __call__(self, net, cfg, extra): 43 | tot = th.tensor(0., device=config.DEVICE) 44 | real_target = th.ones(net.output.size(0), device=config.DEVICE) 45 | fake_target = th.zeros(net.output.size(0), device=config.DEVICE) 46 | for d0, dv in net.discriminator_outputs: 47 | tot += binary_cross_entropy(dv.squeeze(), fake_target) + binary_cross_entropy(d0.squeeze(), real_target) 48 | return tot 49 | 50 | 51 | def backbone_kernels(net, cfg): 52 | return [kernel.vector_kernel(h, cfg.rel_sigma) for h in net.backbone_outputs] 53 | 54 | 55 | def fusion_kernel(net, cfg): 56 | return kernel.vector_kernel(net.fused, cfg.rel_sigma) 57 | 58 | 59 | class Loss(loss.Loss): 60 | # Override the TERM_CLASSES and EXTRA_FUNCS of the Loss class, so we can include the EAMC losses. 61 | TERM_CLASSES = { 62 | "ddc_1": loss.DDC1, 63 | "ddc_2_flipped": loss.DDC2Flipped, 64 | "ddc_2": loss.DDC2, 65 | "ddc_3": loss.DDC3, 66 | "att": AttLoss, 67 | "gen": GenLoss, 68 | "disc": DiscLoss 69 | } 70 | EXTRA_FUNCS = { 71 | "hidden_kernel": loss.hidden_kernel, 72 | "backbone_kernels": backbone_kernels, 73 | "fusion_kernel": fusion_kernel 74 | } 75 | -------------------------------------------------------------------------------- /src/eamc/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom implementation of End-to-End Adversarial-Attention Network for Multi-Modal Clustering (EAMC). 3 | https://openaccess.thecvf.com/content_CVPR_2020/papers/Zhou_End-to-End_Adversarial-Attention_Network_for_Multi-Modal_Clustering_CVPR_2020_paper.pdf 4 | Based on code sent to us by the original authors. 5 | """ 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | import helpers 11 | import config 12 | from lib.fusion import get_fusion_module 13 | from lib.backbones import Backbones, MLP 14 | from models.clustering_module import DDC 15 | from eamc.loss import Loss 16 | 17 | 18 | class Discriminator(nn.Module): 19 | def __init__(self, cfg, input_size): 20 | """ 21 | EAMC discriminator 22 | 23 | :param cfg: Discriminator config 24 | :type cfg: config.eamc.defaults.Discriminator 25 | :param input_size: Input size 26 | :type input_size: Union[List[int, ...], Tuple[int, ...], ...] 27 | """ 28 | super().__init__() 29 | self.mlp = MLP(cfg.mlp_config, input_size=input_size) 30 | self.output_layer = nn.Sequential( 31 | nn.Linear(self.mlp.output_size[0], 1, bias=True), 32 | nn.Sigmoid() 33 | ) 34 | self.d0 = self.dv = None 35 | 36 | def forward(self, x0, xv): 37 | self.d0 = self.output_layer(self.mlp(x0)) 38 | self.dv = self.output_layer(self.mlp(xv)) 39 | return [self.d0, self.dv] 40 | 41 | 42 | class AttentionLayer(nn.Module): 43 | def __init__(self, cfg, input_size): 44 | """ 45 | EAMC attention net 46 | 47 | :param cfg: Attention config 48 | :type cfg: config.eamc.defaults.AttentionLayer 49 | :param input_size: Input size 50 | :type input_size: Union[List[int, ...], Tuple[int, ...], ...] 51 | """ 52 | super().__init__() 53 | self.tau = cfg.tau 54 | self.mlp = MLP(cfg.mlp_config, input_size=[input_size[0] * cfg.n_views]) 55 | self.output_layer = nn.Linear(self.mlp.output_size[0], cfg.n_views, bias=True) 56 | self.weights = None 57 | 58 | def forward(self, xs): 59 | h = th.cat(xs, dim=1) 60 | act = self.output_layer(self.mlp(h)) 61 | e = nn.functional.softmax(th.sigmoid(act) / self.tau, dim=1) 62 | # e = nn.functional.softmax(act, dim=1) 63 | self.weights = th.mean(e, dim=0) 64 | return self.weights 65 | 66 | 67 | class EAMC(nn.Module): 68 | def __init__(self, cfg): 69 | """ 70 | EAMC model 71 | 72 | :param cfg: EAMC config 73 | :type cfg: config.eamc.defaults.EAMC 74 | """ 75 | super().__init__() 76 | 77 | self.cfg = cfg 78 | self.backbones = Backbones(cfg.backbone_configs) 79 | 80 | backbone_output_sizes = self.backbones.output_sizes 81 | assert all([backbone_output_sizes[0] == s for s in backbone_output_sizes]) 82 | assert len(backbone_output_sizes[0]) == 1 83 | hidden_size = backbone_output_sizes[0] 84 | 85 | if cfg.attention_config is not None: 86 | self.fusion = None 87 | self.attention = AttentionLayer(cfg.attention_config, input_size=hidden_size) 88 | self.weights = None 89 | assert getattr(self.cfg, "fusion_config", None) is None, "EAMC attention_config and fusion_config cannot " \ 90 | "both be not-None." 91 | 92 | elif getattr(cfg, "fusion_config", None) is not None: 93 | self.fusion = get_fusion_module(cfg.fusion_config, input_sizes=backbone_output_sizes) 94 | self.attention = None 95 | self.weights = None 96 | 97 | else: 98 | self.attention = None 99 | self.weights = th.full([len(cfg.backbone_configs)], 1/len(cfg.backbone_configs), device=config.DEVICE) 100 | 101 | if cfg.discriminator_config is not None: 102 | self.discriminators = nn.ModuleList( 103 | [Discriminator(cfg.discriminator_config, input_size=hidden_size) 104 | for _ in range(len(cfg.backbone_configs) - 1)] 105 | ) 106 | else: 107 | self.discriminators = None 108 | 109 | self.ddc = DDC(hidden_size, cfg.cm_config) 110 | self.loss = Loss(cfg.loss_config) 111 | 112 | # Initialize weights. 113 | self.apply(helpers.he_init_weights) 114 | 115 | self.backbone_outputs = None 116 | self.discriminator_outputs = None 117 | self.fused = None 118 | self.hidden = None 119 | self.output = None 120 | 121 | self.clustering_optimizer, self.discriminator_optimizer = self.get_optimizers() 122 | 123 | def get_optimizers(self): 124 | opt = self.cfg.optimizer_config 125 | 126 | # Clustering optimizer 127 | clustering_optimizer_spec = [ 128 | {"params": self.backbones.parameters(), "lr": opt.lr_backbones, "betas": opt.betas_backbones}, 129 | {"params": self.ddc.parameters(), "lr": opt.lr_clustering_module, "betas": opt.betas_clustering_module} 130 | ] 131 | if self.cfg.attention_config is not None: 132 | clustering_optimizer_spec.append( 133 | {"params": self.attention.parameters(), "lr": opt.lr_att, "betas": opt.betas_att} 134 | ) 135 | if getattr(self.cfg, "fusion_config", None) is not None: 136 | clustering_optimizer_spec.append( 137 | {"params": self.fusion.parameters(), "lr": 1e-3} 138 | ) 139 | clustering_optimizer = th.optim.Adam(clustering_optimizer_spec) 140 | 141 | # Discriminator optimizer 142 | if self.cfg.discriminator_config is None: 143 | discriminator_optimizer = None 144 | else: 145 | discriminator_optimizer = th.optim.Adam([ 146 | {"params": self.discriminators.parameters(), "lr": opt.lr_disc, "betas": opt.betas_disc} 147 | ]) 148 | 149 | return clustering_optimizer, discriminator_optimizer 150 | 151 | def forward(self, views): 152 | self.backbone_outputs = self.backbones(views) 153 | if self.discriminators is not None: 154 | self.discriminator_outputs = [ 155 | self.discriminators[i](self.backbone_outputs[0], self.backbone_outputs[i+1]) 156 | for i in range(len(self.backbone_outputs) - 1) 157 | ] 158 | 159 | if self.fusion is not None: 160 | self.fused = self.fusion(self.backbone_outputs) 161 | else: 162 | if self.attention is not None: 163 | self.weights = self.attention(self.backbone_outputs) 164 | 165 | self.fused = th.sum(self.weights[None, None, :] * th.stack(self.backbone_outputs, dim=-1), dim=-1) 166 | 167 | self.output, self.hidden = self.ddc(self.fused) 168 | return self.output 169 | 170 | def calc_losses(self, ignore_in_total=tuple()): 171 | return self.loss(self, ignore_in_total=ignore_in_total) 172 | 173 | @staticmethod 174 | def _get_train_mode(i, cfg): 175 | if cfg.discriminator_config is None: 176 | return "gen" 177 | return "gen" if (i % (cfg.t + cfg.t_disc) < cfg.t) else "disc" 178 | 179 | def train_step(self, batch, epoch, it, n_batches): 180 | train_mode = self._get_train_mode(it, self.cfg) 181 | if train_mode == "disc": 182 | # Train discriminator 183 | opt = self.discriminator_optimizer 184 | loss_key = "disc" 185 | ignore_in_total = ("ddc_1", "ddc_2_flipped", "ddc_3", "att", "gen") 186 | else: 187 | opt = self.clustering_optimizer 188 | loss_key = "tot" 189 | ignore_in_total = ("disc",) 190 | 191 | opt.zero_grad() 192 | _ = self(batch) 193 | losses = self.calc_losses(ignore_in_total=ignore_in_total) 194 | losses[loss_key].backward() 195 | 196 | # Clip gradient? 197 | if train_mode == "gen" and self.cfg.clip_norm is not None: 198 | th.nn.utils.clip_grad_norm_(self.parameters(), self.cfg.clip_norm) 199 | 200 | opt.step() 201 | return losses 202 | -------------------------------------------------------------------------------- /src/helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pandas as pd 3 | import numpy as np 4 | import torch.nn as nn 5 | from sklearn.metrics import confusion_matrix 6 | from scipy.optimize import linear_sum_assignment 7 | 8 | import config 9 | 10 | 11 | def npy(t, to_cpu=True): 12 | """ 13 | Convert a tensor to a numpy array. 14 | 15 | :param t: Input tensor 16 | :type t: th.Tensor 17 | :param to_cpu: Call the .cpu() method on `t`? 18 | :type to_cpu: bool 19 | :return: Numpy array 20 | :rtype: np.ndarray 21 | """ 22 | if isinstance(t, (list, tuple)): 23 | # We got a list. Convert each element to numpy 24 | return [npy(ti) for ti in t] 25 | elif isinstance(t, dict): 26 | # We got a dict. Convert each value to numpy 27 | return {k: npy(v) for k, v in t.items()} 28 | # Assuming t is a tensor. 29 | if to_cpu: 30 | return t.cpu().detach().numpy() 31 | return t.detach().numpy() 32 | 33 | 34 | def ensure_iterable(elem, expected_length=1): 35 | if isinstance(elem, (list, tuple)): 36 | assert len(elem) == expected_length, f"Expected iterable {elem} with length {len(elem)} does not have " \ 37 | f"expected length {expected_length}" 38 | else: 39 | elem = expected_length * [elem] 40 | return elem 41 | 42 | 43 | def dict_means(dicts): 44 | """ 45 | Compute the mean value of keys in a list of dicts 46 | 47 | :param dicts: Input dicts 48 | :type dicts: List[dict] 49 | :return: Mean values 50 | :rtype: dict 51 | """ 52 | return pd.DataFrame(dicts).mean(axis=0).to_dict() 53 | 54 | 55 | def add_prefix(dct, prefix, sep="/"): 56 | """ 57 | Add a prefix to all keys in `dct`. 58 | 59 | :param dct: Input dict 60 | :type dct: dict 61 | :param prefix: Prefix 62 | :type prefix: str 63 | :param sep: Separator between prefix and key 64 | :type sep: str 65 | :return: Dict with prefix prepended to all keys 66 | :rtype: dict 67 | """ 68 | return {prefix + sep + key: value for key, value in dct.items()} 69 | 70 | 71 | def ordered_cmat(labels, pred): 72 | """ 73 | Compute the confusion matrix and accuracy corresponding to the best cluster-to-class assignment. 74 | 75 | :param labels: Label array 76 | :type labels: np.array 77 | :param pred: Predictions array 78 | :type pred: np.array 79 | :return: Accuracy and confusion matrix 80 | :rtype: Tuple[float, np.array] 81 | """ 82 | cmat = confusion_matrix(labels, pred) 83 | ri, ci = linear_sum_assignment(-cmat) 84 | ordered = cmat[np.ix_(ri, ci)] 85 | acc = np.sum(np.diag(ordered))/np.sum(ordered) 86 | return acc, ordered 87 | 88 | 89 | def get_save_dir(experiment_name, identifier, run): 90 | """ 91 | Get the save dir for an experiment 92 | 93 | :param experiment_name: Name of the config 94 | :type experiment_name: str 95 | :param identifier: 8-character unique identifier for the current experiment 96 | :type identifier: str 97 | :param run: Current training run 98 | :type run: int 99 | :return: Path to save dir 100 | :rtype: pathlib.Path 101 | """ 102 | if not str(run).startswith("run-"): 103 | run = f"run-{run}" 104 | return config.MODELS_DIR / f"{experiment_name}-{identifier}" / run 105 | 106 | 107 | def he_init_weights(module): 108 | """ 109 | Initialize network weights using the He (Kaiming) initialization strategy. 110 | 111 | :param module: Network module 112 | :type module: nn.Module 113 | """ 114 | if isinstance(module, (nn.Conv2d, nn.Linear)): 115 | nn.init.kaiming_normal_(module.weight) 116 | 117 | 118 | def num2tuple(num): 119 | return num if isinstance(num, (tuple, list)) else (num, num) 120 | 121 | 122 | def conv2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1): 123 | """ 124 | Compute the output shape of a convolution operation. 125 | 126 | :param h_w: Height and width of input 127 | :type h_w: Tuple[int, int] 128 | :param kernel_size: Size of kernel 129 | :type kernel_size: Union[int, Tuple[int, int]] 130 | :param stride: Stride of convolution 131 | :type stride: Union[int, Tuple[int, int]] 132 | :param pad: Padding (in pixels) 133 | :type pad: Union[int, Tuple[int, int]] 134 | :param dilation: Dilation 135 | :type dilation: Union[int, Tuple[int, int]] 136 | :return: Height and width of output 137 | :rtype: Tuple[int, int] 138 | """ 139 | h_w, kernel_size, stride, = num2tuple(h_w), num2tuple(kernel_size), num2tuple(stride) 140 | pad, dilation = num2tuple(pad), num2tuple(dilation) 141 | 142 | h = math.floor((h_w[0] + 2 * pad[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) 143 | w = math.floor((h_w[1] + 2 * pad[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1) 144 | return h, w 145 | -------------------------------------------------------------------------------- /src/lib/backbones.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | 4 | import helpers 5 | 6 | 7 | class Backbone(nn.Module): 8 | def __init__(self): 9 | """ 10 | Backbone base class 11 | """ 12 | super().__init__() 13 | self.layers = nn.ModuleList() 14 | 15 | def forward(self, x): 16 | for layer in self.layers: 17 | x = layer(x) 18 | return x 19 | 20 | 21 | class CNN(Backbone): 22 | def __init__(self, cfg, flatten_output=True, **_): 23 | """ 24 | CNN backbone 25 | 26 | :param cfg: CNN config 27 | :type cfg: config.defaults.CNN 28 | :param flatten_output: Flatten the backbone output? 29 | :type flatten_output: bool 30 | :param _: 31 | :type _: 32 | """ 33 | super().__init__() 34 | 35 | self.output_size = list(cfg.input_size) 36 | 37 | for layer_type, *layer_params in cfg.layers: 38 | if layer_type == "conv": 39 | self.layers.append(nn.Conv2d(in_channels=self.output_size[0], out_channels=layer_params[2], 40 | kernel_size=layer_params[:2])) 41 | # Update output size 42 | self.output_size[0] = layer_params[2] 43 | self.output_size[1:] = helpers.conv2d_output_shape(self.output_size[1:], kernel_size=layer_params[:2]) 44 | # Add activation 45 | if layer_params[3] == "relu": 46 | self.layers.append(nn.ReLU()) 47 | 48 | elif layer_type == "pool": 49 | self.layers.append(nn.MaxPool2d(kernel_size=layer_params)) 50 | # Update output size 51 | self.output_size[1:] = helpers.conv2d_output_shape(self.output_size[1:], kernel_size=layer_params, 52 | stride=layer_params) 53 | 54 | elif layer_type == "relu": 55 | self.layers.append(nn.ReLU()) 56 | 57 | elif layer_type == "lrelu": 58 | self.layers.append(nn.LeakyReLU(layer_params[0])) 59 | 60 | elif layer_type == "bn": 61 | if len(self.output_size) > 1: 62 | self.layers.append(nn.BatchNorm2d(num_features=self.output_size[0])) 63 | else: 64 | self.layers.append(nn.BatchNorm1d(num_features=self.output_size[0])) 65 | 66 | elif layer_type == "fc": 67 | self.layers.append(nn.Flatten()) 68 | self.output_size = [np.prod(self.output_size)] 69 | self.layers.append(nn.Linear(self.output_size[0], layer_params[0], bias=True)) 70 | self.output_size = [layer_params[0]] 71 | 72 | else: 73 | raise RuntimeError(f"Unknown layer type: {layer_type}") 74 | 75 | if flatten_output: 76 | self.layers.append(nn.Flatten()) 77 | self.output_size = [np.prod(self.output_size)] 78 | 79 | 80 | class MLP(Backbone): 81 | def __init__(self, cfg, input_size=None, **_): 82 | """ 83 | MLP backbone 84 | 85 | :param cfg: MLP config 86 | :type cfg: config.defaults.MLP 87 | :param input_size: Optional input size which overrides the one set in `cfg`. 88 | :type input_size: Optional[Union[List, Tuple]] 89 | :param _: 90 | :type _: 91 | """ 92 | super().__init__() 93 | self.output_size = self.create_linear_layers(cfg, self.layers, input_size=input_size) 94 | 95 | @staticmethod 96 | def get_activation_module(a): 97 | if a == "relu": 98 | return nn.ReLU() 99 | elif a == "sigmoid": 100 | return nn.Sigmoid() 101 | elif a == "tanh": 102 | return nn.Tanh() 103 | elif a == "softmax": 104 | return nn.Softmax(dim=1) 105 | elif a.startswith("leaky_relu"): 106 | neg_slope = float(a.split(":")[1]) 107 | return nn.LeakyReLU(neg_slope) 108 | else: 109 | raise RuntimeError(f"Invalid MLP activation: {a}.") 110 | 111 | @classmethod 112 | def create_linear_layers(cls, cfg, layer_container, input_size=None): 113 | # `input_size` takes priority over `cfg.input_size` 114 | if input_size is not None: 115 | output_size = list(input_size) 116 | else: 117 | output_size = list(cfg.input_size) 118 | 119 | if len(output_size) > 1: 120 | layer_container.append(nn.Flatten()) 121 | output_size = [np.prod(output_size)] 122 | 123 | n_layers = len(cfg.layers) 124 | activations = helpers.ensure_iterable(cfg.activation, expected_length=n_layers) 125 | use_bias = helpers.ensure_iterable(cfg.use_bias, expected_length=n_layers) 126 | use_bn = helpers.ensure_iterable(cfg.use_bn, expected_length=n_layers) 127 | 128 | for n_units, act, _use_bias, _use_bn in zip(cfg.layers, activations, use_bias, use_bn): 129 | # If we get n_units = -1, then the number of units should be the same as the previous number of units, or 130 | # the input dim. 131 | if n_units == -1: 132 | n_units = output_size[0] 133 | 134 | layer_container.append(nn.Linear(in_features=output_size[0], out_features=n_units, bias=_use_bias)) 135 | if _use_bn: 136 | # Add BN before activation 137 | layer_container.append(nn.BatchNorm1d(num_features=n_units)) 138 | if act is not None: 139 | # Add activation 140 | layer_container.append(cls.get_activation_module(act)) 141 | output_size[0] = n_units 142 | 143 | return output_size 144 | 145 | 146 | class Backbones(nn.Module): 147 | BACKBONE_CONSTRUCTORS = { 148 | "CNN": CNN, 149 | "MLP": MLP 150 | } 151 | 152 | def __init__(self, backbone_configs, flatten_output=True): 153 | """ 154 | Class representing multiple backbones. Call with list of inputs, where inputs[0] goes into the first backbone, 155 | and so on. 156 | 157 | :param backbone_configs: List of backbone configs. Each element corresponds to a backbone. 158 | :type backbone_configs: List[Union[config.defaults.MLP, config.defaults.CNN], ...] 159 | :param flatten_output: Flatten the backbone outputs? 160 | :type flatten_output: bool 161 | """ 162 | super().__init__() 163 | 164 | self.backbones = nn.ModuleList() 165 | for cfg in backbone_configs: 166 | self.backbones.append(self.create_backbone(cfg, flatten_output=flatten_output)) 167 | 168 | @property 169 | def output_sizes(self): 170 | return [bb.output_size for bb in self.backbones] 171 | 172 | @classmethod 173 | def create_backbone(cls, cfg, flatten_output=True): 174 | if cfg.class_name not in cls.BACKBONE_CONSTRUCTORS: 175 | raise RuntimeError(f"Invalid backbone: '{cfg.class_name}'") 176 | return cls.BACKBONE_CONSTRUCTORS[cfg.class_name](cfg, flatten_output=flatten_output) 177 | 178 | def forward(self, views): 179 | assert len(views) == len(self.backbones), f"n_views ({len(views)}) != n_backbones ({len(self.backbones)})." 180 | outputs = [bb(v) for bb, v in zip(self.backbones, views)] 181 | return outputs 182 | 183 | 184 | -------------------------------------------------------------------------------- /src/lib/fusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | import torch.nn as nn 4 | 5 | 6 | class _Fusion(nn.Module): 7 | def __init__(self, cfg, input_sizes): 8 | """ 9 | Base class for the fusion module 10 | 11 | :param cfg: Fusion config. See config.defaults.Fusion 12 | :param input_sizes: Input shapes 13 | """ 14 | super().__init__() 15 | self.cfg = cfg 16 | self.input_sizes = input_sizes 17 | self.output_size = None 18 | 19 | def forward(self, inputs): 20 | raise NotImplementedError() 21 | 22 | @classmethod 23 | def get_weighted_sum_output_size(cls, input_sizes): 24 | flat_sizes = [np.prod(s) for s in input_sizes] 25 | assert all(s == flat_sizes[0] for s in flat_sizes), f"Fusion method {cls.__name__} requires the flat output" \ 26 | f" shape from all backbones to be identical." \ 27 | f" Got sizes: {input_sizes} -> {flat_sizes}." 28 | return [flat_sizes[0]] 29 | 30 | def get_weights(self, softmax=True): 31 | out = [] 32 | if hasattr(self, "weights"): 33 | out = self.weights 34 | if softmax: 35 | out = nn.functional.softmax(self.weights, dim=-1) 36 | return out 37 | 38 | def update_weights(self, inputs, a): 39 | pass 40 | 41 | 42 | class Mean(_Fusion): 43 | def __init__(self, cfg, input_sizes): 44 | """ 45 | Mean fusion. 46 | 47 | :param cfg: Fusion config. See config.defaults.Fusion 48 | :param input_sizes: Input shapes 49 | """ 50 | super().__init__(cfg, input_sizes) 51 | self.output_size = self.get_weighted_sum_output_size(input_sizes) 52 | 53 | def forward(self, inputs): 54 | return th.mean(th.stack(inputs, -1), dim=-1) 55 | 56 | 57 | class WeightedMean(_Fusion): 58 | """ 59 | Weighted mean fusion. 60 | 61 | :param cfg: Fusion config. See config.defaults.Fusion 62 | :param input_sizes: Input shapes 63 | """ 64 | def __init__(self, cfg, input_sizes): 65 | super().__init__(cfg, input_sizes) 66 | self.weights = nn.Parameter(th.full((self.cfg.n_views,), 1 / self.cfg.n_views), requires_grad=True) 67 | self.output_size = self.get_weighted_sum_output_size(input_sizes) 68 | 69 | def forward(self, inputs): 70 | return _weighted_sum(inputs, self.weights, normalize_weights=True) 71 | 72 | 73 | def _weighted_sum(tensors, weights, normalize_weights=True): 74 | if normalize_weights: 75 | weights = nn.functional.softmax(weights, dim=0) 76 | out = th.sum(weights[None, None, :] * th.stack(tensors, dim=-1), dim=-1) 77 | return out 78 | 79 | 80 | MODULES = { 81 | "mean": Mean, 82 | "weighted_mean": WeightedMean, 83 | } 84 | 85 | 86 | def get_fusion_module(cfg, input_sizes): 87 | return MODULES[cfg.method](cfg, input_sizes) 88 | -------------------------------------------------------------------------------- /src/lib/kernel.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch.nn.functional import relu 3 | 4 | 5 | EPSILON = 1E-9 6 | 7 | 8 | def kernel_from_distance_matrix(dist, rel_sigma, min_sigma=EPSILON): 9 | """ 10 | Compute a Gaussian kernel matrix from a distance matrix. 11 | 12 | :param dist: Disatance matrix 13 | :type dist: th.Tensor 14 | :param rel_sigma: Multiplication factor for the sigma hyperparameter 15 | :type rel_sigma: float 16 | :param min_sigma: Minimum value for sigma. For numerical stability. 17 | :type min_sigma: float 18 | :return: Kernel matrix 19 | :rtype: th.Tensor 20 | """ 21 | # `dist` can sometimes contain negative values due to floating point errors, so just set these to zero. 22 | dist = relu(dist) 23 | sigma2 = rel_sigma * th.median(dist) 24 | # Disable gradient for sigma 25 | sigma2 = sigma2.detach() 26 | sigma2 = th.where(sigma2 < min_sigma, sigma2.new_tensor(min_sigma), sigma2) 27 | k = th.exp(- dist / (2 * sigma2)) 28 | return k 29 | 30 | 31 | def vector_kernel(x, rel_sigma=0.15): 32 | """ 33 | Compute a kernel matrix from the rows of a matrix. 34 | 35 | :param x: Input matrix 36 | :type x: th.Tensor 37 | :param rel_sigma: Multiplication factor for the sigma hyperparameter 38 | :type rel_sigma: float 39 | :return: Kernel matrix 40 | :rtype: th.Tensor 41 | """ 42 | return kernel_from_distance_matrix(cdist(x, x), rel_sigma) 43 | 44 | 45 | def cdist(X, Y): 46 | """ 47 | Pairwise distance between rows of X and rows of Y. 48 | 49 | :param X: First input matrix 50 | :type X: th.Tensor 51 | :param Y: Second input matrix 52 | :type Y: th.Tensor 53 | :return: Matrix containing pairwise distances between rows of X and rows of Y 54 | :rtype: th.Tensor 55 | """ 56 | xyT = X @ th.t(Y) 57 | x2 = th.sum(X**2, dim=1, keepdim=True) 58 | y2 = th.sum(Y**2, dim=1, keepdim=True) 59 | d = x2 - 2 * xyT + th.t(y2) 60 | return d 61 | -------------------------------------------------------------------------------- /src/lib/loss.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | 4 | import config 5 | from lib import kernel 6 | 7 | EPSILON = 1E-9 8 | DEBUG_MODE = False 9 | 10 | 11 | def triu(X): 12 | # Sum of strictly upper triangular part 13 | return th.sum(th.triu(X, diagonal=1)) 14 | 15 | 16 | def _atleast_epsilon(X, eps=EPSILON): 17 | """ 18 | Ensure that all elements are >= `eps`. 19 | 20 | :param X: Input elements 21 | :type X: th.Tensor 22 | :param eps: epsilon 23 | :type eps: float 24 | :return: New version of X where elements smaller than `eps` have been replaced with `eps`. 25 | :rtype: th.Tensor 26 | """ 27 | return th.where(X < eps, X.new_tensor(eps), X) 28 | 29 | 30 | def d_cs(A, K, n_clusters): 31 | """ 32 | Cauchy-Schwarz divergence. 33 | 34 | :param A: Cluster assignment matrix 35 | :type A: th.Tensor 36 | :param K: Kernel matrix 37 | :type K: th.Tensor 38 | :param n_clusters: Number of clusters 39 | :type n_clusters: int 40 | :return: CS-divergence 41 | :rtype: th.Tensor 42 | """ 43 | nom = th.t(A) @ K @ A 44 | dnom_squared = th.unsqueeze(th.diagonal(nom), -1) @ th.unsqueeze(th.diagonal(nom), 0) 45 | 46 | nom = _atleast_epsilon(nom) 47 | dnom_squared = _atleast_epsilon(dnom_squared, eps=EPSILON**2) 48 | 49 | d = 2 / (n_clusters * (n_clusters - 1)) * triu(nom / th.sqrt(dnom_squared)) 50 | return d 51 | 52 | 53 | # ====================================================================================================================== 54 | # Loss terms 55 | # ====================================================================================================================== 56 | 57 | class LossTerm: 58 | # Names of tensors required for the loss computation 59 | required_tensors = [] 60 | 61 | def __init__(self, *args, **kwargs): 62 | """ 63 | Base class for a term in the loss function. 64 | 65 | :param args: 66 | :type args: 67 | :param kwargs: 68 | :type kwargs: 69 | """ 70 | pass 71 | 72 | def __call__(self, net, cfg, extra): 73 | raise NotImplementedError() 74 | 75 | 76 | class DDC1(LossTerm): 77 | """ 78 | L_1 loss from DDC 79 | """ 80 | required_tensors = ["hidden_kernel"] 81 | 82 | def __call__(self, net, cfg, extra): 83 | return d_cs(net.output, extra["hidden_kernel"], cfg.n_clusters) 84 | 85 | 86 | class DDC2(LossTerm): 87 | """ 88 | L_2 loss from DDC 89 | """ 90 | def __call__(self, net, cfg, extra): 91 | n = net.output.size(0) 92 | return 2 / (n * (n - 1)) * triu(net.output @ th.t(net.output)) 93 | 94 | 95 | class DDC2Flipped(LossTerm): 96 | """ 97 | Flipped version of the L_2 loss from DDC. Used by EAMC 98 | """ 99 | 100 | def __call__(self, net, cfg, extra): 101 | return 2 / (cfg.n_clusters * (cfg.n_clusters - 1)) * triu(th.t(net.output) @ net.output) 102 | 103 | 104 | class DDC3(LossTerm): 105 | """ 106 | L_3 loss from DDC 107 | """ 108 | required_tensors = ["hidden_kernel"] 109 | 110 | def __init__(self, cfg): 111 | super().__init__() 112 | self.eye = th.eye(cfg.n_clusters, device=config.DEVICE) 113 | 114 | def __call__(self, net, cfg, extra): 115 | m = th.exp(-kernel.cdist(net.output, self.eye)) 116 | return d_cs(m, extra["hidden_kernel"], cfg.n_clusters) 117 | 118 | 119 | class Contrastive(LossTerm): 120 | large_num = 1e9 121 | 122 | def __init__(self, cfg): 123 | """ 124 | Contrastive loss function 125 | 126 | :param cfg: Loss function config 127 | :type cfg: config.defaults.Loss 128 | """ 129 | super().__init__() 130 | # Select which implementation to use 131 | if cfg.negative_samples_ratio == -1: 132 | self._loss_func = self._loss_without_negative_sampling 133 | else: 134 | self.eye = th.eye(cfg.n_clusters, device=config.DEVICE) 135 | self._loss_func = self._loss_with_negative_sampling 136 | 137 | # Set similarity function 138 | if cfg.contrastive_similarity == "cos": 139 | self.similarity_func = self._cosine_similarity 140 | elif cfg.contrastive_similarity == "gauss": 141 | self.similarity_func = kernel.vector_kernel 142 | else: 143 | raise RuntimeError(f"Invalid contrastive similarity: {cfg.contrastive_similarity}") 144 | 145 | @staticmethod 146 | def _norm(mat): 147 | return th.nn.functional.normalize(mat, p=2, dim=1) 148 | 149 | @staticmethod 150 | def get_weight(net): 151 | w = th.min(th.nn.functional.softmax(net.fusion.weights.detach(), dim=0)) 152 | return w 153 | 154 | @classmethod 155 | def _normalized_projections(cls, net): 156 | n = net.projections.size(0) // 2 157 | h1, h2 = net.projections[:n], net.projections[n:] 158 | h2 = cls._norm(h2) 159 | h1 = cls._norm(h1) 160 | return n, h1, h2 161 | 162 | @classmethod 163 | def _cosine_similarity(cls, projections): 164 | h = cls._norm(projections) 165 | return h @ h.t() 166 | 167 | def _draw_negative_samples(self, net, cfg, v, pos_indices): 168 | """ 169 | Construct set of negative samples. 170 | 171 | :param net: Model 172 | :type net: Union[models.simple_mvc.SiMVC, models.contrastive_mvc.CoMVC] 173 | :param cfg: Loss config 174 | :type cfg: config.defaults.Loss 175 | :param v: Number of views 176 | :type v: int 177 | :param pos_indices: Row indices of the positive samples in the concatenated similarity matrix 178 | :type pos_indices: th.Tensor 179 | :return: Indices of negative samples 180 | :rtype: th.Tensor 181 | """ 182 | cat = net.output.detach().argmax(dim=1) 183 | cat = th.cat(v * [cat], dim=0) 184 | 185 | weights = (1 - self.eye[cat])[:, cat[[pos_indices]]].T 186 | n_negative_samples = int(cfg.negative_samples_ratio * cat.size(0)) 187 | negative_sample_indices = th.multinomial(weights, n_negative_samples, replacement=True) 188 | if DEBUG_MODE: 189 | self._check_negative_samples_valid(cat, pos_indices, negative_sample_indices) 190 | return negative_sample_indices 191 | 192 | @staticmethod 193 | def _check_negative_samples_valid(cat, pos_indices, neg_indices): 194 | pos_cats = cat[pos_indices].view(-1, 1) 195 | neg_cats = cat[neg_indices] 196 | assert (pos_cats != neg_cats).detach().cpu().numpy().all() 197 | 198 | @staticmethod 199 | def _get_positive_samples(logits, v, n): 200 | """ 201 | Get positive samples 202 | 203 | :param logits: Input similarities 204 | :type logits: th.Tensor 205 | :param v: Number of views 206 | :type v: int 207 | :param n: Number of samples per view (batch size) 208 | :type n: int 209 | :return: Similarities of positive pairs, and their indices 210 | :rtype: Tuple[th.Tensor, th.Tensor] 211 | """ 212 | diagonals = [] 213 | inds = [] 214 | for i in range(1, v): 215 | diagonal_offset = i * n 216 | diag_length = (v - i) * n 217 | _upper = th.diagonal(logits, offset=diagonal_offset) 218 | _lower = th.diagonal(logits, offset=-1 * diagonal_offset) 219 | _upper_inds = th.arange(0, diag_length) 220 | _lower_inds = th.arange(i * n, v * n) 221 | if DEBUG_MODE: 222 | assert _upper.size() == _lower.size() == _upper_inds.size() == _lower_inds.size() == (diag_length,) 223 | diagonals += [_upper, _lower] 224 | inds += [_upper_inds, _lower_inds] 225 | 226 | pos = th.cat(diagonals, dim=0) 227 | pos_inds = th.cat(inds, dim=0) 228 | return pos, pos_inds 229 | 230 | def _loss_with_negative_sampling(self, net, cfg, extra): 231 | """ 232 | Contrastive loss implementation with negative sampling. 233 | 234 | :param net: Model 235 | :type net: Union[models.simple_mvc.SiMVC, models.contrastive_mvc.CoMVC] 236 | :param cfg: Loss config 237 | :type cfg: config.defaults.Loss 238 | :param extra: 239 | :type extra: 240 | :return: Loss value 241 | :rtype: th.Tensor 242 | """ 243 | n = net.output.size(0) 244 | v = len(net.backbone_outputs) 245 | logits = self.similarity_func(net.projections) / cfg.tau 246 | 247 | pos, pos_inds = self._get_positive_samples(logits, v, n) 248 | neg_inds = self._draw_negative_samples(net, cfg, v, pos_inds) 249 | neg = logits[pos_inds.view(-1, 1), neg_inds] 250 | 251 | inputs = th.cat((pos.view(-1, 1), neg), dim=1) 252 | labels = th.zeros(v * (v - 1) * n, device=config.DEVICE, dtype=th.long) 253 | loss = th.nn.functional.cross_entropy(inputs, labels) 254 | 255 | if cfg.adaptive_contrastive_weight: 256 | loss *= self.get_weight(net) 257 | 258 | return cfg.delta * loss 259 | 260 | def _loss_without_negative_sampling(self, net, cfg, extra): 261 | """ 262 | Contrastive loss implementation without negative sampling. 263 | Adapted from: https://github.com/google-research/simclr/blob/master/objective.py 264 | 265 | :param net: Model 266 | :type net: Union[models.simple_mvc.SiMVC, models.contrastive_mvc.CoMVC] 267 | :param cfg: Loss config 268 | :type cfg: config.defaults.Loss 269 | :param extra: 270 | :type extra: 271 | :return: 272 | :rtype: 273 | """ 274 | assert len(net.backbone_outputs) == 2, "Contrastive loss without negative sampling only supports 2 views." 275 | n, h1, h2 = self._normalized_projections(net) 276 | 277 | labels = th.arange(0, n, device=config.DEVICE, dtype=th.long) 278 | masks = th.eye(n, device=config.DEVICE) 279 | 280 | logits_aa = ((h1 @ h1.t()) / cfg.tau) - masks * self.large_num 281 | logits_bb = ((h2 @ h2.t()) / cfg.tau) - masks * self.large_num 282 | 283 | logits_ab = (h1 @ h2.t()) / cfg.tau 284 | logits_ba = (h2 @ h1.t()) / cfg.tau 285 | 286 | loss_a = th.nn.functional.cross_entropy(th.cat((logits_ab, logits_aa), dim=1), labels) 287 | loss_b = th.nn.functional.cross_entropy(th.cat((logits_ba, logits_bb), dim=1), labels) 288 | 289 | loss = (loss_a + loss_b) 290 | 291 | if cfg.adaptive_contrastive_weight: 292 | loss *= self.get_weight(net) 293 | 294 | return cfg.delta * loss 295 | 296 | def __call__(self, net, cfg, extra): 297 | return self._loss_func(net, cfg, extra) 298 | 299 | 300 | # ====================================================================================================================== 301 | # Extra functions 302 | # ====================================================================================================================== 303 | 304 | def hidden_kernel(net, cfg): 305 | return kernel.vector_kernel(net.hidden, cfg.rel_sigma) 306 | 307 | 308 | # ====================================================================================================================== 309 | # Loss class 310 | # ====================================================================================================================== 311 | 312 | class Loss(nn.Module): 313 | # Possible terms to include in the loss 314 | TERM_CLASSES = { 315 | "ddc_1": DDC1, 316 | "ddc_2": DDC2, 317 | "ddc_2_flipped": DDC2Flipped, 318 | "ddc_3": DDC3, 319 | "contrast": Contrastive, 320 | } 321 | # Functions to compute the required tensors for the terms. 322 | EXTRA_FUNCS = { 323 | "hidden_kernel": hidden_kernel, 324 | } 325 | 326 | def __init__(self, cfg): 327 | """ 328 | Implementation of a general loss function 329 | 330 | :param cfg: Loss function config 331 | :type cfg: config.defaults.Loss 332 | """ 333 | super().__init__() 334 | self.cfg = cfg 335 | 336 | self.names = cfg.funcs.split("|") 337 | self.weights = cfg.weights if cfg.weights is not None else len(self.names) * [1] 338 | 339 | self.terms = [] 340 | for term_name in self.names: 341 | self.terms.append(self.TERM_CLASSES[term_name](cfg)) 342 | 343 | self.required_extras_names = list(set(sum([t.required_tensors for t in self.terms], []))) 344 | 345 | def forward(self, net, ignore_in_total=tuple()): 346 | extra = {name: self.EXTRA_FUNCS[name](net, self.cfg) for name in self.required_extras_names} 347 | loss_values = {} 348 | for name, term, weight in zip(self.names, self.terms, self.weights): 349 | value = term(net, self.cfg, extra) 350 | # If we got a dict, add each term from the dict with "name/" as the scope. 351 | if isinstance(value, dict): 352 | for key, _value in value.items(): 353 | loss_values[f"{name}/{key}"] = weight * _value 354 | # Otherwise, just add the value to the dict directly 355 | else: 356 | loss_values[name] = weight * value 357 | 358 | loss_values["tot"] = sum([loss_values[k] for k in loss_values.keys() if k not in ignore_in_total]) 359 | return loss_values 360 | 361 | -------------------------------------------------------------------------------- /src/lib/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | 4 | class Optimizer: 5 | def __init__(self, cfg, params): 6 | """ 7 | Wrapper class for optimizers 8 | 9 | :param cfg: Optimizer config 10 | :type cfg: config.defaults.Optimizer 11 | :param params: Parameters to associate with the optimizer 12 | :type params: 13 | """ 14 | self.clip_norm = cfg.clip_norm 15 | self.params = params 16 | self._opt = th.optim.Adam(params, lr=cfg.learning_rate) 17 | if cfg.scheduler_step_size is not None: 18 | assert cfg.scheduler_gamma is not None 19 | self._sch = th.optim.lr_scheduler.StepLR(self._opt, step_size=cfg.scheduler_step_size, 20 | gamma=cfg.scheduler_gamma) 21 | else: 22 | self._sch = None 23 | 24 | def zero_grad(self): 25 | return self._opt.zero_grad() 26 | 27 | def step(self, epoch): 28 | if self._sch is not None: 29 | # Only step the scheduler at integer epochs, and don't step on the first epoch. 30 | if epoch.is_integer() and epoch > 0: 31 | self._sch.step() 32 | 33 | if self.clip_norm is not None: 34 | th.nn.utils.clip_grad_norm_(self.params, self.clip_norm) 35 | 36 | out = self._opt.step() 37 | return out 38 | -------------------------------------------------------------------------------- /src/models/build_model.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | import config 4 | import helpers 5 | from models.ddc import DDCModel 6 | from models.simple_mvc import SiMVC 7 | from models.contrastive_mvc import CoMVC 8 | from eamc.model import EAMC 9 | from data.load import load_dataset 10 | 11 | 12 | MODEL_CONSTRUCTORS = { 13 | "DDCModel": DDCModel, 14 | "SiMVC": SiMVC, 15 | "CoMVC": CoMVC, 16 | "EAMC": EAMC 17 | } 18 | 19 | 20 | def build_model(model_cfg): 21 | """ 22 | Build the model specified by `model_cfg`. 23 | 24 | :param model_cfg: Config of model to build 25 | :type model_cfg: Union[config.defaults.DDCModel, config.defaults.SiMVC, config.defaults.CoMVC, 26 | config.eamc.defaults.EAMC] 27 | :return: Model 28 | :rtype: Union[DDCModel, SiMVC, CoMVC, EAMC] 29 | """ 30 | if model_cfg.class_name not in MODEL_CONSTRUCTORS: 31 | raise ValueError(f"Invalid model type: {model_cfg.type}") 32 | model = MODEL_CONSTRUCTORS[model_cfg.class_name](model_cfg).to(config.DEVICE, non_blocking=True) 33 | return model 34 | 35 | 36 | def from_file(experiment_name=None, tag=None, run=None, ckpt="best", return_data=False, return_config=False, **kwargs): 37 | """ 38 | Load a trained from disc 39 | 40 | :param experiment_name: Name of the experiment (name of the config) 41 | :type experiment_name: str 42 | :param tag: 8-character experiment identifier 43 | :type tag: str 44 | :param run: Training run to load 45 | :type run: int 46 | :param ckpt: Checkpoint to load. Specify a valid checkpoint, or "best" to load the best model. 47 | :type ckpt: Union[int, str] 48 | :param return_data: Return the dataset? 49 | :type return_data: bool 50 | :param return_config: Return the experiment config? 51 | :type return_config: bool 52 | :param kwargs: 53 | :type kwargs: 54 | :return: Loaded model, dataset (if return_data == True), config (if return_config == True) 55 | :rtype: 56 | """ 57 | try: 58 | cfg = config.get_config_from_file(name=experiment_name, tag=tag) 59 | except FileNotFoundError: 60 | print("WARNING: Could not get pickled config.") 61 | cfg = config.get_config_by_name(experiment_name) 62 | 63 | model_dir = helpers.get_save_dir(experiment_name, identifier=tag, run=run) 64 | if ckpt == "best": 65 | model_file = "best.pt" 66 | else: 67 | model_file = f"checkpoint_{str(ckpt).zfill(4)}.pt" 68 | 69 | model_path = model_dir / model_file 70 | net = build_model(cfg.model_config) 71 | print(f"Loading model from {model_path}") 72 | net.load_state_dict(th.load(model_path, map_location=config.DEVICE)) 73 | net.eval() 74 | 75 | out = [net] 76 | 77 | if return_data: 78 | dataset_kwargs = cfg.dataset_config.dict() 79 | for key, value in kwargs.items(): 80 | dataset_kwargs[key] = value 81 | views, labels = load_dataset(to_dataset=False, **dataset_kwargs) 82 | out = [net, views, labels] 83 | 84 | if return_config: 85 | out.append(cfg) 86 | 87 | if len(out) == 1: 88 | out = out[0] 89 | 90 | return out 91 | -------------------------------------------------------------------------------- /src/models/callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import pickle 4 | import numpy as np 5 | import torch as th 6 | from tabulate import tabulate 7 | 8 | import helpers 9 | 10 | 11 | class Callback: 12 | def __init__(self, epoch_interval=1, batch_interval=1): 13 | """ 14 | Base class for training callbacks 15 | 16 | :param epoch_interval: Number of epochs between calling `at_epoch_end`. 17 | :type epoch_interval: int 18 | :param batch_interval: Number of batches between calling `at_batch_end`. 19 | :type batch_interval: int 20 | """ 21 | self.epoch_interval = epoch_interval 22 | self.batch_interval = batch_interval 23 | 24 | def epoch_end(self, epoch, **kwargs): 25 | if not (epoch % self.epoch_interval): 26 | return self.at_epoch_end(epoch, **kwargs) 27 | 28 | def batch_end(self, epoch, batch, **kwargs): 29 | if (not (epoch % self.epoch_interval)) and (not (batch % self.batch_interval)): 30 | return self.at_batch_end(epoch, batch, **kwargs) 31 | 32 | def at_epoch_end(self, epoch, logs=None, net=None, **kwargs): 33 | pass 34 | 35 | def at_batch_end(self, epoch, batch, outputs=None, losses=None, net=None, **kwargs): 36 | pass 37 | 38 | def at_eval(self, net=None, logs=None, **kwargs): 39 | pass 40 | 41 | 42 | class Printer(Callback): 43 | def __init__(self, print_confusion_matrix=True, **kwargs): 44 | """ 45 | Print logs to the terminal. 46 | 47 | :param print_confusion_matrix: Print the confusion matrix when it is available? 48 | :type print_confusion_matrix: bool 49 | :param kwargs: 50 | :type kwargs: 51 | """ 52 | super().__init__(**kwargs) 53 | self.ignore_keys = ["iter_losses/"] 54 | if not print_confusion_matrix: 55 | self.ignore_keys.append("metrics/cmat") 56 | 57 | np.set_printoptions(edgeitems=20, linewidth=200) 58 | 59 | def at_epoch_end(self, epoch, logs=None, net=None, **kwargs): 60 | print_logs = logs.copy() 61 | for key in logs.keys(): 62 | if any([key.startswith(ik) for ik in self.ignore_keys]): 63 | del print_logs[key] 64 | 65 | headers = ["Key", "Value"] 66 | values = list(print_logs.items()) 67 | print(tabulate(values, headers=headers), "\n") 68 | 69 | 70 | class ModelSaver(Callback): 71 | def __init__(self, cfg, experiment_name, identifier, run, best_loss_term, checkpoint_interval=1, **kwargs): 72 | """ 73 | Model saver callback. Saves model at specified checkpoints, or when `best_loss_term` in the loss function 74 | reaches the lowest observed value. 75 | 76 | :param cfg: Experiment config 77 | :type cfg: config.defaults.Experiment 78 | :param experiment_name: Name of the experiment 79 | :type experiment_name: str 80 | :param identifier: 8-character unique experiment identifier 81 | :type identifier: str 82 | :param run: Current training run 83 | :type run: int 84 | :param best_loss_term: Term in the loss function to monitor. 85 | :type best_loss_term: str 86 | :param checkpoint_interval: Number of epochs between saving model checkpoints. 87 | :type checkpoint_interval: int 88 | :param kwargs: 89 | :type kwargs: 90 | """ 91 | super().__init__(**kwargs) 92 | 93 | self.best_loss_term = f"eval_losses/{best_loss_term}" 94 | self.min_loss = np.inf 95 | self.checkpoint_interval = checkpoint_interval 96 | self.save_dir = helpers.get_save_dir(experiment_name, identifier, run) 97 | os.makedirs(self.save_dir, exist_ok=True) 98 | self._save_cfg(cfg) 99 | 100 | def _save_cfg(self, cfg): 101 | with open(self.save_dir / "config.yml", "w") as f: 102 | yaml.dump(cfg.dict(), f) 103 | with open(self.save_dir / "config.pkl", "wb") as f: 104 | pickle.dump(cfg, f) 105 | 106 | def _save_model(self, file_name, net): 107 | model_path = self.save_dir / file_name 108 | th.save(net.state_dict(), model_path) 109 | print(f"Model successfully saved: {model_path}") 110 | 111 | def at_epoch_end(self, epoch, outputs=None, logs=None, net=None, **kwargs): 112 | if not (epoch % self.checkpoint_interval): 113 | # Save model checkpoint 114 | self._save_model(f"checkpoint_{str(epoch).zfill(4)}.pt", net) 115 | 116 | avg_loss = logs.get(self.best_loss_term, np.inf) 117 | # Save to model_best if the current loss is the lowest loss encountered 118 | if avg_loss < self.min_loss: 119 | self.min_loss = avg_loss 120 | self._save_model("best.pt", net) 121 | 122 | 123 | class StopTraining(Exception): 124 | pass 125 | 126 | 127 | class EarlyStopping(Callback): 128 | def __init__(self, patience, best_loss_term, **kwargs): 129 | """ 130 | Early stopping callback. Raises a `StopTraining` exception when the term `best_loss_term` in the loss function 131 | has not decreased in `patience` epochs. 132 | 133 | :param patience: Number of epochs to wait for loss decrease 134 | :type patience: int 135 | :param best_loss_term: Term in the loss function to monitor. 136 | :type best_loss_term: str 137 | :param kwargs: 138 | :type kwargs: 139 | """ 140 | super().__init__(**kwargs) 141 | self.best_loss_term = f"eval_losses/{best_loss_term}" 142 | self.patience = patience 143 | self.min_loss = np.inf 144 | self.best_epoch = 0 145 | 146 | def at_epoch_end(self, epoch, outputs=None, logs=None, net=None, **kwargs): 147 | avg_loss = logs.get(self.best_loss_term, np.inf) 148 | 149 | if np.isnan(avg_loss): 150 | raise StopTraining(f"Got loss = NaN. Training stopped.") 151 | 152 | if avg_loss < self.min_loss: 153 | self.min_loss = avg_loss 154 | self.best_epoch = epoch 155 | 156 | if (epoch - self.best_epoch) >= self.patience: 157 | raise StopTraining(f"Loss has not decreased in {self.patience} epochs. Min. loss was {self.min_loss}. " 158 | f"Training stopped.") 159 | 160 | -------------------------------------------------------------------------------- /src/models/clustering_module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class DDC(nn.Module): 5 | def __init__(self, input_dim, cfg): 6 | """ 7 | DDC clustering module 8 | 9 | :param input_dim: Shape of inputs. 10 | :param cfg: DDC config. See `config.defaults.DDC` 11 | """ 12 | super().__init__() 13 | 14 | hidden_layers = [nn.Linear(input_dim[0], cfg.n_hidden), nn.ReLU()] 15 | if cfg.use_bn: 16 | hidden_layers.append(nn.BatchNorm1d(num_features=cfg.n_hidden)) 17 | self.hidden = nn.Sequential(*hidden_layers) 18 | self.output = nn.Sequential(nn.Linear(cfg.n_hidden, cfg.n_clusters), nn.Softmax(dim=1)) 19 | 20 | def forward(self, x): 21 | hidden = self.hidden(x) 22 | output = self.output(hidden) 23 | return output, hidden 24 | -------------------------------------------------------------------------------- /src/models/contrastive_mvc.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | 4 | import helpers 5 | from lib.loss import Loss 6 | from lib.optimizer import Optimizer 7 | from lib.backbones import Backbones, MLP 8 | from lib.fusion import get_fusion_module 9 | from models.clustering_module import DDC 10 | from models.model_base import ModelBase 11 | 12 | 13 | class CoMVC(ModelBase): 14 | def __init__(self, cfg): 15 | """ 16 | Implementation of the CoMVC model. 17 | 18 | :param cfg: Model config. See `config.defaults.CoMVC` for documentation on the config object. 19 | """ 20 | super().__init__() 21 | 22 | self.cfg = cfg 23 | self.output = self.hidden = self.fused = self.backbone_outputs = self.projections = None 24 | 25 | # Define Backbones and Fusion modules 26 | self.backbones = Backbones(cfg.backbone_configs) 27 | self.fusion = get_fusion_module(cfg.fusion_config, self.backbones.output_sizes) 28 | 29 | bb_sizes = self.backbones.output_sizes 30 | assert all([bb_sizes[0] == s for s in bb_sizes]), f"CoMVC requires all backbones to have the same " \ 31 | f"output size. Got: {bb_sizes}" 32 | 33 | if cfg.projector_config is None: 34 | self.projector = nn.Identity() 35 | else: 36 | self.projector = MLP(cfg.projector_config, input_size=bb_sizes[0]) 37 | 38 | # Define clustering module 39 | self.ddc = DDC(input_dim=self.fusion.output_size, cfg=cfg.cm_config) 40 | # Define loss-module 41 | self.loss = Loss(cfg=cfg.loss_config) 42 | # Initialize weights. 43 | self.apply(helpers.he_init_weights) 44 | # Instantiate optimizer 45 | self.optimizer = Optimizer(cfg.optimizer_config, self.parameters()) 46 | 47 | def forward(self, views): 48 | self.backbone_outputs = self.backbones(views) 49 | self.fused = self.fusion(self.backbone_outputs) 50 | self.projections = self.projector(th.cat(self.backbone_outputs, dim=0)) 51 | self.output, self.hidden = self.ddc(self.fused) 52 | return self.output 53 | 54 | -------------------------------------------------------------------------------- /src/models/ddc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import helpers 4 | from lib.loss import Loss 5 | from lib.optimizer import Optimizer 6 | from lib.backbones import Backbones 7 | from models.model_base import ModelBase 8 | from models.clustering_module import DDC 9 | 10 | 11 | class DDCModel(ModelBase): 12 | def __init__(self, cfg): 13 | """ 14 | Full DDC model 15 | 16 | :param cfg: DDC model config 17 | :type cfg: config.defaults.DDCModel 18 | """ 19 | super().__init__() 20 | 21 | self.cfg = cfg 22 | self.backbone_output = self.output = self.hidden = None 23 | self.backbone = Backbones.create_backbone(cfg.backbone_config) 24 | self.ddc_input_size = np.prod(self.backbone.output_size) 25 | self.ddc = DDC([self.ddc_input_size], cfg.cm_config) 26 | self.loss = Loss(cfg.loss_config) 27 | 28 | # Initialize weights. 29 | self.apply(helpers.he_init_weights) 30 | # Instantiate optimizer 31 | self.optimizer = Optimizer(cfg.optimizer_config, self.parameters()) 32 | 33 | def forward(self, x): 34 | if isinstance(x, list): 35 | # We might get a one-element list as input due to multi-view compatibility. 36 | assert len(x) == 1 37 | x = x[0] 38 | 39 | self.backbone_output = self.backbone(x).view(-1, self.ddc_input_size) 40 | self.output, self.hidden = self.ddc(self.backbone_output) 41 | return self.output 42 | -------------------------------------------------------------------------------- /src/models/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import torch as th 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | from tabulate import tabulate 8 | from sklearn.metrics import normalized_mutual_info_score 9 | 10 | import helpers 11 | from models.build_model import from_file 12 | 13 | IGNORE_IN_TOTAL = ("contrast",) 14 | 15 | 16 | def calc_metrics(labels, pred): 17 | """ 18 | Compute metrics. 19 | 20 | :param labels: Label tensor 21 | :type labels: th.Tensor 22 | :param pred: Predictions tensor 23 | :type pred: th.Tensor 24 | :return: Dictionary containing calculated metrics 25 | :rtype: dict 26 | """ 27 | acc, cmat = helpers.ordered_cmat(labels, pred) 28 | metrics = { 29 | "acc": acc, 30 | "cmat": cmat, 31 | "nmi": normalized_mutual_info_score(labels, pred, average_method="geometric"), 32 | } 33 | return metrics 34 | 35 | 36 | def get_log_params(net): 37 | """ 38 | Get the network parameters we want to log. 39 | 40 | :param net: Model 41 | :type net: 42 | :return: 43 | :rtype: 44 | """ 45 | params_dict = {} 46 | weights = [] 47 | if getattr(net, "fusion", None) is not None: 48 | with th.no_grad(): 49 | weights = net.fusion.get_weights(softmax=True) 50 | 51 | elif hasattr(net, "attention"): 52 | weights = net.weights 53 | 54 | for i, w in enumerate(helpers.npy(weights)): 55 | params_dict[f"fusion/weight_{i}"] = w 56 | 57 | if hasattr(net, "discriminators"): 58 | for i, discriminator in enumerate(net.discriminators): 59 | d0, dv = helpers.npy([discriminator.d0, discriminator.dv]) 60 | params_dict[f"discriminator_{i}/d0/mean"] = d0.mean() 61 | params_dict[f"discriminator_{i}/d0/std"] = d0.std() 62 | params_dict[f"discriminator_{i}/dv/mean"] = dv.mean() 63 | params_dict[f"discriminator_{i}/dv/std"] = dv.std() 64 | 65 | return params_dict 66 | 67 | 68 | def get_eval_data(dataset, n_eval_samples, batch_size): 69 | """ 70 | Create a dataloader to use for evaluation 71 | 72 | :param dataset: Inout dataset. 73 | :type dataset: th.utils.data.Dataset 74 | :param n_eval_samples: Number of samples to include in the evaluation dataset. Set to None to use all available 75 | samples. 76 | :type n_eval_samples: int 77 | :param batch_size: Batch size used for training. 78 | :type batch_size: int 79 | :return: Evaluation dataset loader 80 | :rtype: th.utils.data.DataLoader 81 | """ 82 | if n_eval_samples is not None: 83 | *views, labels = dataset.tensors 84 | n = views[0].size(0) 85 | idx = np.random.choice(n, min(n, n_eval_samples), replace=False) 86 | views, labels = [v[idx] for v in views], labels[idx] 87 | dataset = th.utils.data.TensorDataset(*views, labels) 88 | 89 | eval_loader = th.utils.data.DataLoader(dataset, batch_size=int(batch_size), shuffle=True, num_workers=0, 90 | drop_last=False, pin_memory=False) 91 | return eval_loader 92 | 93 | 94 | def batch_predict(net, eval_data, batch_size): 95 | """ 96 | Compute predictions for `eval_data` in batches. Batching does not influence predictions, but it influences the loss 97 | computations. 98 | 99 | :param net: Model 100 | :type net: 101 | :param eval_data: Evaluation dataloader 102 | :type eval_data: th.utils.data.DataLoader 103 | :param batch_size: Batch size 104 | :type batch_size: int 105 | :return: Label tensor, predictions tensor, list of dicts with loss values, array containing mean and std of cluster 106 | sizes. 107 | :rtype: 108 | """ 109 | predictions = [] 110 | labels = [] 111 | losses = [] 112 | cluster_sizes = [] 113 | 114 | net.eval() 115 | with th.no_grad(): 116 | for i, (*batch, label) in enumerate(eval_data): 117 | pred = net(batch) 118 | labels.append(helpers.npy(label)) 119 | predictions.append(helpers.npy(pred).argmax(axis=1)) 120 | 121 | # Only calculate losses for full batches 122 | if label.size(0) == batch_size: 123 | batch_losses = net.calc_losses(ignore_in_total=IGNORE_IN_TOTAL) 124 | losses.append(helpers.npy(batch_losses)) 125 | cluster_sizes.append(helpers.npy(pred.sum(dim=0))) 126 | 127 | labels = np.concatenate(labels, axis=0) 128 | predictions = np.concatenate(predictions, axis=0) 129 | net.train() 130 | return labels, predictions, losses, np.array(cluster_sizes).sum(axis=0) 131 | 132 | 133 | def get_logs(cfg, net, eval_data, iter_losses=None, epoch=None, include_params=True): 134 | if iter_losses is not None: 135 | logs = helpers.add_prefix(helpers.dict_means(iter_losses), "iter_losses") 136 | else: 137 | logs = {} 138 | if (epoch is None) or ((epoch % cfg.eval_interval) == 0): 139 | labels, pred, eval_losses, cluster_sizes = batch_predict(net, eval_data, cfg.batch_size) 140 | eval_losses = helpers.dict_means(eval_losses) 141 | logs.update(helpers.add_prefix(eval_losses, "eval_losses")) 142 | logs.update(helpers.add_prefix(calc_metrics(labels, pred), "metrics")) 143 | logs.update(helpers.add_prefix({"mean": cluster_sizes.mean(), "sd": cluster_sizes.std()}, "cluster_size")) 144 | if include_params: 145 | logs.update(helpers.add_prefix(get_log_params(net), "params")) 146 | if epoch is not None: 147 | logs["epoch"] = epoch 148 | return logs 149 | 150 | 151 | def eval_run(cfg, cfg_name, experiment_identifier, run, net, eval_data, callbacks=tuple(), load_best=True): 152 | """ 153 | Evaluate a training run. 154 | 155 | :param cfg: Experiment config 156 | :type cfg: config.defaults.Experiment 157 | :param cfg_name: Config name 158 | :type cfg_name: str 159 | :param experiment_identifier: 8-character unique identifier for the current experiment 160 | :type experiment_identifier: str 161 | :param run: Run to evaluate 162 | :type run: int 163 | :param net: Model 164 | :type net: 165 | :param eval_data: Evaluation dataloder 166 | :type eval_data: th.utils.data.DataLoader 167 | :param callbacks: List of callbacks to call after evaluation 168 | :type callbacks: List 169 | :param load_best: Load the "best.pt" model before evaluation? 170 | :type load_best: bool 171 | :return: Evaluation logs 172 | :rtype: dict 173 | """ 174 | if load_best: 175 | model_path = helpers.get_save_dir(cfg_name, experiment_identifier, run) / "best.pt" 176 | if os.path.isfile(model_path): 177 | net.load_state_dict(th.load(model_path)) 178 | else: 179 | print(f"Unable to load best model for evaluation. Model file not found: {model_path}") 180 | logs = get_logs(cfg, net, eval_data, include_params=True) 181 | for cb in callbacks: 182 | cb.at_eval(net=net, logs=logs) 183 | return logs 184 | 185 | 186 | def eval_experiment(cfg_name, tag, plot=False): 187 | """ 188 | Evaluate a full experiment 189 | 190 | :param cfg_name: Name of the config 191 | :type cfg_name: str 192 | :param tag: 8-character unique identifier for the current experiment 193 | :type tag: str 194 | :param plot: Display a scatterplot of the representations before and after fusion? 195 | :type plot: bool 196 | """ 197 | max_n_runs = 100 198 | best_logs = None 199 | best_run = None 200 | best_net = None 201 | best_loss = np.inf 202 | 203 | for run in range(max_n_runs): 204 | try: 205 | net, views, labels, cfg = from_file(cfg_name, tag, run, ckpt="best", return_data=True, return_config=True) 206 | except FileNotFoundError: 207 | break 208 | 209 | eval_dataset = th.utils.data.TensorDataset(*[th.tensor(v) for v in views], th.tensor(labels)) 210 | eval_data = get_eval_data(eval_dataset, cfg.n_eval_samples, cfg.batch_size) 211 | run_logs = eval_run(cfg, cfg_name, tag, run, net, eval_data, load_best=False) 212 | del run_logs["metrics/cmat"] 213 | 214 | if run_logs[f"eval_losses/{cfg.best_loss_term}"] < best_loss: 215 | best_loss = run_logs[f"eval_losses/{cfg.best_loss_term}"] 216 | best_logs = run_logs 217 | best_run = run 218 | best_net = net 219 | 220 | print(f"\nBest run was {best_run}.", end="\n\n") 221 | headers = ["Name", "Value"] 222 | values = list(best_logs.items()) 223 | print(tabulate(values, headers=headers), "\n") 224 | 225 | if plot: 226 | plot_representations(views, labels, best_net) 227 | plt.show() 228 | 229 | 230 | def plot_representations(views, labels, net, project_method="pca"): 231 | with th.no_grad(): 232 | output = net([th.tensor(v) for v in views]) 233 | pred = helpers.npy(output).argmax(axis=1) 234 | 235 | hidden = helpers.npy(net.backbone_outputs) 236 | fused = helpers.npy(net.fused) 237 | 238 | hidden = np.concatenate(hidden, axis=0) 239 | view_hue = sum([labels.shape[0] * [str(i + 1)] for i in range(2)], []) 240 | fused_hue = [str(l + 1) for l in pred] 241 | 242 | view_cmap = "tab10" 243 | class_cmap = "hls" 244 | fig, ax = plt.subplots(1, 2, figsize=(20, 10)) 245 | 246 | plot_projection(X=hidden, method=project_method, hue=view_hue, ax=ax[0], title="Before fusion", 247 | legend_title="View", hue_order=sorted(list(set(view_hue))), cmap=view_cmap) 248 | plot_projection(X=fused, method=project_method, hue=fused_hue, ax=ax[1], title="After fusion", 249 | legend_title="Prediction", hue_order=sorted(list(set(fused_hue))), cmap=class_cmap) 250 | 251 | 252 | def plot_projection(X, method, hue, ax, title=None, cmap="tab10", legend_title=None, legend_loc=1, **kwargs): 253 | X = project(X, method) 254 | pl = sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=hue, ax=ax, legend="full", palette=cmap, **kwargs) 255 | leg = pl.get_legend() 256 | leg._loc = legend_loc 257 | if title is not None: 258 | ax.set_title(title) 259 | if legend_title is not None: 260 | leg.set_title(legend_title) 261 | 262 | 263 | def project(X, method): 264 | if method == "pca": 265 | from sklearn.decomposition import PCA 266 | return PCA(n_components=2).fit_transform(X) 267 | elif method == "tsne": 268 | from sklearn.manifold import TSNE 269 | return TSNE(n_components=2).fit_transform(X) 270 | elif method is None: 271 | return X 272 | else: 273 | raise RuntimeError() 274 | 275 | 276 | def parse_args(): 277 | parser = argparse.ArgumentParser() 278 | parser.add_argument("-c", "--config", dest="cfg_name", required=True) 279 | parser.add_argument("-t", "--tag", dest="tag", required=True) 280 | parser.add_argument("--plot", action="store_true") 281 | return parser.parse_args() 282 | 283 | 284 | if __name__ == '__main__': 285 | args = parse_args() 286 | eval_experiment(args.cfg_name, args.tag, args.plot) 287 | -------------------------------------------------------------------------------- /src/models/model_base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ModelBase(nn.Module): 5 | def __init__(self): 6 | """ 7 | Model base class 8 | """ 9 | super().__init__() 10 | 11 | self.fusion = None 12 | self.optimizer = None 13 | self.loss = None 14 | 15 | def calc_losses(self, ignore_in_total=tuple()): 16 | return self.loss(self, ignore_in_total=ignore_in_total) 17 | 18 | def train_step(self, batch, epoch, it, n_batches): 19 | self.optimizer.zero_grad() 20 | _ = self(batch) 21 | losses = self.calc_losses() 22 | losses["tot"].backward() 23 | self.optimizer.step(epoch + it / n_batches) 24 | return losses 25 | -------------------------------------------------------------------------------- /src/models/simple_mvc.py: -------------------------------------------------------------------------------- 1 | import helpers 2 | from lib.loss import Loss 3 | from lib.fusion import get_fusion_module 4 | from lib.optimizer import Optimizer 5 | from lib.backbones import Backbones 6 | from models.model_base import ModelBase 7 | from models.clustering_module import DDC 8 | 9 | 10 | class SiMVC(ModelBase): 11 | def __init__(self, cfg): 12 | """ 13 | Implementation of the SiMVC model. 14 | 15 | :param cfg: Model config. See `config.defaults.SiMVC` for documentation on the config object. 16 | """ 17 | super().__init__() 18 | 19 | self.cfg = cfg 20 | self.output = self.hidden = self.fused = self.backbone_outputs = None 21 | 22 | # Define Backbones and Fusion modules 23 | self.backbones = Backbones(cfg.backbone_configs) 24 | self.fusion = get_fusion_module(cfg.fusion_config, self.backbones.output_sizes) 25 | # Define clustering module 26 | self.ddc = DDC(input_dim=self.fusion.output_size, cfg=cfg.cm_config) 27 | # Define loss-module 28 | self.loss = Loss(cfg=cfg.loss_config) 29 | # Initialize weights. 30 | self.apply(helpers.he_init_weights) 31 | 32 | # Instantiate optimizer 33 | self.optimizer = Optimizer(cfg.optimizer_config, self.parameters()) 34 | 35 | def forward(self, views): 36 | self.backbone_outputs = self.backbones(views) 37 | self.fused = self.fusion(self.backbone_outputs) 38 | self.output, self.hidden = self.ddc(self.fused) 39 | return self.output 40 | -------------------------------------------------------------------------------- /src/models/train.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import torch as th 3 | 4 | import config 5 | import helpers 6 | from data.load import load_dataset 7 | from models import callback 8 | from models.build_model import build_model 9 | from models import evaluate 10 | 11 | 12 | def train(cfg, net, loader, eval_data, callbacks=tuple()): 13 | """ 14 | Train the model for one run. 15 | 16 | :param cfg: Experiment config 17 | :type cfg: config.defaults.Experiment 18 | :param net: Model 19 | :type net: 20 | :param loader: DataLoder for training data 21 | :type loader: th.utils.data.DataLoader 22 | :param eval_data: DataLoder for evaluation data 23 | :type eval_data: th.utils.data.DataLoader 24 | :param callbacks: Training callbacks. 25 | :type callbacks: List 26 | :return: None 27 | :rtype: None 28 | """ 29 | n_batches = len(loader) 30 | for e in range(1, cfg.n_epochs + 1): 31 | iter_losses = [] 32 | for i, data in enumerate(loader): 33 | *batch, _ = data 34 | try: 35 | batch_losses = net.train_step(batch, epoch=(e-1), it=i, n_batches=n_batches) 36 | except Exception as e: 37 | print(f"Training stopped due to exception: {e}") 38 | return 39 | 40 | iter_losses.append(helpers.npy(batch_losses)) 41 | logs = evaluate.get_logs(cfg, net, eval_data=eval_data, iter_losses=iter_losses, epoch=e, include_params=True) 42 | try: 43 | for cb in callbacks: 44 | cb.epoch_end(e, logs=logs, net=net) 45 | except callback.StopTraining as err: 46 | print(err) 47 | break 48 | 49 | 50 | def main(): 51 | """ 52 | Run an experiment. 53 | """ 54 | experiment_name, cfg = config.get_experiment_config() 55 | dataset = load_dataset(**cfg.dataset_config.dict()) 56 | loader = th.utils.data.DataLoader(dataset, batch_size=int(cfg.batch_size), shuffle=True, num_workers=0, 57 | drop_last=True, pin_memory=False) 58 | eval_data = evaluate.get_eval_data(dataset, cfg.n_eval_samples, cfg.batch_size) 59 | experiment_identifier = wandb.util.generate_id() 60 | 61 | run_logs = [] 62 | for run in range(cfg.n_runs): 63 | net = build_model(cfg.model_config) 64 | print(net) 65 | callbacks = ( 66 | callback.Printer(print_confusion_matrix=(cfg.model_config.cm_config.n_clusters <= 100)), 67 | callback.ModelSaver(cfg=cfg, experiment_name=experiment_name, identifier=experiment_identifier, 68 | run=run, epoch_interval=1, best_loss_term=cfg.best_loss_term, 69 | checkpoint_interval=cfg.checkpoint_interval), 70 | callback.EarlyStopping(patience=cfg.patience, best_loss_term=cfg.best_loss_term, epoch_interval=1) 71 | ) 72 | train(cfg, net, loader, eval_data=eval_data, callbacks=callbacks) 73 | run_logs.append(evaluate.eval_run(cfg=cfg, cfg_name=experiment_name, 74 | experiment_identifier=experiment_identifier, run=run, net=net, 75 | eval_data=eval_data, callbacks=callbacks, load_best=True)) 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /src/scripts/comvc_ablation.sh: -------------------------------------------------------------------------------- 1 | python -m models.train -c mnist_contrast --model_config__loss_config__negative_samples_ratio -1 \ 2 | --model_config__loss_config__adaptive_contrastive_weight 0 3 | python -m models.train -c mnist_contrast --model_config__loss_config__negative_samples_ratio -1 4 | python -m models.train -c mnist_contrast --model_config__loss_config__adaptive_contrastive_weight 0 5 | python -m models.train -c mnist_contrast 6 | -------------------------------------------------------------------------------- /src/scripts/ddc_loss_ablation.sh: -------------------------------------------------------------------------------- 1 | # SiMVC 2 | python -m models.train -c mnist --best_loss_term "tot" --model_config__loss_config__funcs "ddc_1" 3 | python -m models.train -c mnist --best_loss_term "tot" --model_config__loss_config__funcs "ddc_2" 4 | python -m models.train -c mnist --best_loss_term "tot" --model_config__loss_config__funcs "ddc_3" 5 | python -m models.train -c mnist --best_loss_term "tot" --model_config__loss_config__funcs "ddc_1|ddc_2" 6 | python -m models.train -c mnist --best_loss_term "tot" --model_config__loss_config__funcs "ddc_1|ddc_3" 7 | python -m models.train -c mnist --best_loss_term "tot" --model_config__loss_config__funcs "ddc_2|ddc_3" 8 | python -m models.train -c mnist --best_loss_term "tot" # Uses all losses by default 9 | 10 | # ScMVC 11 | python -m models.train -c mnist_contrast --best_loss_term "tot" --model_config__loss_config__funcs "ddc_1|contrast" 12 | python -m models.train -c mnist_contrast --best_loss_term "tot" --model_config__loss_config__funcs "ddc_2|contrast" 13 | python -m models.train -c mnist_contrast --best_loss_term "tot" --model_config__loss_config__funcs "ddc_3|contrast" 14 | python -m models.train -c mnist_contrast --best_loss_term "tot" --model_config__loss_config__funcs "ddc_1|ddc_2|contrast" 15 | python -m models.train -c mnist_contrast --best_loss_term "tot" --model_config__loss_config__funcs "ddc_1|ddc_3|contrast" 16 | python -m models.train -c mnist_contrast --best_loss_term "tot" --model_config__loss_config__funcs "ddc_2|ddc_3|contrast" 17 | python -m models.train -c mnist_contrast --best_loss_term "tot" # Uses all losses by default 18 | -------------------------------------------------------------------------------- /src/scripts/mnist_noise.sh: -------------------------------------------------------------------------------- 1 | # SiMVC 2 | python -m models.train -c mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.1 3 | python -m models.train -c mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.2 4 | python -m models.train -c mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.3 5 | python -m models.train -c mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.4 6 | python -m models.train -c mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.5 7 | python -m models.train -c mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.6 8 | python -m models.train -c mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.7 9 | python -m models.train -c mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.8 10 | python -m models.train -c mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.9 11 | python -m models.train -c mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 1.0 12 | 13 | # CoMVC 14 | python -m models.train -c mnist_contrast --dataset_config__noise_views 1 --dataset_config__noise_sd 0.1 15 | python -m models.train -c mnist_contrast --dataset_config__noise_views 1 --dataset_config__noise_sd 0.2 16 | python -m models.train -c mnist_contrast --dataset_config__noise_views 1 --dataset_config__noise_sd 0.3 17 | python -m models.train -c mnist_contrast --dataset_config__noise_views 1 --dataset_config__noise_sd 0.4 18 | python -m models.train -c mnist_contrast --dataset_config__noise_views 1 --dataset_config__noise_sd 0.5 19 | python -m models.train -c mnist_contrast --dataset_config__noise_views 1 --dataset_config__noise_sd 0.6 20 | python -m models.train -c mnist_contrast --dataset_config__noise_views 1 --dataset_config__noise_sd 0.7 21 | python -m models.train -c mnist_contrast --dataset_config__noise_views 1 --dataset_config__noise_sd 0.8 22 | python -m models.train -c mnist_contrast --dataset_config__noise_views 1 --dataset_config__noise_sd 0.9 23 | python -m models.train -c mnist_contrast --dataset_config__noise_views 1 --dataset_config__noise_sd 1.0 24 | 25 | # EAMC 26 | python -m models.train -c eamc_mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.1 27 | python -m models.train -c eamc_mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.2 28 | python -m models.train -c eamc_mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.3 29 | python -m models.train -c eamc_mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.4 30 | python -m models.train -c eamc_mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.5 31 | python -m models.train -c eamc_mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.6 32 | python -m models.train -c eamc_mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.7 33 | python -m models.train -c eamc_mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.8 34 | python -m models.train -c eamc_mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 0.9 35 | python -m models.train -c eamc_mnist --dataset_config__noise_views 1 --dataset_config__noise_sd 1.0 36 | 37 | --------------------------------------------------------------------------------