├── LICENSE ├── README.md ├── STAligner ├── STALIGNER.py ├── ST_utils.py ├── __init__.py ├── __pycache__ │ ├── STALIGNER.cpython-38.pyc │ ├── ST_utils.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── gat_conv.cpython-38.pyc │ ├── mnn_utils.cpython-38.pyc │ └── train_STAligner.cpython-38.pyc ├── gat_conv.py ├── mnn_utils.py └── train_STAligner.py ├── Tutorials ├── Tutorial_3D_alignment.ipynb ├── Tutorial_Cross_Platforms.ipynb ├── Tutorial_DLPFC.ipynb ├── Tutorial_DLPFC_12S.ipynb └── Tutorial_embryo.ipynb ├── __init__.py ├── requirement.txt ├── requirement_for_macOS.txt └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Xiang Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # STAligner 2 | 3 | ![STAligner_Overview](https://github.com/zhoux85/STAligner/assets/31464727/1358f6b0-75ed-4bdd-9d0b-257788dff73a) 4 | 5 | 6 | ## Overview 7 | 8 | STAligner is designed for alignment and integration of spatially resolved transcriptomics data. 9 | 10 | **a**. STAligner first normalizes the expression profiles for all spots and constructs a spatial neighbor network using the spatial coordinates. STAligner further employs a graph attention auto-encoder neural network to extract spatially aware embedding, and constructs the spot triplets based on current embeddings to guide the alignment process by attracting similar spots and discriminating dissimilar spots across slices. STAligner introduces the triplet loss to update the spot embedding to reduce the distance from the anchor to positive spot, and increase the distance from the anchor to negative spot. The triplet construction and auto-encoder training are optimized iteratively until batch-corrected embeddings are generated. **b**. STAligner can be applied to integrate ST datasets to achieve alignment and simultaneous identification of spatial domains from different biological samples in (**a**), technological platforms (I), developmental (embryonic) stages (II), disease conditions (III) and consecutive slices of a tissue for 3D slice alignment (IV). 11 | 12 | 13 | 14 | ## Installation 15 | The STAligner package is developed based on the Python libraries [Scanpy](https://scanpy.readthedocs.io/en/stable/), [PyTorch](https://pytorch.org/) and [PyG](https://github.com/pyg-team/pytorch_geometric) (*PyTorch Geometric*) framework, and can be run on GPU (recommend) or CPU. 16 | 17 | 18 | 19 | First clone the repository. 20 | 21 | ``` 22 | git clone https://github.com/zhoux85/STAligner.git 23 | cd STAligner-main 24 | ``` 25 | 26 | It's recommended to create a separate conda environment for running STAligner: 27 | 28 | ``` 29 | #create an environment called env_STAligner 30 | conda create -n env_STAligner python=3.8 31 | 32 | #activate your environment 33 | conda activate env_STAligner 34 | ``` 35 | 36 | Install all the required packages. 37 | 38 | For Linux 39 | ``` 40 | pip install -r requirement.txt 41 | ``` 42 | For MacOS 43 | ``` 44 | pip install -r requirement_for_macOS.txt 45 | ``` 46 | 47 | The use of the mclust algorithm requires the rpy2 package (Python) and the mclust package (R). See https://pypi.org/project/rpy2/ and https://cran.r-project.org/web/packages/mclust/index.html for detail. 48 | 49 | The torch-geometric library is also required, please see the installation steps in https://github.com/pyg-team/pytorch_geometric#installation 50 | 51 | Install STAligner. 52 | 53 | ``` 54 | python setup.py build 55 | python setup.py install 56 | ``` 57 | 58 | 59 | 60 | ## Tutorials 61 | 62 | Three step-by-step tutorials are included in the `Tutorial` folder and https://staligner.readthedocs.io/en/latest/ to show how to use STAligner. 63 | 64 | - Tutorial 1: Integrating 4 adjacent DLPFC slices (10x Visium) 65 | - Tutorial 2: Integrating all 12 DLPFC slices from 3 adult samples (10x Visium) 66 | - Tutorial 3: Integrating slices across sequencing platforms (Slide-seqV2 and Stereo-seq) 67 | - Tutorial 4: Integrating 4 mouse embryo slices sampled at the time stages of E9.5, E10.5, E11.5, and E12.5 (Stereo-seq) 68 | - Tutorial 5: Spatial domain guided 3D slices alignment (Slide-seq) 69 | 70 | 71 | 72 | ## Support 73 | 74 | If you have any questions, please feel free to contact us [xzhou@amss.ac.cn](mailto:xzhou@amss.ac.cn). 75 | 76 | 77 | 78 | ## Citation 79 | Zhou, X., Dong, K. & Zhang, S. Integrating spatial transcriptomics data across different conditions, technologies and developmental stages. Nat Comput Sci 3, 894–906 (2023). https://doi.org/10.1038/s43588-023-00528-w 80 | 81 | -------------------------------------------------------------------------------- /STAligner/STALIGNER.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | cudnn.deterministic = True 7 | cudnn.benchmark = True 8 | import torch.nn.functional as F 9 | from .gat_conv import GATConv 10 | 11 | class STAligner(torch.nn.Module): 12 | def __init__(self, hidden_dims): 13 | super(STAligner, self).__init__() 14 | 15 | [in_dim, num_hidden, out_dim] = hidden_dims 16 | self.conv1 = GATConv(in_dim, num_hidden, heads=1, concat=False, 17 | dropout=0, add_self_loops=False, bias=False) 18 | self.conv2 = GATConv(num_hidden, out_dim, heads=1, concat=False, 19 | dropout=0, add_self_loops=False, bias=False) 20 | self.conv3 = GATConv(out_dim, num_hidden, heads=1, concat=False, 21 | dropout=0, add_self_loops=False, bias=False) 22 | self.conv4 = GATConv(num_hidden, in_dim, heads=1, concat=False, 23 | dropout=0, add_self_loops=False, bias=False) 24 | 25 | def forward(self, features, edge_index): 26 | 27 | h1 = F.elu(self.conv1(features, edge_index)) 28 | h2 = self.conv2(h1, edge_index, attention=False) 29 | self.conv3.lin_src.data = self.conv2.lin_src.transpose(0, 1) 30 | self.conv3.lin_dst.data = self.conv2.lin_dst.transpose(0, 1) 31 | self.conv4.lin_src.data = self.conv1.lin_src.transpose(0, 1) 32 | self.conv4.lin_dst.data = self.conv1.lin_dst.transpose(0, 1) 33 | h3 = F.elu(self.conv3(h2, edge_index, attention=True, 34 | tied_attention=self.conv1.attentions)) 35 | h4 = self.conv4(h3, edge_index, attention=False) 36 | 37 | return h2, h4 # F.log_softmax(x, dim=-1) 38 | -------------------------------------------------------------------------------- /STAligner/ST_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import pandas as pd 3 | import numpy as np 4 | import sklearn.neighbors 5 | import networkx as nx 6 | from .mnn_utils import create_dictionary_mnn 7 | 8 | def match_cluster_labels(true_labels,est_labels): 9 | true_labels_arr = np.array(list(true_labels)) 10 | est_labels_arr = np.array(list(est_labels)) 11 | org_cat = list(np.sort(list(pd.unique(true_labels)))) 12 | est_cat = list(np.sort(list(pd.unique(est_labels)))) 13 | B = nx.Graph() 14 | B.add_nodes_from([i+1 for i in range(len(org_cat))], bipartite=0) 15 | B.add_nodes_from([-j-1 for j in range(len(est_cat))], bipartite=1) 16 | for i in range(len(org_cat)): 17 | for j in range(len(est_cat)): 18 | weight = np.sum((true_labels_arr==org_cat[i])* (est_labels_arr==est_cat[j])) 19 | B.add_edge(i+1,-j-1, weight=-weight) 20 | match = nx.algorithms.bipartite.matching.minimum_weight_full_matching(B) 21 | # match = minimum_weight_full_matching(B) 22 | if len(org_cat)>=len(est_cat): 23 | return np.array([match[-est_cat.index(c)-1]-1 for c in est_labels_arr]) 24 | else: 25 | unmatched = [c for c in est_cat if not (-est_cat.index(c)-1) in match.keys()] 26 | l = [] 27 | for c in est_labels_arr: 28 | if (-est_cat.index(c)-1) in match: 29 | l.append(match[-est_cat.index(c)-1]-1) 30 | else: 31 | l.append(len(org_cat)+unmatched.index(c)) 32 | return np.array(l) 33 | 34 | 35 | def Cal_Spatial_Net(adata, rad_cutoff=None, k_cutoff=None, 36 | max_neigh=50, model='Radius', verbose=True): 37 | """\ 38 | Construct the spatial neighbor networks. 39 | 40 | Parameters 41 | ---------- 42 | adata 43 | AnnData object of scanpy package. 44 | rad_cutoff 45 | radius cutoff when model='Radius' 46 | k_cutoff 47 | The number of nearest neighbors when model='KNN' 48 | model 49 | The network construction model. When model=='Radius', the spot is connected to spots whose distance is less than rad_cutoff. When model=='KNN', the spot is connected to its first k_cutoff nearest neighbors. 50 | 51 | Returns 52 | ------- 53 | The spatial networks are saved in adata.uns['Spatial_Net'] 54 | """ 55 | 56 | assert (model in ['Radius', 'KNN']) 57 | if verbose: 58 | print('------Calculating spatial graph...') 59 | coor = pd.DataFrame(adata.obsm['spatial']) 60 | coor.index = adata.obs.index 61 | coor.columns = ['imagerow', 'imagecol'] 62 | 63 | nbrs = sklearn.neighbors.NearestNeighbors( 64 | n_neighbors=max_neigh + 1, algorithm='ball_tree').fit(coor) 65 | distances, indices = nbrs.kneighbors(coor) 66 | if model == 'KNN': 67 | indices = indices[:, 1:k_cutoff + 1] 68 | distances = distances[:, 1:k_cutoff + 1] 69 | if model == 'Radius': 70 | indices = indices[:, 1:] 71 | distances = distances[:, 1:] 72 | 73 | KNN_list = [] 74 | for it in range(indices.shape[0]): 75 | KNN_list.append(pd.DataFrame(zip([it] * indices.shape[1], indices[it, :], distances[it, :]))) 76 | KNN_df = pd.concat(KNN_list) 77 | KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] 78 | 79 | Spatial_Net = KNN_df.copy() 80 | if model == 'Radius': 81 | Spatial_Net = KNN_df.loc[KNN_df['Distance'] < rad_cutoff,] 82 | id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), )) 83 | Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans) 84 | Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans) 85 | # self_loops = pd.DataFrame(zip(Spatial_Net['Cell1'].unique(), Spatial_Net['Cell1'].unique(), 86 | # [0] * len((Spatial_Net['Cell1'].unique())))) ###add self loops 87 | # self_loops.columns = ['Cell1', 'Cell2', 'Distance'] 88 | # Spatial_Net = pd.concat([Spatial_Net, self_loops], axis=0) 89 | 90 | if verbose: 91 | print('The graph contains %d edges, %d cells.' % (Spatial_Net.shape[0], adata.n_obs)) 92 | print('%.4f neighbors per cell on average.' % (Spatial_Net.shape[0] / adata.n_obs)) 93 | adata.uns['Spatial_Net'] = Spatial_Net 94 | 95 | ######### 96 | X = pd.DataFrame(adata.X.toarray()[:, ], index=adata.obs.index, columns=adata.var.index) 97 | cells = np.array(X.index) 98 | cells_id_tran = dict(zip(cells, range(cells.shape[0]))) 99 | if 'Spatial_Net' not in adata.uns.keys(): 100 | raise ValueError("Spatial_Net is not existed! Run Cal_Spatial_Net first!") 101 | 102 | Spatial_Net = adata.uns['Spatial_Net'] 103 | G_df = Spatial_Net.copy() 104 | G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran) 105 | G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran) 106 | G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs)) 107 | G = G + sp.eye(G.shape[0]) # self-loop 108 | adata.uns['adj'] = G 109 | 110 | 111 | def Stats_Spatial_Net(adata): 112 | import matplotlib.pyplot as plt 113 | Num_edge = adata.uns['Spatial_Net']['Cell1'].shape[0] 114 | Mean_edge = Num_edge / adata.shape[0] 115 | plot_df = pd.value_counts(pd.value_counts(adata.uns['Spatial_Net']['Cell1'])) 116 | plot_df = plot_df / adata.shape[0] 117 | fig, ax = plt.subplots(figsize=[3, 2]) 118 | plt.ylabel('Percentage') 119 | plt.xlabel('') 120 | plt.title('Number of Neighbors (Mean=%.2f)' % Mean_edge) 121 | ax.bar(plot_df.index, plot_df) 122 | plt.show() 123 | 124 | 125 | def mclust_R(adata, num_cluster, modelNames='EEE', used_obsm='STAGATE', random_seed=666): 126 | """\ 127 | Clustering using the mclust algorithm. 128 | The parameters are the same as those in the R package mclust. 129 | """ 130 | 131 | np.random.seed(random_seed) 132 | import rpy2.robjects as robjects 133 | robjects.r.library("mclust") 134 | 135 | import rpy2.robjects.numpy2ri 136 | rpy2.robjects.numpy2ri.activate() 137 | r_random_seed = robjects.r['set.seed'] 138 | r_random_seed(random_seed) 139 | rmclust = robjects.r['Mclust'] 140 | 141 | res = rmclust(adata.obsm[used_obsm], num_cluster, modelNames) 142 | mclust_res = np.array(res[-2]) 143 | 144 | adata.obs['mclust'] = mclust_res 145 | adata.obs['mclust'] = adata.obs['mclust'].astype('int') 146 | adata.obs['mclust'] = adata.obs['mclust'].astype('category') 147 | return adata 148 | 149 | import scipy.sparse as sp 150 | def prepare_graph_data(adj): 151 | # adapted from preprocess_adj_bias 152 | num_nodes = adj.shape[0] 153 | # adj = adj + sp.eye(num_nodes)# self-loop ##new !! 154 | #data = adj.tocoo().data 155 | #adj[adj > 0.0] = 1.0 156 | if not sp.isspmatrix_coo(adj): 157 | adj = adj.tocoo() 158 | adj = adj.astype(np.float32) 159 | indices = np.vstack((adj.col, adj.row)).transpose() 160 | 161 | # adj = normalize(adj, norm="l1") 162 | 163 | return (adj, indices, adj.data, adj.shape) 164 | 165 | 166 | def prune_spatial_Net(Graph_df, label): 167 | print('------Pruning the graph...') 168 | print('%d edges before pruning.' %Graph_df.shape[0]) 169 | pro_labels_dict = dict(zip(list(label.index), label)) 170 | Graph_df['Cell1_label'] = Graph_df['Cell1'].map(pro_labels_dict) 171 | Graph_df['Cell2_label'] = Graph_df['Cell2'].map(pro_labels_dict) 172 | Graph_df = Graph_df.loc[Graph_df['Cell1_label']==Graph_df['Cell2_label'],] 173 | print('%d edges after pruning.' %Graph_df.shape[0]) 174 | return Graph_df 175 | 176 | 177 | # https://github.com/ClayFlannigan/icp 178 | def best_fit_transform(A, B): 179 | ''' 180 | Calculates the least-squares best-fit transform that maps corresponding points A to B in m spatial dimensions 181 | Input: 182 | A: Nxm numpy array of corresponding points 183 | B: Nxm numpy array of corresponding points 184 | Returns: 185 | T: (m+1)x(m+1) homogeneous transformation matrix that maps A on to B 186 | R: mxm rotation matrix 187 | t: mx1 translation vector 188 | ''' 189 | 190 | # assert A.shape == B.shape 191 | 192 | # get number of dimensions 193 | m = A.shape[1] 194 | 195 | # translate points to their centroids 196 | centroid_A = np.mean(A, axis=0) 197 | centroid_B = np.mean(B, axis=0) 198 | AA = A - centroid_A 199 | BB = B - centroid_B 200 | 201 | # rotation matrix 202 | H = np.dot(AA.T, BB) 203 | U, S, Vt = np.linalg.svd(H) 204 | R = np.dot(Vt.T, U.T) 205 | 206 | # special reflection case 207 | if np.linalg.det(R) < 0: 208 | Vt[m-1,:] *= -1 209 | R = np.dot(Vt.T, U.T) 210 | 211 | # translation 212 | t = centroid_B.T - np.dot(R,centroid_A.T) 213 | 214 | # homogeneous transformation 215 | T = np.identity(m+1) 216 | T[:m, :m] = R 217 | T[:m, m] = t 218 | 219 | return T, R, t 220 | 221 | def ICP_align(adata_concat, adata_target, adata_ref, slice_target, slice_ref, landmark_domain, plot_align=False): 222 | ### find MNN pairs in the landmark domain with knn=1 223 | adata_slice1 = adata_target[adata_target.obs['louvain'].isin(landmark_domain)] 224 | adata_slice2 = adata_ref[adata_ref.obs['louvain'].isin(landmark_domain)] 225 | 226 | 227 | batch_pair = adata_concat[adata_concat.obs['batch_name'].isin([slice_target, slice_ref]) & adata_concat.obs['louvain'].isin(landmark_domain)] 228 | mnn_dict = create_dictionary_mnn(batch_pair, use_rep='STAligner', batch_name='batch_name', k=1, iter_comb=None, verbose=0) 229 | adata_1 = batch_pair[batch_pair.obs['batch_name']==slice_target] 230 | adata_2 = batch_pair[batch_pair.obs['batch_name']==slice_ref] 231 | 232 | anchor_list = [] 233 | positive_list = [] 234 | for batch_pair_name in mnn_dict.keys(): 235 | for anchor in mnn_dict[batch_pair_name].keys(): 236 | positive_spot = mnn_dict[batch_pair_name][anchor][0] 237 | ### anchor should only in the ref slice, pos only in the target slice 238 | if anchor in adata_1.obs_names and positive_spot in adata_2.obs_names: 239 | anchor_list.append(anchor) 240 | positive_list.append(positive_spot) 241 | 242 | batch_as_dict = dict(zip(list(adata_concat.obs_names), range(0, adata_concat.shape[0]))) 243 | anchor_ind = list(map(lambda _: batch_as_dict[_], anchor_list)) 244 | positive_ind = list(map(lambda _: batch_as_dict[_], positive_list)) 245 | anchor_arr = adata_concat.obsm['STAligner'][anchor_ind, ] 246 | positive_arr = adata_concat.obsm['STAligner'][positive_ind, ] 247 | dist_list = [np.sqrt(np.sum(np.square(anchor_arr[ii, :] - positive_arr[ii, :]))) for ii in range(anchor_arr.shape[0])] 248 | 249 | 250 | key_points_src = np.array(anchor_list)[dist_list < np.percentile(dist_list, 50)] ## remove remote outliers 251 | key_points_dst = np.array(positive_list)[dist_list < np.percentile(dist_list, 50)] 252 | #print(len(anchor_list), len(key_points_src)) 253 | 254 | coor_src = adata_slice1.obsm["spatial"] ## to_be_aligned 255 | coor_dst = adata_slice2.obsm["spatial"] ## reference_points 256 | 257 | ## index number 258 | MNN_ind_src = [list(adata_1.obs_names).index(key_points_src[ii]) for ii in range(len(key_points_src))] 259 | MNN_ind_dst = [list(adata_2.obs_names).index(key_points_dst[ii]) for ii in range(len(key_points_dst))] 260 | 261 | 262 | ####### ICP alignment 263 | init_pose = None 264 | max_iterations = 100 265 | tolerance = 0.001 266 | 267 | coor_used = coor_src ## Batch_list[1][Batch_list[1].obs['annotation']==2].obsm["spatial"] 268 | coor_all = adata_target.obsm["spatial"].copy() 269 | coor_used = np.concatenate([coor_used, np.expand_dims(np.ones(coor_used.shape[0]), axis=1)], axis=1).T 270 | coor_all = np.concatenate([coor_all, np.expand_dims(np.ones(coor_all.shape[0]), axis=1)], axis=1).T 271 | A = coor_src ## to_be_aligned 272 | B = coor_dst ## reference_points 273 | 274 | m = A.shape[1] # get number of dimensions 275 | 276 | # make points homogeneous, copy them to maintain the originals 277 | src = np.ones((m + 1, A.shape[0])) 278 | dst = np.ones((m + 1, B.shape[0])) 279 | src[:m, :] = np.copy(A.T) 280 | dst[:m, :] = np.copy(B.T) 281 | 282 | # apply the initial pose estimation 283 | if init_pose is not None: 284 | src = np.dot(init_pose, src) 285 | prev_error = 0 286 | 287 | for ii in range(max_iterations + 1): 288 | p1 = src[:m, MNN_ind_src].T 289 | p2 = dst[:m, MNN_ind_dst].T 290 | T, _, _ = best_fit_transform(src[:m, MNN_ind_src].T, 291 | dst[:m, MNN_ind_dst].T) ## compute the transformation matrix based on MNNs 292 | import math 293 | distances = np.mean([math.sqrt(((p1[kk, 0] - p2[kk, 0]) ** 2) + ((p1[kk, 1] - p2[kk, 1]) ** 2)) 294 | for kk in range(len(p1))]) 295 | 296 | # update the current source 297 | src = np.dot(T, src) 298 | coor_used = np.dot(T, coor_used) 299 | coor_all = np.dot(T, coor_all) 300 | 301 | # check error 302 | mean_error = np.mean(distances) 303 | # print(mean_error) 304 | if np.abs(prev_error - mean_error) < tolerance: 305 | break 306 | prev_error = mean_error 307 | 308 | aligned_points = coor_used.T # MNNs in the landmark_domain 309 | aligned_points_all = coor_all.T # all points in the slice 310 | 311 | if plot_align: 312 | import matplotlib.pyplot as plt 313 | plt.rcParams["figure.figsize"] = (3, 3) 314 | fig, ax = plt.subplots(1, 2, figsize=(8, 3), gridspec_kw={'wspace': 0.5, 'hspace': 0.1}) 315 | ax[0].scatter(adata_slice2.obsm["spatial"][:, 0], adata_slice2.obsm["spatial"][:, 1], 316 | c="blue", cmap=plt.cm.binary_r, s=1) 317 | ax[0].set_title('Reference '+slice_ref, size=14) 318 | ax[1].scatter(aligned_points[:, 0], aligned_points[:, 1], 319 | c="blue", cmap=plt.cm.binary_r, s=1) 320 | ax[1].set_title('Target '+slice_target, size=14) 321 | 322 | plt.axis("equal") 323 | # plt.axis("off") 324 | plt.show() 325 | 326 | #adata_target.obsm["spatial"] = aligned_points_all[:,:2] 327 | return aligned_points_all[:,:2] 328 | 329 | 330 | 331 | # https://scikit-learn.org/dev/modules/generated/sklearn.neighbors.NearestNeighbors.html#sklearn.neighbors.NearestNeighbors 332 | def nearest_neighbor(src, dst): 333 | ''' 334 | Find the nearest (Euclidean) neighbor in dst for each point in src 335 | Input: 336 | src: Nxm array of points 337 | dst: Nxm array of points 338 | Output: 339 | distances: Euclidean distances of the nearest neighbor 340 | indices: dst indices of the nearest neighbor 341 | ''' 342 | 343 | # assert src.shape == dst.shape 344 | neigh = sklearn.neighbors.NearestNeighbors(n_neighbors=1) 345 | neigh.fit(dst) 346 | distances, indices = neigh.kneighbors(src, return_distance=True) 347 | return distances.ravel(), indices.ravel() 348 | 349 | 350 | -------------------------------------------------------------------------------- /STAligner/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | # Author: Xiang Zhou 4 | # File Name: __init__.py 5 | # Description: 6 | """ 7 | 8 | __author__ = "Xiang Zhou" 9 | __email__ = "xzhou@amss.ac.cn" 10 | 11 | from .ST_utils import match_cluster_labels, Cal_Spatial_Net, Stats_Spatial_Net, mclust_R, ICP_align 12 | from .mnn_utils import create_dictionary_mnn 13 | from .train_STAligner import train_STAligner, train_STAligner_subgraph -------------------------------------------------------------------------------- /STAligner/__pycache__/STALIGNER.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhoux85/STAligner/85ca4ce34072a37b949733aa55a3d3854ed89038/STAligner/__pycache__/STALIGNER.cpython-38.pyc -------------------------------------------------------------------------------- /STAligner/__pycache__/ST_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhoux85/STAligner/85ca4ce34072a37b949733aa55a3d3854ed89038/STAligner/__pycache__/ST_utils.cpython-38.pyc -------------------------------------------------------------------------------- /STAligner/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhoux85/STAligner/85ca4ce34072a37b949733aa55a3d3854ed89038/STAligner/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /STAligner/__pycache__/gat_conv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhoux85/STAligner/85ca4ce34072a37b949733aa55a3d3854ed89038/STAligner/__pycache__/gat_conv.cpython-38.pyc -------------------------------------------------------------------------------- /STAligner/__pycache__/mnn_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhoux85/STAligner/85ca4ce34072a37b949733aa55a3d3854ed89038/STAligner/__pycache__/mnn_utils.cpython-38.pyc -------------------------------------------------------------------------------- /STAligner/__pycache__/train_STAligner.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhoux85/STAligner/85ca4ce34072a37b949733aa55a3d3854ed89038/STAligner/__pycache__/train_STAligner.cpython-38.pyc -------------------------------------------------------------------------------- /STAligner/gat_conv.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple, Optional 2 | from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, 3 | OptTensor) 4 | 5 | import torch 6 | from torch import Tensor 7 | import torch.nn.functional as F 8 | from torch.nn import Parameter 9 | import torch.nn as nn 10 | from torch_sparse import SparseTensor, set_diag 11 | from torch_geometric.nn.dense.linear import Linear 12 | from torch_geometric.nn.conv import MessagePassing 13 | from torch_geometric.utils import remove_self_loops, add_self_loops, softmax 14 | 15 | 16 | class GATConv(MessagePassing): 17 | r"""The graph attentional operator from the `"Graph Attention Networks" 18 | `_ paper 19 | 20 | .. math:: 21 | \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + 22 | \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, 23 | 24 | where the attention coefficients :math:`\alpha_{i,j}` are computed as 25 | 26 | .. math:: 27 | \alpha_{i,j} = 28 | \frac{ 29 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} 30 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] 31 | \right)\right)} 32 | {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} 33 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} 34 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] 35 | \right)\right)}. 36 | 37 | Args: 38 | in_channels (int or tuple): Size of each input sample, or :obj:`-1` to 39 | derive the size from the first input(s) to the forward method. 40 | A tuple corresponds to the sizes of source and target 41 | dimensionalities. 42 | out_channels (int): Size of each output sample. 43 | heads (int, optional): Number of multi-head-attentions. 44 | (default: :obj:`1`) 45 | concat (bool, optional): If set to :obj:`False`, the multi-head 46 | attentions are averaged instead of concatenated. 47 | (default: :obj:`True`) 48 | negative_slope (float, optional): LeakyReLU angle of the negative 49 | slope. (default: :obj:`0.2`) 50 | dropout (float, optional): Dropout probability of the normalized 51 | attention coefficients which exposes each node to a stochastically 52 | sampled neighborhood during training. (default: :obj:`0`) 53 | add_self_loops (bool, optional): If set to :obj:`False`, will not add 54 | self-loops to the input graph. (default: :obj:`True`) 55 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 56 | an additive bias. (default: :obj:`True`) 57 | **kwargs (optional): Additional arguments of 58 | :class:`torch_geometric.nn.conv.MessagePassing`. 59 | """ 60 | _alpha: OptTensor 61 | 62 | def __init__(self, in_channels: Union[int, Tuple[int, int]], 63 | out_channels: int, heads: int = 1, concat: bool = True, 64 | negative_slope: float = 0.2, dropout: float = 0.0, 65 | add_self_loops: bool = True, bias: bool = True, 66 | prune_weight: float = 0.0, **kwargs): 67 | kwargs.setdefault('aggr', 'add') 68 | super(GATConv, self).__init__(node_dim=0, **kwargs) 69 | 70 | self.in_channels = in_channels 71 | self.out_channels = out_channels 72 | self.heads = heads 73 | self.concat = concat 74 | self.negative_slope = negative_slope 75 | self.dropout = dropout 76 | self.add_self_loops = add_self_loops 77 | 78 | # In case we are operating in bipartite graphs, we apply separate 79 | # transformations 'lin_src' and 'lin_dst' to source and target nodes: 80 | # if isinstance(in_channels, int): 81 | # self.lin_src = Linear(in_channels, heads * out_channels, 82 | # bias=False, weight_initializer='glorot') 83 | # self.lin_dst = self.lin_src 84 | # else: 85 | # self.lin_src = Linear(in_channels[0], heads * out_channels, False, 86 | # weight_initializer='glorot') 87 | # self.lin_dst = Linear(in_channels[1], heads * out_channels, False, 88 | # weight_initializer='glorot') 89 | 90 | self.lin_src = nn.Parameter(torch.zeros(size=(in_channels, heads * out_channels))) 91 | nn.init.xavier_normal_(self.lin_src.data, gain=1.414) 92 | self.lin_dst = self.lin_src 93 | 94 | 95 | # The learnable parameters to compute attention coefficients: 96 | self.att_src = Parameter(torch.Tensor(1, heads, out_channels)) 97 | self.att_dst = Parameter(torch.Tensor(1, heads, out_channels)) 98 | nn.init.xavier_normal_(self.att_src.data, gain=1.414) 99 | nn.init.xavier_normal_(self.att_dst.data, gain=1.414) 100 | 101 | # if bias and concat: 102 | # self.bias = Parameter(torch.Tensor(heads * out_channels)) 103 | # elif bias and not concat: 104 | # self.bias = Parameter(torch.Tensor(out_channels)) 105 | # else: 106 | # self.register_parameter('bias', None) 107 | 108 | self._alpha = None 109 | self.attentions = None 110 | 111 | self.prune_weight = prune_weight 112 | # self.reset_parameters() 113 | 114 | # def reset_parameters(self): 115 | # self.lin_src.reset_parameters() 116 | # self.lin_dst.reset_parameters() 117 | # glorot(self.att_src) 118 | # glorot(self.att_dst) 119 | # # zeros(self.bias) 120 | 121 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, prune_edge_index: Adj = None, 122 | size: Size = None, return_attention_weights=None, attention=True, tied_attention = None): 123 | # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa 124 | # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa 125 | # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa 126 | # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa 127 | r""" 128 | Args: 129 | return_attention_weights (bool, optional): If set to :obj:`True`, 130 | will additionally return the tuple 131 | :obj:`(edge_index, attention_weights)`, holding the computed 132 | attention weights for each edge. (default: :obj:`None`) 133 | """ 134 | H, C = self.heads, self.out_channels 135 | 136 | # We first transform the input node features. If a tuple is passed, we 137 | # transform source and target node features via separate weights: 138 | if isinstance(x, Tensor): 139 | assert x.dim() == 2, "Static graphs not supported in 'GATConv'" 140 | # x_src = x_dst = self.lin_src(x).view(-1, H, C) 141 | x_src = x_dst = torch.mm(x, self.lin_src).view(-1, H, C) 142 | else: # Tuple of source and target node features: 143 | x_src, x_dst = x 144 | assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" 145 | x_src = self.lin_src(x_src).view(-1, H, C) 146 | if x_dst is not None: 147 | x_dst = self.lin_dst(x_dst).view(-1, H, C) 148 | 149 | x = (x_src, x_dst) 150 | 151 | if not attention: 152 | return x[0].mean(dim=1) 153 | # return x[0].view(-1, self.heads * self.out_channels) 154 | 155 | if tied_attention == None: 156 | # Next, we compute node-level attention coefficients, both for source 157 | # and target nodes (if present): 158 | alpha_src = (x_src * self.att_src).sum(dim=-1) 159 | alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) 160 | alpha = (alpha_src, alpha_dst) 161 | self.attentions = alpha 162 | else: 163 | alpha = tied_attention 164 | 165 | if self.add_self_loops: 166 | if isinstance(edge_index, Tensor): 167 | # We only want to add self-loops for nodes that appear both as 168 | # source and target nodes: 169 | num_nodes = x_src.size(0) 170 | if x_dst is not None: 171 | num_nodes = min(num_nodes, x_dst.size(0)) 172 | num_nodes = min(size) if size is not None else num_nodes 173 | edge_index, _ = remove_self_loops(edge_index) 174 | edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) 175 | elif isinstance(edge_index, SparseTensor): 176 | edge_index = set_diag(edge_index) 177 | 178 | if self.prune_weight == 0: 179 | # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) 180 | out = self.propagate(edge_index, x=x, alpha=alpha, size=size) 181 | else: 182 | out = (1-self.prune_weight)*self.propagate(edge_index, x=x, alpha=alpha, size=size)+\ 183 | self.prune_weight*self.propagate(prune_edge_index, x=x, alpha=alpha, size=size) 184 | 185 | alpha = self._alpha 186 | assert alpha is not None 187 | self._alpha = None 188 | 189 | if self.concat: 190 | out = out.view(-1, self.heads * self.out_channels) 191 | else: 192 | out = out.mean(dim=1) 193 | 194 | # if self.bias is not None: 195 | # out += self.bias 196 | 197 | if isinstance(return_attention_weights, bool): 198 | if isinstance(edge_index, Tensor): 199 | return out, (edge_index, alpha) 200 | elif isinstance(edge_index, SparseTensor): 201 | return out, edge_index.set_value(alpha, layout='coo') 202 | else: 203 | return out 204 | 205 | def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor, 206 | index: Tensor, ptr: OptTensor, 207 | size_i: Optional[int]) -> Tensor: 208 | # Given egel-level attention coefficients for source and target nodes, 209 | # we simply need to sum them up to "emulate" concatenation: 210 | alpha = alpha_j if alpha_i is None else alpha_j + alpha_i 211 | 212 | # alpha = F.leaky_relu(alpha, self.negative_slope) 213 | alpha = torch.sigmoid(alpha) 214 | alpha = softmax(alpha, index, ptr, size_i) 215 | self._alpha = alpha # Save for later use. 216 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 217 | return x_j * alpha.unsqueeze(-1) 218 | 219 | def __repr__(self): 220 | return '{}({}, {}, heads={})'.format(self.__class__.__name__, 221 | self.in_channels, 222 | self.out_channels, self.heads) 223 | -------------------------------------------------------------------------------- /STAligner/mnn_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from sklearn.neighbors import NearestNeighbors 5 | from annoy import AnnoyIndex 6 | import itertools 7 | import networkx as nx 8 | import hnswlib 9 | 10 | # Modified from https://github.com/lkmklsmn/insct 11 | def create_dictionary_mnn(adata, use_rep, batch_name, k = 50, save_on_disk = True, approx = True, verbose = 1, iter_comb = None): 12 | 13 | cell_names = adata.obs_names 14 | 15 | batch_list = adata.obs[batch_name] 16 | datasets = [] 17 | datasets_pcs = [] 18 | cells = [] 19 | for i in batch_list.unique(): 20 | datasets.append(adata[batch_list == i]) 21 | datasets_pcs.append(adata[batch_list == i].obsm[use_rep]) 22 | cells.append(cell_names[batch_list == i]) 23 | 24 | batch_name_df = pd.DataFrame(np.array(batch_list.unique())) 25 | mnns = dict() 26 | 27 | if iter_comb is None: 28 | iter_comb = list(itertools.combinations(range(len(cells)), 2)) 29 | for comb in iter_comb: 30 | i = comb[0] 31 | j = comb[1] 32 | key_name1 = batch_name_df.loc[comb[0]].values[0] + "_" + batch_name_df.loc[comb[1]].values[0] 33 | mnns[key_name1] = {} # for multiple-slice setting, the key_names1 can avoid the mnns replaced by previous slice-pair 34 | if(verbose > 0): 35 | print('Processing datasets {}'.format((i, j))) 36 | 37 | new = list(cells[j]) 38 | ref = list(cells[i]) 39 | 40 | ds1 = adata[new].obsm[use_rep] 41 | ds2 = adata[ref].obsm[use_rep] 42 | names1 = new 43 | names2 = ref 44 | # if k>1,one point in ds1 may have multiple MNN points in ds2. 45 | match = mnn(ds1, ds2, names1, names2, knn=k, save_on_disk = save_on_disk, approx = approx) 46 | 47 | G = nx.Graph() 48 | G.add_edges_from(match) 49 | node_names = np.array(G.nodes) 50 | anchors = list(node_names) 51 | adj = nx.adjacency_matrix(G) 52 | tmp = np.split(adj.indices, adj.indptr[1:-1]) 53 | 54 | for i in range(0, len(anchors)): 55 | key = anchors[i] 56 | i = tmp[i] 57 | names = list(node_names[i]) 58 | mnns[key_name1][key]= names 59 | return(mnns) 60 | 61 | def validate_sparse_labels(Y): 62 | if not zero_indexed(Y): 63 | raise ValueError('Ensure that your labels are zero-indexed') 64 | if not consecutive_indexed(Y): 65 | raise ValueError('Ensure that your labels are indexed consecutively') 66 | 67 | 68 | def zero_indexed(Y): 69 | if min(abs(Y)) != 0: 70 | return False 71 | return True 72 | 73 | 74 | def consecutive_indexed(Y): 75 | """ Assumes that Y is zero-indexed. """ 76 | n_classes = len(np.unique(Y[Y != np.array(-1)])) 77 | if max(Y) >= n_classes: 78 | return False 79 | return True 80 | 81 | 82 | def nn_approx(ds1, ds2, names1, names2, knn=50): 83 | dim = ds2.shape[1] 84 | num_elements = ds2.shape[0] 85 | p = hnswlib.Index(space='l2', dim=dim) 86 | p.init_index(max_elements=num_elements, ef_construction=100, M = 16) 87 | p.set_ef(10) 88 | p.add_items(ds2) 89 | ind, distances = p.knn_query(ds1, k=knn) 90 | match = set() 91 | for a, b in zip(range(ds1.shape[0]), ind): 92 | for b_i in b: 93 | match.add((names1[a], names2[b_i])) 94 | return match 95 | 96 | 97 | def nn(ds1, ds2, names1, names2, knn=50, metric_p=2): 98 | # Find nearest neighbors of first dataset. 99 | nn_ = NearestNeighbors(knn, p=metric_p) 100 | nn_.fit(ds2) 101 | ind = nn_.kneighbors(ds1, return_distance=False) 102 | 103 | match = set() 104 | for a, b in zip(range(ds1.shape[0]), ind): 105 | for b_i in b: 106 | match.add((names1[a], names2[b_i])) 107 | 108 | return match 109 | 110 | 111 | def nn_annoy(ds1, ds2, names1, names2, knn = 20, metric='euclidean', n_trees = 50, save_on_disk = True): 112 | """ Assumes that Y is zero-indexed. """ 113 | # Build index. 114 | a = AnnoyIndex(ds2.shape[1], metric=metric) 115 | if(save_on_disk): 116 | a.on_disk_build('annoy.index') 117 | for i in range(ds2.shape[0]): 118 | a.add_item(i, ds2[i, :]) 119 | a.build(n_trees) 120 | 121 | # Search index. 122 | ind = [] 123 | for i in range(ds1.shape[0]): 124 | ind.append(a.get_nns_by_vector(ds1[i, :], knn, search_k=-1)) 125 | ind = np.array(ind) 126 | 127 | # Match. 128 | match = set() 129 | for a, b in zip(range(ds1.shape[0]), ind): 130 | for b_i in b: 131 | match.add((names1[a], names2[b_i])) 132 | 133 | return match 134 | 135 | 136 | def mnn(ds1, ds2, names1, names2, knn = 20, save_on_disk = True, approx = True): 137 | if approx: 138 | # Find nearest neighbors in first direction. 139 | # output KNN point for each point in ds1. match1 is a set(): (points in names1, points in names2), the size of the set is ds1.shape[0]*knn 140 | match1 = nn_approx(ds1, ds2, names1, names2, knn=knn)#, save_on_disk = save_on_disk) 141 | # Find nearest neighbors in second direction. 142 | match2 = nn_approx(ds2, ds1, names2, names1, knn=knn)#, save_on_disk = save_on_disk) 143 | else: 144 | match1 = nn(ds1, ds2, names1, names2, knn=knn) 145 | match2 = nn(ds2, ds1, names2, names1, knn=knn) 146 | # Compute mutual nearest neighbors. 147 | mutual = match1 & set([ (b, a) for a, b in match2 ]) 148 | 149 | return mutual 150 | -------------------------------------------------------------------------------- /STAligner/train_STAligner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from tqdm import tqdm 4 | import scipy.sparse as sp 5 | 6 | from .mnn_utils import create_dictionary_mnn 7 | from .STALIGNER import STAligner 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | 12 | cudnn.deterministic = True 13 | cudnn.benchmark = True 14 | import torch.nn.functional as F 15 | from torch_geometric.data import Data 16 | from torch_geometric.loader import DataLoader 17 | 18 | def train_STAGATE(adata, hidden_dims=[512, 30], n_epochs=1000, lr=0.001, key_added='STAligner', 19 | gradient_clipping=5., weight_decay=0.0001, verbose=True, 20 | random_seed=0, save_reconstrction=False, 21 | device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')): 22 | """\ 23 | Training graph attention auto-encoder. 24 | 25 | Parameters 26 | ---------- 27 | adata 28 | AnnData object of scanpy package. 29 | hidden_dims 30 | The dimension of the encoder. 31 | n_epochs 32 | Number of total epochs in training. 33 | lr 34 | Learning rate for AdamOptimizer. 35 | key_added 36 | The latent embeddings are saved in adata.obsm[key_added]. 37 | gradient_clipping 38 | Gradient Clipping. 39 | weight_decay 40 | Weight decay for AdamOptimizer. 41 | save_reconstrction 42 | If True, the reconstructed expression profiles are saved in adata.layers['STAGATE_ReX']. 43 | device 44 | See torch.device. 45 | 46 | Returns 47 | ------- 48 | AnnData 49 | """ 50 | 51 | # seed_everything() 52 | seed = random_seed 53 | import random 54 | random.seed(seed) 55 | torch.manual_seed(seed) 56 | torch.cuda.manual_seed_all(seed) 57 | np.random.seed(seed) 58 | 59 | adata.X = sp.csr_matrix(adata.X) 60 | 61 | if 'highly_variable' in adata.var.columns: 62 | adata_Vars = adata[:, adata.var['highly_variable']] 63 | else: 64 | adata_Vars = adata 65 | 66 | if verbose: 67 | print('Size of Input: ', adata_Vars.shape) 68 | if 'Spatial_Net' not in adata.uns.keys(): 69 | raise ValueError("Spatial_Net is not existed! Run Cal_Spatial_Net first!") 70 | 71 | data = Transfer_pytorch_Data(adata_Vars) 72 | 73 | model = STAligner(hidden_dims=[data.x.shape[1]] + hidden_dims).to(device) 74 | data = data.to(device) 75 | 76 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 77 | 78 | loss_list = [] 79 | for epoch in tqdm(range(1, n_epochs + 1)): 80 | model.train() 81 | optimizer.zero_grad() 82 | z, out = model(data.x, data.edge_index) 83 | loss = F.mse_loss(data.x, out) # F.nll_loss(out[data.train_mask], data.y[data.train_mask]) 84 | loss_list.append(loss) 85 | loss.backward() 86 | torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) 87 | optimizer.step() 88 | 89 | model.eval() 90 | z, out = model(data.x, data.edge_index) 91 | 92 | STAGATE_rep = z.to('cpu').detach().numpy() 93 | adata.obsm[key_added] = STAGATE_rep 94 | 95 | if save_loss: 96 | adata.uns['STAGATE_loss'] = loss 97 | if save_reconstrction: 98 | ReX = out.to('cpu').detach().numpy() 99 | ReX[ReX < 0] = 0 100 | adata.layers['STAGATE_ReX'] = ReX 101 | 102 | return adata 103 | 104 | 105 | def train_STAligner(adata, hidden_dims=[512, 30], n_epochs=1000, lr=0.001, key_added='STAligner', 106 | gradient_clipping=5., weight_decay=0.0001, margin=1.0, verbose=False, 107 | random_seed=666, iter_comb=None, knn_neigh=100, 108 | device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')): 109 | """\ 110 | Train graph attention auto-encoder and use spot triplets across slices to perform batch correction in the embedding space. 111 | 112 | Parameters 113 | ---------- 114 | adata 115 | AnnData object of scanpy package. 116 | hidden_dims 117 | The dimension of the encoder. 118 | n_epochs 119 | Number of total epochs in training. 120 | lr 121 | Learning rate for AdamOptimizer. 122 | key_added 123 | The latent embeddings are saved in adata.obsm[key_added]. 124 | gradient_clipping 125 | Gradient Clipping. 126 | weight_decay 127 | Weight decay for AdamOptimizer. 128 | margin 129 | Margin is used in triplet loss to enforce the distance between positive and negative pairs. 130 | Larger values result in more aggressive correction. 131 | iter_comb 132 | For multiple slices integration, we perform iterative pairwise integration. iter_comb is used to specify the order of integration. 133 | For example, (0, 1) means slice 0 will be algined with slice 1 as reference. 134 | knn_neigh 135 | The number of nearest neighbors when constructing MNNs. If knn_neigh>1, points in one slice may have multiple MNN points in another slice. 136 | device 137 | See torch.device. 138 | 139 | Returns 140 | ------- 141 | AnnData 142 | """ 143 | 144 | # seed_everything() 145 | seed = random_seed 146 | import random 147 | random.seed(seed) 148 | torch.manual_seed(seed) 149 | torch.cuda.manual_seed_all(seed) 150 | np.random.seed(seed) 151 | 152 | section_ids = np.array(adata.obs['batch_name'].unique()) 153 | edgeList = adata.uns['edgeList'] 154 | data = Data(edge_index=torch.LongTensor(np.array([edgeList[0], edgeList[1]])), 155 | prune_edge_index=torch.LongTensor(np.array([])), 156 | x=torch.FloatTensor(adata.X.todense())) 157 | data = data.to(device) 158 | 159 | model = STAligner(hidden_dims=[data.x.shape[1], hidden_dims[0], hidden_dims[1]]).to(device) 160 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 161 | if verbose: 162 | print(model) 163 | 164 | print('Pretrain with STAGATE...') 165 | for epoch in tqdm(range(0, 500)): 166 | model.train() 167 | optimizer.zero_grad() 168 | z, out = model(data.x, data.edge_index) 169 | 170 | loss = F.mse_loss(data.x, out) 171 | loss.backward() 172 | torch.nn.utils.clip_grad_norm_(model.parameters(), 5.) 173 | optimizer.step() 174 | 175 | with torch.no_grad(): 176 | z, _ = model(data.x, data.edge_index) 177 | adata.obsm['STAGATE'] = z.cpu().detach().numpy() 178 | 179 | print('Train with STAligner...') 180 | for epoch in tqdm(range(500, n_epochs)): 181 | if epoch % 100 == 0 or epoch == 500: 182 | if verbose: 183 | print('Update spot triplets at epoch ' + str(epoch)) 184 | adata.obsm['STAGATE'] = z.cpu().detach().numpy() 185 | 186 | # If knn_neigh>1, points in one slice may have multiple MNN points in another slice. 187 | # not all points have MNN achors 188 | mnn_dict = create_dictionary_mnn(adata, use_rep='STAGATE', batch_name='batch_name', k=knn_neigh, 189 | iter_comb=iter_comb, verbose=0) 190 | 191 | anchor_ind = [] 192 | positive_ind = [] 193 | negative_ind = [] 194 | for batch_pair in mnn_dict.keys(): # pairwise compare for multiple batches 195 | batchname_list = adata.obs['batch_name'][mnn_dict[batch_pair].keys()] 196 | # print("before add KNN pairs, len(mnn_dict[batch_pair]):", 197 | # sum(adata_new.obs['batch_name'].isin(batchname_list.unique())), len(mnn_dict[batch_pair])) 198 | 199 | cellname_by_batch_dict = dict() 200 | for batch_id in range(len(section_ids)): 201 | cellname_by_batch_dict[section_ids[batch_id]] = adata.obs_names[ 202 | adata.obs['batch_name'] == section_ids[batch_id]].values 203 | 204 | anchor_list = [] 205 | positive_list = [] 206 | negative_list = [] 207 | for anchor in mnn_dict[batch_pair].keys(): 208 | anchor_list.append(anchor) 209 | ## np.random.choice(mnn_dict[batch_pair][anchor]) 210 | positive_spot = mnn_dict[batch_pair][anchor][0] # select the first positive spot 211 | positive_list.append(positive_spot) 212 | section_size = len(cellname_by_batch_dict[batchname_list[anchor]]) 213 | negative_list.append( 214 | cellname_by_batch_dict[batchname_list[anchor]][np.random.randint(section_size)]) 215 | 216 | batch_as_dict = dict(zip(list(adata.obs_names), range(0, adata.shape[0]))) 217 | anchor_ind = np.append(anchor_ind, list(map(lambda _: batch_as_dict[_], anchor_list))) 218 | positive_ind = np.append(positive_ind, list(map(lambda _: batch_as_dict[_], positive_list))) 219 | negative_ind = np.append(negative_ind, list(map(lambda _: batch_as_dict[_], negative_list))) 220 | 221 | model.train() 222 | optimizer.zero_grad() 223 | z, out = model(data.x, data.edge_index) 224 | mse_loss = F.mse_loss(data.x, out) 225 | 226 | anchor_arr = z[anchor_ind,] 227 | positive_arr = z[positive_ind,] 228 | negative_arr = z[negative_ind,] 229 | 230 | triplet_loss = torch.nn.TripletMarginLoss(margin=margin, p=2, reduction='mean') 231 | tri_output = triplet_loss(anchor_arr, positive_arr, negative_arr) 232 | 233 | loss = mse_loss + tri_output 234 | loss.backward() 235 | torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) 236 | optimizer.step() 237 | 238 | # 239 | model.eval() 240 | adata.obsm[key_added] = z.cpu().detach().numpy() 241 | return adata 242 | 243 | 244 | def train_STAligner_subgraph(adata, hidden_dims=[512, 30], n_epochs=1000, lr=0.001, key_added='STAligner', 245 | gradient_clipping=5., weight_decay=0.0001, margin=1.0, verbose=False, 246 | random_seed=666, iter_comb=None, knn_neigh=100, Batch_list=None, 247 | device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')): 248 | """\ 249 | Train graph attention auto-encoder and use spot triplets across slices to perform batch correction in the embedding space. 250 | To deal with large-scale data with multiple slices and reduce GPU memory usage, each slice is considered as a subgraph for training. 251 | 252 | Parameters 253 | ---------- 254 | adata 255 | AnnData object of scanpy package. 256 | hidden_dims 257 | The dimension of the encoder. 258 | n_epochs 259 | Number of total epochs in training. 260 | lr 261 | Learning rate for AdamOptimizer. 262 | key_added 263 | The latent embeddings are saved in adata.obsm[key_added]. 264 | gradient_clipping 265 | Gradient Clipping. 266 | weight_decay 267 | Weight decay for AdamOptimizer. 268 | margin 269 | Margin is used in triplet loss to enforce the distance between positive and negative pairs. 270 | Larger values result in more aggressive correction. 271 | iter_comb 272 | For multiple slices integration, we perform iterative pairwise integration. iter_comb is used to specify the order of integration. 273 | For example, (0, 1) means slice 0 will be algined with slice 1 as reference. 274 | knn_neigh 275 | The number of nearest neighbors when constructing MNNs. If knn_neigh>1, points in one slice may have multiple MNN points in another slice. 276 | device 277 | See torch.device. 278 | 279 | Returns 280 | ------- 281 | AnnData 282 | """ 283 | 284 | # seed_everything() 285 | seed = random_seed 286 | import random 287 | random.seed(seed) 288 | torch.manual_seed(seed) 289 | torch.cuda.manual_seed_all(seed) 290 | np.random.seed(seed) 291 | 292 | section_ids = np.array(adata.obs['batch_name'].unique()) 293 | 294 | comm_gene = adata.var_names 295 | data_list = [] 296 | for adata_tmp in Batch_list: 297 | adata_tmp = adata_tmp[:, comm_gene] 298 | edge_index = np.nonzero(adata_tmp.uns['adj']) 299 | data_list.append(Data(edge_index=torch.LongTensor(np.array([edge_index[0], edge_index[1]])), 300 | prune_edge_index=torch.LongTensor(np.array([])), 301 | x=torch.FloatTensor(adata_tmp.X.todense()))) 302 | 303 | loader = DataLoader(data_list, batch_size=1, shuffle=True) 304 | 305 | model = STAligner(hidden_dims=[adata.X.shape[1], hidden_dims[0], hidden_dims[1]]).to(device) 306 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 307 | if verbose: 308 | print(model) 309 | 310 | print('Pretrain with STAGATE...') 311 | for epoch in tqdm(range(0, 500)): 312 | for batch in loader: 313 | model.train() 314 | optimizer.zero_grad() 315 | batch = batch.to(device) 316 | z, out = model(batch.x, batch.edge_index) 317 | 318 | loss = F.mse_loss(batch.x, out) # +adv_loss 319 | loss.backward() 320 | torch.nn.utils.clip_grad_norm_(model.parameters(), 5.) 321 | optimizer.step() 322 | 323 | with torch.no_grad(): 324 | z_list = [] 325 | for batch in data_list: 326 | z, _ = model.cpu()(batch.x, batch.edge_index) 327 | z_list.append(z.cpu().detach().numpy()) 328 | adata.obsm['STAGATE'] = np.concatenate(z_list, axis=0) 329 | model = model.to(device) 330 | 331 | print('Train with STAligner...') 332 | for epoch in tqdm(range(500, n_epochs)): 333 | if epoch % 100 == 0 or epoch == 500: 334 | if verbose: 335 | print('Update spot triplets at epoch ' + str(epoch)) 336 | 337 | with torch.no_grad(): 338 | z_list = [] 339 | for batch in data_list: 340 | z, _ = model.cpu()(batch.x, batch.edge_index) 341 | z_list.append(z.cpu().detach().numpy()) 342 | adata.obsm['STAGATE'] = np.concatenate(z_list, axis=0) 343 | model = model.to(device) 344 | 345 | pair_data_list = [] 346 | for comb in iter_comb: 347 | #print(comb) 348 | i, j = comb[0], comb[1] 349 | batch_pair = adata[adata.obs['batch_name'].isin([section_ids[i], section_ids[j]])] 350 | mnn_dict = create_dictionary_mnn(batch_pair, use_rep='STAGATE', batch_name='batch_name', 351 | k=knn_neigh, 352 | iter_comb=None, verbose=0) 353 | 354 | batchname_list = batch_pair.obs['batch_name'] 355 | cellname_by_batch_dict = dict() 356 | for batch_id in range(len(section_ids)): 357 | cellname_by_batch_dict[section_ids[batch_id]] = batch_pair.obs_names[ 358 | batch_pair.obs['batch_name'] == section_ids[batch_id]].values 359 | 360 | anchor_list = [] 361 | positive_list = [] 362 | negative_list = [] 363 | for batch_pair_name in mnn_dict.keys(): # pairwise compare for multiple batches 364 | for anchor in mnn_dict[batch_pair_name].keys(): 365 | anchor_list.append(anchor) 366 | positive_spot = mnn_dict[batch_pair_name][anchor][0] 367 | positive_list.append(positive_spot) 368 | section_size = len(cellname_by_batch_dict[batchname_list[anchor]]) 369 | negative_list.append( 370 | cellname_by_batch_dict[batchname_list[anchor]][np.random.randint(section_size)]) 371 | 372 | batch_as_dict = dict(zip(list(batch_pair.obs_names), range(0, batch_pair.shape[0]))) 373 | anchor_ind = list(map(lambda _: batch_as_dict[_], anchor_list)) 374 | positive_ind = list(map(lambda _: batch_as_dict[_], positive_list)) 375 | negative_ind = list(map(lambda _: batch_as_dict[_], negative_list)) 376 | 377 | edge_list_1 = np.nonzero(Batch_list[i].uns['adj']) 378 | max_ind = edge_list_1[0].max() 379 | edge_list_2 = np.nonzero(Batch_list[j].uns['adj']) 380 | edge_list_2 = (edge_list_2[0] + max_ind + 1, edge_list_2[1] + max_ind + 1) 381 | edge_list = [edge_list_1, edge_list_2] 382 | edge_pairs = [np.append(edge_list[0][0], edge_list[1][0]), np.append(edge_list[0][1], edge_list[1][1])] 383 | pair_data_list.append(Data(edge_index=torch.LongTensor(np.array([edge_pairs[0], edge_pairs[1]])), 384 | anchor_ind=torch.LongTensor(np.array(anchor_ind)), 385 | positive_ind=torch.LongTensor(np.array(positive_ind)), 386 | negative_ind=torch.LongTensor(np.array(negative_ind)), 387 | x=batch_pair.X)) #torch.FloatTensor(batch_pair.X.todense()) 388 | 389 | # for temp in pair_data_list: 390 | # temp.to(device) 391 | pair_loader = DataLoader(pair_data_list, batch_size=1, shuffle=True) 392 | 393 | for batch in pair_loader: 394 | model.train() 395 | optimizer.zero_grad() 396 | 397 | batch.x = torch.FloatTensor(batch.x[0].todense()) 398 | batch = batch.to(device) 399 | z, out = model(batch.x, batch.edge_index) 400 | mse_loss = F.mse_loss(batch.x, out) 401 | 402 | anchor_arr = z[batch.anchor_ind,] 403 | positive_arr = z[batch.positive_ind,] 404 | negative_arr = z[batch.negative_ind,] 405 | 406 | triplet_loss = torch.nn.TripletMarginLoss(margin=margin, p=2, reduction='sum') 407 | tri_output = triplet_loss(anchor_arr, positive_arr, negative_arr) 408 | 409 | loss = mse_loss + tri_output 410 | loss.backward() 411 | torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) 412 | optimizer.step() 413 | 414 | # 415 | model.eval() 416 | with torch.no_grad(): 417 | z_list = [] 418 | for batch in data_list: 419 | z, _ = model.cpu()(batch.x, batch.edge_index) 420 | z_list.append(z.cpu().detach().numpy()) 421 | adata.obsm[key_added] = np.concatenate(z_list, axis=0) 422 | return adata 423 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | # Author: Xiang Zhou 4 | # File Name: __init__.py 5 | # Description: 6 | """ 7 | 8 | __author__ = "Xiang Zhou" 9 | __email__ = "xzhou@amss.ac.cn" -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | numpy==1.20.3 2 | anndata==0.7.8 3 | pandas==1.2.3 4 | scipy==1.5.3 5 | matplotlib==3.5.1 6 | louvain==0.7.1 7 | scanpy==1.9.1 8 | umap-learn==0.5.2 9 | rpy2==2.9.5 10 | seaborn==0.11.2 11 | networkx==2.8.4 12 | hnswlib==0.5.1 13 | annoy==1.17.0 14 | tqdm==4.64.1 15 | torch==1.9.1 16 | torch-geometric==2.0.3 17 | torch-cluster==1.5.9 18 | torch-scatter==2.0.9 19 | torch-sparse==0.6.12 20 | -------------------------------------------------------------------------------- /requirement_for_macOS.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.4 2 | anndata==0.9.2 3 | pandas==2.0.3 4 | scipy==1.5.3 5 | matplotlib==3.7.2 6 | louvain==0.8.1 7 | scanpy==1.9.3 8 | umap-learn==0.5.3 9 | rpy2==3.5.13 10 | seaborn==0.12.2 11 | networkx==3.1 12 | hnswlib==0.7.0 13 | annoy==1.17.3 14 | tqdm==4.66.1 15 | torch==1.13.0 16 | torch-geometric==2.3.1 17 | torch-cluster==1.6.1 18 | torch-scatter==2.1.1 19 | torch-sparse==0.6.17 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import Command, find_packages, setup 2 | 3 | __lib_name__ = "STAligner" 4 | __lib_version__ = "1.0.0" 5 | __description__ = "Integrating spatial transcriptomics data across different conditions, technologies, and developmental stages" 6 | __url__ = "https://github.com/zhoux85/STAligner" 7 | __author__ = "Xiang Zhou" 8 | __author_email__ = "xzhou@amss.ac.cn" 9 | __license__ = "MIT" 10 | __keywords__ = ["spatial transcriptomics", "data integration", "Graph attention auto-encoder", "spatial domain", "three-dimensional reconstruction"] 11 | __requires__ = ["requests",] 12 | 13 | setup( 14 | name = __lib_name__, 15 | version = __lib_version__, 16 | description = __description__, 17 | url = __url__, 18 | author = __author__, 19 | author_email = __author_email__, 20 | license = __license__, 21 | packages = ['STAligner'], 22 | install_requires = __requires__, 23 | zip_safe = False, 24 | include_package_data = True, 25 | ) --------------------------------------------------------------------------------