├── .gitignore ├── LICENSE.md ├── README.md ├── examples ├── linear_metrics.ipynb ├── pairwise_distances.ipynb └── stochastic_metrics.ipynb ├── netrep ├── __init__.py ├── conv_layers.py ├── kmeans.py ├── metrics │ ├── __init__.py │ ├── cka.py │ ├── kernel.py │ ├── linear.py │ ├── perm.py │ ├── stochastic.py │ └── stochastic_process.py ├── multiset.py ├── rbf_sampler.py ├── utils.py └── validation.py ├── setup.py └── tests ├── test_metrics.py ├── test_multiprocessing.py ├── test_multiset.py ├── test_stochastic.py ├── test_stochastic_process.py └── test_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | 3 | # Packages 4 | *.egg 5 | *.egg-info 6 | dist 7 | build 8 | eggs 9 | parts 10 | bin 11 | var 12 | sdist 13 | develop-eggs 14 | .installed.cfg 15 | lib 16 | lib64 17 | __pycache__ 18 | 19 | # Installer logs 20 | pip-log.txt 21 | 22 | # Unit test / coverage reports 23 | .pytest_cache 24 | 25 | # Jupyter 26 | .ipynb_checkpoints 27 | notebooks/ 28 | 29 | # VirtualEnv 30 | env/ 31 | env2/ 32 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Alex Williams 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generalized Shape Metrics on Neural Representations 2 | 3 | ![Generalized Shape Metrics on Neural Representations](https://user-images.githubusercontent.com/636625/139737239-5e3054fe-0465-4c9b-b148-a43acc62aa8e.png) 4 | 5 | In neuroscience and in deep learning, quantifying the (dis)similarity of neural representations across networks is a topic of substantial interest. 6 | 7 | This code package computes [*metrics*](https://en.wikipedia.org/wiki/Metric_(mathematics)) — notions of distance that satisfy the triangle inequality — between neural representations. If we record the activity of `K` networks, we can compute all pairwise distances and collect them into a `K × K` distance matrix. The triangle inequality ensures that all of these distance relationships are, in some sense, self-consistent. This enables us to apply off-the-shelf algorithms for clustering and dimensionality reduction, which are available through many open-source packages such as [scikit-learn](https://scikit-learn.org/). 8 | 9 | Two conference papers **([Neurips '21](https://arxiv.org/abs/2110.14739), [ICLR '23](https://arxiv.org/abs/2211.11665))** describe the approach 10 | 11 | ``` 12 | @inproceedings{neural_shape_metrics, 13 | author = {Alex H. Williams and Erin Kunz and Simon Kornblith and Scott W. Linderman}, 14 | title = {Generalized Shape Metrics on Neural Representations}, 15 | year = {2021}, 16 | booktitle = {Advances in Neural Information Processing Systems}, 17 | volume = {34}, 18 | } 19 | 20 | @inproceedings{stochastic_neural_shape_metrics, 21 | author = {Lyndon R. Duong and Jingyang Zhou and Josue Nassar and Jules Berman and Jeroen Olieslagers and Alex H. Williams}, 22 | title = {Representational dissimilarity metric spaces for stochastic neural networks}, 23 | year = {2023}, 24 | booktitle = {International Conference on Learning Representations}, 25 | } 26 | ``` 27 | 28 | We presented an early version of this work at COSYNE 2021 (see [**7 minute summary on youtube**](https://www.youtube.com/watch?v=Lt_Vo-tQcW0)), and a full workshop talk at COSYNE 2023 ([**30 minute talk on youtube**](https://www.youtube.com/watch?v=e02DWc2z8Hc)). 29 | 30 | **Note:** This research code remains a work-in-progress to some extent. It could use more documentation and examples. Please use at your own risk and reach out to us (alex.h.willia@gmail.com) if you have questions. 31 | 32 | ## A short and preliminary guide 33 | 34 | To install, set up standard python libraries () and then install via `pip`: 35 | 36 | ``` 37 | git clone https://github.com/ahwillia/netrep 38 | cd netrep/ 39 | pip install -e . 40 | ``` 41 | 42 | Since the code is preliminary, you will be able to use `git pull` to get updates as we release them. 43 | 44 | ### Computing the distance between two networks 45 | 46 | The metrics implemented in this library are extensions of [Procrustes distance](https://en.wikipedia.org/wiki/Procrustes_analysis). Some useful background can be found in Dryden & Mardia's textbook on [*Statistical Shape Analysis*](https://www.wiley.com/en-us/Statistical+Shape+Analysis%3A+With+Applications+in+R%2C+2nd+Edition-p-9780470699621). 47 | 48 | The code uses an API similar to [scikit-learn](https://scikit-learn.org/), so we recommend familiarizing yourself with that package. 49 | 50 | We start by defining a metric object. The simplest metric to use is `LinearMetric`, which has a hyperparameter `alpha` which regularizes the alignment operation: 51 | 52 | ```python 53 | from netrep.metrics import LinearMetric 54 | 55 | # Rotationally invariant metric (fully regularized). 56 | proc_metric = LinearMetric(alpha=1.0, center_columns=True) 57 | 58 | # Linearly invariant metric (no regularization). 59 | cca_metric = LinearMetric(alpha=0.0, center_columns=True) 60 | ``` 61 | 62 | Valid values for the regularization term are `0 <= alpha <= 1`. When `alpha == 0`, the resulting metric is similar to CCA and allows for an invertible linear transformation to align the activations. When `alpha == 1`, the model is fully regularized and only allows for rotational alignments. 63 | 64 | We reccomend starting with the fully regularized model where `alpha == 1`. 65 | 66 | Next, we define the data, which are stored in matrices `X` and `Y` that hold paired activations from two networks. Each row of `X` and `Y` contains a matched sample of neural activations. For example, we might record the activity of 500 neurons in visual cortex in response to 1000 images (or, analogously, feed 1000 images into a deep network and store the activations of 500 hidden units). We would collect the neural responses into a `1000 x 500` matrix `X`. We'd then repeat the experiment in a second animal and store the responses in a second matrix `Y`. 67 | 68 | By default if the number of neurons in `X` and `Y` do not match, we zero-pad the dataset with fewer neurons to match the size of the larger dataset. This can be justified on the basis that zero-padding does not distort the geometry of the dataset, it simply embeds it into a higher dimension so that the two may be compared. Alternatively, one could preprocess the data by using PCA (for example) to project the data into a common, lower-dimensional space. The default zero-padding behavior can be deactivated as follows: 69 | 70 | ```python 71 | LinearMetric(alpha=1.0, zero_pad=True) # default behavior 72 | 73 | LinearMetric(alpha=1.0, zero_pad=False) # throws an error if number of columns in X and Y don't match 74 | ``` 75 | 76 | Now we are ready to fit alignment transformations (which account for the neurons being mismatched across networks). Then, we evaluate the distance in the aligned space. These are respectively done by calling `fit(...)` and `score(...)` functions on the metric instance. 77 | 78 | ```python 79 | # Given 80 | # ----- 81 | # X : ndarray, (num_samples x num_neurons), activations from first network. 82 | # 83 | # Y : ndarray, (num_samples x num_neurons), activations from second network. 84 | # 85 | # metric : an instance of LinearMetric(...) 86 | 87 | # Fit alignment transformations. 88 | metric.fit(X, Y) 89 | 90 | # Evaluate distance between X and Y, using alignments fit above. 91 | dist = metric.score(X, Y) 92 | ``` 93 | 94 | Since the model is fit and evaluated by separate function calls, it is very easy to cross-validate the estimated distances: 95 | 96 | ```python 97 | # Given 98 | # ----- 99 | # X_train : ndarray, (num_train_samples x num_neurons), training data from first network. 100 | # 101 | # Y_train : ndarray, (num_train_samples x num_neurons), training data from second network. 102 | # 103 | # X_test : ndarray, (num_test_samples x num_neurons), test data from first network. 104 | # 105 | # Y_test : ndarray, (num_test_samples x num_neurons), test data from second network. 106 | # 107 | # metric : an instance of LinearMetric(...) 108 | 109 | # Fit alignment transformations to the training set. 110 | metric.fit(X_train, Y_train) 111 | 112 | # Evaluate distance on the test set. 113 | dist = metric.score(X_test, Y_test) 114 | ``` 115 | 116 | In fact, we can use scikit-learn's built-in cross-validation tools, since `LinearMetric` extends the `sklearn.base.BaseEstimator` class. So, if you'd like to do 10-fold cross-validation, for example: 117 | 118 | ```python 119 | from sklearn.model_selection import cross_validate 120 | results = cross_validate(metric, X, Y, return_train_score=True, cv=10) 121 | results["train_score"] # holds 10 distance estimates between X and Y, using training data. 122 | results["test_score"] # holds 10 distance estimates between X and Y, using heldout data. 123 | ``` 124 | 125 | We can also call `transform(...)` function to align the activations 126 | 127 | ```python 128 | # Fit alignment transformations. 129 | metric.fit(X, Y) 130 | 131 | # Apply alignment transformations. 132 | X_aligned, Y_aligned = metric.transform(X, Y) 133 | 134 | # Now, e.g., you could use PCA to visualize the data in the aligned space... 135 | ``` 136 | 137 | ## Stochastic shape metrics 138 | 139 | We also provide a way to compare between stochastic neural responses (e.g. biological neural network responses to stimulus repetitions, or latent activations in variational autoencoders). The API is similar to `LinearMetric()`, but requires differently-formatted inputs. 140 | 141 | :warning: WARNING :warning: *Fitting the optimal orthogonal transformation in stochastic shape metrics involves a nonconvex optimization procedure that can be caught in local minima. Please be carefult and use the `n_restarts` parameter to run the optimization algorithm multiple times.* 142 | 143 | **1) Stochastic shape metrics using** `GaussianStochasticMetric()` 144 | 145 | The first method models network response distributions as multivariate Gaussians, and computes distances based on the analytic solution to the 2-Wasserstein distance between two Gaussians. This involves computing class-conditional means and covariances for each network, then computing the metric as follows. 146 | 147 | ```python 148 | # Given 149 | # ----- 150 | # Xi : Tuple[ndarray, ndarray] 151 | # The first array is (num_classes x num_neurons) array of means and the second array is (num_classes x num_neurons x num_neurons) covariance matrices of first network. 152 | # 153 | # Xj : Tuple[ndarray, ndarray] 154 | # Same as Xi, but for the second network's responses. 155 | # 156 | # alpha: float between [0, 2]. 157 | # When alpha=2, this reduces to the deterministic shape metric. When alpha=1, this is the 2-Wasserstein between two Gaussians. When alpha=0, this is the Bures metric between the two sets of covariance matrices. 158 | 159 | # Fit alignment 160 | metric = GaussianStochasticMetric(alpha, init='rand', n_restarts=50) 161 | metric.fit(Xi, Xj) 162 | 163 | # Evaluate the distance between the two networks 164 | dist = metric.score(Xi, Xj) 165 | ``` 166 | 167 | **2) Stochastic shape metrics using** `EnergyStochasticMetric()` 168 | 169 | We also provide stochastic shape metrics based on the Energy distance. This metric is non-parametric (does not make any response distribution assumptions). It can therefore take into account higher-order moments between neurons. 170 | 171 | ```python 172 | # Given 173 | # ----- 174 | # Xi : ndarray, (num_classes x num_repeats x num_neurons) 175 | # First network's responses. 176 | # 177 | # Xj : ndarray, (num_classes x num_repeats x num_neurons) 178 | # Same as Xi, but for the second network's responses. 179 | # 180 | 181 | # Fit alignment 182 | metric = EnergyStochasticMetric() 183 | metric.fit(Xi, Xj) 184 | 185 | # Evaluate the distance between the two networks 186 | dist = metric.score(Xi, Xj) 187 | ``` 188 | 189 | ### Computing distances between many networks 190 | 191 | Things start to get really interesting when we start to consider larger cohorts containing more than just two networks. The `netrep.multiset` file contains some useful methods. Let `Xs = [X1, X2, X3, ..., Xk]` be a list of `num_samples x num_neurons` matrices similar to those described above. We can do the following: 192 | 193 | **Computing all pairwise distances.** The following returns a symmetric `k x k` matrix of distances. 194 | 195 | ```python 196 | metric = LinearMetric(alpha=1.0) 197 | 198 | # Compute kxk distance matrices (leverages multiprocessing). 199 | dist_matrix, _ = metric.pairwise_distances(Xs) 200 | ``` 201 | 202 | By setting `verbose=True`, we print out a progress bar which might be useful for very large datasets. 203 | 204 | We can also split data into training sets and test sets. 205 | 206 | ```python 207 | # Split data into training and testing sets 208 | splitdata = [np.array_split(X, 2) for X in Xs] 209 | traindata = [X_train for (X_train, X_test) in splitdata] 210 | testdata = [X_test for (X_train, X_test) in splitdata] 211 | 212 | # Compute all pairwise train and test distances. 213 | train_dists, test_dists = metric.pairwise_distances(traindata, testdata) 214 | ``` 215 | -------------------------------------------------------------------------------- /examples/pairwise_distances.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Computing pairwise distances between networks\n", 9 | "\n", 10 | "To use the built-in `pairwise_distances()` function, we recommend setting the envrionment variable 'OMP_NUM_THREADS' to 1. This will prevent oversubscription (i.e. using more threads than available cores), which can cause the function to run slower than expected." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 3, 16 | "metadata": {}, 17 | "outputs": [ 18 | { 19 | "name": "stdout", 20 | "output_type": "stream", 21 | "text": [ 22 | "num cpus: 128\n" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "import multiprocessing\n", 28 | "import os\n", 29 | "print(f'num cpus: {multiprocessing.cpu_count()}')\n", 30 | "# set omp threads to 1 to avoid slowdowns due to parallelization\n", 31 | "os.environ['OMP_NUM_THREADS'] = '1'\n", 32 | "\n", 33 | "import numpy as np\n", 34 | "from netrep.metrics import LinearMetric " 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 4, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "name": "stdout", 44 | "output_type": "stream", 45 | "text": [ 46 | "Parallelizing 2016 distance calculations with 128 processes.\n" 47 | ] 48 | }, 49 | { 50 | "name": "stderr", 51 | "output_type": "stream", 52 | "text": [ 53 | "Computing distances: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 2016/2016 [00:07<00:00, 285.44it/s]\n" 54 | ] 55 | }, 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "(64, 64) 5478.173387196459\n", 61 | "(64, 64) 6333.232718084371\n" 62 | ] 63 | } 64 | ], 65 | "source": [ 66 | "def get_data_linear(n_networks, n_images, n_neurons, rng):\n", 67 | " return [rng.standard_normal((n_images, n_neurons)) for _ in range(n_networks)]\n", 68 | "\n", 69 | "def compute_pairwise_linear(rng):\n", 70 | " n_networks, n_images, n_neurons = 64, 1024, 64\n", 71 | " metric = LinearMetric()\n", 72 | " train_data = get_data_linear(n_networks, n_images, n_neurons, rng)\n", 73 | " test_data = get_data_linear(n_networks, n_images, n_neurons, rng)\n", 74 | " D_train, D_test = metric.pairwise_distances(train_data, test_data)\n", 75 | "\n", 76 | " print(D_train.shape, D_train.sum())\n", 77 | " print(D_test.shape, D_test.sum())\n", 78 | "\n", 79 | "\n", 80 | "rng = np.random.default_rng(0)\n", 81 | "compute_pairwise_linear(rng)" 82 | ] 83 | } 84 | ], 85 | "metadata": { 86 | "kernelspec": { 87 | "display_name": "Python 3 (ipykernel)", 88 | "language": "python", 89 | "name": "python3" 90 | }, 91 | "language_info": { 92 | "codemirror_mode": { 93 | "name": "ipython", 94 | "version": 3 95 | }, 96 | "file_extension": ".py", 97 | "mimetype": "text/x-python", 98 | "name": "python", 99 | "nbconvert_exporter": "python", 100 | "pygments_lexer": "ipython3", 101 | "version": "3.9.7" 102 | }, 103 | "orig_nbformat": 4 104 | }, 105 | "nbformat": 4, 106 | "nbformat_minor": 2 107 | } 108 | -------------------------------------------------------------------------------- /netrep/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahwillia/netrep/0186b8a77ec1ebaf541cc8f7173cb2556df8a8f0/netrep/__init__.py -------------------------------------------------------------------------------- /netrep/conv_layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | from netrep.validation import check_equal_shapes 4 | from tqdm import tqdm 5 | 6 | 7 | def convolve_metric(metric, X, Y): 8 | """ 9 | Computes representation metric between convolutional layers, 10 | convolving activations with boundary conditions. 11 | 12 | Parameters 13 | ---------- 14 | metric : Metric 15 | Specifies metric to compute. 16 | X : ndarray 17 | Activations from first layer (images x height x width x channel) 18 | Y : ndarray 19 | Activations from second layer (images x height x width x channel) 20 | 21 | Returns 22 | ------- 23 | dists : ndarray 24 | Matrix with shape (height x width). Holds `metric.score()` for 25 | X and Y, convolving over the two spatial dimensions. 26 | """ 27 | 28 | # Inputs are (images x height x width x channel) tensors, holding activations. 29 | X, Y = check_equal_shapes(X, Y, nd=4, zero_pad=metric.zero_pad) 30 | m, h, w, c = X.shape 31 | 32 | # Flattened Y tensor. 33 | Yf = Y.reshape(-1, c) 34 | 35 | # Compute metric over all possible offsets. 36 | pbar = tqdm(total=(w * h)) 37 | dists = np.full((h, w), -1.0) 38 | for i, j in itertools.product(range(h), range(w)): 39 | 40 | # Apply shift to X tensor, then flatten. 41 | shifts = (i - (h // 2), j - (w // 2)) 42 | Xf = np.roll(X, shifts, axis=(1, 2)).reshape(-1, c) 43 | 44 | # Fit and evaluate metric. 45 | metric.fit(Xf, Yf) 46 | dists[i, j] = metric.score(Xf, Yf) 47 | 48 | # Update progress bar. 49 | pbar.update() 50 | 51 | pbar.close() 52 | return dists 53 | -------------------------------------------------------------------------------- /netrep/kmeans.py: -------------------------------------------------------------------------------- 1 | from netrep.barycenter import barycenter 2 | from netrep.metrics import LinearMetric 3 | 4 | 5 | def procrustes_kmeans( 6 | Xs, n_clusters, dist_matrix=None, max_iter=100, random_state=None 7 | ): 8 | """ 9 | Perform K-means clustering in the metric space defined by 10 | the Procrustes metric. 11 | 12 | Parameters 13 | ---------- 14 | Xs : list of p matrices, (m x n) ndarrays. 15 | Matrix-valued datasets to compare. Rotations are learned 16 | and applied in the n-dimensional space. 17 | 18 | n_clusters : int 19 | Number of clusters to fit. 20 | 21 | dist_matrix : pairwise distances, (p x p) symmetric matrix, optional. 22 | Pairwise distances between all p networks. This is used 23 | to seed the k-means algorithm by a k-means++ procedure. 24 | 25 | max_iter : int, optional. 26 | Maximum number of iterations to apply. 27 | 28 | random_state : int or np.random.RandomState 29 | Specifies the state of the random number generator. 30 | 31 | 32 | Returns 33 | ------- 34 | centroids : (n_clusters x n) ndarray. 35 | Cluster centroids. 36 | 37 | labels : length-p ndarray. 38 | Vector holding the cluster labels for each network. 39 | 40 | cent_dists : (n_clusters x p) ndarray 41 | Matrix holding the distance from each cluster centroid 42 | to each network. 43 | """ 44 | 45 | # Initialize random number generator. 46 | rs = check_random_state(random_state) 47 | 48 | # Initialize Procrustes metric. 49 | proc_metric = LinearMetric(alpha=1.0) 50 | 51 | # Check input. 52 | Xs = check_array(Xs, allow_nd=True) 53 | if Xs.ndim != 3: 54 | raise ValueError( 55 | "Expected 3d array with shape" 56 | "(n_datasets x n_observations x n_features), but " 57 | "got {}-d array with shape {}".format(Xs.ndim, Xs.shape)) 58 | 59 | # Initialize pairwise distances between all networks. 60 | if dist_matrix is None: 61 | dist_matrix = pairwise_distances(proc_metric, Xs, verbose=False) 62 | 63 | # Pick first centroid randomly 64 | init_centroid_idx = [rs.choice(len(Xs))] 65 | init_dists = dist_matrix[idx[0]] ** 2 66 | 67 | # Pick additional clusters according to k-means++ procedure. 68 | for k in range(1, n_clusters): 69 | init_centroid_idx.append( 70 | rs.choice(len(Xs), p = init_dists / init_dists.sum()) 71 | ) 72 | init_dists = np.minimum( 73 | init_dists, 74 | dist_matrix[init_centroid_idx[-1]] ** 2 75 | ) 76 | 77 | # Collect centroids. 78 | centroids = [np.copy(Xs[i]) for i in idx] 79 | 80 | # Determine cluster labels for each datapoint. 81 | labels = np.array( 82 | [np.argmin(dist_matrix[j][idx]) for j in range(len(Xs))] 83 | ) 84 | 85 | # Initialize distance to centroids matrix 86 | cent_dists = np.zeros((n_clusters, Xs.shape[0])) 87 | 88 | # Main loop. 89 | for i in range(max_iter): 90 | 91 | # Update cluster centroids. 92 | for k in range(n_clusters): 93 | centroids[k] = barycenter( 94 | [X for X, c in zip(Xs, labels) if c == k], 95 | group="orth", random_state=rs, max_iter=10, 96 | warmstart=centroids[k] 97 | ) 98 | 99 | # Compute distance from each datapoint to each centroid. 100 | for j in range(len(Xs)): 101 | for k, cent in enumerate(centroids): 102 | proc_metric.fit(Xs[j], cent) 103 | cent_dists[k, j] = proc_metric.score(Xs[j], cent) 104 | 105 | # Compute new cluster labels. 106 | new_labels = np.argmin(cent_dists, axis=0) 107 | 108 | # Check convergence. 109 | converged = np.all(labels == new_labels) 110 | labels = new_labels 111 | 112 | # Break loop if converged. 113 | if converged: 114 | break 115 | 116 | return centroids, labels, cent_dists 117 | -------------------------------------------------------------------------------- /netrep/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from netrep.metrics.cka import LinearCKA 2 | from netrep.metrics.linear import LinearMetric 3 | from netrep.metrics.perm import PermutationMetric 4 | from netrep.metrics.stochastic import GaussianStochasticMetric 5 | from netrep.metrics.stochastic import EnergyStochasticMetric 6 | from netrep.metrics.stochastic_process import GPStochasticMetric -------------------------------------------------------------------------------- /netrep/metrics/cka.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.base import BaseEstimator 3 | 4 | from netrep.utils import angular_distance 5 | 6 | class LinearCKA: 7 | """ 8 | Note: This function differs from the one outlined in 9 | Kornblith et al. (2019). It introduces an arccos(.) 10 | into the final calculation so that the result satisfies 11 | the conditions of a metric. 12 | """ 13 | 14 | def __init__(self, center_columns=True): 15 | self.center_columns = center_columns 16 | 17 | def fit(self, X, Y): 18 | pass 19 | 20 | def score(self, X, Y): 21 | """ 22 | Parameters 23 | ---------- 24 | X : ndarray 25 | (num_samples x num_neurons) matrix of activations. 26 | Y : ndarray 27 | (num_samples x num_neurons) matrix of activations. 28 | 29 | Returns 30 | ------- 31 | dist : float 32 | Distance between X and Y. 33 | """ 34 | 35 | if self.center_columns: 36 | X = X - np.mean(X, axis=0) 37 | Y = Y - np.mean(Y, axis=0) 38 | 39 | # Compute angular distance between (sample x sample) covariance matrices. 40 | return angular_distance(X @ X.T, Y @ Y.T) 41 | -------------------------------------------------------------------------------- /netrep/metrics/kernel.py: -------------------------------------------------------------------------------- 1 | # class KernelizedMetric(BaseEstimator, MetricMixin): 2 | 3 | # def __init__( 4 | # self, alpha=1.0, gamma=0.0, method="exact", 5 | # kernel="linear", kernel_params=dict(), zero_pad=True): 6 | 7 | # if (alpha > 1) or (alpha < 0): 8 | # raise ValueError( 9 | # "Regularization parameter `alpha` must be between zero and one.") 10 | 11 | # if (gamma > 1) or (gamma < 0): 12 | # raise ValueError( 13 | # "Regularization parameter `gamma` must be between zero and one.") 14 | 15 | # if ((alpha + gamma) > 1) or ((alpha + gamma) < 0): 16 | # raise ValueError( 17 | # "Regularization parameters `alpha` and `gamma` must sum to a " 18 | # "number between zero and one.") 19 | 20 | # self.alpha = alpha 21 | # self.gamma = gamma 22 | # self.method = "exact" 23 | # self.kernel = kernel 24 | # self.kernel_params = kernel_params 25 | # self.zero_pad = zero_pad 26 | 27 | # if ("metric" in kernel_params) and (kernel_params["metric"] != kernel): 28 | # raise ValueError( 29 | # "If 'metric' keyword is included in 'kernel_params' " 30 | # "it must match 'kernel' parameter.") 31 | # else: 32 | # self.kernel_params["metric"] = self.kernel 33 | 34 | 35 | # def fit(self, X, Y): 36 | 37 | # X, Y = check_equal_shapes(X, Y, nd=2, zero_pad=self.zero_pad) 38 | # n_obs, n_feats = X.shape 39 | 40 | # if self.method == "exact": 41 | 42 | # # Compute kernel matrices. 43 | # Kx = centered_kernel(X, **self.kernel_params) 44 | # Ky = centered_kernel(Y, **self.kernel_params) 45 | 46 | # # Whiten kernel matrices. 47 | # Kx_whitened, Zx = _whiten_kernel_matrix(Kx, self.alpha, self.gamma) 48 | # Ky_whitened, Zy = _whiten_kernel_matrix(Ky, self.alpha, self.gamma) 49 | 50 | # # Multiply kernel matrices, compute SVD. 51 | # U, _, Vt = np.linalg.svd(Kx_whitened @ Ky_whitened) 52 | 53 | # # Compute alignment transformations. 54 | # self.Wx_ = Zx @ U 55 | # self.Wy_ = Zy @ Vt.T 56 | 57 | # # Store training set, for prediction at test time. 58 | # self.X_ = X.copy() 59 | # self.Y_ = Y.copy() 60 | 61 | # elif self.method in ("rand", "randomized"): 62 | 63 | # # Approximate low-rank eigendecompositions 64 | # lam_x, Vx = randomized_kernel_eigh(X, self.kernel_params) 65 | # lam_y, Vy = randomized_kernel_eigh(Y, self.kernel_params) 66 | 67 | # wx = np.full(n_obs, self.gamma) 68 | # wx[:lam_x.size] += (1 - self.alpha - self.gamma) * (lam_x ** 2) 69 | # wx[:lam_x.size] += (self.alpha * lam_x) 70 | 71 | # wx = np.full(n_obs, self.gamma) 72 | # wx[:lam_x.size] += (1 - self.alpha - self.gamma) * (lam_x ** 2) 73 | # wx[:lam_x.size] += (self.alpha * lam_x) 74 | 75 | # assert False 76 | 77 | # return self 78 | 79 | # def transform_X(self, X): 80 | # check_is_fitted(self, attributes=["Wx_"]) 81 | # return centered_kernel(X, self.X_) @ self.Wx_ 82 | 83 | # def transform_Y(self, Y): 84 | # check_is_fitted(self, attributes=["Wy_"]) 85 | # return centered_kernel(Y, self.Y_) @ self.Wy_ 86 | 87 | # def score(self, X, Y): 88 | # return angular_distance(*self.transform(X, Y)) 89 | 90 | 91 | # def _whiten_kernel_matrix(K, a, g, eigval_tol=1e-7): 92 | 93 | # # Compute eigendecomposition for kernel matrix 94 | # w, v = np.linalg.eigh(K) 95 | 96 | # # Regularize eigenvalues. 97 | # w = ((1 - a - g) * (w ** 2)) + (a * w) + g 98 | # w[w < eigval_tol] = eigval_tol # clip minimum eigenvalue 99 | 100 | # # Matrix holding the whitening transformation. 101 | # Z = (v * (1 / np.sqrt(w))[None, :]) @ v.T 102 | 103 | # # Returned (partially) whitened data and whitening matrix. 104 | # return K @ Z, Z 105 | 106 | 107 | 108 | # def randomized_kernel_approx(X, kernel_params, s, upsample_factor, random_state): 109 | 110 | # # Sample s columns of the kernel matrix, randomly at uniform. 111 | # i1 = random_state.choice(len(X), replace=False, size=s) 112 | # C = pairwise_kernels(X, X[i1], **kernel_params) 113 | 114 | # # Find an orthonormal basis for the sampled columns. 115 | # Q, _ = np.linalg.qr(C) 116 | 117 | # # Sample upsample_factor * s columns, using leverage scores 118 | # lev_scores = np.sum(Q * Q, axis=1) 119 | # i2 = random_state.choice( 120 | # len(X) 121 | # size=(upsample_factor * s), 122 | # p=(lev_scores / np.sum(lev_scores)), 123 | # replace=False 124 | # ) 125 | 126 | # # Empirically, including the initially sampled columns helps performance. 127 | # idx = np.unique(np.concatenate((i1, i2))) 128 | 129 | # # Form low rank estimate, L @ L.T, of kernel matrix. 130 | # Ksub = pairwise_kernels(X[idx], **kernel_params) 131 | # L = np.linalg.pinv(Q[idx]) @ scipy.linalg.sqrtm(Ksub) 132 | 133 | # # Compute SVD of L to estimate the eigendecomposition of kernel matrix. 134 | # eigvecs, sqrt_eigvals, _ = np.linalg.svd(L, full_matrices=False) 135 | 136 | # return sqrt_eigvals ** 2, eigvecs 137 | -------------------------------------------------------------------------------- /netrep/metrics/linear.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import itertools 3 | import multiprocessing 4 | from typing import Literal, Tuple, Optional, List 5 | 6 | import numpy as np 7 | import numpy.typing as npt 8 | from sklearn.utils.validation import check_is_fitted 9 | from sklearn.base import BaseEstimator 10 | from tqdm import tqdm 11 | 12 | from netrep.utils import whiten, angular_distance 13 | from netrep.validation import check_equal_shapes 14 | 15 | class LinearMetric(BaseEstimator): 16 | """Computes distance between two sets of optimally linearly aligned representations. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | alpha: float = 1.0, 22 | center_columns: bool = True, 23 | zero_pad: bool = True, 24 | score_method: Literal["angular", "euclidean"] = "angular" 25 | ): 26 | """ 27 | Parameters 28 | ---------- 29 | alpha : float 30 | Regularization parameter between zero and one. When 31 | (alpha == 1.0) the metric only allows for rotational 32 | alignments. When (alpha == 0.0) the metric allows for 33 | any invertible linear transformation. 34 | 35 | center_columns : bool 36 | If True, learn a mean-centering operation in addition 37 | to the linear/rotational alignment. 38 | 39 | zero_pad : bool 40 | If False, an error is thrown if representations are 41 | provided with different dimensions. If True, the smaller 42 | matrix is zero-padded prior to allow for an alignment. 43 | Some amount of regularization (alpha > 0) is required to 44 | align zero-padded representations. 45 | 46 | score_method : {'angular','euclidean'}, default='angular' 47 | String specifying ground metric. 48 | """ 49 | 50 | if (alpha > 1) or (alpha < 0): 51 | raise ValueError( 52 | "Regularization parameter `alpha` must be between zero and one.") 53 | 54 | if score_method not in ("euclidean", "angular"): 55 | raise ValueError( 56 | "Expected `score_method` parameter to be in {'angular','euclidean'}. " + 57 | f"Found instead score_method == '{score_method}'." 58 | ) 59 | 60 | self.alpha = alpha 61 | self.center_columns = center_columns 62 | self.zero_pad = zero_pad 63 | self.score_method = score_method 64 | 65 | def partial_fit( 66 | self, 67 | X: npt.NDArray 68 | ) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]: 69 | """Computes partial whitening transformation for a response matrix.""" 70 | if self.center_columns: 71 | mx = np.mean(X, axis=0) 72 | Xw, Zx = whiten(X - mx[None, :], self.alpha, preserve_variance=True) 73 | else: 74 | mx = np.zeros(X.shape[1]) 75 | Xw, Zx = whiten(X, self.alpha, preserve_variance=True) 76 | return (mx, Xw, Zx) 77 | 78 | def finalize_fit( 79 | self, 80 | cache_X: Tuple[npt.NDArray, npt.NDArray, npt.NDArray], 81 | cache_Y: Tuple[npt.NDArray, npt.NDArray, npt.NDArray], 82 | ) -> LinearMetric: 83 | """ 84 | Takes outputs of 'partial_fit' function and finishes fitting 85 | transformation matrices (Wx, Wy) and bias terms (mx, my) to 86 | align a pair of neural activations. 87 | """ 88 | 89 | # Extract whitened representations. 90 | self.mx_, Xw, Zx = cache_X 91 | self.my_, Yw, Zy = cache_Y 92 | 93 | # Fit optimal rotational alignment. 94 | U, _, Vt = np.linalg.svd(Xw.T @ Yw) 95 | self.Wx_ = Zx @ U 96 | self.Wy_ = Zy @ Vt.T 97 | 98 | return self 99 | 100 | def fit(self, X: npt.NDArray, Y: npt.NDArray) -> LinearMetric: 101 | """Fits transformation matrices (Wx, Wy) and bias terms (mx, my) 102 | to align a pair of neural activation matrices. 103 | 104 | Parameters 105 | ---------- 106 | X : ndarray 107 | (num_samples x num_neurons) matrix of activations. 108 | Y : ndarray 109 | (num_samples x num_neurons) matrix of activations. 110 | """ 111 | X, Y = check_equal_shapes(X, Y, nd=2, zero_pad=self.zero_pad) 112 | return self.finalize_fit( 113 | self.partial_fit(X), 114 | self.partial_fit(Y) 115 | ) 116 | 117 | def transform( 118 | self, 119 | X: npt.NDArray, 120 | Y: npt.NDArray 121 | ) -> Tuple[npt.NDArray, npt.NDArray]: 122 | """Applies linear alignment transformations to X and Y. 123 | 124 | Parameters 125 | ---------- 126 | X : ndarray 127 | (num_samples x num_neurons) matrix of activations. 128 | Y : ndarray 129 | (num_samples x num_neurons) matrix of activations. 130 | 131 | Returns 132 | ------- 133 | tX : ndarray 134 | Transformed version of X. 135 | tY : ndarray 136 | Transformed version of Y. 137 | """ 138 | X, Y = check_equal_shapes(X, Y, nd=2, zero_pad=self.zero_pad) 139 | return self._transform_X(X), self._transform_Y(Y) 140 | 141 | def fit_score(self, X: npt.NDArray, Y: npt.NDArray) -> float: 142 | """Fits alignment by calling `fit(X, Y)` and then evaluates 143 | the distance by calling `score(X, Y)`. 144 | 145 | Parameters 146 | ---------- 147 | X : ndarray 148 | (num_samples x num_neurons) matrix of activations. 149 | Y : ndarray 150 | (num_samples x num_neurons) matrix of activations. 151 | 152 | Returns 153 | ------- 154 | dist : float 155 | Distance between X and Y. 156 | """ 157 | return self.fit(X, Y).score(X, Y) 158 | 159 | def score(self, X: npt.NDArray, Y: npt.NDArray) -> float: 160 | """Computes the angular distance between X and Y in 161 | the aligned space. 162 | 163 | Parameters 164 | ---------- 165 | X : ndarray 166 | (num_samples x num_neurons) matrix of activations. 167 | Y : ndarray 168 | (num_samples x num_neurons) matrix of activations. 169 | 170 | Returns 171 | ------- 172 | dist : float 173 | Distance between X and Y. 174 | """ 175 | if self.score_method == "angular": 176 | return angular_distance(*self.transform(X, Y)) 177 | else: # self.score_method == "euclidean": 178 | return np.linalg.norm( 179 | np.subtract(*self.transform(X, Y)), ord="fro" 180 | ) 181 | 182 | def _transform_X(self, X: npt.NDArray) -> npt.NDArray: 183 | """Transform X into the aligned space.""" 184 | check_is_fitted(self, attributes=["Wx_"]) 185 | if (X.shape[1] != self.Wx_.shape[0]): 186 | raise ValueError( 187 | "Array with wrong shape passed to transform." 188 | "Expected matrix with {} columns, but got array" 189 | "with shape {}.".format(np.shape(X))) 190 | if self.center_columns: 191 | return (X - self.mx_[None, :]) @ self.Wx_ 192 | else: 193 | return (X @ self.Wx_) 194 | 195 | def _transform_Y(self, Y: npt.NDArray) -> npt.NDArray: 196 | """Transform X into the aligned space.""" 197 | check_is_fitted(self, attributes=["Wy_"]) 198 | if (Y.shape[1] != self.Wy_.shape[0]): 199 | raise ValueError( 200 | "Array with wrong shape passed to transform." 201 | "Expected matrix with {} columns, but got array" 202 | "with shape {}.".format(np.shape(Y))) 203 | if self.center_columns: 204 | return (Y - self.my_[None, :]) @ self.Wy_ 205 | else: 206 | return Y @ self.Wy_ 207 | 208 | def _compute_distance(self, i, j, X, Y, X_test, Y_test): 209 | """Helper function for multiprocessing.""" 210 | 211 | self.fit(X, Y) 212 | dist_train = self.score(X, Y) 213 | if X_test is None and Y_test is None: 214 | dist_test = np.inf 215 | else: 216 | dist_test = self.score(X_test, Y_test) 217 | return i, j, dist_train, dist_test 218 | 219 | def _compute_distance_star(self, args): 220 | """Helper function for multiprocessing. 221 | Using this allows us to use tqdm to track progress via imap_unordered. 222 | """ 223 | return self._compute_distance(*args) 224 | 225 | def pairwise_distances( 226 | self, 227 | train_data: List[Tuple[npt.NDArray, npt.NDArray]], 228 | test_data: Optional[List[Tuple[npt.NDArray, npt.NDArray]]]=None, 229 | processes: Optional[int] = None, 230 | verbose: bool = True, 231 | ): 232 | """Computes pairwise distances between all pairs of networks w/ multiprocessing. 233 | 234 | We suggest setting "OMP_NUM_THREADS=1" in your environment variables to avoid oversubscription 235 | (multiprocesses competing for the same CPU). 236 | 237 | Parameters 238 | ---------- 239 | train_data: List[npt.NDArray] 240 | List of Size([images, neurons]) for train data. 241 | test_data: List[npt.NDArray], optional 242 | List of Size([images, neurons]) for test data. If None, the output 243 | distance matrix will be np.inf. 244 | enable_caching: bool 245 | Whether to cache pre-transformed data. 246 | processes: int, optional 247 | Number of processes to use. If None, defaults to number of CPUs. 248 | verbose: bool, optional 249 | Whether to display progress bar. 250 | 251 | Returns 252 | ------- 253 | D_train: npt.NDArray 254 | n_networks x n_networks distance matrix. 255 | D_test: npt.NDArray 256 | n_networks x n_networks distance matrix. If test_data is None, this is 257 | a matrix of np.inf. 258 | """ 259 | n_networks = len(train_data) 260 | n_dists = n_networks*(n_networks-1)//2 261 | 262 | # create generator of args for multiprocessing 263 | ij = itertools.combinations(range(n_networks), 2) 264 | if test_data is None: 265 | args = ((i, j, train_data[i], train_data[j], None, None) for i, j in ij) 266 | else: 267 | args = ((i, j, train_data[i], train_data[j], test_data[i], test_data[j]) for i, j in ij) 268 | 269 | if verbose: 270 | print(f"Parallelizing {n_dists} distance calculations with {multiprocessing.cpu_count() if processes is None else processes} processes.") 271 | pbar = lambda x: tqdm(x, total=n_dists, desc="Computing distances") 272 | else: 273 | pbar = lambda x: x 274 | 275 | with multiprocessing.Pool(processes=processes) as pool: 276 | results = [] 277 | for result in pbar(pool.imap_unordered(self._compute_distance_star, args)): 278 | results.append(result) 279 | 280 | D_train = np.zeros((n_networks, n_networks)) 281 | D_test = np.zeros((n_networks, n_networks)) 282 | 283 | for i, j, dist_train, dist_test in results: 284 | D_train[i, j], D_train[j, i] = dist_train, dist_train 285 | D_test[i, j], D_test[j, i] = dist_test, dist_test 286 | 287 | return D_train, D_test 288 | -------------------------------------------------------------------------------- /netrep/metrics/perm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import itertools 3 | import multiprocessing 4 | from typing import Tuple, List, Optional, Literal 5 | 6 | import numpy as np 7 | import numpy.typing as npt 8 | from scipy.optimize import linear_sum_assignment as lsa 9 | from sklearn.base import BaseEstimator 10 | from sklearn.utils.validation import check_is_fitted 11 | from tqdm import tqdm 12 | 13 | from netrep.validation import check_equal_shapes 14 | from netrep.utils import angular_distance 15 | 16 | class PermutationMetric(BaseEstimator): 17 | """Computes distance between two sets of optimally permutation-aligned representations. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | center_columns: bool = True, 23 | zero_pad: bool = True, 24 | score_method: Literal["angular", "euclidean"] = "angular" 25 | ): 26 | """ 27 | Parameters 28 | ---------- 29 | center_columns : bool 30 | If True, learn a mean-centering operation in addition 31 | to the linear/rotational alignment. 32 | 33 | zero_pad : bool 34 | If False, an error is thrown if representations are 35 | provided with different dimensions. If True, the smaller 36 | matrix is zero-padded prior to allow for an alignment. 37 | Some amount of regularization (alpha > 0) is required to 38 | align zero-padded representations. 39 | 40 | score_method : {'angular','euclidean'}, default='angular' 41 | String specifying ground metric. 42 | """ 43 | 44 | if score_method not in ("euclidean", "angular"): 45 | raise ValueError( 46 | "Expected `score_method` parameter to be in {'angular','euclidean'}. " + 47 | f"Found instead score_method == '{score_method}'." 48 | ) 49 | 50 | self.center_columns = center_columns 51 | self.zero_pad = zero_pad 52 | self.score_method = score_method 53 | 54 | def partial_fit(self, X: npt.NDArray) -> Tuple[npt.NDArray, npt.NDArray]: 55 | """Computes partial whitening transformation for a neural response matrix. 56 | """ 57 | if self.center_columns: 58 | mx = np.mean(X, axis=0) 59 | Xphi = X - mx[None, :] 60 | else: 61 | mx = np.zeros(X.shape[1]) 62 | Xphi = X 63 | return (mx, Xphi) 64 | 65 | def finalize_fit( 66 | self, 67 | cache_X: Tuple[npt.NDArray, npt.NDArray], 68 | cache_Y: Tuple[npt.NDArray, npt.NDArray] 69 | ) -> PermutationMetric: 70 | """Takes outputs of 'partial_fit' function and finishes fitting permutation 71 | matrices (Px, Py) and bias terms (mx, my) to align a pair of neural activations. 72 | """ 73 | 74 | # Extract whitened representations. 75 | self.mx_, X = cache_X 76 | self.my_, Y = cache_Y 77 | 78 | # Fit optimal permutation matrices. 79 | self.Px_, self.Py_ = lsa(X.T @ Y, maximize=True) 80 | 81 | return self 82 | 83 | def fit(self, X: npt.NDArray, Y: npt.NDArray) -> PermutationMetric: 84 | """Fits permutation matrices (Px, Py) and bias terms (mx, my) to align a pair of 85 | neural activation matrices. 86 | 87 | Parameters 88 | ---------- 89 | X : ndarray 90 | (num_samples x num_neurons) matrix of activations. 91 | Y : ndarray 92 | (num_samples x num_neurons) matrix of activations. 93 | """ 94 | X, Y = check_equal_shapes(X, Y, nd=2, zero_pad=self.zero_pad) 95 | return self.finalize_fit( 96 | self.partial_fit(X), 97 | self.partial_fit(Y) 98 | ) 99 | 100 | def transform( 101 | self, 102 | X: npt.NDArray, 103 | Y: npt.NDArray 104 | ) -> Tuple[npt.NDArray, npt.NDArray]: 105 | """Applies linear alignment transformations to X and Y. 106 | 107 | Parameters 108 | ---------- 109 | X : ndarray 110 | (num_samples x num_neurons) matrix of activations. 111 | Y : ndarray 112 | (num_samples x num_neurons) matrix of activations. 113 | 114 | Returns 115 | ------- 116 | tX : ndarray 117 | Transformed version of X. 118 | tY : ndarray 119 | Transformed version of Y. 120 | """ 121 | X, Y = check_equal_shapes(X, Y, nd=2, zero_pad=self.zero_pad) 122 | return self._transform_X(X), self._transform_Y(Y) 123 | 124 | def fit_score(self, X: npt.NDArray, Y: npt.NDArray) -> float: 125 | """Fits alignment by calling `fit(X, Y)` and then evaluates 126 | the distance by calling `score(X, Y)`. 127 | 128 | Parameters 129 | ---------- 130 | X : ndarray 131 | (num_samples x num_neurons) matrix of activations. 132 | Y : ndarray 133 | (num_samples x num_neurons) matrix of activations. 134 | 135 | Returns 136 | ------- 137 | dist : float 138 | Distance between optimally aligned X and Y. 139 | """ 140 | return self.fit(X, Y).score(X, Y) 141 | 142 | def score(self, X: npt.NDArray, Y: npt.NDArray) -> float: 143 | """Computes the distance between X and Y in the aligned 144 | space. 145 | 146 | Parameters 147 | ---------- 148 | X : ndarray 149 | (num_samples x num_neurons) matrix of activations. 150 | Y : ndarray 151 | (num_samples x num_neurons) matrix of activations. 152 | 153 | Returns 154 | ------- 155 | dist : float 156 | Distance between X and Y. 157 | """ 158 | if self.score_method == "angular": 159 | return angular_distance(*self.transform(X, Y)) 160 | else: # self.score_method == "euclidean": 161 | return np.linalg.norm( 162 | np.subtract(*self.transform(X, Y)), ord="fro" 163 | ) 164 | 165 | 166 | def _transform_X(self, X: npt.NDArray) -> npt.NDArray: 167 | """Transform X into the aligned space.""" 168 | check_is_fitted(self, attributes=["Px_"]) 169 | if (X.shape[1] != len(self.Px_)): 170 | raise ValueError( 171 | "Array with wrong shape passed to transform." 172 | "Expected matrix with {} columns, but got array" 173 | "with shape {}.".format(np.shape(X))) 174 | if self.center_columns: 175 | return (X - self.mx_[None, :])[:, self.Px_] 176 | else: 177 | return X[:, self.Px_] 178 | 179 | def _transform_Y(self, Y: npt.NDArray) -> npt.NDArray: 180 | """Transform X into the aligned space.""" 181 | check_is_fitted(self, attributes=["Py_"]) 182 | if (Y.shape[1] != len(self.Py_)): 183 | raise ValueError( 184 | "Array with wrong shape passed to transform." 185 | "Expected matrix with {} columns, but got array" 186 | "with shape {}.".format(np.shape(Y))) 187 | if self.center_columns: 188 | return (Y - self.my_[None, :])[:, self.Py_] 189 | else: 190 | return Y[:, self.Py_] 191 | 192 | def _compute_distance(self, i, j, X, Y, X_test, Y_test): 193 | """Helper function for multiprocessing.""" 194 | 195 | self.fit(X, Y) 196 | dist_train = self.score(X, Y) 197 | if X_test is None and Y_test is None: 198 | dist_test = np.inf 199 | else: 200 | dist_test = self.score(X_test, Y_test) 201 | return i, j, dist_train, dist_test 202 | 203 | def _compute_distance_star(self, args): 204 | """Helper function for multiprocessing. 205 | Using this allows us to use tqdm to track progress via imap_unordered. 206 | """ 207 | return self._compute_distance(*args) 208 | 209 | def pairwise_distances( 210 | self, 211 | train_data: List[Tuple[npt.NDArray, npt.NDArray]], 212 | test_data: Optional[List[Tuple[npt.NDArray, npt.NDArray]]]=None, 213 | processes: Optional[int] = None, 214 | verbose: bool = True, 215 | ): 216 | """Computes pairwise distances between all pairs of networks w/ multiprocessing. 217 | 218 | We suggest setting "OMP_NUM_THREADS=1" in your environment variables to avoid oversubscription 219 | (multiprocesses competing for the same CPU). 220 | 221 | Parameters 222 | ---------- 223 | train_data: List[npt.NDArray] 224 | List of Size([images, neurons]) for train data. 225 | test_data: List[npt.NDArray], optional 226 | List of Size([images, neurons]) for test data. If None, the output 227 | distance matrix will be np.inf. 228 | enable_caching: bool 229 | Whether to cache pre-transformed data. 230 | processes: int, optional 231 | Number of processes to use. If None, defaults to number of CPUs. 232 | verbose: bool, optional 233 | Whether to display progress bar. 234 | 235 | Returns 236 | ------- 237 | D_train: npt.NDArray 238 | n_networks x n_networks distance matrix. 239 | D_test: npt.NDArray 240 | n_networks x n_networks distance matrix. If test_data is None, this is 241 | a matrix of np.inf. 242 | """ 243 | n_networks = len(train_data) 244 | n_dists = n_networks*(n_networks-1)//2 245 | 246 | # create generator of args for multiprocessing 247 | ij = itertools.combinations(range(n_networks), 2) 248 | if test_data is None: 249 | args = ((i, j, train_data[i], train_data[j], None, None) for i, j in ij) 250 | else: 251 | args = ((i, j, train_data[i], train_data[j], test_data[i], test_data[j]) for i, j in ij) 252 | 253 | if verbose: 254 | print(f"Parallelizing {n_dists} distance calculations with {multiprocessing.cpu_count() if processes is None else processes} processes.") 255 | pbar = lambda x: tqdm(x, total=n_dists, desc="Computing distances") 256 | else: 257 | pbar = lambda x: x 258 | 259 | with multiprocessing.Pool(processes=processes) as pool: 260 | results = [] 261 | for result in pbar(pool.imap_unordered(self._compute_distance_star, args)): 262 | results.append(result) 263 | 264 | D_train = np.zeros((n_networks, n_networks)) 265 | D_test = np.zeros((n_networks, n_networks)) 266 | 267 | for i, j, dist_train, dist_test in results: 268 | D_train[i, j], D_train[j, i] = dist_train, dist_train 269 | D_test[i, j], D_test[j, i] = dist_test, dist_test 270 | 271 | return D_train, D_test 272 | -------------------------------------------------------------------------------- /netrep/metrics/stochastic.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import itertools 3 | import multiprocessing 4 | from typing import Tuple, Optional, Union, Literal, List 5 | 6 | import numpy as np 7 | import numpy.typing as npt 8 | from sklearn.utils.validation import check_random_state 9 | from tqdm import tqdm 10 | 11 | from netrep.utils import align, sq_bures_metric, rand_orth 12 | 13 | 14 | class GaussianStochasticMetric: 15 | """2-Wasserstein distance between Gaussian-distributed network responses. 16 | 17 | Attributes 18 | ---------- 19 | alpha: float between 0 and 2 20 | Interpolates between covariance-only and mean-only distance metrics. 21 | When alpha == 0: only uses covariance. 22 | When alpha == 1: computes 2-Wasserstein. 23 | When alpha == 2: only uses means (i.e. deterministic metric). 24 | group: Literal["orth", "perm", "identity"] 25 | Invariance group over which to optimize. 26 | init: Literal["means", "rand"] 27 | Transform initialization. 28 | niter: int 29 | Number of optimization iterations. 30 | tol: float 31 | Optimization tolerance. 32 | n_restarts: int 33 | Number of restarts. Only valid when `init` is "rand". 34 | T: np.ndarray 35 | Optimal alignment matrix. 36 | loss_hist: List[float] 37 | Loss history. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | alpha: float=1.0, 43 | group: Literal["orth", "perm", "identity"] = "orth", 44 | init: Literal["means", "rand"] = "means", 45 | niter: int = 1000, 46 | tol: float = 1e-8, 47 | random_state: Optional[Union[int, np.random.RandomState]]=None, 48 | n_restarts: int = 1, 49 | ): 50 | if (alpha < 0) or (alpha > 2): 51 | raise ValueError("alpha parameter should be between zero and two.") 52 | self.alpha = alpha 53 | self.group = group 54 | self.init = init 55 | self.niter = niter 56 | self.tol = tol 57 | self._rs = check_random_state(random_state) 58 | self.n_restarts = n_restarts 59 | if self.init == "means": 60 | assert n_restarts == 1 61 | 62 | def fit( 63 | self, 64 | X: Tuple[npt.NDArray, npt.NDArray], 65 | Y: Tuple[npt.NDArray, npt.NDArray] 66 | ) -> GaussianStochasticMetric: 67 | """Aligns network responses with interpolated 2-Wasserstein ground metric. 68 | 69 | Parameters 70 | ---------- 71 | X : Tuple[np.ndarray, np.ndarray] 72 | Tuple of (means, covariances) for first set of network responses. Means has 73 | shape (n_images, n_neurons) and covariances has shape 74 | (n_images, n_neurons, n_neurons). 75 | Y : Tuple[np.ndarray, np.ndarray] 76 | Tuple of (means, covariances) for second set of network responses. Means has 77 | shape (n_images, n_neurons) and covariances has shape 78 | (n_images, n_neurons, n_neurons). 79 | 80 | Returns 81 | ------- 82 | self: GaussianStochasticMetric 83 | Instance of class with optimal alignment matrix stored in `self.T`. 84 | """ 85 | means_X, covs_X = X 86 | means_Y, covs_Y = Y 87 | 88 | assert means_X.shape == means_Y.shape 89 | assert covs_X.shape == covs_Y.shape 90 | assert means_X.shape[0] == covs_X.shape[0] 91 | assert means_X.shape[1] == covs_X.shape[1] 92 | assert means_X.shape[1] == covs_X.shape[2] 93 | 94 | best_loss = np.inf 95 | for _ in range(self.n_restarts): 96 | 97 | if self.init == "means": 98 | init_T = align(means_Y, means_X, group=self.group) 99 | elif self.init == "rand": 100 | init_T = rand_orth(means_X.shape[1], random_state=self._rs) 101 | 102 | T, loss_hist = _fit_gaussian_alignment( 103 | means_X, means_Y, covs_X, covs_Y, init_T, 104 | self.alpha, self.group, self.niter, self.tol 105 | ) 106 | if best_loss > loss_hist[-1]: 107 | best_loss = loss_hist[-1] 108 | best_T = T 109 | 110 | self.T = best_T 111 | self.loss_hist = loss_hist 112 | return self 113 | 114 | def transform( 115 | self, 116 | X: Tuple[npt.NDArray, npt.NDArray], 117 | Y: Tuple[npt.NDArray, npt.NDArray] 118 | ) -> Tuple[Tuple[npt.NDArray, npt.NDArray], Tuple[npt.NDArray, npt.NDArray]]: 119 | """Aligns second set of network responses with first set. 120 | 121 | Parameters 122 | ---------- 123 | X : Tuple[np.ndarray, np.ndarray] 124 | Tuple of (means, covariances) for first set of network responses. Means has 125 | shape (n_images, n_neurons) and covariances has shape 126 | (n_images, n_neurons, n_neurons). 127 | Y : Tuple[np.ndarray, np.ndarray] 128 | Tuple of (means, covariances) for second set of network responses. Means has 129 | shape (n_images, n_neurons) and covariances has shape 130 | (n_images, n_neurons, n_neurons). 131 | 132 | Returns 133 | ------- 134 | X : Tuple[np.ndarray, np.ndarray] 135 | Same as input. 136 | Y_transformed : Tuple[np.ndarray, np.ndarray] 137 | Aligned tuple of (means, covariances) for second set of network responses. 138 | """ 139 | means_Y, covs_Y = Y 140 | Y_transformed = ( 141 | means_Y @ self.T, 142 | np.einsum("ijk,jl,kp->ilp", covs_Y, self.T, self.T, optimize=True) 143 | ) 144 | return X, Y_transformed 145 | 146 | def score( 147 | self, 148 | X: Tuple[npt.NDArray, npt.NDArray], 149 | Y: Tuple[npt.NDArray, npt.NDArray] 150 | ) -> float: 151 | """Computes interpolated 2-Wasserstein distance between aligned network responses. 152 | 153 | Parameters 154 | ---------- 155 | X: Tuple[np.ndarray, np.ndarray] 156 | Tuple of (means, covariances) for first set of network responses. Means has 157 | shape (n_images, n_neurons) and covariances has shape 158 | (n_images, n_neurons, n_neurons). 159 | Y: Tuple[np.ndarray, np.ndarray] 160 | Tuple of (means, covariances) for second set of network responses. Means has 161 | shape (n_images, n_neurons) and covariances has shape 162 | (n_images, n_neurons, n_neurons). 163 | 164 | Returns 165 | ------- 166 | score: float 167 | Interpolated 2-Wasserstein distance between aligned network responses. 168 | """ 169 | X, Y = self.transform(X, Y) 170 | mX, sX = X 171 | mY, sY = Y 172 | 173 | A = np.sum((mX - mY) ** 2, axis=1) 174 | B = np.array([sq_bures_metric(sx, sy) for sx, sy in zip(sX, sY)]) 175 | mn = np.mean(self.alpha * A + (2 - self.alpha) * B) 176 | # mn should always be positive but sometimes numerical rounding errors 177 | # cause mn to be very slightly negative, causing sqrt(mn) to be nan. 178 | # Thus, we take sqrt(abs(mn)) and pass through the sign. Any large 179 | # negative outputs should be caught by unit tests. 180 | return np.sign(mn) * np.sqrt(abs(mn)) 181 | 182 | def fit_score( 183 | self, 184 | X: Tuple[npt.NDArray, npt.NDArray], 185 | Y: Tuple[npt.NDArray, npt.NDArray] 186 | ) -> float: 187 | """Fits alignment matrix and returns distance. 188 | 189 | Parameters 190 | ---------- 191 | X: Tuple[np.ndarray, np.ndarray] 192 | Tuple of (means, covariances) for first set of network responses. Means has 193 | shape (n_images, n_neurons) and covariances has shape 194 | (n_images, n_neurons, n_neurons). 195 | Y: Tuple[np.ndarray, np.ndarray] 196 | Tuple of (means, covariances) for second set of network responses. Means has 197 | shape (n_images, n_neurons) and covariances has shape 198 | (n_images, n_neurons, n_neurons). 199 | 200 | Returns 201 | ------- 202 | score: float 203 | Interpolated 2-Wasserstein distance between aligned network responses. 204 | """ 205 | return self.fit(X, Y).score(X, Y) 206 | 207 | def _compute_distance(self, i, j, X, Y, X_test, Y_test, eps): 208 | """Helper function for multiprocessing.""" 209 | X = (X[0], X[1] + eps * np.eye(X[1].shape[1])) # regularize covariance 210 | Y = (Y[0], Y[1] + eps * np.eye(Y[1].shape[1])) 211 | 212 | self.fit(X, Y) 213 | dist_train = self.score(X, Y) 214 | if X_test is None and Y_test is None: 215 | dist_test = np.inf 216 | else: 217 | dist_test = self.score(X_test, Y_test) 218 | return i, j, dist_train, dist_test 219 | 220 | def _compute_distance_star(self, args): 221 | """Helper function for multiprocessing. 222 | Using this allows us to use tqdm to track progress via imap_unordered. 223 | """ 224 | return self._compute_distance(*args) 225 | 226 | def pairwise_distances( 227 | self, 228 | train_data: List[Tuple[npt.NDArray, npt.NDArray]], 229 | test_data: Optional[List[Tuple[npt.NDArray, npt.NDArray]]]=None, 230 | eps: float = 1E-6, 231 | processes: Optional[int] = None, 232 | verbose: bool = True, 233 | ): 234 | """Computes pairwise distances between all pairs of networks w/ multiprocessing. 235 | 236 | We suggest setting "OMP_NUM_THREADS=1" in your environment variables to avoid oversubscription 237 | (multiprocesses competing for the same CPU). 238 | 239 | Parameters 240 | ---------- 241 | train_data: List[Tuple[npt.NDArray, npt.NDArray]] 242 | List of tuples of (means, covariances) for train data. 243 | test_data: List[Tuple[npt.NDArray, npt.NDArray]], optional 244 | List of tuples of (means, covariances) for test data. If None, the output 245 | distance matrix will be np.inf. 246 | eps: float, optional 247 | Add eps * I to each covariances to regularize. 248 | processes: int, optional 249 | Number of processes to use. If None, defaults to number of CPUs. 250 | verbose: bool, optional 251 | Whether to display progress bar. 252 | 253 | Returns 254 | ------- 255 | D_train: npt.NDArray 256 | n_networks x n_networks distance matrix. 257 | D_test: npt.NDArray 258 | n_networks x n_networks distance matrix. If test_data is None, this is 259 | a matrix of np.inf. 260 | """ 261 | n_networks = len(train_data) 262 | n_dists = n_networks*(n_networks-1)//2 263 | 264 | # create generator of args for multiprocessing 265 | ij = itertools.combinations(range(n_networks), 2) 266 | if test_data is None: 267 | args = ((i, j, train_data[i], train_data[j], None, None, eps) for i, j in ij) 268 | else: 269 | args = ((i, j, train_data[i], train_data[j], test_data[i], test_data[j], eps) for i, j in ij) 270 | 271 | if verbose: 272 | print(f"Parallelizing {n_dists} distance calculations with {multiprocessing.cpu_count() if processes is None else processes} processes.") 273 | pbar = lambda x: tqdm(x, total=n_dists, desc="Computing distances") 274 | else: 275 | pbar = lambda x: x 276 | 277 | with multiprocessing.Pool(processes=processes) as pool: 278 | results = [] 279 | for result in pbar(pool.imap_unordered(self._compute_distance_star, args)): 280 | results.append(result) 281 | 282 | D_train = np.zeros((n_networks, n_networks)) 283 | D_test = np.zeros((n_networks, n_networks)) 284 | 285 | for i, j, dist_train, dist_test in results: 286 | D_train[i, j], D_train[j, i] = dist_train, dist_train 287 | D_test[i, j], D_test[j, i] = dist_test, dist_test 288 | 289 | return D_train, D_test 290 | 291 | class EnergyStochasticMetric: 292 | """Optimal alignment of network responses using energy distance as the ground metric. 293 | 294 | Attributes 295 | ---------- 296 | group: Literal["orth", "perm", "identity"] 297 | Invariance group over which to optimize. 298 | niter: int 299 | Number of optimization iterations. 300 | tol: float 301 | Defaults to 1e-6. 302 | Q: np.ndarray 303 | Optimal alignment matrix. 304 | loss_hist: List[float] 305 | """ 306 | 307 | def __init__( 308 | self, 309 | group: Literal["orth", "perm", "identity"] = "orth", 310 | niter: int = 100, 311 | tol: float = 1e-6): 312 | 313 | self.group = group 314 | self.niter = niter 315 | self.tol = tol 316 | 317 | def fit( 318 | self, 319 | X: npt.NDArray, 320 | Y: npt.NDArray 321 | ) -> EnergyStochasticMetric: 322 | """Fits optimal matrix that aligns network responses Y to X. 323 | 324 | Parameters 325 | ---------- 326 | X : np.ndarray 327 | Responses of first network with Size[(images, repeats, neurons]). 328 | Y : np.ndarray 329 | Responses of second network with Size[(images, repeats, neurons]). 330 | 331 | Returns 332 | ------- 333 | self : EnergyStochasticMetric 334 | Class instance with updated state. 335 | """ 336 | assert X.shape == Y.shape 337 | 338 | r = X.shape[1] 339 | 340 | idx = np.array(list(itertools.product(range(r), range(r)))) 341 | X = np.row_stack([x[idx[:, 0]] for x in X]) 342 | Y = np.row_stack([y[idx[:, 1]] for y in Y]) 343 | 344 | w = np.ones(X.shape[0]) 345 | loss_hist = [np.mean(np.linalg.norm(X - Y, axis=-1))] 346 | 347 | for _ in range(self.niter): 348 | Q = align(w[:, None] * Y, w[:, None] * X, group=self.group) 349 | resid = np.linalg.norm(X - Y @ Q, axis=-1) 350 | loss_hist.append(np.mean(resid)) 351 | w = 1 / np.maximum(np.sqrt(resid), 1e-6) 352 | if (loss_hist[-2] - loss_hist[-1]) < self.tol: 353 | break 354 | 355 | self.w = w 356 | self.Q = Q 357 | self.loss_hist = loss_hist 358 | return self 359 | 360 | def transform( 361 | self, 362 | X: npt.NDArray, 363 | Y: npt.NDArray 364 | ) -> Tuple[npt.NDArray, npt.NDArray]: 365 | """Aligns second network responses to first network responses. 366 | 367 | Parameters 368 | ---------- 369 | X : np.ndarray 370 | First network's responses, with Size[(images, repeats, neurons)]. 371 | Y : np.ndarray 372 | Second network's responses, with Size[(images, repeats, neurons)]. 373 | 374 | Returns 375 | ------- 376 | X : np.ndarray 377 | First network's responses, with Size[(images, repeats, neurons)]. 378 | Y_aligned : np.ndarray 379 | Aligned second network's responses, with Size[(images, repeats, neurons)]. 380 | """ 381 | assert X.shape == Y.shape 382 | Y_aligned = np.einsum("ijk,kl->ijl", Y, self.Q, optimize=True) 383 | return X, Y_aligned 384 | 385 | def score(self, X: npt.NDArray, Y: npt.NDArray) -> float: 386 | """Compute the Energy distance metric between two networks. 387 | 388 | Parameters 389 | ---------- 390 | X : np.ndarray 391 | First network's responses, with Size[(images, repeats, neurons)]. 392 | Y : np.ndarray 393 | Second network's responses, with Size[(images, repeats, neurons)]. 394 | 395 | Returns 396 | ------- 397 | score : float 398 | Energy distance metric between two networks. 399 | """ 400 | X, Y = self.transform(X, Y) 401 | m = X.shape[0] # num images 402 | n_samples = X.shape[1] 403 | 404 | combs = np.array(list( 405 | itertools.combinations(range(n_samples), 2) 406 | )) 407 | prod = np.array(list( 408 | itertools.product(range(n_samples), range(n_samples)) 409 | )) 410 | 411 | d_xy, d_xx, d_yy = 0, 0, 0 412 | for i in range(m): 413 | d_xy += np.mean(np.linalg.norm(X[i][prod[:, 0]] - Y[i][prod[:, 1]], axis=-1)) 414 | d_xx += np.mean(np.linalg.norm(X[i][combs[:, 0]] - X[i][combs[:, 1]], axis=-1)) 415 | d_yy += np.mean(np.linalg.norm(Y[i][combs[:, 0]] - Y[i][combs[:, 1]], axis=-1)) 416 | 417 | return np.sqrt(max(0, (d_xy / m) - .5*((d_xx / m) + (d_yy / m)))) 418 | 419 | def fit_score(self, X: npt.NDArray, Y: npt.NDArray) -> float: 420 | """Fits optimal alignment and computes the Energy distance metric between two networks. 421 | 422 | Parameters 423 | ---------- 424 | X : np.ndarray 425 | First network's responses, with Size[(images, repeats, neurons)]. 426 | Y : np.ndarray 427 | Second network's responses, with Size[(images, repeats, neurons)]. 428 | 429 | Returns 430 | ------- 431 | score : float 432 | Energy distance metric between two networks. 433 | """ 434 | return self.fit(X, Y).score(X, Y) 435 | 436 | def _compute_distance(self, i, j, X, Y, X_test, Y_test): 437 | """Helper function for multiprocessing.""" 438 | 439 | self.fit(X, Y) 440 | dist_train = self.score(X, Y) 441 | if X_test is None and Y_test is None: 442 | dist_test = np.inf 443 | else: 444 | dist_test = self.score(X_test, Y_test) 445 | return i, j, dist_train, dist_test 446 | 447 | def _compute_distance_star(self, args): 448 | """Helper function for multiprocessing. 449 | Using this allows us to use tqdm to track progress via imap_unordered. 450 | """ 451 | return self._compute_distance(*args) 452 | 453 | def pairwise_distances( 454 | self, 455 | train_data: List[Tuple[npt.NDArray, npt.NDArray]], 456 | test_data: Optional[List[Tuple[npt.NDArray, npt.NDArray]]]=None, 457 | processes: Optional[int] = None, 458 | verbose: bool = True, 459 | ): 460 | """Computes pairwise distances between all pairs of networks w/ multiprocessing. 461 | 462 | We suggest setting "OMP_NUM_THREADS=1" in your environment variables to avoid oversubscription 463 | (multiprocesses competing for the same CPU). 464 | 465 | Parameters 466 | ---------- 467 | train_data: List[npt.NDArray] 468 | List of Size([images, repeats, neurons]) for train data. 469 | test_data: List[npt.NDArray], optional 470 | List of Size([images, repeats, neurons]) for test data. If None, the output 471 | distance matrix will be np.inf. 472 | processes: int, optional 473 | Number of processes to use. If None, defaults to number of CPUs. 474 | verbose: bool, optional 475 | Whether to display progress bar. 476 | 477 | Returns 478 | ------- 479 | D_train: npt.NDArray 480 | n_networks x n_networks distance matrix. 481 | D_test: npt.NDArray 482 | n_networks x n_networks distance matrix. If test_data is None, this is 483 | a matrix of np.inf. 484 | """ 485 | n_networks = len(train_data) 486 | n_dists = n_networks*(n_networks-1)//2 487 | 488 | # create generator of args for multiprocessing 489 | ij = itertools.combinations(range(n_networks), 2) 490 | if test_data is None: 491 | args = ((i, j, train_data[i], train_data[j], None, None) for i, j in ij) 492 | else: 493 | args = ((i, j, train_data[i], train_data[j], test_data[i], test_data[j]) for i, j in ij) 494 | 495 | if verbose: 496 | print(f"Parallelizing {n_dists} distance calculations with {multiprocessing.cpu_count() if processes is None else processes} processes.") 497 | pbar = lambda x: tqdm(x, total=n_dists, desc="Computing distances") 498 | else: 499 | pbar = lambda x: x 500 | 501 | with multiprocessing.Pool(processes=processes) as pool: 502 | results = [] 503 | for result in pbar(pool.imap_unordered(self._compute_distance_star, args)): 504 | results.append(result) 505 | pool.close() 506 | pool.join() 507 | 508 | D_train = np.zeros((n_networks, n_networks)) 509 | D_test = np.zeros((n_networks, n_networks)) 510 | 511 | for i, j, dist_train, dist_test in results: 512 | D_train[i, j], D_train[j, i] = dist_train, dist_train 513 | D_test[i, j], D_test[j, i] = dist_test, dist_test 514 | 515 | return D_train, D_test 516 | 517 | 518 | def _fit_gaussian_alignment( 519 | means_X: npt.NDArray, 520 | means_Y: npt.NDArray, 521 | covs_X: npt.NDArray, 522 | covs_Y: npt.NDArray, 523 | T: npt.NDArray, 524 | alpha: float, 525 | group: Literal["orth", "perm", "identity"], 526 | niter: int, 527 | tol: float, 528 | ) -> Tuple[npt.NDArray, List[float]]: 529 | """Helper function for fitting alignment between Gaussian-distributed responses.""" 530 | 531 | vX, uX = np.linalg.eigh(covs_X) 532 | sX = np.einsum("ijk,ik,ilk->ijl", uX, np.sqrt(vX), uX, optimize=True) 533 | 534 | vY, uY = np.linalg.eigh(covs_Y) 535 | sY = np.einsum("ijk,ik,ilk->ijl", uY, np.sqrt(vY), uY, optimize=True) 536 | 537 | loss_hist = [] 538 | 539 | for i in range(niter): 540 | Qs = [align(T.T @ sy, sx, group="orth") for sx, sy in zip(sX, sY)] 541 | A = np.row_stack( 542 | [alpha * means_X] + 543 | [(2 - alpha) * sx for sx in sX] 544 | ) 545 | r_sY = [] 546 | B = np.row_stack( 547 | [alpha * means_Y] + 548 | [Q.T @ ((2 - alpha) * sy) for Q, sy in zip(Qs, sY)] 549 | ) 550 | T = align(B, A, group=group) 551 | loss_hist.append(np.linalg.norm(A - B @ T)) 552 | if i < 2: 553 | pass 554 | elif (loss_hist[-2] - loss_hist[-1]) < tol: 555 | break 556 | 557 | return T, loss_hist 558 | -------------------------------------------------------------------------------- /netrep/metrics/stochastic_process.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import itertools 3 | import multiprocessing 4 | from typing import Tuple, Optional, Union, Literal, List 5 | 6 | import numpy as np 7 | import numpy.typing as npt 8 | from sklearn.utils.validation import check_random_state 9 | from tqdm import tqdm 10 | 11 | from netrep.utils import align, sq_bures_metric, rand_orth 12 | 13 | 14 | class GPStochasticMetric: 15 | """2-Wasserstein distance between Gaussian-distributed network responses. 16 | 17 | Attributes 18 | ---------- 19 | alpha: float between 0 and 2 20 | Interpolates between covariance-only and mean-only distance metrics. 21 | When alpha == 0: only uses covariance. 22 | When alpha == 1: computes 2-Wasserstein. 23 | When alpha == 2: only uses means (i.e. deterministic metric). 24 | group: Literal["orth", "perm", "identity"] 25 | Invariance group over which to optimize. 26 | init: Literal["means", "rand"] 27 | Transform initialization. 28 | niter: int 29 | Number of optimization iterations. 30 | tol: float 31 | Optimization tolerance. 32 | n_restarts: int 33 | Number of restarts. Only valid when `init` is "rand". 34 | T: np.ndarray 35 | Optimal alignment matrix. 36 | loss_hist: List[float] 37 | Loss history. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | n_dims, 43 | alpha: float=1.0, 44 | group: Literal["orth", "perm", "identity"] = "orth", 45 | init: Literal["means", "rand"] = "means", 46 | niter: int = 1000, 47 | tol: float = 1e-8, 48 | random_state: Optional[Union[int, np.random.RandomState]]=None, 49 | n_restarts: int = 1, 50 | ): 51 | if (alpha < 0) or (alpha > 2): 52 | raise ValueError("alpha parameter should be between zero and two.") 53 | self.alpha = alpha 54 | self.group = group 55 | self.init = init 56 | self.niter = niter 57 | self.tol = tol 58 | self._rs = check_random_state(random_state) 59 | self.n_restarts = n_restarts 60 | self.n_dims = n_dims 61 | if self.init == "means": 62 | assert n_restarts == 1 63 | 64 | def fit( 65 | self, 66 | X: Tuple[npt.NDArray, npt.NDArray], 67 | Y: Tuple[npt.NDArray, npt.NDArray] 68 | ) -> GPStochasticMetric: 69 | """Aligns network responses with interpolated 2-Wasserstein ground metric. 70 | 71 | Parameters 72 | ---------- 73 | X : Tuple[np.ndarray, np.ndarray] 74 | Tuple of (means, covariances) for first set of network responses. Means has 75 | shape (n_images, n_neurons) and covariances has shape 76 | (n_images, n_neurons, n_neurons). 77 | Y : Tuple[np.ndarray, np.ndarray] 78 | Tuple of (means, covariances) for second set of network responses. Means has 79 | shape (n_images, n_neurons) and covariances has shape 80 | (n_images, n_neurons, n_neurons). 81 | 82 | Returns 83 | ------- 84 | self: GaussianStochasticMetric 85 | Instance of class with optimal alignment matrix stored in `self.T`. 86 | """ 87 | means_X, covs_X = X 88 | means_Y, covs_Y = Y 89 | 90 | assert means_X.shape == means_Y.shape 91 | assert covs_X.shape == covs_Y.shape 92 | assert means_X.shape[0] == covs_X.shape[0] 93 | 94 | n_times = means_X.shape[0]//self.n_dims 95 | 96 | means_X_t = means_X.reshape(n_times,self.n_dims) 97 | means_Y_t = means_Y.reshape(n_times,self.n_dims) 98 | 99 | best_loss = np.inf 100 | for _ in range(self.n_restarts): 101 | 102 | if self.init == "means": 103 | init_T = align(means_Y_t, means_X_t, group=self.group) 104 | elif self.init == "rand": 105 | init_T = rand_orth(means_X_t.shape[1], random_state=self._rs) 106 | 107 | T, loss_hist = _fit_gp_alignment( 108 | self.n_dims, means_X_t, means_Y_t, covs_X, covs_Y, init_T, 109 | self.alpha, self.group, self.niter, self.tol 110 | ) 111 | if best_loss > loss_hist[-1]: 112 | best_loss = loss_hist[-1] 113 | best_T = T 114 | 115 | self.T = best_T 116 | self.loss_hist = loss_hist 117 | return self 118 | 119 | def transform( 120 | self, 121 | X: Tuple[npt.NDArray, npt.NDArray], 122 | Y: Tuple[npt.NDArray, npt.NDArray] 123 | ) -> Tuple[Tuple[npt.NDArray, npt.NDArray], Tuple[npt.NDArray, npt.NDArray]]: 124 | """Aligns second set of network responses with first set. 125 | 126 | Parameters 127 | ---------- 128 | X : Tuple[np.ndarray, np.ndarray] 129 | Tuple of (means, covariances) for first set of network responses. Means has 130 | shape (n_images, n_neurons) and covariances has shape 131 | (n_images, n_neurons, n_neurons). 132 | Y : Tuple[np.ndarray, np.ndarray] 133 | Tuple of (means, covariances) for second set of network responses. Means has 134 | shape (n_images, n_neurons) and covariances has shape 135 | (n_images, n_neurons, n_neurons). 136 | 137 | Returns 138 | ------- 139 | X : Tuple[np.ndarray, np.ndarray] 140 | Same as input. 141 | Y_transformed : Tuple[np.ndarray, np.ndarray] 142 | Aligned tuple of (means, covariances) for second set of network responses. 143 | """ 144 | means_Y, covs_Y = Y 145 | 146 | n_times = means_Y.shape[0]//self.n_dims 147 | means_Y_t = means_Y.reshape(n_times,self.n_dims) 148 | 149 | T_full = np.kron(np.eye(n_times),self.T) 150 | 151 | Y_transformed = ( 152 | (means_Y_t @ self.T).flatten(), 153 | T_full.T@covs_Y@T_full 154 | ) 155 | return X, Y_transformed 156 | 157 | def score( 158 | self, 159 | X: Tuple[npt.NDArray, npt.NDArray], 160 | Y: Tuple[npt.NDArray, npt.NDArray] 161 | ) -> float: 162 | """Computes interpolated 2-Wasserstein distance between aligned network responses. 163 | 164 | Parameters 165 | ---------- 166 | X: Tuple[np.ndarray, np.ndarray] 167 | Tuple of (means, covariances) for first set of network responses. Means has 168 | shape (n_images, n_neurons) and covariances has shape 169 | (n_images, n_neurons, n_neurons). 170 | Y: Tuple[np.ndarray, np.ndarray] 171 | Tuple of (means, covariances) for second set of network responses. Means has 172 | shape (n_images, n_neurons) and covariances has shape 173 | (n_images, n_neurons, n_neurons). 174 | 175 | Returns 176 | ------- 177 | score: float 178 | Interpolated 2-Wasserstein distance between aligned network responses. 179 | """ 180 | X, Y = self.transform(X, Y) 181 | mX, sX = X 182 | mY, sY = Y 183 | 184 | A = np.sum((mX - mY) ** 2) 185 | B = sq_bures_metric(sX, sY) 186 | mn = np.mean(self.alpha * A + (2 - self.alpha) * B) 187 | # mn should always be positive but sometimes numerical rounding errors 188 | # cause mn to be very slightly negative, causing sqrt(mn) to be nan. 189 | # Thus, we take sqrt(abs(mn)) and pass through the sign. Any large 190 | # negative outputs should be caught by unit tests. 191 | return np.sign(mn) * np.sqrt(abs(mn)) 192 | 193 | def fit_score( 194 | self, 195 | X: Tuple[npt.NDArray, npt.NDArray], 196 | Y: Tuple[npt.NDArray, npt.NDArray] 197 | ) -> float: 198 | """Fits alignment matrix and returns distance. 199 | 200 | Parameters 201 | ---------- 202 | X: Tuple[np.ndarray, np.ndarray] 203 | Tuple of (means, covariances) for first set of network responses. Means has 204 | shape (n_images, n_neurons) and covariances has shape 205 | (n_images, n_neurons, n_neurons). 206 | Y: Tuple[np.ndarray, np.ndarray] 207 | Tuple of (means, covariances) for second set of network responses. Means has 208 | shape (n_images, n_neurons) and covariances has shape 209 | (n_images, n_neurons, n_neurons). 210 | 211 | Returns 212 | ------- 213 | score: float 214 | Interpolated 2-Wasserstein distance between aligned network responses. 215 | """ 216 | return self.fit(X, Y).score(X, Y) 217 | 218 | def _compute_distance(self, i, j, X, Y, X_test, Y_test, eps): 219 | """Helper function for multiprocessing.""" 220 | X = (X[0], X[1] + eps * np.eye(X[1].shape[1])) # regularize covariance 221 | Y = (Y[0], Y[1] + eps * np.eye(Y[1].shape[1])) 222 | 223 | self.fit(X, Y) 224 | dist_train = self.score(X, Y) 225 | if X_test is None and Y_test is None: 226 | dist_test = np.inf 227 | else: 228 | dist_test = self.score(X_test, Y_test) 229 | return i, j, dist_train, dist_test 230 | 231 | def _compute_distance_star(self, args): 232 | """Helper function for multiprocessing. 233 | Using this allows us to use tqdm to track progress via imap_unordered. 234 | """ 235 | return self._compute_distance(*args) 236 | 237 | def pairwise_distances( 238 | self, 239 | train_data: List[Tuple[npt.NDArray, npt.NDArray]], 240 | test_data: Optional[List[Tuple[npt.NDArray, npt.NDArray]]]=None, 241 | eps: float = 1E-6, 242 | processes: Optional[int] = None, 243 | verbose: bool = True, 244 | ): 245 | """Computes pairwise distances between all pairs of networks w/ multiprocessing. 246 | 247 | We suggest setting "OMP_NUM_THREADS=1" in your environment variables to avoid oversubscription 248 | (multiprocesses competing for the same CPU). 249 | 250 | Parameters 251 | ---------- 252 | train_data: List[Tuple[npt.NDArray, npt.NDArray]] 253 | List of tuples of (means, covariances) for train data. 254 | test_data: List[Tuple[npt.NDArray, npt.NDArray]], optional 255 | List of tuples of (means, covariances) for test data. If None, the output 256 | distance matrix will be np.inf. 257 | eps: float, optional 258 | Add eps * I to each covariances to regularize. 259 | processes: int, optional 260 | Number of processes to use. If None, defaults to number of CPUs. 261 | verbose: bool, optional 262 | Whether to display progress bar. 263 | 264 | Returns 265 | ------- 266 | D_train: npt.NDArray 267 | n_networks x n_networks distance matrix. 268 | D_test: npt.NDArray 269 | n_networks x n_networks distance matrix. If test_data is None, this is 270 | a matrix of np.inf. 271 | """ 272 | n_networks = len(train_data) 273 | n_dists = n_networks*(n_networks-1)//2 274 | 275 | # create generator of args for multiprocessing 276 | ij = itertools.combinations(range(n_networks), 2) 277 | if test_data is None: 278 | args = ((i, j, train_data[i], train_data[j], None, None, eps) for i, j in ij) 279 | else: 280 | args = ((i, j, train_data[i], train_data[j], test_data[i], test_data[j], eps) for i, j in ij) 281 | 282 | if verbose: 283 | print(f"Parallelizing {n_dists} distance calculations with {multiprocessing.cpu_count() if processes is None else processes} processes.") 284 | pbar = lambda x: tqdm(x, total=n_dists, desc="Computing distances") 285 | else: 286 | pbar = lambda x: x 287 | 288 | with multiprocessing.Pool(processes=processes) as pool: 289 | results = [] 290 | for result in pbar(pool.imap_unordered(self._compute_distance_star, args)): 291 | results.append(result) 292 | 293 | D_train = np.zeros((n_networks, n_networks)) 294 | D_test = np.zeros((n_networks, n_networks)) 295 | 296 | for i, j, dist_train, dist_test in results: 297 | D_train[i, j], D_train[j, i] = dist_train, dist_train 298 | D_test[i, j], D_test[j, i] = dist_test, dist_test 299 | 300 | return D_train, D_test 301 | 302 | 303 | 304 | def _fit_gp_alignment( 305 | n_dims: int, 306 | means_X: npt.NDArray, 307 | means_Y: npt.NDArray, 308 | covs_X: npt.NDArray, 309 | covs_Y: npt.NDArray, 310 | T: npt.NDArray, 311 | alpha: float, 312 | group: Literal["orth", "perm", "identity"], 313 | niter: int, 314 | tol: float, 315 | ) -> Tuple[npt.NDArray, List[float]]: 316 | """Helper function for fitting alignment between Gaussian-distributed responses.""" 317 | 318 | vX, uX = np.linalg.eigh(covs_X) 319 | sX = np.einsum("jk,k,lk->jl", uX, np.sqrt(vX), uX, optimize=True) 320 | 321 | vY, uY = np.linalg.eigh(covs_Y) 322 | sY = np.einsum("jk,k,lk->jl", uY, np.sqrt(vY), uY, optimize=True) 323 | 324 | loss_hist = [] 325 | 326 | n_times = covs_X.shape[0]//n_dims 327 | 328 | for i in range(niter): 329 | 330 | Qs = align(np.kron(np.eye(n_times),T.T) @ sY, sX, group="orth") 331 | A = np.row_stack( 332 | [alpha * means_X] + 333 | [split((2-alpha)*sX,n_dims,n_dims)] 334 | ) 335 | 336 | B = np.row_stack( 337 | [alpha * means_Y] + 338 | [split(Qs.T@((2-alpha)*sY),n_dims,n_dims)] 339 | ) 340 | 341 | T = align(B, A, group=group) 342 | loss_hist.append(np.linalg.norm(A - B @ T)) 343 | if i < 2: 344 | pass 345 | elif (loss_hist[-2] - loss_hist[-1]) < tol: 346 | break 347 | 348 | return T, loss_hist 349 | 350 | 351 | def split(array, nrows, ncols): 352 | """Split a matrix into sub-matrices.""" 353 | 354 | r, h = array.shape 355 | blocks = array.reshape( 356 | h//nrows, nrows, -1, ncols 357 | ).swapaxes(1,2).reshape(-1, nrows, ncols) 358 | 359 | return blocks.reshape(-1,blocks.shape[-1]) 360 | 361 | -------------------------------------------------------------------------------- /netrep/multiset.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | from tqdm import tqdm 4 | from sklearn.utils.validation import check_array, check_random_state 5 | from netrep.utils import align 6 | 7 | 8 | def euclidean_tangent_space(Xs, Xbar, group="orth"): 9 | """ 10 | Transform list of K matrices ('Xs'), into an approximate 11 | Euclidean space (tangent space) at a point Xbar. 12 | 13 | Note: assumes that the ground metric is Euclidean. 14 | 15 | Parameters 16 | ---------- 17 | Xs : list of K matrices, (m x n) ndarrays. 18 | Matrix-valued datasets to compare. 19 | 20 | Xbar : (m x n) ndarray. 21 | Reference point. Each element in 'Xs' is aligned to 'Xbar'. 22 | 23 | group : str 24 | Specifies group of alignment operations. 25 | 26 | Returns 27 | ------- 28 | Xs_tang : list of K matrices, (m x n) ndarrays. 29 | Matrix-valued datasets in tangent space. These are the 30 | residuals of each element in 'Xs' after alignment to 31 | 'Xbar'. 32 | """ 33 | Xs_tang = np.empty((len(Xs), Xbar.shape[0], Xbar.shape[1])) 34 | for i, X in enumerate(Xs): 35 | Xs_tang[i] = Xbar - (X @ align(X, Xbar, group=group)) 36 | return Xs_tang 37 | 38 | 39 | def pairwise_distances( 40 | metric, traindata, testdata=None, verbose=True, 41 | enable_caching=False 42 | ): 43 | """ 44 | Compute pairwise distances between a collection of 45 | networks. Similar to ``scipy.spatial.distance.pdist``. 46 | 47 | Parameters 48 | ---------- 49 | metric : Metric 50 | Metric to evaluate pairwise distances 51 | 52 | traindata : list of K matrices, (m x n) ndarrays. 53 | Matrix-valued datasets to compare. 54 | 55 | testdata : list of K matrices, (p x n) ndarrays, optional. 56 | If provided, metrics are fit to traindata 57 | and then evaluated on testdata. 58 | 59 | verbose : bool, optional 60 | Prints progress bar if True. (Default is True.) 61 | 62 | Returns 63 | ------- 64 | train_dists : (K x K) symmetric matrix. 65 | Matrix of pairwise distances on training set. 66 | 67 | test_dists : (K x K) symmetric matrix, optional. 68 | Matrix of pairwise distances on the test set. 69 | """ 70 | 71 | # Allocate space for distances. 72 | m = len(traindata) 73 | D_train = np.zeros((m, m)) 74 | 75 | if testdata is not None: 76 | D_test = np.zeros((m, m)) 77 | 78 | # Set up progress bar. 79 | if verbose: 80 | pbar = tqdm(total=(m * (m - 1)) // 2) 81 | 82 | # Fit partial whitening transforms to each dataset. 83 | if enable_caching: 84 | caches = [metric.partial_fit(trn) for trn in traindata] 85 | 86 | # Compute all pairwise distances. 87 | for i in range(m): 88 | for j in range(i + 1, m): 89 | 90 | # Fit metric. 91 | if enable_caching: 92 | metric.finalize_fit(caches[i], caches[j]) 93 | else: 94 | metric.fit(traindata[i], traindata[j]) 95 | 96 | # Evaluate distance on the training set. 97 | D_train[i, j] = metric.score(traindata[i], traindata[j]) 98 | D_train[j, i] = D_train[i, j] 99 | 100 | # Evaluate distance on the test set. 101 | if testdata is not None: 102 | D_test[i, j] = metric.score(testdata[i], testdata[j]) 103 | D_test[j, i] = D_test[i, j] 104 | 105 | # Update progress bar. 106 | if verbose: 107 | pbar.update(1) 108 | 109 | # Close progress bar. 110 | if verbose: 111 | pbar.close() 112 | 113 | return D_train if (testdata is None) else (D_train, D_test) 114 | 115 | 116 | def cross_distances( 117 | metric, Xs, Ys, Xs_test=None, Ys_test=None, verbose=True 118 | ): 119 | """ 120 | Compute pairwise distances between two collections 121 | of networks. Similar to ``scipy.spatial.distance.cdist``. 122 | 123 | Parameters 124 | ---------- 125 | metric : Metric 126 | Metric to evaluate pairwise distances 127 | 128 | Xs : list of Nx matrices, (m x n) ndarrays. 129 | First set of matrix-valued datasets to compare. 130 | 131 | Ys : list of Ny matrices, (m x n) ndarrays. 132 | Second set of matrix-valued datasets to compare. 133 | 134 | Xs_test : list of Nx matrices, (p x n) ndarrays, optional. 135 | If provided, metrics are fit to data in Xs 136 | and then evaluated on Xs_test. 137 | 138 | Xs_test : list of Ny matrices, (p x n) ndarrays, optional. 139 | If provided, metrics are fit to data in Ys 140 | and then evaluated on Ys_test. 141 | 142 | verbose : bool, optional 143 | Prints progress bar if True. (Default is True.) 144 | 145 | Returns 146 | ------- 147 | train_dists : (Nx x Ny) matrix. 148 | Matrix of pairwise distances on training set. 149 | 150 | test_dists : (Nx x Ny) matrix, optional. 151 | Matrix of pairwise distances on the test set. 152 | """ 153 | 154 | # Allocate space for training distances. 155 | Nx, Ny = len(Xs), len(Ys) 156 | D_train = np.zeros((Nx, Ny)) 157 | 158 | # Allocate space for testing distances. 159 | if (Xs_test is not None) and (Xs_test is not None): 160 | if len(Xs_test) != Nx: 161 | raise ValueError( 162 | "Length of Xs_test does not match train set." 163 | ) 164 | if len(Ys_test) != Ny: 165 | raise ValueError( 166 | "Length of Ys_test does not match train set." 167 | ) 168 | D_test = np.zeros((Nx, Ny)) 169 | 170 | elif (Xs_test is None) and (Ys_test is not None): 171 | raise ValueError( 172 | "If 'Ys_test' is specified. 'Xs_test' must also" 173 | "be specified." 174 | ) 175 | 176 | elif (Ys_test is not None) and (Ys_test is None): 177 | raise ValueError( 178 | "If 'Xs_test' is specified. 'Ys_test' must also" 179 | "be specified." 180 | ) 181 | 182 | else: 183 | D_test = None 184 | 185 | # Create progress bar. 186 | if verbose: 187 | pbar = tqdm(total=(Nx * Ny)) 188 | 189 | # Compute distances. 190 | for i, j in itertools.product(range(Nx), range(Ny)): 191 | metric.fit(Xs[i], Ys[j]) 192 | D_train[i, j] = metric.score(Xs[i], Ys[j]) 193 | 194 | if D_test is not None: 195 | D_test[i, j] = metric.score(Xs_test[i], Ys_test[j]) 196 | 197 | if verbose: 198 | pbar.update(1) 199 | 200 | # Close progress bar. 201 | if verbose: 202 | pbar.close() 203 | 204 | # Return distance matrices. 205 | return D_train if (D_test is None) else (D_train, D_test) 206 | 207 | 208 | def frechet_mean( 209 | Xs, group="orth", 210 | random_state=None, tol=1e-3, max_iter=100, 211 | warmstart=None, verbose=False, method="streaming", 212 | return_aligned_Xs=False 213 | ): 214 | """ 215 | Estimate the average (Karcher/Frechet mean) of p networks in the 216 | metric space defined by: 217 | 218 | d*(X, Y) = min_{T} ||X - Y @ T||^2 219 | 220 | For some ground metric 'd' and alignment operations 'T'. 221 | 222 | Parameters 223 | ---------- 224 | Xs : list of p matrices, (m x n) ndarrays. 225 | Matrix-valued datasets to compare. Rotations are learned 226 | and applied in the n-dimensional space. 227 | 228 | group : str 229 | Specifies the set of allowable alignment operations (a group of 230 | isometries). Must be one of ("orth", "perm", "identity"). 231 | 232 | random_state : np.random.RandomState 233 | Specifies state of the random number generator. 234 | 235 | tol : float 236 | Convergence tolerance 237 | 238 | max_iter : int, optional. 239 | Maximum number of iterations to apply. Default = 100. 240 | 241 | warmstart : (m x n) ndarray, optional 242 | If provided, Xbar is initialized to this estimate. 243 | 244 | verbose : bool 245 | If True, print progress. 246 | 247 | return_aligned_Xs : bool 248 | If True, return list of Xs aligned to Xbar. 249 | 250 | Returns 251 | ------- 252 | Xbar : (m x n) ndarray. 253 | Average activation matrix. 254 | 255 | aligned_Xs : list of (m x n) ndarray 256 | Returned if `return_aligned_Xs` option is set to True. 257 | """ 258 | 259 | if method == "streaming": 260 | Xbar = _euclidean_barycenter_streaming( 261 | Xs, group, random_state, tol, max_iter, warmstart, 262 | verbose 263 | ) 264 | elif method == "full_batch": 265 | Xbar = _euclidean_barycenter_full_batch( 266 | Xs, group, random_state, tol, max_iter, warmstart, 267 | verbose 268 | ) 269 | 270 | if return_aligned_Xs: 271 | aligned_Xs = [ 272 | x @ align(x, Xbar, group=group) for x in Xs 273 | ] 274 | 275 | return (Xbar, aligned_Xs) if return_aligned_Xs else Xbar 276 | 277 | 278 | def _euclidean_barycenter_full_batch( 279 | Xs, group, random_state, tol, max_iter, warmstart, verbose 280 | ): 281 | """ 282 | Parameters 283 | ---------- 284 | Xs : list of p matrices, (m x n) ndarrays. 285 | Matrix-valued datasets to compare. Rotations are learned 286 | and applied in the n-dimensional space. 287 | 288 | group : str 289 | Specifies group of ("orth", "perm", "roll", "identity") 290 | 291 | random_state : np.random.RandomState 292 | 293 | tol : float 294 | Convergence tolerance 295 | 296 | max_iter : int, optional. 297 | Maximum number of iterations to apply. 298 | 299 | verbose : bool 300 | If True, print progress. 301 | 302 | Returns 303 | ------- 304 | Xbar : (m x n) ndarray. 305 | Average activation matrix. 306 | """ 307 | 308 | # Handle simple case of no alignment operation. This is just a classic average. 309 | if group == "identity": 310 | return np.mean(Xs, axis=0) 311 | 312 | # Check input 313 | Xs = check_array(Xs, allow_nd=True) 314 | if Xs.ndim != 3: 315 | raise ValueError( 316 | "Expected 3d array with shape" 317 | "(n_datasets x n_observations x n_features), but " 318 | "got {}-d array with shape {}".format(Xs.ndim, Xs.shape)) 319 | 320 | # If only one matrix is provided, the barycenter is trivial. 321 | if Xs.shape[0] == 1: 322 | return Xs[0] 323 | 324 | # Check random state and initialize random permutation over networks. 325 | rs = check_random_state(random_state) 326 | 327 | # Initialize barycenter. 328 | Xbar = Xs[np.random.randint(len(Xs))] if (warmstart is None) else warmstart 329 | X0 = np.empty_like(Xbar) 330 | 331 | # Main loop 332 | itercount, n, chg = 0, 1, np.inf 333 | while (chg > tol) and (itercount < max_iter): 334 | 335 | # Save current barycenter for convergence checking. 336 | np.copyto(X0, Xbar) 337 | Xbar.fill(0.0) 338 | 339 | # Iterate over datasets. Align each dataset to last 340 | # average (held in X0), take running sum. 341 | for x in Xs: 342 | Xbar += x @ align(x, X0, group=group) 343 | 344 | Xbar /= len(Xs) 345 | 346 | # Detect convergence. 347 | chg = np.linalg.norm(Xbar - X0) / np.sqrt(Xbar.size) 348 | 349 | # Display progress. 350 | if verbose: 351 | print(f"Iteration {itercount}, Change: {chg}") 352 | 353 | # Move to next iteration, with new random ordering over datasets. 354 | itercount += 1 355 | 356 | return Xbar 357 | 358 | 359 | def _euclidean_barycenter_streaming( 360 | Xs, group, random_state, tol, max_iter, warmstart, verbose 361 | ): 362 | """ 363 | Parameters 364 | ---------- 365 | Xs : list of p matrices, (m x n) ndarrays. 366 | Matrix-valued datasets to compare. Rotations are learned 367 | and applied in the n-dimensional space. 368 | 369 | group : str 370 | Specifies group of ("orth", "perm", "roll", "identity") 371 | 372 | random_state : np.random.RandomState 373 | 374 | tol : float 375 | Convergence tolerance 376 | 377 | max_iter : int. 378 | Maximum number of iterations to apply. 379 | 380 | warmstart : None or (m x n) ndarray 381 | If provided, Xbar is initialized to this estimate. 382 | 383 | verbose : bool 384 | If True, print progress. 385 | 386 | Returns 387 | ------- 388 | Xbar : (m x n) ndarray. 389 | Average activation matrix. 390 | """ 391 | 392 | # Handle simple case of no alignment operation. This is just a classic average. 393 | if group == "identity": 394 | return np.mean(Xs, axis=0) 395 | 396 | # Check input 397 | Xs = check_array(Xs, allow_nd=True) 398 | if Xs.ndim != 3: 399 | raise ValueError( 400 | "Expected 3d array with shape" 401 | "(n_datasets x n_observations x n_features), but " 402 | "got {}-d array with shape {}".format(Xs.ndim, Xs.shape)) 403 | 404 | # If only one matrix is provided, the barycenter is trivial. 405 | if Xs.shape[0] == 1: 406 | return Xs[0] 407 | 408 | # Check random state and initialize random permutation over networks. 409 | rs = check_random_state(random_state) 410 | indices = rs.permutation(len(Xs)) 411 | 412 | # Initialize barycenter. 413 | Xbar = Xs[indices[-1]] if (warmstart is None) else warmstart 414 | print(Xbar.shape) 415 | X0 = np.empty_like(Xbar) 416 | 417 | # Main loop 418 | itercount, n, chg = 0, 1, np.inf 419 | while (chg > tol) and (itercount < max_iter): 420 | 421 | # Save current barycenter for convergence checking. 422 | np.copyto(X0, Xbar) 423 | 424 | # Iterate over datasets. 425 | for i in indices: 426 | 427 | # Align i-th dataset to barycenter. 428 | XQ = Xs[i] @ align(Xs[i], X0, group=group) 429 | 430 | # Take a small step towards aligned representation. 431 | Xbar = (n / (n + 1)) * Xbar + (1 / (n + 1)) * XQ 432 | n += 1 433 | 434 | # Detect convergence. 435 | chg = np.linalg.norm(Xbar - X0) / np.sqrt(Xbar.size) 436 | 437 | # Display progress. 438 | if verbose: 439 | print(f"Iteration {itercount}, Change: {chg}") 440 | 441 | # Move to next iteration, with new random ordering over datasets. 442 | rs.shuffle(indices) 443 | itercount += 1 444 | 445 | return Xbar 446 | -------------------------------------------------------------------------------- /netrep/rbf_sampler.py: -------------------------------------------------------------------------------- 1 | from sklearn.base import TransformerMixin, BaseEstimator 2 | from sklearn.utils.validation import check_random_state, check_is_fitted 3 | 4 | from netrep.utils import rand_orth 5 | import numpy as np 6 | 7 | 8 | class RBFOrthoSampler(TransformerMixin, BaseEstimator): 9 | 10 | def __init__(self, gamma=1., n_components=100, random_state=None): 11 | self.gamma = gamma 12 | self.n_components = n_components 13 | self.random_state = random_state 14 | 15 | def fit(self, X, y=None): 16 | 17 | X = self._validate_data(X, accept_sparse='csr') 18 | rs = check_random_state(self.random_state) 19 | n_features = X.shape[1] 20 | 21 | nc = self.n_components // 2 22 | self.random_weights_ = np.full((n_features, nc), np.nan) 23 | 24 | i = 0 25 | while i < nc: 26 | Q = rand_orth(n_features, random_state=rs) 27 | j = min(nc, i + n_features) 28 | self.random_weights_[:, i:j] = Q[:, :(j - i)] 29 | i = j 30 | 31 | self.random_weights_ *= np.sqrt(2 * self.gamma) 32 | self.random_weights_ *= np.sqrt(rs.chisquare(n_features, size=(1, nc))) 33 | 34 | return self 35 | 36 | def transform(self, X): 37 | check_is_fitted(self) 38 | P = X @ self.random_weights_ 39 | return np.column_stack((np.cos(P), np.sin(P))) / np.sqrt(P.shape[1]) 40 | -------------------------------------------------------------------------------- /netrep/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Miscellaneous helper functions. 3 | """ 4 | from typing import Tuple, Literal, Union, Optional 5 | 6 | import numpy as np 7 | import numpy.typing as npt 8 | from scipy.linalg import orthogonal_procrustes 9 | from scipy.optimize import linear_sum_assignment 10 | import scipy.sparse 11 | from scipy.stats import ortho_group 12 | from sklearn.utils.extmath import randomized_svd 13 | from sklearn.metrics.pairwise import pairwise_kernels 14 | from sklearn.utils.validation import check_random_state 15 | 16 | 17 | def align( 18 | X: npt.NDArray, 19 | Y: npt.NDArray, 20 | group: Literal["orth", "perm", "identity"] = "orth" 21 | ) -> Union[npt.NDArray, scipy.sparse.csr_matrix, scipy.sparse.dia_matrix]: 22 | """Return a matrix that optimally aligns 'X' to 'Y'. Note 23 | that the optimal alignment is the same for either the 24 | angular distance or the Euclidean distance since all 25 | alignments come from sub-groups of the orthogonal group. 26 | 27 | Parameters 28 | ---------- 29 | X : (m x n) ndarray. 30 | Activation patterns across 'm' inputs and 'n' neurons, 31 | sampled from the first network (the one which is transformed 32 | by the alignment operation). 33 | 34 | Y : (m x n) ndarray. 35 | Activation patterns across 'm' inputs and 'n' neurons, 36 | sampled from the second network (the one which is fixed). 37 | 38 | group : Literal["orth", "perm", "identity"] 39 | Specifies the set of allowable alignment operations (a group of 40 | isometries). Must be one of ("orth", "perm", "identity"). 41 | 42 | Returns 43 | ------- 44 | T : (n x n) ndarray or sparse matrix. 45 | Linear operator such that 'X @ T' is optimally aligned to 'Y'. 46 | Note further that 'Y @ T.transpose()' is optimally aligned to 'X', 47 | by symmetry. 48 | """ 49 | 50 | if group == "orth": 51 | return orthogonal_procrustes(X, Y)[0] 52 | 53 | elif group == "perm": 54 | ri, ci = linear_sum_assignment(X.T @ Y, maximize=True) 55 | n = ri.size 56 | return scipy.sparse.csr_matrix( 57 | (np.ones(n), (ri, ci)), shape=(n, n) 58 | ) 59 | 60 | elif group == "identity": 61 | return scipy.sparse.eye(X.shape[1]) 62 | 63 | else: 64 | raise ValueError(f"Specified group '{group}' not recognized.") 65 | 66 | 67 | def posdefsqrt(A): 68 | va, ua = np.linalg.eigh(A) 69 | return (ua * np.sqrt(np.maximum(va, 0.0))[None, :]) @ ua.T 70 | 71 | 72 | def sq_bures_metric_slow(A: npt.NDArray, B: npt.NDArray) -> float: 73 | """Slow way to compute the square of the Bures metric between two 74 | positive-definite matrices. 75 | """ 76 | va, ua = np.linalg.eigh(A) 77 | Asq = ua @ (np.sqrt(np.maximum(va[:, None], 0.0)) * ua.T) 78 | vbab = np.maximum(np.linalg.eigvalsh(Asq @ B @ Asq), 0.0) 79 | return ( 80 | np.trace(A) + np.trace(B) - 2 * np.sum(np.sqrt(vbab)) 81 | ) 82 | 83 | 84 | def sq_bures_metric(A: npt.NDArray, B: npt.NDArray) -> float: 85 | """Slow way to compute the square of the Bures metric between two 86 | positive-definite matrices. 87 | """ 88 | va, ua = np.linalg.eigh(A) 89 | vb, ub = np.linalg.eigh(B) 90 | sva = np.sqrt(np.maximum(va, 0.0)) 91 | svb = np.sqrt(np.maximum(vb, 0.0)) 92 | return ( 93 | np.sum(va) + np.sum(vb) - 2 * np.sum( 94 | np.linalg.svd( 95 | (sva[:, None] * ua.T) @ (ub * svb[None, :]), 96 | compute_uv=False 97 | ) 98 | ) 99 | ) 100 | 101 | 102 | def centered_kernel(*args, **kwargs): 103 | """ 104 | Lightly wraps `sklearn.metrics.pairwise.pairwise_kernels` 105 | to compute the centered kernel matrix. 106 | """ 107 | K = pairwise_kernels(*args, **kwargs) 108 | sc = np.sum(K, axis=0, keepdims=True) 109 | sr = np.sum(K, axis=1, keepdims=True) 110 | ss = np.sum(sc) 111 | return K - (sc / sr.size) - (sr / sc.size) + (ss / K.size) 112 | 113 | 114 | def angular_distance(X: npt.NDArray, Y: npt.NDArray) -> float: 115 | """Computes angular distance based on Frobenius inner product between two matrices. 116 | 117 | Parameters 118 | ---------- 119 | X : (m x n) ndarray 120 | Y : (m x n) ndarray 121 | 122 | Returns 123 | ------- 124 | distance : float between zero and pi. 125 | """ 126 | normalizer = np.linalg.norm(X.ravel()) * np.linalg.norm(Y.ravel()) 127 | corr = np.dot(X.ravel(), Y.ravel()) / normalizer 128 | # numerical precision issues require us to clip inputs to arccos 129 | return np.arccos(np.clip(corr, -1.0, 1.0)) 130 | 131 | 132 | def trunc_svd(X: npt.NDArray, r: int) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]: 133 | """Singular value decomposition, keeping top r components.""" 134 | return randomized_svd(X, n_components=r, n_iter=5) 135 | 136 | 137 | def econ_svd(X: npt.NDArray) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]: 138 | """Economic Singular Value Decomposition (SVD).""" 139 | return np.linalg.svd(X, full_matrices=False) 140 | 141 | 142 | def rand_orth( 143 | m: int, 144 | n: Optional[int] = None, 145 | random_state: Optional[Union[int, np.random.RandomState]] = None 146 | ) -> npt.NDArray: 147 | """Creates a random matrix with orthogonal columns or rows. 148 | 149 | Parameters 150 | ---------- 151 | m : int 152 | First dimension 153 | n : int 154 | Second dimension (if None, matrix is m x m) 155 | random_state : int or np.random.RandomState 156 | Specifies the state of the random number generator. 157 | 158 | Returns 159 | ------- 160 | Q : ndarray 161 | An m x n random matrix. If m > n, the columns are orthonormal. 162 | If m < n, the rows are orthonormal. If m == n, the result is 163 | an orthogonal matrix. 164 | """ 165 | rs = check_random_state(random_state) 166 | n = m if n is None else n 167 | 168 | Q = ortho_group.rvs(max(m, n), random_state=rs) 169 | 170 | if Q.shape[0] > m: 171 | Q = Q[:m] 172 | if Q.shape[1] > n: 173 | Q = Q[:, :n] 174 | 175 | return Q 176 | 177 | 178 | def whiten( 179 | X: npt.NDArray, 180 | alpha: float, 181 | preserve_variance: bool = True, 182 | eigval_tol=1e-7 183 | ) -> Tuple[npt.NDArray, npt.NDArray]: 184 | """Return regularized whitening transform for a matrix X. 185 | 186 | Parameters 187 | ---------- 188 | X : ndarray 189 | Matrix with shape `(m, n)` holding `m` observations 190 | in `n`-dimensional feature space. Columns of `X` are 191 | expected to be mean-centered so that `X.T @ X` is 192 | the covariance matrix. 193 | alpha : float 194 | Regularization parameter, `0 <= alpha <= 1`. When 195 | `alpha == 0`, the data matrix is fully whitened. 196 | When `alpha == 1` the data matrix is not transformed 197 | (`Z == eye(X.shape[1])`). 198 | preserve_variance : bool 199 | If True, rescale the (partial) whitening matrix so 200 | that the total variance, trace(X.T @ X), is preserved. 201 | eigval_tol : float 202 | Eigenvalues of covariance matrix are clipped to this 203 | minimum value. 204 | 205 | Returns 206 | ------- 207 | X_whitened : ndarray 208 | Transformed data matrix. 209 | Z : ndarray 210 | Matrix implementing the whitening transformation. 211 | `X_whitened = X @ Z`. 212 | """ 213 | 214 | # Return early if regularization is maximal (no whitening). 215 | if alpha > (1 - eigval_tol): 216 | return X, np.eye(X.shape[1]) 217 | 218 | # Compute eigendecomposition of covariance matrix 219 | lam, V = np.linalg.eigh(X.T @ X) 220 | lam = np.maximum(lam, eigval_tol) 221 | 222 | # Compute diagonal of (partial) whitening matrix. 223 | # 224 | # When (alpha == 1), then (d == ones). 225 | # When (alpha == 0), then (d == 1 / sqrt(lam)). 226 | d = alpha + (1 - alpha) * lam ** (-1 / 2) 227 | 228 | # Rescale the whitening matrix. 229 | if preserve_variance: 230 | 231 | # Compute the variance of the transformed data. 232 | # 233 | # When (alpha == 1), then new_var = sum(lam) 234 | # When (alpha == 0), then new_var = len(lam) 235 | new_var = np.sum( 236 | (alpha ** 2) * lam 237 | + 2 * alpha * (1 - alpha) * (lam ** 0.5) 238 | + ((1 - alpha) ** 2) * np.ones_like(lam) 239 | ) 240 | 241 | # Now re-scale d so that the variance of (X @ Z) 242 | # will equal the original variance of X. 243 | d *= np.sqrt(np.sum(lam) / new_var) 244 | 245 | # Form (partial) whitening matrix. 246 | Z = (V * d[None, :]) @ V.T 247 | 248 | # An alternative regularization strategy would be: 249 | # 250 | # lam, V = np.linalg.eigh(X.T @ X) 251 | # d = lam ** (-(1 - alpha) / 2) 252 | # Z = (V * d[None, :]) @ V.T 253 | 254 | # Returned (partially) whitened data and whitening matrix. 255 | return X @ Z, Z 256 | 257 | 258 | 259 | def rand_struc_orth( 260 | n: int, 261 | n_transforms: int = 3, 262 | random_state: Optional[Union[int, np.random.RandomState]] = None 263 | ) -> npt.NDArray: 264 | """Draws random sign flips for structured orthogonal 265 | transformation. See also, `struc_orth_matvec` function. 266 | 267 | Parameters 268 | ---------- 269 | n : int 270 | Dimensionality. 271 | n_transforms : int 272 | Number of sign flips to perform in between Hadamard 273 | transforms. Default is 3. 274 | random_state : int or np.random.RandomState 275 | Random number specification. 276 | """ 277 | rs = check_random_state(random_state) 278 | Ds = np.ones((n_transforms, n), dtype=int) 279 | idx = rs.rand(n_transforms, n) > .5 280 | Ds[idx] = -1 281 | return Ds 282 | 283 | 284 | def struc_orth_matvec(Ds, a, transpose=False): 285 | """Structured orthogonal matrix-vector multiply. Modifies 286 | vector `a` in-place. 287 | 288 | If transpose == False, then this computes: 289 | 290 | H @ Ds[-1] @ ... H @ Ds[1] @ H @ Ds[0] @ H @ a 291 | 292 | If transpose == True, then this computes: 293 | 294 | H @ Ds[0] @ ... H @ Ds[-2] @ H @ Ds[-1] @ H @ a 295 | 296 | Above, H is a normalized Hadamard matrix (i.e. normalized 297 | by ) 298 | 299 | Parameters 300 | ---------- 301 | Ds : ndarray 302 | (n_transforms x n) matrix specifying sign flips in 303 | between each Hadamard transform. 304 | 305 | a : ndarray 306 | Vector with n elements. An error is thrown if n is 307 | not a power of 2. 308 | 309 | transpose : bool 310 | If True, performs matrix-transpose times vector 311 | multiply. Default is False. 312 | """ 313 | 314 | # Check inputs. 315 | if a.ndim != 1: 316 | raise ValueError("Expected array `a` to be a vector.") 317 | 318 | if Ds.ndim != 2: 319 | raise ValueError("Expected array `Ds` to be a matrix.") 320 | 321 | if Ds.shape[1] != a.size: 322 | raise ValueError( 323 | "Dimension mismatch. Expected Ds.shape[1] == a.size.") 324 | 325 | if ((a.size & (a.size - 1)) != 0): 326 | raise ValueError( 327 | "Expected length of `a` to be a power of two. " 328 | "Saw instead, len(a) == {}.".format(a.size)) 329 | 330 | # Reverse order if transpose is desired. 331 | _Ds = Ds[::-1] if transpose else Ds 332 | 333 | # Perform series of Walsh-Hadamard Transforms and sign flips. 334 | fwht(a) 335 | for D in _Ds: 336 | a *= D 337 | fwht(a) 338 | 339 | # Normalize by sqrt(n) for each WH transform. 340 | a /= np.sqrt(a.size) ** (1 + len(Ds)) 341 | 342 | 343 | # @numba.jit(nopython=True) 344 | def fwht(a): 345 | """ 346 | In-place Fast Walsh–Hadamard Transform. 347 | 348 | Source: https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform 349 | """ 350 | h = 1 351 | while h < len(a): 352 | for i in range(0, len(a), h * 2): 353 | for j in range(i, i + h): 354 | x = a[j] 355 | y = a[j + h] 356 | a[j] = x + y 357 | a[j + h] = x - y 358 | h *= 2 359 | -------------------------------------------------------------------------------- /netrep/validation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions to check model inputs. 3 | """ 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | from sklearn.utils.validation import check_array 8 | 9 | 10 | def check_equal_shapes( 11 | X: npt.NDArray, 12 | Y: npt.NDArray, 13 | nd: int = 2, 14 | zero_pad: bool = False 15 | ) -> tuple[npt.NDArray, npt.NDArray]: 16 | """Checks that X and Y have equal shapes.""" 17 | 18 | X = check_array(X, allow_nd=True) 19 | Y = check_array(Y, allow_nd=True) 20 | 21 | if (X.ndim != nd) or (Y.ndim != nd): 22 | raise ValueError( 23 | "Expected {}d arrays, but shapes were {} and " 24 | "{}.".format(nd, X.shape, Y.shape) 25 | ) 26 | 27 | if X.shape != Y.shape: 28 | 29 | if zero_pad and (X.shape[:-1] == Y.shape[:-1]): 30 | 31 | # Number of padded zeros to add. 32 | n = max(X.shape[-1], Y.shape[-1]) 33 | 34 | # Padding specifications for X and Y. 35 | px = np.zeros((nd, 2), dtype="int") 36 | py = np.zeros((nd, 2), dtype="int") 37 | px[-1, -1] = n - X.shape[-1] 38 | py[-1, -1] = n - Y.shape[-1] 39 | 40 | # Pad X and Y with zeros along final axis. 41 | X = np.pad(X, px) 42 | Y = np.pad(Y, py) 43 | 44 | else: 45 | raise ValueError( 46 | "Expected arrays with equal dimensions, " 47 | "but got arrays with shapes {} and {}." 48 | "".format(X.shape, Y.shape)) 49 | 50 | return X, Y 51 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="netrep", 5 | version="0.0.2", 6 | url="https://github.com/ahwillia/netrep", 7 | 8 | author="Alex Williams", 9 | author_email="alex.h.williams@nyu.edu", 10 | 11 | description="Simple methods for comparing network representations.", 12 | 13 | packages=setuptools.find_packages(), 14 | install_requires=[ 15 | 'numpy>=1.16.5', 16 | 'scipy>=1.3.1', 17 | 'scikit-learn>=0.21.3', 18 | 'tqdm>=4.32.2' 19 | ], 20 | extras_require={ 21 | 'dev': [ 22 | 'pytest>=3.7' 23 | ] 24 | }, 25 | ) 26 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests metrics between network representations. 3 | """ 4 | import pytest 5 | import numpy as np 6 | from netrep.metrics import LinearMetric, PermutationMetric 7 | from netrep.utils import angular_distance, rand_orth 8 | from sklearn.utils.validation import check_random_state 9 | 10 | TOL = 1e-6 11 | 12 | 13 | @pytest.mark.parametrize('seed', [1, 2, 3]) 14 | @pytest.mark.parametrize('m', [100]) 15 | @pytest.mark.parametrize('n', [10]) 16 | def test_uncentered_procrustes(seed, m, n): 17 | 18 | # Set random seed, draw random rotation 19 | rs = check_random_state(seed) 20 | Q = rand_orth(n, n, random_state=rs) 21 | 22 | # Create a pair of randomly rotated matrices. 23 | X = rs.randn(m, n) 24 | Y = X @ Q 25 | 26 | # Fit model, assert distance == 0. 27 | metric = LinearMetric(alpha=1.0, center_columns=False) 28 | metric.fit(X, Y) 29 | assert abs(metric.score(X, Y)) < TOL 30 | 31 | 32 | @pytest.mark.parametrize('seed', [1, 2, 3]) 33 | @pytest.mark.parametrize('m', [100]) 34 | @pytest.mark.parametrize('n', [10]) 35 | def test_centered_procrustes(seed, m, n): 36 | 37 | # Set random seed, draw random rotation, offset, and isotropic scaling. 38 | rs = check_random_state(seed) 39 | Q = rand_orth(n, n, random_state=rs) 40 | v = rs.randn(1, n) 41 | c = rs.exponential() 42 | 43 | # Create a pair of randomly rotated matrices. 44 | X = rs.randn(m, n) 45 | Y = c * X @ Q + v 46 | 47 | # Fit model, assert distance == 0. 48 | metric = LinearMetric(alpha=1.0, center_columns=True) 49 | metric.fit(X, Y) 50 | assert abs(metric.score(X, Y)) < TOL 51 | 52 | 53 | @pytest.mark.parametrize('seed', [1, 2, 3]) 54 | @pytest.mark.parametrize('m', [100]) 55 | @pytest.mark.parametrize('n', [10]) 56 | def test_uncentered_cca(seed, m, n): 57 | 58 | # Set random seed, draw random linear alignment. 59 | rs = check_random_state(seed) 60 | W = rs.randn(n, n) 61 | 62 | # Create pair of matrices related by a linear transformation. 63 | X = rs.randn(m, n) 64 | Y = X @ W 65 | 66 | # Fit CCA, assert distance == 0. 67 | metric = LinearMetric(alpha=0.0, center_columns=False) 68 | metric.fit(X, Y) 69 | assert metric.score(X, Y) < TOL 70 | 71 | # Fit Procrustes, assert distance is nonzero. 72 | metric = LinearMetric(alpha=1.0, center_columns=False) 73 | metric.fit(X, Y) 74 | assert abs(metric.score(X, Y)) > TOL 75 | 76 | 77 | @pytest.mark.parametrize('seed', [1, 2, 3]) 78 | @pytest.mark.parametrize('m', [100]) 79 | @pytest.mark.parametrize('n', [10]) 80 | def test_centered_cca(seed, m, n): 81 | 82 | # Set random seed, draw random linear alignment and offset. 83 | rs = check_random_state(seed) 84 | W = rs.randn(n, n) 85 | v = rs.randn(1, n) 86 | 87 | # Create a pair of matrices related by a linear transformation. 88 | X = rs.randn(m, n) 89 | Y = X @ W + v 90 | 91 | # Fit model, assert distance is zero. 92 | metric = LinearMetric(alpha=0.0, center_columns=True) 93 | metric.fit(X, Y) 94 | assert abs(metric.score(X, Y)) < TOL 95 | 96 | # Fit Procrustes, assert distance is nonzero. 97 | metric = LinearMetric(alpha=1.0, center_columns=True) 98 | metric.fit(X, Y) 99 | assert abs(metric.score(X, Y)) > TOL 100 | 101 | 102 | @pytest.mark.parametrize('seed', [1, 2, 3]) 103 | @pytest.mark.parametrize('m', [100]) 104 | @pytest.mark.parametrize('n', [10]) 105 | def test_principal_angles(seed, m, n): 106 | 107 | # Set random seed, draw random linear alignment. 108 | rs = check_random_state(seed) 109 | W = rs.randn(n, n) 110 | 111 | # Create pair of matrices related by a linear transformation. 112 | X = rand_orth(m, n) 113 | Y = rand_orth(m, n) 114 | 115 | # Compute metric based on principal angles. 116 | cos_thetas = np.linalg.svd(X.T @ Y, compute_uv=False) 117 | dist_1 = np.arccos(np.mean(cos_thetas)) 118 | 119 | # Fit model, assert two approaches match. 120 | metric = LinearMetric(alpha=1.0, center_columns=False).fit(X, Y) 121 | assert abs(dist_1 - metric.score(X, Y)) < TOL 122 | 123 | 124 | @pytest.mark.parametrize('seed', [1, 2, 3]) 125 | @pytest.mark.parametrize('alpha', [0.0, 0.5, 1.0]) 126 | @pytest.mark.parametrize('m', [31]) 127 | @pytest.mark.parametrize('n', [30]) 128 | def test_triangle_inequality_linear(seed, alpha, m, n): 129 | 130 | rs = check_random_state(seed) 131 | X = rs.randn(m, n) 132 | Y = rs.randn(m, n) 133 | M = rs.randn(m, n) 134 | 135 | metric = LinearMetric(alpha=alpha, center_columns=True) 136 | 137 | dXY = metric.fit(X, Y).score(X, Y) 138 | dXM = metric.fit(X, M).score(X, M) 139 | dMY = metric.fit(M, Y).score(M, Y) 140 | 141 | assert dXY <= (dXM + dMY + TOL) 142 | 143 | 144 | @pytest.mark.parametrize('seed', [1, 2, 3]) 145 | @pytest.mark.parametrize('center_columns', [True, False]) 146 | @pytest.mark.parametrize('m', [100]) 147 | @pytest.mark.parametrize('n', [10]) 148 | @pytest.mark.parametrize('score_method', ['euclidean', 'angular']) 149 | def test_permutation(seed, center_columns, m, n, score_method): 150 | 151 | # Set random seed, draw random rotation 152 | rs = check_random_state(seed) 153 | 154 | # Create a pair of randomly rotated matrices. 155 | X = rs.randn(m, n) 156 | Y = np.copy(X)[:, rs.permutation(n)] 157 | 158 | # Fit model, assert distance == 0. 159 | metric = PermutationMetric( 160 | center_columns=center_columns, score_method=score_method 161 | ) 162 | assert abs(metric.fit(X, Y).score(X, Y)) < TOL 163 | 164 | 165 | -------------------------------------------------------------------------------- /tests/test_multiprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['OMP_NUM_THREADS'] = '1' 3 | 4 | import pytest 5 | import numpy as np 6 | from netrep.metrics import LinearMetric, PermutationMetric 7 | from netrep.metrics import GaussianStochasticMetric, EnergyStochasticMetric 8 | from sklearn.utils.validation import check_random_state 9 | 10 | 11 | def _get_cov(n_images, n_neurons, rs): 12 | A = rs.randn(n_images, n_neurons, n_neurons) 13 | # batched outerproduct 14 | return np.einsum('bij,bkj->bik', A, A) 15 | 16 | 17 | def _get_data_wasserstein(n_networks, n_images, n_neurons, rs): 18 | return [(rs.randn(n_images, n_neurons), _get_cov(n_images, n_neurons, rs)) for _ in range(n_networks)] 19 | 20 | 21 | @pytest.mark.parametrize('metric_type', ['linear', 'permutation', 'gaussian', 'energy']) 22 | @pytest.mark.parametrize('test_set', [False, True]) 23 | def test_pairwise_distances(metric_type, test_set): 24 | 25 | rs = check_random_state(0) 26 | 27 | # Create a pair of randomly rotated matrices. 28 | n_networks, n_images, n_repeats, n_neurons = 3, 2, 3, 4 29 | 30 | if metric_type == 'linear': 31 | train_data = [rs.randn(n_images, n_neurons) for _ in range(n_networks)] 32 | test_data = [rs.randn(n_images, n_neurons) for _ in range(n_networks)] 33 | 34 | metric = LinearMetric() 35 | 36 | elif metric_type == 'permutation': 37 | train_data = [rs.randn(n_images, n_neurons) for _ in range(n_networks)] 38 | test_data = [rs.randn(n_images, n_neurons) for _ in range(n_networks)] 39 | 40 | metric = PermutationMetric() 41 | 42 | if metric_type == 'energy': 43 | metric = EnergyStochasticMetric() 44 | train_data = [rs.randn(n_images, n_repeats, n_neurons) for _ in range(n_networks)] 45 | test_data = [rs.randn(n_images, n_repeats, n_neurons) for _ in range(n_networks)] 46 | 47 | elif metric_type == 'gaussian': 48 | metric = GaussianStochasticMetric() 49 | train_data = _get_data_wasserstein(n_networks, n_images, n_neurons, rs) 50 | test_data = _get_data_wasserstein(n_networks, n_images, n_neurons, rs) 51 | 52 | if test_set: 53 | D_train, D_test = metric.pairwise_distances(train_data, test_data) 54 | assert D_test.sum() >= 0.0 55 | else: 56 | D_train, D_test = metric.pairwise_distances(train_data) 57 | assert D_test.sum() == np.inf 58 | 59 | assert D_train.shape == (n_networks, n_networks) 60 | assert D_test.shape == (n_networks, n_networks) 61 | assert D_train.sum() >= 0.0 62 | -------------------------------------------------------------------------------- /tests/test_multiset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests metrics between network representations. 3 | """ 4 | import pytest 5 | import numpy as np 6 | from scipy.spatial.distance import pdist 7 | from sklearn.utils.validation import check_random_state 8 | import netrep.metrics 9 | from netrep.multiset import pairwise_distances, frechet_mean 10 | 11 | TOL = 1e-6 12 | 13 | @pytest.mark.parametrize('seed', [1, 2, 3]) 14 | @pytest.mark.parametrize('group', ['orth', 'perm']) 15 | @pytest.mark.parametrize('method', ['full_batch', 'streaming']) 16 | @pytest.mark.parametrize('m', [100]) 17 | @pytest.mark.parametrize('n', [10]) 18 | @pytest.mark.parametrize('num_X', [4]) 19 | def test_frechet_mean(seed, group, method, m, n, num_X): 20 | 21 | # Set random seed, draw random rotation. 22 | rs = check_random_state(seed) 23 | _Xb = rs.randn(m, n) 24 | Xs = [_Xb for _ in range(num_X)] 25 | 26 | Xbar, aligned_Xs = frechet_mean( 27 | Xs, group=group, return_aligned_Xs=True, method=method 28 | ) 29 | 30 | assert np.all( 31 | pdist(np.stack(aligned_Xs).reshape(num_X, -1)) < TOL 32 | ) 33 | -------------------------------------------------------------------------------- /tests/test_stochastic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests metrics betwen stochastic neuralrepresentations. 3 | """ 4 | import pytest 5 | import numpy as np 6 | from netrep.metrics import GaussianStochasticMetric, EnergyStochasticMetric 7 | from netrep.utils import rand_orth 8 | from sklearn.utils.validation import check_random_state 9 | 10 | TOL = 1e-6 11 | 12 | 13 | @pytest.mark.parametrize('seed', [1, 2, 3]) 14 | @pytest.mark.parametrize('m', [10]) 15 | @pytest.mark.parametrize('n', [4]) 16 | def test_gaussian_identity_covs(seed, m, n): 17 | 18 | # Set random seed, draw random rotation 19 | rs = check_random_state(seed) 20 | Q = rand_orth(n, n, random_state=rs) 21 | 22 | # Create a pair of randomly rotated Gaussians. 23 | mean_X = rs.randn(m, n) 24 | mean_Y = mean_X @ Q 25 | covs_X = np.array([np.eye(n) for _ in range(m)]) 26 | covs_Y = np.array([np.eye(n) for _ in range(m)]) 27 | 28 | X = (mean_X, covs_X) 29 | Y = (mean_Y, covs_Y) 30 | 31 | # Fit model, assert distance == 0. 32 | metric = GaussianStochasticMetric(group="orth") 33 | metric.fit(X, Y) 34 | assert abs(metric.score(X, Y)) < TOL 35 | 36 | 37 | @pytest.mark.parametrize('seed', [1, 2, 3]) 38 | @pytest.mark.parametrize('m', [10]) 39 | @pytest.mark.parametrize('n', [4]) 40 | def test_gaussian_zero_means(seed, m, n): 41 | 42 | # Set random seed, draw random rotation 43 | rs = check_random_state(seed) 44 | Q = rand_orth(n, n, random_state=rs) 45 | Us = [rand_orth(n, n, random_state=np.random.RandomState(2)) for _ in range(m)] 46 | Ps = [u @ np.diag(np.logspace(-1, 1, n)) @ u.T for u in Us] 47 | 48 | # Create a pair of randomly rotated Gaussians. 49 | mean_X = np.zeros((m, n)) 50 | mean_Y = np.zeros((m, n)) 51 | covs_X = np.array(Ps) 52 | covs_Y = np.array([Q.T @ p @ Q for p in Ps]) 53 | 54 | X = (mean_X, covs_X) 55 | Y = (mean_Y, covs_Y) 56 | 57 | # Fit model, assert distance == 0. 58 | metric = GaussianStochasticMetric(group="orth") 59 | metric.fit(X, Y) 60 | assert abs(metric.score(X, Y)) < TOL 61 | 62 | 63 | @pytest.mark.parametrize('seed', [1, 2, 3]) 64 | @pytest.mark.parametrize('m', [10]) 65 | @pytest.mark.parametrize('n', [4]) 66 | def test_gaussian_lower_bound(seed, m, n): 67 | 68 | # Set random seed. 69 | rs = check_random_state(seed) 70 | Us = [rand_orth(n, n, random_state=rs) for _ in range(m)] 71 | Vs = [rand_orth(n, n, random_state=rs) for _ in range(m)] 72 | 73 | # Create a pair of random networks. 74 | mean_X = rs.randn(m, n) 75 | mean_Y = rs.randn(m, n) 76 | covs_X = np.array([u @ np.diag(np.logspace(-2, 2, n)) @ u.T for u in Us]) 77 | covs_Y = np.array([v @ np.diag(np.logspace(0, 1, n)) @ v.T for v in Vs]) 78 | 79 | X = (mean_X, covs_X) 80 | Y = (mean_Y, covs_Y) 81 | 82 | alphas = np.linspace(0, 2, 10) 83 | dists = np.zeros_like(alphas) 84 | 85 | for i, a in enumerate(alphas): 86 | dists[i] = GaussianStochasticMetric( 87 | group="orth", 88 | alpha=a, 89 | n_restarts=10, 90 | init="rand", 91 | random_state=rs 92 | ).fit(X, Y).score(X, Y) 93 | 94 | lower_bound = np.sqrt( 95 | (1 - (alphas / 2)) * dists[0] ** 2 + (alphas / 2) * dists[-1] ** 2 96 | ) 97 | 98 | assert np.all(dists > (lower_bound - TOL)) 99 | 100 | 101 | @pytest.mark.parametrize('seed', [1, 2, 3]) 102 | @pytest.mark.parametrize('m', [4]) 103 | @pytest.mark.parametrize('n', [4]) 104 | @pytest.mark.parametrize('p', [100]) 105 | @pytest.mark.parametrize('noise', [0.1]) 106 | def test_energy_distance(seed, m, n, p, noise): 107 | 108 | # Set random seed, draw random rotation 109 | rs = check_random_state(seed) 110 | Q = rand_orth(n, n, random_state=rs) 111 | 112 | # Create a pair of randomly rotated Gaussians. 113 | xm = rs.randn(m, n) 114 | X = xm[:, None, :] + noise * rs.randn(m, p, n) 115 | Y = (xm @ Q)[:, None, :] + noise * rs.randn(m, p, n) 116 | 117 | # Fit model. 118 | metric = EnergyStochasticMetric(group="orth") 119 | metric.fit(X, Y) 120 | 121 | # Check that loss monotonically decreases. 122 | assert np.all(np.diff(metric.loss_hist) <= TOL) 123 | -------------------------------------------------------------------------------- /tests/test_stochastic_process.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """ 3 | Tests metrics betwen stochastic process neural representations. 4 | """ 5 | 6 | import pytest 7 | import numpy as np 8 | from netrep.metrics import GPStochasticMetric,GaussianStochasticMetric 9 | from netrep.utils import rand_orth 10 | from sklearn.utils.validation import check_random_state 11 | from sklearn.covariance import EmpiricalCovariance 12 | 13 | from numpy import random as rand 14 | from netrep.utils import rand_orth 15 | 16 | TOL = 1e-6 17 | 18 | # %% Class for sampling from a gaussian process given a kernel 19 | class GaussianProcess: 20 | def __init__(self,kernel,D): 21 | self.kernel = kernel 22 | self.D = D 23 | 24 | def evaluate_kernel(self, xs, ys): 25 | fun = np.vectorize(self.kernel) 26 | return fun(xs[:, None], ys) 27 | 28 | def sample(self,ts): 29 | T = ts.shape[0] 30 | c_g = self.evaluate_kernel(ts,ts) 31 | fs = rand.multivariate_normal( 32 | mean=np.zeros(T), 33 | cov=c_g, 34 | size=self.D 35 | ) 36 | return fs 37 | 38 | 39 | # %% 40 | @pytest.mark.parametrize('seed', [1, 2, 3]) 41 | @pytest.mark.parametrize('t', [10]) # number of time points 42 | @pytest.mark.parametrize('n', [4]) # number of neurons 43 | @pytest.mark.parametrize('k', [100]) # number of samples 44 | def test_gaussian_process(seed, t, n, k): 45 | # Set random seed, draw random rotation 46 | rs = check_random_state(seed) 47 | Q = rand_orth(n, n, random_state=rs) 48 | 49 | # Generate data from a gaussian process with RBF kernel 50 | ts = np.linspace(0,1,t) 51 | gpA = GaussianProcess( 52 | kernel = lambda x, y: 1e-2*(1e-6*(x==y)+np.exp(-np.linalg.norm(x-y)**2/(2*1.**2))), 53 | D=n 54 | ) 55 | sA = np.array([gpA.sample(ts) for _ in range(k)]).reshape(k,n*t) 56 | 57 | # Transform GP according to a rotation applied to individiual 58 | # blocks of the full covariance matrix 59 | A = [sA.mean(0),EmpiricalCovariance().fit(sA).covariance_] 60 | B = [ 61 | np.kron(np.eye(t),Q)@A[0], 62 | np.kron(np.eye(t),Q)@A[1]@(np.kron(np.eye(t),Q)).T 63 | ] 64 | 65 | 66 | # Compute DSSD 67 | metric = GPStochasticMetric(n_dims=n,group="orth") 68 | 69 | dssd = metric.fit_score(A,B) 70 | assert abs(dssd) < TOL 71 | 72 | # Compute marginal SSD 73 | metric = GaussianStochasticMetric(group="orth") 74 | 75 | A_marginal = [ 76 | A[0].reshape(t,n), 77 | np.array([A[1][i*n:(i+1)*n,i*n:(i+1)*n] for i in range(t)]) 78 | ] 79 | 80 | B_marginal = [ 81 | B[0].reshape(t,n), 82 | np.array([B[1][i*n:(i+1)*n,i*n:(i+1)*n] for i in range(t)]) 83 | ] 84 | 85 | marginal_ssd = metric.fit_score(A_marginal,B_marginal) 86 | assert abs(marginal_ssd) < TOL 87 | 88 | # Compute full SSD 89 | metric = GaussianStochasticMetric(group="orth") 90 | 91 | A_full = [A[0][None],A[1][None]] 92 | B_full = [B[0][None],B[1][None]] 93 | 94 | full_ssd = metric.fit_score(A_full,B_full) 95 | 96 | assert abs(full_ssd) > TOL 97 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests utility functions. 3 | """ 4 | import pytest 5 | 6 | import numpy as np 7 | from numpy.testing import assert_array_almost_equal, assert_allclose 8 | 9 | from netrep.utils import ( 10 | rand_orth, 11 | centered_kernel, 12 | fwht, 13 | rand_struc_orth, 14 | struc_orth_matvec, 15 | whiten, 16 | sq_bures_metric, 17 | sq_bures_metric_slow 18 | ) 19 | 20 | from sklearn.utils.validation import check_random_state 21 | from sklearn.metrics.pairwise import pairwise_kernels 22 | 23 | import scipy.linalg 24 | 25 | ATOL = 1e-6 26 | RTOL = 1e-6 27 | 28 | @pytest.mark.parametrize('seed', [1, 2, 3]) 29 | @pytest.mark.parametrize('m', [50, 100]) 30 | @pytest.mark.parametrize('n', [10, 20]) 31 | def test_whiten(seed, m, n): 32 | # For this test, we assume m > n. 33 | rs = check_random_state(seed) 34 | X = rs.randn(m, n) 35 | XZ, Z = whiten(X, 0.0, preserve_variance=False) 36 | assert_array_almost_equal(XZ.T @ XZ, np.eye(n)) 37 | 38 | 39 | @pytest.mark.parametrize('seed', [1, 2, 3]) 40 | @pytest.mark.parametrize('m', [50, 100]) 41 | @pytest.mark.parametrize('n', [10, 20]) 42 | def test_whiten_preserve_variance(seed, m, n): 43 | # For this test, we assume m > n. 44 | rs = check_random_state(seed) 45 | X = rs.randn(m, n) 46 | XZ, Z = whiten(X, 0.0, preserve_variance=True) 47 | gram = (XZ.T @ XZ) 48 | d = np.full(n, np.sum(gram) / n) 49 | assert_array_almost_equal(gram, np.diag(d)) 50 | 51 | 52 | @pytest.mark.parametrize('seed', [1, 2, 3]) 53 | @pytest.mark.parametrize('alpha', [0.0, 0.5, 1.0]) 54 | @pytest.mark.parametrize('m', [50, 100]) 55 | @pytest.mark.parametrize('n', [10, 20]) 56 | def test_partial_whiten_preserve_variance(seed, alpha, m, n): 57 | # For this test, we assume m > n. 58 | rs = check_random_state(seed) 59 | X = rs.randn(m, n) 60 | XZ, Z = whiten(X, alpha, preserve_variance=True) 61 | assert_allclose(np.trace(X.T @ X), np.trace(XZ.T @ XZ)) 62 | 63 | 64 | @pytest.mark.parametrize('seed', [1, 2, 3]) 65 | @pytest.mark.parametrize('m', [10, 20]) 66 | @pytest.mark.parametrize('n', [10, 20]) 67 | def test_rand_orth(seed, m, n): 68 | Q = rand_orth(m, n, random_state=seed) 69 | 70 | if m == n: 71 | assert_array_almost_equal(Q.T @ Q, np.eye(n)) 72 | assert_array_almost_equal(Q @ Q.T, np.eye(n)) 73 | elif m > n: 74 | assert_array_almost_equal(Q.T @ Q, np.eye(n)) 75 | else: 76 | assert_array_almost_equal(Q @ Q.T, np.eye(m)) 77 | 78 | 79 | @pytest.mark.parametrize('seed', [1, 2, 3]) 80 | @pytest.mark.parametrize('m', [100]) 81 | @pytest.mark.parametrize('n', [10]) 82 | def test_centered_kernel(seed, m, n): 83 | 84 | rs = check_random_state(seed) 85 | 86 | # Check linear kernel is centered. 87 | X = rs.randn(m, n) 88 | K = pairwise_kernels(X - np.mean(X, axis=0), metric="linear") 89 | K2 = centered_kernel(X, metric="linear") 90 | assert_array_almost_equal(K, K2) 91 | assert_array_almost_equal(centered_kernel(X), centered_kernel(X, X)) 92 | 93 | 94 | # @pytest.mark.parametrize('seed', [1, 2, 3]) 95 | # @pytest.mark.parametrize('n', [1, 2, 3, 4, 5]) 96 | # def test_fast_hadamard_transform(seed, n): 97 | 98 | # # Form Hadamard matrix explicitly 99 | # H = scipy.linalg.hadamard(2 ** n) 100 | 101 | # # Draw random vector. 102 | # rs = check_random_state(seed) 103 | # x = rs.randn(2 ** n) 104 | 105 | # # Perform explicit computation 106 | # expected = H @ x 107 | 108 | # # Check that Fast-Walsh_Hadamard transform matches. 109 | # fwht(x) # updates x in-place. 110 | # assert_array_almost_equal(expected, x) 111 | 112 | 113 | # @pytest.mark.parametrize('seed', [1, 2, 3]) 114 | # @pytest.mark.parametrize('n', [1, 2, 14]) 115 | # @pytest.mark.parametrize('n_transforms', [1, 3, 6]) 116 | # def test_structured_orth(seed, n, n_transforms): 117 | 118 | # # Draw random vectors. 119 | # rs = check_random_state(seed) 120 | # x = rs.randn(2 ** n) 121 | # y = rs.randn(2 ** n) 122 | 123 | # # Compute inner product. 124 | # original_inner_prod = np.dot(x, y) 125 | 126 | # # Draw sign flips. 127 | # Ds = rand_struc_orth(2 ** n, n_transforms=n_transforms, random_state=rs) 128 | 129 | # # Apply structured orthogonal transformation. If this is 130 | # # indeed orthogonal, the inner product should be preserved. 131 | # struc_orth_matvec(Ds, x) 132 | # struc_orth_matvec(Ds, y) 133 | 134 | # # Check that the inner products match. 135 | # assert_allclose( 136 | # np.dot(x, y), original_inner_prod, atol=ATOL, rtol=RTOL) 137 | 138 | 139 | # @pytest.mark.parametrize('seed', [1, 2, 3]) 140 | # @pytest.mark.parametrize('n', [1, 2, 14]) 141 | # @pytest.mark.parametrize('n_transforms', [1, 3, 6]) 142 | # def test_structured_orth_inverse(seed, n, n_transforms): 143 | 144 | # # Draw random vectors. 145 | # rs = check_random_state(seed) 146 | # x = rs.randn(2 ** n) 147 | # y = x.copy() 148 | 149 | # # Draw sign flips. 150 | # Ds = rand_struc_orth(2 ** n, n_transforms=n_transforms, random_state=rs) 151 | 152 | # # Apply structured orthogonal transformation, and then apply 153 | # # the inverse transformation. 154 | # struc_orth_matvec(Ds, x) 155 | # struc_orth_matvec(Ds, x, transpose=True) 156 | 157 | # # Check that we recover our original vector 158 | # assert_allclose(x, y, atol=ATOL, rtol=RTOL) 159 | 160 | @pytest.mark.parametrize('seed', [1, 2, 3]) 161 | @pytest.mark.parametrize('n', [1, 2, 14]) 162 | def test_bures(seed, n): 163 | 164 | # Draw covariances. 165 | rs = check_random_state(seed) 166 | X = rs.randn(n, n) 167 | Y = rs.randn(n, n) 168 | Sx = X @ X.T 169 | Sy = Y @ Y.T 170 | 171 | # Check that we recover our original vector 172 | assert_allclose( 173 | sq_bures_metric(Sx, Sy), 174 | sq_bures_metric_slow(Sx, Sy), 175 | atol=ATOL, rtol=RTOL 176 | ) 177 | --------------------------------------------------------------------------------