├── .github └── workflows │ ├── pre-commit.yml │ └── python-publish.yml ├── .gitignore ├── DiffusionEMD ├── Spherical_MNIST_Comparisons.py ├── __init__.py ├── convolutional_sinkhorn.py ├── dataset.py ├── diffusion_emd.py ├── emd.py ├── estimate_utils.py ├── metric_tree.py └── version.py ├── LICENSE ├── LICENSE.md ├── README.rst ├── assets └── schematic_600_400.png ├── notebooks ├── Line Example.ipynb └── Swissroll Example.ipynb ├── requirements.txt ├── setup.cfg ├── setup.py └── tox.ini /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | on: 3 | push: 4 | branches-ignore: 5 | - 'master' 6 | 7 | jobs: 8 | pre-commit: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Cancel Previous Runs 12 | uses: styfle/cancel-workflow-action@0.6.0 13 | with: 14 | access_token: ${{ github.token }} 15 | - uses: actions/checkout@v2 16 | with: 17 | fetch-depth: 0 18 | 19 | - uses: actions/setup-python@v2 20 | with: 21 | python-version: "3.7" 22 | architecture: "x64" 23 | 24 | - uses: actions/cache@v2 25 | with: 26 | path: ~/.cache/pre-commit 27 | key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}- 28 | 29 | - uses: pre-commit/action@v2.0.0 30 | continue-on-error: true 31 | 32 | - name: Commit files 33 | run: | 34 | if [[ `git status --porcelain --untracked-files=no` ]]; then 35 | git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com" 36 | git config --local user.name "github-actions[bot]" 37 | git commit -m "pre-commit" -a 38 | fi 39 | 40 | - name: Push changes 41 | uses: ad-m/github-push-action@master 42 | with: 43 | github_token: ${{ secrets.GITHUB_TOKEN }} 44 | branch: ${{ github.ref }} 45 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Publish Python 🐍 distributions 📦 to PyPI 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: '3.x' 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install build 29 | - name: Build package 30 | run: python -m build 31 | - name: Publish distribution 📦 to Test PyPI 32 | uses: pypa/gh-action-pypi-publish@release/v1 33 | with: 34 | skip_existing: true 35 | user: __token__ 36 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 37 | repository_url: https://test.pypi.org/legacy/ 38 | - name: Publish distribution 📦 to PyPI 39 | uses: pypa/gh-action-pypi-publish@release/v1 40 | with: 41 | user: __token__ 42 | password: ${{ secrets.PYPI_API_TOKEN }} 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *.cover 45 | 46 | # Translations 47 | *.mo 48 | *.pot 49 | 50 | # Django stuff: 51 | *.log 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | # PyBuilder 57 | target/ 58 | 59 | # DotEnv configuration 60 | .env 61 | 62 | # Database 63 | *.db 64 | *.rdb 65 | 66 | # Pycharm 67 | .idea 68 | 69 | # VS Code 70 | .vscode/ 71 | 72 | # Spyder 73 | .spyproject/ 74 | 75 | # Jupyter NB Checkpoints 76 | .ipynb_checkpoints/ 77 | 78 | # exclude data from source control by default 79 | /data/ 80 | 81 | # exclude old folder by default 82 | /old/ 83 | 84 | # Mac OS-specific storage files 85 | .DS_Store 86 | 87 | # vim 88 | *.swp 89 | *.swo 90 | 91 | # Mypy cache 92 | .mypy_cache/ 93 | 94 | # Snakemake cache 95 | .snakemake/ 96 | -------------------------------------------------------------------------------- /DiffusionEMD/Spherical_MNIST_Comparisons.py: -------------------------------------------------------------------------------- 1 | from sklearn.neighbors import KNeighborsClassifier 2 | from sklearn.model_selection import train_test_split 3 | import torch 4 | import torchvision.datasets as datasets 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | import graphtools 8 | import sklearn.datasets 9 | import pygsp 10 | import sklearn 11 | import ot 12 | import pandas as pd 13 | import scipy.sparse 14 | import pickle 15 | import scprep 16 | from manifold_ot import estimate_utils 17 | 18 | 19 | class Spherical_MNIST_Predictions: 20 | def __init__(self, sphere_graph_path, sphere_signals_path): 21 | if sphere_graph_path and sphere_signals_path: 22 | self.sphere_graph = pickle.load(sphere_graph_path) 23 | self.sphere_signals = pickle.load(sphere_signals_path) 24 | else: 25 | print("Please precompute the sphere -- not yet implemented") 26 | self.knn = KNeighborsClassifier(n_neighbors=1) 27 | 28 | def knn_classify(self, embeddings, num_neighbors=1): 29 | # perform a train/test split (by default 50-50) 30 | knn = KNeighborsClassifier(n_neighbors=num_neighbors) 31 | X_train, X_test, y_train, y_test = train_test_split( 32 | embeddings, self.dataset_labels, random_state=0 33 | ) 34 | knn.fit(X_train, y_train) 35 | # get prediction accuracy 36 | preds = knn.predict(X_test) 37 | acc = np.sum((preds == y_test).float()) / len(X_test) 38 | return acc 39 | 40 | def MOT_embedding(self, num_evals=1000): 41 | # perform a MOT embedding of the dataset 42 | def apply_anisotropy(K, anisotropy): 43 | if anisotropy == 0: 44 | # do nothing 45 | return K 46 | if scipy.sparse.issparse(K): 47 | d = np.array(K.sum(1)).flatten() 48 | K = K.tocoo() 49 | K.data = K.data / ((d[K.row] * d[K.col]) ** anisotropy) 50 | K = K.tocsr() 51 | return K, d 52 | d = K.sum(1) 53 | K = K / (np.outer(d, d) ** anisotropy) 54 | return K, d 55 | 56 | def apply_vectors(M, d, d_post=None): 57 | if d_post is None: 58 | d_post = d 59 | if scipy.sparse.issparse(M): 60 | M = M.tocoo() 61 | M.data = M.data * (d[M.row] * d_post[M.col]) 62 | return M.tocsr() 63 | return M / np.outer(d, d_post) 64 | 65 | def diffusion_embeddings( 66 | graph, 67 | distribution_labels, 68 | method="chebyshev", 69 | max_scale=7, 70 | min_scale=1, 71 | version=1, 72 | anisotropy=0.0, 73 | k=None, 74 | return_eig=False, 75 | subselect=False, 76 | alpha=1, 77 | ): 78 | """ 79 | Return the vectors whose L1 distances are the EMD between the given distributions. 80 | The graph supplied (a PyGSP graph) should encompass both distributions. 81 | The distributions themselves should be one-hot encoded with the distribution_labels parameter. 82 | """ 83 | assert version >= 3 84 | assert 0 <= anisotropy <= 1 85 | if k is None: 86 | k = graph.N - 1 87 | print(f"Graph has N = {graph.N}. Using k = {k}") 88 | diffusions = [] 89 | if version <= 4: 90 | graph.compute_laplacian(lap_type="normalized") 91 | # Lazy symmetric random walk matrix 92 | P = np.eye(graph.N) - graph.L / 2 93 | # e, U = np.linalg.eigh(P) 94 | e, U = scipy.sparse.linalg.eigsh(P, k=k) 95 | for scale in [2 ** i for i in range(1, max_scale)]: 96 | Pt = U @ np.diag(e ** scale) @ U.T 97 | diffusions.append(Pt @ distribution_labels) 98 | else: 99 | A = graph.W 100 | D = np.array(A.sum(axis=0)).squeeze() 101 | P = apply_anisotropy(A, anisotropy) 102 | # Sums along axis=1 are all 1 103 | D_norm = np.array(P.sum(axis=0)).squeeze() 104 | M = apply_vectors(P, D_norm ** -0.5) 105 | e, U = scipy.sparse.linalg.eigsh(M, k=k) 106 | for scale in [2 ** i for i in range(min_scale, max_scale)]: 107 | Pt_sym = U @ np.diag(e ** scale) @ U.T 108 | Pt = apply_vectors(Pt_sym, D_norm ** -0.5, D_norm ** 0.5) 109 | diffusions.append(Pt @ distribution_labels) 110 | diffusions = np.stack(diffusions, axis=-1) 111 | n, n_samples, n_scales = diffusions.shape 112 | embeddings = [] 113 | for i in range(n_scales): 114 | d = diffusions[..., i] 115 | if (version == 2) or (version == 3): 116 | if i < n_scales - 1: 117 | d -= diffusions[..., -1] 118 | weight = 0.5 ** (n_scales - i - 1) 119 | elif version == 4: 120 | if i < n_scales - 1: 121 | d -= diffusions[..., i + 1] 122 | weight = 0.5 ** (n_scales - i - 1) 123 | elif version == 5: 124 | if i < n_scales - 1: 125 | d -= diffusions[..., -1] 126 | weight = 0.5 ** ((n_scales - i - 1) * alpha) 127 | elif version == 6: 128 | if i < n_scales - 1: 129 | d -= diffusions[..., i + 1] 130 | weight = 0.5 ** ((n_scales - i - 1) * alpha) 131 | lvl_embed = weight * d.T 132 | 133 | embeddings.append(lvl_embed) 134 | 135 | if subselect: 136 | num_samples = approximate_rank_of_scales( 137 | P, 0.5, scales=[2 ** i for i in range(min_scale, max_scale)] 138 | ) 139 | print(num_samples) 140 | augmented_num_samples = [ 141 | min(n * (2 ** (i + min_scale)), graph.N) 142 | for i, n in enumerate(num_samples) 143 | ] 144 | print(augmented_num_samples) 145 | selections = [] 146 | pps = [] 147 | for arank in augmented_num_samples: 148 | selected, pp = randomized_interpolative_decomposition( 149 | np.array(P), arank, arank + 8, return_p=True 150 | ) 151 | selections.append(selected) 152 | pps.append(pp) 153 | # augmented_num_samples 154 | print(embeddings[0].shape, len(embeddings)) 155 | tmp = [] 156 | for s, e, a in zip(selections, embeddings, augmented_num_samples): 157 | tmp.append(e[:, s] * graph.N / a) 158 | embeddings = tmp 159 | embeddings = np.concatenate(embeddings, axis=1) 160 | if return_eig and subselect: 161 | return embeddings, e, U, pps 162 | if return_eig: 163 | return embeddings, e, U 164 | return embeddings 165 | 166 | embeddings = diffusion_embeddings( 167 | self.sphere_graph, self.sphere_signals, version=5, max_scale=12, k=num_evals 168 | ) 169 | return embeddings 170 | -------------------------------------------------------------------------------- /DiffusionEMD/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion_emd import DiffusionTree, DiffusionCheb, DiffusionExact, DiffusionTreeV2 2 | from .metric_tree import QuadTree, ClusterTree, MetricTree 3 | from .convolutional_sinkhorn import conv_sinkhorn 4 | 5 | from .version import __version__ 6 | -------------------------------------------------------------------------------- /DiffusionEMD/convolutional_sinkhorn.py: -------------------------------------------------------------------------------- 1 | """ Implements convolutional sinkhorn distances from Solomon et al. 2015 2 | """ 3 | import numpy as np 4 | import pygsp 5 | 6 | 7 | def conv_sinkhorn( 8 | W, m_0, m_1, stopThr=1e-4, max_iter=1e3, method="chebyshev", t=50, verbose=False 9 | ): 10 | """ Implements the convolutional sinkhorn operator described in Solomon et 11 | al. 2015. This is sinkhorn except the cost matrix is replaced with the heat 12 | operator which may be easier to apply. 13 | 14 | Notes: It is unclear how to pick t from the manuscript. We will pick by 15 | cross validation. 16 | 17 | Parameters 18 | ---------- 19 | W, n x n adjacency matrix of a graph 20 | m_0, m_1 distributions over W numpy arrays of length n 21 | """ 22 | eps = 1e-8 23 | N = W.shape[0] 24 | G = pygsp.graphs.Graph(W) 25 | if method == "chebyshev": 26 | G.estimate_lmax() 27 | elif method == "exact": 28 | G.compute_fourier_basis() 29 | else: 30 | raise NotImplementedError("Unknown method %s" % method) 31 | heat_filter = pygsp.filters.Heat(G, t) 32 | v = np.ones(N) 33 | w = np.ones(N) 34 | for i in range(1, int(max_iter) + 1): 35 | v_prev = v 36 | v = m_0 / (heat_filter.filter(w, method=method) + eps) 37 | w = m_1 / (heat_filter.filter(v, method=method) + eps) 38 | if i % 100 == 0: 39 | if verbose: 40 | print(i, np.sum(np.abs(v - v_prev))) 41 | if np.sum(np.abs(v - v_prev)) < stopThr: 42 | if verbose: 43 | print("converged at iteration %d" % i) 44 | break 45 | 46 | return np.sum(t * (m_0 * np.log(v + eps) + m_1 * np.log(w + eps))) 47 | -------------------------------------------------------------------------------- /DiffusionEMD/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Handles datasets for the manifold OT project 3 | 4 | 5 | """ 6 | import graphtools 7 | import numpy as np 8 | from scipy.stats import special_ortho_group 9 | import sklearn.datasets as skd 10 | import sklearn.metrics 11 | from sklearn.neighbors import kneighbors_graph 12 | from sklearn.neighbors import radius_neighbors_graph 13 | import ot 14 | import pygsp 15 | 16 | 17 | class Dataset(object): 18 | """ Dataset class for Optimal Transport 19 | 20 | Paramters 21 | --------- 22 | X: [N x F] 23 | 24 | labels: [N x M] 25 | 26 | """ 27 | 28 | def __init__(self): 29 | super().__init__() 30 | self.X = None 31 | self.labels = None 32 | self.graph = None 33 | 34 | def get_labels(self): 35 | return self.labels 36 | 37 | def get_data(self): 38 | return self.X 39 | 40 | def standardize_data(self): 41 | """ Standardize data putting it in a unit box around the origin. 42 | This is necessary for quadtree type algorithms 43 | """ 44 | X = self.X 45 | minx = np.min(self.X, axis=0) 46 | maxx = np.max(self.X, axis=0) 47 | self.std_X = (X - minx) / (maxx - minx) 48 | return self.std_X 49 | 50 | def rotate_to_dim(self, dim): 51 | """ Rotate dataset to a different dimensionality """ 52 | self.rot_mat = special_ortho_group.rvs(dim)[: self.X.shape[1]] 53 | self.high_X = np.dot(self.X, self.rot_mat) 54 | return self.high_X 55 | 56 | class Ring(Dataset): 57 | def __init__(self, n_points, random_state=42): 58 | super().__init__() 59 | self.n_points = n_points 60 | N = n_points 61 | self.random_state = random_state 62 | np.random.seed(42) 63 | self.X = np.linspace(0, 1 - (1 / N), N)[:, None] 64 | self.X_circle = np.stack([np.cos(2 * np.pi * self.X[:,0]), np.sin(2 * np.pi * self.X[:,0])], axis=1) 65 | # print(self.X_circle) 66 | #self.graph = pygsp.graphs.NNGraph( 67 | # self.X_circle, epsilon=0.1, NNtype="radius", rescale=False, center=False 68 | #) 69 | self.graph = pygsp.graphs.Ring(self.n_points) 70 | self.labels = np.eye(N) 71 | 72 | def get_graph(self): 73 | return self.graph 74 | 75 | class Line(Dataset): 76 | def __init__(self, n_points, epsilon=0.1, random_state=42): 77 | super().__init__() 78 | self.n_points = n_points 79 | N = n_points 80 | self.random_state = random_state 81 | np.random.seed(42) 82 | self.X = np.linspace(0, 1, N)[:, None] 83 | # self.X_circle = np.stack( 84 | # [np.cos(2 * np.pi * self.X[:, 0]), np.sin(2 * np.pi * self.X[:, 0])], 85 | # axis=1 86 | # ) 87 | self.graph = self.create_radius_graph(self.X, epsilon) 88 | self.labels = np.eye(N) 89 | 90 | def create_radius_graph(self, X, epsilon): 91 | """ 92 | Create a graph where each node is connected to all other nodes within a certain radius. 93 | """ 94 | adjacency_matrix = radius_neighbors_graph(X, radius=epsilon, mode='connectivity', include_self=False) 95 | 96 | # Create the pygsp graph using the adjacency matrix 97 | pygsp_graph = pygsp.graphs.Graph(adjacency_matrix) 98 | return pygsp_graph 99 | 100 | def get_graph(self): 101 | return self.graph 102 | 103 | 104 | class SklearnDataset(Dataset): 105 | """ Make a dataset based on an SKLearn dataset with a 106 | gaussian centered at each point. 107 | """ 108 | 109 | def __init__( 110 | self, 111 | name=None, 112 | n_distributions=100, 113 | n_points_per_distribution=50, 114 | noise=0.0, 115 | random_state=42, 116 | ): 117 | super().__init__() 118 | self.name = name 119 | self.n_distributions = n_distributions 120 | self.n_points_per_distribution = 50 121 | self.noise = noise 122 | self.random_state = random_state 123 | if name == "swiss_roll": 124 | f = skd.make_swiss_roll 125 | elif name == "s_curve": 126 | f = skd.make_s_curve 127 | else: 128 | raise NotImplementedError("Unknown sklearn dataset: %s" % name) 129 | self.means, self.t = f( 130 | n_samples=n_distributions, noise=noise, random_state=random_state 131 | ) 132 | rng = np.random.default_rng(random_state) 133 | 134 | clouds = np.array( 135 | [ 136 | rng.multivariate_normal( 137 | mean, 20 * np.identity(3), n_points_per_distribution 138 | ) 139 | for mean in self.means 140 | ] 141 | ) 142 | self.X = np.reshape(clouds, (n_distributions * n_points_per_distribution, 3)) 143 | self.labels = np.repeat( 144 | np.eye(n_distributions), n_points_per_distribution, axis=0 145 | ) 146 | 147 | def get_graph(self): 148 | """ Create a graphtools graph if does not exist 149 | """ 150 | if self.graph is None: 151 | self.graph = graphtools.Graph(self.X, use_pygsp=True) 152 | return self.graph 153 | 154 | class SwissRoll(Dataset): 155 | def __init__( 156 | self, 157 | n_distributions=100, 158 | n_points_per_distribution=50, 159 | noise=0.0, 160 | manifold_noise=1.0, 161 | width=1, 162 | random_state=42, 163 | ): 164 | super().__init__() 165 | rng = np.random.default_rng(random_state) 166 | 167 | mean_t = 1.5 * np.pi * (1 + 2 * rng.uniform(size=(1, n_distributions))) 168 | mean_y = width * rng.uniform(size=(1, n_distributions)) 169 | t_noise = ( 170 | manifold_noise 171 | * 3 172 | * rng.normal(size=(n_distributions, n_points_per_distribution)) 173 | ) 174 | y_noise = ( 175 | manifold_noise 176 | * 7 177 | * rng.normal(size=(n_distributions, n_points_per_distribution)) 178 | ) 179 | ts = np.reshape(t_noise + mean_t.T, -1) 180 | ys = np.reshape(y_noise + mean_y.T, -1) 181 | xs = ts * np.cos(ts) 182 | zs = ts * np.sin(ts) 183 | X = np.stack((xs, ys, zs)) 184 | X += noise * rng.normal(size=(3, n_distributions * n_points_per_distribution)) 185 | self.X = X.T 186 | self.ts = np.squeeze(ts) 187 | self.labels = np.repeat( 188 | np.eye(n_distributions), n_points_per_distribution, axis=0 189 | ) 190 | self.t = mean_t[0] 191 | mean_x = mean_t * np.cos(mean_t) 192 | mean_z = mean_t * np.sin(mean_t) 193 | self.means = np.concatenate((mean_x, mean_y, mean_z)).T 194 | 195 | def get_graph(self): 196 | """ Create a graphtools graph if does not exist 197 | """ 198 | if self.graph is None: 199 | self.graph = graphtools.Graph(self.X, use_pygsp=True) 200 | return self.graph 201 | 202 | 203 | class Sphere(Dataset): 204 | def __init__( 205 | self, 206 | n_distributions=100, 207 | n_points_per_distribution=50, 208 | dim = 3, 209 | noise=0.05, 210 | label_noise = 0.0, 211 | manifold_noise=1.0, 212 | width=1, 213 | flip=False, 214 | random_state=42, 215 | ): 216 | super().__init__() 217 | self.n_distributions = n_distributions 218 | self.n_points_per_distribution = n_points_per_distribution 219 | self.dim = dim 220 | self.noise = noise 221 | self.manifold_noise = manifold_noise 222 | rng = np.random.default_rng(random_state) 223 | 224 | X = rng.normal(0, 1, (self.dim, self.n_distributions)) 225 | X = X / np.linalg.norm(X, axis=0) 226 | self.means = X.T 227 | X = X[:, :, None] 228 | X = np.repeat(X, n_points_per_distribution, axis=-1) 229 | noise = noise * rng.normal(size = (dim, n_distributions, n_points_per_distribution)) 230 | X += noise 231 | X = X.reshape(dim, -1) 232 | X = X / np.linalg.norm(X, axis=0) 233 | #X += noise * rng.normal(size=(self.dim, n_distributions, n_points_per_distribution)) 234 | 235 | self.X = X.T 236 | self.labels = np.repeat( 237 | np.eye(n_distributions), n_points_per_distribution, axis=0 238 | ) 239 | 240 | # Flipping noise 241 | if flip: 242 | index_to_flip = np.random.randint(n_distributions * n_points_per_distribution, size = n_distributions) 243 | for i in range(n_distributions): 244 | self.labels[index_to_flip[i], i] = 1 - self.labels[index_to_flip[i], i] 245 | self.labels = self.labels / np.sum(self.labels, axis=0) 246 | 247 | 248 | # Ground truth dists (approximate) and clip for numerical errors 249 | self.gtdists = np.arccos(np.clip(self.means @ self.means.T, 0, 1)) 250 | 251 | def get_graph(self): 252 | """ Create a graphtools graph if does not exist 253 | """ 254 | if self.graph is None: 255 | #self.graph = graphtools.Graph(self.X, use_pygsp=True, knn=10) 256 | self.graph = pygsp.graphs.NNGraph( 257 | self.X, epsilon=0.1, NNtype="radius", rescale=False, center=False 258 | ) 259 | #self.graph = graphtools.Graph(self.X, use_pygsp=True, knn=100) 260 | return self.graph 261 | 262 | class Mnist(Dataset): 263 | def __init__(self): 264 | from torchvision.datasets import MNIST 265 | self.mnist_train = MNIST("/home/atong/data/mnist/", download=True) 266 | self.mnist_test = MNIST("/home/atong/data/mnist/", download=True, train=False) 267 | self.graph = pygsp.graphs.Grid2d(28, 28) 268 | 269 | def get_graph(self): 270 | return self.graph 271 | 272 | 273 | -------------------------------------------------------------------------------- /DiffusionEMD/diffusion_emd.py: -------------------------------------------------------------------------------- 1 | """ These functions provide a way to quickly embed a set of distributions over 2 | a graph into vectors where the L_1 distance between these embeded vectors 3 | corresponds to the Wasserstein distance between distributions. 4 | """ 5 | 6 | import numpy as np 7 | import pygsp 8 | import scipy 9 | from scipy.linalg import qr 10 | 11 | # from scipy.linalg.interpolative import interp_decomp 12 | import scipy.sparse 13 | 14 | from . import estimate_utils 15 | 16 | 17 | def estimate_dos(A, pflag=False, npts=1001): 18 | """ Estimate the density of states of the matrix A 19 | 20 | A should be a matrix of with eigenvalues in tha range [-1, 1]. 21 | """ 22 | c = estimate_utils.moments_cheb_dos(A, A.shape[0], N=50)[0] 23 | return estimate_utils.plot_chebint((c,), pflag=pflag, npts=npts) 24 | 25 | 26 | def approximate_rank(A, thresh): 27 | """ Determines the rank relative to a threshold as defined in 28 | https://doi.org/10.1016/j.acha.2012.03.002 29 | $$R_{\delta}(A) = \| \{ \frac{\sigma_j}{\sigma_0} \ge \delta \}$$ 30 | Where $\sigma_j$ denotes the $jth$ largest singular value of the matrix K 31 | TODO: This function currently assumes symmetricish distribution of eigenvalues. 32 | """ 33 | eig, density = estimate_dos(A) 34 | approx_rank = np.maximum( 35 | np.max(density[np.where(-eig >= (thresh))]), 0 36 | ) + np.maximum(A.shape[0] - np.min(density[np.where(eig >= thresh)]), 0) 37 | 38 | return int(np.ceil(approx_rank)) 39 | 40 | 41 | def interpolative_decomposition(A, k, return_p=False): 42 | assert k < np.min(A.shape) 43 | q, r, perm = qr(A, pivoting=True) 44 | b = q[:, :k] @ r[:k, :k] 45 | if return_p: 46 | p = np.concatenate([np.eye(k), np.linalg.inv(r[:k, :k]) @ r[:k, k:]], axis=1) 47 | return b, p, perm 48 | return b 49 | 50 | 51 | def approximate_rank_of_scales(A, thresh, scales): 52 | """ Returns one rank per scale, note that higher scales have less accuracy 53 | and may need more evaluations. Number of evaluations is currently set 54 | manually. 55 | """ 56 | eig, density = estimate_dos(A) 57 | ranks = [] 58 | for scale in scales: 59 | if scale == 0: 60 | approx_rank = A.shape[0] 61 | else: 62 | approx_rank = np.maximum( 63 | np.max(density[np.where(-eig >= (thresh ** (1 / scale)))]), 0 64 | ) + np.maximum( 65 | A.shape[0] - np.min(density[np.where(eig >= (thresh ** (1 / scale)))]), 0 66 | ) 67 | approx_rank = int(np.ceil(approx_rank)) 68 | ranks.append(approx_rank) 69 | return ranks 70 | 71 | 72 | def apply_anisotropy(K, anisotropy): 73 | if anisotropy == 0: 74 | # do nothing 75 | return K 76 | 77 | if scipy.sparse.issparse(K): 78 | d = np.array(K.sum(1)).flatten() 79 | K = K.tocoo() 80 | K.data = K.data / ((d[K.row] * d[K.col]) ** anisotropy) 81 | K = K.tocsr() 82 | return K 83 | d = K.sum(1) 84 | K = K / (np.outer(d, d) ** anisotropy) 85 | return K 86 | 87 | 88 | def apply_vectors(M, d, d_post=None): 89 | if d_post is None: 90 | d_post = d 91 | if scipy.sparse.issparse(M): 92 | M = M.tocoo() 93 | M.data = M.data * (d[M.row] * d_post[M.col]) 94 | return M.tocsr() 95 | return M / np.outer(d, d_post) 96 | 97 | def apply_left(M, d): 98 | if scipy.sparse.issparse(M): 99 | M = M.tocoo() 100 | M.data = M.data * (d[M.row]) 101 | return M.tocsr() 102 | else: 103 | M 104 | 105 | def apply_right(M, d): 106 | if scipy.sparse.issparse(M): 107 | M = M.tocoo() 108 | M.data = M.data * (d[M.col]) 109 | return M.tocsr() 110 | 111 | 112 | 113 | def adjacency_to_operator(A, anisotropy): 114 | """ Gets the symmetric conjugate of the diffusion operator and its 115 | row/col sums as a vector. 116 | """ 117 | M = apply_anisotropy(A, anisotropy) 118 | D_norm = np.array(M.sum(axis=0)).squeeze() 119 | return M, D_norm 120 | 121 | 122 | def randomized_interpolative_decomposition( 123 | A, k_1, k_2, k_3=5, tol=1e-6, return_p=False 124 | ): 125 | """ Finds the columns of a large matrix that represent the whole matrix 126 | well in terms of rank. This is done by first projecting to k_2 (of order 127 | k_1) dimensions randomly, then doing QR decomposition. This results in a 128 | matrix S that (approximately) consists of a subset of size k_1 columns of 129 | W. To find the indices that S represents we then randomly project columns 130 | down to k_3 elements to quickly test for equality. This projection ensures 131 | the equality test is parallelizable and scales linearly with the size of W. 132 | Note that this equality test could fail for many reasons, including: 133 | (1) repeated columns in W, or columns that are within our tolerance of L_2 134 | distance. 135 | (2) k_3 is too small resulting in false positives (i.e. columns both in the 136 | null space of our projection that are not equal). 137 | """ 138 | m, n = A.shape 139 | assert k_1 < k_2 140 | assert k_3 < k_2 141 | assert k_2 <= min(m, n) 142 | # if use_sparse: 143 | # J sparse_interpolative_decomposition(A, k_1, return_p=return_p) 144 | # else: 145 | G = np.random.randn(k_2, A.shape[0]) 146 | W = G @ A 147 | S = interpolative_decomposition(W, k_1, return_p=return_p) 148 | if return_p: 149 | S, P, perm = S 150 | indices = [] 151 | R = np.random.randn(k_3, k_2) 152 | # Slow way, implemented more efficiently 153 | # Q = (R @ S)[:, :, None] - (R @ W)[:, None, :] 154 | count = 0 155 | while len(indices) != k_1: 156 | R = np.random.randn(k_3, k_2) 157 | # print(count, tol * (10**-count)) 158 | indices = np.argwhere( 159 | np.linalg.norm((R @ S)[:, :, None] - (R @ W)[:, None, :], axis=0) 160 | < tol * (10 ** -count) 161 | )[:, 1] 162 | count += 1 163 | if count >= 10: 164 | indices = np.argwhere( 165 | np.linalg.norm(S[:, :, None] - W[:, None, :], axis=0) < tol 166 | )[:, 1] 167 | break 168 | print(count) 169 | if len(indices) != k_1: 170 | raise ValueError("Len indices not equal to k_1: %d != %d" % (len(indices), k_1)) 171 | if return_p: 172 | return indices, P, perm 173 | return indices 174 | 175 | 176 | class DiffusionEMD(object): 177 | """ Base class for DiffusionEMD estimators 178 | """ 179 | 180 | def __init__( 181 | self, 182 | max_scale=10, 183 | n_scales=6, 184 | delta=0, 185 | anisotropy=1, 186 | alpha=0.5, 187 | min_basis=0, 188 | max_basis=None, 189 | **kwargs 190 | ): 191 | self.max_scale = max_scale 192 | # Filter does not tolerate scales below zero 193 | self.n_scales = min(n_scales, max_scale + 1) 194 | self.delta = delta 195 | self.anisotropy = anisotropy 196 | self.alpha = alpha 197 | self.min_basis = min_basis 198 | if max_basis is None: 199 | max_basis = np.inf 200 | self.max_basis = max_basis 201 | self.scales = [ 202 | 2 ** i for i in range(max_scale - self.n_scales + 1, max_scale + 1) 203 | ] 204 | assert 0 <= self.anisotropy <= 1 205 | 206 | def transform(self, y): 207 | pass 208 | 209 | def fit(self, X): 210 | self.X = X 211 | self.N = X.shape[0] 212 | self.M = apply_anisotropy(X, self.anisotropy) 213 | self.D = np.array(self.M.sum(axis=0)).squeeze() 214 | self.T = apply_vectors(self.M, self.D ** -0.5) 215 | 216 | def _compute_rank(self): 217 | self.basis_sizes = approximate_rank_of_scales( 218 | self.T, self.delta, scales=self.scales 219 | ) 220 | self.basis_sizes = np.clip(self.basis_sizes, a_min=self.min_basis, a_max=None) 221 | 222 | def fit_transform(self, X, y, **kwargs): 223 | self.fit(X, **kwargs) 224 | return self.transform(y) 225 | 226 | class DiffusionTree(DiffusionEMD): 227 | def __init__( 228 | self, 229 | max_scale=10, 230 | n_scales=1000, 231 | delta=0, 232 | anisotropy=1, 233 | alpha=0.5, 234 | min_basis=0, 235 | max_basis=None, 236 | ): 237 | n_scales = max_scale + 1 238 | super().__init__( 239 | max_scale=max_scale, 240 | n_scales=n_scales, 241 | delta=delta, 242 | anisotropy=anisotropy, 243 | alpha=alpha, 244 | min_basis=min_basis, 245 | max_basis=max_basis, 246 | ) 247 | 248 | def fit(self, X): 249 | super().fit(X) 250 | self.T = apply_vectors(self.M, self.D ** -0.5) 251 | self._compute_rank() 252 | self._compute_diff_op() 253 | 254 | def _compute_diff_op(self): 255 | self.Ts = [self.T] 256 | self.Ps = [None] 257 | self.bases = [np.arange(self.N)] 258 | self.perms = [None] 259 | for j, arank in enumerate(self.basis_sizes[1:]): 260 | Tj = self.Ts[j] 261 | N = Tj.shape[0] 262 | # If arank is not significantly smaller, don't bother shrinking basis 263 | if arank < min(N * 0.5, self.max_basis): 264 | basis, P, perm = randomized_interpolative_decomposition(Tj, arank, min(arank + 8, N), return_p = True) 265 | Tp1 = Tj[basis] 266 | else: 267 | P = None 268 | basis = np.arange(N) 269 | Tp1 = Tj 270 | perm = None 271 | self.perms.append(perm) 272 | self.Ts.append(Tp1 @ Tp1.transpose()) 273 | self.Ps.append(P) 274 | self.bases.append(basis) 275 | 276 | def transform(self, y): 277 | dist_at_scale = y 278 | embeddings = [] 279 | n_scales = len(self.scales) 280 | prev_diffusion = None 281 | for i, s in enumerate(self.scales): 282 | T = self.Ts[i] 283 | P = self.Ps[i] 284 | perm = self.perms[i] 285 | if P is not None: 286 | dist_at_scale = ( 287 | P 288 | @ estimate_utils.permutation_vector_to_matrix(perm) 289 | @ dist_at_scale 290 | ) 291 | diffusion_at_scale = T @ dist_at_scale 292 | if P is not None: 293 | tmp = P.T @ diffusion_at_scale 294 | else: 295 | tmp = diffusion_at_scale 296 | if i > 0: 297 | weight = 0.5 ** ((n_scales - i) * self.alpha) * ( 298 | self.N / diffusion_at_scale.shape[0] 299 | ) 300 | lvl_embed = weight * (tmp - prev_diffusion).T 301 | embeddings.append(lvl_embed) 302 | prev_diffusion = diffusion_at_scale 303 | embeddings.append(tmp.T) 304 | embeddings = np.concatenate(embeddings, axis=1) 305 | self.embeddings = embeddings 306 | return self.embeddings 307 | 308 | class DiffusionTreeV2(DiffusionEMD): 309 | def __init__( 310 | self, 311 | max_scale=10, 312 | n_scales=6, 313 | delta=0, 314 | anisotropy=1, 315 | alpha=0.5, 316 | min_basis=0, 317 | max_basis=None, 318 | ): 319 | n_scales = max_scale + 1 320 | super().__init__( 321 | max_scale=max_scale, 322 | n_scales=n_scales, 323 | delta=delta, 324 | anisotropy=anisotropy, 325 | alpha=alpha, 326 | min_basis=min_basis, 327 | max_basis=max_basis, 328 | ) 329 | 330 | def fit(self, X): 331 | super().fit(X) 332 | self.T = apply_vectors(self.M, self.D ** -0.5) 333 | self._compute_rank() 334 | self._compute_diff_op() 335 | 336 | def _compute_diff_op(self): 337 | self.Ts = [self.T] 338 | self.Ps = [None] 339 | self.bases = [np.arange(self.N)] 340 | for j, arank in enumerate(self.basis_sizes[1:]): 341 | Tj = self.Ts[j] 342 | N = Tj.shape[0] 343 | # If arank is not significantly smaller, don't bother shrinking basis 344 | if arank < min(N * 0.5, self.max_basis): 345 | from scipy.linalg.interpolative import interp_decomp 346 | basis, P = interp_decomp(Tj, arank) 347 | Tp1 = Tj[basis[arank:]] 348 | else: 349 | P = None 350 | basis = np.arange(N) 351 | Tp1 = Tj 352 | self.Ts.append(Tp1 @ Tp1.transpose()) 353 | self.Ps.append(P) 354 | self.bases.append(basis) 355 | 356 | def transform(self, y): 357 | dist_at_scale = y 358 | embeddings = [] 359 | n_scales = len(self.scales) 360 | prev_diffusion = None 361 | for i, s in enumerate(self.scales): 362 | T = self.Ts[i] 363 | P = self.Ps[i] 364 | diffusion_at_scale = T @ dist_at_scale 365 | tmp = diffusion_at_scale 366 | if i > 0: 367 | weight = 0.5 ** ((n_scales - i) * self.alpha) * ( 368 | self.N / diffusion_at_scale.shape[0] 369 | ) 370 | lvl_embed = weight * (tmp - prev_diffusion).T 371 | embeddings.append(lvl_embed) 372 | prev_diffusion = diffusion_at_scale 373 | embeddings.append(tmp.T) 374 | embeddings = np.concatenate(embeddings, axis=1) 375 | self.embeddings = embeddings 376 | return self.embeddings 377 | 378 | class DiffusionCheb(DiffusionEMD): 379 | def __init__( 380 | self, 381 | max_scale=10, 382 | n_scales=6, 383 | delta=0, 384 | anisotropy=1, 385 | alpha=0.5, 386 | min_basis=0, 387 | max_basis=None, 388 | method="chebyshev", 389 | use_diff_wavelets=True, 390 | cheb_order=32, 391 | ): 392 | self.method = method 393 | self.use_diff_wavelets = use_diff_wavelets 394 | self.cheb_order = cheb_order 395 | super().__init__( 396 | max_scale=max_scale, 397 | n_scales=n_scales, 398 | delta=delta, 399 | anisotropy=anisotropy, 400 | alpha=alpha, 401 | min_basis=min_basis, 402 | max_basis=max_basis, 403 | ) 404 | 405 | def fit(self, X): 406 | super().fit(X) 407 | graph = pygsp.graphs.Graph(self.M) 408 | # Use the normalized laplacian here for eigenvalues in [0, 2] 409 | graph.compute_laplacian("normalized") 410 | if self.method == "exact": 411 | graph.compute_fourier_basis() 412 | else: 413 | graph.estimate_lmax() 414 | 415 | kernels = [lambda x, s=s: np.minimum((1 - x) ** s, 1) for s in self.scales] 416 | self.filter = pygsp.filters.Filter(graph, kernels) 417 | 418 | def _subsample_embeddings(self, embeddings): 419 | # TODO make this work on a concatenated set of embeddings 420 | self.selections = [ 421 | randomized_interpolative_decomposition(self.M, rank, min(rank + 8, self.N)) 422 | if rank < self.max_basis 423 | else np.random.randint(self.N, size=rank) 424 | for rank in self.basis_sizes 425 | ] 426 | embeddings = [ 427 | e[:, s] * self.M.shape[0] / a 428 | for s, e, a in zip(self.selections, embeddings, self.basis_sizes) 429 | ] 430 | return embeddings 431 | 432 | def transform(self, y): 433 | D_labels = (self.D[:, None] ** .5) * y 434 | #D_labels = (self.D[:, None] ** -0.5) * y 435 | diffusions = self.filter.filter( 436 | D_labels, method=self.method, order=self.cheb_order 437 | ) 438 | diffusions = (self.D ** -0.5)[:, None, None] * diffusions 439 | #diffusions = (self.D ** 0.5)[:, None, None] * diffusions 440 | n, n_samples, n_scales = diffusions.shape 441 | embeddings = [] 442 | for k in range(n_scales): 443 | d = diffusions[..., k] 444 | if self.use_diff_wavelets: 445 | # Corresponds to Dual norm version (1) in Leeb and Coifman 2016 446 | if k < n_scales - 1: 447 | d -= diffusions[..., k + 1] 448 | weight = 0.5 ** ((n_scales - k - 1) * self.alpha) 449 | else: 450 | # Corresponds to Dual norm version (2) in Leeb and Coifman 2016 451 | if k < n_scales - 1: 452 | d -= diffusions[..., -1] 453 | weight = 0.5 ** ((n_scales - k - 1) * self.alpha) 454 | lvl_embed = weight * d.T 455 | embeddings.append(lvl_embed) 456 | if self.delta > 0: 457 | self._compute_rank() 458 | embeddings = self._subsample_embeddings(embeddings) 459 | else: 460 | self.basis_sizes = [n_samples] * n_scales 461 | self.embeddings = np.concatenate(embeddings, axis=1) 462 | return self.embeddings 463 | 464 | class DiffusionExact(DiffusionEMD): 465 | def __init__( 466 | self, 467 | max_scale=10, 468 | n_scales=6, 469 | delta=0, 470 | anisotropy=1, 471 | alpha=0.5, 472 | min_basis=0, 473 | max_basis=None, 474 | no_diff = False, 475 | use_diff_wavelets=False, 476 | ): 477 | self.use_diff_wavelets = use_diff_wavelets 478 | self.no_diff = no_diff 479 | super().__init__( 480 | max_scale=max_scale, 481 | n_scales=n_scales, 482 | delta=delta, 483 | anisotropy=anisotropy, 484 | alpha=alpha, 485 | min_basis=min_basis, 486 | max_basis=max_basis, 487 | ) 488 | # Always include the zeroth scale for the exact computation 489 | self.scales = [ 490 | 0, *[2 ** i for i in range(max_scale - self.n_scales+1, max_scale + 1) 491 | ]] 492 | 493 | def fit(self, X): 494 | super().fit(X + scipy.sparse.eye(X.shape[0])) 495 | #self.T = self.T.todense() 496 | 497 | # compute basis 498 | #if delta > 0: 499 | # self._compute_rank() 500 | # self._subsample_basis() 501 | 502 | 503 | def _subsample_basis(self): 504 | # TODO make this work on a concatenated set of embeddings 505 | self.selections = [ 506 | randomized_interpolative_decomposition(self.M, rank, min(rank + 8, self.N)) 507 | if rank < min(self.max_basis, self.N) 508 | else np.random.randint(self.N, size=rank) 509 | for rank in self.basis_sizes 510 | ] 511 | 512 | 513 | def transform(self, y): 514 | print(self.D[:, None].shape, y.shape) 515 | print(type(self.D), type(y)) 516 | D_labels = (self.D[:, None] ** -0.5) * y 517 | diffusions = [D_labels] 518 | tmp = D_labels 519 | print(self.scales) 520 | for scale in range(1, max(self.scales)+1): 521 | tmp = self.T @ tmp 522 | if scale in self.scales: 523 | diffusions.append(tmp) 524 | diffusions = np.stack(diffusions, axis=-1) 525 | diffusions = (self.D ** 0.5)[:, None, None] * diffusions 526 | n, n_samples, n_scales = diffusions.shape 527 | embeddings = [] 528 | for k in range(n_scales): 529 | d = diffusions[..., k] 530 | weight = 0.5 ** ((n_scales - k - 1) * self.alpha) 531 | if not self.no_diff: 532 | if self.use_diff_wavelets: 533 | # Corresponds to Dual norm version (1) in Leeb and Coifman 2016 534 | if k < n_scales - 1: 535 | d -= diffusions[..., k + 1] 536 | else: 537 | # Corresponds to Dual norm version (2) in Leeb and Coifman 2016 538 | if k < n_scales - 1: 539 | d -= diffusions[..., -1] 540 | lvl_embed = weight * d.T 541 | embeddings.append(lvl_embed) 542 | if self.delta > 0: 543 | self._compute_rank() 544 | embeddings = self._subsample_embeddings(embeddings) 545 | else: 546 | self.basis_sizes = [n_samples] * n_scales 547 | self.embeddings = np.concatenate(embeddings, axis=1) 548 | return self.embeddings 549 | -------------------------------------------------------------------------------- /DiffusionEMD/emd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import ot as pot # Python Optimal Transport package 3 | import scipy.sparse 4 | from sklearn.metrics.pairwise import pairwise_distances 5 | 6 | 7 | def sinkhorn( 8 | p, q, metric="euclidean", 9 | ): 10 | """ 11 | Returns the earth mover's distance between two point clouds 12 | Parameters 13 | ---------- 14 | cloud1 : 2-D array 15 | First point cloud 16 | cloud2 : 2-D array 17 | Second point cloud 18 | Returns 19 | ------- 20 | distance : float 21 | The distance between the two point clouds 22 | """ 23 | p_weights = np.ones(len(p)) / len(p) 24 | q_weights = np.ones(len(q)) / len(q) 25 | 26 | pairwise_dist = np.ascontiguousarray( 27 | pairwise_distances(p, Y=q, metric=metric, n_jobs=-1) 28 | ) 29 | 30 | result = pot.sinkhorn2( 31 | p_weights, 32 | q_weights, 33 | pairwise_dist, 34 | reg=0.05, 35 | numItermax=100, 36 | return_matrix=False, 37 | ) 38 | return np.sqrt(result) 39 | 40 | 41 | def exact( 42 | p, q, metric="euclidean", 43 | ): 44 | """ 45 | Returns the earth mover's distance between two point clouds 46 | Parameters 47 | ---------- 48 | cloud1 : 2-D array 49 | First point cloud 50 | cloud2 : 2-D array 51 | Second point cloud 52 | Returns 53 | ------- 54 | distance : float 55 | The distance between the two point clouds 56 | """ 57 | p_weights = np.ones(len(p)) / len(p) 58 | q_weights = np.ones(len(q)) / len(q) 59 | pairwise_dist = np.ascontiguousarray( 60 | pairwise_distances(p, Y=q, metric=metric, n_jobs=-1) 61 | ) 62 | result = pot.lp.emd2( 63 | p_weights, q_weights, pairwise_dist, numItermax=1e7, return_matrix=False 64 | ) 65 | return result 66 | 67 | 68 | def interpolate_with_ot(p0, p1, tmap, interp_frac, size): 69 | """ 70 | Interpolate between p0 and p1 at fraction t_interpolate knowing a transport 71 | map from p0 to p1. 72 | 73 | Parameters 74 | ---------- 75 | p0 : 2-D array 76 | The genes of each cell in the source population 77 | p1 : 2-D array 78 | The genes of each cell in the destination population 79 | tmap : 2-D array 80 | A transport map from p0 to p1 81 | t_interpolate : float 82 | The fraction at which to interpolate 83 | size : int 84 | The number of cells in the interpolated population 85 | Returns 86 | ------- 87 | p05 : 2-D array 88 | An interpolated population of 'size' cells 89 | """ 90 | p0 = p0.toarray() if scipy.sparse.isspmatrix(p0) else p0 91 | p1 = p1.toarray() if scipy.sparse.isspmatrix(p1) else p1 92 | p0 = np.asarray(p0, dtype=np.float64) 93 | p1 = np.asarray(p1, dtype=np.float64) 94 | tmap = np.asarray(tmap, dtype=np.float64) 95 | if p0.shape[1] != p1.shape[1]: 96 | raise ValueError("Unable to interpolate. Number of genes do not match") 97 | if p0.shape[0] != tmap.shape[0] or p1.shape[0] != tmap.shape[1]: 98 | raise ValueError( 99 | "Unable to interpolate. Tmap size is {}, expected {}".format( 100 | tmap.shape, (len(p0), len(p1)) 101 | ) 102 | ) 103 | # Assume growth is exponential and retrieve growth rate at t_interpolate 104 | # If all sums are the same then this does not change anything 105 | # This only matters if sum is not the same for all rows 106 | p = tmap / np.power(tmap.sum(axis=0), 1.0 - interp_frac) 107 | p = p.flatten(order="C") 108 | p = p / p.sum() 109 | choices = np.random.choice(len(p0) * len(p1), p=p, size=size) 110 | return np.asarray( 111 | [ 112 | p0[i // len(p1)] * (1 - interp_frac) + p1[i % len(p1)] * interp_frac 113 | for i in choices 114 | ], 115 | dtype=np.float64, 116 | ) 117 | 118 | 119 | def interpolate_per_point_with_ot(p0, p1, tmap, interp_frac): 120 | """ 121 | Interpolate between p0 and p1 at fraction t_interpolate knowing a transport 122 | map from p0 to p1. 123 | Parameters 124 | ---------- 125 | p0 : 2-D array 126 | The genes of each cell in the source population 127 | p1 : 2-D array 128 | The genes of each cell in the destination population 129 | tmap : 2-D array 130 | A transport map from p0 to p1 131 | t_interpolate : float 132 | The fraction at which to interpolate 133 | Returns 134 | ------- 135 | p05 : 2-D array 136 | An interpolated population of 'size' cells 137 | """ 138 | assert len(p0) == len(p1) 139 | p0 = p0.toarray() if scipy.sparse.isspmatrix(p0) else p0 140 | p1 = p1.toarray() if scipy.sparse.isspmatrix(p1) else p1 141 | p0 = np.asarray(p0, dtype=np.float64) 142 | p1 = np.asarray(p1, dtype=np.float64) 143 | tmap = np.asarray(tmap, dtype=np.float64) 144 | if p0.shape[1] != p1.shape[1]: 145 | raise ValueError("Unable to interpolate. Number of genes do not match") 146 | if p0.shape[0] != tmap.shape[0] or p1.shape[0] != tmap.shape[1]: 147 | raise ValueError( 148 | "Unable to interpolate. Tmap size is {}, expected {}".format( 149 | tmap.shape, (len(p0), len(p1)) 150 | ) 151 | ) 152 | 153 | # Assume growth is exponential and retrieve growth rate at t_interpolate 154 | # If all sums are the same then this does not change anything 155 | # This only matters if sum is not the same for all rows 156 | p = tmap / (tmap.sum(axis=0) / 1.0 - interp_frac) 157 | # p = tmap / np.power(tmap.sum(axis=0), 1.0 - interp_frac) 158 | # p = p.flatten(order="C") 159 | p = p / p.sum(axis=0) 160 | choices = np.array([np.random.choice(len(p0), p=p[i]) for i in range(len(p0))]) 161 | return np.asarray( 162 | [ 163 | p0[i] * (1 - interp_frac) + p1[j] * interp_frac 164 | for i, j in enumerate(choices) 165 | ], 166 | dtype=np.float64, 167 | ) 168 | -------------------------------------------------------------------------------- /DiffusionEMD/estimate_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapated from Vertex frequency codebase. Credit to Gabriel Dolsten. 3 | Algorithms based on https://arxiv.org/pdf/1905.09758.pdf 4 | Goal is to estimate the density of eigenvalues over a known range. 5 | """ 6 | 7 | import numpy as np 8 | import scipy.sparse as ss 9 | import numpy.random as nr 10 | import matplotlib.pyplot as plt 11 | import pygsp 12 | import ot 13 | 14 | 15 | def moments_cheb_dos(A, n, nZ=100, N=10, kind=1): 16 | """ 17 | Compute a column vector of Chebyshev moments of the form c(k) = tr(T_k(A)) 18 | for k = 0 to N-1. This routine does no scaling; the spectrum of A should 19 | already lie in [-1,1]. The traces are computed via a stochastic estimator 20 | with nZ probe 21 | 22 | Args: 23 | A: Matrix or function apply matrix (to multiple RHS) 24 | n: Dimension of the space 25 | nZ: Number of probe vectors with which we compute moments 26 | N: Number of moments to compute 27 | kind: 1 or 2 for first or second kind Chebyshev functions 28 | (default = 1) 29 | 30 | Output: 31 | c: a column vector of N moment estimates 32 | cs: standard deviation of the moment estimator 33 | (std/sqrt(nZ)) 34 | """ 35 | 36 | # Create a function handle if given a matrix 37 | if callable(A): 38 | Afun = A 39 | else: 40 | if isinstance(A, np.ndarray): 41 | A = ss.csr_matrix(A) 42 | 43 | def Afun(x): 44 | return A * x 45 | 46 | if N < 2: 47 | N = 2 48 | 49 | # Set up random probe vectors (allowed to be passed in) 50 | if not isinstance(nZ, int): 51 | Z = nZ 52 | nZ = Z.shape[1] 53 | else: 54 | Z = np.sign(nr.randn(n, nZ)) 55 | 56 | # Estimate moments for each probe vector 57 | cZ = moments_cheb(Afun, Z, N, kind) 58 | c = np.mean(cZ, 1) 59 | cs = np.std(cZ, 1, ddof=1) / np.sqrt(nZ) 60 | 61 | c = c.reshape([N, -1]) 62 | cs = cs.reshape([N, -1]) 63 | return c, cs 64 | 65 | 66 | def moments_cheb(A, V, N=10, kind=1): 67 | """ 68 | Compute a column vector of Chebyshev moments of the form c(k) = v'*T_k(A)*v 69 | for k = 0 to N-1. This routine does no scaling; the spectrum of A should 70 | already lie in [-1,1] 71 | 72 | Args: 73 | A: Matrix or function apply matrix (to multiple RHS) 74 | V: Starting vectors 75 | N: Number of moments to compute 76 | kind: 1 or 2 for first or second kind Chebyshev functions 77 | (default = 1) 78 | 79 | Output: 80 | c: a length N vector of moments 81 | """ 82 | 83 | if N < 2: 84 | N = 2 85 | 86 | if not isinstance(V, np.ndarray): 87 | V = V.toarray() 88 | 89 | # Create a function handle if given a matrix 90 | if callable(A): 91 | Afun = A 92 | else: 93 | if isinstance(A, np.ndarray): 94 | A = ss.csr_matrix(A) 95 | 96 | def Afun(x): 97 | return A * x 98 | 99 | n, p = V.shape 100 | c = np.zeros((N, p)) 101 | 102 | # Run three-term recurrence to compute moments 103 | TVp = V # x 104 | TVk = kind * Afun(V) # Ax 105 | c[0] = np.sum(V * TVp, 0) # xx 106 | c[1] = np.sum(V * TVk, 0) # xAx 107 | for i in range(2, N): 108 | TV = 2 * Afun(TVk) - TVp # A*2T_1 - T_o 109 | TVp = TVk 110 | TVk = TV 111 | c[i] = sum(V * TVk, 0) 112 | return c 113 | 114 | 115 | def plot_cheb_argparse(npts, c, xx0=-1, ab=np.array([1, 0])): 116 | """ 117 | Handle argument parsing for plotting routines. Should not be called directly 118 | by users. 119 | 120 | Args: 121 | npts: Number of points in a default mesh 122 | c: Vector of moments 123 | xx0: Input sampling mesh (original coordinates) 124 | ab: Scaling map parameters 125 | 126 | Output: 127 | c: Vector of moments 128 | xx: Input sampling mesh ([-1,1] coordinates) 129 | xx0: Input sampling mesh (original coordinates) 130 | ab: Scaling map parameters 131 | """ 132 | 133 | if isinstance(xx0, int): 134 | # only c is given 135 | xx0 = np.linspace(-1 + 1e-8, 1 - 1e-8, npts) 136 | xx = xx0 137 | else: 138 | if len(xx0) == 2: 139 | # parameters are c, ab 140 | ab = xx0 141 | xx = np.linspace(-1 + 1e-8, 1 - 1e-8, npts) 142 | xx0 = ab[0] * xx + ab[1] 143 | else: 144 | # parameteres are c, xx0 145 | xx = xx0 146 | 147 | # All parameters specified 148 | if not (ab == [1, 0]).all(): 149 | xx = (xx0 - ab[1]) / ab[0] 150 | 151 | return c, xx, xx0, ab 152 | 153 | 154 | def plot_chebint(varargin, npts=1001, pflag=True): 155 | """ 156 | Given a (filtered) set of first-kind Chebyshev moments, compute the integral 157 | of the density: 158 | int_0^s (2/pi)*sqrt(1-x^2)*( c(0)/2+sum_{n=1}^{N-1}c_nT_n(x) ) 159 | Output a plot of cumulative density function by default. 160 | 161 | Args: 162 | c: Array of Chebyshev moments (on [-1,1]) 163 | xx: Evaluation points (defaults to mesh of 1001 pts) 164 | ab: Mapping parameters (default to identity) 165 | pflag: Option to output the plot 166 | 167 | Output: 168 | yy: Estimated cumulative density up to each xx point 169 | """ 170 | 171 | # Parse arguments 172 | c, xx, xx0, ab = plot_cheb_argparse(npts, *varargin) 173 | 174 | N = len(c) 175 | txx = np.arccos(xx) 176 | yy = c[0] * (txx - np.pi) / 2 177 | for idx in np.arange(1, N): 178 | yy += c[idx] * np.sin(idx * txx) / idx 179 | 180 | yy *= -2 / np.pi 181 | 182 | # Plot by default 183 | if pflag: 184 | plt.plot(xx0, yy) 185 | # plt.ion() 186 | plt.show() 187 | # plt.pause(1) 188 | # plt.clf() 189 | 190 | return [xx0, yy] 191 | 192 | 193 | def plot_chebhist(varargin, pflag=True, npts=21): 194 | """ 195 | Given a (filtered) set of first-kind Chebyshev moments, compute the integral 196 | of the density: 197 | int_0^s (2/pi)*sqrt(1-x^2)*( c(0)/2+sum_{n=1}^{N-1}c_nT_n(x) ) 198 | Output a histogram of cumulative density function by default. 199 | 200 | Args: 201 | c: Vector of Chebyshev moments (on [-1,1]) 202 | xx: Evaluation points (defaults to mesh of 21 pts) 203 | ab: Mapping parameters (default to identity) 204 | pflag: Option to output the plot 205 | 206 | Output: 207 | yy: Estimated counts on buckets between xx points 208 | """ 209 | 210 | # Parse arguments 211 | c, xx, xx0, ab = plot_cheb_argparse(npts, *varargin) 212 | 213 | # Compute CDF and bin the difference 214 | yy = plot_chebint((c, xx0, ab), pflag=False) 215 | yy = yy[1:] - yy[:-1] 216 | xm = (xx0[1:] + xx0[:-1]) / 2 217 | 218 | # Plot by default 219 | if pflag: 220 | plt.bar(xm + 1, yy, align="center", width=0.1) 221 | # plt.ion() 222 | plt.show() 223 | # plt.pause(1) 224 | # plt.clf() 225 | 226 | return [xm + 1, yy] 227 | 228 | 229 | def matrix_normalize(W, mode="s"): 230 | """ 231 | Normalize an adjacency matrix. 232 | 233 | Args: 234 | W: weighted adjacency matrix 235 | mode: string indicating the style of normalization; 236 | 's': Symmetric scaling by the degree (default) 237 | 'r': Normalize to row-stochastic 238 | 'c': Normalize to col-stochastic 239 | 240 | Output: 241 | N: a normalized adjacency matrix or stochastic matrix (in sparse form) 242 | """ 243 | 244 | dc = np.asarray(W.sum(0)).squeeze() 245 | dr = np.asarray(W.sum(1)).squeeze() 246 | [i, j, wij] = ss.find(W) 247 | 248 | # Normalize in desired style 249 | if mode in "sl": 250 | wij = wij / np.sqrt(dr[i] * dc[j]) 251 | elif mode == "r": 252 | wij = wij / dr[i] 253 | elif mode == "c": 254 | wij = wij / dc[j] 255 | else: 256 | raise ValueError("Unknown mode!") 257 | 258 | N = ss.csr_matrix((wij, (i, j)), shape=W.shape) 259 | return N 260 | 261 | 262 | def simple_diffusion_embeddings(graph, distribution_labels, subsample=False, scales=7): 263 | """ The plain version, without any frills. 264 | Return the vectors whose L1 distances are the EMD between the given distributions. 265 | The graph supplied (a PyGSP graph) should encompass both distributions. 266 | The distributions themselves should be one-hot encoded with the 267 | distribution_labels parameter. 268 | """ 269 | heat_filter = pygsp.filters.Heat( 270 | graph, tau=[2 ** i for i in range(1, scales + 1)], normalize=False 271 | ) 272 | diffusions = heat_filter.filter(distribution_labels, method="chebyshev", order=32) 273 | print(diffusions.shape) 274 | if subsample: 275 | rng = np.random.default_rng(42) 276 | if len(diffusions.shape) == 2: 277 | n_samples = 1 278 | n, n_scales = diffusions.shape 279 | else: 280 | n, n_samples, n_scales = diffusions.shape 281 | embeddings = [] 282 | for i in range(n_scales): 283 | d = diffusions[..., i] 284 | weight = 0.5 ** (n_scales - i) 285 | if subsample: 286 | subsample_idx = rng.integers(n, size=n // 10) 287 | lvl_embed = weight * d[subsample_idx].T 288 | else: 289 | lvl_embed = weight * d.T 290 | embeddings.append(lvl_embed) 291 | if len(diffusions.shape) == 2: 292 | embeddings = np.concatenate(embeddings) 293 | else: 294 | embeddings = np.concatenate(embeddings, axis=1) 295 | return embeddings 296 | 297 | 298 | def l1_distance_matrix(embeddings): 299 | """ 300 | Gives a square distance matrix with the L1 distances between the provided embeddings 301 | """ 302 | D = np.zeros((len(embeddings), len(embeddings))) 303 | for i, embed1 in enumerate(embeddings): 304 | for j, embed2 in enumerate(embeddings): 305 | D[i][j] = np.sum(np.abs(embed1 - embed2)) 306 | D[j][i] = D[i][j] 307 | return D 308 | 309 | 310 | def exact_ot(signals, dists): 311 | D = np.zeros((len(signals), len(signals))) 312 | for i, sig1 in enumerate(signals): 313 | for j, sig2 in enumerate(signals): 314 | sig1 = sig1.copy(order="C") 315 | sig2 = sig2.copy(order="C") 316 | dists = dists.copy(order="C") 317 | D[i][j] = ot.emd2(sig1, sig2, dists, processes=-2) 318 | D[j][i] = D[i][j] 319 | return D 320 | 321 | 322 | def permutation_vector_to_matrix(E): 323 | """Convert a permutation vector E (list or rank-1 array, length n) to a 324 | permutation matrix (n by n). The result is returned as a 325 | scipy.sparse.coo_matrix, where the entries at (E[k], k) are 1. 326 | """ 327 | n = len(E) 328 | j = np.arange(n) 329 | return ss.coo_matrix((np.ones(n), (E, j)), shape=(n, n)) 330 | -------------------------------------------------------------------------------- /DiffusionEMD/metric_tree.py: -------------------------------------------------------------------------------- 1 | """ metric_tree.py 2 | This file uses sklearn trees generally used for KNN calculation as an 3 | approximate metric tree for wasserstein distance. Further extensions are 4 | quadtree, and one based on hierarchical clustering. The idea is to use the 5 | tree with edge lengths as the (L2) distance between means. The distance 6 | between any two points embedded in this tree is then the geodesic distance 7 | along the tree. Note that this is an offline algorithm, we do not support 8 | adding points after the initial construction. 9 | """ 10 | import numpy as np 11 | from sklearn.base import BaseEstimator 12 | from sklearn.utils.validation import check_X_y, check_is_fitted 13 | from sklearn.neighbors import KDTree, BallTree 14 | from sklearn.metrics import DistanceMetric 15 | from sklearn.cluster import MiniBatchKMeans 16 | from scipy.sparse import coo_matrix 17 | 18 | 19 | class QuadTree(object): 20 | """ 21 | This quadtree could be sped up, but is an easy implementation 22 | """ 23 | 24 | def __init__(self, X, n_levels=25, noise=1.0, *args, **kwargs): 25 | assert np.all(np.min(X, axis=0) >= 0) 26 | assert np.all(np.max(X, axis=0) <= 1) 27 | assert n_levels >= 1 28 | self.kwargs = kwargs 29 | self.X = X 30 | self.noise = noise 31 | # self.X = self.X + np.random.randn(*self.X.shape) * noise 32 | self.dims = X.shape[1] 33 | self.n_clusters = 2 ** self.dims 34 | self.n_levels = n_levels 35 | center = np.random.rand(self.dims) * noise 36 | self.tree, self.indices, self.centers, self.dists = self._cluster( 37 | center, np.arange(X.shape[0]), n_levels=self.n_levels - 1, start=0 38 | ) 39 | self.tree = [(0, self.X.shape[0], n_levels, 0), *self.tree] 40 | self.dists = np.array([0, *self.dists]) 41 | self.centers = [center, *self.centers] 42 | self.centers = np.array(self.centers) 43 | 44 | def _cluster(self, center, index, n_levels, start): 45 | """ 46 | Parameters 47 | ---------- 48 | 49 | bounds: 50 | [2 x D] matrix giving min / max of bounding box for this cluster 51 | 52 | """ 53 | if n_levels == 0 or len(index) == 0: 54 | return None 55 | labels = np.ones_like(index) * -1 56 | dim_masks = np.array([self.X[index, d] > center[d] for d in range(self.dims)]) 57 | import itertools 58 | 59 | bin_masks = np.array(list(itertools.product([False, True], repeat=self.dims))) 60 | label_masks = np.all(bin_masks[..., None] == dim_masks[None, ...], axis=1) 61 | for i, mask in enumerate(label_masks): 62 | labels[mask] = i 63 | assert np.all(labels > -1) 64 | shift = 2 ** -(self.n_levels - n_levels + 2) 65 | shifts = np.array(list(itertools.product([-shift, shift], repeat=self.dims))) 66 | cluster_centers = shifts + center 67 | sorted_index = [] 68 | children = [] 69 | ccenters = [] 70 | cdists = [] 71 | is_leaf = [0] * self.n_clusters 72 | unique, ucounts = np.unique(labels, return_counts=True) 73 | counts = np.zeros(self.n_clusters, dtype=np.int32) 74 | for u, c in zip(unique, ucounts): 75 | counts[u] = c 76 | cstart = 0 77 | for i, count, ccenter in zip(unique, counts, cluster_centers): 78 | ret = self._cluster( 79 | ccenter, index[labels == i], n_levels - 1, start + cstart 80 | ) 81 | if ret is None: 82 | sorted_index.extend(index[labels == i]) 83 | is_leaf[i] = 1 84 | continue 85 | sorted_index.extend(ret[1]) 86 | children.extend(ret[0]) 87 | ccenters.extend(ret[2]) 88 | cdists.extend(ret[3]) 89 | cstart += count 90 | 91 | to_return = list( 92 | zip( 93 | *[ 94 | np.array([0, *np.cumsum(counts)]) + start, 95 | np.cumsum(counts) + start, 96 | [n_levels] * self.n_clusters, 97 | is_leaf, 98 | ] 99 | ) 100 | ) 101 | dists = np.linalg.norm(cluster_centers - center[None, :], axis=1) 102 | return ( 103 | [*to_return, *children], 104 | sorted_index, 105 | [*cluster_centers, *ccenters], 106 | [*dists, *cdists], 107 | ) 108 | 109 | def get_arrays(self): 110 | return None, self.indices, self.tree, self.centers, self.dists 111 | 112 | 113 | class ClusterTree(object): 114 | def __init__(self, X, n_clusters=10, n_levels=5, *args, **kwargs): 115 | self.X = X 116 | self.n_clusters = n_clusters 117 | self.n_levels = n_levels 118 | center = self.X.mean(axis=0) 119 | self.tree, self.indices, self.centers, self.dists = self._cluster( 120 | center, np.arange(X.shape[0]), n_levels=self.n_levels - 1, start=0 121 | ) 122 | self.tree = [(0, self.X.shape[0], n_levels, n_levels == 1), *self.tree] 123 | self.centers = [center, *self.centers] 124 | self.dists = np.array([0, *self.dists]) 125 | self.centers = np.array(self.centers) 126 | 127 | def _cluster(self, center, index, n_levels, start): 128 | """ 129 | Returns a list of tuples corresponding to each subnode of the tree 130 | (center, level, start, end, is_leaf), sorted_index 131 | center is the cluster center 132 | level is the level of the node counting the root as the zeroth level 133 | sorted_index is athe list of 134 | """ 135 | if n_levels == 0 or len(index) < self.n_clusters: 136 | return None 137 | cl = MiniBatchKMeans(n_clusters=self.n_clusters) 138 | cl.fit(self.X[index]) 139 | sorted_index = [] 140 | children = [] 141 | ccenters = [] 142 | cdists = [] 143 | is_leaf = [0] * self.n_clusters 144 | unique, ucounts = np.unique(cl.labels_, return_counts=True) 145 | counts = np.zeros(self.n_clusters, dtype=np.int32) 146 | for u, c in zip(unique, ucounts): 147 | counts[u] = c 148 | cstart = 0 149 | for i, count in zip(unique, counts): 150 | ret = self._cluster( 151 | cl.cluster_centers_[i], 152 | index[cl.labels_ == i], 153 | n_levels - 1, 154 | start + cstart, 155 | ) 156 | if ret is None: 157 | sorted_index.extend(index[cl.labels_ == i]) 158 | is_leaf[i] = 1 159 | continue 160 | sorted_index.extend(ret[1]) 161 | children.extend(ret[0]) 162 | ccenters.extend(ret[2]) 163 | cdists.extend(ret[3]) 164 | cstart += count 165 | to_return = list( 166 | zip( 167 | *[ 168 | np.array([0, *np.cumsum(counts)]) + start, 169 | np.cumsum(counts) + start, 170 | [n_levels] * self.n_clusters, 171 | is_leaf, 172 | ] 173 | ) 174 | ) 175 | dists = np.linalg.norm(cl.cluster_centers_ - center[None, :], axis=1) 176 | return ( 177 | [*to_return, *children], 178 | sorted_index, 179 | [*cl.cluster_centers_, *ccenters], 180 | [*dists, *cdists], 181 | ) 182 | 183 | def get_arrays(self): 184 | return None, self.indices, self.tree, self.centers, self.dists 185 | 186 | 187 | class MetricTree(BaseEstimator): 188 | def __init__(self, tree_type="ball", leaf_size=40, metric="euclidean", **kwargs): 189 | self.tree_type = tree_type 190 | if tree_type == "ball": 191 | self.tree_cls = BallTree 192 | elif tree_type == "kd": 193 | self.tree_cls = KDTree 194 | elif tree_type == "cluster": 195 | self.tree_cls = ClusterTree 196 | elif tree_type == "quad": 197 | self.tree_cls = QuadTree 198 | else: 199 | raise NotImplementedError("Unknown tree type") 200 | self.kwargs = kwargs 201 | self.leaf_size = leaf_size 202 | self.metric = metric 203 | self.dist_fn = DistanceMetric.get_metric(metric) 204 | 205 | def get_node_weights(self): 206 | """ Takes the middle of the bounds as the node center for each node 207 | TODO (alex): This could be improved or at least experimented with 208 | """ 209 | node_weights = self.tree.get_arrays()[-1] 210 | if self.tree_type == "ball": 211 | centers = node_weights[0] 212 | n = centers.shape[0] 213 | # Subtracts the child from the parent relying on the order of nodes in the tree 214 | lengths = np.linalg.norm( 215 | centers[np.insert(np.arange(n - 1) // 2, 0, 0)] - centers[np.arange(n)], 216 | axis=1, 217 | ) 218 | return lengths 219 | elif self.tree_type == "kd": 220 | # Averages the two boundaries of the KD box 221 | centers = node_weights.mean(axis=0) 222 | n = centers.shape[0] 223 | # Subtracts the child from the parent relying on the order of nodes in the tree 224 | lengths = np.linalg.norm( 225 | centers[np.insert(np.arange(n - 1) // 2, 0, 0)] - centers[np.arange(n)], 226 | axis=1, 227 | ) 228 | return lengths 229 | elif self.tree_type == "cluster": 230 | return node_weights 231 | elif self.tree_type == "quad": 232 | return node_weights 233 | else: 234 | raise NotImplementedError("Unknown tree type") 235 | 236 | def fit_transform(self, X, y): 237 | """ 238 | X is data array (np array) 239 | y is one-hot encoded distribution index (np array of size # points x # 240 | distributions. 241 | """ 242 | X, y = check_X_y(X, y, accept_sparse=True, multi_output=True) 243 | self.classes_ = y.shape[1] # unique_labels(y) 244 | self.X_ = X 245 | self.y_ = y 246 | self.tree = self.tree_cls( 247 | X, leaf_size=self.leaf_size, metric=self.metric, **self.kwargs 248 | ) 249 | tree_indices = self.tree.get_arrays()[1] 250 | node_data = self.tree.get_arrays()[2] 251 | y_indices = y[tree_indices] # reorders point labels by tree order. 252 | 253 | self.edge_weights = self.get_node_weights() 254 | counts = np.empty((len(node_data), y.shape[1])) 255 | for node_idx in reversed(range(len(node_data))): 256 | start, end, is_leaf, radius = node_data[node_idx] 257 | 258 | # Find the number of points present in this range from each distribution 259 | counts[node_idx] = np.sum( 260 | y_indices[start:end], axis=0 261 | ) # as y is a one-hot encoding, we just need to sum over the relevant bits. 262 | 263 | if np.issubdtype(y.dtype, np.floating): 264 | # if is floating then don't worry about the logic below 265 | self.counts_mtx = coo_matrix(counts).T 266 | return self.counts_mtx, self.edge_weights 267 | 268 | # convert to COO format 269 | dim = (self.classes_, len(node_data)) 270 | dist_list = np.arange(1, self.classes_ + 1) 271 | self.counts_mtx = coo_matrix(dim, dtype=np.int32) 272 | for i, count in enumerate(counts): 273 | if np.sum(count) == 0: # if no classes have signals in this region 274 | continue 275 | # get the signals with nonzero representation in the region 276 | # count is a list of the representation per distribution. 277 | # count_copy is used to eliminate distributions without representation 278 | count_copy = count.copy() 279 | count_copy[count_copy > 0] = 1 280 | dists_represented = np.multiply(dist_list, count_copy) 281 | j_list = ( 282 | dists_represented[dists_represented != 0] - 1 283 | ) # we added 1 to the distribution numbers to do the zero trick. 284 | val_list = count[count != 0] 285 | i_list = [i] * len(j_list) 286 | self.counts_mtx += coo_matrix( 287 | (val_list, (j_list, i_list)), shape=dim, dtype=np.int32 288 | ) 289 | 290 | return self.counts_mtx, self.edge_weights 291 | 292 | def transform(self, X): 293 | """ Transforms datasets y to (L1) vector space. 294 | 295 | Returns vectors representing edge weights and weights over vector. 296 | """ 297 | check_is_fitted(self, "X_") 298 | 299 | if X != self.X_: 300 | raise ValueError("X transformed must equal fitted X") 301 | 302 | 303 | if __name__ == "__main__": 304 | mt = MetricTree(tree_type="cluster") 305 | gt = np.repeat(np.arange(10), 100) 306 | gt = ( 307 | (np.repeat(np.arange(max(gt) + 1)[:, None], len(gt), axis=1) == gt) 308 | .astype(int) 309 | .T 310 | ) 311 | counts, edge_weights = mt.fit_transform(X=np.random.random_sample((1000, 3)), y=gt) 312 | print(counts, edge_weights) 313 | print(counts.toarray()[:50]) 314 | -------------------------------------------------------------------------------- /DiffusionEMD/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.5.0" 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | ---------------------------------- 2 | 3 | Non-Commercial License 4 | Yale Copyright © 2024 Yale University. 5 | 6 | Permission is hereby granted to use, copy, modify, and distribute this Software for any non-commercial purpose. Any distribution or modification or derivations of the Software (together “Derivative Works”) must be made available on GitHub and shall include this copyright notice and this permission notice in all copies or substantial portions of the Software. For the purposes of this license, "non-commercial" means not intended for or directed towards commercial advantage or monetary compensation either via the Software itself or Derivative Works or uses of either which lead to or generate any commercial products. In any event, the use and modification of the Software or Derivative Works shall remain governed by the terms and conditions of this Agreement; Any commercial use of the Software requires a separate commercial license from the copyright holder at Yale University. Direct any requests for commercial licenses to Yale Ventures at yaleventures@yale.edu. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | 10 | ---------------------------------- 11 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | ---------------------------------- 2 | 3 | Non-Commercial License 4 | Yale Copyright © 2024 Yale University. 5 | 6 | Permission is hereby granted to use, copy, modify, and distribute this Software for any non-commercial purpose. Any distribution or modification or derivations of the Software (together “Derivative Works”) must be made available on GitHub and shall include this copyright notice and this permission notice in all copies or substantial portions of the Software. For the purposes of this license, "non-commercial" means not intended for or directed towards commercial advantage or monetary compensation either via the Software itself or Derivative Works or uses of either which lead to or generate any commercial products. In any event, the use and modification of the Software or Derivative Works shall remain governed by the terms and conditions of this Agreement; Any commercial use of the Software requires a separate commercial license from the copyright holder at Yale University. Direct any requests for commercial licenses to Yale Ventures at yaleventures@yale.edu. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | 10 | ---------------------------------- 11 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Implementation of Diffusion EMD 2 | =============================== 3 | 4 | Diffusion Earth Mover's Distance embeds the Wasserstein distance between two distributions on a graph into `L^1` in log-linear time. 5 | 6 | Installation 7 | ------------ 8 | 9 | DiffusionEMD is available in `pypi`. Install by running the following:: 10 | 11 | pip install DiffusionEMD 12 | 13 | Quick Start 14 | ----------- 15 | 16 | 17 | DiffusionEMD is written following the `sklearn` estimator framework. We provide two functions that operate quite differently. First the Chebyshev approxiamtion of the operator in `DiffusionCheb`, which we recommend when the number of distributions is small compared to the number of points. Second, the Interpolative Decomposition method that computes dyadic powers of $P^{2^k}$ directly in `DiffusionTree`. These two classes are used in the same way, first supplying parameters, fitting to a graph and array of distributions:: 18 | 19 | import numpy as np 20 | from DiffusionEMD import DiffusionCheb 21 | 22 | # Setup an adjacency matrix and a set of distributions to embed 23 | adj = np.ones((10, 10)) 24 | distributions = np.random.randn(10, 5) 25 | dc = DiffusionCheb() 26 | 27 | # Embeddings where the L1 distance approximates the Earth Mover's Distance 28 | embeddings = dc.fit_transform(adj, distributions) 29 | # Shape: (5, 60) 30 | 31 | Requirements can be found in `requirements.txt` 32 | 33 | Examples 34 | -------- 35 | 36 | Examples are in the `notebooks` directory. 37 | 38 | Take a look at the examples provided there to get a sense of how the parameters 39 | behave on simple examples that are easy to visualize. 40 | 41 | Paper 42 | ----- 43 | 44 | This code implements the algorithms described in this paper: 45 | 46 | ArXiv Link: http://arxiv.org/abs/2102.12833:: 47 | 48 | @InProceedings{pmlr-v139-tong21a, 49 | title = {Diffusion Earth Mover’s Distance and Distribution Embeddings}, 50 | author = {Tong, Alexander Y and Huguet, Guillaume and Natik, Amine and Macdonald, Kincaid and Kuchroo, Manik and Coifman, Ronald and Wolf, Guy and Krishnaswamy, Smita}, 51 | booktitle = {Proceedings of the 38th International Conference on Machine Learning}, 52 | pages = {10336--10346}, 53 | year = {2021}, 54 | editor = {Meila, Marina and Zhang, Tong}, 55 | volume = {139}, 56 | series = {Proceedings of Machine Learning Research}, 57 | month = {18--24 Jul}, 58 | publisher = {PMLR}, 59 | pdf = {http://proceedings.mlr.press/v139/tong21a/tong21a.pdf}, 60 | url = {http://proceedings.mlr.press/v139/tong21a.html}, 61 | } 62 | -------------------------------------------------------------------------------- /assets/schematic_600_400.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/DiffusionEMD/26f147ca49f87e74587f44565081753a354d46ec/assets/schematic_600_400.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16.0 2 | scipy 3 | matplotlib>=3.0 4 | pot 5 | pygsp 6 | graphtools 7 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [build_sphinx] 2 | all-files = 1 3 | source-dir = doc/source 4 | build-dir = doc/build 5 | warning-is-error = 0 6 | 7 | [flake8] 8 | ignore = 9 | # top-level module docstring 10 | D100, D104, W503, 11 | # space before : conflicts with black 12 | E203 13 | per-file-ignores = 14 | # imported but unused 15 | __init__.py: F401 16 | max-line-length = 88 17 | exclude = 18 | .git, 19 | __pycache__, 20 | build, 21 | dist, 22 | test, 23 | doc 24 | 25 | [isort] 26 | profile = black 27 | force_single_line = true 28 | force_alphabetical_sort = true 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | import os 4 | 5 | install_requires = [ 6 | "numpy>=1.16.0", 7 | "scipy", 8 | "matplotlib>=3.0", 9 | "pot", 10 | "pygsp", 11 | "graphtools", 12 | ] 13 | 14 | doc_requires = [ 15 | "sphinx", 16 | "sphinxcontrib-napoleon", 17 | "ipykernel", 18 | "nbsphinx", 19 | "autodocsumm", 20 | ] 21 | 22 | test_requires = [ 23 | "nose", 24 | "nose2", 25 | "coverage", 26 | "coveralls", 27 | "parameterized", 28 | "requests", 29 | "packaging", 30 | "mock", 31 | "matplotlib>=3.0", 32 | "black", 33 | ] 34 | 35 | version_py = os.path.join(os.path.dirname(__file__), "DiffusionEMD", "version.py") 36 | version = open(version_py).read().strip().split("=")[-1].replace('"', "").strip() 37 | 38 | readme = open("README.rst").read() 39 | 40 | setup( 41 | name="DiffusionEMD", 42 | packages=find_packages(), 43 | version=version, 44 | description="Diffusion based earth mover's distance.", 45 | author="Alexander Tong", 46 | author_email="alexandertongdev@gmail.com", 47 | license="MIT", 48 | install_requires=install_requires, 49 | extras_require={ 50 | "test": test_requires, 51 | "doc": doc_requires, 52 | }, 53 | long_description=readme, 54 | url="https://github.com/KrishnaswamyLab/DiffusionEMD", 55 | ) 56 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | max-complexity = 10 4 | --------------------------------------------------------------------------------