├── .gitignore
├── README.md
├── calc_NUDFT.ipynb
├── cfg
├── cfg_mnist.yaml
└── cfg_run.yaml
├── data
├── pbmc_x.zip
└── pbmc_y.zip
├── dataset.py
├── idc_evaluate.ipynb
├── idc_example.ipynb
├── img
├── freq_bias.png
├── img.png
├── nudft_ALLAML.png
├── pbmc.gif
├── supervised_train_plots.png
└── supervised_train_plots2.png
├── interpretability_metrics.py
├── lspin_pbmc.py
├── model.py
├── requirements.txt
├── run.py
└── train_evaluate.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | # Interpretable Deep Clustering
7 |
8 |
9 |
10 |
11 |
12 | ## An official implementation of the ICML 2024 accepted paper: [Interpretable Deep Clustering for Tabular Data](https://openreview.net/pdf?id=QPy7zLfvof)
13 |
14 |
15 |
16 |
17 | ## UPDATES:
18 |
19 | - 2024-10-08:
20 | * we add an example notebook that produces a NUDFT plot and compares gated vs non-gated supervised models
21 | * we present below a fixed figure 4 from the paper: the frequency values are normalized for each model and gated features are used for IDC model rather than raw features.
22 |
23 |
24 | ## How to run on your data:
25 |
26 | 1. Do you have a dataset without labels? Use our colab example notebook:
27 | 2. If you have a labeled dataset, please follow the colab with evaluation example:
28 |
29 |
30 | ## Fixed Figure 4. Spectral properties of the learned predictive function using ALLAML dataset.
31 |
32 | The model trained with the gating network (IDC) has higher Fourier amplitudes at all frequency levels than
33 | without gates (IDCw/o_gates) the baseline (TELL). This suggests that IDC can better handle the inductive bias of tabular data.
34 |
35 |
36 |
37 | ### Citation:
38 | Please cite our paper if you use this code:
39 |
40 |
41 | ```
42 |
43 | @InProceedings{pmlr-v235-svirsky24a,
44 | title = {Interpretable Deep Clustering for Tabular Data},
45 | author = {Svirsky, Jonathan and Lindenbaum, Ofir},
46 | booktitle = {Proceedings of the 41st International Conference on Machine Learning},
47 | pages = {47314--47330},
48 | year = {2024},
49 | editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix},
50 | volume = {235},
51 | series = {Proceedings of Machine Learning Research},
52 | month = {21--27 Jul},
53 | publisher = {PMLR},
54 | pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/svirsky24a/svirsky24a.pdf},
55 | url = {https://proceedings.mlr.press/v235/svirsky24a.html},
56 | abstract = {Clustering is a fundamental learning task widely used as a first step in data analysis. For example, biologists use cluster assignments to analyze genome sequences, medical records, or images. Since downstream analysis is typically performed at the cluster level, practitioners seek reliable and interpretable clustering models. We propose a new deep-learning framework for general domain tabular data that predicts interpretable cluster assignments at the instance and cluster levels. First, we present a self-supervised procedure to identify the subset of the most informative features from each data point. Then, we design a model that predicts cluster assignments and a gate matrix that provides cluster-level feature selection. Overall, our model provides cluster assignments with an indication of the driving feature for each sample and each cluster. We show that the proposed method can reliably predict cluster assignments in biological, text, image, and physics tabular datasets. Furthermore, using previously proposed metrics, we verify that our model leads to interpretable results at a sample and cluster level. Our code is available on https://github.com/jsvir/idc.}
57 | }
58 | ```
59 |
60 |
61 |
62 | ### TODO list:
63 | - [x] Add interpretability evaluation scripts
64 | - [ ] Add experiments configs
65 | - [ ] Add features index outputs
66 | - [x] Add synthetic dataset deneration code
--------------------------------------------------------------------------------
/calc_NUDFT.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "source": [
6 | "# Calculation of NUDFT \n",
7 | "\n",
8 | "We proved an example for pre-trained supervised model [LSPIN](https://proceedings.mlr.press/v162/yang22i/yang22i.pdf) trained on PBMC dataset compared against deep classifier.\n",
9 | "\n",
10 | "We provide the checkpoints in ckpts directory.\n"
11 | ],
12 | "metadata": {
13 | "collapsed": false
14 | },
15 | "id": "9e52f756a2211364"
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "source": [],
20 | "metadata": {
21 | "collapsed": false
22 | },
23 | "id": "e40cd92806e61978"
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "source": [
28 | "### Imports and util functions\n"
29 | ],
30 | "metadata": {
31 | "collapsed": false
32 | },
33 | "id": "9dbae214166550cd"
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 1,
38 | "outputs": [],
39 | "source": [
40 | "import platform\n",
41 | "from nfft import nfft_adjoint\n",
42 | "import torch\n",
43 | "import numpy as np\n",
44 | "from matplotlib import pyplot as plt\n",
45 | "from omegaconf import OmegaConf\n",
46 | "import seaborn as sns\n",
47 | "from lspin_pbmc import PBMC, Classifier, GatingNet\n",
48 | "\n",
49 | "# which dimension use in the y_hat outputs (normalized with softmax)\n",
50 | "TARGET_IDX = 1\n",
51 | "COLUMNS_SUBSET=100000\n",
52 | "\n",
53 | "\n",
54 | "def spectrum_NUDFT(x, y, kmax=50, nk=1000):\n",
55 | " kvals = np.linspace(0.1, kmax, nk+1)\n",
56 | " nufft = (1 / len(x)) * nfft_adjoint(-(x * kmax / nk), y, 2 * (nk + 1))[nk + 1:]\n",
57 | " return [kvals, np.array(nufft, dtype=\"complex_\")]\n",
58 | "\n",
59 | "\n",
60 | "def select_k_columns_with_max_variance(matrix):\n",
61 | " variance_per_column = torch.var(matrix, dim=0)\n",
62 | " _, sorted_indices = torch.sort(variance_per_column, descending=True)\n",
63 | " top_k_indices = sorted_indices[:COLUMNS_SUBSET]\n",
64 | " selected_columns = matrix[:, top_k_indices]\n",
65 | " return selected_columns"
66 | ],
67 | "metadata": {
68 | "collapsed": false,
69 | "ExecuteTime": {
70 | "end_time": "2024-10-08T12:57:48.715663Z",
71 | "start_time": "2024-10-08T12:57:44.858113500Z"
72 | }
73 | },
74 | "id": "e1c3233b55927f30"
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "source": [
79 | "### Load pre-trained checkpoints"
80 | ],
81 | "metadata": {
82 | "collapsed": false
83 | },
84 | "id": "697ba1c508082d"
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 2,
89 | "outputs": [],
90 | "source": [
91 | "gated_cfg = OmegaConf.create(dict(\n",
92 | " gated=True,\n",
93 | " input_dim=17126,\n",
94 | " n_clusters=2,\n",
95 | " dataset=\"PBMC\",\n",
96 | " data_dir=\"C:/data/fs/pbmc\" if platform.system() == \"Windows\" else \".\" ,\n",
97 | " batch_size=256,\n",
98 | " repitions=5,\n",
99 | " sigma=0.5,\n",
100 | " reg_beta=100,\n",
101 | " devices=1,\n",
102 | " accelerator=\"gpu\",\n",
103 | " max_epochs=100,\n",
104 | " deterministic=True,\n",
105 | " logger=True,\n",
106 | " log_every_n_steps=10,\n",
107 | " check_val_every_n_epoch=1,\n",
108 | " enable_checkpointing=False,\n",
109 | "))\n",
110 | "\n",
111 | "nongated_cfg = OmegaConf.create(dict(\n",
112 | " gated=False,\n",
113 | " input_dim=17126,\n",
114 | " n_clusters=2,\n",
115 | " dataset=\"PBMC\",\n",
116 | " data_dir=\"C:/data/fs/pbmc\" if platform.system() == \"Windows\" else \".\" ,\n",
117 | " batch_size=256,\n",
118 | " repitions=5,\n",
119 | " devices=1,\n",
120 | " accelerator=\"gpu\",\n",
121 | " max_epochs=100,\n",
122 | " deterministic=True,\n",
123 | " logger=True,\n",
124 | " log_every_n_steps=10,\n",
125 | " check_val_every_n_epoch=1,\n",
126 | " enable_checkpointing=False,\n",
127 | "))\n",
128 | "\n",
129 | "\n"
130 | ],
131 | "metadata": {
132 | "collapsed": false,
133 | "ExecuteTime": {
134 | "end_time": "2024-10-08T12:57:48.730048700Z",
135 | "start_time": "2024-10-08T12:57:48.711662700Z"
136 | }
137 | },
138 | "id": "545756af156b20df"
139 | },
140 | {
141 | "cell_type": "markdown",
142 | "source": [
143 | "The models were trained for ~400 epochs and we present the training/validation plots:\n",
144 | "\n",
145 | " \n",
146 | " \n"
147 | ],
148 | "metadata": {
149 | "collapsed": false
150 | },
151 | "id": "101afcef26dfe116"
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": 8,
156 | "outputs": [
157 | {
158 | "name": "stdout",
159 | "output_type": "stream",
160 | "text": [
161 | "Dataset PBMC stats:\n",
162 | "X.shape: (20742, 17126)\n",
163 | "Y.shape: (20742,)\n",
164 | "X.min=-2.734075760955218, X.max=144.01736006468113\n",
165 | "Y.min=0, Y.max=1\n",
166 | "Label 0 has 10479 samples\n",
167 | "Label 1 has 10263 samples\n",
168 | "Split to train/test: train 16594 test 4148\n"
169 | ]
170 | }
171 | ],
172 | "source": [
173 | "gating_net = GatingNet(gated_cfg)\n",
174 | "classifier_gated = Classifier(gated_cfg)\n",
175 | "gating_net.load_state_dict(torch.load(\"ckpts/supervised/sparse_model_last_pbmc_beta_10_seed_0.pth\")[\"gating\"])\n",
176 | "classifier_gated.load_state_dict(torch.load(\"ckpts/supervised/sparse_model_last_pbmc_beta_10_seed_0.pth\")[\"clustering\"])\n",
177 | "\n",
178 | "# # load clustering model without gates:\n",
179 | "classifier = Classifier(nongated_cfg)\n",
180 | "classifier.load_state_dict(torch.load(\"ckpts/supervised/sparse_model_nogates_best_pbmc_seed_0.pth\")[\"clustering\"])\n",
181 | "\n",
182 | "classifier = classifier.to('cpu')\n",
183 | "classifier_gated = classifier_gated.to('cpu')\n",
184 | "\n",
185 | "classifier.eval()\n",
186 | "classifier_gated.eval()\n",
187 | "gating_net.eval()\n",
188 | "\n",
189 | "_, test_dataset = PBMC.setup(gated_cfg.data_dir)\n",
190 | "\n"
191 | ],
192 | "metadata": {
193 | "collapsed": false,
194 | "ExecuteTime": {
195 | "end_time": "2024-10-08T13:17:14.874885100Z",
196 | "start_time": "2024-10-08T13:16:56.649883800Z"
197 | }
198 | },
199 | "id": "a055dcc6d21c43e6"
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 10,
204 | "outputs": [
205 | {
206 | "data": {
207 | "text/plain": ""
208 | },
209 | "metadata": {},
210 | "output_type": "display_data"
211 | },
212 | {
213 | "data": {
214 | "text/plain": "",
215 | "image/png": ""
216 | },
217 | "metadata": {},
218 | "output_type": "display_data"
219 | }
220 | ],
221 | "source": [
222 | "plt.clf()\n",
223 | "plt.figure(figsize=(10, 6))\n",
224 | "plt.yscale('log') # Set the y-axis to logarithmic scale\n",
225 | "plt.xscale('log')\n",
226 | "plt.xticks(fontsize=20)\n",
227 | "plt.yticks(fontsize=20)\n",
228 | "plt.xlim((0.1, 100))\n",
229 | "plt.xlabel('$|k|$', fontsize=20)\n",
230 | "plt.ylabel('$P_{f(x)}(k)$', fontsize=20)\n",
231 | "plt.grid(False)\n",
232 | "\n",
233 | "with torch.no_grad():\n",
234 | " x_tensor = torch.tensor(test_dataset.data).float()\n",
235 | " filtered_dataset = select_k_columns_with_max_variance(x_tensor).cpu().numpy()\n",
236 | " normalized_dataset = filtered_dataset \n",
237 | " gates = gating_net.get_gates(x_tensor)\n",
238 | "\n",
239 | " spectrum_nudft_all = []\n",
240 | " for feat_id in range(filtered_dataset.shape[-1]):\n",
241 | " spectrum_nudft_all.append(spectrum_NUDFT(filtered_dataset[:, feat_id], test_dataset.targets))\n",
242 | " spectra = [np.sqrt(np.abs(t[1] ** 2)).reshape(-1, 1) for t in spectrum_nudft_all]\n",
243 | " k = spectrum_nudft_all[0][0]\n",
244 | " k_repeated = np.concatenate([k] * len(spectra), axis=0).reshape(-1)\n",
245 | "\n",
246 | " # gated data, gated classifier predictions:\n",
247 | " e = classifier_gated.encoder(x_tensor * gates)\n",
248 | " y_hat = torch.softmax(classifier_gated.head(e), dim=1)[:, TARGET_IDX].numpy()\n",
249 | " classifier_gated_spectrum_nudft_all = []\n",
250 | " non_zero_ids = torch.nonzero(gates.sum(dim=0) > 0, as_tuple=True)[0].long()\n",
251 | " max_spectrum_val = - np.inf\n",
252 | " max_spectrum_gated_val = - np.inf\n",
253 | "\n",
254 | " # raw data, classifier predictions:\n",
255 | " e_raw = classifier.encoder(x_tensor)\n",
256 | " y_hat_raw = torch.softmax(classifier.head(e_raw), dim=1)[:, TARGET_IDX].numpy()\n",
257 | " classifier_spectrum_nudft_all = []\n",
258 | " normalized_gated_dataset = normalized_dataset * gates.numpy()\n",
259 | " for feat_id in range(normalized_dataset.shape[-1]):\n",
260 | " feat_spectrum = spectrum_NUDFT(normalized_gated_dataset[:, feat_id], y_hat)\n",
261 | " max_spectrum_gated_val = max(max_spectrum_gated_val, np.abs(feat_spectrum[1]).max())\n",
262 | " classifier_gated_spectrum_nudft_all.append(feat_spectrum)\n",
263 | "\n",
264 | " feat_spectrum_raw = spectrum_NUDFT(normalized_dataset[:, feat_id], y_hat_raw)\n",
265 | " max_spectrum_val = max(max_spectrum_val, np.abs(feat_spectrum_raw[1]).max())\n",
266 | " classifier_spectrum_nudft_all.append(feat_spectrum_raw)\n",
267 | "\n",
268 | " classifier_gated_spectra = [np.abs(t[1]).reshape(-1) for t in classifier_gated_spectrum_nudft_all]\n",
269 | " classifier_gated_spectra_single_column = np.concatenate(classifier_gated_spectra, axis=0)\n",
270 | " classifier_gated_spectra_single_column = classifier_gated_spectra_single_column / max_spectrum_gated_val\n",
271 | " sns.lineplot(x=k_repeated, y=classifier_gated_spectra_single_column, label='$P_{f_{lspin}}(k)$')\n",
272 | "\n",
273 | " classifier_model_spectra = [np.abs(t[1]).reshape(-1) for t in classifier_spectrum_nudft_all]\n",
274 | " classifier_spectra_single_column = np.concatenate(classifier_model_spectra, axis=0)\n",
275 | " classifier_spectra_single_column = classifier_spectra_single_column / max_spectrum_val\n",
276 | " sns.lineplot(x=k_repeated, y=classifier_spectra_single_column, label='$P_{f_{classifier}}(k)$')\n",
277 | "\n",
278 | "plt.legend(fontsize=16)\n",
279 | "plt.show()"
280 | ],
281 | "metadata": {
282 | "collapsed": false,
283 | "ExecuteTime": {
284 | "end_time": "2024-10-08T13:30:05.517370800Z",
285 | "start_time": "2024-10-08T13:18:02.330028300Z"
286 | }
287 | },
288 | "id": "bb12fc6aae6893cd"
289 | },
290 | {
291 | "cell_type": "markdown",
292 | "source": [
293 | "As it could be seen from the plot, the gating of the features adds a high frequency bias and makes the predictor to be more appropriate for tabular dataset. Please refer to the paper [An Inductive Bias for Tabular Deep Learning](https://proceedings.neurips.cc/paper_files/paper/2023/file/8671b6dffc08b4fcf5b8ce26799b2bef-Paper-Conference.pdf) and Figure 1:\n",
294 | "\n",
295 | "\n",
296 | " \n",
297 | "\n",
298 | "Due to their heterogeneous nature, tabular datasets tend to describe higher frequency target functions compared to images. The spectra corresponding to image datasets (curves in color) tend to feature lower Fourier amplitudes at higher frequencies than hetergoneous tabular datasets (cyan region).\n"
299 | ],
300 | "metadata": {
301 | "collapsed": false
302 | },
303 | "id": "91c4ea1eac38133e"
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": null,
308 | "outputs": [],
309 | "source": [],
310 | "metadata": {
311 | "collapsed": false
312 | },
313 | "id": "5a310757b3524a5a"
314 | }
315 | ],
316 | "metadata": {
317 | "kernelspec": {
318 | "display_name": "Python 3",
319 | "language": "python",
320 | "name": "python3"
321 | },
322 | "language_info": {
323 | "codemirror_mode": {
324 | "name": "ipython",
325 | "version": 2
326 | },
327 | "file_extension": ".py",
328 | "mimetype": "text/x-python",
329 | "name": "python",
330 | "nbconvert_exporter": "python",
331 | "pygments_lexer": "ipython2",
332 | "version": "2.7.6"
333 | }
334 | },
335 | "nbformat": 4,
336 | "nbformat_minor": 5
337 | }
338 |
--------------------------------------------------------------------------------
/cfg/cfg_mnist.yaml:
--------------------------------------------------------------------------------
1 | dataset: MNIST10K
2 | data_dir: idc/data
3 | scaler: MinMaxScaler
4 | batch_size: 100
5 | seeds: 1
6 | epochs: &epochs 700
7 |
8 | ae_non_gated_epochs: 10
9 | ae_pretrain_epochs: 300
10 | start_global_gates_training_on_epoch: 400
11 |
12 | mask_percentage: 0.9
13 | latent_noise_std: 0.01
14 |
15 | trainer:
16 | devices: 1
17 | accelerator: gpu
18 | max_epochs: *epochs
19 | deterministic: true
20 | logger: true
21 | log_every_n_steps: 10
22 | check_val_every_n_epoch: 10
23 | enable_checkpointing: false
24 | num_sanity_val_steps: 0
25 |
26 |
27 | # GTCR loss
28 | gtcr_loss: true
29 | gtcr_projection_dim: null # for large number of features use it
30 | gtcr_eps: 1
31 |
32 |
33 | # Compression loss
34 | eps: 0.1
35 |
36 | # Gating Net
37 | use_gating: true
38 | gates_hidden_dim: 784
39 |
40 | # EncoderDecoder
41 | encdec:
42 | - 512
43 | - 512
44 | - 2048
45 | - &bn_layer 10
46 | - 2048
47 | - 512
48 | - 512
49 |
50 | clustering_head:
51 | - *bn_layer
52 | - 2048
53 |
54 | tau: 100
55 |
56 | aux_classifier:
57 | - 2048
58 |
59 | local_gates_lambda: 1
60 | global_gates_lambda: 0.0001
61 | gtcr_lambda: 0.01
62 |
63 | lr:
64 | pretrain: 1e-3
65 | clustering: 1e-2
66 | aux_classifier: 1e-2
67 |
68 | sched:
69 | pretrain_min_lr: 1e-6
70 | clustering_min_lr: 1e-6
71 |
72 |
73 |
74 | save_seed_checkpoints: false
75 | validate: true
--------------------------------------------------------------------------------
/cfg/cfg_run.yaml:
--------------------------------------------------------------------------------
1 | filepath_samples: idc/data/pbmc_x.npz
2 | num_clusters: 2
3 |
4 | batch_size: 256
5 | seeds: 1
6 | epochs: &epochs 200
7 |
8 | ae_non_gated_epochs: 5 #50 we reduce the number of epochs for training inside a notebook
9 | ae_pretrain_epochs: 10 #100 we reduce the number of epochs for training inside a notebook
10 | start_global_gates_training_on_epoch: 150
11 |
12 | mask_percentage: 0.9
13 | latent_noise_std: 0.01
14 |
15 | trainer:
16 | devices: 1
17 | accelerator: gpu
18 | max_epochs: *epochs
19 | deterministic: true
20 | logger: true
21 | log_every_n_steps: 10
22 | check_val_every_n_epoch: 10
23 | enable_checkpointing: false
24 | num_sanity_val_steps: 0
25 |
26 |
27 | # GTCR loss
28 | gtcr_loss: true
29 | gtcr_projection_dim: 1024 # for large number of features use it
30 | gtcr_eps: 1
31 |
32 |
33 | # Compression loss
34 | eps: 0.1
35 |
36 | # Gating Net
37 | use_gating: true
38 | gates_hidden_dim: 1024
39 |
40 | # EncoderDecoder
41 | encdec:
42 | - 512
43 | - 512
44 | - 2048
45 | - &bn_layer 128
46 | - 2048
47 | - 512
48 | - 512
49 |
50 | clustering_head:
51 | - *bn_layer
52 | - 2048
53 |
54 | tau: 100
55 |
56 | aux_classifier:
57 | - 2048
58 |
59 | local_gates_lambda: 100
60 | global_gates_lambda: 10
61 | gtcr_lambda: 0.01
62 |
63 | lr:
64 | pretrain: 1e-3
65 | clustering: 1e-3
66 | aux_classifier: 1e-1
67 |
68 | sched:
69 | pretrain_min_lr: 1e-4
70 | clustering_min_lr: 1e-4
71 |
72 | save_seed_checkpoints: false
73 |
--------------------------------------------------------------------------------
/data/pbmc_x.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsvir/idc/6288558d24268fa842944dd7f0490a62fd9d1fdf/data/pbmc_x.zip
--------------------------------------------------------------------------------
/data/pbmc_y.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsvir/idc/6288558d24268fa842944dd7f0490a62fd9d1fdf/data/pbmc_y.zip
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import MNIST
2 | from torch.utils.data import Dataset
3 | import numpy as np
4 | import torch
5 | from sklearn import preprocessing
6 | from scipy.io import loadmat
7 | from sklearn.preprocessing import MinMaxScaler
8 | from scipy.stats import zscore
9 | import matplotlib.pyplot as plt
10 | from sklearn import datasets
11 |
12 |
13 | class ClusteringDataset(Dataset):
14 | def __init__(self, data, labels=None, num_clusters=None):
15 | super().__init__()
16 | self.data = data
17 | self.labels = labels
18 | self._num_clusters = num_clusters
19 | if num_clusters is None and labels is None:
20 | raise ValueError("At least one of the values should be provided (labels/num_clusters)")
21 | self.print_stats()
22 |
23 | def __getitem__(self, index: int):
24 | if self.labels is None:
25 | return torch.tensor(self.data[index]).float()
26 | return torch.tensor(self.data[index]).float(), torch.tensor(self.labels[index]).long()
27 |
28 | def __len__(self) -> int:
29 | return len(self.data)
30 |
31 | @property
32 | def num_clusters(self):
33 | return self._num_clusters if self._num_clusters is not None else len(np.unique(self.labels))
34 |
35 | def num_features(self):
36 | return self.data.shape[-1]
37 |
38 | def print_stats(self):
39 | print('X.shape: ', self.data.shape)
40 | print(f"X.min={self.data.min()}, X.max={self.data.max()}")
41 | if self.labels is not None:
42 | print('Y.shape: ', self.labels.shape)
43 | for y_u in np.unique(self.labels):
44 | print(f'{y_u}: {np.sum(self.labels == y_u)}')
45 | print(f"Y.min={self.labels.min()}, Y.max={self.labels.max()}")
46 |
47 | @classmethod
48 | def setup(cls, cfg):
49 | pass
50 |
51 |
52 | class PBMC(ClusteringDataset):
53 | def __init__(self, data, targets):
54 | super().__init__(data, targets)
55 |
56 | @classmethod
57 | def setup(cls, cfg):
58 | data_dir = cfg.data_dir
59 | with np.load(f"{data_dir}/pbmc_x.npz") as data:
60 | X = data['arr_0']
61 | with np.load(f"{data_dir}/pbmc_y.npz") as data:
62 | Y = data['arr_0']
63 | Y = Y - Y.min()
64 | scaler = getattr(preprocessing, cfg.scaler)()
65 | X = scaler.fit_transform(X)
66 | return cls(X, Y)
67 |
68 |
69 | class BIASE(ClusteringDataset):
70 | def __init__(self, data, targets):
71 | super().__init__(data, targets)
72 |
73 | @classmethod
74 | def setup(cls, cfg):
75 | name = 'biase'
76 | data_dir = cfg.data_dir
77 | dataset_x = f"{data_dir}/{name}/{name}_data.csv"
78 | dataset_y = f"{data_dir}/{name}/{name}_celldata.csv"
79 | with open(dataset_x) as r:
80 | data = [l.strip() for l in r.readlines()]
81 | cell_keys = data[0].split(',')[1:]
82 | rows = [np.array([float(v) for v in row.split(',')[1:]]).reshape((1, -1)) for row in data[1:]]
83 | X = BIASE.remove_zero_columns(
84 | np.concatenate(rows, axis=0).transpose()) # np.concatenate(rows, axis=0).transpose()
85 | with open(dataset_y) as r:
86 | y_data = [l.strip().split(',') for l in r.readlines()[1:]]
87 | cell2class = {row[0]: row[2] for row in y_data}
88 | class2count = {}
89 | for cell, clas in cell2class.items():
90 | class2count.setdefault(clas, 0)
91 | class2count[clas] += 1
92 |
93 | print(class2count)
94 | class2id = {c: i for i, c in enumerate(set(sorted(list(cell2class.values()))))}
95 |
96 | Y = []
97 | for cell_key in cell_keys:
98 | Y.append(class2id[cell2class[cell_key]])
99 | Y = np.array(Y).reshape(-1)
100 | X = BIASE.transform(X)
101 |
102 | X = np.log(1 + X)
103 | X = X + .001 * np.random.normal(0, 1, (X.shape))
104 | scaler = getattr(preprocessing, cfg.scaler)()
105 | X = scaler.fit_transform(X)
106 | return cls(X, Y)
107 |
108 |
109 | class INTESTINE(ClusteringDataset):
110 | def __init__(self, data, targets):
111 | super().__init__(data, targets)
112 |
113 | @classmethod
114 | def setup(cls, cfg):
115 | scaler = getattr(preprocessing, cfg.scaler)()
116 | name = 'intestine'
117 | data_dir = cfg.data_dir
118 | dataset_x = f"{data_dir}/{name}/{name}_data.csv"
119 | dataset_y = f"{data_dir}/{name}/{name}_celldata.csv"
120 | with open(dataset_x) as r:
121 | data = [l.strip() for l in r.readlines()]
122 | cell_keys = data[0].split(',')[1:]
123 | rows = [np.array([float(v) for v in row.split(',')[1:]]).reshape((1, -1)) for row in data[1:]]
124 | X = np.concatenate(rows, axis=0).T
125 | with open(dataset_y) as r:
126 | y_data = [l.strip().split(',') for l in r.readlines()[1:]]
127 | cell2class = {row[0]: row[2] for row in y_data}
128 | class2count = {}
129 | for cell, clas in cell2class.items():
130 | class2count.setdefault(clas, 0)
131 | class2count[clas] += 1
132 | print(class2count)
133 | class2id = {c: i for i, c in enumerate(sorted(set(list(cell2class.values()))))}
134 | Y = []
135 | for cell_key in cell_keys:
136 | Y.append(class2id[cell2class[cell_key]])
137 | Y = np.array(Y).reshape(-1)
138 | X = scaler.fit_transform(X)
139 | return cls(X, Y)
140 |
141 |
142 | class CNAE9(ClusteringDataset):
143 | def __init__(self, data, targets):
144 | super().__init__(data, targets)
145 |
146 | @classmethod
147 | def setup(cls, cfg):
148 | scaler = getattr(preprocessing, cfg.scaler)()
149 | data = np.loadtxt(f"{cfg.data_dir}/cnae_9_numpy.txt")
150 | X = data[:, :-1]
151 | Y = data[:, -1]
152 | Y = Y - Y.min()
153 | X = scaler.fit_transform(X)
154 | return cls(X, Y)
155 |
156 |
157 | class MFEATZERNIKE(ClusteringDataset):
158 | def __init__(self, data, targets):
159 | super().__init__(data, targets)
160 |
161 | @classmethod
162 | def setup(cls, cfg):
163 | scaler = getattr(preprocessing, cfg.scaler)()
164 | data = np.loadtxt(f"{cfg.data_dir}/mfeat_zernike_numpy.txt")
165 | X = data[:, :-1]
166 | Y = data[:, -1]
167 | Y = Y - Y.min()
168 | X = scaler.fit_transform(X)
169 | return cls(X, Y)
170 |
171 |
172 | class ALLAML(ClusteringDataset):
173 | def __init__(self, data, targets):
174 | super().__init__(data, targets)
175 |
176 | @classmethod
177 | def setup(cls, cfg):
178 | dataset = loadmat(f"{cfg.data_dir}/ALLAML.mat")
179 | X = dataset.get('X')
180 | Y = dataset.get('Y').reshape(-1)
181 | Y = Y - Y.min()
182 | scaler = getattr(preprocessing, cfg.scaler)()
183 | X = scaler.fit_transform(X)
184 | return cls(X, Y)
185 |
186 |
187 | class PROSTATE(ClusteringDataset):
188 | def __init__(self, data, targets):
189 | super().__init__(data, targets)
190 |
191 | @classmethod
192 | def setup(cls, cfg):
193 | dataset = loadmat(f"{cfg.data_dir}/PROSTATE.mat")
194 | X = dataset.get('X')
195 | Y = dataset.get('Y').reshape(-1)
196 | Y = Y - Y.min() # to start from zero
197 | scaler = getattr(preprocessing, cfg.scaler)()
198 | X = scaler.fit_transform(X)
199 | return cls(X, Y)
200 |
201 |
202 | class TOX171(ClusteringDataset):
203 | def __init__(self, data, targets):
204 | super().__init__(data, targets)
205 |
206 | @classmethod
207 | def setup(cls, cfg):
208 | dataset = loadmat(f"{cfg.data_dir}/TOX171.mat")
209 | X = dataset.get('X')
210 | Y = dataset.get('Y').reshape(-1)
211 | Y = Y - Y.min() # to start from zero
212 | scaler = getattr(preprocessing, cfg.scaler)()
213 | X = scaler.fit_transform(X)
214 | return cls(X, Y)
215 |
216 |
217 | class SRBCT(ClusteringDataset):
218 | def __init__(self, data, targets):
219 | super().__init__(data, targets)
220 |
221 | @classmethod
222 | def setup(cls, cfg):
223 | dataset = loadmat(f"{cfg.data_dir}/SRBCT.mat")
224 | X = dataset.get('X')
225 | Y = dataset.get('Y').reshape(-1)
226 | Y = Y - Y.min() # to start from zero
227 | scaler = getattr(preprocessing, cfg.scaler)()
228 | X = scaler.fit_transform(X)
229 | return cls(X, Y)
230 |
231 |
232 | class MNIST60K(ClusteringDataset):
233 | def __init__(self, data, targets):
234 | super().__init__(data, targets)
235 |
236 | @classmethod
237 | def setup(cls, cfg):
238 | scaler = getattr(preprocessing, cfg.scaler)()
239 | X = MNIST(cfg.data_dir, train=True, download=True).data.reshape(-1, 784).cpu().numpy()
240 | Y = MNIST(cfg.data_dir, train=True, download=True).targets.cpu().numpy()
241 | X = scaler.fit_transform(X)
242 | return cls(X, Y)
243 |
244 |
245 | class MNIST10K(ClusteringDataset):
246 | def __init__(self, data, targets):
247 | super().__init__(data, targets)
248 |
249 | @classmethod
250 | def setup(cls, cfg):
251 | scaler = getattr(preprocessing, cfg.scaler)()
252 | X = MNIST(cfg.data_dir, train=True, download=True).data.reshape(-1, 784).cpu().numpy()
253 | Y = MNIST(cfg.data_dir, train=True, download=True).targets.cpu().numpy()
254 | X = scaler.fit_transform(X)
255 | X = X[:10000]
256 | Y = Y[:10000]
257 | return cls(X, Y)
258 |
259 |
260 | class NumpyTableDataset(ClusteringDataset):
261 | def __init__(self, data, labels=None, num_clusters=None):
262 | super().__init__(data, labels, num_clusters)
263 |
264 | @classmethod
265 | def setup(cls, filepath_samples: str, filepath_labels: str = None, num_clusters: int = None):
266 | """
267 | :param filepath_samples: the path to the npz file, the format of the numpy array should be NxD
268 | (number of samples x number of features)
269 | :param filepath_labels: the path to the npz file, the format of the numpy array should be N
270 | (number of samples)
271 | :param num_clusters: the integer number of expected clusters
272 | """
273 | with np.load(filepath_samples) as data:
274 | X = data['arr_0']
275 |
276 | if filepath_labels is not None:
277 | with np.load(filepath_labels) as data:
278 | Y = data['arr_0']
279 | X = preprocessing.StandardScaler().fit_transform(X)
280 | Y = Y - Y.min()
281 | else:
282 | Y = None
283 | return cls(X, Y, num_clusters)
284 |
285 |
286 | def remove_zero_columns(X):
287 | non_zero_columns = []
288 | for col in range(X.shape[1]):
289 | if np.min(X[:, col]) == 0 and np.max(X[:, col]) == 0:
290 | continue
291 | else:
292 | non_zero_columns.append(col)
293 | X = X[:, non_zero_columns]
294 | return X
295 |
296 |
297 | class Synthetic(Dataset):
298 | def __init__(self, X, Y):
299 | super().__init__()
300 | self.data = X
301 | self.targets = Y
302 |
303 | def __getitem__(self, index: int):
304 | x = self.data[index]
305 | return torch.tensor(x).float(), torch.tensor(self.targets[index]).long()
306 |
307 | def __len__(self) -> int:
308 | return len(self.data)
309 |
310 | @classmethod
311 | def setup(cls, num_samples=5000, num_features=3, num_clusters=3, num_noise_dims=10):
312 | """
313 | Make num_clusters + 1 clusters in 3d and adds additional num_noise_dims noise features
314 | :param num_samples: number of samples in the dataset
315 | :param num_features: number of features in the dataset
316 | :param num_clusters: number of clusters in the dataset
317 | :param num_noise_dims: number of noise dimensions in addition to num_features
318 | :return: generates a dataset
319 | """
320 | x_2d, y_2d = datasets.make_blobs(num_samples, num_features-1, centers=num_clusters, cluster_std=.5,
321 | random_state=0)
322 | # split the points for cluster==2 into 2 clusters:
323 | max_x = x_2d[:, 1].max()
324 | min_x = x_2d[:, 1].min()
325 | x_y_2 = x_2d[y_2d == 2][:, 1]
326 | x_y_2 = MinMaxScaler((0, 1)).fit_transform(x_y_2.reshape(-1, 1)).reshape(-1)
327 | x_y_2 = MinMaxScaler((min_x, max_x)).fit_transform(x_y_2.reshape(-1, 1)).reshape(-1)
328 | x_2d[:, 1][y_2d == 2] = x_y_2
329 |
330 | z = np.random.rand(num_samples)
331 | y_2d[(y_2d == 2) & (z > 0.5)] = 3
332 |
333 | x_2d[:, 0][y_2d == 0] = x_2d[:, 0][y_2d == 1]
334 |
335 | bg = np.random.normal(loc=0, scale=0.01, size=(num_samples, num_noise_dims))
336 | X = np.concatenate([x_2d, z.reshape(-1, 1), bg], axis=1)
337 | X[:, 2][y_2d == 3] = X[:, 2][y_2d == 3] + 0.5 # separate in z axis
338 | X[:, 2][y_2d == 0] = MinMaxScaler(
339 | (X[:, 2][(y_2d == 3) | (y_2d == 2)].min(), X[:, 2][(y_2d == 3) | (y_2d == 2)].max())).fit_transform(
340 | X[:, 2][y_2d == 0].reshape(-1, 1)).reshape(-1)
341 | X[:, 2][y_2d == 1] = MinMaxScaler(
342 | (X[:, 2][(y_2d == 3) | (y_2d == 2)].min(), X[:, 2][(y_2d == 3) | (y_2d == 2)].max())).fit_transform(
343 | X[:, 2][y_2d == 1].reshape(-1, 1)).reshape(-1)
344 |
345 | Y = y_2d
346 |
347 | X4 = X[Y == 3]
348 | max_len = len(X4)
349 | X1 = X[Y == 0][:max_len, :]
350 | X2 = X[Y == 1][:max_len, :]
351 | X3 = X[Y == 2][:max_len, :]
352 |
353 | Y1 = Y[Y == 0][:max_len]
354 | Y2 = Y[Y == 1][:max_len]
355 | Y3 = Y[Y == 2][:max_len]
356 | Y4 = Y[Y == 3][:max_len]
357 | X = np.concatenate([X1, X2, X3, X4], axis=0)
358 | Y = np.concatenate([Y1, Y2, Y3, Y4], axis=0)
359 |
360 | print("Class stats:")
361 | for y_i in np.unique(Y):
362 | print(f"{y_i}: {len(Y[Y == y_i])} samples")
363 | X[:, :3] = zscore(X[:, :3])
364 |
365 | plt.style.use('classic')
366 | plt.rcParams['axes.spines.right'] = False
367 | plt.rcParams['axes.spines.top'] = False
368 | fig = plt.figure()
369 | fig.set_facecolor('w')
370 | plt.scatter(X[:, 0], X[:, 1], c=Y, s=100, alpha=0.8, cmap='viridis', edgecolor='k', linewidth=2)
371 | plt.xlabel('$X_1$', fontsize=30)
372 | plt.ylabel('$X_2$', fontsize=30)
373 | plt.tight_layout()
374 | plt.xticks([])
375 | plt.yticks([])
376 | plt.savefig("synth_X_1_X_2.png")
377 |
378 | plt.clf()
379 | fig = plt.figure()
380 | fig.set_facecolor('w')
381 | plt.scatter(X[:, 0], X[:, 2], c=Y, s=100, alpha=0.8, cmap='viridis', edgecolor='k', linewidth=2)
382 | plt.xlabel('$X_1$', fontsize=30)
383 | plt.ylabel('$X_3$', fontsize=30)
384 | plt.tight_layout()
385 | plt.xticks([])
386 | plt.yticks([])
387 | plt.savefig("synth_X_1_X_3.png")
388 | return cls(X, Y)
389 |
390 | def num_classes(self):
391 | return len(np.unique(self.targets))
392 |
393 | def num_features(self):
394 | return self.data.shape[-1]
--------------------------------------------------------------------------------
/img/freq_bias.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsvir/idc/6288558d24268fa842944dd7f0490a62fd9d1fdf/img/freq_bias.png
--------------------------------------------------------------------------------
/img/img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsvir/idc/6288558d24268fa842944dd7f0490a62fd9d1fdf/img/img.png
--------------------------------------------------------------------------------
/img/nudft_ALLAML.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsvir/idc/6288558d24268fa842944dd7f0490a62fd9d1fdf/img/nudft_ALLAML.png
--------------------------------------------------------------------------------
/img/pbmc.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsvir/idc/6288558d24268fa842944dd7f0490a62fd9d1fdf/img/pbmc.gif
--------------------------------------------------------------------------------
/img/supervised_train_plots.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsvir/idc/6288558d24268fa842944dd7f0490a62fd9d1fdf/img/supervised_train_plots.png
--------------------------------------------------------------------------------
/img/supervised_train_plots2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jsvir/idc/6288558d24268fa842944dd7f0490a62fd9d1fdf/img/supervised_train_plots2.png
--------------------------------------------------------------------------------
/interpretability_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from munkres import Munkres
3 | import torch
4 | from sklearn.metrics import confusion_matrix, jaccard_score
5 | from scipy.spatial import distance_matrix
6 | from tqdm import tqdm
7 | from sklearn.svm import LinearSVC
8 |
9 |
10 | def generalizability(X_train, gates_train, Y_train, X_test, gates_test, Y_test):
11 | """
12 | How the interpretation of the prediction generalizes to other simple prediction models, e.g. Linear Support Vector Classification
13 | """
14 | classifier = LinearSVC()
15 | classifier.fit(X_train * gates_train, Y_train)
16 | return classifier.score(X_test * gates_test, Y_test)
17 |
18 |
19 | def faithfulness(gates_i, x, inference_fn, y, num_features=784):
20 | """
21 | Are the identified features significant for prediction?
22 | """
23 | importance_vec = np.sum(gates_i > 0, axis=0)
24 | importance_ind = np.where(np.sum(gates_i > 0, axis=0) > 0)[0]
25 | importance_ind_sort = importance_ind[np.argsort(-importance_vec[importance_vec > 0])]
26 | mask = np.ones(num_features)
27 | acc_arr_bad = []
28 | for i in importance_ind_sort:
29 | mask[i] = 0
30 | y_hat = inference_fn(x * mask)
31 | if isinstance(y_hat, torch.Tensor):
32 | y_hat = y_hat.cpu().numpy()
33 | mean_val = get_accuracy(y_hat, y, 10)
34 | acc_arr_bad.append(mean_val)
35 | return np.corrcoef(importance_vec[importance_ind_sort], np.array(acc_arr_bad))[0, 1]
36 |
37 |
38 | def stability(x, gates, k=2, subset_size=10000,p=2):
39 | """
40 | Are explanations to similar samples consistent?
41 | inputs:
42 | x_test is N x D matrix of samples
43 | gates is N x D matrix predicted by STG for x_test
44 | k is the number of neighbors
45 | outputs:
46 | mean Lipchitz constant of the explanation function
47 | """
48 | dist_mat_x = distance_matrix(x, x, p=p)
49 | nn_dist_mat = np.sort(dist_mat_x, axis=1)[:, 0:k]
50 | nn_ind_mat = np.argsort(dist_mat_x, axis=1)[:, 0:k]
51 | lipchitz_constants = []
52 | for i in tqdm(range(subset_size)):
53 | lipchitz_constants.append(max(distance_matrix(gates[nn_ind_mat[i], :], gates[nn_ind_mat[i], :])[0][1:] / nn_dist_mat[i][1:]))
54 | return np.mean(np.array(lipchitz_constants))
55 |
56 |
57 | def diversity(y, gates, num_clusters=10, num_features=784):
58 | """
59 | How different are the selected variables for instances of distinct classes?
60 | For formula see appendix A.7 in
61 | Yang, Junchen, Ofir Lindenbaum, and Yuval Kluger. "Locally sparse neural networks for tabular biomedical data." International Conference on Machine Learning. PMLR, 2022.
62 | """
63 | per_matrix = np.zeros((num_clusters, num_clusters))
64 | all_gates = []
65 | for i in range(num_clusters):
66 | indices_p = np.where(y == i)[0]
67 | onez_p = np.zeros(num_features)
68 | active_gates = np.where(np.median(gates[indices_p, :], axis=0) > 0)[0]
69 | onez_p[active_gates] = 1
70 | all_gates = np.append(all_gates, active_gates)
71 | for j in range(num_clusters):
72 | indices_n = np.where(y == j)[0]
73 | active_gates_n = np.where(np.median(gates[indices_n, :], axis=0) > 0)[0]
74 | onez_n = np.zeros(num_features)
75 | onez_n[active_gates_n] = 1
76 | per_matrix[i, j] = jaccard_score(onez_n, onez_p)
77 |
78 | diversity = 100 * (1 - (per_matrix / (num_clusters * (num_clusters - 1))).sum())
79 | return diversity
80 |
81 |
82 | def uniqueness(x, gates, k=2, subset_size=10000, p=2):
83 | """
84 | uniqueness of the selected features for similar samples (how granular our explanations are?)
85 | inputs:
86 | x_test is N x D matrix of samples
87 | gates is N x D matrix predicted by STG for x_test
88 | k is the number of neighbors
89 | """
90 | dist_mat_x = distance_matrix(x, x, p=p)
91 | nn_dist_mat = np.sort(dist_mat_x, axis=1)[:, 0:k]
92 | nn_ind_mat = np.argsort(dist_mat_x, axis=1)[:, 0:k]
93 | vals = []
94 | for i in tqdm(range(subset_size)):
95 | vals.append(min(distance_matrix(gates[nn_ind_mat[i], :], gates[nn_ind_mat[i], :])[0][1:] / nn_dist_mat[i][1:]))
96 | return np.mean(np.array(vals))
97 |
98 |
99 |
100 | def get_accuracy(cluster_assignments, y_true, n_clusters):
101 | '''
102 | Computes the accuracy based on the provided kmeans cluster assignments
103 | and true labels, using the Munkres algorithm
104 | cluster_assignments: array of labels, outputted by kmeans
105 | y_true: true labels
106 | n_clusters: number of clusters in the dataset
107 | returns: a tuple containing the accuracy and confusion matrix,
108 | in that order
109 | '''
110 | confusion_mat = confusion_matrix(y_true, cluster_assignments, labels=None)
111 | # compute accuracy based on optimal 1:1 assignment of clusters to labels
112 | cost_matrix = calculate_cost_matrix(confusion_mat, n_clusters)
113 | indices = Munkres().compute(cost_matrix)
114 | kmeans_to_true_cluster_labels = get_cluster_labels_from_indices(indices)
115 | y_pred = kmeans_to_true_cluster_labels[cluster_assignments]
116 | return np.mean(y_pred == y_true)
117 |
118 |
119 | def calculate_cost_matrix(C, n_clusters):
120 | cost_matrix = np.zeros((n_clusters, n_clusters))
121 | # cost_matrix[i,j] will be the cost of assigning cluster i to label j
122 | for j in range(n_clusters):
123 | s = np.sum(C[:, j]) # number of examples in cluster i
124 | for i in range(n_clusters):
125 | t = C[i, j]
126 | cost_matrix[j, i] = s - t
127 | return cost_matrix
128 |
129 |
130 | def get_cluster_labels_from_indices(indices):
131 | n_clusters = len(indices)
132 | clusterLabels = np.zeros(n_clusters)
133 | for i in range(n_clusters):
134 | clusterLabels[i] = indices[i][1]
135 | return clusterLabels
--------------------------------------------------------------------------------
/lspin_pbmc.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 | import torch.nn.functional as F
3 | from torch.utils.data import DataLoader
4 | from pytorch_lightning import LightningModule
5 | from torch.utils.data import Dataset
6 | from sklearn.preprocessing import StandardScaler
7 | from pytorch_lightning import Trainer, seed_everything
8 | import argparse
9 | import torch
10 | import math
11 | import numpy as np
12 | import os
13 | from pytorch_lightning.loggers import TensorBoardLogger
14 | from pytorch_lightning.callbacks import LearningRateMonitor
15 | import platform
16 |
17 |
18 | def parse_args(args):
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--gated", action="store_true")
21 | parser.add_argument("--dataset", type=str, default="PBMC")
22 | parser.add_argument("--data_dir", type=str, default="C:/data/fs/pbmc" if platform.system() == "Windows" else ".")
23 | parser.add_argument("--batch_size", type=int, default=256)
24 | parser.add_argument("--repitions", type=int, default=5)
25 | parser.add_argument("--lr", type=int, default=1e-3)
26 |
27 | # gatenet config
28 | parser.add_argument("--sigma", type=float, default=0.5)
29 | parser.add_argument("--reg_beta", type=float, default=10)
30 | parser.add_argument("--target_sparsity", type=int, default=0.9)
31 | parser.add_argument("--gates_lr", type=int, default=2e-3)
32 |
33 |
34 | # trainer config
35 | parser.add_argument("--devices", type=int, default=1)
36 | parser.add_argument("--accelerator", type=str, default="gpu")
37 | parser.add_argument("--max_epochs", type=int, default=1000)
38 | parser.add_argument("--deterministic", type=bool, default=True)
39 | parser.add_argument("--logger", type=bool, default=True)
40 | parser.add_argument("--log_every_n_steps", type=int, default=10)
41 | parser.add_argument("--check_val_every_n_epoch", type=int, default=1)
42 | parser.add_argument("--enable_checkpointing", type=bool, default=False)
43 |
44 | args = parser.parse_args(args)
45 | return args
46 |
47 |
48 | class PBMC(Dataset):
49 | def __init__(self, X, Y):
50 | super().__init__()
51 | self.data = X
52 | self.targets = Y
53 |
54 | def __getitem__(self, index: int):
55 | x = self.data[index]
56 | x = x.reshape(-1)
57 | return torch.tensor(x).float(), torch.tensor(self.targets[index]).long()
58 |
59 | def __len__(self) -> int:
60 | return len(self.data)
61 |
62 | @classmethod
63 | def setup(cls, data_dir, test_size=0.2):
64 | with np.load(f"{data_dir}/pbmc_x.npz") as data:
65 | X = data['arr_0']
66 | with np.load(f"{data_dir}/pbmc_y.npz") as data:
67 | Y = data['arr_0']
68 |
69 | Y = Y - Y.min()
70 | X = StandardScaler().fit_transform(X)
71 | print(f'Dataset PBMC stats:')
72 | print('X.shape: ', X.shape)
73 | print('Y.shape: ', Y.shape)
74 | print(f"X.min={X.min()}, X.max={X.max()}")
75 | print(f"Y.min={Y.min()}, Y.max={Y.max()}")
76 |
77 | for y_uniq in np.unique(Y):
78 | print(f"Label {y_uniq} has {len(Y[Y == y_uniq])} samples")
79 |
80 | np.random.seed(1948)
81 | random_index = np.random.permutation(len(X))
82 | test_size = int(len(X) * test_size)
83 | x_test = X[random_index][:test_size]
84 | y_test = Y[random_index][:test_size]
85 |
86 | x_train = X[random_index][test_size:]
87 | y_train = Y[random_index][test_size:]
88 |
89 | print(f"Split to train/test: train {len(x_train)} test {len(x_test)}")
90 | return cls(x_train, y_train), cls(x_test, y_test)
91 |
92 | def num_classes(self):
93 | return len(np.unique(self.targets))
94 |
95 | def num_features(self):
96 | return self.data.shape[-1]
97 |
98 |
99 | class BaseModule(LightningModule):
100 | def __init__(self, cfg):
101 | super().__init__()
102 | self.cfg = cfg
103 | self.save_hyperparameters()
104 | self.best_evaluation_stats = {}
105 | self.automatic_optimization = False
106 | self.best_accuracy = - np.infty
107 | self.classifier_net = Classifier(cfg)
108 | self.val_cluster_list = []
109 | self.val_label_list = []
110 | self.best_acc = - 100
111 |
112 | if cfg.gated:
113 | self.gating_net = GatingNet(cfg)
114 | self.val_cluster_list_gated = []
115 | self.open_gates = []
116 | self.best_local_feats = None
117 |
118 | def training_step(self, batch, batch_idx):
119 | opt = self.optimizers()
120 | sch = self.lr_schedulers()
121 | x, y = batch
122 | x = x.reshape(x.size(0), -1)
123 | opt.zero_grad()
124 |
125 | if hasattr(self, 'gating_net'):
126 | mu, _, gates = self.gating_net(x)
127 | ae_emb = self.classifier_net.encoder(x * gates)
128 |
129 | reg_loss = self.gating_net.regularization(mu)
130 | self.log("train/reg_loss", reg_loss.item())
131 | else:
132 | ae_emb = self.classifier_net.encoder(x)
133 | reg_loss = 0
134 |
135 | cluster_logits = self.classifier_net.head(ae_emb)
136 | ce_loss = F.cross_entropy(cluster_logits, y)
137 |
138 | self.log("train/ce_loss", ce_loss.item())
139 | loss = ce_loss + self.cfg.reg_beta * reg_loss
140 | self.manual_backward(loss)
141 | opt.step()
142 | sch.step()
143 |
144 | if self.global_step % 100 == 0:
145 | if hasattr(self, 'gating_net'):
146 | print(f"Epoch {self.current_epoch} "
147 | f"step {self.global_step} "
148 | f"train/reg_loss {reg_loss.item()} "
149 | f"train/ce_loss {ce_loss.item()}")
150 | def configure_optimizers(self):
151 |
152 | if hasattr(self, 'gating_net'):
153 | params =[ { # classifier
154 | "params": chain(
155 | self.classifier_net.encoder.parameters(),
156 | self.classifier_net.head.parameters()),
157 | "lr": self.cfg.lr,
158 |
159 | },
160 | { # gates
161 | "params": self.gating_net.net.parameters(),
162 | "lr": self.cfg.gates_lr,
163 | }]
164 |
165 | else:
166 | params = chain(
167 | self.classifier_net.encoder.parameters(),
168 | self.classifier_net.head.parameters(),
169 | )
170 | optimizer = torch.optim.SGD(
171 | params=params,
172 | lr=self.cfg.lr)
173 |
174 | steps = self.train_dataset.__len__() // self.batch_size * self.cfg.max_epochs
175 | print(f"Cosine annealing LR scheduling is applied during {steps} steps")
176 | sched = torch.optim.lr_scheduler.CosineAnnealingLR(
177 | optimizer=optimizer,
178 | T_max=steps,
179 | eta_min=1e-4)
180 | return [optimizer], [sched]
181 |
182 | def validation_step(self, batch, batch_idx):
183 | x, y = batch
184 | if hasattr(self, 'gating_net'):
185 | gates = self.gating_net.get_gates(x)
186 | ae_emb = self.classifier_net.encoder(x * gates)
187 | self.open_gates.append(self.gating_net.num_open_gates(x))
188 | else:
189 | ae_emb = self.classifier_net.encoder(x)
190 | cluster_logits = self.classifier_net.head(ae_emb)
191 | y_hat = cluster_logits.argmax(dim=-1)
192 | self.val_cluster_list.append(y_hat.cpu())
193 | self.val_label_list.append(y.cpu())
194 |
195 | def on_validation_epoch_start(self):
196 | self.val_cluster_list = []
197 | self.val_cluster_list_gated = []
198 | self.val_label_list = []
199 | if hasattr(self, 'gating_net'):
200 | self.open_gates = []
201 |
202 | def on_validation_epoch_end(self):
203 | if self.current_epoch > 0:
204 | cluster_mtx = torch.cat(self.val_cluster_list, dim=0)
205 | label_mtx = torch.cat(self.val_label_list, dim=0)
206 | acc = torch.mean((cluster_mtx == label_mtx).float()).item()
207 | if self.best_accuracy < acc:
208 | self.best_accuracy = acc
209 | if hasattr(self, 'gating_net'):
210 | meta_dict = {"gating": self.gating_net.state_dict(), "clustering": self.classifier_net.state_dict()}
211 | torch.save(meta_dict, f'sparse_model_best_pbmc_beta_{self.cfg.reg_beta}_seed_{self.cfg.seed}.pth')
212 | print(f"New best accuracy: {acc} open gates: {np.mean(self.open_gates).item()}")
213 | else:
214 | meta_dict = {"clustering": self.classifier_net.state_dict()}
215 | torch.save(meta_dict, f'sparse_model_nogates_best_pbmc_seed_{self.cfg.seed}.pth')
216 | print(f"New best accuracy: {acc}")
217 | format_str = '' # '_kmeans' if self.current_epoch == 9 else ''
218 | self.log(f'val/acc_single{format_str}', acc) # this is ACC
219 | if hasattr(self, 'gating_net'):
220 | self.log("val/num_open_gates", np.mean(self.open_gates).item())
221 | meta_dict = {"gating": self.gating_net.state_dict(), "clustering": self.classifier_net.state_dict()}
222 | torch.save(meta_dict, f'sparse_model_last_pbmc_beta_{self.cfg.reg_beta}_seed_{self.cfg.seed}.pth')
223 | self.update_stats(acc, np.mean(self.open_gates).item())
224 | else:
225 | meta_dict = {"clustering": self.classifier_net.state_dict()}
226 | torch.save(meta_dict, f'sparse_model_nogates_last_pbmc_seed_{self.cfg.seed}.pth')
227 | self.update_stats(acc, None)
228 |
229 | def update_stats(self, acc, local_feats=None):
230 | if self.best_acc <= acc:
231 | self.best_acc = acc
232 | if local_feats is not None:
233 | self.best_local_feats = local_feats
234 |
235 |
236 | class ClassificationModule(BaseModule):
237 | def __init__(self, cfg):
238 | self.train_dataset, self.test_dataset = PBMC.setup(cfg.data_dir)
239 | print(f"Train Dataset length: {self.train_dataset.__len__()}")
240 | print(f"Test Dataset length: {self.test_dataset.__len__()}")
241 | cfg.input_dim = self.train_dataset.num_features()
242 | cfg.n_clusters = self.train_dataset.num_classes()
243 | self.batch_size = min(self.train_dataset.__len__(), cfg.batch_size)
244 | super().__init__(cfg)
245 |
246 | def train_dataloader(self):
247 | return DataLoader(self.train_dataset,
248 | batch_size=self.batch_size,
249 | drop_last=True,
250 | shuffle=True,
251 | num_workers=0)
252 |
253 | def val_dataloader(self):
254 | return DataLoader(self.test_dataset,
255 | batch_size=self.batch_size,
256 | drop_last=False,
257 | shuffle=False,
258 | num_workers=0)
259 |
260 |
261 | class Classifier(torch.nn.Module):
262 | def __init__(self, cfg):
263 | super(Classifier, self).__init__()
264 | self.cfg = cfg
265 | self.encoder = torch.nn.Sequential(
266 | torch.nn.Linear(cfg.input_dim, 512),
267 | torch.nn.BatchNorm1d(512),
268 | torch.nn.ReLU(),
269 | )
270 | self.head = torch.nn.Sequential(
271 | torch.nn.Linear(512, 2048),
272 | torch.nn.BatchNorm1d(2048),
273 | torch.nn.ReLU(),
274 | torch.nn.Linear(2048, cfg.n_clusters),
275 | )
276 |
277 | self.encoder.apply(self.init_weights_normal)
278 | self.head.apply(self.init_weights_normal)
279 |
280 | @staticmethod
281 | def init_weights_normal(m):
282 | if isinstance(m, torch.nn.Linear):
283 | torch.nn.init.xavier_normal_(m.weight)
284 | if 'bias' in vars(m).keys():
285 | m.bias.data.fill_(0.0)
286 |
287 | def pretrain_forward(self, x):
288 | return self.decoder(self.encoder(x))
289 |
290 |
291 | class GatingNet(torch.nn.Module):
292 | def __init__(self, cfg):
293 | super(GatingNet, self).__init__()
294 | self.cfg = cfg
295 | self._sqrt_2 = math.sqrt(2)
296 | self.sigma = cfg.sigma
297 | self.net = torch.nn.Sequential(
298 | torch.nn.Linear(cfg.input_dim, 512),
299 | torch.nn.ReLU(),
300 | torch.nn.Linear(512, 2048),
301 | torch.nn.ReLU(),
302 | torch.nn.Linear(2048, 512),
303 | torch.nn.ReLU(),
304 | torch.nn.Linear(512, cfg.input_dim),
305 | torch.nn.Tanh()
306 | )
307 | self.net.apply(self.init_weights)
308 |
309 | def init_weights(self, m):
310 | if isinstance(m, torch.nn.Linear):
311 | torch.nn.init.xavier_normal_(m.weight)
312 | if m.out_features == self.cfg.input_dim:
313 | m.bias.data.fill_(.5)
314 | else:
315 | m.bias.data.fill_(0.0)
316 |
317 | def global_forward(self, batch_size, y):
318 | noise = torch.normal(mean=0, std=self.sigma, size=(batch_size, self.cfg.input_dim),
319 | device=self.global_gates_net.weight.device)
320 | z = torch.tanh(self.global_gates_net(y)).reshape(1, -1).repeat(batch_size, 1) + noise * self.training
321 | gates = self.hard_sigmoid(z)
322 | return torch.tanh(self.global_gates_net(y)), gates
323 |
324 | def open_global_gates(self):
325 | return self.hard_sigmoid(torch.tanh(self.global_gates_net.weight)).sum(dim=1).mean().cpu().item()
326 |
327 | def forward(self, x):
328 | noise = torch.normal(mean=0, std=self.sigma, size=x.size(), device=x.device)
329 | mu = self.net(x)
330 | z = mu + noise * self.training
331 | gates = self.hard_sigmoid(z)
332 | sparse_x = x * gates
333 | return mu, sparse_x, gates
334 |
335 | @staticmethod
336 | def hard_sigmoid(x):
337 | return torch.clamp(x + .5, 0.0, 1.0)
338 |
339 | def regularization(self, mu, reduction_func=torch.mean):
340 | return max(reduction_func(0.5 - 0.5 * torch.erf((-0.5 - mu) / (0.5 * self._sqrt_2))),
341 | torch.tensor(1 - self.cfg.target_sparsity, device=mu.device, dtype=mu.data.dtype))
342 |
343 | def get_gates(self, x):
344 | with torch.no_grad():
345 | gates = self.hard_sigmoid(self.net(x))
346 | return gates
347 |
348 | def num_open_gates(self, x):
349 | return torch.sum(self.get_gates(x) > 0).item() / x.size(0)
350 |
351 |
352 | def train_test(cfg):
353 | torch.use_deterministic_algorithms(True)
354 | torch.backends.cudnn.deterministic = True
355 | torch.backends.cudnn.benchmark = False
356 | gated_str = "_gated" if cfg.gated else ""
357 | with open(f"results_{os.path.basename(__file__)}{gated_str}_reg_beta_{cfg.reg_beta}.txt", mode='w') as f:
358 |
359 | header = '\t'.join(['seed', 'acc', 'local_gates'])
360 | f.write(f"{header}\n")
361 | f.flush()
362 |
363 | for seed in range(cfg.repitions):
364 | cfg.seed = seed
365 | seed_everything(seed)
366 | np.random.seed(seed)
367 | if not os.path.exists(cfg.dataset):
368 | os.makedirs(cfg.dataset)
369 | model = ClassificationModule(cfg)
370 | logger = TensorBoardLogger(cfg.dataset, name=os.path.basename(__file__), log_graph=False)
371 | trainer = Trainer(
372 | devices=cfg.devices,
373 | accelerator=cfg.accelerator,
374 | max_epochs=cfg.max_epochs,
375 | deterministic=cfg.deterministic,
376 | logger=cfg.logger,
377 | log_every_n_steps=cfg.log_every_n_steps,
378 | check_val_every_n_epoch=cfg.check_val_every_n_epoch,
379 | enable_checkpointing=cfg.enable_checkpointing,
380 | callbacks=[LearningRateMonitor(logging_interval='step')]
381 | )
382 | trainer.logger = logger
383 | trainer.fit(model)
384 | if cfg.gated:
385 | results_str = '\t'.join([f'{seed}', f'{model.best_acc}', f'{model.best_local_feats}'])
386 | else:
387 | results_str = '\t'.join([f'{seed}', f'{model.best_acc}'])
388 | f.write(f"{results_str}\n")
389 | f.flush()
390 |
391 |
392 | if __name__ == "__main__":
393 | cfg = parse_args(None)
394 | train_test(cfg)
395 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 |
5 | def init_weights_normal(m):
6 | if isinstance(m, torch.nn.Linear):
7 | torch.nn.init.normal_(m.weight, std=0.001)
8 | if 'bias' in vars(m).keys():
9 | m.bias.data.fill_(0.0)
10 |
11 |
12 | def clustering_head(cfg):
13 | return torch.nn.Sequential(
14 | torch.nn.Linear(cfg.clustering_head[0], cfg.clustering_head[1]),
15 | torch.nn.BatchNorm1d(cfg.clustering_head[1]),
16 | torch.nn.ReLU(),
17 | torch.nn.Linear(cfg.clustering_head[1], cfg.n_clusters)).apply(init_weights_normal)
18 |
19 |
20 | def aux_classifier_head(cfg):
21 | return torch.nn.Sequential(
22 | torch.nn.Linear(cfg.input_dim, cfg.aux_classifier[0]),
23 | torch.nn.BatchNorm1d(cfg.aux_classifier[0]),
24 | torch.nn.ReLU(),
25 | torch.nn.Linear(cfg.aux_classifier[0], cfg.n_clusters)).apply(init_weights_normal)
26 |
27 |
28 | class EncoderDecoder(torch.nn.Module):
29 | def __init__(self, cfg):
30 | super(EncoderDecoder, self).__init__()
31 | self.cfg = cfg
32 | self.encoder = []
33 | self.encoder = self.build_encoder()
34 | self.decoder = self.build_decoder()
35 | self.encoder.apply(init_weights_normal)
36 | self.decoder.apply(init_weights_normal)
37 |
38 | def build_encoder(self):
39 | layers = [
40 | torch.nn.Linear(self.cfg.input_dim, self.cfg.encdec[0]),
41 | torch.nn.BatchNorm1d(self.cfg.encdec[0]),
42 | torch.nn.ReLU()
43 | ]
44 | hidden_layers = len(self.cfg.encdec) // 2 + 1
45 | for layer_idx in range(1, hidden_layers):
46 | if layer_idx == hidden_layers - 1:
47 | layers += [torch.nn.Linear(self.cfg.encdec[layer_idx - 1], self.cfg.encdec[layer_idx])]
48 | else:
49 | layers += [
50 | torch.nn.Linear(self.cfg.encdec[layer_idx - 1], self.cfg.encdec[layer_idx]),
51 | torch.nn.BatchNorm1d(self.cfg.encdec[layer_idx]),
52 | torch.nn.ReLU()
53 | ]
54 | return torch.nn.Sequential(*layers)
55 |
56 | def build_decoder(self):
57 | hidden_layers = len(self.cfg.encdec) // 2 + 1
58 | layers = []
59 | for layer_idx in range(hidden_layers, len(self.cfg.encdec)):
60 | layers += [
61 | torch.nn.Linear(self.cfg.encdec[layer_idx - 1], self.cfg.encdec[layer_idx]),
62 | torch.nn.BatchNorm1d(self.cfg.encdec[layer_idx]),
63 | torch.nn.ReLU()
64 | ]
65 | layers += [torch.nn.Linear(self.cfg.encdec[-1], self.cfg.input_dim)]
66 | return torch.nn.Sequential(*layers)
67 |
68 | def forward(self, x):
69 | return self.decoder(self.encoder(x))
70 |
71 | class GatingNet(torch.nn.Module):
72 | def __init__(self, cfg):
73 | super(GatingNet, self).__init__()
74 | self.cfg = cfg
75 | self._sqrt_2 = math.sqrt(2)
76 | self.sigma = 0.5
77 | self.local_gates = torch.nn.Sequential(
78 | torch.nn.Linear(cfg.input_dim, cfg.gates_hidden_dim),
79 | torch.nn.Tanh(),
80 | torch.nn.Linear(cfg.gates_hidden_dim, cfg.input_dim),
81 | torch.nn.Tanh()
82 | )
83 | self.local_gates.apply(self.init_weights)
84 | self.global_gates_net = torch.nn.Embedding(self.cfg.n_clusters, self.cfg.input_dim)
85 | torch.nn.init.normal_(self.global_gates_net.weight, std=0.01)
86 |
87 | @staticmethod
88 | def init_weights(m):
89 | if isinstance(m, torch.nn.Linear):
90 | torch.nn.init.normal_(m.weight, std=0.001)
91 | if 'bias' in vars(m).keys():
92 | m.bias.data.fill_(0.0)
93 |
94 | def global_forward(self, y):
95 | noise = torch.normal(mean=0, std=self.sigma, size=(y.size(0), self.cfg.input_dim),
96 | device=self.global_gates_net.weight.device)
97 | z = torch.tanh(self.global_gates_net(y)) + .5 * noise * self.training
98 | gates = self.hard_sigmoid(z)
99 | return torch.tanh(self.global_gates_net(y)), gates
100 |
101 | def open_global_gates(self):
102 | return self.hard_sigmoid(torch.tanh(self.global_gates_net.weight)).sum(dim=1).mean().cpu().item()
103 |
104 | def forward(self, x):
105 | noise = torch.normal(mean=0, std=self.sigma, size=x.size(), device=x.device)
106 | mu = self.local_gates(x)
107 | z = mu + .5 * noise * self.training
108 | gates = self.hard_sigmoid(z)
109 | sparse_x = x * gates
110 | return mu, sparse_x, gates
111 |
112 | @staticmethod
113 | def hard_sigmoid(x):
114 | return torch.clamp(x + .5, 0.0, 1.0)
115 |
116 | def regularization(self, mu, reduction_func=torch.mean):
117 | return reduction_func(0.5 - 0.5 * torch.erf((-1 / 2 - mu) / self._sqrt_2))
118 |
119 | def get_gates(self, x):
120 | with torch.no_grad():
121 | gates = self.hard_sigmoid(self.local_gates(x))
122 | return gates
123 |
124 | def num_open_gates(self, x, ):
125 | return self.get_gates(x).sum(dim=1).cpu().median(dim=0)[0].item()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=2.0.1
2 | pytorch-lightning==2.0.0
3 | scikit-learn==1.1.2
4 | scipy==1.10.0-rc1
5 | omegaconf==2.2.3
6 | matplotlib==3.6.3
7 | matplotlib-inline==0.1.6
8 | umap-learn==0.5.6
9 | torchvision==0.15.2
10 | munkres==1.1.4
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 | import torch
3 | import math
4 | from omegaconf import OmegaConf
5 | import torch.nn.functional as F
6 | from torch.utils.data import DataLoader
7 | from pytorch_lightning import LightningModule
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | from pytorch_lightning import Trainer, seed_everything
11 | import os
12 | from pytorch_lightning.loggers import TensorBoardLogger
13 | from pytorch_lightning.callbacks import LearningRateMonitor
14 | from sklearn.metrics import silhouette_score, davies_bouldin_score
15 | import argparse
16 | from dataset import NumpyTableDataset
17 | from model import clustering_head, aux_classifier_head, EncoderDecoder, GatingNet
18 | import umap
19 |
20 |
21 | class TotalCodingRateWithProjection(torch.nn.Module):
22 | """ Based on https://github.com/zengyi-li/NMCE-release/blob/main/NMCE/loss.py """
23 |
24 | def __init__(self, cfg):
25 | super().__init__()
26 | self.eps = cfg.gtcr_eps
27 | if cfg.gtcr_projection_dim is not None:
28 | self.random_matrix = torch.tensor(np.random.normal(
29 | loc=0.0,
30 | scale=1.0 / np.sqrt(cfg.gtcr_projection_dim),
31 | size=(cfg.input_dim, cfg.gtcr_projection_dim)
32 | )).float()
33 | else:
34 | self.random_matrix = None
35 |
36 | def compute_discrimn_loss(self, W):
37 | p, m = W.shape # [d, B]
38 | I = torch.eye(p, device=W.device)
39 | scalar = p / (m * self.eps)
40 | logdet = torch.logdet(I + scalar * W.matmul(W.T))
41 | return logdet / 2.
42 |
43 | def forward(self, x):
44 | if self.random_matrix is not None:
45 | x = x @ self.random_matrix.to(x.device)
46 | return - self.compute_discrimn_loss(x.T)
47 |
48 |
49 | class MaximalCodingRateReduction(torch.nn.Module):
50 | """ Based on https://github.com/zengyi-li/NMCE-release/blob/main/NMCE/loss.py """
51 |
52 | def __init__(self, eps=0.01, gamma=1, compress_only=False):
53 | super(MaximalCodingRateReduction, self).__init__()
54 | self.eps = eps
55 | self.gamma = gamma
56 | self.compress_only = compress_only
57 |
58 | def compute_discrimn_loss(self, W):
59 | p, m = W.shape
60 | I = torch.eye(p, device=W.device)
61 | scalar = p / (m * self.eps)
62 | logdet = torch.logdet(I + scalar * W.matmul(W.T))
63 | return logdet / 2.
64 |
65 | def compute_compress_loss(self, W, Pi):
66 | p, m = W.shape
67 | k, _, _ = Pi.shape
68 | I = torch.eye(p, device=W.device).expand((k, p, p))
69 | trPi = Pi.sum(2) + 1e-8
70 | scale = (p / (trPi * self.eps)).view(k, 1, 1)
71 | W = W.view((1, p, m))
72 | log_det = torch.logdet(I + scale * W.mul(Pi).matmul(W.transpose(1, 2)))
73 | compress_loss = (trPi.squeeze() * log_det / (2 * m)).sum()
74 | return compress_loss
75 |
76 | def forward(self, X, Y, num_classes=None):
77 | # This function support Y as label integer or membership probablity.
78 | if len(Y.shape) == 1:
79 | # if Y is a label vector
80 | if num_classes is None:
81 | num_classes = Y.max() + 1
82 | Pi = torch.zeros((num_classes, 1, Y.shape[0]), device=Y.device)
83 | for indx, label in enumerate(Y):
84 | Pi[label, 0, indx] = 1
85 | else:
86 | # if Y is a probility matrix
87 | if num_classes is None:
88 | num_classes = Y.shape[1]
89 | Pi = Y.T.reshape((num_classes, 1, -1))
90 |
91 | W = X.T
92 | compress_loss = self.compute_compress_loss(W, Pi)
93 | if not self.compress_only:
94 | discrimn_loss = self.compute_discrimn_loss(W)
95 | return discrimn_loss, compress_loss
96 | else:
97 | return None, compress_loss
98 |
99 |
100 | class BaseModule(LightningModule):
101 | def __init__(self, cfg):
102 | super().__init__()
103 | self.cfg = cfg
104 |
105 | self.train_dataset = NumpyTableDataset.setup(
106 | filepath_samples=cfg.get("filepath_samples"),
107 | num_clusters=cfg.get("num_clusters", None)
108 | )
109 | self.val_dataset = self.train_dataset
110 |
111 | print(f"Dataset length: {self.train_dataset.__len__()}")
112 | self.cfg.input_dim = self.train_dataset.num_features()
113 | self.cfg.n_clusters = self.train_dataset.num_clusters
114 | self.batch_size = min(self.train_dataset.__len__(), cfg.batch_size)
115 |
116 | self.save_hyperparameters()
117 | self.best_evaluation_stats = {}
118 | self.ae_train = False
119 | self.automatic_optimization = False
120 | self.best_accuracy = - np.infty
121 | self.gating_net = GatingNet(self.cfg)
122 | self.encdec = EncoderDecoder(self.cfg)
123 | self.clustering_head = clustering_head(self.cfg)
124 | self.aux_classifier_head = aux_classifier_head(self.cfg)
125 | self.mcrr = MaximalCodingRateReduction(eps=self.cfg.eps, compress_only=True)
126 | self.gtcr_loss = TotalCodingRateWithProjection(self.cfg)
127 |
128 | self.open_gates = []
129 | self.val_embs_list = []
130 |
131 | self.max_silhouette_score = []
132 | self.min_dbi_score = []
133 |
134 | def train_dataloader(self):
135 | return DataLoader(self.train_dataset,
136 | batch_size=self.batch_size,
137 | drop_last=True,
138 | shuffle=True,
139 | num_workers=0)
140 |
141 | def val_dataloader(self):
142 | return DataLoader(self.val_dataset,
143 | batch_size=self.batch_size,
144 | drop_last=False,
145 | shuffle=False,
146 | num_workers=0)
147 |
148 | def global_gates_step(self, x):
149 | gates = self.gating_net.get_gates(x)
150 | ae_emb = self.encdec.encoder(x * gates)
151 | cluster_logits = self.clustering_head(ae_emb)
152 | y_hat = cluster_logits.argmax(dim=-1)
153 | glob_gates_mu, glob_gates = self.gating_net.global_forward(y_hat)
154 | reg_loss = self.gating_net.regularization(glob_gates_mu)
155 | aux_y_hat = self.aux_classifier_head(x * gates * glob_gates)
156 | aux_loss = F.cross_entropy(aux_y_hat, y_hat)
157 | self.log('glob_gates_reg_loss', reg_loss.item())
158 | self.log('glob_gates_ce_loss', aux_loss.item())
159 | return aux_loss + self.cfg.global_gates_lambda * reg_loss
160 |
161 | def ae_step(self, x):
162 | if self.current_epoch > self.cfg.ae_non_gated_epochs:
163 | mu, _, gates = self.gating_net(x)
164 | reg_loss = self.gating_net.regularization(mu)
165 | gtcr_loss = self.gtcr_loss(gates) / x.size(0)
166 | self.log("pretrain/gates_reg_loss", reg_loss.item())
167 | self.log("pretrain/gates_tcr_loss", gtcr_loss.item())
168 | loss = self.cosine_increase_lambda(
169 | min_val=0.,
170 | max_val=self.cfg.local_gates_lambda
171 | ) * reg_loss + gtcr_loss * self.cfg.gtcr_lambda
172 | else:
173 | gates = torch.ones_like(x, device=x.device).float()
174 | loss = 0
175 |
176 | # task 1: reconstruct x from x
177 | x_recon = self.encdec(x)
178 | x_recon_loss = F.mse_loss(x_recon, x)
179 | self.log("pretrain/x_recon_loss", x_recon_loss.item())
180 |
181 | # task 2: reconstruct x from gated x:
182 | x_recon_from_gated = self.encdec(x * gates)
183 | x_from_gated_x_recon_loss = F.mse_loss(x_recon_from_gated, x)
184 | self.log("pretrain/x_from_gated_x_recon_loss", x_from_gated_x_recon_loss.item())
185 |
186 | # task 3: reconstruct x from randomly masked x
187 | mask_rnd = torch.rand(x.size()).to(x.device)
188 | mask = torch.ones(x.size()).to(x.device).float()
189 | mask[mask_rnd < self.cfg.mask_percentage] = 0
190 | x_recon_masked = self.encdec(x * mask)
191 | input_noised_recon_loss = F.mse_loss(x_recon_masked, x)
192 | self.log("pretrain/input_noised_recon_loss", input_noised_recon_loss.item())
193 |
194 | # task 4: reconstruct x from noisy embedding
195 | e = self.encdec.encoder(x)
196 | e = e * torch.normal(mean=1., std=self.cfg.latent_noise_std, size=e.size(), device=e.device)
197 | recon_noised = self.encdec.decoder(e)
198 | noised_aug_loss = F.mse_loss(recon_noised, x)
199 | self.log("pretrain/latent_noised_recon_loss", noised_aug_loss.item())
200 |
201 | # combined loss:
202 | loss = loss + x_recon_loss + x_from_gated_x_recon_loss + input_noised_recon_loss + noised_aug_loss
203 | return loss
204 |
205 | def training_step(self, x, batch_idx):
206 | ae_opt, clust_opt, glob_gates_opt = self.optimizers()
207 | pretrain_sched, sch = self.lr_schedulers()
208 | x = x.reshape(x.size(0), -1)
209 |
210 | # reconstruction step + local gates training
211 | if self.current_epoch <= self.cfg.ae_pretrain_epochs:
212 | ae_opt.zero_grad()
213 | loss = self.ae_step(x)
214 | self.manual_backward(loss)
215 | ae_opt.step()
216 | pretrain_sched.step()
217 | return
218 |
219 | # clusters compression step
220 | clust_opt.zero_grad()
221 | gates = self.gating_net.get_gates(x)
222 | ae_emb = self.encdec.encoder(x * gates)
223 | cluster_logits = self.clustering_head(ae_emb)
224 | loss = self.mcrr_loss(ae_emb, cluster_logits)
225 | self.manual_backward(loss)
226 | clust_opt.step()
227 |
228 | # global gates training
229 | if self.current_epoch >= self.cfg.start_global_gates_training_on_epoch:
230 | glob_gates_opt.zero_grad()
231 | loss = self.global_gates_step(x)
232 | self.manual_backward(loss)
233 | glob_gates_opt.step()
234 | sch.step()
235 |
236 | def configure_optimizers(self):
237 | pretrain_optimizer = torch.optim.Adam(
238 | params=chain(
239 | self.encdec.parameters(),
240 | self.gating_net.local_gates.parameters(),
241 | ),
242 | lr=self.cfg.lr.pretrain)
243 |
244 | cluster_optimizer = torch.optim.Adam(
245 | params=chain(
246 | self.clustering_head.parameters(),
247 | ),
248 | lr=self.cfg.lr.clustering)
249 |
250 | glob_gates_opt = torch.optim.SGD(
251 | params=chain(
252 | self.aux_classifier_head.parameters(),
253 | self.gating_net.global_gates_net.parameters(),
254 | ),
255 | lr=self.cfg.lr.aux_classifier)
256 |
257 | steps = self.train_dataset.__len__() // self.batch_size * (
258 | self.cfg.trainer.max_epochs - self.cfg.ae_pretrain_epochs)
259 | pretrain_steps = self.train_dataset.__len__() // self.batch_size * self.cfg.ae_pretrain_epochs
260 | # pretrain_steps = self.dataset.__len__() // self.batch_size * self.cfg.trainer.max_epochs
261 | print(f"Cosine annealing LR scheduling is applied during {steps} steps")
262 | sched = torch.optim.lr_scheduler.CosineAnnealingLR(
263 | optimizer=cluster_optimizer,
264 | T_max=steps,
265 | eta_min=self.cfg.sched.clustering_min_lr)
266 | pretrain_sched = torch.optim.lr_scheduler.CosineAnnealingLR(
267 | optimizer=pretrain_optimizer,
268 | T_max=pretrain_steps,
269 | eta_min=self.cfg.sched.pretrain_min_lr)
270 | return [pretrain_optimizer, cluster_optimizer, glob_gates_opt], [pretrain_sched, sched]
271 |
272 | def cosine_increase_lambda(self, min_val, max_val):
273 | epoch = self.current_epoch - self.cfg.ae_pretrain_epochs
274 | total_epochs = self.cfg.ae_pretrain_epochs - self.cfg.ae_non_gated_epochs
275 | return min_val + 0.5 * (max_val - min_val) * (1. + np.cos(epoch * math.pi / total_epochs))
276 |
277 | def validation_step(self, x, batch_idx):
278 | if not (self.ae_train and self.current_epoch < self.cfg.ae_pretrain_epochs) and self.current_epoch > 0:
279 | gates = self.gating_net.get_gates(x)
280 | ae_emb = self.encdec.encoder(x * gates)
281 | cluster_logits = self.clustering_head(ae_emb)
282 | y_hat = cluster_logits.argmax(dim=-1)
283 | self.val_cluster_list.append(y_hat.cpu())
284 | self.open_gates.append(self.gating_net.num_open_gates(x))
285 | self.val_embs_list.append(ae_emb)
286 |
287 | def on_validation_epoch_start(self):
288 | self.val_cluster_list = []
289 | self.open_gates = []
290 | self.val_embs_list = []
291 |
292 | @staticmethod
293 | def plot_clustering(val_embs_list, cluster_mtx, current_epoch, silhouette, dbi):
294 | reducer = umap.UMAP(n_neighbors=10, min_dist=0.1, n_components=2, random_state=0)
295 | embedding = reducer.fit_transform(torch.cat(val_embs_list, dim=0).cpu().numpy())
296 | plt.figure(figsize=(10, 7))
297 | plt.scatter(embedding[:, 0], embedding[:, 1], c=cluster_mtx.numpy(), s=50, edgecolor='k')
298 | plt.title(f'Clustering (UMAP). Epoch: {current_epoch}. Silhouette: {silhouette:0.3f}. DBI: {dbi:0.3f}')
299 | plt.savefig(f"umap_epoch_{current_epoch}.png")
300 |
301 | def on_validation_epoch_end(self):
302 | if not (self.ae_train and self.current_epoch < self.cfg.ae_pretrain_epochs) and self.current_epoch > 0:
303 | if self.current_epoch < self.cfg.ae_pretrain_epochs - 1:
304 | return
305 | else:
306 | cluster_mtx = torch.cat(self.val_cluster_list, dim=0)
307 | self.log("num_open_gates", np.mean(self.open_gates).item())
308 | self.log("num_open_global_gates", self.gating_net.open_global_gates())
309 | if self.cfg.save_seed_checkpoints:
310 | meta_dict = {"gating": self.gating_net.state_dict(), "clustering": self.clustering_net.state_dict()}
311 | torch.save(meta_dict, f'sparse_model_last_{self.cfg.dataset}_seed_{self.cfg.seed}.pth')
312 | try:
313 | silhouette_score_embs = silhouette_score(torch.cat(self.val_embs_list, dim=0).cpu().numpy(),
314 | cluster_mtx.numpy())
315 | self.log(f'silhouette_score_embs', silhouette_score_embs)
316 | self.max_silhouette_score.append(silhouette_score_embs)
317 | except:
318 | silhouette_score_embs = -1
319 | try:
320 | dbi_score = davies_bouldin_score(torch.cat(self.val_embs_list, dim=0).cpu().numpy(),
321 | cluster_mtx.numpy())
322 | self.log(f'dbi_score_embs', dbi_score)
323 | self.min_dbi_score.append(dbi_score)
324 | except:
325 | dbi_score = 0
326 |
327 | self.plot_clustering(self.val_embs_list, cluster_mtx, self.current_epoch, silhouette_score_embs, dbi_score)
328 |
329 | def mcrr_loss(self, c, logits):
330 | logprobs = torch.log_softmax(logits, dim=-1)
331 | prob = GumbleSoftmax(self.tau())(logprobs)
332 | _, compress_loss = self.mcrr(F.normalize(c), prob, num_classes=self.cfg.n_clusters)
333 | compress_loss /= c.size(1)
334 | self.log(f'compress_loss', compress_loss.item())
335 | return compress_loss
336 |
337 | def tau(self):
338 | return self.cfg.tau
339 |
340 |
341 | class GumbleSoftmax(torch.nn.Module):
342 | def __init__(self, tau, straight_through=False):
343 | super().__init__()
344 | self.tau = tau
345 | self.straight_through = straight_through
346 |
347 | def forward(self, logps):
348 | gumble = torch.rand_like(logps).log().mul(-1).log().mul(-1)
349 | logits = logps + gumble
350 | out = (logits / self.tau).softmax(dim=1)
351 | if not self.straight_through:
352 | return out
353 | else:
354 | out_binary = (logits * 1e8).softmax(dim=1).detach()
355 | out_diff = (out_binary - out).detach()
356 | return out_diff + out
357 |
358 |
359 | if __name__ == "__main__":
360 | parser = argparse.ArgumentParser()
361 | parser.add_argument('--cfg', type=str)
362 | args = parser.parse_args()
363 | cfg = OmegaConf.load(args.cfg)
364 | torch.use_deterministic_algorithms(True)
365 | torch.backends.cudnn.deterministic = True
366 | torch.backends.cudnn.benchmark = False
367 | for seed in range(cfg.seeds):
368 | cfg.seed = seed
369 | seed_everything(seed)
370 | np.random.seed(seed)
371 | model = BaseModule(cfg)
372 | logger = TensorBoardLogger("logs", name=os.path.basename(__file__), log_graph=False)
373 | trainer = Trainer(**cfg.trainer, callbacks=[LearningRateMonitor(logging_interval='step')])
374 | trainer.logger = logger
375 | trainer.fit(model)
376 |
--------------------------------------------------------------------------------
/train_evaluate.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 | import torch
3 | import math
4 | from omegaconf import OmegaConf
5 | import torch.nn.functional as F
6 | from torch.utils.data import DataLoader
7 | from pytorch_lightning import LightningModule
8 | import numpy as np
9 | from pytorch_lightning import Trainer, seed_everything
10 | import os
11 | from pytorch_lightning.loggers import TensorBoardLogger
12 | from pytorch_lightning.callbacks import LearningRateMonitor
13 | from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score, silhouette_score, davies_bouldin_score
14 | import argparse
15 | import dataset
16 | from model import clustering_head, aux_classifier_head, EncoderDecoder, GatingNet
17 |
18 |
19 | class TotalCodingRateWithProjection(torch.nn.Module):
20 | """ Based on https://github.com/zengyi-li/NMCE-release/blob/main/NMCE/loss.py """
21 | def __init__(self, cfg):
22 | super().__init__()
23 | self.eps = cfg.gtcr_eps
24 | if cfg.gtcr_projection_dim is not None:
25 | self.random_matrix = torch.tensor(np.random.normal(
26 | loc=0.0,
27 | scale=1.0 / np.sqrt(cfg.gtcr_projection_dim),
28 | size=(cfg.input_dim, cfg.gtcr_projection_dim)
29 | )).float()
30 | else:
31 | self.random_matrix = None
32 |
33 | def compute_discrimn_loss(self, W):
34 | p, m = W.shape # [d, B]
35 | I = torch.eye(p, device=W.device)
36 | scalar = p / (m * self.eps)
37 | logdet = torch.logdet(I + scalar * W.matmul(W.T))
38 | return logdet / 2.
39 |
40 | def forward(self, x):
41 | if self.random_matrix is not None:
42 | x = x @ self.random_matrix.to(x.device)
43 | return - self.compute_discrimn_loss(x.T)
44 |
45 |
46 | class MaximalCodingRateReduction(torch.nn.Module):
47 | """ Based on https://github.com/zengyi-li/NMCE-release/blob/main/NMCE/loss.py """
48 |
49 | def __init__(self, eps=0.01, gamma=1, compress_only=False):
50 | super(MaximalCodingRateReduction, self).__init__()
51 | self.eps = eps
52 | self.gamma = gamma
53 | self.compress_only = compress_only
54 |
55 | def compute_discrimn_loss(self, W):
56 | p, m = W.shape
57 | I = torch.eye(p, device=W.device)
58 | scalar = p / (m * self.eps)
59 | logdet = torch.logdet(I + scalar * W.matmul(W.T))
60 | return logdet / 2.
61 |
62 | def compute_compress_loss(self, W, Pi):
63 | p, m = W.shape
64 | k, _, _ = Pi.shape
65 | I = torch.eye(p, device=W.device).expand((k, p, p))
66 | trPi = Pi.sum(2) + 1e-8
67 | scale = (p / (trPi * self.eps)).view(k, 1, 1)
68 | W = W.view((1, p, m))
69 | log_det = torch.logdet(I + scale * W.mul(Pi).matmul(W.transpose(1, 2)))
70 | compress_loss = (trPi.squeeze() * log_det / (2 * m)).sum()
71 | return compress_loss
72 |
73 | def forward(self, X, Y, num_classes=None):
74 | # This function support Y as label integer or membership probablity.
75 | if len(Y.shape) == 1:
76 | # if Y is a label vector
77 | if num_classes is None:
78 | num_classes = Y.max() + 1
79 | Pi = torch.zeros((num_classes, 1, Y.shape[0]), device=Y.device)
80 | for indx, label in enumerate(Y):
81 | Pi[label, 0, indx] = 1
82 | else:
83 | # if Y is a probility matrix
84 | if num_classes is None:
85 | num_classes = Y.shape[1]
86 | Pi = Y.T.reshape((num_classes, 1, -1))
87 |
88 | W = X.T
89 | compress_loss = self.compute_compress_loss(W, Pi)
90 | if not self.compress_only:
91 | discrimn_loss = self.compute_discrimn_loss(W)
92 | return discrimn_loss, compress_loss
93 | else:
94 | return None, compress_loss
95 |
96 |
97 | class BaseModule(LightningModule):
98 | def __init__(self, cfg):
99 | super().__init__()
100 | self.cfg = cfg
101 | self.train_dataset = getattr(dataset, cfg.dataset).setup(cfg)
102 | self.val_dataset = self.train_dataset
103 |
104 | print(f"Dataset length: {self.train_dataset.__len__()}")
105 | self.cfg.input_dim = self.train_dataset.num_features()
106 | self.cfg.n_clusters = self.train_dataset.num_clusters
107 | self.batch_size = min(self.train_dataset.__len__(), cfg.batch_size)
108 |
109 | self.save_hyperparameters()
110 | self.best_evaluation_stats = {}
111 | self.ae_train = False
112 | self.automatic_optimization = False
113 | self.best_accuracy = - np.infty
114 | self.gating_net = GatingNet(self.cfg)
115 | self.encdec = EncoderDecoder(self.cfg)
116 | self.clustering_head = clustering_head(self.cfg)
117 | self.aux_classifier_head = aux_classifier_head(self.cfg)
118 | self.mcrr = MaximalCodingRateReduction(eps=self.cfg.eps, compress_only=True)
119 | self.gtcr_loss = TotalCodingRateWithProjection(self.cfg)
120 |
121 | self.val_cluster_list = []
122 | self.val_cluster_list_gated = []
123 | self.val_label_list = []
124 | self.open_gates = []
125 | self.val_embs_list = []
126 |
127 | self.best_acc = - 100
128 | self.best_ari = - 100
129 | self.best_nmi = - 100
130 | self.best_local_feats = None
131 | self.best_global_feats = None
132 | self.max_silhouette_score = []
133 | self.min_dbi_score = []
134 |
135 | def train_dataloader(self):
136 | return DataLoader(self.train_dataset,
137 | batch_size=self.batch_size,
138 | drop_last=True,
139 | shuffle=True,
140 | num_workers=0)
141 |
142 | def val_dataloader(self):
143 | return DataLoader(self.val_dataset,
144 | batch_size=self.batch_size,
145 | drop_last=False,
146 | shuffle=False,
147 | num_workers=0)
148 |
149 | def update_stats(self, acc, ari, nmi, local_feats, global_feats):
150 | if self.best_acc <= acc:
151 | self.best_acc = acc
152 | self.best_ari = ari
153 | self.best_nmi = nmi
154 | self.best_local_feats = local_feats
155 | self.best_global_feats = global_feats
156 |
157 | def global_gates_step(self, x):
158 | gates = self.gating_net.get_gates(x)
159 | ae_emb = self.encdec.encoder(x * gates)
160 | cluster_logits = self.clustering_head(ae_emb)
161 | y_hat = cluster_logits.argmax(dim=-1)
162 | glob_gates_mu, glob_gates = self.gating_net.global_forward(y_hat)
163 | reg_loss = self.gating_net.regularization(glob_gates_mu)
164 | aux_y_hat = self.aux_classifier_head(x * gates * glob_gates)
165 | aux_loss = F.cross_entropy(aux_y_hat, y_hat)
166 | self.log('train/glob_gates_reg_loss', reg_loss.item())
167 | self.log('train/glob_gates_ce_loss', aux_loss.item())
168 | return aux_loss + self.cfg.global_gates_lambda * reg_loss
169 |
170 | def ae_step(self, x):
171 | if self.current_epoch > self.cfg.ae_non_gated_epochs:
172 | mu, _, gates = self.gating_net(x)
173 | reg_loss = self.gating_net.regularization(mu)
174 | gtcr_loss = self.gtcr_loss(gates) / x.size(0)
175 | self.log("pretrain/gates_reg_loss", reg_loss.item())
176 | self.log("pretrain/gates_tcr_loss", gtcr_loss.item())
177 | loss = self.cosine_increase_lambda(
178 | min_val=0.,
179 | max_val=self.cfg.local_gates_lambda
180 | ) * reg_loss + gtcr_loss * self.cfg.gtcr_lambda
181 | else:
182 | gates = torch.ones_like(x, device=x.device).float()
183 | loss = 0
184 |
185 | # task 1: reconstruct x from x
186 | x_recon = self.encdec(x)
187 | x_recon_loss = F.mse_loss(x_recon, x)
188 | self.log("pretrain/x_recon_loss", x_recon_loss.item())
189 |
190 | # task 2: reconstruct x from gated x:
191 | x_recon_from_gated = self.encdec(x * gates)
192 | x_from_gated_x_recon_loss = F.mse_loss(x_recon_from_gated, x)
193 | self.log("pretrain/x_from_gated_x_recon_loss", x_from_gated_x_recon_loss.item())
194 |
195 | # task 3: reconstruct x from randomly masked x
196 | mask_rnd = torch.rand(x.size()).to(x.device)
197 | mask = torch.ones(x.size()).to(x.device).float()
198 | mask[mask_rnd < self.cfg.mask_percentage] = 0
199 | x_recon_masked = self.encdec(x * mask)
200 | input_noised_recon_loss = F.mse_loss(x_recon_masked, x)
201 | self.log("pretrain/input_noised_recon_loss", input_noised_recon_loss.item())
202 |
203 | # task 4: reconstruct x from noisy embedding
204 | e = self.encdec.encoder(x)
205 | e = e * torch.normal(mean=1., std=self.cfg.latent_noise_std, size=e.size(), device=e.device)
206 | recon_noised = self.encdec.decoder(e)
207 | noised_aug_loss = F.mse_loss(recon_noised, x)
208 | self.log("pretrain/latent_noised_recon_loss", noised_aug_loss.item())
209 |
210 | # combined loss:
211 | loss = loss + x_recon_loss + x_from_gated_x_recon_loss + input_noised_recon_loss + noised_aug_loss
212 | return loss
213 |
214 | def training_step(self, batch, batch_idx):
215 | ae_opt, clust_opt, glob_gates_opt = self.optimizers()
216 | pretrain_sched, sch = self.lr_schedulers()
217 | x, _ = batch
218 | x = x.reshape(x.size(0), -1)
219 |
220 | # reconstruction step + local gates training
221 | if self.current_epoch <= self.cfg.ae_pretrain_epochs:
222 | ae_opt.zero_grad()
223 | loss = self.ae_step(x)
224 | self.manual_backward(loss)
225 | ae_opt.step()
226 | pretrain_sched.step()
227 | return
228 |
229 | # clusters compression step
230 | clust_opt.zero_grad()
231 | gates = self.gating_net.get_gates(x)
232 | ae_emb = self.encdec.encoder(x * gates)
233 | cluster_logits = self.clustering_head(ae_emb)
234 | loss = self.mcrr_loss(ae_emb, cluster_logits)
235 | self.manual_backward(loss)
236 | clust_opt.step()
237 |
238 | # global gates training
239 | if self.current_epoch >= self.cfg.start_global_gates_training_on_epoch:
240 | glob_gates_opt.zero_grad()
241 | loss = self.global_gates_step(x)
242 | self.manual_backward(loss)
243 | glob_gates_opt.step()
244 | sch.step()
245 |
246 | def configure_optimizers(self):
247 | pretrain_optimizer = torch.optim.Adam(
248 | params=chain(
249 | self.encdec.parameters(),
250 | self.gating_net.local_gates.parameters(),
251 | ),
252 | lr=self.cfg.lr.pretrain)
253 |
254 | cluster_optimizer = torch.optim.Adam(
255 | params=chain(
256 | self.clustering_head.parameters(),
257 | ),
258 | lr=self.cfg.lr.clustering)
259 |
260 | glob_gates_opt = torch.optim.SGD(
261 | params=chain(
262 | self.aux_classifier_head.parameters(),
263 | self.gating_net.global_gates_net.parameters(),
264 | ),
265 | lr=self.cfg.lr.aux_classifier)
266 |
267 | steps = self.train_dataset.__len__() // self.batch_size * (
268 | self.cfg.trainer.max_epochs - self.cfg.ae_pretrain_epochs)
269 | pretrain_steps = self.train_dataset.__len__() // self.batch_size * self.cfg.ae_pretrain_epochs
270 | # pretrain_steps = self.dataset.__len__() // self.batch_size * self.cfg.trainer.max_epochs
271 | print(f"Cosine annealing LR scheduling is applied during {steps} steps")
272 | sched = torch.optim.lr_scheduler.CosineAnnealingLR(
273 | optimizer=cluster_optimizer,
274 | T_max=steps,
275 | eta_min=self.cfg.sched.clustering_min_lr)
276 | pretrain_sched = torch.optim.lr_scheduler.CosineAnnealingLR(
277 | optimizer=pretrain_optimizer,
278 | T_max=pretrain_steps,
279 | eta_min=self.cfg.sched.pretrain_min_lr)
280 | return [pretrain_optimizer, cluster_optimizer, glob_gates_opt], [pretrain_sched, sched]
281 |
282 | def cosine_increase_lambda(self, min_val, max_val):
283 | epoch = self.current_epoch - self.cfg.ae_pretrain_epochs
284 | total_epochs = self.cfg.ae_pretrain_epochs - self.cfg.ae_non_gated_epochs
285 | return min_val + 0.5 * (max_val - min_val) * (1. + np.cos(epoch * math.pi / total_epochs))
286 |
287 | def validation_step(self, batch, batch_idx):
288 | x, y = batch
289 | gates = self.gating_net.get_gates(x)
290 | ae_emb = self.encdec.encoder(x * gates)
291 | cluster_logits = self.clustering_head(ae_emb)
292 | y_hat = cluster_logits.argmax(dim=-1)
293 | self.val_cluster_list.append(y_hat.cpu())
294 | self.val_label_list.append(y.cpu())
295 | self.open_gates.append(self.gating_net.num_open_gates(x))
296 | self.val_embs_list.append(ae_emb)
297 |
298 | def on_validation_epoch_start(self):
299 | self.val_cluster_list = []
300 | self.val_cluster_list_gated = []
301 | self.val_label_list = []
302 | self.open_gates = []
303 | self.val_embs_list = []
304 |
305 | @staticmethod
306 | def cluster_match(cluster_mtx, label_mtx, n_classes=10, print_result=True):
307 | cluster_indx = list(cluster_mtx.unique())
308 | assigned_label_list = []
309 | assigned_count = []
310 | while (len(assigned_label_list) <= n_classes) and len(cluster_indx) > 0:
311 | max_label_list = []
312 | max_count_list = []
313 | for indx in cluster_indx:
314 | mask = cluster_mtx == indx
315 | label_elements, counts = label_mtx[mask].unique(return_counts=True)
316 | for assigned_label in assigned_label_list:
317 | counts[label_elements == assigned_label] = 0
318 | max_count_list.append(counts.max())
319 | max_label_list.append(label_elements[counts.argmax()])
320 |
321 | max_label = torch.stack(max_label_list)
322 | max_count = torch.stack(max_count_list)
323 | assigned_label_list.append(max_label[max_count.argmax()])
324 | assigned_count.append(max_count.max())
325 | cluster_indx.pop(max_count.argmax().item())
326 | total_correct = torch.tensor(assigned_count).sum().item()
327 | total_sample = cluster_mtx.shape[0]
328 | acc = total_correct / total_sample
329 | if print_result:
330 | print('{}/{} ({}%) correct'.format(total_correct, total_sample, acc * 100))
331 | else:
332 | return total_correct, total_sample, acc
333 |
334 | def on_validation_epoch_end(self):
335 | """ Based on https://github.com/zengyi-li/NMCE-release/blob/main/NMCE/func.py"""
336 | if not (self.ae_train and self.current_epoch < self.cfg.ae_pretrain_epochs) and self.current_epoch > 0:
337 | if self.current_epoch < self.cfg.ae_pretrain_epochs - 1:
338 | return
339 | else:
340 | cluster_mtx = torch.cat(self.val_cluster_list, dim=0)
341 | label_mtx = torch.cat(self.val_label_list, dim=0)
342 | _, _, acc_single = self.cluster_match(
343 | cluster_mtx,
344 | label_mtx,
345 | n_classes=label_mtx.max() + 1,
346 | print_result=False)
347 | if self.best_accuracy < acc_single:
348 | print("New best accuracy:", acc_single)
349 | self.best_accuracy = acc_single
350 | if self.cfg.save_seed_checkpoints:
351 | meta_dict = {"gating": self.gating_net.state_dict(), "clustering": self.clustering_net.state_dict()}
352 | torch.save(meta_dict, f'sparse_model_best_{self.cfg.dataset}_seed_{self.cfg.seed}.pth')
353 |
354 | nmi = normalized_mutual_info_score(label_mtx.numpy(), cluster_mtx.numpy())
355 | ari = adjusted_rand_score(label_mtx.numpy(), cluster_mtx.numpy())
356 | format_str = '' # '_kmeans' if self.current_epoch == 9 else ''
357 | self.log(f'val/acc_single{format_str}', acc_single) # this is ACC
358 | self.log(f'val/NMI{format_str}', nmi)
359 | self.log(f'val/ARI{format_str}', ari)
360 | self.log("val/num_open_gates", np.mean(self.open_gates).item())
361 | self.log("val/num_open_global_gates", self.gating_net.open_global_gates())
362 | if self.cfg.save_seed_checkpoints:
363 | meta_dict = {"gating": self.gating_net.state_dict(), "clustering": self.clustering_net.state_dict()}
364 | torch.save(meta_dict, f'sparse_model_last_{self.cfg.dataset}_seed_{self.cfg.seed}.pth')
365 |
366 | self.update_stats(acc_single, ari, nmi, np.mean(self.open_gates).item(),
367 | self.gating_net.open_global_gates())
368 |
369 | try:
370 | silhouette_score_embs = silhouette_score(torch.cat(self.val_embs_list, dim=0).cpu().numpy(),
371 | cluster_mtx.numpy())
372 | self.log(f'val/silhouette_score_embs', silhouette_score_embs)
373 | self.max_silhouette_score.append(silhouette_score_embs)
374 | except:
375 | pass
376 | try:
377 | dbi_score = davies_bouldin_score(torch.cat(self.val_embs_list, dim=0).cpu().numpy(),
378 | cluster_mtx.numpy())
379 | self.log(f'val/dbi_score_embs', dbi_score)
380 | self.min_dbi_score.append(dbi_score)
381 | except:
382 | pass
383 |
384 | def mcrr_loss(self, c, logits):
385 | logprobs = torch.log_softmax(logits, dim=-1)
386 | prob = GumbleSoftmax(self.tau())(logprobs)
387 | _, compress_loss = self.mcrr(F.normalize(c), prob, num_classes=self.cfg.n_clusters)
388 | compress_loss /= c.size(1)
389 | self.log(f'train/compress_loss', compress_loss.item())
390 | return compress_loss
391 |
392 | def tau(self):
393 | return self.cfg.tau
394 |
395 |
396 | class GumbleSoftmax(torch.nn.Module):
397 | def __init__(self, tau, straight_through=False):
398 | super().__init__()
399 | self.tau = tau
400 | self.straight_through = straight_through
401 |
402 | def forward(self, logps):
403 | gumble = torch.rand_like(logps).log().mul(-1).log().mul(-1)
404 | logits = logps + gumble
405 | out = (logits / self.tau).softmax(dim=1)
406 | if not self.straight_through:
407 | return out
408 | else:
409 | out_binary = (logits * 1e8).softmax(dim=1).detach()
410 | out_diff = (out_binary - out).detach()
411 | return out_diff + out
412 |
413 |
414 | if __name__ == "__main__":
415 | parser = argparse.ArgumentParser()
416 | parser.add_argument('--cfg', type=str)
417 | args = parser.parse_args()
418 | cfg = OmegaConf.load(args.cfg)
419 | torch.use_deterministic_algorithms(True)
420 | torch.backends.cudnn.deterministic = True
421 | torch.backends.cudnn.benchmark = False
422 | if not cfg.validate:
423 | cfg.trainer.check_val_every_n_epoch = cfg.trainer.max_epochs + 1 # the validation will be never done
424 |
425 | with open(f"results_{os.path.basename(__file__)}.txt", mode='a') as f:
426 | header = '\t'.join(['seed', 'acc', 'ari', 'nmi', 'local_gates', 'global_gates',
427 | 'topk_max_silhouette_score', 'topk_min_dbi_score'])
428 | f.write(f"{header}\n")
429 | f.flush()
430 | for seed in range(cfg.seeds):
431 | cfg.seed = seed
432 | seed_everything(seed)
433 | np.random.seed(seed)
434 | if not os.path.exists(cfg.dataset):
435 | os.makedirs(cfg.dataset)
436 | model = BaseModule(cfg)
437 | logger = TensorBoardLogger(cfg.dataset, name=os.path.basename(__file__), log_graph=False)
438 | trainer = Trainer(**cfg.trainer, callbacks=[LearningRateMonitor(logging_interval='step')])
439 | trainer.logger = logger
440 | trainer.fit(model)
441 | topk_max_siluetter_score = np.mean(sorted(model.max_silhouette_score, reverse=True)[:10])
442 | topk_min_dbi_score = np.mean(sorted(model.max_silhouette_score)[:10])
443 | results_str = '\t'.join(
444 | [f'{seed}',
445 | f'{model.best_acc}',
446 | f'{model.best_ari}',
447 | f'{model.best_nmi}',
448 | f'{model.best_local_feats}',
449 | f'{model.best_global_feats}',
450 | f'{topk_max_siluetter_score}',
451 | f'{topk_min_dbi_score}',
452 | ])
453 | with open(f"results_{os.path.basename(__file__)}.txt", mode='a') as f:
454 | f.write(f"{results_str}\n")
455 | f.flush()
456 |
--------------------------------------------------------------------------------