├── .gitignore ├── LICENSE ├── README.md ├── SpaceFlow ├── SpaceFlow.py ├── __init__.py └── util.py ├── images ├── annotation.png ├── domain_segmentation.png └── pSM.png ├── pyproject.toml ├── requirements.txt ├── setup.cfg ├── test.py └── tutorials └── seqfish_mouse_embryogenesis.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | test/ 3 | .idea 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Honglei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DOI](https://zenodo.org/badge/476440894.svg)](https://zenodo.org/badge/latestdoi/476440894) 2 | 3 | # SpaceFlow: Identifying Multicellular Spatiotemporal Organization of Cells using Spatial Transcriptome Data 4 | 5 | SpaceFlow is Python package for identifying spatiotemporal patterns and spatial domains from Spatial Transcriptomic (ST) Data. Based on deep graph network, SpaceFlow provides the following functions: 6 | 1. Encodes the ST data into **low-dimensional embeddings** that reflecting both expression similarity and the spatial proximity of cells in ST data. 7 | 2. Incorporates **spatiotemporal** relationships of cells or spots in ST data through a **pseudo-Spatiotemporal Map (pSM)** derived from the embeddings. 8 | 3. Identifies **spatial domains** with spatially-coherent expression patterns. 9 | 10 | Check out [our paper (Ren et al., Nature Communications, 2022)](https://www.nature.com/articles/s41467-022-31739-w) for the detailed methods and applications. 11 | 12 | SpaceFlow was developed in `Python 3.7` with `Pytorch 1.9.0`. Specific package versions are available in `requirements.txt`. The marker gene identification analysis is performed using `Scanpy 1.8.1` package. The cell-cell communication inference is performed through `CellChat v1.1.3` in a `R v4.1.2` environment. 13 | 14 | ## Installation 15 | 16 | ### 1. Prepare environment 17 | To install SpaceFlow, we recommend using the [Anaconda Python Distribution](https://anaconda.org/) and creating an isolated environment, so that the SpaceFlow and dependencies don't conflict or interfere with other packages or applications. To create the environment, run the following script in command line: 18 | 19 | ```bash 20 | conda create -n spaceflow_env python=3.7 21 | ``` 22 | 23 | After create the environment, you can activate the `spaceflow_env` environment by: 24 | ```bash 25 | conda activate spaceflow_env 26 | ``` 27 | 28 | ### 2. Install `Pytorch` 29 | Please install `Pytorch` that match your machine and environment first by following the instructions on : 30 | [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/) 31 | 32 | Note that if you want to install `Pytorch` on a `GPU` machine, you need to install **CUDA** first, see guide here for installing **CUDA** [https://developer.nvidia.com/cuda-downloads](https://developer.nvidia.com/cuda-downloads). 33 | 34 | ### 3. Install SpaceFlow 35 | After successfully installed `Pytorch` with the version that `>=1.9.0`, install the SpaceFlow package using `pip` by: 36 | ```bash 37 | pip install SpaceFlow 38 | ``` 39 | 40 | If the installation is still not successful, try to install the required packages in `requirements.txt` by: 41 | 42 | ```bash 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | ## Usage 47 | 48 | ### Quick Start by Example ([Jupyter Notebook](tutorials/seqfish_mouse_embryogenesis.ipynb)) 49 | 50 | We will use the mouse organogenesis ST data from [(Lohoff, T. et al. 2022)](https://doi.org/10.1038/s41587-021-01006-2) generated by [seqFISH](https://spatial.caltech.edu/seqfish/) to demonstrate the usage of SpaceFlow. 51 | 52 | The data is available in [squidpy](https://squidpy.readthedocs.io/en/stable/) package, so we first import the `squidpy` package and load the data. If `squidpy` is not installed. Please run `pip install squidpy` to install. 53 | 54 | #### 1. Import SpaceFlow and squidpy package 55 | 56 | ```python 57 | import squidpy as sq 58 | import scanpy as sc 59 | from SpaceFlow import SpaceFlow 60 | ``` 61 | 62 | #### 2. Load the ST data from squidpy package 63 | ```python 64 | adata = sq.datasets.seqfish() 65 | sc.pp.filter_genes(adata, min_cells=3) 66 | ``` 67 | 68 | #### 3. Create SpaceFlow Object 69 | We can create a SpaceFlow object through either `anndata.AnnData` object or the count matrix as input: 70 | 71 | To construct SpaceFlow object by inputting an `anndata.AnnData` object: 72 | ```python 73 | sf = SpaceFlow.SpaceFlow(adata=adata) 74 | ``` 75 | Parameters: 76 | - `adata`: the count matrix of gene expression, 2D numpy array of size (# of cells, # of genes), type `anndata.AnnData`, see `https://anndata.readthedocs.io/en/latest/` for more info about`anndata`. 77 | 78 | To SpaceFlow object by raw count matrix: 79 | ```python 80 | sf = SpaceFlow.SpaceFlow(count_matrix=adata.X, spatial_locs=adata.obsm['spatial'], sample_names=adata.obs_names, gene_names=adata.var_names) 81 | ``` 82 | Parameters: 83 | - `count_matrix`: the count matrix of gene expression, 2D numpy array of size (# of cells, # of genes), type `numpy.ndarray`, optional 84 | - `spatial_locs`: spatial locations of cells (or spots) match to rows of the count matrix, 1D numpy array of size (n_locations,), type `numpy.ndarray`, optional 85 | - `sample_names`: list of sample names in 1D numpy str array of size (n_cells,), optional 86 | - `gene_names`: list of gene names in 1D numpy str array of size (n_genes,), optional 87 | 88 | 89 | #### 4. Preprocessing the ST Data 90 | Next, we preprocess the ST data by run: 91 | 92 | ```python 93 | sf.preprocessing_data(n_top_genes=3000) 94 | ``` 95 | Parameters: 96 | - `n_top_genes`: the number of the top highly variable genes. 97 | 98 | The preprocessing includes the normalization and log-transformation of the expression count matrix, the selection of highly variable genes, and the construction of spatial proximity graph using spatial coordinates. (Details see the `preprocessing_data` function in `SpaceFlow/SpaceFlow.py`) 99 | 100 | #### 5. Train the deep graph network model 101 | 102 | We then train a spatially regularized deep graph network model to learn a low-dimensional embedding that reflecting both expression similarity and the spatial proximity of cells in ST data. 103 | 104 | ```python 105 | sf.train(spatial_regularization_strength=0.1, z_dim=50, lr=1e-3, epochs=1000, max_patience=50, min_stop=100, random_seed=42, gpu=0, regularization_acceleration=True, edge_subset_sz=1000000) 106 | ``` 107 | 108 | Parameters: 109 | - `spatial_regularization_strength`: the strength of spatial regularization, the larger the more of the spatial coherence in the identified spatial domains and spatiotemporal patterns. (default: 0.1) 110 | - `z_dim`: the target size of the learned embedding. (default: 50) 111 | - `lr`: learning rate for optimizing the model. (default: 1e-3) 112 | - `epochs`: the max number of the epochs for model training. (default: 1000) 113 | - `max_patience`: the max number of the epoch for waiting the loss decreasing. If loss does not decrease for epochs larger than this threshold, the learning will stop, and the model with the parameters that shows the minimal loss are kept as the best model. (default: 50) 114 | - `min_stop`: the earliest epoch the learning can stop if no decrease in loss for epochs larger than the `max_patience`. (default: 100) 115 | - `random_seed`: the random seed set to the random generators of the `random`, `numpy`, `torch` packages. (default: 42) 116 | - `gpu`: the index of the Nvidia GPU, if no GPU, the model will be trained via CPU, which is slower than the GPU training time. (default: 0) 117 | - `regularization_acceleration`: whether or not accelerate the calculation of regularization loss using edge subsetting strategy (default: True) 118 | - `edge_subset_sz`: the edge subset size for regularization acceleration (default: 1000000) 119 | 120 | #### 6. Domain segmentation of the ST data 121 | 122 | After the model training, the learned low-dimensional embedding can be accessed through `sf.embedding`. 123 | 124 | SpaceFlow will use this learned embedding to identify the spatial domains based on [Leiden](https://www.nature.com/articles/s41598-019-41695-z) algorithm. 125 | 126 | ```python 127 | sf.segmentation(domain_label_save_filepath="./domains.tsv", n_neighbors=50, resolution=1.0) 128 | ``` 129 | Parameters: 130 | 131 | - `domain_label_save_filepath`: the file path for saving the identified domain labels. (default: "./domains.tsv") 132 | - `n_neighbors`: the number of the nearest neighbors for each cell for constructing the graph for Leiden using the embedding as input. (default: 50) 133 | - `resolution`: the resolution of the Leiden clustering, the larger the coarser of the domains. (default: 1.0) 134 | 135 | #### 7. Visualization of the annotation and the identified spatial domains 136 | 137 | We next plot the spatial domains using the identified domain labels and spatial coordinates of cells. 138 | 139 | ```python 140 | sf.plot_segmentation(segmentation_figure_save_filepath="./domain_segmentation.pdf", colormap="tab20", scatter_sz=1., rsz=4., csz=4., wspace=.4, hspace=.5, left=0.125, right=0.9, bottom=0.1, top=0.9) 141 | ``` 142 | 143 | The expected output is: 144 | 145 | ![Domain Segmentation](images/domain_segmentation.png) 146 | 147 | Parameters: 148 | - `segmentation_figure_save_filepath`: optional, type: str, the file path for saving the figure of the spatial domain visualization. (default: "./domain_segmentation.pdf") 149 | - `colormap`: optional, type: str, the colormap of the different domains, full colormap options see [matplotlib](https://matplotlib.org/3.5.1/tutorials/colors/colormaps.html) 150 | - `scatter_sz`: optional, type: float, the marker size in points. (default: 1.0) 151 | - `rsz`: optional, type: float, row size of the figure in inches, (default: 4.0) 152 | - `csz`: optional, type: float, column size of the figure in inches, (default: 4.0) 153 | - `wspace`: optional, type: float, the amount of width reserved for space between subplots, expressed as a fraction of the average axis width (default: 0.4) 154 | - `hspace`: optional, type: float,the amount of height reserved for space between subplots, expressed as a fraction of the average axis height (default: 0.4) 155 | - `left`: optional, type: float, the leftmost position of the subplots of the figure in fraction (default: 0.125) 156 | - `right`: optional, type: float, the rightmost position of the subplots of the figure in fraction (default: 0.9) 157 | - `bottom`: optional, type: float, the bottom position of the subplots of the figure in fraction (default: 0.1) 158 | - `top`: optional, type: float, the top position of the subplots of the figure in fraction (default: 0.9) 159 | 160 | We can also visualize the expert annotation for comparison by: 161 | 162 | ```python 163 | import scanpy as sc 164 | sc.pl.spatial(adata, color="celltype_mapped_refined", spot_size=0.03) 165 | ``` 166 | 167 | The expected output is: 168 | 169 | ![Expert Annotation](images/annotation.png) 170 | 171 | #### 8. Idenfify the spatiotemporal patterns of the ST data through pseudo-Spatiotemporal Map (pSM) 172 | 173 | Next, we apply the diffusion pseudotime (dpt) algorithm to the learned spatially-consistent embedding to generate a pseudo-Spatiotemporal Map (pSM). This pSM represents a spatially-coherent pseudotime ordering of cells that encodes biological relationships between cells, such as developmental trajectories and cancer progression 174 | 175 | ```python 176 | sf.pseudo_Spatiotemporal_Map(pSM_values_save_filepath="./pSM_values.tsv", n_neighbors=20, resolution=1.0) 177 | ``` 178 | Parameters: 179 | - `pSM_values_save_filepath` : the file path for saving the inferred pSM values. 180 | - `n_neighbors`: the number of the nearest neighbors for each cell for constructing the graph for Leiden using the embedding as input. (default: 20) 181 | - `resolution`: the resolution of the Leiden clustering, the larger the coarser of the domains. (default: 1.0) 182 | 183 | #### 9. Visualization of the identified pseudo-Spatiotemporal Map (pSM) 184 | 185 | We next visualize the identified pseudo-Spatiotemporal Map (pSM). 186 | 187 | ```python 188 | sf.plot_pSM(pSM_figure_save_filepath="./pseudo-Spatiotemporal-Map.pdf", colormap="roma", scatter_sz=1., rsz=4., csz=4., wspace=.4, hspace=.5, left=0.125, right=0.9, bottom=0.1, top=0.9) 189 | ``` 190 | 191 | The expected output is: 192 | 193 | ![pSM](images/pSM.png) 194 | 195 | Parameters: 196 | - `pSM_figure_save_filepath`: optional, type: str, the file path for saving the figure of the pSM visualization. (default: "./pseudo-Spatiotemporal-Map.pdf") 197 | - `colormap`: optional, type: str, the colormap of the pSM (default: 'roma'), full colormap options see [Scientific Colormaps](https://www.fabiocrameri.ch/colourmaps-userguide/) 198 | - `scatter_sz`:optional, type: float, the marker size in points. (default: 1.0) 199 | - `rsz`: optional, type: float, row size of the figure in inches, (default: 4.0) 200 | - `csz`: optional, type: float, column size of the figure in inches, (default: 4.0) 201 | - `wspace`: optional, type: float, the amount of width reserved for space between subplots, expressed as a fraction of the average axis width (default: 0.4) 202 | - `hspace`: optional, type: float,the amount of height reserved for space between subplots, expressed as a fraction of the average axis height (default: 0.4) 203 | - `left`: optional, type: float, the leftmost position of the subplots of the figure in fraction (default: 0.125) 204 | - `right`: optional, type: float, the rightmost position of the subplots of the figure in fraction (default: 0.9) 205 | - `bottom`: optional, type: float, the bottom position of the subplots of the figure in fraction (default: 0.1) 206 | - `top`: optional, type: float, the top position of the subplots of the figure in fraction (default: 0.9) 207 | 208 | ## Please cite 209 | 210 | Ren, Honglei, et al. "Identifying multicellular spatiotemporal organization of cells with SpaceFlow." Nature Communications 13.1 (2022): 1-14. https://www.nature.com/articles/s41467-022-31739-w 211 | 212 | ## Contact 213 | If you have any questions or found any issues, please contact: [hongleir@uci.edu](mailto:hongleir@uci.edu). 214 | -------------------------------------------------------------------------------- /SpaceFlow/SpaceFlow.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | import os 4 | import torch 5 | import random 6 | import gudhi 7 | import anndata 8 | import cmcrameri 9 | import numpy as np 10 | import scanpy as sc 11 | import networkx as nx 12 | import torch.nn as nn 13 | import matplotlib.pyplot as plt 14 | from scipy.spatial import distance_matrix 15 | from torch_geometric.nn import GCNConv, DeepGraphInfomax 16 | from sklearn.neighbors import kneighbors_graph 17 | from SpaceFlow.util import sparse_mx_to_torch_edge_list, corruption 18 | 19 | class SpaceFlow(object): 20 | """An object for analysis of spatial transcriptomics data. 21 | 22 | :param adata: the `anndata.AnnData` object as input, see `https://anndata.readthedocs.io/en/latest/` for more info about`anndata`. 23 | :type adata: class:`anndata.AnnData` 24 | :param count_matrix: count matrix of gene expression, 2D numpy array of size (n_cells, n_genes) 25 | :type count_matrix: class:`numpy.ndarray` 26 | :param spatial_locs: spatial locations of cells (or spots) match to rows of the count matrix, 1D numpy array of size (n_cells,) 27 | :type spatial_locs: class:`numpy.ndarray` 28 | :param sample_names: list of sample names in 1D numpy str array of size (n_cells,), optional 29 | :type sample_names: class:`numpy.ndarray` or `list` of `str` 30 | :param gene_names: list of gene names in 1D numpy str array of size (n_genes,), optional 31 | :type gene_names: class:`numpy.ndarray` or `list` of `str` 32 | 33 | """ 34 | 35 | def __init__(self, adata=None, count_matrix=None, spatial_locs=None, sample_names=None, gene_names=None): 36 | """ 37 | Inputs 38 | ------ 39 | adata: an anndata.AnnData type object, optional (either input `adata` or both `count_matrix` and `spatial_locs`) 40 | count_matrix : count matrix of gene expression, 2D numpy array of size (n_cells, n_genes) 41 | spatial_locs : spatial locations of cells (or spots) match to rows of the count matrix, 1D numpy array of size (n_cells,) 42 | sample_names : list of sample names in 1D numpy str array of size (n_cells,), optional 43 | gene_names : list of gene names in 1D numpy str array of size (n_genes,), optional 44 | """ 45 | if adata and isinstance(adata, anndata.AnnData): 46 | self.adata = adata 47 | elif count_matrix is not None and spatial_locs is not None: 48 | self.adata = anndata.AnnData(count_matrix.astype(float)) 49 | self.adata.obsm['spatial'] = spatial_locs.astype(float) 50 | if gene_names: 51 | self.adata.var_names = np.array(gene_names).astype(str) 52 | if sample_names: 53 | self.adata.obs_names = np.array(sample_names).astype(str) 54 | else: 55 | print("Please input either an anndata.AnnData or both the count_matrix (count matrix of gene expression, 2D int numpy array of size (n_cells, n_genes)) and spatial_locs (spatial locations of cells (or spots) in 1D float numpy array of size (n_locations,)) to initiate SpaceFlow object.") 56 | exit(1) 57 | 58 | def plt_setting(self, fig_title_sz=30, font_sz=12, font_weight="bold", axes_title_sz=12, axes_label_sz=12, xtick_sz=10, ytick_sz=10, legend_font_sz=10): 59 | """ 60 | Setting the plotting configuration 61 | :param fig_title_sz: fontsize of the figure title, default: 30 62 | :type fig_title_sz: int, optional 63 | :param font_sz: controls default text sizes, default: 12 64 | :type font_sz: int, optional 65 | :param font_weight: controls default text weights, default: 'bold' 66 | :type font_weight: str, optional 67 | :param axes_title_sz: fontsize of the axes title, default: 12 68 | :type axes_title_sz: int, optional 69 | :param axes_label_sz: fontsize of the x and y labels, default 12 70 | :type axes_label_sz: int, optional 71 | :param xtick_sz: fontsize of the x tick label, default 10 72 | :type xtick_sz: int, optional 73 | :param ytick_sz: fontsize of the y tick label, default 10 74 | :type ytick_sz: int, optional 75 | :param legend_font_sz: legend fontsize, default 10 76 | :type legend_font_sz: int, optional 77 | """ 78 | plt.rc('figure', titlesize=fig_title_sz) # fontsize of the figure title 79 | plt.rc('font', size=font_sz, weight=font_weight) # controls default text sizes 80 | plt.rc('axes', titlesize=axes_title_sz) # fontsize of the axes title 81 | plt.rc('axes', labelsize=axes_label_sz) # fontsize of the x and y labels 82 | plt.rc('xtick', labelsize=xtick_sz) # fontsize of the x tick label 83 | plt.rc('ytick', labelsize=ytick_sz) # fontsize of the y tick label 84 | plt.rc('legend', fontsize=legend_font_sz) # legend fontsize 85 | 86 | def prepare_figure(self, rsz=4., csz=4., wspace=.4, hspace=.5, left=0.125, right=0.9, bottom=0.1, top=0.9): 87 | """ 88 | Prepare the figure and axes given the configuration 89 | :param rsz: row size of the figure in inches, default: 4.0 90 | :type rsz: float, optional 91 | :param csz: column size of the figure in inches, default: 4.0 92 | :type csz: float, optional 93 | :param wspace: the amount of width reserved for space between subplots, expressed as a fraction of the average axis width, default: 0.4 94 | :type wspace: float, optional 95 | :param hspace: the amount of height reserved for space between subplots, expressed as a fraction of the average axis width, default: 0.4 96 | :type hspace: float, optional 97 | :param left: the leftmost position of the subplots of the figure in fraction, default: 0.125 98 | :type left: float, optional 99 | :param right: the rightmost position of the subplots of the figure in fraction, default: 0.9 100 | :type right: float, optional 101 | :param bottom: the bottom position of the subplots of the figure in fraction, default: 0.1 102 | :type bottom: float, optional 103 | :param top: the top position of the subplots of the figure in fraction, default: 0.9 104 | :type top: float, optional 105 | """ 106 | fig, axs = plt.subplots(1, 1, figsize=(csz, rsz)) 107 | plt.subplots_adjust(wspace=wspace, hspace=hspace, left=left, right=right, bottom=bottom, top=top) 108 | return fig, axs 109 | 110 | def preprocessing_data(self, n_top_genes=None, n_neighbors=10): 111 | """ 112 | Preprocessing the spatial transcriptomics data 113 | Generates: `self.adata_filtered`: (n_cells, n_locations) `numpy.ndarray` 114 | `self.spatial_graph`: (n_cells, n_locations) `numpy.ndarray` 115 | :param adata: the annData object for spatial transcriptomics data with adata.obsm['spatial'] set to be the spatial locations. 116 | :type adata: class:`anndata.annData` 117 | :param n_top_genes: the number of top highly variable genes 118 | :type n_top_genes: int, optional 119 | :param n_neighbors: the number of nearest neighbors for building spatial neighbor graph 120 | :type n_neighbors: int, optional 121 | :return: a preprocessed annData object of the spatial transcriptomics data 122 | :rtype: class:`anndata.annData` 123 | :return: a geometry-aware spatial proximity graph of the spatial spots of cells 124 | :rtype: class:`scipy.sparse.csr_matrix` 125 | """ 126 | adata = self.adata 127 | if not adata: 128 | print("No annData object found, please run SpaceFlow.SpaceFlow(expr_data, spatial_locs) first!") 129 | return 130 | sc.pp.normalize_total(adata, target_sum=1e4) 131 | sc.pp.log1p(adata) 132 | sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor='cell_ranger', subset=True) 133 | sc.pp.pca(adata) 134 | spatial_locs = adata.obsm['spatial'] 135 | spatial_graph = self.graph_alpha(spatial_locs, n_neighbors=n_neighbors) 136 | 137 | self.adata_preprocessed = adata 138 | self.spatial_graph = spatial_graph 139 | 140 | def graph_alpha(self, spatial_locs, n_neighbors=10): 141 | """ 142 | Construct a geometry-aware spatial proximity graph of the spatial spots of cells by using alpha complex. 143 | :param adata: the annData object for spatial transcriptomics data with adata.obsm['spatial'] set to be the spatial locations. 144 | :type adata: class:`anndata.annData` 145 | :param n_neighbors: the number of nearest neighbors for building spatial neighbor graph based on Alpha Complex 146 | :type n_neighbors: int, optional, default: 10 147 | :return: a spatial neighbor graph 148 | :rtype: class:`scipy.sparse.csr_matrix` 149 | """ 150 | A_knn = kneighbors_graph(spatial_locs, n_neighbors=n_neighbors, mode='distance') 151 | estimated_graph_cut = A_knn.sum() / float(A_knn.count_nonzero()) 152 | spatial_locs_list = spatial_locs.tolist() 153 | n_node = len(spatial_locs_list) 154 | alpha_complex = gudhi.AlphaComplex(points=spatial_locs_list) 155 | simplex_tree = alpha_complex.create_simplex_tree(max_alpha_square=estimated_graph_cut ** 2) 156 | skeleton = simplex_tree.get_skeleton(1) 157 | initial_graph = nx.Graph() 158 | initial_graph.add_nodes_from([i for i in range(n_node)]) 159 | for s in skeleton: 160 | if len(s[0]) == 2: 161 | initial_graph.add_edge(s[0][0], s[0][1]) 162 | 163 | extended_graph = nx.Graph() 164 | extended_graph.add_nodes_from(initial_graph) 165 | extended_graph.add_edges_from(initial_graph.edges) 166 | 167 | # Remove self edges 168 | for i in range(n_node): 169 | try: 170 | extended_graph.remove_edge(i, i) 171 | except: 172 | pass 173 | 174 | return nx.to_scipy_sparse_matrix(extended_graph, format='csr') 175 | 176 | def train(self, embedding_save_filepath="./embedding.tsv", spatial_regularization_strength=0.1, z_dim=50, lr=1e-3, epochs=1000, max_patience=50, min_stop=100, random_seed=42, gpu=0, regularization_acceleration=True, edge_subset_sz=1000000): 177 | adata_preprocessed, spatial_graph = self.adata_preprocessed, self.spatial_graph 178 | """ 179 | Training the Deep GraphInfomax Model 180 | :param embedding_save_filepath: the default save path for the low-dimensional embeddings 181 | :type embedding_save_filepath: class:`str` 182 | :param spatial_regularization_strength: the strength for spatial regularization 183 | :type spatial_regularization_strength: float, optional, default: 0.1 184 | :param z_dim: the size of latent dimension 185 | :type z_dim: int, optional, default: 50 186 | :param lr: the learning rate for model optimization 187 | :type lr: float, optional, default: 1e-3 188 | :param epochs: the max epoch number 189 | :type epochs: int, optional, default: 1000 190 | :param max_patience: the tolerance epoch number without training loss decrease 191 | :type max_patience: int, optional, default: 50 192 | :param min_stop: the minimum epoch number for training before any early stop 193 | :type min_stop: int, optional, default: 100 194 | :param random_seed: the random seed 195 | :type random_seed: int, optional, default: 42 196 | :param gpu: the index for gpu device that will be used for model training, if no gpu detected, cpu will be used. 197 | :type gpu: int, optional, default: 0 198 | :param regularization_acceleration: whether or not accelerate the calculation of regularization loss using edge subsetting strategy 199 | :type regularization_acceleration: bool, optional, default: True 200 | :param edge_subset_sz: the edge subset size for regularization acceleration 201 | :type edge_subset_sz: int, optional, default: 1000000 202 | :return: low dimensional embeddings for the ST data, shape: n_cells x z_dim 203 | :rtype: class:`numpy.ndarray` 204 | """ 205 | if not adata_preprocessed: 206 | print("The data has not been preprocessed, please run preprocessing_data() method first!") 207 | return 208 | torch.manual_seed(random_seed) 209 | random.seed(random_seed) 210 | np.random.seed(random_seed) 211 | 212 | device = f"cuda:{gpu}" if torch.cuda.is_available() else 'cpu' 213 | model = DeepGraphInfomax( 214 | hidden_channels=z_dim, encoder=GraphEncoder(adata_preprocessed.shape[1], z_dim), 215 | summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)), 216 | corruption=corruption).to(device) 217 | 218 | expr = adata_preprocessed.X.todense() if type(adata_preprocessed.X).__module__ != np.__name__ else adata_preprocessed.X 219 | expr = torch.tensor(expr).float().to(device) 220 | 221 | edge_list = sparse_mx_to_torch_edge_list(spatial_graph).to(device) 222 | 223 | model.train() 224 | min_loss = np.inf 225 | patience = 0 226 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 227 | best_params = model.state_dict() 228 | 229 | for epoch in range(epochs): 230 | train_loss = 0.0 231 | torch.set_grad_enabled(True) 232 | optimizer.zero_grad() 233 | z, neg_z, summary = model(expr, edge_list) 234 | loss = model.loss(z, neg_z, summary) 235 | 236 | coords = torch.tensor(adata_preprocessed.obsm['spatial']).float().to(device) 237 | if regularization_acceleration or adata_preprocessed.shape[0] > 5000: 238 | cell_random_subset_1, cell_random_subset_2 = torch.randint(0, z.shape[0], (edge_subset_sz,)).to( 239 | device), torch.randint(0, z.shape[0], (edge_subset_sz,)).to(device) 240 | z1, z2 = torch.index_select(z, 0, cell_random_subset_1), torch.index_select(z, 0, cell_random_subset_2) 241 | c1, c2 = torch.index_select(coords, 0, cell_random_subset_1), torch.index_select(coords, 0, 242 | cell_random_subset_1) 243 | pdist = torch.nn.PairwiseDistance(p=2) 244 | 245 | z_dists = pdist(z1, z2) 246 | z_dists = z_dists / torch.max(z_dists) 247 | 248 | sp_dists = pdist(c1, c2) 249 | sp_dists = sp_dists / torch.max(sp_dists) 250 | n_items = z_dists.size(dim=0) 251 | else: 252 | z_dists = torch.cdist(z, z, p=2) 253 | z_dists = torch.div(z_dists, torch.max(z_dists)).to(device) 254 | sp_dists = torch.cdist(coords, coords, p=2) 255 | sp_dists = torch.div(sp_dists, torch.max(sp_dists)).to(device) 256 | n_items = z.size(dim=0) * z.size(dim=0) 257 | 258 | penalty_1 = torch.div(torch.sum(torch.mul(1.0 - z_dists, sp_dists)), n_items).to(device) 259 | loss = loss + spatial_regularization_strength * penalty_1 260 | 261 | loss.backward() 262 | optimizer.step() 263 | train_loss += loss.item() 264 | 265 | if train_loss > min_loss: 266 | patience += 1 267 | else: 268 | patience = 0 269 | min_loss = train_loss 270 | best_params = model.state_dict() 271 | if epoch % 10 == 1: 272 | print(f"Epoch {epoch + 1}/{epochs}, Loss: {str(train_loss)}") 273 | if patience > max_patience and epoch > min_stop: 274 | break 275 | 276 | model.load_state_dict(best_params) 277 | 278 | z, _, _ = model(expr, edge_list) 279 | embedding = z.cpu().detach().numpy() 280 | save_dir = os.path.dirname(embedding_save_filepath) 281 | if not os.path.exists(save_dir): 282 | os.makedirs(save_dir) 283 | np.savetxt(embedding_save_filepath, embedding[:, :], delimiter="\t") 284 | print(f"Training complete!\nEmbedding is saved at {embedding_save_filepath}") 285 | 286 | self.embedding = embedding 287 | return embedding 288 | 289 | def segmentation(self, domain_label_save_filepath="./domains.tsv", n_neighbors=50, resolution=1.0): 290 | """ 291 | Perform domain segmentation for ST data using Leiden clustering with low-dimensional embeddings as input 292 | :param domain_label_save_filepath: the default save path for the domain labels 293 | :type domain_label_save_filepath: class:`str`, optional, default: "./domains.tsv" 294 | :param n_neighbors: The size of local neighborhood (in terms of number of neighboring data 295 | points) used for manifold approximation. See `https://scanpy.readthedocs.io/en/stable/generated/scanpy.pp.neighbors.html` for detail 296 | :type n_neighbors: int, optional, default: 50 297 | :param resolution: A parameter value controlling the coarseness of the clustering. 298 | Higher values lead to more clusters. See `https://scanpy.readthedocs.io/en/stable/generated/scanpy.tl.leiden.html` for detail 299 | :type resolution: float, optional, default: 1.0 300 | """ 301 | error_message = "No embedding found, please ensure you have run train() method before segmentation!" 302 | try: 303 | print("Performing domain segmentation") 304 | embedding_adata = anndata.AnnData(self.embedding) 305 | sc.pp.neighbors(embedding_adata, n_neighbors=n_neighbors, use_rep='X') 306 | sc.tl.leiden(embedding_adata, resolution=float(resolution)) 307 | domains = embedding_adata.obs["leiden"].cat.codes 308 | 309 | save_dir = os.path.dirname(domain_label_save_filepath) 310 | if not os.path.exists(save_dir): 311 | os.makedirs(save_dir) 312 | np.savetxt(domain_label_save_filepath, domains, fmt='%d', header='', footer='', comments='') 313 | print(f"Segmentation complete, domain labels of cells or spots saved at {domain_label_save_filepath} !") 314 | self.domains = domains 315 | 316 | except NameError: 317 | print(error_message) 318 | except AttributeError: 319 | print(error_message) 320 | 321 | def plot_segmentation(self, segmentation_figure_save_filepath="./domain_segmentation.pdf", colormap="tab20", scatter_sz=1., rsz=4., csz=4., wspace=.4, hspace=.5, left=0.125, right=0.9, bottom=0.1, top=0.9): 322 | """ 323 | Plot the domain segmentation for ST data in spatial 324 | :param segmentation_figure_save_filepath: the default save path for the figure 325 | :type segmentation_figure_save_filepath: class:`str`, optional, default: "./domain_segmentation.pdf" 326 | :param colormap: The colormap to use. See `https://matplotlib.org/stable/tutorials/colors/colormaps.html` for full list of colormaps 327 | :type colormap: str, optional, default: tab20 328 | :param scatter_sz: The marker size in points**2 329 | :type scatter_sz: float, optional, default: 1.0 330 | :param rsz: row size of the figure in inches, default: 4.0 331 | :type rsz: float, optional 332 | :param csz: column size of the figure in inches, default: 4.0 333 | :type csz: float, optional 334 | :param wspace: the amount of width reserved for space between subplots, expressed as a fraction of the average axis width, default: 0.4 335 | :type wspace: float, optional 336 | :param hspace: the amount of height reserved for space between subplots, expressed as a fraction of the average axis width, default: 0.4 337 | :type hspace: float, optional 338 | :param left: the leftmost position of the subplots of the figure in fraction, default: 0.125 339 | :type left: float, optional 340 | :param right: the rightmost position of the subplots of the figure in fraction, default: 0.9 341 | :type right: float, optional 342 | :param bottom: the bottom position of the subplots of the figure in fraction, default: 0.1 343 | :type bottom: float, optional 344 | :param top: the top position of the subplots of the figure in fraction, default: 0.9 345 | :type top: float, optional 346 | """ 347 | error_message = "No segmentation data found, please ensure you have run the segmentation() method." 348 | try: 349 | fig, ax = self.prepare_figure(rsz=rsz, csz=csz, wspace=wspace, hspace=hspace, left=left, right=right, bottom=bottom, top=top) 350 | 351 | pred_clusters = np.array(self.domains).astype(int) 352 | uniq_pred = np.unique(pred_clusters) 353 | n_cluster = len(uniq_pred) 354 | x, y = self.adata_preprocessed.obsm["spatial"][:, 0], self.adata_preprocessed.obsm["spatial"][:, 1] 355 | cmap = plt.get_cmap(colormap) 356 | for cid, cluster in enumerate(uniq_pred): 357 | color = cmap((cid * (n_cluster / (n_cluster - 1.0))) / n_cluster) 358 | ind = pred_clusters == cluster 359 | ax.scatter(x[ind], y[ind], s=scatter_sz, color=color, label=cluster, marker=".") 360 | ax.set_facecolor("none") 361 | ax.invert_yaxis() 362 | ax.set_title("Domain Segmentation", fontsize=14) 363 | box = ax.get_position() 364 | height_ratio = .8 365 | ax.set_position([box.x0, box.y0, box.width * 0.8, box.height * height_ratio]) 366 | lgnd = ax.legend(loc='center left', fontsize=8, bbox_to_anchor=(1, 0.5), scatterpoints=1, handletextpad=0.1, 367 | borderaxespad=.1, ncol=int(math.ceil(n_cluster/10))) 368 | for handle in lgnd.legendHandles: 369 | handle._sizes = [8] 370 | 371 | save_dir = os.path.dirname(segmentation_figure_save_filepath) 372 | if not os.path.exists(save_dir): 373 | os.makedirs(save_dir) 374 | plt.savefig(segmentation_figure_save_filepath, dpi=300) 375 | print(f"Plotting complete, segmentation figure saved at {segmentation_figure_save_filepath} !") 376 | plt.close('all') 377 | except NameError: 378 | print(error_message) 379 | except AttributeError: 380 | print(error_message) 381 | 382 | def pseudo_Spatiotemporal_Map(self, pSM_values_save_filepath="./pSM_values.tsv", n_neighbors=20, resolution=1.0): 383 | """ 384 | Perform pseudo-Spatiotemporal Map for ST data 385 | :param pSM_values_save_filepath: the default save path for the pSM values 386 | :type pSM_values_save_filepath: class:`str`, optional, default: "./pSM_values.tsv" 387 | :param n_neighbors: The size of local neighborhood (in terms of number of neighboring data 388 | points) used for manifold approximation. See `https://scanpy.readthedocs.io/en/stable/generated/scanpy.pp.neighbors.html` for detail 389 | :type n_neighbors: int, optional, default: 20 390 | :param resolution: A parameter value controlling the coarseness of the clustering. 391 | Higher values lead to more clusters. See `https://scanpy.readthedocs.io/en/stable/generated/scanpy.tl.leiden.html` for detail 392 | :type resolution: float, optional, default: 1.0 393 | """ 394 | error_message = "No embedding found, please ensure you have run train() method before calculating pseudo-Spatiotemporal Map!" 395 | max_cell_for_subsampling = 5000 396 | try: 397 | print("Performing pseudo-Spatiotemporal Map") 398 | adata = anndata.AnnData(self.embedding) 399 | sc.pp.neighbors(adata, n_neighbors=n_neighbors, use_rep='X') 400 | sc.tl.umap(adata) 401 | sc.tl.leiden(adata, resolution=resolution) 402 | sc.tl.paga(adata) 403 | if adata.shape[0] < max_cell_for_subsampling: 404 | sub_adata_x = adata.X 405 | else: 406 | indices = np.arange(adata.shape[0]) 407 | selected_ind = np.random.choice(indices, max_cell_for_subsampling, False) 408 | sub_adata_x = adata.X[selected_ind, :] 409 | sum_dists = distance_matrix(sub_adata_x, sub_adata_x).sum(axis=1) 410 | adata.uns['iroot'] = np.argmax(sum_dists) 411 | sc.tl.diffmap(adata) 412 | sc.tl.dpt(adata) 413 | pSM_values = adata.obs['dpt_pseudotime'].to_numpy() 414 | save_dir = os.path.dirname(pSM_values_save_filepath) 415 | if not os.path.exists(save_dir): 416 | os.makedirs(save_dir) 417 | np.savetxt(pSM_values_save_filepath, pSM_values, fmt='%.5f', header='', footer='', comments='') 418 | print(f"pseudo-Spatiotemporal Map(pSM) calculation complete, pSM values of cells or spots saved at {pSM_values_save_filepath}!") 419 | self.pSM_values = pSM_values 420 | except NameError: 421 | print(error_message) 422 | except AttributeError: 423 | print(error_message) 424 | 425 | def plot_pSM(self, pSM_figure_save_filepath="./pseudo-Spatiotemporal-Map.pdf", colormap='roma', scatter_sz=1., rsz=4., csz=4., wspace=.4, hspace=.5, left=0.125, right=0.9, bottom=0.1, top=0.9): 426 | """ 427 | Plot the domain segmentation for ST data in spatial 428 | :param pSM_figure_save_filepath: the default save path for the figure 429 | :type pSM_figure_save_filepath: class:`str`, optional, default: "./Spatiotemporal-Map.pdf" 430 | :param colormap: The colormap to use. See `https://www.fabiocrameri.ch/colourmaps-userguide/` for name list of colormaps 431 | :type colormap: str, optional, default: roma 432 | :param scatter_sz: The marker size in points**2 433 | :type scatter_sz: float, optional, default: 1.0 434 | :param rsz: row size of the figure in inches, default: 4.0 435 | :type rsz: float, optional 436 | :param csz: column size of the figure in inches, default: 4.0 437 | :type csz: float, optional 438 | :param wspace: the amount of width reserved for space between subplots, expressed as a fraction of the average axis width, default: 0.4 439 | :type wspace: float, optional 440 | :param hspace: the amount of height reserved for space between subplots, expressed as a fraction of the average axis width, default: 0.4 441 | :type hspace: float, optional 442 | :param left: the leftmost position of the subplots of the figure in fraction, default: 0.125 443 | :type left: float, optional 444 | :param right: the rightmost position of the subplots of the figure in fraction, default: 0.9 445 | :type right: float, optional 446 | :param bottom: the bottom position of the subplots of the figure in fraction, default: 0.1 447 | :type bottom: float, optional 448 | :param top: the top position of the subplots of the figure in fraction, default: 0.9 449 | :type top: float, optional 450 | """ 451 | error_message = "No pseudo Spatiotemporal Map data found, please ensure you have run the pseudo_Spatiotemporal_Map() method." 452 | try: 453 | fig, ax = self.prepare_figure(rsz=rsz, csz=csz, wspace=wspace, hspace=hspace, left=left, right=right, bottom=bottom, top=top) 454 | x, y = self.adata_preprocessed.obsm["spatial"][:, 0], self.adata_preprocessed.obsm["spatial"][:, 1] 455 | st = ax.scatter(x, y, s=scatter_sz, c=self.pSM_values, cmap=f"cmc.{colormap}", marker=".") 456 | ax.invert_yaxis() 457 | clb = fig.colorbar(st) 458 | clb.ax.set_ylabel("pseudotime", labelpad=10, rotation=270, fontsize=10, weight='bold') 459 | ax.set_title("pseudo-Spatiotemporal Map", fontsize=14) 460 | ax.set_facecolor("none") 461 | 462 | save_dir = os.path.dirname(pSM_figure_save_filepath) 463 | if not os.path.exists(save_dir): 464 | os.makedirs(save_dir) 465 | plt.savefig(pSM_figure_save_filepath, dpi=300) 466 | print(f"Plotting complete, pseudo-Spatiotemporal Map figure saved at {pSM_figure_save_filepath} !") 467 | plt.close('all') 468 | except NameError: 469 | print(error_message) 470 | except AttributeError: 471 | print(error_message) 472 | 473 | class GraphEncoder(nn.Module): 474 | def __init__(self, in_channels, hidden_channels): 475 | super(GraphEncoder, self).__init__() 476 | self.conv = GCNConv(in_channels, hidden_channels, cached=False) 477 | self.prelu = nn.PReLU(hidden_channels) 478 | self.conv2 = GCNConv(hidden_channels, hidden_channels, cached=False) 479 | self.prelu2 = nn.PReLU(hidden_channels) 480 | 481 | def forward(self, x, edge_index): 482 | x = self.conv(x, edge_index) 483 | x = self.prelu(x) 484 | x = self.conv2(x, edge_index) 485 | x = self.prelu2(x) 486 | return x 487 | -------------------------------------------------------------------------------- /SpaceFlow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hongleir/SpaceFlow/58d1ab03f1538791f710db840d36a08f396491fb/SpaceFlow/__init__.py -------------------------------------------------------------------------------- /SpaceFlow/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import numpy as np 4 | 5 | def sparse_mx_to_torch_edge_list(sparse_mx): 6 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 7 | edge_list = torch.from_numpy( 8 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 9 | return edge_list 10 | 11 | def corruption(x, edge_index): 12 | return x[torch.randperm(x.size(0))], edge_index -------------------------------------------------------------------------------- /images/annotation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hongleir/SpaceFlow/58d1ab03f1538791f710db840d36a08f396491fb/images/annotation.png -------------------------------------------------------------------------------- /images/domain_segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hongleir/SpaceFlow/58d1ab03f1538791f710db840d36a08f396491fb/images/domain_segmentation.png -------------------------------------------------------------------------------- /images/pSM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hongleir/SpaceFlow/58d1ab03f1538791f710db840d36a08f396491fb/images/pSM.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | pandas 4 | anndata 5 | gudhi 6 | matplotlib 7 | networkx 8 | notebook 9 | scanpy 10 | squidpy 11 | cmcrameri 12 | scikit-learn 13 | torch>=1.9.0 14 | torch-geometric>=1.7.2 15 | torch-sparse>=0.6.11 16 | torch-scatter>=2.0.8 -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = SpaceFlow 3 | version = 1.0.3 4 | description = Identifying Spatiotemporal Patterns of Cells for Spatial Transcriptome Data 5 | author = Honglei Ren 6 | author_email = hongleir1@gmail.com 7 | 8 | [options] 9 | packages = find: 10 | install_requires = 11 | numpy 12 | scipy 13 | pandas 14 | anndata 15 | gudhi 16 | matplotlib 17 | networkx 18 | notebook 19 | scanpy 20 | squidpy 21 | scikit-learn 22 | torch>=1.9.0 23 | torch-geometric>=1.7.2 24 | torch-sparse>=0.6.11 25 | torch-scatter>=2.0.8 26 | cmcrameri 27 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import scanpy as sc 2 | import squidpy as sq 3 | from SpaceFlow import SpaceFlow 4 | 5 | adata = sq.datasets.seqfish() 6 | sf = SpaceFlow.SpaceFlow(expr_data=adata.X, spatial_locs=adata.obsm['spatial']) 7 | sf.preprocessing_data() 8 | sf.train() 9 | sf.segmentation() 10 | sf.plot_segmentation() 11 | sf.pseudo_Spatiotemporal_Map() 12 | sf.plot_pSM() -------------------------------------------------------------------------------- /tutorials/seqfish_mouse_embryogenesis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6e2c6ffb", 6 | "metadata": {}, 7 | "source": [ 8 | "### 1. Import SpaceFlow and squidpy package" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "bfb92fc1", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import warnings\n", 19 | "warnings.filterwarnings('ignore')\n", 20 | "import squidpy as sq\n", 21 | "from SpaceFlow import SpaceFlow" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "51f08772", 27 | "metadata": {}, 28 | "source": [ 29 | "### 2. Load the ST data from squidpy package" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "c9751643", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "adata = sq.datasets.seqfish()" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "id": "0c555016", 45 | "metadata": {}, 46 | "source": [ 47 | "### 3. Create SpaceFlow Object\n", 48 | "\n", 49 | "We create a SpaceFlow object using the count matrix of gene expression and the corresponding spatial locations of cells (or spots):" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "id": "03d86b43", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "sf = SpaceFlow.SpaceFlow(expr_data=adata.X, \n", 60 | " spatial_locs=adata.obsm['spatial'])" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "id": "44b702dc", 66 | "metadata": {}, 67 | "source": [ 68 | "Parameters:\n", 69 | "- `expr_data`: the count matrix of gene expression, 2D numpy array of size (# of cells, # of genes)\n", 70 | "- `spatial_locs`: spatial locations of cells (or spots) match to rows of the count matrix, 1D numpy array of size (n_locations,)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "id": "f94e1a04", 76 | "metadata": {}, 77 | "source": [ 78 | "### 4. Preprocessing the ST Data\n", 79 | "Next, we preprocess the ST data by run:" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 4, 85 | "id": "c39b6093", 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "sf.preprocessing_data(n_top_genes=3000)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "id": "aefc6a0c", 95 | "metadata": {}, 96 | "source": [ 97 | "Parameters:\n", 98 | "- `n_top_genes`: the number of the top highly variable genes.\n", 99 | "\n", 100 | "The preprocessing includes the normalization and log-transformation of the expression count matrix, the selection of highly variable genes, and the construction of spatial proximity graph using spatial coordinates. (Details see the `preprocessing_data` function in `SpaceFlow/SpaceFlow.py`)\n" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "id": "ece4a0d9", 106 | "metadata": {}, 107 | "source": [ 108 | "### 5. Train the deep graph network model\n", 109 | "\n", 110 | "We then train a spatially regularized deep graph network model to learn a low-dimensional embedding that reflecting both expression similarity and the spatial proximity of cells in ST data. \n" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 5, 116 | "id": "57a701a1", 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "name": "stdout", 121 | "output_type": "stream", 122 | "text": [ 123 | "Epoch 2/1000, Loss: 2.105508804321289\n", 124 | "Epoch 12/1000, Loss: 1.3464184999465942\n", 125 | "Epoch 22/1000, Loss: 1.09248685836792\n", 126 | "Epoch 32/1000, Loss: 0.8551639914512634\n", 127 | "Epoch 42/1000, Loss: 0.6653984189033508\n", 128 | "Epoch 52/1000, Loss: 0.4879186451435089\n", 129 | "Epoch 62/1000, Loss: 0.34967240691185\n", 130 | "Epoch 72/1000, Loss: 0.274603009223938\n", 131 | "Epoch 82/1000, Loss: 0.2149953991174698\n", 132 | "Epoch 92/1000, Loss: 0.18800733983516693\n", 133 | "Epoch 102/1000, Loss: 0.15587514638900757\n", 134 | "Epoch 112/1000, Loss: 0.15160660445690155\n", 135 | "Epoch 122/1000, Loss: 0.13229766488075256\n", 136 | "Epoch 132/1000, Loss: 0.12213142216205597\n", 137 | "Epoch 142/1000, Loss: 0.11389018595218658\n", 138 | "Epoch 152/1000, Loss: 0.11026303470134735\n", 139 | "Epoch 162/1000, Loss: 0.10858513414859772\n", 140 | "Epoch 172/1000, Loss: 0.0981854647397995\n", 141 | "Epoch 182/1000, Loss: 0.09261471033096313\n", 142 | "Epoch 192/1000, Loss: 0.09784780442714691\n", 143 | "Epoch 202/1000, Loss: 0.09372460097074509\n", 144 | "Epoch 212/1000, Loss: 0.09502124786376953\n", 145 | "Epoch 222/1000, Loss: 0.0898984745144844\n", 146 | "Epoch 232/1000, Loss: 0.08758103847503662\n", 147 | "Epoch 242/1000, Loss: 0.07875576615333557\n", 148 | "Epoch 252/1000, Loss: 0.07564771175384521\n", 149 | "Epoch 262/1000, Loss: 0.07840961217880249\n", 150 | "Epoch 272/1000, Loss: 0.08345237374305725\n", 151 | "Epoch 282/1000, Loss: 0.07537221163511276\n", 152 | "Epoch 292/1000, Loss: 0.07451540231704712\n", 153 | "Epoch 302/1000, Loss: 0.0789402425289154\n", 154 | "Epoch 312/1000, Loss: 0.07124275714159012\n", 155 | "Epoch 322/1000, Loss: 0.07497982680797577\n", 156 | "Epoch 332/1000, Loss: 0.07454021275043488\n", 157 | "Epoch 342/1000, Loss: 0.06699051707983017\n", 158 | "Epoch 352/1000, Loss: 0.08066301047801971\n", 159 | "Epoch 362/1000, Loss: 0.06973342597484589\n", 160 | "Epoch 372/1000, Loss: 0.07150772213935852\n", 161 | "Epoch 382/1000, Loss: 0.0745912715792656\n", 162 | "Epoch 392/1000, Loss: 0.07225732505321503\n", 163 | "Epoch 402/1000, Loss: 0.06479880213737488\n", 164 | "Epoch 412/1000, Loss: 0.07020575553178787\n", 165 | "Epoch 422/1000, Loss: 0.07892539352178574\n", 166 | "Epoch 432/1000, Loss: 0.06901434808969498\n", 167 | "Training complete!\n", 168 | "Embedding is saved at ./embedding.tsv\n" 169 | ] 170 | }, 171 | { 172 | "data": { 173 | "text/plain": [ 174 | "array([[-1.3693943 , -0.1211542 , 3.066468 , ..., -0.58150786,\n", 175 | " -0.12908368, 3.6877515 ],\n", 176 | " [-1.1312834 , -0.282027 , 3.43263 , ..., -0.6431831 ,\n", 177 | " -0.0902295 , 4.229353 ],\n", 178 | " [-1.0586573 , 2.594063 , 0.5477483 , ..., -0.2179767 ,\n", 179 | " -0.20996477, 1.9479373 ],\n", 180 | " ...,\n", 181 | " [-0.58471024, 2.330395 , -0.04218347, ..., -0.25748822,\n", 182 | " 0.0110341 , 1.8769083 ],\n", 183 | " [-0.8118924 , 0.48114178, 3.4098723 , ..., -0.2737087 ,\n", 184 | " 0.23058248, 2.079208 ],\n", 185 | " [-0.39583313, 1.7552938 , 0.08656111, ..., -0.27222347,\n", 186 | " 1.6503936 , 1.3976719 ]], dtype=float32)" 187 | ] 188 | }, 189 | "execution_count": 5, 190 | "metadata": {}, 191 | "output_type": "execute_result" 192 | } 193 | ], 194 | "source": [ 195 | "sf.train(spatial_regularization_strength=0.1, \n", 196 | " z_dim=50, \n", 197 | " lr=1e-3, \n", 198 | " epochs=1000, \n", 199 | " max_patience=50, \n", 200 | " min_stop=100, \n", 201 | " random_seed=42, \n", 202 | " gpu=0, \n", 203 | " regularization_acceleration=True, \n", 204 | " edge_subset_sz=1000000)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "id": "e100a063", 210 | "metadata": {}, 211 | "source": [ 212 | "Parameters:\n", 213 | "- `spatial_regularization_strength`: the strength of spatial regularization, the larger the more of the spatial coherence in the identified spatial domains and spatiotemporal patterns. (default: 0.1)\n", 214 | "- `z_dim`: the target size of the learned embedding. (default: 50)\n", 215 | "- `lr`: learning rate for optimizing the model. (default: 1e-3)\n", 216 | "- `epochs`: the max number of the epochs for model training. (default: 1000)\n", 217 | "- `max_patience`: the max number of the epoch for waiting the loss decreasing. If loss does not decrease for epochs larger than this threshold, the learning will stop, and the model with the parameters that shows the minimal loss are kept as the best model. (default: 50) \n", 218 | "- `min_stop`: the earliest epoch the learning can stop if no decrease in loss for epochs larger than the `max_patience`. (default: 100) \n", 219 | "- `random_seed`: the random seed set to the random generators of the `random`, `numpy`, `torch` packages. (default: 42)\n", 220 | "- `gpu`: the index of the Nvidia GPU, if no GPU, the model will be trained via CPU, which is slower than the GPU training time. (default: 0) \n", 221 | "- `regularization_acceleration`: whether or not accelerate the calculation of regularization loss using edge subsetting strategy (default: True)\n", 222 | "- `edge_subset_sz`: the edge subset size for regularization acceleration (default: 1000000)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "id": "99fefbb0", 228 | "metadata": {}, 229 | "source": [ 230 | "### 6. Domain segmentation of the ST data\n", 231 | "\n", 232 | "After the model training, the learned low-dimensional embedding can be accessed through `sf.embedding`.\n", 233 | "\n", 234 | "SpaceFlow will use this learned embedding to identify the spatial domains based on [Leiden](https://www.nature.com/articles/s41598-019-41695-z) algorithm. \n" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 6, 240 | "id": "b0cb4cd6", 241 | "metadata": {}, 242 | "outputs": [ 243 | { 244 | "name": "stdout", 245 | "output_type": "stream", 246 | "text": [ 247 | "Performing domain segmentation\n", 248 | "Segmentation complete, domain labels of cells or spots saved at ./domains.tsv !\n" 249 | ] 250 | } 251 | ], 252 | "source": [ 253 | "sf.segmentation(domain_label_save_filepath=\"./domains.tsv\", \n", 254 | " n_neighbors=50, \n", 255 | " resolution=1.0)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "id": "361a2f7e", 261 | "metadata": {}, 262 | "source": [ 263 | "Parameters:\n", 264 | "\n", 265 | "- `domain_label_save_filepath`: the file path for saving the identified domain labels. (default: \"./domains.tsv\")\n", 266 | "- `n_neighbors`: the number of the nearest neighbors for each cell for constructing the graph for Leiden using the embedding as input. (default: 50)\n", 267 | "- `resolution`: the resolution of the Leiden clustering, the larger the coarser of the domains. (default: 1.0)\n" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "id": "f584534d", 273 | "metadata": {}, 274 | "source": [ 275 | "### 7. Visualization of the identified spatial domains\n", 276 | "\n", 277 | "We next plot the spatial domains using the identified domain labels and spatial coordinates of cells." 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 7, 283 | "id": "7e4d546d", 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "Plotting complete, segmentation figure saved at ./domain_segmentation.pdf !\n" 291 | ] 292 | } 293 | ], 294 | "source": [ 295 | "sf.plot_segmentation(segmentation_figure_save_filepath=\"./domain_segmentation.pdf\", \n", 296 | " colormap=\"tab20\", \n", 297 | " scatter_sz=1., \n", 298 | " rsz=4., \n", 299 | " csz=4., \n", 300 | " wspace=.4, \n", 301 | " hspace=.5, \n", 302 | " left=0.125, \n", 303 | " right=0.9, \n", 304 | " bottom=0.1, \n", 305 | " top=0.9)" 306 | ] 307 | }, 308 | { 309 | "cell_type": "markdown", 310 | "id": "a13d92ec", 311 | "metadata": {}, 312 | "source": [ 313 | "Parameters:\n", 314 | "- `segmentation_figure_save_filepath`: optional, type: str, the file path for saving the figure of the spatial domain visualization. (default: \"./domain_segmentation.pdf\")\n", 315 | "- `colormap`: optional, type: str, the colormap of the different domains, full colormap options see [matplotlib](https://matplotlib.org/3.5.1/tutorials/colors/colormaps.html)\n", 316 | "- `scatter_sz`: optional, type: float, the marker size in points. (default: 1.0)\n", 317 | "- `rsz`: optional, type: float, row size of the figure in inches, (default: 4.0)\n", 318 | "- `csz`: optional, type: float, column size of the figure in inches, (default: 4.0)\n", 319 | "- `wspace`: optional, type: float, the amount of width reserved for space between subplots, expressed as a fraction of the average axis width (default: 0.4)\n", 320 | "- `hspace`: optional, type: float,the amount of height reserved for space between subplots, expressed as a fraction of the average axis height (default: 0.4)\n", 321 | "- `left`: optional, type: float, the leftmost position of the subplots of the figure in fraction (default: 0.125)\n", 322 | "- `right`: optional, type: float, the rightmost position of the subplots of the figure in fraction (default: 0.9)\n", 323 | "- `bottom`: optional, type: float, the bottom position of the subplots of the figure in fraction (default: 0.1)\n", 324 | "- `top`: optional, type: float, the top position of the subplots of the figure in fraction (default: 0.9)\n" 325 | ] 326 | }, 327 | { 328 | "cell_type": "markdown", 329 | "id": "1998dd75", 330 | "metadata": {}, 331 | "source": [ 332 | "We can also visualize the expert annotation for comparison by:" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 8, 338 | "id": "c20aa423", 339 | "metadata": {}, 340 | "outputs": [ 341 | { 342 | "data": { 343 | "image/png": "\n", 344 | "text/plain": [ 345 | "
" 346 | ] 347 | }, 348 | "metadata": {}, 349 | "output_type": "display_data" 350 | } 351 | ], 352 | "source": [ 353 | "import scanpy as sc\n", 354 | "sc.pl.spatial(adata, \n", 355 | " color=\"celltype_mapped_refined\",\n", 356 | " spot_size=0.03)" 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "id": "a85293e6", 362 | "metadata": {}, 363 | "source": [ 364 | "### 8. Idenfify the spatiotemporal patterns of the ST data through pseudo-Spatiotemporal Map (pSM)\n", 365 | "\n", 366 | "Next, we apply the diffusion pseudotime (dpt) algorithm to the learned spatially-consistent embedding to generate a pseudo-Spatiotemporal Map (pSM). This pSM represents a spatially-coherent pseudotime ordering of cells that encodes biological relationships between cells, such as developmental trajectories and cancer progression\n" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 10, 372 | "id": "5bc6ba72", 373 | "metadata": {}, 374 | "outputs": [ 375 | { 376 | "name": "stdout", 377 | "output_type": "stream", 378 | "text": [ 379 | "Performing pseudo-Spatiotemporal Map\n", 380 | "pseudo-Spatiotemporal Map(pSM) calculation complete, pSM values of cells or spots saved at ./pSM_values.tsv!\n" 381 | ] 382 | } 383 | ], 384 | "source": [ 385 | "sf.pseudo_Spatiotemporal_Map(pSM_values_save_filepath=\"./pSM_values.tsv\", \n", 386 | " n_neighbors=20, \n", 387 | " resolution=1.0)" 388 | ] 389 | }, 390 | { 391 | "cell_type": "markdown", 392 | "id": "bff9449d", 393 | "metadata": {}, 394 | "source": [ 395 | "Parameters:\n", 396 | "- `pSM_values_save_filepath` : the file path for saving the inferred pSM values. \n", 397 | "- `n_neighbors`: the number of the nearest neighbors for each cell for constructing the graph for Leiden using the embedding as input. (default: 20) \n", 398 | "- `resolution`: the resolution of the Leiden clustering, the larger the coarser of the domains. (default: 1.0)" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "id": "28b22f62", 404 | "metadata": {}, 405 | "source": [ 406 | "### 9. Visualization of the identified pseudo-Spatiotemporal Map (pSM)\n", 407 | "\n", 408 | "We next visualize the identified pseudo-Spatiotemporal Map (pSM).\n" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": null, 414 | "id": "63335035", 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "sf.plot_pSM(pSM_figure_save_filepath=\"./pseudo-Spatiotemporal-Map.pdf\", \n", 419 | " colormap=\"roma\", \n", 420 | " scatter_sz=1., \n", 421 | " rsz=4., \n", 422 | " csz=4., \n", 423 | " wspace=.4, \n", 424 | " hspace=.5, \n", 425 | " left=0.125, \n", 426 | " right=0.9, \n", 427 | " bottom=0.1, \n", 428 | " top=0.9)" 429 | ] 430 | }, 431 | { 432 | "cell_type": "markdown", 433 | "id": "dc3d5432", 434 | "metadata": {}, 435 | "source": [ 436 | "Parameters:\n", 437 | "- `pSM_figure_save_filepath`: optional, type: str, the file path for saving the figure of the pSM visualization. (default: \"./pseudo-Spatiotemporal-Map.pdf\")\n", 438 | "- `colormap`: optional, type: str, the colormap of the pSM (default: 'roma'), full colormap options see [Scientific Colormaps](https://www.fabiocrameri.ch/colourmaps-userguide/)\n", 439 | "- `scatter_sz`:optional, type: float, the marker size in points. (default: 1.0)\n", 440 | "- `rsz`: optional, type: float, row size of the figure in inches, (default: 4.0)\n", 441 | "- `csz`: optional, type: float, column size of the figure in inches, (default: 4.0)\n", 442 | "- `wspace`: optional, type: float, the amount of width reserved for space between subplots, expressed as a fraction of the average axis width (default: 0.4)\n", 443 | "- `hspace`: optional, type: float,the amount of height reserved for space between subplots, expressed as a fraction of the average axis height (default: 0.4)\n", 444 | "- `left`: optional, type: float, the leftmost position of the subplots of the figure in fraction (default: 0.125)\n", 445 | "- `right`: optional, type: float, the rightmost position of the subplots of the figure in fraction (default: 0.9)\n", 446 | "- `bottom`: optional, type: float, the bottom position of the subplots of the figure in fraction (default: 0.1)\n", 447 | "- `top`: optional, type: float, the top position of the subplots of the figure in fraction (default: 0.9)\n" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": null, 453 | "id": "fd465665", 454 | "metadata": {}, 455 | "outputs": [], 456 | "source": [] 457 | } 458 | ], 459 | "metadata": { 460 | "kernelspec": { 461 | "display_name": "spaceflow_env", 462 | "language": "python", 463 | "name": "spaceflow_env" 464 | }, 465 | "language_info": { 466 | "codemirror_mode": { 467 | "name": "ipython", 468 | "version": 3 469 | }, 470 | "file_extension": ".py", 471 | "mimetype": "text/x-python", 472 | "name": "python", 473 | "nbconvert_exporter": "python", 474 | "pygments_lexer": "ipython3", 475 | "version": "3.7.13" 476 | } 477 | }, 478 | "nbformat": 4, 479 | "nbformat_minor": 5 480 | } 481 | --------------------------------------------------------------------------------