├── .gitignore ├── .readthedocs.yml ├── README.md ├── SPACEL ├── Scube │ ├── __init__.py │ ├── alignment.py │ ├── gpr.py │ ├── plot.py │ └── utils_3d.py ├── Splane │ ├── __init__.py │ ├── base_model.py │ ├── graph.py │ ├── kegra │ │ ├── __init__.py │ │ ├── gnn.py │ │ └── gnn_utils.py │ ├── model.py │ ├── pygcn │ │ ├── .gitignore │ │ ├── LICENCE │ │ ├── README.md │ │ ├── data │ │ │ └── cora │ │ │ │ ├── README │ │ │ │ ├── cora.cites │ │ │ │ └── cora.content │ │ ├── figure.png │ │ ├── pygcn │ │ │ ├── __init__.py │ │ │ ├── layers.py │ │ │ ├── models.py │ │ │ ├── train.py │ │ │ └── utils.py │ │ └── setup.py │ ├── pygcn_utils.py │ └── utils.py ├── Spoint │ ├── __init__.py │ ├── base_model.py │ ├── data_augmentation.py │ ├── data_downsample.py │ ├── data_utils.py │ ├── metrics.py │ ├── model.py │ └── spatial_simulation.py ├── __init__.py ├── _version.py └── setting.py ├── docs ├── Makefile ├── _static │ └── img │ │ └── figure1.png ├── api.md ├── conf.py ├── index.md ├── installation.md ├── make.bat ├── requirements.txt ├── tutorials.md └── tutorials │ ├── MERFISH_mouse_brain_Scube.ipynb │ ├── MERFISH_mouse_brain_Splane.ipynb │ ├── STARmap_mouse_brain_GPR.ipynb │ ├── ST_mouse_brain_Scube.ipynb │ ├── ST_mouse_brain_Splane.ipynb │ ├── ST_mouse_brain_Spoint.ipynb │ ├── Stereo-seq_Scube.ipynb │ ├── Visium_human_DLPFC_Spoint.ipynb │ └── Visium_human_breast_cancer_Splane.ipynb ├── environment.yml ├── readthedocs_environment.yml ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks 4 | 5 | ### JupyterNotebooks ### 6 | # gitignore template for Jupyter Notebooks 7 | # website: http://jupyter.org/ 8 | */.DS_Store/* 9 | .DS_Store 10 | .ipynb_checkpoints 11 | */.ipynb_checkpoints/* 12 | 13 | # IPython 14 | profile_default/ 15 | ipython_config.py 16 | 17 | # Remove previous ipynb_checkpoints 18 | # git rm -r .ipynb_checkpoints/ 19 | 20 | ### Python ### 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | generated/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | wheels/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .nox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | *.py,cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | cover/ 74 | 75 | # Translations 76 | *.mo 77 | *.pot 78 | 79 | # Django stuff: 80 | *.log 81 | local_settings.py 82 | db.sqlite3 83 | db.sqlite3-journal 84 | 85 | # Flask stuff: 86 | instance/ 87 | .webassets-cache 88 | 89 | # Scrapy stuff: 90 | .scrapy 91 | 92 | # Sphinx documentation 93 | docs/_build/ 94 | 95 | # PyBuilder 96 | .pybuilder/ 97 | target/ 98 | 99 | # Jupyter Notebook 100 | 101 | # IPython 102 | 103 | # pyenv 104 | # For a library or package, you might want to ignore these files since the code is 105 | # intended to run in multiple environments; otherwise, check them in: 106 | # .python-version 107 | 108 | # pipenv 109 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 110 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 111 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 112 | # install all needed dependencies. 113 | #Pipfile.lock 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks 159 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | conda: 4 | environment: readthedocs_environment.yml -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Documentation Status](https://readthedocs.org/projects/spacel/badge/?version=latest)](https://spacel.readthedocs.io/en/latest/?badge=latest) 2 | ![PyPI](https://img.shields.io/pypi/v/SPACEL) 3 | 4 | # SPACEL: characterizing spatial transcriptome architectures by deep-learning 5 | 6 | ![](docs/_static/img/figure1.png "Overview") 7 | SPACEL (**SP**atial **A**rchitecture **C**haracterization by d**E**ep **L**earning) is a Python package of deep-learning-based methods for ST data analysis. SPACEL consists of three modules: 8 | * Spoint embedded a multiple-layer perceptron with a probabilistic model to deconvolute cell type composition for each spot on single ST slice. 9 | * Splane employs a graph convolutional network approach and an adversarial learning algorithm to identify uniform spatial domains that are transcriptomically and spatially coherent across multiple ST slices. 10 | * Scube automatically transforms the spatial coordinate systems of consecutive slices and stacks them together to construct a three-dimensional (3D) alignment of the tissue. 11 | 12 | ## Getting started 13 | * [Requirements](#Requirements) 14 | * [Installation](#Installation) 15 | * Tutorials 16 | * [Spoint tutorial: Deconvolution of cell types compostion on human brain Visium dataset](docs/tutorials/Visium_human_DLPFC_Spoint.ipynb) 17 | * [Splane tutorial: Identify uniform spatial domain on human breast cancer Visium dataset](docs/tutorials/Visium_human_breast_cancer_Splane.ipynb) 18 | * [Splane&Scube tutorial (1/2): Identify uniform spatial domain on human brain MERFISH dataset](docs/tutorials/MERFISH_mouse_brain_Splane.ipynb) 19 | * [Splane&Scube tutorial (1/2): Alignment of consecutive ST slices on human brain MERFISH dataset](docs/tutorials/MERFISH_mouse_brain_Scube.ipynb) 20 | * [Scube tutorial: Alignment of consecutive ST slices on mouse embryo Stereo-seq dataset](docs/tutorials/Stereo-seq_Scube.ipynb) 21 | * [Scube tutorial: 3D expression modeling with gaussian process regression](docs/tutorials/STARmap_mouse_brain_GPR.ipynb) 22 | * [SPACEL workflow (1/3): Deconvolution by Spoint on mouse brain ST dataset](docs/tutorials/ST_mouse_brain_Spoint.ipynb) 23 | * [SPACEL workflow (2/3): Identification of spatial domain by Splane on mouse brain ST dataset](docs/tutorials/ST_mouse_brain_Splane.ipynb) 24 | * [SPACEL workflow (3/3): Alignment 3D tissue by Scube on mouse brain ST dataset](docs/tutorials/ST_mouse_brain_Scube.ipynb) 25 | 26 | Read the [documentation](https://spacel.readthedocs.io) for more information. 27 | 28 | ## Latest updates 29 | ### Version 1.1.8 2024-07-23 30 | #### Fixed Bugs 31 | - Fixed the conflict between optax version and phthon 3.8. 32 | 33 | ### Version 1.1.7 2024-01-16 34 | #### Fixed Bugs 35 | - Fixed a variable reference error in function `identify_spatial_domain`. Thanks to @tobias-zehnde for the contribution. 36 | 37 | ### Version 1.1.6 2023-07-27 38 | #### Fixed Bugs 39 | - Fixed a bug regarding the similarity loss weight hyperparameter `simi_l`, which in the previous version did not affect the loss value. 40 | 41 | ## Requirements 42 | **Note**: The current version of SPACEL only supports Linux and MacOS, not Windows platform. 43 | 44 | To install `SPACEL`, you need to install [PyTorch](https://pytorch.org) with GPU support first. If you don't need GPU acceleration, you can just skip the installation for `cudnn` and `cudatoolkit`. 45 | * Create conda environment for `SPACEL`: 46 | ``` 47 | conda env create -f environment.yml 48 | ``` 49 | or 50 | ``` 51 | conda create -n SPACEL -c conda-forge -c default cudatoolkit=10.2 python=3.8 rpy2 r-base r-fitdistrplus 52 | ``` 53 | You must choose correct `PyTorch`, `cudnn` and `cudatoolkit` version dependent on your graphic driver version. 54 | 55 | Note: If you want to run 3D expression GPR model in Scube, you need to install the [Open3D](http://www.open3d.org/docs/release/) python library first. 56 | 57 | ## Installation 58 | * Install `SPACEL`: 59 | ``` 60 | pip install SPACEL 61 | ``` 62 | * Test if [PyTorch](https://pytorch.org) for GPU available: 63 | ``` 64 | python 65 | >>> import torch 66 | >>> torch.cuda.is_available() 67 | ``` 68 | If these command line have not return `True`, please check your gpu driver version and `cudatoolkit` version. For more detail, look at [CUDA Toolkit Major Component Versions](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions). 69 | -------------------------------------------------------------------------------- /SPACEL/Scube/__init__.py: -------------------------------------------------------------------------------- 1 | from .alignment import * 2 | from .plot import * 3 | from .gpr import * 4 | import os 5 | os.environ["OPENBLAS_NUM_THREADS"] = "1" 6 | os.environ["MKL_NUM_THREADS"] = "1" 7 | -------------------------------------------------------------------------------- /SPACEL/Scube/alignment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from multiprocessing import Pool 5 | from scipy.optimize import minimize, differential_evolution 6 | from sklearn.metrics import pairwise_distances 7 | from sklearn.neighbors import BallTree 8 | from copy import deepcopy 9 | import time 10 | from functools import partial 11 | 12 | def rotate_loc(x,y,x0,y0,angle): 13 | angle = angle/(180/np.pi) 14 | x1=np.cos(angle)*(x-x0)-np.sin(angle)*(y-y0) + x0 15 | y1=np.cos(angle)*(y-y0)+np.sin(angle)*(x-x0) + y0 16 | return x1,y1 17 | 18 | # Deprecated 19 | def neighbor_simi(source_loc,target_loc, source_cluster, target_cluster, target_p_dis,k,knn_exclude_cutoff): 20 | dis_knn_exclude_cutoff = np.median(np.partition(target_p_dis,kth=k+knn_exclude_cutoff,axis=1)[:,k+knn_exclude_cutoff]) 21 | p_dis = pairwise_distances(source_loc,target_loc) 22 | top5_neighbors = np.argpartition(p_dis,kth=k,axis=1)[:,:k] 23 | n_simi = [] 24 | for i, sc in enumerate(source_cluster): 25 | tmp_neighbors = list(top5_neighbors[i]) 26 | tmp_neighbors_reversed = tmp_neighbors[::-1] 27 | for j in tmp_neighbors_reversed: 28 | if p_dis[i,j] > dis_knn_exclude_cutoff: 29 | tmp_neighbors.remove(j) 30 | else: 31 | break 32 | if len(tmp_neighbors) != 0: 33 | n_simi.append((target_cluster[tmp_neighbors] == sc).sum()/len(tmp_neighbors)) 34 | return np.mean(n_simi) - (len(n_simi)/i-1)**2 35 | 36 | def neighbor_simi_fast(source_loc,target_loc, source_cluster, target_cluster,k,knn_exclude_cutoff,p=2,a=1): 37 | target_tree = BallTree(target_loc, leaf_size=15, metric='minkowski') 38 | t_dis, _ = target_tree.query(target_loc, k=k+knn_exclude_cutoff) 39 | dis_knn_exclude_cutoff = np.median(t_dis[:,k+knn_exclude_cutoff-1]) 40 | s_dis, nearest_neighbors = target_tree.query(source_loc, k=k) 41 | nearest_neighbors[s_dis > dis_knn_exclude_cutoff] = -1 42 | kept_ind = (nearest_neighbors != -1).sum(1)>0 43 | if kept_ind.sum() > 0: 44 | nearest_neighbors = nearest_neighbors[kept_ind] 45 | neighbors_cluster = deepcopy(nearest_neighbors) 46 | neighbors_cluster[nearest_neighbors!=-1] = target_cluster[nearest_neighbors[nearest_neighbors!=-1]] 47 | simi = (((neighbors_cluster - source_cluster[kept_ind].reshape(-1,1)) == 0).sum(1)/(neighbors_cluster!=-1).sum(1)).mean() 48 | else: 49 | simi = 0 50 | overlap = kept_ind.sum()/len(source_cluster) 51 | neighbors_simi = simi - a*((1-overlap)**p) 52 | return neighbors_simi 53 | 54 | def neighbor_simi_new(source_loc,target_loc, source_cluster, target_cluster,k,knn_exclude_cutoff,p=2,a=1): 55 | target_tree = BallTree(target_loc, leaf_size=15, metric='minkowski') 56 | source_tree = BallTree(source_loc, leaf_size=15, metric='minkowski') 57 | t_dis, _ = target_tree.query(target_loc, k=k+knn_exclude_cutoff) 58 | dis_knn_exclude_cutoff = np.median(t_dis[:,k+knn_exclude_cutoff-1]) 59 | s_dis, target_nearest_neighbors = target_tree.query(source_loc, k=k) 60 | _, source_nearest_neighbors = source_tree.query(source_loc, k=k) 61 | target_nearest_neighbors[s_dis > dis_knn_exclude_cutoff] = -1 62 | kept_ind = (target_nearest_neighbors != -1).sum(1)>0 63 | if kept_ind.sum() > 0: 64 | target_nearest_neighbors = target_nearest_neighbors[kept_ind] 65 | target_neighbors_cluster = deepcopy(target_nearest_neighbors) 66 | target_neighbors_cluster[target_nearest_neighbors!=-1] = target_cluster[target_nearest_neighbors[target_nearest_neighbors!=-1]] 67 | source_neighbors_cluster = source_cluster[source_nearest_neighbors][kept_ind] 68 | cluster_num = max(np.unique(np.concatenate([target_cluster,source_cluster])))+1 69 | cluster_onehot = np.concatenate([np.eye(cluster_num),np.zeros((1,cluster_num))]) 70 | # simi = (((neighbors_cluster - source_cluster[kept_ind].reshape(-1,1)) == 0).sum(1)/(neighbors_cluster!=-1).sum(1)).mean() 71 | target_cluster_count = cluster_onehot[target_neighbors_cluster].sum(1) 72 | source_cluster_count = cluster_onehot[source_neighbors_cluster].sum(1) 73 | target_cluster_prop = target_cluster_count/target_cluster_count.sum(1,keepdims=True) 74 | source_cluster_prop = source_cluster_count/source_cluster_count.sum(1,keepdims=True) 75 | simi = ((source_cluster_prop - target_cluster_prop)**2).sum(1).mean()/2 76 | else: 77 | simi = 1 78 | overlap = kept_ind.sum()/len(source_cluster) 79 | neighbors_simi = simi + a*((1-overlap)**p) 80 | # print(simi,overlap) 81 | return -neighbors_simi 82 | 83 | def score(warp_param, target_loc, source_loc, target_cluster, source_cluster,k,knn_exclude_cutoff,p,a): 84 | new_source_loc = [] 85 | x0 = 0 86 | y0 = 0 87 | for x,y in source_loc: 88 | new_source_loc.append(rotate_loc(x,y,x0,y0,warp_param[0])) 89 | new_source_loc = np.array(new_source_loc) 90 | new_source_loc[:,0] += warp_param[1] 91 | new_source_loc[:,1] += warp_param[2] 92 | return -neighbor_simi_fast(new_source_loc, target_loc, source_cluster, target_cluster, k, knn_exclude_cutoff,p,a) 93 | 94 | def optimize(target_loc, source_loc, target_cluster, source_cluster, n_neighbors, knn_exclude_cutoff ,p,a, bound_alpha,n_threads,*args,**kwargs): 95 | 96 | source_loc_flip = deepcopy(source_loc) 97 | source_loc_flip[:,0] = -source_loc_flip[:,0] 98 | func1 = partial(score, target_loc=target_loc, source_loc=source_loc, target_cluster=target_cluster, source_cluster=source_cluster,k=n_neighbors,knn_exclude_cutoff=knn_exclude_cutoff,p=p,a=a) 99 | opm1 = differential_evolution(func1, 100 | bounds=((0, 360), (target_loc.min(0)[0]*bound_alpha, target_loc.max(0)[0]*bound_alpha), (target_loc.min(0)[1]*bound_alpha, target_loc.max(0)[1]*bound_alpha)), 101 | workers=n_threads, 102 | updating='immediate' if n_threads == 1 else 'deferred', 103 | *args, 104 | **kwargs) 105 | score1 = -opm1.fun 106 | result1 = opm1.x 107 | 108 | func2 = partial(score, target_loc=target_loc, source_loc=source_loc_flip, target_cluster=target_cluster, source_cluster=source_cluster,k=n_neighbors,knn_exclude_cutoff=knn_exclude_cutoff,p=p,a=a) 109 | opm2 = differential_evolution(func2, 110 | bounds=((0, 360), (target_loc.min(0)[0]*bound_alpha, target_loc.max(0)[0]*bound_alpha), (target_loc.min(0)[1]*bound_alpha, target_loc.max(0)[1]*bound_alpha)), 111 | workers=n_threads, 112 | updating='immediate' if n_threads == 1 else 'deferred', 113 | *args, 114 | **kwargs) 115 | score2 = -opm2.fun 116 | result2 = opm2.x 117 | 118 | if score1 > score2: 119 | return 0, result1, score1, 1, result2, score2 120 | else: 121 | return 1, result2, score2, 0, result1, score1 122 | 123 | def align_pairwise(param, n_neighbors, knn_exclude_cutoff,p,a,bound_alpha,n_threads,*args,**kwargs): 124 | i, target_loc,source_loc,target_cluster,source_cluster = param 125 | flip, result, score, alter_flip, alter_result,alter_score = optimize(target_loc, source_loc, target_cluster, source_cluster, n_neighbors, knn_exclude_cutoff, p,a,bound_alpha,n_threads,*args,**kwargs) 126 | return [i,flip,result[0],result[1],result[2],score,alter_flip,alter_result[0],alter_result[1],alter_result[2],alter_score] 127 | 128 | def align( 129 | ad_list, 130 | cluster_key='spatial_domain', 131 | output_path=None, 132 | raw_loc_key='spatial', 133 | aligned_loc_key='spatial_aligned', 134 | n_neighbors=15, 135 | knn_exclude_cutoff=None, 136 | p=2, 137 | a=1, 138 | bound_alpha=1, 139 | write_loc_path=None, 140 | n_threads=1, 141 | seed=42, 142 | subset_prop=None, 143 | *args, 144 | **kwargs 145 | ): 146 | """Pairwise alignment. 147 | 148 | Pairwise align the slices in ad_list. The aligned coordinates are saved in ``.obsm[aligned_loc_key]`` in each slices of ``ad_list``. 149 | 150 | Args: 151 | ad_list: A list containing all slice data in AnnData object. 152 | cluster_key: A string representing one column of ``obs`` in AnnData object, containing the spatial domain information used for alignment. 153 | output_path: A string representing the path directory where the alignment parameters are saved. If ``None``, it will be 'Scube_outputs'. 154 | raw_loc_key: A string representing one key of ``obsm`` in AnnData object of each slice in ``ad_list``, containing the raw coordinates. 155 | aligned_loc_key: A string written to a key of ``obsm`` in AnnData object of each slice in ``ad_list``, containing the aligned coordinates. 156 | n_neighbors: A number of neighbors in target slices considered by each spot/cell in source slices. 157 | knn_exclude_cutoff: A number used to filter the neighbors in MNN. The neighbor will be exclude when the distance of neighbors larger than the median of neareast ``n_neighbors + knn_exclude_cutoff`` neighbors distance in all spots/cells in target slice. If ``None``, it will automatically default to ``n_neighbors``. 158 | p: Degree of the penalty function. 159 | a: Coefficient of the penalty function. 160 | bound_alpha: For the optimized boundary, the multiplier based on the maximum and minimum values of the slice coordinates. 161 | write_loc_path: A string representing the path directory where the aligned coordinates of all slices are saved. If ``None``, it won't be saved. 162 | n_threads: The number of parallel threads for the optimization algorithm. 163 | seed: Seed for the optimization algorithm. 164 | subset_prop: The downsampling ratio for cells in each slice. 165 | 166 | Returns: 167 | ``None`` 168 | """ 169 | if subset_prop is not None: 170 | for i in range(len(ad_list)): 171 | ad_list[i] = ad_list[i][np.random.permutation(ad_list[i].obs_names)[:int(ad_list[i].shape[0]*subset_prop)]].copy() 172 | if output_path is None: 173 | output_path = 'Scube_outputs' 174 | if not os.path.exists(output_path): 175 | os.makedirs(output_path) 176 | 177 | # centering X, Y coordinate 178 | for i in range(len(ad_list)): 179 | raw_loc = np.asarray(ad_list[i].obsm[raw_loc_key], dtype=np.float32) 180 | raw_loc[:,:2] = raw_loc[:,:2] - np.median(raw_loc[:,:2],axis=0) 181 | ad_list[i].obsm['spatial_pair'] = raw_loc 182 | 183 | if knn_exclude_cutoff is None: 184 | knn_exclude_cutoff = n_neighbors 185 | 186 | start = time.time() 187 | print('Start alignment...') 188 | res = [] 189 | for i in range(1,len(ad_list)): 190 | print(f'Alignment slice {i} to {i-1}') 191 | target_ind = i-1 192 | source_ind = i 193 | target_xy = ad_list[target_ind].obsm['spatial_pair'][:,:2] 194 | source_xy = ad_list[source_ind].obsm['spatial_pair'][:,:2] 195 | target_cluster = np.asarray(ad_list[target_ind].obs[cluster_key]) 196 | source_cluster = np.asarray(ad_list[source_ind].obs[cluster_key]) 197 | cluster_name = np.unique(np.concatenate([target_cluster,source_cluster])) 198 | cluster_index = np.arange(len(cluster_name)) 199 | cluster_name = dict(zip(cluster_name, cluster_index)) 200 | target_cluster = np.array(pd.Series(target_cluster).replace(cluster_name)) 201 | source_cluster = np.array(pd.Series(source_cluster).replace(cluster_name)) 202 | 203 | param = [i, target_xy, source_xy, target_cluster, source_cluster] 204 | r = align_pairwise(param, n_neighbors=n_neighbors, knn_exclude_cutoff=knn_exclude_cutoff,p=p,a=a,bound_alpha=bound_alpha, n_threads=n_threads,seed=seed) 205 | res.append(r) 206 | 207 | warp_info=np.array(res,dtype=np.float32) 208 | np.save(os.path.join(output_path,'warp_info.npy'),warp_info) 209 | 210 | warp_info = warp_info[:,:6] 211 | score = warp_info[:,5] 212 | print('Runtime: ' + str(time.time() - start),'s') 213 | 214 | ad_list[0].obsm[aligned_loc_key] = ad_list[0].obsm['spatial_pair'] 215 | for i in range(1,len(ad_list)): 216 | target_ind = i-1 217 | source_ind = i 218 | target_loc = ad_list[target_ind].obsm['spatial_pair'] 219 | source_loc = ad_list[source_ind].obsm['spatial_pair'] 220 | target_cluster = np.asarray(ad_list[target_ind].obs[cluster_key]) 221 | source_cluster = np.asarray(ad_list[source_ind].obs[cluster_key]) 222 | old_source_loc = deepcopy(source_loc) 223 | for r in warp_info[:source_ind][::-1]: 224 | if r[1] == 1: 225 | loc_flip = deepcopy(old_source_loc) 226 | loc_flip[:,0] = -loc_flip[:,0] 227 | old_source_loc = loc_flip 228 | new_source_xy = [] 229 | for _x,_y in old_source_loc[:,:2]: 230 | new_source_xy.append(rotate_loc(_x,_y,0,0,r[2])) 231 | new_source_xy = np.array(new_source_xy) 232 | new_source_loc = deepcopy(source_loc) 233 | new_source_loc[:,:2] = new_source_xy 234 | # translate loc 235 | new_source_loc[:,0] += r[3] 236 | new_source_loc[:,1] += r[4] 237 | old_source_loc = deepcopy(new_source_loc) 238 | ad_list[source_ind].obsm[aligned_loc_key] = new_source_loc 239 | 240 | for i in range(len(ad_list)): 241 | if isinstance (ad_list[i].obsm[raw_loc_key],pd.DataFrame): 242 | columns = ad_list[i].obsm[raw_loc_key].columns 243 | elif ad_list[i].obsm[raw_loc_key].shape[1] == 3: 244 | columns = ['X','Y','Z'] 245 | else: 246 | columns = ['X','Y'] 247 | aligned_loc = pd.DataFrame(ad_list[i].obsm[aligned_loc_key], columns=columns, index=ad_list[i].obs.index) 248 | ad_list[i].obsm[aligned_loc_key] = aligned_loc 249 | 250 | if write_loc_path is not None: 251 | coo = pd.DataFrame() 252 | for i in range(len(ad_list)): 253 | loc = ad_list[i].obsm[aligned_loc_key] 254 | coo = pd.concat([coo,loc],axis=0) 255 | coo.to_csv(write_loc_path) -------------------------------------------------------------------------------- /SPACEL/Scube/gpr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | import gpytorch 5 | import math 6 | import os 7 | import scanpy as sc 8 | from .plot import plot_3d 9 | 10 | class ExactGPModel(gpytorch.models.ExactGP): 11 | def __init__(self, train_x, train_y, likelihood,lengthscale_prior=None,outputscale_prior=None): 12 | super(ExactGPModel, self).__init__(train_x, train_y, likelihood) 13 | self.mean_module = gpytorch.means.ConstantMean() 14 | self.covar_module = gpytorch.kernels.ScaleKernel( 15 | base_kernel=gpytorch.kernels.RBFKernel( 16 | lengthscale_prior=lengthscale_prior 17 | ), 18 | outputscale_prior=outputscale_prior 19 | ) 20 | 21 | def forward(self, x): 22 | mean_x = self.mean_module(x) 23 | covar_x = self.covar_module(x) 24 | return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) 25 | 26 | class GPRmodel(): 27 | """The GPR model class. 28 | 29 | The GPR model class for prediction of 3D spatial transcriptomic data. 30 | 31 | Attributes: 32 | used_genes: A list of gene names used for selecting genes as input. 33 | log_bf: The log Bayes factor (BF) value indicating the variation of each gene in 3D spatial data. 34 | use_gpu: A boolean value indicating whether to use the GPU for training. 35 | subset: An integer value indicating the number of spots/cells to be downsampled for training. 36 | lengthscale_prior: The prior value of the lengthscale parameter in the Gaussian Process Regression (GPR) model. 37 | outputscale_prior: The prior value of the outputscale parameter in the GPR model. 38 | noise_prior: The prior value of the noise parameter in the GPR model. 39 | output_dir: The path where the outputs will be saved. 40 | """ 41 | 42 | def __init__(self,expr,loc,loc_resample,used_genes,use_gpu=False,output_dir=None,**kwargs): 43 | """Initializes the instance of GPR model class. 44 | 45 | Args: 46 | expr: A matrix of expression values at each location in the spatial transcriptomic data. 47 | loc: A matrix of coordinate values at each location in the spatial transcriptomic data. 48 | loc_resample: A matrix of coordinate values at each location in the resampled data. 49 | used_genes: A list of gene names used for selecting genes as input. 50 | use_gpu: A boolean value indicating whether to use the GPU for training. 51 | output_dir: The path where the outputs will be saved. 52 | """ 53 | if 'subset' not in kwargs.keys(): 54 | self.subset = 10000 55 | else: 56 | self.subset = kwargs['subset'] 57 | 58 | if 'lengthscale_prior' not in kwargs.keys(): 59 | self.lengthscale_prior = None 60 | else: 61 | self.lengthscale_prior = kwargs['lengthscale_prior'] 62 | 63 | if 'outputscale_prior' not in kwargs.keys(): 64 | self.outputscale_prior = None 65 | else: 66 | self.outputscale_prior = kwargs['outputscale_prior'] 67 | 68 | if 'noise_prior' not in kwargs.keys(): 69 | self.noise_prior = None 70 | else: 71 | self.noise_prior = kwargs['noise_prior'] 72 | 73 | if output_dir is not None: 74 | self.output_dir = output_dir 75 | else: 76 | self.output_dir = './gpr_models' 77 | if not os.path.exists(self.output_dir): 78 | os.makedirs(self.output_dir) 79 | 80 | self.model = None 81 | self.flat_model = None 82 | self.loc_resample = torch.tensor(loc_resample,dtype=torch.float32) 83 | self.train_x, self.subset_y = self.prepare_gpr_data(loc,expr.values,self.subset) 84 | self.train_y = None 85 | self.used_genes = used_genes 86 | self.g_ind = [list(expr.columns).index(g) for g in self.used_genes] 87 | self.gene_ind_map = pd.DataFrame(self.g_ind,index=self.used_genes,columns=['g_ind']) 88 | self.log_bf = pd.DataFrame(index=self.used_genes,columns=['log_bf']) 89 | self.use_gpu = use_gpu 90 | if self.use_gpu: 91 | print('Using GPU accelerate') 92 | self.train_x = self.train_x.cuda() 93 | self.loc_resample = self.loc_resample.cuda() 94 | 95 | @staticmethod 96 | def prepare_gpr_data(X,y,subset): 97 | subset_ind = np.random.permutation(X.shape[0])[:subset] 98 | subset_x = X[subset_ind] 99 | subset_y = y[subset_ind] 100 | subset_x = torch.tensor(subset_x,dtype=torch.float32) 101 | subset_y = torch.tensor(subset_y,dtype=torch.float32) 102 | return subset_x, subset_y 103 | 104 | def prepare_gpr_model(self, lengthscale_prior=None,outputscale_prior=None,noise_prior=None,bayesian_alter=False): 105 | 106 | if bayesian_alter: 107 | if self.flat_model is None: 108 | likelihood = gpytorch.likelihoods.GaussianLikelihood( 109 | noise_prior=noise_prior 110 | ) 111 | self.flat_model = ExactGPModel(self.train_x, self.train_y, likelihood, lengthscale_prior=lengthscale_prior,outputscale_prior=outputscale_prior) 112 | # self.init_model(self.flat_model,lengthscale=torch.tensor(99999)) 113 | self.init_model(self.flat_model,lengthscale=torch.tensor(1000)) 114 | else: 115 | if self.model is None: 116 | likelihood = gpytorch.likelihoods.GaussianLikelihood( 117 | noise_prior=noise_prior 118 | ) 119 | self.model = ExactGPModel(self.train_x, self.train_y, likelihood, lengthscale_prior=lengthscale_prior,outputscale_prior=outputscale_prior) 120 | self.init_model(self.model) 121 | 122 | 123 | def init_model(self,model,noise=None,lengthscale=None,outputscale=None,constant=None): 124 | if noise is None: 125 | noise = self.train_y.std()/2 126 | if lengthscale is None: 127 | lengthscale = self.train_y.std() 128 | if outputscale is None: 129 | outputscale = torch.tensor(4) 130 | if constant is None: 131 | constant = torch.tensor(self.train_y.mean()) 132 | 133 | hypers = { 134 | 'likelihood.noise_covar.noise': noise, 135 | 'covar_module.base_kernel.lengthscale': lengthscale, 136 | 'covar_module.outputscale': outputscale, 137 | 'mean_module.constant': constant 138 | } 139 | 140 | model.initialize(**hypers) 141 | model.train_targets = self.train_y 142 | model.zero_grad() 143 | 144 | def train_single_model(self,model,lr=1,training_iter=500,save=False,save_path=None,optimize_method='Adam'): 145 | # Find optimal model hyperparameters 146 | model.train() 147 | model.likelihood.train() 148 | 149 | # "Loss" for GPs - the marginal log likelihood 150 | mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model) 151 | best_loss = np.inf 152 | best_model_state = model.state_dict() 153 | if optimize_method=='Adam': 154 | optimizer = torch.optim.Adam(model.parameters(),lr=lr) 155 | for i in range(training_iter): 156 | optimizer.zero_grad() 157 | output = model(self.train_x) 158 | loss = -mll(output, self.train_y) 159 | loss.backward() 160 | optimizer.step() 161 | print(loss.item()) 162 | if loss.item() < best_loss: 163 | best_model_state = model.state_dict() 164 | best_loss = loss.item() 165 | 166 | # Bugs need to be solved 167 | elif optimize_method=='LBGFS': 168 | optimizer = torch.optim.LBFGS(model.parameters(),line_search_fn='strong_wolfe', lr=lr) 169 | def closure(): 170 | optimizer.zero_grad() 171 | output = model(self.train_x) 172 | loss = -mll(output, self.train_y) 173 | loss.backward() 174 | return loss 175 | for i in range(training_iter): 176 | loss = optimizer.step(closure) 177 | if loss.item() < best_loss: 178 | best_model_state = model.state_dict() 179 | best_loss = loss.item() 180 | else: 181 | raise ValueError('Invalid optimize method: ', optimize_method) 182 | 183 | # Get into evaluation (predictive posterior) mode 184 | model.eval() 185 | model.likelihood.eval() 186 | if save: 187 | torch.save(best_model_state, save_path) 188 | print('Best model loss:', best_loss) 189 | return best_loss 190 | 191 | def optim_lengthscale(self,model,lr=1,l_range=torch.arange(1,12,1),optimize_method='Adam'): 192 | 193 | model.train() 194 | model.likelihood.train() 195 | loss_list = [] 196 | print('Optimize lenthscale...') 197 | for lenthscale_alpha in l_range: 198 | self.init_model(model,lengthscale=lenthscale_alpha*self.train_y.std()) 199 | mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model) 200 | 201 | if optimize_method=='Adam': 202 | optimizer = torch.optim.Adam(model.parameters(),lr=lr) 203 | optimizer.zero_grad() 204 | output = model(self.train_x) 205 | loss = -mll(output, self.train_y) 206 | loss.backward() 207 | optimizer.step() 208 | elif optimize_method=='LBGFS': 209 | optimizer = torch.optim.LBFGS(model.parameters(),line_search_fn='strong_wolfe', lr=lr) 210 | def closure(): 211 | optimizer.zero_grad() 212 | output = model(self.train_x) 213 | loss = -mll(output, self.train_y) 214 | loss.backward() 215 | return loss 216 | loss = optimizer.step(closure) 217 | else: 218 | raise ValueError('Invalid optimize method: ', optimize_method) 219 | loss_list.append(loss.item()) 220 | best_l_alpha = l_range[np.argmin(loss_list)] 221 | print('Initialize lenthscale alpha as %.3f' % best_l_alpha) 222 | self.init_model(model,lengthscale=best_l_alpha*self.train_y.std()) 223 | model.eval() 224 | model.likelihood.eval() 225 | 226 | def train(self,lr=1,training_iter=500,save_model=True,save_pred=False,cal_bf=False,optim_l=False,optimize_method='Adam'): 227 | """Training GPR model 228 | 229 | Training GPR model. 230 | 231 | Args: 232 | lr: The learning rate used in the training process. 233 | training_iter: The number of iterations for the training. 234 | save_model: A boolean value indicating whether to save the trained model. The default value is True. 235 | save_pred: A boolean value indicating whether to save the prediction results. 236 | cal_bf: A boolean value indicating whether to calculate the BF (Bayes factor) value. 237 | optim_l: A boolean value indicating whether to optimize the initial length scale value. 238 | optimize_method: The optimization method used in the training. It must be one of 'Adam' and 'LBGFS'. By default, it is set to 'Adam'. 239 | 240 | Returns: 241 | ``None`` 242 | """ 243 | for g,g_i in zip(self.used_genes,self.g_ind): 244 | print(f'Modeling {g}') 245 | self.train_y = self.subset_y[:,g_i] 246 | self.prepare_gpr_model(lengthscale_prior=self.lengthscale_prior,outputscale_prior=self.outputscale_prior,noise_prior=self.noise_prior) 247 | if self.use_gpu: 248 | self.train_y = self.train_y.cuda() 249 | self.model = self.model.cuda() 250 | self.model.likelihood = self.model.likelihood.cuda() 251 | 252 | if optim_l: 253 | self.optim_lengthscale(self.model,lr=lr,optimize_method=optimize_method) 254 | 255 | mll = self.train_single_model( 256 | self.model, 257 | lr=lr, 258 | optimize_method=optimize_method, 259 | training_iter=training_iter, 260 | save=save_model, 261 | save_path=os.path.join(self.output_dir,f'{g}_iter{training_iter}_lr{lr}_model_state.pth')) 262 | 263 | if save_pred: 264 | resampled_pred = self.predict_resampled_spot() 265 | np.save(os.path.join(self.output_dir,f'{g}_iter{training_iter}_lr{lr}_resampled_pred.npy'),resampled_pred) 266 | 267 | if cal_bf: 268 | self.prepare_gpr_model(lengthscale_prior=self.lengthscale_prior,outputscale_prior=self.outputscale_prior,noise_prior=self.noise_prior,bayesian_alter=True) 269 | if self.use_gpu: 270 | self.flat_model = self.flat_model.cuda() 271 | self.flat_model.likelihood = self.flat_model.likelihood.cuda() 272 | flat_mll = self.train_single_model(self.flat_model,lr=lr,training_iter=training_iter,save=save_model,save_path=os.path.join(self.output_dir,f'{g}_iter{training_iter}_lr{lr}_flat_model_state.pth')) 273 | log_bf = -mll+flat_mll 274 | self.log_bf.loc[g,'log_bf'] = log_bf 275 | print(log_bf) 276 | 277 | def load_gene_model(self,gene,training_iter,lr): 278 | g_i = self.gene_ind_map.loc[gene,'g_ind'] 279 | self.train_y = self.subset_y[:,g_i] 280 | if self.use_gpu: 281 | self.train_y = self.train_y.cuda() 282 | if self.model is None: 283 | self.prepare_gpr_model() 284 | state_dict = torch.load(os.path.join(self.output_dir,f'{gene}_iter{training_iter}_lr{lr}_model_state.pth')) 285 | self.model.load_state_dict(state_dict) 286 | self.model.train_targets = self.train_y 287 | 288 | def eval_model(self): 289 | with torch.no_grad(), gpytorch.settings.fast_pred_var(): 290 | observed_pred = self.model.likelihood(self.model(self.train_x)) 291 | return observed_pred 292 | 293 | def predict_resampled_spot(self, gene=None, data=None, training_iter=None, lr=None, save_pred=False, save_pred_path=None, n_in_batch=30000): 294 | if gene is not None: 295 | self.load_gene_model(gene, training_iter, lr) 296 | obs_pred_list = [] 297 | for i in range(int(np.ceil(data.shape[0]/n_in_batch))): 298 | low_bound = i*n_in_batch 299 | high_bound = min((i+1)*n_in_batch,data.shape[0]) 300 | data_batch = data[low_bound:high_bound,:] 301 | with torch.no_grad(), gpytorch.settings.fast_pred_var(): 302 | observed_pred_batch = self.model.likelihood(self.model(torch.tensor(data_batch))) 303 | obs_pred_list.append(observed_pred_batch.mean.cpu().numpy()) 304 | obs_pred_concated = np.concatenate(obs_pred_list) 305 | if save_pred: 306 | if save_pred_path is None: 307 | save_pred_path = os.path.join(self.output_dir,f'{gene}_iter{training_iter}_lr{lr}_resampled_pred.npy') 308 | np.save(save_pred_path,obs_pred_concated) 309 | return obs_pred_concated 310 | 311 | def plot_gpr_expr( 312 | self, 313 | gene, 314 | training_iter, 315 | lr, 316 | data=None, 317 | pred_path=None, 318 | save=False, 319 | save_path=None, 320 | save_pred=False, 321 | save_pred_path=None, 322 | save_dpi=150, 323 | return_expr=True, 324 | *args,**kwargs 325 | ): 326 | """Plotting predicted expression values. 327 | 328 | Plotting the expression values predicted by the trained GPR model. 329 | 330 | Args: 331 | gene: The name of the gene for prediction. 332 | training_iter: The number of model iteration for prediction. 333 | lr: The learning rate of the model for prediction. 334 | data: The coordinate matrix for prediction. 335 | save: A boolean value indicating whether to save the prediction figure. 336 | save_path: The path where the outputs will be saved. The file extension must be one of the supported picture types. 337 | save_dpi: The DPI (dots per inch) of the saved results. 338 | return_expr: A boolean value indicating whether to return the prediction values. 339 | 340 | Returns: 341 | ``None`` or the prediction values. 342 | """ 343 | if data is None: 344 | data = self.loc_resample 345 | else: 346 | data = torch.tensor(data,dtype=torch.float32) 347 | if self.use_gpu: 348 | data = data.cuda() 349 | # if pred_path is None: 350 | # pred_path = os.path.join(self.output_dir,f'{gene}_iter{training_iter}_lr{lr}_resampled_pred.npy') 351 | if save and save_path is None: 352 | save_path = os.path.join(self.output_dir,f'{gene}_iter{training_iter}_lr{lr}_resampled_pred.png') 353 | if (pred_path is not None) and os.path.exists(pred_path): 354 | resampled_pred = np.load(os.path.join(self.output_dir,f'{gene}_iter{training_iter}_lr{lr}_resampled_pred.npy')) 355 | else: 356 | resampled_pred = self.predict_resampled_spot(gene,data,training_iter,lr,save_pred,save_pred_path) 357 | plot_3d(data.cpu().numpy(), val=resampled_pred,save_path=save_path,save_dpi=save_dpi,*args,**kwargs) 358 | if return_expr: 359 | return resampled_pred 360 | -------------------------------------------------------------------------------- /SPACEL/Scube/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib.patches import Patch 3 | from matplotlib.lines import Line2D 4 | import matplotlib 5 | import seaborn as sns 6 | import pandas as pd 7 | import numpy as np 8 | 9 | def plot_3d( 10 | loc, 11 | val=None, 12 | color=None, 13 | figsize=(8,8), 14 | return_fig=False, 15 | elev=None, 16 | azim=None, 17 | xlim=None, 18 | ylim=None, 19 | zlim=None, 20 | frameon=True, 21 | save_path=None, 22 | save_dpi=150, 23 | show=True, 24 | *args, 25 | **kwargs 26 | ): 27 | """Plot all slices stacked in 3D 28 | 29 | Plot all slices stacked with a given number of subplots rows and columns. Spots/cells colored by spatial domain. 30 | 31 | Args: 32 | loc: An array of the coordinates of each spots/cells in all slices, which the first three columns are X-axis, Y-axis, Z-axis coordinates. 33 | val: The colors of each spots/cells given in loc use to plot. 34 | color: The colors of each spots/cells given in loc use to plot. 35 | figsize: Size of the figure. 36 | return_fig: Whether to return the figure. 37 | elev: The elevation angle in the vertical plane in degrees. If ``None`` then the initial value as specified in the ``Axes3D`` constructor is used. 38 | azim: The azimuth angle in the horizontal plane in degrees. If ``None`` then the initial value as specified in the ``Axes3D`` constructor is used. 39 | xlim: A tuple given the left and right xlims in X-axis. 40 | ylim: A tuple given the left and right xlims in Y-axis. 41 | zlim: A tuple given the left and right xlims in Z-axis. 42 | frameon: Whether to hide the coordinate axes. 43 | save_path: A string representing the path directory where the figure saved. 44 | save_dpi: The resolution in dots per inch. 45 | show: Whether to show the figure. 46 | 47 | Returns: 48 | A ``matplotlib.figure.Figure`` object 49 | """ 50 | if 'marker' not in kwargs.keys(): 51 | kwargs['marker'] = 'o' 52 | if 's' not in kwargs.keys(): 53 | kwargs['s'] = 5 54 | if 'cmap' not in kwargs.keys(): 55 | kwargs['cmap'] = 'Spectral_r' 56 | fig = plt.figure(figsize=figsize) 57 | ax = fig.add_subplot(projection='3d') 58 | if (elev is not None) and (azim is not None): 59 | ax.view_init(elev, azim) # 设定视角 60 | if color is None: 61 | ax.scatter(loc[:,0], loc[:,1], loc[:,2],c=val,*args,**kwargs) 62 | else: 63 | ax.scatter(loc[:,0], loc[:,1], loc[:,2],c=color,*args,**kwargs) 64 | if xlim is not None: 65 | ax.set_xlim(xlim) 66 | if ylim is not None: 67 | ax.set_ylim(ylim) 68 | if zlim is not None: 69 | ax.set_zlim(zlim) 70 | if not frameon: 71 | plt.axis('off') 72 | if save_path is not None: 73 | print(save_path) 74 | plt.savefig(save_path,dpi=save_dpi,bbox_inches='tight') 75 | if show: 76 | plt.show() 77 | else: 78 | plt.close() 79 | if return_fig: 80 | return fig 81 | 82 | def plot_single_slice(adata, spatial_key, cluster_key, frameon=False, i=1, j=1, n=1, s=1): 83 | ind = np.sort(adata.obs[cluster_key].unique().copy()) 84 | color = adata.uns[cluster_key+'_colors'].copy() 85 | dic = dict(zip(ind, color)) 86 | col = adata.obs[cluster_key].replace(dic) 87 | ax = plt.subplot(i,j,n) 88 | if type(adata.obsm[spatial_key]) == pd.core.frame.DataFrame: 89 | plt.scatter(adata.obsm[spatial_key].iloc[:,0], adata.obsm[spatial_key].iloc[:,1], c=col, s=s, rasterized=True) 90 | if type(adata.obsm[spatial_key]) == np.ndarray: 91 | plt.scatter(adata.obsm[spatial_key][:,0], adata.obsm[spatial_key][:,1], c=col, s=s, rasterized=True) 92 | plt.axis('equal') 93 | if not frameon: 94 | plt.axis('off') 95 | 96 | def plot_stacked_slices( 97 | ad_list, 98 | spatial_key, 99 | cluster_key, 100 | legend=True, 101 | frameon=False, 102 | colors=None, 103 | i=1, 104 | j=1, 105 | s=1 106 | ): 107 | """Plot all slices stacked 108 | 109 | Plot all slices stacked in one figure with a given number of subplots rows and columns. Spots/cells colored by spatial domain. 110 | 111 | Args: 112 | ad_list: A list of ``AnnData`` objects containing all slices. 113 | spatial_key: A string representing one key of ``obsm`` in AnnData object of all slices, containing the coordinates used to plot. 114 | cluster_key: A string representing one column of ``obs`` in AnnData object of all slices, containing the spatial domain information used to plot. 115 | legend: Whether to display the legend. 116 | frameon: Whether to hide the coordinate axes. 117 | colors: A list of colors for each spatial domain to plot. If ``None``, it will default to ``tab10`` or ``tab20`` accoording to the number of spatial domains. 118 | i: Number of rows of the subplots. 119 | j: Number of columns of the subplots. If i=j=1, it will be a single figure. 120 | s: Size of points. 121 | 122 | Returns: 123 | ``None`` 124 | """ 125 | 126 | clusters = [] 127 | colored_num = 0 128 | for ad in ad_list: 129 | if f'{cluster_key}_colors' in ad.uns.keys(): 130 | colored_num += 1 131 | clusters.extend(ad.obs[cluster_key].cat.categories) 132 | clusters = np.unique(clusters) 133 | if colored_num < len(ad_list): 134 | if colors is None: 135 | if len(clusters) > 10: 136 | colors = [matplotlib.colors.to_hex(c) for c in sns.color_palette('tab20',n_colors=len(clusters))] 137 | else: 138 | colors = [matplotlib.colors.to_hex(c) for c in sns.color_palette('tab10',n_colors=len(clusters))] 139 | color_map = pd.DataFrame(colors,index=clusters,columns=['color']) 140 | for ad in ad_list: 141 | ad.uns[f'{cluster_key}_color_map'] = color_map 142 | ad.uns[f'{cluster_key}_colors'] = [color_map.loc[c,'color'] for c in ad.obs[cluster_key].cat.categories] 143 | else: 144 | color_map = pd.DataFrame(index=clusters,columns=['color']) 145 | for ad in ad_list: 146 | color_map.loc[ad.obs[cluster_key].cat.categories,'color'] = ad.uns[f'{cluster_key}_colors'] 147 | ad.uns[f'{cluster_key}_color_map'] = color_map 148 | if legend: 149 | legend_elements = [ Line2D([], [], marker='.', markersize=10, color=color_map.loc[i,'color'], linestyle='None', label=i) for i in color_map.index] 150 | for k in range(len(ad_list)): 151 | plot_single_slice(ad_list[k], spatial_key, cluster_key, frameon, i, j, n=np.min([k+1,i*j]), s=s) 152 | plt.legend(handles=legend_elements, loc='right') 153 | -------------------------------------------------------------------------------- /SPACEL/Scube/utils_3d.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from scipy.spatial.distance import pdist, cdist 6 | 7 | def get_alpha(loc,subset=200,k=200): 8 | dis = cdist(loc,loc[np.random.permutation(loc.shape[0])[:subset]]) 9 | return np.median(np.sort(dis,axis=0)[200]) 10 | 11 | def smooth_mesh(mesh,taubin_iter=None,subdivide_iter=3,show=False): 12 | """Smooth Mesh. 13 | 14 | Smoothing a mesh for a tissue. 15 | 16 | Args: 17 | taubin_iter: A `Float` value indicating the number of iterations for taubin smooth. 18 | subdivide_iter: A `Float` value indicating the number of iterations for subdivide smooth. 19 | show: A `Boolean` value indicating whether to show the mesh. 20 | 21 | Returns: 22 | A smoothed mesh object. 23 | """ 24 | if taubin_iter is not None: 25 | print(f'filter with Taubin with {taubin_iter} iterations') 26 | mesh = mesh.filter_smooth_taubin(number_of_iterations=taubin_iter) 27 | mesh.compute_vertex_normals() 28 | if subdivide_iter is not None: 29 | mesh = mesh.subdivide_loop(number_of_iterations=subdivide_iter) 30 | print( 31 | f'After subdivision it has {len(mesh.vertices)} vertices and {len(mesh.triangles)} triangles' 32 | ) 33 | if show: 34 | o3d.visualization.draw_geometries([mesh], mesh_show_back_face=True) 35 | return mesh 36 | 37 | def create_mesh(loc,alpha=None,show=False): 38 | """Create Mesh. 39 | 40 | Creating a mesh for a tissue. 41 | 42 | Args: 43 | loc: A `DataFrame` object represents the 3D location of each spot. 44 | alpha: A `Float` value used for o3d.geometry.TriangleMesh.create_from_point_cloud_alpha_shape function. A large alpha, will ignore more details. 45 | show: A `Boolean` value indicating whether to show the mesh. 46 | 47 | Returns: 48 | A mesh object. 49 | """ 50 | pcd = o3d.geometry.PointCloud() 51 | pcd.points = o3d.utility.Vector3dVector(loc) 52 | if alpha is None: 53 | alpha = get_alpha(loc) 54 | print(f"alpha={alpha:.3f}") 55 | mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_alpha_shape(pcd, alpha) 56 | mesh.compute_vertex_normals() 57 | if show: 58 | o3d.visualization.draw_geometries([pcd,mesh], mesh_show_back_face=True) 59 | return mesh 60 | 61 | def sample_in_mesh(mesh,xyz,num_sample=500000,save_sampled=None,save_surface=None): 62 | """Sample spots/cells in a mesh. 63 | 64 | Sampling spots/cells in a mesh. 65 | 66 | Args: 67 | mesh: A mesh object to be sampled. 68 | xyz: A matrix of x, y, and z coordinates of each spot/cell. 69 | num_sample: A `int` value indicating the number of spots/cells to be sampled. 70 | save_surface: A `Boolean` value indicating whether to save the spots/cells on the surface of the mesh. 71 | 72 | Returns: 73 | A smoothed mesh object. 74 | """ 75 | mesh_scene = o3d.t.geometry.TriangleMesh.from_legacy(mesh) 76 | 77 | # Create a scene and add the triangle mesh 78 | scene = o3d.t.geometry.RaycastingScene() 79 | _ = scene.add_triangles(mesh_scene) 80 | 81 | dx,dy,dz = xyz.max(0)-xyz.min(0) 82 | dd = np.math.pow(num_sample/((dx/dz)*(dy/dz)*(dz/dz)), 1/3) 83 | 84 | nx,ny,nz=int((dx/dz)*dd),int((dy/dz)*dd),int(dd) 85 | xmin,ymin,zmin=xyz.min(0) 86 | xmax,ymax,zmax=xyz.max(0) 87 | x_sampled = np.linspace(xmin-1e-7, xmax+1e-7, nx) 88 | y_sampled = np.linspace(ymin-1e-7, ymax+1e-7, ny) 89 | z_sampled = np.linspace(zmin-1e-7, zmax+1e-7, nz) 90 | 91 | xyz_sampled = [] 92 | for x_s in x_sampled: 93 | for y_s in y_sampled: 94 | for z_s in z_sampled: 95 | xyz_sampled.append([x_s,y_s,z_s]) 96 | 97 | xyz_sampled = np.array(xyz_sampled) 98 | 99 | query_point = o3d.core.Tensor(xyz_sampled, dtype=o3d.core.Dtype.Float32) 100 | 101 | # Compute distance of the query point from the surface 102 | unsigned_distance = scene.compute_distance(query_point) 103 | signed_distance = scene.compute_signed_distance(query_point) 104 | occupancy = scene.compute_occupancy(query_point) 105 | 106 | occupancy = np.array(occupancy) 107 | 108 | xyz_sampled_inmesh = np.array(xyz_sampled)[occupancy == 1] 109 | xyz_surface = np.asarray(mesh.vertices) 110 | if save_sampled is not None: 111 | np.save(save_sampled,xyz_sampled_inmesh) 112 | if save_surface is not None: 113 | np.save(save_surface,xyz_surface) 114 | return xyz_sampled_inmesh, xyz_surface 115 | 116 | def convert_colors(val, cmap='Spectral_r'): 117 | val = (val - val.min())/(val.max()-val.min()) 118 | cmap = plt.get_cmap(cmap) 119 | # point_colors = [cmap(val) for val in ano4_gpr] 120 | point_colors = [cmap(v) for v in val] 121 | point_colors = np.array(point_colors)[:,:3] 122 | return point_colors 123 | 124 | def get_surface_colors(mesh, surface_expr, cmap='Spectral_r'): 125 | """Obtaining the color of the surface of mesh. 126 | 127 | Obtaining the color of the surface of mesh according to the expression value. 128 | 129 | Args: 130 | mesh: A mesh object to be colored. 131 | surface_expr: A expression matrix of each spots/cells at the surface of mesh. 132 | cmap: The cmap used for generating colors. 133 | 134 | Returns: 135 | ``None`` 136 | """ 137 | surface_expr = (surface_expr - surface_expr.min())/(surface_expr.max()-surface_expr.min()) 138 | cmap = plt.get_cmap(cmap) 139 | point_colors = [cmap(val) for val in surface_expr] 140 | point_colors = np.array(point_colors)[:,:3] 141 | mesh.vertex_colors = o3d.utility.Vector3dVector(point_colors) 142 | 143 | def save_view_parameters(mesh, filename, point_size=3): 144 | """Saving the view parameters. 145 | 146 | Saving the view parameters according to user's opperations. 147 | 148 | Args: 149 | mesh: A mesh object to be shown. 150 | filename: The path where the parameters will be saved. 151 | point_size: The size of points to be shown. 152 | 153 | Returns: 154 | ``None`` 155 | """ 156 | vis = o3d.visualization.Visualizer() 157 | vis.create_window() 158 | opt = vis.get_render_option() 159 | opt.mesh_show_back_face = True 160 | opt.point_size=point_size 161 | vis.add_geometry(mesh) 162 | vis.run() 163 | param = vis.get_view_control().convert_to_pinhole_camera_parameters() 164 | o3d.io.write_pinhole_camera_parameters(filename,param) 165 | vis.destroy_window() 166 | 167 | def load_view_parameters(mesh, filename, point_size=3): 168 | """Loading the view parameters. 169 | 170 | Loading the view parameters from a file. 171 | 172 | Args: 173 | mesh: A mesh object to be shown. 174 | filename: The path where the parameters saved. 175 | point_size: The size of points to be shown. 176 | 177 | Returns: 178 | ``None`` 179 | """ 180 | vis = o3d.visualization.Visualizer() 181 | vis.create_window() 182 | ctr = vis.get_view_control() 183 | opt = vis.get_render_option() 184 | opt.mesh_show_back_face = True 185 | opt.point_size=point_size 186 | param = o3d.io.read_pinhole_camera_parameters(filename) 187 | vis.add_geometry(mesh) 188 | ctr.convert_from_pinhole_camera_parameters(param) 189 | vis.run() 190 | vis.destroy_window() -------------------------------------------------------------------------------- /SPACEL/Splane/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['OMP_NUM_THREADS'] = '1' 3 | os.environ['OPENBLAS_NUM_THREADS'] = '1' 4 | from . import model,graph,utils 5 | from .model import init_model -------------------------------------------------------------------------------- /SPACEL/Splane/base_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import random 4 | import os 5 | import matplotlib 6 | from sklearn.cluster import KMeans 7 | from sklearn.metrics import davies_bouldin_score 8 | import seaborn as sns 9 | import squidpy as sq 10 | from .utils import clustering 11 | import math 12 | import time 13 | import numpy as np 14 | from tqdm import tqdm 15 | from time import strftime, localtime 16 | import tempfile 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.optim as optim 22 | from torch.nn.parameter import Parameter 23 | from torch.nn.modules.module import Module 24 | 25 | 26 | class GraphConvolution(Module): 27 | def __init__(self, in_features, out_features, support, bias=True): 28 | super(GraphConvolution, self).__init__() 29 | self.in_features = in_features 30 | self.out_features = out_features 31 | self.support = support 32 | self.weight = Parameter(torch.FloatTensor(in_features * support, out_features)) 33 | if bias: 34 | self.bias = Parameter(torch.FloatTensor(out_features)) 35 | else: 36 | self.register_parameter('bias', None) 37 | self.reset_parameters() 38 | 39 | def reset_parameters(self): 40 | stdv = 1. / math.sqrt(self.weight.size(1)) 41 | self.weight.data.uniform_(-stdv, stdv) 42 | if self.bias is not None: 43 | self.bias.data.uniform_(-stdv, stdv) 44 | 45 | def forward(self, features, basis): 46 | supports = list() 47 | for i in range(self.support): 48 | supports.append(basis[i].matmul(features)) 49 | supports = torch.cat(supports, dim=1) 50 | output = torch.spmm(supports, self.weight) 51 | if self.bias is not None: 52 | return output + self.bias 53 | else: 54 | return output 55 | 56 | def __repr__(self): 57 | return self.__class__.__name__ + ' (' \ 58 | + str(self.in_features) + ' -> ' \ 59 | + str(self.out_features) + ')' 60 | 61 | class Splane_GCN(nn.Module): 62 | def __init__(self,feature_dims,support,latent_dims=8,hidden_dims=64,dropout=0.8): 63 | super(Splane_GCN, self).__init__() 64 | self.feature_dims = feature_dims 65 | self.support = support 66 | self.latent_dims=latent_dims 67 | self.hidden_dims=hidden_dims 68 | self.dropout = dropout 69 | self.encode_gc1 = GraphConvolution(feature_dims, hidden_dims, support) 70 | self.encode_gc2 = GraphConvolution(hidden_dims, latent_dims, support) 71 | self.decode_gc1 = GraphConvolution(latent_dims, hidden_dims, support) 72 | self.decode_gc2 = GraphConvolution(hidden_dims, feature_dims, support) 73 | 74 | nn.init.kaiming_normal_(self.encode_gc1.weight) 75 | nn.init.xavier_uniform_(self.encode_gc2.weight) 76 | nn.init.kaiming_normal_(self.decode_gc1.weight) 77 | nn.init.xavier_uniform_(self.decode_gc2.weight) 78 | 79 | @staticmethod 80 | def l2_activate(x,dim): 81 | 82 | def scale(z): 83 | zmax = z.max(1, keepdims=True).values 84 | zmin = z.min(1, keepdims=True).values 85 | z_std = torch.nan_to_num(torch.div(z - zmin,(zmax - zmin)),0) 86 | return z_std 87 | 88 | x = scale(x) 89 | x = F.normalize(x, p=2, dim=1) 90 | return x 91 | 92 | def encode(self, x, adj): 93 | x = F.dropout(x, self.dropout, training=self.training) 94 | x = F.leaky_relu(self.encode_gc1(x, adj)) 95 | x = F.dropout(x, self.dropout, training=self.training) 96 | x = self.encode_gc2(x, adj) 97 | return self.l2_activate(x, dim=1) 98 | 99 | def decode(self, x, adj): 100 | x = F.dropout(x, self.dropout, training=self.training) 101 | x = F.leaky_relu(self.decode_gc1(x, adj)) 102 | x = F.dropout(x, self.dropout, training=self.training) 103 | x = self.decode_gc2(x, adj) 104 | return x 105 | 106 | def forward(self, x, adj): 107 | z = self.encode(x, adj) 108 | x_ = self.decode(z, adj) 109 | return z, x_ 110 | 111 | class Splane_Disc(nn.Module): 112 | def __init__(self,label,latent_dims=8,hidden_dims=64,dropout=0.5): 113 | super(Splane_Disc, self).__init__() 114 | self.latent_dims=latent_dims 115 | self.hidden_dims=hidden_dims 116 | self.dropout = dropout 117 | self.class_num = label.shape[1] 118 | self.disc = nn.Sequential( 119 | nn.Linear(latent_dims, hidden_dims), 120 | nn.LeakyReLU(), 121 | nn.BatchNorm1d(hidden_dims), 122 | nn.Linear(hidden_dims, hidden_dims), 123 | nn.LeakyReLU(), 124 | nn.BatchNorm1d(hidden_dims), 125 | nn.Dropout(dropout), 126 | nn.Linear(hidden_dims, self.class_num) 127 | ) 128 | 129 | def forward(self, x): 130 | x = self.disc(x) 131 | y = F.softmax(x, dim=1) 132 | return y 133 | 134 | class SplaneModel(): 135 | def __init__( 136 | self, 137 | expr_ad_list, 138 | n_clusters, 139 | X, 140 | graph, 141 | support, 142 | slice_class_onehot, 143 | nb_mask, 144 | train_idx, 145 | test_idx, 146 | celltype_weights, 147 | morans_mean, 148 | lr, 149 | l1, 150 | l2, 151 | latent_dim, 152 | hidden_dims, 153 | gnn_dropout, 154 | use_gpu 155 | ): 156 | self.expr_ad_list = expr_ad_list 157 | self.model_g = Splane_GCN(X.shape[1],support,latent_dims=latent_dim,hidden_dims=hidden_dims,dropout=gnn_dropout) 158 | self.model_d = Splane_Disc(slice_class_onehot,latent_dims=latent_dim,hidden_dims=hidden_dims) 159 | self.graph = graph 160 | self.slice_class_onehot = slice_class_onehot 161 | self.nb_mask = nb_mask 162 | self.train_idx = train_idx 163 | self.test_idx = test_idx 164 | self.celltype_weights = torch.tensor(celltype_weights) 165 | self.morans_mean = morans_mean 166 | self.best_path = None 167 | self.cos_loss_obj = F.cosine_similarity 168 | self.d_loss_obj = F.cross_entropy 169 | self.n_clusters = n_clusters 170 | self.Cluster = KMeans(n_clusters=self.n_clusters,n_init=10,tol=1e-3,algorithm='full',max_iter=1000,random_state=42) 171 | self.optimizer_g = optim.RMSprop(self.model_g.parameters(),lr=lr) 172 | self.optimizer_d = optim.RMSprop(self.model_d.parameters(),lr=lr) 173 | self.l1 = l1 174 | self.l2 = l2 175 | if use_gpu: 176 | self.model_g = self.model_g.cuda() 177 | self.model_d = self.model_d.cuda() 178 | self.celltype_weights = self.celltype_weights.cuda() 179 | 180 | @staticmethod 181 | def kl_divergence(y_true, y_pred, dim=0): 182 | y_pred = torch.clip(y_pred, torch.finfo(torch.float32).eps) 183 | y_true = y_true.to(y_pred.dtype) 184 | y_true = torch.nan_to_num(torch.div(y_true, y_true.sum(dim, keepdims=True)),0) 185 | y_pred = torch.nan_to_num(torch.div(y_pred, y_pred.sum(dim, keepdims=True)),0) 186 | y_true = torch.clip(y_true, torch.finfo(torch.float32).eps, 1) 187 | y_pred = torch.clip(y_pred, torch.finfo(torch.float32).eps, 1) 188 | return torch.mul(y_true, torch.log(torch.nan_to_num(torch.div(y_true, y_pred)))).mean(dim) 189 | 190 | def train_model_g(self,d_l,simi_l): 191 | self.model_g.train() 192 | self.optimizer_g.zero_grad() 193 | encoded, decoded = self.model_g(self.graph[0],self.graph[1:]) 194 | y_disc = self.model_d(encoded) 195 | d_loss = F.cross_entropy(self.slice_class_onehot, y_disc) 196 | decoded_mask = decoded[self.train_idx] 197 | x_mask = self.graph[0][self.train_idx] 198 | simi_loss = -torch.mean(torch.sum(encoded[self.nb_mask[0]] * encoded[self.nb_mask[1]], dim=1)) + torch.mean(torch.abs(encoded[self.nb_mask[0]]-encoded[self.nb_mask[1]])) 199 | g_loss = -torch.sum(self.celltype_weights*F.cosine_similarity(x_mask, decoded_mask,dim=0))+torch.sum(self.celltype_weights*self.kl_divergence(x_mask, decoded_mask, dim=0)) + simi_l*simi_loss 200 | 201 | total_loss = g_loss - d_l*d_loss 202 | total_loss.backward() 203 | self.optimizer_g.step() 204 | return total_loss 205 | 206 | def train_model_d(self,): 207 | self.model_d.train() 208 | self.optimizer_d.zero_grad() 209 | encoded, decoded = self.model_g(self.graph[0],self.graph[1:]) 210 | y_disc = self.model_d(encoded) 211 | d_loss = F.cross_entropy(self.slice_class_onehot, y_disc) 212 | d_loss.backward() 213 | self.optimizer_d.step() 214 | return d_loss 215 | 216 | def test_model(self,d_l,simi_l): 217 | self.model_g.eval() 218 | self.model_d.eval() 219 | encoded, decoded = self.model_g(self.graph[0],self.graph[1:]) 220 | y_disc = self.model_d(encoded) 221 | d_loss = F.cross_entropy(self.slice_class_onehot, y_disc) 222 | decoded_mask = decoded[self.test_idx] 223 | x_mask = self.graph[0][self.test_idx] 224 | ll = torch.eq(torch.argmax(self.slice_class_onehot, -1), torch.argmax(y_disc, -1)) 225 | accuarcy = ll.to(torch.float32).mean() 226 | simi_loss = -torch.mean(torch.sum(encoded[self.nb_mask[0]] * encoded[self.nb_mask[1]], dim=1)) + torch.mean(torch.abs(encoded[self.nb_mask[0]]-encoded[self.nb_mask[1]])) 227 | g_loss = -torch.sum(self.celltype_weights*F.cosine_similarity(x_mask, decoded_mask,dim=0))+torch.sum(self.celltype_weights*self.kl_divergence(x_mask, decoded_mask, dim=0)) + simi_l*simi_loss 228 | total_loss = g_loss - d_l*d_loss 229 | db_loss = clustering(self.Cluster, encoded.cpu().detach().numpy()) 230 | return total_loss, g_loss, d_loss, accuarcy, simi_loss, db_loss, encoded, decoded 231 | 232 | def train( 233 | self, 234 | max_epochs=300, 235 | convergence=0.0001, 236 | db_convergence=0, 237 | early_stop_epochs=10, 238 | d_l=0.5, 239 | simi_l=None, 240 | g_step = 1, 241 | d_step = 1, 242 | plot_step=5, 243 | save_path=None, 244 | prefix=None 245 | ): 246 | """Training Splane model. 247 | 248 | Training Splane model for identification of uniform spatial domains in multiple slics. 249 | 250 | Args: 251 | max_steps: The max step of training. The training process will be stop when achive max step. 252 | convergence: The total loss threshold for early stop. 253 | db_convergence: The DBS threshold for early stop. 254 | early_stop_epochs: The max epochs of loss difference less than convergence. 255 | d_l: The weight of discriminator loss. 256 | simi_l: The weight of similarity loss. 257 | plot_step: The interval steps of training. 258 | save_path: A string representing the path directory where the model is saved. 259 | prefix: A string added to the prefix of file name of saved model. 260 | 261 | Returns: 262 | ``None`` 263 | """ 264 | best_loss = np.inf 265 | best_db_loss = np.inf 266 | best_simi_loss = np.inf 267 | if simi_l is None: 268 | simi_l = 1/np.mean(self.morans_mean) 269 | print(f'Setting the weight of similarity loss to {simi_l:.3f}') 270 | 271 | if save_path is None: 272 | save_path = os.path.join(tempfile.gettempdir() ,'Splane_models_'+strftime("%Y%m%d%H%M%S",localtime())) 273 | if not os.path.exists(save_path): 274 | os.makedirs(save_path) 275 | early_stop_count = 0 276 | pbar = tqdm(range(max_epochs)) 277 | for epoch in pbar: 278 | for _ in range(g_step): 279 | train_total_loss = self.train_model_g(d_l=d_l, simi_l=simi_l) 280 | for _ in range(d_step): 281 | self.train_model_d() 282 | 283 | if epoch % plot_step == 0: 284 | test_total_loss, test_g_loss, test_d_loss, test_acc, simi_loss, db_loss, encoded, decoded = self.test_model(d_l=d_l, simi_l=simi_l) 285 | current_loss = test_g_loss.cpu().detach().numpy() 286 | current_db_loss = db_loss 287 | if (best_loss - current_loss > convergence) & (best_db_loss - current_db_loss > db_convergence): 288 | if best_loss > current_loss: 289 | best_loss = current_loss 290 | if best_db_loss > current_db_loss: 291 | best_db_loss = current_db_loss 292 | pbar.set_description("The best epoch {0} total loss={1:.3f} g loss={2:.3f} d loss={3:.3f} d acc={4:.3f} simi loss={5:.3f} db loss={6:.3f}".format(epoch, test_total_loss, test_g_loss, test_d_loss, test_acc, simi_loss, db_loss),refresh=True) 293 | old_best_path = self.best_path 294 | early_stop_count = 0 295 | if prefix is not None: 296 | self.best_path = os.path.join(save_path,prefix+'_'+f'Splane_weights_epoch{epoch}.h5') 297 | else: 298 | self.best_path = os.path.join(save_path,f'Splane_weights_epoch{epoch}.h5') 299 | if old_best_path is not None: 300 | if os.path.exists(old_best_path): 301 | os.remove(old_best_path) 302 | torch.save(self.model_g.state_dict(), self.best_path) 303 | else: 304 | early_stop_count += 1 305 | 306 | 307 | # print("Epoch {} train g loss={} g loss={} d loss={} acc={} simi loss={} db loss={}".format(epoch, test_total_loss, test_g_loss, test_d_loss, test_acc, simi_loss, db_loss)) 308 | if early_stop_count > early_stop_epochs: 309 | print('Stop trainning because of loss convergence') 310 | break 311 | 312 | def identify_spatial_domain(self,key=None,colors=None): 313 | """Identify Spaital domains. 314 | 315 | Identification of uniform spatial domains in multiple slics. 316 | 317 | Args: 318 | key: A column name to be saved in the `.obs` attribute of AnnData object of the ST data, representing the spatial domains. If not provided, the spatial domain will be saved as 'spatial_domain' in the `.obs` attribute. 319 | colors: A list of colors assigned to each spatial domain. 320 | 321 | Returns: 322 | ``None`` 323 | """ 324 | if colors is None: 325 | if self.n_clusters > 10: 326 | colors = [matplotlib.colors.to_hex(c) for c in sns.color_palette('tab20',n_colors=self.n_clusters)] 327 | else: 328 | colors = [matplotlib.colors.to_hex(c) for c in sns.color_palette('tab10',n_colors=self.n_clusters)] 329 | color_map = pd.DataFrame(colors,index=np.arange(self.n_clusters),columns=['color']) 330 | if key is None: 331 | key = 'spatial_domain' 332 | # self.model_g(self.graph[0], self.graph[1:]) 333 | self.model_g.load_state_dict(torch.load(self.best_path)) 334 | self.model_g.eval() 335 | encoded, decoded = self.model_g(self.graph[0], self.graph[1:]) 336 | clusters = self.Cluster.fit_predict(encoded.cpu().detach().numpy()) 337 | loc_index = 0 338 | for i in range(len(self.expr_ad_list)): 339 | if key in self.expr_ad_list[i].obs.columns: 340 | self.expr_ad_list[i].obs = self.expr_ad_list[i].obs.drop(columns=key) 341 | self.expr_ad_list[i].obs[key] = clusters[loc_index:loc_index+self.expr_ad_list[i].shape[0]] 342 | self.expr_ad_list[i].obs[key] = pd.Categorical(self.expr_ad_list[i].obs[key]) 343 | self.expr_ad_list[i].uns[f'{key}_colors'] = [color_map.loc[c,'color'] for c in self.expr_ad_list[i].obs[key].cat.categories] 344 | loc_index += self.expr_ad_list[i].shape[0] 345 | -------------------------------------------------------------------------------- /SPACEL/Splane/graph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.sparse import coo_matrix, block_diag 3 | import torch 4 | import torch.nn.functional as F 5 | from .pygcn_utils import * 6 | 7 | def get_graph_inputs(celltype_ad_list): 8 | print('Generating GNN inputs...') 9 | A_list = [] 10 | X_list = [] 11 | for celltype_ad in celltype_ad_list: 12 | X_tmp = np.matrix(celltype_ad.X,dtype='float32') 13 | X_list.append(X_tmp) 14 | A_list.append(coo_matrix(celltype_ad.obsp['spatial_distances'],dtype='float32')) 15 | 16 | X_raw = np.concatenate(X_list) 17 | class_index = 0 18 | slice_class = [] 19 | for A_tmp in A_list: 20 | slice_class = slice_class + [class_index]*A_tmp.shape[0] 21 | class_index += 1 22 | A = block_diag(A_list) 23 | nb_mask = np.argwhere(A > 0).T 24 | slice_class_onehot = F.one_hot(torch.tensor(slice_class)).float() 25 | return X_raw,A,nb_mask,slice_class_onehot 26 | 27 | def get_graph_kernel(features,adj,k=2): 28 | features_scaled = (features-features.mean(0))/features.std(0) 29 | features_scaled = torch.tensor(features_scaled) 30 | SYM_NORM = False # symmetric (True) vs. left-only (False) normalization 31 | L = normalized_laplacian(adj, SYM_NORM) 32 | L_scaled = rescale_laplacian(L) 33 | T_k = chebyshev_polynomial(L_scaled, k) 34 | support = k + 1 35 | graph = [features_scaled]+T_k 36 | for _i in range(len(graph))[1:]: 37 | graph[_i] = sparse_mx_to_torch_sparse_tensor(graph[_i]) 38 | return features_scaled, graph, support 39 | 40 | def split_train_test_idx(X,train_prop): 41 | rand_idx = np.random.permutation(X.shape[0]) 42 | train_idx = rand_idx[:int(len(rand_idx)*train_prop)] 43 | test_idx = rand_idx[int(len(rand_idx)*train_prop):] 44 | return train_idx, test_idx -------------------------------------------------------------------------------- /SPACEL/Splane/kegra/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuKunLab/SPACEL/02801c2dcb18cbf4ffefdc1352a81314c571fe85/SPACEL/Splane/kegra/__init__.py -------------------------------------------------------------------------------- /SPACEL/Splane/kegra/gnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from tensorflow.keras import activations, initializers, constraints 4 | from tensorflow.keras import regularizers 5 | from tensorflow.keras.layers import Layer 6 | import tensorflow.keras.backend as K 7 | import tensorflow as tf 8 | 9 | 10 | class GraphConvolution(Layer): 11 | """Basic graph convolution layer as in https://arxiv.org/abs/1609.02907""" 12 | def __init__(self, units, support=1, 13 | activation=None, 14 | use_bias=True, 15 | kernel_initializer='glorot_uniform', 16 | bias_initializer='zeros', 17 | kernel_regularizer=None, 18 | bias_regularizer=None, 19 | activity_regularizer=None, 20 | kernel_constraint=None, 21 | bias_constraint=None, 22 | **kwargs): 23 | if 'input_shape' not in kwargs and 'input_dim' in kwargs: 24 | kwargs['input_shape'] = (kwargs.pop('input_dim'),) 25 | super(GraphConvolution, self).__init__(**kwargs) 26 | self.units = units 27 | self.activation = activations.get(activation) 28 | self.use_bias = use_bias 29 | self.kernel_initializer = initializers.get(kernel_initializer) 30 | self.bias_initializer = initializers.get(bias_initializer) 31 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 32 | self.bias_regularizer = regularizers.get(bias_regularizer) 33 | self.activity_regularizer = regularizers.get(activity_regularizer) 34 | self.kernel_constraint = constraints.get(kernel_constraint) 35 | self.bias_constraint = constraints.get(bias_constraint) 36 | self.supports_masking = True 37 | 38 | self.support = support 39 | assert support >= 1 40 | 41 | def compute_output_shape(self, input_shapes): 42 | features_shape = input_shapes[0] 43 | output_shape = (features_shape[0], self.units) 44 | return output_shape # (batch_size, output_dim) 45 | 46 | def build(self, input_shapes): 47 | features_shape = input_shapes[0] 48 | assert len(features_shape) == 2 49 | input_dim = features_shape[1] 50 | 51 | self.kernel = self.add_weight(shape=(input_dim * self.support, 52 | self.units), 53 | initializer=self.kernel_initializer, 54 | name='kernel', 55 | regularizer=self.kernel_regularizer, 56 | constraint=self.kernel_constraint) 57 | if self.use_bias: 58 | self.bias = self.add_weight(shape=(self.units,), 59 | initializer=self.bias_initializer, 60 | name='bias', 61 | regularizer=self.bias_regularizer, 62 | constraint=self.bias_constraint) 63 | else: 64 | self.bias = None 65 | self.built = True 66 | 67 | def call(self, inputs, mask=None): 68 | features = inputs[0] 69 | basis = inputs[1:] 70 | 71 | supports = list() 72 | for i in range(self.support): 73 | # supports.append(K.dot(basis[i], features)) 74 | # supports.append(K.dot(tf.sparse.to_dense(basis[i]), features)) 75 | # supports.append(K.dot(tf.sparse.to_dense(basis[i]), tf.sparse.to_dense(features))) 76 | supports.append(tf.sparse.sparse_dense_matmul(basis[i], features)) 77 | supports = K.concatenate(supports, axis=1) 78 | output = K.dot(supports, self.kernel) 79 | 80 | if self.use_bias: 81 | output += self.bias 82 | return self.activation(output) 83 | 84 | def get_config(self): 85 | config = {'units': self.units, 86 | 'support': self.support, 87 | 'activation': activations.serialize(self.activation), 88 | 'use_bias': self.use_bias, 89 | 'kernel_initializer': initializers.serialize( 90 | self.kernel_initializer), 91 | 'bias_initializer': initializers.serialize( 92 | self.bias_initializer), 93 | 'kernel_regularizer': regularizers.serialize( 94 | self.kernel_regularizer), 95 | 'bias_regularizer': regularizers.serialize( 96 | self.bias_regularizer), 97 | 'activity_regularizer': regularizers.serialize( 98 | self.activity_regularizer), 99 | 'kernel_constraint': constraints.serialize( 100 | self.kernel_constraint), 101 | 'bias_constraint': constraints.serialize(self.bias_constraint) 102 | } 103 | 104 | base_config = super(GraphConvolution, self).get_config() 105 | return dict(list(base_config.items()) + list(config.items())) 106 | -------------------------------------------------------------------------------- /SPACEL/Splane/kegra/gnn_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import scipy.sparse as sp 4 | import numpy as np 5 | from scipy.sparse.linalg.eigen.arpack import eigsh, ArpackNoConvergence 6 | 7 | 8 | def encode_onehot(labels): 9 | classes = set(labels) 10 | classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)} 11 | labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32) 12 | return labels_onehot 13 | 14 | 15 | def load_data(path="data/cora/", dataset="cora"): 16 | """Load citation network dataset (cora only for now)""" 17 | print('Loading {} dataset...'.format(dataset)) 18 | 19 | idx_features_labels = np.genfromtxt("{}{}.content".format(path, dataset), dtype=np.dtype(str)) 20 | features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32) 21 | labels = encode_onehot(idx_features_labels[:, -1]) 22 | 23 | # build graph 24 | idx = np.array(idx_features_labels[:, 0], dtype=np.int32) 25 | idx_map = {j: i for i, j in enumerate(idx)} 26 | edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset), dtype=np.int32) 27 | edges = np.array(list(map(idx_map.get, edges_unordered.flatten())), 28 | dtype=np.int32).reshape(edges_unordered.shape) 29 | adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), 30 | shape=(labels.shape[0], labels.shape[0]), dtype=np.float32) 31 | 32 | # build symmetric adjacency matrix 33 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 34 | 35 | print('Dataset has {} nodes, {} edges, {} features.'.format(adj.shape[0], edges.shape[0], features.shape[1])) 36 | 37 | return features.todense(), adj, labels 38 | 39 | 40 | def normalize_adj(adj, symmetric=True): 41 | if symmetric: 42 | d = sp.diags(np.power(np.array(adj.sum(1)), -0.5).flatten(), 0) 43 | a_norm = adj.dot(d).transpose().dot(d).tocsr() 44 | else: 45 | d = sp.diags(np.power(np.array(adj.sum(1)), -1).flatten(), 0) 46 | a_norm = d.dot(adj).tocsr() 47 | return a_norm 48 | 49 | 50 | def preprocess_adj(adj, symmetric=True): 51 | adj = adj + sp.eye(adj.shape[0]) 52 | adj = normalize_adj(adj, symmetric) 53 | return adj 54 | 55 | 56 | def sample_mask(idx, l): 57 | mask = np.zeros(l) 58 | mask[idx] = 1 59 | return np.array(mask, dtype=np.bool) 60 | 61 | 62 | def get_splits(y): 63 | idx_train = range(140) 64 | idx_val = range(200, 500) 65 | idx_test = range(500, 1500) 66 | y_train = np.zeros(y.shape, dtype=np.int32) 67 | y_val = np.zeros(y.shape, dtype=np.int32) 68 | y_test = np.zeros(y.shape, dtype=np.int32) 69 | y_train[idx_train] = y[idx_train] 70 | y_val[idx_val] = y[idx_val] 71 | y_test[idx_test] = y[idx_test] 72 | train_mask = sample_mask(idx_train, y.shape[0]) 73 | return y_train, y_val, y_test, idx_train, idx_val, idx_test, train_mask 74 | 75 | 76 | def categorical_crossentropy(preds, labels): 77 | return np.mean(-np.log(np.extract(labels, preds))) 78 | 79 | 80 | def accuracy(preds, labels): 81 | return np.mean(np.equal(np.argmax(labels, 1), np.argmax(preds, 1))) 82 | 83 | 84 | def evaluate_preds(preds, labels, indices): 85 | 86 | split_loss = list() 87 | split_acc = list() 88 | 89 | for y_split, idx_split in zip(labels, indices): 90 | split_loss.append(categorical_crossentropy(preds[idx_split], y_split[idx_split])) 91 | split_acc.append(accuracy(preds[idx_split], y_split[idx_split])) 92 | 93 | return split_loss, split_acc 94 | 95 | 96 | def normalized_laplacian(adj, symmetric=True): 97 | adj_normalized = normalize_adj(adj, symmetric) 98 | laplacian = sp.eye(adj.shape[0]) - adj_normalized 99 | return laplacian 100 | 101 | 102 | def rescale_laplacian(laplacian): 103 | try: 104 | print('Calculating largest eigenvalue of normalized graph Laplacian...') 105 | largest_eigval = eigsh(laplacian, 1, which='LM', return_eigenvectors=False)[0] 106 | except ArpackNoConvergence: 107 | print('Eigenvalue calculation did not converge! Using largest_eigval=2 instead.') 108 | largest_eigval = 2 109 | 110 | scaled_laplacian = (2. / largest_eigval) * laplacian - sp.eye(laplacian.shape[0]) 111 | return scaled_laplacian 112 | 113 | 114 | def chebyshev_polynomial(X, k): 115 | """Calculate Chebyshev polynomials up to order k. Return a list of sparse matrices.""" 116 | print("Calculating Chebyshev polynomials up to order {}...".format(k)) 117 | 118 | T_k = list() 119 | T_k.append(sp.eye(X.shape[0]).tocsr()) 120 | T_k.append(X) 121 | 122 | def chebyshev_recurrence(T_k_minus_one, T_k_minus_two, X): 123 | X_ = sp.csr_matrix(X, copy=True) 124 | return 2 * X_.dot(T_k_minus_one) - T_k_minus_two 125 | 126 | for i in range(2, k+1): 127 | T_k.append(chebyshev_recurrence(T_k[-1], T_k[-2], X)) 128 | 129 | return T_k 130 | 131 | 132 | def sparse_to_tuple(sparse_mx): 133 | if not sp.isspmatrix_coo(sparse_mx): 134 | sparse_mx = sparse_mx.tocoo() 135 | coords = np.vSPACEL((sparse_mx.row, sparse_mx.col)).transpose() 136 | values = sparse_mx.data 137 | shape = sparse_mx.shape 138 | return coords, values, shape -------------------------------------------------------------------------------- /SPACEL/Splane/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import random 4 | import torch 5 | import squidpy as sq 6 | from .graph import get_graph_inputs,get_graph_kernel,split_train_test_idx 7 | from .utils import generate_celltype_ad_list,cal_celltype_weight 8 | from .base_model import SplaneModel 9 | 10 | def init_model( 11 | expr_ad_list:list, 12 | n_clusters:int, 13 | k:int=2, 14 | use_weight=True, 15 | train_prop:float=0.5, 16 | n_neighbors=6, 17 | min_prop=0.01, 18 | lr:float=3e-3, 19 | l1:float=0.01, 20 | l2:float=0.01, 21 | latent_dim:int=16, 22 | hidden_dims:int=64, 23 | gnn_dropout:float=0.8, 24 | simi_neighbor=1, 25 | use_gpu=None, 26 | seed=42 27 | )->SplaneModel: 28 | """Initialize Splane model. 29 | 30 | Build the model then set the data and paratemters. 31 | 32 | Args: 33 | expr_ad_list: A list of AnnData object of spatial transcriptomic data as the model input. 34 | n_clusters: The number of cluster of the model ouput. 35 | k: The order of neighbors of a spot for graph construction. 36 | use_weight: If True, the cell type proportion of Moran was used as the weight of the loss. 37 | train_prop: The proportion of training set. 38 | n_neighbors: The number of neighbors for graph construction. 39 | lr: Learning rate of training. 40 | latent_dim: The dimension of latent features. It equal to the number of nodes of bottleneck layer. 41 | hidden_dims: The number of nodes of hidden layers. 42 | gnn_dropout: The dropout rate of GNN model. 43 | simi_neighbor: The order of neighbors used for similarity loss. If is None, It equal to the order used for constructed graph. 44 | seed: Random number seed. 45 | Returns: 46 | A ``DataFrame`` contained deconvoluted results. Each row representing a spot, and each column representing a cell type. 47 | """ 48 | 49 | print('Setting global seed:', seed) 50 | random.seed(seed) 51 | np.random.seed(seed) 52 | torch.manual_seed(seed) 53 | torch.cuda.manual_seed(seed) 54 | torch.backends.cudnn.deterministic = True 55 | torch.backends.cudnn.benchmark = False 56 | 57 | if use_gpu is None: 58 | if torch.cuda.is_available(): 59 | use_gpu = True 60 | else: 61 | use_gpu = False 62 | 63 | for expr_ad in expr_ad_list: 64 | if 'spatial_connectivities' not in expr_ad.obsp.keys(): 65 | sq.gr.spatial_neighbors(expr_ad,coord_type='grid',n_neighs=n_neighbors) 66 | celltype_ad_list = generate_celltype_ad_list(expr_ad_list,min_prop=min_prop) 67 | celltype_weights,morans_mean = cal_celltype_weight(celltype_ad_list) 68 | kept_ind = celltype_weights > 0 69 | if not use_weight: 70 | celltype_weights = np.ones(len(celltype_weights))/len(celltype_weights) 71 | X,A,nb_mask,slice_class_onehot = get_graph_inputs(celltype_ad_list) 72 | X_filtered, graph, support = get_graph_kernel(X[:,kept_ind],A,k=k) 73 | celltype_weights = celltype_weights[kept_ind] 74 | morans_mean = morans_mean[kept_ind] 75 | if simi_neighbor == 1: 76 | nb_mask = nb_mask 77 | elif simi_neighbor == None: 78 | nb_mask = np.array(np.where(graph[-1].to_dense())!=0) 79 | nb_mask = nb_mask[:,nb_mask[0] != nb_mask[1]] 80 | else: 81 | raise ValueError('simi_neighbor must be 1 or None.') 82 | train_idx,test_idx = split_train_test_idx(X,train_prop=0.5) 83 | if use_gpu: 84 | for i in range(len(graph)): 85 | graph[i] = graph[i].cuda() 86 | slice_class_onehot = slice_class_onehot.cuda() 87 | return SplaneModel( 88 | expr_ad_list, 89 | n_clusters, 90 | X_filtered, 91 | graph, 92 | support, 93 | slice_class_onehot, 94 | nb_mask, 95 | train_idx, 96 | test_idx, 97 | celltype_weights, 98 | morans_mean, 99 | lr, 100 | l1, 101 | l2, 102 | latent_dim, 103 | hidden_dims, 104 | gnn_dropout, 105 | use_gpu 106 | ) 107 | -------------------------------------------------------------------------------- /SPACEL/Splane/pygcn/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | -------------------------------------------------------------------------------- /SPACEL/Splane/pygcn/LICENCE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2017 Thomas Kipf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /SPACEL/Splane/pygcn/README.md: -------------------------------------------------------------------------------- 1 | Graph Convolutional Networks in PyTorch 2 | ==== 3 | 4 | PyTorch implementation of Graph Convolutional Networks (GCNs) for semi-supervised classification [1]. 5 | 6 | For a high-level introduction to GCNs, see: 7 | 8 | Thomas Kipf, [Graph Convolutional Networks](http://tkipf.github.io/graph-convolutional-networks/) (2016) 9 | 10 | ![Graph Convolutional Networks](figure.png) 11 | 12 | Note: There are subtle differences between the TensorFlow implementation in https://github.com/tkipf/gcn and this PyTorch re-implementation. This re-implementation serves as a proof of concept and is not intended for reproduction of the results reported in [1]. 13 | 14 | This implementation makes use of the Cora dataset from [2]. 15 | 16 | ## Installation 17 | 18 | ```python setup.py install``` 19 | 20 | ## Requirements 21 | 22 | * PyTorch 0.4 or 0.5 23 | * Python 2.7 or 3.6 24 | 25 | ## Usage 26 | 27 | ```python train.py``` 28 | 29 | ## References 30 | 31 | [1] [Kipf & Welling, Semi-Supervised Classification with Graph Convolutional Networks, 2016](https://arxiv.org/abs/1609.02907) 32 | 33 | [2] [Sen et al., Collective Classification in Network Data, AI Magazine 2008](http://linqs.cs.umd.edu/projects/projects/lbc/) 34 | 35 | ## Cite 36 | 37 | Please cite our paper if you use this code in your own work: 38 | 39 | ``` 40 | @article{kipf2016semi, 41 | title={Semi-Supervised Classification with Graph Convolutional Networks}, 42 | author={Kipf, Thomas N and Welling, Max}, 43 | journal={arXiv preprint arXiv:1609.02907}, 44 | year={2016} 45 | } 46 | ``` 47 | -------------------------------------------------------------------------------- /SPACEL/Splane/pygcn/data/cora/README: -------------------------------------------------------------------------------- 1 | This directory contains the a selection of the Cora dataset (www.research.whizbang.com/data). 2 | 3 | The Cora dataset consists of Machine Learning papers. These papers are classified into one of the following seven classes: 4 | Case_Based 5 | Genetic_Algorithms 6 | Neural_Networks 7 | Probabilistic_Methods 8 | Reinforcement_Learning 9 | Rule_Learning 10 | Theory 11 | 12 | The papers were selected in a way such that in the final corpus every paper cites or is cited by atleast one other paper. There are 2708 papers in the whole corpus. 13 | 14 | After stemming and removing stopwords we were left with a vocabulary of size 1433 unique words. All words with document frequency less than 10 were removed. 15 | 16 | 17 | THE DIRECTORY CONTAINS TWO FILES: 18 | 19 | The .content file contains descriptions of the papers in the following format: 20 | 21 | + 22 | 23 | The first entry in each line contains the unique string ID of the paper followed by binary values indicating whether each word in the vocabulary is present (indicated by 1) or absent (indicated by 0) in the paper. Finally, the last entry in the line contains the class label of the paper. 24 | 25 | The .cites file contains the citation graph of the corpus. Each line describes a link in the following format: 26 | 27 | 28 | 29 | Each line contains two paper IDs. The first entry is the ID of the paper being cited and the second ID stands for the paper which contains the citation. The direction of the link is from right to left. If a line is represented by "paper1 paper2" then the link is "paper2->paper1". -------------------------------------------------------------------------------- /SPACEL/Splane/pygcn/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuKunLab/SPACEL/02801c2dcb18cbf4ffefdc1352a81314c571fe85/SPACEL/Splane/pygcn/figure.png -------------------------------------------------------------------------------- /SPACEL/Splane/pygcn/pygcn/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | from .layers import * 5 | from .models import * 6 | from .utils import * -------------------------------------------------------------------------------- /SPACEL/Splane/pygcn/pygcn/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from torch.nn.parameter import Parameter 6 | from torch.nn.modules.module import Module 7 | 8 | 9 | class GraphConvolution(Module): 10 | """ 11 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 12 | """ 13 | 14 | def __init__(self, in_features, out_features, bias=True): 15 | super(GraphConvolution, self).__init__() 16 | self.in_features = in_features 17 | self.out_features = out_features 18 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 19 | if bias: 20 | self.bias = Parameter(torch.FloatTensor(out_features)) 21 | else: 22 | self.register_parameter('bias', None) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self): 26 | stdv = 1. / math.sqrt(self.weight.size(1)) 27 | self.weight.data.uniform_(-stdv, stdv) 28 | if self.bias is not None: 29 | self.bias.data.uniform_(-stdv, stdv) 30 | 31 | def forward(self, input, adj): 32 | support = torch.mm(input, self.weight) 33 | output = torch.spmm(adj, support) 34 | if self.bias is not None: 35 | return output + self.bias 36 | else: 37 | return output 38 | 39 | def __repr__(self): 40 | return self.__class__.__name__ + ' (' \ 41 | + str(self.in_features) + ' -> ' \ 42 | + str(self.out_features) + ')' 43 | -------------------------------------------------------------------------------- /SPACEL/Splane/pygcn/pygcn/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from pygcn.layers import GraphConvolution 4 | 5 | 6 | class GCN(nn.Module): 7 | def __init__(self, nfeat, nhid, nclass, dropout): 8 | super(GCN, self).__init__() 9 | 10 | self.gc1 = GraphConvolution(nfeat, nhid) 11 | self.gc2 = GraphConvolution(nhid, nclass) 12 | self.dropout = dropout 13 | 14 | def forward(self, x, adj): 15 | x = F.relu(self.gc1(x, adj)) 16 | x = F.dropout(x, self.dropout, training=self.training) 17 | x = self.gc2(x, adj) 18 | return F.log_softmax(x, dim=1) 19 | -------------------------------------------------------------------------------- /SPACEL/Splane/pygcn/pygcn/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import time 5 | import argparse 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | from pygcn.utils import load_data, accuracy 13 | from pygcn.models import GCN 14 | 15 | # Training settings 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--no-cuda', action='store_true', default=False, 18 | help='Disables CUDA training.') 19 | parser.add_argument('--fastmode', action='store_true', default=False, 20 | help='Validate during training pass.') 21 | parser.add_argument('--seed', type=int, default=42, help='Random seed.') 22 | parser.add_argument('--epochs', type=int, default=200, 23 | help='Number of epochs to train.') 24 | parser.add_argument('--lr', type=float, default=0.01, 25 | help='Initial learning rate.') 26 | parser.add_argument('--weight_decay', type=float, default=5e-4, 27 | help='Weight decay (L2 loss on parameters).') 28 | parser.add_argument('--hidden', type=int, default=16, 29 | help='Number of hidden units.') 30 | parser.add_argument('--dropout', type=float, default=0.5, 31 | help='Dropout rate (1 - keep probability).') 32 | 33 | args = parser.parse_args() 34 | args.cuda = not args.no_cuda and torch.cuda.is_available() 35 | 36 | np.random.seed(args.seed) 37 | torch.manual_seed(args.seed) 38 | if args.cuda: 39 | torch.cuda.manual_seed(args.seed) 40 | 41 | # Load data 42 | adj, features, labels, idx_train, idx_val, idx_test = load_data() 43 | 44 | # Model and optimizer 45 | model = GCN(nfeat=features.shape[1], 46 | nhid=args.hidden, 47 | nclass=labels.max().item() + 1, 48 | dropout=args.dropout) 49 | optimizer = optim.Adam(model.parameters(), 50 | lr=args.lr, weight_decay=args.weight_decay) 51 | 52 | if args.cuda: 53 | model.cuda() 54 | features = features.cuda() 55 | adj = adj.cuda() 56 | labels = labels.cuda() 57 | idx_train = idx_train.cuda() 58 | idx_val = idx_val.cuda() 59 | idx_test = idx_test.cuda() 60 | 61 | 62 | def train(epoch): 63 | t = time.time() 64 | model.train() 65 | optimizer.zero_grad() 66 | output = model(features, adj) 67 | loss_train = F.nll_loss(output[idx_train], labels[idx_train]) 68 | acc_train = accuracy(output[idx_train], labels[idx_train]) 69 | loss_train.backward() 70 | optimizer.step() 71 | 72 | if not args.fastmode: 73 | # Evaluate validation set performance separately, 74 | # deactivates dropout during validation run. 75 | model.eval() 76 | output = model(features, adj) 77 | 78 | loss_val = F.nll_loss(output[idx_val], labels[idx_val]) 79 | acc_val = accuracy(output[idx_val], labels[idx_val]) 80 | print('Epoch: {:04d}'.format(epoch+1), 81 | 'loss_train: {:.4f}'.format(loss_train.item()), 82 | 'acc_train: {:.4f}'.format(acc_train.item()), 83 | 'loss_val: {:.4f}'.format(loss_val.item()), 84 | 'acc_val: {:.4f}'.format(acc_val.item()), 85 | 'time: {:.4f}s'.format(time.time() - t)) 86 | 87 | 88 | def test(): 89 | model.eval() 90 | output = model(features, adj) 91 | loss_test = F.nll_loss(output[idx_test], labels[idx_test]) 92 | acc_test = accuracy(output[idx_test], labels[idx_test]) 93 | print("Test set results:", 94 | "loss= {:.4f}".format(loss_test.item()), 95 | "accuracy= {:.4f}".format(acc_test.item())) 96 | 97 | 98 | # Train model 99 | t_total = time.time() 100 | for epoch in range(args.epochs): 101 | train(epoch) 102 | print("Optimization Finished!") 103 | print("Total time elapsed: {:.4f}s".format(time.time() - t_total)) 104 | 105 | # Testing 106 | test() 107 | -------------------------------------------------------------------------------- /SPACEL/Splane/pygcn/pygcn/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import torch 4 | 5 | 6 | def encode_onehot(labels): 7 | classes = set(labels) 8 | classes_dict = {c: np.identity(len(classes))[i, :] for i, c in 9 | enumerate(classes)} 10 | labels_onehot = np.array(list(map(classes_dict.get, labels)), 11 | dtype=np.int32) 12 | return labels_onehot 13 | 14 | 15 | def load_data(path="../data/cora/", dataset="cora"): 16 | """Load citation network dataset (cora only for now)""" 17 | print('Loading {} dataset...'.format(dataset)) 18 | 19 | idx_features_labels = np.genfromtxt("{}{}.content".format(path, dataset), 20 | dtype=np.dtype(str)) 21 | features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32) 22 | labels = encode_onehot(idx_features_labels[:, -1]) 23 | 24 | # build graph 25 | idx = np.array(idx_features_labels[:, 0], dtype=np.int32) 26 | idx_map = {j: i for i, j in enumerate(idx)} 27 | edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset), 28 | dtype=np.int32) 29 | edges = np.array(list(map(idx_map.get, edges_unordered.flatten())), 30 | dtype=np.int32).reshape(edges_unordered.shape) 31 | adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), 32 | shape=(labels.shape[0], labels.shape[0]), 33 | dtype=np.float32) 34 | 35 | # build symmetric adjacency matrix 36 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 37 | 38 | features = normalize(features) 39 | adj = normalize(adj + sp.eye(adj.shape[0])) 40 | 41 | idx_train = range(140) 42 | idx_val = range(200, 500) 43 | idx_test = range(500, 1500) 44 | 45 | features = torch.FloatTensor(np.array(features.todense())) 46 | labels = torch.LongTensor(np.where(labels)[1]) 47 | # adj = sparse_mx_to_torch_sparse_tensor(adj) 48 | 49 | idx_train = torch.LongTensor(idx_train) 50 | idx_val = torch.LongTensor(idx_val) 51 | idx_test = torch.LongTensor(idx_test) 52 | 53 | return adj, features, labels, idx_train, idx_val, idx_test 54 | 55 | 56 | def normalize(mx): 57 | """Row-normalize sparse matrix""" 58 | rowsum = np.array(mx.sum(1)) 59 | r_inv = np.power(rowsum, -1).flatten() 60 | r_inv[np.isinf(r_inv)] = 0. 61 | r_mat_inv = sp.diags(r_inv) 62 | mx = r_mat_inv.dot(mx) 63 | return mx 64 | 65 | 66 | def accuracy(output, labels): 67 | preds = output.max(1)[1].type_as(labels) 68 | correct = preds.eq(labels).double() 69 | correct = correct.sum() 70 | return correct / len(labels) 71 | 72 | 73 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 74 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 75 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 76 | indices = torch.from_numpy( 77 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 78 | values = torch.from_numpy(sparse_mx.data) 79 | shape = torch.Size(sparse_mx.shape) 80 | return torch.sparse.FloatTensor(indices, values, shape) 81 | -------------------------------------------------------------------------------- /SPACEL/Splane/pygcn/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from setuptools import find_packages 3 | 4 | setup(name='pygcn', 5 | version='0.1', 6 | description='Graph Convolutional Networks in PyTorch', 7 | author='Thomas Kipf', 8 | author_email='thomas.kipf@gmail.com', 9 | url='https://tkipf.github.io', 10 | download_url='https://github.com/tkipf/pygcn', 11 | license='MIT', 12 | install_requires=['numpy', 13 | 'torch', 14 | 'scipy' 15 | ], 16 | package_data={'pygcn': ['README.md']}, 17 | packages=find_packages()) -------------------------------------------------------------------------------- /SPACEL/Splane/pygcn_utils.py: -------------------------------------------------------------------------------- 1 | import scipy.sparse as sp 2 | import numpy as np 3 | from scipy.sparse.linalg import eigsh, ArpackNoConvergence 4 | import torch 5 | 6 | def normalize_adj(adj, symmetric=True): 7 | if symmetric: 8 | d = sp.diags(np.power(np.array(adj.sum(1)), -0.5).flatten(), 0) 9 | a_norm = adj.dot(d).transpose().dot(d).tocsr() 10 | else: 11 | d = sp.diags(np.power(np.array(adj.sum(1)), -1).flatten(), 0) 12 | a_norm = d.dot(adj).tocsr() 13 | return a_norm 14 | 15 | def normalized_laplacian(adj, symmetric=True): 16 | adj_normalized = normalize_adj(adj, symmetric) 17 | laplacian = sp.eye(adj.shape[0]) - adj_normalized 18 | return laplacian 19 | 20 | 21 | def rescale_laplacian(laplacian): 22 | try: 23 | print('Calculating largest eigenvalue of normalized graph Laplacian...') 24 | largest_eigval = eigsh(laplacian, 1, which='LM', return_eigenvectors=False)[0] 25 | except ArpackNoConvergence: 26 | print('Eigenvalue calculation did not converge! Using largest_eigval=2 instead.') 27 | largest_eigval = 2 28 | 29 | scaled_laplacian = (2. / largest_eigval) * laplacian - sp.eye(laplacian.shape[0]) 30 | return scaled_laplacian 31 | 32 | 33 | def chebyshev_polynomial(X, k): 34 | """Calculate Chebyshev polynomials up to order k. Return a list of sparse matrices.""" 35 | print("Calculating Chebyshev polynomials up to order {}...".format(k)) 36 | 37 | T_k = list() 38 | T_k.append(sp.eye(X.shape[0]).tocsr()) 39 | T_k.append(X) 40 | 41 | def chebyshev_recurrence(T_k_minus_one, T_k_minus_two, X): 42 | X_ = sp.csr_matrix(X, copy=True) 43 | return 2 * X_.dot(T_k_minus_one) - T_k_minus_two 44 | 45 | for i in range(2, k+1): 46 | T_k.append(chebyshev_recurrence(T_k[-1], T_k[-2], X)) 47 | 48 | return T_k 49 | 50 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 51 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 52 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 53 | indices = torch.from_numpy( 54 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 55 | values = torch.from_numpy(sparse_mx.data) 56 | shape = torch.Size(sparse_mx.shape) 57 | return torch.sparse.FloatTensor(indices, values, shape) -------------------------------------------------------------------------------- /SPACEL/Splane/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import anndata 3 | import numpy as np 4 | from scipy.sparse import coo_matrix 5 | from sklearn.metrics import davies_bouldin_score 6 | 7 | def one_hot_encode(labels, unique_labels=None): 8 | if unique_labels is None: 9 | unique_labels = np.unique(labels) 10 | num_classes = len(unique_labels) 11 | label_map = {label: i for i, label in enumerate(unique_labels)} 12 | encoded = np.eye(num_classes)[np.array([label_map[label] for label in labels])] 13 | return encoded, unique_labels 14 | 15 | def add_cell_type_composition(ad, prop_df=None, celltype_anno=None, all_celltypes=None): 16 | """Add cell type composition. 17 | 18 | Adding cell type compostion to AnnData object of spatial transcriptomic data as Splane input. 19 | 20 | Args: 21 | ad: A AnnData object of spatial transcriptomic data as Splane input. 22 | prop_df: A DataFrame of cell type composition used for spot-based spatial transcriptomic data. 23 | celltype_anno: A list containing the cell type annotations for each cell in the single-cell resolution spatial transcriptomic data. This parameter is not used if `prof_ad` is provided. 24 | all_celltypes: A list of all cell types present in all slices. This parameter is used when a single slice does not cover all cell types in the dataset. 25 | 26 | Returns: 27 | ``None`` 28 | """ 29 | if prop_df is not None: 30 | if all_celltypes is not None: 31 | prop_df.loc[:,np.setdiff1d(all_celltypes, prop_df.columns)] = 0 32 | ad.obs[prop_df.columns] = prop_df.values 33 | ad.uns['celltypes'] = prop_df.columns 34 | elif celltype_anno is not None: 35 | encoded, unique_celltypes = one_hot_encode(celltype_anno, all_celltypes) 36 | ad.obs[unique_celltypes] = encoded 37 | ad.uns['celltypes'] = unique_celltypes 38 | else: 39 | raise ValueError("prop_df and celltype_anno can not both be None.") 40 | 41 | # From scanpy 42 | def _morans_i_mtx( 43 | g_data: np.ndarray, 44 | g_indices: np.ndarray, 45 | g_indptr: np.ndarray, 46 | X: np.ndarray, 47 | ) -> np.ndarray: 48 | M, N = X.shape 49 | assert N == len(g_indptr) - 1 50 | W = g_data.sum() 51 | out = np.zeros(M, dtype=np.float_) 52 | for k in range(M): 53 | x = X[k, :] 54 | out[k] = _morans_i_vec_W(g_data, g_indices, g_indptr, x, W) 55 | return out 56 | 57 | # From scanpy 58 | def _morans_i_vec_W( 59 | g_data: np.ndarray, 60 | g_indices: np.ndarray, 61 | g_indptr: np.ndarray, 62 | x: np.ndarray, 63 | W: np.float_, 64 | ) -> float: 65 | z = x - x.mean() 66 | z2ss = (z * z).sum() 67 | N = len(x) 68 | inum = 0.0 69 | 70 | for i in range(N): 71 | s = slice(g_indptr[i], g_indptr[i + 1]) 72 | i_indices = g_indices[s] 73 | i_data = g_data[s] 74 | inum += (i_data * z[i_indices]).sum() * z[i] 75 | 76 | return len(x) / W * inum / z2ss 77 | 78 | def fill_low_prop(ad,min_prop): 79 | mtx = ad.X 80 | mtx[mtx < min_prop] = 0 81 | ad.X = mtx 82 | return ad 83 | 84 | def cal_celltype_moran(ad): 85 | moran_vals = _morans_i_mtx( 86 | ad.obsp['spatial_connectivities'].data, 87 | ad.obsp['spatial_connectivities'].indices, 88 | ad.obsp['spatial_connectivities'].indptr, 89 | ad.X.T 90 | ) 91 | ad.uns['moran_vals'] = np.nan_to_num(moran_vals) 92 | 93 | def cal_celltype_weight(ad_list): 94 | print('Calculating cell type weights...') 95 | for ad in ad_list: 96 | cal_celltype_moran(ad) 97 | moran_min=-1 98 | morans = ad_list[0].uns['moran_vals'].copy() 99 | for i, ad in enumerate(ad_list[1:]): 100 | morans += ad.uns['moran_vals'].copy() 101 | morans_mean = morans/len(ad_list) 102 | celltype_weights = morans_mean/morans_mean.sum() 103 | return celltype_weights, morans_mean 104 | 105 | def generate_celltype_ad_list(expr_ad_list,min_prop): 106 | celltype_ad_list = [] 107 | for expr_ad in expr_ad_list: 108 | celltype_ad = anndata.AnnData(expr_ad.obs[[c for c in expr_ad.uns['celltypes']]]) 109 | celltype_ad.obs = expr_ad.obs 110 | celltype_ad.obsm =expr_ad.obsm 111 | celltype_ad.obsp = expr_ad.obsp 112 | celltype_ad = fill_low_prop(celltype_ad,min_prop) 113 | celltype_ad_list.append(celltype_ad) 114 | return celltype_ad_list 115 | 116 | def clustering(Cluster, feature): 117 | predict_labels = Cluster.fit_predict(feature) 118 | db = davies_bouldin_score(feature, predict_labels) 119 | return db 120 | 121 | def split_ad(ad,by): 122 | ad_list = [] 123 | for s in np.unique(ad.obs[by]): 124 | ad_split = ad[ad.obs[by] == s].copy() 125 | ad_list.append(ad_split) 126 | return ad_list 127 | 128 | -------------------------------------------------------------------------------- /SPACEL/Spoint/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_downsample import * 2 | from .model import * 3 | from .data_utils import * 4 | from .spatial_simulation import * 5 | from .base_model import * 6 | -------------------------------------------------------------------------------- /SPACEL/Spoint/base_model.py: -------------------------------------------------------------------------------- 1 | from . import model 2 | from . import data_utils 3 | import numpy as np 4 | import pandas as pd 5 | import scanpy as sc 6 | import scvi 7 | import anndata 8 | from . import metrics 9 | import matplotlib.pyplot as plt 10 | import os 11 | import tempfile 12 | from copy import deepcopy 13 | import logging 14 | import itertools 15 | from functools import partial 16 | from tqdm import tqdm 17 | from time import strftime, localtime 18 | from scipy.sparse import issparse 19 | from sklearn.decomposition import PCA 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | import torch.optim as optim 25 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, BatchSampler 26 | from torch.nn.parameter import Parameter 27 | from torch.nn.modules.module import Module 28 | from torch.utils.data.dataloader import default_collate 29 | 30 | def compute_kernel(x, y): 31 | x_size = x.size(0) 32 | y_size = y.size(0) 33 | dim = x.size(1) 34 | 35 | tiled_x = x.unsqueeze(1).expand(x_size, y_size, dim) 36 | tiled_y = y.unsqueeze(0).expand(x_size, y_size, dim) 37 | 38 | kernel = torch.exp(-torch.square(tiled_x - tiled_y).mean(dim=2) / dim) 39 | return kernel 40 | 41 | def compute_mmd(x, y): 42 | x_kernel = compute_kernel(x, x) 43 | y_kernel = compute_kernel(y, y) 44 | xy_kernel = compute_kernel(x, y) 45 | 46 | mmd = x_kernel.mean() + y_kernel.mean() - 2 * xy_kernel.mean() 47 | return mmd 48 | 49 | class PredictionModel(nn.Module): 50 | def __init__( 51 | self, 52 | input_dims, 53 | latent_dims, 54 | hidden_dims, 55 | celltype_dims, 56 | dropout 57 | ): 58 | super(PredictionModel, self).__init__() 59 | 60 | self.encoder = nn.Sequential( 61 | nn.Linear(input_dims, hidden_dims), 62 | nn.LeakyReLU(), 63 | nn.LayerNorm(hidden_dims), 64 | nn.Dropout(dropout), 65 | nn.Linear(hidden_dims, latent_dims), 66 | ) 67 | self.decoder = nn.Sequential( 68 | nn.Linear(celltype_dims, hidden_dims), 69 | nn.LeakyReLU(), 70 | nn.LayerNorm(hidden_dims), 71 | nn.Linear(hidden_dims, hidden_dims), 72 | nn.LeakyReLU(), 73 | nn.LayerNorm(hidden_dims), 74 | nn.Linear(hidden_dims, hidden_dims), 75 | nn.LeakyReLU(), 76 | nn.LayerNorm(hidden_dims), 77 | nn.Linear(hidden_dims, input_dims) 78 | ) 79 | self.pred = nn.Sequential( 80 | nn.Linear(latent_dims, hidden_dims), 81 | nn.LeakyReLU(), 82 | nn.LayerNorm(hidden_dims), 83 | nn.Dropout(dropout), 84 | nn.Linear(hidden_dims, celltype_dims), 85 | nn.Softmax(dim=1), 86 | ) 87 | 88 | nn.init.kaiming_normal_(self.encoder[0].weight) 89 | nn.init.kaiming_normal_(self.encoder[4].weight) 90 | nn.init.kaiming_normal_(self.decoder[0].weight) 91 | nn.init.kaiming_normal_(self.decoder[3].weight) 92 | nn.init.kaiming_normal_(self.decoder[6].weight) 93 | nn.init.xavier_uniform_(self.decoder[-1].weight) 94 | nn.init.kaiming_normal_(self.pred[0].weight) 95 | nn.init.xavier_uniform_(self.pred[4].weight) 96 | 97 | def forward(self, x): 98 | z = self.encoder(x) 99 | pred = self.pred(z) 100 | decoded = self.decoder(pred) 101 | return z, pred, decoded 102 | 103 | class SpointModel(): 104 | def __init__( 105 | self, 106 | st_ad, 107 | sm_ad, 108 | clusters, 109 | used_genes, 110 | spot_names, 111 | use_rep, 112 | st_batch_key=None, 113 | scvi_layers=2, 114 | scvi_latent=64, 115 | scvi_gene_likelihood='zinb', 116 | scvi_dispersion='gene-batch', 117 | latent_dims=32, 118 | hidden_dims=512, 119 | infer_losses=['kl','cos'], 120 | l1=0.01, 121 | l2=0.01, 122 | sm_lr=3e-4, 123 | st_lr=3e-5, 124 | use_gpu=None, 125 | seed=42 126 | ): 127 | if ((use_gpu is None) or (use_gpu is True)) and (torch.cuda.is_available()): 128 | self.device = 'cuda' 129 | else: 130 | self.device = 'cpu' 131 | self.use_gpu = use_gpu 132 | 133 | self.st_ad = st_ad 134 | self.sm_ad = sm_ad 135 | self.scvi_dims=64 136 | self.spot_names = spot_names 137 | self.used_genes = used_genes 138 | self.clusters = clusters 139 | self.st_batch_key = st_batch_key 140 | self.scvi_layers = scvi_layers 141 | self.scvi_latent = scvi_latent 142 | self.scvi_gene_likelihood = scvi_gene_likelihood 143 | self.scvi_dispersion = scvi_dispersion 144 | self.kl_infer_loss_func = partial(self.kl_divergence, dim=1) 145 | self.kl_rec_loss_func = partial(self.kl_divergence, dim=1) 146 | self.cosine_infer_loss_func = partial(F.cosine_similarity, dim=1) 147 | self.cosine_rec_loss_func = partial(F.cosine_similarity, dim=1) 148 | self.rmse_loss_func = self.rmse 149 | self.infer_losses = infer_losses 150 | self.mmd_loss = compute_mmd 151 | self.l1 = l1 152 | self.l2 = l2 153 | self.use_rep = use_rep 154 | if use_rep == 'scvi': 155 | self.feature_dims = scvi_latent 156 | elif use_rep == 'X': 157 | self.feature_dims = st_ad.shape[1] 158 | elif use_rep == 'pca': 159 | self.feature_dims = 50 160 | else: 161 | raise ValueError('use_rep must be one of scvi, pca and X.') 162 | self.latent_dims = latent_dims 163 | self.hidden_dims = hidden_dims 164 | self.sm_lr = sm_lr 165 | self.st_lr = st_lr 166 | self.init_model() 167 | self.st_data = None 168 | self.sm_data = None 169 | self.sm_labels = None 170 | self.best_path = None 171 | self.history = pd.DataFrame(columns = ['sm_train_rec_loss','sm_train_infer_loss','sm_test_rec_loss','sm_test_infer_loss','st_train_rec_loss','st_test_rec_loss','st_train_mmd_loss','st_test_mmd_loss','is_best']) 172 | self.batch_size = None 173 | self.seed = seed 174 | 175 | @staticmethod 176 | def rmse(y_true, y_pred): 177 | mse = F.mse_loss(y_pred, y_true) 178 | rmse = torch.sqrt(mse) 179 | return rmse 180 | 181 | @staticmethod 182 | def kl_divergence(y_true, y_pred, dim=0): 183 | y_pred = torch.clip(y_pred, torch.finfo(torch.float32).eps) 184 | y_true = y_true.to(y_pred.dtype) 185 | y_true = torch.nan_to_num(torch.div(y_true, y_true.sum(dim, keepdims=True)),0) 186 | y_pred = torch.nan_to_num(torch.div(y_pred, y_pred.sum(dim, keepdims=True)),0) 187 | y_true = torch.clip(y_true, torch.finfo(torch.float32).eps, 1) 188 | y_pred = torch.clip(y_pred, torch.finfo(torch.float32).eps, 1) 189 | return torch.mul(y_true, torch.log(torch.nan_to_num(torch.div(y_true, y_pred)))).mean(dim) 190 | 191 | def init_model(self): 192 | self.model = PredictionModel(self.feature_dims,self.latent_dims,self.hidden_dims,len(self.clusters),0.8).to(self.device) 193 | self.sm_optimizer = optim.Adam(list(self.model.encoder.parameters())+list(self.model.pred.parameters()),lr=self.sm_lr) 194 | self.st_optimizer = optim.Adam(list(self.model.encoder.parameters())+list(self.model.decoder.parameters()),lr=self.st_lr) 195 | 196 | def get_scvi_latent( 197 | self, 198 | n_layers=None, 199 | n_latent=None, 200 | gene_likelihood=None, 201 | dispersion=None, 202 | max_epochs=100, 203 | early_stopping=True, 204 | batch_size=4096, 205 | ): 206 | if self.st_batch_key is not None: 207 | if 'simulated' in self.st_ad.obs[self.st_batch_key]: 208 | raise ValueError(f'obs[{self.st_batch_key}] cannot include "real".') 209 | self.st_ad.obs["batch"] = self.st_ad.obs[self.st_batch_key].astype(str) 210 | self.sm_ad.obs["batch"] = 'simulated' 211 | else: 212 | self.st_ad.obs["batch"] = 'real' 213 | self.sm_ad.obs["batch"] = 'simulated' 214 | 215 | adata = sc.concat([self.st_ad,self.sm_ad]) 216 | adata.layers["counts"] = adata.X.copy() 217 | 218 | scvi.model.SCVI.setup_anndata( 219 | adata, 220 | layer="counts", 221 | batch_key="batch" 222 | ) 223 | if n_layers is None: 224 | n_layers = self.scvi_layers 225 | if n_latent is None: 226 | n_latent = self.scvi_latent 227 | if gene_likelihood is None: 228 | gene_likelihood = self.scvi_gene_likelihood 229 | if dispersion is None: 230 | dispersion = self.scvi_dispersion 231 | vae = scvi.model.SCVI(adata, n_layers=n_layers, n_latent=n_latent, gene_likelihood=gene_likelihood,dispersion=dispersion) 232 | vae.train(max_epochs=max_epochs,early_stopping=early_stopping,batch_size=batch_size,use_gpu=self.use_gpu) 233 | adata.obsm["X_scVI"] = vae.get_latent_representation() 234 | 235 | st_scvi_ad = anndata.AnnData(adata[adata.obs['batch'] != 'simulated'].obsm["X_scVI"]) 236 | sm_scvi_ad = anndata.AnnData(adata[adata.obs['batch'] == 'simulated'].obsm["X_scVI"]) 237 | 238 | st_scvi_ad.obs = self.st_ad.obs 239 | st_scvi_ad.obsm = self.st_ad.obsm 240 | 241 | sm_scvi_ad.obs = self.sm_ad.obs 242 | sm_scvi_ad.obsm = self.sm_ad.obsm 243 | 244 | sm_scvi_ad = data_utils.check_data_type(sm_scvi_ad) 245 | st_scvi_ad = data_utils.check_data_type(st_scvi_ad) 246 | 247 | self.sm_data = sm_scvi_ad.X 248 | self.sm_labels = sm_scvi_ad.obsm['label'].values 249 | self.st_data = st_scvi_ad.X 250 | 251 | return sm_scvi_ad,st_scvi_ad 252 | 253 | 254 | def build_dataset(self, batch_size, device=None): 255 | if device is None: 256 | device = self.device 257 | x_train,y_train,x_test,y_test = data_utils.split_shuffle_data(np.array(self.sm_data,dtype=np.float32),np.array(self.sm_labels,dtype=np.float32)) 258 | 259 | x_train = torch.tensor(x_train).to(device) 260 | y_train = torch.tensor(y_train).to(device) 261 | x_test = torch.tensor(x_test).to(device) 262 | y_test = torch.tensor(y_test).to(device) 263 | st_data = torch.tensor(self.st_data).to(device) 264 | 265 | self.sm_train_ds = TensorDataset(x_train, y_train) 266 | self.sm_test_ds = TensorDataset(x_test,y_test) 267 | self.st_ds = TensorDataset(st_data) 268 | 269 | self.sm_train_batch_size = min(len(self.sm_train_ds), batch_size) 270 | self.sm_test_batch_size = min(len(self.sm_test_ds), batch_size) 271 | self.st_batch_size = min(len(self.st_ds), batch_size) 272 | 273 | g = torch.Generator() 274 | g.manual_seed(self.seed) 275 | self.sm_train_sampler = BatchSampler(RandomSampler(self.sm_train_ds, generator=g), batch_size=self.sm_train_batch_size, drop_last=True) 276 | self.sm_test_sampler = BatchSampler(RandomSampler(self.sm_test_ds, generator=g), batch_size=self.sm_test_batch_size, drop_last=True) 277 | self.st_sampler = BatchSampler(RandomSampler(self.st_ds, generator=g), batch_size=self.st_batch_size, drop_last=True) 278 | 279 | def train_st(self, sm_data, st_data, rec_w=1, m_w=1): 280 | self.model.train() 281 | self.st_optimizer.zero_grad() 282 | sm_latent, sm_predictions, sm_rec_data = self.model(sm_data) 283 | st_latent, _, st_rec_data = self.model(st_data) 284 | sm_rec_loss = self.kl_rec_loss_func(sm_data, sm_rec_data).mean() - self.cosine_rec_loss_func(sm_data, sm_rec_data).mean() 285 | st_rec_loss = self.kl_rec_loss_func(st_data, st_rec_data).mean() - self.cosine_rec_loss_func(st_data, st_rec_data).mean() 286 | mmd_loss = self.mmd_loss(sm_latent, st_latent) 287 | loss = rec_w*sm_rec_loss + rec_w*st_rec_loss + m_w*mmd_loss 288 | loss.backward() 289 | self.st_optimizer.step() 290 | return loss, sm_rec_loss, st_rec_loss, mmd_loss 291 | 292 | def train_sm(self, sm_data, sm_labels, infer_w=1): 293 | self.model.train() 294 | self.sm_optimizer.zero_grad() 295 | sm_latent, sm_predictions, sm_rec_data = self.model(sm_data) 296 | infer_loss = 0 297 | for loss in self.infer_losses: 298 | if loss == 'kl': 299 | infer_loss += self.kl_infer_loss_func(sm_labels, sm_predictions).mean() 300 | elif loss == 'cos': 301 | infer_loss -= self.cosine_infer_loss_func(sm_labels, sm_predictions).mean() 302 | elif loss == 'rmse': 303 | infer_loss += self.rmse_loss_func(sm_labels, sm_predictions) 304 | loss = infer_w*infer_loss 305 | loss.backward() 306 | self.sm_optimizer.step() 307 | return loss, infer_loss 308 | 309 | def test_st(self, sm_data, st_data, rec_w=1, m_w=1): 310 | self.model.eval() 311 | sm_latent, sm_predictions, sm_rec_data = self.model(sm_data) 312 | st_latent, _, st_rec_data = self.model(st_data) 313 | sm_rec_loss = self.kl_rec_loss_func(sm_data, sm_rec_data).mean() - self.cosine_rec_loss_func(sm_data, sm_rec_data).mean() 314 | st_rec_loss = self.kl_rec_loss_func(st_data, st_rec_data).mean() - self.cosine_rec_loss_func(st_data, st_rec_data).mean() 315 | mmd_loss = self.mmd_loss(sm_latent, st_latent) 316 | loss = rec_w*sm_rec_loss + rec_w*st_rec_loss + m_w*mmd_loss 317 | return loss, sm_rec_loss, st_rec_loss, mmd_loss 318 | 319 | def test_sm(self, sm_data, sm_labels, infer_w=1): 320 | self.model.eval() 321 | sm_latent, sm_predictions, sm_rec_data = self.model(sm_data) 322 | infer_loss = 0 323 | for loss in self.infer_losses: 324 | if loss == 'kl': 325 | infer_loss += self.kl_infer_loss_func(sm_labels, sm_predictions).mean() 326 | elif loss == 'cos': 327 | infer_loss -= self.cosine_infer_loss_func(sm_labels, sm_predictions).mean() 328 | elif loss == 'rmse': 329 | infer_loss += self.rmse_loss_func(sm_labels, sm_predictions) 330 | loss = infer_w*infer_loss 331 | return loss, infer_loss 332 | 333 | def train_model_by_step( 334 | self, 335 | max_steps=5000, 336 | save_mode='all', 337 | save_path=None, 338 | prefix=None, 339 | sm_step=10, 340 | st_step=10, 341 | test_step_gap=1, 342 | convergence=0.001, 343 | early_stop=True, 344 | early_stop_max=2000, 345 | sm_lr=None, 346 | st_lr=None, 347 | rec_w=1, 348 | infer_w=1, 349 | m_w=1, 350 | ): 351 | if len(self.history) > 0: 352 | best_ind = np.where(self.history['is_best'] == 'True')[0][-1] 353 | best_loss = self.history['sm_test_infer_loss'][best_ind] 354 | best_rec_loss = self.history['st_test_rec_loss'][best_ind] 355 | else: 356 | best_loss = np.inf 357 | best_rec_loss = np.inf 358 | early_stop_count = 0 359 | if sm_lr is not None: 360 | for g in self.sm_optimizer.param_groups: 361 | g['lr'] = sm_lr 362 | if st_lr is not None: 363 | for g in self.st_optimizer.param_groups: 364 | g['lr'] = st_lr 365 | 366 | pbar = tqdm(range(max_steps)) 367 | sm_trainr_iter = itertools.cycle(self.sm_train_sampler) 368 | sm_test_iter = itertools.cycle(self.sm_test_sampler) 369 | st_iter = itertools.cycle(self.st_sampler) 370 | sm_train_shuffle_step = max(int(len(self.sm_train_ds)/(self.sm_train_batch_size*sm_step)),1) 371 | sm_test_shuffle_step = max(int(len(self.sm_test_ds)/(self.sm_test_batch_size*sm_step)),1) 372 | st_shuffle_step = max(int(len(self.st_ds)/(self.st_batch_size*st_step)),1) 373 | for step in pbar: 374 | if step % sm_train_shuffle_step == 0: 375 | sm_train_iter = itertools.cycle(self.sm_train_sampler) 376 | if step % sm_test_shuffle_step == 0: 377 | sm_test_iter = itertools.cycle(self.sm_test_sampler) 378 | if step % st_shuffle_step == 0: 379 | st_iter = itertools.cycle(self.st_sampler) 380 | 381 | st_exp = self.st_ds[next(st_iter)][0] 382 | sm_exp, sm_proportion = self.sm_train_ds[next(sm_train_iter)] 383 | for i in range(st_step): 384 | st_train_total_loss, sm_train_rec_loss, st_train_rec_loss, st_train_mmd_loss = self.train_st(sm_exp, st_exp, rec_w=rec_w, m_w=m_w) 385 | for i in range(sm_step): 386 | sm_train_total_loss, sm_train_infer_loss = self.train_sm(sm_exp, sm_proportion, infer_w=infer_w) 387 | 388 | if step % test_step_gap == 0: 389 | sm_test_exp, sm_test_proportion = self.sm_test_ds[next(sm_test_iter)] 390 | st_test_total_loss, sm_test_rec_loss, st_test_rec_loss, st_test_mmd_loss = self.test_st(sm_test_exp, st_exp, rec_w=rec_w, m_w=m_w) 391 | sm_test_total_loss, sm_test_infer_loss = self.test_sm(sm_test_exp, sm_test_proportion, infer_w=infer_w) 392 | 393 | current_infer_loss = sm_test_infer_loss.item() 394 | 395 | best_flag='False' 396 | if best_loss - current_infer_loss > convergence: 397 | if best_loss > current_infer_loss: 398 | best_loss = current_infer_loss 399 | best_flag='True' 400 | # print('### Update best model') 401 | early_stop_count = 0 402 | old_best_path = self.best_path 403 | if prefix is not None: 404 | self.best_path = os.path.join(save_path,prefix+'_'+f'celleagle_weights_step{step}.h5') 405 | else: 406 | self.best_path = os.path.join(save_path,f'celleagle_weights_step{step}.h5') 407 | if save_mode == 'best': 408 | if old_best_path is not None: 409 | if os.path.exists(old_best_path): 410 | os.remove(old_best_path) 411 | torch.save(self.model.state_dict(), self.best_path) 412 | else: 413 | early_stop_count += 1 414 | 415 | if save_mode == 'all': 416 | if prefix is not None: 417 | self.best_path = os.path.join(save_path,prefix+'_'+f'celleagle_weights_step{step}.h5') 418 | else: 419 | self.best_path = os.path.join(save_path,f'celleagle_weights_step{step}.h5') 420 | torch.save(self.model.state_dict(), self.best_path) 421 | 422 | self.history = pd.concat([ 423 | self.history, 424 | pd.DataFrame({ 425 | 'sm_train_infer_loss':sm_train_infer_loss.item(), 426 | 'sm_train_rec_loss':sm_train_infer_loss.item(), 427 | 'sm_test_rec_loss':sm_test_rec_loss.item(), 428 | 'sm_test_infer_loss':sm_test_infer_loss.item(), 429 | 'st_train_rec_loss':st_train_rec_loss.item(), 430 | 'st_test_rec_loss':st_test_rec_loss.item(), 431 | 'st_train_mmd_loss':st_train_rec_loss.item(), 432 | 'st_test_mmd_loss':st_test_rec_loss.item(), 433 | 'is_best':best_flag 434 | },index=[0]) 435 | ]).reset_index(drop=True) 436 | 437 | pbar.set_description(f"Step {step + 1}: Test inference loss={sm_test_infer_loss.item():.3f}",refresh=True) 438 | 439 | if (early_stop_count > early_stop_max) and early_stop: 440 | print('Stop trainning because of loss convergence') 441 | break 442 | 443 | def train_model( 444 | self, 445 | max_steps=5000, 446 | save_mode='all', 447 | save_path=None, 448 | prefix=None, 449 | sm_step=10, 450 | st_step=10, 451 | test_step_gap=1, 452 | convergence=0.001, 453 | early_stop=False, 454 | early_stop_max=2000, 455 | sm_lr=None, 456 | st_lr=None, 457 | batch_size=1024, 458 | rec_w=1, 459 | infer_w=1, 460 | m_w=1, 461 | ): 462 | """Training Spoint model. 463 | 464 | Training Spoint model. 465 | 466 | Args: 467 | max_steps: The max step of training. The training process will be stop when achive max step. 468 | save_mode: A string determinates how the model is saved. It must be one of 'best' and 'all'. 469 | save_path: A string representing the path directory where the model is saved. 470 | prefix: A string added to the prefix of file name of saved model. 471 | convergence: The threshold of early stop. 472 | early_stop: If True, turn on early stop. 473 | early_stop_max: The max steps of loss difference less than convergence. 474 | sm_lr: Learning rate for simulated data. 475 | st_lr: Learning rate for spatial transcriptomic data. 476 | disc_lr: Learning rate of discriminator. 477 | batch_size: Batch size of the data be feeded in model once. 478 | rec_w: The weight of reconstruction loss. 479 | infer_w: The weig ht of inference loss. 480 | m_w: The weight of MMD loss. 481 | 482 | Returns: 483 | ``None`` 484 | """ 485 | self.init_model() 486 | self.train_model_by_step( 487 | max_steps=max_steps, 488 | save_mode=save_mode, 489 | save_path=save_path, 490 | prefix=prefix, 491 | sm_step=sm_step, 492 | st_step=st_step, 493 | test_step_gap=test_step_gap, 494 | convergence=convergence, 495 | early_stop=early_stop, 496 | early_stop_max=early_stop_max, 497 | sm_lr=sm_lr, 498 | st_lr=st_lr, 499 | rec_w=rec_w, 500 | infer_w=infer_w, 501 | m_w=m_w 502 | ) 503 | 504 | def train( 505 | self, 506 | max_steps=5000, 507 | save_mode='all', 508 | save_path=None, 509 | prefix=None, 510 | sm_step=10, 511 | st_step=10, 512 | test_step_gap=1, 513 | convergence=0.001, 514 | early_stop=False, 515 | early_stop_max=2000, 516 | sm_lr=None, 517 | st_lr=None, 518 | batch_size=1024, 519 | rec_w=1, 520 | infer_w=1, 521 | m_w=1, 522 | scvi_max_epochs=100, 523 | scvi_early_stopping=True, 524 | scvi_batch_size=4096, 525 | ): 526 | """Training Spoint model. 527 | 528 | Obtain latent feature from scVI then feed in Spoint model for training. 529 | 530 | Args: 531 | max_steps: The max step of training. The training process will be stop when achive max step. 532 | save_mode: A string determinates how the model is saved. It must be one of 'best' and 'all'. 533 | save_path: A string representing the path directory where the model is saved. 534 | prefix: A string added to the prefix of file name of saved model. 535 | convergence: The threshold of early stop. 536 | early_stop: If True, turn on early stop. 537 | early_stop_max: The max steps of loss difference less than convergence. 538 | sm_lr: Learning rate for simulated data. 539 | st_lr: Learning rate for spatial transcriptomic data. 540 | batch_size: Batch size of the data be feeded in model once. 541 | rec_w: The weight of reconstruction loss. 542 | infer_w: The weig ht of inference loss. 543 | m_w: The weight of MMD loss. 544 | scvi_max_epochs: The max epoch of scVI. 545 | scvi_batch_size: The batch size of scVI. 546 | Returns: 547 | ``None`` 548 | """ 549 | if save_path is None: 550 | save_path = os.path.join(tempfile.gettempdir() ,'Spoint_models_'+strftime("%Y%m%d%H%M%S",localtime())) 551 | if not os.path.exists(save_path): 552 | os.makedirs(save_path) 553 | self.get_scvi_latent(max_epochs=scvi_max_epochs, early_stopping=scvi_early_stopping, batch_size=scvi_batch_size) 554 | self.build_dataset(batch_size) 555 | self.train_model( 556 | max_steps=max_steps, 557 | save_mode=save_mode, 558 | save_path=save_path, 559 | prefix=prefix, 560 | sm_step=sm_step, 561 | st_step=st_step, 562 | test_step_gap=test_step_gap, 563 | convergence=convergence, 564 | early_stop=early_stop, 565 | early_stop_max=early_stop_max, 566 | sm_lr=sm_lr, 567 | st_lr=st_lr, 568 | rec_w=rec_w, 569 | infer_w=infer_w, 570 | m_w=m_w 571 | ) 572 | 573 | def eval_model(self,model_path=None,use_best_model=True,batch_size=4096,metric='pcc'): 574 | if metric=='pcc': 575 | metric_name = 'PCC' 576 | func = metrics.pcc 577 | if metric=='spcc': 578 | metric_name = 'SPCC' 579 | func = metrics.spcc 580 | if metric=='mae': 581 | metric_name = 'MAE' 582 | func = metrics.mae 583 | if metric=='js': 584 | metric_name = 'JS' 585 | func = metrics.js 586 | if metric=='rmse': 587 | metric_name = 'RMSE' 588 | func = metrics.rmse 589 | if metric=='ssim': 590 | metric_name = 'SSIM' 591 | func = metrics.ssim 592 | 593 | if model_path is not None: 594 | self.model.load_state_dict(torch.load(model_path)) 595 | elif use_best_model: 596 | self.model.load_state_dict(torch.load(self.best_path)) 597 | model.eval() 598 | pre = [] 599 | prop = [] 600 | for exp_batch, prop_batch in self.sm_test_dataloader: 601 | latent_tmp, pre_tmp, _ = self.model(exp_batch) 602 | pre.extend(pre_tmp.cpu().detach().numpy()) 603 | prop.extend(prop_batch.cpu().detach().numpy()) 604 | pre = np.array(pre) 605 | prop = np.array(prop) 606 | metric_list = [] 607 | for i,c in enumerate(self.clusters): 608 | metric_list.append(func(pre[:,i],prop[:,i])) 609 | print('### Evaluate model with simulation data') 610 | for i in range(len(metric_list)): 611 | print(f'{metric_name} of {self.clusters[i]}, {metric_list[i]}') 612 | 613 | def plot_training_history(self,save=None,return_fig=False,show=True,dpi=300): 614 | if len(self.history) > 0: 615 | fig, ax = plt.subplots() 616 | plt.plot(np.arange(len(self.history)), self.history['sm_test_infer_loss'], label='sm_test_infer_loss') 617 | plt.plot(np.arange(len(self.history)), self.history['st_test_rec_loss'], label='st_test_rec_loss') 618 | plt.xlabel('Epochs') 619 | plt.ylabel('Losses') 620 | plt.title('Training history') 621 | plt.legend() 622 | if save is not None: 623 | plt.savefig(save,bbox_inches='tight',dpi=dpi) 624 | if show: 625 | plt.show() 626 | plt.close() 627 | if return_fig: 628 | return fig 629 | else: 630 | print('History is empty, training model first') 631 | 632 | 633 | def deconv_spatial(self,st_data=None,min_prop=0.01,model_path=None,use_best_model=True,add_obs=True,add_uns=True): 634 | """Deconvolute spatial transcriptomic data. 635 | 636 | Using well-trained Spoint model to predict the cell type porportion of spots in spatial transcriptomic data. 637 | 638 | Args: 639 | st_data: An AnnData object of spatial transcriptomic data to be deconvolute. 640 | min_prop: A threshold value below which the predicted value will be set to 0. 641 | model_path: A string representing the path of saved model file. 642 | use_best_model: If True, the model with the least loss will be used, otherwise, the last trained model will be used. 643 | add_obs: If True, the predicted results will be writen to the obs of input AnnData object of spatial transcriptomic data. 644 | add_uns: If True, the name of predicted cell types will be writen to the uns of input AnnData object of spatial transcriptomic data 645 | 646 | Returns: 647 | A ``DataFrame`` contained deconvoluted results. Each row representing a spot, and each column representing a cell type. 648 | """ 649 | if st_data is None: 650 | st_data = self.st_data 651 | st_data = torch.tensor(st_data).to(self.device) 652 | if model_path is not None: 653 | self.model.load_state_dict(torch.load(model_path)) 654 | elif use_best_model: 655 | self.model.load_state_dict(torch.load(self.best_path)) 656 | self.model.to(self.device) 657 | self.model.eval() 658 | latent, pre, _ = self.model(st_data) 659 | pre = pre.cpu().detach().numpy() 660 | pre[pre < min_prop] = 0 661 | pre = pd.DataFrame(pre,columns=self.clusters,index=self.st_ad.obs_names) 662 | self.st_ad.obs[pre.columns] = pre.values 663 | self.st_ad.uns['celltypes'] = list(pre.columns) 664 | return pre -------------------------------------------------------------------------------- /SPACEL/Spoint/data_augmentation.py: -------------------------------------------------------------------------------- 1 | import numba 2 | import numpy as np 3 | 4 | @numba.njit 5 | def random_dropout(cell_expr,max_rate): 6 | non_zero_mask = np.where(cell_expr!=0)[0] 7 | zero_mask = np.random.choice(non_zero_mask,int(len(non_zero_mask)*np.float32(np.random.uniform(0,max_rate)))) 8 | cell_expr[zero_mask] = 0 9 | return cell_expr 10 | 11 | @numba.njit 12 | def random_scale(cell_expr,max_val): 13 | scale_factor = np.float32(1+np.random.uniform(-max_val,max_val)) 14 | cell_expr = cell_expr*scale_factor 15 | return cell_expr 16 | 17 | @numba.njit 18 | def random_shift(cell_expr,kth): 19 | shift_value = np.random.choice(np.array([1,0,-1]),1)[0]*np.unique(cell_expr)[int(np.random.uniform(0,kth)*len(np.unique(cell_expr)))] 20 | cell_expr[cell_expr != 0] = cell_expr[cell_expr != 0]+shift_value 21 | cell_expr[cell_expr < 0] = 0 22 | return cell_expr 23 | 24 | @numba.njit(parallel=True) 25 | def random_augment(mtx,max_rate=0.8,max_val=0.8,kth=0.2): 26 | for i in numba.prange(mtx.shape[0]): 27 | random_dropout(mtx[i,:],max_rate=max_rate) 28 | random_scale(mtx[i,:],max_val=max_val) 29 | random_shift(mtx[i,:],kth=kth) 30 | return mtx 31 | 32 | @numba.njit 33 | def random_augmentation_cell(cell_expr,max_rate=0.8,max_val=0.8,kth=0.2): 34 | cell_expr = random_dropout(cell_expr,max_rate=max_rate) 35 | cell_expr = random_scale(cell_expr,max_val=max_val) 36 | cell_expr = random_shift(cell_expr,kth=kth) 37 | return cell_expr -------------------------------------------------------------------------------- /SPACEL/Spoint/data_downsample.py: -------------------------------------------------------------------------------- 1 | import numba 2 | import numpy as np 3 | import multiprocessing as mp 4 | from functools import partial 5 | import random 6 | 7 | # Cite from https://github.com/numba/numba-examples 8 | @numba.jit(nopython=True, parallel=True) 9 | def get_bin_edges(a, bins): 10 | bin_edges = np.zeros((bins+1,), dtype=np.float32) 11 | a_min = a.min() 12 | a_max = a.max() 13 | delta = (a_max - a_min) / bins 14 | for i in numba.prange(bin_edges.shape[0]): 15 | bin_edges[i] = a_min + i * delta 16 | 17 | bin_edges[-1] = a_max # Avoid roundoff error on last point 18 | return bin_edges 19 | 20 | # Modified from https://github.com/numba/numba-examples 21 | @numba.jit(nopython=True, parallel=False) 22 | def compute_bin(x, bin_edges): 23 | # assuming uniform bins for now 24 | n = bin_edges.shape[0] - 1 25 | a_max = bin_edges[-1] 26 | # special case to mirror NumPy behavior for last bin 27 | if x == a_max: 28 | return n - 1 # a_max always in last bin 29 | bin = np.searchsorted(bin_edges, x)-1 30 | if bin < 0 or bin >= n: 31 | return None 32 | else: 33 | return bin 34 | 35 | # Modified from https://github.com/numba/numba-examples 36 | @numba.jit(nopython=True, parallel=False) 37 | def numba_histogram(a, bin_edges): 38 | hist = np.zeros((bin_edges.shape[0] - 1,), dtype=np.intp) 39 | for x in a.flat: 40 | bin = compute_bin(x, bin_edges) 41 | if bin is not None: 42 | hist[int(bin)] += 1 43 | return hist, bin_edges 44 | 45 | 46 | # Modified from https://rdrr.io/bioc/scRecover/src/R/countsSampling.R 47 | # Downsample cell reads to a fraction 48 | @numba.jit(nopython=True, parallel=True) 49 | def downsample_cell(cell_counts,fraction): 50 | n = np.floor(np.sum(cell_counts) * fraction) 51 | readsGet = np.sort(np.random.choice(np.arange(np.sum(cell_counts)), np.intp(n), replace=False)) 52 | cumCounts = np.concatenate((np.array([0]),np.cumsum(cell_counts))) 53 | counts_new = numba_histogram(readsGet,cumCounts)[0] 54 | counts_new = counts_new.astype(np.float32) 55 | return counts_new 56 | 57 | def downsample_cell_python(cell_counts,fraction): 58 | n = np.floor(np.sum(cell_counts) * fraction) 59 | readsGet = np.sort(random.sample(range(np.intp(np.sum(cell_counts))), np.intp(n))) 60 | cumCounts = np.concatenate((np.array([0]),np.cumsum(cell_counts))) 61 | counts_new = numba_histogram(readsGet,cumCounts)[0] 62 | counts_new = counts_new.astype(np.float32) 63 | return counts_new 64 | 65 | @numba.jit(nopython=True, parallel=True) 66 | def downsample_per_cell(cell_counts,new_cell_counts): 67 | n = new_cell_counts 68 | if n < np.sum(cell_counts): 69 | readsGet = np.sort(np.random.choice(np.arange(np.sum(cell_counts)), np.intp(n), replace=False)) 70 | cumCounts = np.concatenate((np.array([0]),np.cumsum(cell_counts))) 71 | counts_new = numba_histogram(readsGet,cumCounts)[0] 72 | counts_new = counts_new.astype(np.float32) 73 | return counts_new 74 | else: 75 | return cell_counts.astype(np.float32) 76 | 77 | def downsample_per_cell_python(param): 78 | cell_counts,new_cell_counts = param[0],param[1] 79 | n = new_cell_counts 80 | if n < np.sum(cell_counts): 81 | readsGet = np.sort(random.sample(range(np.intp(np.sum(cell_counts))), np.intp(n))) 82 | cumCounts = np.concatenate((np.array([0]),np.cumsum(cell_counts))) 83 | counts_new = numba_histogram(readsGet,cumCounts)[0] 84 | counts_new = counts_new.astype(np.float32) 85 | return counts_new 86 | else: 87 | return cell_counts.astype(np.float32) 88 | 89 | def downsample_matrix_by_cell(matrix,per_cell_counts,n_cpus=None,numba_end=True): 90 | if numba_end: 91 | downsample_func = downsample_per_cell 92 | else: 93 | downsample_func = downsample_per_cell_python 94 | if n_cpus is not None: 95 | with mp.Pool(n_cpus) as p: 96 | matrix_ds = p.map(downsample_func, zip(matrix,per_cell_counts)) 97 | else: 98 | matrix_ds = [downsample_func(c,per_cell_counts[i]) for i,c in enumerate(matrix)] 99 | return np.array(matrix_ds) 100 | 101 | # ps. slow speed. 102 | def downsample_matrix_total(matrix,fraction): 103 | matrix_flat = matrix.reshape(-1) 104 | matrix_flat_ds = downsample_cell(matrix_flat,fraction) 105 | matrix_ds = matrix_flat_ds.reshape(matrix.shape) 106 | return matrix_ds 107 | 108 | -------------------------------------------------------------------------------- /SPACEL/Spoint/data_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import scanpy as sc 4 | import anndata 5 | from scipy.sparse import issparse,csr_matrix 6 | from sklearn.preprocessing import normalize 7 | from . import data_downsample 8 | from . import data_augmentation 9 | from . import spatial_simulation 10 | import logging 11 | # logging.basicConfig(level=print, 12 | # format='%(asctime)s %(levelname)s %(message)s', 13 | # datefmt='%m-%d %H:%M') 14 | # logging.getLogger().setLevel(print) 15 | 16 | from rpy2.robjects.packages import importr 17 | import rpy2.robjects as robjects 18 | 19 | def normalize_adata(ad,target_sum=None): 20 | ad_norm = sc.pp.normalize_total(ad,inplace=False,target_sum=1e4) 21 | ad_norm = sc.pp.log1p(ad_norm['X']) 22 | # ad_norm = sc.pp.scale(ad_norm) 23 | # ad_norm = normalize(ad_norm,axis=1) 24 | ad.layers['norm'] = ad_norm 25 | return ad 26 | 27 | def normalize_mtx(mtx,target_sum): 28 | mtx = mtx[mtx.sum(1)!=0,:] 29 | mtx = np.nan_to_num(np.log1p((mtx.T*target_sum/mtx.sum(axis=1)).T)) 30 | mtx = (mtx-mtx.min(axis=1,keepdims=True))/(mtx.max(axis=1,keepdims=True)-mtx.min(axis=1,keepdims=True)) 31 | mtx = normalize(mtx,axis=1) 32 | return mtx 33 | 34 | # 计算单细胞亚群差异基因 35 | def find_sc_markers(sc_ad, celltype_key, layer='norm', deg_method=None, log2fc_min=0.5, pval_cutoff=0.01, n_top_markers=200, pct_diff=None, pct_min=0.1): 36 | print('### Finding marker genes...') 37 | 38 | # filter celltype contain only one sample. 39 | filtered_celltypes = list(sc_ad.obs[celltype_key].value_counts()[(sc_ad.obs[celltype_key].value_counts() == 1).values].index) 40 | if len(filtered_celltypes) > 0: 41 | sc_ad = sc_ad[sc_ad.obs[~(sc_ad.obs[celltype_key].isin(filtered_celltypes))].index,:].copy() 42 | print(f'### Filter cluster contain only one sample: {filtered_celltypes}') 43 | 44 | sc.tl.rank_genes_groups(sc_ad, groupby=celltype_key, pts=True, layer=layer, use_raw=False, method=deg_method) 45 | marker_genes_dfs = [] 46 | for c in np.unique(sc_ad.obs[celltype_key]): 47 | tmp_marker_gene_df = sc.get.rank_genes_groups_df(sc_ad, group=c, pval_cutoff=pval_cutoff, log2fc_min=log2fc_min) 48 | if (tmp_marker_gene_df.empty is not True): 49 | tmp_marker_gene_df.index = tmp_marker_gene_df.names 50 | tmp_marker_gene_df.loc[:,celltype_key] = c 51 | if pct_diff is not None: 52 | pct_diff_genes = sc_ad.var_names[np.where((sc_ad.uns['rank_genes_groups']['pts'][c]-sc_ad.uns['rank_genes_groups']['pts_rest'][c]) > pct_diff)] 53 | tmp_marker_gene_df = tmp_marker_gene_df.loc[np.intersect1d(pct_diff_genes, tmp_marker_gene_df.index),:] 54 | if pct_min is not None: 55 | # pct_min_genes = sc_ad.var_names[np.where((sc_ad.uns['rank_genes_groups']['pts'][c]) > pct_min)] 56 | tmp_marker_gene_df = tmp_marker_gene_df[tmp_marker_gene_df['pct_nz_group'] > pct_min] 57 | if n_top_markers is not None: 58 | tmp_marker_gene_df = tmp_marker_gene_df.sort_values('logfoldchanges',ascending=False) 59 | tmp_marker_gene_df = tmp_marker_gene_df.iloc[:n_top_markers,:] 60 | marker_genes_dfs.append(tmp_marker_gene_df) 61 | marker_gene_df = pd.concat(marker_genes_dfs,axis=0) 62 | print(marker_gene_df[celltype_key].value_counts()) 63 | all_marker_genes = np.unique(marker_gene_df.names) 64 | return all_marker_genes 65 | 66 | # 计算空间HVG 67 | def find_st_hvg(st_ad,n_top_hvg=None): 68 | print('### Finding HVG in spatial...') 69 | sc.pp.highly_variable_genes(st_ad,n_top_genes=n_top_hvg,flavor='seurat_v3') 70 | return st_ad.var_names[st_ad.var['highly_variable'] == True] 71 | 72 | # 73 | def filter_model_genes( 74 | sc_ad, 75 | st_ad, 76 | celltype_key=None, 77 | used_genes=None, 78 | sc_genes=None, 79 | st_genes=None, 80 | layer='norm', 81 | deg_method=None, 82 | log2fc_min=0.5, 83 | pval_cutoff=0.01, 84 | n_top_markers:int=200, 85 | n_top_hvg=None, 86 | pct_diff=None, 87 | pct_min=0.1, 88 | ): 89 | overlaped_genes = np.intersect1d(sc_ad.var_names,st_ad.var_names) 90 | sc_ad = sc_ad[:,overlaped_genes].copy() 91 | st_ad = st_ad[:,overlaped_genes].copy() 92 | if used_genes is None: 93 | if st_genes is None: 94 | if n_top_hvg is None: 95 | st_genes = st_ad.var_names 96 | else: 97 | st_genes = find_st_hvg(st_ad, n_top_hvg) 98 | if sc_genes is None: 99 | sc_ad = sc_ad[:, st_genes].copy() 100 | sc_genes = find_sc_markers(sc_ad, celltype_key, layer, deg_method, log2fc_min, pval_cutoff, n_top_markers, pct_diff, pct_min) 101 | used_genes = np.intersect1d(sc_genes,st_genes) 102 | sc_ad = sc_ad[:,used_genes].copy() 103 | st_ad = st_ad[:,used_genes].copy() 104 | sc.pp.filter_cells(sc_ad,min_genes=1) 105 | sc.pp.filter_cells(st_ad,min_genes=1) 106 | print(f'### Used gene numbers: {len(used_genes)}') 107 | return sc_ad, st_ad 108 | 109 | def check_data_type(ad): 110 | if issparse(ad.X): 111 | ad.X = ad.X.toarray() 112 | if ad.X.dtype != np.float32: 113 | ad.X =ad.X.astype(np.float32) 114 | return ad 115 | 116 | def generate_sm_adata(sc_ad,num_sample,celltype_key,n_threads,cell_counts,clusters_mean,cells_mean,cells_min,cells_max,cell_sample_counts,cluster_sample_counts,ncell_sample_list,cluster_sample_list): 117 | sm_data,sm_labels = spatial_simulation.generate_simulation_data(sc_ad,num_sample=num_sample,celltype_key=celltype_key,downsample_fraction=None,data_augmentation=False,n_cpus=n_threads,cell_counts=cell_counts,clusters_mean=clusters_mean,cells_mean=cells_mean,cells_min=cells_min,cells_max=cells_max,cell_sample_counts=cell_sample_counts,cluster_sample_counts=cluster_sample_counts,ncell_sample_list=ncell_sample_list,cluster_sample_list=cluster_sample_list) 118 | sm_data_mtx = csr_matrix(sm_data) 119 | sm_ad = anndata.AnnData(sm_data_mtx) 120 | sm_ad.var.index = sc_ad.var_names 121 | sm_labels = (sm_labels.T/sm_labels.sum(axis=1)).T 122 | sm_ad.obsm['label'] = pd.DataFrame(sm_labels,columns=np.array(sc_ad.obs[celltype_key].value_counts().index.values),index=sm_ad.obs_names) 123 | return sm_ad 124 | 125 | def downsample_sm_spot_counts(sm_ad,st_ad,n_threads): 126 | fitdistrplus = importr('fitdistrplus') 127 | lib_sizes = robjects.FloatVector(np.array(st_ad.X.sum(1)).reshape(-1)) 128 | res = fitdistrplus.fitdist(lib_sizes,'lnorm') 129 | loc = res[0][0] 130 | scale = res[0][1] 131 | 132 | sm_mtx_count = sm_ad.X.toarray() 133 | sample_cell_counts = np.random.lognormal(loc,scale,sm_ad.shape[0]) 134 | sm_mtx_count_lb = data_downsample.downsample_matrix_by_cell(sm_mtx_count,sample_cell_counts.astype(np.int64), n_cpus=n_threads, numba_end=False) 135 | sm_ad.X = csr_matrix(sm_mtx_count_lb) 136 | 137 | def split_shuffle_data(X,Y,shuffle=True,proportion=0.8): 138 | if shuffle: 139 | reind = np.random.permutation(len(X)) 140 | X = X[reind] 141 | Y = Y[reind] 142 | X_train = X[:int(len(X)*proportion)] 143 | Y_train = Y[:int(len(Y)*proportion)] 144 | X_test = X[int(len(X)*proportion):] 145 | Y_test = Y[int(len(Y)*proportion):] 146 | return X_train,Y_train,X_test,Y_test 147 | -------------------------------------------------------------------------------- /SPACEL/Spoint/metrics.py: -------------------------------------------------------------------------------- 1 | from scipy.stats import pearsonr, entropy, spearmanr 2 | from scipy.spatial.distance import jensenshannon 3 | from sklearn.metrics import mean_squared_error 4 | import numpy as np 5 | 6 | def pcc(x1,x2): 7 | return pearsonr(x1,x2)[0] 8 | 9 | def spcc(x1,x2): 10 | return spearmanr(x1,x2)[0] 11 | 12 | def rmse(x1,x2): 13 | return mean_squared_error(x1,x2,squared=False) 14 | 15 | def mae(x1,x2): 16 | return np.mean(np.abs(x1-x2)) 17 | 18 | def js(x1,x2): 19 | return jensenshannon(x1,x2) 20 | 21 | def kl(x1,x2): 22 | entropy(x1, x2) 23 | 24 | def ssim(im1,im2,M=1): 25 | im1, im2 = im1/im1.max(), im2/im2.max() 26 | mu1 = im1.mean() 27 | mu2 = im2.mean() 28 | sigma1 = np.sqrt(((im1 - mu1) ** 2).mean()) 29 | sigma2 = np.sqrt(((im2 - mu2) ** 2).mean()) 30 | sigma12 = ((im1 - mu1) * (im2 - mu2)).mean() 31 | k1, k2, L = 0.01, 0.03, M 32 | C1 = (k1*L) ** 2 33 | C2 = (k2*L) ** 2 34 | C3 = C2/2 35 | l12 = (2*mu1*mu2 + C1)/(mu1 ** 2 + mu2 ** 2 + C1) 36 | c12 = (2*sigma1*sigma2 + C2)/(sigma1 ** 2 + sigma2 ** 2 + C2) 37 | s12 = (sigma12 + C3)/(sigma1*sigma2 + C3) 38 | ssim = l12 * c12 * s12 39 | return ssim -------------------------------------------------------------------------------- /SPACEL/Spoint/model.py: -------------------------------------------------------------------------------- 1 | from . import data_utils 2 | from . import base_model 3 | from . import data_downsample 4 | from . import data_augmentation 5 | from . import spatial_simulation 6 | import numpy as np 7 | import torch 8 | from scipy.sparse import csr_matrix 9 | import numba 10 | import logging 11 | import random 12 | # logging.basicConfig(level=print, 13 | # format='%(asctime)s %(levelname)s %(message)s', 14 | # datefmt='%m-%d %H:%M') 15 | # logging.getLogger().setLevel(print) 16 | 17 | def init_model( 18 | sc_ad, 19 | st_ad, 20 | celltype_key, 21 | sc_genes=None, 22 | st_genes=None, 23 | used_genes=None, 24 | deg_method:str='wilcoxon', 25 | n_top_markers:int=200, 26 | n_top_hvg:int=None, 27 | log2fc_min=0.5, 28 | pval_cutoff=0.01, 29 | pct_diff=None, 30 | pct_min=0.1, 31 | use_rep='scvi', 32 | st_batch_key=None, 33 | sm_size:int=500000, 34 | cell_counts=None, 35 | clusters_mean=None, 36 | cells_mean=10, 37 | cells_min=1, 38 | cells_max=20, 39 | cell_sample_counts=None, 40 | cluster_sample_counts=None, 41 | ncell_sample_list=None, 42 | cluster_sample_list=None, 43 | scvi_layers=2, 44 | scvi_latent=128, 45 | scvi_gene_likelihood='zinb', 46 | scvi_dispersion='gene-batch', 47 | latent_dims=128, 48 | hidden_dims=512, 49 | infer_losses=['kl','cos'], 50 | n_threads=4, 51 | seed=42, 52 | use_gpu=None 53 | ): 54 | """Initialize Spoint model. 55 | 56 | Given specific data and parameters to initialize Spoint model. 57 | 58 | Args: 59 | sc_ad: An AnnData object representing single cell reference. 60 | st_ad: An AnnData object representing spatial transcriptomic data. 61 | celltype_key: A string representing cell types annotation columns in obs of single cell reference. 62 | sc_genes: A sequence of strings containing genes of single cell reference used in Spoint model. Only used when ``used_genes`` is None. 63 | st_genes: A sequence of strings containing genes of spatial transcriptomic data used in Spoint model. Only used when ``used_genes`` is None. 64 | used_genes: A sequence of strings containing genes used in Spoint model. 65 | deg_method: A string passed to method parameter of scanpy.tl.rank_genes_groups. 66 | n_top_markers: The number of differential expressed genes in each cell type of single cell reference used in Spoint model. 67 | n_top_hvg: The number of highly variable genes of spatial transcriptomic data used in Spoint model. 68 | log2fc_min: The threshold of log2 fold-change used for filtering differential expressed genes of single cell reference. 69 | pval_cutoff: The threshold of p-value used for filtering differential expressed genes of single cell reference. 70 | pct_min: The threshold of precentage of expressed cells used for filtering differential expressed genes of single cell reference. 71 | st_batch_key: A column name in obs of spatial transcriptomic data representing batch groups of spatial transcriptomic data. 72 | sm_size: The number of simulated spots. 73 | hiddem_dims: The number of nodes of hidden layers in Spoint model. 74 | n_threads: The number of cpu core used for parallel. 75 | 76 | Returns: 77 | A ``SpointModel`` object. 78 | """ 79 | 80 | print('Setting global seed:', seed) 81 | random.seed(seed) 82 | np.random.seed(seed) 83 | 84 | torch.manual_seed(seed) 85 | torch.cuda.manual_seed(seed) 86 | torch.backends.cudnn.deterministic = True 87 | torch.backends.cudnn.benchmark = False 88 | spatial_simulation.numba_set_seed(seed) 89 | numba.set_num_threads(n_threads) 90 | 91 | sc_ad = data_utils.normalize_adata(sc_ad,target_sum=1e4) 92 | st_ad = data_utils.normalize_adata(st_ad,target_sum=1e4) 93 | sc_ad, st_ad = data_utils.filter_model_genes( 94 | sc_ad, 95 | st_ad, 96 | celltype_key=celltype_key, 97 | deg_method=deg_method, 98 | n_top_markers=n_top_markers, 99 | n_top_hvg=n_top_hvg, 100 | used_genes=used_genes, 101 | sc_genes=sc_genes, 102 | st_genes=st_genes, 103 | log2fc_min=log2fc_min, 104 | pval_cutoff=pval_cutoff, 105 | pct_diff=pct_diff, 106 | pct_min=pct_min 107 | ) 108 | sm_ad = data_utils.generate_sm_adata(sc_ad,num_sample=sm_size,celltype_key=celltype_key,n_threads=n_threads,cell_counts=cell_counts,clusters_mean=clusters_mean,cells_mean=cells_mean,cells_min=cells_min,cells_max=cells_max,cell_sample_counts=cell_sample_counts,cluster_sample_counts=cluster_sample_counts,ncell_sample_list=ncell_sample_list,cluster_sample_list=cluster_sample_list) 109 | data_utils.downsample_sm_spot_counts(sm_ad,st_ad,n_threads=n_threads) 110 | 111 | model = base_model.SpointModel( 112 | st_ad, 113 | sm_ad, 114 | clusters = np.array(sm_ad.obsm['label'].columns), 115 | spot_names = np.array(st_ad.obs_names), 116 | used_genes = np.array(st_ad.var_names), 117 | use_rep=use_rep, 118 | st_batch_key=st_batch_key, 119 | scvi_layers=scvi_layers, 120 | scvi_latent=scvi_latent, 121 | scvi_gene_likelihood=scvi_gene_likelihood, 122 | scvi_dispersion=scvi_dispersion, 123 | latent_dims=latent_dims, 124 | hidden_dims=hidden_dims, 125 | infer_losses=infer_losses, 126 | use_gpu=use_gpu, 127 | seed=seed 128 | ) 129 | return model -------------------------------------------------------------------------------- /SPACEL/Spoint/spatial_simulation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import numba as nb 4 | from numba import jit 5 | import collections 6 | import random 7 | from .data_downsample import downsample_cell,downsample_matrix_by_cell 8 | from .data_augmentation import random_augment,random_augmentation_cell 9 | import logging 10 | # logging.basicConfig(level=print, 11 | # format='%(asctime)s %(levelname)s %(message)s', 12 | # datefmt='%m-%d %H:%M') 13 | # logging.getLogger().setLevel(print) 14 | 15 | # 汇总每个spot的细胞数,统计细胞数的分布 16 | def count_cell_counts(cell_counts): 17 | cell_counts = np.array(cell_counts.values,dtype=int).reshape(-1) 18 | counts_list = np.array(np.histogram(cell_counts,range=[0,np.max(cell_counts)+1],bins=np.max(cell_counts)+1)[0],dtype=int) 19 | counts_index = np.array((np.histogram(cell_counts,range=[0,np.max(cell_counts)+1],bins=np.max(cell_counts)+1)[1][:-1]),dtype=int) 20 | counts_df = pd.DataFrame(counts_list,index=counts_index,columns=['count'],dtype=np.int32) 21 | counts_df = counts_df[(counts_df['count'] != 0) & (counts_df.index != 0)] 22 | count_sum = 0 23 | for i in np.arange(len(counts_df)): 24 | count_sum += counts_df.iloc[i].values 25 | if count_sum > counts_df.values.sum()*0.99: 26 | counts_df_filtered = counts_df.iloc[:i+1,:] 27 | break 28 | return counts_df_filtered 29 | 30 | 31 | @nb.njit 32 | def numba_set_seed(seed): 33 | np.random.seed(seed) 34 | random.seed(seed) 35 | 36 | # 对某个axis调用numpy函数(numba版本) 37 | @nb.njit 38 | def np_apply_along_axis(func1d, axis, arr): 39 | assert arr.ndim == 2 40 | assert axis in [0, 1] 41 | if axis == 0: 42 | result = np.empty(arr.shape[1], dtype=arr.dtype) 43 | for i in range(len(result)): 44 | result[i] = func1d(arr[:, i]) 45 | else: 46 | result = np.empty(arr.shape[0], dtype=arr.dtype) 47 | for i in range(len(result)): 48 | result[i] = func1d(arr[i, :]) 49 | return result 50 | 51 | # 对某个axis计算均值(numba版本) 52 | @nb.njit 53 | def np_mean(array, axis): 54 | return np_apply_along_axis(np.mean, axis, array) 55 | 56 | # 对某个axis计算加和(numba版本) 57 | @nb.njit 58 | def np_sum(array, axis): 59 | return np_apply_along_axis(np.sum, axis, array) 60 | 61 | # 根据参数采样单细胞数据,生成模拟spot(numba版本) 62 | @jit(nopython=True,parallel=True) 63 | def sample_cell(param_list,cluster_p,clusters,cluster_id,sample_exp,sample_cluster,cell_p_balanced,downsample_fraction=None,data_augmentation=True,max_rate=0.8,max_val=0.8,kth=0.2): 64 | exp = np.empty((len(param_list), sample_exp.shape[1]),dtype=np.float32) 65 | density = np.empty((len(param_list), sample_cluster.shape[1]),dtype=np.float32) 66 | 67 | for i in nb.prange(len(param_list)): 68 | params = param_list[i] 69 | num_cell = params[0] 70 | num_cluster = params[1] 71 | used_clusters = clusters[np.searchsorted(np.cumsum(cluster_p), np.random.rand(num_cluster), side="right")] 72 | cluster_mask = np.array([False]*len(cluster_id)) 73 | for c in used_clusters: 74 | cluster_mask = (cluster_id==c)|(cluster_mask) 75 | # print('cluster_mask',cluster_mask) 76 | # print('used_clusters',used_clusters) 77 | used_cell_ind = np.where(cluster_mask)[0] 78 | used_cell_p = cell_p_balanced[cluster_mask] 79 | used_cell_p = used_cell_p/used_cell_p.sum() 80 | sampled_cells = used_cell_ind[np.searchsorted(np.cumsum(used_cell_p), np.random.rand(num_cell), side="right")] 81 | combined_exp = np_sum(sample_exp[sampled_cells,:],axis=0).astype(np.float32) 82 | if data_augmentation: 83 | combined_exp = random_augmentation_cell(combined_exp,max_rate=max_rate,max_val=max_val,kth=kth) 84 | if downsample_fraction is not None: 85 | combined_exp = downsample_cell(combined_exp, downsample_fraction) 86 | combined_clusters = np_sum(sample_cluster[cluster_id[sampled_cells]],axis=0).astype(np.float32) 87 | exp[i,:] = combined_exp 88 | density[i,:] = combined_clusters 89 | return exp,density 90 | 91 | @jit(nopython=True,parallel=True) 92 | def sample_cell_from_clusters(cluster_sample_list,ncell_sample_list,cluster_id,sample_exp,sample_cluster,cell_p_balanced,downsample_fraction=None,data_augmentation=True,max_rate=0.8,max_val=0.8,kth=0.2): 93 | exp = np.empty((len(cluster_sample_list), sample_exp.shape[1]),dtype=np.float32) 94 | density = np.empty((len(cluster_sample_list), sample_cluster.shape[1]),dtype=np.float32) 95 | for i in nb.prange(len(cluster_sample_list)): 96 | used_clusters = np.where(cluster_sample_list[i] == 1)[0] 97 | num_cell = ncell_sample_list[i] 98 | cluster_mask = np.array([False]*len(cluster_id)) 99 | for c in used_clusters: 100 | cluster_mask = (cluster_id==c)|(cluster_mask) 101 | used_cell_ind = np.where(cluster_mask)[0] 102 | used_cell_p = cell_p_balanced[cluster_mask] 103 | used_cell_p = used_cell_p/used_cell_p.sum() 104 | sampled_cells = used_cell_ind[np.searchsorted(np.cumsum(used_cell_p), np.random.rand(num_cell), side="right")] 105 | combined_exp = np_sum(sample_exp[sampled_cells,:],axis=0).astype(np.float32) 106 | if data_augmentation: 107 | combined_exp = random_augmentation_cell(combined_exp,max_rate=max_rate,max_val=max_val,kth=kth) 108 | if downsample_fraction is not None: 109 | combined_exp = downsample_cell(combined_exp, downsample_fraction) 110 | combined_clusters = np_sum(sample_cluster[cluster_id[sampled_cells]],axis=0).astype(np.float32) 111 | exp[i,:] = combined_exp 112 | density[i,:] = combined_clusters 113 | return exp,density 114 | 115 | def init_sample_prob(sc_ad,celltype_key): 116 | print('### Initializing sample probability') 117 | sc_ad.uns['celltype2num'] = pd.DataFrame( 118 | np.arange(len(sc_ad.obs[celltype_key].value_counts())).T, 119 | index=sc_ad.obs[celltype_key].value_counts().index.values, 120 | columns=['celltype_num'] 121 | ) 122 | sc_ad.obs['celltype_num'] = [sc_ad.uns['celltype2num'].loc[c,'celltype_num'] for c in sc_ad.obs[celltype_key]] 123 | cluster_p_unbalance = sc_ad.obs['celltype_num'].value_counts()/sc_ad.obs['celltype_num'].value_counts().sum() 124 | cluster_p_sqrt = np.sqrt(sc_ad.obs['celltype_num'].value_counts())/np.sqrt(sc_ad.obs['celltype_num'].value_counts()).sum() 125 | cluster_p_balance = pd.Series( 126 | np.ones(len(sc_ad.obs['celltype_num'].value_counts()))/len(sc_ad.obs['celltype_num'].value_counts()), 127 | index=sc_ad.obs['celltype_num'].value_counts().index 128 | ) 129 | # cluster_p_balance = np.ones(len(sc_ad.obs['celltype_num'].value_counts()))/len(sc_ad.obs['celltype_num'].value_counts()) 130 | cell_p_balanced = [1/cluster_p_unbalance[c] for c in sc_ad.obs['celltype_num']] 131 | cell_p_balanced = np.array(cell_p_balanced)/np.array(cell_p_balanced).sum() 132 | sc_ad.obs['cell_p_balanced'] = cell_p_balanced 133 | sc_ad.uns['cluster_p_balance'] = cluster_p_balance 134 | sc_ad.uns['cluster_p_sqrt'] = cluster_p_sqrt 135 | sc_ad.uns['cluster_p_unbalance'] = cluster_p_unbalance 136 | return sc_ad 137 | 138 | # 将表达矩阵转化成array 139 | def generate_sample_array(sc_ad, used_genes): 140 | if used_genes is not None: 141 | sc_df = sc_ad.to_df().loc[:,used_genes] 142 | else: 143 | sc_df = sc_ad.to_df() 144 | return sc_df.values 145 | 146 | # 从均匀分布中获取每个spot采样的细胞数和细胞类型数 147 | def get_param_from_uniform(num_sample,cells_min=None,cells_max=None,clusters_min=None,clusters_max=None): 148 | 149 | cell_count = np.asarray(np.ceil(np.random.uniform(int(cells_min),int(cells_max),size=num_sample)),dtype=int) 150 | cluster_count = np.asarray(np.ceil(np.clip(np.random.uniform(clusters_min,clusters_max,size=num_sample),1,cell_count)),dtype=int) 151 | return cell_count, cluster_count 152 | 153 | # 从高斯分布中获取每个spot采样的细胞数和细胞类型数 154 | def get_param_from_gaussian(num_sample,cells_min=None,cells_max=None,cells_mean=None,cells_std=None,clusters_mean=None,clusters_std=None): 155 | 156 | cell_count = np.asarray(np.ceil(np.clip(np.random.normal(cells_mean,cells_std,size=num_sample),int(cells_min),int(cells_max))),dtype=int) 157 | cluster_count = np.asarray(np.ceil(np.clip(np.random.normal(clusters_mean,clusters_std,size=num_sample),1,cell_count)),dtype=int) 158 | return cell_count,cluster_count 159 | 160 | # 从用空间数据估计的cell counts中获取每个spot采样的细胞数和细胞类型数 161 | def get_param_from_cell_counts( 162 | num_sample, 163 | cell_counts, 164 | cluster_sample_mode='gaussian', 165 | cells_min=None,cells_max=None, 166 | cells_mean=None,cells_std=None, 167 | clusters_mean=None,clusters_std=None, 168 | clusters_min=None,clusters_max=None 169 | ): 170 | cell_count = np.asarray(np.ceil(np.clip(np.random.normal(cells_mean,cells_std,size=num_sample),int(cells_min),int(cells_max))),dtype=int) 171 | if cluster_sample_mode == 'gaussian': 172 | cluster_count = np.asarray(np.ceil(np.clip(np.random.normal(clusters_mean,clusters_std,size=num_sample),1,cell_count)),dtype=int) 173 | elif cluster_sample_mode == 'uniform': 174 | cluster_count = np.asarray(np.ceil(np.clip(np.random.uniform(clusters_min,clusters_max,size=num_sample),1,cell_count)),dtype=int) 175 | else: 176 | raise TypeError('Not correct sample method.') 177 | return cell_count,cluster_count 178 | 179 | # 获取每个cluster的采样概率 180 | def get_cluster_sample_prob(sc_ad,mode): 181 | if mode == 'unbalance': 182 | cluster_p = sc_ad.uns['cluster_p_unbalance'].values 183 | elif mode == 'balance': 184 | cluster_p = sc_ad.uns['cluster_p_balance'].values 185 | elif mode == 'sqrt': 186 | cluster_p = sc_ad.uns['cluster_p_sqrt'].values 187 | else: 188 | raise TypeError('Balance argument must be one of [ None, banlance, sqrt ].') 189 | return cluster_p 190 | 191 | def cal_downsample_fraction(sc_ad,st_ad,celltype_key=None): 192 | st_counts_median = np.median(st_ad.X.sum(axis=1)) 193 | simulated_st_data, simulated_st_labels = generate_simulation_data(sc_ad,num_sample=10000,celltype_key=celltype_key,balance_mode=['unbalance']) 194 | simulated_st_counts_median = np.median(simulated_st_data.sum(axis=1)) 195 | if st_counts_median < simulated_st_counts_median: 196 | fraction = st_counts_median / simulated_st_counts_median 197 | print(f'### Simulated data downsample fraction: {fraction}') 198 | return fraction 199 | else: 200 | return None 201 | 202 | # 生成模拟数据 203 | def generate_simulation_data( 204 | sc_ad, 205 | celltype_key, 206 | num_sample: int, 207 | used_genes=None, 208 | balance_mode=['unbalance','sqrt','balance'], 209 | cell_sample_method='gaussian', 210 | cluster_sample_method='gaussian', 211 | cell_counts=None, 212 | downsample_fraction=None, 213 | data_augmentation=True, 214 | max_rate=0.8,max_val=0.8,kth=0.2, 215 | cells_min=1,cells_max=20, 216 | cells_mean=10,cells_std=5, 217 | clusters_mean=None,clusters_std=None, 218 | clusters_min=None,clusters_max=None, 219 | cell_sample_counts=None,cluster_sample_counts=None, 220 | ncell_sample_list=None, 221 | cluster_sample_list=None, 222 | n_cpus=None 223 | ): 224 | if not 'cluster_p_unbalance' in sc_ad.uns: 225 | sc_ad = init_sample_prob(sc_ad,celltype_key) 226 | num_sample_per_mode = num_sample//len(balance_mode) 227 | cluster_ordered = np.array(sc_ad.obs['celltype_num'].value_counts().index) 228 | cluster_num = len(cluster_ordered) 229 | cluster_id = sc_ad.obs['celltype_num'].values 230 | cluster_mask = np.eye(cluster_num) 231 | if (cell_sample_counts is None) or (cluster_sample_counts is None): 232 | if cell_counts is not None: 233 | cells_mean = np.mean(np.sort(cell_counts)[int(len(cell_counts)*0.05):int(len(cell_counts)*0.95)]) 234 | cells_std = np.std(np.sort(cell_counts)[int(len(cell_counts)*0.05):int(len(cell_counts)*0.95)]) 235 | cells_min = int(np.min(np.sort(cell_counts)[int(len(cell_counts)*0.05):int(len(cell_counts)*0.95)])) 236 | cells_max = int(np.max(np.sort(cell_counts)[int(len(cell_counts)*0.05):int(len(cell_counts)*0.95)])) 237 | if clusters_mean is None: 238 | clusters_mean = cells_mean/2 239 | if clusters_std is None: 240 | clusters_std = cells_std/2 241 | if clusters_min is None: 242 | clusters_min = cells_min 243 | if clusters_max is None: 244 | clusters_max = np.min((cells_max//2,cluster_num)) 245 | 246 | if cell_counts is not None: 247 | cell_sample_counts, cluster_sample_counts = get_param_from_cell_counts(num_sample_per_mode,cell_counts,cluster_sample_method,cells_mean=cells_mean,cells_std=cells_std,cells_max=cells_max,cells_min=cells_min,clusters_mean=clusters_mean,clusters_std=clusters_std,clusters_min=clusters_min,clusters_max=clusters_max) 248 | elif cell_sample_method == 'gaussian': 249 | cell_sample_counts, cluster_sample_counts = get_param_from_gaussian(num_sample_per_mode,cells_mean=cells_mean,cells_std=cells_std,cells_max=cells_max,cells_min=cells_min,clusters_mean=clusters_mean,clusters_std=clusters_std) 250 | elif cell_sample_method == 'uniform': 251 | cell_sample_counts, cluster_sample_counts = get_param_from_uniform(num_sample_per_mode,cells_max=cells_max,cells_min=cells_min,clusters_min=clusters_min,clusters_max=clusters_max) 252 | else: 253 | raise TypeError('Not correct sample method.') 254 | if cluster_sample_list is None or ncell_sample_list is None: 255 | params = np.array(list(zip(cell_sample_counts, cluster_sample_counts))) 256 | 257 | sample_data_list = [] 258 | sample_labels_list = [] 259 | for b in balance_mode: 260 | print(f'### Genetating simulated spatial data using scRNA data with mode: {b}') 261 | cluster_p = get_cluster_sample_prob(sc_ad,b) 262 | if downsample_fraction is not None: 263 | if downsample_fraction > 0.035: 264 | sample_data,sample_labels = sample_cell( 265 | param_list=params, 266 | cluster_p=cluster_p, 267 | clusters=cluster_ordered, 268 | cluster_id=cluster_id, 269 | sample_exp=generate_sample_array(sc_ad,used_genes), 270 | sample_cluster=cluster_mask, 271 | cell_p_balanced=sc_ad.obs['cell_p_balanced'].values, 272 | downsample_fraction=downsample_fraction, 273 | data_augmentation=data_augmentation,max_rate=max_rate,max_val=max_val,kth=kth, 274 | ) 275 | else: 276 | sample_data,sample_labels = sample_cell( 277 | param_list=params, 278 | cluster_p=cluster_p, 279 | clusters=cluster_ordered, 280 | cluster_id=cluster_id, 281 | sample_exp=generate_sample_array(sc_ad,used_genes), 282 | sample_cluster=cluster_mask, 283 | cell_p_balanced=sc_ad.obs['cell_p_balanced'].values, 284 | data_augmentation=data_augmentation,max_rate=max_rate,max_val=max_val,kth=kth, 285 | ) 286 | # logging.warning('### Downsample data with python backend') 287 | sample_data = downsample_matrix_by_cell(sample_data, downsample_fraction, n_cpus=n_cpus, numba_end=False) 288 | else: 289 | sample_data,sample_labels = sample_cell( 290 | param_list=params, 291 | cluster_p=cluster_p, 292 | clusters=cluster_ordered, 293 | cluster_id=cluster_id, 294 | sample_exp=generate_sample_array(sc_ad,used_genes), 295 | sample_cluster=cluster_mask, 296 | cell_p_balanced=sc_ad.obs['cell_p_balanced'].values, 297 | data_augmentation=data_augmentation,max_rate=max_rate,max_val=max_val,kth=kth, 298 | ) 299 | # if data_augmentation: 300 | # sample_data = random_augment(sample_data) 301 | sample_data_list.append(sample_data) 302 | sample_labels_list.append(sample_labels) 303 | else: 304 | sample_data_list = [] 305 | sample_labels_list = [] 306 | for b in balance_mode: 307 | print(f'### Genetating simulated spatial data using scRNA data with mode: {b}') 308 | cluster_p = get_cluster_sample_prob(sc_ad,b) 309 | if downsample_fraction is not None: 310 | if downsample_fraction > 0.035: 311 | sample_data,sample_labels = sample_cell_from_clusters( 312 | cluster_sample_list=cluster_sample_list, 313 | ncell_sample_list=ncell_sample_list, 314 | cluster_id=cluster_id, 315 | sample_exp=generate_sample_array(sc_ad,used_genes), 316 | sample_cluster=cluster_mask, 317 | cell_p_balanced=sc_ad.obs['cell_p_balanced'].values, 318 | downsample_fraction=downsample_fraction, 319 | data_augmentation=data_augmentation,max_rate=max_rate,max_val=max_val,kth=kth, 320 | ) 321 | else: 322 | sample_data,sample_labels = sample_cell_from_clusters( 323 | cluster_sample_list=cluster_sample_list, 324 | ncell_sample_list=ncell_sample_list, 325 | cluster_id=cluster_id, 326 | sample_exp=generate_sample_array(sc_ad,used_genes), 327 | sample_cluster=cluster_mask, 328 | cell_p_balanced=sc_ad.obs['cell_p_balanced'].values, 329 | data_augmentation=data_augmentation,max_rate=max_rate,max_val=max_val,kth=kth, 330 | ) 331 | # logging.warning('### Downsample data with python backend') 332 | sample_data = downsample_matrix_by_cell(sample_data, downsample_fraction, n_cpus=n_cpus, numba_end=False) 333 | else: 334 | sample_data,sample_labels = sample_cell_from_clusters( 335 | cluster_sample_list=cluster_sample_list, 336 | ncell_sample_list=ncell_sample_list, 337 | cluster_id=cluster_id, 338 | sample_exp=generate_sample_array(sc_ad,used_genes), 339 | sample_cluster=cluster_mask, 340 | cell_p_balanced=sc_ad.obs['cell_p_balanced'].values, 341 | data_augmentation=data_augmentation,max_rate=max_rate,max_val=max_val,kth=kth, 342 | ) 343 | sample_data_list.append(sample_data) 344 | sample_labels_list.append(sample_labels) 345 | return np.concatenate(sample_data_list), np.concatenate(sample_labels_list) 346 | 347 | @jit(nopython=True,parallel=True) 348 | def sample_cell_exp(cell_counts,sample_exp,cell_p,downsample_fraction=None,data_augmentation=True,max_rate=0.8,max_val=0.8,kth=0.2): 349 | exp = np.empty((len(cell_counts), sample_exp.shape[1]),dtype=np.float32) 350 | ind = np.zeros((len(cell_counts), np.max(cell_counts)),dtype=np.int32) 351 | cell_ind = np.arange(sample_exp.shape[0]) 352 | for i in nb.prange(len(cell_counts)): 353 | num_cell = cell_counts[i] 354 | sampled_cells=cell_ind[np.searchsorted(np.cumsum(cell_p), np.random.rand(num_cell), side="right")] 355 | combined_exp=np_sum(sample_exp[sampled_cells,:],axis=0).astype(np.float64) 356 | # print(combined_exp.dtype) 357 | if downsample_fraction is not None: 358 | combined_exp = downsample_cell(combined_exp, downsample_fraction) 359 | if data_augmentation: 360 | combined_exp = random_augmentation_cell(combined_exp,max_rate=max_rate,max_val=max_val,kth=kth) 361 | exp[i,:] = combined_exp 362 | ind[i,:cell_counts[i]] = sampled_cells + 1 363 | return exp,ind 364 | 365 | def generate_simulation_st_data( 366 | st_ad, 367 | num_sample: int, 368 | used_genes=None, 369 | balance_mode=['unbalance'], 370 | cell_sample_method='gaussian', 371 | cell_counts=None, 372 | downsample_fraction=None, 373 | data_augmentation=True, 374 | max_rate=0.8,max_val=0.8,kth=0.2, 375 | cells_min=1,cells_max=10, 376 | cells_mean=5,cells_std=3, 377 | ): 378 | print('### Genetating simulated data using spatial data') 379 | cell_p = np.ones(len(st_ad))/len(st_ad) 380 | if cell_counts is not None: 381 | cells_mean = np.mean(np.sort(cell_counts)[int(len(cell_counts)*0.05):int(len(cell_counts)*0.95)]) 382 | cells_std = np.std(np.sort(cell_counts)[int(len(cell_counts)*0.05):int(len(cell_counts)*0.95)]) 383 | cells_min = int(np.min(np.sort(cell_counts)[int(len(cell_counts)*0.05):int(len(cell_counts)*0.95)])) 384 | cells_max = int(np.max(np.sort(cell_counts)[int(len(cell_counts)*0.05):int(len(cell_counts)*0.95)])) 385 | elif cell_sample_method == 'gaussian': 386 | cell_counts = np.asarray(np.ceil(np.clip(np.random.normal(cells_mean,cells_std,size=num_sample),int(cells_min),int(cells_max))),dtype=int) 387 | elif cell_sample_method == 'uniform': 388 | cell_counts = np.asarray(np.ceil(np.random.uniform(int(cells_min),int(cells_max),size=num_sample)),dtype=int) 389 | else: 390 | raise TypeError('Not correct sample method.') 391 | 392 | sample_data,sample_ind = sample_cell_exp( 393 | cell_counts=cell_counts, 394 | sample_exp=generate_sample_array(st_ad,used_genes), 395 | cell_p=cell_p, 396 | downsample_fraction=downsample_fraction, 397 | data_augmentation=data_augmentation,max_rate=max_rate,max_val=max_val,kth=kth, 398 | ) 399 | # if data_augmentation: 400 | # sample_data = random_augment(sample_data) 401 | return sample_data,sample_ind 402 | -------------------------------------------------------------------------------- /SPACEL/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import pandas as pd 3 | import logging 4 | from ._version import __version__ 5 | 6 | warnings.simplefilter(action='ignore', category=FutureWarning) 7 | warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning) -------------------------------------------------------------------------------- /SPACEL/_version.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | __version__ = "1.1.8" -------------------------------------------------------------------------------- /SPACEL/setting.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import numpy as np 4 | 5 | def auto_cuda_device(): 6 | out = subprocess.getoutput('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free') 7 | if len(out.split('\n')) > 1: 8 | memory_available = [int(x.split()[2]) for x in out.split('\n')] 9 | max_idx = np.where(memory_available == np.max(memory_available))[0] 10 | os.environ['CUDA_VISIBLE_DEVICES']=str(np.random.permutation(max_idx)[0]) 11 | print('Using GPU:',os.environ['CUDA_VISIBLE_DEVICES']) 12 | else: 13 | raise ValueError('Invalid output from nvidia-smi.') 14 | 15 | def set_environ_seed(seed=42): 16 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 17 | os.environ['PYTHONHASHSEED']=str(seed) 18 | print('Setting environment seed:',seed) 19 | 20 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/img/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuKunLab/SPACEL/02801c2dcb18cbf4ffefdc1352a81314c571fe85/docs/_static/img/figure1.png -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | ```{eval-rst} 2 | .. module:: SPACEL 3 | ``` 4 | 5 | ```{eval-rst} 6 | .. automodule:: SPACEL 7 | :noindex: 8 | ``` 9 | # API 10 | 11 | ## Spoint: 12 | 13 | ```{eval-rst} 14 | .. module:: SPACEL.Spoint 15 | ``` 16 | 17 | ```{eval-rst} 18 | .. currentmodule:: SPACEL 19 | ``` 20 | 21 | ```{eval-rst} 22 | .. autosummary:: 23 | :toctree: generated/ 24 | 25 | Spoint.model.init_model 26 | Spoint.base_model.SpointModel.train 27 | Spoint.base_model.SpointModel.deconv_spatial 28 | ``` 29 | 30 | ## Splane: 31 | 32 | ```{eval-rst} 33 | .. module:: SPACEL.Splane 34 | ``` 35 | 36 | ```{eval-rst} 37 | .. currentmodule:: SPACEL 38 | ``` 39 | 40 | ```{eval-rst} 41 | .. autosummary:: 42 | :toctree: generated/ 43 | 44 | Splane.utils.add_cell_type_composition 45 | Splane.model.init_model 46 | Splane.base_model.SplaneModel.train 47 | Splane.base_model.SplaneModel.identify_spatial_domain 48 | ``` 49 | 50 | ## Scube: 51 | 52 | ```{eval-rst} 53 | .. module:: SPACEL.Scube 54 | ``` 55 | 56 | ```{eval-rst} 57 | .. currentmodule:: SPACEL 58 | ``` 59 | 60 | ```{eval-rst} 61 | .. autosummary:: 62 | :toctree: generated/ 63 | 64 | Scube.alignment.align 65 | Scube.plot.plot_stacked_slices 66 | Scube.plot.plot_3d 67 | Scube.gpr.GPRmodel 68 | Scube.gpr.GPRmodel.train 69 | Scube.gpr.GPRmodel.plot_gpr_expr 70 | ``` 71 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.insert(0, os.path.abspath("..")) 4 | from SPACEL import Spoint 5 | from SPACEL import Splane 6 | from SPACEL import Scube 7 | # Configuration file for the Sphinx documentation builder. 8 | 9 | # -- Project information 10 | 11 | project = 'SPACEL' 12 | author = 'Hao Xu' 13 | 14 | release = '1.1.8' 15 | version = '1.1.8' 16 | 17 | # -- General configuration 18 | exclude_patterns = ['_build', '.DS_Store', '**.ipynb_checkpoints'] 19 | extensions = [ 20 | 'myst_parser', 21 | 'sphinx.ext.duration', 22 | 'sphinx.ext.doctest', 23 | 'sphinx.ext.autodoc', 24 | 'sphinx.ext.autosummary', 25 | 'sphinx.ext.intersphinx', 26 | 'sphinx.ext.coverage', 27 | 'sphinx.ext.mathjax', 28 | 'sphinx.ext.napoleon', 29 | 'sphinx_autodoc_typehints', 30 | 'nbsphinx', 31 | ] 32 | 33 | intersphinx_mapping = { 34 | 'python': ('https://docs.python.org/3/', None), 35 | 'sphinx': ('https://www.sphinx-doc.org/en/master/', None), 36 | } 37 | intersphinx_disabled_domains = ['std'] 38 | 39 | templates_path = ['_templates'] 40 | 41 | # -- Options for HTML output 42 | 43 | html_theme = 'sphinx_rtd_theme' 44 | 45 | # -- Options for EPUB output 46 | epub_show_urls = 'footnote' -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | [![Documentation Status](https://readthedocs.org/projects/spacel/badge/?version=latest)](https://spacel.readthedocs.io/en/latest/?badge=latest) 2 | ![PyPI](https://img.shields.io/pypi/v/SPACEL) 3 | 4 | # SPACEL: characterizing spatial transcriptome architectures by deep-learning 5 | 6 | ```{image} _static/img/figure1.png 7 | :width: 900px 8 | ``` 9 | SPACEL (**SP**atial **A**rchitecture **C**haracterization by d**E**ep **L**earning) is a Python package of deep-learning-based methods for ST data analysis. SPACEL consists of three modules: 10 | 11 | - Spoint embedded a multiple-layer perceptron with a probabilistic model to deconvolute cell type composition for each spot on single ST slice. 12 | - Splane employs a graph convolutional network approach and an adversarial learning algorithm to identify uniform spatial domains that are transcriptomically and spatially coherent across multiple ST slices. 13 | - Scube automatically transforms the spatial coordinate systems of consecutive slices and stacks them together to construct a three-dimensional (3D) alignment of the tissue. 14 | 15 | ## Content 16 | * {doc}`Installation ` 17 | * {doc}`Tutorials ` 18 | * {doc}`Spoint tutorial: Deconvolution of cell types compostion on human brain Visium dataset ` 19 | * {doc}`Splane tutorial: Identify uniform spatial domain on human breast cancer Visium dataset ` 20 | * {doc}`Splane&Scube tutorial (1/2): Identify uniform spatial domain on human brain MERFISH dataset ` 21 | * {doc}`Splane&Scube tutorial (1/2): Alignment of consecutive ST slices on human brain MERFISH dataset ` 22 | * {doc}`Scube tutorial: Alignment of consecutive ST slices on mouse embryo Stereo-seq dataset ` 23 | * {doc}`Scube tutorial: 3D expression modeling with gaussian process regression ` 24 | * {doc}`SPACEL workflow (1/3): Deconvolution by Spoint on mouse brain ST dataset ` 25 | * {doc}`SPACEL workflow (2/3): Identification of spatial domain by Splane on mouse brain ST dataset ` 26 | * {doc}`SPACEL workflow (3/3): Alignment 3D tissue by Scube on mouse brain ST dataset ` 27 | * {doc}`API ` 28 | 29 | ## Latest updates 30 | ### Version 1.1.8 2024-07-23 31 | #### Fixed Bugs 32 | - Fixed the conflict between optax version and phthon 3.8. 33 | 34 | ### Version 1.1.7 2024-01-16 35 | #### Fixed Bugs 36 | - Fixed a variable reference error in function `identify_spatial_domain`. Thanks to @tobias-zehnde for the contribution. 37 | 38 | ### Version 1.1.6 2023-07-27 39 | #### Fixed Bugs 40 | - Fixed a bug regarding the similarity loss weight hyperparameter `simi_l`, which in the previous version did not affect the loss value. 41 | 42 | ```{toctree} 43 | :hidden: true 44 | :maxdepth: 1 45 | 46 | installation 47 | tutorials 48 | api 49 | ``` 50 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Requirements 4 | **Note**: The current version of SPACEL only supports Linux and MacOS, not Windows platform. 5 | 6 | To install `SPACEL`, you need to install [PyTorch](https://pytorch.org) with GPU support first. If you don't need GPU acceleration, you can just skip the installation for `cudnn` and `cudatoolkit`. 7 | * Create conda environment for `SPACEL`: 8 | ``` 9 | conda env create -f environment.yml 10 | ``` 11 | or 12 | ``` 13 | conda create -n SPACEL -c conda-forge -c default cudatoolkit=10.2 python=3.8 rpy2 r-base r-fitdistrplus 14 | ``` 15 | You must choose correct `PyTorch`, `cudnn` and `cudatoolkit` version dependent on your graphic driver version. 16 | 17 | Note: If you want to run 3D expression GPR model in Scube, you need to install the [Open3D](http://www.open3d.org/docs/release/) python library first. 18 | 19 | ## Installation 20 | * Install `SPACEL`: 21 | ``` 22 | pip install SPACEL 23 | ``` 24 | * Test if [PyTorch](https://pytorch.org) for GPU available: 25 | ``` 26 | python 27 | >>> import torch 28 | >>> torch.cuda.is_available() 29 | ``` 30 | If these command line have not return `True`, please check your gpu driver version and `cudatoolkit` version. For more detail, look at [CUDA Toolkit Major Component Versions](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions). 31 | 32 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | -r ../requirements.txt 2 | sphinx 3 | sphinx-rtd-theme 4 | myst-parser 5 | sphinx-autodoc-typehints 6 | nbsphinx -------------------------------------------------------------------------------- /docs/tutorials.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | * {doc}`Spoint tutorial: Deconvolution of cell types compostion on human brain Visium dataset ` 3 | * {doc}`Splane tutorial: Identify uniform spatial domain on human breast cancer Visium dataset ` 4 | * {doc}`Splane&Scube tutorial (1/2): Identify uniform spatial domain on human brain MERFISH dataset ` 5 | * {doc}`Splane&Scube tutorial (1/2): Alignment of consecutive ST slices on human brain MERFISH dataset ` 6 | * {doc}`Scube tutorial: Alignment of consecutive ST slices on mouse embryo Stereo-seq dataset ` 7 | * {doc}`Scube tutorial: 3D expression modeling with gaussian process regression ` 8 | * {doc}`SPACEL workflow (1/3): Deconvolution by Spoint on mouse brain ST dataset ` 9 | * {doc}`SPACEL workflow (2/3): Identification of spatial domain by Splane on mouse brain ST dataset ` 10 | * {doc}`SPACEL workflow (3/3): Alignment 3D tissue by Scube on mouse brain ST dataset ` 11 | ```{toctree} 12 | :hidden: true 13 | :maxdepth: 1 14 | 15 | tutorials/Visium_human_DLPFC_Spoint.ipynb 16 | tutorials/Visium_human_breast_cancer_Splane.ipynb 17 | tutorials/MERFISH_mouse_brain_Splane.ipynb 18 | tutorials/MERFISH_mouse_brain_Scube.ipynb 19 | tutorials/Stereo-seq_Scube.ipynb 20 | tutorials/STARmap_mouse_brain_GPR.ipynb 21 | tutorials/ST_mouse_brain_Spoint.ipynb 22 | tutorials/ST_mouse_brain_Splane.ipynb 23 | tutorials/ST_mouse_brain_Scube.ipynb 24 | ``` 25 | 26 | -------------------------------------------------------------------------------- /docs/tutorials/MERFISH_mouse_brain_Splane.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "4054f846-fc71-41ca-8bd4-83a30a3a485f", 6 | "metadata": {}, 7 | "source": [ 8 | "# Splane&Scube tutorial (1/2): Identify uniform spatial domain on human brain MERFISH dataset" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "7a586b80-766e-47e3-a9f2-fe4dff2781c6", 14 | "metadata": {}, 15 | "source": [ 16 | "July 2023" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "b2623f99-4eac-43b0-8349-997f97d7e2c4", 22 | "metadata": {}, 23 | "source": [ 24 | "Dataset: 33 MERFISH slices of mouse brain ([here](https://zenodo.org/record/8167488))" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "348496e8-7ee6-486d-aeb8-e8d85d4a5119", 30 | "metadata": {}, 31 | "source": [ 32 | "## Data preprocessing" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 1, 38 | "id": "808753f6-5d7c-4ab7-be34-c0a05911f039", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "from SPACEL.setting import set_environ_seed\n", 43 | "set_environ_seed(42)\n", 44 | "from SPACEL import Splane\n", 45 | "import scanpy as sc\n", 46 | "import numpy as np\n", 47 | "import pandas as pd\n", 48 | "import matplotlib" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "id": "cf34fcfa-8daf-43f9-be8b-d202bd2eedfb", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "st_merfish = sc.read_h5ad('../data/merfish_mouse_brain/merfish_mouse_brain.h5ad')" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "id": "47a25442-24b3-4da5-97fc-513514d57d58", 64 | "metadata": {}, 65 | "source": [ 66 | "Here, we will incorporate the cell type composition predicted by **Spoint** into the spatial anndata object for subsequent spatial domain identification in **Splane** using the `add_cell_type_composition` function. This function takes a DataFrame containing the cell type composition matrix as input for spot-based spatial transcriptomic data or a series of cell type annotations as input for single-cell resolution spatial transcriptomic data." 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "dfee4531-b779-4e38-b2a7-dba8f1eb8fc1", 73 | "metadata": { 74 | "scrolled": true, 75 | "tags": [] 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "Splane.utils.add_cell_type_composition(st_merfish, celltype_anno=st_merfish.obs['label'])\n", 80 | "adata_list = Splane.utils.split_ad(st_merfish,'slice_id')" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "cd58d1e1-12f8-40bb-bc15-b7b053c8a686", 86 | "metadata": {}, 87 | "source": [ 88 | "## Training Splane model" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "id": "075e28b6-615d-4b1b-80ad-2294b31d11f6", 94 | "metadata": {}, 95 | "source": [ 96 | "In this step, we initialize the Splane model by ``Splane.init_model(...)`` using the anndata object list as input. The ``n_clusters`` parameter determines the number of spatial domains to be identified. The ``k`` parameter controls the degree of neighbors considered in the model, with a larger ``k`` value resulting in more emphasis on global structure rather than local structure. The ``gnn_dropout`` parameter influences the level of smoothness in the model’s predictions, with a higher ``gnn_dropout`` value resulting in a smoother output that accommodates the sparsity of the spatial transcriptomics data. " 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "id": "370b34d0-9a7e-44ca-bf3e-4a15898a83b9", 102 | "metadata": {}, 103 | "source": [ 104 | "We train the model by ``splane.train(...)`` to obtain latent feature of each spots/cells. The parameter ``d_l`` affects the level of batch effect correction between slices. By default, ``d_l`` is ``0.2`` for spatial transcriptomics data with single cell resolution." 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "id": "877a6eae-fbf7-4225-b5e1-aa5dd9511d6d", 110 | "metadata": {}, 111 | "source": [ 112 | "Then, we can identify the spatial domain to which each spot/cell belongs by ``splane.identify_spatial_domain(...)``. By default, the results will be saved in ``spatial_domain`` column in ``.obs``. If the key parameter is provided, the results will be saved in ``.obs[key]``." 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 6, 118 | "id": "2190f4f8-655d-421c-8f40-dab283a0624d", 119 | "metadata": { 120 | "tags": [] 121 | }, 122 | "outputs": [ 123 | { 124 | "name": "stdout", 125 | "output_type": "stream", 126 | "text": [ 127 | "Setting environment seed: 42\n", 128 | "Setting global seed: 42\n", 129 | "Calculating cell type weights...\n", 130 | "Generating GNN inputs...\n", 131 | "Calculating largest eigenvalue of normalized graph Laplacian...\n", 132 | "Calculating Chebyshev polynomials up to order 2...\n" 133 | ] 134 | }, 135 | { 136 | "name": "stderr", 137 | "output_type": "stream", 138 | "text": [ 139 | "The best epoch 115 total loss=-16.317 g loss=-15.619 d loss=3.488 d acc=0.060 simi loss=-0.997 db loss=0.614: 17%|█▋ | 170/1000 [7:43:09<37:41:19, 163.47s/it]" 140 | ] 141 | }, 142 | { 143 | "name": "stdout", 144 | "output_type": "stream", 145 | "text": [ 146 | "Stop trainning because of loss convergence\n" 147 | ] 148 | }, 149 | { 150 | "name": "stderr", 151 | "output_type": "stream", 152 | "text": [ 153 | "\n" 154 | ] 155 | } 156 | ], 157 | "source": [ 158 | "splane_model = Splane.init_model(adata_list, n_clusters=7,use_gpu=False,n_neighbors=25, gnn_dropout=0.5)\n", 159 | "splane_model.train(d_l=0.2)\n", 160 | "splane_model.identify_spatial_domain()" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 7, 166 | "id": "99ed6b52-11a3-4699-a704-5caa07d26ca1", 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "sc.concat(adata_list).write(f'../data/merfish_mouse_brain/merfish_mouse_brain.h5ad')" 171 | ] 172 | } 173 | ], 174 | "metadata": { 175 | "kernelspec": { 176 | "display_name": "python38-spacel-pytorch", 177 | "language": "python", 178 | "name": "python38-spacel-pytorch" 179 | }, 180 | "language_info": { 181 | "codemirror_mode": { 182 | "name": "ipython", 183 | "version": 3 184 | }, 185 | "file_extension": ".py", 186 | "mimetype": "text/x-python", 187 | "name": "python", 188 | "nbconvert_exporter": "python", 189 | "pygments_lexer": "ipython3", 190 | "version": "3.8.16" 191 | } 192 | }, 193 | "nbformat": 4, 194 | "nbformat_minor": 5 195 | } 196 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: SPACEL 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - cudatoolkit=10.2.* 7 | - python=3.8 8 | - pip 9 | - rpy2 10 | - r-base 11 | - r-fitdistrplus 12 | -------------------------------------------------------------------------------- /readthedocs_environment.yml: -------------------------------------------------------------------------------- 1 | name: SPACEL 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - cudatoolkit=10.2.* 7 | - python=3.8 8 | - pip 9 | - rpy2 10 | - r-base 11 | - r-fitdistrplus 12 | - pip: 13 | - -r requirements.txt 14 | - -r docs/requirements.txt 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numba 2 | scikit-learn 3 | scanpy 4 | scvi-tools<=0.20.3 5 | squidpy<1.3 6 | torch<=1.13 7 | gpytorch 8 | optax<=0.1.7 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | d = {} 7 | with open("SPACEL/_version.py") as f: 8 | exec(f.read(), d) 9 | 10 | setup( 11 | name="SPACEL", 12 | version=d["__version__"], 13 | author="Hao Xu", 14 | author_email="xuhaoustc@mail.ustc.edu.cn", 15 | description="SPACEL: characterizing spatial transcriptome architectures by deep-learning", 16 | long_description=long_description, 17 | long_description_content_type="text/markdown", 18 | url="https://github.com/QuKunLab/SPACEL", 19 | packages=find_packages(), 20 | python_requires=">=3.8", 21 | install_requires=[ 22 | "pip", 23 | "squidpy<1.3", 24 | "scvi-tools<=0.20.3", 25 | "scikit-learn", 26 | "scanpy", 27 | "numba", 28 | "torch<=1.13", 29 | "gpytorch", 30 | "optax<=0.1.7", 31 | ], 32 | ) --------------------------------------------------------------------------------