├── LICENSE ├── README.md ├── README.rst ├── STAGATE ├── STAGATE.py ├── Train_STAGATE.py ├── __init__.py ├── __pycache__ │ ├── STAGATE.cpython-37.pyc │ ├── STGATE.cpython-37.pyc │ ├── Train_STAGATE.cpython-37.pyc │ ├── Train_STGATE.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── model.cpython-37.pyc │ └── utils.cpython-37.pyc ├── model.py └── utils.py ├── STAGATE_Overview.png ├── requirement.txt ├── requirements_STAGATE_tensorflow_version.txt └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kangning Dong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # STAGATE 2 | [![DOI](https://zenodo.org/badge/398185411.svg)](https://zenodo.org/badge/latestdoi/398185411) 3 | 4 | ![](https://github.com/QIFEIDKN/STAGATE/blob/main/STAGATE_Overview.png) 5 | 6 | ## News 7 | 2022.03.05 STAGATE based on pyG (PyTorch Geometric) framework is availble at [STAGATE_pyG](https://github.com/QIFEIDKN/STAGATE_pyG). 8 | 9 | Benefit from the optimization of the pyG package for training graph neural networks, it is more than 10x faster than STAGATE based on the tensorflow1 framework, and can use a batch training strategy to deal with large-scale data. 10 | 11 | ## Overview 12 | STAGATE is designed for spatial clustering and denoising expressions of spatial resolved transcriptomics (ST) data. 13 | 14 | STAGATE learns low-dimensional latent embeddings with both spatial information and gene expressions via a graph attention auto-encoder. The method adopts an attention mechanism in the middle layer of the encoder and decoder, which adaptively learns the edge weights of spatial neighbor networks, and further uses them to update the spot representation by collectively aggregating information from its neighbors. The latent embeddings and the reconstructed expression profiles can be used to downstream tasks such as spatial domain identification, visualization, spatial trajectory inference, data denoising and 3D expression domain extraction. 15 | 16 | ## Getting started 17 | See [Documentation and Tutorials](https://stagate.readthedocs.io/en/latest/index.html). 18 | 19 | ## Software dependencies 20 | scanpy 21 | 22 | tensorflow==1.15.0 23 | 24 | ## Installation 25 | cd STAGATE-main 26 | 27 | python setup.py build 28 | 29 | python setup.py install 30 | 31 | ## Citation 32 | Dong, Kangning, and Shihua Zhang. "Deciphering spatial domains from spatially resolved transcriptomics with an adaptive graph attention auto-encoder." Nature Communications 13.1 (2022): 1-12. 33 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /STAGATE/STAGATE.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | import scipy.sparse as sp 3 | import numpy as np 4 | from .model import GATE 5 | from tqdm import tqdm 6 | 7 | class STAGATE(): 8 | 9 | def __init__(self, hidden_dims, alpha, n_epochs=500, lr=0.0001, 10 | gradient_clipping=5, nonlinear=True, weight_decay=0.0001, 11 | verbose=True, random_seed=2020): 12 | np.random.seed(random_seed) 13 | tf.set_random_seed(random_seed) 14 | self.loss_list = [] 15 | self.lr = lr 16 | self.n_epochs = n_epochs 17 | self.gradient_clipping = gradient_clipping 18 | self.build_placeholders() 19 | self.verbose = verbose 20 | self.alpha = alpha 21 | self.gate = GATE(hidden_dims, alpha, nonlinear, weight_decay) 22 | self.loss, self.H, self.C, self.ReX = self.gate(self.A, self.prune_A, self.X) 23 | self.optimize(self.loss) 24 | self.build_session() 25 | 26 | def build_placeholders(self): 27 | self.A = tf.sparse_placeholder(dtype=tf.float32) 28 | self.prune_A = tf.sparse_placeholder(dtype=tf.float32) 29 | self.X = tf.placeholder(dtype=tf.float32) 30 | 31 | def build_session(self, gpu= True): 32 | config = tf.ConfigProto() 33 | config.gpu_options.allow_growth = True 34 | if gpu == False: 35 | config.intra_op_parallelism_threads = 0 36 | config.inter_op_parallelism_threads = 0 37 | self.session = tf.Session(config=config) 38 | self.session.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) 39 | 40 | def optimize(self, loss): 41 | optimizer = tf.train.AdamOptimizer(learning_rate=self.lr) 42 | gradients, variables = zip(*optimizer.compute_gradients(loss)) 43 | gradients, _ = tf.clip_by_global_norm(gradients, self.gradient_clipping) 44 | self.train_op = optimizer.apply_gradients(zip(gradients, variables)) 45 | 46 | def __call__(self, A, prune_A, X): 47 | for epoch in tqdm(range(self.n_epochs)): 48 | self.run_epoch(epoch, A, prune_A, X) 49 | 50 | def run_epoch(self, epoch, A, prune_A, X): 51 | 52 | loss, _ = self.session.run([self.loss, self.train_op], 53 | feed_dict={self.A: A, 54 | self.prune_A: prune_A, 55 | self.X: X}) 56 | self.loss_list.append(loss) 57 | #if self.verbose: 58 | # print("Epoch: %s, Loss: %.4f" % (epoch, loss)) 59 | return loss 60 | 61 | def infer(self, A, prune_A, X): 62 | H, C, ReX = self.session.run([self.H, self.C, self.ReX], 63 | feed_dict={self.A: A, 64 | self.prune_A: prune_A, 65 | self.X: X}) 66 | 67 | return H, self.Conbine_Atten_l(C), self.loss_list, ReX 68 | 69 | def Conbine_Atten_l(self, input): 70 | if self.alpha == 0: 71 | return [sp.coo_matrix((input[layer][1], (input[layer][0][:, 0], input[layer][0][:, 1])), shape=(input[layer][2][0], input[layer][2][1])) for layer in input] 72 | else: 73 | Att_C = [sp.coo_matrix((input['C'][layer][1], (input['C'][layer][0][:, 0], input['C'][layer][0][:, 1])), shape=(input['C'][layer][2][0], input['C'][layer][2][1])) for layer in input['C']] 74 | Att_pruneC = [sp.coo_matrix((input['prune_C'][layer][1], (input['prune_C'][layer][0][:, 0], input['prune_C'][layer][0][:, 1])), shape=(input['prune_C'][layer][2][0], input['prune_C'][layer][2][1])) for layer in input['prune_C']] 75 | return [self.alpha*Att_pruneC[layer] + (1-self.alpha)*Att_C[layer] for layer in input['C']] 76 | -------------------------------------------------------------------------------- /STAGATE/Train_STAGATE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | from .STAGATE import STAGATE 4 | import tensorflow.compat.v1 as tf 5 | import pandas as pd 6 | import scanpy as sc 7 | 8 | def train_STAGATE(adata, hidden_dims=[512, 30], alpha=0, n_epochs=500, lr=0.0001, key_added='STAGATE', 9 | gradient_clipping=5, nonlinear=True, weight_decay=0.0001,verbose=True, 10 | random_seed=2020, pre_labels=None, pre_resolution=0.2, 11 | save_attention=False, save_loss=False, save_reconstrction=False): 12 | """\ 13 | Training graph attention auto-encoder. 14 | 15 | Parameters 16 | ---------- 17 | adata 18 | AnnData object of scanpy package. 19 | hidden_dims 20 | The dimension of the encoder. 21 | alpha 22 | The weight of cell type-aware spatial neighbor network. 23 | n_epochs 24 | Number of total epochs in training. 25 | lr 26 | Learning rate for AdamOptimizer. 27 | key_added 28 | The latent embeddings are saved in adata.obsm[key_added]. 29 | gradient_clipping 30 | Gradient Clipping. 31 | nonlinear 32 | If True, the nonlinear avtivation is performed. 33 | weight_decay 34 | Weight decay for AdamOptimizer. 35 | pre_labels 36 | The key in adata.obs for the manually designate the pre-clustering results. Only used when alpha>0. 37 | pre_resolution 38 | The resolution parameter of sc.tl.louvain for the pre-clustering. Only used when alpha>0 and per_labels==None. 39 | save_attention 40 | If True, the weights of the attention layers are saved in adata.uns['STAGATE_attention'] 41 | save_loss 42 | If True, the training loss is saved in adata.uns['STAGATE_loss']. 43 | save_reconstrction 44 | If True, the reconstructed expression profiles are saved in adata.layers['STAGATE_ReX']. 45 | 46 | Returns 47 | ------- 48 | AnnData 49 | """ 50 | 51 | tf.reset_default_graph() 52 | np.random.seed(random_seed) 53 | tf.set_random_seed(random_seed) 54 | if 'highly_variable' in adata.var.columns: 55 | adata_Vars = adata[:, adata.var['highly_variable']] 56 | else: 57 | adata_Vars = adata 58 | X = pd.DataFrame(adata_Vars.X.toarray()[:, ], index=adata_Vars.obs.index, columns=adata_Vars.var.index) 59 | if verbose: 60 | print('Size of Input: ', adata_Vars.shape) 61 | cells = np.array(X.index) 62 | cells_id_tran = dict(zip(cells, range(cells.shape[0]))) 63 | if 'Spatial_Net' not in adata.uns.keys(): 64 | raise ValueError("Spatial_Net is not existed! Run Cal_Spatial_Net first!") 65 | Spatial_Net = adata.uns['Spatial_Net'] 66 | G_df = Spatial_Net.copy() 67 | G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran) 68 | G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran) 69 | G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs)) 70 | G_tf = prepare_graph_data(G) 71 | 72 | 73 | trainer = STAGATE(hidden_dims=[X.shape[1]] + hidden_dims, alpha=alpha, 74 | n_epochs=n_epochs, lr=lr, gradient_clipping=gradient_clipping, 75 | nonlinear=nonlinear,weight_decay=weight_decay, verbose=verbose, 76 | random_seed=random_seed) 77 | if alpha == 0: 78 | trainer(G_tf, G_tf, X) 79 | embeddings, attentions, loss, ReX= trainer.infer(G_tf, G_tf, X) 80 | else: 81 | G_df = Spatial_Net.copy() 82 | if pre_labels==None: 83 | if verbose: 84 | print('------Pre-clustering using louvain with resolution=%.2f' %pre_resolution) 85 | sc.tl.pca(adata, svd_solver='arpack') 86 | sc.pp.neighbors(adata) 87 | sc.tl.louvain(adata, resolution=pre_resolution, key_added='expression_louvain_label') 88 | pre_labels = 'expression_louvain_label' 89 | prune_G_df = prune_spatial_Net(G_df, adata.obs[pre_labels]) 90 | prune_G_df['Cell1'] = prune_G_df['Cell1'].map(cells_id_tran) 91 | prune_G_df['Cell2'] = prune_G_df['Cell2'].map(cells_id_tran) 92 | prune_G = sp.coo_matrix((np.ones(prune_G_df.shape[0]), (prune_G_df['Cell1'], prune_G_df['Cell2']))) 93 | prune_G_tf = prepare_graph_data(prune_G) 94 | prune_G_tf = (prune_G_tf[0], prune_G_tf[1], G_tf[2]) 95 | trainer(G_tf, prune_G_tf, X) 96 | embeddings, attentions, loss, ReX = trainer.infer(G_tf, prune_G_tf, X) 97 | cell_reps = pd.DataFrame(embeddings) 98 | cell_reps.index = cells 99 | 100 | adata.obsm[key_added] = cell_reps.loc[adata.obs_names, ].values 101 | if save_attention: 102 | adata.uns['STAGATE_attention'] = attentions 103 | if save_loss: 104 | adata.uns['STAGATE_loss'] = loss 105 | if save_reconstrction: 106 | ReX = pd.DataFrame(ReX, index=X.index, columns=X.columns) 107 | ReX[ReX<0] = 0 108 | adata.layers['STAGATE_ReX'] = ReX.values 109 | return adata 110 | 111 | 112 | def prune_spatial_Net(Graph_df, label): 113 | print('------Pruning the graph...') 114 | print('%d edges before pruning.' %Graph_df.shape[0]) 115 | pro_labels_dict = dict(zip(list(label.index), label)) 116 | Graph_df['Cell1_label'] = Graph_df['Cell1'].map(pro_labels_dict) 117 | Graph_df['Cell2_label'] = Graph_df['Cell2'].map(pro_labels_dict) 118 | Graph_df = Graph_df.loc[Graph_df['Cell1_label']==Graph_df['Cell2_label'],] 119 | print('%d edges after pruning.' %Graph_df.shape[0]) 120 | return Graph_df 121 | 122 | 123 | def prepare_graph_data(adj): 124 | # adapted from preprocess_adj_bias 125 | num_nodes = adj.shape[0] 126 | adj = adj + sp.eye(num_nodes)# self-loop 127 | #data = adj.tocoo().data 128 | #adj[adj > 0.0] = 1.0 129 | if not sp.isspmatrix_coo(adj): 130 | adj = adj.tocoo() 131 | adj = adj.astype(np.float32) 132 | indices = np.vstack((adj.col, adj.row)).transpose() 133 | return (indices, adj.data, adj.shape) 134 | 135 | def recovery_Imputed_Count(adata, size_factor): 136 | assert('ReX' in adata.uns) 137 | temp_df = adata.uns['ReX'].copy() 138 | sf = size_factor.loc[temp_df.index] 139 | temp_df = np.expm1(temp_df) 140 | temp_df = (temp_df.T * sf).T 141 | adata.uns['ReX_Count'] = temp_df 142 | return adata 143 | -------------------------------------------------------------------------------- /STAGATE/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | # Author: Kangning Dong 4 | # File Name: __init__.py 5 | # Description: 6 | """ 7 | 8 | __author__ = "Kangning Dong" 9 | __email__ = "dongkangning16@mails.ucas.ac.cn" 10 | 11 | from .Train_STAGATE import train_STAGATE 12 | from .utils import Cal_Spatial_Net, Stats_Spatial_Net, mclust_R, Cal_Spatial_Net_3D 13 | -------------------------------------------------------------------------------- /STAGATE/__pycache__/STAGATE.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QIFEIDKN/STAGATE/48ce7f874c83a9f1f68187be00370181261ab7c5/STAGATE/__pycache__/STAGATE.cpython-37.pyc -------------------------------------------------------------------------------- /STAGATE/__pycache__/STGATE.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QIFEIDKN/STAGATE/48ce7f874c83a9f1f68187be00370181261ab7c5/STAGATE/__pycache__/STGATE.cpython-37.pyc -------------------------------------------------------------------------------- /STAGATE/__pycache__/Train_STAGATE.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QIFEIDKN/STAGATE/48ce7f874c83a9f1f68187be00370181261ab7c5/STAGATE/__pycache__/Train_STAGATE.cpython-37.pyc -------------------------------------------------------------------------------- /STAGATE/__pycache__/Train_STGATE.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QIFEIDKN/STAGATE/48ce7f874c83a9f1f68187be00370181261ab7c5/STAGATE/__pycache__/Train_STGATE.cpython-37.pyc -------------------------------------------------------------------------------- /STAGATE/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QIFEIDKN/STAGATE/48ce7f874c83a9f1f68187be00370181261ab7c5/STAGATE/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /STAGATE/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QIFEIDKN/STAGATE/48ce7f874c83a9f1f68187be00370181261ab7c5/STAGATE/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /STAGATE/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QIFEIDKN/STAGATE/48ce7f874c83a9f1f68187be00370181261ab7c5/STAGATE/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /STAGATE/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | 3 | class GATE(): 4 | 5 | def __init__(self, hidden_dims, alpha=0.8, nonlinear=True, weight_decay=0.0001): 6 | self.n_layers = len(hidden_dims) - 1 7 | self.alpha = alpha 8 | self.W, self.v, self.prune_v = self.define_weights(hidden_dims) 9 | self.C = {} 10 | self.prune_C = {} 11 | self.nonlinear = nonlinear 12 | self.weight_decay = weight_decay 13 | 14 | def __call__(self, A, prune_A, X): 15 | # Encoder 16 | H = X 17 | for layer in range(self.n_layers): 18 | H = self.__encoder(A, prune_A, H, layer) 19 | if self.nonlinear: 20 | if layer != self.n_layers-1: 21 | H = tf.nn.elu(H) 22 | # Final node representations 23 | self.H = H 24 | 25 | # Decoder 26 | for layer in range(self.n_layers - 1, -1, -1): 27 | H = self.__decoder(H, layer) 28 | if self.nonlinear: 29 | if layer != 0: 30 | H = tf.nn.elu(H) 31 | X_ = H 32 | 33 | # The reconstruction loss of node features 34 | features_loss = tf.sqrt(tf.reduce_sum(tf.reduce_sum(tf.pow(X - X_, 2)))) 35 | 36 | for layer in range(self.n_layers): 37 | weight_decay_loss = 0 38 | weight_decay_loss += tf.multiply(tf.nn.l2_loss(self.W[layer]), self.weight_decay, name='weight_loss') 39 | 40 | # Total loss 41 | self.loss = features_loss + weight_decay_loss 42 | 43 | if self.alpha == 0: 44 | self.Att_l = self.C 45 | else: 46 | #self.Att_l = {x: (1-self.alpha)*self.C[x] + self.alpha*self.prune_C[x] for x in self.C.keys()} 47 | self.Att_l = {'C': self.C, 'prune_C': self.prune_C} 48 | return self.loss, self.H, self.Att_l, X_ 49 | 50 | 51 | def __encoder(self, A, prune_A, H, layer): 52 | H = tf.matmul(H, self.W[layer]) 53 | if layer == self.n_layers-1: 54 | return H 55 | self.C[layer] = self.graph_attention_layer(A, H, self.v[layer], layer) 56 | if self.alpha == 0: 57 | return tf.sparse_tensor_dense_matmul(self.C[layer], H) 58 | else: 59 | self.prune_C[layer] = self.graph_attention_layer(prune_A, H, self.prune_v[layer], layer) 60 | return (1-self.alpha)*tf.sparse_tensor_dense_matmul(self.C[layer], H) + self.alpha*tf.sparse_tensor_dense_matmul(self.prune_C[layer], H) 61 | 62 | 63 | def __decoder(self, H, layer): 64 | H = tf.matmul(H, self.W[layer], transpose_b=True) 65 | if layer == 0: 66 | return H 67 | if self.alpha == 0: 68 | return tf.sparse_tensor_dense_matmul(self.C[layer-1], H) 69 | else: 70 | return (1-self.alpha)*tf.sparse_tensor_dense_matmul(self.C[layer-1], H) + self.alpha*tf.sparse_tensor_dense_matmul(self.prune_C[layer-1], H) 71 | 72 | 73 | def define_weights(self, hidden_dims): 74 | W = {} 75 | for i in range(self.n_layers): 76 | W[i] = tf.get_variable("W%s" % i, shape=(hidden_dims[i], hidden_dims[i+1])) 77 | 78 | Ws_att = {} 79 | for i in range(self.n_layers-1): 80 | v = {} 81 | v[0] = tf.get_variable("v%s_0" % i, shape=(hidden_dims[i+1], 1)) 82 | v[1] = tf.get_variable("v%s_1" % i, shape=(hidden_dims[i+1], 1)) 83 | 84 | Ws_att[i] = v 85 | if self.alpha == 0: 86 | return W, Ws_att, None 87 | prune_Ws_att = {} 88 | for i in range(self.n_layers-1): 89 | prune_v = {} 90 | prune_v[0] = tf.get_variable("prune_v%s_0" % i, shape=(hidden_dims[i+1], 1)) 91 | prune_v[1] = tf.get_variable("prune_v%s_1" % i, shape=(hidden_dims[i+1], 1)) 92 | 93 | prune_Ws_att[i] = prune_v 94 | 95 | return W, Ws_att, prune_Ws_att 96 | 97 | def graph_attention_layer(self, A, M, v, layer): 98 | 99 | with tf.variable_scope("layer_%s"% layer): 100 | f1 = tf.matmul(M, v[0]) 101 | f1 = A * f1 102 | f2 = tf.matmul(M, v[1]) 103 | f2 = A * tf.transpose(f2, [1, 0]) 104 | logits = tf.sparse_add(f1, f2) 105 | 106 | unnormalized_attentions = tf.SparseTensor(indices=logits.indices, 107 | values=tf.nn.sigmoid(logits.values), 108 | dense_shape=logits.dense_shape) 109 | attentions = tf.sparse_softmax(unnormalized_attentions) 110 | 111 | attentions = tf.SparseTensor(indices=attentions.indices, 112 | values=attentions.values, 113 | dense_shape=attentions.dense_shape) 114 | 115 | return attentions -------------------------------------------------------------------------------- /STAGATE/utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import sklearn.neighbors 4 | 5 | 6 | def Cal_Spatial_Net(adata, rad_cutoff=None, k_cutoff=None, model='Radius', verbose=True): 7 | """\ 8 | Construct the spatial neighbor networks. 9 | 10 | Parameters 11 | ---------- 12 | adata 13 | AnnData object of scanpy package. 14 | rad_cutoff 15 | radius cutoff when model='Radius' 16 | k_cutoff 17 | The number of nearest neighbors when model='KNN' 18 | model 19 | The network construction model. When model=='Radius', the spot is connected to spots whose distance is less than rad_cutoff. When model=='KNN', the spot is connected to its first k_cutoff nearest neighbors. 20 | 21 | Returns 22 | ------- 23 | The spatial networks are saved in adata.uns['Spatial_Net'] 24 | """ 25 | 26 | assert(model in ['Radius', 'KNN']) 27 | if verbose: 28 | print('------Calculating spatial graph...') 29 | coor = pd.DataFrame(adata.obsm['spatial']) 30 | coor.index = adata.obs.index 31 | coor.columns = ['imagerow', 'imagecol'] 32 | 33 | if model == 'Radius': 34 | nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(coor) 35 | distances, indices = nbrs.radius_neighbors(coor, return_distance=True) 36 | KNN_list = [] 37 | for it in range(indices.shape[0]): 38 | KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it]))) 39 | 40 | if model == 'KNN': 41 | nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff+1).fit(coor) 42 | distances, indices = nbrs.kneighbors(coor) 43 | KNN_list = [] 44 | for it in range(indices.shape[0]): 45 | KNN_list.append(pd.DataFrame(zip([it]*indices.shape[1],indices[it,:], distances[it,:]))) 46 | 47 | KNN_df = pd.concat(KNN_list) 48 | KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] 49 | 50 | Spatial_Net = KNN_df.copy() 51 | Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,] 52 | id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), )) 53 | Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans) 54 | Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans) 55 | if verbose: 56 | print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs)) 57 | print('%.4f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs)) 58 | 59 | adata.uns['Spatial_Net'] = Spatial_Net 60 | 61 | 62 | def Cal_Spatial_Net_3D(adata, rad_cutoff_2D, rad_cutoff_Zaxis, 63 | key_section='Section_id', section_order=None, verbose=True): 64 | """\ 65 | Construct the spatial neighbor networks. 66 | 67 | Parameters 68 | ---------- 69 | adata 70 | AnnData object of scanpy package. 71 | rad_cutoff_2D 72 | radius cutoff for 2D SNN construction. 73 | rad_cutoff_Zaxis 74 | radius cutoff for 2D SNN construction for consturcting SNNs between adjacent sections. 75 | key_section 76 | The columns names of section_ID in adata.obs. 77 | section_order 78 | The order of sections. The SNNs between adjacent sections are constructed according to this order. 79 | 80 | Returns 81 | ------- 82 | The 3D spatial networks are saved in adata.uns['Spatial_Net']. 83 | """ 84 | adata.uns['Spatial_Net_2D'] = pd.DataFrame() 85 | adata.uns['Spatial_Net_Zaxis'] = pd.DataFrame() 86 | num_section = np.unique(adata.obs[key_section]).shape[0] 87 | if verbose: 88 | print('Radius used for 2D SNN:', rad_cutoff_2D) 89 | print('Radius used for SNN between sections:', rad_cutoff_Zaxis) 90 | for temp_section in np.unique(adata.obs[key_section]): 91 | if verbose: 92 | print('------Calculating 2D SNN of section ', temp_section) 93 | temp_adata = adata[adata.obs[key_section] == temp_section, ] 94 | Cal_Spatial_Net( 95 | temp_adata, rad_cutoff=rad_cutoff_2D, verbose=False) 96 | temp_adata.uns['Spatial_Net']['SNN'] = temp_section 97 | if verbose: 98 | print('This graph contains %d edges, %d cells.' % 99 | (temp_adata.uns['Spatial_Net'].shape[0], temp_adata.n_obs)) 100 | print('%.4f neighbors per cell on average.' % 101 | (temp_adata.uns['Spatial_Net'].shape[0]/temp_adata.n_obs)) 102 | adata.uns['Spatial_Net_2D'] = pd.concat( 103 | [adata.uns['Spatial_Net_2D'], temp_adata.uns['Spatial_Net']]) 104 | for it in range(num_section-1): 105 | section_1 = section_order[it] 106 | section_2 = section_order[it+1] 107 | if verbose: 108 | print('------Calculating SNN between adjacent section %s and %s.' % 109 | (section_1, section_2)) 110 | Z_Net_ID = section_1+'-'+section_2 111 | temp_adata = adata[adata.obs[key_section].isin( 112 | [section_1, section_2]), ] 113 | Cal_Spatial_Net( 114 | temp_adata, rad_cutoff=rad_cutoff_Zaxis, verbose=False) 115 | spot_section_trans = dict( 116 | zip(temp_adata.obs.index, temp_adata.obs[key_section])) 117 | temp_adata.uns['Spatial_Net']['Section_id_1'] = temp_adata.uns['Spatial_Net']['Cell1'].map( 118 | spot_section_trans) 119 | temp_adata.uns['Spatial_Net']['Section_id_2'] = temp_adata.uns['Spatial_Net']['Cell2'].map( 120 | spot_section_trans) 121 | used_edge = temp_adata.uns['Spatial_Net'].apply( 122 | lambda x: x['Section_id_1'] != x['Section_id_2'], axis=1) 123 | temp_adata.uns['Spatial_Net'] = temp_adata.uns['Spatial_Net'].loc[used_edge, ] 124 | temp_adata.uns['Spatial_Net'] = temp_adata.uns['Spatial_Net'].loc[:, [ 125 | 'Cell1', 'Cell2', 'Distance']] 126 | temp_adata.uns['Spatial_Net']['SNN'] = Z_Net_ID 127 | if verbose: 128 | print('This graph contains %d edges, %d cells.' % 129 | (temp_adata.uns['Spatial_Net'].shape[0], temp_adata.n_obs)) 130 | print('%.4f neighbors per cell on average.' % 131 | (temp_adata.uns['Spatial_Net'].shape[0]/temp_adata.n_obs)) 132 | adata.uns['Spatial_Net_Zaxis'] = pd.concat( 133 | [adata.uns['Spatial_Net_Zaxis'], temp_adata.uns['Spatial_Net']]) 134 | adata.uns['Spatial_Net'] = pd.concat( 135 | [adata.uns['Spatial_Net_2D'], adata.uns['Spatial_Net_Zaxis']]) 136 | if verbose: 137 | print('3D SNN contains %d edges, %d cells.' % 138 | (adata.uns['Spatial_Net'].shape[0], adata.n_obs)) 139 | print('%.4f neighbors per cell on average.' % 140 | (adata.uns['Spatial_Net'].shape[0]/adata.n_obs)) 141 | 142 | def Stats_Spatial_Net(adata): 143 | import matplotlib.pyplot as plt 144 | Num_edge = adata.uns['Spatial_Net']['Cell1'].shape[0] 145 | Mean_edge = Num_edge/adata.shape[0] 146 | plot_df = pd.value_counts(pd.value_counts(adata.uns['Spatial_Net']['Cell1'])) 147 | plot_df = plot_df/adata.shape[0] 148 | fig, ax = plt.subplots(figsize=[3,2]) 149 | plt.ylabel('Percentage') 150 | plt.xlabel('') 151 | plt.title('Number of Neighbors (Mean=%.2f)'%Mean_edge) 152 | ax.bar(plot_df.index, plot_df) 153 | 154 | def mclust_R(adata, num_cluster, modelNames='EEE', used_obsm='STAGATE', random_seed=2020): 155 | """\ 156 | Clustering using the mclust algorithm. 157 | The parameters are the same as those in the R package mclust. 158 | """ 159 | 160 | np.random.seed(random_seed) 161 | import rpy2.robjects as robjects 162 | robjects.r.library("mclust") 163 | 164 | import rpy2.robjects.numpy2ri 165 | rpy2.robjects.numpy2ri.activate() 166 | r_random_seed = robjects.r['set.seed'] 167 | r_random_seed(random_seed) 168 | rmclust = robjects.r['Mclust'] 169 | 170 | res = rmclust(rpy2.robjects.numpy2ri.numpy2rpy(adata.obsm[used_obsm]), num_cluster, modelNames) 171 | mclust_res = np.array(res[-2]) 172 | 173 | adata.obs['mclust'] = mclust_res 174 | adata.obs['mclust'] = adata.obs['mclust'].astype('int') 175 | adata.obs['mclust'] = adata.obs['mclust'].astype('category') 176 | return adata 177 | -------------------------------------------------------------------------------- /STAGATE_Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QIFEIDKN/STAGATE/48ce7f874c83a9f1f68187be00370181261ab7c5/STAGATE_Overview.png -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.18.2 2 | pandas>=1.1.4 3 | scipy>=1.4.1 4 | scikit-learn>=0.23.2 5 | tqdm>=4.61.0 6 | tensorflow==1.15.0 7 | matplotlib>=3.3.3 8 | scanpy>=1.6.1 -------------------------------------------------------------------------------- /requirements_STAGATE_tensorflow_version.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.14.1 2 | anndata==0.7.6 3 | annoy==1.17.0 4 | anyio==3.5.0 5 | argcomplete==1.12.3 6 | argon2-cffi==21.3.0 7 | argon2-cffi-bindings==21.2.0 8 | asciitree==0.3.3 9 | astor==0.8.1 10 | attrs==21.2.0 11 | Babel==2.10.1 12 | backcall==0.2.0 13 | backports.zoneinfo==0.2.1 14 | bbknn==1.5.1 15 | beautifulsoup4==4.11.1 16 | bleach==5.0.0 17 | cached-property==1.5.2 18 | certifi==2021.10.8 19 | cffi==1.15.0 20 | charset-normalizer==2.0.12 21 | cloudpickle==2.0.0 22 | cycler==0.10.0 23 | Cython==0.29.25 24 | dask==2022.1.0 25 | dask-image==2021.12.0 26 | debugpy==1.5.1 27 | decorator==5.1.0 28 | defusedxml==0.7.1 29 | docrep==0.3.2 30 | entrypoints==0.3 31 | fasteners==0.17.2 32 | fastjsonschema==2.15.3 33 | fsspec==2022.1.0 34 | gast==0.2.2 35 | google-pasta==0.2.0 36 | grpcio==1.41.0 37 | h5py==3.4.0 38 | harmonypy==0.0.5 39 | hnswlib==0.5.2 40 | idna==3.3 41 | igraph==0.9.8 42 | imageio==2.13.5 43 | importlib-metadata==4.8.1 44 | importlib-resources==5.7.1 45 | inflect==5.3.0 46 | iniconfig==1.1.1 47 | intervaltree==3.1.0 48 | ipykernel==6.6.0 49 | ipython==7.30.1 50 | ipython-genutils==0.2.0 51 | ivis==1.7.2 52 | jedi==0.18.1 53 | Jinja2==3.0.3 54 | joblib==1.1.0 55 | json5==0.9.6 56 | jsonschema==4.4.0 57 | jupyter-client==7.1.0 58 | jupyter-core==4.9.1 59 | jupyter-server==1.17.0 60 | jupyterlab==3.3.4 61 | jupyterlab-pygments==0.2.2 62 | jupyterlab-server==2.13.0 63 | Keras-Applications==1.0.8 64 | Keras-Preprocessing==1.1.2 65 | kiwisolver==1.3.2 66 | kneed==0.7.0 67 | leidenalg==0.8.8 68 | llvmlite==0.37.0 69 | louvain==0.7.0 70 | Markdown==3.3.4 71 | MarkupSafe==2.0.1 72 | matplotlib==3.4.3 73 | matplotlib-inline==0.1.3 74 | mistune==0.8.4 75 | morphops==0.1.13 76 | natsort==7.1.1 77 | nbclassic==0.3.7 78 | nbclient==0.6.0 79 | nbconvert==6.5.0 80 | nbformat==5.3.0 81 | nest-asyncio==1.5.4 82 | networkx==2.6.3 83 | notebook==6.4.11 84 | notebook-shim==0.1.0 85 | numba==0.54.1 86 | numcodecs==0.9.1 87 | numexpr==2.7.3 88 | numpy==1.20.0 89 | omnipath==1.0.5 90 | opt-einsum==3.3.0 91 | packaging==21.0 92 | pandas==1.3.3 93 | pandocfilters==1.5.0 94 | parso==0.8.3 95 | partd==1.2.0 96 | patsy==0.5.2 97 | pexpect==4.8.0 98 | pickleshare==0.7.5 99 | Pillow==8.3.2 100 | PIMS==0.5 101 | pluggy==1.0.0 102 | prometheus-client==0.14.1 103 | prompt-toolkit==3.0.24 104 | protobuf==3.18.1 105 | ptyprocess==0.7.0 106 | py==1.11.0 107 | pycparser==2.21 108 | Pygments==2.10.0 109 | pynndescent==0.5.4 110 | pyparsing==2.4.7 111 | pyrsistent==0.18.1 112 | pytest==6.2.5 113 | python-dateutil==2.8.2 114 | python-igraph==0.9.8 115 | pytz==2021.3 116 | pytz-deprecation-shim==0.1.0.post0 117 | PyWavelets==1.2.0 118 | pyzmq==22.3.0 119 | requests==2.27.1 120 | rpy2==3.1.0 121 | scanpy==1.8.1 122 | scikit-image==0.19.1 123 | scikit-learn==1.0 124 | scikit-misc==0.1.4 125 | scipy==1.7.1 126 | seaborn==0.11.2 127 | Send2Trash==1.8.0 128 | simplegeneric==0.8.1 129 | sinfo==0.3.4 130 | six==1.16.0 131 | sklearn==0.0 132 | sniffio==1.2.0 133 | sortedcontainers==2.4.0 134 | soupsieve==2.3.2.post1 135 | spatial-eggplant==0.1 136 | squidpy==1.1.2 137 | STAGATE==1.0.1 138 | statsmodels==0.13.0 139 | stdlib-list==0.8.0 140 | tables==3.6.1 141 | tensorboard==1.15.0 142 | tensorflow-estimator==1.15.1 143 | tensorflow-gpu==1.15.2 144 | termcolor==1.1.0 145 | terminado==0.13.3 146 | texttable==1.6.4 147 | threadpoolctl==3.0.0 148 | tifffile==2021.11.2 149 | tinycss2==1.1.1 150 | toml==0.10.2 151 | toolz==0.11.2 152 | tornado==6.1 153 | tqdm==4.62.3 154 | traitlets==5.1.1 155 | typing-extensions==3.10.0.2 156 | tzdata==2021.5 157 | tzlocal==4.1 158 | umap-learn==0.5.1 159 | urllib3==1.26.9 160 | wcwidth==0.2.5 161 | webencodings==0.5.1 162 | websocket-client==1.3.2 163 | Werkzeug==2.0.2 164 | wrapt==1.13.1 165 | xarray==0.20.2 166 | xlrd==1.2.0 167 | zarr==2.11.0a2 168 | zipp==3.6.0 169 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import Command, find_packages, setup 2 | 3 | __lib_name__ = "STAGATE" 4 | __lib_version__ = "1.0.1" 5 | __description__ = "Deciphering spatial domains from spatially resolved transcriptomics with adaptive graph attention auto-encoder" 6 | __url__ = "https://github.com/QIFEIDKN/STAGATE" 7 | __author__ = "Kangning Dong" 8 | __author_email__ = "dongkangning16@mails.ucas.ac.cn" 9 | __license__ = "MIT" 10 | __keywords__ = ["spatial transcriptomics", "Deep learning", "Graph attention auto-encoder"] 11 | __requires__ = ["requests",] 12 | 13 | with open("README.rst", "r", encoding="utf-8") as f: 14 | __long_description__ = f.read() 15 | 16 | setup( 17 | name = __lib_name__, 18 | version = __lib_version__, 19 | description = __description__, 20 | url = __url__, 21 | author = __author__, 22 | author_email = __author_email__, 23 | license = __license__, 24 | packages = ['STAGATE'], 25 | install_requires = __requires__, 26 | zip_safe = False, 27 | include_package_data = True, 28 | long_description = __long_description__ 29 | ) 30 | --------------------------------------------------------------------------------