├── .gitignore ├── wass_div.png ├── data ├── sg_demo.npz └── mihi_demo.mat ├── dist_align.png ├── README.md ├── src ├── utils.py └── hiwa.py └── HiWA-Demo.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | **/.ipynb_checkpoints 2 | **/__pycache__ -------------------------------------------------------------------------------- /wass_div.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nerdslab/PyHiWA/HEAD/wass_div.png -------------------------------------------------------------------------------- /data/sg_demo.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nerdslab/PyHiWA/HEAD/data/sg_demo.npz -------------------------------------------------------------------------------- /dist_align.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nerdslab/PyHiWA/HEAD/dist_align.png -------------------------------------------------------------------------------- /data/mihi_demo.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nerdslab/PyHiWA/HEAD/data/mihi_demo.mat -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hierarchical Wasserstein Alignment (HiWA) 2 | -------------------------------------------- 3 | This repository contains a Python package implementing the algorithm described in this paper: 4 | 5 | John Lee, Max Dabagia, E Dyer, C Rozell: Hierarchical Wasserstein Alignment for Multimodal Distributions, NeurIPS 2019. https://arxiv.org/abs/1906.11768 6 | 7 | ## Overview 8 | ---------- 9 | Optimal transport approaches to distribution alignment attempt to minimize the divergence between two distributions, as quantified by the Wasserstein distance between them. HiWA introduces some further assumptions which make this problem tractable even in the presence of noise and ambiguity, which are unavoidable in real-world datasets. These assumptions are: 10 | 1. Well-defined cluster structure exists in each dataset. 11 | 2. The inter- and intra-cluster structure is consistent between datasets. 12 | HiWA leverages this cluster structure by first determining how best to align the clusters, and then using this information to influence a aligning transformation of the entire dataset to match the target. 13 | 14 | ## Contents 15 | ---------- 16 | The `HiWA` class in this repository is a self-contained implementation of the algorithm. The included Jupyter Notebook is a comprehensive demo with the algorithm applied to both a synthetic dataset and a real-world neuroscience problem on decoding movement intention from neuron firing patterns in the primary motor cortex of a non-human primate. 17 | 18 | ## Dependencies 19 | --------------- 20 | `numpy, scipy, matplotlib, scikit-learn` 21 | 22 | ## Tips 23 | ------- 24 | - The dimensionality reduction technique used to construct the low-dimensional embedding is *critically* important to the algorithm's success. The first thing to check if it is not working is whether the low-dimensional embeddings of the source and target datasets are capturing the same structure, and whether that structure is alignable (i.e. there are no pathological symmetries). 25 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io import loadmat 3 | import matplotlib.pyplot as plt 4 | from scipy.linalg import sqrtm 5 | from scipy.stats import mode 6 | from sklearn.neighbors import NearestNeighbors 7 | from itertools import permutations 8 | from mpl_toolkits.mplot3d import Axes3D 9 | 10 | def load_data(file_name): 11 | mat_dict = loadmat(file_name, appendmat=True) 12 | return np.array(mat_dict['Tte']), np.array(mat_dict['Ttr']), np.array(mat_dict['Xte']), np.array(mat_dict['Xtr']), np.array(mat_dict['Yte']), np.array(mat_dict['Ytr']) 13 | 14 | 15 | def normal(X): 16 | mean_X = np.mean(X, axis=0) 17 | cov_X = np.cov(X, rowvar=False) 18 | X_n = X 19 | #for col in range(X.shape[1]): 20 | # X_n[:, col] = (X_n[:, col] - mean_X[col]) / np.sqrt(cov_X[col, col]) 21 | return np.matmul(X_n - mean_X, np.linalg.inv(sqrtm(cov_X))) 22 | 23 | def remove_const_cols(Y): 24 | return Y[:, ~np.all(Y[1:] == Y[:-1], axis=0)] 25 | 26 | def map_X_3D(X): 27 | return np.column_stack((X[:, 0], X[:, 1], np.linalg.norm(X, axis=1))) 28 | 29 | def LS_oracle(X_test, Y_test): 30 | X_n = X_test 31 | H_inv = X_n.T @ np.linalg.pinv(Y_test).T 32 | return H_inv 33 | 34 | def plot_2d_clusters(X, labels): 35 | for i in np.unique(labels): 36 | plt.plot(X[labels == i, 0], X[labels == i, 1], linestyle='', marker='.', markersize=15) 37 | 38 | 39 | def plot_3d_clusters(X, labels): 40 | for i in np.unique(labels): 41 | plt.scatter(X[id == i, 0], X[id == i, 1], X[id == i, 2], marker='.') 42 | 43 | 44 | def eval_R2(X, Y): 45 | X = normal(X) 46 | Y = normal(Y) 47 | return 1 - np.mean(np.power(Y - X, 2), axis=0).sum() / np.var(Y, axis=0).sum() 48 | 49 | def match_clustered_labels(X, X_labels, Y, Y_labels): 50 | """Find the optimal permutation of labels found by clustering, via 1-nearest neighbors classification accuracy. 51 | 52 | Parameters 53 | ---------- 54 | X : array-like, shape (n_samples_x, n_features) 55 | Dataset of labels to match to. 56 | 57 | X_labels : array-like, shape (n_samples_x, ) 58 | Labels to match to. 59 | 60 | Y : array-like, shape (n_samples_y, n_features) 61 | Dataset of labels to be matched. 62 | 63 | Y_labels : array-like, shape (n_samples_y, ) 64 | Labels to be matched. 65 | 66 | Returns 67 | ---------- 68 | Y_new : New Y labels. 69 | 70 | """ 71 | nbrs = NearestNeighbors() 72 | nbrs.fit(X) 73 | idx = nbrs.kneighbors(Y, n_neighbors=1, return_distance=False).squeeze() 74 | 75 | Y_new = np.zeros(Y_labels.shape, dtype=int) 76 | for i in np.unique(Y_labels): 77 | idc = idx[Y_labels == i] 78 | Y_new[Y_labels == i], _ = mode(X_labels[idc]) 79 | 80 | return Y_new 81 | -------------------------------------------------------------------------------- /src/hiwa.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.decomposition import PCA 3 | from scipy.linalg import sqrtm, orth 4 | 5 | def _normal(X): 6 | """Normalize a matrix of observations in rows and variables in columns.""" 7 | return (X - np.mean(X, axis=0)) @ np.linalg.inv(sqrtm(np.cov(X, rowvar=False))) 8 | 9 | 10 | def _closed_form_rotation_solver(m): 11 | """Closed form rotation solver via SVD method.""" 12 | u, _, vh = np.linalg.svd(m) 13 | return u @ vh 14 | 15 | 16 | def _sinkhorn(p, q, X, Y, gamma, maxiter): 17 | """Entropy-Regularized (Sinkhorn) Optimal Transport between distributions""" 18 | x2, y2 = np.sum(np.power(X, 2), axis=0), np.sum(np.power(Y, 2), axis=0) 19 | C = np.tile(y2[np.newaxis, :], (X.shape[1], 1)) + np.tile(x2[:, np.newaxis], (1, Y.shape[1])) - 2 * (X.T @ Y) 20 | K = np.exp(-C / gamma) 21 | b = np.full(q.shape, 1 / len(q)) 22 | for n in range(maxiter): 23 | a = p / (K @ b) 24 | b = q / (K.T @ a) 25 | if np.isnan(a).any(): 26 | raise ArithmeticError('NaN found!') 27 | P = np.diag(a.squeeze()) @ K @ np.diag(b.squeeze()) 28 | return P, np.sum(C * P) 29 | 30 | 31 | def _sinkhorn_clusters(p, q, C, gamma, maxiter): 32 | """Entropy-Regularized (Sinkhorn) Optimal Transport between clusters""" 33 | K = np.exp(-C / gamma) 34 | b = np.ones(q.shape) 35 | for n in range(maxiter): 36 | a = p / (K @ b) 37 | b = q / (K.T @ a) 38 | 39 | P = np.diag(a.squeeze()) @ K @ np.diag(b.squeeze()) 40 | 41 | return P, np.sum(C * P) 42 | 43 | 44 | def _rMSE(X, Rg, Rgt): 45 | """Relative mean square error""" 46 | return np.linalg.norm(Rgt @ X.T - Rg @ X.T, 'fro') ** 2 / np.linalg.norm(Rgt @ X.T, 'fro') ** 2 47 | 48 | 49 | def _eval_R2(X, Rg, Rgt): 50 | """Correlation coefficient""" 51 | X = (Rgt @ X.T).T 52 | Y = (Rg @ X.T).T 53 | return 1 - np.mean(Y - X, axis=0).sum() / np.var(Y, axis=0).sum() 54 | 55 | 56 | class HiWA: 57 | """Hierarchical Wasserstein Alignment (HiWA) 58 | 59 | Applies nested OT between points XX and YY with decreasing entropy 60 | 61 | Parameters 62 | ---------- 63 | dim_red_method : object with a fit_transform() method or None 64 | Method to compute a low-d embedding of the source and target distributions. 65 | Defaults to PCA. 66 | 67 | parallelize : Boolean 68 | Whether to use the multiprocessing module to run ADMM iterations concurrently. 69 | Defaults to True 70 | 71 | normalize : Boolean 72 | Whether to normalize source and target before attempting alignment. 73 | Defaults to True, should be performed to prevent numerical errors. 74 | 75 | maxiter : int or None 76 | Maximum iterations for ADMM 77 | Defaults to 300 78 | 79 | tol : float, double, or None 80 | Stopping criterion for ADMM 81 | The change between rotation matrices on subsequent iterations must be 82 | greater than this (measured as Frobenius norm of the difference) 83 | Defaults to 0.1 84 | 85 | mu : float, double, or None 86 | ADMM parameter 87 | smaller = slower & more accurate, larger = faster & less accurate 88 | Defaults to 0.005 89 | 90 | shorn_maxiter : int or None 91 | Maximum iterations for Sinkhorn OT between clusters 92 | Defaults to 1000 93 | 94 | shorn_gamma : float, double, or None 95 | Entropy temperature for Sinkhorn OT between clusters 96 | larger = slower & more accurate, smaller = faster & less accurate 97 | Defaults to 0.2 98 | 99 | sa_maxiter : int or None 100 | Maximum iterations for subspace alignment procedure 101 | Defaults to 100 102 | 103 | sa_tol : float, double, or None 104 | Stopping criterion for subspace alignment procedure 105 | Defaults to 0.01 106 | 107 | sa_shorn_maxiter : int or None 108 | Maximum iterations for Sinkhorn OT within SA procedure 109 | Defaults to 150 110 | 111 | sa_shorn_gamma : float, double, or None 112 | Entropy temperature for Sinkhorn OT within SA procedure 113 | larger = slower & more accurate, smaller = faster & less accurate 114 | Defaults to 0.1 115 | """ 116 | 117 | def __init__(self, dim_red_method=PCA(n_components=2), normalize=True, maxiter=300, 118 | tol=1e-1, mu=5e-3, shorn_maxiter=1000, shorn_gamma=2e-1, sa_maxiter=100, 119 | sa_tol=1e-2, sa_shorn_maxiter=150, sa_shorn_gamma=1e-1): 120 | 121 | # Save parameters 122 | self.dim_red_method = dim_red_method 123 | self.normalize = normalize 124 | self.maxiter = maxiter 125 | self.tol = tol 126 | self.mu = mu 127 | self.shorn_maxiter = shorn_maxiter 128 | self.shorn_gamma = shorn_gamma 129 | self.sa_maxiter = sa_maxiter 130 | self.sa_tol = sa_tol 131 | self.sa_shorn_maxiter = sa_shorn_maxiter 132 | self.sa_shorn_gamma = sa_shorn_gamma 133 | 134 | def fit(self, X, X_labels, Y, Y_labels, **kwargs): 135 | """Fit the model with X, learning a rotation to match X to Y. 136 | 137 | Parameters 138 | ---------- 139 | X : array-like, shape (n_samples_x, n_features) 140 | Source dataset, to be rotated. 141 | 142 | X_labels : array-like, shape (n_samples_x, ) 143 | Cluster labels for the source dataset. 144 | 145 | Y : arraylike, shape (n_samples_y, n_features) 146 | Target dataset, of the same number of features as X, to rotate X to match. 147 | 148 | Y_labels : array-like, shape (n_samples_y, ) 149 | Cluster labels for the target dataset. 150 | 151 | """ 152 | # If not provided, compute transformations for source and target datasets using the method specified during 153 | # initialization 154 | if self.normalize: 155 | X = _normal(X) 156 | Y = _normal(Y) 157 | X_transform = kwargs.get('X_transform', np.linalg.pinv(X) @ self.dim_red_method.fit_transform(X)) 158 | Y_transform = kwargs.get('Y_transform', np.linalg.pinv(Y) @ self.dim_red_method.fit_transform(Y)) 159 | self.Rgt = kwargs.get('Rgt', np.identity(X.shape[1])) 160 | 161 | # Initialization 162 | h_dim, num_clusters_x, num_clusters_y = X.shape[1], len(np.unique(X_labels)), len(np.unique(Y_labels)) 163 | # Rg = np.identity(h_dim) 164 | Rg = _closed_form_rotation_solver(np.random.random((h_dim, h_dim))) 165 | P = np.full((num_clusters_x, num_clusters_y), 1 / (num_clusters_x * num_clusters_y)) 166 | p = np.full((num_clusters_x, 1), 1 / num_clusters_x) 167 | q = np.full((num_clusters_y, 1), 1 / num_clusters_y) 168 | # Lagrangian multipliers 169 | L = np.zeros((h_dim, h_dim, num_clusters_x, num_clusters_y)) 170 | # Auxiliary variables 171 | R = np.zeros((h_dim, h_dim, num_clusters_x, num_clusters_y)) 172 | R[:, :, :, :] = np.identity(h_dim)[:, :, np.newaxis, np.newaxis] 173 | 174 | C = np.zeros((num_clusters_x, num_clusters_y)) 175 | 176 | diagnostics = {'gamma': np.zeros(self.maxiter, dtype=float), 177 | 'Rg_norm': np.zeros(self.maxiter, dtype=float), 178 | 'rMSE': np.zeros(self.maxiter, dtype=float), 179 | 'R2': np.zeros(self.maxiter, dtype=float), 180 | 'C': np.zeros(C.shape, dtype=float)} 181 | 182 | # Compute low-d embeddings in high-d space and scale 183 | X_mbed = (X_transform @ X_transform.T @ X.T).T / np.sqrt(h_dim) 184 | Y_mbed = (Y_transform @ Y_transform.T @ Y.T).T / np.sqrt(h_dim) 185 | 186 | clust_ids_x = np.unique(X_labels) 187 | clust_ids_y = np.unique(Y_labels) 188 | 189 | # Distributed ADMM 190 | for n in range(self.maxiter): 191 | # Solve for each Q (potentially in parallel) 192 | 193 | for i in range(num_clusters_x): 194 | for j in range(num_clusters_y): 195 | T = (self.mu / h_dim) * (Rg - L[:, :, i, j]) 196 | R[:, :, i, j], _, C[i, j] = self._subspace_alignment_solver( 197 | X_mbed[(X_labels == clust_ids_x[i]).squeeze(), :], 198 | Y_mbed[(Y_labels == clust_ids_y[j]).squeeze(), :], 199 | P[i, j], T) 200 | 201 | # Solve for P 202 | P, _ = _sinkhorn_clusters(p, q, C, self.shorn_gamma, self.shorn_maxiter) 203 | 204 | # Solve for global rotation, Rg 205 | Rg_prev = Rg 206 | Rg = _closed_form_rotation_solver( 207 | np.mean(np.reshape(R + L, [h_dim, h_dim, num_clusters_x * num_clusters_y], order='F'), axis=2)) 208 | 209 | # Update Lagrangian multipliers 210 | L = L + R - Rg[:, :, np.newaxis, np.newaxis] 211 | 212 | diagnostics['gamma'][n] = self.shorn_gamma 213 | diagnostics['Rg_norm'][n] = np.linalg.norm(Rg_prev - Rg, 'fro') 214 | diagnostics['rMSE'][n] = _rMSE(X, Rg, self.Rgt) 215 | diagnostics['R2'][n] = _eval_R2(X, Rg, self.Rgt) 216 | 217 | if (np.isnan(P).any() or diagnostics['Rg_norm'][n] <= self.tol) and n >= 5: 218 | diagnostics['gamma'] = diagnostics['gamma'][0:n + 1] 219 | diagnostics['Rg_norm'] = diagnostics['Rg_norm'][0:n + 1] 220 | diagnostics['rMSE'] = diagnostics['rMSE'][0:n + 1] 221 | diagnostics['R2'] = diagnostics['R2'][0:n + 1] 222 | diagnostics['C'] = C 223 | break 224 | 225 | self.Rg = Rg 226 | self.P = P 227 | self.diagnostics = diagnostics 228 | 229 | def transform(self, X): 230 | """Transform X by applying the learned rotation to it. 231 | 232 | Parameters 233 | ---------- 234 | X : array-like, shape (n_samples_x, n_features) 235 | Source dataset, to be rotated. 236 | 237 | Returns 238 | ---------- 239 | X_new : array-like, shape(n_samples_x, n_features) 240 | """ 241 | return (self.Rg @ X.T).T 242 | 243 | def fit_transform(self, X, X_labels, Y, Y_labels, **kwargs): 244 | """Fit the model with X, learning a rotation to match X to Y, and apply the learned rotation to it. 245 | 246 | Parameters 247 | ---------- 248 | X : array-like, shape (n_samples_x, n_features) 249 | Source dataset, to be rotated. 250 | 251 | X_labels : array-like, shape (n_samples_x, ) 252 | Cluster labels for the source dataset. 253 | 254 | Y : arraylike, shape (n_samples_y, n_features) 255 | Target dataset, of the same number of features as X, to rotate X to match. 256 | 257 | Y_labels : array-like, shape (n_samples_y, ) 258 | Cluster labels for the target dataset. 259 | 260 | Returns 261 | ---------- 262 | X_new : array-like, shape(n_samples_x, n_features) 263 | 264 | """ 265 | self.fit(X, X_labels, Y, Y_labels, **kwargs) 266 | return self.transform(X) 267 | 268 | def _subspace_alignment_solver(self, X, Y, P, T): 269 | """Earth Mover's Distance on low-rank projections""" 270 | # Initialization 271 | h_dim, num_x, num_y = X.shape[1], X.shape[0], Y.shape[0] 272 | R = orth(np.random.random((h_dim, h_dim))) 273 | # R = np.identity(h_dim) 274 | Q = np.full((num_x, num_y), 1 / (num_x * num_y)) 275 | p = np.full((num_x, 1), 1 / num_x) 276 | q = np.full((num_y, 1), 1 / num_y) 277 | 278 | # Alternating minimization 279 | for i in range(self.sa_maxiter): 280 | R_prev = R 281 | 282 | # Solve rotation 283 | R = _closed_form_rotation_solver(2 * P * (Y.T @ Q.T @ X) + T) 284 | 285 | # Solve Sinkhorn OT 286 | Q, dist = _sinkhorn(p, q, R @ X.T, Y.T, self.sa_shorn_gamma / P, self.sa_shorn_maxiter) 287 | 288 | if np.linalg.norm(R_prev - R, 2) <= self.sa_tol: 289 | break 290 | 291 | return R, Q, dist 292 | -------------------------------------------------------------------------------- /HiWA-Demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Hierarchical Wasserstein Alignment (HiWA)\n", 8 | "\n", 9 | "_Aim:_ Decompose data into potential clusters and then find global alignment across two datasets by leveraging clustering structure.\n", 10 | "\n", 11 | "_Paper describing HiWA:_ John Lee, Max Dabagia, E Dyer, C Rozell: Hierarchical Wasserstein Alignment for Multimodal Distributions, May 2019.\n", 12 | "\n", 13 | "#### In this demo there are steps to:\n", 14 | "\n", 15 | "- Dimensionality reduction to first simplify problem\n", 16 | "- Clustering data to feed labels into HiWA\n", 17 | "- Alignment between clusters and then global alignment with HiWA" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 67, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "# Import necessary packages\n", 27 | "\n", 28 | "import numpy as np\n", 29 | "from src import utils, hiwa\n", 30 | "from sklearn.datasets import make_spd_matrix\n", 31 | "from sklearn.decomposition import PCA, FactorAnalysis\n", 32 | "from sklearn.manifold import Isomap\n", 33 | "from sklearn.neighbors import NearestNeighbors\n", 34 | "from sklearn.mixture import GaussianMixture\n", 35 | "from scipy.linalg import pinv\n", 36 | "import time\n", 37 | "import os\n", 38 | "import matplotlib.pyplot as plt\n", 39 | "%matplotlib inline" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "### Example A - Synthetic Gaussians " 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "#### Step 0. Load data from numpy archive" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 262, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "with np.load(os.path.join('data', 'sg_demo.npz')) as dataz:\n", 63 | " X_te = dataz['X_te']\n", 64 | " T_te = dataz['T_te']\n", 65 | " X_tr = dataz['X_tr']\n", 66 | " T_tr = dataz['T_tr']" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "#### Step 1. Align data with HiWA" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 272, 79 | "metadata": {}, 80 | "outputs": [ 81 | { 82 | "name": "stdout", 83 | "output_type": "stream", 84 | "text": [ 85 | "Time elapsed: 0.57946 seconds\n" 86 | ] 87 | } 88 | ], 89 | "source": [ 90 | "# Fit the model and transform the source dataset\n", 91 | "t1 = time.time()\n", 92 | "\n", 93 | "# This indicates to the model it should compute a low-d embedding with isomap, \n", 94 | "# and that the parameter datasets will not be normalized.\n", 95 | "hwa = hiwa.HiWA(dim_red_method=Isomap(n_components=2), normalize=True)\n", 96 | "\n", 97 | "# Fit the model, also passing in the correct rotation in Rgt since we can \n", 98 | "# calculate it in this case (for computing error metrics)\n", 99 | "# Note that we can provide the low-d mapping for either the source or target; \n", 100 | "# the model will calculate whatever isn't provided\n", 101 | "hwa.fit(X_te, T_te, X_tr, T_tr, Rgt=Rgt)\n", 102 | "\n", 103 | "# Transform the dataset with the learned rotation\n", 104 | "X_te_rec = hwa.transform(X_te)\n", 105 | "t2 = time.time()\n", 106 | "print('Time elapsed: {:.5} seconds'.format(t2 - t1))" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "#### Step 2. Evaluate performance with 1NN target accuracy" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 273, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "NN Accuracy (Before): 26.00%\n", 126 | "NN Accuracy (After): 94.75%\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "# Evaluate nearest neighbors classification accuracy before alignment\n", 132 | "nbrs = NearestNeighbors()\n", 133 | "nbrs.fit(X_te)\n", 134 | "idx = nbrs.kneighbors(X_tr, n_neighbors=1, return_distance=False).squeeze()\n", 135 | "nn_hwa_prealign = (T_te[idx] == T_tr).sum() / T_tr.shape[0]\n", 136 | "print('NN Accuracy (Before): {:.2%}'.format(nn_hwa_prealign))\n", 137 | "\n", 138 | "# Evaluate nearest neighbors classification accuracy after alignment\n", 139 | "nbrs = NearestNeighbors()\n", 140 | "nbrs.fit(X_te_rec)\n", 141 | "idx = nbrs.kneighbors(X_tr, n_neighbors=1, return_distance=False).squeeze()\n", 142 | "nn_hwa = (T_te[idx] == T_tr).sum() / T_tr.shape[0]\n", 143 | "print('NN Accuracy (After): {:.2%}'.format(nn_hwa))" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "#### Step 3. Visualize results" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 274, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "data": { 160 | "image/png": "\n", 161 | "text/plain": [ 162 | "
" 163 | ] 164 | }, 165 | "metadata": { 166 | "needs_background": "light" 167 | }, 168 | "output_type": "display_data" 169 | } 170 | ], 171 | "source": [ 172 | "# Plot the results, plus some diagnostics\n", 173 | "plt.figure(figsize=(12, 11))\n", 174 | "plt.subplot(2, 3, 1)\n", 175 | "plt.title('Rotated (NN Accuracy = {:.2%})'.format(nn_hwa_prealign))\n", 176 | "utils.plot_2d_clusters(X_te, T_te)\n", 177 | "plt.subplot(2, 3, 2)\n", 178 | "plt.title('Rotated (NN Accuracy = {:.2%})'.format(nn_hwa))\n", 179 | "utils.plot_2d_clusters(X_te_rec, T_te)\n", 180 | "plt.subplot(2, 3, 3)\n", 181 | "plt.title('Target')\n", 182 | "utils.plot_2d_clusters(X_tr, T_tr)\n", 183 | "ax = plt.subplot(2, 3, 4)\n", 184 | "plt.plot(hwa.diagnostics['Rg_norm'])\n", 185 | "ax.set_yscale('log')\n", 186 | "plt.title('Residual Norm')\n", 187 | "plt.xlabel('Iteration')\n", 188 | "plt.subplot(2, 3, 5)\n", 189 | "plt.imshow(hwa.P)\n", 190 | "plt.xticks(np.arange(4))\n", 191 | "plt.yticks(np.arange(4))\n", 192 | "plt.title('Correspondence (Entropy$ = $ %.4f)' %(-np.sum(hwa.P * np.log(hwa.P))))\n", 193 | "plt.subplot(2, 3, 6)\n", 194 | "plt.xticks(np.arange(4))\n", 195 | "plt.yticks(np.arange(4))\n", 196 | "plt.title('Distance Matrix')\n", 197 | "c = plt.imshow(hwa.diagnostics['C'])" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "### Example B - Motor cortex during reaching movements (NHP)\n", 205 | "_Data collected by:_ Matthew Perich, Lee Miller Lab (Northwestern)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "metadata": {}, 211 | "source": [ 212 | "#### Step 0. Load datasets and reach direction labels" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 2, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "# Load the datasets from a .mat file\n", 222 | "T_te, T_tr, X_te, X_tr, Y_te, Y_tr = utils.load_data(os.path.join('data', 'mihi_demo.mat'))\n", 223 | "T_te, T_tr = T_te.squeeze(), T_tr.squeeze()" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": {}, 229 | "source": [ 230 | "#### Step 1. Apply initial dimensionality reduction" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 3, 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "# Apply initial dimensionality reduction (n_components)\n", 240 | "Y_te_3d = FactorAnalysis(n_components=3).fit_transform(utils.remove_const_cols(Y_te))\n", 241 | "X_tr_3d = utils.map_X_3D(X_tr)\n", 242 | "X_transform = np.linalg.pinv(X_tr_3d) @ X_tr" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "metadata": {}, 248 | "source": [ 249 | "#### Step 2. Align data with HiWA" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 4, 255 | "metadata": {}, 256 | "outputs": [ 257 | { 258 | "name": "stdout", 259 | "output_type": "stream", 260 | "text": [ 261 | "Time elapsed: 8.0 seconds\n" 262 | ] 263 | } 264 | ], 265 | "source": [ 266 | "# Fit the model and transform the source dataset\n", 267 | "t1 = time.time()\n", 268 | "\n", 269 | "# This indicates to the model it should compute a low-d embedding with isomap, \n", 270 | "# and that the parameter datasets will not be normalized\n", 271 | "hwa = hiwa.HiWA(dim_red_method=Isomap(n_components=2, n_neighbors=12), normalize=True)\n", 272 | "\n", 273 | "# Fit the model, also passing in the correct rotation in Rgt since we can \n", 274 | "# calculate it in this case (for computing error metrics)\n", 275 | "# Note that we can provide the low-d mapping for either the source or target; \n", 276 | "# the model will calculate whatever isn't provided\n", 277 | "hwa.fit(Y_te_3d, T_te, X_tr_3d, T_tr, Y_transform=X_transform, Rgt=utils.LS_oracle(utils.map_X_3D(X_te), Y_te_3d))\n", 278 | "\n", 279 | "# Transform the dataset with the learned rotation\n", 280 | "Y_te_rec = hwa.transform(Y_te_3d)\n", 281 | "t2 = time.time()\n", 282 | "print('Time elapsed: {:.5} seconds'.format(t2 - t1))" 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": {}, 288 | "source": [ 289 | "#### Step 3. Evaluate performance metrics (R2 and target accuracy)" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 6, 295 | "metadata": {}, 296 | "outputs": [ 297 | { 298 | "name": "stdout", 299 | "output_type": "stream", 300 | "text": [ 301 | "R2 Value: 0.6307\n", 302 | "NN Accuracy (Before): 24.40%\n", 303 | "NN Accuracy (After): 52.65%\n" 304 | ] 305 | } 306 | ], 307 | "source": [ 308 | "# Evaluate decoding accuracy after alignment\n", 309 | "r2_hwa = utils.eval_R2(Y_te_rec[:,0:2], X_te)\n", 310 | "print('R2 Value: {:.4}'.format(r2_hwa))\n", 311 | "\n", 312 | "# Evaluate nearest neighbors classification accuracy before alignment\n", 313 | "nbrs = NearestNeighbors()\n", 314 | "nbrs.fit(Y_te_3d)\n", 315 | "idx = nbrs.kneighbors(X_tr_3d, n_neighbors=1, return_distance=False).squeeze()\n", 316 | "nn_hwa_prealign = (T_te[idx] == T_tr).sum() / T_tr.shape[0]\n", 317 | "print('NN Accuracy (Before): {:.2%}'.format(nn_hwa_prealign))\n", 318 | "\n", 319 | "# Evaluate nearest neighbors classification accuracy after alignment\n", 320 | "nbrs = NearestNeighbors()\n", 321 | "nbrs.fit(Y_te_rec)\n", 322 | "idx = nbrs.kneighbors(X_tr_3d, n_neighbors=1, return_distance=False).squeeze()\n", 323 | "nn_hwa = (T_te[idx] == T_tr).sum() / T_tr.shape[0]\n", 324 | "print('NN Accuracy (After): {:.2%}'.format(nn_hwa))" 325 | ] 326 | }, 327 | { 328 | "cell_type": "markdown", 329 | "metadata": {}, 330 | "source": [ 331 | "#### Step 4. Visualize results" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 7, 337 | "metadata": {}, 338 | "outputs": [ 339 | { 340 | "data": { 341 | "image/png": "\n", 342 | "text/plain": [ 343 | "
" 344 | ] 345 | }, 346 | "metadata": { 347 | "needs_background": "light" 348 | }, 349 | "output_type": "display_data" 350 | } 351 | ], 352 | "source": [ 353 | "# Plot the results, plus some diagnostics\n", 354 | "plt.figure(figsize=(12, 11))\n", 355 | "plt.subplot(2, 3, 1)\n", 356 | "plt.title('Rotated (NN Accuracy = {:.2%})'.format(nn_hwa_prealign))\n", 357 | "utils.plot_2d_clusters(Y_te_3d, T_te)\n", 358 | "plt.subplot(2, 3, 2)\n", 359 | "plt.title('Rotated (NN Accuracy = {:.2%})'.format(nn_hwa))\n", 360 | "utils.plot_2d_clusters(Y_te_rec, T_te)\n", 361 | "plt.subplot(2, 3, 3)\n", 362 | "plt.title('Target')\n", 363 | "utils.plot_2d_clusters(X_tr, T_tr)\n", 364 | "ax = plt.subplot(2, 3, 4)\n", 365 | "plt.plot(hwa.diagnostics['Rg_norm'])\n", 366 | "ax.set_yscale('log')\n", 367 | "plt.title('Residual Norm')\n", 368 | "plt.xlabel('Iteration')\n", 369 | "plt.subplot(2, 3, 5)\n", 370 | "plt.imshow(hwa.P)\n", 371 | "plt.xticks(np.arange(4))\n", 372 | "plt.yticks(np.arange(4))\n", 373 | "plt.title('Correspondence (Entropy$ = $ %.4f)' %(-np.sum(hwa.P * np.log(hwa.P))))\n", 374 | "plt.subplot(2, 3, 6)\n", 375 | "plt.xticks(np.arange(4))\n", 376 | "plt.yticks(np.arange(4))\n", 377 | "plt.title('Distance Matrix')\n", 378 | "c = plt.imshow(hwa.diagnostics['C'])" 379 | ] 380 | } 381 | ], 382 | "metadata": { 383 | "kernelspec": { 384 | "display_name": "Python 3", 385 | "language": "python", 386 | "name": "python3" 387 | }, 388 | "language_info": { 389 | "codemirror_mode": { 390 | "name": "ipython", 391 | "version": 3 392 | }, 393 | "file_extension": ".py", 394 | "mimetype": "text/x-python", 395 | "name": "python", 396 | "nbconvert_exporter": "python", 397 | "pygments_lexer": "ipython3", 398 | "version": "3.7.3" 399 | } 400 | }, 401 | "nbformat": 4, 402 | "nbformat_minor": 2 403 | } 404 | --------------------------------------------------------------------------------