├── .gitignore ├── README.md ├── calc_NUDFT.ipynb ├── cfg ├── cfg_mnist.yaml └── cfg_run.yaml ├── data ├── pbmc_x.zip └── pbmc_y.zip ├── dataset.py ├── idc_evaluate.ipynb ├── idc_example.ipynb ├── img ├── freq_bias.png ├── img.png ├── nudft_ALLAML.png ├── pbmc.gif ├── supervised_train_plots.png └── supervised_train_plots2.png ├── interpretability_metrics.py ├── lspin_pbmc.py ├── model.py ├── requirements.txt ├── run.py └── train_evaluate.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | # Interpretable Deep Clustering 7 | 8 |

9 | 10 |

11 | 12 | ## An official implementation of the ICML 2024 accepted paper: [Interpretable Deep Clustering for Tabular Data](https://openreview.net/pdf?id=QPy7zLfvof) 13 |

14 | 15 |

16 | 17 | ## UPDATES: 18 | 19 | - 2024-10-08: 20 | * we add an example notebook that produces a NUDFT plot and compares gated vs non-gated supervised models 21 | * we present below a fixed figure 4 from the paper: the frequency values are normalized for each model and gated features are used for IDC model rather than raw features. 22 | 23 | 24 | ## How to run on your data: 25 | 26 | 1. Do you have a dataset without labels? Use our colab example notebook: Open In Colab 27 | 2. If you have a labeled dataset, please follow the colab with evaluation example: Open In Colab 28 | 29 | 30 | ## Fixed Figure 4. Spectral properties of the learned predictive function using ALLAML dataset. 31 | 32 | The model trained with the gating network (IDC) has higher Fourier amplitudes at all frequency levels than 33 | without gates (IDCw/o_gates) the baseline (TELL). This suggests that IDC can better handle the inductive bias of tabular data. 34 |

35 | 36 |

37 | ### Citation: 38 | Please cite our paper if you use this code: 39 | 40 | 41 | ``` 42 | 43 | @InProceedings{pmlr-v235-svirsky24a, 44 | title = {Interpretable Deep Clustering for Tabular Data}, 45 | author = {Svirsky, Jonathan and Lindenbaum, Ofir}, 46 | booktitle = {Proceedings of the 41st International Conference on Machine Learning}, 47 | pages = {47314--47330}, 48 | year = {2024}, 49 | editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix}, 50 | volume = {235}, 51 | series = {Proceedings of Machine Learning Research}, 52 | month = {21--27 Jul}, 53 | publisher = {PMLR}, 54 | pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/svirsky24a/svirsky24a.pdf}, 55 | url = {https://proceedings.mlr.press/v235/svirsky24a.html}, 56 | abstract = {Clustering is a fundamental learning task widely used as a first step in data analysis. For example, biologists use cluster assignments to analyze genome sequences, medical records, or images. Since downstream analysis is typically performed at the cluster level, practitioners seek reliable and interpretable clustering models. We propose a new deep-learning framework for general domain tabular data that predicts interpretable cluster assignments at the instance and cluster levels. First, we present a self-supervised procedure to identify the subset of the most informative features from each data point. Then, we design a model that predicts cluster assignments and a gate matrix that provides cluster-level feature selection. Overall, our model provides cluster assignments with an indication of the driving feature for each sample and each cluster. We show that the proposed method can reliably predict cluster assignments in biological, text, image, and physics tabular datasets. Furthermore, using previously proposed metrics, we verify that our model leads to interpretable results at a sample and cluster level. Our code is available on https://github.com/jsvir/idc.} 57 | } 58 | ``` 59 | 60 | 61 | 62 | ### TODO list: 63 | - [x] Add interpretability evaluation scripts 64 | - [ ] Add experiments configs 65 | - [ ] Add features index outputs 66 | - [x] Add synthetic dataset deneration code -------------------------------------------------------------------------------- /calc_NUDFT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# Calculation of NUDFT \n", 7 | "\n", 8 | "We proved an example for pre-trained supervised model [LSPIN](https://proceedings.mlr.press/v162/yang22i/yang22i.pdf) trained on PBMC dataset compared against deep classifier.\n", 9 | "\n", 10 | "We provide the checkpoints in ckpts directory.\n" 11 | ], 12 | "metadata": { 13 | "collapsed": false 14 | }, 15 | "id": "9e52f756a2211364" 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "source": [], 20 | "metadata": { 21 | "collapsed": false 22 | }, 23 | "id": "e40cd92806e61978" 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "source": [ 28 | "### Imports and util functions\n" 29 | ], 30 | "metadata": { 31 | "collapsed": false 32 | }, 33 | "id": "9dbae214166550cd" 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 1, 38 | "outputs": [], 39 | "source": [ 40 | "import platform\n", 41 | "from nfft import nfft_adjoint\n", 42 | "import torch\n", 43 | "import numpy as np\n", 44 | "from matplotlib import pyplot as plt\n", 45 | "from omegaconf import OmegaConf\n", 46 | "import seaborn as sns\n", 47 | "from lspin_pbmc import PBMC, Classifier, GatingNet\n", 48 | "\n", 49 | "# which dimension use in the y_hat outputs (normalized with softmax)\n", 50 | "TARGET_IDX = 1\n", 51 | "COLUMNS_SUBSET=100000\n", 52 | "\n", 53 | "\n", 54 | "def spectrum_NUDFT(x, y, kmax=50, nk=1000):\n", 55 | " kvals = np.linspace(0.1, kmax, nk+1)\n", 56 | " nufft = (1 / len(x)) * nfft_adjoint(-(x * kmax / nk), y, 2 * (nk + 1))[nk + 1:]\n", 57 | " return [kvals, np.array(nufft, dtype=\"complex_\")]\n", 58 | "\n", 59 | "\n", 60 | "def select_k_columns_with_max_variance(matrix):\n", 61 | " variance_per_column = torch.var(matrix, dim=0)\n", 62 | " _, sorted_indices = torch.sort(variance_per_column, descending=True)\n", 63 | " top_k_indices = sorted_indices[:COLUMNS_SUBSET]\n", 64 | " selected_columns = matrix[:, top_k_indices]\n", 65 | " return selected_columns" 66 | ], 67 | "metadata": { 68 | "collapsed": false, 69 | "ExecuteTime": { 70 | "end_time": "2024-10-08T12:57:48.715663Z", 71 | "start_time": "2024-10-08T12:57:44.858113500Z" 72 | } 73 | }, 74 | "id": "e1c3233b55927f30" 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "source": [ 79 | "### Load pre-trained checkpoints" 80 | ], 81 | "metadata": { 82 | "collapsed": false 83 | }, 84 | "id": "697ba1c508082d" 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 2, 89 | "outputs": [], 90 | "source": [ 91 | "gated_cfg = OmegaConf.create(dict(\n", 92 | " gated=True,\n", 93 | " input_dim=17126,\n", 94 | " n_clusters=2,\n", 95 | " dataset=\"PBMC\",\n", 96 | " data_dir=\"C:/data/fs/pbmc\" if platform.system() == \"Windows\" else \".\" ,\n", 97 | " batch_size=256,\n", 98 | " repitions=5,\n", 99 | " sigma=0.5,\n", 100 | " reg_beta=100,\n", 101 | " devices=1,\n", 102 | " accelerator=\"gpu\",\n", 103 | " max_epochs=100,\n", 104 | " deterministic=True,\n", 105 | " logger=True,\n", 106 | " log_every_n_steps=10,\n", 107 | " check_val_every_n_epoch=1,\n", 108 | " enable_checkpointing=False,\n", 109 | "))\n", 110 | "\n", 111 | "nongated_cfg = OmegaConf.create(dict(\n", 112 | " gated=False,\n", 113 | " input_dim=17126,\n", 114 | " n_clusters=2,\n", 115 | " dataset=\"PBMC\",\n", 116 | " data_dir=\"C:/data/fs/pbmc\" if platform.system() == \"Windows\" else \".\" ,\n", 117 | " batch_size=256,\n", 118 | " repitions=5,\n", 119 | " devices=1,\n", 120 | " accelerator=\"gpu\",\n", 121 | " max_epochs=100,\n", 122 | " deterministic=True,\n", 123 | " logger=True,\n", 124 | " log_every_n_steps=10,\n", 125 | " check_val_every_n_epoch=1,\n", 126 | " enable_checkpointing=False,\n", 127 | "))\n", 128 | "\n", 129 | "\n" 130 | ], 131 | "metadata": { 132 | "collapsed": false, 133 | "ExecuteTime": { 134 | "end_time": "2024-10-08T12:57:48.730048700Z", 135 | "start_time": "2024-10-08T12:57:48.711662700Z" 136 | } 137 | }, 138 | "id": "545756af156b20df" 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "source": [ 143 | "The models were trained for ~400 epochs and we present the training/validation plots:\n", 144 | "\n", 145 | "\"Alt\n", 146 | "\"Alt\n" 147 | ], 148 | "metadata": { 149 | "collapsed": false 150 | }, 151 | "id": "101afcef26dfe116" 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 8, 156 | "outputs": [ 157 | { 158 | "name": "stdout", 159 | "output_type": "stream", 160 | "text": [ 161 | "Dataset PBMC stats:\n", 162 | "X.shape: (20742, 17126)\n", 163 | "Y.shape: (20742,)\n", 164 | "X.min=-2.734075760955218, X.max=144.01736006468113\n", 165 | "Y.min=0, Y.max=1\n", 166 | "Label 0 has 10479 samples\n", 167 | "Label 1 has 10263 samples\n", 168 | "Split to train/test: train 16594 test 4148\n" 169 | ] 170 | } 171 | ], 172 | "source": [ 173 | "gating_net = GatingNet(gated_cfg)\n", 174 | "classifier_gated = Classifier(gated_cfg)\n", 175 | "gating_net.load_state_dict(torch.load(\"ckpts/supervised/sparse_model_last_pbmc_beta_10_seed_0.pth\")[\"gating\"])\n", 176 | "classifier_gated.load_state_dict(torch.load(\"ckpts/supervised/sparse_model_last_pbmc_beta_10_seed_0.pth\")[\"clustering\"])\n", 177 | "\n", 178 | "# # load clustering model without gates:\n", 179 | "classifier = Classifier(nongated_cfg)\n", 180 | "classifier.load_state_dict(torch.load(\"ckpts/supervised/sparse_model_nogates_best_pbmc_seed_0.pth\")[\"clustering\"])\n", 181 | "\n", 182 | "classifier = classifier.to('cpu')\n", 183 | "classifier_gated = classifier_gated.to('cpu')\n", 184 | "\n", 185 | "classifier.eval()\n", 186 | "classifier_gated.eval()\n", 187 | "gating_net.eval()\n", 188 | "\n", 189 | "_, test_dataset = PBMC.setup(gated_cfg.data_dir)\n", 190 | "\n" 191 | ], 192 | "metadata": { 193 | "collapsed": false, 194 | "ExecuteTime": { 195 | "end_time": "2024-10-08T13:17:14.874885100Z", 196 | "start_time": "2024-10-08T13:16:56.649883800Z" 197 | } 198 | }, 199 | "id": "a055dcc6d21c43e6" 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 10, 204 | "outputs": [ 205 | { 206 | "data": { 207 | "text/plain": "
" 208 | }, 209 | "metadata": {}, 210 | "output_type": "display_data" 211 | }, 212 | { 213 | "data": { 214 | "text/plain": "
", 215 | "image/png": "" 216 | }, 217 | "metadata": {}, 218 | "output_type": "display_data" 219 | } 220 | ], 221 | "source": [ 222 | "plt.clf()\n", 223 | "plt.figure(figsize=(10, 6))\n", 224 | "plt.yscale('log') # Set the y-axis to logarithmic scale\n", 225 | "plt.xscale('log')\n", 226 | "plt.xticks(fontsize=20)\n", 227 | "plt.yticks(fontsize=20)\n", 228 | "plt.xlim((0.1, 100))\n", 229 | "plt.xlabel('$|k|$', fontsize=20)\n", 230 | "plt.ylabel('$P_{f(x)}(k)$', fontsize=20)\n", 231 | "plt.grid(False)\n", 232 | "\n", 233 | "with torch.no_grad():\n", 234 | " x_tensor = torch.tensor(test_dataset.data).float()\n", 235 | " filtered_dataset = select_k_columns_with_max_variance(x_tensor).cpu().numpy()\n", 236 | " normalized_dataset = filtered_dataset \n", 237 | " gates = gating_net.get_gates(x_tensor)\n", 238 | "\n", 239 | " spectrum_nudft_all = []\n", 240 | " for feat_id in range(filtered_dataset.shape[-1]):\n", 241 | " spectrum_nudft_all.append(spectrum_NUDFT(filtered_dataset[:, feat_id], test_dataset.targets))\n", 242 | " spectra = [np.sqrt(np.abs(t[1] ** 2)).reshape(-1, 1) for t in spectrum_nudft_all]\n", 243 | " k = spectrum_nudft_all[0][0]\n", 244 | " k_repeated = np.concatenate([k] * len(spectra), axis=0).reshape(-1)\n", 245 | "\n", 246 | " # gated data, gated classifier predictions:\n", 247 | " e = classifier_gated.encoder(x_tensor * gates)\n", 248 | " y_hat = torch.softmax(classifier_gated.head(e), dim=1)[:, TARGET_IDX].numpy()\n", 249 | " classifier_gated_spectrum_nudft_all = []\n", 250 | " non_zero_ids = torch.nonzero(gates.sum(dim=0) > 0, as_tuple=True)[0].long()\n", 251 | " max_spectrum_val = - np.inf\n", 252 | " max_spectrum_gated_val = - np.inf\n", 253 | "\n", 254 | " # raw data, classifier predictions:\n", 255 | " e_raw = classifier.encoder(x_tensor)\n", 256 | " y_hat_raw = torch.softmax(classifier.head(e_raw), dim=1)[:, TARGET_IDX].numpy()\n", 257 | " classifier_spectrum_nudft_all = []\n", 258 | " normalized_gated_dataset = normalized_dataset * gates.numpy()\n", 259 | " for feat_id in range(normalized_dataset.shape[-1]):\n", 260 | " feat_spectrum = spectrum_NUDFT(normalized_gated_dataset[:, feat_id], y_hat)\n", 261 | " max_spectrum_gated_val = max(max_spectrum_gated_val, np.abs(feat_spectrum[1]).max())\n", 262 | " classifier_gated_spectrum_nudft_all.append(feat_spectrum)\n", 263 | "\n", 264 | " feat_spectrum_raw = spectrum_NUDFT(normalized_dataset[:, feat_id], y_hat_raw)\n", 265 | " max_spectrum_val = max(max_spectrum_val, np.abs(feat_spectrum_raw[1]).max())\n", 266 | " classifier_spectrum_nudft_all.append(feat_spectrum_raw)\n", 267 | "\n", 268 | " classifier_gated_spectra = [np.abs(t[1]).reshape(-1) for t in classifier_gated_spectrum_nudft_all]\n", 269 | " classifier_gated_spectra_single_column = np.concatenate(classifier_gated_spectra, axis=0)\n", 270 | " classifier_gated_spectra_single_column = classifier_gated_spectra_single_column / max_spectrum_gated_val\n", 271 | " sns.lineplot(x=k_repeated, y=classifier_gated_spectra_single_column, label='$P_{f_{lspin}}(k)$')\n", 272 | "\n", 273 | " classifier_model_spectra = [np.abs(t[1]).reshape(-1) for t in classifier_spectrum_nudft_all]\n", 274 | " classifier_spectra_single_column = np.concatenate(classifier_model_spectra, axis=0)\n", 275 | " classifier_spectra_single_column = classifier_spectra_single_column / max_spectrum_val\n", 276 | " sns.lineplot(x=k_repeated, y=classifier_spectra_single_column, label='$P_{f_{classifier}}(k)$')\n", 277 | "\n", 278 | "plt.legend(fontsize=16)\n", 279 | "plt.show()" 280 | ], 281 | "metadata": { 282 | "collapsed": false, 283 | "ExecuteTime": { 284 | "end_time": "2024-10-08T13:30:05.517370800Z", 285 | "start_time": "2024-10-08T13:18:02.330028300Z" 286 | } 287 | }, 288 | "id": "bb12fc6aae6893cd" 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "source": [ 293 | "As it could be seen from the plot, the gating of the features adds a high frequency bias and makes the predictor to be more appropriate for tabular dataset. Please refer to the paper [An Inductive Bias for Tabular Deep Learning](https://proceedings.neurips.cc/paper_files/paper/2023/file/8671b6dffc08b4fcf5b8ce26799b2bef-Paper-Conference.pdf) and Figure 1:\n", 294 | "\n", 295 | "\n", 296 | "\"Alt\n", 297 | "\n", 298 | "Due to their heterogeneous nature, tabular datasets tend to describe higher frequency target functions compared to images. The spectra corresponding to image datasets (curves in color) tend to feature lower Fourier amplitudes at higher frequencies than hetergoneous tabular datasets (cyan region).\n" 299 | ], 300 | "metadata": { 301 | "collapsed": false 302 | }, 303 | "id": "91c4ea1eac38133e" 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "outputs": [], 309 | "source": [], 310 | "metadata": { 311 | "collapsed": false 312 | }, 313 | "id": "5a310757b3524a5a" 314 | } 315 | ], 316 | "metadata": { 317 | "kernelspec": { 318 | "display_name": "Python 3", 319 | "language": "python", 320 | "name": "python3" 321 | }, 322 | "language_info": { 323 | "codemirror_mode": { 324 | "name": "ipython", 325 | "version": 2 326 | }, 327 | "file_extension": ".py", 328 | "mimetype": "text/x-python", 329 | "name": "python", 330 | "nbconvert_exporter": "python", 331 | "pygments_lexer": "ipython2", 332 | "version": "2.7.6" 333 | } 334 | }, 335 | "nbformat": 4, 336 | "nbformat_minor": 5 337 | } 338 | -------------------------------------------------------------------------------- /cfg/cfg_mnist.yaml: -------------------------------------------------------------------------------- 1 | dataset: MNIST10K 2 | data_dir: idc/data 3 | scaler: MinMaxScaler 4 | batch_size: 100 5 | seeds: 1 6 | epochs: &epochs 700 7 | 8 | ae_non_gated_epochs: 10 9 | ae_pretrain_epochs: 300 10 | start_global_gates_training_on_epoch: 400 11 | 12 | mask_percentage: 0.9 13 | latent_noise_std: 0.01 14 | 15 | trainer: 16 | devices: 1 17 | accelerator: gpu 18 | max_epochs: *epochs 19 | deterministic: true 20 | logger: true 21 | log_every_n_steps: 10 22 | check_val_every_n_epoch: 10 23 | enable_checkpointing: false 24 | num_sanity_val_steps: 0 25 | 26 | 27 | # GTCR loss 28 | gtcr_loss: true 29 | gtcr_projection_dim: null # for large number of features use it 30 | gtcr_eps: 1 31 | 32 | 33 | # Compression loss 34 | eps: 0.1 35 | 36 | # Gating Net 37 | use_gating: true 38 | gates_hidden_dim: 784 39 | 40 | # EncoderDecoder 41 | encdec: 42 | - 512 43 | - 512 44 | - 2048 45 | - &bn_layer 10 46 | - 2048 47 | - 512 48 | - 512 49 | 50 | clustering_head: 51 | - *bn_layer 52 | - 2048 53 | 54 | tau: 100 55 | 56 | aux_classifier: 57 | - 2048 58 | 59 | local_gates_lambda: 1 60 | global_gates_lambda: 0.0001 61 | gtcr_lambda: 0.01 62 | 63 | lr: 64 | pretrain: 1e-3 65 | clustering: 1e-2 66 | aux_classifier: 1e-2 67 | 68 | sched: 69 | pretrain_min_lr: 1e-6 70 | clustering_min_lr: 1e-6 71 | 72 | 73 | 74 | save_seed_checkpoints: false 75 | validate: true -------------------------------------------------------------------------------- /cfg/cfg_run.yaml: -------------------------------------------------------------------------------- 1 | filepath_samples: idc/data/pbmc_x.npz 2 | num_clusters: 2 3 | 4 | batch_size: 256 5 | seeds: 1 6 | epochs: &epochs 200 7 | 8 | ae_non_gated_epochs: 5 #50 we reduce the number of epochs for training inside a notebook 9 | ae_pretrain_epochs: 10 #100 we reduce the number of epochs for training inside a notebook 10 | start_global_gates_training_on_epoch: 150 11 | 12 | mask_percentage: 0.9 13 | latent_noise_std: 0.01 14 | 15 | trainer: 16 | devices: 1 17 | accelerator: gpu 18 | max_epochs: *epochs 19 | deterministic: true 20 | logger: true 21 | log_every_n_steps: 10 22 | check_val_every_n_epoch: 10 23 | enable_checkpointing: false 24 | num_sanity_val_steps: 0 25 | 26 | 27 | # GTCR loss 28 | gtcr_loss: true 29 | gtcr_projection_dim: 1024 # for large number of features use it 30 | gtcr_eps: 1 31 | 32 | 33 | # Compression loss 34 | eps: 0.1 35 | 36 | # Gating Net 37 | use_gating: true 38 | gates_hidden_dim: 1024 39 | 40 | # EncoderDecoder 41 | encdec: 42 | - 512 43 | - 512 44 | - 2048 45 | - &bn_layer 128 46 | - 2048 47 | - 512 48 | - 512 49 | 50 | clustering_head: 51 | - *bn_layer 52 | - 2048 53 | 54 | tau: 100 55 | 56 | aux_classifier: 57 | - 2048 58 | 59 | local_gates_lambda: 100 60 | global_gates_lambda: 10 61 | gtcr_lambda: 0.01 62 | 63 | lr: 64 | pretrain: 1e-3 65 | clustering: 1e-3 66 | aux_classifier: 1e-1 67 | 68 | sched: 69 | pretrain_min_lr: 1e-4 70 | clustering_min_lr: 1e-4 71 | 72 | save_seed_checkpoints: false 73 | -------------------------------------------------------------------------------- /data/pbmc_x.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsvir/idc/6288558d24268fa842944dd7f0490a62fd9d1fdf/data/pbmc_x.zip -------------------------------------------------------------------------------- /data/pbmc_y.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsvir/idc/6288558d24268fa842944dd7f0490a62fd9d1fdf/data/pbmc_y.zip -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import MNIST 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import torch 5 | from sklearn import preprocessing 6 | from scipy.io import loadmat 7 | from sklearn.preprocessing import MinMaxScaler 8 | from scipy.stats import zscore 9 | import matplotlib.pyplot as plt 10 | from sklearn import datasets 11 | 12 | 13 | class ClusteringDataset(Dataset): 14 | def __init__(self, data, labels=None, num_clusters=None): 15 | super().__init__() 16 | self.data = data 17 | self.labels = labels 18 | self._num_clusters = num_clusters 19 | if num_clusters is None and labels is None: 20 | raise ValueError("At least one of the values should be provided (labels/num_clusters)") 21 | self.print_stats() 22 | 23 | def __getitem__(self, index: int): 24 | if self.labels is None: 25 | return torch.tensor(self.data[index]).float() 26 | return torch.tensor(self.data[index]).float(), torch.tensor(self.labels[index]).long() 27 | 28 | def __len__(self) -> int: 29 | return len(self.data) 30 | 31 | @property 32 | def num_clusters(self): 33 | return self._num_clusters if self._num_clusters is not None else len(np.unique(self.labels)) 34 | 35 | def num_features(self): 36 | return self.data.shape[-1] 37 | 38 | def print_stats(self): 39 | print('X.shape: ', self.data.shape) 40 | print(f"X.min={self.data.min()}, X.max={self.data.max()}") 41 | if self.labels is not None: 42 | print('Y.shape: ', self.labels.shape) 43 | for y_u in np.unique(self.labels): 44 | print(f'{y_u}: {np.sum(self.labels == y_u)}') 45 | print(f"Y.min={self.labels.min()}, Y.max={self.labels.max()}") 46 | 47 | @classmethod 48 | def setup(cls, cfg): 49 | pass 50 | 51 | 52 | class PBMC(ClusteringDataset): 53 | def __init__(self, data, targets): 54 | super().__init__(data, targets) 55 | 56 | @classmethod 57 | def setup(cls, cfg): 58 | data_dir = cfg.data_dir 59 | with np.load(f"{data_dir}/pbmc_x.npz") as data: 60 | X = data['arr_0'] 61 | with np.load(f"{data_dir}/pbmc_y.npz") as data: 62 | Y = data['arr_0'] 63 | Y = Y - Y.min() 64 | scaler = getattr(preprocessing, cfg.scaler)() 65 | X = scaler.fit_transform(X) 66 | return cls(X, Y) 67 | 68 | 69 | class BIASE(ClusteringDataset): 70 | def __init__(self, data, targets): 71 | super().__init__(data, targets) 72 | 73 | @classmethod 74 | def setup(cls, cfg): 75 | name = 'biase' 76 | data_dir = cfg.data_dir 77 | dataset_x = f"{data_dir}/{name}/{name}_data.csv" 78 | dataset_y = f"{data_dir}/{name}/{name}_celldata.csv" 79 | with open(dataset_x) as r: 80 | data = [l.strip() for l in r.readlines()] 81 | cell_keys = data[0].split(',')[1:] 82 | rows = [np.array([float(v) for v in row.split(',')[1:]]).reshape((1, -1)) for row in data[1:]] 83 | X = BIASE.remove_zero_columns( 84 | np.concatenate(rows, axis=0).transpose()) # np.concatenate(rows, axis=0).transpose() 85 | with open(dataset_y) as r: 86 | y_data = [l.strip().split(',') for l in r.readlines()[1:]] 87 | cell2class = {row[0]: row[2] for row in y_data} 88 | class2count = {} 89 | for cell, clas in cell2class.items(): 90 | class2count.setdefault(clas, 0) 91 | class2count[clas] += 1 92 | 93 | print(class2count) 94 | class2id = {c: i for i, c in enumerate(set(sorted(list(cell2class.values()))))} 95 | 96 | Y = [] 97 | for cell_key in cell_keys: 98 | Y.append(class2id[cell2class[cell_key]]) 99 | Y = np.array(Y).reshape(-1) 100 | X = BIASE.transform(X) 101 | 102 | X = np.log(1 + X) 103 | X = X + .001 * np.random.normal(0, 1, (X.shape)) 104 | scaler = getattr(preprocessing, cfg.scaler)() 105 | X = scaler.fit_transform(X) 106 | return cls(X, Y) 107 | 108 | 109 | class INTESTINE(ClusteringDataset): 110 | def __init__(self, data, targets): 111 | super().__init__(data, targets) 112 | 113 | @classmethod 114 | def setup(cls, cfg): 115 | scaler = getattr(preprocessing, cfg.scaler)() 116 | name = 'intestine' 117 | data_dir = cfg.data_dir 118 | dataset_x = f"{data_dir}/{name}/{name}_data.csv" 119 | dataset_y = f"{data_dir}/{name}/{name}_celldata.csv" 120 | with open(dataset_x) as r: 121 | data = [l.strip() for l in r.readlines()] 122 | cell_keys = data[0].split(',')[1:] 123 | rows = [np.array([float(v) for v in row.split(',')[1:]]).reshape((1, -1)) for row in data[1:]] 124 | X = np.concatenate(rows, axis=0).T 125 | with open(dataset_y) as r: 126 | y_data = [l.strip().split(',') for l in r.readlines()[1:]] 127 | cell2class = {row[0]: row[2] for row in y_data} 128 | class2count = {} 129 | for cell, clas in cell2class.items(): 130 | class2count.setdefault(clas, 0) 131 | class2count[clas] += 1 132 | print(class2count) 133 | class2id = {c: i for i, c in enumerate(sorted(set(list(cell2class.values()))))} 134 | Y = [] 135 | for cell_key in cell_keys: 136 | Y.append(class2id[cell2class[cell_key]]) 137 | Y = np.array(Y).reshape(-1) 138 | X = scaler.fit_transform(X) 139 | return cls(X, Y) 140 | 141 | 142 | class CNAE9(ClusteringDataset): 143 | def __init__(self, data, targets): 144 | super().__init__(data, targets) 145 | 146 | @classmethod 147 | def setup(cls, cfg): 148 | scaler = getattr(preprocessing, cfg.scaler)() 149 | data = np.loadtxt(f"{cfg.data_dir}/cnae_9_numpy.txt") 150 | X = data[:, :-1] 151 | Y = data[:, -1] 152 | Y = Y - Y.min() 153 | X = scaler.fit_transform(X) 154 | return cls(X, Y) 155 | 156 | 157 | class MFEATZERNIKE(ClusteringDataset): 158 | def __init__(self, data, targets): 159 | super().__init__(data, targets) 160 | 161 | @classmethod 162 | def setup(cls, cfg): 163 | scaler = getattr(preprocessing, cfg.scaler)() 164 | data = np.loadtxt(f"{cfg.data_dir}/mfeat_zernike_numpy.txt") 165 | X = data[:, :-1] 166 | Y = data[:, -1] 167 | Y = Y - Y.min() 168 | X = scaler.fit_transform(X) 169 | return cls(X, Y) 170 | 171 | 172 | class ALLAML(ClusteringDataset): 173 | def __init__(self, data, targets): 174 | super().__init__(data, targets) 175 | 176 | @classmethod 177 | def setup(cls, cfg): 178 | dataset = loadmat(f"{cfg.data_dir}/ALLAML.mat") 179 | X = dataset.get('X') 180 | Y = dataset.get('Y').reshape(-1) 181 | Y = Y - Y.min() 182 | scaler = getattr(preprocessing, cfg.scaler)() 183 | X = scaler.fit_transform(X) 184 | return cls(X, Y) 185 | 186 | 187 | class PROSTATE(ClusteringDataset): 188 | def __init__(self, data, targets): 189 | super().__init__(data, targets) 190 | 191 | @classmethod 192 | def setup(cls, cfg): 193 | dataset = loadmat(f"{cfg.data_dir}/PROSTATE.mat") 194 | X = dataset.get('X') 195 | Y = dataset.get('Y').reshape(-1) 196 | Y = Y - Y.min() # to start from zero 197 | scaler = getattr(preprocessing, cfg.scaler)() 198 | X = scaler.fit_transform(X) 199 | return cls(X, Y) 200 | 201 | 202 | class TOX171(ClusteringDataset): 203 | def __init__(self, data, targets): 204 | super().__init__(data, targets) 205 | 206 | @classmethod 207 | def setup(cls, cfg): 208 | dataset = loadmat(f"{cfg.data_dir}/TOX171.mat") 209 | X = dataset.get('X') 210 | Y = dataset.get('Y').reshape(-1) 211 | Y = Y - Y.min() # to start from zero 212 | scaler = getattr(preprocessing, cfg.scaler)() 213 | X = scaler.fit_transform(X) 214 | return cls(X, Y) 215 | 216 | 217 | class SRBCT(ClusteringDataset): 218 | def __init__(self, data, targets): 219 | super().__init__(data, targets) 220 | 221 | @classmethod 222 | def setup(cls, cfg): 223 | dataset = loadmat(f"{cfg.data_dir}/SRBCT.mat") 224 | X = dataset.get('X') 225 | Y = dataset.get('Y').reshape(-1) 226 | Y = Y - Y.min() # to start from zero 227 | scaler = getattr(preprocessing, cfg.scaler)() 228 | X = scaler.fit_transform(X) 229 | return cls(X, Y) 230 | 231 | 232 | class MNIST60K(ClusteringDataset): 233 | def __init__(self, data, targets): 234 | super().__init__(data, targets) 235 | 236 | @classmethod 237 | def setup(cls, cfg): 238 | scaler = getattr(preprocessing, cfg.scaler)() 239 | X = MNIST(cfg.data_dir, train=True, download=True).data.reshape(-1, 784).cpu().numpy() 240 | Y = MNIST(cfg.data_dir, train=True, download=True).targets.cpu().numpy() 241 | X = scaler.fit_transform(X) 242 | return cls(X, Y) 243 | 244 | 245 | class MNIST10K(ClusteringDataset): 246 | def __init__(self, data, targets): 247 | super().__init__(data, targets) 248 | 249 | @classmethod 250 | def setup(cls, cfg): 251 | scaler = getattr(preprocessing, cfg.scaler)() 252 | X = MNIST(cfg.data_dir, train=True, download=True).data.reshape(-1, 784).cpu().numpy() 253 | Y = MNIST(cfg.data_dir, train=True, download=True).targets.cpu().numpy() 254 | X = scaler.fit_transform(X) 255 | X = X[:10000] 256 | Y = Y[:10000] 257 | return cls(X, Y) 258 | 259 | 260 | class NumpyTableDataset(ClusteringDataset): 261 | def __init__(self, data, labels=None, num_clusters=None): 262 | super().__init__(data, labels, num_clusters) 263 | 264 | @classmethod 265 | def setup(cls, filepath_samples: str, filepath_labels: str = None, num_clusters: int = None): 266 | """ 267 | :param filepath_samples: the path to the npz file, the format of the numpy array should be NxD 268 | (number of samples x number of features) 269 | :param filepath_labels: the path to the npz file, the format of the numpy array should be N 270 | (number of samples) 271 | :param num_clusters: the integer number of expected clusters 272 | """ 273 | with np.load(filepath_samples) as data: 274 | X = data['arr_0'] 275 | 276 | if filepath_labels is not None: 277 | with np.load(filepath_labels) as data: 278 | Y = data['arr_0'] 279 | X = preprocessing.StandardScaler().fit_transform(X) 280 | Y = Y - Y.min() 281 | else: 282 | Y = None 283 | return cls(X, Y, num_clusters) 284 | 285 | 286 | def remove_zero_columns(X): 287 | non_zero_columns = [] 288 | for col in range(X.shape[1]): 289 | if np.min(X[:, col]) == 0 and np.max(X[:, col]) == 0: 290 | continue 291 | else: 292 | non_zero_columns.append(col) 293 | X = X[:, non_zero_columns] 294 | return X 295 | 296 | 297 | class Synthetic(Dataset): 298 | def __init__(self, X, Y): 299 | super().__init__() 300 | self.data = X 301 | self.targets = Y 302 | 303 | def __getitem__(self, index: int): 304 | x = self.data[index] 305 | return torch.tensor(x).float(), torch.tensor(self.targets[index]).long() 306 | 307 | def __len__(self) -> int: 308 | return len(self.data) 309 | 310 | @classmethod 311 | def setup(cls, num_samples=5000, num_features=3, num_clusters=3, num_noise_dims=10): 312 | """ 313 | Make num_clusters + 1 clusters in 3d and adds additional num_noise_dims noise features 314 | :param num_samples: number of samples in the dataset 315 | :param num_features: number of features in the dataset 316 | :param num_clusters: number of clusters in the dataset 317 | :param num_noise_dims: number of noise dimensions in addition to num_features 318 | :return: generates a dataset 319 | """ 320 | x_2d, y_2d = datasets.make_blobs(num_samples, num_features-1, centers=num_clusters, cluster_std=.5, 321 | random_state=0) 322 | # split the points for cluster==2 into 2 clusters: 323 | max_x = x_2d[:, 1].max() 324 | min_x = x_2d[:, 1].min() 325 | x_y_2 = x_2d[y_2d == 2][:, 1] 326 | x_y_2 = MinMaxScaler((0, 1)).fit_transform(x_y_2.reshape(-1, 1)).reshape(-1) 327 | x_y_2 = MinMaxScaler((min_x, max_x)).fit_transform(x_y_2.reshape(-1, 1)).reshape(-1) 328 | x_2d[:, 1][y_2d == 2] = x_y_2 329 | 330 | z = np.random.rand(num_samples) 331 | y_2d[(y_2d == 2) & (z > 0.5)] = 3 332 | 333 | x_2d[:, 0][y_2d == 0] = x_2d[:, 0][y_2d == 1] 334 | 335 | bg = np.random.normal(loc=0, scale=0.01, size=(num_samples, num_noise_dims)) 336 | X = np.concatenate([x_2d, z.reshape(-1, 1), bg], axis=1) 337 | X[:, 2][y_2d == 3] = X[:, 2][y_2d == 3] + 0.5 # separate in z axis 338 | X[:, 2][y_2d == 0] = MinMaxScaler( 339 | (X[:, 2][(y_2d == 3) | (y_2d == 2)].min(), X[:, 2][(y_2d == 3) | (y_2d == 2)].max())).fit_transform( 340 | X[:, 2][y_2d == 0].reshape(-1, 1)).reshape(-1) 341 | X[:, 2][y_2d == 1] = MinMaxScaler( 342 | (X[:, 2][(y_2d == 3) | (y_2d == 2)].min(), X[:, 2][(y_2d == 3) | (y_2d == 2)].max())).fit_transform( 343 | X[:, 2][y_2d == 1].reshape(-1, 1)).reshape(-1) 344 | 345 | Y = y_2d 346 | 347 | X4 = X[Y == 3] 348 | max_len = len(X4) 349 | X1 = X[Y == 0][:max_len, :] 350 | X2 = X[Y == 1][:max_len, :] 351 | X3 = X[Y == 2][:max_len, :] 352 | 353 | Y1 = Y[Y == 0][:max_len] 354 | Y2 = Y[Y == 1][:max_len] 355 | Y3 = Y[Y == 2][:max_len] 356 | Y4 = Y[Y == 3][:max_len] 357 | X = np.concatenate([X1, X2, X3, X4], axis=0) 358 | Y = np.concatenate([Y1, Y2, Y3, Y4], axis=0) 359 | 360 | print("Class stats:") 361 | for y_i in np.unique(Y): 362 | print(f"{y_i}: {len(Y[Y == y_i])} samples") 363 | X[:, :3] = zscore(X[:, :3]) 364 | 365 | plt.style.use('classic') 366 | plt.rcParams['axes.spines.right'] = False 367 | plt.rcParams['axes.spines.top'] = False 368 | fig = plt.figure() 369 | fig.set_facecolor('w') 370 | plt.scatter(X[:, 0], X[:, 1], c=Y, s=100, alpha=0.8, cmap='viridis', edgecolor='k', linewidth=2) 371 | plt.xlabel('$X_1$', fontsize=30) 372 | plt.ylabel('$X_2$', fontsize=30) 373 | plt.tight_layout() 374 | plt.xticks([]) 375 | plt.yticks([]) 376 | plt.savefig("synth_X_1_X_2.png") 377 | 378 | plt.clf() 379 | fig = plt.figure() 380 | fig.set_facecolor('w') 381 | plt.scatter(X[:, 0], X[:, 2], c=Y, s=100, alpha=0.8, cmap='viridis', edgecolor='k', linewidth=2) 382 | plt.xlabel('$X_1$', fontsize=30) 383 | plt.ylabel('$X_3$', fontsize=30) 384 | plt.tight_layout() 385 | plt.xticks([]) 386 | plt.yticks([]) 387 | plt.savefig("synth_X_1_X_3.png") 388 | return cls(X, Y) 389 | 390 | def num_classes(self): 391 | return len(np.unique(self.targets)) 392 | 393 | def num_features(self): 394 | return self.data.shape[-1] -------------------------------------------------------------------------------- /img/freq_bias.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsvir/idc/6288558d24268fa842944dd7f0490a62fd9d1fdf/img/freq_bias.png -------------------------------------------------------------------------------- /img/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsvir/idc/6288558d24268fa842944dd7f0490a62fd9d1fdf/img/img.png -------------------------------------------------------------------------------- /img/nudft_ALLAML.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsvir/idc/6288558d24268fa842944dd7f0490a62fd9d1fdf/img/nudft_ALLAML.png -------------------------------------------------------------------------------- /img/pbmc.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsvir/idc/6288558d24268fa842944dd7f0490a62fd9d1fdf/img/pbmc.gif -------------------------------------------------------------------------------- /img/supervised_train_plots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsvir/idc/6288558d24268fa842944dd7f0490a62fd9d1fdf/img/supervised_train_plots.png -------------------------------------------------------------------------------- /img/supervised_train_plots2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsvir/idc/6288558d24268fa842944dd7f0490a62fd9d1fdf/img/supervised_train_plots2.png -------------------------------------------------------------------------------- /interpretability_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from munkres import Munkres 3 | import torch 4 | from sklearn.metrics import confusion_matrix, jaccard_score 5 | from scipy.spatial import distance_matrix 6 | from tqdm import tqdm 7 | from sklearn.svm import LinearSVC 8 | 9 | 10 | def generalizability(X_train, gates_train, Y_train, X_test, gates_test, Y_test): 11 | """ 12 | How the interpretation of the prediction generalizes to other simple prediction models, e.g. Linear Support Vector Classification 13 | """ 14 | classifier = LinearSVC() 15 | classifier.fit(X_train * gates_train, Y_train) 16 | return classifier.score(X_test * gates_test, Y_test) 17 | 18 | 19 | def faithfulness(gates_i, x, inference_fn, y, num_features=784): 20 | """ 21 | Are the identified features significant for prediction? 22 | """ 23 | importance_vec = np.sum(gates_i > 0, axis=0) 24 | importance_ind = np.where(np.sum(gates_i > 0, axis=0) > 0)[0] 25 | importance_ind_sort = importance_ind[np.argsort(-importance_vec[importance_vec > 0])] 26 | mask = np.ones(num_features) 27 | acc_arr_bad = [] 28 | for i in importance_ind_sort: 29 | mask[i] = 0 30 | y_hat = inference_fn(x * mask) 31 | if isinstance(y_hat, torch.Tensor): 32 | y_hat = y_hat.cpu().numpy() 33 | mean_val = get_accuracy(y_hat, y, 10) 34 | acc_arr_bad.append(mean_val) 35 | return np.corrcoef(importance_vec[importance_ind_sort], np.array(acc_arr_bad))[0, 1] 36 | 37 | 38 | def stability(x, gates, k=2, subset_size=10000,p=2): 39 | """ 40 | Are explanations to similar samples consistent? 41 | inputs: 42 | x_test is N x D matrix of samples 43 | gates is N x D matrix predicted by STG for x_test 44 | k is the number of neighbors 45 | outputs: 46 | mean Lipchitz constant of the explanation function 47 | """ 48 | dist_mat_x = distance_matrix(x, x, p=p) 49 | nn_dist_mat = np.sort(dist_mat_x, axis=1)[:, 0:k] 50 | nn_ind_mat = np.argsort(dist_mat_x, axis=1)[:, 0:k] 51 | lipchitz_constants = [] 52 | for i in tqdm(range(subset_size)): 53 | lipchitz_constants.append(max(distance_matrix(gates[nn_ind_mat[i], :], gates[nn_ind_mat[i], :])[0][1:] / nn_dist_mat[i][1:])) 54 | return np.mean(np.array(lipchitz_constants)) 55 | 56 | 57 | def diversity(y, gates, num_clusters=10, num_features=784): 58 | """ 59 | How different are the selected variables for instances of distinct classes? 60 | For formula see appendix A.7 in 61 | Yang, Junchen, Ofir Lindenbaum, and Yuval Kluger. "Locally sparse neural networks for tabular biomedical data." International Conference on Machine Learning. PMLR, 2022. 62 | """ 63 | per_matrix = np.zeros((num_clusters, num_clusters)) 64 | all_gates = [] 65 | for i in range(num_clusters): 66 | indices_p = np.where(y == i)[0] 67 | onez_p = np.zeros(num_features) 68 | active_gates = np.where(np.median(gates[indices_p, :], axis=0) > 0)[0] 69 | onez_p[active_gates] = 1 70 | all_gates = np.append(all_gates, active_gates) 71 | for j in range(num_clusters): 72 | indices_n = np.where(y == j)[0] 73 | active_gates_n = np.where(np.median(gates[indices_n, :], axis=0) > 0)[0] 74 | onez_n = np.zeros(num_features) 75 | onez_n[active_gates_n] = 1 76 | per_matrix[i, j] = jaccard_score(onez_n, onez_p) 77 | 78 | diversity = 100 * (1 - (per_matrix / (num_clusters * (num_clusters - 1))).sum()) 79 | return diversity 80 | 81 | 82 | def uniqueness(x, gates, k=2, subset_size=10000, p=2): 83 | """ 84 | uniqueness of the selected features for similar samples (how granular our explanations are?) 85 | inputs: 86 | x_test is N x D matrix of samples 87 | gates is N x D matrix predicted by STG for x_test 88 | k is the number of neighbors 89 | """ 90 | dist_mat_x = distance_matrix(x, x, p=p) 91 | nn_dist_mat = np.sort(dist_mat_x, axis=1)[:, 0:k] 92 | nn_ind_mat = np.argsort(dist_mat_x, axis=1)[:, 0:k] 93 | vals = [] 94 | for i in tqdm(range(subset_size)): 95 | vals.append(min(distance_matrix(gates[nn_ind_mat[i], :], gates[nn_ind_mat[i], :])[0][1:] / nn_dist_mat[i][1:])) 96 | return np.mean(np.array(vals)) 97 | 98 | 99 | 100 | def get_accuracy(cluster_assignments, y_true, n_clusters): 101 | ''' 102 | Computes the accuracy based on the provided kmeans cluster assignments 103 | and true labels, using the Munkres algorithm 104 | cluster_assignments: array of labels, outputted by kmeans 105 | y_true: true labels 106 | n_clusters: number of clusters in the dataset 107 | returns: a tuple containing the accuracy and confusion matrix, 108 | in that order 109 | ''' 110 | confusion_mat = confusion_matrix(y_true, cluster_assignments, labels=None) 111 | # compute accuracy based on optimal 1:1 assignment of clusters to labels 112 | cost_matrix = calculate_cost_matrix(confusion_mat, n_clusters) 113 | indices = Munkres().compute(cost_matrix) 114 | kmeans_to_true_cluster_labels = get_cluster_labels_from_indices(indices) 115 | y_pred = kmeans_to_true_cluster_labels[cluster_assignments] 116 | return np.mean(y_pred == y_true) 117 | 118 | 119 | def calculate_cost_matrix(C, n_clusters): 120 | cost_matrix = np.zeros((n_clusters, n_clusters)) 121 | # cost_matrix[i,j] will be the cost of assigning cluster i to label j 122 | for j in range(n_clusters): 123 | s = np.sum(C[:, j]) # number of examples in cluster i 124 | for i in range(n_clusters): 125 | t = C[i, j] 126 | cost_matrix[j, i] = s - t 127 | return cost_matrix 128 | 129 | 130 | def get_cluster_labels_from_indices(indices): 131 | n_clusters = len(indices) 132 | clusterLabels = np.zeros(n_clusters) 133 | for i in range(n_clusters): 134 | clusterLabels[i] = indices[i][1] 135 | return clusterLabels -------------------------------------------------------------------------------- /lspin_pbmc.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | import torch.nn.functional as F 3 | from torch.utils.data import DataLoader 4 | from pytorch_lightning import LightningModule 5 | from torch.utils.data import Dataset 6 | from sklearn.preprocessing import StandardScaler 7 | from pytorch_lightning import Trainer, seed_everything 8 | import argparse 9 | import torch 10 | import math 11 | import numpy as np 12 | import os 13 | from pytorch_lightning.loggers import TensorBoardLogger 14 | from pytorch_lightning.callbacks import LearningRateMonitor 15 | import platform 16 | 17 | 18 | def parse_args(args): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--gated", action="store_true") 21 | parser.add_argument("--dataset", type=str, default="PBMC") 22 | parser.add_argument("--data_dir", type=str, default="C:/data/fs/pbmc" if platform.system() == "Windows" else ".") 23 | parser.add_argument("--batch_size", type=int, default=256) 24 | parser.add_argument("--repitions", type=int, default=5) 25 | parser.add_argument("--lr", type=int, default=1e-3) 26 | 27 | # gatenet config 28 | parser.add_argument("--sigma", type=float, default=0.5) 29 | parser.add_argument("--reg_beta", type=float, default=10) 30 | parser.add_argument("--target_sparsity", type=int, default=0.9) 31 | parser.add_argument("--gates_lr", type=int, default=2e-3) 32 | 33 | 34 | # trainer config 35 | parser.add_argument("--devices", type=int, default=1) 36 | parser.add_argument("--accelerator", type=str, default="gpu") 37 | parser.add_argument("--max_epochs", type=int, default=1000) 38 | parser.add_argument("--deterministic", type=bool, default=True) 39 | parser.add_argument("--logger", type=bool, default=True) 40 | parser.add_argument("--log_every_n_steps", type=int, default=10) 41 | parser.add_argument("--check_val_every_n_epoch", type=int, default=1) 42 | parser.add_argument("--enable_checkpointing", type=bool, default=False) 43 | 44 | args = parser.parse_args(args) 45 | return args 46 | 47 | 48 | class PBMC(Dataset): 49 | def __init__(self, X, Y): 50 | super().__init__() 51 | self.data = X 52 | self.targets = Y 53 | 54 | def __getitem__(self, index: int): 55 | x = self.data[index] 56 | x = x.reshape(-1) 57 | return torch.tensor(x).float(), torch.tensor(self.targets[index]).long() 58 | 59 | def __len__(self) -> int: 60 | return len(self.data) 61 | 62 | @classmethod 63 | def setup(cls, data_dir, test_size=0.2): 64 | with np.load(f"{data_dir}/pbmc_x.npz") as data: 65 | X = data['arr_0'] 66 | with np.load(f"{data_dir}/pbmc_y.npz") as data: 67 | Y = data['arr_0'] 68 | 69 | Y = Y - Y.min() 70 | X = StandardScaler().fit_transform(X) 71 | print(f'Dataset PBMC stats:') 72 | print('X.shape: ', X.shape) 73 | print('Y.shape: ', Y.shape) 74 | print(f"X.min={X.min()}, X.max={X.max()}") 75 | print(f"Y.min={Y.min()}, Y.max={Y.max()}") 76 | 77 | for y_uniq in np.unique(Y): 78 | print(f"Label {y_uniq} has {len(Y[Y == y_uniq])} samples") 79 | 80 | np.random.seed(1948) 81 | random_index = np.random.permutation(len(X)) 82 | test_size = int(len(X) * test_size) 83 | x_test = X[random_index][:test_size] 84 | y_test = Y[random_index][:test_size] 85 | 86 | x_train = X[random_index][test_size:] 87 | y_train = Y[random_index][test_size:] 88 | 89 | print(f"Split to train/test: train {len(x_train)} test {len(x_test)}") 90 | return cls(x_train, y_train), cls(x_test, y_test) 91 | 92 | def num_classes(self): 93 | return len(np.unique(self.targets)) 94 | 95 | def num_features(self): 96 | return self.data.shape[-1] 97 | 98 | 99 | class BaseModule(LightningModule): 100 | def __init__(self, cfg): 101 | super().__init__() 102 | self.cfg = cfg 103 | self.save_hyperparameters() 104 | self.best_evaluation_stats = {} 105 | self.automatic_optimization = False 106 | self.best_accuracy = - np.infty 107 | self.classifier_net = Classifier(cfg) 108 | self.val_cluster_list = [] 109 | self.val_label_list = [] 110 | self.best_acc = - 100 111 | 112 | if cfg.gated: 113 | self.gating_net = GatingNet(cfg) 114 | self.val_cluster_list_gated = [] 115 | self.open_gates = [] 116 | self.best_local_feats = None 117 | 118 | def training_step(self, batch, batch_idx): 119 | opt = self.optimizers() 120 | sch = self.lr_schedulers() 121 | x, y = batch 122 | x = x.reshape(x.size(0), -1) 123 | opt.zero_grad() 124 | 125 | if hasattr(self, 'gating_net'): 126 | mu, _, gates = self.gating_net(x) 127 | ae_emb = self.classifier_net.encoder(x * gates) 128 | 129 | reg_loss = self.gating_net.regularization(mu) 130 | self.log("train/reg_loss", reg_loss.item()) 131 | else: 132 | ae_emb = self.classifier_net.encoder(x) 133 | reg_loss = 0 134 | 135 | cluster_logits = self.classifier_net.head(ae_emb) 136 | ce_loss = F.cross_entropy(cluster_logits, y) 137 | 138 | self.log("train/ce_loss", ce_loss.item()) 139 | loss = ce_loss + self.cfg.reg_beta * reg_loss 140 | self.manual_backward(loss) 141 | opt.step() 142 | sch.step() 143 | 144 | if self.global_step % 100 == 0: 145 | if hasattr(self, 'gating_net'): 146 | print(f"Epoch {self.current_epoch} " 147 | f"step {self.global_step} " 148 | f"train/reg_loss {reg_loss.item()} " 149 | f"train/ce_loss {ce_loss.item()}") 150 | def configure_optimizers(self): 151 | 152 | if hasattr(self, 'gating_net'): 153 | params =[ { # classifier 154 | "params": chain( 155 | self.classifier_net.encoder.parameters(), 156 | self.classifier_net.head.parameters()), 157 | "lr": self.cfg.lr, 158 | 159 | }, 160 | { # gates 161 | "params": self.gating_net.net.parameters(), 162 | "lr": self.cfg.gates_lr, 163 | }] 164 | 165 | else: 166 | params = chain( 167 | self.classifier_net.encoder.parameters(), 168 | self.classifier_net.head.parameters(), 169 | ) 170 | optimizer = torch.optim.SGD( 171 | params=params, 172 | lr=self.cfg.lr) 173 | 174 | steps = self.train_dataset.__len__() // self.batch_size * self.cfg.max_epochs 175 | print(f"Cosine annealing LR scheduling is applied during {steps} steps") 176 | sched = torch.optim.lr_scheduler.CosineAnnealingLR( 177 | optimizer=optimizer, 178 | T_max=steps, 179 | eta_min=1e-4) 180 | return [optimizer], [sched] 181 | 182 | def validation_step(self, batch, batch_idx): 183 | x, y = batch 184 | if hasattr(self, 'gating_net'): 185 | gates = self.gating_net.get_gates(x) 186 | ae_emb = self.classifier_net.encoder(x * gates) 187 | self.open_gates.append(self.gating_net.num_open_gates(x)) 188 | else: 189 | ae_emb = self.classifier_net.encoder(x) 190 | cluster_logits = self.classifier_net.head(ae_emb) 191 | y_hat = cluster_logits.argmax(dim=-1) 192 | self.val_cluster_list.append(y_hat.cpu()) 193 | self.val_label_list.append(y.cpu()) 194 | 195 | def on_validation_epoch_start(self): 196 | self.val_cluster_list = [] 197 | self.val_cluster_list_gated = [] 198 | self.val_label_list = [] 199 | if hasattr(self, 'gating_net'): 200 | self.open_gates = [] 201 | 202 | def on_validation_epoch_end(self): 203 | if self.current_epoch > 0: 204 | cluster_mtx = torch.cat(self.val_cluster_list, dim=0) 205 | label_mtx = torch.cat(self.val_label_list, dim=0) 206 | acc = torch.mean((cluster_mtx == label_mtx).float()).item() 207 | if self.best_accuracy < acc: 208 | self.best_accuracy = acc 209 | if hasattr(self, 'gating_net'): 210 | meta_dict = {"gating": self.gating_net.state_dict(), "clustering": self.classifier_net.state_dict()} 211 | torch.save(meta_dict, f'sparse_model_best_pbmc_beta_{self.cfg.reg_beta}_seed_{self.cfg.seed}.pth') 212 | print(f"New best accuracy: {acc} open gates: {np.mean(self.open_gates).item()}") 213 | else: 214 | meta_dict = {"clustering": self.classifier_net.state_dict()} 215 | torch.save(meta_dict, f'sparse_model_nogates_best_pbmc_seed_{self.cfg.seed}.pth') 216 | print(f"New best accuracy: {acc}") 217 | format_str = '' # '_kmeans' if self.current_epoch == 9 else '' 218 | self.log(f'val/acc_single{format_str}', acc) # this is ACC 219 | if hasattr(self, 'gating_net'): 220 | self.log("val/num_open_gates", np.mean(self.open_gates).item()) 221 | meta_dict = {"gating": self.gating_net.state_dict(), "clustering": self.classifier_net.state_dict()} 222 | torch.save(meta_dict, f'sparse_model_last_pbmc_beta_{self.cfg.reg_beta}_seed_{self.cfg.seed}.pth') 223 | self.update_stats(acc, np.mean(self.open_gates).item()) 224 | else: 225 | meta_dict = {"clustering": self.classifier_net.state_dict()} 226 | torch.save(meta_dict, f'sparse_model_nogates_last_pbmc_seed_{self.cfg.seed}.pth') 227 | self.update_stats(acc, None) 228 | 229 | def update_stats(self, acc, local_feats=None): 230 | if self.best_acc <= acc: 231 | self.best_acc = acc 232 | if local_feats is not None: 233 | self.best_local_feats = local_feats 234 | 235 | 236 | class ClassificationModule(BaseModule): 237 | def __init__(self, cfg): 238 | self.train_dataset, self.test_dataset = PBMC.setup(cfg.data_dir) 239 | print(f"Train Dataset length: {self.train_dataset.__len__()}") 240 | print(f"Test Dataset length: {self.test_dataset.__len__()}") 241 | cfg.input_dim = self.train_dataset.num_features() 242 | cfg.n_clusters = self.train_dataset.num_classes() 243 | self.batch_size = min(self.train_dataset.__len__(), cfg.batch_size) 244 | super().__init__(cfg) 245 | 246 | def train_dataloader(self): 247 | return DataLoader(self.train_dataset, 248 | batch_size=self.batch_size, 249 | drop_last=True, 250 | shuffle=True, 251 | num_workers=0) 252 | 253 | def val_dataloader(self): 254 | return DataLoader(self.test_dataset, 255 | batch_size=self.batch_size, 256 | drop_last=False, 257 | shuffle=False, 258 | num_workers=0) 259 | 260 | 261 | class Classifier(torch.nn.Module): 262 | def __init__(self, cfg): 263 | super(Classifier, self).__init__() 264 | self.cfg = cfg 265 | self.encoder = torch.nn.Sequential( 266 | torch.nn.Linear(cfg.input_dim, 512), 267 | torch.nn.BatchNorm1d(512), 268 | torch.nn.ReLU(), 269 | ) 270 | self.head = torch.nn.Sequential( 271 | torch.nn.Linear(512, 2048), 272 | torch.nn.BatchNorm1d(2048), 273 | torch.nn.ReLU(), 274 | torch.nn.Linear(2048, cfg.n_clusters), 275 | ) 276 | 277 | self.encoder.apply(self.init_weights_normal) 278 | self.head.apply(self.init_weights_normal) 279 | 280 | @staticmethod 281 | def init_weights_normal(m): 282 | if isinstance(m, torch.nn.Linear): 283 | torch.nn.init.xavier_normal_(m.weight) 284 | if 'bias' in vars(m).keys(): 285 | m.bias.data.fill_(0.0) 286 | 287 | def pretrain_forward(self, x): 288 | return self.decoder(self.encoder(x)) 289 | 290 | 291 | class GatingNet(torch.nn.Module): 292 | def __init__(self, cfg): 293 | super(GatingNet, self).__init__() 294 | self.cfg = cfg 295 | self._sqrt_2 = math.sqrt(2) 296 | self.sigma = cfg.sigma 297 | self.net = torch.nn.Sequential( 298 | torch.nn.Linear(cfg.input_dim, 512), 299 | torch.nn.ReLU(), 300 | torch.nn.Linear(512, 2048), 301 | torch.nn.ReLU(), 302 | torch.nn.Linear(2048, 512), 303 | torch.nn.ReLU(), 304 | torch.nn.Linear(512, cfg.input_dim), 305 | torch.nn.Tanh() 306 | ) 307 | self.net.apply(self.init_weights) 308 | 309 | def init_weights(self, m): 310 | if isinstance(m, torch.nn.Linear): 311 | torch.nn.init.xavier_normal_(m.weight) 312 | if m.out_features == self.cfg.input_dim: 313 | m.bias.data.fill_(.5) 314 | else: 315 | m.bias.data.fill_(0.0) 316 | 317 | def global_forward(self, batch_size, y): 318 | noise = torch.normal(mean=0, std=self.sigma, size=(batch_size, self.cfg.input_dim), 319 | device=self.global_gates_net.weight.device) 320 | z = torch.tanh(self.global_gates_net(y)).reshape(1, -1).repeat(batch_size, 1) + noise * self.training 321 | gates = self.hard_sigmoid(z) 322 | return torch.tanh(self.global_gates_net(y)), gates 323 | 324 | def open_global_gates(self): 325 | return self.hard_sigmoid(torch.tanh(self.global_gates_net.weight)).sum(dim=1).mean().cpu().item() 326 | 327 | def forward(self, x): 328 | noise = torch.normal(mean=0, std=self.sigma, size=x.size(), device=x.device) 329 | mu = self.net(x) 330 | z = mu + noise * self.training 331 | gates = self.hard_sigmoid(z) 332 | sparse_x = x * gates 333 | return mu, sparse_x, gates 334 | 335 | @staticmethod 336 | def hard_sigmoid(x): 337 | return torch.clamp(x + .5, 0.0, 1.0) 338 | 339 | def regularization(self, mu, reduction_func=torch.mean): 340 | return max(reduction_func(0.5 - 0.5 * torch.erf((-0.5 - mu) / (0.5 * self._sqrt_2))), 341 | torch.tensor(1 - self.cfg.target_sparsity, device=mu.device, dtype=mu.data.dtype)) 342 | 343 | def get_gates(self, x): 344 | with torch.no_grad(): 345 | gates = self.hard_sigmoid(self.net(x)) 346 | return gates 347 | 348 | def num_open_gates(self, x): 349 | return torch.sum(self.get_gates(x) > 0).item() / x.size(0) 350 | 351 | 352 | def train_test(cfg): 353 | torch.use_deterministic_algorithms(True) 354 | torch.backends.cudnn.deterministic = True 355 | torch.backends.cudnn.benchmark = False 356 | gated_str = "_gated" if cfg.gated else "" 357 | with open(f"results_{os.path.basename(__file__)}{gated_str}_reg_beta_{cfg.reg_beta}.txt", mode='w') as f: 358 | 359 | header = '\t'.join(['seed', 'acc', 'local_gates']) 360 | f.write(f"{header}\n") 361 | f.flush() 362 | 363 | for seed in range(cfg.repitions): 364 | cfg.seed = seed 365 | seed_everything(seed) 366 | np.random.seed(seed) 367 | if not os.path.exists(cfg.dataset): 368 | os.makedirs(cfg.dataset) 369 | model = ClassificationModule(cfg) 370 | logger = TensorBoardLogger(cfg.dataset, name=os.path.basename(__file__), log_graph=False) 371 | trainer = Trainer( 372 | devices=cfg.devices, 373 | accelerator=cfg.accelerator, 374 | max_epochs=cfg.max_epochs, 375 | deterministic=cfg.deterministic, 376 | logger=cfg.logger, 377 | log_every_n_steps=cfg.log_every_n_steps, 378 | check_val_every_n_epoch=cfg.check_val_every_n_epoch, 379 | enable_checkpointing=cfg.enable_checkpointing, 380 | callbacks=[LearningRateMonitor(logging_interval='step')] 381 | ) 382 | trainer.logger = logger 383 | trainer.fit(model) 384 | if cfg.gated: 385 | results_str = '\t'.join([f'{seed}', f'{model.best_acc}', f'{model.best_local_feats}']) 386 | else: 387 | results_str = '\t'.join([f'{seed}', f'{model.best_acc}']) 388 | f.write(f"{results_str}\n") 389 | f.flush() 390 | 391 | 392 | if __name__ == "__main__": 393 | cfg = parse_args(None) 394 | train_test(cfg) 395 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | def init_weights_normal(m): 6 | if isinstance(m, torch.nn.Linear): 7 | torch.nn.init.normal_(m.weight, std=0.001) 8 | if 'bias' in vars(m).keys(): 9 | m.bias.data.fill_(0.0) 10 | 11 | 12 | def clustering_head(cfg): 13 | return torch.nn.Sequential( 14 | torch.nn.Linear(cfg.clustering_head[0], cfg.clustering_head[1]), 15 | torch.nn.BatchNorm1d(cfg.clustering_head[1]), 16 | torch.nn.ReLU(), 17 | torch.nn.Linear(cfg.clustering_head[1], cfg.n_clusters)).apply(init_weights_normal) 18 | 19 | 20 | def aux_classifier_head(cfg): 21 | return torch.nn.Sequential( 22 | torch.nn.Linear(cfg.input_dim, cfg.aux_classifier[0]), 23 | torch.nn.BatchNorm1d(cfg.aux_classifier[0]), 24 | torch.nn.ReLU(), 25 | torch.nn.Linear(cfg.aux_classifier[0], cfg.n_clusters)).apply(init_weights_normal) 26 | 27 | 28 | class EncoderDecoder(torch.nn.Module): 29 | def __init__(self, cfg): 30 | super(EncoderDecoder, self).__init__() 31 | self.cfg = cfg 32 | self.encoder = [] 33 | self.encoder = self.build_encoder() 34 | self.decoder = self.build_decoder() 35 | self.encoder.apply(init_weights_normal) 36 | self.decoder.apply(init_weights_normal) 37 | 38 | def build_encoder(self): 39 | layers = [ 40 | torch.nn.Linear(self.cfg.input_dim, self.cfg.encdec[0]), 41 | torch.nn.BatchNorm1d(self.cfg.encdec[0]), 42 | torch.nn.ReLU() 43 | ] 44 | hidden_layers = len(self.cfg.encdec) // 2 + 1 45 | for layer_idx in range(1, hidden_layers): 46 | if layer_idx == hidden_layers - 1: 47 | layers += [torch.nn.Linear(self.cfg.encdec[layer_idx - 1], self.cfg.encdec[layer_idx])] 48 | else: 49 | layers += [ 50 | torch.nn.Linear(self.cfg.encdec[layer_idx - 1], self.cfg.encdec[layer_idx]), 51 | torch.nn.BatchNorm1d(self.cfg.encdec[layer_idx]), 52 | torch.nn.ReLU() 53 | ] 54 | return torch.nn.Sequential(*layers) 55 | 56 | def build_decoder(self): 57 | hidden_layers = len(self.cfg.encdec) // 2 + 1 58 | layers = [] 59 | for layer_idx in range(hidden_layers, len(self.cfg.encdec)): 60 | layers += [ 61 | torch.nn.Linear(self.cfg.encdec[layer_idx - 1], self.cfg.encdec[layer_idx]), 62 | torch.nn.BatchNorm1d(self.cfg.encdec[layer_idx]), 63 | torch.nn.ReLU() 64 | ] 65 | layers += [torch.nn.Linear(self.cfg.encdec[-1], self.cfg.input_dim)] 66 | return torch.nn.Sequential(*layers) 67 | 68 | def forward(self, x): 69 | return self.decoder(self.encoder(x)) 70 | 71 | class GatingNet(torch.nn.Module): 72 | def __init__(self, cfg): 73 | super(GatingNet, self).__init__() 74 | self.cfg = cfg 75 | self._sqrt_2 = math.sqrt(2) 76 | self.sigma = 0.5 77 | self.local_gates = torch.nn.Sequential( 78 | torch.nn.Linear(cfg.input_dim, cfg.gates_hidden_dim), 79 | torch.nn.Tanh(), 80 | torch.nn.Linear(cfg.gates_hidden_dim, cfg.input_dim), 81 | torch.nn.Tanh() 82 | ) 83 | self.local_gates.apply(self.init_weights) 84 | self.global_gates_net = torch.nn.Embedding(self.cfg.n_clusters, self.cfg.input_dim) 85 | torch.nn.init.normal_(self.global_gates_net.weight, std=0.01) 86 | 87 | @staticmethod 88 | def init_weights(m): 89 | if isinstance(m, torch.nn.Linear): 90 | torch.nn.init.normal_(m.weight, std=0.001) 91 | if 'bias' in vars(m).keys(): 92 | m.bias.data.fill_(0.0) 93 | 94 | def global_forward(self, y): 95 | noise = torch.normal(mean=0, std=self.sigma, size=(y.size(0), self.cfg.input_dim), 96 | device=self.global_gates_net.weight.device) 97 | z = torch.tanh(self.global_gates_net(y)) + .5 * noise * self.training 98 | gates = self.hard_sigmoid(z) 99 | return torch.tanh(self.global_gates_net(y)), gates 100 | 101 | def open_global_gates(self): 102 | return self.hard_sigmoid(torch.tanh(self.global_gates_net.weight)).sum(dim=1).mean().cpu().item() 103 | 104 | def forward(self, x): 105 | noise = torch.normal(mean=0, std=self.sigma, size=x.size(), device=x.device) 106 | mu = self.local_gates(x) 107 | z = mu + .5 * noise * self.training 108 | gates = self.hard_sigmoid(z) 109 | sparse_x = x * gates 110 | return mu, sparse_x, gates 111 | 112 | @staticmethod 113 | def hard_sigmoid(x): 114 | return torch.clamp(x + .5, 0.0, 1.0) 115 | 116 | def regularization(self, mu, reduction_func=torch.mean): 117 | return reduction_func(0.5 - 0.5 * torch.erf((-1 / 2 - mu) / self._sqrt_2)) 118 | 119 | def get_gates(self, x): 120 | with torch.no_grad(): 121 | gates = self.hard_sigmoid(self.local_gates(x)) 122 | return gates 123 | 124 | def num_open_gates(self, x, ): 125 | return self.get_gates(x).sum(dim=1).cpu().median(dim=0)[0].item() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0.1 2 | pytorch-lightning==2.0.0 3 | scikit-learn==1.1.2 4 | scipy==1.10.0-rc1 5 | omegaconf==2.2.3 6 | matplotlib==3.6.3 7 | matplotlib-inline==0.1.6 8 | umap-learn==0.5.6 9 | torchvision==0.15.2 10 | munkres==1.1.4 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | import torch 3 | import math 4 | from omegaconf import OmegaConf 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | from pytorch_lightning import LightningModule 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | from pytorch_lightning import Trainer, seed_everything 11 | import os 12 | from pytorch_lightning.loggers import TensorBoardLogger 13 | from pytorch_lightning.callbacks import LearningRateMonitor 14 | from sklearn.metrics import silhouette_score, davies_bouldin_score 15 | import argparse 16 | from dataset import NumpyTableDataset 17 | from model import clustering_head, aux_classifier_head, EncoderDecoder, GatingNet 18 | import umap 19 | 20 | 21 | class TotalCodingRateWithProjection(torch.nn.Module): 22 | """ Based on https://github.com/zengyi-li/NMCE-release/blob/main/NMCE/loss.py """ 23 | 24 | def __init__(self, cfg): 25 | super().__init__() 26 | self.eps = cfg.gtcr_eps 27 | if cfg.gtcr_projection_dim is not None: 28 | self.random_matrix = torch.tensor(np.random.normal( 29 | loc=0.0, 30 | scale=1.0 / np.sqrt(cfg.gtcr_projection_dim), 31 | size=(cfg.input_dim, cfg.gtcr_projection_dim) 32 | )).float() 33 | else: 34 | self.random_matrix = None 35 | 36 | def compute_discrimn_loss(self, W): 37 | p, m = W.shape # [d, B] 38 | I = torch.eye(p, device=W.device) 39 | scalar = p / (m * self.eps) 40 | logdet = torch.logdet(I + scalar * W.matmul(W.T)) 41 | return logdet / 2. 42 | 43 | def forward(self, x): 44 | if self.random_matrix is not None: 45 | x = x @ self.random_matrix.to(x.device) 46 | return - self.compute_discrimn_loss(x.T) 47 | 48 | 49 | class MaximalCodingRateReduction(torch.nn.Module): 50 | """ Based on https://github.com/zengyi-li/NMCE-release/blob/main/NMCE/loss.py """ 51 | 52 | def __init__(self, eps=0.01, gamma=1, compress_only=False): 53 | super(MaximalCodingRateReduction, self).__init__() 54 | self.eps = eps 55 | self.gamma = gamma 56 | self.compress_only = compress_only 57 | 58 | def compute_discrimn_loss(self, W): 59 | p, m = W.shape 60 | I = torch.eye(p, device=W.device) 61 | scalar = p / (m * self.eps) 62 | logdet = torch.logdet(I + scalar * W.matmul(W.T)) 63 | return logdet / 2. 64 | 65 | def compute_compress_loss(self, W, Pi): 66 | p, m = W.shape 67 | k, _, _ = Pi.shape 68 | I = torch.eye(p, device=W.device).expand((k, p, p)) 69 | trPi = Pi.sum(2) + 1e-8 70 | scale = (p / (trPi * self.eps)).view(k, 1, 1) 71 | W = W.view((1, p, m)) 72 | log_det = torch.logdet(I + scale * W.mul(Pi).matmul(W.transpose(1, 2))) 73 | compress_loss = (trPi.squeeze() * log_det / (2 * m)).sum() 74 | return compress_loss 75 | 76 | def forward(self, X, Y, num_classes=None): 77 | # This function support Y as label integer or membership probablity. 78 | if len(Y.shape) == 1: 79 | # if Y is a label vector 80 | if num_classes is None: 81 | num_classes = Y.max() + 1 82 | Pi = torch.zeros((num_classes, 1, Y.shape[0]), device=Y.device) 83 | for indx, label in enumerate(Y): 84 | Pi[label, 0, indx] = 1 85 | else: 86 | # if Y is a probility matrix 87 | if num_classes is None: 88 | num_classes = Y.shape[1] 89 | Pi = Y.T.reshape((num_classes, 1, -1)) 90 | 91 | W = X.T 92 | compress_loss = self.compute_compress_loss(W, Pi) 93 | if not self.compress_only: 94 | discrimn_loss = self.compute_discrimn_loss(W) 95 | return discrimn_loss, compress_loss 96 | else: 97 | return None, compress_loss 98 | 99 | 100 | class BaseModule(LightningModule): 101 | def __init__(self, cfg): 102 | super().__init__() 103 | self.cfg = cfg 104 | 105 | self.train_dataset = NumpyTableDataset.setup( 106 | filepath_samples=cfg.get("filepath_samples"), 107 | num_clusters=cfg.get("num_clusters", None) 108 | ) 109 | self.val_dataset = self.train_dataset 110 | 111 | print(f"Dataset length: {self.train_dataset.__len__()}") 112 | self.cfg.input_dim = self.train_dataset.num_features() 113 | self.cfg.n_clusters = self.train_dataset.num_clusters 114 | self.batch_size = min(self.train_dataset.__len__(), cfg.batch_size) 115 | 116 | self.save_hyperparameters() 117 | self.best_evaluation_stats = {} 118 | self.ae_train = False 119 | self.automatic_optimization = False 120 | self.best_accuracy = - np.infty 121 | self.gating_net = GatingNet(self.cfg) 122 | self.encdec = EncoderDecoder(self.cfg) 123 | self.clustering_head = clustering_head(self.cfg) 124 | self.aux_classifier_head = aux_classifier_head(self.cfg) 125 | self.mcrr = MaximalCodingRateReduction(eps=self.cfg.eps, compress_only=True) 126 | self.gtcr_loss = TotalCodingRateWithProjection(self.cfg) 127 | 128 | self.open_gates = [] 129 | self.val_embs_list = [] 130 | 131 | self.max_silhouette_score = [] 132 | self.min_dbi_score = [] 133 | 134 | def train_dataloader(self): 135 | return DataLoader(self.train_dataset, 136 | batch_size=self.batch_size, 137 | drop_last=True, 138 | shuffle=True, 139 | num_workers=0) 140 | 141 | def val_dataloader(self): 142 | return DataLoader(self.val_dataset, 143 | batch_size=self.batch_size, 144 | drop_last=False, 145 | shuffle=False, 146 | num_workers=0) 147 | 148 | def global_gates_step(self, x): 149 | gates = self.gating_net.get_gates(x) 150 | ae_emb = self.encdec.encoder(x * gates) 151 | cluster_logits = self.clustering_head(ae_emb) 152 | y_hat = cluster_logits.argmax(dim=-1) 153 | glob_gates_mu, glob_gates = self.gating_net.global_forward(y_hat) 154 | reg_loss = self.gating_net.regularization(glob_gates_mu) 155 | aux_y_hat = self.aux_classifier_head(x * gates * glob_gates) 156 | aux_loss = F.cross_entropy(aux_y_hat, y_hat) 157 | self.log('glob_gates_reg_loss', reg_loss.item()) 158 | self.log('glob_gates_ce_loss', aux_loss.item()) 159 | return aux_loss + self.cfg.global_gates_lambda * reg_loss 160 | 161 | def ae_step(self, x): 162 | if self.current_epoch > self.cfg.ae_non_gated_epochs: 163 | mu, _, gates = self.gating_net(x) 164 | reg_loss = self.gating_net.regularization(mu) 165 | gtcr_loss = self.gtcr_loss(gates) / x.size(0) 166 | self.log("pretrain/gates_reg_loss", reg_loss.item()) 167 | self.log("pretrain/gates_tcr_loss", gtcr_loss.item()) 168 | loss = self.cosine_increase_lambda( 169 | min_val=0., 170 | max_val=self.cfg.local_gates_lambda 171 | ) * reg_loss + gtcr_loss * self.cfg.gtcr_lambda 172 | else: 173 | gates = torch.ones_like(x, device=x.device).float() 174 | loss = 0 175 | 176 | # task 1: reconstruct x from x 177 | x_recon = self.encdec(x) 178 | x_recon_loss = F.mse_loss(x_recon, x) 179 | self.log("pretrain/x_recon_loss", x_recon_loss.item()) 180 | 181 | # task 2: reconstruct x from gated x: 182 | x_recon_from_gated = self.encdec(x * gates) 183 | x_from_gated_x_recon_loss = F.mse_loss(x_recon_from_gated, x) 184 | self.log("pretrain/x_from_gated_x_recon_loss", x_from_gated_x_recon_loss.item()) 185 | 186 | # task 3: reconstruct x from randomly masked x 187 | mask_rnd = torch.rand(x.size()).to(x.device) 188 | mask = torch.ones(x.size()).to(x.device).float() 189 | mask[mask_rnd < self.cfg.mask_percentage] = 0 190 | x_recon_masked = self.encdec(x * mask) 191 | input_noised_recon_loss = F.mse_loss(x_recon_masked, x) 192 | self.log("pretrain/input_noised_recon_loss", input_noised_recon_loss.item()) 193 | 194 | # task 4: reconstruct x from noisy embedding 195 | e = self.encdec.encoder(x) 196 | e = e * torch.normal(mean=1., std=self.cfg.latent_noise_std, size=e.size(), device=e.device) 197 | recon_noised = self.encdec.decoder(e) 198 | noised_aug_loss = F.mse_loss(recon_noised, x) 199 | self.log("pretrain/latent_noised_recon_loss", noised_aug_loss.item()) 200 | 201 | # combined loss: 202 | loss = loss + x_recon_loss + x_from_gated_x_recon_loss + input_noised_recon_loss + noised_aug_loss 203 | return loss 204 | 205 | def training_step(self, x, batch_idx): 206 | ae_opt, clust_opt, glob_gates_opt = self.optimizers() 207 | pretrain_sched, sch = self.lr_schedulers() 208 | x = x.reshape(x.size(0), -1) 209 | 210 | # reconstruction step + local gates training 211 | if self.current_epoch <= self.cfg.ae_pretrain_epochs: 212 | ae_opt.zero_grad() 213 | loss = self.ae_step(x) 214 | self.manual_backward(loss) 215 | ae_opt.step() 216 | pretrain_sched.step() 217 | return 218 | 219 | # clusters compression step 220 | clust_opt.zero_grad() 221 | gates = self.gating_net.get_gates(x) 222 | ae_emb = self.encdec.encoder(x * gates) 223 | cluster_logits = self.clustering_head(ae_emb) 224 | loss = self.mcrr_loss(ae_emb, cluster_logits) 225 | self.manual_backward(loss) 226 | clust_opt.step() 227 | 228 | # global gates training 229 | if self.current_epoch >= self.cfg.start_global_gates_training_on_epoch: 230 | glob_gates_opt.zero_grad() 231 | loss = self.global_gates_step(x) 232 | self.manual_backward(loss) 233 | glob_gates_opt.step() 234 | sch.step() 235 | 236 | def configure_optimizers(self): 237 | pretrain_optimizer = torch.optim.Adam( 238 | params=chain( 239 | self.encdec.parameters(), 240 | self.gating_net.local_gates.parameters(), 241 | ), 242 | lr=self.cfg.lr.pretrain) 243 | 244 | cluster_optimizer = torch.optim.Adam( 245 | params=chain( 246 | self.clustering_head.parameters(), 247 | ), 248 | lr=self.cfg.lr.clustering) 249 | 250 | glob_gates_opt = torch.optim.SGD( 251 | params=chain( 252 | self.aux_classifier_head.parameters(), 253 | self.gating_net.global_gates_net.parameters(), 254 | ), 255 | lr=self.cfg.lr.aux_classifier) 256 | 257 | steps = self.train_dataset.__len__() // self.batch_size * ( 258 | self.cfg.trainer.max_epochs - self.cfg.ae_pretrain_epochs) 259 | pretrain_steps = self.train_dataset.__len__() // self.batch_size * self.cfg.ae_pretrain_epochs 260 | # pretrain_steps = self.dataset.__len__() // self.batch_size * self.cfg.trainer.max_epochs 261 | print(f"Cosine annealing LR scheduling is applied during {steps} steps") 262 | sched = torch.optim.lr_scheduler.CosineAnnealingLR( 263 | optimizer=cluster_optimizer, 264 | T_max=steps, 265 | eta_min=self.cfg.sched.clustering_min_lr) 266 | pretrain_sched = torch.optim.lr_scheduler.CosineAnnealingLR( 267 | optimizer=pretrain_optimizer, 268 | T_max=pretrain_steps, 269 | eta_min=self.cfg.sched.pretrain_min_lr) 270 | return [pretrain_optimizer, cluster_optimizer, glob_gates_opt], [pretrain_sched, sched] 271 | 272 | def cosine_increase_lambda(self, min_val, max_val): 273 | epoch = self.current_epoch - self.cfg.ae_pretrain_epochs 274 | total_epochs = self.cfg.ae_pretrain_epochs - self.cfg.ae_non_gated_epochs 275 | return min_val + 0.5 * (max_val - min_val) * (1. + np.cos(epoch * math.pi / total_epochs)) 276 | 277 | def validation_step(self, x, batch_idx): 278 | if not (self.ae_train and self.current_epoch < self.cfg.ae_pretrain_epochs) and self.current_epoch > 0: 279 | gates = self.gating_net.get_gates(x) 280 | ae_emb = self.encdec.encoder(x * gates) 281 | cluster_logits = self.clustering_head(ae_emb) 282 | y_hat = cluster_logits.argmax(dim=-1) 283 | self.val_cluster_list.append(y_hat.cpu()) 284 | self.open_gates.append(self.gating_net.num_open_gates(x)) 285 | self.val_embs_list.append(ae_emb) 286 | 287 | def on_validation_epoch_start(self): 288 | self.val_cluster_list = [] 289 | self.open_gates = [] 290 | self.val_embs_list = [] 291 | 292 | @staticmethod 293 | def plot_clustering(val_embs_list, cluster_mtx, current_epoch, silhouette, dbi): 294 | reducer = umap.UMAP(n_neighbors=10, min_dist=0.1, n_components=2, random_state=0) 295 | embedding = reducer.fit_transform(torch.cat(val_embs_list, dim=0).cpu().numpy()) 296 | plt.figure(figsize=(10, 7)) 297 | plt.scatter(embedding[:, 0], embedding[:, 1], c=cluster_mtx.numpy(), s=50, edgecolor='k') 298 | plt.title(f'Clustering (UMAP). Epoch: {current_epoch}. Silhouette: {silhouette:0.3f}. DBI: {dbi:0.3f}') 299 | plt.savefig(f"umap_epoch_{current_epoch}.png") 300 | 301 | def on_validation_epoch_end(self): 302 | if not (self.ae_train and self.current_epoch < self.cfg.ae_pretrain_epochs) and self.current_epoch > 0: 303 | if self.current_epoch < self.cfg.ae_pretrain_epochs - 1: 304 | return 305 | else: 306 | cluster_mtx = torch.cat(self.val_cluster_list, dim=0) 307 | self.log("num_open_gates", np.mean(self.open_gates).item()) 308 | self.log("num_open_global_gates", self.gating_net.open_global_gates()) 309 | if self.cfg.save_seed_checkpoints: 310 | meta_dict = {"gating": self.gating_net.state_dict(), "clustering": self.clustering_net.state_dict()} 311 | torch.save(meta_dict, f'sparse_model_last_{self.cfg.dataset}_seed_{self.cfg.seed}.pth') 312 | try: 313 | silhouette_score_embs = silhouette_score(torch.cat(self.val_embs_list, dim=0).cpu().numpy(), 314 | cluster_mtx.numpy()) 315 | self.log(f'silhouette_score_embs', silhouette_score_embs) 316 | self.max_silhouette_score.append(silhouette_score_embs) 317 | except: 318 | silhouette_score_embs = -1 319 | try: 320 | dbi_score = davies_bouldin_score(torch.cat(self.val_embs_list, dim=0).cpu().numpy(), 321 | cluster_mtx.numpy()) 322 | self.log(f'dbi_score_embs', dbi_score) 323 | self.min_dbi_score.append(dbi_score) 324 | except: 325 | dbi_score = 0 326 | 327 | self.plot_clustering(self.val_embs_list, cluster_mtx, self.current_epoch, silhouette_score_embs, dbi_score) 328 | 329 | def mcrr_loss(self, c, logits): 330 | logprobs = torch.log_softmax(logits, dim=-1) 331 | prob = GumbleSoftmax(self.tau())(logprobs) 332 | _, compress_loss = self.mcrr(F.normalize(c), prob, num_classes=self.cfg.n_clusters) 333 | compress_loss /= c.size(1) 334 | self.log(f'compress_loss', compress_loss.item()) 335 | return compress_loss 336 | 337 | def tau(self): 338 | return self.cfg.tau 339 | 340 | 341 | class GumbleSoftmax(torch.nn.Module): 342 | def __init__(self, tau, straight_through=False): 343 | super().__init__() 344 | self.tau = tau 345 | self.straight_through = straight_through 346 | 347 | def forward(self, logps): 348 | gumble = torch.rand_like(logps).log().mul(-1).log().mul(-1) 349 | logits = logps + gumble 350 | out = (logits / self.tau).softmax(dim=1) 351 | if not self.straight_through: 352 | return out 353 | else: 354 | out_binary = (logits * 1e8).softmax(dim=1).detach() 355 | out_diff = (out_binary - out).detach() 356 | return out_diff + out 357 | 358 | 359 | if __name__ == "__main__": 360 | parser = argparse.ArgumentParser() 361 | parser.add_argument('--cfg', type=str) 362 | args = parser.parse_args() 363 | cfg = OmegaConf.load(args.cfg) 364 | torch.use_deterministic_algorithms(True) 365 | torch.backends.cudnn.deterministic = True 366 | torch.backends.cudnn.benchmark = False 367 | for seed in range(cfg.seeds): 368 | cfg.seed = seed 369 | seed_everything(seed) 370 | np.random.seed(seed) 371 | model = BaseModule(cfg) 372 | logger = TensorBoardLogger("logs", name=os.path.basename(__file__), log_graph=False) 373 | trainer = Trainer(**cfg.trainer, callbacks=[LearningRateMonitor(logging_interval='step')]) 374 | trainer.logger = logger 375 | trainer.fit(model) 376 | -------------------------------------------------------------------------------- /train_evaluate.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | import torch 3 | import math 4 | from omegaconf import OmegaConf 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | from pytorch_lightning import LightningModule 8 | import numpy as np 9 | from pytorch_lightning import Trainer, seed_everything 10 | import os 11 | from pytorch_lightning.loggers import TensorBoardLogger 12 | from pytorch_lightning.callbacks import LearningRateMonitor 13 | from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score, silhouette_score, davies_bouldin_score 14 | import argparse 15 | import dataset 16 | from model import clustering_head, aux_classifier_head, EncoderDecoder, GatingNet 17 | 18 | 19 | class TotalCodingRateWithProjection(torch.nn.Module): 20 | """ Based on https://github.com/zengyi-li/NMCE-release/blob/main/NMCE/loss.py """ 21 | def __init__(self, cfg): 22 | super().__init__() 23 | self.eps = cfg.gtcr_eps 24 | if cfg.gtcr_projection_dim is not None: 25 | self.random_matrix = torch.tensor(np.random.normal( 26 | loc=0.0, 27 | scale=1.0 / np.sqrt(cfg.gtcr_projection_dim), 28 | size=(cfg.input_dim, cfg.gtcr_projection_dim) 29 | )).float() 30 | else: 31 | self.random_matrix = None 32 | 33 | def compute_discrimn_loss(self, W): 34 | p, m = W.shape # [d, B] 35 | I = torch.eye(p, device=W.device) 36 | scalar = p / (m * self.eps) 37 | logdet = torch.logdet(I + scalar * W.matmul(W.T)) 38 | return logdet / 2. 39 | 40 | def forward(self, x): 41 | if self.random_matrix is not None: 42 | x = x @ self.random_matrix.to(x.device) 43 | return - self.compute_discrimn_loss(x.T) 44 | 45 | 46 | class MaximalCodingRateReduction(torch.nn.Module): 47 | """ Based on https://github.com/zengyi-li/NMCE-release/blob/main/NMCE/loss.py """ 48 | 49 | def __init__(self, eps=0.01, gamma=1, compress_only=False): 50 | super(MaximalCodingRateReduction, self).__init__() 51 | self.eps = eps 52 | self.gamma = gamma 53 | self.compress_only = compress_only 54 | 55 | def compute_discrimn_loss(self, W): 56 | p, m = W.shape 57 | I = torch.eye(p, device=W.device) 58 | scalar = p / (m * self.eps) 59 | logdet = torch.logdet(I + scalar * W.matmul(W.T)) 60 | return logdet / 2. 61 | 62 | def compute_compress_loss(self, W, Pi): 63 | p, m = W.shape 64 | k, _, _ = Pi.shape 65 | I = torch.eye(p, device=W.device).expand((k, p, p)) 66 | trPi = Pi.sum(2) + 1e-8 67 | scale = (p / (trPi * self.eps)).view(k, 1, 1) 68 | W = W.view((1, p, m)) 69 | log_det = torch.logdet(I + scale * W.mul(Pi).matmul(W.transpose(1, 2))) 70 | compress_loss = (trPi.squeeze() * log_det / (2 * m)).sum() 71 | return compress_loss 72 | 73 | def forward(self, X, Y, num_classes=None): 74 | # This function support Y as label integer or membership probablity. 75 | if len(Y.shape) == 1: 76 | # if Y is a label vector 77 | if num_classes is None: 78 | num_classes = Y.max() + 1 79 | Pi = torch.zeros((num_classes, 1, Y.shape[0]), device=Y.device) 80 | for indx, label in enumerate(Y): 81 | Pi[label, 0, indx] = 1 82 | else: 83 | # if Y is a probility matrix 84 | if num_classes is None: 85 | num_classes = Y.shape[1] 86 | Pi = Y.T.reshape((num_classes, 1, -1)) 87 | 88 | W = X.T 89 | compress_loss = self.compute_compress_loss(W, Pi) 90 | if not self.compress_only: 91 | discrimn_loss = self.compute_discrimn_loss(W) 92 | return discrimn_loss, compress_loss 93 | else: 94 | return None, compress_loss 95 | 96 | 97 | class BaseModule(LightningModule): 98 | def __init__(self, cfg): 99 | super().__init__() 100 | self.cfg = cfg 101 | self.train_dataset = getattr(dataset, cfg.dataset).setup(cfg) 102 | self.val_dataset = self.train_dataset 103 | 104 | print(f"Dataset length: {self.train_dataset.__len__()}") 105 | self.cfg.input_dim = self.train_dataset.num_features() 106 | self.cfg.n_clusters = self.train_dataset.num_clusters 107 | self.batch_size = min(self.train_dataset.__len__(), cfg.batch_size) 108 | 109 | self.save_hyperparameters() 110 | self.best_evaluation_stats = {} 111 | self.ae_train = False 112 | self.automatic_optimization = False 113 | self.best_accuracy = - np.infty 114 | self.gating_net = GatingNet(self.cfg) 115 | self.encdec = EncoderDecoder(self.cfg) 116 | self.clustering_head = clustering_head(self.cfg) 117 | self.aux_classifier_head = aux_classifier_head(self.cfg) 118 | self.mcrr = MaximalCodingRateReduction(eps=self.cfg.eps, compress_only=True) 119 | self.gtcr_loss = TotalCodingRateWithProjection(self.cfg) 120 | 121 | self.val_cluster_list = [] 122 | self.val_cluster_list_gated = [] 123 | self.val_label_list = [] 124 | self.open_gates = [] 125 | self.val_embs_list = [] 126 | 127 | self.best_acc = - 100 128 | self.best_ari = - 100 129 | self.best_nmi = - 100 130 | self.best_local_feats = None 131 | self.best_global_feats = None 132 | self.max_silhouette_score = [] 133 | self.min_dbi_score = [] 134 | 135 | def train_dataloader(self): 136 | return DataLoader(self.train_dataset, 137 | batch_size=self.batch_size, 138 | drop_last=True, 139 | shuffle=True, 140 | num_workers=0) 141 | 142 | def val_dataloader(self): 143 | return DataLoader(self.val_dataset, 144 | batch_size=self.batch_size, 145 | drop_last=False, 146 | shuffle=False, 147 | num_workers=0) 148 | 149 | def update_stats(self, acc, ari, nmi, local_feats, global_feats): 150 | if self.best_acc <= acc: 151 | self.best_acc = acc 152 | self.best_ari = ari 153 | self.best_nmi = nmi 154 | self.best_local_feats = local_feats 155 | self.best_global_feats = global_feats 156 | 157 | def global_gates_step(self, x): 158 | gates = self.gating_net.get_gates(x) 159 | ae_emb = self.encdec.encoder(x * gates) 160 | cluster_logits = self.clustering_head(ae_emb) 161 | y_hat = cluster_logits.argmax(dim=-1) 162 | glob_gates_mu, glob_gates = self.gating_net.global_forward(y_hat) 163 | reg_loss = self.gating_net.regularization(glob_gates_mu) 164 | aux_y_hat = self.aux_classifier_head(x * gates * glob_gates) 165 | aux_loss = F.cross_entropy(aux_y_hat, y_hat) 166 | self.log('train/glob_gates_reg_loss', reg_loss.item()) 167 | self.log('train/glob_gates_ce_loss', aux_loss.item()) 168 | return aux_loss + self.cfg.global_gates_lambda * reg_loss 169 | 170 | def ae_step(self, x): 171 | if self.current_epoch > self.cfg.ae_non_gated_epochs: 172 | mu, _, gates = self.gating_net(x) 173 | reg_loss = self.gating_net.regularization(mu) 174 | gtcr_loss = self.gtcr_loss(gates) / x.size(0) 175 | self.log("pretrain/gates_reg_loss", reg_loss.item()) 176 | self.log("pretrain/gates_tcr_loss", gtcr_loss.item()) 177 | loss = self.cosine_increase_lambda( 178 | min_val=0., 179 | max_val=self.cfg.local_gates_lambda 180 | ) * reg_loss + gtcr_loss * self.cfg.gtcr_lambda 181 | else: 182 | gates = torch.ones_like(x, device=x.device).float() 183 | loss = 0 184 | 185 | # task 1: reconstruct x from x 186 | x_recon = self.encdec(x) 187 | x_recon_loss = F.mse_loss(x_recon, x) 188 | self.log("pretrain/x_recon_loss", x_recon_loss.item()) 189 | 190 | # task 2: reconstruct x from gated x: 191 | x_recon_from_gated = self.encdec(x * gates) 192 | x_from_gated_x_recon_loss = F.mse_loss(x_recon_from_gated, x) 193 | self.log("pretrain/x_from_gated_x_recon_loss", x_from_gated_x_recon_loss.item()) 194 | 195 | # task 3: reconstruct x from randomly masked x 196 | mask_rnd = torch.rand(x.size()).to(x.device) 197 | mask = torch.ones(x.size()).to(x.device).float() 198 | mask[mask_rnd < self.cfg.mask_percentage] = 0 199 | x_recon_masked = self.encdec(x * mask) 200 | input_noised_recon_loss = F.mse_loss(x_recon_masked, x) 201 | self.log("pretrain/input_noised_recon_loss", input_noised_recon_loss.item()) 202 | 203 | # task 4: reconstruct x from noisy embedding 204 | e = self.encdec.encoder(x) 205 | e = e * torch.normal(mean=1., std=self.cfg.latent_noise_std, size=e.size(), device=e.device) 206 | recon_noised = self.encdec.decoder(e) 207 | noised_aug_loss = F.mse_loss(recon_noised, x) 208 | self.log("pretrain/latent_noised_recon_loss", noised_aug_loss.item()) 209 | 210 | # combined loss: 211 | loss = loss + x_recon_loss + x_from_gated_x_recon_loss + input_noised_recon_loss + noised_aug_loss 212 | return loss 213 | 214 | def training_step(self, batch, batch_idx): 215 | ae_opt, clust_opt, glob_gates_opt = self.optimizers() 216 | pretrain_sched, sch = self.lr_schedulers() 217 | x, _ = batch 218 | x = x.reshape(x.size(0), -1) 219 | 220 | # reconstruction step + local gates training 221 | if self.current_epoch <= self.cfg.ae_pretrain_epochs: 222 | ae_opt.zero_grad() 223 | loss = self.ae_step(x) 224 | self.manual_backward(loss) 225 | ae_opt.step() 226 | pretrain_sched.step() 227 | return 228 | 229 | # clusters compression step 230 | clust_opt.zero_grad() 231 | gates = self.gating_net.get_gates(x) 232 | ae_emb = self.encdec.encoder(x * gates) 233 | cluster_logits = self.clustering_head(ae_emb) 234 | loss = self.mcrr_loss(ae_emb, cluster_logits) 235 | self.manual_backward(loss) 236 | clust_opt.step() 237 | 238 | # global gates training 239 | if self.current_epoch >= self.cfg.start_global_gates_training_on_epoch: 240 | glob_gates_opt.zero_grad() 241 | loss = self.global_gates_step(x) 242 | self.manual_backward(loss) 243 | glob_gates_opt.step() 244 | sch.step() 245 | 246 | def configure_optimizers(self): 247 | pretrain_optimizer = torch.optim.Adam( 248 | params=chain( 249 | self.encdec.parameters(), 250 | self.gating_net.local_gates.parameters(), 251 | ), 252 | lr=self.cfg.lr.pretrain) 253 | 254 | cluster_optimizer = torch.optim.Adam( 255 | params=chain( 256 | self.clustering_head.parameters(), 257 | ), 258 | lr=self.cfg.lr.clustering) 259 | 260 | glob_gates_opt = torch.optim.SGD( 261 | params=chain( 262 | self.aux_classifier_head.parameters(), 263 | self.gating_net.global_gates_net.parameters(), 264 | ), 265 | lr=self.cfg.lr.aux_classifier) 266 | 267 | steps = self.train_dataset.__len__() // self.batch_size * ( 268 | self.cfg.trainer.max_epochs - self.cfg.ae_pretrain_epochs) 269 | pretrain_steps = self.train_dataset.__len__() // self.batch_size * self.cfg.ae_pretrain_epochs 270 | # pretrain_steps = self.dataset.__len__() // self.batch_size * self.cfg.trainer.max_epochs 271 | print(f"Cosine annealing LR scheduling is applied during {steps} steps") 272 | sched = torch.optim.lr_scheduler.CosineAnnealingLR( 273 | optimizer=cluster_optimizer, 274 | T_max=steps, 275 | eta_min=self.cfg.sched.clustering_min_lr) 276 | pretrain_sched = torch.optim.lr_scheduler.CosineAnnealingLR( 277 | optimizer=pretrain_optimizer, 278 | T_max=pretrain_steps, 279 | eta_min=self.cfg.sched.pretrain_min_lr) 280 | return [pretrain_optimizer, cluster_optimizer, glob_gates_opt], [pretrain_sched, sched] 281 | 282 | def cosine_increase_lambda(self, min_val, max_val): 283 | epoch = self.current_epoch - self.cfg.ae_pretrain_epochs 284 | total_epochs = self.cfg.ae_pretrain_epochs - self.cfg.ae_non_gated_epochs 285 | return min_val + 0.5 * (max_val - min_val) * (1. + np.cos(epoch * math.pi / total_epochs)) 286 | 287 | def validation_step(self, batch, batch_idx): 288 | x, y = batch 289 | gates = self.gating_net.get_gates(x) 290 | ae_emb = self.encdec.encoder(x * gates) 291 | cluster_logits = self.clustering_head(ae_emb) 292 | y_hat = cluster_logits.argmax(dim=-1) 293 | self.val_cluster_list.append(y_hat.cpu()) 294 | self.val_label_list.append(y.cpu()) 295 | self.open_gates.append(self.gating_net.num_open_gates(x)) 296 | self.val_embs_list.append(ae_emb) 297 | 298 | def on_validation_epoch_start(self): 299 | self.val_cluster_list = [] 300 | self.val_cluster_list_gated = [] 301 | self.val_label_list = [] 302 | self.open_gates = [] 303 | self.val_embs_list = [] 304 | 305 | @staticmethod 306 | def cluster_match(cluster_mtx, label_mtx, n_classes=10, print_result=True): 307 | cluster_indx = list(cluster_mtx.unique()) 308 | assigned_label_list = [] 309 | assigned_count = [] 310 | while (len(assigned_label_list) <= n_classes) and len(cluster_indx) > 0: 311 | max_label_list = [] 312 | max_count_list = [] 313 | for indx in cluster_indx: 314 | mask = cluster_mtx == indx 315 | label_elements, counts = label_mtx[mask].unique(return_counts=True) 316 | for assigned_label in assigned_label_list: 317 | counts[label_elements == assigned_label] = 0 318 | max_count_list.append(counts.max()) 319 | max_label_list.append(label_elements[counts.argmax()]) 320 | 321 | max_label = torch.stack(max_label_list) 322 | max_count = torch.stack(max_count_list) 323 | assigned_label_list.append(max_label[max_count.argmax()]) 324 | assigned_count.append(max_count.max()) 325 | cluster_indx.pop(max_count.argmax().item()) 326 | total_correct = torch.tensor(assigned_count).sum().item() 327 | total_sample = cluster_mtx.shape[0] 328 | acc = total_correct / total_sample 329 | if print_result: 330 | print('{}/{} ({}%) correct'.format(total_correct, total_sample, acc * 100)) 331 | else: 332 | return total_correct, total_sample, acc 333 | 334 | def on_validation_epoch_end(self): 335 | """ Based on https://github.com/zengyi-li/NMCE-release/blob/main/NMCE/func.py""" 336 | if not (self.ae_train and self.current_epoch < self.cfg.ae_pretrain_epochs) and self.current_epoch > 0: 337 | if self.current_epoch < self.cfg.ae_pretrain_epochs - 1: 338 | return 339 | else: 340 | cluster_mtx = torch.cat(self.val_cluster_list, dim=0) 341 | label_mtx = torch.cat(self.val_label_list, dim=0) 342 | _, _, acc_single = self.cluster_match( 343 | cluster_mtx, 344 | label_mtx, 345 | n_classes=label_mtx.max() + 1, 346 | print_result=False) 347 | if self.best_accuracy < acc_single: 348 | print("New best accuracy:", acc_single) 349 | self.best_accuracy = acc_single 350 | if self.cfg.save_seed_checkpoints: 351 | meta_dict = {"gating": self.gating_net.state_dict(), "clustering": self.clustering_net.state_dict()} 352 | torch.save(meta_dict, f'sparse_model_best_{self.cfg.dataset}_seed_{self.cfg.seed}.pth') 353 | 354 | nmi = normalized_mutual_info_score(label_mtx.numpy(), cluster_mtx.numpy()) 355 | ari = adjusted_rand_score(label_mtx.numpy(), cluster_mtx.numpy()) 356 | format_str = '' # '_kmeans' if self.current_epoch == 9 else '' 357 | self.log(f'val/acc_single{format_str}', acc_single) # this is ACC 358 | self.log(f'val/NMI{format_str}', nmi) 359 | self.log(f'val/ARI{format_str}', ari) 360 | self.log("val/num_open_gates", np.mean(self.open_gates).item()) 361 | self.log("val/num_open_global_gates", self.gating_net.open_global_gates()) 362 | if self.cfg.save_seed_checkpoints: 363 | meta_dict = {"gating": self.gating_net.state_dict(), "clustering": self.clustering_net.state_dict()} 364 | torch.save(meta_dict, f'sparse_model_last_{self.cfg.dataset}_seed_{self.cfg.seed}.pth') 365 | 366 | self.update_stats(acc_single, ari, nmi, np.mean(self.open_gates).item(), 367 | self.gating_net.open_global_gates()) 368 | 369 | try: 370 | silhouette_score_embs = silhouette_score(torch.cat(self.val_embs_list, dim=0).cpu().numpy(), 371 | cluster_mtx.numpy()) 372 | self.log(f'val/silhouette_score_embs', silhouette_score_embs) 373 | self.max_silhouette_score.append(silhouette_score_embs) 374 | except: 375 | pass 376 | try: 377 | dbi_score = davies_bouldin_score(torch.cat(self.val_embs_list, dim=0).cpu().numpy(), 378 | cluster_mtx.numpy()) 379 | self.log(f'val/dbi_score_embs', dbi_score) 380 | self.min_dbi_score.append(dbi_score) 381 | except: 382 | pass 383 | 384 | def mcrr_loss(self, c, logits): 385 | logprobs = torch.log_softmax(logits, dim=-1) 386 | prob = GumbleSoftmax(self.tau())(logprobs) 387 | _, compress_loss = self.mcrr(F.normalize(c), prob, num_classes=self.cfg.n_clusters) 388 | compress_loss /= c.size(1) 389 | self.log(f'train/compress_loss', compress_loss.item()) 390 | return compress_loss 391 | 392 | def tau(self): 393 | return self.cfg.tau 394 | 395 | 396 | class GumbleSoftmax(torch.nn.Module): 397 | def __init__(self, tau, straight_through=False): 398 | super().__init__() 399 | self.tau = tau 400 | self.straight_through = straight_through 401 | 402 | def forward(self, logps): 403 | gumble = torch.rand_like(logps).log().mul(-1).log().mul(-1) 404 | logits = logps + gumble 405 | out = (logits / self.tau).softmax(dim=1) 406 | if not self.straight_through: 407 | return out 408 | else: 409 | out_binary = (logits * 1e8).softmax(dim=1).detach() 410 | out_diff = (out_binary - out).detach() 411 | return out_diff + out 412 | 413 | 414 | if __name__ == "__main__": 415 | parser = argparse.ArgumentParser() 416 | parser.add_argument('--cfg', type=str) 417 | args = parser.parse_args() 418 | cfg = OmegaConf.load(args.cfg) 419 | torch.use_deterministic_algorithms(True) 420 | torch.backends.cudnn.deterministic = True 421 | torch.backends.cudnn.benchmark = False 422 | if not cfg.validate: 423 | cfg.trainer.check_val_every_n_epoch = cfg.trainer.max_epochs + 1 # the validation will be never done 424 | 425 | with open(f"results_{os.path.basename(__file__)}.txt", mode='a') as f: 426 | header = '\t'.join(['seed', 'acc', 'ari', 'nmi', 'local_gates', 'global_gates', 427 | 'topk_max_silhouette_score', 'topk_min_dbi_score']) 428 | f.write(f"{header}\n") 429 | f.flush() 430 | for seed in range(cfg.seeds): 431 | cfg.seed = seed 432 | seed_everything(seed) 433 | np.random.seed(seed) 434 | if not os.path.exists(cfg.dataset): 435 | os.makedirs(cfg.dataset) 436 | model = BaseModule(cfg) 437 | logger = TensorBoardLogger(cfg.dataset, name=os.path.basename(__file__), log_graph=False) 438 | trainer = Trainer(**cfg.trainer, callbacks=[LearningRateMonitor(logging_interval='step')]) 439 | trainer.logger = logger 440 | trainer.fit(model) 441 | topk_max_siluetter_score = np.mean(sorted(model.max_silhouette_score, reverse=True)[:10]) 442 | topk_min_dbi_score = np.mean(sorted(model.max_silhouette_score)[:10]) 443 | results_str = '\t'.join( 444 | [f'{seed}', 445 | f'{model.best_acc}', 446 | f'{model.best_ari}', 447 | f'{model.best_nmi}', 448 | f'{model.best_local_feats}', 449 | f'{model.best_global_feats}', 450 | f'{topk_max_siluetter_score}', 451 | f'{topk_min_dbi_score}', 452 | ]) 453 | with open(f"results_{os.path.basename(__file__)}.txt", mode='a') as f: 454 | f.write(f"{results_str}\n") 455 | f.flush() 456 | --------------------------------------------------------------------------------