├── .gitignore ├── LICENSE.md ├── README.md ├── api ├── diffusion.py ├── dse.py ├── dsmi.py └── information_utils.py └── assets ├── compare-cifar10-supervised-resnet-ConvInitStd-1e-2-1e-1-seed1-2-3.png ├── compare-stl10-supervised-resnet-ConvInitStd-1e-2-1e-1-seed1-2-3.png ├── curse_of_dim.png ├── def_DSE.png ├── def_DSMI.png ├── logos ├── MetaAI_logo.png ├── Mila_logo.png └── Yale_logo.png ├── main_figure_DSE(Z).png ├── main_figure_DSMI(Z;X).png ├── main_figure_DSMI(Z;Y).png ├── method_comparison.png ├── procedure.png ├── visualize_embeddings.png └── vs_imagenet_acc.png /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | data 3 | external_models 4 | pretrained_models 5 | **/__pycache__ 6 | tools/_jupyter/.ipynb_checkpoints 7 | training 8 | validation 9 | testing 10 | 11 | # Logs 12 | logs 13 | **/*.log 14 | 15 | # Models 16 | **/*.pt 17 | **/*.pkl 18 | **/*.pth 19 | **/*.pth.tar 20 | 21 | # Slurm 22 | **/slurm* 23 | **/bash_*.sh -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Non-Commercial License Yale Copyright © 2024 Yale University. 2 | 3 | Permission is hereby granted to use, copy, modify, and distribute this Software for any non-commercial purpose. Any distribution or modification or derivations of the Software (together “Derivative Works”) must be made available on GitHub and shall include this copyright notice and this permission notice in all copies or substantial portions of the Software. For the purposes of this license, "non-commercial" means not intended for or directed towards commercial advantage or monetary compensation either via the Software itself or Derivative Works or uses of either which lead to or generate any commercial products. In any event, the use and modification of the Software or Derivative Works shall remain governed by the terms and conditions of this Agreement; Any commercial use of the Software requires a separate commercial license from the copyright holder at Yale University. Direct any requests for commercial licenses to Yale Ventures at yaleventures@yale.edu. 4 | 5 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 | [ICMLW 2023, IEEE CISS 2024] DSE/DSMI 4 |

5 | 6 |

7 | Diffusion Spectral Entropy and Mutual Information 8 |

9 | 10 |
11 | 12 | [![ArXiv](https://img.shields.io/badge/ArXiv-DSE/DSMI-firebrick)](https://arxiv.org/abs/2312.04823) 13 | [![Slides](https://img.shields.io/badge/Slides-yellow)](https://chenliu-1996.github.io/slides/DSE_slides.pdf) 14 | [![Twitter](https://img.shields.io/twitter/follow/KrishnaswamyLab.svg?style=social)](https://twitter.com/KrishnaswamyLab) 15 | [![Twitter](https://img.shields.io/twitter/follow/DanqiLiao.svg?style=social)](https://x.com/DanqiLiao73090) 16 | [![Twitter](https://img.shields.io/twitter/follow/ChenLiu-1996.svg?style=social)](https://twitter.com/ChenLiu_1996) 17 | [![LinkedIn](https://img.shields.io/badge/LinkedIn-ChenLiu-1996?color=blue)](https://www.linkedin.com/in/chenliu1996/) 18 | [![Github Stars](https://img.shields.io/github/stars/ChenLiu-1996/DiffusionSpectralEntropy.svg?style=social&label=Stars)](https://github.com/ChenLiu-1996/DiffusionSpectralEntropy/) 19 | 20 |
21 | 22 | 23 | **Krishnaswamy Lab, Yale University** 24 | 25 | This is the **official** implementation of 26 | 27 | [**Assessing Neural Network Representations During Training Using Noise-Resilient Diffusion Spectral Entropy**](https://arxiv.org/abs/2312.04823) 28 | 29 | 30 | 31 | 32 | ## Announcement 33 | **Due to certain internal policies, we removed the codebase from public access. However, for the benefit of future researchers, we hereby provide the DSE/DSMI functions.** 34 | 35 | ## Citation 36 | ``` 37 | @inproceedings{DSE2023, 38 | title={Assessing Neural Network Representations During Training Using Noise-Resilient Diffusion Spectral Entropy}, 39 | author={Liao, Danqi and Liu, Chen and Christensen, Ben and Tong, Alexander and Huguet, Guillaume and Wolf, Guy and Nickel, Maximilian and Adelstein, Ian and Krishnaswamy, Smita}, 40 | booktitle={ICML 2023 Workshop on Topology, Algebra and Geometry in Machine Learning (TAG-ML)}, 41 | year={2023}, 42 | } 43 | @inproceedings{DSE2024, 44 | title={Assessing neural network representations during training using noise-resilient diffusion spectral entropy}, 45 | author={Liao, Danqi and Liu, Chen and Christensen, Benjamin W and Tong, Alexander and Huguet, Guillaume and Wolf, Guy and Nickel, Maximilian and Adelstein, Ian and Krishnaswamy, Smita}, 46 | booktitle={2024 58th Annual Conference on Information Sciences and Systems (CISS)}, 47 | pages={1--6}, 48 | year={2024}, 49 | organization={IEEE} 50 | } 51 | ``` 52 | 53 | 54 | ## API: Your One-Stop Shop 55 | Here we present the refactored and reorganized APIs for this project. 56 | 57 | ### Diffusion Spectral Entropy 58 | [Go to function](./api/dse.py/#L7) 59 | ``` 60 | api > dse.py > diffusion_spectral_entropy 61 | ``` 62 | 63 | ### Diffusion Spectral Mutual Information 64 | [Go to function](./api/dsmi.py/#L7) 65 | ``` 66 | api > dsmi.py > diffusion_spectral_mutual_information 67 | ``` 68 | 69 | ### Unit Tests for DSE and DSMI 70 | You can directly run the following lines for built-in unit tests. 71 | ``` 72 | python dse.py 73 | python dsmi.py 74 | ``` 75 | 76 | ## Overview 77 | > We proposed a framework to measure the **entropy** and **mutual information** in high dimensional data and thus applicable to modern neural networks. 78 | 79 | We can measure, with respect to a given set of data samples, (1) the entropy of the neural representation at a specific layer and (2) the mutual information between a random variable (e.g., model input or output) and the neural representation at a specific layer. 80 | 81 | Compared to the classic Shannon formulation using the binning method, e.g. as in the famous paper **_Deep Learning and the Information Bottleneck Principle_** [[PDF]](https://arxiv.org/abs/1503.02406) [[Github1]](https://github.com/stevenliuyi/information-bottleneck) [[Github2]](https://github.com/artemyk/ibsgd), our proposed method is more robust and expressive. 82 | 83 | ## Main Advantage 84 | No binning and hence **no curse of dimensionality**. Therefore, **it works on modern deep neural networks** (e.g., ResNet-50), not just on toy models with double digit layer width. See "Limitations of the Classic Shannon Entropy and Mutual Information" in our paper for details. 85 | 86 | 87 | 88 | ## A One-Minute Explanation of the Methods 89 | Conceptually, we build a data graph from the neural network representations of all data points in a dataset, and compute the diffusion matrix of the data graph. This matrix is a condensed representation of the diffusion geometry of the neural representation manifold. Our proposed **Diffusion Spectral Entropy (DSE)** and **Diffusion Spectral Mutual Information (DSMI)** can be computed from this diffusion matrix. 90 | 91 | 92 | 93 | ## Quick Flavors of the Results 94 | 95 | ### Definition 96 | 97 | 98 | ### Theoretical Results 99 | One major statement to make is that the proposed DSE and DSMI are "not conceptually the same as" the classic Shannon counterparts. They are defined differently and while they maintain the gist of "entropy" and "mutual information" measures, they have their own unique properties. For example, DSE is *more sensitive to the underlying dimension and structures (e.g., number of branches or clusters) than to the spread or noise in the data itself, which is contracted to the manifold by raising the diffusion operator to the power of $t$*. 100 | 101 | In the theoretical results, we upper- and lower-bounded the proposed DSE and DSMI. More interestingly, we showed that if a data distribution originates as a single Gaussian blob but later evolves into $k$ distinct Gaussian blobs, the upper bound of the expected DSE will increase. This has implication for the training process of classification networks. 102 | 103 | ### Empirical Results 104 | We first use toy experiments to showcase that DSE and DSMI "behave properly" as measures of entropy and mutual information. We also demonstrate they are more robust to high dimensions than the classic counterparts. 105 | 106 | Then, we also look at how well DSE and DSMI behave at higher dimensions. In the figure below, we show how DSMI outperforms other mutual information estimators when the dimension is high. Besides, the runtime comparison shows DSMI scales better with respect to dimension. 107 | 108 | 109 | 110 |
111 | 112 | Finally, it's time to put them in practice! We use DSE and DSMI to visualize the training dynamics of classification networks of 6 backbones (3 ConvNets and 3 Transformers) under 3 training conditions and 3 random seeds. We are evaluating the penultimate layer of the neural network --- the second-to-last layer where people believe embeds the rich representation of the data and are often used for visualization, linear-probing evaluation, etc. 113 | 114 | 115 | 116 | DSE(Z) increasese during training. This happens for both generalizable training and overfitting. The former case coincides with our theoretical finding that DSE(Z) shall increase as the model learns to separate data representation into clusters. 117 | 118 | 119 | 120 | DSMI(Z; Y) increases during generalizable training but stays stagnant during overfitting. This is very much expected. 121 | 122 | 123 | 124 | DSMI(Z; X) shows quite intriguing trends. On MNIST, it keeps increasing. On CIFAR-10 and STL-10, it peaks quickly and gradually decreases. Recall that IB [Tishby et al.] suggests that I(Z; X) shall decrease while [Saxe et al. ICLR'18] believes the opposite. We find that both of them could be correct since the trend we observe is dataset-dependent. One possibility is that MNIST features are too easy to learn (and perhaps the models all overfit?) --- and we leave this to future explorations. 125 | 126 | 127 | ## Utility Studies: How can we use DSE and DSMI? 128 | One may ask, besides just peeking into the training dynamics of neural networks, how can we _REALLY_ use DSE and DSMI? Here comes the utility studies. 129 | 130 | ### Guiding network initialization with DSE 131 | We sought to assess the effects of network initialization in terms of DSE. We were motivated by two observations: (1) the initial DSEs for different models are not always the same despite using the same method for random initialization; (2) if DSE starts low, it grows monotonically; if DSE starts high, it first decreases and then increases. 132 | 133 | We found that if we initialize the convolutional layers with weights $\sim \mathcal{N}(0, \sigma)$, DSE $S_D(Z)$ is affected by $\sigma$. We then trained ResNet models with networks initialized at high ($\approx$ log(n)) versus low ($\approx 0$) DSE by setting $\sigma=0.1$ and $\sigma=0.01$, respectively. The training history suggests that initializing the network at a lower $S_D(Z)$ can improve the convergence speed and final performance. We believe this is because the high initial DSE from random initialization corresponds to an undesirable high-entropy state, which the network needs to get away from (causing the DSE decrease) before it migrates to the desirable high-entropy state (causing the DSE increase). 134 | 135 | 136 | 137 | ### ImageNet cross-model correlation 138 | By far, we have monitored DSE and DSMI **along the training process of the same model**. Now we will show how DSMI correlates with downstream classification accuracy **across many different pre-trained models**. The following result demonstrates the potential in using DSMI for pre-screening potentially competent models for your specialized dataset. 139 | 140 | 141 | 142 | 143 | ## Reproducing Results in the ongoing submission. 144 | Removed due to internal policies. 145 | 146 | ## Preparation 147 | 148 | ### Environment 149 | We developed the codebase in a miniconda environment. 150 | Tested on Python 3.9.13 + PyTorch 1.12.1. 151 | How we created the conda environment: 152 | **Some packages may no longer be required.** 153 | ``` 154 | conda create --name dse pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch 155 | conda activate dse 156 | conda install -c anaconda scikit-image pillow matplotlib seaborn tqdm 157 | python -m pip install -U giotto-tda 158 | python -m pip install POT torch-optimizer 159 | python -m pip install tinyimagenet 160 | python -m pip install natsort 161 | python -m pip install phate 162 | python -m pip install DiffusionEMD 163 | python -m pip install magic-impute 164 | python -m pip install timm 165 | python -m pip install pytorch-lightning 166 | ``` 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /api/diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import pairwise_distances 3 | import warnings 4 | 5 | warnings.filterwarnings("ignore") 6 | 7 | 8 | def compute_diffusion_matrix(X: np.array, sigma: float = 10.0): 9 | ''' 10 | Adapted from 11 | https://github.com/professorwug/diffusion_curvature/blob/master/diffusion_curvature/core.py 12 | 13 | Given input X returns a diffusion matrix P, as an numpy ndarray. 14 | Using the "anisotropic" kernel 15 | Inputs: 16 | X: a numpy array of size n x d 17 | sigma: a float 18 | conceptually, the neighborhood size of Gaussian kernel. 19 | Returns: 20 | K: a numpy array of size n x n that has the same eigenvalues as the diffusion matrix. 21 | ''' 22 | 23 | # Construct the distance matrix. 24 | D = pairwise_distances(X) 25 | 26 | # Gaussian kernel 27 | G = (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp((-D**2) / (2 * sigma**2)) 28 | 29 | # Anisotropic density normalization. 30 | Deg = np.diag(1 / np.sum(G, axis=1)**0.5) 31 | K = Deg @ G @ Deg 32 | 33 | # Now K has the exact same eigenvalues as the diffusion matrix `P` 34 | # which is defined as `P = D^{-1} K`, with `D = np.diag(np.sum(K, axis=1))`. 35 | 36 | return K 37 | -------------------------------------------------------------------------------- /api/dse.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from information_utils import approx_eigvals, exact_eigvals 3 | from diffusion import compute_diffusion_matrix 4 | import os 5 | import random 6 | 7 | from sklearn.metrics import pairwise_distances 8 | 9 | 10 | def diffusion_spectral_entropy(embedding_vectors: np.array, 11 | gaussian_kernel_sigma: float = 10, 12 | t: int = 1, 13 | max_N: int = 10000, 14 | chebyshev_approx: bool = False, 15 | eigval_save_path: str = None, 16 | eigval_save_precision: np.dtype = np.float16, 17 | classic_shannon_entropy: bool = False, 18 | matrix_entry_entropy: bool = False, 19 | num_bins_per_dim: int = 2, 20 | random_seed: int = 0, 21 | verbose: bool = False): 22 | ''' 23 | >>> If `classic_shannon_entropy` is False (default) 24 | 25 | Diffusion Spectral Entropy over a set of N vectors, each of D dimensions. 26 | 27 | DSE = - sum_i [eig_i^t log eig_i^t] 28 | where each `eig_i` is an eigenvalue of `P`, 29 | where `P` is the diffusion matrix computed on the data graph of the [N, D] vectors. 30 | 31 | >>> If `classic_shannon_entropy` is True 32 | 33 | Classic Shannon Entropy over a set of N vectors, each of D dimensions. 34 | 35 | CSE = - sum_i [p(x) log p(x)] 36 | where each p(x) is the probability density of a histogram bin, after some sort of binning. 37 | 38 | args: 39 | embedding_vectors: np.array of shape [N, D] 40 | N: number of data points / samples 41 | D: number of feature dimensions of the neural representation 42 | 43 | gaussian_kernel_sigma: float 44 | The bandwidth of Gaussian kernel (for computation of the diffusion matrix) 45 | Can be adjusted per the dataset. 46 | Increase if the data points are very far away from each other. 47 | 48 | t: int 49 | Power of diffusion matrix (equivalent to power of diffusion eigenvalues) 50 | <-> Iteration of diffusion process 51 | Usually small, e.g., 1 or 2. 52 | Can be adjusted per dataset. 53 | Rule of thumb: after powering eigenvalues to `t`, there should be approximately 54 | 1 percent of eigenvalues that remain larger than 0.01 55 | 56 | max_N: int 57 | Max number of data points / samples used for computation. 58 | 59 | chebyshev_approx: bool 60 | Whether or not to use Chebyshev moments for faster approximation of eigenvalues. 61 | Currently we DO NOT RECOMMEND USING THIS. Eigenvalues may be changed quite a bit. 62 | 63 | eigval_save_path: str 64 | If provided, 65 | (1) If running for the first time, will save the computed eigenvalues in this location. 66 | (2) Otherwise, if the file already exists, skip eigenvalue computation and load from this file. 67 | 68 | eigval_save_precision: np.dtype 69 | We use `np.float16` by default to reduce storage space required. 70 | For best precision, use `np.float64` instead. 71 | 72 | classic_shannon_entropy: bool 73 | Toggle between DSE and CSE. False (default) == DSE. 74 | 75 | matrix_entry_entropy: bool 76 | An alternative formulation where, instead of computing the entropy on 77 | diffusion matrix eigenvalues, we compute the entropy on diffusion matrix entries. 78 | Only relevant to DSE. 79 | 80 | num_bins_per_dim: int 81 | Number of bins per feature dim. 82 | Only relevant to CSE (i.e., `classic_shannon_entropy` is True). 83 | 84 | verbose: bool 85 | Whether or not to print progress to console. 86 | ''' 87 | 88 | # Subsample embedding vectors if number of data sample is too large. 89 | if max_N is not None and embedding_vectors is not None and len( 90 | embedding_vectors) > max_N: 91 | if random_seed is not None: 92 | random.seed(random_seed) 93 | rand_inds = np.array( 94 | random.sample(range(len(embedding_vectors)), k=max_N)) 95 | embedding_vectors = embedding_vectors[rand_inds, :] 96 | 97 | if not classic_shannon_entropy: 98 | # Computing Diffusion Spectral Entropy. 99 | if verbose: print('Computing Diffusion Spectral Entropy...') 100 | 101 | if matrix_entry_entropy: 102 | if verbose: print('Computing diffusion matrix.') 103 | # Compute diffusion matrix `P`. 104 | K = compute_diffusion_matrix(embedding_vectors, 105 | sigma=gaussian_kernel_sigma) 106 | # Row normalize to get proper row stochastic matrix P 107 | D_inv = np.diag(1.0 / np.sum(K, axis=1)) 108 | P = D_inv @ K 109 | 110 | if verbose: print('Diffusion matrix computed.') 111 | 112 | entries = P.reshape(-1) 113 | entries = np.abs(entries) 114 | prob = entries / entries.sum() 115 | 116 | else: 117 | if eigval_save_path is not None and os.path.exists( 118 | eigval_save_path): 119 | if verbose: 120 | print('Loading pre-computed eigenvalues from %s' % 121 | eigval_save_path) 122 | eigvals = np.load(eigval_save_path)['eigvals'] 123 | eigvals = eigvals.astype( 124 | np.float64) # mitigate rounding error. 125 | if verbose: print('Pre-computed eigenvalues loaded.') 126 | 127 | else: 128 | if verbose: print('Computing diffusion matrix.') 129 | # Note that `K` is a symmetric matrix with the same eigenvalues as the diffusion matrix `P`. 130 | K = compute_diffusion_matrix(embedding_vectors, 131 | sigma=gaussian_kernel_sigma) 132 | if verbose: print('Diffusion matrix computed.') 133 | 134 | if verbose: print('Computing eigenvalues.') 135 | if chebyshev_approx: 136 | if verbose: print('Using Chebyshev approximation.') 137 | eigvals = approx_eigvals(K) 138 | else: 139 | eigvals = exact_eigvals(K) 140 | if verbose: print('Eigenvalues computed.') 141 | 142 | if eigval_save_path is not None: 143 | os.makedirs(os.path.dirname(eigval_save_path), 144 | exist_ok=True) 145 | # Save eigenvalues. 146 | eigvals = eigvals.astype( 147 | eigval_save_precision) # reduce storage space. 148 | with open(eigval_save_path, 'wb+') as f: 149 | np.savez(f, eigvals=eigvals) 150 | if verbose: 151 | print('Eigenvalues saved to %s' % eigval_save_path) 152 | 153 | # Eigenvalues may be negative. Only care about the magnitude, not the sign. 154 | eigvals = np.abs(eigvals) 155 | 156 | # Power eigenvalues to `t` to mitigate effect of noise. 157 | eigvals = eigvals**t 158 | 159 | prob = eigvals / eigvals.sum() 160 | 161 | else: 162 | # Computing Classic Shannon Entropy. 163 | if verbose: print('Computing Classic Shannon Entropy...') 164 | 165 | vecs = embedding_vectors.copy() 166 | 167 | # Min-Max scale each dimension. 168 | vecs = (vecs - np.min(vecs, axis=0)) / (np.max(vecs, axis=0) - 169 | np.min(vecs, axis=0)) 170 | 171 | # Bin along each dimension. 172 | bins = np.linspace(0, 1, num_bins_per_dim + 1)[:-1] 173 | vecs = np.digitize(vecs, bins=bins) 174 | 175 | # Count probability. 176 | counts = np.unique(vecs, axis=0, return_counts=True)[1] 177 | prob = counts / np.sum(counts) 178 | 179 | prob = prob + np.finfo(float).eps 180 | entropy = -np.sum(prob * np.log2(prob)) 181 | 182 | return entropy 183 | 184 | def adjacency_spectral_entropy(embedding_vectors: np.array, 185 | gaussian_kernel_sigma: float = 10, 186 | anisotropic: bool = False, 187 | use_knn: bool = False, 188 | knn: int = 10, 189 | max_N: int = 10000, 190 | eigval_save_path: str = None, 191 | eigval_save_precision: np.dtype = np.float16, 192 | random_seed: int = 0, 193 | verbose: bool = False): 194 | ''' 195 | Entropy based on eigenvals from adjacency matrix instead of diffusion matrix 196 | 197 | embedding_vectors: np.array of shape [N, D] 198 | N: number of data points / samples 199 | D: number of feature dimensions of the neural representation 200 | 201 | gaussian_kernel_sigma: float 202 | The bandwidth of Gaussian kernel (for computation of the affinity matrix) 203 | Can be adjusted per the dataset. 204 | Increase if the data points are very far away from each other. 205 | 206 | anisotropic: bool 207 | Whether to use anisotropic normalization 208 | Default false 209 | 210 | use_knn: bool 211 | Whether to use KNN for computing adjacency matrix (binarized) 212 | Default False, and the defualt is using Gaussian kernel for adjacency (non-binarized) 213 | 214 | knn: int 215 | Number of neighbors for KNN adj matrix 216 | 217 | max_N: int 218 | Max number of data points / samples used for computation. 219 | 220 | eigval_save_path: str 221 | If provided, 222 | (1) If running for the first time, will save the computed eigenvalues in this location. 223 | (2) Otherwise, if the file already exists, skip eigenvalue computation and load from this file. 224 | 225 | eigval_save_precision: np.dtype 226 | We use `np.float16` by default to reduce storage space required. 227 | For best precision, use `np.float64` instead. 228 | 229 | verbose: bool 230 | Whether or not to print progress to console. 231 | ''' 232 | # Subsample embedding vectors if number of data sample is too large. 233 | if max_N is not None and embedding_vectors is not None and len( 234 | embedding_vectors) > max_N: 235 | if random_seed is not None: 236 | random.seed(random_seed) 237 | rand_inds = np.array( 238 | random.sample(range(len(embedding_vectors)), k=max_N)) 239 | embedding_vectors = embedding_vectors[rand_inds, :] 240 | 241 | if eigval_save_path is not None and os.path.exists(eigval_save_path): 242 | if verbose: 243 | print('Loading pre-computed eigenvalues from %s' % 244 | eigval_save_path) 245 | eigvals = np.load(eigval_save_path)['eigvals'] 246 | eigvals = eigvals.astype( 247 | np.float64) # mitigate rounding error. 248 | if verbose: print('Pre-computed eigenvalues loaded.') 249 | else: 250 | if verbose: print('Computing adjacency matrix.') 251 | adj_matrix = None 252 | 253 | # Construct the distance matrix. 254 | D = pairwise_distances(embedding_vectors) 255 | if use_knn != True: 256 | ''' Gaussian kernel adj ''' 257 | G = (1 / (gaussian_kernel_sigma * np.sqrt(2 * np.pi))) * np.exp((-D**2) / (2 * gaussian_kernel_sigma**2)) 258 | 259 | if anisotropic == True: 260 | # Anisotropic density normalization. 261 | Deg = np.diag(1 / np.sum(G, axis=1)**0.5) 262 | K = Deg @ G @ Deg 263 | adj_matrix = K 264 | else: 265 | adj_matrix = G 266 | 267 | else: 268 | N = D.shape[0] 269 | adj_matrix = np.zeros(D.shape) 270 | ''' KNN binarized adj ''' 271 | mink_index = np.argpartition(D, knn-1, axis=1) 272 | mink_vals = D[np.arange(N), mink_index[:, knn-1]] # the kth shortest val for D's each row 273 | filter_mask = np.tile(mink_vals.reshape(N, 1), (1, N)) # (N, N) filter mask 274 | adj_matrix = (D <= filter_mask) * 1 275 | if verbose: print('Create binary Adj Matrix... with mean ', np.mean(np.sum(adj_matrix, axis=1))) 276 | if verbose: print('Adjacency matrix computed.') 277 | 278 | if verbose: print('Computing eigenvalues.') 279 | eigvals = exact_eigvals(adj_matrix) 280 | if verbose: print('Eigenvalues computed.') 281 | 282 | if eigval_save_path is not None: 283 | os.makedirs(os.path.dirname(eigval_save_path), 284 | exist_ok=True) 285 | # Save eigenvalues. 286 | eigvals = eigvals.astype( 287 | eigval_save_precision) # reduce storage space. 288 | with open(eigval_save_path, 'wb+') as f: 289 | np.savez(f, eigvals=eigvals) 290 | if verbose: 291 | print('Eigenvalues saved to %s' % eigval_save_path) 292 | 293 | # Eigenvalues may be negative. Only care about the magnitude, not the sign. 294 | eigvals = np.abs(eigvals) 295 | 296 | prob = eigvals / eigvals.sum() 297 | 298 | prob = prob + np.finfo(float).eps 299 | entropy = -np.sum(prob * np.log2(prob)) 300 | 301 | return entropy 302 | 303 | 304 | 305 | if __name__ == '__main__': 306 | print('Testing Diffusion Spectral Entropy.') 307 | print('\n1st run, random vecs, without saving eigvals.') 308 | embedding_vectors = np.random.uniform(0, 1, (1000, 256)) 309 | DSE = diffusion_spectral_entropy(embedding_vectors=embedding_vectors) 310 | print('DSE =', DSE) 311 | 312 | print( 313 | '\n2nd run, random vecs, saving eigvals (np.float16). May be slightly off due to float16 saving.' 314 | ) 315 | tmp_path = './test_dse_eigval.npz' 316 | embedding_vectors = np.random.uniform(0, 1, (1000, 256)) 317 | DSE = diffusion_spectral_entropy(embedding_vectors=embedding_vectors, 318 | eigval_save_path=tmp_path) 319 | print('DSE =', DSE) 320 | 321 | print( 322 | '\n3rd run, loading eigvals from 2nd run. May be slightly off due to float16 saving.' 323 | ) 324 | embedding_vectors = None # does not matter, will be ignored anyways 325 | DSE = diffusion_spectral_entropy(embedding_vectors=embedding_vectors, 326 | eigval_save_path=tmp_path) 327 | print('DSE =', DSE) 328 | os.remove(tmp_path) 329 | 330 | print('\n4th run, random vecs, saving eigvals (np.float64).') 331 | embedding_vectors = np.random.uniform(0, 1, (1000, 256)) 332 | DSE = diffusion_spectral_entropy(embedding_vectors=embedding_vectors, 333 | eigval_save_path=tmp_path, 334 | eigval_save_precision=np.float64) 335 | print('DSE =', DSE) 336 | 337 | print('\n5th run, loading eigvals from 4th run. Shall be identical.') 338 | embedding_vectors = None # does not matter, will be ignored anyways 339 | DSE = diffusion_spectral_entropy(embedding_vectors=embedding_vectors, 340 | eigval_save_path=tmp_path) 341 | print('DSE =', DSE) 342 | os.remove(tmp_path) 343 | 344 | print('\n6th run, Classic Shannon Entropy.') 345 | embedding_vectors = np.random.uniform(0, 1, (1000, 256)) 346 | CSE = diffusion_spectral_entropy(embedding_vectors=embedding_vectors, 347 | classic_shannon_entropy=True) 348 | print('CSE =', CSE) 349 | 350 | print( 351 | '\n7th run, Entropy on diffusion matrix entries rather than eigenvalues.' 352 | ) 353 | embedding_vectors = np.random.uniform(0, 1, (1000, 256)) 354 | DSE_matrix_entry = diffusion_spectral_entropy( 355 | embedding_vectors=embedding_vectors, matrix_entry_entropy=True) 356 | print('DSE-matrix-entry =', DSE_matrix_entry) 357 | 358 | print( 359 | '\n8th run, Entropy on KNN binarized adjacency matrix.' 360 | ) 361 | embedding_vectors = np.random.uniform(0, 1, (1000, 256)) 362 | knn_binarized_entropy = adjacency_spectral_entropy( 363 | embedding_vectors=embedding_vectors, use_knn=True, knn=10, verbose=True) 364 | print('KNN binarized adjacency matrix =', knn_binarized_entropy) 365 | 366 | print( 367 | '\n9th run, Entropy on Gaussian adjacency matrix.' 368 | ) 369 | embedding_vectors = np.random.uniform(0, 1, (1000, 256)) 370 | gaussian_adj_entropy = adjacency_spectral_entropy( 371 | embedding_vectors=embedding_vectors, anisotropic=False, verbose=True) 372 | print('KNN binarized adjacency matrix =', gaussian_adj_entropy) 373 | 374 | print( 375 | '\n10th run, Entropy on Anisotropic Gaussian adjacency matrix.' 376 | ) 377 | embedding_vectors = np.random.uniform(0, 1, (1000, 256)) 378 | aniso_adj_entropy = adjacency_spectral_entropy( 379 | embedding_vectors=embedding_vectors, anisotropic=True, verbose=True) 380 | print('KNN binarized adjacency matrix =', aniso_adj_entropy) 381 | -------------------------------------------------------------------------------- /api/dsmi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from dse import diffusion_spectral_entropy, adjacency_spectral_entropy 3 | from sklearn.cluster import SpectralClustering 4 | import random 5 | 6 | 7 | def diffusion_spectral_mutual_information( 8 | embedding_vectors: np.array, 9 | reference_vectors: np.array, 10 | reference_discrete: bool = None, 11 | gaussian_kernel_sigma: float = 10, 12 | t: int = 1, 13 | chebyshev_approx: bool = False, 14 | num_repetitions: int = 5, 15 | n_clusters: int = 10, 16 | precomputed_clusters: np.array = None, 17 | classic_shannon_entropy: bool = False, 18 | matrix_entry_entropy: bool = False, 19 | num_bins_per_dim: int = 2, 20 | random_seed: int = 0, 21 | verbose: bool = False): 22 | ''' 23 | DSMI between two sets of random variables. 24 | The first (`embedding_vectors`) must be a set of N vectors each of D dimension. 25 | The second (`reference_vectors`) must be a set of N vectors each of D' dimension. 26 | D is not necessarily the same as D'. 27 | In some common cases, we may have the following as `reference_vectors` 28 | - class labels (D' == 1) of shape [N, 1] 29 | - flattened input signals/images of shape [N, D'] 30 | 31 | DSMI(A; B) = DSE(A) - DSE(A | B) 32 | where DSE is the diffusion spectral entropy. 33 | 34 | DSE(A | B) = sum_i [p(B = b_i) DSE(A | B = b_i)] 35 | where i = 0,1,...,m 36 | m = number of categories in random variable B 37 | if B itself is a discrete variable (e.g., class label), this is straightforward 38 | otherwise, we can use spectral clustering to create discrete categories/clusters in B 39 | 40 | For numerical consistency, instead of computing DSE(A) on all data points of A, 41 | we estimate it from a subset of A, with the size of subset equal to {B = b_i}. 42 | 43 | The final computation is: 44 | 45 | DSMI(A; B) = DSE(A) - DSE(A | B) = sum_i [p(B = b_i) (DSE(A*) - DSE(A | B = b_i))] 46 | where A* is a subsampled version of A, with len(A*) == len(B = b_i). 47 | 48 | args: 49 | embedding_vectors: np.array of shape [N, D] 50 | N: number of data points / samples 51 | D: number of feature dimensions of the neural representation 52 | 53 | reference_vectors: np.array of shape [N, D'] 54 | N: number of data points / samples 55 | D': number of feature dimensions of the neural representation or input/output variable 56 | 57 | reference_discrete: bool 58 | Whether `reference_vectors` is discrete or continuous. 59 | This determines whether or not we perform clustering/binning on `reference_vectors`. 60 | NOTE: If True, we assume D' == 1. Common case: `reference_vectors` is the discrete class labels. 61 | If not provided, will be inferred from `reference_vectors`. 62 | 63 | gaussian_kernel_sigma: float 64 | The bandwidth of Gaussian kernel (for computation of the diffusion matrix) 65 | Can be adjusted per the dataset. 66 | Increase if the data points are very far away from each other. 67 | 68 | t: int 69 | Power of diffusion matrix (equivalent to power of diffusion eigenvalues) 70 | <-> Iteration of diffusion process 71 | Usually small, e.g., 1 or 2. 72 | Can be adjusted per dataset. 73 | Rule of thumb: after powering eigenvalues to `t`, there should be approximately 74 | 1 percent of eigenvalues that remain larger than 0.01 75 | 76 | chebyshev_approx: bool 77 | Whether or not to use Chebyshev moments for faster approximation of eigenvalues. 78 | Currently we DO NOT RECOMMEND USING THIS. Eigenvalues may be changed quite a bit. 79 | 80 | num_repetitions: int 81 | Number of repetition during DSE(A*) estimation. 82 | The variance is usually low, so a small number shall suffice. 83 | 84 | random_seed: int 85 | Random seed. For DSE(A*) estimation repeatability. 86 | 87 | n_clusters: int 88 | Number of clusters for `reference_vectors`. 89 | Only used when `reference_discrete` is False (`reference_vectors` is not discrete). 90 | If D' == 1 --> will use scalar binning. 91 | If D' > 1 --> will use spectral clustering. 92 | 93 | precomputed_clusters: np.array 94 | If provided, will directly use it as the cluster assignments for `reference_vectors`. 95 | Only used when `reference_discrete` is False (`reference_vectors` is not discrete). 96 | NOTE: When you have a fixed set of `reference_vectors` (e.g., a set of images), 97 | you probably want to only compute the spectral clustering once, and recycle the computed 98 | clusters for subsequent DSMI computations. 99 | 100 | classic_shannon_entropy: bool 101 | Whether or not we use CSE to replace DSE in the computation. 102 | NOTE: If true, the resulting mutual information will be CSMI instead of DSMI. 103 | 104 | matrix_entry_entropy: bool 105 | An alternative formulation where, instead of computing the entropy on 106 | diffusion matrix eigenvalues, we compute the entropy on diffusion matrix entries. 107 | Only relevant to DSE. 108 | 109 | num_bins_per_dim: int 110 | Number of bins per feature dim. 111 | Only relevant to CSE (i.e., `classic_shannon_entropy` is True). 112 | 113 | verbose: bool 114 | Whether or not to print progress to console. 115 | ''' 116 | 117 | # Reshape from [N, ] to [N, 1]. 118 | if len(reference_vectors.shape) == 1: 119 | reference_vectors = reference_vectors.reshape( 120 | reference_vectors.shape[0], 1) 121 | 122 | N_embedding, _ = embedding_vectors.shape 123 | N_reference, D_reference = reference_vectors.shape 124 | 125 | if N_embedding != N_reference: 126 | if verbose: 127 | print( 128 | 'WARNING: DSMI embedding and reference do not have the same N: %s vs %s' 129 | % (N_embedding, N_reference)) 130 | 131 | if reference_discrete is None: 132 | # Infer whether `reference_vectors` is discrete. 133 | # Criteria: D' == 1 and `reference_vectors` is an integer type. 134 | reference_discrete = D_reference == 1 \ 135 | and np.issubdtype( 136 | reference_vectors.dtype, np.integer) 137 | 138 | # 139 | '''STEP 1. Prepare the category/cluster assignments.''' 140 | 141 | if reference_discrete: 142 | # `reference_vectors` is expected to be discrete class labels. 143 | assert D_reference == 1, \ 144 | 'DSMI `reference_discrete` is set to True, but shape of `reference_vectors` is not [N, 1].' 145 | precomputed_clusters = reference_vectors 146 | 147 | elif D_reference == 1: 148 | # `reference_vectors` is a set of continuous scalars. 149 | # Perform scalar binning if cluster assignments are not provided. 150 | if precomputed_clusters is None: 151 | vecs = reference_vectors.copy() 152 | # Min-Max scale each dimension. 153 | vecs = (vecs - np.min(vecs, axis=0)) / (np.max(vecs, axis=0) - 154 | np.min(vecs, axis=0)) 155 | # Bin along each dimension. 156 | bins = np.linspace(0, 1, n_clusters + 1)[:-1] 157 | vecs = np.digitize(vecs, bins=bins) 158 | precomputed_clusters = vecs 159 | 160 | else: 161 | # `reference_vectors` is a set of continuous vectors. 162 | # Perform spectral clustering if cluster assignments are not provided. 163 | if precomputed_clusters is None: 164 | cluster_op = SpectralClustering( 165 | n_clusters=n_clusters, 166 | affinity='nearest_neighbors', 167 | assign_labels='cluster_qr', 168 | random_state=0).fit(reference_vectors) 169 | precomputed_clusters = cluster_op.labels_ 170 | 171 | clusters_list, cluster_cnts = np.unique(precomputed_clusters, 172 | return_counts=True) 173 | 174 | # 175 | '''STEP 2. Compute DSMI.''' 176 | MI_by_class = [] 177 | 178 | for cluster_idx in clusters_list: 179 | # DSE(A | B = b_i) 180 | inds = (precomputed_clusters == cluster_idx).reshape(-1) 181 | embeddings_curr_class = embedding_vectors[inds, :] 182 | 183 | entropy_AgivenB_curr_class = diffusion_spectral_entropy( 184 | embedding_vectors=embeddings_curr_class, 185 | gaussian_kernel_sigma=gaussian_kernel_sigma, 186 | t=t, 187 | chebyshev_approx=chebyshev_approx, 188 | classic_shannon_entropy=classic_shannon_entropy, 189 | matrix_entry_entropy=matrix_entry_entropy, 190 | num_bins_per_dim=num_bins_per_dim) 191 | 192 | # DSE(A*) 193 | if random_seed is not None: 194 | random.seed(random_seed) 195 | entropy_A_estimation_list = [] 196 | for _ in np.arange(num_repetitions): 197 | rand_inds = np.array( 198 | random.sample(range(precomputed_clusters.shape[0]), 199 | k=np.sum(precomputed_clusters == cluster_idx))) 200 | embeddings_random_subset = embedding_vectors[rand_inds, :] 201 | 202 | entropy_A_subsample_rep = diffusion_spectral_entropy( 203 | embedding_vectors=embeddings_random_subset, 204 | gaussian_kernel_sigma=gaussian_kernel_sigma, 205 | t=t, 206 | chebyshev_approx=chebyshev_approx, 207 | classic_shannon_entropy=classic_shannon_entropy, 208 | matrix_entry_entropy=matrix_entry_entropy, 209 | num_bins_per_dim=num_bins_per_dim) 210 | entropy_A_estimation_list.append(entropy_A_subsample_rep) 211 | 212 | entropy_A_estimation = np.mean(entropy_A_estimation_list) 213 | 214 | MI_by_class.append((entropy_A_estimation - entropy_AgivenB_curr_class)) 215 | 216 | mutual_information = np.sum(cluster_cnts / np.sum(cluster_cnts) * 217 | np.array(MI_by_class)) 218 | 219 | return mutual_information, precomputed_clusters 220 | 221 | def adjacency_spectral_mutual_information( 222 | embedding_vectors: np.array, 223 | reference_vectors: np.array, 224 | reference_discrete: bool = None, 225 | gaussian_kernel_sigma: float = 10, 226 | use_knn: bool = False, 227 | anisotropic: bool = False, 228 | num_repetitions: int = 5, 229 | n_clusters: int = 10, 230 | precomputed_clusters: np.array = None, 231 | random_seed: int = 0, 232 | verbose: bool = False): 233 | ''' 234 | MI between two sets of random variables using adjacency matrix. 235 | The first (`embedding_vectors`) must be a set of N vectors each of D dimension. 236 | The second (`reference_vectors`) must be a set of N vectors each of D' dimension. 237 | D is not necessarily the same as D'. 238 | In some common cases, we may have the following as `reference_vectors` 239 | - class labels (D' == 1) of shape [N, 1] 240 | - flattened input signals/images of shape [N, D'] 241 | 242 | ASMI(A; B) = ASE(A) - ASE(A | B) 243 | where ASE is the adjacency spectral entropy. 244 | 245 | ASE(A | B) = sum_i [p(B = b_i) ASE(A | B = b_i)] 246 | where i = 0,1,...,m 247 | m = number of categories in random variable B 248 | if B itself is a discrete variable (e.g., class label), this is straightforward 249 | otherwise, we can use spectral clustering to create discrete categories/clusters in B 250 | 251 | For numerical consistency, instead of computing DSE(A) on all data points of A, 252 | we estimate it from a subset of A, with the size of subset equal to {B = b_i}. 253 | 254 | The final computation is: 255 | 256 | DSMI(A; B) = DSE(A) - DSE(A | B) = sum_i [p(B = b_i) (DSE(A*) - DSE(A | B = b_i))] 257 | where A* is a subsampled version of A, with len(A*) == len(B = b_i). 258 | 259 | args: 260 | embedding_vectors: np.array of shape [N, D] 261 | N: number of data points / samples 262 | D: number of feature dimensions of the neural representation 263 | 264 | reference_vectors: np.array of shape [N, D'] 265 | N: number of data points / samples 266 | D': number of feature dimensions of the neural representation or input/output variable 267 | 268 | reference_discrete: bool 269 | Whether `reference_vectors` is discrete or continuous. 270 | This determines whether or not we perform clustering/binning on `reference_vectors`. 271 | NOTE: If True, we assume D' == 1. Common case: `reference_vectors` is the discrete class labels. 272 | If not provided, will be inferred from `reference_vectors`. 273 | 274 | gaussian_kernel_sigma: float 275 | The bandwidth of Gaussian kernel (for computation of the diffusion matrix) 276 | Can be adjusted per the dataset. 277 | Increase if the data points are very far away from each other. 278 | 279 | num_repetitions: int 280 | Number of repetition during DSE(A*) estimation. 281 | The variance is usually low, so a small number shall suffice. 282 | 283 | random_seed: int 284 | Random seed. For DSE(A*) estimation repeatability. 285 | 286 | n_clusters: int 287 | Number of clusters for `reference_vectors`. 288 | Only used when `reference_discrete` is False (`reference_vectors` is not discrete). 289 | If D' == 1 --> will use scalar binning. 290 | If D' > 1 --> will use spectral clustering. 291 | 292 | precomputed_clusters: np.array 293 | If provided, will directly use it as the cluster assignments for `reference_vectors`. 294 | Only used when `reference_discrete` is False (`reference_vectors` is not discrete). 295 | NOTE: When you have a fixed set of `reference_vectors` (e.g., a set of images), 296 | you probably want to only compute the spectral clustering once, and recycle the computed 297 | clusters for subsequent DSMI computations. 298 | 299 | matrix_entry_entropy: bool 300 | An alternative formulation where, instead of computing the entropy on 301 | diffusion matrix eigenvalues, we compute the entropy on diffusion matrix entries. 302 | Only relevant to DSE. 303 | 304 | verbose: bool 305 | Whether or not to print progress to console. 306 | ''' 307 | 308 | # Reshape from [N, ] to [N, 1]. 309 | if len(reference_vectors.shape) == 1: 310 | reference_vectors = reference_vectors.reshape( 311 | reference_vectors.shape[0], 1) 312 | 313 | N_embedding, _ = embedding_vectors.shape 314 | N_reference, D_reference = reference_vectors.shape 315 | 316 | if N_embedding != N_reference: 317 | if verbose: 318 | print( 319 | 'WARNING: ASMI embedding and reference do not have the same N: %s vs %s' 320 | % (N_embedding, N_reference)) 321 | 322 | if reference_discrete is None: 323 | # Infer whether `reference_vectors` is discrete. 324 | # Criteria: D' == 1 and `reference_vectors` is an integer type. 325 | reference_discrete = D_reference == 1 \ 326 | and np.issubdtype( 327 | reference_vectors.dtype, np.integer) 328 | 329 | # 330 | '''STEP 1. Prepare the category/cluster assignments.''' 331 | 332 | if reference_discrete: 333 | # `reference_vectors` is expected to be discrete class labels. 334 | assert D_reference == 1, \ 335 | 'DSMI `reference_discrete` is set to True, but shape of `reference_vectors` is not [N, 1].' 336 | precomputed_clusters = reference_vectors 337 | 338 | elif D_reference == 1: 339 | # `reference_vectors` is a set of continuous scalars. 340 | # Perform scalar binning if cluster assignments are not provided. 341 | if precomputed_clusters is None: 342 | vecs = reference_vectors.copy() 343 | # Min-Max scale each dimension. 344 | vecs = (vecs - np.min(vecs, axis=0)) / (np.max(vecs, axis=0) - 345 | np.min(vecs, axis=0)) 346 | # Bin along each dimension. 347 | bins = np.linspace(0, 1, n_clusters + 1)[:-1] 348 | vecs = np.digitize(vecs, bins=bins) 349 | precomputed_clusters = vecs 350 | 351 | else: 352 | # `reference_vectors` is a set of continuous vectors. 353 | # Perform spectral clustering if cluster assignments are not provided. 354 | if precomputed_clusters is None: 355 | cluster_op = SpectralClustering( 356 | n_clusters=n_clusters, 357 | affinity='nearest_neighbors', 358 | assign_labels='cluster_qr', 359 | random_state=0).fit(reference_vectors) 360 | precomputed_clusters = cluster_op.labels_ 361 | 362 | clusters_list, cluster_cnts = np.unique(precomputed_clusters, 363 | return_counts=True) 364 | 365 | # 366 | '''STEP 2. Compute ASMI.''' 367 | MI_by_class = [] 368 | 369 | for cluster_idx in clusters_list: 370 | # DSE(A | B = b_i) 371 | inds = (precomputed_clusters == cluster_idx).reshape(-1) 372 | embeddings_curr_class = embedding_vectors[inds, :] 373 | 374 | entropy_AgivenB_curr_class = adjacency_spectral_entropy( 375 | embedding_vectors=embeddings_curr_class, 376 | gaussian_kernel_sigma=gaussian_kernel_sigma, 377 | use_knn=use_knn, 378 | anisotropic=anisotropic) 379 | 380 | # ASE(A*) 381 | if random_seed is not None: 382 | random.seed(random_seed) 383 | entropy_A_estimation_list = [] 384 | for _ in np.arange(num_repetitions): 385 | rand_inds = np.array( 386 | random.sample(range(precomputed_clusters.shape[0]), 387 | k=np.sum(precomputed_clusters == cluster_idx))) 388 | embeddings_random_subset = embedding_vectors[rand_inds, :] 389 | 390 | entropy_A_subsample_rep = adjacency_spectral_entropy( 391 | embedding_vectors=embeddings_random_subset, 392 | gaussian_kernel_sigma=gaussian_kernel_sigma, 393 | use_knn=use_knn, 394 | anisotropic=anisotropic) 395 | entropy_A_estimation_list.append(entropy_A_subsample_rep) 396 | 397 | entropy_A_estimation = np.mean(entropy_A_estimation_list) 398 | 399 | MI_by_class.append((entropy_A_estimation - entropy_AgivenB_curr_class)) 400 | 401 | mutual_information = np.sum(cluster_cnts / np.sum(cluster_cnts) * 402 | np.array(MI_by_class)) 403 | 404 | return mutual_information, precomputed_clusters 405 | 406 | 407 | if __name__ == '__main__': 408 | print('Testing Diffusion Spectral Mutual Information.') 409 | print('\n1st run. DSMI, Embeddings vs discrete class labels.') 410 | embedding_vectors = np.random.uniform(0, 1, (1000, 256)) 411 | class_labels = np.uint8(np.random.uniform(0, 11, (1000, 1))) 412 | DSMI, _ = diffusion_spectral_mutual_information( 413 | embedding_vectors=embedding_vectors, reference_vectors=class_labels) 414 | print('DSMI =', DSMI) 415 | 416 | print('\n2nd run. DSMI, Embeddings vs continuous scalars') 417 | embedding_vectors = np.random.uniform(0, 1, (1000, 256)) 418 | continuous_scalars = np.random.uniform(-1, 1, (1000, 1)) 419 | DSMI, _ = diffusion_spectral_mutual_information( 420 | embedding_vectors=embedding_vectors, 421 | reference_vectors=continuous_scalars) 422 | print('DSMI =', DSMI) 423 | 424 | print('\n3rd run. DSMI, Embeddings vs Input Image') 425 | embedding_vectors = np.random.uniform(0, 1, (1000, 256)) 426 | input_image = np.random.uniform(-1, 1, (1000, 3, 32, 32)) 427 | input_image = input_image.reshape(input_image.shape[0], -1) 428 | DSMI, _ = diffusion_spectral_mutual_information( 429 | embedding_vectors=embedding_vectors, reference_vectors=input_image) 430 | print('DSMI =', DSMI) 431 | 432 | print('\n4th run. DSMI, Classification dataset.') 433 | from sklearn.datasets import make_classification 434 | embedding_vectors, class_labels = make_classification(n_samples=1000, 435 | n_features=5) 436 | DSMI, _ = diffusion_spectral_mutual_information( 437 | embedding_vectors=embedding_vectors, reference_vectors=class_labels) 438 | print('DSMI =', DSMI) 439 | 440 | print('\n5th run. CSMI, Classification dataset.') 441 | embedding_vectors, class_labels = make_classification(n_samples=1000, 442 | n_features=5) 443 | CSMI, _ = diffusion_spectral_mutual_information( 444 | embedding_vectors=embedding_vectors, 445 | reference_vectors=class_labels, 446 | classic_shannon_entropy=True) 447 | print('CSMI =', CSMI) 448 | 449 | print('\n6th run. DSMI-matrix-entry, Classification dataset.') 450 | embedding_vectors, class_labels = make_classification(n_samples=1000, 451 | n_features=5) 452 | DSMI_matrix_entry, _ = diffusion_spectral_mutual_information( 453 | embedding_vectors=embedding_vectors, 454 | reference_vectors=class_labels, 455 | matrix_entry_entropy=True) 456 | print('DSMI-matrix-entry =', DSMI_matrix_entry) 457 | 458 | print('\n7th run. ASMI-KNN, Classification dataset.') 459 | embedding_vectors, class_labels = make_classification(n_samples=1000, 460 | n_features=5) 461 | ASMI_knn, _ = adjacency_spectral_mutual_information( 462 | embedding_vectors=embedding_vectors, 463 | reference_vectors=class_labels, 464 | use_knn=True) 465 | print('ASMI-KNN =', ASMI_knn) 466 | 467 | print('\n7th run. ASMI-Gaussian, Classification dataset.') 468 | embedding_vectors, class_labels = make_classification(n_samples=1000, 469 | n_features=5) 470 | ASMI_gausadj, _ = adjacency_spectral_mutual_information( 471 | embedding_vectors=embedding_vectors, 472 | reference_vectors=class_labels) 473 | print('ASMI-Gaussian-Adj =', ASMI_gausadj) 474 | 475 | print('\n8th run. ASMI-Gaussian-Anisotropic, Classification dataset.') 476 | embedding_vectors, class_labels = make_classification(n_samples=1000, 477 | n_features=5) 478 | ASMI_anisotropic, _ = adjacency_spectral_mutual_information( 479 | embedding_vectors=embedding_vectors, 480 | reference_vectors=class_labels, 481 | anisotropic=True) 482 | print('ASMI-Anisotropic-Adj =', ASMI_anisotropic) 483 | 484 | 485 | -------------------------------------------------------------------------------- /api/information_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from DiffusionEMD.diffusion_emd import estimate_dos 3 | 4 | 5 | def approx_eigvals(A: np.array, filter_thr: float = 1e-3): 6 | ''' 7 | Estimate the eigenvalues of a matrix `A` using 8 | Chebyshev approximation of the eigenspectrum. 9 | 10 | Assuming the eigenvalues of `A` are within [-1, 1]. 11 | 12 | There is no guarantee the set of eigenvalues are accurate. 13 | ''' 14 | 15 | matrix = A.copy() 16 | N = matrix.shape[0] 17 | 18 | if filter_thr is not None: 19 | matrix[np.abs(matrix) < filter_thr] = 0 20 | 21 | # Chebyshev approximation of eigenspectrum. 22 | eigs, cdf = estimate_dos(matrix) 23 | 24 | # CDF to PDF conversion. 25 | pdf = np.zeros_like(cdf) 26 | for i in range(len(cdf) - 1): 27 | pdf[i] = cdf[i + 1] - cdf[i] 28 | 29 | # Estimate the set of eigenvalues. 30 | counts = N * pdf / np.sum(pdf) 31 | eigenvalues = [] 32 | for i, count in enumerate(counts): 33 | if np.round(count) > 0: 34 | eigenvalues += [eigs[i]] * int(np.round(count)) 35 | 36 | eigenvalues = np.array(eigenvalues) 37 | 38 | return eigenvalues 39 | 40 | 41 | def exact_eigvals(A: np.array): 42 | ''' 43 | Compute the exact eigenvalues. 44 | ''' 45 | if np.allclose(A, A.T, rtol=1e-5, atol=1e-8): 46 | # Symmetric matrix. 47 | eigenvalues = np.linalg.eigvalsh(A) 48 | else: 49 | eigenvalues = np.linalg.eigvals(A) 50 | 51 | return eigenvalues 52 | 53 | 54 | def exact_eig(A: np.array): 55 | ''' 56 | Compute the exact eigenvalues & vecs. 57 | ''' 58 | 59 | #return np.ones(A.shape[0]), np.ones((A.shape[0],A.shape[0])) 60 | if np.allclose(A, A.T, rtol=1e-5, atol=1e-8): 61 | # Symmetric matrix. 62 | eigenvalues_P, eigenvectors_P = np.linalg.eigh(A) 63 | else: 64 | eigenvalues_P, eigenvectors_P = np.linalg.eig(A) 65 | 66 | # Sort eigenvalues 67 | sorted_idx = np.argsort(eigenvalues_P)[::-1] 68 | eigenvalues_P = eigenvalues_P[sorted_idx] 69 | eigenvectors_P = eigenvectors_P[:, sorted_idx] 70 | 71 | return eigenvalues_P, eigenvectors_P 72 | -------------------------------------------------------------------------------- /assets/compare-cifar10-supervised-resnet-ConvInitStd-1e-2-1e-1-seed1-2-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenLiu-1996/DiffusionSpectralEntropy/faf208cbfaf39fc2264053b742afe58083a6dd5a/assets/compare-cifar10-supervised-resnet-ConvInitStd-1e-2-1e-1-seed1-2-3.png -------------------------------------------------------------------------------- /assets/compare-stl10-supervised-resnet-ConvInitStd-1e-2-1e-1-seed1-2-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenLiu-1996/DiffusionSpectralEntropy/faf208cbfaf39fc2264053b742afe58083a6dd5a/assets/compare-stl10-supervised-resnet-ConvInitStd-1e-2-1e-1-seed1-2-3.png -------------------------------------------------------------------------------- /assets/curse_of_dim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenLiu-1996/DiffusionSpectralEntropy/faf208cbfaf39fc2264053b742afe58083a6dd5a/assets/curse_of_dim.png -------------------------------------------------------------------------------- /assets/def_DSE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenLiu-1996/DiffusionSpectralEntropy/faf208cbfaf39fc2264053b742afe58083a6dd5a/assets/def_DSE.png -------------------------------------------------------------------------------- /assets/def_DSMI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenLiu-1996/DiffusionSpectralEntropy/faf208cbfaf39fc2264053b742afe58083a6dd5a/assets/def_DSMI.png -------------------------------------------------------------------------------- /assets/logos/MetaAI_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenLiu-1996/DiffusionSpectralEntropy/faf208cbfaf39fc2264053b742afe58083a6dd5a/assets/logos/MetaAI_logo.png -------------------------------------------------------------------------------- /assets/logos/Mila_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenLiu-1996/DiffusionSpectralEntropy/faf208cbfaf39fc2264053b742afe58083a6dd5a/assets/logos/Mila_logo.png -------------------------------------------------------------------------------- /assets/logos/Yale_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenLiu-1996/DiffusionSpectralEntropy/faf208cbfaf39fc2264053b742afe58083a6dd5a/assets/logos/Yale_logo.png -------------------------------------------------------------------------------- /assets/main_figure_DSE(Z).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenLiu-1996/DiffusionSpectralEntropy/faf208cbfaf39fc2264053b742afe58083a6dd5a/assets/main_figure_DSE(Z).png -------------------------------------------------------------------------------- /assets/main_figure_DSMI(Z;X).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenLiu-1996/DiffusionSpectralEntropy/faf208cbfaf39fc2264053b742afe58083a6dd5a/assets/main_figure_DSMI(Z;X).png -------------------------------------------------------------------------------- /assets/main_figure_DSMI(Z;Y).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenLiu-1996/DiffusionSpectralEntropy/faf208cbfaf39fc2264053b742afe58083a6dd5a/assets/main_figure_DSMI(Z;Y).png -------------------------------------------------------------------------------- /assets/method_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenLiu-1996/DiffusionSpectralEntropy/faf208cbfaf39fc2264053b742afe58083a6dd5a/assets/method_comparison.png -------------------------------------------------------------------------------- /assets/procedure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenLiu-1996/DiffusionSpectralEntropy/faf208cbfaf39fc2264053b742afe58083a6dd5a/assets/procedure.png -------------------------------------------------------------------------------- /assets/visualize_embeddings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenLiu-1996/DiffusionSpectralEntropy/faf208cbfaf39fc2264053b742afe58083a6dd5a/assets/visualize_embeddings.png -------------------------------------------------------------------------------- /assets/vs_imagenet_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenLiu-1996/DiffusionSpectralEntropy/faf208cbfaf39fc2264053b742afe58083a6dd5a/assets/vs_imagenet_acc.png --------------------------------------------------------------------------------