├── .gitattributes ├── .github └── workflows │ ├── docs.yml │ └── pre-commit.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── SpatialDE ├── __init__.py ├── _internal │ ├── __init__.py │ ├── distance_cache.py │ ├── gpflow_helpers.py │ ├── kernels.py │ ├── models.py │ ├── optimizer.py │ ├── score_test.py │ ├── sm_kernel.py │ ├── svca.py │ ├── tf_dataset.py │ ├── util.py │ └── util_mixture.py ├── aeh.py ├── de_test.py ├── dp_hmrf.py ├── gaussian_process.py ├── io.py └── svca.py ├── docs ├── .gitignore ├── Makefile ├── make.bat └── source │ ├── _templates │ ├── class.rst │ └── module.rst │ ├── conf.py │ └── index.rst ├── pyproject.toml └── setup.cfg /.gitattributes: -------------------------------------------------------------------------------- 1 | *.csv filter=lfs diff=lfs merge=lfs -text 2 | *.tsv filter=lfs diff=lfs merge=lfs -text 3 | *.jpg filter=lfs diff=lfs merge=lfs -text 4 | *.xlsx filter=lfs diff=lfs merge=lfs -text 5 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | on: [push] 3 | 4 | jobs: 5 | docs: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v2 9 | - uses: actions/setup-python@v2 10 | with: 11 | python-version: "3.9" 12 | - name: Cache pip 13 | uses: actions/cache@v2 14 | with: 15 | # This path is specific to Ubuntu 16 | path: ~/.cache/pip 17 | # Look to see if there is a cache hit for the corresponding requirements file 18 | key: ${{ runner.os }}-pip-${{ hashFiles('setup.cfg') }} 19 | restore-keys: | 20 | ${{ runner.os }}-pip- 21 | ${{ runner.os }}- 22 | - run: pip install .[docs] 23 | - run: pip install patsy # missing dep of NaiveDE 24 | - name: Running the Sphinx to gh-pages Action 25 | uses: uibcdf/action-sphinx-docs-to-gh-pages@v2.1.0 26 | with: 27 | branch: master 28 | dir_docs: docs/source 29 | sphinxopts: '' 30 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | 7 | jobs: 8 | pre-commit: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions/setup-python@v2 13 | - uses: pre-commit/action@v3.0.0 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | .Rhistory 3 | .RData 4 | .Ruserdata 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # celery beat schedule file 91 | celerybeat-schedule 92 | 93 | # SageMath parsed files 94 | *.sage.py 95 | 96 | # Environments 97 | .env 98 | .venv 99 | env/ 100 | venv/ 101 | ENV/ 102 | env.bak/ 103 | venv.bak/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | .dmypy.json 118 | dmypy.json 119 | 120 | # Pyre type checker 121 | .pyre/ 122 | 123 | # Visual studio code 124 | .vscode/ 125 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 22.12.0 4 | hooks: 5 | - id: black 6 | language_version: python3 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Ilia Kats , Stegle Group 4 | Copyright (c) 2018 Teichmann Group 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SpatialDE 2 2 | 3 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 4 | [![Docs: latest](https://img.shields.io/badge/docs-latest-blue.svg)](https://pmbio.github.io/SpatialDE) 5 | 6 | This is the next iteration of the SpatialDE package for analyzing spatial transcriptomics data. 7 | -------------------------------------------------------------------------------- /SpatialDE/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import version as __version__ 2 | 3 | from .de_test import test 4 | 5 | from .gaussian_process import GP, SGPIPM, GPControl, fit, fit_fast, fit_detailed 6 | from .dp_hmrf import ( 7 | tissue_segmentation, 8 | TissueSegmentationParameters, 9 | TissueSegmentationStatus, 10 | TissueSegmentation, 11 | ) 12 | from .aeh import spatial_patterns, SpatialPatternParameters, SpatialPatterns 13 | from .svca import test_spatial_interactions, fit_spatial_interactions 14 | from .io import read_spaceranger 15 | 16 | import tensorflow as tf 17 | 18 | gpus = tf.config.experimental.list_physical_devices("GPU") 19 | if gpus: 20 | try: 21 | for gpu in gpus: 22 | tf.config.experimental.set_memory_growth(gpu, True) 23 | logical_gpus = tf.config.experimental.list_logical_devices("GPU") 24 | except RuntimeError as e: 25 | print(e) 26 | del tf 27 | del gpus 28 | -------------------------------------------------------------------------------- /SpatialDE/_internal/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /SpatialDE/_internal/distance_cache.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import tensorflow as tf 3 | from gpflow import default_float 4 | from gpflow.utilities import to_default_float 5 | from gpflow.utilities.ops import square_distance, difference_matrix 6 | 7 | 8 | def cached(variable): 9 | def cache(func): 10 | func = tf.function(func, experimental_compile=True, experimental_relax_shapes=True) 11 | 12 | @functools.wraps(func) 13 | def caching_wrapper(self, *args, **kwargs): 14 | if not hasattr(self, variable) or getattr(self, variable) is None: 15 | mat = func(self.X) 16 | if self._cache: 17 | setattr(self, variable, mat) 18 | else: 19 | mat = getattr(self, variable) 20 | return mat 21 | 22 | return caching_wrapper 23 | 24 | return cache 25 | 26 | 27 | class DistanceCache: 28 | def __init__(self, X: tf.Tensor, cache=True): 29 | self.X = X 30 | self._cache = cache 31 | 32 | @property 33 | @cached("_squaredEuclidean") 34 | def squaredEuclideanDistance(X): 35 | return square_distance(to_default_float(X), None) 36 | 37 | @property 38 | @cached("_sumDiff") 39 | def sumOfDifferences(X): 40 | return tf.reduce_sum(difference_matrix(to_default_float(X), None), axis=-1) 41 | -------------------------------------------------------------------------------- /SpatialDE/_internal/gpflow_helpers.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from abc import ABCMeta, abstractmethod 3 | from collections import defaultdict 4 | from dataclasses import dataclass 5 | from typing import Optional, Union 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import gpflow 10 | import tensorflow as tf 11 | 12 | from .models import Model 13 | from .sm_kernel import * 14 | from .util import gower_factor 15 | 16 | 17 | class Linear(gpflow.kernels.Linear): 18 | def K_novar(self, X, X2=None): 19 | if X2 is None: 20 | return tf.matmul(X, X, transpose_b=True) 21 | else: 22 | return tf.tensordot(X, X2, [-1, -1]) 23 | 24 | 25 | class SMPlusLinearKernel(gpflow.kernels.Sum): 26 | def __init__(self, sm_kernel): 27 | super().__init__([sm_kernel, Linear()]) 28 | 29 | def K_novar(self, X, X2=None): 30 | return self._reduce([k.K_novar(X, X2) for k in self.kernels]) 31 | 32 | @property 33 | def spectral_mixture(self): 34 | return self.kernels[0] 35 | 36 | @property 37 | def linear(self): 38 | return self.kernels[1] 39 | 40 | @staticmethod 41 | def _scaled_var(X, k): 42 | return gower_factor(k(X), k.variance) 43 | 44 | def scaled_variance(self, X): 45 | smvars = tf.convert_to_tensor([self._scaled_var(X, k) for k in self.kernels[0]]) 46 | linvar = self._scaled_var(X, self.kernels[1]) 47 | return tf.concat((smvars, tf.expand_dims(linvar, axis=0)), axis=0) 48 | 49 | 50 | class GeneGPModel(metaclass=ABCMeta): 51 | @abstractmethod 52 | def freeze(self): 53 | pass 54 | 55 | @staticmethod 56 | def mixture_kernel(X, Y, ncomponents=5, ard=True, minvar=1e-3): 57 | range = tf.reduce_min(tf.reduce_max(X, axis=0) - tf.reduce_min(X, axis=0)) 58 | dist = tf.sqrt(gpflow.utilities.ops.square_distance(X, None)) 59 | dist = tf.linalg.set_diag(dist, tf.fill((X.shape[0],), tf.cast(np.inf, dist.dtype))) 60 | min_1nndist = tf.reduce_min(dist) 61 | 62 | datarange = tf.math.reduce_max( 63 | tf.math.reduce_max(X, axis=0) - tf.math.reduce_min(X, axis=0) 64 | ) 65 | 66 | minperiod = 2 * min_1nndist 67 | varinit = min_1nndist + 0.5 * (range - min_1nndist) 68 | periodinit = minperiod + 0.5 * (range - minperiod) 69 | 70 | if ard: 71 | varinit = tf.repeat(varinit, X.shape[1]) 72 | periodinit = tf.repeat(periodinit, X.shape[1]) 73 | 74 | maxvar = 10 * tf.math.reduce_variance(Y) 75 | kernels = [] 76 | for v in np.linspace(minvar, np.minimum(1, 0.9 * maxvar), ncomponents): 77 | k = Spectral( 78 | variance=gpflow.Parameter( 79 | v, 80 | transform=tfp.bijectors.Sigmoid( 81 | low=gpflow.utilities.to_default_float(0), 82 | high=gpflow.utilities.to_default_float(maxvar), 83 | ), 84 | ), 85 | lengthscales=gpflow.Parameter( 86 | varinit, transform=tfp.bijectors.Sigmoid(low=0.1 * min_1nndist, high=datarange) 87 | ), 88 | periods=gpflow.Parameter( 89 | periodinit, transform=tfp.bijectors.Sigmoid(low=minperiod, high=2 * datarange) 90 | ), 91 | ) 92 | kernels.append(k) 93 | kern = SpectralMixture(kernels) 94 | return SMPlusLinearKernel(kern) 95 | 96 | 97 | class GPR(gpflow.models.GPR, GeneGPModel): 98 | def __init__( 99 | self, 100 | X: np.ndarray, 101 | Y: np.ndarray, 102 | n_kernel_components: int = 5, 103 | ard: bool = True, 104 | minvar: float = 1e-3, 105 | ): 106 | kern = self.mixture_kernel(X, Y, n_kernel_components, ard, minvar) 107 | super().__init__(data=[X, Y], kernel=kern, mean_function=gpflow.mean_functions.Constant()) 108 | 109 | def freeze(self): 110 | X = self.data[0] 111 | self.data = (None, self.data[1]) 112 | frozen = gpflow.utilities.freeze(self) 113 | frozen.data = self.data = (X, self.data[1]) 114 | return frozen 115 | 116 | 117 | class SGPR(gpflow.models.SGPR, GeneGPModel): 118 | def __init__( 119 | self, 120 | X: np.ndarray, 121 | Y: np.ndarray, 122 | inducing_variable: Union[np.ndarray, gpflow.inducing_variables.InducingPoints], 123 | n_kernel_components: int = 5, 124 | ard: bool = True, 125 | minvar: float = 1e-3, 126 | ): 127 | kern = self.mixture_kernel(X, Y, n_kernel_components, ard, minvar) 128 | super().__init__( 129 | data=[X, Y], 130 | kernel=kern, 131 | inducing_variable=inducing_variable, 132 | mean_function=gpflow.mean_functions.Constant(), 133 | ) 134 | 135 | def freeze(self): 136 | X = self.data[0] 137 | self.data = (None, self.data[1]) 138 | 139 | trainable_inducers = self.inducing_variable.Z.trainable 140 | if not trainable_inducers: 141 | Z = self.inducing_variable.Z 142 | frozen = gpflow.utilities.freeze(self) 143 | frozen.data = self.data = (X, self.data[1]) 144 | if not trainable_inducers: 145 | frozen.inducing_variable.Z = Z 146 | return frozen 147 | 148 | def log_marginal_likelihood(self): 149 | return self.elbo() 150 | 151 | 152 | @dataclass(frozen=True) 153 | class VarPart: 154 | spectral_mixture: tf.Tensor 155 | linear: tf.Tensor 156 | noise: tf.Tensor 157 | 158 | 159 | @dataclass(frozen=True) 160 | class Variance: 161 | scaled_variance: VarPart 162 | fraction_variance: VarPart 163 | var_fraction_variance: VarPart 164 | 165 | 166 | class GeneGP(Model): 167 | def __init__(self, model: GeneGPModel, minimize_fun, *args, **kwargs): 168 | self.model = model 169 | 170 | self._frozen = False 171 | self._variancevars = list(self.model.kernel.parameters) 172 | self._variancevars.append(self.model.likelihood.variance) 173 | 174 | self._trainable_variance_idx = [] 175 | offset = 0 176 | variancevars = set(self._variancevars) 177 | for v in self.model.parameters: 178 | if v in variancevars: 179 | self._trainable_variance_idx.extend([offset + int(i) for i in range(tf.size(v))]) 180 | offset += int(tf.size(v)) 181 | 182 | self.__invHess = None 183 | self.model.likelihood.variance.assign(tf.math.reduce_variance(self.model.data[1])) 184 | 185 | t0 = time() 186 | self._optimize(minimize_fun, *args, **kwargs) 187 | self._freeze() 188 | t = time() - t0 189 | 190 | self._time = t 191 | 192 | @property 193 | def kernel(self): 194 | return self.model.kernel 195 | 196 | @property 197 | def K(self): 198 | return self.model.kernel.K_novar(self.model.data[0]) 199 | 200 | @property 201 | def y(self): 202 | return tf.squeeze(self.model.data[1]).numpy() 203 | 204 | def predict_mean(self, X=None): 205 | if X is None: 206 | X = self.model.data[0] 207 | return self.model.predict_f(X)[0] 208 | 209 | def plot_power_spectrum(self, xlim: float = None, ylim: float = None, **kwargs): 210 | return self.model.kernel.spectral_mixture.plot_power_spectrum(xlim, ylim, **kwargs) 211 | 212 | @property 213 | def time(self): 214 | return self._time 215 | 216 | @staticmethod 217 | def _concat_tensors(tens): 218 | return tf.concat([tf.reshape(t, (-1,)) for t in tens], axis=0) 219 | 220 | @property 221 | def _invHess(self): 222 | if self.__invHess is None: 223 | # tf.hessians doesn't work yet (https://github.com/tensorflow/tensorflow/issues/29781) 224 | # and tape.jacobian() doesn't like lengthscale and period parameters for some reason 225 | # (it aborts with AttributeError: Tensor.graph is meaningless when eager execution is enabled.). 226 | # So we need to do this the hard way 227 | with tf.GradientTape(persistent=True) as tape: 228 | y = self.model.log_marginal_likelihood() 229 | tape.watch(y) 230 | grad = self._concat_tensors(tape.gradient(y, self.model.trainable_variables)) 231 | grads = tf.split( 232 | grad, tf.ones((tf.size(grad),), dtype=tf.int32) 233 | ) # this is necessary to be able to get the gradient of each entry 234 | hess = tf.stack( 235 | [ 236 | self._concat_tensors( 237 | tape.gradient( 238 | g, 239 | self.model.trainable_variables, 240 | unconnected_gradients=tf.UnconnectedGradients.ZERO, 241 | ) 242 | ) 243 | for g in grads 244 | ] 245 | ) 246 | self._invHess = tf.linalg.inv(hess) 247 | return self.__invHess 248 | 249 | @_invHess.setter 250 | def _invHess(self, invhess): 251 | x, y = tf.meshgrid(self._trainable_variance_idx, self._trainable_variance_idx) 252 | invhess = tf.reshape( 253 | tf.gather_nd(invhess, tf.stack([tf.reshape(x, (-1,)), tf.reshape(y, (-1,))], axis=1)), 254 | (len(self._trainable_variance_idx), len(self._trainable_variance_idx)), 255 | ) 256 | self.__invHess = invhess 257 | 258 | def _optimize(self, minimize_fun, *args, **kwargs): 259 | res = minimize_fun( 260 | lambda: -self.model.log_marginal_likelihood(), 261 | self.model.trainable_variables, 262 | *args, 263 | **kwargs, 264 | ) 265 | if isinstance(res, dict) and "hess_inv" in res: 266 | self._invHess = gpflow.utilities.to_default_float(res["hess_inv"]) 267 | elif hasattr(res, "hess_inv"): 268 | self._invHess = gpflow.utilities.to_default_float(res.hess_inv) 269 | 270 | def _freeze(self): 271 | if self._frozen: 272 | return 273 | # this code calculates the variance of the fraction of spatial variance estimate 274 | # We use the negative of the Hessian of the marginal log-likelihood as observed Fisher observation, the inverse of which is the 275 | # asymptotic covariance matrix of the estimate. We then use the Delta method to get the asymptotic variance of FSV. 276 | # TODO: I'm not quite sure if this is valid for the case of free inducing points, since these are variational parameters 277 | with tf.GradientTape() as t: 278 | variances = self._concat_tensors( 279 | [ 280 | self.model.kernel.scaled_variance(self.model.data[0]), 281 | tf.expand_dims(self.model.likelihood.variance, axis=0), 282 | ] 283 | ) 284 | totalvar = tf.reduce_sum(variances) 285 | variance_fractions = variances / totalvar 286 | 287 | grads = t.jacobian( 288 | variance_fractions, [v.unconstrained_variable for v in self._variancevars] 289 | ) 290 | grads = tf.concat([tf.expand_dims(g, -1) if g.ndim < 2 else g for g in grads], axis=1) 291 | errors = tf.reduce_sum((grads @ self._invHess) * grads, axis=1) 292 | 293 | variances = VarPart(variances[0:-2], variances[-2], variances[-1]) 294 | variance_fractions = VarPart( 295 | variance_fractions[0:-2], variance_fractions[-2], variance_fractions[-1] 296 | ) 297 | errors = VarPart(errors[0:-2], errors[-2], errors[-1]) 298 | 299 | self.variances = Variance(variances, variance_fractions, errors) 300 | 301 | self.model = self.model.freeze() 302 | self._frozen = True 303 | 304 | 305 | class DataSetResults(dict): 306 | def __init__(self, *args, **kwargs): 307 | super().__init__(*args, **kwargs) 308 | 309 | def __setitem__(self, key, value: GeneGP): 310 | if not isinstance(value, GeneGP): 311 | raise TypeError("value must be a GeneGP object") 312 | super().__setitem__(key, value) 313 | 314 | def to_df(self, modelcol: str = "model"): 315 | df = defaultdict(lambda: []) 316 | for gene, res in self.items(): 317 | df["gene"].append(gene) 318 | variances = res.variances 319 | df["FSV"].append((1 - variances.fraction_variance.noise).numpy()) 320 | df["s2_FSV"].append(variances.var_fraction_variance.noise.numpy()) 321 | 322 | for i, k in enumerate(res.kernel.spectral_mixture.kernels): 323 | df["sm_variance_%i" % i].append(k.variance.numpy()) 324 | df["sm_lengthscale_%i" % i].append(k.lengthscales.numpy()) 325 | df["sm_period_%i" % i].append(k.periods.numpy()) 326 | df["linear_variance"].append(res.kernel.linear.variance.numpy()) 327 | df["noise_variance"].append(res.model.likelihood.variance.numpy()) 328 | df["time"].append(res.time) 329 | df[modelcol].append(res) 330 | df = pd.DataFrame(df) 331 | return df 332 | -------------------------------------------------------------------------------- /SpatialDE/_internal/kernels.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional 2 | from abc import ABCMeta, abstractmethod 3 | import math 4 | 5 | import tensorflow as tf 6 | from gpflow import default_float 7 | from gpflow.utilities import to_default_float 8 | from gpflow.utilities.ops import square_distance, difference_matrix 9 | 10 | from .distance_cache import DistanceCache 11 | 12 | 13 | def scale(X: tf.Tensor, lengthscale: Optional[float] = 1): 14 | if X is not None: 15 | return X / lengthscale 16 | else: 17 | return X 18 | 19 | 20 | def scaled_difference_matrix( 21 | X: tf.Tensor, 22 | Y: Optional[tf.Tensor] = None, 23 | lengthscale: Optional[float] = 1, 24 | ): 25 | return difference_matrix(scale(X, lengthscale), scale(Y, lengthscale)) 26 | 27 | 28 | def scaled_squared_distance( 29 | X: tf.Tensor, Y: Optional[tf.Tensor] = None, lengthscale: Optional[float] = 1 30 | ): 31 | return square_distance(scale(X, lengthscale), scale(Y, lengthscale)) 32 | 33 | 34 | class Kernel(metaclass=ABCMeta): 35 | def __init__(self, cache: DistanceCache): 36 | self._cache = cache 37 | 38 | def K(self, X: Optional[tf.Tensor] = None, Y: Optional[tf.Tensor] = None): 39 | if X is None and Y is not None: 40 | X, Y = Y, X 41 | if (X is None or X is self._cache.X) and Y is None: 42 | return self._K(cache=True) 43 | else: 44 | X = to_default_float(X) 45 | if Y is not None: 46 | Y = to_default_float(Y) 47 | return self._K(X, Y) 48 | 49 | def K_diag(self, X: Optional[tf.Tensor]): 50 | if X is None: 51 | return self._K_diag(cache=True) 52 | else: 53 | return self._K_diag(to_default_float(X)) 54 | 55 | @abstractmethod 56 | def _K(self, X: Optional[tf.Tensor] = None, Y: Optional[tf.Tensor] = None, cache: bool = False): 57 | pass 58 | 59 | @abstractmethod 60 | def _K_diag(self, X: Optional[tf.Tensor] = None, cache: bool = False): 61 | pass 62 | 63 | 64 | class StationaryKernel(Kernel): 65 | def __init__(self, cache: DistanceCache, lengthscale=1): 66 | super().__init__(cache) 67 | self.lengthscale = lengthscale 68 | 69 | def _K_diag(self, X: Optional[tf.Tensor] = None, cache: bool = False): 70 | if cache: 71 | n = self._cache.X.shape[0] 72 | else: 73 | n = X.shape[0] 74 | return self._K_diag_impl(n) 75 | 76 | def _K_diag_impl(self, n: int): 77 | return tf.repeat(tf.convert_to_tensor(1, dtype=default_float()), n) 78 | 79 | 80 | class SquaredExponential(StationaryKernel): 81 | def _K(self, X: Optional[tf.Tensor] = None, Y: Optional[tf.Tensor] = None, cache: bool = False): 82 | if cache: 83 | dist = self._cache.squaredEuclideanDistance / self.lengthscale**2 84 | else: 85 | dist = scaled_squared_distance(X, Y, self.lengthscale) 86 | return tf.exp(-0.5 * dist) 87 | 88 | 89 | class Cosine(StationaryKernel): 90 | def _K(self, X: Optional[tf.Tensor] = None, Y: Optional[tf.Tensor] = None, cache: bool = False): 91 | if cache: 92 | dist = self._cache.sumOfDifferences / self.lengthscale 93 | else: 94 | dist = tf.reduce_sum(scaled_difference_matrix(X, Y, self.lengthscale), axis=-1) 95 | return tf.cos(2 * math.pi * dist) 96 | 97 | 98 | class Linear(Kernel): 99 | def _K(self, X: tf.Tensor, Y: Optional[tf.Tensor] = None): 100 | if Y is None: 101 | Y = X 102 | return tf.sum(X[:, tf.newaxis, :] * Y[tf.newaxis, ...], axis=-1) 103 | 104 | def _K_diag(self, X: tf.Tensor, cache: bool = False): 105 | if cache: 106 | X = self._cache.X 107 | return tf.sum(tf.square(X), axis=-1) 108 | -------------------------------------------------------------------------------- /SpatialDE/_internal/models.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Union 3 | from dataclasses import dataclass 4 | 5 | import numpy as np 6 | import scipy 7 | from scipy import optimize 8 | from scipy.misc import derivative 9 | from scipy.stats import chi2 10 | 11 | from .kernels import Kernel 12 | 13 | 14 | class Model: 15 | def __init__(self, X: np.ndarray, kernel: Kernel): 16 | self.X = X 17 | self.n = X.shape[0] 18 | self.kernel = kernel 19 | 20 | self._K = None 21 | self._y = None 22 | 23 | def _reset(self): 24 | pass 25 | 26 | @property 27 | def K(self): 28 | if self._K is not None: 29 | return self._K 30 | else: 31 | return self.kernel.K(self.X) 32 | 33 | @property 34 | def n_parameters(self) -> float: 35 | return 0 36 | 37 | @property 38 | def FSV(self) -> float: 39 | return np.nan 40 | 41 | @property 42 | def s2_FSV(self) -> float: 43 | return np.nan 44 | 45 | @property 46 | def logdelta(self) -> float: 47 | return np.nan 48 | 49 | @logdelta.setter 50 | def logdelta(self, ld: float): 51 | pass 52 | 53 | @property 54 | def s2_logdelta(self) -> float: 55 | return np.nan 56 | 57 | @property 58 | def delta(self) -> float: 59 | return np.nan 60 | 61 | @delta.setter 62 | def delta(self, nd: float): 63 | pass 64 | 65 | @property 66 | def s2_delta(self) -> float: 67 | return np.nan 68 | 69 | @property 70 | def mu(self) -> float: 71 | return np.nan 72 | 73 | @property 74 | def sigma_s2(self) -> float: 75 | return np.nan 76 | 77 | @property 78 | def sigma_n2(self) -> float: 79 | return np.nan 80 | 81 | @property 82 | def y(self) -> np.ndarray: 83 | return self._y 84 | 85 | @y.setter 86 | def y(self, newy: np.ndarray): 87 | self._y = newy 88 | self._reset() 89 | self._ychanged() 90 | 91 | def optimize(self): 92 | pass 93 | 94 | @property 95 | def log_marginal_likelihood(self) -> float: 96 | self._check_y() 97 | return self._lml() 98 | 99 | def _lml(self): 100 | return np.nan 101 | 102 | def _ychanged(self): 103 | pass 104 | 105 | def _check_y(self): 106 | if self.y is None: 107 | raise RuntimeError("assign observed values first") 108 | if self.y.shape[0] != self.n: 109 | raise RuntimeError("different numbers of observations in y and X") 110 | 111 | def __enter__(self): 112 | self._K = self.K 113 | 114 | def __exit__(self, *args): 115 | self._K = None 116 | 117 | 118 | class GPModel(Model): 119 | def __init__(self, X: np.ndarray, kernel: Kernel): 120 | super().__init__(X, kernel) 121 | 122 | self._y = None 123 | self._logdelta = 0 124 | self._reset() 125 | 126 | def __enter__(self): 127 | super().__enter__() 128 | self._reset() 129 | # Gower normalization factor for covariance matric K 130 | # Based on https://github.com/PMBio/limix/blob/master/limix/utils/preprocess.py 131 | self.gower = (np.trace(self.K) - np.sum(np.mean(self.K, axis=0))) / (self.n - 1) 132 | 133 | return self 134 | 135 | def __exit__(self, *args): 136 | super().__exit__(*args) 137 | 138 | def _reset(self): 139 | self._s2_FSV = None 140 | self._s2_delta = None 141 | self._s2_logdelta = None 142 | self._mu = None 143 | self._sigma_s2 = None 144 | self._sigma_n2 = None 145 | 146 | @property 147 | @abstractmethod 148 | def n_parameters(self): 149 | pass 150 | 151 | @property 152 | def FSV(self): 153 | return self.gower / (self.delta + self.gower) 154 | 155 | @property 156 | def s2_FSV(self): 157 | if self._s2_FSV is None: 158 | self._s2_FSV = self._calc_s2_FSV() 159 | return self._s2_FSV 160 | 161 | @property 162 | def logdelta(self) -> float: 163 | return self._logdelta 164 | 165 | @logdelta.setter 166 | def logdelta(self, ld: float): 167 | self._logdelta = ld 168 | self._reset() 169 | self._logdeltachanged() 170 | 171 | @property 172 | def s2_logdelta(self) -> float: 173 | if self._s2_logdelta is None: 174 | self._s2_logdelta = self._calc_s2_logdelta() 175 | return self._s2_logdelta 176 | 177 | @property 178 | def s2_delta(self) -> float: 179 | if self._s2_delta is None: 180 | self._s2_delta = self._calc_s2_delta() 181 | return self._s2_delta 182 | 183 | def _logdeltachanged(self): 184 | pass 185 | 186 | @property 187 | def delta(self) -> float: 188 | return np.exp(self.logdelta) 189 | 190 | @delta.setter 191 | def delta(self, d: float): 192 | self.logdelta = np.log(d) 193 | 194 | def _objective(self, func): 195 | def obj(logdelta): 196 | self.logdelta = logdelta 197 | return func() 198 | 199 | return obj 200 | 201 | def optimize(self): 202 | res = optimize.minimize( 203 | self._objective(lambda: -self.log_marginal_likelihood), 204 | 10, 205 | method="l-bfgs-b", 206 | bounds=[(-10, 20)], 207 | jac=False, 208 | options={"eps": 1e-4}, 209 | ) 210 | self.logdelta = res.x[0] 211 | return res 212 | 213 | @abstractmethod 214 | def _lml(self): 215 | pass 216 | 217 | @property 218 | def mu(self) -> float: 219 | if self._mu is None: 220 | self._mu = self._calc_mu() 221 | return self._mu 222 | 223 | @property 224 | def sigma_s2(self) -> float: 225 | if self._sigma_s2 is None: 226 | self._sigma_s2 = self._calc_sigma_s2() 227 | return self._sigma_s2 228 | 229 | @property 230 | def sigma_n2(self) -> float: 231 | if self._sigma_n2 is None: 232 | self._sigma_n2 = self._calc_sigma_n2() 233 | return self._sigma_n2 234 | 235 | @abstractmethod 236 | def _calc_mu(self) -> float: 237 | pass 238 | 239 | @abstractmethod 240 | def _calc_sigma_s2(self) -> float: 241 | pass 242 | 243 | @abstractmethod 244 | def _calc_sigma_s2(self) -> float: 245 | pass 246 | 247 | def _calc_s2_logdelta(self) -> float: 248 | ld = self.logdelta 249 | s2 = -1 / derivative( 250 | self._objective(lambda: self.log_marginal_likelihood), self.logdelta, n=2 251 | ) 252 | self.logdelta = ld 253 | return s2 254 | 255 | def _calc_s2_delta(self) -> float: 256 | return self.s2_logdelta * self.delta 257 | 258 | def _calc_s2_FSV(self) -> float: 259 | ld = self.logdelta 260 | der = derivative(self._objective(lambda: self.FSV), self.logdelta, n=1) 261 | self.logdelta = ld 262 | return der**2 / self.s2_logdelta 263 | 264 | 265 | class SGPR(GPModel): 266 | def __init__(self, X: np.ndarray, Z: np.ndarray, kern: Kernel): 267 | super().__init__(X, kernel=kern) 268 | self.Z = Z 269 | self.z = Z.shape[0] 270 | 271 | def __enter__(self): 272 | super().__enter__() 273 | K_uu = self.kernel.K(self.Z) 274 | K_uf = self.kernel.K(self.Z, self.X) 275 | K_ff = self.kernel.K_diag(self.X) 276 | 277 | L = np.linalg.cholesky(K_uu + 1e-6 * np.eye(self.z)) 278 | LK_uf = scipy.linalg.solve_triangular(L, K_uf, lower=True) 279 | A = LK_uf @ LK_uf.T 280 | 281 | self._Lambda, U = np.linalg.eigh(A) 282 | self._B = U.T @ LK_uf 283 | self._B1 = np.sum(self._B, axis=-1) 284 | self._By = None 285 | self._traceterm = np.sum(K_ff) - np.trace(A) 286 | 287 | return self 288 | 289 | def __exit__(self, *args): 290 | super().__exit__(*args) 291 | self._Lambda = None 292 | self._B = None 293 | self._B1 = None 294 | self._By = None 295 | self._traceterm = None 296 | 297 | @property 298 | def n_parameters(self) -> float: 299 | return 3 300 | 301 | def _lml(self) -> float: 302 | delta = self.delta 303 | return 0.5 * ( 304 | -self.n * np.log(2 * np.pi) 305 | - self.z * np.log(self.sigma_s2) 306 | - (self.n - self.z) * np.log(self.sigma_n2) 307 | - np.sum(np.log(delta + self._Lambda)) 308 | - self.n 309 | - self._traceterm / delta 310 | ) 311 | 312 | def _residual_quadratic(self) -> float: 313 | self._check_y() 314 | return ( 315 | np.sum(self.y**2) 316 | - 2 * np.sum(self.y) * self.mu 317 | + self.mu**2 * self.n 318 | - np.sum((self._By - self._B1 * self.mu) ** 2 / self._dL) 319 | ) 320 | 321 | def _calc_mu(self): 322 | self._check_y() 323 | sy = np.sum(self.y) 324 | ytones = np.sum(self._By * self._B1 / self._dL) 325 | onesones = np.sum(self._B1**2 / self._dL) 326 | 327 | return (sy - ytones) / (self.n - onesones) 328 | 329 | def _calc_sigma_s2(self): 330 | return self.sigma_n2 / self.delta 331 | 332 | def _calc_sigma_n2(self): 333 | return self._residual_quadratic() / self.n 334 | 335 | def _ychanged(self): 336 | if self._B is not None: 337 | self._By = np.dot(self._B, self.y) 338 | 339 | def _logdeltachanged(self): 340 | self._dL = self.delta + self._Lambda 341 | 342 | 343 | class GPR(GPModel): 344 | def __init__(self, X: np.ndarray, kern: Kernel): 345 | super().__init__(X, kernel=kern) 346 | 347 | def __enter__(self): 348 | super().__enter__() 349 | 350 | K = self.kernel.K(self.X) 351 | self._Lambda, self._U = np.linalg.eigh(K) 352 | self._U1 = np.sum(self._U, axis=0) 353 | self._Uy = None 354 | 355 | return self 356 | 357 | def __exit__(self, *args): 358 | super().__exit__(*args) 359 | self._Lambda = self._U = None 360 | self._U1 = None 361 | self._Uy = None 362 | 363 | @property 364 | def n_parameters(self) -> float: 365 | return 3 366 | 367 | def _lml(self) -> float: 368 | return 0.5 * ( 369 | -self.n * np.log(2 * np.pi) 370 | - np.sum(np.log(self._dL)) 371 | - self.n 372 | - self.n * np.log(self.sigma_s2) 373 | ) 374 | 375 | def _residual_quadratic(self) -> float: 376 | return np.sum((self._Uy - self._U1 * self.mu) ** 2 / self._dL) 377 | 378 | def _calc_mu(self) -> float: 379 | return np.sum(self._U1 * self._Uy / self._dL) / np.sum(np.square(self._U1) / self._dL) 380 | 381 | def _calc_sigma_s2(self) -> float: 382 | return self._residual_quadratic() / self.n 383 | 384 | def _calc_sigma_n2(self) -> float: 385 | return self.sigma_s2 * self.delta 386 | 387 | def _ychanged(self): 388 | if self._U is not None: 389 | self._Uy = np.dot(self._U.T, self.y) 390 | 391 | def _logdeltachanged(self): 392 | self._dL = self.delta + self._Lambda 393 | 394 | 395 | class Constant(Model): 396 | def __init__(self, X: np.ndarray): 397 | super().__init__(X, Kernel()) 398 | 399 | def _reset(self): 400 | self._mu = None 401 | self._s2 = None 402 | 403 | @property 404 | def n_parameters(self) -> float: 405 | return 2 406 | 407 | def optimize(self): 408 | self._check_y() 409 | self._mu = np.mean(self.y) 410 | self._s2 = np.var(self.y, ddof=0) 411 | return optimize.OptimizeResult(success=True) 412 | 413 | @property 414 | def mu(self) -> float: 415 | if self._mu is None: 416 | self.optimize() 417 | return self._mu 418 | 419 | @property 420 | def sigma_n2(self) -> float: 421 | if self._s2 is None: 422 | self.optimize() 423 | return self._s2 424 | 425 | def _lml(self) -> float: 426 | return ( 427 | -0.5 * self.n * np.log(2 * np.pi * self.sigma_n2) 428 | - 0.5 * np.sum(np.square(self.y - self.mu)) / self.sigma_n2 429 | ) 430 | 431 | 432 | class Null(Model): 433 | def __init__(self, X: np.ndarray): 434 | super().__init__(X, Kernel()) 435 | 436 | def _reset(self): 437 | self._s2 = None 438 | 439 | @property 440 | def n_parameters(self) -> float: 441 | return 1 442 | 443 | def optimize(self): 444 | self._check_y() 445 | self._s2 = np.sum(np.square(self.y)) / self.n 446 | return optimize.OptimizeResult(success=True) 447 | 448 | @property 449 | def sigma_n2(self) -> float: 450 | if self._s2 is None: 451 | self.optimize() 452 | return self._s2 453 | 454 | def _lml(self) -> float: 455 | return ( 456 | -0.5 * self.n * np.log(2 * np.pi * self.sigma_n2) 457 | - 0.5 * np.square(self.y) / self.sigma_n2 458 | ) 459 | 460 | 461 | def model_factory(X: np.ndarray, Z: Union[np.ndarray, None], kern: Kernel, *args, **kwargs): 462 | if Z is None: 463 | return GPR(X, kern, *args, **kwargs) 464 | else: 465 | return SGPR(X, Z, kern, *args, **kwargs) 466 | -------------------------------------------------------------------------------- /SpatialDE/_internal/optimizer.py: -------------------------------------------------------------------------------- 1 | import scipy.optimize 2 | import tensorflow as tf 3 | 4 | from .util import concat_tensors, assign_concat 5 | 6 | 7 | class MultiScipyOptimizer: 8 | def __init__(self, objective, variables): 9 | self.objective = objective 10 | self.variables = variables 11 | self._obj = self._wrap_func(objective, variables) 12 | 13 | def minimize(self, method="bfgs", **scipy_kwargs): 14 | res = scipy.optimize.minimize( 15 | self._obj, 16 | concat_tensors(self.variables).numpy(), 17 | method=method, 18 | jac=True, 19 | **scipy_kwargs, 20 | ) 21 | assign_concat(res.x, self.variables) 22 | return res 23 | 24 | @classmethod 25 | def _wrap_func(cls, func, vars): 26 | def _objective(x): 27 | assign_concat(x, vars) 28 | with tf.GradientTape() as t: 29 | obj = func() 30 | grads = concat_tensors( 31 | t.gradient(obj, vars, unconnected_gradients=tf.UnconnectedGradients.ZERO) 32 | ) 33 | return obj, grads 34 | 35 | _objective = tf.function(_objective) 36 | 37 | def _obj(x): 38 | loss, grad = _objective(x) 39 | return loss.numpy(), grad.numpy() 40 | 41 | return _obj 42 | -------------------------------------------------------------------------------- /SpatialDE/_internal/score_test.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass, fields 3 | from typing import Optional, Union, List, Tuple 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | import tensorflow_probability as tfp 8 | from gpflow import default_float 9 | from gpflow.utilities import to_default_float 10 | 11 | tfd = tfp.distributions 12 | from scipy.optimize import minimize 13 | 14 | from .kernels import Kernel 15 | 16 | from enum import Enum, auto 17 | import math 18 | 19 | 20 | @dataclass(frozen=True) 21 | class ScoreTestResults: 22 | kappa: Union[float, tf.Tensor] 23 | nu: Union[float, tf.Tensor] 24 | U_tilde: Union[float, tf.Tensor] 25 | e_tilde: Union[float, tf.Tensor] 26 | I_tilde: Union[float, tf.Tensor] 27 | pval: Union[float, tf.Tensor] 28 | 29 | def to_dict(self): 30 | ret = {} 31 | for f in fields(self): 32 | obj = getattr(self, f.name) 33 | if tf.is_tensor(obj): 34 | obj = obj.numpy() 35 | ret[f.name] = obj 36 | return ret 37 | 38 | 39 | def combine_pvalues( 40 | results: Union[ScoreTestResults, List[ScoreTestResults], tf.Tensor, np.ndarray] 41 | ) -> float: 42 | if isinstance(results, ScoreTestResults): 43 | pvals = results.pval 44 | elif isinstance(results, list): 45 | pvals = tf.stack([r.pval for r in results], axis=0) 46 | elif tf.is_tensor(results): 47 | pvals = results 48 | elif isinstance(results, np.ndarray): 49 | pvals = tf.convert_to_tensor(pvals) 50 | else: 51 | raise TypeError("Unknown type for results.") 52 | 53 | comb = tf.reduce_mean(tf.tan((0.5 - pvals) * math.pi)) 54 | return tfd.Cauchy( 55 | tf.convert_to_tensor(0, comb.dtype), tf.convert_to_tensor(1, comb.dtype) 56 | ).survival_function(comb) 57 | 58 | 59 | class ScoreTest(ABC): 60 | @dataclass 61 | class NullModel(ABC): 62 | pass 63 | 64 | def __init__( 65 | self, 66 | omnibus: bool = False, 67 | kernel: Optional[Union[Kernel, List[Kernel]]] = None, 68 | yidx: Optional[tf.Tensor] = None, 69 | ): 70 | self._yidx = yidx 71 | self.omnibus = omnibus 72 | self.n = None 73 | if kernel is not None: 74 | self.kernel = kernel 75 | 76 | def __call__( 77 | self, y: tf.Tensor, nullmodel: Optional[NullModel] = None 78 | ) -> Tuple[ScoreTestResults, NullModel]: 79 | y = tf.squeeze(y) 80 | if self._yidx is not None: 81 | y = tf.gather(y, self._yidx) 82 | try: 83 | if nullmodel is None: 84 | nullmodel = self._fit_null(y) 85 | stat, e_tilde, I_tau_tau = self._test(y, nullmodel) 86 | return self._calc_test(stat, e_tilde, I_tau_tau), nullmodel 87 | except TypeError as e: 88 | if y.dtype is not default_float(): 89 | raise 90 | # raise TypeError( 91 | # f"Value vector has wrong dtype. Expected: {repr(default_float())}, given: {repr(y.dtype)}" 92 | # ) 93 | else: 94 | raise 95 | 96 | @property 97 | def kernel(self) -> List[Kernel]: 98 | return self.kernel 99 | 100 | @kernel.setter 101 | def kernel(self, kernel: Union[Kernel, List[Kernel]]): 102 | self._kernel = [kernel] if isinstance(kernel, Kernel) else kernel 103 | if len(self._kernel) > 1: 104 | if self.omnibus: 105 | self._K = self._kernel[0].K() 106 | for k in self._kernel[1:]: 107 | self._K += k.K() 108 | else: 109 | self._K = tf.stack([k.K() for k in kernel], axis=0) 110 | else: 111 | self._K = self._kernel[0].K() 112 | self._K = to_default_float(self._K) 113 | self.n = tf.shape(self._K)[0] 114 | 115 | if self._yidx is not None: 116 | x, y = tf.meshgrid(self._yidx, self._yidx) 117 | idx = tf.reshape(tf.stack((y, x), axis=2), (-1, 2)) 118 | if tf.size(tf.shape(self._K)) > 2: 119 | bdim = tf.shape(self._K)[0] 120 | idx = tf.tile(idx, (bdim, 1)) 121 | idx = tf.concat( 122 | ( 123 | tf.repeat( 124 | tf.range(bdim, dtype=self._yidx.dtype), tf.square(tf.size(self._yidx)) 125 | )[:, tf.newaxis], 126 | idx, 127 | ), 128 | axis=1, 129 | ) 130 | self._K = tf.squeeze( 131 | tf.reshape( 132 | tf.gather_nd(self._K, idx), 133 | (-1, tf.size(self._yidx), tf.size(self._yidx)), 134 | ) 135 | ) 136 | 137 | @abstractmethod 138 | def _test( 139 | self, y: tf.Tensor, nullmodel: NullModel 140 | ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: 141 | pass 142 | 143 | @abstractmethod 144 | def _fit_null(self, y: tf.Tensor) -> NullModel: 145 | pass 146 | 147 | @staticmethod 148 | def _calc_test(stat, e_tilde, I_tau_tau) -> ScoreTestResults: 149 | kappa = I_tau_tau / (2 * e_tilde) 150 | nu = 2 * e_tilde**2 / I_tau_tau 151 | pval = tfd.Chi2(nu).survival_function(stat / kappa) 152 | return ScoreTestResults(kappa, nu, stat, e_tilde, I_tau_tau, pval) 153 | 154 | 155 | class NegativeBinomialScoreTest(ScoreTest): 156 | @dataclass 157 | class NullModel(ScoreTest.NullModel): 158 | mu: tf.Tensor 159 | alpha: tf.Tensor 160 | 161 | def __init__( 162 | self, 163 | sizefactors: tf.Tensor, 164 | omnibus: bool = False, 165 | kernel: Optional[Union[Kernel, List[Kernel]]] = None, 166 | ): 167 | self.sizefactors = tf.squeeze(tf.cast(sizefactors, tf.float64)) 168 | if tf.rank(self.sizefactors) > 1: 169 | raise ValueError("Size factors vector must have rank 1") 170 | 171 | yidx = tf.cast(tf.squeeze(tf.where(self.sizefactors > 0)), tf.int32) 172 | if tf.shape(yidx)[0] != tf.shape(self.sizefactors)[0]: 173 | self.sizefactors = tf.gather(self.sizefactors, yidx) 174 | else: 175 | yidx = None 176 | super().__init__(omnibus, kernel, yidx) 177 | 178 | def _fit_null(self, y: tf.Tensor) -> NullModel: 179 | scaledy = tf.cast(y, tf.float64) / self.sizefactors 180 | res = minimize( 181 | lambda *args: self._negative_negbinom_loglik(*args).numpy(), 182 | x0=[ 183 | tf.math.log(tf.reduce_mean(scaledy)), 184 | tf.math.log( 185 | tf.maximum(1e-8, self._moments_dispersion_estimate(scaledy, self.sizefactors)) 186 | ), 187 | ], 188 | args=(tf.cast(y, tf.float64), self.sizefactors), 189 | jac=lambda *args: self._grad_negative_negbinom_loglik(*args).numpy(), 190 | method="bfgs", 191 | ) 192 | mu = tf.exp(res.x[0]) * self.sizefactors 193 | alpha = tf.exp(res.x[1]) 194 | return self.NullModel(mu, alpha) 195 | 196 | def _test( 197 | self, y: tf.Tensor, nullmodel: NullModel 198 | ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: 199 | return self._do_test( 200 | self._K, 201 | to_default_float(y), 202 | to_default_float(nullmodel.alpha), 203 | to_default_float(nullmodel.mu), 204 | ) 205 | 206 | @staticmethod 207 | @tf.function(experimental_compile=True) 208 | def _do_test( 209 | K: tf.Tensor, rawy: tf.Tensor, alpha: tf.Tensor, mu: tf.Tensor 210 | ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: 211 | W = mu / (1 + alpha * mu) 212 | stat = 0.5 * tf.reduce_sum( 213 | (rawy / mu - 1) * W * tf.tensordot(K, W * (rawy / mu - 1), axes=(-1, -1)), axis=-1 214 | ) 215 | 216 | P = tf.linalg.diag(W) - W[:, tf.newaxis] * W[tf.newaxis, :] / tf.reduce_sum(W) 217 | PK = W[:, tf.newaxis] * K - W[:, tf.newaxis] * ((W[tf.newaxis, :] @ K) / tf.reduce_sum(W)) 218 | trace_PK = tf.linalg.trace(PK) 219 | e_tilde = 0.5 * trace_PK 220 | I_tau_tau = 0.5 * tf.reduce_sum(PK * PK, axis=(-2, -1)) 221 | I_tau_theta = 0.5 * tf.reduce_sum(PK * P, axis=(-2, -1)) 222 | I_theta_theta = 0.5 * tf.reduce_sum(tf.square(P), axis=(-2, -1)) 223 | I_tau_tau_tilde = I_tau_tau - tf.square(I_tau_theta) / I_theta_theta 224 | 225 | return stat, e_tilde, I_tau_tau_tilde 226 | 227 | @staticmethod 228 | @tf.function(experimental_compile=True) 229 | def _moments_dispersion_estimate(y, sizefactors): 230 | """ 231 | This is lifted from the first DESeq paper 232 | """ 233 | v = tf.math.reduce_variance(y) 234 | m = tf.reduce_mean(y) 235 | s = tf.reduce_mean(1 / sizefactors) 236 | return (v - s * m) / tf.square(m) 237 | 238 | @staticmethod 239 | @tf.function(experimental_compile=True) 240 | def _negative_negbinom_loglik(params, counts, sizefactors): 241 | logmu = params[0] 242 | logalpha = params[1] 243 | mus = tf.exp(logmu) * sizefactors 244 | nexpalpha = tf.exp(-logalpha) 245 | ct_plus_alpha = counts + nexpalpha 246 | return -tf.reduce_sum( 247 | tf.math.lgamma(ct_plus_alpha) 248 | - tf.math.lgamma(nexpalpha) 249 | - ct_plus_alpha * tf.math.log(1 + tf.exp(logalpha) * mus) 250 | + counts * logalpha 251 | + counts * tf.math.log(mus) 252 | - tf.math.lgamma(counts + 1) 253 | ) 254 | 255 | @staticmethod 256 | @tf.function(experimental_compile=True) 257 | def _grad_negative_negbinom_loglik(params, counts, sizefactors): 258 | logmu = params[0] 259 | logalpha = params[1] 260 | mu = tf.exp(logmu) 261 | mus = mu * sizefactors 262 | nexpalpha = tf.exp(-logalpha) 263 | one_alpha_mu = 1 + tf.exp(logalpha) * mus 264 | 265 | grad0 = tf.reduce_sum((counts - mus) / one_alpha_mu) # d/d_mu 266 | grad1 = tf.reduce_sum( 267 | nexpalpha 268 | * ( 269 | tf.math.digamma(nexpalpha) 270 | - tf.math.digamma(counts + nexpalpha) 271 | + tf.math.log(one_alpha_mu) 272 | ) 273 | + (counts - mus) / one_alpha_mu 274 | ) # d/d_alpha 275 | return -tf.convert_to_tensor((grad0, grad1)) 276 | 277 | 278 | class NormalScoreTest(ScoreTest): 279 | @dataclass 280 | class NullModel(ScoreTest.NullModel): 281 | mu: tf.Tensor 282 | sigmasq: tf.Tensor 283 | 284 | def _fit_null(self, y: tf.Tensor) -> NullModel: 285 | return self.NullModel(tf.reduce_mean(y), tf.reduce_variance(y)) 286 | 287 | def _test( 288 | self, y: tf.Tensor, nullmodel: NullModel 289 | ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: 290 | return self._do_test( 291 | self._K, 292 | to_default_float(y), 293 | to_default_float(nullmodel.sigmasq), 294 | to_default_float(nullmodel.mu), 295 | ) 296 | 297 | @staticmethod 298 | @tf.function(experimental_compile=True) 299 | def _do_test( 300 | K: tf.Tensor, rawy: tf.Tensor, sigmasq: tf.Tensor, mu: tf.Tensor 301 | ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: 302 | W = 1 / sigmasq # W^-1 303 | stat = 0.5 * tf.reduce_sum( 304 | (rawy - mu) * W * tf.tensordot(K, W * (rawy - mu), axes=(-1, -1)), axis=-1 305 | ) 306 | 307 | P = tf.linalg.diag(W) - W[:, tf.newaxis] * W[tf.newaxis, :] / tf.reduce_sum(W) 308 | PK = W[:, tf.newaxis] * K - W[:, tf.newaxis] * ((W[tf.newaxis, :] @ K) / tf.reduce_sum(W)) 309 | trace_PK = tf.linalg.trace(PK) 310 | e_tilde = 0.5 * trace_PK 311 | I_tau_tau = 0.5 * tf.reduce_sum(PK * PK, axis=(-2, -1)) 312 | I_tau_theta = 0.5 * tf.reduce_sum(PK * P, axis=(-2, -1)) 313 | I_theta_theta = 0.5 * tf.reduce_sum(tf.square(P), axis=(-2, -1)) 314 | I_tau_tau_tilde = I_tau_tau - tf.square(I_tau_theta) / I_theta_theta 315 | 316 | return stat, e_tilde, I_tau_tau_tilde 317 | -------------------------------------------------------------------------------- /SpatialDE/_internal/sm_kernel.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_probability as tfp 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | from gpflow import Parameter 7 | from gpflow.kernels import Stationary, Sum 8 | from gpflow.utilities import positive 9 | from gpflow.utilities.ops import square_distance, difference_matrix 10 | from gpflow.utilities import to_default_float 11 | 12 | 13 | class Spectral(Stationary): 14 | def __init__(self, variance=1.0, lengthscales=1, periods=1, **kwargs): 15 | super().__init__( 16 | variance=variance, 17 | lengthscales=Parameter(lengthscales, transform=positive(lower=to_default_float(1e-6))), 18 | **kwargs, 19 | ) 20 | self.periods = Parameter(periods, transform=positive(lower=to_default_float(1e-6))) 21 | 22 | self._validate_ard_active_dims(self.periods) 23 | 24 | @property 25 | def ard(self) -> bool: 26 | return self.lengthscales.shapes.ndims > 0 or self.periods.shapes.ndims > 0 27 | 28 | def K(self, X, X2=None): 29 | return self.variance * self.K_novar(X, X2) 30 | 31 | def K_novar(self, X, X2=None): 32 | Xs = X / self.periods 33 | if X2 is not None: 34 | X2s = X2 / self.periods 35 | else: 36 | X2s = None 37 | dist = difference_matrix(Xs, X2s) 38 | cospart = tf.cos(2 * np.pi * tf.reduce_sum(dist, axis=-1)) 39 | 40 | dist = square_distance(self.scale(X), self.scale(X2)) 41 | exppart = tf.exp(-0.5 * dist) 42 | 43 | return cospart * exppart 44 | 45 | def K_diag(self, X): 46 | return tf.fill(tf.shape(X)[:-1], tf.squeeze(self.variance)) 47 | 48 | def log_power_spectrum(self, s): 49 | if not tf.is_tensor(s): 50 | s = tf.convert_to_tensor(s, dtype=self.variance.dtype) 51 | else: 52 | s = tf.cast(s, dtype=self.variance.dtype) 53 | if s.ndim < 2: 54 | s = tf.expand_dims(s, 1) 55 | loc = tf.broadcast_to(self.periods, (s.shape[1],)) 56 | scale_diag = tf.broadcast_to(self.lengthscales / (0.25 * np.pi**2), (s.shape[1],)) 57 | mvd = tfp.distributions.MultivariateNormalDiag(loc=1 / loc, scale_diag=1 / scale_diag) 58 | return tf.math.log(tf.constant(0.5, dtype=self.variance.dtype)) + tf.reduce_logsumexp( 59 | [mvd.log_prob(s), mvd.log_prob(-s)], axis=0 60 | ) 61 | 62 | 63 | class SpectralMixture(Sum): 64 | def __init__(self, kernels=None, dimnames=None, **kwargs): 65 | if kernels is None: 66 | kernels = [Spectral()] 67 | elif isinstance(kernels, list): 68 | if not all([isinstance(k, Spectral) for k in kernels]): 69 | raise ValueError("Not all kernels are Spectral") 70 | else: 71 | kernels = [Spectral() for _ in range(kernels)] 72 | super().__init__(kernels) 73 | 74 | if dimnames is None: 75 | self.dimnames = ("X", "Y") 76 | else: 77 | self.dimnames = dimnames 78 | 79 | def K_novar(self, X, X2=None): 80 | return self._reduce([k.K_novar(X, X2) for k in self.kernels]) 81 | 82 | def log_power_spectrum(self, s): 83 | dens = [] 84 | for k in self.kernels: 85 | dens.append(k.variance * k.log_power_spectrum(s)) 86 | return tf.reduce_logsumexp(dens, axis=0) 87 | 88 | def plot_power_spectrum(self, xlim: float = None, ylim: float = None, **kwargs): 89 | if xlim is None or ylim is None: 90 | lengthscales = tf.convert_to_tensor([k.lengthscales for k in self.kernels]) 91 | if lengthscales.ndim < 2: 92 | lengthscales = tf.tile(tf.expand_dims(lengthscales, axis=1), (1, 2)) 93 | periods = tf.convert_to_tensor([k.periods for k in self.kernels]) 94 | if periods.ndim < 2: 95 | periods = tf.tile(tf.expand_dims(periods, axis=1), (1, 2)) 96 | maxfreq = tf.math.argmin(periods, axis=0, output_type=tf.int32) 97 | maxfreq = tf.stack([maxfreq, tf.range(maxfreq.shape[0], dtype=tf.int32)], axis=1) 98 | limits = 1 / tf.gather_nd(periods, maxfreq) 99 | limits += 2 * tf.gather_nd(lengthscales, maxfreq) 100 | if xlim is None: 101 | xlim = limits[0].numpy() 102 | else: 103 | xlim = np.asarray([xlim])[0] 104 | if ylim is None: 105 | ylim = limits[1].numpy() 106 | else: 107 | ylim = np.asarray([ylim])[0] 108 | 109 | limtype = np.promote_types(xlim.dtype, ylim.dtype) 110 | xlim = xlim.astype(limtype) 111 | ylim = ylim.astype(limtype) 112 | 113 | dim = max( 114 | [k.lengthscales.ndim for k in self.kernels] + [k.periods.ndim for k in self.kernels] 115 | ) 116 | fig, ax = plt.subplots() 117 | if dim < 1: 118 | x = tf.linspace(0, xlim, 1000) 119 | ps = self.log_power_spectrum(x) 120 | ax.plot(x, ps) 121 | ax.set_xlim((0, xlim.numpy())) 122 | ax.set_xlabel("frequency") 123 | ax.set_ylabel("log spectral density") 124 | else: 125 | x, y = tf.meshgrid(tf.linspace(0.0, xlim, 1000), tf.linspace(0.0, ylim, 1000)) 126 | ps = tf.reshape( 127 | self.log_power_spectrum( 128 | tf.stack([tf.reshape(x, (-1,)), tf.reshape(y, (-1,))], axis=1) 129 | ), 130 | x.shape, 131 | ) 132 | pos = ax.contourf( 133 | x, y, ps, levels=tf.linspace(tf.reduce_min(ps), tf.reduce_max(ps), 100), **kwargs 134 | ) 135 | ax.set_xlabel(f"{self.dimnames[0]} frequency") 136 | ax.set_ylabel(f"{self.dimnames[1]} frequency") 137 | cbar = fig.colorbar(pos) 138 | cbar.ax.set_ylabel("log spectral density") 139 | return ax 140 | 141 | def __iter__(self): 142 | self._i = 0 143 | return self 144 | 145 | def __next__(self): 146 | if self._i < len(self.kernels): 147 | k = self.kernels[self._i] 148 | self._i += 1 149 | return k 150 | else: 151 | raise StopIteration 152 | -------------------------------------------------------------------------------- /SpatialDE/_internal/svca.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import Optional, List 3 | 4 | import numpy as np 5 | import scipy 6 | 7 | import tensorflow as tf 8 | import gpflow 9 | from gpflow import default_float 10 | from gpflow.utilities import to_default_float 11 | 12 | from .util import gower_factor, quantile_normalize 13 | from .score_test import ScoreTest 14 | from .optimizer import MultiScipyOptimizer 15 | 16 | 17 | class SVCA(tf.Module): 18 | _fracvar = namedtuple("FractionVariance", "intrinsic environmental noise") 19 | _fracvar_interact = namedtuple("FractionVraiance", "intrinsic environmental interaction noise") 20 | 21 | def __init__( 22 | self, 23 | expression: np.ndarray, 24 | X: np.ndarray, 25 | sizefactors: np.ndarray, 26 | kernel: Optional[gpflow.kernels.Kernel] = None, 27 | ): 28 | self.expression = to_default_float(expression) 29 | self.sizefactors = to_default_float(sizefactors) 30 | self._ncells, self._ngenes = tf.shape(self.expression) 31 | self._X = to_default_float(X) 32 | 33 | self._current_expression = tf.Variable( 34 | tf.zeros((self._ncells, self._ngenes - 1), dtype=default_float()), trainable=False 35 | ) 36 | self.intrinsic_variance_matrix = tf.Variable( 37 | tf.zeros((self._ncells, self._ncells), dtype=default_float()), trainable=False 38 | ) 39 | self._sigmas = gpflow.Parameter( 40 | tf.ones((4,), dtype=default_float()), transform=gpflow.utilities.positive(lower=1e-9) 41 | ) 42 | self._currentgene = tf.Variable(0, dtype=tf.int32, trainable=False) 43 | self.muhat = tf.Variable(tf.ones((self._ncells,), dtype=default_float()), trainable=False) 44 | 45 | self.kernel = kernel 46 | 47 | self._opt = MultiScipyOptimizer(lambda: -self.profile_log_reml(), self.trainable_variables) 48 | self._use_interactions = tf.Variable(False, dtype=tf.bool, trainable=False) 49 | 50 | self._old_interactions = False 51 | 52 | @property 53 | def sizefactors(self) -> np.ndarray: 54 | return self._sizefactors 55 | 56 | @sizefactors.setter 57 | def sizefactors(self, sizefactors: np.ndarray): 58 | self._sizefactors = np.squeeze(sizefactors) 59 | if len(self._sizefactors.shape) != 1: 60 | raise ValueError("Size factors vector must have rank 1") 61 | self._log_sizefactors = tf.squeeze(tf.math.log(to_default_float(sizefactors))) 62 | 63 | @property 64 | def kernel(self): 65 | return self._kernel 66 | 67 | @kernel.setter 68 | def kernel(self, kern): 69 | self._kernel = kern 70 | self._init_kern = gpflow.utilities.read_values(kern) 71 | 72 | @property 73 | def currentgene(self) -> int: 74 | return self._currentgene.numpy() 75 | 76 | @currentgene.setter 77 | def currentgene(self, gene: int): 78 | gene = tf.cast(gene, self._ngenes.dtype) 79 | if gene < 0 or gene >= self._ngenes: 80 | raise IndexError(f"gene must be between 0 and {self._ngenes}") 81 | 82 | self._currentgene.assign(gene) 83 | idx = [] if gene == 0 else tf.range(gene) 84 | idx = ( 85 | tf.concat((idx, tf.range(gene + 1, self._ngenes)), axis=0) 86 | if gene < self._ngenes - 1 87 | else idx 88 | ) 89 | self._current_expression.assign( 90 | tf.gather(self.expression, idx, axis=1) / self._sizefactors[:, tf.newaxis] 91 | ) 92 | 93 | intvar = tf.matmul(self._current_expression, self._current_expression, transpose_b=True) 94 | self.intrinsic_variance_matrix.assign(intvar / gower_factor(intvar)) 95 | 96 | muhat = self.expression[:, gene] 97 | muhat = tf.where(muhat < 1, 1, muhat) # avoid problems with log link 98 | self.muhat.assign(muhat) 99 | 100 | self._sigmas.assign( 101 | tf.fill( 102 | (4,), 103 | 0.25 104 | * tf.math.reduce_variance( 105 | tf.math.log(self.expression[:, gene] + 1) - self._log_sizefactors 106 | ), 107 | ) 108 | ) 109 | if self._kernel is not None: 110 | gpflow.utilities.multiple_assign(self._kernel, self._init_kern) 111 | 112 | def profile_log_reml(self): 113 | Vchol = tf.linalg.cholesky(self.V()) 114 | r = self._r(Vchol) 115 | quad = tf.tensordot( 116 | r, 117 | tf.squeeze(tf.linalg.cholesky_solve(Vchol, r[:, tf.newaxis])), 118 | axes=(-1, -1), 119 | ) 120 | ldet = tf.reduce_sum(tf.math.log(tf.linalg.diag_part(Vchol))) 121 | ldet2 = tf.math.log( 122 | tf.reduce_sum( 123 | tf.linalg.cholesky_solve(Vchol, tf.ones((self._ncells, 1), dtype=default_float())) 124 | ) 125 | ) 126 | 127 | return -ldet - 0.5 * quad - 0.5 * ldet2 128 | 129 | def _alphahat(self, Vchol): 130 | Vinvnu = tf.linalg.cholesky_solve(Vchol, self.nu[:, tf.newaxis]) 131 | VinvX = tf.linalg.cholesky_solve(Vchol, tf.ones((self._ncells, 1), dtype=default_float())) 132 | return tf.reduce_sum(Vinvnu) / tf.reduce_sum(VinvX) 133 | 134 | def alphahat(self): 135 | return self._alphahat(tf.linalg.cholesky(self.V())) 136 | 137 | def _betahat(self, Vchol): 138 | return tf.squeeze(self.D() @ tf.linalg.cholesky_solve(Vchol, self._r(Vchol)[:, tf.newaxis])) 139 | 140 | def betahat(self): 141 | return self._betahat(tf.linalg.cholesky(self.V())) 142 | 143 | @property 144 | def nu(self): 145 | return ( 146 | tf.math.log(self.muhat) 147 | + self.expression[:, self._currentgene] / self.muhat 148 | - 1 149 | - self._log_sizefactors 150 | ) 151 | 152 | def _r(self, Vchol): 153 | return self.nu - self._alphahat(Vchol) 154 | 155 | def r(self): 156 | Vchol = tf.linalg.cholesky(self.V()) 157 | return self._r(Vchol) 158 | 159 | @tf.function 160 | def estimate_muhat(self): 161 | Vchol = tf.linalg.cholesky(self.V()) 162 | self.muhat.assign( 163 | tf.exp(self._alphahat(Vchol) + self._betahat(Vchol) + self._log_sizefactors) 164 | ) 165 | 166 | def V(self): 167 | V = self.D() 168 | V = tf.linalg.set_diag(V, tf.linalg.diag_part(V) + 1 / self.muhat) 169 | return V 170 | 171 | # no property here, apparently tf.function has a problem with conditionals in properties 172 | def D(self): 173 | var = self.intrinsic_variance + self.environmental_variance 174 | var = tf.linalg.set_diag(var, tf.linalg.diag_part(var) + self.noise_variance) 175 | if self._use_interactions: 176 | var += self.interaction_variance 177 | return var 178 | 179 | def dV_dsigma(self): 180 | if self._use_interactions: 181 | return tf.stack( 182 | ( 183 | self.intrinsic_variance_matrix, 184 | self.environmental_variance_matrix, 185 | self.interaction_variance_matrix, 186 | tf.eye(self._ncells, dtype=default_float()), 187 | ), 188 | axis=0, 189 | ) 190 | else: 191 | return tf.stack( 192 | ( 193 | self.intrinsic_variance_matrix, 194 | self.environmental_variance_matrix, 195 | tf.eye(self._ncells, dtype=default_float()), 196 | ), 197 | axis=0, 198 | ) 199 | 200 | def fraction_variance(self): 201 | intrinsic = gower_factor(self.intrinsic_variance) 202 | environ = gower_factor(self.environmental_variance) 203 | noise = self.noise_variance 204 | 205 | totalgower = intrinsic + environ + noise 206 | if self._use_interactions: 207 | interact = gower_factor(self.interaction_variance) 208 | totalgower += interact 209 | 210 | return self._fracvar_interact( 211 | (intrinsic / totalgower).numpy(), 212 | (environ / totalgower).numpy(), 213 | (interact / totalgower).numpy(), 214 | (noise / totalgower).numpy(), 215 | ) 216 | else: 217 | return self._fracvar( 218 | (intrinsic / totalgower).numpy(), 219 | (environ / totalgower).numpy(), 220 | (noise / totalgower).numpy(), 221 | ) 222 | 223 | @property 224 | def environmental_variance_matrix(self): 225 | return self.kernel.K(self._X) 226 | 227 | @property 228 | def interaction_variance_matrix(self): 229 | envmat = self.environmental_variance_matrix 230 | intmat = envmat @ tf.matmul(self.intrinsic_variance_matrix, envmat, transpose_b=True) 231 | return intmat / gower_factor(intmat) 232 | 233 | @property 234 | def intrinsic_variance(self): 235 | return self._sigmas[0] * self.intrinsic_variance_matrix 236 | 237 | @property 238 | def environmental_variance(self): 239 | return self._sigmas[1] * self.environmental_variance_matrix 240 | 241 | @property 242 | def interaction_variance(self): 243 | return self._sigmas[2] * self.interaction_variance_matrix 244 | 245 | @property 246 | def noise_variance(self): 247 | return self._sigmas[3] 248 | 249 | def use_interactions(self, interact: bool): 250 | self._old_interactions = self._use_interactions.numpy() 251 | self._use_interactions.assign(interact) 252 | return self 253 | 254 | def __enter__(self): 255 | return self 256 | 257 | def __exit__(self, *args): 258 | self._use_interactions.assign(self._old_interactions) 259 | 260 | def optimize(self, abstol: float = 1e-5, maxiter: int = 1000): 261 | oldsigmas = self._sigmas.numpy() 262 | for i in range(maxiter): 263 | self._opt.minimize() 264 | sigmas = self._sigmas.numpy() 265 | if np.all(np.abs(sigmas - oldsigmas) < abstol): 266 | break 267 | oldsigmas = sigmas 268 | self.estimate_muhat() 269 | 270 | 271 | class SVCAInteractionScoreTest(ScoreTest): 272 | def __init__( 273 | self, 274 | expression_matrix: np.ndarray, 275 | X: np.ndarray, 276 | sizefactors: np.ndarray, 277 | kernel: Optional[gpflow.kernels.Kernel] = None, 278 | ): 279 | super().__init__() 280 | self._model = SVCA(expression_matrix, X, sizefactors, kernel) 281 | self._model.use_interactions(False) 282 | 283 | @property 284 | def kernel(self) -> List[gpflow.kernels.Kernel]: 285 | if self._model.kernel is not None: 286 | return [self._model.kernel] 287 | else: 288 | return [] 289 | 290 | @kernel.setter 291 | def kernel(self, kernel: gpflow.kernels.Kernel): 292 | self._model.kernel = kernel 293 | 294 | def _fit_null(self, y): 295 | self._model.currentgene = y 296 | self._model.optimize() 297 | return None 298 | 299 | def _test(self, y, nullmodel: None): 300 | return self._do_test( 301 | self._model.r(), 302 | self._model.V(), 303 | self._model.dV_dsigma(), 304 | self._model.interaction_variance_matrix, 305 | ) 306 | 307 | @staticmethod 308 | @tf.function(experimental_compile=True) 309 | def _do_test(residual, V, dV, interaction_mat): 310 | cholV = tf.linalg.cholesky(V) 311 | Vinvres = tf.squeeze(tf.linalg.cholesky_solve(cholV, residual[:, tf.newaxis])) 312 | stat = 0.5 * tf.tensordot( 313 | Vinvres, tf.tensordot(interaction_mat, Vinvres, axes=(-1, -1)), axes=(-1, -1) 314 | ) 315 | 316 | Vinv_int = tf.linalg.cholesky_solve(cholV, interaction_mat) 317 | Vinv_dV = tf.linalg.cholesky_solve(cholV[tf.newaxis, ...], dV) 318 | 319 | Vinv_X = tf.squeeze( 320 | tf.linalg.cholesky_solve( 321 | cholV, tf.ones((tf.shape(residual)[0], 1), dtype=default_float()) 322 | ) 323 | ) 324 | hatMat = Vinv_X[:, tf.newaxis] * Vinv_X[tf.newaxis, :] / tf.reduce_sum(Vinv_X) 325 | 326 | P_int = Vinv_int - hatMat @ interaction_mat 327 | P_dV = Vinv_dV - hatMat[tf.newaxis, ...] @ dV 328 | 329 | e_tilde = 0.5 * tf.linalg.trace(P_int) 330 | I_tau_tau = tf.reduce_sum(tf.transpose(P_int) * P_int) 331 | I_tau_theta = tf.reduce_sum(tf.transpose(P_int) * P_dV, axis=[-2, -1]) 332 | I_theta_theta = tf.reduce_sum( 333 | tf.linalg.matrix_transpose(P_dV[tf.newaxis, ...]) * P_dV[:, tf.newaxis, ...], 334 | axis=[-2, -1], 335 | ) 336 | 337 | I_tau_tau_tilde = 0.5 * ( 338 | I_tau_tau 339 | - tf.tensordot( 340 | I_tau_theta, 341 | tf.squeeze(tf.linalg.solve(I_theta_theta, I_tau_theta[..., tf.newaxis])), 342 | axes=(-1, -1), 343 | ) 344 | ) 345 | 346 | return stat, e_tilde, I_tau_tau_tilde 347 | -------------------------------------------------------------------------------- /SpatialDE/_internal/tf_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import numpy as np 4 | from scipy.sparse import issparse 5 | import tensorflow as tf 6 | 7 | from anndata import AnnData 8 | 9 | 10 | class AnnDataIterator: 11 | def __init__( 12 | self, 13 | adata: AnnData, 14 | genes: Optional[List[str]] = None, 15 | layer: Optional[str] = None, 16 | dtype: Optional[Union[np.dtype, tf.DType]] = None, 17 | ): 18 | self.adata = adata 19 | self.genes = genes 20 | if self.genes is None: 21 | self.genes = self.adata.var_names 22 | self.layer = layer 23 | if dtype is not None: 24 | outtype = dtype 25 | elif self.layer is None: 26 | outtype = self.adata.X.dtype 27 | else: 28 | outtype = self.adata.layers[layer].dtype 29 | self.output_types = (outtype, tf.string) 30 | 31 | def __call__(self): 32 | for i, g in enumerate(self.genes): 33 | if self.layer is None: 34 | data = self.adata.X[:, i] 35 | else: 36 | data = self.adata.layers[self.layer][:, i] 37 | if issparse(data): 38 | data = data.toarray() 39 | with tf.device(tf.DeviceSpec(device_type="CPU").to_string()): 40 | gene = tf.convert_to_tensor(g) 41 | yield tf.convert_to_tensor(np.squeeze(data), dtype=self.output_types[0]), gene 42 | 43 | 44 | class AnnDataDataset(tf.data.Dataset): 45 | def __new__( 46 | cls, 47 | adata: AnnData, 48 | genes: Optional[List[str]] = None, 49 | layer: Optional[str] = None, 50 | dtype: Optional[Union[np.dtype, tf.DType]] = None, 51 | ): 52 | it = AnnDataIterator(adata, genes, layer, dtype) 53 | return ( 54 | tf.data.Dataset.from_generator(it, output_types=it.output_types) 55 | .repeat(1) 56 | .prefetch(tf.data.experimental.AUTOTUNE) 57 | ) 58 | -------------------------------------------------------------------------------- /SpatialDE/_internal/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.sparse import issparse 3 | import scipy.stats 4 | import pandas as pd 5 | import tensorflow as tf 6 | 7 | import NaiveDE 8 | from anndata import AnnData 9 | 10 | from enum import Enum, auto 11 | from typing import Optional, Union 12 | 13 | import logging 14 | 15 | from .distance_cache import DistanceCache 16 | from .kernels import Linear, SquaredExponential, Cosine 17 | 18 | 19 | def get_dtype(df: pd.DataFrame, msg="Data frame"): 20 | dtys = df.dtypes.unique() 21 | if dtys.size > 1: 22 | logging.warning("%s has more than one dtype, selecting the first one" % msg) 23 | return dtys[0] 24 | 25 | 26 | def normalize_counts( 27 | adata: AnnData, 28 | sizefactorcol: Optional[str] = None, 29 | layer: Optional[str] = None, 30 | copy: bool = False, 31 | ): 32 | if copy: 33 | adata = adata.copy() 34 | 35 | if sizefactorcol is None: 36 | sizefactors = pd.DataFrame({"sizefactors": calc_sizefactors(adata, layer=layer)}) 37 | sizefactorcol = "np.log(sizefactors)" 38 | else: 39 | sizefactors = adata.obs 40 | X = adata.X if layer is None else adata.layers[layer] 41 | stabilized = NaiveDE.stabilize(dense_slice(X.T)) 42 | regressed = np.asarray(NaiveDE.regress_out(sizefactors, stabilized, sizefactorcol).T) 43 | if layer is None: 44 | adata.X = regressed 45 | else: 46 | adata.layers[layer] = regressed 47 | return adata 48 | 49 | 50 | def dense_slice(slice): 51 | if issparse(slice): 52 | slice = slice.toarray() 53 | else: 54 | slice = np.asarray(slice) # work around anndata.ArrayView 55 | return np.squeeze(slice) 56 | 57 | 58 | def bh_adjust(pvals): 59 | order = np.argsort(pvals) 60 | alpha = np.minimum( 61 | 1, np.maximum.accumulate(len(pvals) / np.arange(1, len(pvals) + 1) * pvals[order]) 62 | ) 63 | return alpha[np.argsort(order)] 64 | 65 | 66 | def calc_sizefactors(adata: AnnData, layer=None): 67 | X = adata.X if layer is None else adata.layers[layer] 68 | return np.asarray(X.sum(axis=1)).squeeze() 69 | 70 | 71 | def get_l_limits(cache: DistanceCache): 72 | R2 = cache.squaredEuclideanDistance 73 | R2 = R2[R2 > 1e-8] 74 | 75 | l_min = tf.sqrt(tf.reduce_min(R2)) * 2.0 76 | l_max = tf.sqrt(tf.reduce_max(R2)) 77 | 78 | return l_min, l_max 79 | 80 | 81 | def factory(kern: str, cache: DistanceCache, lengthscale: Optional[float] = None): 82 | if kern == "linear": 83 | return Linear(cache) 84 | elif kern == "SE": 85 | return SquaredExponential(cache, lengthscale=lengthscale) 86 | elif kern == "PER": 87 | return Cosine(cache, lengthscale=lengthscale) 88 | else: 89 | raise ValueError("unknown kernel") 90 | 91 | 92 | def kspace_walk(kernel_space: dict, cache: DistanceCache): 93 | for kern, lengthscales in kernel_space.items(): 94 | try: 95 | for l in lengthscales: 96 | yield factory(kern, cache, l), kern 97 | except TypeError: 98 | yield factory(kern, cache, lengthscales), kern 99 | 100 | 101 | def default_kernel_space(cache: DistanceCache): 102 | l_min, l_max = get_l_limits(cache) 103 | return { 104 | "SE": np.logspace(np.log10(l_min), np.log10(l_max), 5), 105 | "PER": np.logspace(np.log10(l_min), np.log10(l_max), 5), 106 | } 107 | 108 | 109 | def concat_tensors(tens): 110 | return tf.concat([tf.reshape(t, (-1,)) for t in tens], axis=0) 111 | 112 | 113 | def assign_concat(x, vars): 114 | offset = 0 115 | for v in vars: 116 | newval = tf.reshape(x[offset : (offset + tf.size(v))], v.shape) 117 | v.assign(newval) 118 | offset += tf.size(v) 119 | 120 | 121 | def gower_factor(mat, varcomp=1): 122 | """Gower normalization factor for covariance matric K 123 | 124 | Based on https://github.com/PMBio/limix/blob/master/limix/utils/preprocess.py 125 | """ 126 | return ( 127 | varcomp 128 | * (tf.linalg.trace(mat) - tf.reduce_sum(tf.reduce_mean(mat, axis=0))) 129 | / tf.cast(tf.shape(mat)[0] - 1, mat.dtype) 130 | ) 131 | 132 | 133 | def quantile_normalize(mat): 134 | idx = np.argsort(mat, axis=0) + 0.5 135 | return scipy.stats.norm(loc=0, scale=1).ppf(idx / mat.shape[0]) 136 | -------------------------------------------------------------------------------- /SpatialDE/_internal/util_mixture.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import tensorflow as tf 4 | 5 | 6 | @tf.function(experimental_relax_shapes=True) 7 | def prune_components(labels: tf.Tensor, pihat: tf.Tensor, threshold: tf.Tensor, everything=False): 8 | toretain = tf.squeeze(tf.where(tf.reduce_any(pihat > threshold, axis=1)), axis=1) 9 | if not everything: 10 | toretain = tf.range( 11 | tf.reduce_max(toretain) + 1 12 | ) # we can not prune everything during optimization, then vhat3_cumsum would be wrong 13 | return tf.cast(toretain, labels.dtype), labels 14 | return prune_labels(labels, tf.cast(toretain, labels.dtype)) 15 | 16 | 17 | @tf.function(experimental_relax_shapes=True) 18 | def prune_labels(labels: tf.Tensor, toretain: Optional[tf.Tensor] = None): 19 | if toretain is None: 20 | ulabels, _ = tf.unique(labels) 21 | toretain = tf.sort(ulabels) 22 | else: 23 | toretain = tf.sort(toretain) 24 | diffs = toretain[1:] - toretain[:-1] 25 | missing = tf.cast(tf.where(diffs > 1), labels.dtype) 26 | if tf.size(missing) > 0: 27 | missing = tf.squeeze(missing, axis=1) 28 | todrop = tf.TensorArray(labels.dtype, size=tf.size(missing), infer_shape=False) 29 | shift = tf.cast(0, labels.dtype) 30 | for i in tf.range(tf.size(missing)): 31 | m = missing[i] 32 | idx = tf.where(labels > toretain[m] - shift) 33 | shift += diffs[m] - 1 34 | labels = tf.tensor_scatter_nd_sub(labels, idx, tf.repeat(diffs[m] - 1, tf.size(idx))) 35 | todrop = todrop.write(i, tf.range(toretain[m] + 1, toretain[m] + diffs[m])) 36 | todrop = todrop.concat() 37 | if toretain[0] > 0: 38 | todrop = tf.concat((tf.range(toretain[0]), todrop), axis=0) 39 | labels = labels - toretain[0] 40 | idx = tf.squeeze( 41 | tf.sparse.to_dense( 42 | tf.sets.difference( 43 | tf.range(tf.reduce_max(toretain) + 1)[tf.newaxis, :], 44 | tf.cast(todrop[tf.newaxis, :], dtype=labels.dtype), 45 | ) 46 | ) 47 | ) 48 | else: 49 | idx = tf.cast(tf.range(tf.size(toretain)), labels.dtype) 50 | return idx, labels 51 | -------------------------------------------------------------------------------- /SpatialDE/aeh.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Union, Tuple 2 | import warnings 3 | from dataclasses import dataclass 4 | from collections import namedtuple 5 | from collections.abc import Iterable 6 | from numbers import Real, Integral 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | from gpflow import default_float, default_jitter, Parameter, set_trainable 12 | from gpflow.utilities import to_default_float, positive 13 | from gpflow.optimizers import Scipy 14 | 15 | from anndata import AnnData 16 | 17 | from ._internal.kernels import SquaredExponential 18 | from ._internal.util import normalize_counts, get_l_limits 19 | from ._internal.distance_cache import DistanceCache 20 | from ._internal.util_mixture import prune_components, prune_labels 21 | 22 | 23 | @dataclass(frozen=True) 24 | class SpatialPatternParameters: 25 | """ 26 | Parameters for automated expession histology. 27 | 28 | Args: 29 | nclasses: Maximum number of regions to consider. Defaults to the square root of the number of observations. 30 | lengthscales: List of kernel lenthscales. Defaults to a single lengthscale of the minimum distance between 31 | observations. 32 | pattern_prune_threshold: Probability threshold at which unused patterns are removed. Defaults to ``1e-6``. 33 | method: Optimization algorithm, must be known to ``scipy.optimize.minimize``. Defaults to ``l-bfgs-b``. 34 | tol: Convergence tolerance. Defaults to 1e-9. 35 | maxiter: Maximum number of iterations. Defaults to ``1000``. 36 | gamma_1: Parameter of the noise variance prior, defaults to ``1e-14``. 37 | gamma_2: Parameter of the noise variance prior, defaults to ``1e-14``. 38 | eta_1: Parameter of the Dirichlet process hyperprior, defaults to ``1``. 39 | eta_2: Parameter of the Dirichlet process hyperprior, defaults to ``1``. 40 | """ 41 | 42 | nclasses: Optional[Integral] = None 43 | lengthscales: Optional[Union[Real, List[Real]]] = None 44 | pattern_prune_threshold: float = 1e-6 45 | method: str = "l-bfgs-b" 46 | tol: Optional[Real] = 1e-9 47 | maxiter: Integral = 1000 48 | gamma_1: Real = 1e-14 49 | gamma_2: Real = 1e-14 50 | eta_1: Real = 1 51 | eta_2: Real = 1 52 | 53 | def __post_init__(self): 54 | if self.nclasses is not None: 55 | assert not isinstance( 56 | self.lengthscales, Iterable 57 | ), "You must specify either nclasses or a list of lengthscales" 58 | if isinstance(self.lengthscales, Real): 59 | assert self.lengthscales > 0, "Lengthscales must be positive" 60 | elif self.lengthscales is not None: 61 | for l in self.lengthscales: 62 | assert l > 0, "Lengthscales must be positive" 63 | assert ( 64 | self.pattern_prune_threshold >= 0 and self.pattern_prune_threshold <= 1 65 | ), "Class pruning threshold must be between 0 and 1" 66 | assert self.method in ("l-bfgs-b", "bfgs"), "Method must be either bfgs or l-bfgs-b" 67 | if self.tol is not None: 68 | assert self.tol > 0, "Tolerance must be greater than 0" 69 | assert self.maxiter >= 1, "Maximum number of iterations must greater than or equal to 1" 70 | assert self.gamma_1 > 0, "Gamma1 hyperparameter must be positive" 71 | assert self.gamma_2 > 0, "Gamma2 hyperparameter must be positive" 72 | assert self.eta_1 > 0, "Eta1 hyperparameter must be positive" 73 | assert self.eta_2 > 0, "Eta2 hyperparameter must be positive" 74 | 75 | 76 | @dataclass(frozen=True) 77 | class SpatialPatterns: 78 | """ 79 | Results of automated expression histology. 80 | 81 | Args: 82 | converged: Whether the optimization converged. 83 | status: Status of the optimization. 84 | labels: The estimated region labels. 85 | pattern_probabilities: N_obs x N_patterns array with the estimated region probabilities for each observation. 86 | niter: Number of iterations for the optimization. 87 | elbo_trace: ELBO values at each iteration. 88 | """ 89 | 90 | converged: bool 91 | status: str 92 | labels: np.ndarray 93 | pattern_probabilities: np.ndarray 94 | patterns: np.ndarray 95 | niter: int 96 | elbo_trace: np.ndarray 97 | 98 | 99 | class _SpatialPatterns(tf.Module): 100 | def __init__( 101 | self, 102 | X: np.ndarray, 103 | counts: np.ndarray, 104 | nclasses: Integral, 105 | lengthscales: List[Real], 106 | gamma_1: Real, 107 | gamma_2: Real, 108 | eta_1: Real, 109 | eta_2: Real, 110 | rng: np.random.Generator = np.random.default_rng(), 111 | ): 112 | self.X = to_default_float(X) 113 | self.counts = to_default_float(counts) 114 | self.nsamples, self.ngenes = tf.shape(counts) 115 | self._fnsamples, self._fngenes = to_default_float(tf.shape(counts)) 116 | self.nclasses = nclasses 117 | self._fnclasses = to_default_float(self.nclasses) 118 | 119 | self.gamma_1 = to_default_float(gamma_1) 120 | self.gamma_2 = to_default_float(gamma_2) 121 | self.eta_1 = to_default_float(eta_1) 122 | self.eta_2 = to_default_float(eta_2) 123 | 124 | dcache = DistanceCache(X) 125 | if lengthscales is None: 126 | l_min, l_max = get_l_limits(dcache) 127 | lengthscales = [0.5 * l_min] 128 | elif not isinstance(lengthscales, Iterable): 129 | lengthscales = [lengthscales] 130 | self.kernels = [] 131 | Kernel = namedtuple("DecomposedKernel", "Lambda U") 132 | 133 | lcounts = np.unique(lengthscales, return_counts=True) 134 | for l, c in zip(*lcounts): 135 | k = SquaredExponential(dcache, lengthscale=l).K() 136 | S, U = tf.linalg.eigh(k) 137 | self.kernels.extend([Kernel(S, U)] * c) 138 | if len(self.kernels) == 1: 139 | self.kernels = self.kernels * nclasses 140 | 141 | self.phi = Parameter( 142 | rng.uniform(-0.01, 0.01, (self.ngenes, self.nclasses)), dtype=default_float() 143 | ) # to break ties, otherwise all gradients are the same 144 | self.etahat_2 = Parameter( 145 | eta_2 + default_jitter(), dtype=default_float(), transform=positive(lower=eta_2) 146 | ) 147 | self.gammahat_2 = Parameter( 148 | self.gammahat_1, 149 | dtype=default_float(), 150 | transform=positive(lower=gamma_2), 151 | ) 152 | 153 | @property 154 | def etahat_1(self): 155 | return self.eta_1 + self._fnclasses - 1 156 | 157 | @property 158 | def pihat(self): 159 | return tf.nn.softmax(self.phi, axis=1) 160 | 161 | @property 162 | def gammahat_1(self): 163 | return self.gamma_1 + 0.5 * self._fnsamples * self._fngenes 164 | 165 | @property 166 | def _sigmahat(self): 167 | return self.gammahat_1 / self.gammahat_2 168 | 169 | def Sigma_hat_inv(self, c=0): 170 | k = self.kernels[c] 171 | EigUt = k.U * ((self._sigmahat * self._N(c) * k.Lambda + 1) / k.Lambda)[tf.newaxis, :] 172 | return tf.matmul(k.U, EigUt, transpose_b=True) 173 | 174 | def Sigma_hat(self, c=0): 175 | k = self.kernels[c] 176 | EigUt = k.U * (k.Lambda / (self._sigmahat * self._N(c) * k.Lambda + 1))[tf.newaxis, :] 177 | return tf.matmul(k.U, EigUt, transpose_b=True) 178 | 179 | @property 180 | def mu_hat(self): 181 | ybar = self._ybar() 182 | return tf.stack( 183 | [self._mu_hat(c, ybar=ybar[:, c]) for c in tf.range(self.nclasses)], 184 | axis=1, 185 | ) 186 | 187 | def _mu_hat(self, c=None, Sigma_hat=None, ybar=None): 188 | assert c is not None or Sigma_hat is not None and ybar is not None 189 | if Sigma_hat is None: 190 | Sigma_hat = self.Sigma_hat(c) 191 | if ybar is None: 192 | ybar = self._ybar(c) 193 | return self._sigmahat * tf.tensordot(Sigma_hat, ybar, axes=(-1, -1)) 194 | 195 | def _ybar(self, c=None): 196 | if c is None: 197 | return self.counts @ self.pihat 198 | else: 199 | return tf.tensordot(self.counts, self.pihat[:, c], axes=(-1, -1)) 200 | 201 | def _N(self, c=None): 202 | if c is None: 203 | return tf.reduce_sum(self.pihat, axis=0) 204 | else: 205 | return tf.reduce_sum(self.pihat[:, c]) 206 | 207 | @property 208 | def _lhat(self): 209 | return tf.math.log(self.gammahat_2) - tf.math.digamma(self.gammahat_1) 210 | 211 | @property 212 | def _alphahat(self): 213 | return self.etahat_1 / self.etahat_2 214 | 215 | @property 216 | def _alphahat1(self): 217 | return 1 + self._N()[:-1] 218 | 219 | @property 220 | def _alphahat2(self): 221 | pihat_cumsum = tf.cumsum(tf.reduce_sum(self.pihat, axis=0), reverse=True) 222 | return pihat_cumsum[1:] + self._alphahat 223 | 224 | @property 225 | def _vhat2(self): 226 | return tf.math.digamma(self._alphahat1) - tf.math.digamma(self._alphahat1 + self._alphahat2) 227 | 228 | @property 229 | def _vhat3(self): 230 | return tf.math.digamma(self._alphahat2) - tf.math.digamma(self._alphahat1 + self._alphahat2) 231 | 232 | def elbo(self): 233 | pihat = self.pihat 234 | dotcounts = tf.reduce_sum(tf.square(self.counts), axis=0) 235 | sigmahat = self._sigmahat 236 | ybar = self._ybar() 237 | lhat = self._lhat 238 | N = self._N() 239 | 240 | term1 = 0.5 * (self._fnsamples * self._fngenes * lhat + sigmahat * tf.reduce_sum(dotcounts)) 241 | 242 | term2 = 0 243 | for c in range(self.nclasses): 244 | k = self.kernels[c] 245 | UTybar = tf.tensordot(k.U, ybar[:, c], axes=(0, 0)) 246 | Lambdahat = sigmahat * N[c] * k.Lambda + 1 247 | 248 | ybar_mu = tf.square(sigmahat) * tf.tensordot( 249 | UTybar, k.Lambda / Lambdahat * UTybar, axes=(-1, -1) 250 | ) 251 | pimu = N[c] * ( 252 | sigmahat**3 253 | * tf.tensordot(UTybar * tf.square(k.Lambda / Lambdahat), UTybar, axes=(-1, -1)) 254 | + tf.reduce_sum(k.Lambda / Lambdahat) 255 | ) 256 | muinv_ybar = tf.square(sigmahat) * tf.tensordot( 257 | UTybar * k.Lambda / tf.square(Lambdahat), UTybar, axes=(-1, -1) 258 | ) 259 | trace_inv = tf.reduce_sum(1 / Lambdahat) 260 | logdet = tf.reduce_sum(tf.math.log(Lambdahat)) 261 | 262 | term2 += ybar_mu - 0.5 * (pimu + muinv_ybar + trace_inv + logdet) 263 | 264 | term3 = tf.reduce_sum(N[:-1] * self._vhat2) 265 | 266 | vhat3_cumsum = tf.concat(((0,), tf.cumsum(self._vhat3)), axis=0) 267 | term4 = tf.reduce_sum(N * vhat3_cumsum) 268 | 269 | term5 = tf.reduce_sum((self._alphahat - self._alphahat2) * self._vhat3) 270 | term6 = tf.reduce_sum(pihat * self.phi) 271 | 272 | term7 = tf.reduce_sum(tf.reduce_logsumexp(self.phi, axis=1)) 273 | term8 = ( 274 | -self.gammahat_1 * (1 - tf.math.log(self.gammahat_2)) 275 | + tf.math.lgamma(self.gammahat_1) 276 | + (self.gammahat_1 + 1) * lhat 277 | ) 278 | term9 = tf.reduce_sum((self._alphahat1 - 1) * self._vhat2) 279 | term10 = tf.reduce_sum(tf.math.lbeta(tf.stack((self._alphahat1, self._alphahat2), axis=1))) 280 | term11 = ( 281 | -(self.eta_1 + self._fnclasses) * tf.math.log(self.etahat_2) 282 | + (self.etahat_2 - self.eta_2) * self._alphahat 283 | ) 284 | 285 | elbo = ( 286 | -term1 + term2 + term3 + term4 + term5 - term6 + term7 + term8 - term9 + term10 + term11 287 | ) / (self._fnsamples * self._fnclasses * self._fngenes) 288 | return elbo 289 | 290 | 291 | def spatial_patterns( 292 | adata: AnnData, 293 | genes: Optional[List[str]] = None, 294 | normalized=False, 295 | spatial_key="spatial", 296 | layer: Optional[str] = None, 297 | params: SpatialPatternParameters = SpatialPatternParameters(), 298 | rng: np.random.Generator = np.random.default_rng(), 299 | copy: bool = False, 300 | ) -> Tuple[SpatialPatterns, Union[AnnData, None]]: 301 | """ 302 | Detect spatial patterns of gene expression and assign genes to patterns. 303 | 304 | Uses a Gaussian process mixture. A Dirichlet process prior allows 305 | to automatically determine the number of distinct regions in the dataset. 306 | 307 | Args: 308 | adata: The annotated data matrix. 309 | genes: List of genes to base the analysis on. Defaults to all genes. 310 | normalized: Whether the data are already normalized to an approximately Gaussian likelihood. 311 | If ``False``, they will be normalized using the workflow from Svensson et al, 2018. 312 | spatial_key: Key in ``adata.obsm`` where the spatial coordinates are stored. 313 | layer: Name of the AnnData object layer to use. By default ``adata.X`` is used. 314 | params: Parameters for the algorithm, e.g. prior distributions, spatial smoothness, etc. 315 | rng: Random number generator. 316 | copy: Whether to return a copy of ``adata`` with results or write the results into ``adata`` 317 | in-place. 318 | 319 | Returns: 320 | A tuple. The first element is a :py:class:`SpatialPatterns`, the second is ``None`` if ``copy == False`` 321 | or an ``AnnData`` object. Patterns will be in ``adata.obs["spatial_pattern_0"]``, ..., 322 | ``adata.obs["spatial_pattern_n"]``. 323 | """ 324 | if not normalized and genes is None: 325 | warnings.warn( 326 | "normalized is False and no genes are given. Assuming that adata contains complete data set, will normalize and fit a GP for every gene." 327 | ) 328 | data = normalize_counts(adata, copy=True) if not normalized else adata 329 | if genes is not None: 330 | data = data[:, genes] 331 | 332 | X = data.obsm[spatial_key] 333 | counts = data.X if layer is None else adata.layers[layer] 334 | 335 | # This is important, we only care about co-expression, not absolute levels. 336 | counts = counts - tf.reduce_mean(counts, axis=0) 337 | counts = counts / tf.math.reduce_std(counts, axis=0) 338 | 339 | nclasses = params.nclasses 340 | if nclasses is None: 341 | if isinstance(params.lengthscales, Iterable): 342 | nclasses = len(params.lengthscales) 343 | else: 344 | nclasses = int(np.ceil(np.sqrt(data.n_vars))) 345 | 346 | patterns = _SpatialPatterns( 347 | X, 348 | counts, 349 | nclasses, 350 | params.lengthscales, 351 | params.gamma_1, 352 | params.gamma_2, 353 | params.eta_1, 354 | params.eta_2, 355 | rng, 356 | ) 357 | opt = Scipy() 358 | elbo_trace = [patterns.elbo()] 359 | res = opt.minimize( 360 | lambda: -patterns.elbo(), 361 | patterns.trainable_variables, 362 | method=params.method, 363 | step_callback=lambda step, vars, vals: elbo_trace.append(patterns.elbo()), 364 | tol=params.tol, 365 | options={"maxiter": params.maxiter}, 366 | ) 367 | 368 | prune_threshold = tf.convert_to_tensor(params.pattern_prune_threshold, dtype=default_float()) 369 | idx, labels = prune_components( 370 | tf.argmax(patterns.pihat, axis=1), 371 | tf.transpose(patterns.pihat), 372 | prune_threshold, 373 | everything=True, 374 | ) 375 | pihat = tf.linalg.normalize(tf.gather(patterns.pihat, idx, axis=1), ord=1, axis=1)[0] 376 | patterns = tf.gather(patterns.mu_hat, idx, axis=1).numpy() 377 | 378 | if copy: 379 | adata = adata.copy() 380 | toreturn = adata 381 | else: 382 | toreturn = None 383 | for i in range(patterns.shape[1]): 384 | adata.obs[f"spatial_pattern_{i}"] = patterns[:, i] 385 | 386 | return ( 387 | SpatialPatterns( 388 | converged=res.success, 389 | status=res.message, 390 | labels=labels.numpy(), 391 | pattern_probabilities=pihat.numpy(), 392 | patterns=patterns, 393 | niter=res.nit, 394 | elbo_trace=np.asarray(elbo_trace), 395 | ), 396 | toreturn, 397 | ) 398 | -------------------------------------------------------------------------------- /SpatialDE/de_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from time import time 3 | import warnings 4 | from itertools import zip_longest 5 | from typing import Optional, Dict, Tuple, Union, List, Literal 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from gpflow import default_float 10 | 11 | from anndata import AnnData 12 | 13 | from tqdm.auto import tqdm 14 | 15 | from ._internal.util import DistanceCache 16 | from ._internal.util import bh_adjust, calc_sizefactors, default_kernel_space, kspace_walk 17 | from ._internal.score_test import ( 18 | NegativeBinomialScoreTest, 19 | NormalScoreTest, 20 | combine_pvalues, 21 | ) 22 | from ._internal.tf_dataset import AnnDataDataset 23 | 24 | 25 | def _add_individual_score_test_result(resultdict, kernel, kname, gene): 26 | if "kernel" not in resultdict: 27 | resultdict["kernel"] = [kname] 28 | else: 29 | resultdict["kernel"].append(kname) 30 | if "gene" not in resultdict: 31 | resultdict["gene"] = [gene] 32 | else: 33 | resultdict["gene"].append(gene) 34 | for key, var in vars(kernel).items(): 35 | if key[0] != "_": 36 | if key not in resultdict: 37 | resultdict[key] = [var] 38 | else: 39 | resultdict[key].append(var) 40 | return resultdict 41 | 42 | 43 | def _merge_individual_results(individual_results): 44 | merged = {} 45 | for res in individual_results: 46 | for k, v in res.items(): 47 | if k not in merged: 48 | merged[k] = v if not np.isscalar(v) else [v] 49 | else: 50 | if isinstance(merged[k], np.ndarray): 51 | merged[k] = np.concatenate((merged[k], v)) 52 | elif isinstance(v, list): 53 | merged[k].extend(v) 54 | else: 55 | merged[k].append(v) 56 | return pd.DataFrame(merged) 57 | 58 | 59 | def test( 60 | adata: AnnData, 61 | layer: Optional[str] = None, 62 | omnibus: bool = False, 63 | spatial_key: str = "spatial", 64 | kernel_space: Optional[Dict[str, Union[float, List[float]]]] = None, 65 | sizefactors: Optional[np.ndarray] = None, 66 | stack_kernels: Optional[bool] = None, 67 | obs_dist: Literal["NegativeBinomial", "Normal"] = "NegativeBinomial", 68 | use_cache: bool = True, 69 | ) -> Tuple[pd.DataFrame, Union[pd.DataFrame, None]]: 70 | """ 71 | Test for spatially variable genes. 72 | 73 | Perform a score test to detect spatially variable genes in a spatial transcriptomics 74 | dataset. Multiple kernels can be tested to detect genes with different spatial patterns and lengthscales. 75 | The test uses a count-based likelihood and thus operates on raw count data. Two ways of handling multiple 76 | kernels are implemented: omnibus and Cauchy combination. The Cauchy combination tests each kernel separately 77 | and combines the p-values afterwards, while the omnibus test tests all kernels simultaneously. With multiple 78 | kernels the omnibus test is faster, but may have slightly less statistical power than the Cauchy combination. 79 | 80 | Args: 81 | adata: The annotated data matrix. 82 | layer: Name of the AnnData object layer to use. By default ``adata.X`` is used. 83 | omnibus: Whether to do an omnibus test. 84 | spatial_key: Key in ``adata.obsm`` where the spatial coordinates are stored. 85 | kernel_space: Kernels to test against. Dictionary with the name of the kernel function as key and list of 86 | lengthscales (if applicable) as values. Currently, three kernel functions are known: 87 | 88 | * ``SE``, the squared exponential kernel :math:`k(\\boldsymbol{x}^{(1)}, \\boldsymbol{x}^{(2)}; l) = \\exp\\left(-\\frac{\\lVert \\boldsymbol{x}^{(1)} - \\boldsymbol{x}^{(2)} \\rVert}{l^2}\\right)` 89 | * ``PER``, the periodic kernel :math:`k(\\boldsymbol{x}^{(1)}, \\boldsymbol{x}^{(2)}; l) = \\cos\\left(2 \pi \\frac{\\sum_i (x^{(1)}_i - x^{(2)}_i)}{l}\\right)` 90 | * ``linear``, the linear kernel :math:`k(\\boldsymbol{x}^{(1)}, \\boldsymbol{x}^{(2)}) = (\\boldsymbol{x}^{(1)})^\\top \\boldsymbol{x}^{(2)}` 91 | 92 | By default, 5 squared exponential and 5 periodic kernels with lengthscales spanning the range of the 93 | data will be used. 94 | sizefactors: Scaling factors for the observations. Default to total read counts. 95 | stack_kernels: When using the Cauchy combination, all tests can be performed in one operation by stacking 96 | the kernel matrices. This leads to increased memory consumption, but will drastically improve runtime 97 | on GPUs for smaller data sets. Defaults to ``True`` for datasets with less than 2000 observations and 98 | ``False`` otherwise. 99 | obs_dist: Distribution of the observations. If set as "Normal", model the regression to have Gaussian mean field error with identity link function. 100 | use_cache: Whether to use a pre-computed distance matrix for all kernels instead of computing the distance 101 | matrix anew for each kernel. Increases memory consumption, but is somewhat faster. 102 | 103 | Returns: 104 | If ``omnibus==True``, a tuple with a Pandas DataFrame as the first element and ``None`` as the second. 105 | The DataFrame contains the results of the test for each gene, in particular p-values and BH-adjusted p-values. 106 | Otherwise, a tuple of two DataFrames. The first contains the combined results, while the second contains results 107 | from individual tests. 108 | """ 109 | logging.info("Performing DE test") 110 | 111 | X = adata.obsm[spatial_key] 112 | dcache = DistanceCache(X, use_cache) 113 | if sizefactors is None: 114 | sizefactors = calc_sizefactors(adata) 115 | if kernel_space is None: 116 | kernel_space = default_kernel_space(dcache) 117 | individual_results = None if omnibus else [] 118 | if stack_kernels is None and adata.n_obs <= 2000 or stack_kernels or omnibus: 119 | kernels = [] 120 | kernelnames = [] 121 | for k, name in kspace_walk(kernel_space, dcache): 122 | kernels.append(k) 123 | kernelnames.append(name) 124 | if obs_dist == "NegativeBinomial": 125 | test = NegativeBinomialScoreTest( 126 | sizefactors, 127 | omnibus, 128 | kernels, 129 | ) 130 | else: 131 | test = NormalScoreTest(omnibus, kernels) 132 | 133 | results = [] 134 | with tqdm(total=adata.n_vars) as pbar: 135 | for i, (y, g) in AnnDataDataset(adata, dtype=default_float(), layer=layer).enumerate(): 136 | i = i.numpy() 137 | g = g.numpy().decode("utf-8") 138 | t0 = time() 139 | result, _ = test(y) 140 | t = time() - t0 141 | pbar.update() 142 | res = {"gene": g, "time": t} 143 | resultdict = result.to_dict() 144 | if omnibus: 145 | res.update(resultdict) 146 | else: 147 | res["pval"] = combine_pvalues(result).numpy() 148 | results.append(res) 149 | if not omnibus: 150 | for k, n in zip(kernels, kernelnames): 151 | _add_individual_score_test_result(resultdict, k, n, g) 152 | individual_results.append(resultdict) 153 | else: # doing all tests at once with stacked kernels leads to excessive memory consumption 154 | results = [[0, []] for _ in range(adata.n_vars)] 155 | nullmodels = [] 156 | test = NegativeBinomialScoreTest(sizefactors) 157 | for k, n in kspace_walk(kernel_space, dcache): 158 | test.kernel = k 159 | if len(nullmodels) > 0: 160 | nullit = nullmodels 161 | havenull = True 162 | else: 163 | nullit = () 164 | havenull = False 165 | with tqdm(total=adata.n_vars) as pbar: 166 | for null, (i, (y, g)) in zip_longest( 167 | nullit, AnnDataDataset(adata, dtype=default_float(), layer=layer).enumerate() 168 | ): 169 | i = i.numpy() 170 | g = g.numpy().decode("utf-8") 171 | 172 | t0 = time() 173 | res, null = test(y, null) 174 | t = time() - t0 175 | if not havenull: 176 | nullmodels.append(null) 177 | pbar.update() 178 | results[i][0] += t 179 | results[i][1].append(res) 180 | resultdict = res.to_dict() 181 | individual_results.append( 182 | _add_individual_score_test_result(resultdict, k, n, g) 183 | ) 184 | for i, g in enumerate(adata.var_names): 185 | results[i] = { 186 | "gene": g, 187 | "time": results[i][0], 188 | "pval": combine_pvalues(results[i][1]).numpy(), 189 | } 190 | 191 | results = pd.DataFrame(results) 192 | results["padj"] = bh_adjust(results.pval.to_numpy()) 193 | 194 | if individual_results is not None: 195 | individual_results = _merge_individual_results(individual_results) 196 | return results, individual_results 197 | -------------------------------------------------------------------------------- /SpatialDE/dp_hmrf.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Union, Tuple 2 | from dataclasses import dataclass 3 | from enum import Enum, auto 4 | import logging 5 | import warnings 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import tensorflow as tf 10 | from gpflow.utilities.ops import square_distance 11 | 12 | from anndata import AnnData 13 | 14 | from ._internal.util import calc_sizefactors, dense_slice 15 | from ._internal.util_mixture import prune_components, prune_labels 16 | 17 | 18 | @dataclass(frozen=True) 19 | class TissueSegmentationParameters: 20 | """ 21 | Parameters for tissue segmentation. 22 | 23 | Args: 24 | nclasses: Maximum number of regions to consider. Defaults to the square root of the number of observations. 25 | neighbors: Number of neighbors for the nearest-neighbor graph. Defaults to a fully connected graph (there is 26 | no speed difference). A value of 0 makes the model ignore spatial information and reduces it to a Poisson 27 | mixture model with a Dirichlet process prior. 28 | smoothness_factor: Spatial smoothness. Larger values induce more fine-grained segmentations. This value will 29 | be multiplied with the minimum squared distance within the data set, so it is dimensionless. Defaults to ``2``. 30 | class_prune_threshold: Probability threshold at which unused regions are removed. Defaults to ``1e-6``. 31 | abstol: Absolute convergence tolerance. Defaults to ``1e-12``. 32 | reltol: Relative convergence tolerance. Defaults to ``1e-12``. 33 | maxiter: Maximum number of iterations. Defaults to ``1000``. 34 | gamma_1: Parameter of the Poisson mean prior, defaults to ``1e-14``. 35 | gamma_2: Parameter of the Poisson mean prior, defaults to ``1e-14``. 36 | eta_1: Parameter of the Dirichlet process hyperprior, defaults to ``1``. 37 | eta_2: Parameter of the Dirichlet process hyperprior, defaults to ``1``. 38 | """ 39 | 40 | nclasses: Optional[int] = None 41 | neighbors: Optional[int] = None 42 | smoothness_factor: float = 2 43 | class_prune_threshold: float = 1e-6 44 | abstol: float = 1e-12 45 | reltol: float = 1e-12 46 | maxiter: int = 1000 47 | gamma_1: float = 1e-14 48 | gamma_2: float = 1e-14 49 | eta_1: float = 1 50 | eta_2: float = 1 51 | 52 | def __post_init__(self): 53 | assert ( 54 | self.nclasses is None or self.nclasses >= 1 55 | ), "Number of classes must be None or at least 1" 56 | assert ( 57 | self.neighbors is None or self.neighbors >= 0 58 | ), "Number of neighbors must be None or at least 0" 59 | assert self.smoothness_factor > 0, "Smoothness factor must be greater than 0" 60 | assert ( 61 | self.class_prune_threshold >= 0 and self.class_prune_threshold <= 1 62 | ), "Class pruning threshold must be between 0 and 1" 63 | assert self.abstol > 0, "Absolute tolerance must be greater than 0" 64 | assert self.reltol > 0, "Relative tolerance must be greater than 0" 65 | assert self.maxiter >= 1, "Maximum number of iterations must be greater than or equal to 1" 66 | assert self.gamma_1 > 0, "Gamma1 hyperparameter must be greater than 0" 67 | assert self.gamma_2 > 0, "Gamma2 hyperparameter must be greater than 0" 68 | assert self.eta_1 > 0, "Eta1 hyperparameter must be greater than 0" 69 | assert self.eta_2 > 0, "Eta2 hyperparameter must be greater than 0" 70 | 71 | 72 | class TissueSegmentationStatus(Enum): 73 | AbsoluteToleranceReached = auto() 74 | RelativeToleranceReached = auto() 75 | MaximumIterationsReached = auto() 76 | 77 | 78 | @dataclass(frozen=True) 79 | class TissueSegmentation: 80 | """ 81 | Results of tissue segmentation. 82 | 83 | Args: 84 | converged: Whether the optimization converged. 85 | status: Status of the optimization. 86 | labels: The estimated region labels. 87 | class_probabilities: N_obs x N_regions array with the estimated region probabilities for each observation. 88 | gammahat_1: N_classes x N_genes array with the estimated parameter of the gene expression posterior. 89 | gammahat_2: N_classes x 1 array with the estimated parameter of the gene expression posterior. 90 | niter: Number of iterations for the optimization. 91 | prune_iterations: Iterations at which unneeded regions were removed. 92 | elbo_trace: ELBO values at each iteration. 93 | nclasses_trace: Number of regions at each iteration. 94 | """ 95 | 96 | converged: bool 97 | status: TissueSegmentationStatus 98 | labels: np.ndarray 99 | class_probabilities: np.ndarray 100 | gammahat_1: np.ndarray 101 | gammahat_2: np.ndarray 102 | niter: int 103 | prune_iterations: np.ndarray 104 | elbo_trace: np.ndarray 105 | nclasses_trace: np.ndarray 106 | 107 | 108 | @tf.function(experimental_relax_shapes=True) 109 | def _segment( 110 | counts: tf.Tensor, 111 | sizefactors: tf.Tensor, 112 | distances: tf.Tensor, 113 | nclasses: tf.Tensor, 114 | fnclasses: tf.Tensor, 115 | ngenes: tf.Tensor, 116 | labels: tf.Tensor, 117 | gamma_1: tf.Tensor, 118 | gamma_2: tf.Tensor, 119 | eta_1: tf.Tensor, 120 | eta_2: tf.Tensor, 121 | alphahat_1: tf.Tensor, 122 | alphahat_2: tf.Tensor, 123 | etahat_2: tf.Tensor, 124 | gammahat_1: tf.Tensor, 125 | gammahat_2: tf.Tensor, 126 | ): 127 | eta_1_nclasses = eta_1 + fnclasses 128 | if labels is not None and distances is not None: 129 | p_x_neighborhood = tf.TensorArray(counts.dtype, size=tf.shape(gammahat_1)[0]) 130 | for c in tf.range(tf.shape(gammahat_1)[0]): 131 | p_x_neighborhood = p_x_neighborhood.write( 132 | c, -tf.reduce_sum(tf.where(labels != c, distances, 0), axis=1) 133 | ) 134 | p_x_neighborhood = p_x_neighborhood.stack() 135 | p_x_neighborhood = p_x_neighborhood - tf.reduce_logsumexp( 136 | p_x_neighborhood, axis=0, keepdims=True 137 | ) 138 | else: 139 | p_x_neighborhood = tf.convert_to_tensor(0, counts.dtype) 140 | 141 | lambdahat_1 = gammahat_1 / gammahat_2 142 | lambdahat_2 = tf.math.digamma(gammahat_1) - tf.math.log(gammahat_2) 143 | alpha12 = alphahat_1 + alphahat_2 144 | dgalpha = tf.math.digamma(alpha12) 145 | vhat2 = tf.math.digamma(alphahat_1) - dgalpha 146 | vhat3 = tf.math.digamma(alphahat_2) - dgalpha 147 | alphahat = (eta_1_nclasses - 1) / etahat_2 148 | vhat3_cumsum = tf.cumsum(vhat3) - vhat3 149 | 150 | vhat_sum = (tf.concat(((0,), vhat3_cumsum), axis=0) + tf.concat((vhat2, (0,)), axis=0))[ 151 | :, tf.newaxis 152 | ] 153 | 154 | phi = ( 155 | p_x_neighborhood 156 | + vhat_sum 157 | + tf.matmul(lambdahat_2, counts, transpose_b=True) 158 | - sizefactors * tf.reduce_sum(lambdahat_1, axis=1, keepdims=True) 159 | ) 160 | pihat = tf.nn.softmax(phi, axis=0) 161 | pihat_cumsum = tf.cumsum(pihat, axis=0, reverse=True) - pihat 162 | 163 | vhat3_sum = tf.reduce_sum(vhat3) 164 | gammahat_1 = gamma_1 + pihat @ counts 165 | gammahat_2 = gamma_2 + tf.matmul(pihat, sizefactors, transpose_b=True) 166 | etahat_2 = eta_2 - vhat3_sum 167 | alphahat_1 = 1 + ngenes * tf.reduce_sum(pihat, axis=1)[:-1] 168 | alphahat_2 = ngenes * tf.reduce_sum(pihat_cumsum, axis=1)[:-1] + alphahat 169 | 170 | elbo = ( 171 | tf.reduce_sum(pihat * p_x_neighborhood) 172 | + tf.reduce_sum( 173 | pihat 174 | * ( 175 | vhat_sum 176 | + tf.matmul(lambdahat_2, counts, transpose_b=True) 177 | - tf.reduce_sum(lambdahat_1, axis=1, keepdims=True) * sizefactors 178 | ) 179 | ) 180 | + tf.reduce_sum((alphahat - alphahat_2) * vhat3) 181 | + tf.reduce_sum((gamma_1 - gammahat_1) * lambdahat_2) 182 | + tf.reduce_sum((gammahat_2 - gamma_2) * lambdahat_1) 183 | - tf.reduce_sum(pihat * phi) 184 | + tf.reduce_sum(tf.reduce_logsumexp(phi, axis=0)) 185 | - tf.reduce_sum(gammahat_1 * tf.math.log(gammahat_2) - tf.math.lgamma(gammahat_1)) 186 | - tf.reduce_sum((alphahat_1 - 1) * vhat2) 187 | + tf.reduce_sum(tf.math.lbeta(tf.stack((alphahat_1, alphahat_2), axis=1))) 188 | - eta_1_nclasses * tf.math.log(etahat_2) 189 | + (etahat_2 - eta_2) * alphahat 190 | ) / tf.cast(nclasses * ngenes * tf.shape(counts)[0], counts.dtype) 191 | 192 | return pihat, alphahat_1, alphahat_2, etahat_2, gammahat_1, gammahat_2, elbo 193 | 194 | 195 | def tissue_segmentation( 196 | adata: AnnData, 197 | layer: Optional[str] = None, 198 | genes: Optional[List[str]] = None, 199 | sizefactors: Optional[np.ndarray] = None, 200 | spatial_key: str = "spatial", 201 | params: TissueSegmentationParameters = TissueSegmentationParameters(), 202 | labels: Optional[Union[np.ndarray, tf.Tensor]] = None, 203 | rng: np.random.Generator = np.random.default_rng(), 204 | copy=False, 205 | ) -> Tuple[TissueSegmentation, Union[AnnData, None]]: 206 | """ 207 | Segment a spatial transcriptomics dataset into distinct spatial regions. 208 | 209 | Uses a hidden Markov random field (HMRF) model with a Poisson likelihood. A Dirichlet process prior allows 210 | to automatically determine the number of distinct regions in the dataset. 211 | 212 | Args: 213 | adata: The annotated data matrix. 214 | layer: Name of the AnnData object layer to use. By default ``adata.X`` is used. 215 | genes: List of genes to base the segmentation on. Defaults to all genes. 216 | sizefactors: Scaling factors for the observations. Defaults to total read counts. 217 | spatial_key: Key in ``adata.obsm`` where the spatial coordinates are stored. 218 | params: Parameters for the algorithm, e.g. prior distributions, spatial smoothness, etc. 219 | labels: Initial label for each observation. Defaults to a random initialization. 220 | rng: Random number generator. 221 | copy: Whether to return a copy of ``adata`` with results or write the results into ``adata`` 222 | in-place. 223 | 224 | Returns: 225 | A tuple. The first element is a :py:class:`TissueSegmentation`, the second is ``None`` if ``copy == False`` 226 | or an ``AnnData`` object. Region labels will be in ``adata.obs["segmentation_labels"]`` and region 227 | probabilities in ``adata.obsm["segmentation_class_probabilities"]``. 228 | """ 229 | if genes is None and sizefactors is None: 230 | warnings.warn( 231 | "Neither genes nor sizefactors are given. Assuming that adata contains complete data set, will calculate size factors and perform segmentation using the complete data set.", 232 | RuntimeWarning, 233 | ) 234 | 235 | if sizefactors is None: 236 | sizefactors = calc_sizefactors(adata, layer=layer) 237 | if genes is not None: 238 | ngenes = len(genes) 239 | data = adata[:, genes] 240 | else: 241 | ngenes = adata.n_vars 242 | data = adata 243 | try: 244 | X = data.obsm[spatial_key] 245 | except KeyError: 246 | X = None 247 | 248 | dtype = tf.float64 249 | labels_dtype = tf.int32 250 | nclasses = params.nclasses 251 | nsamples = data.n_obs 252 | if nclasses is None: 253 | nclasses = tf.cast( 254 | tf.math.ceil(tf.sqrt(tf.convert_to_tensor(nsamples, dtype=tf.float32))), tf.int32 255 | ) 256 | fngenes = tf.cast(ngenes, dtype=dtype) 257 | fnclasses = tf.cast(nclasses, dtype=dtype) 258 | 259 | sizefactors = tf.convert_to_tensor(sizefactors[np.newaxis, :], dtype=dtype) 260 | 261 | gamma_1 = tf.convert_to_tensor(params.gamma_1, dtype=dtype) 262 | gamma_2 = tf.convert_to_tensor(params.gamma_2, dtype=dtype) 263 | eta_1 = tf.convert_to_tensor(params.eta_1, dtype=dtype) 264 | eta_2 = tf.convert_to_tensor(params.eta_2, dtype=dtype) 265 | 266 | counts = tf.convert_to_tensor( 267 | dense_slice(data.X if layer is None else data.layers[layer]), dtype=dtype 268 | ) 269 | 270 | distances = None 271 | if X is not None and (params.neighbors is None or params.neighbors > 0): 272 | X = tf.convert_to_tensor(X, dtype=dtype) 273 | distances = square_distance(X, None) 274 | if params.neighbors is not None and params.neighbors < nsamples: 275 | distances, indices = tf.math.top_k(-distances, k=params.neighbors + 1, sorted=True) 276 | distances = -distances[:, 1:] 277 | distances = 2 * params.smoothness_factor * tf.reduce_min(distances) / distances 278 | indices = indices[:, 1:] 279 | indices = tf.stack( 280 | ( 281 | tf.repeat(tf.range(distances.shape[0]), indices.shape[1]), 282 | tf.reshape(indices, -1), 283 | ), 284 | axis=1, 285 | ) 286 | dists = tf.reshape(distances, -1) 287 | distances = tf.scatter_nd(indices, dists, (distances.shape[0], distances.shape[0])) 288 | distances = tf.tensor_scatter_nd_update( 289 | distances, indices[:, ::-1], dists 290 | ) # symmetrize 291 | else: 292 | distances = tf.linalg.set_diag( 293 | distances, tf.repeat(tf.convert_to_tensor(np.inf, dtype), tf.shape(distances)[0]) 294 | ) 295 | distances = 2 * params.smoothness_factor * tf.reduce_min(distances) / distances 296 | else: 297 | logging.info("Not using spatial information, fitting Poisson mixture model instead.") 298 | 299 | if labels is not None: 300 | labels = tf.squeeze(tf.convert_to_tensor(labels, dtype=labels_dtype)) 301 | if tf.rank(labels) > 1 or tf.shape(labels)[0] != nsamples: 302 | labels = None 303 | warnings.warn( 304 | "Shape of given labels does not conform to data. Initializing with random labels.", 305 | RuntimeWarning, 306 | ) 307 | if labels is None and distances is not None: 308 | labels = tf.convert_to_tensor(rng.choice(nclasses, nsamples), dtype=labels_dtype) 309 | 310 | alphahat_1 = tf.ones(shape=(nclasses - 1,), dtype=dtype) 311 | alphahat_2 = tf.ones(shape=(nclasses - 1,), dtype=dtype) 312 | etahat_2 = eta_1 + fnclasses - 1 313 | gammahat_1 = tf.fill((nclasses, ngenes), tf.convert_to_tensor(1e-6, dtype=dtype)) 314 | gammahat_2 = tf.fill((nclasses, 1), tf.convert_to_tensor(1e-6, dtype=dtype)) 315 | 316 | prune_threshold = tf.convert_to_tensor(params.class_prune_threshold, dtype=dtype) 317 | lastelbo = -tf.convert_to_tensor(np.inf, dtype=dtype) 318 | elbos = [] 319 | nclassestrace = [] 320 | pruneidx = [] 321 | converged = False 322 | status = TissueSegmentationStatus.MaximumIterationsReached 323 | for i in range(params.maxiter): 324 | (pihat, alphahat_1, alphahat_2, etahat_2, gammahat_1, gammahat_2, elbo,) = _segment( 325 | counts, 326 | sizefactors, 327 | distances, 328 | nclasses, 329 | fnclasses, 330 | ngenes, 331 | labels, 332 | gamma_1, 333 | gamma_2, 334 | eta_1, 335 | eta_2, 336 | alphahat_1, 337 | alphahat_2, 338 | etahat_2, 339 | gammahat_1, 340 | gammahat_2, 341 | ) 342 | labels = tf.math.argmax(pihat, axis=0, output_type=labels_dtype) 343 | elbos.append(elbo.numpy()) 344 | nclassestrace.append(nclasses) 345 | elbodiff = tf.abs(elbo - lastelbo) 346 | if elbodiff < params.abstol: 347 | converged = True 348 | status = TissueSegmentationStatus.AbsoluteToleranceReached 349 | break 350 | elif elbodiff / tf.minimum(tf.abs(elbo), tf.abs(lastelbo)) < params.reltol: 351 | converged = True 352 | status = TissueSegmentationStatus.AbsoluteToleranceReached 353 | break 354 | lastelbo = elbo 355 | 356 | if i == 1 or not i % 10: 357 | idx, labels = prune_components(labels, pihat, prune_threshold, everything=True) 358 | if tf.size(idx) < tf.shape(gammahat_1)[0]: 359 | pruneidx.append(i) 360 | alphahat_1 = tf.gather(alphahat_1, idx[:-1], axis=0) 361 | alphahat_2 = tf.gather(alphahat_2, idx[:-1], axis=0) 362 | gammahat_1 = tf.gather(gammahat_1, idx, axis=0) 363 | gammahat_2 = tf.gather(gammahat_2, idx, axis=0) 364 | nclasses = tf.size(idx) 365 | 366 | idx, labels = prune_components(labels, pihat, prune_threshold, everything=True) 367 | pihat = tf.linalg.normalize(tf.gather(pihat, idx, axis=0), ord=1, axis=0)[0] 368 | gammahat_1 = tf.gather(gammahat_1, idx, axis=0) 369 | gammahat_2 = tf.gather(gammahat_2, idx, axis=0) 370 | 371 | if copy: 372 | adata = adata.copy() 373 | toreturn = adata 374 | else: 375 | toreturn = None 376 | labels = labels.numpy() 377 | pihat = pihat.numpy().T 378 | adata.obs["segmentation_labels"] = pd.Categorical(labels) 379 | adata.obsm["segmentation_class_probabilities"] = pihat 380 | return ( 381 | TissueSegmentation( 382 | converged, 383 | status, 384 | labels, 385 | pihat, 386 | gammahat_1.numpy(), 387 | gammahat_2.numpy(), 388 | i, 389 | np.asarray(pruneidx), 390 | np.asarray(elbos), 391 | np.asarray(nclassestrace), 392 | ), 393 | toreturn, 394 | ) 395 | -------------------------------------------------------------------------------- /SpatialDE/gaussian_process.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from time import time 3 | import warnings 4 | from typing import Optional, Dict, List 5 | from enum import Enum, auto 6 | 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from tqdm.auto import tqdm 11 | 12 | import tensorflow as tf 13 | 14 | import NaiveDE 15 | from anndata import AnnData 16 | 17 | from ._internal.kernels import SquaredExponential, Cosine, Linear 18 | from ._internal.models import Model, Constant, Null, model_factory 19 | from ._internal.util import ( 20 | DistanceCache, 21 | default_kernel_space, 22 | kspace_walk, 23 | dense_slice, 24 | normalize_counts, 25 | ) 26 | from ._internal.tf_dataset import AnnDataDataset 27 | from ._internal.gpflow_helpers import * 28 | 29 | 30 | class GP(Enum): 31 | GPR = auto() 32 | """ 33 | Dense Gaussian process. 34 | """ 35 | SGPR = auto() 36 | """ 37 | Sparse Gaussian process. 38 | """ 39 | 40 | 41 | class SGPIPM(Enum): 42 | free = auto() 43 | """Inducing points are initialized randomly and their positions are optimized together with the other parameters.""" 44 | random = auto() 45 | """Inducing points are placed at random locations.""" 46 | grid = auto() 47 | """Inducing points are placed in a regular grid.""" 48 | 49 | 50 | @dataclass(frozen=True) 51 | class GPControl: 52 | """ 53 | Parameters for Gaussian process fitting. 54 | 55 | Args: 56 | gp: Type of GP to fit. 57 | ipm: Inducing point method. Only used if ``gp == GP.SGPR``. 58 | ncomponents: Number of kernel components. 59 | ard: Whether to use automatic relevance determination. This amounts to having one 60 | lengthscale per spatial dimension. 61 | ninducers: Number of inducing points. 62 | """ 63 | 64 | gp: Optional[GP] = None 65 | ipm: SGPIPM = SGPIPM.grid 66 | ncomponents: int = 5 67 | ard: bool = True 68 | ninducers: Optional[int] = None 69 | 70 | 71 | def inducers_grid(X, ninducers): 72 | rngmin = X.min(0) 73 | rngmax = X.max(0) 74 | xvals = np.linspace(rngmin[0], rngmax[0], int(np.ceil(np.sqrt(ninducers)))) 75 | yvals = np.linspace(rngmin[1], rngmax[1], int(np.ceil(np.sqrt(ninducers)))) 76 | xx, xy = np.meshgrid(xvals, yvals) 77 | return np.hstack((xx.reshape((xx.size, 1)), xy.reshape((xy.size, 1)))) 78 | 79 | 80 | def fit_model(model: Model, genes: Union[List[str], np.ndarray], counts: np.ndarray): 81 | results = [] 82 | with warnings.catch_warnings(): 83 | warnings.simplefilter("ignore", RuntimeWarning) 84 | with model: 85 | for i, gene in enumerate(tqdm(genes)): 86 | y = dense_slice(counts[:, i]) 87 | model.y = y 88 | t0 = time() 89 | 90 | res = model.optimize() 91 | t = time() - t0 92 | res = { 93 | "gene": gene, 94 | "max_ll": model.log_marginal_likelihood, 95 | "max_delta": model.delta, 96 | "max_mu_hat": model.mu, 97 | "max_s2_t_hat": model.sigma_s2, 98 | "max_s2_e_hat": model.sigma_n2, 99 | "time": t, 100 | "n": model.n, 101 | "FSV": model.FSV, 102 | "s2_FSV": np.abs( 103 | model.s2_FSV 104 | ), # we are at the optimum, so this should never be negative. 105 | "s2_logdelta": np.abs( 106 | model.s2_logdelta 107 | ), # Negative results are due to numerical errors when evaluating vanishing Hessians 108 | "converged": res.success, 109 | "M": model.n_parameters, 110 | } 111 | for (k, v) in vars(model.kernel).items(): 112 | if k not in res and k[0] != "_": 113 | res[k] = v 114 | 115 | results.append(res) 116 | return pd.DataFrame(results) 117 | 118 | 119 | def fit_detailed( 120 | adata: AnnData, 121 | genes: Optional[List[str]] = None, 122 | layer: Optional[str] = None, 123 | normalized: bool = False, 124 | sizefactor_col: Optional[str] = None, 125 | spatial_key: str = "spatial", 126 | control: Optional[GPControl] = GPControl(), 127 | rng: np.random.Generator = np.random.default_rng(), 128 | ) -> DataSetResults: 129 | """ 130 | Fits Gaussian processes to genes. 131 | 132 | A Gaussian process based on highly interpretable spectral mixture kernels (Wilson et al. 2013, Wilson 2014) is fitted 133 | separately to each gene. Sparse GPs are used on large datasets (>1000 observations) to improve speed. 134 | This uses a Gaussian likelihood and requires appropriate data normalization. 135 | 136 | Args: 137 | adata: The annotated data matrix. 138 | genes: List of genes to base the analysis on. Defaults to all genes. 139 | layer: Name of the AnnData object layer to use. By default ``adata.X`` is used. 140 | normalized: Whether the data are already normalized to an approximately Gaussian likelihood. 141 | If ``False``, they will be normalized using the workflow from Svensson et al, 2018. 142 | sizefactor_col: Column in ``adata.obs`` to be used for normalization. If ``None``, total number of 143 | counts per spot will be used. 144 | spatial_key: Key in ``adata.obsm`` where the spatial coordinates are stored. 145 | control: Parameters for the Gaussian process, e.g. number of kernel components, number of inducing points. 146 | rng: Random number generator. 147 | 148 | Returns: 149 | A dictionary with the results. The dictionary has an additional method ``to_df``, which returns a DataFrame 150 | with the key results. 151 | """ 152 | if not normalized and genes is None: 153 | warnings.warn( 154 | "normalized is False and no genes are given. Assuming that adata contains complete data set, will normalize and fit a GP for every gene." 155 | ) 156 | 157 | if not normalized: 158 | adata = normalize_counts(adata, sizefactor_col, layer, copy=True) 159 | 160 | data = adata[:, genes] if genes is not None else adata 161 | X = data.obsm[spatial_key] 162 | counts = data.X if layer is None else data.layers[layer] 163 | 164 | gp = control.gp 165 | if gp is None: 166 | if data.n_obs < 1000: 167 | gp = GP.GPR 168 | else: 169 | gp = GP.SGPR 170 | 171 | results = DataSetResults() 172 | X = tf.convert_to_tensor(X, dtype=gpflow.config.default_float()) 173 | t = tqdm(data.var_names) 174 | opt = gpflow.optimizers.Scipy() 175 | 176 | logging.info("Fitting gene models") 177 | if gp == GP.GPR: 178 | for g, gene in enumerate(t): 179 | t.set_description(gene, refresh=False) 180 | model = GPR( 181 | X, 182 | Y=tf.convert_to_tensor( 183 | dense_slice(counts[:, g])[:, np.newaxis], 184 | dtype=gpflow.config.default_float(), 185 | ), 186 | n_kernel_components=control.ncomponents, 187 | ard=control.ard, 188 | ) 189 | results[gene] = GeneGP(model, opt.minimize, method="bfgs") 190 | elif gp == GP.SGPR: 191 | ninducers = ( 192 | np.ceil(np.sqrt(data.n_obs)).astype(np.int32) 193 | if control.ninducers is None 194 | else control.ninducers 195 | ) 196 | if control.ipm == SGPIPM.free or control.ipm == SGPIPM.random: 197 | inducers = X[rng.integers(0, X.shape[0], ninducers), :] 198 | elif control.ipm == SGPIPM.grid: 199 | rngmin = tf.reduce_min(X, axis=0) 200 | rngmax = tf.reduce_max(X, axis=0) 201 | xvals = tf.linspace(rngmin[0], rngmax[0], int(np.ceil(np.sqrt(ninducers)))) 202 | yvals = tf.linspace(rngmin[1], rngmax[1], int(np.ceil(np.sqrt(ninducers)))) 203 | xx, xy = tf.meshgrid(xvals, yvals) 204 | inducers = tf.stack((tf.reshape(xx, (-1,)), tf.reshape(xy, (-1,))), axis=1) 205 | inducers = gpflow.inducing_variables.InducingPoints(inducers) 206 | if control.ipm != SGPIPM.free: 207 | gpflow.utilities.set_trainable(inducers, False) 208 | 209 | method = "BFGS" 210 | if control.ipm == SGPIPM.free and ninducers > 1e3: 211 | method = "L-BFGS-B" 212 | 213 | for g, gene in enumerate(t): 214 | t.set_description(gene, refresh=False) 215 | model = SGPR( 216 | X, 217 | Y=tf.constant( 218 | dense_slice(counts[:, g])[:, np.newaxis], 219 | dtype=gpflow.config.default_float(), 220 | ), 221 | inducing_variable=inducers, 222 | n_kernel_components=control.ncomponents, 223 | ard=control.ard, 224 | ) 225 | results[gene] = GeneGP(model, opt.minimize, method=method) 226 | 227 | logging.info("Finished fitting models to %i genes" % data.n_vars) 228 | return results 229 | 230 | 231 | def fit_fast( 232 | adata: AnnData, 233 | genes: Optional[List[str]] = None, 234 | layer: Optional[str] = None, 235 | normalized: bool = False, 236 | sizefactor_col: Optional[str] = None, 237 | sparse: Optional[bool] = None, 238 | spatial_key: str = "spatial", 239 | kernel_space: Optional[Dict[str, Union[float, List[float]]]] = None, 240 | ) -> pd.DataFrame: 241 | """ 242 | Fits Gaussian processes to genes. 243 | 244 | This uses the inflexible but fast technique of Svensson et al. (2018). In particular, the kernel lengthscale is not 245 | optimized, but must be given beforehand. Multiple kernel functions and lengthscales can be specified, the best-fitting 246 | model will be retained. To further improve speed, sparse GPs are used for large (>1000 observations) data sets with 247 | inducing points located on a regular grid. 248 | 249 | Args: 250 | adata: The annotated data matrix. 251 | genes: List of genes to base the analysis on. Defaults to all genes. 252 | layer: Name of the AnnData object layer to use. By default ``adata.X`` is used. 253 | normalized: Whether the data are already normalized to an approximately Gaussian likelihood. 254 | If ``False``, they will be normalized using the workflow from Svensson et al, 2018. 255 | sizefactor_col: Column in ``adata.obs`` to be used for normalization. If ``None``, total number of 256 | counts per spot will be used. 257 | spatial_key: Key in ``adata.obsm`` where the spatial coordinates are stored. 258 | sparse: Whether to use sparse GPs. Slightly faster on large datasets, but less precise. 259 | Defaults to sparse GPs if more than 1000 data points are given. 260 | kernel_space: Kernels to test against. Dictionary with the name of the kernel function as key and list of 261 | lengthscales (if applicable) as values. Currently, three kernel functions are known: 262 | 263 | * ``SE``, the squared exponential kernel :math:`k(\\boldsymbol{x}^{(1)}, \\boldsymbol{x}^{(2)}; l) = \\exp\\left(-\\frac{\\lVert \\boldsymbol{x}^{(1)} - \\boldsymbol{x}^{(2)} \\rVert}{l^2}\\right)` 264 | * ``PER``, the periodic kernel :math:`k(\\boldsymbol{x}^{(1)}, \\boldsymbol{x}^{(2)}; l) = \\cos\\left(2 \pi \\frac{\\sum_i (x^{(1)}_i - x^{(2)}_i)}{l}\\right)` 265 | * ``linear``, the linear kernel :math:`k(\\boldsymbol{x}^{(1)}, \\boldsymbol{x}^{(2)}) = (\\boldsymbol{x}^{(1)})^\\top \\boldsymbol{x}^{(2)}` 266 | 267 | By default, 5 squared exponential and 5 periodic kernels with lengthscales spanning the range of the 268 | data will be used. 269 | 270 | Returns: 271 | A Pandas DataFrame with the results. 272 | """ 273 | if not normalized and genes is None: 274 | warnings.warn( 275 | "normalized is False and no genes are given. Assuming that adata contains complete data set, will normalize and fit a GP for every gene." 276 | ) 277 | 278 | if not normalized: 279 | adata = normalize_counts(adata, sizefactor_col, layer, copy=True) 280 | 281 | data = adata[:, genes] if genes is not None else adata 282 | 283 | X = data.obsm[spatial_key] 284 | counts = data.X if layer is None else data.layers[layer] 285 | 286 | dcache = DistanceCache(X) 287 | if kernel_space is None: 288 | kernel_space = default_kernel_space(dcache) 289 | 290 | logging.info("Fitting gene models") 291 | n_models = 0 292 | Z = None 293 | if sparse is None and X.shape[0] > 1000 or sparse: 294 | Z = inducers_grid(X, np.maximum(100, np.sqrt(data.n_obs))) 295 | 296 | results = [] 297 | for kern, kname in kspace_walk(kernel_space, dcache): 298 | model = model_factory(X, Z, kern) 299 | res = fit_model(model, data.var_names, counts) 300 | res["model"] = kname 301 | results.append(res) 302 | n_models += 1 303 | 304 | n_genes = data.n_vars 305 | logging.info("Finished fitting {} models to {} genes".format(n_models, n_genes)) 306 | 307 | results = pd.concat(results, sort=True).reset_index(drop=True) 308 | sizes = ( 309 | results.groupby(["model", "gene"], sort=False).size().groupby("model", sort=False).unique() 310 | ) 311 | results = results.set_index("model") 312 | results.loc[sizes > 1, "M"] += 1 313 | results = results.reset_index() 314 | results["BIC"] = -2 * results["max_ll"] + results["M"] * np.log(results["n"]) 315 | 316 | results = results.loc[results.groupby(["model", "gene"], sort=False)["max_ll"].idxmax()] 317 | results = results.loc[results.groupby("gene", sort=False)["BIC"].idxmin()] 318 | 319 | return results.reset_index(drop=True) 320 | 321 | 322 | def fit( 323 | adata: AnnData, 324 | genes: Optional[List[str]] = None, 325 | layer: Optional[str] = None, 326 | normalized=False, 327 | spatial_key: str = "spatial", 328 | control: Optional[GPControl] = GPControl(), 329 | kernel_space: Optional[Dict[str, float]] = None, 330 | rng: np.random.Generator = np.random.default_rng(), 331 | ) -> pd.DataFrame: 332 | """ 333 | Fits Gaussian processes to genes. 334 | 335 | This dispatches to :py:func:`fit_fast` if ``control`` is ``None``, otherwise to :py:func:`fit_detailed`. 336 | All arguments are forwarded. 337 | 338 | Returns: A Pandas DataFrame with the results. 339 | """ 340 | if control is None: 341 | return fit_fast(adata, genes, layer, normalized, spatial_key, kernel_space) 342 | else: 343 | return ( 344 | fit_detailed(adata, genes, layer, normalized, spatial_key, control, rng) 345 | .to_df(modelcol="model") 346 | .reset_index(drop=True) 347 | ) 348 | -------------------------------------------------------------------------------- /SpatialDE/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import warnings 4 | import json 5 | import logging 6 | 7 | import numpy as np 8 | from scipy.sparse import csr_matrix 9 | import pandas as pd 10 | from matplotlib.image import imread 11 | 12 | import h5py 13 | from anndata import AnnData 14 | 15 | 16 | def read_spaceranger(spaceranger_out_dir: str, read_images: bool = True) -> AnnData: 17 | """ 18 | Read 10x SpaceRanger output. 19 | 20 | Args: 21 | spaceranger_out_dir: Path to the directory with SpaceRanger output. 22 | read_images: Whether to also read images into memory. 23 | 24 | Returns: 25 | An annotated data matrix. 26 | """ 27 | fname = glob.glob(os.path.join(spaceranger_out_dir, "*filtered_feature_bc_matrix.h5")) 28 | if len(fname) == 0: 29 | raise FileNotFoundError( 30 | "filtered_feature_bc_matrix.h5 file not found in specified directory" 31 | ) 32 | elif len(fname) > 1: 33 | warnings.warn( 34 | "Multiple files ending with filtered_feature_bc_matrix.h5 found in specified directory, using the first one", 35 | RuntimeWarning, 36 | ) 37 | fname = fname[0] 38 | with h5py.File(fname, "r") as f: 39 | matrix = f["matrix"] 40 | sparsemat = csr_matrix( 41 | (matrix["data"][...], matrix["indices"][...], matrix["indptr"][...]), 42 | shape=matrix["shape"][...][::-1], 43 | ) 44 | 45 | barcodes = matrix["barcodes"][...].astype(np.unicode) 46 | 47 | adata = AnnData(X=sparsemat) 48 | 49 | features = matrix["features"] 50 | adata.var_names = features["name"][...].astype(np.unicode) 51 | adata.var["id"] = features["id"][...].astype(np.unicode) 52 | for f in features["_all_tag_keys"]: 53 | feature = features[f][...] 54 | if feature.dtype.kind in ("S", "U"): 55 | feature = feature.astype(np.unicode) 56 | adata.var[f.astype(np.unicode)] = feature 57 | 58 | _, counts = np.unique(adata.var_names, return_counts=True) 59 | if np.sum(counts > 1) > 0: 60 | logging.warning("Duplicate gene names present. Converting to unique names.") 61 | adata.var_names_make_unique() 62 | 63 | tissue_positions = ( 64 | pd.read_csv( 65 | os.path.join(spaceranger_out_dir, "spatial", "tissue_positions_list.csv"), 66 | names=( 67 | "barcode", 68 | "in_tissue", 69 | "array_row", 70 | "array_col", 71 | "pxl_col_in_fullres", 72 | "pxl_row_in_fullres", 73 | ), 74 | ) 75 | .set_index("barcode") 76 | .loc[barcodes] 77 | .drop("in_tissue", axis=1) 78 | ) 79 | adata.obsm["spatial"] = tissue_positions[ 80 | ["pxl_row_in_fullres", "pxl_col_in_fullres"] 81 | ].to_numpy() 82 | adata.obs = tissue_positions.drop(["pxl_row_in_fullres", "pxl_col_in_fullres"], axis=1) 83 | 84 | with open(os.path.join(spaceranger_out_dir, "spatial", "scalefactors_json.json"), "r") as f: 85 | meta = json.load(f) 86 | adata.uns["spot_diameter_fullres"] = meta["spot_diameter_fullres"] 87 | if read_images: 88 | adata.uns["tissue_lowres_image"] = imread( 89 | os.path.join(spaceranger_out_dir, "spatial", "tissue_lowres_image.png") 90 | ) 91 | adata.uns["tissue_hires_image"] = imread( 92 | os.path.join(spaceranger_out_dir, "spatial", "tissue_hires_image.png") 93 | ) 94 | adata.uns["tissue_hires_scalef"] = meta["tissue_hires_scalef"] 95 | adata.uns["tissue_lowres_scalef"] = meta["tissue_lowres_scalef"] 96 | 97 | return adata 98 | -------------------------------------------------------------------------------- /SpatialDE/svca.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from typing import Optional, List, Tuple, Union 3 | import warnings 4 | from dataclasses import dataclass, field 5 | 6 | from tqdm.auto import tqdm, trange 7 | from anndata import AnnData 8 | import numpy as np 9 | import pandas as pd 10 | import gpflow 11 | from gpflow.utilities import to_default_float 12 | import tensorflow as tf 13 | import tensorflow_probability as tfp 14 | 15 | from ._internal.svca import SVCA, SVCAInteractionScoreTest 16 | from ._internal.optimizer import MultiScipyOptimizer 17 | from ._internal.util import get_l_limits, bh_adjust, calc_sizefactors, dense_slice 18 | from ._internal.distance_cache import DistanceCache 19 | 20 | 21 | def test_spatial_interactions( 22 | adata: AnnData, 23 | layer: Optional[str] = None, 24 | spatial_key: str = "spatial", 25 | ard: bool = False, 26 | sizefactors: Optional[np.ndarray] = None, 27 | copy: bool = False, 28 | ) -> Tuple[pd.DataFrame, Union[AnnData, None]]: 29 | """ 30 | Fit an SVCA model (Arnol, 2019) and test for spatial cell-cell interactions. 31 | 32 | In contrast to the original publication, which used a Gaussian approximation with a 33 | likelihood ratio test, this fits a Poisson GLMM and uses a score test. 34 | 35 | Args: 36 | adata: The annotated data matrix. 37 | layer: Name of the AnnData object layer to use. By default ``adata.X`` is used. 38 | spatial_key: Key in ``adata.obsm`` where the spatial coordinates are stored. 39 | ard: Whether to use automatic relevance determination for the kernel. This amounts to 40 | a separate lengthscale for each spatial dimension. 41 | sizefactors: Scaling factors for the observations. Defaults to total read counts. 42 | copy: Whether to return a copy of ``adata`` with results or write the results into ``adata`` 43 | in-place. 44 | 45 | Returns: 46 | A tuple. The first element is a Pandas DataFrame with the test results, the second is ``None`` 47 | if ``copy == False`` or an ``AnnData`` object. 48 | """ 49 | if sizefactors is None: 50 | sizefactors = calc_sizefactors(adata) 51 | 52 | X = adata.obsm[spatial_key] 53 | l_min, l_max = get_l_limits(DistanceCache(X)) 54 | lscales = l_min if not ard else [l_min] * X.shape[1] 55 | kernel = gpflow.kernels.SquaredExponential(lengthscales=lscales) 56 | kernel.lengthscales.transform = tfp.bijectors.Sigmoid( 57 | low=to_default_float(0.5 * l_min), high=to_default_float(2 * l_max) 58 | ) 59 | gpflow.set_trainable(kernel.variance, False) 60 | 61 | results = [] 62 | parameters = [] 63 | test = SVCAInteractionScoreTest( 64 | dense_slice(adata.X if layer is None else adata.layers[layer]), X, sizefactors, kernel 65 | ) 66 | 67 | params = gpflow.utilities.parameter_dict(test.kernel[0]) 68 | sortedkeys = sorted(params.keys()) 69 | dtype = np.dtype([(k, params[k].dtype.as_numpy_dtype) for k in sortedkeys]) 70 | 71 | for i, g in enumerate(tqdm(adata.var_names)): 72 | t0 = time() 73 | res, _ = test(i, None) 74 | t = time() - t0 75 | results.append({"time": t, "pval": res.pval.numpy(), "gene": g}) 76 | params = gpflow.utilities.read_values(test.kernel[0]) 77 | parameters.append(tuple([params[k] for k in sortedkeys])) 78 | 79 | results = pd.DataFrame(results) 80 | results.loc[ 81 | results.pval > 1, "pval" 82 | ] = 1 # this seems to be a bug in tensorflow_probability, survival_function should never be >1 83 | results["padj"] = bh_adjust(results.pval.to_numpy()) 84 | 85 | if copy: 86 | adata = adata.copy() 87 | toreturn = adata 88 | else: 89 | toreturn = None 90 | adata.varm["svca"] = np.array(parameters, dtype=dtype) 91 | adata.obsm["svca_sizefactors"] = sizefactors 92 | adata.uns["svca_ard"] = ard 93 | adata.uns["svca_layer"] = layer 94 | 95 | return results, toreturn 96 | 97 | 98 | def fit_spatial_interactions( 99 | adata: AnnData, 100 | layer: Optional[str] = None, 101 | genes: Optional[List[str]] = None, 102 | spatial_key: str = "spatial", 103 | ard: bool = False, 104 | sizefactors: Optional[np.ndarray] = None, 105 | ) -> pd.DataFrame: 106 | """ 107 | Estimate magnitude of spatial cell-cell interactions using an SVCA model (ARnol, 2019). 108 | 109 | In contrast to the original publication, which used a Gaussian approximation, this fits 110 | a Poisson GLMM. This function is intendend to be used after :py:func:`test_spatial_interactions` 111 | to analyse the genes showing significant cell-cell interactions. 112 | 113 | Args: 114 | adata: The annotated data matrix. 115 | layer: Name of the AnnData object layer to use. By default ``adata.X`` is used. 116 | genes: List of genes to analyze. Defaults to all genes. 117 | spatial_key: Key in ``adata.obsm`` where the spatial coordinates are stored. 118 | ard: Whether to use automatic relevance determination for the kernel. This amounts to 119 | a separate lengthscale for each spatial dimension. 120 | sizefactors: Scaling factors for the observations. Defaults to total read counts. 121 | 122 | Returns: A Pandas DataFrame with the results. 123 | """ 124 | if ( 125 | "svca" not in adata.varm 126 | or "svca_sizefactors" not in adata.obsm 127 | or "svca_ard" not in adata.uns 128 | or "svca_layer" not in adata.uns 129 | ): 130 | warnings.warn("SVCA parameters not found in adata. Performing ab initio fitting.") 131 | if genes is None: 132 | genes = adata.var_names 133 | if sizefactors is None: 134 | sizefactors = calc_sizefactors(adata) 135 | trainable = True 136 | else: 137 | if genes is None: 138 | warnings.warn("No genes given. Fitting all genes.") 139 | genes = adata.var_names 140 | sizefactors = adata.obsm["svca_sizefactors"] 141 | ard = adata.uns["svca_ard"] 142 | trainable = False 143 | 144 | X = adata.obsm[spatial_key] 145 | l_min, l_max = get_l_limits(DistanceCache(X)) 146 | lscales = l_min if not ard else [l_min] * X.shape[1] 147 | kernel = gpflow.kernels.SquaredExponential(lengthscales=lscales) 148 | kernel.lengthscales.transform = tfp.bijectors.Sigmoid( 149 | low=to_default_float(0.5 * l_min), high=to_default_float(2 * l_max) 150 | ) 151 | gpflow.set_trainable(kernel.variance, False) 152 | 153 | model = SVCA( 154 | dense_slice( 155 | adata.X if adata.uns["svca_layer"] is None else adata.layers[adata.uns["svca_layer"]] 156 | ), 157 | X, 158 | sizefactors, 159 | kernel, 160 | ) 161 | model.use_interactions(True) 162 | 163 | idx = np.argsort(adata.var_names) 164 | idx = idx[np.searchsorted(adata.var_names.to_numpy(), genes, sorter=idx)] 165 | if not trainable: 166 | param_names = adata.varm["svca"].dtype.names 167 | 168 | results = [] 169 | for g, i in zip(tqdm(genes), idx): 170 | model.currentgene = i 171 | if not trainable: 172 | gpflow.utilities.multiple_assign( 173 | model.kernel, {n: v for n, v in zip(param_names, adata.varm["svca"][i])} 174 | ) 175 | t0 = time() 176 | model.optimize() 177 | t = time() - t0 178 | fracvars = model.fraction_variance()._asdict() 179 | fracvars.update({"gene": g, "time": t}) 180 | results.append(fracvars) 181 | return pd.DataFrame(results) 182 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | source/generated 2 | build 3 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_templates/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | -------------------------------------------------------------------------------- /docs/source/_templates/module.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | 5 | {% block attributes %} 6 | {% if attributes %} 7 | .. rubric:: {{ _('Module Attributes') }} 8 | 9 | .. autosummary:: 10 | {% for item in attributes %} 11 | {{ item }} 12 | {%- endfor %} 13 | {% endif %} 14 | {% endblock %} 15 | 16 | {% block functions %} 17 | {% if functions %} 18 | .. rubric:: {{ _('Functions') }} 19 | 20 | .. autosummary:: 21 | :toctree: 22 | {% for item in functions %} 23 | {{ item }} 24 | {%- endfor %} 25 | {% endif %} 26 | {% endblock %} 27 | 28 | {% block classes %} 29 | {% if classes %} 30 | .. rubric:: {{ _('Classes') }} 31 | 32 | .. autosummary:: 33 | :toctree: 34 | :template: class.rst 35 | {% for item in classes %} 36 | {{ item }} 37 | {%- endfor %} 38 | {% endif %} 39 | {% endblock %} 40 | 41 | {% block exceptions %} 42 | {% if exceptions %} 43 | .. rubric:: {{ _('Exceptions') }} 44 | 45 | .. autosummary:: 46 | :toctree: 47 | {% for item in exceptions %} 48 | {{ item }} 49 | {%- endfor %} 50 | {% endif %} 51 | {% endblock %} 52 | 53 | {% block modules %} 54 | {% if modules %} 55 | .. rubric:: Modules 56 | 57 | .. autosummary:: 58 | :toctree: 59 | :recursive: 60 | {% for item in modules %} 61 | {{ item }} 62 | {%- endfor %} 63 | {% endif %} 64 | {% endblock %} 65 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "SpatialDE" 21 | copyright = "2021, Ilia Kats, Valentine Svensson" 22 | author = "Ilia Kats, Valentine Svensson" 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | "sphinx.ext.autodoc", 32 | "sphinx.ext.napoleon", 33 | "sphinx_autodoc_typehints", 34 | "sphinx.ext.autosummary", 35 | "sphinx_rtd_theme", 36 | "sphinx.ext.mathjax", 37 | ] 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | templates_path = ["_templates"] 41 | autodoc_default_flags = ["members"] 42 | autosummary_generate = True 43 | autosummary_imported_members = True 44 | 45 | # List of patterns, relative to source directory, that match files and 46 | # directories to ignore when looking for source files. 47 | # This pattern also affects html_static_path and html_extra_path. 48 | exclude_patterns = ["docs"] 49 | 50 | # -- docstring parsing config ------------------------------------------------ 51 | napoleon_numpy_docstring = False 52 | napoleon_include_init_with_doc = False 53 | napoleon_include_special_with_doc = False 54 | napoleon_attr_annotations = True 55 | 56 | autodoc_typehints = "both" 57 | autodoc_preserve_defaults = True 58 | always_document_param_types = True 59 | 60 | 61 | # -- Options for HTML output ------------------------------------------------- 62 | 63 | # The theme to use for HTML and HTML Help pages. See the documentation for 64 | # a list of builtin themes. 65 | # 66 | html_theme = "sphinx_rtd_theme" 67 | html_theme_options = {"collapse_navigation": False, "style_external_links": True} 68 | 69 | # Add any paths that contain custom static files (such as style sheets) here, 70 | # relative to this directory. They are copied after the builtin static files, 71 | # so a file named "default.css" will overwrite the builtin "default.css". 72 | html_static_path = ["_static"] 73 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. SpatialDE documentation master file, created by 2 | sphinx-quickstart on Fri Oct 9 11:58:49 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to SpatialDE's documentation! 7 | ===================================== 8 | .. toctree:: 9 | :maxdepth: 10 10 | :caption: Table of Contents 11 | :name: mastertoc 12 | 13 | API reference 14 | ============= 15 | .. autosummary:: 16 | :caption: API reference 17 | :toctree: generated 18 | :template: module.rst 19 | 20 | SpatialDE 21 | .. 22 | .. .. automodule:: SpatialDE 23 | .. :members: 24 | .. :undoc-members: 25 | .. :imported-members: 26 | .. 27 | 28 | 29 | Indices and tables 30 | ================== 31 | 32 | * :ref:`genindex` 33 | * :ref:`modindex` 34 | * :ref:`search` 35 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools >= 47.0.0", 4 | "wheel", 5 | "setuptools_scm[toml] >= 3.4", 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | 9 | [tool.setuptools_scm] 10 | write_to = "SpatialDE/version.py" 11 | 12 | [tool.black] 13 | line-length = 100 14 | target-version = ['py37'] 15 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = SpatialDE 3 | url = https://github.com/ilia-kats/SpatialDE 4 | description = Spatial and Temporal DE test 5 | long_description = file: README.rst 6 | author = Ilia Kats, Valentine Svensson 7 | author_email = i.kats@dkfz-heidelberg.de, valentine@nxn.se 8 | license = MIT 9 | 10 | [options] 11 | packages = find: 12 | python_requires = >= 3.7 13 | install_requires = 14 | numpy 15 | scipy >= 1.0 16 | pandas >= 1.0 17 | matplotlib >= 3.0 18 | tqdm 19 | gpflow >= 2.0 20 | tensorflow >= 2.3 21 | tensorflow-probability >= 0.10 22 | anndata >= 0.7 23 | NaiveDE 24 | h5py 25 | 26 | [options.extras_require] 27 | docs = 28 | sphinx >= 4.0 29 | sphinx-autodoc-typehints 30 | sphinx-rtd-theme 31 | --------------------------------------------------------------------------------