├── DLPFC_MULTISLICES.py ├── README.md ├── SpaMask.jpg ├── SpaMask ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-311.pyc │ ├── model.cpython-311.pyc │ ├── preprocess.cpython-311.pyc │ ├── spaMask.cpython-311.pyc │ └── utils.cpython-311.pyc ├── model.py ├── preprocess.py ├── spaMask.py └── utils.py ├── TutorialDonor.ipynb └── requirement.txt /DLPFC_MULTISLICES.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import ot 3 | import scipy.sparse 4 | import matplotlib 5 | import sklearn 6 | from matplotlib import cm 7 | import matplotlib.pyplot as plt 8 | import random 9 | import numpy as np 10 | from sklearn.neighbors import NearestNeighbors 11 | 12 | np.random.seed(42) 13 | random.seed(42) 14 | 15 | 16 | def align_spots(adata_st_list_input, # list of spatial transcriptomics datasets 17 | method="icp", # "icp" or "paste" 18 | data_type="Visium", 19 | # a spot has six nearest neighborhoods if "Visium", four nearest neighborhoods otherwise 20 | coor_key="spatial", # "spatial" for visium; key for the spatial coordinates used for alignment 21 | tol=0.01, # parameter for "icp" method; tolerance level 22 | test_all_angles=False, # parameter for "icp" method; whether to test multiple rotation angles or not 23 | plot=False, 24 | paste_alpha=0.1, 25 | paste_dissimilarity="kl" 26 | ): 27 | # Align coordinates of spatial transcriptomics 28 | 29 | # The first adata in the list is used as a reference for alignment 30 | adata_st_list = adata_st_list_input.copy() 31 | 32 | if plot: 33 | # Choose colors 34 | cmap = cm.get_cmap('rainbow', len(adata_st_list)) 35 | colors_list = [matplotlib.colors.rgb2hex(cmap(i)) for i in range(len(adata_st_list))] 36 | 37 | # Plot spots before alignment 38 | plt.figure(figsize=(5, 5)) 39 | plt.title("Before alignment") 40 | for i in range(len(adata_st_list)): 41 | plt.scatter(adata_st_list[i].obsm[coor_key][:, 0], 42 | adata_st_list[i].obsm[coor_key][:, 1], 43 | c=colors_list[i], 44 | label="Slice %d spots" % i, s=5., alpha=0.5) 45 | ax = plt.gca() 46 | ax.set_ylim(ax.get_ylim()[::-1]) 47 | plt.xticks([]) 48 | plt.yticks([]) 49 | plt.legend(loc=(1.02, .2), ncol=(len(adata_st_list) // 13 + 1)) 50 | plt.show() 51 | 52 | if (method == "icp") or (method == "ICP"): 53 | print("Using the Iterative Closest Point algorithm for alignemnt.") 54 | # Detect edges 55 | print("Detecting edges...") 56 | point_cloud_list = [] 57 | for adata in adata_st_list: 58 | # Use in-tissue spots only 59 | if 'in_tissue' in adata.obs.columns: 60 | adata = adata[adata.obs['in_tissue'] == 1] 61 | if data_type == "Visium": 62 | loc_x = adata.obs.loc[:, ["array_row"]] 63 | loc_x = np.array(loc_x) * np.sqrt(3) 64 | loc_y = adata.obs.loc[:, ["array_col"]] 65 | loc_y = np.array(loc_y) 66 | loc = np.concatenate((loc_x, loc_y), axis=1) 67 | pairwise_loc_distsq = np.sum((loc.reshape([1, -1, 2]) - loc.reshape([-1, 1, 2])) ** 2, axis=2) 68 | n_neighbors = np.sum(pairwise_loc_distsq < 5, axis=1) - 1 69 | edge = ((n_neighbors > 1) & (n_neighbors < 5)).astype(np.float32) 70 | else: 71 | loc_x = adata.obs.loc[:, ["array_row"]] 72 | loc_x = np.array(loc_x) 73 | loc_y = adata.obs.loc[:, ["array_col"]] 74 | loc_y = np.array(loc_y) 75 | loc = np.concatenate((loc_x, loc_y), axis=1) 76 | pairwise_loc_distsq = np.sum((loc.reshape([1, -1, 2]) - loc.reshape([-1, 1, 2])) ** 2, axis=2) 77 | min_distsq = np.sort(np.unique(pairwise_loc_distsq), axis=None)[1] 78 | n_neighbors = np.sum(pairwise_loc_distsq < (min_distsq * 3), axis=1) - 1 79 | edge = ((n_neighbors > 1) & (n_neighbors < 7)).astype(np.float32) 80 | point_cloud_list.append(adata.obsm[coor_key][edge == 1].copy()) 81 | 82 | # Align edges 83 | print("Aligning edges...") 84 | trans_list = [] 85 | adata_st_list[0].obsm["spatial_aligned"] = adata_st_list[0].obsm[coor_key].copy() 86 | # Calculate pairwise transformation matrices 87 | for i in range(len(adata_st_list) - 1): 88 | if test_all_angles == True: 89 | for angle in [0., np.pi * 1 / 3, np.pi * 2 / 3, np.pi, np.pi * 4 / 3, np.pi * 5 / 3]: 90 | R = np.array([[np.cos(angle), np.sin(angle), 0], 91 | [-np.sin(angle), np.cos(angle), 0], 92 | [0, 0, 1]]).T 93 | T, distances, _ = icp(transform(point_cloud_list[i + 1], R), point_cloud_list[i], tolerance=tol) 94 | if angle == 0: 95 | loss_best = np.mean(distances) 96 | angle_best = angle 97 | R_best = R 98 | T_best = T 99 | else: 100 | if np.mean(distances) < loss_best: 101 | loss_best = np.mean(distances) 102 | angle_best = angle 103 | R_best = R 104 | T_best = T 105 | T = T_best @ R_best 106 | else: 107 | T, _, _ = icp(point_cloud_list[i + 1], point_cloud_list[i], tolerance=tol) 108 | trans_list.append(T) 109 | # Tranform 110 | for i in range(len(adata_st_list) - 1): 111 | point_cloud_align = adata_st_list[i + 1].obsm[coor_key].copy() 112 | for T in trans_list[:(i + 1)][::-1]: 113 | point_cloud_align = transform(point_cloud_align, T) 114 | adata_st_list[i + 1].obsm["spatial_aligned"] = point_cloud_align 115 | 116 | elif (method == "paste") or (method == "PASTE"): 117 | print("Using PASTE algorithm for alignemnt.") 118 | # Align spots 119 | print("Aligning spots...") 120 | pis = [] 121 | # Calculate pairwise transformation matrices 122 | for i in range(len(adata_st_list) - 1): 123 | pi = pairwise_align_paste(adata_st_list[i], adata_st_list[i + 1], coor_key=coor_key, 124 | alpha=paste_alpha, dissimilarity=paste_dissimilarity) 125 | pis.append(pi) 126 | # Tranform 127 | S1, S2 = generalized_procrustes_analysis(adata_st_list[0].obsm[coor_key], 128 | adata_st_list[1].obsm[coor_key], 129 | pis[0]) 130 | adata_st_list[0].obsm["spatial_aligned"] = S1 131 | adata_st_list[1].obsm["spatial_aligned"] = S2 132 | for i in range(1, len(adata_st_list) - 1): 133 | S1, S2 = generalized_procrustes_analysis(adata_st_list[i].obsm["spatial_aligned"], 134 | adata_st_list[i + 1].obsm[coor_key], 135 | pis[i]) 136 | adata_st_list[i + 1].obsm["spatial_aligned"] = S2 137 | 138 | if plot: 139 | plt.figure(figsize=(5, 5)) 140 | plt.title("After alignment") 141 | for i in range(len(adata_st_list)): 142 | plt.scatter(adata_st_list[i].obsm["spatial_aligned"][:, 0], 143 | adata_st_list[i].obsm["spatial_aligned"][:, 1], 144 | c=colors_list[i], 145 | label="Slice %d spots" % i, s=5., alpha=0.5) 146 | ax = plt.gca() 147 | ax.set_ylim(ax.get_ylim()[::-1]) 148 | plt.xticks([]) 149 | plt.yticks([]) 150 | plt.legend(loc=(1.02, .2), ncol=(len(adata_st_list) // 13 + 1)) 151 | plt.show() 152 | 153 | return adata_st_list 154 | 155 | 156 | # Functions for the Iterative Closest Point algorithm 157 | # Credit to https://github.com/ClayFlannigan/icp 158 | def best_fit_transform(A, B): 159 | assert A.shape == B.shape 160 | 161 | # get number of dimensions 162 | m = A.shape[1] 163 | 164 | # translate points to their centroids 165 | centroid_A = np.mean(A, axis=0) 166 | centroid_B = np.mean(B, axis=0) 167 | AA = A - centroid_A 168 | BB = B - centroid_B 169 | 170 | # rotation matrix 171 | H = np.dot(AA.T, BB) 172 | U, S, Vt = np.linalg.svd(H) 173 | R = np.dot(Vt.T, U.T) 174 | 175 | # special reflection case 176 | if np.linalg.det(R) < 0: 177 | Vt[m - 1, :] *= -1 178 | R = np.dot(Vt.T, U.T) 179 | 180 | # translation 181 | t = centroid_B.T - np.dot(R, centroid_A.T) 182 | 183 | # homogeneous transformation 184 | T = np.identity(m + 1) 185 | T[:m, :m] = R 186 | T[:m, m] = t 187 | 188 | return T, R, t 189 | 190 | 191 | def nearest_neighbor(src, dst): 192 | ''' 193 | Find the nearest (Euclidean) neighbor in dst for each point in src 194 | Input: 195 | src: Nxm array of points 196 | dst: Nxm array of points 197 | Output: 198 | distances: Euclidean distances of the nearest neighbor 199 | indices: dst indices of the nearest neighbor 200 | ''' 201 | 202 | neigh = NearestNeighbors(n_neighbors=1) 203 | neigh.fit(dst) 204 | distances, indices = neigh.kneighbors(src, return_distance=True) 205 | return distances.ravel(), indices.ravel() 206 | 207 | 208 | def icp(A, B, init_pose=None, max_iterations=20, tolerance=0.001): 209 | ''' 210 | The Iterative Closest Point method: finds best-fit transform that maps points A on to points B 211 | Input: 212 | A: Nxm numpy array of source mD points 213 | B: Nxm numpy array of destination mD point 214 | init_pose: (m+1)x(m+1) homogeneous transformation 215 | max_iterations: exit algorithm after max_iterations 216 | tolerance: convergence criteria 217 | Output: 218 | T: final homogeneous transformation that maps A on to B 219 | distances: Euclidean distances (errors) of the nearest neighbor 220 | i: number of iterations to converge 221 | ''' 222 | 223 | # get number of dimensions 224 | m = A.shape[1] 225 | 226 | # make points homogeneous, copy them to maintain the originals 227 | src = np.ones((m + 1, A.shape[0])) 228 | dst = np.ones((m + 1, B.shape[0])) 229 | src[:m, :] = np.copy(A.T) 230 | dst[:m, :] = np.copy(B.T) 231 | 232 | # apply the initial pose estimation 233 | if init_pose is not None: 234 | src = np.dot(init_pose, src) 235 | 236 | prev_error = 0 237 | 238 | for i in range(max_iterations): 239 | # find the nearest neighbors between the current source and destination points 240 | distances, indices = nearest_neighbor(src[:m, :].T, dst[:m, :].T) 241 | 242 | # compute the transformation between the current source and nearest destination points 243 | T, _, _ = best_fit_transform(src[:m, :].T, dst[:m, indices].T) 244 | 245 | # update the current source 246 | src = np.dot(T, src) 247 | 248 | # check error 249 | mean_error = np.mean(distances) 250 | if np.abs(prev_error - mean_error) < tolerance: 251 | break 252 | prev_error = mean_error 253 | 254 | # calculate final transformation 255 | T, _, _ = best_fit_transform(A, src[:m, :].T) 256 | 257 | return T, distances, i 258 | 259 | 260 | def transform(point_cloud, T): 261 | point_cloud_align = np.ones((point_cloud.shape[0], 3)) 262 | point_cloud_align[:, 0:2] = np.copy(point_cloud) 263 | point_cloud_align = np.dot(T, point_cloud_align.T).T 264 | return point_cloud_align[:, :2] 265 | 266 | 267 | # Functions for the PASTE algorithm 268 | # Credit to https://github.com/raphael-group/paste 269 | 270 | ## Covert a sparse matrix into a dense np array 271 | to_dense_array = lambda X: X.toarray() if isinstance(X, scipy.sparse.csr.spmatrix) else np.array(X) 272 | 273 | ## Returns the data matrix or representation 274 | extract_data_matrix = lambda adata, rep: adata.X if rep is None else adata.obsm[rep] 275 | 276 | 277 | def intersect(lst1, lst2): 278 | temp = set(lst2) 279 | lst3 = [value for value in lst1 if value in temp] 280 | return lst3 281 | 282 | 283 | def kl_divergence_backend(X, Y): 284 | assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features." 285 | 286 | nx = ot.backend.get_backend(X, Y) 287 | 288 | X = X / nx.sum(X, axis=1, keepdims=True) 289 | Y = Y / nx.sum(Y, axis=1, keepdims=True) 290 | log_X = nx.log(X) 291 | log_Y = nx.log(Y) 292 | X_log_X = nx.einsum('ij,ij->i', X, log_X) 293 | X_log_X = nx.reshape(X_log_X, (1, X_log_X.shape[0])) 294 | D = X_log_X.T - nx.dot(X, log_Y.T) 295 | return nx.to_numpy(D) 296 | 297 | 298 | def my_fused_gromov_wasserstein(M, C1, C2, p, q, G_init=None, loss_fun='square_loss', alpha=0.5, armijo=False, 299 | log=False, numItermax=200, use_gpu=False, **kwargs): 300 | p, q = ot.utils.list_to_array(p, q) 301 | 302 | p0, q0, C10, C20, M0 = p, q, C1, C2, M 303 | nx = ot.backend.get_backend(p0, q0, C10, C20, M0) 304 | 305 | constC, hC1, hC2 = ot.gromov.init_matrix(C1, C2, p, q, loss_fun) 306 | 307 | if G_init is None: 308 | G0 = p[:, None] * q[None, :] 309 | else: 310 | G0 = (1 / nx.sum(G_init)) * G_init 311 | if use_gpu: 312 | G0 = G0.cuda() 313 | 314 | def f(G): 315 | return ot.gromov.gwloss(constC, hC1, hC2, G) 316 | 317 | def df(G): 318 | return ot.gromov.gwggrad(constC, hC1, hC2, G) 319 | 320 | if log: 321 | res, log = ot.gromov.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, 322 | log=True, **kwargs) 323 | 324 | fgw_dist = log['loss'][-1] 325 | 326 | log['fgw_dist'] = fgw_dist 327 | log['u'] = log['u'] 328 | log['v'] = log['v'] 329 | return res, log 330 | 331 | else: 332 | return ot.gromov.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, 333 | **kwargs) 334 | 335 | 336 | def pairwise_align_paste( 337 | sliceA, 338 | sliceB, 339 | alpha=0.1, 340 | dissimilarity='kl', 341 | use_rep=None, 342 | G_init=None, 343 | a_distribution=None, 344 | b_distribution=None, 345 | norm=False, 346 | numItermax=200, 347 | backend=ot.backend.NumpyBackend(), 348 | use_gpu=False, 349 | return_obj=False, 350 | verbose=False, 351 | gpu_verbose=False, 352 | coor_key="spatial", 353 | **kwargs): 354 | if use_gpu: 355 | try: 356 | import torch 357 | except: 358 | print("We currently only have gpu support for Pytorch. Please install torch.") 359 | 360 | if isinstance(backend, ot.backend.TorchBackend): 361 | if torch.cuda.is_available(): 362 | if gpu_verbose: 363 | print("gpu is available, using gpu.") 364 | else: 365 | if gpu_verbose: 366 | print("gpu is not available, resorting to torch cpu.") 367 | use_gpu = False 368 | else: 369 | print( 370 | "We currently only have gpu support for Pytorch, please set backend = ot.backend.TorchBackend(). Reverting to selected backend cpu.") 371 | use_gpu = False 372 | else: 373 | if gpu_verbose: 374 | print("Using selected backend cpu. If you want to use gpu, set use_gpu = True.") 375 | 376 | # subset for common genes 377 | common_genes = intersect(sliceA.var.index, sliceB.var.index) 378 | sliceA = sliceA[:, common_genes] 379 | sliceB = sliceB[:, common_genes] 380 | 381 | # Backend 382 | nx = backend 383 | 384 | # Calculate spatial distances 385 | coordinatesA = sliceA.obsm[coor_key].copy() 386 | coordinatesA = nx.from_numpy(coordinatesA) 387 | coordinatesB = sliceB.obsm[coor_key].copy() 388 | coordinatesB = nx.from_numpy(coordinatesB) 389 | 390 | if isinstance(nx, ot.backend.TorchBackend): 391 | coordinatesA = coordinatesA.float() 392 | coordinatesB = coordinatesB.float() 393 | D_A = ot.dist(coordinatesA, coordinatesA, metric='euclidean') 394 | D_B = ot.dist(coordinatesB, coordinatesB, metric='euclidean') 395 | 396 | if isinstance(nx, ot.backend.TorchBackend) and use_gpu: 397 | D_A = D_A.cuda() 398 | D_B = D_B.cuda() 399 | 400 | # Calculate expression dissimilarity 401 | A_X, B_X = nx.from_numpy(to_dense_array(extract_data_matrix(sliceA, use_rep))), nx.from_numpy( 402 | to_dense_array(extract_data_matrix(sliceB, use_rep))) 403 | 404 | if isinstance(nx, ot.backend.TorchBackend) and use_gpu: 405 | A_X = A_X.cuda() 406 | B_X = B_X.cuda() 407 | 408 | if dissimilarity.lower() == 'euclidean' or dissimilarity.lower() == 'euc': 409 | M = ot.dist(A_X, B_X) 410 | else: 411 | s_A = A_X + 0.01 412 | s_B = B_X + 0.01 413 | M = kl_divergence_backend(s_A, s_B) 414 | M = nx.from_numpy(M) 415 | 416 | if isinstance(nx, ot.backend.TorchBackend) and use_gpu: 417 | M = M.cuda() 418 | 419 | # init distributions 420 | if a_distribution is None: 421 | a = nx.ones((sliceA.shape[0],)) / sliceA.shape[0] 422 | else: 423 | a = nx.from_numpy(a_distribution) 424 | 425 | if b_distribution is None: 426 | b = nx.ones((sliceB.shape[0],)) / sliceB.shape[0] 427 | else: 428 | b = nx.from_numpy(b_distribution) 429 | 430 | if isinstance(nx, ot.backend.TorchBackend) and use_gpu: 431 | a = a.cuda() 432 | b = b.cuda() 433 | 434 | if norm: 435 | D_A /= nx.min(D_A[D_A > 0]) 436 | D_B /= nx.min(D_B[D_B > 0]) 437 | 438 | # Run OT 439 | if G_init is not None: 440 | G_init = nx.from_numpy(G_init) 441 | if isinstance(nx, ot.backend.TorchBackend): 442 | G_init = G_init.float() 443 | if use_gpu: 444 | G_init.cuda() 445 | pi, logw = my_fused_gromov_wasserstein(M, D_A, D_B, a, b, G_init=G_init, loss_fun='square_loss', alpha=alpha, 446 | log=True, numItermax=numItermax, verbose=verbose, use_gpu=use_gpu) 447 | pi = nx.to_numpy(pi) 448 | obj = nx.to_numpy(logw['fgw_dist']) 449 | if isinstance(backend, ot.backend.TorchBackend) and use_gpu: 450 | torch.cuda.empty_cache() 451 | 452 | if return_obj: 453 | return pi, obj 454 | return pi 455 | 456 | 457 | def generalized_procrustes_analysis(X, Y, pi, output_params=False, matrix=False): 458 | """ 459 | Finds and applies optimal rotation between spatial coordinates of two layers (may also do a reflection). 460 | Args: 461 | X: np array of spatial coordinates (ex: sliceA.obs['spatial']) 462 | Y: np array of spatial coordinates (ex: sliceB.obs['spatial']) 463 | pi: mapping between the two layers output by PASTE 464 | output_params: Boolean of whether to return rotation angle and translations along with spatial coordiantes. 465 | matrix: Boolean of whether to return the rotation as a matrix or an angle 466 | Returns: 467 | Aligned spatial coordinates of X, Y, rotation angle, translation of X, translation of Y 468 | """ 469 | assert X.shape[1] == 2 and Y.shape[1] == 2 470 | 471 | tX = pi.sum(axis=1).dot(X) 472 | tY = pi.sum(axis=0).dot(Y) 473 | X = X - tX 474 | Y = Y - tY 475 | H = Y.T.dot(pi.T.dot(X)) 476 | U, S, Vt = np.linalg.svd(H) 477 | R = Vt.T.dot(U.T) 478 | Y = R.dot(Y.T).T 479 | if output_params and not matrix: 480 | M = np.array([[0, -1], [1, 0]]) 481 | theta = np.arctan(np.trace(M.dot(H)) / np.trace(H)) 482 | return X, Y, theta, tX, tY 483 | elif output_params and matrix: 484 | return X, Y, R, tX, tY 485 | else: 486 | return X, Y 487 | 488 | 489 | # %% 490 | import anndata as ad 491 | import scipy.sparse 492 | from sklearn.metrics import pairwise_distances 493 | from sklearn.neighbors import NearestNeighbors 494 | import sklearn.neighbors 495 | import scipy.sparse as sp 496 | import numpy as np 497 | import pandas as pd 498 | 499 | 500 | def preprocess(adata_st_list, # list of spatial transcriptomics (ST) anndata objects 501 | section_ids=None, 502 | three_dim_coor=None, # if not None, use existing 3d coordinates in shape [# of total spots, 3] 503 | coor_key="spatial_aligned", # "spatial_aligned" by default 504 | rad_cutoff=None, # cutoff radius of spots for building graph 505 | rad_coef=1.5, # if rad_cutoff=None, rad_cutoff is the minimum distance between spots multiplies rad_coef 506 | k_cutoff=12, 507 | slice_dist_micron=None, # pairwise distances in micrometer for reconstructing z-axis 508 | c2c_dist=100, # center to center distance between nearest spots in micrometer 509 | model='KNN', 510 | ): 511 | assert (model in ['Radius', 'KNN']) 512 | adata_st = ad.concat(adata_st_list, label="slice_name", keys=section_ids) 513 | adata_st.obs['Ground Truth'] = adata_st.obs['Ground Truth'].astype('category') 514 | adata_st.obs["batch_name"] = adata_st.obs["slice_name"].astype('category') 515 | 516 | # Build a graph for spots across multiple slices 517 | print("Start building a graph...") 518 | 519 | # Build 3D coordinates 520 | if three_dim_coor is None: 521 | # The first adata in adata_list is used as a reference for computing cutoff radius of spots 522 | adata_st_ref = adata_st_list[0].copy() 523 | loc_ref = np.array(adata_st_ref.obsm[coor_key]) 524 | pair_dist_ref = pairwise_distances(loc_ref) 525 | min_dist_ref = np.sort(np.unique(pair_dist_ref), axis=None)[1] 526 | 527 | if rad_cutoff is None: 528 | # The radius is computed base on the attribute "adata.obsm['spatial']" 529 | rad_cutoff = min_dist_ref * rad_coef 530 | print("Radius for graph connection is %.4f." % rad_cutoff) 531 | 532 | # Use the attribute "adata.obsm['spatial_aligned']" to build a global graph 533 | if slice_dist_micron is None: 534 | loc_xy = pd.DataFrame(adata_st.obsm['spatial_aligned']).values 535 | loc_z = np.zeros(adata_st.shape[0]) 536 | loc = np.concatenate([loc_xy, loc_z.reshape(-1, 1)], axis=1) 537 | else: 538 | if len(slice_dist_micron) != (len(adata_st_list) - 1): 539 | raise ValueError("The length of 'slice_dist_micron' should be the number of adatas - 1 !") 540 | else: 541 | loc_xy = pd.DataFrame(adata_st.obsm['spatial_aligned']).values 542 | loc_z = np.zeros(adata_st.shape[0]) 543 | dim = 0 544 | for i in range(len(slice_dist_micron)): 545 | dim += adata_st_list[i].shape[0] 546 | loc_z[dim:] += slice_dist_micron[i] * (min_dist_ref / c2c_dist) 547 | loc = np.concatenate([loc_xy, loc_z.reshape(-1, 1)], axis=1) 548 | 549 | # If 3D coordinates already exists 550 | else: 551 | if rad_cutoff is None: 552 | raise ValueError("Please specify 'rad_cutoff' for finding 3D neighbors!") 553 | loc = three_dim_coor 554 | 555 | loc = pd.DataFrame(loc) 556 | loc.index = adata_st.obs.index 557 | loc.columns = ['x', 'y', 'z'] 558 | 559 | if model == 'Radius': 560 | nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(loc) 561 | distances, indices = nbrs.radius_neighbors(loc, return_distance=True) 562 | KNN_list = [] 563 | for it in range(indices.shape[0]): 564 | KNN_list.append(pd.DataFrame(zip([it] * indices[it].shape[0], indices[it], distances[it]))) 565 | 566 | if model == 'KNN': 567 | nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff + 1).fit(loc) 568 | distances, indices = nbrs.kneighbors(loc) 569 | KNN_list = [] 570 | for it in range(indices.shape[0]): 571 | KNN_list.append(pd.DataFrame(zip([it] * indices.shape[1], indices[it, :], distances[it, :]))) 572 | 573 | KNN_df = pd.concat(KNN_list) 574 | KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] 575 | 576 | Spatial_Net = KNN_df.copy() 577 | Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance'] > 0,] 578 | id_cell_trans = dict(zip(range(loc.shape[0]), np.array(loc.index), )) 579 | Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans) 580 | Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans) 581 | 582 | print('The graph contains %d edges, %d cells.' % (Spatial_Net.shape[0], adata_st.n_obs)) 583 | print('%.4f neighbors per cell on average.' % (Spatial_Net.shape[0] / adata_st.n_obs)) 584 | 585 | adata_st.uns['Spatial_Net'] = Spatial_Net 586 | return adata_st 587 | 588 | 589 | # %% 590 | import shutil 591 | import warnings 592 | 593 | warnings.filterwarnings('ignore') 594 | 595 | import os 596 | 597 | import torch 598 | import SpaMask as stm 599 | from pathlib import Path 600 | import scanpy as sc 601 | from sklearn import metrics 602 | # %% 603 | from sklearn.decomposition import PCA 604 | 605 | 606 | def load_adata(section_ids, k_cutoff, rad_cutoff, model, n_top_genes): 607 | Batch_list = [] 608 | for section_id in section_ids: 609 | print(section_id) 610 | input_dir = os.path.join('D:\\project\\datasets\\DLPFC\\', section_id) 611 | adata = sc.read_visium(path=input_dir, count_file=section_id + '_filtered_feature_bc_matrix.h5', 612 | load_images=True) 613 | adata.var_names_make_unique(join="++") 614 | 615 | # read the annotation 616 | Ann_df = pd.read_csv(os.path.join(input_dir, section_id + '_truth.txt'), sep='\t', header=None, index_col=0) 617 | Ann_df.columns = ['Ground Truth'] 618 | Ann_df[Ann_df.isna()] = "unknown" 619 | adata.obs['Ground Truth'] = Ann_df.loc[adata.obs_names, 'Ground Truth'].astype('category') 620 | 621 | # make spot name unique 622 | adata.obs_names = [x + '_' + section_id for x in adata.obs_names] 623 | 624 | # stm.Cal_Spatial_Net(adata, rad_cutoff=150) 625 | adata.var_names_make_unique() 626 | adata.layers['count'] = adata.X.toarray() 627 | sc.pp.filter_genes(adata, min_cells=50) 628 | sc.pp.filter_genes(adata, min_counts=10) 629 | sc.pp.normalize_total(adata, target_sum=1e6) 630 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", layer='count', n_top_genes=n_top_genes) 631 | adata = adata[:, adata.var['highly_variable'] == True] 632 | sc.pp.scale(adata) 633 | adata = adata[:, adata.var['highly_variable']] 634 | Batch_list.append(adata) 635 | 636 | # %% 637 | Batch_list = align_spots(Batch_list, method='icp', plot=False) 638 | # %% 639 | adata_st = preprocess(Batch_list, section_ids=section_ids, k_cutoff=k_cutoff, rad_cutoff=rad_cutoff, model=model, 640 | slice_dist_micron=[10, 10, 10]) 641 | adata_X = PCA(n_components=200, random_state=42).fit_transform(adata_st.X) 642 | adata_st.obsm['feat'] = adata_X 643 | 644 | return adata_st 645 | 646 | 647 | def train_one(args, adata, section_ids, num_clusters, ARI_list): 648 | # %% 649 | net = stm.spaMask.SPAMASK(adata, 650 | tissue_name='Donor', 651 | num_clusters=num_clusters, 652 | device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'), 653 | learning_rate=args.learning_rate, 654 | weight_decay=args.weight_decay, 655 | max_epoch=args.max_epoch, 656 | gradient_clipping=args.gradient_clipping, 657 | feat_mask_rate=args.feat_mask_rate, 658 | edge_drop_rate=args.edge_drop_rate, 659 | hidden_dim=args.hidden_dim, 660 | latent_dim=args.latent_dim, 661 | bn=args.bn, 662 | att_dropout_rate=args.att_dropout_rate, 663 | fc_dropout_rate=args.fc_dropout_rate, 664 | use_token=args.use_token, 665 | rep_loss=args.rep_loss, 666 | rel_loss=args.rel_loss, 667 | alpha=args.alpha, 668 | lam=args.lam, 669 | random_seed=args.seed, 670 | nps=args.nps) 671 | net.train() 672 | # %% 673 | method = "kmeans" 674 | net.process(method=method) 675 | 676 | adata = net.get_adata() 677 | sub_adata = adata[~pd.isnull(adata.obs['Ground Truth'])] 678 | ARI = metrics.adjusted_rand_score(sub_adata.obs['Ground Truth'], sub_adata.obs[method]) 679 | print(f"total ARI:{ARI}") 680 | for name in section_ids: 681 | sub_adata_tmp = sub_adata[sub_adata.obs['batch_name'] == name] 682 | ARI = metrics.adjusted_rand_score(sub_adata_tmp.obs['Ground Truth'], sub_adata_tmp.obs[method]) 683 | print(f"{name} ARI:{round(ARI, 4)}") 684 | ARI_list.append(ARI) 685 | 686 | return ARI_list, adata 687 | 688 | 689 | args = stm.utils.build_args() 690 | args.hidden_dim, args.latent_dim = 512, 256 691 | args.max_epoch = 1000 692 | args.lam = 2 693 | args.feat_mask_rate = 0.5 694 | args.edge_drop_rate = 0.2 695 | args.top_genes = 5000 696 | args.rad_cutoff = 200 697 | args.k_cutoff = 21 698 | args.model = 'KNN' 699 | 700 | slices_list = ['151673', '151674', '151675', '151676'] 701 | num_clusters = 7 702 | 703 | adata = load_adata(slices_list, k_cutoff=args.k_cutoff, rad_cutoff=args.rad_cutoff, model=args.model, 704 | n_top_genes=args.top_genes) 705 | 706 | 707 | ARI_list = [] 708 | ARI_list, adata = train_one(args, adata, slices_list, num_clusters, ARI_list) 709 | 710 | ARI = np.median(ARI_list) 711 | 712 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SpaMask: Dual Masking Graph Autoencoder with Contrastive Learning for Spatial Transcriptomics 2 | ## 🔥 Introduction 3 | Understanding the spatial locations of cell within tissues is crucial for unraveling the organization of cellular diversity. Recent advancements in spatial resolved transcriptomics (SRT) have enabled the analysis of gene expression while preserving the spatial context within tissues. Spatial domain characterization is a critical first step in SRT data analysis, providing the foundation for subsequent analyses and insights into biological implications. Graph neural networks (GNNs) have emerged as a common tool for addressing this challenge due to the structural nature of SRT data. However, current graph-based deep learning approaches often overlook the instability caused by the high sparsity of SRT data. **Masking mechanisms**, as an effective self-supervised learning strategy, can enhance the robustness of these models. To this end, we propose **SpaMask, dual masking graph autoencoder with contrastive learning for SRT analysis**. Unlike previous GNNs, SpaMask masks a portion of spot nodes and spot-to-spot edges to enhance its performance and robustness. SpaMask combines **Masked Graph Autoencoders (MGAE) and Masked Graph Contrastive Learning (MGCL)** modules, with MGAE using node masking to leverage spatial neighbors for improved clustering accuracy, while MGCL applies edge masking to create a contrastive loss framework that tightens embeddings of adjacent nodes based on spatial proximity and feature similarity. We conducted a comprehensive evaluation of SpaMask on **eight datasets from five different platforms**. Compared to existing methods, SpaMask achieves superior clustering accuracy and effective batch correction. 4 | 5 | ![SpaMask.jpg](SpaMask.jpg) 6 | 7 | ## 🌐 Data 8 | - All public datasets used in this paper are available at [Zenodo](https://zenodo.org/records/14062665) 9 | 10 | ## 🔬 Setup 11 | - `pip install -r requirement.txt` 12 | 13 | ## 🚀 Get Started 14 | We provided codes for reproducing the experiments of the paper, and comprehensive tutorials for using SpaMask. 15 | - Please see `TutorialDonor.ipynb`. 16 | 17 | 18 | ## 🔥Citing 19 |

The corresponding BiBTeX citation are given below:

20 |
21 | @article{min2025spamask,
22 |   title={SpaMask: Dual masking graph autoencoder with contrastive learning for spatial transcriptomics},
23 |   author={Min, Wenwen and Fang, Donghai and Chen, Jinyu and Zhang, Shihua},
24 |   journal={PLOS Computational Biology},
25 |   volume={21},
26 |   number={4},
27 |   pages={e1012881},
28 |   year={2025},
29 |   publisher={Public Library of Science San Francisco, CA USA}
30 | }
31 | 
32 | 33 | ## Article link 34 | 35 | - [https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1012881](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1012881) 36 | -------------------------------------------------------------------------------- /SpaMask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenwenmin/SpaMask/89ba5ba647e02df7ea9d343032f4a2205fb8a034/SpaMask.jpg -------------------------------------------------------------------------------- /SpaMask/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocess import load_feat, Cal_Spatial_Net, Transfer_pytorch_Data 2 | from .utils import fix_seed, Stats_Spatial_Net, mclust_R, save_args_to_file 3 | from .model import stMask_model 4 | from .spaMask import SPAMASK 5 | 6 | __all__ = [ 7 | "load_feat", 8 | "Cal_Spatial_Net", 9 | "Transfer_pytorch_Data", 10 | "fix_seed", 11 | "Stats_Spatial_Net", 12 | "mclust_R", 13 | "save_args_to_file", 14 | ] 15 | -------------------------------------------------------------------------------- /SpaMask/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenwenmin/SpaMask/89ba5ba647e02df7ea9d343032f4a2205fb8a034/SpaMask/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /SpaMask/__pycache__/model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenwenmin/SpaMask/89ba5ba647e02df7ea9d343032f4a2205fb8a034/SpaMask/__pycache__/model.cpython-311.pyc -------------------------------------------------------------------------------- /SpaMask/__pycache__/preprocess.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenwenmin/SpaMask/89ba5ba647e02df7ea9d343032f4a2205fb8a034/SpaMask/__pycache__/preprocess.cpython-311.pyc -------------------------------------------------------------------------------- /SpaMask/__pycache__/spaMask.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenwenmin/SpaMask/89ba5ba647e02df7ea9d343032f4a2205fb8a034/SpaMask/__pycache__/spaMask.cpython-311.pyc -------------------------------------------------------------------------------- /SpaMask/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wenwenmin/SpaMask/89ba5ba647e02df7ea9d343032f4a2205fb8a034/SpaMask/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /SpaMask/model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial 3 | import torch.nn.functional as F 4 | import torch 5 | from torch import nn 6 | from torch_geometric.nn import ( 7 | TransformerConv, 8 | LayerNorm, 9 | Linear, 10 | GCNConv, 11 | SAGEConv, 12 | GATConv, 13 | GINConv, 14 | GATv2Conv, 15 | global_add_pool, 16 | global_mean_pool, 17 | global_max_pool 18 | ) 19 | 20 | try: 21 | import torch_cluster # noqa 22 | 23 | random_walk = torch.ops.torch_cluster.random_walk 24 | except ImportError: 25 | random_walk = None 26 | from torch_geometric.utils.num_nodes import maybe_num_nodes 27 | from torch_geometric.utils import to_undirected, sort_edge_index 28 | from torch_geometric.utils import add_self_loops, negative_sampling, degree 29 | 30 | def create_activation(name): 31 | if name == "relu": 32 | return nn.ReLU() 33 | elif name == "gelu": 34 | return nn.GELU() 35 | elif name == "prelu": 36 | return nn.PReLU() 37 | elif name is None: 38 | return nn.Identity() 39 | elif name == "elu": 40 | return nn.ELU() 41 | else: 42 | raise NotImplementedError(f"{name} is not implemented.") 43 | 44 | 45 | 46 | class Encoder(nn.Module): 47 | def __init__(self, input_dim, hidden_dim, latent_dim, bn=True, dropout_rate=.1, act="prelu", bias=True): 48 | super().__init__() 49 | bn = nn.BatchNorm1d if bn else nn.Identity 50 | # self.conv1 = GCNConv(in_channels=input_dim, out_channels=hidden_dim, heads=1, dropout=dropout_rate, concat=False, bias=bias) 51 | # self.bn1 = bn(hidden_dim * 1) 52 | # self.conv2 = GCNConv(in_channels=hidden_dim, out_channels=latent_dim, heads=1, dropout=dropout_rate, concat=False, bias=bias) 53 | # self.bn2 = bn(latent_dim * 1) 54 | self.conv1 = GCNConv(in_channels=input_dim, out_channels=hidden_dim) 55 | self.bn1 = bn(hidden_dim * 1) 56 | self.conv2 = GCNConv(in_channels=hidden_dim, out_channels=latent_dim) 57 | self.bn2 = bn(latent_dim * 1) 58 | self.activation = create_activation(act) 59 | 60 | def forward(self, x, edge_index): 61 | h = self.activation(self.bn2(self.conv2(self.activation(self.bn1(self.conv1(x, edge_index))), edge_index))) 62 | return h 63 | 64 | class FeatureDecoder(nn.Module): 65 | def __init__(self, latent_dim, output_dim, dropout_rate=.1, act="prelu", bias=True): 66 | super().__init__() 67 | # self.conv1 = GCNConv(in_channels=latent_dim, out_channels=output_dim, heads=1, dropout=dropout_rate, concat=False, bias=bias) 68 | self.conv1 = GCNConv(in_channels=latent_dim, out_channels=output_dim) 69 | self.activation = create_activation(act) 70 | 71 | def forward(self, x, edge_index): 72 | h = self.conv1(x, edge_index) 73 | return h 74 | 75 | class TopologyDecoder(nn.Module): 76 | def __init__(self, input_dim, latent_dim, output_dim=1, dropout_rate=0.5, act="relu"): 77 | super().__init__() 78 | self.fc1 = Linear(in_channels=input_dim, out_channels=latent_dim) 79 | self.fc1.reset_parameters() 80 | self.fc2 = Linear(in_channels=latent_dim, out_channels=output_dim) 81 | self.fc2.reset_parameters() 82 | self.d_drop = nn.Dropout(dropout_rate) 83 | self.activation = create_activation(act) 84 | 85 | def forward(self, x, edge_index): 86 | h = x[edge_index[0]] * x[edge_index[1]] 87 | h = self.fc2(self.activation(self.fc1(self.d_drop(h)))) 88 | return h 89 | 90 | class stMask_model(nn.Module): 91 | def __init__(self, features_dims, bn=False, att_dropout_rate=.2, fc_dropout_rate=.5, use_token=True, alpha=2, edge_drop_rate=0.3, feat_mask_rate=0.3, rep_loss="cse",rel_loss="ce"): 92 | super().__init__() 93 | [input_dim, hidden_dim, latent_dim, output_dim] = features_dims 94 | self.encoder = Encoder(input_dim, hidden_dim, latent_dim, bn=bn, dropout_rate=att_dropout_rate, act="prelu", bias=True) 95 | 96 | self.use_token = use_token 97 | if self.use_token: 98 | self.enc_mask_token = nn.Parameter(torch.zeros(1, input_dim)) 99 | self.encoder_to_decoder = nn.Linear(latent_dim, latent_dim, bias=False) 100 | nn.init.xavier_uniform_(self.encoder_to_decoder.weight) 101 | self.feat_deocder = FeatureDecoder(latent_dim, output_dim, dropout_rate=att_dropout_rate, act="prelu", bias=True) 102 | self.topo_decoder = TopologyDecoder(latent_dim, 2*latent_dim, 1, fc_dropout_rate) 103 | 104 | 105 | self.feat_loss = self.setup_loss_fn(rep_loss, alpha) 106 | self.edge_loss = self.setup_loss_fn(rel_loss) 107 | 108 | self.edge_drop_rate = edge_drop_rate 109 | self.feat_mask_rate = feat_mask_rate 110 | 111 | def forward(self, data): 112 | x = data.x 113 | edge_index = data.edge_index 114 | num_nodes = data.num_nodes 115 | 116 | use_mask_x, mask_nodes = self.mask_feature(x, self.feat_mask_rate) 117 | remaining_edges, masked_edges = self.dropout_edge(edge_index, self.edge_drop_rate) 118 | 119 | rep_x = self.encoder(use_mask_x, edge_index) 120 | rep_e = self.encoder(x, remaining_edges) 121 | 122 | # remasking feats 123 | rec_x = self.encoder_to_decoder(rep_x) 124 | rec_x[mask_nodes] = 0 125 | rec_x = self.feat_deocder(rec_x, edge_index) 126 | feat_loss = self.feat_loss(x[mask_nodes], rec_x[mask_nodes]) 127 | 128 | # sampling neg edges 129 | aug_edge_index, _ = add_self_loops(edge_index) 130 | neg_edges = self.random_negative_sampler( 131 | aug_edge_index, 132 | num_nodes=num_nodes, 133 | num_neg_samples=masked_edges.view(2, -1).size(1), 134 | ).view_as(masked_edges) 135 | 136 | pos_edge = self.topo_decoder(rep_e, masked_edges) 137 | neg_edge = self.topo_decoder(rep_e, neg_edges) 138 | topo_loss = self.ce_loss(pos_edge, neg_edge) 139 | 140 | return feat_loss, topo_loss 141 | 142 | def setup_loss_fn(self, loss_fn, alpha_l=2): 143 | if loss_fn == "mse": 144 | criterion = nn.MSELoss() 145 | elif loss_fn == "cse": 146 | criterion = partial(self.sce_loss, alpha=alpha_l) 147 | elif loss_fn == "ce": 148 | criterion = partial(self.ce_loss) 149 | else: 150 | raise NotImplementedError 151 | return criterion 152 | 153 | def sce_loss(self, x, y, alpha=3): 154 | x = F.normalize(x, p=2, dim=-1) 155 | y = F.normalize(y, p=2, dim=-1) 156 | loss = (1 - (x * y).sum(dim=-1)).pow_(alpha) 157 | loss = loss.mean() 158 | return loss 159 | 160 | def ce_loss(self, pos_out, neg_out): 161 | pos_loss = F.binary_cross_entropy(pos_out.sigmoid(), torch.ones_like(pos_out)) 162 | neg_loss = F.binary_cross_entropy(neg_out.sigmoid(), torch.zeros_like(neg_out)) 163 | return pos_loss + neg_loss 164 | 165 | def mask_feature(self, x, feat_mask_rate=0.3): 166 | num_nodes = x.shape[0] 167 | perm = torch.randperm(num_nodes, device=x.device) 168 | # random masking 169 | num_mask_nodes = int(feat_mask_rate * num_nodes) 170 | mask_nodes = perm[: num_mask_nodes] 171 | keep_nodes = perm[num_mask_nodes:] 172 | out_x = x.clone() 173 | if self.use_token: 174 | out_x[mask_nodes] += self.enc_mask_token 175 | else: 176 | out_x[mask_nodes] = 0.0 177 | return out_x, mask_nodes #, keep_nodes 178 | 179 | def mask_features(self, x, feat_mask_rate=0.3): 180 | mask_nodes = torch.empty((x.size(0),), dtype=torch.float32, device=x.device).uniform_(0, 1) < feat_mask_rate 181 | mask_x = x.clone() 182 | mask_x[mask_nodes] = 0 183 | if self.use_token: 184 | mask_x[mask_nodes] += self.enc_mask_token 185 | return mask_x, mask_nodes 186 | 187 | def dropout_edge(self, my_edge_index, edge_drop_rate=0.3): 188 | edge_index = my_edge_index.clone() 189 | p = torch.zeros(edge_index.shape[1]).to(edge_index.device) + 1 - edge_drop_rate 190 | stay = torch.bernoulli(p).to(torch.bool) 191 | mask = ~stay 192 | remaining_edges, masked_edges = edge_index[:, stay], edge_index[:, mask] 193 | remaining_edges = to_undirected(remaining_edges) 194 | return remaining_edges, masked_edges 195 | 196 | def random_negative_sampler(self, edge_index, num_nodes, num_neg_samples): 197 | neg_edges = torch.randint(0, num_nodes, size=(2, num_neg_samples)).to(edge_index) 198 | return neg_edges 199 | 200 | @torch.no_grad() 201 | def embed(self, data): 202 | x = data.x 203 | edge_index = data.edge_index 204 | h = self.encoder(x, edge_index) 205 | return h 206 | 207 | @torch.no_grad() 208 | def recon(self, data): 209 | x = data.x 210 | edge_index = data.edge_index 211 | h = self.encoder(x, edge_index) 212 | rec = self.encoder_to_decoder(h) 213 | rec = self.feat_deocder(rec, edge_index) 214 | return h, rec 215 | 216 | 217 | @torch.no_grad() 218 | def embed_masking(self, data): 219 | x = data.x 220 | edge_index = data.edge_index 221 | use_mask_x, mask_nodes = self.mask_feature(x, self.feat_mask_rate) 222 | remaining_edges, masked_edges = self.dropout_edge(edge_index, self.edge_drop_rate) 223 | rep_x = self.encoder(use_mask_x, edge_index) 224 | rep_e = self.encoder(x, remaining_edges) 225 | return rep_x, rep_e 226 | -------------------------------------------------------------------------------- /SpaMask/preprocess.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | import os 4 | import torch 5 | import random 6 | import numpy as np 7 | import scanpy as sc 8 | import pandas as pd 9 | import scipy.sparse as sp 10 | from sklearn.neighbors import NearestNeighbors 11 | import sklearn.neighbors 12 | from torch_geometric.data import Data 13 | from pathlib import Path 14 | 15 | def prefilter_genes(adata,min_counts=None,max_counts=None,min_cells=10,max_cells=None): 16 | if min_cells is None and min_counts is None and max_cells is None and max_counts is None: 17 | raise ValueError('Provide one of min_counts, min_genes, max_counts or max_genes.') 18 | id_tmp=np.asarray([True]*adata.shape[1],dtype=bool) 19 | id_tmp=np.logical_and(id_tmp,sc.pp.filter_genes(adata.X,min_cells=min_cells)[0]) if min_cells is not None else id_tmp 20 | id_tmp=np.logical_and(id_tmp,sc.pp.filter_genes(adata.X,max_cells=max_cells)[0]) if max_cells is not None else id_tmp 21 | id_tmp=np.logical_and(id_tmp,sc.pp.filter_genes(adata.X,min_counts=min_counts)[0]) if min_counts is not None else id_tmp 22 | id_tmp=np.logical_and(id_tmp,sc.pp.filter_genes(adata.X,max_counts=max_counts)[0]) if max_counts is not None else id_tmp 23 | adata._inplace_subset_var(id_tmp) 24 | 25 | def load_feat(adata, top_genes=3000, model="pca"): 26 | assert (model in ['pca', 'hvg', 'other']) 27 | if model == "pca": 28 | adata.var_names_make_unique() 29 | if isinstance(adata.X, np.ndarray): 30 | adata.layers['count'] = adata.X 31 | else: 32 | adata.layers['count'] = adata.X.toarray() 33 | sc.pp.filter_genes(adata, min_cells=50) 34 | sc.pp.filter_genes(adata, min_counts=10) 35 | sc.pp.normalize_total(adata, target_sum=1e6) 36 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", layer='count', n_top_genes=top_genes) 37 | adata = adata[:, adata.var['highly_variable'] == True] 38 | sc.pp.scale(adata) 39 | from sklearn.decomposition import PCA # sklearn PCA is used because PCA in scanpy is not stable. 40 | adata_X = PCA(n_components=200, random_state=42).fit_transform(adata.X) 41 | adata.obsm['feat'] = adata_X 42 | print(f"adata.obsm['feat'].shape:{adata.obsm['feat'].shape}") 43 | 44 | elif model == "hvg": 45 | # Expression data preprocessing 46 | adata.var_names_make_unique() 47 | prefilter_genes(adata, min_cells=3) # avoiding all genes are zeros 48 | # Normalization 49 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=top_genes) 50 | sc.pp.normalize_total(adata, target_sum=1e4) 51 | sc.pp.log1p(adata) 52 | adata.X = sp.csr_matrix(adata.X) 53 | adata_Vars = adata[:, adata.var['highly_variable']] 54 | # sc.pp.scale(adata) 55 | adata.obsm['feat'] = adata_Vars.X[:, ] 56 | 57 | elif model == "other": 58 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=top_genes) 59 | sc.pp.normalize_total(adata, target_sum=1e4) 60 | sc.pp.log1p(adata) 61 | sc.pp.scale(adata, zero_center=False, max_value=10) 62 | adata.X = sp.csr_matrix(adata.X) 63 | adata_Vars = adata[:, adata.var['highly_variable']] 64 | adata.obsm['feat'] = adata_Vars.X[:, ] 65 | 66 | return adata 67 | 68 | 69 | 70 | def Cal_Spatial_Net(adata, rad_cutoff=None, k_cutoff=None, model='Radius', verbose=True): 71 | assert (model in ['Radius', 'KNN']) 72 | if verbose: 73 | print('------Calculating spatial graph...') 74 | coor = pd.DataFrame(adata.obsm['spatial']) 75 | coor.index = adata.obs.index 76 | coor.columns = ['imagerow', 'imagecol'] 77 | 78 | if model == 'Radius': 79 | nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(coor) 80 | distances, indices = nbrs.radius_neighbors(coor, return_distance=True) 81 | KNN_list = [] 82 | for it in range(indices.shape[0]): 83 | KNN_list.append(pd.DataFrame(zip([it] * indices[it].shape[0], indices[it], distances[it]))) 84 | 85 | if model == 'KNN': 86 | nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff + 1).fit(coor) 87 | distances, indices = nbrs.kneighbors(coor) 88 | KNN_list = [] 89 | for it in range(indices.shape[0]): 90 | KNN_list.append(pd.DataFrame(zip([it] * indices.shape[1], indices[it, :], distances[it, :]))) 91 | 92 | KNN_df = pd.concat(KNN_list) 93 | KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] 94 | 95 | Spatial_Net = KNN_df.copy() 96 | Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance'] > 0,] 97 | id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), )) 98 | Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans) 99 | Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans) 100 | if verbose: 101 | print('The graph contains %d edges, %d cells.' % (Spatial_Net.shape[0], adata.n_obs)) 102 | print('%.4f neighbors per cell on average.' % (Spatial_Net.shape[0] / adata.n_obs)) 103 | 104 | adata.uns['Spatial_Net'] = Spatial_Net 105 | # ######### 106 | # X = pd.DataFrame(adata.X.toarray()[:, ], index=adata.obs.index, columns=adata.var.index) 107 | # cells = np.array(X.index) 108 | # cells_id_tran = dict(zip(cells, range(cells.shape[0]))) 109 | # if 'Spatial_Net' not in adata.uns.keys(): 110 | # raise ValueError("Spatial_Net is not existed! Run Cal_Spatial_Net first!") 111 | # 112 | # Spatial_Net = adata.uns['Spatial_Net'] 113 | # G_df = Spatial_Net.copy() 114 | # G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran) 115 | # G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran) 116 | # G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs)) 117 | # G = G + sp.eye(G.shape[0]) # self-loop 118 | # adata.uns['adj'] = G 119 | return adata 120 | 121 | 122 | def Transfer_pytorch_Data(adata, weightless=True): 123 | if weightless: 124 | return weightless_undirected_graph(adata) 125 | else: 126 | return powered_undirected_graph(adata) 127 | 128 | 129 | def weightless_undirected_graph(adata): 130 | G_df = adata.uns['Spatial_Net'].copy() 131 | cells = np.array(adata.obs_names) 132 | cells_id_tran = dict(zip(cells, range(cells.shape[0]))) 133 | G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran) 134 | G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran) 135 | G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs)) 136 | G = G + sp.eye(G.shape[0]) 137 | edgeList = np.nonzero(G) 138 | if type(adata.obsm['feat']) == np.ndarray: 139 | data = Data(edge_index=torch.LongTensor(np.array( 140 | [edgeList[0], edgeList[1]])), x=torch.FloatTensor(adata.obsm['feat'])) # .todense() 141 | else: 142 | data = Data(edge_index=torch.LongTensor(np.array( 143 | [edgeList[0], edgeList[1]])), x=torch.FloatTensor(adata.obsm['feat'].todense())) # .todense() 144 | return data 145 | 146 | def powered_undirected_graph(adata): 147 | pass 148 | 149 | if __name__ == '__main__': 150 | # sample name 151 | sample_name = '151676' 152 | n_clusters = 5 if sample_name in ['151669', '151670', '151671', '151672'] else 7 153 | # path 154 | data_root = Path("D:\\project\\datasets\\DLPFC\\") 155 | count_file = sample_name + "_filtered_feature_bc_matrix.h5" 156 | adata = sc.read_visium(data_root / sample_name, count_file=count_file) 157 | adata = load_feat(adata, model="pca") 158 | print(adata.obsm['feat'].shape) -------------------------------------------------------------------------------- /SpaMask/spaMask.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from .preprocess import load_feat, Cal_Spatial_Net, Transfer_pytorch_Data 8 | from .utils import fix_seed, Stats_Spatial_Net, mclust_R, Kmeans_cluster 9 | from .model import stMask_model 10 | 11 | class SPAMASK: 12 | def __init__(self, 13 | adata, 14 | tissue_name="BRCA", 15 | num_clusters=20, 16 | top_genes=4000, 17 | genes_model="hvg", # 'pca', 'hvg' 18 | rad_cutoff=300, 19 | k_cutoff=12, 20 | graph_model='Radius', # 'Radius', 'KNN' 21 | device=torch.device('cpu'), 22 | learning_rate=0.001, 23 | weight_decay=2e-4, 24 | max_epoch=1500, 25 | gradient_clipping=5, 26 | feat_mask_rate=0.3, 27 | edge_drop_rate=0.6, 28 | hidden_dim=512, 29 | latent_dim=256, 30 | bn=True, 31 | att_dropout_rate=0.2, 32 | fc_dropout_rate=0.5, 33 | use_token=True, 34 | alpha=2, 35 | rep_loss="cse", 36 | rel_loss="ce", 37 | lam=1.4, 38 | random_seed=2024, 39 | nps=30, 40 | ): 41 | 42 | self.__adata = adata.copy() 43 | self.__tissue_name = tissue_name 44 | self.__top_genes = top_genes 45 | self.__genes_model = genes_model 46 | self.__rad_cutoff = rad_cutoff 47 | self.__k_cutoff = k_cutoff 48 | self.__graph_model = graph_model 49 | self.__device = device 50 | self.__learning_rate = learning_rate 51 | self.__weight_decay = weight_decay 52 | self.__max_epoch = max_epoch 53 | self.__gradient_clipping = gradient_clipping 54 | self.__feat_mask_rate = feat_mask_rate 55 | self.__edge_drop_rate = edge_drop_rate 56 | self.__hidden_dim = hidden_dim 57 | self.__latent_dim = latent_dim 58 | self.__bn = bn 59 | self.__att_dropout_rate = att_dropout_rate 60 | self.__fc_dropout_rate = fc_dropout_rate 61 | self.__use_token = use_token 62 | self.__alpha = alpha 63 | self.__rep_loss = rep_loss 64 | self.__rel_loss = rel_loss 65 | self.__lam = lam 66 | self.__nps = nps 67 | 68 | 69 | fix_seed(random_seed) 70 | 71 | if 'highly_variable' not in self.__adata.var.keys() and 'feat' not in adata.obsm.keys(): 72 | self.__adata = load_feat(self.__adata, top_genes=self.__top_genes, model=self.__genes_model) 73 | 74 | if 'Spatial_Net' not in self.__adata.uns.keys(): 75 | Cal_Spatial_Net(self.__adata, rad_cutoff=self.__rad_cutoff, k_cutoff=self.__k_cutoff, model=self.__graph_model) 76 | 77 | self.num_clusters = num_clusters # 5 if self.tissue_name in ['151669', '151670', '151671', '151672'] else 7 78 | print(self.__adata.obsm['feat'].shape) 79 | 80 | def train(self): 81 | data = Transfer_pytorch_Data(self.__adata).to(self.__device) 82 | output_dim = input_dim = data.x.shape[-1] 83 | features_dims = [input_dim, self.__hidden_dim, self.__latent_dim, output_dim] 84 | self.model = stMask_model(features_dims, bn=self.__bn, 85 | att_dropout_rate=self.__att_dropout_rate,fc_dropout_rate=self.__fc_dropout_rate, 86 | use_token=self.__use_token, alpha=self.__alpha, 87 | edge_drop_rate=self.__edge_drop_rate, feat_mask_rate=self.__feat_mask_rate, 88 | rep_loss=self.__rep_loss,rel_loss=self.__rel_loss).to(self.__device) 89 | 90 | self.optimizer = torch.optim.Adam(self.model.parameters(), self.__learning_rate, weight_decay=self.__weight_decay) 91 | 92 | y_pred_last = None 93 | epoch_iter = tqdm(range(self.__max_epoch)) 94 | for epoch in epoch_iter: 95 | self.model.train() 96 | self.optimizer.zero_grad() 97 | 98 | feat_loss, topo_loss = self.model(data) 99 | 100 | loss = feat_loss + topo_loss * self.__lam 101 | loss.backward() 102 | gradient_clipping = self.__gradient_clipping 103 | if gradient_clipping > 1: 104 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), gradient_clipping) 105 | self.optimizer.step() 106 | epoch_iter.set_description(f"Dataset_Name:{self.__tissue_name}, Ep {epoch}: train loss:{loss.item():.4f}") 107 | 108 | def process(self, method="kmeans"): 109 | data = Transfer_pytorch_Data(self.__adata).to(self.__device) 110 | with torch.no_grad(): 111 | self.model.eval() 112 | h, z = self.model.recon(data=data) 113 | rep = h.to('cpu').detach().numpy() 114 | rec = z.to('cpu').detach().numpy() 115 | if rep.shape[-1] > 64: 116 | from sklearn.decomposition import PCA 117 | pca = PCA(n_components=self.__nps) 118 | rep = pca.fit_transform(rep) 119 | self.__adata.obsm["eval_pred"] = rep 120 | self.__adata.obsm["eval_recon"] = rec 121 | 122 | if method == "mclust": 123 | mclust_R(self.__adata, num_cluster=self.num_clusters, used_obsm="eval_pred", key_added_pred=method) 124 | elif method == "kmeans": 125 | Kmeans_cluster(self.__adata, num_cluster=self.num_clusters, used_obsm="eval_pred", key_added_pred=method) 126 | 127 | 128 | 129 | def show_Stats_Spatial_Net(self): 130 | Stats_Spatial_Net(self.__adata) 131 | 132 | def save_model_dict(self, save_model_file): 133 | torch.save({'state_dict': self.model.state_dict()}, save_model_file) 134 | print('Saving model to %s' % save_model_file) 135 | 136 | def save_model(self, save_model_file): 137 | torch.save(self.model, save_model_file) 138 | print('Saving model to %s' % save_model_file) 139 | 140 | def load_model_dict(self, save_model_file): 141 | saved_state_dict = torch.load(save_model_file) 142 | self.model.load_state_dict(saved_state_dict['state_dict']) 143 | print('Loading model from %s' % save_model_file) 144 | 145 | def load_model(self, save_model_file): 146 | self.model = torch.load(save_model_file) 147 | print('Loading model from %s' % save_model_file) 148 | 149 | 150 | def get_adata(self): 151 | return self.__adata 152 | 153 | -------------------------------------------------------------------------------- /SpaMask/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | 7 | from sklearn.cluster import KMeans 8 | from scipy.spatial import distance 9 | from sklearn.metrics import adjusted_mutual_info_score,normalized_mutual_info_score,completeness_score,fowlkes_mallows_score, homogeneity_score 10 | 11 | from sklearn.metrics.cluster import v_measure_score, adjusted_rand_score 12 | 13 | # the location of R (used for the mclust clustering) 14 | os.environ['R_HOME'] = 'D:/software/R/R-4.3.2' 15 | os.environ['R_USER'] = 'D:/software/anaconda/anaconda3/envs/pt20cu118/Lib/site-packages/rpy2' 16 | 17 | 18 | def fix_seed(seed=2024): 19 | import random 20 | import torch 21 | from torch.backends import cudnn 22 | 23 | os.environ['PYTHONHASHSEED'] = str(seed) 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | cudnn.deterministic = True 30 | cudnn.benchmark = False 31 | 32 | os.environ['PYTHONHASHSEED'] = str(seed) 33 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 34 | 35 | 36 | def Stats_Spatial_Net(adata): 37 | Num_edge = adata.uns['Spatial_Net']['Cell1'].shape[0] 38 | Mean_edge = Num_edge / adata.shape[0] 39 | plot_df = pd.value_counts(pd.value_counts(adata.uns['Spatial_Net']['Cell1'])) 40 | plot_df = plot_df / adata.shape[0] 41 | fig, ax = plt.subplots(figsize=[3, 2]) 42 | plt.ylabel('Percentage') 43 | plt.xlabel('') 44 | plt.title('Number of Neighbors (Mean=%.2f)' % Mean_edge) 45 | ax.bar(plot_df.index, plot_df) 46 | 47 | def Kmeans_cluster(adata, num_cluster, used_obsm='model_pred', key_added_pred="kmeans", random_seed=2024): 48 | np.random.seed(random_seed) 49 | cluster_model = KMeans(n_clusters=num_cluster, init='k-means++', n_init=100, max_iter=1000, tol=1e-6) 50 | cluster_labels = cluster_model.fit_predict(adata.obsm[used_obsm]) 51 | adata.obs[key_added_pred] = cluster_labels 52 | adata.obs[key_added_pred] = adata.obs[key_added_pred].astype('int') 53 | adata.obs[key_added_pred] = adata.obs[key_added_pred].astype('category') 54 | return adata 55 | 56 | def mclust_R(adata, num_cluster, modelNames='EEE', used_obsm='model_pred', key_added_pred="mclust", random_seed=2024): 57 | np.random.seed(random_seed) 58 | import rpy2.robjects as robjects 59 | robjects.r.library("mclust") 60 | 61 | import rpy2.robjects.numpy2ri 62 | rpy2.robjects.numpy2ri.activate() 63 | r_random_seed = robjects.r['set.seed'] 64 | r_random_seed(random_seed) 65 | rmclust = robjects.r['Mclust'] 66 | 67 | res = rmclust(rpy2.robjects.numpy2ri.numpy2rpy(adata.obsm[used_obsm]), num_cluster, modelNames) 68 | mclust_res = np.array(res[-2]) 69 | 70 | adata.obs[key_added_pred] = mclust_res 71 | adata.obs[key_added_pred] = adata.obs[key_added_pred].astype('int') 72 | adata.obs[key_added_pred] = adata.obs[key_added_pred].astype('category') 73 | return adata 74 | 75 | 76 | def build_args(): 77 | import argparse 78 | parser = argparse.ArgumentParser(description="stMask") 79 | parser.add_argument("--model_name", type=str, default="SpaMask") 80 | parser.add_argument("--seed", type=int, default=2023) 81 | parser.add_argument("--tissue_name", type=str, default="151507") 82 | 83 | parser.add_argument("--top_genes", type=int, default=2000) 84 | parser.add_argument("--genes_model", type=str, default="pca") 85 | parser.add_argument("--rad_cutoff", type=int, default=200) 86 | parser.add_argument("--k_cutoff", type=int, default=12) 87 | parser.add_argument("--graph_model", type=str, default="KNN") 88 | 89 | parser.add_argument('--nps', type=int, default=30) 90 | parser.add_argument('--gradient_clipping', type=float, default=5.) 91 | parser.add_argument("--need_refine", action='store_true', default=False) 92 | 93 | # 各模型的训练设置 94 | parser.add_argument("--learning_rate", type=float, default=0.001) 95 | parser.add_argument("--weight_decay", type=float, default=2e-4) 96 | parser.add_argument("--max_epoch", type=int, default=500, help="number of training epochs") 97 | 98 | # ST params 99 | parser.add_argument("--edge_drop_rate", type=float, default=0.4) 100 | parser.add_argument("--feat_mask_rate", type=float, default=0.3) 101 | 102 | parser.add_argument("--hidden_dim", type=int, default=512) 103 | parser.add_argument("--latent_dim", type=int, default=256) 104 | 105 | parser.add_argument('--bn', action='store_true', default=True) 106 | parser.add_argument("--att_dropout_rate", type=float, default=.2) 107 | parser.add_argument("--fc_dropout_rate", type=float, default=.5) 108 | parser.add_argument("--use_token", action='store_true', default=True) 109 | parser.add_argument("--rep_loss", type=str, default="cse") 110 | parser.add_argument("--rel_loss", type=str, default="ce") 111 | parser.add_argument("--alpha", type=float, default=2.0) 112 | 113 | parser.add_argument("--lam", type=float, default=2) 114 | args = parser.parse_args(args=[]) 115 | return args 116 | 117 | 118 | def measureClusteringTrueLabel(labels_true, labels_pred): 119 | ari = adjusted_rand_score(labels_true, labels_pred) 120 | ami = adjusted_mutual_info_score(labels_true, labels_pred) 121 | nmi = normalized_mutual_info_score(labels_true, labels_pred) 122 | cs = completeness_score(labels_true, labels_pred) 123 | fms = fowlkes_mallows_score(labels_true, labels_pred) 124 | vms = v_measure_score(labels_true, labels_pred) 125 | hs = homogeneity_score(labels_true, labels_pred) 126 | return ari, ami, nmi, cs, fms, vms, hs 127 | 128 | 129 | def refine(adata, pred, shape="hexagon"): 130 | sample_id = adata.obs.index.tolist() 131 | dis = distance.cdist(adata.obsm['spatial'], adata.obsm['spatial'], 'euclidean') 132 | refined_pred = [] 133 | pred = pd.DataFrame({"pred": pred}, index=sample_id) 134 | dis_df = pd.DataFrame(dis, index=sample_id, columns=sample_id) 135 | if shape == "hexagon": 136 | num_nbs = 6 137 | elif shape == "square": 138 | num_nbs = 4 139 | else: 140 | print("Shape not recongized, shape='hexagon' for Visium data, 'square' for ST data.") 141 | for i in range(len(sample_id)): 142 | index = sample_id[i] 143 | dis_tmp = dis_df.loc[index, :].sort_values() 144 | nbs = dis_tmp[0:num_nbs + 1] 145 | nbs_pred = pred.loc[nbs.index, "pred"] 146 | self_pred = pred.loc[index, "pred"] 147 | v_c = nbs_pred.value_counts() 148 | if (v_c.loc[self_pred] < num_nbs / 2) and (np.max(v_c) > num_nbs / 2): 149 | refined_pred.append(v_c.idxmax()) 150 | else: 151 | refined_pred.append(self_pred) 152 | return refined_pred 153 | 154 | 155 | def save_args_to_file(args, filename): 156 | with open(filename, 'w') as file: 157 | file.write('Parsed Arguments:\n') 158 | for arg, value in vars(args).items(): 159 | arg_info = f"{arg}: {value}\n" 160 | file.write(arg_info) -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | aiobotocore==2.5.4 2 | aiohttp==3.9.5 3 | aioitertools==0.11.0 4 | aiosignal==1.3.1 5 | alembic==1.13.1 6 | anndata==0.10.7 7 | annotated-types==0.7.0 8 | annoy==1.17.3 9 | antlr4-python3-runtime==4.9.3 10 | anyio==3.5.0 11 | argon2-cffi==21.3.0 12 | argon2-cffi-bindings==21.2.0 13 | array_api_compat==1.6 14 | asciitree==0.3.3 15 | asttokens==2.0.5 16 | attrs==22.1.0 17 | backcall==0.2.0 18 | bbknn==1.3.9 19 | beautifulsoup4==4.12.2 20 | bleach==4.1.0 21 | botocore==1.31.17 22 | certifi==2024.2.2 23 | cffi==1.15.1 24 | charset-normalizer==3.3.2 25 | click==8.1.7 26 | click-plugins==1.1.1 27 | cligj==0.7.2 28 | cloudpickle==3.0.0 29 | colorama==0.4.6 30 | colorcet==3.1.0 31 | colorlog==6.8.2 32 | comm==0.1.2 33 | contourpy==1.2.1 34 | cycler==0.12.1 35 | Cython==3.0.10 36 | dask==2024.5.0 37 | dask-expr==1.1.0 38 | dask-image==2023.8.1 39 | datashader==0.16.1 40 | debugpy==1.6.7 41 | decorator==5.1.1 42 | defusedxml==0.7.1 43 | Deprecated==1.2.14 44 | dgl==2.2.1 45 | distributed==2024.5.0 46 | docrep==0.3.2 47 | einops==0.7.0 48 | entrypoints==0.4 49 | executing==0.8.3 50 | faiss-cpu==1.8.0 51 | fasteners==0.19 52 | fastjsonschema==2.16.2 53 | fbpca==1.0 54 | filelock==3.9.0 55 | fiona==1.9.6 56 | fonttools==4.51.0 57 | frozenlist==1.4.1 58 | fsspec==2023.6.0 59 | geomloss==0.2.6 60 | geopandas==0.14.4 61 | geosketch==1.2 62 | greenlet==3.0.3 63 | gseapy==1.1.3 64 | h5py==3.11.0 65 | harmonypy==0.0.9 66 | hnswlib==0.8.0 67 | huggingface-hub==0.22.2 68 | idna==3.4 69 | igraph==0.11.4 70 | imageio==2.34.1 71 | importlib_metadata==7.1.0 72 | inflect==7.2.1 73 | intervaltree==3.1.0 74 | ipykernel==6.25.0 75 | ipython==8.15.0 76 | ipython-genutils==0.2.0 77 | jedi==0.18.1 78 | Jinja2==3.1.2 79 | jmespath==1.0.1 80 | joblib==1.4.0 81 | jsonschema==4.17.3 82 | jsonschema-specifications==2023.12.1 83 | jupyter_client==7.4.9 84 | jupyter_core==5.3.0 85 | jupyter-server==1.23.4 86 | jupyterlab-pygments==0.1.2 87 | kiwisolver==1.4.5 88 | latexcodec==3.0.0 89 | lazy_loader==0.4 90 | legacy-api-wrap==1.4 91 | leidenalg==0.10.2 92 | lightning-utilities==0.11.2 93 | llvmlite==0.42.0 94 | locket==1.0.0 95 | louvain==0.8.2 96 | lxml==4.9.3 97 | Mako==1.3.5 98 | markdown-it-py==3.0.0 99 | MarkupSafe==2.1.1 100 | matplotlib==3.8.4 101 | matplotlib-inline==0.1.6 102 | matplotlib-scalebar==0.8.1 103 | mdurl==0.1.2 104 | mistune==0.8.4 105 | more-itertools==10.2.0 106 | mpmath==1.3.0 107 | msgpack==1.0.8 108 | multidict==6.0.5 109 | multipledispatch==1.0.0 110 | multiscale_spatial_image==0.11.2 111 | natsort==8.4.0 112 | nbclassic==0.5.5 113 | nbclient==0.5.13 114 | nbconvert==6.5.4 115 | nbformat==5.9.2 116 | nest-asyncio==1.5.6 117 | networkx==3.2.1 118 | notebook==6.5.4 119 | notebook_shim==0.2.2 120 | numba==0.59.1 121 | numcodecs==0.12.1 122 | numpy==1.26.3 123 | ome-zarr==0.8.3 124 | omegaconf==2.3.0 125 | omnipath==1.0.8 126 | opencv-python==4.9.0.80 127 | optuna==3.6.1 128 | packaging==23.1 129 | pandas==2.2.2 130 | pandocfilters==1.5.0 131 | param==2.1.0 132 | parso==0.8.3 133 | partd==1.4.1 134 | patsy==0.5.6 135 | pickleshare==0.7.5 136 | pillow==10.2.0 137 | PIMS==0.6.1 138 | pip==23.2.1 139 | platformdirs==3.10.0 140 | POT==0.9.3 141 | prometheus-client==0.14.1 142 | prompt-toolkit==3.0.36 143 | protobuf==5.26.1 144 | psutil==5.9.0 145 | pure-eval==0.2.2 146 | pyarrow==15.0.2 147 | pybtex==0.24.0 148 | pycparser==2.21 149 | pyct==0.5.0 150 | pydantic==2.7.3 151 | pydantic_core==2.18.4 152 | pydot==2.0.0 153 | pygeos==0.14 154 | Pygments==2.15.1 155 | pynndescent==0.5.12 156 | pyparsing==3.1.2 157 | pyproj==3.6.1 158 | pyrsistent==0.18.0 159 | python-dateutil==2.8.2 160 | pytorch-lightning==2.2.2 161 | pytz==2024.1 162 | pywin32==305.1 163 | pywinpty==2.0.10 164 | PyYAML==6.0.1 165 | pyzmq==23.2.0 166 | ray==2.10.0 167 | referencing==0.34.0 168 | requests==2.31.0 169 | rich==13.7.1 170 | rpds-py==0.18.0 171 | rpy2==3.5.16 172 | s3fs==2023.6.0 173 | safetensors==0.4.3 174 | scanorama==1.7.4 175 | scanpy==1.10.1 176 | scib==1.1.5 177 | scikit-image==0.23.2 178 | scikit-learn==1.4.2 179 | scikit-misc==0.3.1 180 | scipy==1.13.0 181 | seaborn==0.13.2 182 | Send2Trash==1.8.0 183 | session-info==1.0.0 184 | setuptools==68.0.0 185 | shapely==2.0.4 186 | six==1.16.0 187 | slicerator==1.1.0 188 | sniffio==1.2.0 189 | sortedcontainers==2.4.0 190 | soupsieve==2.4 191 | spatial_image==0.3.0 192 | spatialdata==0.0.15 193 | SQLAlchemy==2.0.30 194 | squidpy==1.4.1 195 | stack-data==0.2.0 196 | statsmodels==0.14.2 197 | stdlib-list==0.10.0 198 | sympy==1.12 199 | taming-transformers==0.0.1 200 | tblib==3.0.0 201 | tensorboardX==2.6.2.2 202 | terminado==0.17.1 203 | texttable==1.7.0 204 | threadpoolctl==3.4.0 205 | tifffile==2024.5.3 206 | timm==0.9.16 207 | tinycss2==1.2.1 208 | toolz==0.12.1 209 | torch==2.2.2+cu118 210 | torch_cluster==1.6.3+pt22cu118 211 | torch-fidelity==0.3.0 212 | torch_geometric==2.5.2 213 | torch_scatter==2.1.2+pt22cu118 214 | torch_sparse==0.6.18+pt22cu118 215 | torch_spline_conv==1.2.2+pt22cu118 216 | torchaudio==2.2.2+cu118 217 | torchdata==0.7.1 218 | torchmetrics==1.3.2 219 | torchvision==0.17.2+cu118 220 | tornado==6.3.2 221 | tqdm==4.66.2 222 | traitlets==5.7.1 223 | typeguard==4.2.1 224 | typing_extensions==4.11.0 225 | tzdata==2024.1 226 | tzlocal==5.2 227 | umap-learn==0.5.6 228 | urllib3==1.26.18 229 | validators==0.28.1 230 | wcwidth==0.2.5 231 | webencodings==0.5.1 232 | websocket-client==0.58.0 233 | wheel==0.41.2 234 | wrapt==1.16.0 235 | xarray==2023.12.0 236 | xarray-dataclasses==1.7.0 237 | xarray-datatree==0.0.14 238 | xarray-schema==0.0.3 239 | xarray-spatial==0.4.0 240 | yarl==1.9.4 241 | zarr==2.17.2 242 | zict==3.0.0 243 | zipp==3.18.1 244 | --------------------------------------------------------------------------------