├── LICENSE ├── README.md ├── STitch3D ├── __init__.py ├── align_tools.py ├── model.py ├── networks.py └── utils.py ├── demos ├── Overview.jpg ├── mouse_brain_hpc.gif └── mouse_brain_layers.gif ├── environment.yml └── plot3D_func.R /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 YangLabHKUST 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # STitch3D 2 | [![DOI](https://zenodo.org/badge/567093619.svg)](https://zenodo.org/badge/latestdoi/567093619) [![PyPI](https://img.shields.io/pypi/v/stitch3d?color=green)](https://pypi.python.org/pypi/stitch3d/) [![PyPi license](https://badgen.net/pypi/license/stitch3d/)](https://pypi.org/project/stitch3d/) [![Downloads](https://static.pepy.tech/personalized-badge/stitch3d?period=total&units=international_system&left_color=grey&right_color=orange&left_text=downloads)](https://pepy.tech/project/stitch3d) [![Stars](https://img.shields.io/github/stars/YangLabHKUST/STitch3D?logo=GitHub&color=yellow)](https://github.com/YangLabHKUST/STitch3D/stargazers) 3 | 4 | *Construction of a 3D whole organism spatial atlas by joint modelling of multiple slices with deep neural networks* 5 | 6 | An effective and efficient 3D analysis method for spatial transcriptomics data. 7 | 8 | Check out our manuscript in Nature Machine Intelligence: 9 | + [Nature Machine Intelligence website](https://www.nature.com/articles/s42256-023-00734-1) 10 | + [Read fulltext link](https://rdcu.be/doZ9u) 11 | + [Preprint in bioRxiv](https://doi.org/10.1101/2023.02.02.526814) 12 | 13 | ![STitch3D\_pipeline](demos/Overview.jpg) 14 | 15 | We developed STitch3D, a deep learning-based method for 3D reconstruction of tissues or whole organisms. Briefly, STitch3D characterizes complex tissue architectures by borrowing information across multiple 2D tissue slices and integrates them with a paired single-cell RNA-sequencing atlas. 16 | 17 | With innovations in model designs, STitch3D enables two critical 3D analyses: First, STitch3D detects 3D spatial tissue regions which are related to biological functions, for example cortex layer structures in brain; Second, STitch3D infers 3D spatial distributions of fine-grained cell types in tissues, substantially improving the spatial resolution of seq-based ST approaches. The output of STitch3D can be further used for various downstream tasks like inference of spatial trajectories, denoising of spatial gene expression patterns, identification of genes enriched in specific biologically meaningful regions and detection of cell type gradients in newly generated virtual slices. 18 | 19 | An example: STitch3D reconstructed the adult mouse brain, detected 3D layer organizations of the cerebral cortex, and infered curve-shaped distributions of four hippocampal neuron types in three cornu ammonis areas and dentate gyrus. 20 | 21 | ![hpc](demos/mouse_brain_layers.gif) ![hpc](demos/mouse_brain_hpc.gif) 22 | 23 | Installation 24 | ------------ 25 | * STitch3D can be installed from PyPI: 26 | ```bash 27 | pip install stitch3d 28 | ``` 29 | * Alternatively, STitch3D can be downloaded from GitHub: 30 | ```bash 31 | git clone https://github.com/YangLabHKUST/STitch3D.git 32 | cd STitch3D 33 | conda config --set channel_priority strict 34 | conda env update --f environment.yml 35 | conda activate stitch3d 36 | ``` 37 | Normally the installation time will be less than ten minutes. We have tested our package on Linux (Ubuntu 18.04.5 LTS). Software dependencies are listed in [this website](https://stitch3d-tutorial.readthedocs.io/en/latest/installation.html#software-dependencies). 38 | 39 | Tutorials and reproducibility 40 | ----------------------------- 41 | We provided codes for reproducing the experiments of the paper "Construction of a 3D whole organism spatial atlas by joint modelling of multiple slices with deep neural networks", and comprehensive tutorials for using STitch3D. Please check the [tutorial website](https://stitch3d-tutorial.readthedocs.io/en/latest/index.html) for more details. 42 | 43 | Interactive 3D results 44 | ---------------------- 45 | Interactive 3D analysis results from STitch3D are available at the [website](https://stitch3d-tutorial.readthedocs.io/en/latest/index.html). 46 | 47 | Reference 48 | ---------------------- 49 | Gefei Wang, Jia Zhao, Yan Yan, Yang Wang, Angela Ruohao Wu, Can Yang. Construction of a 3D whole organism spatial atlas by joint modelling of multiple slices with deep neural networks. Nature Machine Intelligence 5, 1200–1213 (2023). 50 | -------------------------------------------------------------------------------- /STitch3D/__init__.py: -------------------------------------------------------------------------------- 1 | import STitch3D.utils 2 | import STitch3D.model 3 | import STitch3D.networks 4 | import STitch3D.align_tools -------------------------------------------------------------------------------- /STitch3D/align_tools.py: -------------------------------------------------------------------------------- 1 | import ot 2 | import numpy as np 3 | import anndata as ad 4 | from anndata import AnnData 5 | import torch 6 | import scipy.sparse 7 | from sklearn.neighbors import NearestNeighbors 8 | 9 | 10 | # Functions for the Iterative Closest Point algorithm 11 | # Credit to https://github.com/ClayFlannigan/icp 12 | def best_fit_transform(A, B): 13 | ''' 14 | Calculates the least-squares best-fit transform that maps corresponding points A to B in m spatial dimensions 15 | Input: 16 | A: Nxm numpy array of corresponding points 17 | B: Nxm numpy array of corresponding points 18 | Returns: 19 | T: (m+1)x(m+1) homogeneous transformation matrix that maps A on to B 20 | R: mxm rotation matrix 21 | t: mx1 translation vector 22 | ''' 23 | 24 | assert A.shape == B.shape 25 | 26 | # get number of dimensions 27 | m = A.shape[1] 28 | 29 | # translate points to their centroids 30 | centroid_A = np.mean(A, axis=0) 31 | centroid_B = np.mean(B, axis=0) 32 | AA = A - centroid_A 33 | BB = B - centroid_B 34 | 35 | # rotation matrix 36 | H = np.dot(AA.T, BB) 37 | U, S, Vt = np.linalg.svd(H) 38 | R = np.dot(Vt.T, U.T) 39 | 40 | # special reflection case 41 | if np.linalg.det(R) < 0: 42 | Vt[m-1,:] *= -1 43 | R = np.dot(Vt.T, U.T) 44 | 45 | # translation 46 | t = centroid_B.T - np.dot(R,centroid_A.T) 47 | 48 | # homogeneous transformation 49 | T = np.identity(m+1) 50 | T[:m, :m] = R 51 | T[:m, m] = t 52 | 53 | return T, R, t 54 | 55 | def nearest_neighbor(src, dst): 56 | ''' 57 | Find the nearest (Euclidean) neighbor in dst for each point in src 58 | Input: 59 | src: Nxm array of points 60 | dst: Nxm array of points 61 | Output: 62 | distances: Euclidean distances of the nearest neighbor 63 | indices: dst indices of the nearest neighbor 64 | ''' 65 | 66 | neigh = NearestNeighbors(n_neighbors=1) 67 | neigh.fit(dst) 68 | distances, indices = neigh.kneighbors(src, return_distance=True) 69 | return distances.ravel(), indices.ravel() 70 | 71 | 72 | def icp(A, B, init_pose=None, max_iterations=20, tolerance=0.001): 73 | ''' 74 | The Iterative Closest Point method: finds best-fit transform that maps points A on to points B 75 | Input: 76 | A: Nxm numpy array of source mD points 77 | B: Nxm numpy array of destination mD point 78 | init_pose: (m+1)x(m+1) homogeneous transformation 79 | max_iterations: exit algorithm after max_iterations 80 | tolerance: convergence criteria 81 | Output: 82 | T: final homogeneous transformation that maps A on to B 83 | distances: Euclidean distances (errors) of the nearest neighbor 84 | i: number of iterations to converge 85 | ''' 86 | 87 | # get number of dimensions 88 | m = A.shape[1] 89 | 90 | # make points homogeneous, copy them to maintain the originals 91 | src = np.ones((m+1,A.shape[0])) 92 | dst = np.ones((m+1,B.shape[0])) 93 | src[:m,:] = np.copy(A.T) 94 | dst[:m,:] = np.copy(B.T) 95 | 96 | # apply the initial pose estimation 97 | if init_pose is not None: 98 | src = np.dot(init_pose, src) 99 | 100 | prev_error = 0 101 | 102 | for i in range(max_iterations): 103 | # find the nearest neighbors between the current source and destination points 104 | distances, indices = nearest_neighbor(src[:m,:].T, dst[:m,:].T) 105 | 106 | # compute the transformation between the current source and nearest destination points 107 | T,_,_ = best_fit_transform(src[:m,:].T, dst[:m,indices].T) 108 | 109 | # update the current source 110 | src = np.dot(T, src) 111 | 112 | # check error 113 | mean_error = np.mean(distances) 114 | if np.abs(prev_error - mean_error) < tolerance: 115 | break 116 | prev_error = mean_error 117 | 118 | # calculate final transformation 119 | T,_,_ = best_fit_transform(A, src[:m,:].T) 120 | 121 | return T, distances, i 122 | 123 | 124 | def transform(point_cloud, T): 125 | point_cloud_align = np.ones((point_cloud.shape[0], 3)) 126 | point_cloud_align[:,0:2] = np.copy(point_cloud) 127 | point_cloud_align = np.dot(T, point_cloud_align.T).T 128 | return point_cloud_align[:, :2] 129 | 130 | 131 | # Functions for the PASTE algorithm 132 | # Credit to https://github.com/raphael-group/paste 133 | 134 | ## Covert a sparse matrix into a dense np array 135 | to_dense_array = lambda X: X.toarray() if isinstance(X,scipy.sparse.csr.spmatrix) else np.array(X) 136 | 137 | ## Returns the data matrix or representation 138 | extract_data_matrix = lambda adata,rep: adata.X if rep is None else adata.obsm[rep] 139 | 140 | def intersect(lst1, lst2): 141 | """ 142 | Gets and returns intersection of two lists. 143 | Args: 144 | lst1: List 145 | lst2: List 146 | Returns: 147 | lst3: List of common elements. 148 | """ 149 | 150 | temp = set(lst2) 151 | lst3 = [value for value in lst1 if value in temp] 152 | return lst3 153 | 154 | def kl_divergence_backend(X, Y): 155 | """ 156 | Returns pairwise KL divergence (over all pairs of samples) of two matrices X and Y. 157 | Takes advantage of POT backend to speed up computation. 158 | Args: 159 | X: np array with dim (n_samples by n_features) 160 | Y: np array with dim (m_samples by n_features) 161 | Returns: 162 | D: np array with dim (n_samples by m_samples). Pairwise KL divergence matrix. 163 | """ 164 | assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features." 165 | 166 | nx = ot.backend.get_backend(X,Y) 167 | 168 | X = X/nx.sum(X,axis=1, keepdims=True) 169 | Y = Y/nx.sum(Y,axis=1, keepdims=True) 170 | log_X = nx.log(X) 171 | log_Y = nx.log(Y) 172 | X_log_X = nx.einsum('ij,ij->i',X,log_X) 173 | X_log_X = nx.reshape(X_log_X,(1,X_log_X.shape[0])) 174 | D = X_log_X.T - nx.dot(X,log_Y.T) 175 | return nx.to_numpy(D) 176 | 177 | 178 | def my_fused_gromov_wasserstein(M, C1, C2, p, q, G_init = None, loss_fun='square_loss', alpha=0.5, armijo=False, log=False,numItermax=200, use_gpu = False, **kwargs): 179 | """ 180 | Adapted fused_gromov_wasserstein with the added capability of defining a G_init (inital mapping). 181 | Also added capability of utilizing different POT backends to speed up computation. 182 | 183 | For more info, see: https://pythonot.github.io/gen_modules/ot.gromov.html 184 | """ 185 | 186 | p, q = ot.utils.list_to_array(p, q) 187 | 188 | p0, q0, C10, C20, M0 = p, q, C1, C2, M 189 | nx = ot.backend.get_backend(p0, q0, C10, C20, M0) 190 | 191 | constC, hC1, hC2 = ot.gromov.init_matrix(C1, C2, p, q, loss_fun) 192 | 193 | if G_init is None: 194 | G0 = p[:, None] * q[None, :] 195 | else: 196 | G0 = (1/nx.sum(G_init)) * G_init 197 | if use_gpu: 198 | G0 = G0.cuda() 199 | 200 | def f(G): 201 | return ot.gromov.gwloss(constC, hC1, hC2, G) 202 | 203 | def df(G): 204 | return ot.gromov.gwggrad(constC, hC1, hC2, G) 205 | 206 | if log: 207 | res, log = ot.gromov.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) 208 | 209 | fgw_dist = log['loss'][-1] 210 | 211 | log['fgw_dist'] = fgw_dist 212 | log['u'] = log['u'] 213 | log['v'] = log['v'] 214 | return res, log 215 | 216 | else: 217 | return ot.gromov.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) 218 | 219 | 220 | def pairwise_align_paste( 221 | sliceA, 222 | sliceB, 223 | alpha = 0.1, 224 | dissimilarity ='kl', 225 | use_rep = None, 226 | G_init = None, 227 | a_distribution = None, 228 | b_distribution = None, 229 | norm = False, 230 | numItermax = 200, 231 | backend = ot.backend.NumpyBackend(), 232 | use_gpu = False, 233 | return_obj = False, 234 | verbose = False, 235 | gpu_verbose = False, 236 | coor_key = "spatial", 237 | **kwargs): 238 | """ 239 | Calculates and returns optimal alignment of two slices. 240 | 241 | Args: 242 | sliceA: Slice A to align. 243 | sliceB: Slice B to align. 244 | alpha: Alignment tuning parameter. Note: 0 <= alpha <= 1. 245 | dissimilarity: Expression dissimilarity measure: ``'kl'`` or ``'euclidean'``. 246 | use_rep: If ``None``, uses ``slice.X`` to calculate dissimilarity between spots, otherwise uses the representation given by ``slice.obsm[use_rep]``. 247 | G_init (array-like, optional): Initial mapping to be used in FGW-OT, otherwise default is uniform mapping. 248 | a_distribution (array-like, optional): Distribution of sliceA spots, otherwise default is uniform. 249 | b_distribution (array-like, optional): Distribution of sliceB spots, otherwise default is uniform. 250 | numItermax: Max number of iterations during FGW-OT. 251 | norm: If ``True``, scales spatial distances such that neighboring spots are at distance 1. Otherwise, spatial distances remain unchanged. 252 | backend: Type of backend to run calculations. For list of backends available on system: ``ot.backend.get_backend_list()``. 253 | use_gpu: If ``True``, use gpu. Otherwise, use cpu. Currently we only have gpu support for Pytorch. 254 | return_obj: If ``True``, additionally returns objective function output of FGW-OT. 255 | verbose: If ``True``, FGW-OT is verbose. 256 | gpu_verbose: If ``True``, print whether gpu is being used to user. 257 | coor_key: The key of spatial coordinates, ``spatial`` for default. 258 | 259 | Returns: 260 | - Alignment of spots. 261 | If ``return_obj = True``, additionally returns: 262 | 263 | - Objective function output of FGW-OT. 264 | """ 265 | 266 | # Determine if gpu or cpu is being used 267 | if use_gpu: 268 | try: 269 | import torch 270 | except: 271 | print("We currently only have gpu support for Pytorch. Please install torch.") 272 | 273 | if isinstance(backend,ot.backend.TorchBackend): 274 | if torch.cuda.is_available(): 275 | if gpu_verbose: 276 | print("gpu is available, using gpu.") 277 | else: 278 | if gpu_verbose: 279 | print("gpu is not available, resorting to torch cpu.") 280 | use_gpu = False 281 | else: 282 | print("We currently only have gpu support for Pytorch, please set backend = ot.backend.TorchBackend(). Reverting to selected backend cpu.") 283 | use_gpu = False 284 | else: 285 | if gpu_verbose: 286 | print("Using selected backend cpu. If you want to use gpu, set use_gpu = True.") 287 | 288 | # subset for common genes 289 | common_genes = intersect(sliceA.var.index, sliceB.var.index) 290 | sliceA = sliceA[:, common_genes] 291 | sliceB = sliceB[:, common_genes] 292 | 293 | # Backend 294 | nx = backend 295 | 296 | # Calculate spatial distances 297 | coordinatesA = sliceA.obsm[coor_key].copy() 298 | coordinatesA = nx.from_numpy(coordinatesA) 299 | coordinatesB = sliceB.obsm[coor_key].copy() 300 | coordinatesB = nx.from_numpy(coordinatesB) 301 | 302 | if isinstance(nx,ot.backend.TorchBackend): 303 | coordinatesA = coordinatesA.float() 304 | coordinatesB = coordinatesB.float() 305 | D_A = ot.dist(coordinatesA,coordinatesA, metric='euclidean') 306 | D_B = ot.dist(coordinatesB,coordinatesB, metric='euclidean') 307 | 308 | if isinstance(nx,ot.backend.TorchBackend) and use_gpu: 309 | D_A = D_A.cuda() 310 | D_B = D_B.cuda() 311 | 312 | # Calculate expression dissimilarity 313 | A_X, B_X = nx.from_numpy(to_dense_array(extract_data_matrix(sliceA,use_rep))), nx.from_numpy(to_dense_array(extract_data_matrix(sliceB,use_rep))) 314 | 315 | if isinstance(nx,ot.backend.TorchBackend) and use_gpu: 316 | A_X = A_X.cuda() 317 | B_X = B_X.cuda() 318 | 319 | if dissimilarity.lower()=='euclidean' or dissimilarity.lower()=='euc': 320 | M = ot.dist(A_X,B_X) 321 | else: 322 | s_A = A_X + 0.01 323 | s_B = B_X + 0.01 324 | M = kl_divergence_backend(s_A, s_B) 325 | M = nx.from_numpy(M) 326 | 327 | if isinstance(nx,ot.backend.TorchBackend) and use_gpu: 328 | M = M.cuda() 329 | 330 | # init distributions 331 | if a_distribution is None: 332 | a = nx.ones((sliceA.shape[0],))/sliceA.shape[0] 333 | else: 334 | a = nx.from_numpy(a_distribution) 335 | 336 | if b_distribution is None: 337 | b = nx.ones((sliceB.shape[0],))/sliceB.shape[0] 338 | else: 339 | b = nx.from_numpy(b_distribution) 340 | 341 | if isinstance(nx,ot.backend.TorchBackend) and use_gpu: 342 | a = a.cuda() 343 | b = b.cuda() 344 | 345 | if norm: 346 | D_A /= nx.min(D_A[D_A>0]) 347 | D_B /= nx.min(D_B[D_B>0]) 348 | 349 | # Run OT 350 | if G_init is not None: 351 | G_init = nx.from_numpy(G_init) 352 | if isinstance(nx,ot.backend.TorchBackend): 353 | G_init = G_init.float() 354 | if use_gpu: 355 | G_init.cuda() 356 | pi, logw = my_fused_gromov_wasserstein(M, D_A, D_B, a, b, G_init = G_init, loss_fun='square_loss', alpha= alpha, log=True, numItermax=numItermax,verbose=verbose, use_gpu = use_gpu) 357 | pi = nx.to_numpy(pi) 358 | obj = nx.to_numpy(logw['fgw_dist']) 359 | if isinstance(backend,ot.backend.TorchBackend) and use_gpu: 360 | torch.cuda.empty_cache() 361 | 362 | if return_obj: 363 | return pi, obj 364 | return pi 365 | 366 | 367 | def generalized_procrustes_analysis(X, Y, pi, output_params=False, matrix=False): 368 | """ 369 | Finds and applies optimal rotation between spatial coordinates of two layers (may also do a reflection). 370 | Args: 371 | X: np array of spatial coordinates (ex: sliceA.obs['spatial']) 372 | Y: np array of spatial coordinates (ex: sliceB.obs['spatial']) 373 | pi: mapping between the two layers output by PASTE 374 | output_params: Boolean of whether to return rotation angle and translations along with spatial coordiantes. 375 | matrix: Boolean of whether to return the rotation as a matrix or an angle 376 | Returns: 377 | Aligned spatial coordinates of X, Y, rotation angle, translation of X, translation of Y 378 | """ 379 | assert X.shape[1] == 2 and Y.shape[1] == 2 380 | 381 | tX = pi.sum(axis=1).dot(X) 382 | tY = pi.sum(axis=0).dot(Y) 383 | X = X - tX 384 | Y = Y - tY 385 | H = Y.T.dot(pi.T.dot(X)) 386 | U, S, Vt = np.linalg.svd(H) 387 | R = Vt.T.dot(U.T) 388 | Y = R.dot(Y.T).T 389 | if output_params and not matrix: 390 | M = np.array([[0,-1],[1,0]]) 391 | theta = np.arctan(np.trace(M.dot(H))/np.trace(H)) 392 | return X, Y, theta, tX, tY 393 | elif output_params and matrix: 394 | return X, Y, R, tX, tY 395 | else: 396 | return X, Y -------------------------------------------------------------------------------- /STitch3D/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import numpy as np 6 | import pandas as pd 7 | import scipy.sparse 8 | from tqdm import tqdm 9 | import os 10 | from STitch3D.networks import * 11 | 12 | 13 | class Model(): 14 | 15 | def __init__(self, adata_st, adata_basis, 16 | hidden_dims=[512, 128], 17 | n_heads=1, 18 | slice_emb_dim=16, 19 | coef_fe=0.1, 20 | training_steps=20000, 21 | lr=2e-3, 22 | seed=1234, 23 | distribution="Poisson" 24 | ): 25 | 26 | self.training_steps = training_steps 27 | 28 | self.adata_st = adata_st 29 | self.celltypes = list(adata_basis.obs.index) 30 | 31 | # add device 32 | self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 33 | 34 | # set random seed 35 | torch.manual_seed(seed) 36 | np.random.seed(seed) 37 | if torch.cuda.is_available(): 38 | torch.cuda.manual_seed_all(seed) 39 | torch.backends.cudnn.benchmark = True 40 | 41 | self.hidden_dims = [adata_st.shape[1]] + hidden_dims 42 | self.n_celltype = adata_basis.shape[0] 43 | self.n_slices = len(sorted(set(adata_st.obs["slice"].values))) 44 | 45 | # build model 46 | if distribution == "Poisson": 47 | self.net = DeconvNet(hidden_dims=self.hidden_dims, 48 | n_celltypes=self.n_celltype, 49 | n_slices=self.n_slices, 50 | n_heads=n_heads, 51 | slice_emb_dim=slice_emb_dim, 52 | coef_fe=coef_fe, 53 | ).to(self.device) 54 | else: #Negative Binomial distribution 55 | self.net = DeconvNet_NB(hidden_dims=self.hidden_dims, 56 | n_celltypes=self.n_celltype, 57 | n_slices=self.n_slices, 58 | n_heads=n_heads, 59 | slice_emb_dim=slice_emb_dim, 60 | coef_fe=coef_fe, 61 | ).to(self.device) 62 | 63 | self.optimizer = optim.Adamax(list(self.net.parameters()), lr=lr) 64 | 65 | # read data 66 | if scipy.sparse.issparse(adata_st.X): 67 | self.X = torch.from_numpy(adata_st.X.toarray()).float().to(self.device) 68 | else: 69 | self.X = torch.from_numpy(adata_st.X).float().to(self.device) 70 | self.A = torch.from_numpy(np.array(adata_st.obsm["graph"])).float().to(self.device) 71 | self.Y = torch.from_numpy(np.array(adata_st.obsm["count"])).float().to(self.device) 72 | self.lY = torch.from_numpy(np.array(adata_st.obs["library_size"].values.reshape(-1, 1))).float().to(self.device) 73 | self.slice = torch.from_numpy(np.array(adata_st.obs["slice"].values)).long().to(self.device) 74 | self.basis = torch.from_numpy(np.array(adata_basis.X)).float().to(self.device) 75 | 76 | def train(self, report_loss=True, step_interval=2000): 77 | self.net.train() 78 | for step in tqdm(range(self.training_steps)): 79 | loss = self.net(adj_matrix=self.A, 80 | node_feats=self.X, 81 | count_matrix=self.Y, 82 | library_size=self.lY, 83 | slice_label=self.slice, 84 | basis=self.basis) 85 | self.optimizer.zero_grad() 86 | loss.backward() 87 | self.optimizer.step() 88 | 89 | if report_loss: 90 | if not step % step_interval: 91 | print("Step: %s, Loss: %.4f, d_loss: %.4f, f_loss: %.4f" % (step, loss.item(), self.net.decon_loss.item(), self.net.features_loss.item())) 92 | 93 | 94 | def eval(self, adata_st_list_raw, save=False, output_path="./results"): 95 | self.net.eval() 96 | self.Z, self.beta, self.alpha, self.gamma = self.net.evaluate(self.A, self.X, self.slice) 97 | 98 | if save == True: 99 | if not os.path.exists(output_path): 100 | os.makedirs(output_path) 101 | 102 | # add learned representations to full ST adata object 103 | embeddings = self.Z.detach().cpu().numpy() 104 | cell_reps = pd.DataFrame(embeddings) 105 | cell_reps.index = self.adata_st.obs.index 106 | self.adata_st.obsm['latent'] = cell_reps.loc[self.adata_st.obs_names, ].values 107 | if save == True: 108 | cell_reps.to_csv(os.path.join(output_path, "representation.csv")) 109 | 110 | # add deconvolution results to original anndata objects 111 | b = self.beta.detach().cpu().numpy() 112 | n_spots = 0 113 | adata_st_decon_list = [] 114 | for i, adata_st_i in enumerate(adata_st_list_raw): 115 | adata_st_i.obs.index = adata_st_i.obs.index + "-slice%d" % i 116 | decon_res = pd.DataFrame(b[n_spots:(n_spots+adata_st_i.shape[0]), :], 117 | columns=self.celltypes) 118 | decon_res.index = adata_st_i.obs.index 119 | adata_st_i.obs = adata_st_i.obs.join(decon_res) 120 | n_spots += adata_st_i.shape[0] 121 | adata_st_decon_list.append(adata_st_i) 122 | 123 | if save == True: 124 | decon_res.to_csv(os.path.join(output_path, "prop_slice%d.csv" % i)) 125 | adata_st_i.write(os.path.join(output_path, "res_adata_slice%d.h5ad" % i)) 126 | 127 | # Save 3d coordinates 128 | if save == True: 129 | coor_3d = pd.DataFrame(data=self.adata_st.obsm['3D_coor'], index=self.adata_st.obs.index, columns=['x', 'y', 'z']) 130 | coor_3d.to_csv(os.path.join(output_path, "3D_coordinates.csv")) 131 | 132 | return adata_st_decon_list 133 | 134 | 135 | def cells_to_spatial(self, adata_ref_input, 136 | celltype_ref_col="celltype", # column of adata_ref_input.obs for cell type information 137 | celltype_ref=None, # specify cell types to use for deconvolution 138 | target_num=20, # target number of cells per spot 139 | save=False, 140 | lam_sim=0.1, 141 | lam_num=1e-3, 142 | lam_M=1, 143 | lr=2e-3, training_steps_M=20000, report_loss=True, step_interval=2000, output_path="./results"): 144 | 145 | import scanpy as sc 146 | 147 | # When map cells to spatial locations, 148 | # the reference dataset needs to be processed in the same way as we used it to construct the cell-type matrix 149 | 150 | adata_ref = adata_ref_input.copy() 151 | adata_ref.var_names_make_unique() 152 | # Remove mt-genes 153 | adata_ref = adata_ref[:, np.array(~adata_ref.var.index.isna()) 154 | & np.array(~adata_ref.var_names.str.startswith("mt-")) 155 | & np.array(~adata_ref.var_names.str.startswith("MT-"))] 156 | if celltype_ref is not None: 157 | if not isinstance(celltype_ref, list): 158 | raise ValueError("'celltype_ref' must be a list!") 159 | else: 160 | adata_ref = adata_ref[[(t in celltype_ref) for t in adata_ref.obs[celltype_ref_col].values.astype(str)], :] 161 | else: 162 | celltype_counts = adata_ref.obs[celltype_ref_col].value_counts() 163 | celltype_ref = list(celltype_counts.index[celltype_counts > 1]) 164 | adata_ref = adata_ref[[(t in celltype_ref) for t in adata_ref.obs[celltype_ref_col].values.astype(str)], :] 165 | 166 | # Remove cells and genes with 0 counts 167 | sc.pp.filter_cells(adata_ref, min_genes=1) 168 | sc.pp.filter_genes(adata_ref, min_cells=1) 169 | 170 | adata_ref = adata_ref[:, self.adata_st.var.index] 171 | 172 | celltype_list = list(sorted(set(adata_ref.obs[celltype_ref_col].values.astype(str)))) 173 | if scipy.sparse.issparse(adata_ref.X): 174 | ref_counts = adata_ref.X.toarray() 175 | else: 176 | ref_counts = adata_ref.X 177 | 178 | # Generate count matrix for single cells 179 | ref_counts = torch.from_numpy(ref_counts).to(torch.float32).to(self.device) # N_cells x G 180 | 181 | celltype_onehot = np.zeros((adata_ref.shape[0], len(celltype_list))) 182 | for i in range(adata_ref.shape[0]): 183 | celltype_onehot[i, celltype_list.index(list(adata_ref.obs[celltype_ref_col].values)[i])] += 1. 184 | 185 | # Generate one-hot cell-type matrix for single cells 186 | celltype_onehot = torch.from_numpy(celltype_onehot).to(torch.float32).to(self.device) # N_cells x C 187 | 188 | # Generate adjusted expression matrix for spatial spots 189 | Y_adjusted = (torch.matmul(self.beta, self.basis) * self.lY).detach() # N_spots x G 190 | 191 | beta = self.beta.detach() # N_spots x C 192 | 193 | M = torch.zeros(adata_ref.shape[0], self.Y.shape[0]) # N_cells x N_spots 194 | M = M.to(self.device) 195 | M.requires_grad = True 196 | 197 | self.optimizer_M = optim.Adamax([M], lr=lr) 198 | 199 | for step in tqdm(range(training_steps_M)): 200 | M_hat = F.softmax(M, dim=1) # N_cells x N_spots 201 | 202 | generated_spots = torch.matmul(torch.transpose(M_hat, 0, 1), ref_counts) # N_spots x G 203 | loss_sim_spots = - torch.mean(F.cosine_similarity(Y_adjusted, generated_spots, dim=1)) 204 | loss_sim_genes = - torch.mean(F.cosine_similarity(Y_adjusted, generated_spots, dim=0)) 205 | 206 | generated_spots_prop = torch.matmul(torch.transpose(M_hat, 0, 1), celltype_onehot) # N_spots x C 207 | generated_spots_prop = generated_spots_prop / torch.sum(M_hat, axis=0).view(-1, 1) # Normalize generated proportions 208 | 209 | loss_prop = torch.mean(torch.sum((generated_spots_prop - beta) ** 2, dim=1)) 210 | 211 | # regularizers 212 | target_num = adata_ref.shape[0] / self.Y.shape[0] 213 | reg_cell_num = torch.mean((torch.sum(M_hat, axis=0) - target_num)**2) 214 | reg_M = -torch.mean(M_hat * torch.log(M_hat)) 215 | loss_M = loss_prop + lam_sim * (loss_sim_spots + loss_sim_genes) + lam_num * reg_cell_num + lam_M * reg_M 216 | self.optimizer_M.zero_grad() 217 | loss_M.backward() 218 | self.optimizer_M.step() 219 | 220 | if report_loss: 221 | if not step % step_interval: 222 | print("Step: %s, Loss: %.4f, proption_loss: %.4f, spot_sim_loss: %.4f, cell_num_reg: %.4f, M_reg: %.4f" % 223 | (step, loss_M.item(), loss_prop.item(), (loss_sim_spots + loss_sim_genes).item(), reg_cell_num.item(), reg_M.item())) 224 | 225 | M_hat = F.softmax(M, dim=1) 226 | self.M_hat = M_hat.detach().cpu().numpy() 227 | 228 | adata_ref.obsm['spatial_aligned'] = self.adata_st[np.argmax(self.M_hat, axis=1)].obsm['spatial_aligned'] 229 | adata_ref.obsm['3D_coor'] = self.adata_st[np.argmax(self.M_hat, axis=1)].obsm['3D_coor'] 230 | adata_ref.obs['slice'] = self.adata_st[np.argmax(self.M_hat, axis=1)].obs['slice'].values 231 | 232 | return adata_ref 233 | 234 | 235 | -------------------------------------------------------------------------------- /STitch3D/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import numpy as np 6 | 7 | 8 | class DenseLayer(nn.Module): 9 | 10 | def __init__(self, 11 | c_in, # dimensionality of input features 12 | c_out, # dimensionality of output features 13 | zero_init=False, # initialize weights as zeros; use Xavier uniform init if zero_init=False 14 | ): 15 | 16 | super().__init__() 17 | 18 | self.linear = nn.Linear(c_in, c_out) 19 | 20 | # Initialization 21 | if zero_init: 22 | nn.init.zeros_(self.linear.weight.data) 23 | else: 24 | nn.init.uniform_(self.linear.weight.data, -np.sqrt(6 / (c_in + c_out)), np.sqrt(6 / (c_in + c_out))) 25 | nn.init.zeros_(self.linear.bias.data) 26 | 27 | def forward(self, 28 | node_feats, # input node features 29 | ): 30 | 31 | node_feats = self.linear(node_feats) 32 | 33 | return node_feats 34 | 35 | 36 | class GATSingleHead(nn.Module): 37 | 38 | def __init__(self, 39 | c_in, # dimensionality of input features 40 | c_out, # dimensionality of output features 41 | temp=1, # temperature parameter 42 | ): 43 | 44 | super().__init__() 45 | 46 | self.linear = nn.Linear(c_in, c_out) 47 | self.v0 = nn.Parameter(torch.Tensor(c_out, 1)) 48 | self.v1 = nn.Parameter(torch.Tensor(c_out, 1)) 49 | self.temp = temp 50 | 51 | # Initialization 52 | nn.init.uniform_(self.linear.weight.data, -np.sqrt(6 / (c_in + c_out)), np.sqrt(6 / (c_in + c_out))) 53 | nn.init.zeros_(self.linear.bias.data) 54 | nn.init.uniform_(self.v0.data, -np.sqrt(6 / (c_out + 1)), np.sqrt(6 / (c_out + 1))) 55 | nn.init.uniform_(self.v1.data, -np.sqrt(6 / (c_out + 1)), np.sqrt(6 / (c_out + 1))) 56 | 57 | def forward(self, 58 | node_feats, # input node features 59 | adj_matrix, # adjacency matrix including self-connections 60 | ): 61 | 62 | # Apply linear layer and sort nodes by head 63 | node_feats = self.linear(node_feats) 64 | f1 = torch.matmul(node_feats, self.v0) 65 | f2 = torch.matmul(node_feats, self.v1) 66 | attn_logits = adj_matrix * (f1 + f2.T) 67 | unnormalized_attentions = (F.sigmoid(attn_logits) - 0.5).to_sparse() 68 | attn_probs = torch.sparse.softmax(unnormalized_attentions / self.temp, dim=1) 69 | attn_probs = attn_probs.to_dense() 70 | node_feats = torch.matmul(attn_probs, node_feats) 71 | 72 | return node_feats 73 | 74 | 75 | class GATMultiHead(nn.Module): 76 | 77 | def __init__(self, 78 | c_in, # dimensionality of input features 79 | c_out, # dimensionality of output features 80 | n_heads=1, # number of attention heads 81 | concat_heads=True, # concatenate attention heads or not 82 | ): 83 | 84 | super().__init__() 85 | 86 | self.n_heads = n_heads 87 | self.concat_heads = concat_heads 88 | if self.concat_heads: 89 | assert c_out % n_heads == 0, "The number of output features should be divisible by the number of heads." 90 | c_out = c_out // n_heads 91 | 92 | self.block = nn.ModuleList() 93 | for i_block in range(self.n_heads): 94 | self.block.append(GATSingleHead(c_in=c_in, c_out=c_out)) 95 | 96 | def forward(self, 97 | node_feats, # input node features 98 | adj_matrix, # adjacency matrix including self-connections 99 | ): 100 | 101 | res = [] 102 | for i_block in range(self.n_heads): 103 | res.append(self.block[i_block](node_feats, adj_matrix)) 104 | 105 | if self.concat_heads: 106 | node_feats = torch.cat(res, dim=1) 107 | else: 108 | node_feats = torch.mean(torch.stack(res, dim=0), dim=0) 109 | 110 | return node_feats 111 | 112 | 113 | class DeconvNet(nn.Module): 114 | 115 | def __init__(self, 116 | hidden_dims, # dimensionality of hidden layers 117 | n_celltypes, # number of cell types 118 | n_slices, # number of slices 119 | n_heads, # number of attention heads 120 | slice_emb_dim, # dimensionality of slice id embedding 121 | coef_fe, 122 | ): 123 | 124 | super().__init__() 125 | 126 | # define layers 127 | # encoder layers 128 | self.encoder_layer1 = GATMultiHead(hidden_dims[0], hidden_dims[1], n_heads=n_heads, concat_heads=True) 129 | self.encoder_layer2 = DenseLayer(hidden_dims[1], hidden_dims[2]) 130 | # decoder layers 131 | self.decoder_layer1 = GATMultiHead(hidden_dims[2] + slice_emb_dim, hidden_dims[1], n_heads=n_heads, concat_heads=True) 132 | self.decoder_layer2 = DenseLayer(hidden_dims[1], hidden_dims[0]) 133 | # deconvolution layers 134 | self.deconv_alpha_layer = DenseLayer(hidden_dims[2] + slice_emb_dim, 1, zero_init=True) 135 | self.deconv_beta_layer = DenseLayer(hidden_dims[2], n_celltypes, zero_init=True) 136 | 137 | self.gamma = nn.Parameter(torch.Tensor(n_slices, hidden_dims[0]).zero_()) 138 | 139 | self.slice_emb = nn.Embedding(n_slices, slice_emb_dim) 140 | 141 | self.coef_fe = coef_fe 142 | 143 | def forward(self, 144 | adj_matrix, # adjacency matrix including self-connections 145 | node_feats, # input node features 146 | count_matrix, # gene expression counts 147 | library_size, # library size (based on Y) 148 | slice_label, # slice label 149 | basis, # basis matrix 150 | ): 151 | # encoder 152 | Z = self.encoder(adj_matrix, node_feats) 153 | 154 | # deconvolutioner 155 | slice_label_emb = self.slice_emb(slice_label) 156 | beta, alpha = self.deconvolutioner(Z, slice_label_emb) 157 | 158 | # decoder 159 | node_feats_recon = self.decoder(adj_matrix, Z, slice_label_emb) 160 | 161 | # reconstruction loss of node features 162 | self.features_loss = torch.mean(torch.sqrt(torch.sum(torch.pow(node_feats-node_feats_recon, 2), axis=1))) 163 | 164 | # deconvolution loss 165 | log_lam = torch.log(torch.matmul(beta, basis) + 1e-6) + alpha + self.gamma[slice_label] 166 | lam = torch.exp(log_lam) 167 | self.decon_loss = - torch.mean(torch.sum(count_matrix * 168 | (torch.log(library_size + 1e-6) + log_lam) - library_size * lam, axis=1)) 169 | 170 | # Total loss 171 | loss = self.decon_loss + self.coef_fe * self.features_loss 172 | 173 | return loss 174 | 175 | def evaluate(self, adj_matrix, node_feats, slice_label): 176 | slice_label_emb = self.slice_emb(slice_label) 177 | # encoder 178 | Z = self.encoder(adj_matrix, node_feats) 179 | 180 | # deconvolutioner 181 | beta, alpha = self.deconvolutioner(Z, slice_label_emb) 182 | 183 | return Z, beta, alpha, self.gamma 184 | 185 | def encoder(self, adj_matrix, node_feats): 186 | H = node_feats 187 | H = F.elu(self.encoder_layer1(H, adj_matrix)) 188 | Z = self.encoder_layer2(H) 189 | return Z 190 | 191 | def decoder(self, adj_matrix, Z, slice_label_emb): 192 | H = torch.cat((Z, slice_label_emb), axis=1) 193 | H = F.elu(self.decoder_layer1(H, adj_matrix)) 194 | X_recon = self.decoder_layer2(H) 195 | return X_recon 196 | 197 | def deconvolutioner(self, Z, slice_label_emb): 198 | beta = self.deconv_beta_layer(F.elu(Z)) 199 | beta = F.softmax(beta, dim=1) 200 | H = F.elu(torch.cat((Z, slice_label_emb), axis=1)) 201 | alpha = self.deconv_alpha_layer(H) 202 | return beta, alpha 203 | 204 | 205 | class DeconvNet_NB(nn.Module): 206 | 207 | def __init__(self, 208 | hidden_dims, # dimensionality of hidden layers 209 | n_celltypes, # number of cell types 210 | n_slices, # number of slices 211 | n_heads, # number of attention heads 212 | slice_emb_dim, # dimensionality of slice id embedding 213 | coef_fe, 214 | ): 215 | 216 | super().__init__() 217 | 218 | # define layers 219 | # encoder layers 220 | self.encoder_layer1 = GATMultiHead(hidden_dims[0], hidden_dims[1], n_heads=n_heads, concat_heads=True) 221 | self.encoder_layer2 = DenseLayer(hidden_dims[1], hidden_dims[2]) 222 | # decoder layers 223 | self.decoder_layer1 = GATMultiHead(hidden_dims[2] + slice_emb_dim, hidden_dims[1], n_heads=n_heads, concat_heads=True) 224 | self.decoder_layer2 = DenseLayer(hidden_dims[1], hidden_dims[0]) 225 | # deconvolution layers 226 | self.deconv_alpha_layer = DenseLayer(hidden_dims[2] + slice_emb_dim, 1, zero_init=True) 227 | self.deconv_beta_layer = DenseLayer(hidden_dims[2], n_celltypes, zero_init=True) 228 | 229 | self.gamma = nn.Parameter(torch.Tensor(n_slices, hidden_dims[0]).zero_()) 230 | self.logtheta = nn.Parameter(5. * torch.ones(n_slices, hidden_dims[0])) 231 | 232 | self.slice_emb = nn.Embedding(n_slices, slice_emb_dim) 233 | 234 | self.coef_fe = coef_fe 235 | 236 | def forward(self, 237 | adj_matrix, # adjacency matrix including self-connections 238 | node_feats, # input node features 239 | count_matrix, # gene expression counts 240 | library_size, # library size (based on Y) 241 | slice_label, # slice label 242 | basis, # basis matrix 243 | ): 244 | # encoder 245 | Z = self.encoder(adj_matrix, node_feats) 246 | 247 | # deconvolutioner 248 | slice_label_emb = self.slice_emb(slice_label) 249 | beta, alpha = self.deconvolutioner(Z, slice_label_emb) 250 | 251 | # decoder 252 | node_feats_recon = self.decoder(adj_matrix, Z, slice_label_emb) 253 | 254 | # reconstruction loss of node features 255 | self.features_loss = torch.mean(torch.sqrt(torch.sum(torch.pow(node_feats-node_feats_recon, 2), axis=1))) 256 | 257 | # deconvolution loss 258 | log_lam = torch.log(torch.matmul(beta, basis) + 1e-6) + alpha + self.gamma[slice_label] 259 | lam = torch.exp(log_lam) 260 | theta = torch.exp(self.logtheta) 261 | self.decon_loss = - torch.mean(torch.sum(torch.lgamma(count_matrix + theta[slice_label] + 1e-6) - 262 | torch.lgamma(theta[slice_label] + 1e-6) + 263 | theta[slice_label] * torch.log(theta[slice_label] + 1e-6) - 264 | theta[slice_label] * torch.log(theta[slice_label] + library_size * lam + 1e-6) + 265 | count_matrix * torch.log(library_size * lam + 1e-6) - 266 | count_matrix * torch.log(theta[slice_label] + library_size * lam + 1e-6), axis=1)) 267 | 268 | # Total loss 269 | loss = self.decon_loss + self.coef_fe * self.features_loss 270 | 271 | return loss 272 | 273 | def evaluate(self, adj_matrix, node_feats, slice_label): 274 | slice_label_emb = self.slice_emb(slice_label) 275 | # encoder 276 | Z = self.encoder(adj_matrix, node_feats) 277 | 278 | # deconvolutioner 279 | beta, alpha = self.deconvolutioner(Z, slice_label_emb) 280 | 281 | return Z, beta, alpha, self.gamma 282 | 283 | def encoder(self, adj_matrix, node_feats): 284 | H = node_feats 285 | H = F.elu(self.encoder_layer1(H, adj_matrix)) 286 | Z = self.encoder_layer2(H) 287 | return Z 288 | 289 | def decoder(self, adj_matrix, Z, slice_label_emb): 290 | H = torch.cat((Z, slice_label_emb), axis=1) 291 | H = F.elu(self.decoder_layer1(H, adj_matrix)) 292 | X_recon = self.decoder_layer2(H) 293 | return X_recon 294 | 295 | def deconvolutioner(self, Z, slice_label_emb): 296 | beta = self.deconv_beta_layer(F.elu(Z)) 297 | beta = F.softmax(beta, dim=1) 298 | H = F.elu(torch.cat((Z, slice_label_emb), axis=1)) 299 | alpha = self.deconv_alpha_layer(H) 300 | return beta, alpha 301 | 302 | -------------------------------------------------------------------------------- /STitch3D/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scanpy as sc 3 | import anndata as ad 4 | import pandas as pd 5 | import scipy.sparse 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | from STitch3D.align_tools import * 9 | from sklearn.neighbors import NearestNeighbors 10 | from sklearn.metrics import pairwise_distances 11 | from matplotlib import cm 12 | 13 | 14 | def align_spots(adata_st_list_input, # list of spatial transcriptomics datasets 15 | method="icp", # "icp" or "paste" 16 | data_type="Visium", # a spot has six nearest neighborhoods if "Visium", four nearest neighborhoods otherwise 17 | coor_key="spatial", # "spatial" for visium; key for the spatial coordinates used for alignment 18 | tol=0.01, # parameter for "icp" method; tolerance level 19 | test_all_angles=False, # parameter for "icp" method; whether to test multiple rotation angles or not 20 | plot=False, 21 | paste_alpha=0.1, 22 | paste_dissimilarity="kl" 23 | ): 24 | # Align coordinates of spatial transcriptomics 25 | 26 | # The first adata in the list is used as a reference for alignment 27 | adata_st_list = adata_st_list_input.copy() 28 | 29 | if plot: 30 | # Choose colors 31 | cmap = cm.get_cmap('rainbow', len(adata_st_list)) 32 | colors_list = [matplotlib.colors.rgb2hex(cmap(i)) for i in range(len(adata_st_list))] 33 | 34 | # Plot spots before alignment 35 | plt.figure(figsize=(5, 5)) 36 | plt.title("Before alignment") 37 | for i in range(len(adata_st_list)): 38 | plt.scatter(adata_st_list[i].obsm[coor_key][:, 0], 39 | adata_st_list[i].obsm[coor_key][:, 1], 40 | c=colors_list[i], 41 | label="Slice %d spots" % i, s=5., alpha=0.5) 42 | ax = plt.gca() 43 | ax.set_ylim(ax.get_ylim()[::-1]) 44 | plt.xticks([]) 45 | plt.yticks([]) 46 | plt.legend(loc=(1.02, .2), ncol=(len(adata_st_list)//13 + 1)) 47 | plt.show() 48 | 49 | 50 | if (method == "icp") or (method == "ICP"): 51 | print("Using the Iterative Closest Point algorithm for alignemnt.") 52 | # Detect edges 53 | print("Detecting edges...") 54 | point_cloud_list = [] 55 | for adata in adata_st_list: 56 | # Use in-tissue spots only 57 | if 'in_tissue' in adata.obs.columns: 58 | adata = adata[adata.obs['in_tissue'] == 1] 59 | if data_type == "Visium": 60 | loc_x = adata.obs.loc[:, ["array_row"]] 61 | loc_x = np.array(loc_x) * np.sqrt(3) 62 | loc_y = adata.obs.loc[:, ["array_col"]] 63 | loc_y = np.array(loc_y) 64 | loc = np.concatenate((loc_x, loc_y), axis=1) 65 | pairwise_loc_distsq = np.sum((loc.reshape([1,-1,2]) - loc.reshape([-1,1,2])) ** 2, axis=2) 66 | n_neighbors = np.sum(pairwise_loc_distsq < 5, axis=1) - 1 67 | edge = ((n_neighbors > 1) & (n_neighbors < 5)).astype(np.float32) 68 | else: 69 | loc_x = adata.obs.loc[:, ["array_row"]] 70 | loc_x = np.array(loc_x) 71 | loc_y = adata.obs.loc[:, ["array_col"]] 72 | loc_y = np.array(loc_y) 73 | loc = np.concatenate((loc_x, loc_y), axis=1) 74 | pairwise_loc_distsq = np.sum((loc.reshape([1,-1,2]) - loc.reshape([-1,1,2])) ** 2, axis=2) 75 | min_distsq = np.sort(np.unique(pairwise_loc_distsq), axis=None)[1] 76 | n_neighbors = np.sum(pairwise_loc_distsq < (min_distsq * 3), axis=1) - 1 77 | edge = ((n_neighbors > 1) & (n_neighbors < 7)).astype(np.float32) 78 | point_cloud_list.append(adata.obsm[coor_key][edge == 1].copy()) 79 | 80 | # Align edges 81 | print("Aligning edges...") 82 | trans_list = [] 83 | adata_st_list[0].obsm["spatial_aligned"] = adata_st_list[0].obsm[coor_key].copy() 84 | # Calculate pairwise transformation matrices 85 | for i in range(len(adata_st_list) - 1): 86 | if test_all_angles == True: 87 | for angle in [0., np.pi * 1 / 3, np.pi * 2 / 3, np.pi, np.pi * 4 / 3, np.pi * 5 / 3]: 88 | R = np.array([[np.cos(angle), np.sin(angle), 0], 89 | [-np.sin(angle), np.cos(angle), 0], 90 | [0, 0, 1]]).T 91 | T, distances, _ = icp(transform(point_cloud_list[i+1], R), point_cloud_list[i], tolerance=tol) 92 | if angle == 0: 93 | loss_best = np.mean(distances) 94 | angle_best = angle 95 | R_best = R 96 | T_best = T 97 | else: 98 | if np.mean(distances) < loss_best: 99 | loss_best = np.mean(distances) 100 | angle_best = angle 101 | R_best = R 102 | T_best = T 103 | T = T_best @ R_best 104 | else: 105 | T, _, _ = icp(point_cloud_list[i+1], point_cloud_list[i], tolerance=tol) 106 | trans_list.append(T) 107 | # Tranform 108 | for i in range(len(adata_st_list) - 1): 109 | point_cloud_align = adata_st_list[i+1].obsm[coor_key].copy() 110 | for T in trans_list[:(i+1)][::-1]: 111 | point_cloud_align = transform(point_cloud_align, T) 112 | adata_st_list[i+1].obsm["spatial_aligned"] = point_cloud_align 113 | 114 | elif (method == "paste") or (method == "PASTE"): 115 | print("Using PASTE algorithm for alignemnt.") 116 | # Align spots 117 | print("Aligning spots...") 118 | pis = [] 119 | # Calculate pairwise transformation matrices 120 | for i in range(len(adata_st_list) - 1): 121 | pi = pairwise_align_paste(adata_st_list[i], adata_st_list[i+1], coor_key=coor_key, 122 | alpha = paste_alpha, dissimilarity = paste_dissimilarity) 123 | pis.append(pi) 124 | # Tranform 125 | S1, S2 = generalized_procrustes_analysis(adata_st_list[0].obsm[coor_key], 126 | adata_st_list[1].obsm[coor_key], 127 | pis[0]) 128 | adata_st_list[0].obsm["spatial_aligned"] = S1 129 | adata_st_list[1].obsm["spatial_aligned"] = S2 130 | for i in range(1, len(adata_st_list) - 1): 131 | S1, S2 = generalized_procrustes_analysis(adata_st_list[i].obsm["spatial_aligned"], 132 | adata_st_list[i+1].obsm[coor_key], 133 | pis[i]) 134 | adata_st_list[i+1].obsm["spatial_aligned"] = S2 135 | 136 | if plot: 137 | plt.figure(figsize=(5, 5)) 138 | plt.title("After alignment") 139 | for i in range(len(adata_st_list)): 140 | plt.scatter(adata_st_list[i].obsm["spatial_aligned"][:, 0], 141 | adata_st_list[i].obsm["spatial_aligned"][:, 1], 142 | c=colors_list[i], 143 | label="Slice %d spots" % i, s=5., alpha=0.5) 144 | ax = plt.gca() 145 | ax.set_ylim(ax.get_ylim()[::-1]) 146 | plt.xticks([]) 147 | plt.yticks([]) 148 | plt.legend(loc=(1.02, .2), ncol=(len(adata_st_list)//13 + 1)) 149 | plt.show() 150 | 151 | return adata_st_list 152 | 153 | 154 | def preprocess(adata_st_list_input, # list of spatial transcriptomics (ST) anndata objects 155 | adata_ref_input, # reference single-cell anndata object 156 | celltype_ref_col="celltype", # column of adata_ref_input.obs for cell type information 157 | sample_col=None, # column of adata_ref_input.obs for batch labels 158 | celltype_ref=None, # specify cell types to use for deconvolution 159 | n_hvg_group=500, # number of highly variable genes for reference anndata 160 | three_dim_coor=None, # if not None, use existing 3d coordinates in shape [# of total spots, 3] 161 | coor_key="spatial_aligned", # "spatial_aligned" by default 162 | rad_cutoff=None, # cutoff radius of spots for building graph 163 | rad_coef=1.1, # if rad_cutoff=None, rad_cutoff is the minimum distance between spots multiplies rad_coef 164 | slice_dist_micron=None, # pairwise distances in micrometer for reconstructing z-axis 165 | prune_graph_cos=False, # prune graph connections according to cosine similarity 166 | cos_threshold=0.5, # threshold for pruning graph connections 167 | c2c_dist=100, # center to center distance between nearest spots in micrometer 168 | ): 169 | 170 | adata_st_list = adata_st_list_input.copy() 171 | 172 | print("Finding highly variable genes...") 173 | adata_ref = adata_ref_input.copy() 174 | adata_ref.var_names_make_unique() 175 | # Remove mt-genes 176 | adata_ref = adata_ref[:, np.array(~adata_ref.var.index.isna()) 177 | & np.array(~adata_ref.var_names.str.startswith("mt-")) 178 | & np.array(~adata_ref.var_names.str.startswith("MT-"))] 179 | if celltype_ref is not None: 180 | if not isinstance(celltype_ref, list): 181 | raise ValueError("'celltype_ref' must be a list!") 182 | else: 183 | adata_ref = adata_ref[[(t in celltype_ref) for t in adata_ref.obs[celltype_ref_col].values.astype(str)], :] 184 | else: 185 | celltype_counts = adata_ref.obs[celltype_ref_col].value_counts() 186 | celltype_ref = list(celltype_counts.index[celltype_counts > 1]) 187 | adata_ref = adata_ref[[(t in celltype_ref) for t in adata_ref.obs[celltype_ref_col].values.astype(str)], :] 188 | 189 | # Remove cells and genes with 0 counts 190 | sc.pp.filter_cells(adata_ref, min_genes=1) 191 | sc.pp.filter_genes(adata_ref, min_cells=1) 192 | 193 | # Concatenate ST adatas 194 | for i in range(len(adata_st_list)): 195 | adata_st_new = adata_st_list[i].copy() 196 | adata_st_new.var_names_make_unique() 197 | # Remove mt-genes 198 | adata_st_new = adata_st_new[:, (np.array(~adata_st_new.var.index.str.startswith("mt-")) 199 | & np.array(~adata_st_new.var.index.str.startswith("MT-")))] 200 | adata_st_new.obs.index = adata_st_new.obs.index + "-slice%d" % i 201 | adata_st_new.obs['slice'] = i 202 | if i == 0: 203 | adata_st = adata_st_new 204 | else: 205 | genes_shared = adata_st.var.index & adata_st_new.var.index 206 | adata_st = adata_st[:, genes_shared].concatenate(adata_st_new[:, genes_shared], index_unique=None) 207 | 208 | adata_st.obs["slice"] = adata_st.obs["slice"].values.astype(int) 209 | 210 | # Take gene intersection 211 | genes = list(adata_st.var.index & adata_ref.var.index) 212 | adata_ref = adata_ref[:, genes] 213 | adata_st = adata_st[:, genes] 214 | 215 | # Select hvgs 216 | adata_ref_log = adata_ref.copy() 217 | sc.pp.log1p(adata_ref_log) 218 | hvgs = select_hvgs(adata_ref_log, celltype_ref_col=celltype_ref_col, num_per_group=n_hvg_group) 219 | 220 | print("%d highly variable genes selected." % len(hvgs)) 221 | adata_ref = adata_ref[:, hvgs] 222 | 223 | print("Calculate basis for deconvolution...") 224 | sc.pp.filter_cells(adata_ref, min_genes=1) 225 | sc.pp.normalize_total(adata_ref, target_sum=1) 226 | celltype_list = list(sorted(set(adata_ref.obs[celltype_ref_col].values.astype(str)))) 227 | 228 | basis = np.zeros((len(celltype_list), len(adata_ref.var.index))) 229 | if sample_col is not None: 230 | sample_list = list(sorted(set(adata_ref.obs[sample_col].values.astype(str)))) 231 | for i in range(len(celltype_list)): 232 | c = celltype_list[i] 233 | tmp_list = [] 234 | for j in range(len(sample_list)): 235 | s = sample_list[j] 236 | tmp = adata_ref[(adata_ref.obs[celltype_ref_col].values.astype(str) == c) & 237 | (adata_ref.obs[sample_col].values.astype(str) == s), :].X 238 | if scipy.sparse.issparse(tmp): 239 | tmp = tmp.toarray() 240 | if tmp.shape[0] >= 3: 241 | tmp_list.append(np.mean(tmp, axis=0).reshape((-1))) 242 | tmp_mean = np.mean(tmp_list, axis=0) 243 | if scipy.sparse.issparse(tmp_mean): 244 | tmp_mean = tmp_mean.toarray() 245 | print("%d batches are used for computing the basis vector of cell type <%s>." % (len(tmp_list), c)) 246 | basis[i, :] = tmp_mean 247 | else: 248 | for i in range(len(celltype_list)): 249 | c = celltype_list[i] 250 | tmp = adata_ref[adata_ref.obs[celltype_ref_col].values.astype(str) == c, :].X 251 | if scipy.sparse.issparse(tmp): 252 | tmp = tmp.toarray() 253 | basis[i, :] = np.mean(tmp, axis=0).reshape((-1)) 254 | 255 | adata_basis = ad.AnnData(X=basis) 256 | df_gene = pd.DataFrame({"gene": adata_ref.var.index}) 257 | df_gene = df_gene.set_index("gene") 258 | df_celltype = pd.DataFrame({"celltype": celltype_list}) 259 | df_celltype = df_celltype.set_index("celltype") 260 | adata_basis.obs = df_celltype 261 | adata_basis.var = df_gene 262 | adata_basis = adata_basis[~np.isnan(adata_basis.X[:, 0])] 263 | 264 | print("Preprocess ST data...") 265 | # Store counts and library sizes for Poisson modeling 266 | st_mtx = adata_st[:, hvgs].X.copy() 267 | if scipy.sparse.issparse(st_mtx): 268 | st_mtx = st_mtx.toarray() 269 | adata_st.obsm["count"] = st_mtx 270 | st_library_size = np.sum(st_mtx, axis=1) 271 | adata_st.obs["library_size"] = st_library_size 272 | 273 | # Normalize ST data 274 | sc.pp.normalize_total(adata_st, target_sum=1e4) 275 | sc.pp.log1p(adata_st) 276 | adata_st = adata_st[:, hvgs] 277 | if scipy.sparse.issparse(adata_st.X): 278 | adata_st.X = adata_st.X.toarray() 279 | 280 | # Build a graph for spots across multiple slices 281 | print("Start building a graph...") 282 | 283 | # Build 3D coordinates 284 | if three_dim_coor is None: 285 | 286 | # The first adata in adata_list is used as a reference for computing cutoff radius of spots 287 | adata_st_ref = adata_st_list[0].copy() 288 | loc_ref = np.array(adata_st_ref.obsm[coor_key]) 289 | pair_dist_ref = pairwise_distances(loc_ref) 290 | min_dist_ref = np.sort(np.unique(pair_dist_ref), axis=None)[1] 291 | 292 | if rad_cutoff is None: 293 | # The radius is computed base on the attribute "adata.obsm['spatial']" 294 | rad_cutoff = min_dist_ref * rad_coef 295 | print("Radius for graph connection is %.4f." % rad_cutoff) 296 | 297 | # Use the attribute "adata.obsm['spatial_aligned']" to build a global graph 298 | if slice_dist_micron is None: 299 | loc_xy = pd.DataFrame(adata_st.obsm['spatial_aligned']).values 300 | loc_z = np.zeros(adata_st.shape[0]) 301 | loc = np.concatenate([loc_xy, loc_z.reshape(-1, 1)], axis=1) 302 | else: 303 | if len(slice_dist_micron) != (len(adata_st_list) - 1): 304 | raise ValueError("The length of 'slice_dist_micron' should be the number of adatas - 1 !") 305 | else: 306 | loc_xy = pd.DataFrame(adata_st.obsm['spatial_aligned']).values 307 | loc_z = np.zeros(adata_st.shape[0]) 308 | dim = 0 309 | for i in range(len(slice_dist_micron)): 310 | dim += adata_st_list[i].shape[0] 311 | loc_z[dim:] += slice_dist_micron[i] * (min_dist_ref / c2c_dist) 312 | loc = np.concatenate([loc_xy, loc_z.reshape(-1, 1)], axis=1) 313 | 314 | # If 3D coordinates already exists 315 | else: 316 | if rad_cutoff is None: 317 | raise ValueError("Please specify 'rad_cutoff' for finding 3D neighbors!") 318 | loc = three_dim_coor 319 | 320 | pair_dist = pairwise_distances(loc) 321 | G = (pair_dist < rad_cutoff).astype(float) 322 | 323 | if prune_graph_cos: 324 | pair_dist_cos = pairwise_distances(adata_st.X, metric="cosine") # 1 - cosine_similarity 325 | G_cos = (pair_dist_cos < (1 - cos_threshold)).astype(float) 326 | G = G * G_cos 327 | 328 | print('%.4f neighbors per cell on average.' % (np.mean(np.sum(G, axis=1)) - 1)) 329 | adata_st.obsm["graph"] = G 330 | adata_st.obsm["3D_coor"] = loc 331 | 332 | return adata_st, adata_basis 333 | 334 | 335 | def select_hvgs(adata_ref, celltype_ref_col, num_per_group=200): 336 | sc.tl.rank_genes_groups(adata_ref, groupby=celltype_ref_col, method="t-test", key_added="ttest", use_raw=False) 337 | markers_df = pd.DataFrame(adata_ref.uns['ttest']['names']).iloc[0:num_per_group, :] 338 | genes = sorted(list(np.unique(markers_df.melt().value.values))) 339 | return genes 340 | 341 | 342 | def calculate_impubasis(adata_st_input, #st anndata object (should be one of the output from STitch3D.utils.preprocess) 343 | adata_ref_input, # reference single-cell anndata object (raw data) 344 | celltype_ref_col="celltype", # column of adata_ref_input.obs for cell type information 345 | sample_col=None, # column of adata_ref_input.obs for batch labels 346 | celltype_ref=None, # specify cell types to use for deconvolution 347 | ): 348 | 349 | adata_ref = adata_ref_input.copy() 350 | adata_ref.var_names_make_unique() 351 | # Remove mt-genes 352 | adata_ref = adata_ref[:, np.array(~adata_ref.var.index.isna()) 353 | & np.array(~adata_ref.var_names.str.startswith("mt-")) 354 | & np.array(~adata_ref.var_names.str.startswith("MT-"))] 355 | if celltype_ref is not None: 356 | if not isinstance(celltype_ref, list): 357 | raise ValueError("'celltype_ref' must be a list!") 358 | else: 359 | adata_ref = adata_ref[[(t in celltype_ref) for t in adata_ref.obs[celltype_ref_col].values.astype(str)], :] 360 | else: 361 | celltype_counts = adata_ref.obs[celltype_ref_col].value_counts() 362 | celltype_ref = list(celltype_counts.index[celltype_counts > 1]) 363 | adata_ref = adata_ref[[(t in celltype_ref) for t in adata_ref.obs[celltype_ref_col].values.astype(str)], :] 364 | 365 | # Remove cells and genes with 0 counts 366 | sc.pp.filter_cells(adata_ref, min_genes=1) 367 | sc.pp.filter_genes(adata_ref, min_cells=1) 368 | 369 | # Calculate single cell library sizes 370 | hvgs = adata_st_input.var.index 371 | adata_ref_ls = adata_ref[:, hvgs] 372 | sc.pp.filter_cells(adata_ref_ls, min_genes=1) 373 | adata_ref = adata_ref[adata_ref_ls.obs.index, :] 374 | # ref_ls: library size (only account for hvgs) of single cells in ref 375 | if scipy.sparse.issparse(adata_ref_ls.X): 376 | ref_ls = np.sum(adata_ref_ls.X.toarray(), axis=1).reshape((-1,1)) 377 | adata_ref.obsm["forimpu"] = adata_ref.X.toarray() / ref_ls 378 | else: 379 | ref_ls = np.sum(adata_ref_ls.X, axis=1).reshape((-1,1)) 380 | adata_ref.obsm["forimpu"] = adata_ref.X / ref_ls 381 | 382 | # Calculate basis for imputation 383 | celltype_list = list(sorted(set(adata_ref.obs[celltype_ref_col].values.astype(str)))) 384 | basis_impu = np.zeros((len(celltype_list), len(adata_ref.var.index))) 385 | if sample_col is not None: 386 | sample_list = list(sorted(set(adata_ref.obs[sample_col].values.astype(str)))) 387 | for i in range(len(celltype_list)): 388 | c = celltype_list[i] 389 | tmp_list = [] 390 | for j in range(len(sample_list)): 391 | s = sample_list[j] 392 | tmp = adata_ref[(adata_ref.obs[celltype_ref_col].values.astype(str) == c) & 393 | (adata_ref.obs[sample_col].values.astype(str) == s), :].obsm["forimpu"] 394 | if scipy.sparse.issparse(tmp): 395 | tmp = tmp.toarray() 396 | if tmp.shape[0] >= 3: 397 | tmp_list.append(np.mean(tmp, axis=0).reshape((-1))) 398 | tmp_mean = np.mean(tmp_list, axis=0) 399 | if scipy.sparse.issparse(tmp_mean): 400 | tmp_mean = tmp_mean.toarray() 401 | print("%d batches are used for computing the basis vector of cell type <%s>." % (len(tmp_list), c)) 402 | basis_impu[i, :] = tmp_mean 403 | else: 404 | for i in range(len(celltype_list)): 405 | c = celltype_list[i] 406 | tmp = adata_ref[adata_ref.obs[celltype_ref_col].values.astype(str) == c, :].obsm["forimpu"] 407 | if scipy.sparse.issparse(tmp): 408 | tmp = tmp.toarray() 409 | basis_impu[i, :] = np.mean(tmp, axis=0).reshape((-1)) 410 | 411 | adata_basis_impu = ad.AnnData(X=basis_impu) 412 | df_gene = pd.DataFrame({"gene": adata_ref.var.index}) 413 | df_gene = df_gene.set_index("gene") 414 | df_celltype = pd.DataFrame({"celltype": celltype_list}) 415 | df_celltype = df_celltype.set_index("celltype") 416 | adata_basis_impu.obs = df_celltype 417 | adata_basis_impu.var = df_gene 418 | adata_basis_impu = adata_basis_impu[~np.isnan(adata_basis_impu.X[:, 0])] 419 | return adata_basis_impu 420 | 421 | 422 | -------------------------------------------------------------------------------- /demos/Overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLabHKUST/STitch3D/8da29785f6637c496e6660b44c5184ea434a7ec0/demos/Overview.jpg -------------------------------------------------------------------------------- /demos/mouse_brain_hpc.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLabHKUST/STitch3D/8da29785f6637c496e6660b44c5184ea434a7ec0/demos/mouse_brain_hpc.gif -------------------------------------------------------------------------------- /demos/mouse_brain_layers.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLabHKUST/STitch3D/8da29785f6637c496e6660b44c5184ea434a7ec0/demos/mouse_brain_layers.gif -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: stitch3d 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - bioconda 6 | - defaults 7 | dependencies: 8 | - python>=3.7 9 | - pytorch>=1.6.0, <=1.13.1 10 | - scanpy=1.7.2 11 | - anndata=0.7.6 12 | - pandas=1.1.5 13 | - numpy>=1.19.0 14 | - louvain=0.7.0 15 | - leidenalg>=0.7.0 16 | - umap-learn>=0.4.6 17 | - pot>=0.8.0 18 | - numba>=0.49.1 19 | - matplotlib<3.7 20 | -------------------------------------------------------------------------------- /plot3D_func.R: -------------------------------------------------------------------------------- 1 | plot3D_proportions <- function(directory, 2 | celltypes, 3 | celltype_colors, 4 | um=c(1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1), 5 | axis_rescale=c(1,1,1), 6 | spot_radius=0.5, 7 | alpha_threshold=0.2, 8 | alpha_background=0.02){ 9 | 10 | #load cell-type proportions and 3D coordinates of spots 11 | file_list <- list.files(path = directory) 12 | n_slice <- sum(unlist(lapply(file_list, function(x){startsWith(x, "prop_slice")}))) 13 | #cell-type proportions 14 | for (i in (0:(n_slice-1))){ 15 | prop <- read.table(paste0(directory, "/prop_slice", i, ".csv"), sep=",", header=TRUE) 16 | if (i == 0){ 17 | prop_all <- prop 18 | }else{ 19 | prop_all <- rbind(prop_all, prop) 20 | } 21 | } 22 | colnames(prop_all)[1] <- "spot" 23 | prop_all <- prop_all[, c("spot", celltypes)] 24 | #3D coordinates 25 | coor_3d <- read.table(paste0(directory, "/3D_coordinates.csv"), sep=",", header=TRUE) 26 | colnames(coor_3d)[1] <- "spot" 27 | 28 | spots.table <- merge(coor_3d, prop_all, by=c("spot")) 29 | 30 | 31 | #set colors and alpha values for spots 32 | if (length(celltypes) > 1){ 33 | prop <- spots.table[, celltypes] 34 | }else{ 35 | prop <- data.frame(ct = spots.table[, celltypes]) 36 | colnames(prop) <- celltypes 37 | } 38 | prop$max_prop <- apply(prop, 1, max) 39 | prop$max_celltype <- apply(prop, 1, function(x){celltypes[which.max(x)]}) 40 | 41 | prop$color <- "gray" 42 | prop$alpha <- prop$max_prop 43 | 44 | #for others 45 | prop$alpha[prop$max_prop <= alpha_threshold] <- alpha_background 46 | 47 | #for target cell types 48 | for (c in 1:length(celltypes)){ 49 | prop$color[(prop$max_prop > alpha_threshold) & (as.vector(prop$max_celltype) == celltypes[c])] <- celltype_colors[c] 50 | } 51 | 52 | 53 | #3D plot 54 | open3d(windowRect = c(0, 0, 720, 720)) 55 | par3d(persp) 56 | view3d(userMatrix = matrix(um, byrow=TRUE, nrow=4)) 57 | spheres3d(spots.table$x*axis_rescale[1], spots.table$y*axis_rescale[2], spots.table$z*axis_rescale[3], 58 | col = prop$color, radius=spot_radius, 59 | alpha = prop$alpha) 60 | decorate3d() 61 | } 62 | 63 | 64 | plot3D_clusters <- function(directory, 65 | clusters, 66 | cluster_colors, 67 | um=c(1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1), 68 | axis_rescale=c(1,1,1), 69 | spot_radius=0.5, 70 | alpha_threshold=0.2, 71 | alpha_background=0.02){ 72 | 73 | #load cluster assignments and 3D coordinates of spots 74 | cluster_df <- read.table(paste0(directory, "/clustering_result.csv"), sep=",", header=TRUE) 75 | colnames(cluster_df) <- c("spot", "cluster") 76 | coor_3d <- read.table(paste0(directory, "/3D_coordinates.csv"), sep=",", header=TRUE) 77 | colnames(coor_3d)[1] <- "spot" 78 | 79 | spots.table <- merge(coor_3d, cluster_df, by=c("spot")) 80 | spots.table$cluster <- as.integer(spots.table$cluster) 81 | 82 | #3D plot 83 | open3d(windowRect = c(0, 0, 720, 720)) 84 | par3d(persp) 85 | view3d(userMatrix = matrix(um, byrow=TRUE, nrow=4)) 86 | for (c in 1:length(clusters)){ 87 | spheres3d(spots.table[spots.table$cluster==clusters[c], ]$x*axis_rescale[1], 88 | spots.table[spots.table$cluster==clusters[c], ]$y*axis_rescale[2], 89 | spots.table[spots.table$cluster==clusters[c], ]$z*axis_rescale[3], 90 | col = cluster_colors[c], radius=spot_radius, 91 | alpha=1) 92 | } 93 | spheres3d(spots.table$x*axis_rescale[1], spots.table$y*axis_rescale[2], spots.table$z*axis_rescale[3], 94 | col='gray', radius=spot_radius, 95 | alpha=alpha_background) 96 | decorate3d() 97 | } 98 | --------------------------------------------------------------------------------