├── DLPFC_MULTISLICES.py
├── README.md
├── SpaMask.jpg
├── SpaMask
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-311.pyc
│ ├── model.cpython-311.pyc
│ ├── preprocess.cpython-311.pyc
│ ├── spaMask.cpython-311.pyc
│ └── utils.cpython-311.pyc
├── model.py
├── preprocess.py
├── spaMask.py
└── utils.py
├── TutorialDonor.ipynb
└── requirement.txt
/DLPFC_MULTISLICES.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import ot
3 | import scipy.sparse
4 | import matplotlib
5 | import sklearn
6 | from matplotlib import cm
7 | import matplotlib.pyplot as plt
8 | import random
9 | import numpy as np
10 | from sklearn.neighbors import NearestNeighbors
11 |
12 | np.random.seed(42)
13 | random.seed(42)
14 |
15 |
16 | def align_spots(adata_st_list_input, # list of spatial transcriptomics datasets
17 | method="icp", # "icp" or "paste"
18 | data_type="Visium",
19 | # a spot has six nearest neighborhoods if "Visium", four nearest neighborhoods otherwise
20 | coor_key="spatial", # "spatial" for visium; key for the spatial coordinates used for alignment
21 | tol=0.01, # parameter for "icp" method; tolerance level
22 | test_all_angles=False, # parameter for "icp" method; whether to test multiple rotation angles or not
23 | plot=False,
24 | paste_alpha=0.1,
25 | paste_dissimilarity="kl"
26 | ):
27 | # Align coordinates of spatial transcriptomics
28 |
29 | # The first adata in the list is used as a reference for alignment
30 | adata_st_list = adata_st_list_input.copy()
31 |
32 | if plot:
33 | # Choose colors
34 | cmap = cm.get_cmap('rainbow', len(adata_st_list))
35 | colors_list = [matplotlib.colors.rgb2hex(cmap(i)) for i in range(len(adata_st_list))]
36 |
37 | # Plot spots before alignment
38 | plt.figure(figsize=(5, 5))
39 | plt.title("Before alignment")
40 | for i in range(len(adata_st_list)):
41 | plt.scatter(adata_st_list[i].obsm[coor_key][:, 0],
42 | adata_st_list[i].obsm[coor_key][:, 1],
43 | c=colors_list[i],
44 | label="Slice %d spots" % i, s=5., alpha=0.5)
45 | ax = plt.gca()
46 | ax.set_ylim(ax.get_ylim()[::-1])
47 | plt.xticks([])
48 | plt.yticks([])
49 | plt.legend(loc=(1.02, .2), ncol=(len(adata_st_list) // 13 + 1))
50 | plt.show()
51 |
52 | if (method == "icp") or (method == "ICP"):
53 | print("Using the Iterative Closest Point algorithm for alignemnt.")
54 | # Detect edges
55 | print("Detecting edges...")
56 | point_cloud_list = []
57 | for adata in adata_st_list:
58 | # Use in-tissue spots only
59 | if 'in_tissue' in adata.obs.columns:
60 | adata = adata[adata.obs['in_tissue'] == 1]
61 | if data_type == "Visium":
62 | loc_x = adata.obs.loc[:, ["array_row"]]
63 | loc_x = np.array(loc_x) * np.sqrt(3)
64 | loc_y = adata.obs.loc[:, ["array_col"]]
65 | loc_y = np.array(loc_y)
66 | loc = np.concatenate((loc_x, loc_y), axis=1)
67 | pairwise_loc_distsq = np.sum((loc.reshape([1, -1, 2]) - loc.reshape([-1, 1, 2])) ** 2, axis=2)
68 | n_neighbors = np.sum(pairwise_loc_distsq < 5, axis=1) - 1
69 | edge = ((n_neighbors > 1) & (n_neighbors < 5)).astype(np.float32)
70 | else:
71 | loc_x = adata.obs.loc[:, ["array_row"]]
72 | loc_x = np.array(loc_x)
73 | loc_y = adata.obs.loc[:, ["array_col"]]
74 | loc_y = np.array(loc_y)
75 | loc = np.concatenate((loc_x, loc_y), axis=1)
76 | pairwise_loc_distsq = np.sum((loc.reshape([1, -1, 2]) - loc.reshape([-1, 1, 2])) ** 2, axis=2)
77 | min_distsq = np.sort(np.unique(pairwise_loc_distsq), axis=None)[1]
78 | n_neighbors = np.sum(pairwise_loc_distsq < (min_distsq * 3), axis=1) - 1
79 | edge = ((n_neighbors > 1) & (n_neighbors < 7)).astype(np.float32)
80 | point_cloud_list.append(adata.obsm[coor_key][edge == 1].copy())
81 |
82 | # Align edges
83 | print("Aligning edges...")
84 | trans_list = []
85 | adata_st_list[0].obsm["spatial_aligned"] = adata_st_list[0].obsm[coor_key].copy()
86 | # Calculate pairwise transformation matrices
87 | for i in range(len(adata_st_list) - 1):
88 | if test_all_angles == True:
89 | for angle in [0., np.pi * 1 / 3, np.pi * 2 / 3, np.pi, np.pi * 4 / 3, np.pi * 5 / 3]:
90 | R = np.array([[np.cos(angle), np.sin(angle), 0],
91 | [-np.sin(angle), np.cos(angle), 0],
92 | [0, 0, 1]]).T
93 | T, distances, _ = icp(transform(point_cloud_list[i + 1], R), point_cloud_list[i], tolerance=tol)
94 | if angle == 0:
95 | loss_best = np.mean(distances)
96 | angle_best = angle
97 | R_best = R
98 | T_best = T
99 | else:
100 | if np.mean(distances) < loss_best:
101 | loss_best = np.mean(distances)
102 | angle_best = angle
103 | R_best = R
104 | T_best = T
105 | T = T_best @ R_best
106 | else:
107 | T, _, _ = icp(point_cloud_list[i + 1], point_cloud_list[i], tolerance=tol)
108 | trans_list.append(T)
109 | # Tranform
110 | for i in range(len(adata_st_list) - 1):
111 | point_cloud_align = adata_st_list[i + 1].obsm[coor_key].copy()
112 | for T in trans_list[:(i + 1)][::-1]:
113 | point_cloud_align = transform(point_cloud_align, T)
114 | adata_st_list[i + 1].obsm["spatial_aligned"] = point_cloud_align
115 |
116 | elif (method == "paste") or (method == "PASTE"):
117 | print("Using PASTE algorithm for alignemnt.")
118 | # Align spots
119 | print("Aligning spots...")
120 | pis = []
121 | # Calculate pairwise transformation matrices
122 | for i in range(len(adata_st_list) - 1):
123 | pi = pairwise_align_paste(adata_st_list[i], adata_st_list[i + 1], coor_key=coor_key,
124 | alpha=paste_alpha, dissimilarity=paste_dissimilarity)
125 | pis.append(pi)
126 | # Tranform
127 | S1, S2 = generalized_procrustes_analysis(adata_st_list[0].obsm[coor_key],
128 | adata_st_list[1].obsm[coor_key],
129 | pis[0])
130 | adata_st_list[0].obsm["spatial_aligned"] = S1
131 | adata_st_list[1].obsm["spatial_aligned"] = S2
132 | for i in range(1, len(adata_st_list) - 1):
133 | S1, S2 = generalized_procrustes_analysis(adata_st_list[i].obsm["spatial_aligned"],
134 | adata_st_list[i + 1].obsm[coor_key],
135 | pis[i])
136 | adata_st_list[i + 1].obsm["spatial_aligned"] = S2
137 |
138 | if plot:
139 | plt.figure(figsize=(5, 5))
140 | plt.title("After alignment")
141 | for i in range(len(adata_st_list)):
142 | plt.scatter(adata_st_list[i].obsm["spatial_aligned"][:, 0],
143 | adata_st_list[i].obsm["spatial_aligned"][:, 1],
144 | c=colors_list[i],
145 | label="Slice %d spots" % i, s=5., alpha=0.5)
146 | ax = plt.gca()
147 | ax.set_ylim(ax.get_ylim()[::-1])
148 | plt.xticks([])
149 | plt.yticks([])
150 | plt.legend(loc=(1.02, .2), ncol=(len(adata_st_list) // 13 + 1))
151 | plt.show()
152 |
153 | return adata_st_list
154 |
155 |
156 | # Functions for the Iterative Closest Point algorithm
157 | # Credit to https://github.com/ClayFlannigan/icp
158 | def best_fit_transform(A, B):
159 | assert A.shape == B.shape
160 |
161 | # get number of dimensions
162 | m = A.shape[1]
163 |
164 | # translate points to their centroids
165 | centroid_A = np.mean(A, axis=0)
166 | centroid_B = np.mean(B, axis=0)
167 | AA = A - centroid_A
168 | BB = B - centroid_B
169 |
170 | # rotation matrix
171 | H = np.dot(AA.T, BB)
172 | U, S, Vt = np.linalg.svd(H)
173 | R = np.dot(Vt.T, U.T)
174 |
175 | # special reflection case
176 | if np.linalg.det(R) < 0:
177 | Vt[m - 1, :] *= -1
178 | R = np.dot(Vt.T, U.T)
179 |
180 | # translation
181 | t = centroid_B.T - np.dot(R, centroid_A.T)
182 |
183 | # homogeneous transformation
184 | T = np.identity(m + 1)
185 | T[:m, :m] = R
186 | T[:m, m] = t
187 |
188 | return T, R, t
189 |
190 |
191 | def nearest_neighbor(src, dst):
192 | '''
193 | Find the nearest (Euclidean) neighbor in dst for each point in src
194 | Input:
195 | src: Nxm array of points
196 | dst: Nxm array of points
197 | Output:
198 | distances: Euclidean distances of the nearest neighbor
199 | indices: dst indices of the nearest neighbor
200 | '''
201 |
202 | neigh = NearestNeighbors(n_neighbors=1)
203 | neigh.fit(dst)
204 | distances, indices = neigh.kneighbors(src, return_distance=True)
205 | return distances.ravel(), indices.ravel()
206 |
207 |
208 | def icp(A, B, init_pose=None, max_iterations=20, tolerance=0.001):
209 | '''
210 | The Iterative Closest Point method: finds best-fit transform that maps points A on to points B
211 | Input:
212 | A: Nxm numpy array of source mD points
213 | B: Nxm numpy array of destination mD point
214 | init_pose: (m+1)x(m+1) homogeneous transformation
215 | max_iterations: exit algorithm after max_iterations
216 | tolerance: convergence criteria
217 | Output:
218 | T: final homogeneous transformation that maps A on to B
219 | distances: Euclidean distances (errors) of the nearest neighbor
220 | i: number of iterations to converge
221 | '''
222 |
223 | # get number of dimensions
224 | m = A.shape[1]
225 |
226 | # make points homogeneous, copy them to maintain the originals
227 | src = np.ones((m + 1, A.shape[0]))
228 | dst = np.ones((m + 1, B.shape[0]))
229 | src[:m, :] = np.copy(A.T)
230 | dst[:m, :] = np.copy(B.T)
231 |
232 | # apply the initial pose estimation
233 | if init_pose is not None:
234 | src = np.dot(init_pose, src)
235 |
236 | prev_error = 0
237 |
238 | for i in range(max_iterations):
239 | # find the nearest neighbors between the current source and destination points
240 | distances, indices = nearest_neighbor(src[:m, :].T, dst[:m, :].T)
241 |
242 | # compute the transformation between the current source and nearest destination points
243 | T, _, _ = best_fit_transform(src[:m, :].T, dst[:m, indices].T)
244 |
245 | # update the current source
246 | src = np.dot(T, src)
247 |
248 | # check error
249 | mean_error = np.mean(distances)
250 | if np.abs(prev_error - mean_error) < tolerance:
251 | break
252 | prev_error = mean_error
253 |
254 | # calculate final transformation
255 | T, _, _ = best_fit_transform(A, src[:m, :].T)
256 |
257 | return T, distances, i
258 |
259 |
260 | def transform(point_cloud, T):
261 | point_cloud_align = np.ones((point_cloud.shape[0], 3))
262 | point_cloud_align[:, 0:2] = np.copy(point_cloud)
263 | point_cloud_align = np.dot(T, point_cloud_align.T).T
264 | return point_cloud_align[:, :2]
265 |
266 |
267 | # Functions for the PASTE algorithm
268 | # Credit to https://github.com/raphael-group/paste
269 |
270 | ## Covert a sparse matrix into a dense np array
271 | to_dense_array = lambda X: X.toarray() if isinstance(X, scipy.sparse.csr.spmatrix) else np.array(X)
272 |
273 | ## Returns the data matrix or representation
274 | extract_data_matrix = lambda adata, rep: adata.X if rep is None else adata.obsm[rep]
275 |
276 |
277 | def intersect(lst1, lst2):
278 | temp = set(lst2)
279 | lst3 = [value for value in lst1 if value in temp]
280 | return lst3
281 |
282 |
283 | def kl_divergence_backend(X, Y):
284 | assert X.shape[1] == Y.shape[1], "X and Y do not have the same number of features."
285 |
286 | nx = ot.backend.get_backend(X, Y)
287 |
288 | X = X / nx.sum(X, axis=1, keepdims=True)
289 | Y = Y / nx.sum(Y, axis=1, keepdims=True)
290 | log_X = nx.log(X)
291 | log_Y = nx.log(Y)
292 | X_log_X = nx.einsum('ij,ij->i', X, log_X)
293 | X_log_X = nx.reshape(X_log_X, (1, X_log_X.shape[0]))
294 | D = X_log_X.T - nx.dot(X, log_Y.T)
295 | return nx.to_numpy(D)
296 |
297 |
298 | def my_fused_gromov_wasserstein(M, C1, C2, p, q, G_init=None, loss_fun='square_loss', alpha=0.5, armijo=False,
299 | log=False, numItermax=200, use_gpu=False, **kwargs):
300 | p, q = ot.utils.list_to_array(p, q)
301 |
302 | p0, q0, C10, C20, M0 = p, q, C1, C2, M
303 | nx = ot.backend.get_backend(p0, q0, C10, C20, M0)
304 |
305 | constC, hC1, hC2 = ot.gromov.init_matrix(C1, C2, p, q, loss_fun)
306 |
307 | if G_init is None:
308 | G0 = p[:, None] * q[None, :]
309 | else:
310 | G0 = (1 / nx.sum(G_init)) * G_init
311 | if use_gpu:
312 | G0 = G0.cuda()
313 |
314 | def f(G):
315 | return ot.gromov.gwloss(constC, hC1, hC2, G)
316 |
317 | def df(G):
318 | return ot.gromov.gwggrad(constC, hC1, hC2, G)
319 |
320 | if log:
321 | res, log = ot.gromov.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC,
322 | log=True, **kwargs)
323 |
324 | fgw_dist = log['loss'][-1]
325 |
326 | log['fgw_dist'] = fgw_dist
327 | log['u'] = log['u']
328 | log['v'] = log['v']
329 | return res, log
330 |
331 | else:
332 | return ot.gromov.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC,
333 | **kwargs)
334 |
335 |
336 | def pairwise_align_paste(
337 | sliceA,
338 | sliceB,
339 | alpha=0.1,
340 | dissimilarity='kl',
341 | use_rep=None,
342 | G_init=None,
343 | a_distribution=None,
344 | b_distribution=None,
345 | norm=False,
346 | numItermax=200,
347 | backend=ot.backend.NumpyBackend(),
348 | use_gpu=False,
349 | return_obj=False,
350 | verbose=False,
351 | gpu_verbose=False,
352 | coor_key="spatial",
353 | **kwargs):
354 | if use_gpu:
355 | try:
356 | import torch
357 | except:
358 | print("We currently only have gpu support for Pytorch. Please install torch.")
359 |
360 | if isinstance(backend, ot.backend.TorchBackend):
361 | if torch.cuda.is_available():
362 | if gpu_verbose:
363 | print("gpu is available, using gpu.")
364 | else:
365 | if gpu_verbose:
366 | print("gpu is not available, resorting to torch cpu.")
367 | use_gpu = False
368 | else:
369 | print(
370 | "We currently only have gpu support for Pytorch, please set backend = ot.backend.TorchBackend(). Reverting to selected backend cpu.")
371 | use_gpu = False
372 | else:
373 | if gpu_verbose:
374 | print("Using selected backend cpu. If you want to use gpu, set use_gpu = True.")
375 |
376 | # subset for common genes
377 | common_genes = intersect(sliceA.var.index, sliceB.var.index)
378 | sliceA = sliceA[:, common_genes]
379 | sliceB = sliceB[:, common_genes]
380 |
381 | # Backend
382 | nx = backend
383 |
384 | # Calculate spatial distances
385 | coordinatesA = sliceA.obsm[coor_key].copy()
386 | coordinatesA = nx.from_numpy(coordinatesA)
387 | coordinatesB = sliceB.obsm[coor_key].copy()
388 | coordinatesB = nx.from_numpy(coordinatesB)
389 |
390 | if isinstance(nx, ot.backend.TorchBackend):
391 | coordinatesA = coordinatesA.float()
392 | coordinatesB = coordinatesB.float()
393 | D_A = ot.dist(coordinatesA, coordinatesA, metric='euclidean')
394 | D_B = ot.dist(coordinatesB, coordinatesB, metric='euclidean')
395 |
396 | if isinstance(nx, ot.backend.TorchBackend) and use_gpu:
397 | D_A = D_A.cuda()
398 | D_B = D_B.cuda()
399 |
400 | # Calculate expression dissimilarity
401 | A_X, B_X = nx.from_numpy(to_dense_array(extract_data_matrix(sliceA, use_rep))), nx.from_numpy(
402 | to_dense_array(extract_data_matrix(sliceB, use_rep)))
403 |
404 | if isinstance(nx, ot.backend.TorchBackend) and use_gpu:
405 | A_X = A_X.cuda()
406 | B_X = B_X.cuda()
407 |
408 | if dissimilarity.lower() == 'euclidean' or dissimilarity.lower() == 'euc':
409 | M = ot.dist(A_X, B_X)
410 | else:
411 | s_A = A_X + 0.01
412 | s_B = B_X + 0.01
413 | M = kl_divergence_backend(s_A, s_B)
414 | M = nx.from_numpy(M)
415 |
416 | if isinstance(nx, ot.backend.TorchBackend) and use_gpu:
417 | M = M.cuda()
418 |
419 | # init distributions
420 | if a_distribution is None:
421 | a = nx.ones((sliceA.shape[0],)) / sliceA.shape[0]
422 | else:
423 | a = nx.from_numpy(a_distribution)
424 |
425 | if b_distribution is None:
426 | b = nx.ones((sliceB.shape[0],)) / sliceB.shape[0]
427 | else:
428 | b = nx.from_numpy(b_distribution)
429 |
430 | if isinstance(nx, ot.backend.TorchBackend) and use_gpu:
431 | a = a.cuda()
432 | b = b.cuda()
433 |
434 | if norm:
435 | D_A /= nx.min(D_A[D_A > 0])
436 | D_B /= nx.min(D_B[D_B > 0])
437 |
438 | # Run OT
439 | if G_init is not None:
440 | G_init = nx.from_numpy(G_init)
441 | if isinstance(nx, ot.backend.TorchBackend):
442 | G_init = G_init.float()
443 | if use_gpu:
444 | G_init.cuda()
445 | pi, logw = my_fused_gromov_wasserstein(M, D_A, D_B, a, b, G_init=G_init, loss_fun='square_loss', alpha=alpha,
446 | log=True, numItermax=numItermax, verbose=verbose, use_gpu=use_gpu)
447 | pi = nx.to_numpy(pi)
448 | obj = nx.to_numpy(logw['fgw_dist'])
449 | if isinstance(backend, ot.backend.TorchBackend) and use_gpu:
450 | torch.cuda.empty_cache()
451 |
452 | if return_obj:
453 | return pi, obj
454 | return pi
455 |
456 |
457 | def generalized_procrustes_analysis(X, Y, pi, output_params=False, matrix=False):
458 | """
459 | Finds and applies optimal rotation between spatial coordinates of two layers (may also do a reflection).
460 | Args:
461 | X: np array of spatial coordinates (ex: sliceA.obs['spatial'])
462 | Y: np array of spatial coordinates (ex: sliceB.obs['spatial'])
463 | pi: mapping between the two layers output by PASTE
464 | output_params: Boolean of whether to return rotation angle and translations along with spatial coordiantes.
465 | matrix: Boolean of whether to return the rotation as a matrix or an angle
466 | Returns:
467 | Aligned spatial coordinates of X, Y, rotation angle, translation of X, translation of Y
468 | """
469 | assert X.shape[1] == 2 and Y.shape[1] == 2
470 |
471 | tX = pi.sum(axis=1).dot(X)
472 | tY = pi.sum(axis=0).dot(Y)
473 | X = X - tX
474 | Y = Y - tY
475 | H = Y.T.dot(pi.T.dot(X))
476 | U, S, Vt = np.linalg.svd(H)
477 | R = Vt.T.dot(U.T)
478 | Y = R.dot(Y.T).T
479 | if output_params and not matrix:
480 | M = np.array([[0, -1], [1, 0]])
481 | theta = np.arctan(np.trace(M.dot(H)) / np.trace(H))
482 | return X, Y, theta, tX, tY
483 | elif output_params and matrix:
484 | return X, Y, R, tX, tY
485 | else:
486 | return X, Y
487 |
488 |
489 | # %%
490 | import anndata as ad
491 | import scipy.sparse
492 | from sklearn.metrics import pairwise_distances
493 | from sklearn.neighbors import NearestNeighbors
494 | import sklearn.neighbors
495 | import scipy.sparse as sp
496 | import numpy as np
497 | import pandas as pd
498 |
499 |
500 | def preprocess(adata_st_list, # list of spatial transcriptomics (ST) anndata objects
501 | section_ids=None,
502 | three_dim_coor=None, # if not None, use existing 3d coordinates in shape [# of total spots, 3]
503 | coor_key="spatial_aligned", # "spatial_aligned" by default
504 | rad_cutoff=None, # cutoff radius of spots for building graph
505 | rad_coef=1.5, # if rad_cutoff=None, rad_cutoff is the minimum distance between spots multiplies rad_coef
506 | k_cutoff=12,
507 | slice_dist_micron=None, # pairwise distances in micrometer for reconstructing z-axis
508 | c2c_dist=100, # center to center distance between nearest spots in micrometer
509 | model='KNN',
510 | ):
511 | assert (model in ['Radius', 'KNN'])
512 | adata_st = ad.concat(adata_st_list, label="slice_name", keys=section_ids)
513 | adata_st.obs['Ground Truth'] = adata_st.obs['Ground Truth'].astype('category')
514 | adata_st.obs["batch_name"] = adata_st.obs["slice_name"].astype('category')
515 |
516 | # Build a graph for spots across multiple slices
517 | print("Start building a graph...")
518 |
519 | # Build 3D coordinates
520 | if three_dim_coor is None:
521 | # The first adata in adata_list is used as a reference for computing cutoff radius of spots
522 | adata_st_ref = adata_st_list[0].copy()
523 | loc_ref = np.array(adata_st_ref.obsm[coor_key])
524 | pair_dist_ref = pairwise_distances(loc_ref)
525 | min_dist_ref = np.sort(np.unique(pair_dist_ref), axis=None)[1]
526 |
527 | if rad_cutoff is None:
528 | # The radius is computed base on the attribute "adata.obsm['spatial']"
529 | rad_cutoff = min_dist_ref * rad_coef
530 | print("Radius for graph connection is %.4f." % rad_cutoff)
531 |
532 | # Use the attribute "adata.obsm['spatial_aligned']" to build a global graph
533 | if slice_dist_micron is None:
534 | loc_xy = pd.DataFrame(adata_st.obsm['spatial_aligned']).values
535 | loc_z = np.zeros(adata_st.shape[0])
536 | loc = np.concatenate([loc_xy, loc_z.reshape(-1, 1)], axis=1)
537 | else:
538 | if len(slice_dist_micron) != (len(adata_st_list) - 1):
539 | raise ValueError("The length of 'slice_dist_micron' should be the number of adatas - 1 !")
540 | else:
541 | loc_xy = pd.DataFrame(adata_st.obsm['spatial_aligned']).values
542 | loc_z = np.zeros(adata_st.shape[0])
543 | dim = 0
544 | for i in range(len(slice_dist_micron)):
545 | dim += adata_st_list[i].shape[0]
546 | loc_z[dim:] += slice_dist_micron[i] * (min_dist_ref / c2c_dist)
547 | loc = np.concatenate([loc_xy, loc_z.reshape(-1, 1)], axis=1)
548 |
549 | # If 3D coordinates already exists
550 | else:
551 | if rad_cutoff is None:
552 | raise ValueError("Please specify 'rad_cutoff' for finding 3D neighbors!")
553 | loc = three_dim_coor
554 |
555 | loc = pd.DataFrame(loc)
556 | loc.index = adata_st.obs.index
557 | loc.columns = ['x', 'y', 'z']
558 |
559 | if model == 'Radius':
560 | nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(loc)
561 | distances, indices = nbrs.radius_neighbors(loc, return_distance=True)
562 | KNN_list = []
563 | for it in range(indices.shape[0]):
564 | KNN_list.append(pd.DataFrame(zip([it] * indices[it].shape[0], indices[it], distances[it])))
565 |
566 | if model == 'KNN':
567 | nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff + 1).fit(loc)
568 | distances, indices = nbrs.kneighbors(loc)
569 | KNN_list = []
570 | for it in range(indices.shape[0]):
571 | KNN_list.append(pd.DataFrame(zip([it] * indices.shape[1], indices[it, :], distances[it, :])))
572 |
573 | KNN_df = pd.concat(KNN_list)
574 | KNN_df.columns = ['Cell1', 'Cell2', 'Distance']
575 |
576 | Spatial_Net = KNN_df.copy()
577 | Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance'] > 0,]
578 | id_cell_trans = dict(zip(range(loc.shape[0]), np.array(loc.index), ))
579 | Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
580 | Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)
581 |
582 | print('The graph contains %d edges, %d cells.' % (Spatial_Net.shape[0], adata_st.n_obs))
583 | print('%.4f neighbors per cell on average.' % (Spatial_Net.shape[0] / adata_st.n_obs))
584 |
585 | adata_st.uns['Spatial_Net'] = Spatial_Net
586 | return adata_st
587 |
588 |
589 | # %%
590 | import shutil
591 | import warnings
592 |
593 | warnings.filterwarnings('ignore')
594 |
595 | import os
596 |
597 | import torch
598 | import SpaMask as stm
599 | from pathlib import Path
600 | import scanpy as sc
601 | from sklearn import metrics
602 | # %%
603 | from sklearn.decomposition import PCA
604 |
605 |
606 | def load_adata(section_ids, k_cutoff, rad_cutoff, model, n_top_genes):
607 | Batch_list = []
608 | for section_id in section_ids:
609 | print(section_id)
610 | input_dir = os.path.join('D:\\project\\datasets\\DLPFC\\', section_id)
611 | adata = sc.read_visium(path=input_dir, count_file=section_id + '_filtered_feature_bc_matrix.h5',
612 | load_images=True)
613 | adata.var_names_make_unique(join="++")
614 |
615 | # read the annotation
616 | Ann_df = pd.read_csv(os.path.join(input_dir, section_id + '_truth.txt'), sep='\t', header=None, index_col=0)
617 | Ann_df.columns = ['Ground Truth']
618 | Ann_df[Ann_df.isna()] = "unknown"
619 | adata.obs['Ground Truth'] = Ann_df.loc[adata.obs_names, 'Ground Truth'].astype('category')
620 |
621 | # make spot name unique
622 | adata.obs_names = [x + '_' + section_id for x in adata.obs_names]
623 |
624 | # stm.Cal_Spatial_Net(adata, rad_cutoff=150)
625 | adata.var_names_make_unique()
626 | adata.layers['count'] = adata.X.toarray()
627 | sc.pp.filter_genes(adata, min_cells=50)
628 | sc.pp.filter_genes(adata, min_counts=10)
629 | sc.pp.normalize_total(adata, target_sum=1e6)
630 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", layer='count', n_top_genes=n_top_genes)
631 | adata = adata[:, adata.var['highly_variable'] == True]
632 | sc.pp.scale(adata)
633 | adata = adata[:, adata.var['highly_variable']]
634 | Batch_list.append(adata)
635 |
636 | # %%
637 | Batch_list = align_spots(Batch_list, method='icp', plot=False)
638 | # %%
639 | adata_st = preprocess(Batch_list, section_ids=section_ids, k_cutoff=k_cutoff, rad_cutoff=rad_cutoff, model=model,
640 | slice_dist_micron=[10, 10, 10])
641 | adata_X = PCA(n_components=200, random_state=42).fit_transform(adata_st.X)
642 | adata_st.obsm['feat'] = adata_X
643 |
644 | return adata_st
645 |
646 |
647 | def train_one(args, adata, section_ids, num_clusters, ARI_list):
648 | # %%
649 | net = stm.spaMask.SPAMASK(adata,
650 | tissue_name='Donor',
651 | num_clusters=num_clusters,
652 | device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
653 | learning_rate=args.learning_rate,
654 | weight_decay=args.weight_decay,
655 | max_epoch=args.max_epoch,
656 | gradient_clipping=args.gradient_clipping,
657 | feat_mask_rate=args.feat_mask_rate,
658 | edge_drop_rate=args.edge_drop_rate,
659 | hidden_dim=args.hidden_dim,
660 | latent_dim=args.latent_dim,
661 | bn=args.bn,
662 | att_dropout_rate=args.att_dropout_rate,
663 | fc_dropout_rate=args.fc_dropout_rate,
664 | use_token=args.use_token,
665 | rep_loss=args.rep_loss,
666 | rel_loss=args.rel_loss,
667 | alpha=args.alpha,
668 | lam=args.lam,
669 | random_seed=args.seed,
670 | nps=args.nps)
671 | net.train()
672 | # %%
673 | method = "kmeans"
674 | net.process(method=method)
675 |
676 | adata = net.get_adata()
677 | sub_adata = adata[~pd.isnull(adata.obs['Ground Truth'])]
678 | ARI = metrics.adjusted_rand_score(sub_adata.obs['Ground Truth'], sub_adata.obs[method])
679 | print(f"total ARI:{ARI}")
680 | for name in section_ids:
681 | sub_adata_tmp = sub_adata[sub_adata.obs['batch_name'] == name]
682 | ARI = metrics.adjusted_rand_score(sub_adata_tmp.obs['Ground Truth'], sub_adata_tmp.obs[method])
683 | print(f"{name} ARI:{round(ARI, 4)}")
684 | ARI_list.append(ARI)
685 |
686 | return ARI_list, adata
687 |
688 |
689 | args = stm.utils.build_args()
690 | args.hidden_dim, args.latent_dim = 512, 256
691 | args.max_epoch = 1000
692 | args.lam = 2
693 | args.feat_mask_rate = 0.5
694 | args.edge_drop_rate = 0.2
695 | args.top_genes = 5000
696 | args.rad_cutoff = 200
697 | args.k_cutoff = 21
698 | args.model = 'KNN'
699 |
700 | slices_list = ['151673', '151674', '151675', '151676']
701 | num_clusters = 7
702 |
703 | adata = load_adata(slices_list, k_cutoff=args.k_cutoff, rad_cutoff=args.rad_cutoff, model=args.model,
704 | n_top_genes=args.top_genes)
705 |
706 |
707 | ARI_list = []
708 | ARI_list, adata = train_one(args, adata, slices_list, num_clusters, ARI_list)
709 |
710 | ARI = np.median(ARI_list)
711 |
712 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SpaMask: Dual Masking Graph Autoencoder with Contrastive Learning for Spatial Transcriptomics
2 | ## 🔥 Introduction
3 | Understanding the spatial locations of cell within tissues is crucial for unraveling the organization of cellular diversity. Recent advancements in spatial resolved transcriptomics (SRT) have enabled the analysis of gene expression while preserving the spatial context within tissues. Spatial domain characterization is a critical first step in SRT data analysis, providing the foundation for subsequent analyses and insights into biological implications. Graph neural networks (GNNs) have emerged as a common tool for addressing this challenge due to the structural nature of SRT data. However, current graph-based deep learning approaches often overlook the instability caused by the high sparsity of SRT data. **Masking mechanisms**, as an effective self-supervised learning strategy, can enhance the robustness of these models. To this end, we propose **SpaMask, dual masking graph autoencoder with contrastive learning for SRT analysis**. Unlike previous GNNs, SpaMask masks a portion of spot nodes and spot-to-spot edges to enhance its performance and robustness. SpaMask combines **Masked Graph Autoencoders (MGAE) and Masked Graph Contrastive Learning (MGCL)** modules, with MGAE using node masking to leverage spatial neighbors for improved clustering accuracy, while MGCL applies edge masking to create a contrastive loss framework that tightens embeddings of adjacent nodes based on spatial proximity and feature similarity. We conducted a comprehensive evaluation of SpaMask on **eight datasets from five different platforms**. Compared to existing methods, SpaMask achieves superior clustering accuracy and effective batch correction.
4 |
5 | 
6 |
7 | ## 🌐 Data
8 | - All public datasets used in this paper are available at [Zenodo](https://zenodo.org/records/14062665)
9 |
10 | ## 🔬 Setup
11 | - `pip install -r requirement.txt`
12 |
13 | ## 🚀 Get Started
14 | We provided codes for reproducing the experiments of the paper, and comprehensive tutorials for using SpaMask.
15 | - Please see `TutorialDonor.ipynb`.
16 |
17 |
18 | ## 🔥Citing
19 |
The corresponding BiBTeX citation are given below:
20 |
21 | @article{min2025spamask,
22 | title={SpaMask: Dual masking graph autoencoder with contrastive learning for spatial transcriptomics},
23 | author={Min, Wenwen and Fang, Donghai and Chen, Jinyu and Zhang, Shihua},
24 | journal={PLOS Computational Biology},
25 | volume={21},
26 | number={4},
27 | pages={e1012881},
28 | year={2025},
29 | publisher={Public Library of Science San Francisco, CA USA}
30 | }
31 |
32 |
33 | ## Article link
34 |
35 | - [https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1012881](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1012881)
36 |
--------------------------------------------------------------------------------
/SpaMask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenwenmin/SpaMask/89ba5ba647e02df7ea9d343032f4a2205fb8a034/SpaMask.jpg
--------------------------------------------------------------------------------
/SpaMask/__init__.py:
--------------------------------------------------------------------------------
1 | from .preprocess import load_feat, Cal_Spatial_Net, Transfer_pytorch_Data
2 | from .utils import fix_seed, Stats_Spatial_Net, mclust_R, save_args_to_file
3 | from .model import stMask_model
4 | from .spaMask import SPAMASK
5 |
6 | __all__ = [
7 | "load_feat",
8 | "Cal_Spatial_Net",
9 | "Transfer_pytorch_Data",
10 | "fix_seed",
11 | "Stats_Spatial_Net",
12 | "mclust_R",
13 | "save_args_to_file",
14 | ]
15 |
--------------------------------------------------------------------------------
/SpaMask/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenwenmin/SpaMask/89ba5ba647e02df7ea9d343032f4a2205fb8a034/SpaMask/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/SpaMask/__pycache__/model.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenwenmin/SpaMask/89ba5ba647e02df7ea9d343032f4a2205fb8a034/SpaMask/__pycache__/model.cpython-311.pyc
--------------------------------------------------------------------------------
/SpaMask/__pycache__/preprocess.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenwenmin/SpaMask/89ba5ba647e02df7ea9d343032f4a2205fb8a034/SpaMask/__pycache__/preprocess.cpython-311.pyc
--------------------------------------------------------------------------------
/SpaMask/__pycache__/spaMask.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenwenmin/SpaMask/89ba5ba647e02df7ea9d343032f4a2205fb8a034/SpaMask/__pycache__/spaMask.cpython-311.pyc
--------------------------------------------------------------------------------
/SpaMask/__pycache__/utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wenwenmin/SpaMask/89ba5ba647e02df7ea9d343032f4a2205fb8a034/SpaMask/__pycache__/utils.cpython-311.pyc
--------------------------------------------------------------------------------
/SpaMask/model.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from functools import partial
3 | import torch.nn.functional as F
4 | import torch
5 | from torch import nn
6 | from torch_geometric.nn import (
7 | TransformerConv,
8 | LayerNorm,
9 | Linear,
10 | GCNConv,
11 | SAGEConv,
12 | GATConv,
13 | GINConv,
14 | GATv2Conv,
15 | global_add_pool,
16 | global_mean_pool,
17 | global_max_pool
18 | )
19 |
20 | try:
21 | import torch_cluster # noqa
22 |
23 | random_walk = torch.ops.torch_cluster.random_walk
24 | except ImportError:
25 | random_walk = None
26 | from torch_geometric.utils.num_nodes import maybe_num_nodes
27 | from torch_geometric.utils import to_undirected, sort_edge_index
28 | from torch_geometric.utils import add_self_loops, negative_sampling, degree
29 |
30 | def create_activation(name):
31 | if name == "relu":
32 | return nn.ReLU()
33 | elif name == "gelu":
34 | return nn.GELU()
35 | elif name == "prelu":
36 | return nn.PReLU()
37 | elif name is None:
38 | return nn.Identity()
39 | elif name == "elu":
40 | return nn.ELU()
41 | else:
42 | raise NotImplementedError(f"{name} is not implemented.")
43 |
44 |
45 |
46 | class Encoder(nn.Module):
47 | def __init__(self, input_dim, hidden_dim, latent_dim, bn=True, dropout_rate=.1, act="prelu", bias=True):
48 | super().__init__()
49 | bn = nn.BatchNorm1d if bn else nn.Identity
50 | # self.conv1 = GCNConv(in_channels=input_dim, out_channels=hidden_dim, heads=1, dropout=dropout_rate, concat=False, bias=bias)
51 | # self.bn1 = bn(hidden_dim * 1)
52 | # self.conv2 = GCNConv(in_channels=hidden_dim, out_channels=latent_dim, heads=1, dropout=dropout_rate, concat=False, bias=bias)
53 | # self.bn2 = bn(latent_dim * 1)
54 | self.conv1 = GCNConv(in_channels=input_dim, out_channels=hidden_dim)
55 | self.bn1 = bn(hidden_dim * 1)
56 | self.conv2 = GCNConv(in_channels=hidden_dim, out_channels=latent_dim)
57 | self.bn2 = bn(latent_dim * 1)
58 | self.activation = create_activation(act)
59 |
60 | def forward(self, x, edge_index):
61 | h = self.activation(self.bn2(self.conv2(self.activation(self.bn1(self.conv1(x, edge_index))), edge_index)))
62 | return h
63 |
64 | class FeatureDecoder(nn.Module):
65 | def __init__(self, latent_dim, output_dim, dropout_rate=.1, act="prelu", bias=True):
66 | super().__init__()
67 | # self.conv1 = GCNConv(in_channels=latent_dim, out_channels=output_dim, heads=1, dropout=dropout_rate, concat=False, bias=bias)
68 | self.conv1 = GCNConv(in_channels=latent_dim, out_channels=output_dim)
69 | self.activation = create_activation(act)
70 |
71 | def forward(self, x, edge_index):
72 | h = self.conv1(x, edge_index)
73 | return h
74 |
75 | class TopologyDecoder(nn.Module):
76 | def __init__(self, input_dim, latent_dim, output_dim=1, dropout_rate=0.5, act="relu"):
77 | super().__init__()
78 | self.fc1 = Linear(in_channels=input_dim, out_channels=latent_dim)
79 | self.fc1.reset_parameters()
80 | self.fc2 = Linear(in_channels=latent_dim, out_channels=output_dim)
81 | self.fc2.reset_parameters()
82 | self.d_drop = nn.Dropout(dropout_rate)
83 | self.activation = create_activation(act)
84 |
85 | def forward(self, x, edge_index):
86 | h = x[edge_index[0]] * x[edge_index[1]]
87 | h = self.fc2(self.activation(self.fc1(self.d_drop(h))))
88 | return h
89 |
90 | class stMask_model(nn.Module):
91 | def __init__(self, features_dims, bn=False, att_dropout_rate=.2, fc_dropout_rate=.5, use_token=True, alpha=2, edge_drop_rate=0.3, feat_mask_rate=0.3, rep_loss="cse",rel_loss="ce"):
92 | super().__init__()
93 | [input_dim, hidden_dim, latent_dim, output_dim] = features_dims
94 | self.encoder = Encoder(input_dim, hidden_dim, latent_dim, bn=bn, dropout_rate=att_dropout_rate, act="prelu", bias=True)
95 |
96 | self.use_token = use_token
97 | if self.use_token:
98 | self.enc_mask_token = nn.Parameter(torch.zeros(1, input_dim))
99 | self.encoder_to_decoder = nn.Linear(latent_dim, latent_dim, bias=False)
100 | nn.init.xavier_uniform_(self.encoder_to_decoder.weight)
101 | self.feat_deocder = FeatureDecoder(latent_dim, output_dim, dropout_rate=att_dropout_rate, act="prelu", bias=True)
102 | self.topo_decoder = TopologyDecoder(latent_dim, 2*latent_dim, 1, fc_dropout_rate)
103 |
104 |
105 | self.feat_loss = self.setup_loss_fn(rep_loss, alpha)
106 | self.edge_loss = self.setup_loss_fn(rel_loss)
107 |
108 | self.edge_drop_rate = edge_drop_rate
109 | self.feat_mask_rate = feat_mask_rate
110 |
111 | def forward(self, data):
112 | x = data.x
113 | edge_index = data.edge_index
114 | num_nodes = data.num_nodes
115 |
116 | use_mask_x, mask_nodes = self.mask_feature(x, self.feat_mask_rate)
117 | remaining_edges, masked_edges = self.dropout_edge(edge_index, self.edge_drop_rate)
118 |
119 | rep_x = self.encoder(use_mask_x, edge_index)
120 | rep_e = self.encoder(x, remaining_edges)
121 |
122 | # remasking feats
123 | rec_x = self.encoder_to_decoder(rep_x)
124 | rec_x[mask_nodes] = 0
125 | rec_x = self.feat_deocder(rec_x, edge_index)
126 | feat_loss = self.feat_loss(x[mask_nodes], rec_x[mask_nodes])
127 |
128 | # sampling neg edges
129 | aug_edge_index, _ = add_self_loops(edge_index)
130 | neg_edges = self.random_negative_sampler(
131 | aug_edge_index,
132 | num_nodes=num_nodes,
133 | num_neg_samples=masked_edges.view(2, -1).size(1),
134 | ).view_as(masked_edges)
135 |
136 | pos_edge = self.topo_decoder(rep_e, masked_edges)
137 | neg_edge = self.topo_decoder(rep_e, neg_edges)
138 | topo_loss = self.ce_loss(pos_edge, neg_edge)
139 |
140 | return feat_loss, topo_loss
141 |
142 | def setup_loss_fn(self, loss_fn, alpha_l=2):
143 | if loss_fn == "mse":
144 | criterion = nn.MSELoss()
145 | elif loss_fn == "cse":
146 | criterion = partial(self.sce_loss, alpha=alpha_l)
147 | elif loss_fn == "ce":
148 | criterion = partial(self.ce_loss)
149 | else:
150 | raise NotImplementedError
151 | return criterion
152 |
153 | def sce_loss(self, x, y, alpha=3):
154 | x = F.normalize(x, p=2, dim=-1)
155 | y = F.normalize(y, p=2, dim=-1)
156 | loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
157 | loss = loss.mean()
158 | return loss
159 |
160 | def ce_loss(self, pos_out, neg_out):
161 | pos_loss = F.binary_cross_entropy(pos_out.sigmoid(), torch.ones_like(pos_out))
162 | neg_loss = F.binary_cross_entropy(neg_out.sigmoid(), torch.zeros_like(neg_out))
163 | return pos_loss + neg_loss
164 |
165 | def mask_feature(self, x, feat_mask_rate=0.3):
166 | num_nodes = x.shape[0]
167 | perm = torch.randperm(num_nodes, device=x.device)
168 | # random masking
169 | num_mask_nodes = int(feat_mask_rate * num_nodes)
170 | mask_nodes = perm[: num_mask_nodes]
171 | keep_nodes = perm[num_mask_nodes:]
172 | out_x = x.clone()
173 | if self.use_token:
174 | out_x[mask_nodes] += self.enc_mask_token
175 | else:
176 | out_x[mask_nodes] = 0.0
177 | return out_x, mask_nodes #, keep_nodes
178 |
179 | def mask_features(self, x, feat_mask_rate=0.3):
180 | mask_nodes = torch.empty((x.size(0),), dtype=torch.float32, device=x.device).uniform_(0, 1) < feat_mask_rate
181 | mask_x = x.clone()
182 | mask_x[mask_nodes] = 0
183 | if self.use_token:
184 | mask_x[mask_nodes] += self.enc_mask_token
185 | return mask_x, mask_nodes
186 |
187 | def dropout_edge(self, my_edge_index, edge_drop_rate=0.3):
188 | edge_index = my_edge_index.clone()
189 | p = torch.zeros(edge_index.shape[1]).to(edge_index.device) + 1 - edge_drop_rate
190 | stay = torch.bernoulli(p).to(torch.bool)
191 | mask = ~stay
192 | remaining_edges, masked_edges = edge_index[:, stay], edge_index[:, mask]
193 | remaining_edges = to_undirected(remaining_edges)
194 | return remaining_edges, masked_edges
195 |
196 | def random_negative_sampler(self, edge_index, num_nodes, num_neg_samples):
197 | neg_edges = torch.randint(0, num_nodes, size=(2, num_neg_samples)).to(edge_index)
198 | return neg_edges
199 |
200 | @torch.no_grad()
201 | def embed(self, data):
202 | x = data.x
203 | edge_index = data.edge_index
204 | h = self.encoder(x, edge_index)
205 | return h
206 |
207 | @torch.no_grad()
208 | def recon(self, data):
209 | x = data.x
210 | edge_index = data.edge_index
211 | h = self.encoder(x, edge_index)
212 | rec = self.encoder_to_decoder(h)
213 | rec = self.feat_deocder(rec, edge_index)
214 | return h, rec
215 |
216 |
217 | @torch.no_grad()
218 | def embed_masking(self, data):
219 | x = data.x
220 | edge_index = data.edge_index
221 | use_mask_x, mask_nodes = self.mask_feature(x, self.feat_mask_rate)
222 | remaining_edges, masked_edges = self.dropout_edge(edge_index, self.edge_drop_rate)
223 | rep_x = self.encoder(use_mask_x, edge_index)
224 | rep_e = self.encoder(x, remaining_edges)
225 | return rep_x, rep_e
226 |
--------------------------------------------------------------------------------
/SpaMask/preprocess.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings('ignore')
3 | import os
4 | import torch
5 | import random
6 | import numpy as np
7 | import scanpy as sc
8 | import pandas as pd
9 | import scipy.sparse as sp
10 | from sklearn.neighbors import NearestNeighbors
11 | import sklearn.neighbors
12 | from torch_geometric.data import Data
13 | from pathlib import Path
14 |
15 | def prefilter_genes(adata,min_counts=None,max_counts=None,min_cells=10,max_cells=None):
16 | if min_cells is None and min_counts is None and max_cells is None and max_counts is None:
17 | raise ValueError('Provide one of min_counts, min_genes, max_counts or max_genes.')
18 | id_tmp=np.asarray([True]*adata.shape[1],dtype=bool)
19 | id_tmp=np.logical_and(id_tmp,sc.pp.filter_genes(adata.X,min_cells=min_cells)[0]) if min_cells is not None else id_tmp
20 | id_tmp=np.logical_and(id_tmp,sc.pp.filter_genes(adata.X,max_cells=max_cells)[0]) if max_cells is not None else id_tmp
21 | id_tmp=np.logical_and(id_tmp,sc.pp.filter_genes(adata.X,min_counts=min_counts)[0]) if min_counts is not None else id_tmp
22 | id_tmp=np.logical_and(id_tmp,sc.pp.filter_genes(adata.X,max_counts=max_counts)[0]) if max_counts is not None else id_tmp
23 | adata._inplace_subset_var(id_tmp)
24 |
25 | def load_feat(adata, top_genes=3000, model="pca"):
26 | assert (model in ['pca', 'hvg', 'other'])
27 | if model == "pca":
28 | adata.var_names_make_unique()
29 | if isinstance(adata.X, np.ndarray):
30 | adata.layers['count'] = adata.X
31 | else:
32 | adata.layers['count'] = adata.X.toarray()
33 | sc.pp.filter_genes(adata, min_cells=50)
34 | sc.pp.filter_genes(adata, min_counts=10)
35 | sc.pp.normalize_total(adata, target_sum=1e6)
36 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", layer='count', n_top_genes=top_genes)
37 | adata = adata[:, adata.var['highly_variable'] == True]
38 | sc.pp.scale(adata)
39 | from sklearn.decomposition import PCA # sklearn PCA is used because PCA in scanpy is not stable.
40 | adata_X = PCA(n_components=200, random_state=42).fit_transform(adata.X)
41 | adata.obsm['feat'] = adata_X
42 | print(f"adata.obsm['feat'].shape:{adata.obsm['feat'].shape}")
43 |
44 | elif model == "hvg":
45 | # Expression data preprocessing
46 | adata.var_names_make_unique()
47 | prefilter_genes(adata, min_cells=3) # avoiding all genes are zeros
48 | # Normalization
49 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=top_genes)
50 | sc.pp.normalize_total(adata, target_sum=1e4)
51 | sc.pp.log1p(adata)
52 | adata.X = sp.csr_matrix(adata.X)
53 | adata_Vars = adata[:, adata.var['highly_variable']]
54 | # sc.pp.scale(adata)
55 | adata.obsm['feat'] = adata_Vars.X[:, ]
56 |
57 | elif model == "other":
58 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=top_genes)
59 | sc.pp.normalize_total(adata, target_sum=1e4)
60 | sc.pp.log1p(adata)
61 | sc.pp.scale(adata, zero_center=False, max_value=10)
62 | adata.X = sp.csr_matrix(adata.X)
63 | adata_Vars = adata[:, adata.var['highly_variable']]
64 | adata.obsm['feat'] = adata_Vars.X[:, ]
65 |
66 | return adata
67 |
68 |
69 |
70 | def Cal_Spatial_Net(adata, rad_cutoff=None, k_cutoff=None, model='Radius', verbose=True):
71 | assert (model in ['Radius', 'KNN'])
72 | if verbose:
73 | print('------Calculating spatial graph...')
74 | coor = pd.DataFrame(adata.obsm['spatial'])
75 | coor.index = adata.obs.index
76 | coor.columns = ['imagerow', 'imagecol']
77 |
78 | if model == 'Radius':
79 | nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(coor)
80 | distances, indices = nbrs.radius_neighbors(coor, return_distance=True)
81 | KNN_list = []
82 | for it in range(indices.shape[0]):
83 | KNN_list.append(pd.DataFrame(zip([it] * indices[it].shape[0], indices[it], distances[it])))
84 |
85 | if model == 'KNN':
86 | nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff + 1).fit(coor)
87 | distances, indices = nbrs.kneighbors(coor)
88 | KNN_list = []
89 | for it in range(indices.shape[0]):
90 | KNN_list.append(pd.DataFrame(zip([it] * indices.shape[1], indices[it, :], distances[it, :])))
91 |
92 | KNN_df = pd.concat(KNN_list)
93 | KNN_df.columns = ['Cell1', 'Cell2', 'Distance']
94 |
95 | Spatial_Net = KNN_df.copy()
96 | Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance'] > 0,]
97 | id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
98 | Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
99 | Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)
100 | if verbose:
101 | print('The graph contains %d edges, %d cells.' % (Spatial_Net.shape[0], adata.n_obs))
102 | print('%.4f neighbors per cell on average.' % (Spatial_Net.shape[0] / adata.n_obs))
103 |
104 | adata.uns['Spatial_Net'] = Spatial_Net
105 | # #########
106 | # X = pd.DataFrame(adata.X.toarray()[:, ], index=adata.obs.index, columns=adata.var.index)
107 | # cells = np.array(X.index)
108 | # cells_id_tran = dict(zip(cells, range(cells.shape[0])))
109 | # if 'Spatial_Net' not in adata.uns.keys():
110 | # raise ValueError("Spatial_Net is not existed! Run Cal_Spatial_Net first!")
111 | #
112 | # Spatial_Net = adata.uns['Spatial_Net']
113 | # G_df = Spatial_Net.copy()
114 | # G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
115 | # G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
116 | # G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
117 | # G = G + sp.eye(G.shape[0]) # self-loop
118 | # adata.uns['adj'] = G
119 | return adata
120 |
121 |
122 | def Transfer_pytorch_Data(adata, weightless=True):
123 | if weightless:
124 | return weightless_undirected_graph(adata)
125 | else:
126 | return powered_undirected_graph(adata)
127 |
128 |
129 | def weightless_undirected_graph(adata):
130 | G_df = adata.uns['Spatial_Net'].copy()
131 | cells = np.array(adata.obs_names)
132 | cells_id_tran = dict(zip(cells, range(cells.shape[0])))
133 | G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
134 | G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
135 | G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
136 | G = G + sp.eye(G.shape[0])
137 | edgeList = np.nonzero(G)
138 | if type(adata.obsm['feat']) == np.ndarray:
139 | data = Data(edge_index=torch.LongTensor(np.array(
140 | [edgeList[0], edgeList[1]])), x=torch.FloatTensor(adata.obsm['feat'])) # .todense()
141 | else:
142 | data = Data(edge_index=torch.LongTensor(np.array(
143 | [edgeList[0], edgeList[1]])), x=torch.FloatTensor(adata.obsm['feat'].todense())) # .todense()
144 | return data
145 |
146 | def powered_undirected_graph(adata):
147 | pass
148 |
149 | if __name__ == '__main__':
150 | # sample name
151 | sample_name = '151676'
152 | n_clusters = 5 if sample_name in ['151669', '151670', '151671', '151672'] else 7
153 | # path
154 | data_root = Path("D:\\project\\datasets\\DLPFC\\")
155 | count_file = sample_name + "_filtered_feature_bc_matrix.h5"
156 | adata = sc.read_visium(data_root / sample_name, count_file=count_file)
157 | adata = load_feat(adata, model="pca")
158 | print(adata.obsm['feat'].shape)
--------------------------------------------------------------------------------
/SpaMask/spaMask.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | import numpy as np
4 | import torch
5 | from tqdm import tqdm
6 |
7 | from .preprocess import load_feat, Cal_Spatial_Net, Transfer_pytorch_Data
8 | from .utils import fix_seed, Stats_Spatial_Net, mclust_R, Kmeans_cluster
9 | from .model import stMask_model
10 |
11 | class SPAMASK:
12 | def __init__(self,
13 | adata,
14 | tissue_name="BRCA",
15 | num_clusters=20,
16 | top_genes=4000,
17 | genes_model="hvg", # 'pca', 'hvg'
18 | rad_cutoff=300,
19 | k_cutoff=12,
20 | graph_model='Radius', # 'Radius', 'KNN'
21 | device=torch.device('cpu'),
22 | learning_rate=0.001,
23 | weight_decay=2e-4,
24 | max_epoch=1500,
25 | gradient_clipping=5,
26 | feat_mask_rate=0.3,
27 | edge_drop_rate=0.6,
28 | hidden_dim=512,
29 | latent_dim=256,
30 | bn=True,
31 | att_dropout_rate=0.2,
32 | fc_dropout_rate=0.5,
33 | use_token=True,
34 | alpha=2,
35 | rep_loss="cse",
36 | rel_loss="ce",
37 | lam=1.4,
38 | random_seed=2024,
39 | nps=30,
40 | ):
41 |
42 | self.__adata = adata.copy()
43 | self.__tissue_name = tissue_name
44 | self.__top_genes = top_genes
45 | self.__genes_model = genes_model
46 | self.__rad_cutoff = rad_cutoff
47 | self.__k_cutoff = k_cutoff
48 | self.__graph_model = graph_model
49 | self.__device = device
50 | self.__learning_rate = learning_rate
51 | self.__weight_decay = weight_decay
52 | self.__max_epoch = max_epoch
53 | self.__gradient_clipping = gradient_clipping
54 | self.__feat_mask_rate = feat_mask_rate
55 | self.__edge_drop_rate = edge_drop_rate
56 | self.__hidden_dim = hidden_dim
57 | self.__latent_dim = latent_dim
58 | self.__bn = bn
59 | self.__att_dropout_rate = att_dropout_rate
60 | self.__fc_dropout_rate = fc_dropout_rate
61 | self.__use_token = use_token
62 | self.__alpha = alpha
63 | self.__rep_loss = rep_loss
64 | self.__rel_loss = rel_loss
65 | self.__lam = lam
66 | self.__nps = nps
67 |
68 |
69 | fix_seed(random_seed)
70 |
71 | if 'highly_variable' not in self.__adata.var.keys() and 'feat' not in adata.obsm.keys():
72 | self.__adata = load_feat(self.__adata, top_genes=self.__top_genes, model=self.__genes_model)
73 |
74 | if 'Spatial_Net' not in self.__adata.uns.keys():
75 | Cal_Spatial_Net(self.__adata, rad_cutoff=self.__rad_cutoff, k_cutoff=self.__k_cutoff, model=self.__graph_model)
76 |
77 | self.num_clusters = num_clusters # 5 if self.tissue_name in ['151669', '151670', '151671', '151672'] else 7
78 | print(self.__adata.obsm['feat'].shape)
79 |
80 | def train(self):
81 | data = Transfer_pytorch_Data(self.__adata).to(self.__device)
82 | output_dim = input_dim = data.x.shape[-1]
83 | features_dims = [input_dim, self.__hidden_dim, self.__latent_dim, output_dim]
84 | self.model = stMask_model(features_dims, bn=self.__bn,
85 | att_dropout_rate=self.__att_dropout_rate,fc_dropout_rate=self.__fc_dropout_rate,
86 | use_token=self.__use_token, alpha=self.__alpha,
87 | edge_drop_rate=self.__edge_drop_rate, feat_mask_rate=self.__feat_mask_rate,
88 | rep_loss=self.__rep_loss,rel_loss=self.__rel_loss).to(self.__device)
89 |
90 | self.optimizer = torch.optim.Adam(self.model.parameters(), self.__learning_rate, weight_decay=self.__weight_decay)
91 |
92 | y_pred_last = None
93 | epoch_iter = tqdm(range(self.__max_epoch))
94 | for epoch in epoch_iter:
95 | self.model.train()
96 | self.optimizer.zero_grad()
97 |
98 | feat_loss, topo_loss = self.model(data)
99 |
100 | loss = feat_loss + topo_loss * self.__lam
101 | loss.backward()
102 | gradient_clipping = self.__gradient_clipping
103 | if gradient_clipping > 1:
104 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), gradient_clipping)
105 | self.optimizer.step()
106 | epoch_iter.set_description(f"Dataset_Name:{self.__tissue_name}, Ep {epoch}: train loss:{loss.item():.4f}")
107 |
108 | def process(self, method="kmeans"):
109 | data = Transfer_pytorch_Data(self.__adata).to(self.__device)
110 | with torch.no_grad():
111 | self.model.eval()
112 | h, z = self.model.recon(data=data)
113 | rep = h.to('cpu').detach().numpy()
114 | rec = z.to('cpu').detach().numpy()
115 | if rep.shape[-1] > 64:
116 | from sklearn.decomposition import PCA
117 | pca = PCA(n_components=self.__nps)
118 | rep = pca.fit_transform(rep)
119 | self.__adata.obsm["eval_pred"] = rep
120 | self.__adata.obsm["eval_recon"] = rec
121 |
122 | if method == "mclust":
123 | mclust_R(self.__adata, num_cluster=self.num_clusters, used_obsm="eval_pred", key_added_pred=method)
124 | elif method == "kmeans":
125 | Kmeans_cluster(self.__adata, num_cluster=self.num_clusters, used_obsm="eval_pred", key_added_pred=method)
126 |
127 |
128 |
129 | def show_Stats_Spatial_Net(self):
130 | Stats_Spatial_Net(self.__adata)
131 |
132 | def save_model_dict(self, save_model_file):
133 | torch.save({'state_dict': self.model.state_dict()}, save_model_file)
134 | print('Saving model to %s' % save_model_file)
135 |
136 | def save_model(self, save_model_file):
137 | torch.save(self.model, save_model_file)
138 | print('Saving model to %s' % save_model_file)
139 |
140 | def load_model_dict(self, save_model_file):
141 | saved_state_dict = torch.load(save_model_file)
142 | self.model.load_state_dict(saved_state_dict['state_dict'])
143 | print('Loading model from %s' % save_model_file)
144 |
145 | def load_model(self, save_model_file):
146 | self.model = torch.load(save_model_file)
147 | print('Loading model from %s' % save_model_file)
148 |
149 |
150 | def get_adata(self):
151 | return self.__adata
152 |
153 |
--------------------------------------------------------------------------------
/SpaMask/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import matplotlib.pyplot as plt
6 |
7 | from sklearn.cluster import KMeans
8 | from scipy.spatial import distance
9 | from sklearn.metrics import adjusted_mutual_info_score,normalized_mutual_info_score,completeness_score,fowlkes_mallows_score, homogeneity_score
10 |
11 | from sklearn.metrics.cluster import v_measure_score, adjusted_rand_score
12 |
13 | # the location of R (used for the mclust clustering)
14 | os.environ['R_HOME'] = 'D:/software/R/R-4.3.2'
15 | os.environ['R_USER'] = 'D:/software/anaconda/anaconda3/envs/pt20cu118/Lib/site-packages/rpy2'
16 |
17 |
18 | def fix_seed(seed=2024):
19 | import random
20 | import torch
21 | from torch.backends import cudnn
22 |
23 | os.environ['PYTHONHASHSEED'] = str(seed)
24 | random.seed(seed)
25 | np.random.seed(seed)
26 | torch.manual_seed(seed)
27 | torch.cuda.manual_seed(seed)
28 | torch.cuda.manual_seed_all(seed)
29 | cudnn.deterministic = True
30 | cudnn.benchmark = False
31 |
32 | os.environ['PYTHONHASHSEED'] = str(seed)
33 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
34 |
35 |
36 | def Stats_Spatial_Net(adata):
37 | Num_edge = adata.uns['Spatial_Net']['Cell1'].shape[0]
38 | Mean_edge = Num_edge / adata.shape[0]
39 | plot_df = pd.value_counts(pd.value_counts(adata.uns['Spatial_Net']['Cell1']))
40 | plot_df = plot_df / adata.shape[0]
41 | fig, ax = plt.subplots(figsize=[3, 2])
42 | plt.ylabel('Percentage')
43 | plt.xlabel('')
44 | plt.title('Number of Neighbors (Mean=%.2f)' % Mean_edge)
45 | ax.bar(plot_df.index, plot_df)
46 |
47 | def Kmeans_cluster(adata, num_cluster, used_obsm='model_pred', key_added_pred="kmeans", random_seed=2024):
48 | np.random.seed(random_seed)
49 | cluster_model = KMeans(n_clusters=num_cluster, init='k-means++', n_init=100, max_iter=1000, tol=1e-6)
50 | cluster_labels = cluster_model.fit_predict(adata.obsm[used_obsm])
51 | adata.obs[key_added_pred] = cluster_labels
52 | adata.obs[key_added_pred] = adata.obs[key_added_pred].astype('int')
53 | adata.obs[key_added_pred] = adata.obs[key_added_pred].astype('category')
54 | return adata
55 |
56 | def mclust_R(adata, num_cluster, modelNames='EEE', used_obsm='model_pred', key_added_pred="mclust", random_seed=2024):
57 | np.random.seed(random_seed)
58 | import rpy2.robjects as robjects
59 | robjects.r.library("mclust")
60 |
61 | import rpy2.robjects.numpy2ri
62 | rpy2.robjects.numpy2ri.activate()
63 | r_random_seed = robjects.r['set.seed']
64 | r_random_seed(random_seed)
65 | rmclust = robjects.r['Mclust']
66 |
67 | res = rmclust(rpy2.robjects.numpy2ri.numpy2rpy(adata.obsm[used_obsm]), num_cluster, modelNames)
68 | mclust_res = np.array(res[-2])
69 |
70 | adata.obs[key_added_pred] = mclust_res
71 | adata.obs[key_added_pred] = adata.obs[key_added_pred].astype('int')
72 | adata.obs[key_added_pred] = adata.obs[key_added_pred].astype('category')
73 | return adata
74 |
75 |
76 | def build_args():
77 | import argparse
78 | parser = argparse.ArgumentParser(description="stMask")
79 | parser.add_argument("--model_name", type=str, default="SpaMask")
80 | parser.add_argument("--seed", type=int, default=2023)
81 | parser.add_argument("--tissue_name", type=str, default="151507")
82 |
83 | parser.add_argument("--top_genes", type=int, default=2000)
84 | parser.add_argument("--genes_model", type=str, default="pca")
85 | parser.add_argument("--rad_cutoff", type=int, default=200)
86 | parser.add_argument("--k_cutoff", type=int, default=12)
87 | parser.add_argument("--graph_model", type=str, default="KNN")
88 |
89 | parser.add_argument('--nps', type=int, default=30)
90 | parser.add_argument('--gradient_clipping', type=float, default=5.)
91 | parser.add_argument("--need_refine", action='store_true', default=False)
92 |
93 | # 各模型的训练设置
94 | parser.add_argument("--learning_rate", type=float, default=0.001)
95 | parser.add_argument("--weight_decay", type=float, default=2e-4)
96 | parser.add_argument("--max_epoch", type=int, default=500, help="number of training epochs")
97 |
98 | # ST params
99 | parser.add_argument("--edge_drop_rate", type=float, default=0.4)
100 | parser.add_argument("--feat_mask_rate", type=float, default=0.3)
101 |
102 | parser.add_argument("--hidden_dim", type=int, default=512)
103 | parser.add_argument("--latent_dim", type=int, default=256)
104 |
105 | parser.add_argument('--bn', action='store_true', default=True)
106 | parser.add_argument("--att_dropout_rate", type=float, default=.2)
107 | parser.add_argument("--fc_dropout_rate", type=float, default=.5)
108 | parser.add_argument("--use_token", action='store_true', default=True)
109 | parser.add_argument("--rep_loss", type=str, default="cse")
110 | parser.add_argument("--rel_loss", type=str, default="ce")
111 | parser.add_argument("--alpha", type=float, default=2.0)
112 |
113 | parser.add_argument("--lam", type=float, default=2)
114 | args = parser.parse_args(args=[])
115 | return args
116 |
117 |
118 | def measureClusteringTrueLabel(labels_true, labels_pred):
119 | ari = adjusted_rand_score(labels_true, labels_pred)
120 | ami = adjusted_mutual_info_score(labels_true, labels_pred)
121 | nmi = normalized_mutual_info_score(labels_true, labels_pred)
122 | cs = completeness_score(labels_true, labels_pred)
123 | fms = fowlkes_mallows_score(labels_true, labels_pred)
124 | vms = v_measure_score(labels_true, labels_pred)
125 | hs = homogeneity_score(labels_true, labels_pred)
126 | return ari, ami, nmi, cs, fms, vms, hs
127 |
128 |
129 | def refine(adata, pred, shape="hexagon"):
130 | sample_id = adata.obs.index.tolist()
131 | dis = distance.cdist(adata.obsm['spatial'], adata.obsm['spatial'], 'euclidean')
132 | refined_pred = []
133 | pred = pd.DataFrame({"pred": pred}, index=sample_id)
134 | dis_df = pd.DataFrame(dis, index=sample_id, columns=sample_id)
135 | if shape == "hexagon":
136 | num_nbs = 6
137 | elif shape == "square":
138 | num_nbs = 4
139 | else:
140 | print("Shape not recongized, shape='hexagon' for Visium data, 'square' for ST data.")
141 | for i in range(len(sample_id)):
142 | index = sample_id[i]
143 | dis_tmp = dis_df.loc[index, :].sort_values()
144 | nbs = dis_tmp[0:num_nbs + 1]
145 | nbs_pred = pred.loc[nbs.index, "pred"]
146 | self_pred = pred.loc[index, "pred"]
147 | v_c = nbs_pred.value_counts()
148 | if (v_c.loc[self_pred] < num_nbs / 2) and (np.max(v_c) > num_nbs / 2):
149 | refined_pred.append(v_c.idxmax())
150 | else:
151 | refined_pred.append(self_pred)
152 | return refined_pred
153 |
154 |
155 | def save_args_to_file(args, filename):
156 | with open(filename, 'w') as file:
157 | file.write('Parsed Arguments:\n')
158 | for arg, value in vars(args).items():
159 | arg_info = f"{arg}: {value}\n"
160 | file.write(arg_info)
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | aiobotocore==2.5.4
2 | aiohttp==3.9.5
3 | aioitertools==0.11.0
4 | aiosignal==1.3.1
5 | alembic==1.13.1
6 | anndata==0.10.7
7 | annotated-types==0.7.0
8 | annoy==1.17.3
9 | antlr4-python3-runtime==4.9.3
10 | anyio==3.5.0
11 | argon2-cffi==21.3.0
12 | argon2-cffi-bindings==21.2.0
13 | array_api_compat==1.6
14 | asciitree==0.3.3
15 | asttokens==2.0.5
16 | attrs==22.1.0
17 | backcall==0.2.0
18 | bbknn==1.3.9
19 | beautifulsoup4==4.12.2
20 | bleach==4.1.0
21 | botocore==1.31.17
22 | certifi==2024.2.2
23 | cffi==1.15.1
24 | charset-normalizer==3.3.2
25 | click==8.1.7
26 | click-plugins==1.1.1
27 | cligj==0.7.2
28 | cloudpickle==3.0.0
29 | colorama==0.4.6
30 | colorcet==3.1.0
31 | colorlog==6.8.2
32 | comm==0.1.2
33 | contourpy==1.2.1
34 | cycler==0.12.1
35 | Cython==3.0.10
36 | dask==2024.5.0
37 | dask-expr==1.1.0
38 | dask-image==2023.8.1
39 | datashader==0.16.1
40 | debugpy==1.6.7
41 | decorator==5.1.1
42 | defusedxml==0.7.1
43 | Deprecated==1.2.14
44 | dgl==2.2.1
45 | distributed==2024.5.0
46 | docrep==0.3.2
47 | einops==0.7.0
48 | entrypoints==0.4
49 | executing==0.8.3
50 | faiss-cpu==1.8.0
51 | fasteners==0.19
52 | fastjsonschema==2.16.2
53 | fbpca==1.0
54 | filelock==3.9.0
55 | fiona==1.9.6
56 | fonttools==4.51.0
57 | frozenlist==1.4.1
58 | fsspec==2023.6.0
59 | geomloss==0.2.6
60 | geopandas==0.14.4
61 | geosketch==1.2
62 | greenlet==3.0.3
63 | gseapy==1.1.3
64 | h5py==3.11.0
65 | harmonypy==0.0.9
66 | hnswlib==0.8.0
67 | huggingface-hub==0.22.2
68 | idna==3.4
69 | igraph==0.11.4
70 | imageio==2.34.1
71 | importlib_metadata==7.1.0
72 | inflect==7.2.1
73 | intervaltree==3.1.0
74 | ipykernel==6.25.0
75 | ipython==8.15.0
76 | ipython-genutils==0.2.0
77 | jedi==0.18.1
78 | Jinja2==3.1.2
79 | jmespath==1.0.1
80 | joblib==1.4.0
81 | jsonschema==4.17.3
82 | jsonschema-specifications==2023.12.1
83 | jupyter_client==7.4.9
84 | jupyter_core==5.3.0
85 | jupyter-server==1.23.4
86 | jupyterlab-pygments==0.1.2
87 | kiwisolver==1.4.5
88 | latexcodec==3.0.0
89 | lazy_loader==0.4
90 | legacy-api-wrap==1.4
91 | leidenalg==0.10.2
92 | lightning-utilities==0.11.2
93 | llvmlite==0.42.0
94 | locket==1.0.0
95 | louvain==0.8.2
96 | lxml==4.9.3
97 | Mako==1.3.5
98 | markdown-it-py==3.0.0
99 | MarkupSafe==2.1.1
100 | matplotlib==3.8.4
101 | matplotlib-inline==0.1.6
102 | matplotlib-scalebar==0.8.1
103 | mdurl==0.1.2
104 | mistune==0.8.4
105 | more-itertools==10.2.0
106 | mpmath==1.3.0
107 | msgpack==1.0.8
108 | multidict==6.0.5
109 | multipledispatch==1.0.0
110 | multiscale_spatial_image==0.11.2
111 | natsort==8.4.0
112 | nbclassic==0.5.5
113 | nbclient==0.5.13
114 | nbconvert==6.5.4
115 | nbformat==5.9.2
116 | nest-asyncio==1.5.6
117 | networkx==3.2.1
118 | notebook==6.5.4
119 | notebook_shim==0.2.2
120 | numba==0.59.1
121 | numcodecs==0.12.1
122 | numpy==1.26.3
123 | ome-zarr==0.8.3
124 | omegaconf==2.3.0
125 | omnipath==1.0.8
126 | opencv-python==4.9.0.80
127 | optuna==3.6.1
128 | packaging==23.1
129 | pandas==2.2.2
130 | pandocfilters==1.5.0
131 | param==2.1.0
132 | parso==0.8.3
133 | partd==1.4.1
134 | patsy==0.5.6
135 | pickleshare==0.7.5
136 | pillow==10.2.0
137 | PIMS==0.6.1
138 | pip==23.2.1
139 | platformdirs==3.10.0
140 | POT==0.9.3
141 | prometheus-client==0.14.1
142 | prompt-toolkit==3.0.36
143 | protobuf==5.26.1
144 | psutil==5.9.0
145 | pure-eval==0.2.2
146 | pyarrow==15.0.2
147 | pybtex==0.24.0
148 | pycparser==2.21
149 | pyct==0.5.0
150 | pydantic==2.7.3
151 | pydantic_core==2.18.4
152 | pydot==2.0.0
153 | pygeos==0.14
154 | Pygments==2.15.1
155 | pynndescent==0.5.12
156 | pyparsing==3.1.2
157 | pyproj==3.6.1
158 | pyrsistent==0.18.0
159 | python-dateutil==2.8.2
160 | pytorch-lightning==2.2.2
161 | pytz==2024.1
162 | pywin32==305.1
163 | pywinpty==2.0.10
164 | PyYAML==6.0.1
165 | pyzmq==23.2.0
166 | ray==2.10.0
167 | referencing==0.34.0
168 | requests==2.31.0
169 | rich==13.7.1
170 | rpds-py==0.18.0
171 | rpy2==3.5.16
172 | s3fs==2023.6.0
173 | safetensors==0.4.3
174 | scanorama==1.7.4
175 | scanpy==1.10.1
176 | scib==1.1.5
177 | scikit-image==0.23.2
178 | scikit-learn==1.4.2
179 | scikit-misc==0.3.1
180 | scipy==1.13.0
181 | seaborn==0.13.2
182 | Send2Trash==1.8.0
183 | session-info==1.0.0
184 | setuptools==68.0.0
185 | shapely==2.0.4
186 | six==1.16.0
187 | slicerator==1.1.0
188 | sniffio==1.2.0
189 | sortedcontainers==2.4.0
190 | soupsieve==2.4
191 | spatial_image==0.3.0
192 | spatialdata==0.0.15
193 | SQLAlchemy==2.0.30
194 | squidpy==1.4.1
195 | stack-data==0.2.0
196 | statsmodels==0.14.2
197 | stdlib-list==0.10.0
198 | sympy==1.12
199 | taming-transformers==0.0.1
200 | tblib==3.0.0
201 | tensorboardX==2.6.2.2
202 | terminado==0.17.1
203 | texttable==1.7.0
204 | threadpoolctl==3.4.0
205 | tifffile==2024.5.3
206 | timm==0.9.16
207 | tinycss2==1.2.1
208 | toolz==0.12.1
209 | torch==2.2.2+cu118
210 | torch_cluster==1.6.3+pt22cu118
211 | torch-fidelity==0.3.0
212 | torch_geometric==2.5.2
213 | torch_scatter==2.1.2+pt22cu118
214 | torch_sparse==0.6.18+pt22cu118
215 | torch_spline_conv==1.2.2+pt22cu118
216 | torchaudio==2.2.2+cu118
217 | torchdata==0.7.1
218 | torchmetrics==1.3.2
219 | torchvision==0.17.2+cu118
220 | tornado==6.3.2
221 | tqdm==4.66.2
222 | traitlets==5.7.1
223 | typeguard==4.2.1
224 | typing_extensions==4.11.0
225 | tzdata==2024.1
226 | tzlocal==5.2
227 | umap-learn==0.5.6
228 | urllib3==1.26.18
229 | validators==0.28.1
230 | wcwidth==0.2.5
231 | webencodings==0.5.1
232 | websocket-client==0.58.0
233 | wheel==0.41.2
234 | wrapt==1.16.0
235 | xarray==2023.12.0
236 | xarray-dataclasses==1.7.0
237 | xarray-datatree==0.0.14
238 | xarray-schema==0.0.3
239 | xarray-spatial==0.4.0
240 | yarl==1.9.4
241 | zarr==2.17.2
242 | zict==3.0.0
243 | zipp==3.18.1
244 |
--------------------------------------------------------------------------------