├── eval ├── __init__.py └── eval_core_base.py ├── figs ├── main_fig.webp └── sample_result.webp ├── .gitignore ├── annotation ├── 151507_biofeature.npy ├── 151508_biofeature.npy ├── 151509_biofeature.npy ├── 151510_biofeature.npy ├── 151669_biofeature.npy ├── 151670_biofeature.npy ├── 151671_biofeature.npy ├── 151672_biofeature.npy ├── 151673_biofeature.npy ├── 151674_biofeature.npy ├── 151675_biofeature.npy ├── 151676_biofeature.npy └── dataset_setting.csv ├── utils ├── _make_error_label.py ├── __init__.py ├── _refine_label.py ├── _aligned_accuracy_score.py ├── _cluster_map.py ├── _stable_cluster.py ├── _cluster.py └── _targeted_cluster.py ├── aug └── aug.py ├── downstream_task_demo ├── deconv.py ├── trajectory.ipynb ├── exp_main_net_modalbais.py └── exp_main_net_feat.py ├── README.md ├── model.py ├── preprocess.py ├── environment.yml ├── MUST.py ├── main_MuST.py └── dataloader └── stdata.py /eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figs/main_fig.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zangzelin/code_Must/HEAD/figs/main_fig.webp -------------------------------------------------------------------------------- /figs/sample_result.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zangzelin/code_Must/HEAD/figs/sample_result.webp -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | */__pycache__/* 3 | data/ 4 | save_near_index/ 5 | save_processed_data/ 6 | wandb/ 7 | result/ -------------------------------------------------------------------------------- /annotation/151507_biofeature.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zangzelin/code_Must/HEAD/annotation/151507_biofeature.npy -------------------------------------------------------------------------------- /annotation/151508_biofeature.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zangzelin/code_Must/HEAD/annotation/151508_biofeature.npy -------------------------------------------------------------------------------- /annotation/151509_biofeature.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zangzelin/code_Must/HEAD/annotation/151509_biofeature.npy -------------------------------------------------------------------------------- /annotation/151510_biofeature.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zangzelin/code_Must/HEAD/annotation/151510_biofeature.npy -------------------------------------------------------------------------------- /annotation/151669_biofeature.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zangzelin/code_Must/HEAD/annotation/151669_biofeature.npy -------------------------------------------------------------------------------- /annotation/151670_biofeature.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zangzelin/code_Must/HEAD/annotation/151670_biofeature.npy -------------------------------------------------------------------------------- /annotation/151671_biofeature.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zangzelin/code_Must/HEAD/annotation/151671_biofeature.npy -------------------------------------------------------------------------------- /annotation/151672_biofeature.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zangzelin/code_Must/HEAD/annotation/151672_biofeature.npy -------------------------------------------------------------------------------- /annotation/151673_biofeature.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zangzelin/code_Must/HEAD/annotation/151673_biofeature.npy -------------------------------------------------------------------------------- /annotation/151674_biofeature.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zangzelin/code_Must/HEAD/annotation/151674_biofeature.npy -------------------------------------------------------------------------------- /annotation/151675_biofeature.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zangzelin/code_Must/HEAD/annotation/151675_biofeature.npy -------------------------------------------------------------------------------- /annotation/151676_biofeature.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zangzelin/code_Must/HEAD/annotation/151676_biofeature.npy -------------------------------------------------------------------------------- /utils/_make_error_label.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def make_error_label(y_true, y_pred, wildcard=None): 5 | mask = y_true == y_pred 6 | if wildcard is not None: 7 | mask |= (y_true == wildcard) 8 | dic = {False: 0, True: 1} 9 | 10 | return np.vectorize(dic.__getitem__)(mask) 11 | -------------------------------------------------------------------------------- /annotation/dataset_setting.csv: -------------------------------------------------------------------------------- 1 | dataset,suffix,class_num 2 | 151507,biofeature,7 3 | 151508,biofeature,7 4 | 151509,biofeature,7 5 | 151510,biofeature,7 6 | 151669,biofeature,5 7 | 151670,biofeature,5 8 | 151671,biofeature,5 9 | 151672,biofeature,5 10 | 151673,biofeature,7 11 | 151674,biofeature,7 12 | 151675,biofeature,7 13 | 151676,biofeature,7 -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Cluster 2 | from ._cluster import cluster 3 | from ._cluster_map import cluster_map 4 | from ._targeted_cluster import targeted_cluster 5 | from ._stable_cluster import stable_cluster 6 | from ._refine_label import refine_label 7 | 8 | # Metrics 9 | from ._aligned_accuracy_score import aligned_accuracy_score 10 | 11 | from ._make_error_label import make_error_label -------------------------------------------------------------------------------- /utils/_refine_label.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from typing import Union 4 | 5 | # For Py3.7- compatibility 6 | try: 7 | from typing import Literal 8 | except ImportError: 9 | from typing_extensions import Literal 10 | 11 | 12 | def refine_label(label: np.ndarray, 13 | method: Union[Literal["hexagon"], str] = "hexagon", 14 | random_state: int = 0, 15 | corrds: np.ndarray = None, 16 | radius: int = 30,) -> np.ndarray: 17 | 18 | if method == "hexagon": 19 | if corrds is None: 20 | raise ValueError('Hexagonal refinement supposed to be based on corrdinations.') 21 | 22 | from sklearn.neighbors import NearestNeighbors 23 | nbrs = NearestNeighbors(radius=radius).fit(corrds) 24 | _, index = nbrs.radius_neighbors(corrds) 25 | 26 | refined_label = [] 27 | for vec in index: 28 | cnt = np.bincount(label[vec]) 29 | refined_label.append(np.argmax(cnt)) 30 | 31 | refined_label = np.array(refined_label, dtype=int) 32 | 33 | return refined_label -------------------------------------------------------------------------------- /utils/_aligned_accuracy_score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from munkres import Munkres 3 | from sklearn.metrics import accuracy_score, confusion_matrix 4 | from ._cluster_map import cluster_map 5 | 6 | 7 | # TODO: str support 8 | def aligned_accuracy_score(true: np.ndarray, pred: np.ndarray, wildcard: int = None) -> np.ndarray: 9 | """ 10 | Provide accuracy score for unsupervised label, like cluster analysis. 11 | 12 | Currently, this method only support integer labels. 13 | 14 | Parameters 15 | ---------- 16 | true 1-d int array 17 | True label. 18 | pred 1-d int array 19 | Prediction. 20 | 21 | Returns 22 | ------- 23 | Accuracy score for matched unsupervised label. 24 | 25 | """ 26 | mask = true != wildcard 27 | true = true[mask] 28 | pred = pred[mask] 29 | pred = cluster_map(true, pred, wildcard=999) 30 | 31 | if len(np.unique(true)) != len(np.unique(pred)): 32 | raise ValueError(f"True label has {len(np.unique(true))} classes " 33 | f"but Prediction has {len(np.unique(pred))} classes!") 34 | 35 | return accuracy_score(true, pred) 36 | -------------------------------------------------------------------------------- /utils/_cluster_map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from munkres import Munkres 3 | from sklearn.metrics import confusion_matrix 4 | 5 | 6 | def cluster_map(true: np.ndarray, pred: np.ndarray, wildcard: int = None) -> np.ndarray: 7 | mask = true != wildcard 8 | 9 | if len(np.unique(true[mask])) != len(np.unique(pred[mask])): 10 | raise ValueError(f"True label has {len(np.unique(true[mask]))} classes " 11 | f"but Prediction has {len(np.unique(pred[mask]))} classes!") 12 | 13 | # temporary replace dis-continuous label 14 | enc = dict([(i, j) for i, j in zip(np.unique(true), np.arange(len(np.unique(true))))]) 15 | dec = dict([(j, i) for i, j in enc.items()]) 16 | if np.any(np.unique(true) != np.arange(len(np.unique(true)))): 17 | true = np.vectorize(enc.get)(true) 18 | if np.any(np.unique(pred) != np.arange(len(np.unique(pred)))): 19 | pred = np.vectorize(enc.get)(pred) 20 | 21 | cm = len(true[mask]) - confusion_matrix(pred[mask], true[mask]) 22 | idx = Munkres().compute(cm) 23 | idx = dict(idx) 24 | 25 | pred = np.vectorize(idx.get)(pred) 26 | pred = np.vectorize(dec.get)(pred) 27 | 28 | return pred 29 | -------------------------------------------------------------------------------- /utils/_stable_cluster.py: -------------------------------------------------------------------------------- 1 | from ._cluster import cluster 2 | import numpy as np 3 | 4 | def stable_cluster(latent: np.ndarray, 5 | method: str = "louvain", 6 | attempt_num: int = 3, 7 | target_metric: str = "silhouette", 8 | n_neighbors: int = 10, 9 | n_clusters: int = 10, 10 | resolution: float = 0.5, 11 | pca_dim: int = None, 12 | mclust_model: str = 'EEE') -> np.ndarray: 13 | """ 14 | Easy clustering for clarity. 15 | 16 | Parameters 17 | ---------- 18 | latent: 2-d array 19 | 2-d array to be clustered. 20 | method: "louvain", "leiden", "kmeans" or "mclust" 21 | method used to cluster. 22 | random_state: int 23 | Seed for reproducibility. 24 | n_neighbors: int 25 | Only in Louvain or Leiden. Number of neighborhood to be discovered. 26 | n_clusters: int 27 | Only in KMeans. Number of cluster to be clustered. 28 | resolution: float 29 | A parameter value controlling the coarseness of the clustering. 30 | Higher values lead to more clusters. 31 | pca_dim: int 32 | Target dimesion after PCA, None for no PCA. 33 | mclust_model: str 34 | Model mclusted to use. 35 | 36 | Returns 37 | ------- 38 | label: 1-d array 39 | Clustered label as integers. 40 | 41 | """ 42 | 43 | latent_original = latent 44 | if pca_dim is not None: 45 | from sklearn.decomposition import PCA 46 | latent = PCA(n_components=pca_dim, random_state=0).fit_transform(latent) 47 | 48 | metrics = [] 49 | labels = [] 50 | for seed in range(attempt_num): 51 | label = cluster(latent=latent, 52 | method=method, 53 | random_state=seed, 54 | n_neighbors=n_neighbors, 55 | n_clusters=n_clusters, 56 | resolution=0.5, 57 | pca_dim=None, 58 | mclust_model='EEE') 59 | if target_metric == "silhouette": 60 | from sklearn.metrics import silhouette_score 61 | metrics.append(silhouette_score(latent_original, label)) 62 | 63 | labels.append(label) 64 | 65 | metrics = np.array(metrics) 66 | if target_metric == "silhouette": 67 | return labels[np.argmax(metrics)].astype(int) 68 | else: 69 | return None -------------------------------------------------------------------------------- /utils/_cluster.py: -------------------------------------------------------------------------------- 1 | import anndata 2 | import numpy as np 3 | 4 | from typing import Optional, Union 5 | 6 | # For Py3.7- compatibility 7 | try: 8 | from typing import Literal 9 | except ImportError: 10 | from typing_extensions import Literal 11 | 12 | 13 | def cluster(latent: np.ndarray, 14 | method: Union[Literal["leiden", "louvain", "kmeans", 'mclust'], str] = "louvain", 15 | random_state: int = 0, 16 | n_neighbors: int = 10, 17 | n_clusters: int = 10, 18 | resolution: float = 0.5, 19 | pca_dim: int = None, 20 | mclust_model: str = 'EEE') -> np.ndarray: 21 | """ 22 | Easy clustering for clarity. 23 | 24 | Parameters 25 | ---------- 26 | latent: 2-d array 27 | 2-d array to be clustered. 28 | method: "louvain", "leiden", "kmeans" or "mclust" 29 | method used to cluster. 30 | random_state: int 31 | Seed for reproducibility. 32 | n_neighbors: int 33 | Only in Louvain or Leiden. Number of neighborhood to be discovered. 34 | n_clusters: int 35 | Only in KMeans. Number of cluster to be clustered. 36 | resolution: float 37 | A parameter value controlling the coarseness of the clustering. 38 | Higher values lead to more clusters. 39 | pca_dim: int 40 | Target dimesion after PCA, None for no PCA. 41 | mclust_model: str 42 | Model mclusted to use. 43 | 44 | Returns 45 | ------- 46 | label: 1-d array 47 | Clustered label as integers. 48 | 49 | """ 50 | 51 | if pca_dim is not None: 52 | from sklearn.decomposition import PCA 53 | latent = PCA(n_components=pca_dim, random_state=random_state).fit_transform(latent) 54 | 55 | if method in ["louvain", "leiden"]: 56 | import scanpy as sc 57 | 58 | data = anndata.AnnData(latent, dtype=np.float32) 59 | sc.pp.neighbors(data, n_neighbors=n_neighbors, use_rep='X', random_state=random_state) 60 | if method == "louvain": 61 | sc.tl.louvain(data, resolution=resolution, random_state=random_state) 62 | label = data.obs["louvain"].to_numpy() 63 | elif method == "leiden": 64 | sc.tl.leiden(data, resolution=resolution, random_state=random_state) 65 | label = data.obs["leiden"].to_numpy() 66 | elif method == "kmeans": 67 | from sklearn.cluster import KMeans 68 | label = KMeans(n_clusters=n_clusters, random_state=random_state).fit_predict(latent) 69 | elif method == "mclust": 70 | from rpy2.robjects import r, numpy2ri 71 | 72 | r.library("mclust") 73 | numpy2ri.activate() 74 | r_random_seed = r['set.seed'] 75 | r_random_seed(random_state) 76 | 77 | rmclust = r['Mclust'] 78 | res = rmclust(numpy2ri.numpy2rpy(latent), n_clusters, mclust_model) 79 | 80 | label = np.asarray(res[-2], dtype=int) - 1 81 | else: 82 | raise ValueError(f"Invalid Method {method}!") 83 | 84 | return label.astype(int) -------------------------------------------------------------------------------- /utils/_targeted_cluster.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import trange 3 | from typing import Optional, Union 4 | 5 | # For Py3.7- compatibility 6 | try: 7 | from typing import Literal 8 | except ImportError: 9 | from typing_extensions import Literal 10 | 11 | from ._cluster import cluster 12 | 13 | def targeted_cluster(latent: np.ndarray, 14 | target_n_clusters: int = None, 15 | max_iter: int = 500, 16 | method: Union[Literal["leiden", "louvain"], str] = "louvain", 17 | random_state: int = 0, 18 | n_neighbors: int = 12, 19 | resolution: float = 0.55, ) -> np.ndarray: 20 | """ 21 | Targeted clustering for specified cluster numbers. 22 | 23 | Parameters 24 | ---------- 25 | latent: 2-d array 26 | 2-d array to be clustered. 27 | target_n_clusters int 28 | Optional target cluster numbers. 29 | max_iter int 30 | Maximum iteration counts for specified cluster numbers. 31 | method: "louvain" or "leiden" 32 | method used to cluster. 33 | n_neighbors: int 34 | Number of neighborhood to be discovered. 35 | resolution: float 36 | A parameter value controlling the coarseness of the clustering. 37 | Higher values lead to more clusters. 38 | random_state: int 39 | Seed for reproducibility. 40 | 41 | Returns 42 | ------- 43 | label: 1-d array 44 | Clustered label as integers. 45 | 46 | """ 47 | label = None 48 | visited = {} 49 | tbar = trange(max_iter) 50 | delta_res = 0.02 51 | for run_idx in tbar: 52 | label = cluster(latent=latent, 53 | method=method, 54 | random_state=random_state, 55 | n_neighbors=int(n_neighbors), 56 | resolution=resolution) 57 | 58 | class_num = len(np.unique(label)) 59 | diff = class_num - target_n_clusters 60 | tbar.set_description(f"[{class_num:2d}:{target_n_clusters:2d}] res:{resolution:.4f} n_nbrs:{n_neighbors:2d}") 61 | 62 | if diff == 0: 63 | break 64 | elif run_idx == max_iter - 1: 65 | raise RuntimeError("Hit iteration limit!") 66 | 67 | direct = 1 if diff > 0 else -1 68 | 69 | if np.abs(diff) > 7: 70 | n_neighbors += direct * 2 if n_neighbors > 4 else 0 71 | resolution -= direct * 0.1 if resolution > 0.1 else 0 72 | elif np.abs(diff) > 3: 73 | n_neighbors += direct * 1 if n_neighbors > 4 else 0 74 | resolution -= direct * 0.05 if resolution > 0.05 else 0 75 | else: 76 | if delta_res < 0.0001: 77 | delta_res = 0.02 78 | n_neighbors = max(n_neighbors + direct * 2, 1) 79 | elif resolution < 0.0001: 80 | resolution = 0.2 81 | n_neighbors = max(n_neighbors + direct * 2, 1) 82 | elif resolution > delta_res: 83 | resolution -= direct * delta_res 84 | if visited.get((n_neighbors, resolution), None) is True: 85 | delta_res /= 2 86 | elif visited.get((n_neighbors, resolution), None) is True: 87 | n_neighbors += direct * 3 if n_neighbors > 3 else 0 88 | else: 89 | raise RuntimeError(f"Unable to cluster as {target_n_clusters} classes!") 90 | 91 | visited[(n_neighbors, resolution)] = True 92 | 93 | return label 94 | -------------------------------------------------------------------------------- /aug/aug.py: -------------------------------------------------------------------------------- 1 | import os 2 | import joblib 3 | import logging 4 | 5 | import torch 6 | import numpy as np 7 | from sklearn.metrics import pairwise_distances 8 | 9 | def aug_near_mix(index, dataset, neighbors_index, k=10, random_t=0.1, device="cuda", ): 10 | r = ( 11 | torch.arange(start=0, end=index.shape[0]) * k 12 | + torch.randint(low=1, high=k, size=(index.shape[0],)) 13 | ).to(device) 14 | random_select_near_index = ( 15 | neighbors_index[index][:, :k].reshape((-1,))[r].long() 16 | ) 17 | random_select_near_data2 = dataset.data[random_select_near_index] 18 | random_rate = torch.rand(size=(index.shape[0], 1)).to(device) * random_t 19 | data_cuda_index = dataset.data[index].to(device) 20 | random_select_near_data2 = random_select_near_data2.to(device) 21 | 22 | return ( 23 | random_rate * random_select_near_data2 + (1 - random_rate) * data_cuda_index 24 | ) 25 | 26 | 27 | def aug_near_feautee_change(index, dataset, neighbors_index, k=10, random_t=0.99, device="cuda"): 28 | r = torch.arange(start=0, end=index.shape[0], device=device) * k + torch.randint(low=1, high=k, size=(index.shape[0],), device=device) 29 | 30 | random_select_near_index = ( 31 | neighbors_index[index][:, :k].reshape((-1,))[r].long() 32 | ) 33 | random_select_near_data2 = dataset[random_select_near_index] 34 | data_origin = dataset[index] 35 | random_rate = torch.rand(size=(1, data_origin.shape[1]), device=device) 36 | random_mask = (random_rate > random_t).reshape(-1).float() 37 | return random_select_near_data2 * random_mask + data_origin * (1 - random_mask) 38 | 39 | 40 | def aug_randn(index, dataset, neighbors_index=None, k=10, random_t=0.01, device="cuda"): 41 | data_origin = dataset[index] 42 | return ( 43 | data_origin 44 | + torch.randn(data_origin.shape, device=data_origin.device) * torch.var(dataset, dim=0) * random_t 45 | ) 46 | 47 | def cal_near_index(data, label=None, k=10, device="cuda", uselabel=False, modal=None, graphwithpca=False, dataset="placeholder", unique_str=""): 48 | filename = f"save_near_index/pca{graphwithpca}dataset{dataset}K{k}uselabel{uselabel}modal{modal}n{data.shape[0]}w{data.shape[1]}{unique_str}" 49 | 50 | os.makedirs("save_near_index", exist_ok=True) 51 | if not os.path.exists(filename): 52 | X_rshaped = ( 53 | data.reshape( 54 | (data.shape[0], -1)).detach().cpu().numpy() 55 | ) 56 | # if graphwithpca and X_rshaped.shape[1]>50: 57 | # X_rshaped = PCA(n_components=50).fit_transform(X_rshaped) 58 | if not uselabel: 59 | # index = NNDescent(X_rshaped, n_jobs=-1) 60 | dis = pairwise_distances(X_rshaped) 61 | neighbors_index = dis.argsort(axis=1)[:, 1:k + 1] 62 | # print('X_rshaped', X_rshaped) 63 | # print('neighbors_index', neighbors_index) 64 | # neighbors_index, neighbors_dist = index.query(X_rshaped, k=k+1) 65 | # neighbors_index = neighbors_index[:,1:] 66 | else: 67 | dis = pairwise_distances(X_rshaped) 68 | M = np.repeat(label.reshape(1, -1), X_rshaped.shape[0], axis=0) 69 | dis[(M - M.T) != 0] = dis.max() + 1 70 | neighbors_index = dis.argsort(axis=1)[:, 1:k + 1] 71 | joblib.dump(value=neighbors_index, filename=filename) 72 | 73 | logging.debug(f"save data to {filename}") 74 | else: 75 | logging.debug(f"load data from {filename}") 76 | neighbors_index = joblib.load(filename) 77 | 78 | return torch.tensor(neighbors_index).to(device) -------------------------------------------------------------------------------- /downstream_task_demo/deconv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | import seaborn 4 | import argparse 5 | import numpy as np 6 | import scanpy as sc 7 | import matplotlib.pyplot as plt 8 | from sklearn.linear_model import Lasso 9 | 10 | from dataloader.stdata import STData 11 | 12 | def Analysis_mix(label, emb): 13 | cell_mean_value_list = [] 14 | for c in range(np.max(label)+1): 15 | cell_mean_value_list.append(np.mean(emb[label==c], axis=0)) 16 | 17 | cell_mean_value_numpy = np.array(cell_mean_value_list) 18 | dis = [] 19 | for i in range(emb.shape[0]): 20 | label_c = label[i] 21 | dis.append(np.sqrt( 22 | np.sum((cell_mean_value_numpy[label_c] - emb[i])**2) 23 | )) 24 | dis = np.array(dis) 25 | 26 | return dis, cell_mean_value_list 27 | 28 | 29 | if __name__ == '__main__': 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--num_fea', type=int, default=300) 32 | parser.add_argument('--result_dir', type=str, default='results/') 33 | parser.add_argument('--save_dir', type=str, default='deconv/') 34 | parser.add_argument('--cluster_method', type=str, default='mclust') 35 | parser.add_argument('--dataset', type=str, default='200115_08') 36 | parser.add_argument('--min_cells', type=int, default=0) 37 | parser.add_argument('--alpha', type=float, default=0.1) 38 | args = parser.parse_args() 39 | 40 | args.save_dir += args.dataset + '/' 41 | os.makedirs(args.save_dir, exist_ok=True) 42 | if args.cluster_method == 'mclust': 43 | args.cluster_method = 'main' 44 | 45 | wandb.init( 46 | project="MuST_deconv", 47 | entity="liliangyu", 48 | name=f'Deconv_{args.dataset}', 49 | config=args.__dict__, 50 | save_code=True, 51 | ) 52 | 53 | visium = STData(name=args.dataset, bio_norm=False) 54 | adata = visium.adata 55 | sc.pp.filter_genes(adata, min_cells=args.min_cells) 56 | 57 | data = np.load(f"{args.result_dir}{args.dataset}/trans_input.npy")[:,:args.num_fea] 58 | emb = np.load(f"{args.result_dir}{args.dataset}/emb.npy") 59 | label = np.load(f'{args.result_dir}{args.dataset}/pred_{args.cluster_method}.npy') 60 | hvg = np.load(f'{args.result_dir}{args.dataset}/hvg.npy') 61 | 62 | print('dataset loaded') 63 | 64 | feature_name = adata.var_names[hvg].tolist()[:args.num_fea] 65 | 66 | dis, cell_mean_value_list = Analysis_mix(label, emb) 67 | 68 | weight_list = [] 69 | for i in range(emb.shape[0]): 70 | A = np.array(cell_mean_value_list).T # (20, 72) 71 | B = emb[i] # (1, 72) 72 | model = Lasso(alpha=args.alpha) 73 | model.fit(A,B) 74 | weight = model.coef_ 75 | weight_list.append(weight) 76 | 77 | weight = np.array(weight_list).T 78 | weight[weight<1e-5] = 0 79 | w_sum = weight.sum(axis=0) 80 | n_weight = weight / w_sum 81 | np.save(args.save_dir+'val.npy', weight) 82 | np.save(args.save_dir+'n_val.npy', n_weight) 83 | 84 | fig = plt.figure(figsize=(20,5), dpi=300) 85 | seaborn.heatmap(weight, cmap='Reds',) 86 | plt.tight_layout() 87 | 88 | fig_sp = visium.px_plot_spatial(label, background_image=True, save_path=args.save_dir + 'fig_sp.png') 89 | fig_sp.write_html(args.save_dir + 'fig_sp.html') 90 | for i in np.unique(label): 91 | visium.px_plot_spatial_gene(weight[i], save_path=args.save_dir + f'sp_{i}.png') 92 | visium.px_plot_spatial_gene(n_weight[i], save_path=args.save_dir + f'spn_{i}.png') 93 | 94 | fig = plt.figure(figsize=(20,5), dpi=300) 95 | plt.title(f'Cluster {i}') 96 | seaborn.heatmap(weight[:, label==i], cmap='Reds',) 97 | plt.tight_layout() 98 | plt.savefig(args.save_dir + f'hm_{i}.png', dpi=200) 99 | plt.close() 100 | 101 | fig = plt.figure(figsize=(20,5), dpi=300) 102 | plt.title(f'Cluster {i}') 103 | seaborn.heatmap(n_weight[:, label==i], cmap='Reds',) 104 | plt.tight_layout() 105 | plt.savefig(args.save_dir + f'hmn_{i}.png', dpi=200) 106 | plt.close() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Must: Maximizing Latent Capacity of Spatial Transcriptomics Data 2 | 3 | ![Main Figure](figs/main_fig.webp) 4 | 5 | Spatial transcriptomics (ST) technologies have revolutionized the study of gene expression patterns in tissues by providing multimodality data in transcriptomic, spatial, and morphological, offering opportunities for understanding tissue biology beyond transcriptomics. However, we identify the modality bias phenomenon in ST data species, i.e., the inconsistent contribution of different modalities to the labels leads to a tendency for the analysis methods to retain the information of the dominant modality. How to mitigate the adverse effects of modality bias to satisfy various downstream tasks remains a fundamental challenge. This paper introduces Multiple-modality Structure Transformation, named MuST, a novel methodology to tackle the challenge. MuST integrates the multi-modality information contained in the ST data effectively into a uniform latent space to provide a foundation for all the downstream tasks. It learns intrinsic local structures by topology discovery strategy and topology fusion loss function to solve the inconsistencies among different modalities. Thus, these topology-based and deep learning techniques provide a solid foundation for a variety of analytical tasks while coordinating different modalities. The effectiveness of MuST is assessed by performance metrics and biological significance. The results show that it outperforms existing state-of-the-art methods with clear advantages in the precision of identifying and preserving structures of tissues and biomarkers. MuST offers a versatile toolkit for the intricate analysis of complex biological systems. 6 | 7 | Full paper can be downloaded via [arxiv](https://arxiv.org/abs/2401.07543). 8 | 9 | ## Configurating python environment 10 | 11 | We recommend using conda for configuration. You can refer to our `envrionment.yml` to configure the environment or try `conda env create`. 12 | 13 | ```bash 14 | conda env create -f environment.yml 15 | ``` 16 | 17 | ## Run MuST 18 | 19 | You can run MuST with a single line of code to get spatial clustering result and its latent embedding. 20 | 21 | ### Minimun replication 22 | 23 | Running minimal replication can be done with the following command: 24 | 25 | ```bash 26 | python main_MuST.py 27 | ``` 28 | 29 | We use V1_Adult_Mouse_Brain(from 10x Visium) dataset for demonstration. On our 64-core A100 machine, it usually lasts about 4 minutes. Sample result is provided below: 30 | 31 | ![Sample Result](figs/sample_result.webp) 32 | 33 | ### Multi platform support 34 | 35 | For 10x Visium platform, datasets are downloaded and handled automatically. 36 | For Stereo-seq and SlideSeq platforms, check section `Data Description` in our paper to get data. 37 | 38 | We provided a universal standard to handle multi platform datasets. Some datasets need to be processed as one file and its path should be `data//data.h5ad`. 39 | 40 | ```bash 41 | python main_MuST.py --dataset= 42 | ``` 43 | 44 | ### Other Specifications 45 | 46 | - **more clustering methods**: we employ louvain and leiden as alternative clustering method. Augments are `plot_leiden` and `plot_louvain`. 47 | 48 | ## Explore MuST 49 | 50 | We recommand to use `wandb` to log and view results. Wandb result url can be found in command line once the run is finished. For deeper insights, the structure of MuST result is demonstrated below 51 | 52 | ```bash 53 | result 54 | |-- V1_Adult_Mouse_Brain 55 | |-- emb.npy # embedding 56 | |-- emb_2d.npy # 2D embedding suggested by umap 57 | |-- hvg.npy # highly variable genes selection mask 58 | |-- mclust.png # spatial clustering result by MClust 59 | |-- mclust_emb.png # 2D embedding colored by spatial clustering result 60 | |-- pred_main.npy # MClust result 61 | |-- setting.txt 62 | ``` 63 | 64 | We provide Downstream Task Demos in `downstream_task_demo`. 65 | 66 | ## Cite us 67 | 68 | ```bib 69 | @misc{zang2024must, 70 | title={Must: Maximizing Latent Capacity of Spatial Transcriptomics Data}, 71 | author={Zelin Zang and Liangyu Li and Yongjie Xu and Chenrui Duan and Kai Wang and Yang You and Yi Sun and Stan Z. Li}, 72 | year={2024}, 73 | eprint={2401.07543}, 74 | archivePrefix={arXiv}, 75 | primaryClass={cs.CE} 76 | } 77 | ``` 78 | 79 | -------------------------------------------------------------------------------- /downstream_task_demo/trajectory.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import pandas as pd\n", 11 | "import matplotlib.pyplot as pl\n", 12 | "from matplotlib import rcParams\n", 13 | "import scanpy as sc\n", 14 | "\n", 15 | "from dataloader.stdata import STData\n", 16 | "\n", 17 | "# sc.settings.verbosity = 3\n", 18 | "sc.settings.set_figure_params(dpi=400, frameon=False, figsize=(5, 5), facecolor='white') # low dpi (dots per inch) yields small inline figures\n", 19 | "# dic = {0: \"WM\", 1: \"Layer 1\", 2: \"Layer 2\", 3: \"Layer 3\", 4: \"Layer 4\", 5: \"Layer 5\", 6: \"Layer 6\"}\n", 20 | "dic = {0: \"0\", 1: \"1\", 2: \"2\", 3: \"3\", 4: \"4\", 5: \"5\", 6: \"6\", 7: \"7\", 8: \"8\"}" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "for dataset in ['200127_15']:\n", 30 | " dataset = str(dataset)\n", 31 | " load_dir = f'results/{dataset}/'\n", 32 | " visium = STData(dataset)\n", 33 | " emb = np.load(load_dir+'emb.npy')\n", 34 | " emb_2d = np.load(load_dir + 'emb_2d.npy')\n", 35 | " pred = np.load(load_dir+'pred_leiden.npy')\n", 36 | "\n", 37 | " visium.adata.obs['pred'] = np.vectorize(dic.__getitem__)(pred)\n", 38 | " visium.adata.obs['pred'] = visium.adata.obs['pred'].astype('category')\n", 39 | " visium.adata.obsm['emb'] = emb\n", 40 | " visium.adata.obsm['umap'] = emb_2d\n", 41 | " visium.adata.uns['pred_colors'] = ['#2E91E5', '#E15F99', '#1CA71C', '#FB0D0D',\n", 42 | " '#DA16FF', '#222A2A', '#B68100', '#750D86',\n", 43 | " '#EB663B', '#511CFB', '#00A08B']\n", 44 | " sc.pp.neighbors(visium.adata, n_neighbors=2, use_rep='emb')\n", 45 | " sc.tl.paga(visium.adata, groups='pred')\n", 46 | " sc.pl.paga_compare(visium.adata, legend_fontsize=0, frameon=False, size=20,\n", 47 | " legend_fontoutline=0, show=False)\n", 48 | " pl.savefig(load_dir + 'traj.png')\n", 49 | " extent = pl.gca().get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())\n", 50 | " pl.gcf().savefig(load_dir+'traj_single.png', bbox_inches=extent)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "# for dataset in [ 151507, 151508, 151509, 151510, 151669, 151670, 151671, 151672, 151673, 151674, 151675, 151676]:\n", 60 | "for dataset in [151673]:\n", 61 | " dataset = str(dataset)\n", 62 | " load_dir = f'test_dlpfc/{dataset}/'\n", 63 | " visium = STData(dataset)\n", 64 | " emb = np.load(load_dir+'emb.npy')\n", 65 | " emb_2d = np.load(load_dir + 'emb_2d.npy')\n", 66 | " pred = np.load(load_dir+'pred_main.npy')\n", 67 | "\n", 68 | " visium.adata.obs['pred'] = np.vectorize(dic.__getitem__)(pred)\n", 69 | " visium.adata.obs['true'] = np.vectorize(dic.__getitem__)(visium.get_label())\n", 70 | " visium.adata.obs['true'] = visium.adata.obs['true'].astype('category')\n", 71 | " visium.adata.obs['pred'] = visium.adata.obs['pred'].astype('category')\n", 72 | " visium.adata.obsm['emb'] = emb\n", 73 | " visium.adata.obsm['umap'] = emb_2d\n", 74 | " visium.adata.uns['true_colors'] = ['#2E91E5', '#E15F99', '#1CA71C', '#FB0D0D', '#DA16FF', '#222A2A', '#B68100']\n", 75 | " sc.pp.neighbors(visium.adata, n_neighbors=10, use_rep='emb')\n", 76 | " sc.tl.paga(visium.adata, groups='true')\n", 77 | " sc.pl.paga_compare(visium.adata, legend_fontsize=0, frameon=False, size=20,\n", 78 | " legend_fontoutline=0, show=False, threshold=0.3)\n", 79 | " pl.savefig(load_dir + 'traj.png')\n", 80 | " extent = pl.gca().get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())\n", 81 | " pl.gcf().savefig(load_dir+'traj_single.png', bbox_inches=extent)\n", 82 | " _ = visium.px_plot_embedding(emb_2d, visium.get_label(), save_path=load_dir+'true_emb.png')" 83 | ] 84 | } 85 | ], 86 | "metadata": { 87 | "kernelspec": { 88 | "display_name": "mmeg", 89 | "language": "python", 90 | "name": "python3" 91 | }, 92 | "language_info": { 93 | "codemirror_mode": { 94 | "name": "ipython", 95 | "version": 3 96 | }, 97 | "file_extension": ".py", 98 | "mimetype": "text/x-python", 99 | "name": "python", 100 | "nbconvert_exporter": "python", 101 | "pygments_lexer": "ipython3", 102 | "version": "3.8.15" 103 | }, 104 | "orig_nbformat": 4 105 | }, 106 | "nbformat": 4, 107 | "nbformat_minor": 2 108 | } 109 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | from torch.nn.modules.module import Module 6 | 7 | 8 | class Encoder_MUST(Module): 9 | def __init__(self, in_features_a, in_features_b, out_features, graph_neigh, n_encoder_layer=1, dropout=0.0, act=F.relu, morph_trans_ratio=0.5, platform='10x', 10 | bn_type='bn', n_fusion_layer=1): 11 | super(Encoder_MUST, self).__init__() 12 | self.in_features_a = in_features_a 13 | self.in_features_b = in_features_b 14 | self.out_features = out_features 15 | self.graph_neigh = graph_neigh 16 | self.dropout = dropout 17 | self.act = act 18 | self.morph_trans_ratio = morph_trans_ratio 19 | self.platform=platform 20 | 21 | self.encoder_a = nn.Sequential() 22 | for i in range(n_encoder_layer): 23 | if i == 0: 24 | self.encoder_a.append(nn.Linear(self.out_features, self.in_features_a)) 25 | else: 26 | self.encoder_a.append(nn.Linear(self.out_features, self.out_features)) 27 | if self.in_features_b is not None: 28 | self.encoder_b = nn.Sequential() 29 | for i in range(n_encoder_layer): 30 | if i == 0: 31 | self.encoder_b.append(nn.Linear(self.out_features, self.in_features_b)) 32 | else: 33 | self.encoder_b.append(nn.Linear(self.out_features, self.out_features)) 34 | self.mlp_out = Parameter(torch.FloatTensor(self.out_features, self.in_features_a)) 35 | self.reset_parameters() 36 | 37 | if bn_type == 'bn': 38 | self.batch_norm = nn.BatchNorm1d(out_features) 39 | elif bn_type == 'none': 40 | self.batch_norm = nn.Identity(out_features) 41 | 42 | if n_fusion_layer == 1: 43 | self.mlp = nn.Linear(self.out_features, self.out_features,) 44 | else: 45 | self.mlp = nn.Sequential() 46 | for i in range(n_fusion_layer): 47 | self.mlp.append(nn.Linear(self.out_features, self.out_features)) 48 | 49 | self.sigm = nn.Sigmoid() 50 | print(self.encoder_a) 51 | print(self.mlp) 52 | print(self.batch_norm) 53 | 54 | def reset_parameters(self): 55 | for weight in self.encoder_a: 56 | torch.nn.init.xavier_uniform_(weight.weight) 57 | if self.in_features_b is not None: 58 | for weight in self.encoder_b: 59 | torch.nn.init.constant_(weight.weight, 0) 60 | torch.nn.init.xavier_uniform_(self.mlp_out) 61 | 62 | def head_fwd(self, encoder, data, adj): 63 | for i, weight in enumerate(encoder): 64 | if i == 0: 65 | z = F.dropout(data, self.dropout, self.training) 66 | z = torch.mm(z, weight.weight) 67 | z = torch.mm(adj, z) 68 | else: 69 | z = F.dropout(z, self.dropout, self.training) 70 | z = torch.mm(z, weight.weight) 71 | z = torch.mm(adj, z) 72 | 73 | return z 74 | 75 | def forward(self, feat_a, feat_b, adj): 76 | z1 = self.head_fwd(self.encoder_a, feat_a, adj) 77 | z2 = None 78 | if feat_b is not None: 79 | z2 = self.head_fwd(self.encoder_b, feat_b, adj) 80 | 81 | # hiden_emb = torch.concat([z1, z2], axis=1) 82 | if feat_b is not None: 83 | # import numpy as np 84 | # np.save('trans_emb.npy', z1.detach().cpu().numpy()) 85 | # np.save('morph_emb.npy', z2.detach().cpu().numpy()) 86 | # import pdb;pdb.set_trace() 87 | hiden_emb = z1 * self.morph_trans_ratio + z2 * (1 - self.morph_trans_ratio) # z1 trans z2 morph 88 | else: 89 | hiden_emb = z1 90 | hiden_emb2 = self.mlp(hiden_emb) 91 | hiden_emb2 = self.batch_norm(hiden_emb2) 92 | 93 | h = torch.mm(hiden_emb2, self.mlp_out) 94 | h = torch.sparse.mm(adj, h) 95 | 96 | return [hiden_emb, hiden_emb2], h, z1, z2 97 | 98 | def save(self, save_dir=''): 99 | torch.save(self.encoder_a, save_dir + 'encoder_a.pt') 100 | torch.save(self.mlp, save_dir + 'mlp.pt') 101 | torch.save(self.mlp_out, save_dir + 'mlp_out.pt') 102 | if self.in_features_b is not None: 103 | torch.save(self.encoder_b, save_dir + 'encoder_b.pt') 104 | 105 | def load(self, load_dir=''): 106 | self.encoder_a = torch.load(load_dir + 'encoder_a.pt') 107 | self.mlp = torch.load(load_dir + 'mlp.pt') 108 | self.mlp_out = torch.load(load_dir + 'mlp_out.pt') 109 | if self.platform == '10x': 110 | self.encoder_b = torch.load(load_dir + 'encoder_b.pt') 111 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | import scanpy as sc 6 | import scipy.sparse as sp 7 | from torch.backends import cudnn 8 | from scipy.sparse.csc import csc_matrix 9 | from scipy.sparse.csr import csr_matrix 10 | from sklearn.neighbors import NearestNeighbors 11 | 12 | def permutation(feature): 13 | # fix_seed(FLAGS.random_seed) 14 | ids = np.arange(feature.shape[0]) 15 | ids = np.random.permutation(ids) 16 | feature_permutated = feature[ids] 17 | 18 | return feature_permutated 19 | 20 | def construct_interaction(adata, n_neighbors=3): 21 | import ot 22 | """Constructing spot-to-spot interactive graph""" 23 | position = adata.obsm['spatial'] 24 | 25 | # calculate distance matrix 26 | distance_matrix = ot.dist(position, position, metric='euclidean') 27 | n_spot = distance_matrix.shape[0] 28 | 29 | adata.obsm['distance_matrix'] = distance_matrix 30 | 31 | # find k-nearest neighbors 32 | interaction = np.zeros([n_spot, n_spot]) 33 | for i in range(n_spot): 34 | vec = distance_matrix[i, :] 35 | distance = vec.argsort() 36 | for t in range(1, n_neighbors + 1): 37 | y = distance[t] 38 | interaction[i, y] = 1 39 | 40 | adata.obsm['graph_neigh'] = interaction 41 | 42 | #transform adj to symmetrical adj 43 | adj = interaction 44 | adj = adj + adj.T 45 | adj = np.where(adj>1, 1, adj) 46 | 47 | adata.obsm['adj'] = adj 48 | 49 | def construct_interaction_KNN(adata, n_neighbors=3): 50 | position = adata.obsm['spatial'] 51 | n_spot = position.shape[0] 52 | nbrs = NearestNeighbors(n_neighbors=n_neighbors+1).fit(position) 53 | _ , indices = nbrs.kneighbors(position) 54 | x = indices[:, 0].repeat(n_neighbors) 55 | y = indices[:, 1:].flatten() 56 | interaction = np.zeros([n_spot, n_spot]) 57 | interaction[x, y] = 1 58 | 59 | adata.obsm['graph_neigh'] = interaction 60 | 61 | #transform adj to symmetrical adj 62 | adj = interaction 63 | adj = adj + adj.T 64 | adj = np.where(adj>1, 1, adj) 65 | 66 | adata.obsm['adj'] = adj 67 | print('Graph constructed!') 68 | 69 | def preprocess(adata, min_cells=None, n_top_genes=3000, max_value=10, dataset=None): 70 | print("preprocessing " + dataset) 71 | 72 | if min_cells is not None: 73 | sc.pp.filter_genes(adata, min_cells=min_cells) 74 | print(f'filter mincells={min_cells}, shape {adata.shape}') 75 | sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=n_top_genes) 76 | sc.pp.normalize_total(adata, target_sum=1e4) 77 | sc.pp.log1p(adata) 78 | sc.pp.scale(adata, zero_center=False, max_value=max_value) 79 | 80 | return adata 81 | 82 | def get_feature(adata, deconvolution=False): 83 | if deconvolution: 84 | adata_Vars = adata 85 | else: 86 | adata_Vars = adata[:, adata.var['highly_variable']] 87 | 88 | if isinstance(adata_Vars.X, csc_matrix) or isinstance(adata_Vars.X, csr_matrix): 89 | feat = adata_Vars.X.toarray()[:, ] 90 | else: 91 | feat = adata_Vars.X[:, ] 92 | 93 | # data augmentation 94 | feat_a = permutation(feat) 95 | 96 | adata.obsm['feat'] = feat 97 | adata.obsm['feat_a'] = feat_a 98 | 99 | def normalize_adj(adj): 100 | """Symmetrically normalize adjacency matrix.""" 101 | adj = sp.coo_matrix(adj) 102 | rowsum = np.array(adj.sum(1)) 103 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 104 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 105 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 106 | adj = adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt) 107 | return adj.toarray() 108 | 109 | def preprocess_adj(adj): 110 | """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation.""" 111 | adj_normalized = normalize_adj(adj)+np.eye(adj.shape[0]) 112 | return adj_normalized 113 | 114 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 115 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 116 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 117 | indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 118 | values = torch.from_numpy(sparse_mx.data) 119 | shape = torch.Size(sparse_mx.shape) 120 | return torch.sparse.FloatTensor(indices, values, shape) 121 | 122 | def preprocess_adj_sparse(adj): 123 | adj = sp.coo_matrix(adj) 124 | adj_ = adj + sp.eye(adj.shape[0]) 125 | rowsum = np.array(adj_.sum(1)) 126 | degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten()) 127 | adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo() 128 | return sparse_mx_to_torch_sparse_tensor(adj_normalized) 129 | 130 | def fix_seed(seed): 131 | os.environ['PYTHONHASHSEED'] = str(seed) 132 | random.seed(seed) 133 | np.random.seed(seed) 134 | torch.manual_seed(seed) 135 | torch.cuda.manual_seed(seed) 136 | torch.cuda.manual_seed_all(seed) 137 | cudnn.deterministic = True 138 | cudnn.benchmark = False 139 | 140 | os.environ['PYTHONHASHSEED'] = str(seed) 141 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 142 | 143 | 144 | -------------------------------------------------------------------------------- /downstream_task_demo/exp_main_net_modalbais.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import warnings 4 | 5 | import numpy as np 6 | import wandb 7 | 8 | from dataloader.stdata import STData 9 | from sklearn.svm import SVC 10 | from sklearn.svm import SVR 11 | from sklearn.multioutput import MultiOutputRegressor 12 | from sklearn.preprocessing import StandardScaler, MinMaxScaler 13 | import shap 14 | import xgboost as xgb 15 | from xgboost import XGBRegressor 16 | from xgboost import XGBClassifier 17 | import seaborn as sns 18 | import matplotlib.pyplot as plt 19 | 20 | def svc_train(data, emb, label): 21 | print('Training XGBClassifier') 22 | clf_svc = XGBClassifier( 23 | objective='multi:softprob', 24 | num_class=np.max(label)+1, 25 | eval_metric='mlogloss', 26 | nthread=55, 27 | tree_method="hist", 28 | ) 29 | 30 | clf_svc.fit(data, label) 31 | return clf_svc 32 | 33 | 34 | 35 | if __name__ == '__main__': 36 | import argparse 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--R_HOME', type=str, default='/root/miniconda3/lib/R') 39 | parser.add_argument("--wandb", type=str, default="online") 40 | 41 | # Datasets 42 | parser.add_argument('--dataset', type=str, default='151673') 43 | parser.add_argument('--sample', type=str, default='barcode') 44 | parser.add_argument('--n_top_genes', type=int, default=3000) 45 | parser.add_argument('--max_value', type=int, default=10) 46 | parser.add_argument('--crop_size', type=int, default=224) 47 | parser.add_argument('--preprocessed', type=int, default=0) 48 | parser.add_argument('--min_cells', type=int, default=50) 49 | 50 | # Augmentation 51 | parser.add_argument('--graphwithpca', type=bool, default=True) 52 | parser.add_argument('--uselabel', type=bool, default=False) 53 | parser.add_argument('--K_m0', type=int, default=50) 54 | parser.add_argument('--K_m1', type=int, default=50) 55 | 56 | # Cluster 57 | parser.add_argument('--cluster_using', type=str, default='gene_rec') # gene_rec 58 | parser.add_argument('--n_clusters', type=int, default=15) 59 | parser.add_argument('--radius', type=int, default=50) 60 | parser.add_argument('--cluster_refinement', type=int, default=0) 61 | 62 | # Model 63 | parser.add_argument('--learning_rate', type=float, default=0.001) 64 | parser.add_argument('--weight_decay', type=float, default=0.00) 65 | parser.add_argument('--epochs', type=int, default=1000) 66 | parser.add_argument('--dim_input', type=int, default=3000) 67 | parser.add_argument('--dim_output', type=int, default=64) 68 | parser.add_argument('--alpha', type=float, default=0.001) 69 | parser.add_argument('--beta', type=float, default=1) 70 | parser.add_argument('--aug_rate_0', type=float, default=0.1) 71 | parser.add_argument('--aug_rate_1', type=float, default=0.1) 72 | parser.add_argument('--v_latent', type=float, default=0.01) 73 | parser.add_argument('--theta', type=float, default=0.1) 74 | parser.add_argument('--random_seed', type=int, default=1) 75 | parser.add_argument('--n_encoder_layer', type=int, default=1, help='number of encoder layers') 76 | parser.add_argument('--n_fusion_layer', type=int, default=1) 77 | parser.add_argument('--bn_type', type=str, default='bn') 78 | parser.add_argument('--self_loop', type=int, default=0) 79 | parser.add_argument('--down_sample_rate', type=float, default=1) 80 | parser.add_argument('--morph_trans_ratio', type=float, default=1) 81 | parser.add_argument('--aug_method', type=str, default="near_mix") 82 | parser.add_argument('--run_dir', type=str, default=os.getenv('WANDB_RUN_DIR')) 83 | parser.add_argument('--model_dir', type=str, default='model/') 84 | parser.add_argument('--save_dir', type=str, default=None) 85 | parser.add_argument('--plot', type=int, default=0, help='Plot the result') 86 | parser.add_argument('--var_plot', type=int, default=0) 87 | parser.add_argument('--plot_louvain', type=int, default=1) 88 | parser.add_argument('--plot_leiden', type=int, default=0) 89 | 90 | parser.add_argument('--norm', type=str, default="none") 91 | parser.add_argument('--num_fea', type=int, default=3000) 92 | parser.add_argument('--num_sample', type=int, default=1500) 93 | parser.add_argument('--aim_label', type=int, default=1) 94 | parser.add_argument('--max_evals', type=int, default=5500) 95 | 96 | args = parser.parse_args() 97 | 98 | args.plot = bool(args.plot) 99 | args.var_plot = bool(args.var_plot) 100 | args.plot_louvain = bool(args.plot_louvain) 101 | args.plot_leiden = bool(args.plot_leiden) 102 | args.cluster_refinement = bool(args.cluster_refinement) 103 | args.preprocessed = bool(args.preprocessed) 104 | 105 | if args.run_dir is not None and not os.path.exists(args.run_dir): 106 | os.mkdir(args.run_dir) 107 | if args.save_dir is not None: 108 | args.save_dir += f"{args.dataset}/" 109 | os.makedirs(args.save_dir, exist_ok=True) 110 | 111 | wandb.init( 112 | project="MuST_modalbias", 113 | name=f'dataset{args.dataset}_num_fea{args.num_fea}_num_sample{args.num_sample}_aim_label{args.aim_label}_max_evals{args.max_evals}_sample{args.sample}', 114 | config=args, 115 | ) 116 | 117 | visium = STData(name=args.dataset, crop_size=args.crop_size, bio_norm=False, sample=args.sample) # Reset sample to get better results. 118 | 119 | adata = visium.adata 120 | adata.uns["name"] = args.dataset 121 | 122 | 123 | emb = np.load(f'results/{args.dataset}/emb.npy') 124 | label = np.load(f'results/{args.dataset}/pred_main.npy') 125 | trans_dict = {v:i for i, v in enumerate(np.unique(label))} 126 | label = np.vectorize(trans_dict.__getitem__)(label) 127 | trans = np.load(f'results/{args.dataset}/trans_input.npy') 128 | hvg = np.load(f'results/{args.dataset}/hvg.npy') 129 | adata = adata[:, hvg] 130 | morph = visium.get_morph() 131 | loc = visium.get_coords() 132 | loc = np.concatenate([loc, (loc[:, 0] + loc[:, 1]).reshape(-1, 1), (loc[:, 0] - loc[:, 1]).reshape(-1, 1)], axis=1) 133 | if args.norm == 'std': 134 | loc = StandardScaler().fit_transform(loc) 135 | elif args.norm == 'minmax': 136 | loc = MinMaxScaler().fit_transform(loc) 137 | data = np.concatenate([trans, morph, loc], axis=1) 138 | 139 | feature_name = adata.var_names.tolist() 140 | 141 | print('explain') 142 | clf_svc = svc_train(data, emb, label) 143 | explainer = shap.Explainer(clf_svc, data, feature_names=feature_name[:args.num_fea]) 144 | shap_values = explainer(data) 145 | 146 | shap_values = np.abs(shap_values.values) 147 | 148 | imp_mor = shap_values[:, trans.shape[1] + morph.shape[1]:].max((1,2)) 149 | imp_gen = shap_values[:, :trans.shape[1] + morph.shape[1]].max((1,2)) 150 | 151 | mor_rate = imp_mor / (imp_mor + imp_gen) 152 | save_dir = 'res_multimodal/' 153 | os.makedirs(save_dir, exist_ok=True) 154 | np.save(save_dir + f'mor_rate_{args.dataset}_{args.norm}', mor_rate) 155 | fig = visium.px_plot_spatial_gene(mor_rate, background_image=False, dpi=400, save_path=save_dir + f'rate_{args.dataset}_{args.norm}.png') 156 | fig = visium.px_plot_spatial_gene(mor_rate, background_image=True, dpi=400, save_path=save_dir + f'rate_{args.dataset}_{args.norm}_bg.png') 157 | colors = ['#2E91E5', '#E15F99', '#1CA71C'] 158 | plt.figure(figsize=(3, 4), dpi=400) 159 | sns.violinplot(mor_rate, color=colors[0]) 160 | plt.xticks([]) 161 | plt.savefig(save_dir + f'violin_{args.dataset}_{args.norm}.png') 162 | wandb.log({'rate': fig}) 163 | 164 | wandb.finish() -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: MuST 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - _r-mutex=1.0.1=anacondar_1 11 | - binutils_impl_linux-64=2.36.1=h193b22a_2 12 | - binutils_linux-64=2.36=hf3e587d_33 13 | - blas=1.0=mkl 14 | - boltons=23.0.0=pyhd8ed1ab_0 15 | - brotlipy=0.7.0=py310h7f8727e_1002 16 | - bwidget=1.9.14=ha770c72_1 17 | - bzip2=1.0.8=h7b6447c_0 18 | - c-ares=1.18.1=h7f98852_0 19 | - ca-certificates=2023.7.22=hbcca054_0 20 | - cairo=1.16.0=hf32fb01_1 21 | - certifi=2023.7.22=pyhd8ed1ab_0 22 | - cffi=1.15.1=py310h5eee18b_3 23 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 24 | - conda=23.7.3=py310hff52083_0 25 | - conda-content-trust=0.1.3=py310h06a4308_0 26 | - conda-package-handling=1.9.0=py310h5eee18b_1 27 | - cryptography=38.0.1=py310h9ce1e76_0 28 | - cuda=11.6.1=0 29 | - cuda-cccl=11.6.55=hf6102b2_0 30 | - cuda-command-line-tools=11.6.2=0 31 | - cuda-compiler=11.6.2=0 32 | - cuda-cudart=11.6.55=he381448_0 33 | - cuda-cudart-dev=11.6.55=h42ad0f4_0 34 | - cuda-cuobjdump=11.6.124=h2eeebcb_0 35 | - cuda-cupti=11.6.124=h86345e5_0 36 | - cuda-cuxxfilt=11.6.124=hecbf4f6_0 37 | - cuda-driver-dev=11.6.55=0 38 | - cuda-gdb=12.0.140=0 39 | - cuda-libraries=11.6.1=0 40 | - cuda-libraries-dev=11.6.1=0 41 | - cuda-memcheck=11.8.86=0 42 | - cuda-nsight=12.0.140=0 43 | - cuda-nsight-compute=12.0.1=0 44 | - cuda-nvcc=11.6.124=hbba6d2d_0 45 | - cuda-nvdisasm=12.0.140=0 46 | - cuda-nvml-dev=11.6.55=haa9ef22_0 47 | - cuda-nvprof=12.0.146=0 48 | - cuda-nvprune=11.6.124=he22ec0a_0 49 | - cuda-nvrtc=11.6.124=h020bade_0 50 | - cuda-nvrtc-dev=11.6.124=h249d397_0 51 | - cuda-nvtx=11.6.124=h0630a44_0 52 | - cuda-nvvp=12.0.146=0 53 | - cuda-runtime=11.6.1=0 54 | - cuda-samples=11.6.101=h8efea70_0 55 | - cuda-sanitizer-api=12.0.140=0 56 | - cuda-toolkit=11.6.1=0 57 | - cuda-tools=11.6.1=0 58 | - cuda-visual-tools=11.6.1=0 59 | - curl=7.76.1=h979ede3_1 60 | - ffmpeg=4.3=hf484d3e_0 61 | - flit-core=3.6.0=pyhd3eb1b0_0 62 | - fontconfig=2.14.1=hef1e5e3_0 63 | - freetype=2.12.1=h4a9f257_0 64 | - fribidi=1.0.10=h36c2ea0_0 65 | - gcc_impl_linux-64=7.5.0=hda68d29_13 66 | - gcc_linux-64=7.5.0=h47867f9_33 67 | - gds-tools=1.5.1.14=0 68 | - gettext=0.19.8.1=h73d1719_1008 69 | - gfortran_impl_linux-64=7.5.0=h56cb351_20 70 | - gfortran_linux-64=7.5.0=h78c8a43_33 71 | - giflib=5.2.1=h5eee18b_1 72 | - glib=2.68.4=h9c3ff4c_1 73 | - glib-tools=2.68.4=h9c3ff4c_1 74 | - gmp=6.2.1=h295c915_3 75 | - gnutls=3.6.15=he1e5248_0 76 | - graphite2=1.3.13=h58526e2_1001 77 | - gsl=2.4=h14c3975_4 78 | - gxx_impl_linux-64=7.5.0=h64c220c_13 79 | - gxx_linux-64=7.5.0=h555fc39_33 80 | - harfbuzz=2.4.0=h37c48d4_1 81 | - icu=58.2=hf484d3e_1000 82 | - idna=3.4=py310h06a4308_0 83 | - intel-openmp=2021.4.0=h06a4308_3561 84 | - jpeg=9e=h7f8727e_0 85 | - jsonpatch=1.32=pyhd8ed1ab_0 86 | - jsonpointer=2.0=py_0 87 | - kernel-headers_linux-64=2.6.32=he073ed8_16 88 | - krb5=1.17.2=h926e7f8_0 89 | - lame=3.100=h7b6447c_0 90 | - lcms2=2.12=h3be6417_0 91 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 92 | - lerc=3.0=h295c915_0 93 | - libcublas=11.9.2.110=h5e84587_0 94 | - libcublas-dev=11.9.2.110=h5c901ab_0 95 | - libcufft=10.7.1.112=hf425ae0_0 96 | - libcufft-dev=10.7.1.112=ha5ce4c0_0 97 | - libcufile=1.5.1.14=0 98 | - libcufile-dev=1.5.1.14=0 99 | - libcurand=10.3.1.124=0 100 | - libcurand-dev=10.3.1.124=0 101 | - libcurl=7.76.1=hc4aaa36_1 102 | - libcusolver=11.3.4.124=h33c3c4e_0 103 | - libcusparse=11.7.2.124=h7538f96_0 104 | - libcusparse-dev=11.7.2.124=hbbe9722_0 105 | - libdeflate=1.8=h7f8727e_5 106 | - libedit=3.1.20191231=he28a2e2_2 107 | - libev=4.33=h516909a_1 108 | - libffi=3.4.2=h6a678d5_6 109 | - libgcc-ng=11.2.0=h1234567_1 110 | - libgfortran-ng=7.5.0=h14aa051_20 111 | - libgfortran4=7.5.0=h14aa051_20 112 | - libglib=2.68.4=h174f98d_1 113 | - libgomp=11.2.0=h1234567_1 114 | - libiconv=1.16=h7f8727e_2 115 | - libidn2=2.3.2=h7f8727e_0 116 | - libnghttp2=1.43.0=h812cca2_1 117 | - libnpp=11.6.3.124=hd2722f0_0 118 | - libnpp-dev=11.6.3.124=h3c42840_0 119 | - libnvjpeg=11.6.2.124=hd473ad6_0 120 | - libnvjpeg-dev=11.6.2.124=hb5906b9_0 121 | - libpng=1.6.37=hbc83047_0 122 | - libssh2=1.10.0=ha56f1ee_2 123 | - libstdcxx-ng=11.2.0=h1234567_1 124 | - libtasn1=4.16.0=h27cfd23_0 125 | - libtiff=4.5.0=hecacb30_0 126 | - libunistring=0.9.10=h27cfd23_0 127 | - libuuid=1.41.5=h5eee18b_0 128 | - libwebp=1.2.4=h11a3e52_0 129 | - libwebp-base=1.2.4=h5eee18b_0 130 | - libxcb=1.15=h7f8727e_0 131 | - libxml2=2.9.14=h74e7548_0 132 | - lz4-c=1.9.4=h6a678d5_0 133 | - make=4.3=hd18ef5c_1 134 | - mkl=2021.4.0=h06a4308_640 135 | - mkl-service=2.4.0=py310h7f8727e_0 136 | - mkl_fft=1.3.1=py310hd6ae3a3_0 137 | - mkl_random=1.2.2=py310h00e6091_0 138 | - ncurses=6.3=h5eee18b_3 139 | - nettle=3.7.3=hbbd107a_1 140 | - nsight-compute=2022.4.1.6=0 141 | - openh264=2.1.1=h4ff587b_0 142 | - openssl=1.1.1v=h7f8727e_0 143 | - pango=1.42.4=h7062337_4 144 | - pcre=8.45=h9c3ff4c_0 145 | - pillow=9.3.0=py310hace64e9_1 146 | - pip=22.3.1=py310h06a4308_0 147 | - pixman=0.40.0=h36c2ea0_0 148 | - pluggy=1.0.0=py310h06a4308_1 149 | - pycosat=0.6.4=py310h5eee18b_0 150 | - pycparser=2.21=pyhd3eb1b0_0 151 | - pyopenssl=22.0.0=pyhd3eb1b0_0 152 | - pysocks=1.7.1=py310h06a4308_0 153 | - python=3.10.8=h7a1cb2a_1 154 | - python_abi=3.10=2_cp310 155 | - pytorch=1.13.1=py3.10_cuda11.6_cudnn8.3.2_0 156 | - pytorch-cuda=11.6=h867d48c_1 157 | - pytorch-mutex=1.0=cuda 158 | - r-base=3.6.1=haffb61f_2 159 | - readline=8.2=h5eee18b_0 160 | - requests=2.28.1=py310h06a4308_0 161 | - ruamel.yaml=0.17.21=py310h5eee18b_0 162 | - ruamel.yaml.clib=0.2.6=py310h5eee18b_1 163 | - setuptools=65.5.0=py310h06a4308_0 164 | - six=1.16.0=pyhd3eb1b0_1 165 | - sqlite=3.40.0=h5082296_0 166 | - sysroot_linux-64=2.12=he073ed8_16 167 | - tk=8.6.12=h1ccaba5_0 168 | - tktable=2.10=hb7b940f_3 169 | - toolz=0.12.0=py310h06a4308_0 170 | - torchaudio=0.13.1=py310_cu116 171 | - torchvision=0.14.1=py310_cu116 172 | - tqdm=4.64.1=py310h06a4308_0 173 | - typing_extensions=4.4.0=py310h06a4308_0 174 | - tzdata=2022g=h04d1e81_0 175 | - urllib3=1.26.13=py310h06a4308_0 176 | - wheel=0.37.1=pyhd3eb1b0_0 177 | - xz=5.2.8=h5eee18b_0 178 | - zlib=1.2.13=h5eee18b_0 179 | - zstd=1.5.2=ha4553b6_0 180 | - pip: 181 | - aiohttp==3.8.3 182 | - aiosignal==1.3.1 183 | - anndata==0.8.0 184 | - appdirs==1.4.4 185 | - async-timeout==4.0.2 186 | - attrs==22.2.0 187 | - click==8.1.3 188 | - contourpy==1.0.7 189 | - cycler==0.11.0 190 | - docker-pycreds==0.4.0 191 | - exceptiongroup==1.1.3 192 | - fonttools==4.38.0 193 | - frozenlist==1.3.3 194 | - fsspec==2023.1.0 195 | - gitdb==4.0.10 196 | - gitpython==3.1.30 197 | - h5py==3.9.0 198 | - igraph==0.10.4 199 | - imageio==2.25.0 200 | - iniconfig==2.0.0 201 | - jinja2==3.1.2 202 | - joblib==1.2.0 203 | - kiwisolver==1.4.4 204 | - leidenalg==0.10.1 205 | - lightning-lite==1.8.6 206 | - lightning-utilities==0.6.0.post0 207 | - llvmlite==0.39.1 208 | - louvain==0.8.0 209 | - markupsafe==2.1.3 210 | - matplotlib==3.6.3 211 | - multidict==6.0.4 212 | - munkres==1.1.4 213 | - natsort==8.4.0 214 | - networkx==3.0 215 | - numba==0.56.4 216 | - numpy==1.22.3 217 | - packaging==23.0 218 | - pandas==1.4.2 219 | - pathtools==0.1.2 220 | - patsy==0.5.3 221 | - plotly==5.13.0 222 | - pot==0.8.2 223 | - protobuf==4.21.12 224 | - psutil==5.9.4 225 | - pynndescent==0.5.8 226 | - pyparsing==3.0.9 227 | - pytest==7.4.1 228 | - python-dateutil==2.8.2 229 | - python-igraph==0.10.4 230 | - pytorch-lightning==1.9.0 231 | - pytz==2022.7.1 232 | - pyyaml==6.0 233 | - rpy2==3.4.1 234 | - scanpy==1.9.1 235 | - scikit-learn==1.1.1 236 | - scikit-misc==0.1.4 237 | - scipy==1.8.1 238 | - seaborn==0.12.2 239 | - sentry-sdk==1.14.0 240 | - session-info==1.0.0 241 | - setproctitle==1.3.2 242 | - smmap==5.0.0 243 | - statsmodels==0.14.0 244 | - stdlib-list==0.9.0 245 | - tenacity==8.1.0 246 | - texttable==1.6.7 247 | - threadpoolctl==3.1.0 248 | - tomli==2.0.1 249 | - torchmetrics==0.11.1 250 | - tzlocal==5.0.1 251 | - umap-learn==0.5.3 252 | - wandb==0.13.9 253 | - yarl==1.8.2 254 | prefix: /root/miniconda3 255 | -------------------------------------------------------------------------------- /downstream_task_demo/exp_main_net_feat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import warnings 4 | 5 | import numpy as np 6 | import scanpy as sc 7 | from umap import UMAP 8 | import wandb 9 | 10 | from dataloader.stdata import STData 11 | from sklearn.svm import SVC 12 | from sklearn.svm import SVR 13 | from sklearn.multioutput import MultiOutputRegressor 14 | import shap 15 | import xgboost as xgb 16 | from xgboost import XGBRegressor 17 | from xgboost import XGBClassifier 18 | import matplotlib.pyplot as plt 19 | 20 | 21 | def svc_train(data, emb, label): 22 | xgb.set_config(verbosity=3) 23 | print('Training XGBRegressor') 24 | bst = XGBRegressor( 25 | n_jobs=55, 26 | tree_method = "hist", 27 | ) 28 | num_round = 100 29 | bst.fit(data, emb) 30 | 31 | print('Training XGBClassifier') 32 | clf_svc = XGBClassifier( 33 | objective='multi:softprob', 34 | num_class=np.max(label)+1, 35 | eval_metric='mlogloss', 36 | nthread=55, 37 | tree_method="hist", 38 | ) 39 | 40 | clf_svc.fit(emb, label) 41 | return bst, clf_svc 42 | 43 | def svc_pre(input): 44 | return clf_svc.predict_proba(multioutput_regressor.predict(input))[:,args.aim_label] 45 | 46 | 47 | if __name__ == '__main__': 48 | import argparse 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--R_HOME', type=str, default='/root/miniconda3/lib/R') 51 | parser.add_argument("--wandb", type=str, default="online") 52 | 53 | # Datasets 54 | parser.add_argument('--dataset', type=str, default='151673') 55 | parser.add_argument('--sample', type=str, default='barcode') 56 | parser.add_argument('--n_top_genes', type=int, default=3000) 57 | parser.add_argument('--max_value', type=int, default=10) 58 | parser.add_argument('--crop_size', type=int, default=224) 59 | parser.add_argument('--preprocessed', type=int, default=0) 60 | parser.add_argument('--min_cells', type=int, default=0) 61 | 62 | # Augmentation 63 | parser.add_argument('--graphwithpca', type=bool, default=True) 64 | parser.add_argument('--uselabel', type=bool, default=False) 65 | parser.add_argument('--K_m0', type=int, default=50) 66 | parser.add_argument('--K_m1', type=int, default=50) 67 | 68 | # Cluster 69 | parser.add_argument('--cluster_using', type=str, default='gene_rec') # gene_rec 70 | parser.add_argument('--cluster_method', type=str, default='mclust') # gene_rec 71 | parser.add_argument('--n_clusters', type=int, default=15) 72 | parser.add_argument('--radius', type=int, default=50) 73 | parser.add_argument('--cluster_refinement', type=int, default=0) 74 | 75 | # Model 76 | parser.add_argument('--learning_rate', type=float, default=0.001) 77 | parser.add_argument('--weight_decay', type=float, default=0.00) 78 | parser.add_argument('--epochs', type=int, default=1000) 79 | parser.add_argument('--dim_input', type=int, default=3000) 80 | parser.add_argument('--dim_output', type=int, default=64) 81 | parser.add_argument('--alpha', type=float, default=0.001) 82 | parser.add_argument('--beta', type=float, default=1) 83 | parser.add_argument('--aug_rate_0', type=float, default=0.1) 84 | parser.add_argument('--aug_rate_1', type=float, default=0.1) 85 | parser.add_argument('--v_latent', type=float, default=0.01) 86 | parser.add_argument('--theta', type=float, default=0.1) 87 | parser.add_argument('--random_seed', type=int, default=1) 88 | parser.add_argument('--n_encoder_layer', type=int, default=1, help='number of encoder layers') 89 | parser.add_argument('--n_fusion_layer', type=int, default=1) 90 | parser.add_argument('--bn_type', type=str, default='bn') 91 | parser.add_argument('--self_loop', type=int, default=0) 92 | parser.add_argument('--down_sample_rate', type=float, default=1) 93 | parser.add_argument('--morph_trans_ratio', type=float, default=1) 94 | parser.add_argument('--aug_method', type=str, default="near_mix") 95 | parser.add_argument('--run_dir', type=str, default=os.getenv('WANDB_RUN_DIR')) 96 | parser.add_argument('--model_dir', type=str, default='test_dlpfc/') 97 | parser.add_argument('--result_dir', type=str, default='results/') 98 | parser.add_argument('--save_dir', type=str, default='exp/') 99 | parser.add_argument('--plot', type=int, default=0, help='Plot the result') 100 | parser.add_argument('--var_plot', type=int, default=0) 101 | parser.add_argument('--plot_louvain', type=int, default=1) 102 | parser.add_argument('--plot_leiden', type=int, default=0) 103 | 104 | parser.add_argument('--num_fea', type=int, default=3000) 105 | parser.add_argument('--num_sample', type=int, default=1500) 106 | parser.add_argument('--aim_label', type=int, default=1) 107 | parser.add_argument('--max_evals', type=int, default=6000) 108 | 109 | args = parser.parse_args() 110 | 111 | args.plot = bool(args.plot) 112 | args.var_plot = bool(args.var_plot) 113 | args.plot_louvain = bool(args.plot_louvain) 114 | args.plot_leiden = bool(args.plot_leiden) 115 | args.cluster_refinement = bool(args.cluster_refinement) 116 | args.preprocessed = bool(args.preprocessed) 117 | if args.cluster_method == 'mclust': 118 | args.cluster_method = 'main' 119 | 120 | if args.run_dir is not None and not os.path.exists(args.run_dir): 121 | os.mkdir(args.run_dir) 122 | if args.save_dir is not None: 123 | args.save_dir += f"{args.dataset}/{args.aim_label}/" 124 | os.makedirs(args.save_dir, exist_ok=True) 125 | os.makedirs(args.save_dir+'model/', exist_ok=True) 126 | with open(args.save_dir + 'setting.txt', 'w') as f: 127 | f.write(str(args)) 128 | 129 | wandb_agent = wandb.init( 130 | project="MuST_exp_feat", 131 | entity="liliangyu", 132 | config=args.__dict__, 133 | name='EXP-'+''.join(sys.argv[1:]), 134 | mode=args.wandb, 135 | save_code=True, 136 | dir=args.run_dir, 137 | ) 138 | 139 | visium = STData(name=args.dataset, crop_size=args.crop_size, bio_norm=False, sample=args.sample) # Reset sample to get better results. 140 | 141 | adata = visium.adata 142 | sc.pp.filter_genes(adata, min_cells=args.min_cells) 143 | adata.uns["name"] = args.dataset 144 | 145 | n_clusters = visium.get_annotation_class() 146 | if n_clusters is not None: 147 | warnings.warn("n_cluster rewritten due to known label") 148 | else: 149 | n_clusters = args.n_clusters 150 | 151 | # clustering & ARI 152 | full_label = np.load(f'{args.result_dir}{args.dataset}/pred_{args.cluster_method}.npy') 153 | data = np.load(f'{args.result_dir}{args.dataset}/trans_input.npy') 154 | emb = np.load(f'{args.result_dir}{args.dataset}/emb.npy') 155 | hvg = np.load(f'{args.result_dir}{args.dataset}/hvg.npy') 156 | feature_name = adata.var_names[hvg].tolist() 157 | 158 | random_numbers = np.random.choice(len(data), args.num_sample, replace=False) 159 | 160 | print('down sample the data') 161 | data = data[random_numbers] 162 | emb = emb[random_numbers] 163 | label = full_label[random_numbers] 164 | 165 | print('random select the feature') 166 | data = data[:, :args.num_fea] 167 | 168 | print('explain') 169 | multioutput_regressor, clf_svc = svc_train(data, emb, label) 170 | explainer = shap.Explainer(svc_pre, data, feature_names=feature_name[:args.num_fea]) 171 | shap_values = explainer(data, max_evals=args.max_evals) 172 | 173 | plt.figure(figsize=(15, 15)) 174 | shap.plots.heatmap(shap_values, max_display=20, instance_order=shap_values.sum(1)) 175 | plt.tight_layout() 176 | plt.savefig(args.save_dir + 'shap_heamap.png', dpi=300) 177 | wandb.log({'shap_heamap': wandb.Image(args.save_dir + 'shap_heamap.png')}) 178 | 179 | fig_sp = visium.px_plot_spatial(full_label, background_image=True, save_path=args.save_dir + 'fig_sp.png') 180 | fig_sp.write_html(args.save_dir + 'fig_sp.html') 181 | 182 | shap_array = shap_values.values 183 | mean_abs_shap_values = np.mean(np.abs(shap_array), axis=0) 184 | sorted_indices = np.argsort(-mean_abs_shap_values) 185 | top_features = [feature_name[i] for i in sorted_indices[:20]] 186 | print('top_features', top_features) 187 | for i, gene in enumerate(top_features): 188 | visium.px_plot_spatial_gene( 189 | gene, background_image=False, save_path=args.save_dir + f'gene_{gene}.png') 190 | wandb.log({f'gene_{i}_{gene}': wandb.Image(args.save_dir + f'gene_{gene}.png')}) 191 | 192 | with open(f'exp_{args.dataset}.txt', 'a') as f: 193 | f.write(f'{args.aim_label}: ' + str(top_features) + ',\n') 194 | 195 | wandb.finish() 196 | -------------------------------------------------------------------------------- /MUST.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from preprocess import preprocess_adj, preprocess_adj_sparse, preprocess, construct_interaction, construct_interaction_KNN, get_feature, permutation, fix_seed 3 | import time 4 | import random 5 | import logging 6 | import warnings 7 | import numpy as np 8 | from model import Encoder_MUST 9 | from aug import aug 10 | from tqdm import tqdm, trange 11 | from torch import nn 12 | import torch.nn.functional as F 13 | import scipy 14 | from scipy.sparse import csc_matrix 15 | from scipy.sparse import csr_matrix 16 | import scipy.sparse as sp 17 | 18 | 19 | import pandas as pd 20 | import wandb 21 | 22 | class MUST(): 23 | def __init__(self, 24 | adata, 25 | morph, 26 | n_top_genes=3000, 27 | max_value=1, 28 | adata_sc=None, 29 | device=torch.device('cpu'), 30 | learning_rate=0.001, 31 | weight_decay=0.00, 32 | epochs=600, 33 | dim_input=3000, 34 | dim_output=64, 35 | random_seed=41, 36 | alpha=1, 37 | beta=1, 38 | theta=0.1, 39 | v_latent=0.01, 40 | datatype='10X', 41 | aug_rate_0=0.1, 42 | aug_rate_1=0.1, 43 | n_encoder_layer=1, 44 | n_fusion_layer=1, 45 | bn_type='bn', 46 | self_loop=1, 47 | morph_trans_ratio=0.5, 48 | graphwithpca=False, 49 | uselabel=False, 50 | K_m0=5, 51 | K_m1=5, 52 | aug_method="randn", 53 | unique_str="", 54 | preprocessed=False, 55 | down_sample_rate=0.1, 56 | min_cells=50, 57 | ): 58 | self.adata = adata.copy() 59 | self.morph = morph 60 | self.device = device 61 | self.learning_rate = learning_rate 62 | self.weight_decay = weight_decay 63 | self.epochs = epochs 64 | self.random_seed = random_seed 65 | self.alpha = alpha 66 | self.beta = beta 67 | self.theta = theta 68 | self.datatype = datatype 69 | self.v_latent = v_latent 70 | self.v_input = 100 71 | self.aug_rate_0 = aug_rate_0 72 | self.aug_rate_1 = aug_rate_1 73 | self.n_encoder_layer = n_encoder_layer 74 | self.n_fusion_layer = n_fusion_layer 75 | self.bn_type = bn_type 76 | self.self_loop = self_loop 77 | self.morph_trans_ratio = morph_trans_ratio 78 | self.graphwithpca=graphwithpca 79 | self.uselabel=uselabel 80 | self.K_m0=K_m0 81 | self.K_m1=K_m1 82 | self.aug_method=aug_method 83 | self.unique_str=unique_str 84 | self.down_sample_rate=down_sample_rate 85 | 86 | self.dataset = adata.uns["name"] 87 | 88 | fix_seed(self.random_seed) 89 | 90 | if not preprocessed and 'highly_variable' not in adata.var.keys(): 91 | if self.datatype == '10x': 92 | self.adata = preprocess(self.adata, n_top_genes=n_top_genes, max_value=max_value, dataset=self.dataset) 93 | else: 94 | self.adata = preprocess(self.adata, min_cells=min_cells, n_top_genes=n_top_genes, max_value=max_value, dataset=self.dataset) 95 | 96 | fix_seed(self.random_seed) 97 | 98 | if 'adj' not in adata.obsm.keys(): 99 | if self.datatype in ['stereo', 'slide']: 100 | construct_interaction_KNN(self.adata) 101 | else: 102 | construct_interaction(self.adata) 103 | 104 | if 'feat' not in adata.obsm.keys(): 105 | get_feature(self.adata) 106 | 107 | self.features = torch.FloatTensor( 108 | self.adata.obsm['feat'].copy()).to(self.device) 109 | print(self.features) 110 | self.features_a = torch.FloatTensor( 111 | self.adata.obsm['feat_a'].copy()).to(self.device) 112 | self.morph = morph 113 | if self.morph is not None: 114 | self.morph = torch.FloatTensor( 115 | morph).to(self.device) 116 | self.morph_a = torch.FloatTensor( 117 | permutation(morph)).to(self.device) 118 | self.adj = self.adata.obsm['adj'] + np.eye(self.adata.obsm['adj'].shape[0]) * self.self_loop 119 | 120 | self.graph_neigh = torch.FloatTensor( 121 | self.adata.obsm['graph_neigh'].copy() + np.eye(self.adj.shape[0]) * self.self_loop).to(self.device) 122 | self.neighbor_index_a = aug.cal_near_index(data=self.features, k=K_m1, uselabel=uselabel, graphwithpca=graphwithpca, device=self.device, modal='0', dataset=self.dataset, unique_str=unique_str) 123 | if self.morph is not None: 124 | self.neighbor_index_b = aug.cal_near_index(data=self.morph, k=K_m0, uselabel=uselabel, graphwithpca=graphwithpca, device=self.device, modal='1', dataset=self.dataset, unique_str=unique_str) 125 | 126 | if self.morph is not None: 127 | self.input_morph = morph 128 | self.input_trans = self.adata.obsm['feat'] 129 | 130 | self.dim_input_a = self.features.shape[1] 131 | if self.morph is not None: 132 | self.dim_input_b = self.morph.shape[1] 133 | else: 134 | self.dim_input_b = None 135 | self.dim_output = dim_output 136 | 137 | if self.datatype in ['Stereo', 'Slide']: 138 | # using sparse 139 | print('Building sparse matrix ...') 140 | self.adj = preprocess_adj_sparse(self.adj).to(self.device) 141 | else: 142 | # standard version 143 | self.adj = preprocess_adj(self.adj) 144 | self.adj_coo = sp.coo_matrix(self.adj) 145 | 146 | indices = torch.LongTensor([self.adj_coo.row, self.adj_coo.col]) 147 | values = torch.FloatTensor(self.adj_coo.data) 148 | shape = self.adj_coo.shape 149 | 150 | self.adj_sparse = torch.sparse_coo_tensor(indices, values, shape).to(self.device) 151 | 152 | 153 | def _TwowaydivergenceLoss(self, P_, Q_, select=None): 154 | 155 | EPS = 1e-5 156 | losssum1 = P_ * torch.log(Q_ + EPS) 157 | losssum2 = (1 - P_) * torch.log(1 - Q_ + EPS) 158 | losssum = -1 * (losssum1 + losssum2) 159 | 160 | return losssum.mean() 161 | 162 | def _DistanceSquared(self, x, y=None, metric="euclidean"): 163 | if metric == "euclidean": 164 | if y is not None: 165 | m, n = x.size(0), y.size(0) 166 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 167 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 168 | dist = xx + yy 169 | dist = torch.addmm(dist, mat1=x, mat2=y.t(), beta=1, alpha=-2) 170 | dist = dist.clamp(min=1e-12) 171 | else: 172 | m, n = x.size(0), x.size(0) 173 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 174 | yy = xx.t() 175 | dist = xx + yy 176 | dist = torch.addmm(dist, mat1=x, mat2=x.t(), beta=1, alpha=-2) 177 | dist = dist.clamp(min=1e-12) 178 | dist[torch.eye(dist.shape[0]) == 1] = 1e-12 179 | 180 | if metric == "cossim": 181 | input_a, input_b = x, x 182 | normalized_input_a = torch.nn.functional.normalize(input_a) 183 | normalized_input_b = torch.nn.functional.normalize(input_b) 184 | dist = torch.mm(normalized_input_a, normalized_input_b.T) 185 | dist *= -1 186 | dist += 1 187 | 188 | dist[torch.eye(dist.shape[0]) == 1] = 1e-12 189 | 190 | return dist 191 | 192 | def _CalGamma(self, v): 193 | 194 | a = scipy.special.gamma((v + 1) / 2) 195 | b = np.sqrt(v * np.pi) * scipy.special.gamma(v / 2) 196 | out = a / b 197 | 198 | return out 199 | 200 | def _Similarity(self, 201 | dist, 202 | gamma, 203 | v): 204 | dist_rho = dist 205 | dist_rho[dist_rho < 0] = 0 206 | Pij = ( 207 | gamma 208 | * torch.tensor(2 * 3.14) 209 | * gamma 210 | * torch.pow((1 + dist_rho / v), exponent=-1 * (v + 1)) 211 | ) 212 | return Pij 213 | 214 | def loss_manifold( 215 | self, 216 | input_data, 217 | latent_data, 218 | v_latent, 219 | metric='euclidean', 220 | ): 221 | 222 | data_1 = input_data[: input_data.shape[0] // 2] 223 | 224 | dis_P = self._DistanceSquared(data_1, metric=metric) 225 | latent_data_1 = latent_data[: input_data.shape[0] // 2] 226 | 227 | dis_P_2 = dis_P # + nndistance.reshape(1, -1) 228 | P_2 = self._Similarity(dist=dis_P_2, 229 | gamma=self._CalGamma(self.v_input), 230 | v=self.v_input, ) 231 | latent_data_2 = latent_data[(input_data.shape[0] // 2):] 232 | dis_Q_2 = self._DistanceSquared(latent_data_1, latent_data_2) 233 | Q_2 = self._Similarity( 234 | dist=dis_Q_2, 235 | gamma=self._CalGamma(v_latent), 236 | v=v_latent, 237 | ) 238 | loss_ce_2 = self._TwowaydivergenceLoss(P_=P_2, Q_=Q_2) 239 | return loss_ce_2 240 | 241 | def augmentation(self, fea, t=0.1): 242 | fea_rand = torch.randn(fea.shape, device=fea.device) * torch.var(fea, dim=0) * t 243 | return fea + fea_rand 244 | 245 | def train(self, verbose=True): 246 | if self.datatype in ['Stereo', 'Slide']: 247 | self.model = Encoder_sparse( 248 | self.dim_input, self.dim_output, self.graph_neigh).to(self.device) 249 | else: 250 | # self.model = Encoder(self.dim_input, self.dim_output, self.graph_neigh).to(self.device) 251 | self.model = Encoder_MUST( 252 | self.dim_input_a, 253 | self.dim_input_b, 254 | self.dim_output, 255 | self.graph_neigh, 256 | n_encoder_layer=self.n_encoder_layer, 257 | n_fusion_layer=self.n_fusion_layer, 258 | bn_type=self.bn_type, 259 | morph_trans_ratio=self.morph_trans_ratio, 260 | platform=self.datatype, 261 | ).to(self.device) 262 | 263 | self.optimizer = torch.optim.AdamW(self.model.parameters(), self.learning_rate, 264 | weight_decay=self.weight_decay) 265 | 266 | logging.info('Begin to train ST data...') 267 | self.model.train() 268 | self.adata.obsm['trans_input'] = self.features.detach().cpu().numpy() 269 | 270 | tmp_idx = torch.tensor(np.arange(self.features.shape[0]), dtype=int) # As augmentation index when batch is not adopted. 271 | if verbose: 272 | tr = trange(self.epochs) 273 | else: 274 | tr = range(self.epochs) 275 | for epoch in tr: 276 | self.model.train() 277 | 278 | aug_func = getattr(aug, f"aug_{self.aug_method}") 279 | 280 | if self.morph is not None: 281 | self.morph_a = aug_func(index=tmp_idx, dataset=self.morph, neighbors_index=self.neighbor_index_b, 282 | k=self.K_m0, random_t=self.aug_rate_0, device=self.device) 283 | else: 284 | self.morph_a = None 285 | self.features_a = aug_func(index=tmp_idx, dataset=self.features, neighbors_index=self.neighbor_index_a, 286 | k=self.K_m1, random_t=self.aug_rate_1, device=self.device) 287 | hiden_feat_list, self.emb, __, __ = self.model( 288 | self.features, self.morph, self.adj_sparse) 289 | hiden_feat_list_a, self.emb_a, __, __ = self.model( 290 | self.features_a, self.morph_a, self.adj_sparse) 291 | 292 | [self.hiden_feat, self.hiden_feat_p] = hiden_feat_list 293 | [self.hiden_feat_a, self.hiden_feat_p_a] = hiden_feat_list_a 294 | 295 | down_sample_mask = torch.rand(self.hiden_feat.shape[0]) < self.down_sample_rate 296 | self.d_hiden_feat = self.hiden_feat[down_sample_mask] 297 | self.d_hiden_feat_p = self.hiden_feat_p[down_sample_mask] 298 | self.d_hiden_feat_a = self.hiden_feat_a[down_sample_mask] 299 | self.d_hiden_feat_p_a = self.hiden_feat_p_a[down_sample_mask] 300 | self.d_emb = self.emb[down_sample_mask] 301 | self.d_emb_a = self.emb_a[down_sample_mask] 302 | self.d_features = self.features[down_sample_mask] 303 | 304 | self.man_loss = self.loss_manifold( 305 | input_data=torch.cat([self.d_hiden_feat, self.d_hiden_feat_a], dim=0), 306 | latent_data=torch.cat([self.d_hiden_feat_p, self.d_hiden_feat_p_a], dim=0), 307 | v_latent=self.v_latent, 308 | ) 309 | # self.man_loss = 0 310 | self.feat_loss = F.mse_loss(self.d_features, self.d_emb) 311 | # self.feat_loss = torch.mean(self.d_features) 312 | 313 | loss = self.alpha*self.feat_loss + self.beta*self.man_loss 314 | # loss = self.alpha*self.feat_loss 315 | # loss = self.beta*self.man_loss 316 | 317 | wandb.log({'feat_loss': self.alpha*self.feat_loss, 318 | 'man_loss': self.beta*self.man_loss, 319 | 'all_loss': loss}) 320 | 321 | self.optimizer.zero_grad() 322 | loss.backward() 323 | self.optimizer.step() 324 | 325 | logging.info("Optimization finished for ST data!") 326 | 327 | with torch.no_grad(): 328 | self.model.eval() 329 | if self.datatype in ['Stereo', 'Slide']: 330 | self.emb_rec = self.model( 331 | self.features, self.morph, self.adj_sparse)[1] 332 | self.emb_rec = F.normalize( 333 | self.emb_rec, p=2, dim=1).detach().cpu().numpy() 334 | else: 335 | raw_output = self.model(self.features, self.morph, self.adj_sparse) 336 | self.lat = raw_output[0][1].detach().cpu().numpy() 337 | self.rec = raw_output[1].detach().cpu().numpy() 338 | self.trans_emb = raw_output[2].detach().cpu().numpy() 339 | self.morph_emb = raw_output[3].detach().cpu().numpy() if raw_output[3] is not None else None 340 | self.adata.obsm['emb'] = self.lat 341 | self.adata.obsm['gene_rec'] = self.rec 342 | 343 | return self.adata 344 | 345 | def discover_region(self): 346 | raw_output = self.model(self.features, self.morph, self.adj_sparse) 347 | self.lat = raw_output[0][1].detach().cpu().numpy() 348 | self.rec = raw_output[1].detach().cpu().numpy() 349 | self.trans_emb = raw_output[2].detach().cpu().numpy() 350 | self.morph_emb = raw_output[3].detach().cpu().numpy() if raw_output[3] is not None else None 351 | self.adata.obsm['emb'] = self.lat 352 | self.adata.obsm['gene_rec'] = self.rec 353 | 354 | return self.adata 355 | 356 | def save(self, save_dir=''): 357 | self.model.save(save_dir) 358 | 359 | def load(self, load_dir=''): 360 | self.model = Encoder_MUST( 361 | self.dim_input_a, 362 | self.dim_input_b, 363 | self.dim_output, 364 | self.graph_neigh, 365 | n_encoder_layer=self.n_encoder_layer, 366 | n_fusion_layer=self.n_fusion_layer, 367 | bn_type=self.bn_type, 368 | morph_trans_ratio=self.morph_trans_ratio, 369 | platform=self.datatype, 370 | ).to(self.device) 371 | self.model.load(load_dir) -------------------------------------------------------------------------------- /main_MuST.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import warnings 4 | 5 | import wandb 6 | import numpy as np 7 | from umap import UMAP 8 | 9 | import torch 10 | from MUST import MUST 11 | from dataloader.stdata import STData 12 | 13 | import eval.eval_core_base as ecb 14 | from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score, davies_bouldin_score 15 | from utils import cluster, refine_label, make_error_label, targeted_cluster, cluster_map, aligned_accuracy_score, stable_cluster 16 | 17 | if __name__ == '__main__': 18 | import argparse 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--R_HOME', type=str, default='/root/miniconda3/lib/R') 21 | parser.add_argument("--wandb", type=str, default="online") 22 | 23 | # Datasets 24 | parser.add_argument('--dataset', type=str, default='V1_Adult_Mouse_Brain') 25 | parser.add_argument('--sample', type=str, default='barcode') 26 | parser.add_argument('--n_top_genes', type=int, default=3000) 27 | parser.add_argument('--max_value', type=int, default=10) 28 | parser.add_argument('--crop_size', type=int, default=224) 29 | parser.add_argument('--preprocessed', type=int, default=0) 30 | parser.add_argument('--min_cells', type=int, default=50) 31 | parser.add_argument('--force_no_morph', type=int, default=0) 32 | 33 | # Augmentation 34 | parser.add_argument('--graphwithpca', type=bool, default=True) 35 | parser.add_argument('--uselabel', type=bool, default=False) 36 | parser.add_argument('--K_m0', type=int, default=7) 37 | parser.add_argument('--K_m1', type=int, default=7) 38 | 39 | # Cluster 40 | parser.add_argument('--cluster_using', type=str, default='gene_rec') # gene_rec 41 | parser.add_argument('--n_clusters', type=int, default=20) 42 | parser.add_argument('--radius', type=int, default=50) 43 | parser.add_argument('--cluster_refinement', type=int, default=0) 44 | 45 | # Model 46 | parser.add_argument('--learning_rate', type=float, default=0.001) 47 | parser.add_argument('--weight_decay', type=float, default=0.00) 48 | parser.add_argument('--epochs', type=int, default=600) 49 | parser.add_argument('--dim_input', type=int, default=3000) 50 | parser.add_argument('--dim_output', type=int, default=60) 51 | parser.add_argument('--alpha', type=float, default=0.0015) 52 | parser.add_argument('--beta', type=float, default=1) 53 | parser.add_argument('--aug_rate_0', type=float, default=0.1) 54 | parser.add_argument('--aug_rate_1', type=float, default=0.1) 55 | parser.add_argument('--v_latent', type=float, default=0.05) 56 | parser.add_argument('--theta', type=float, default=0.1) 57 | parser.add_argument('--random_seed', type=int, default=0) 58 | parser.add_argument('--n_encoder_layer', type=int, default=1) 59 | parser.add_argument('--n_fusion_layer', type=int, default=1) 60 | parser.add_argument('--bn_type', type=str, default='bn') 61 | parser.add_argument('--self_loop', type=int, default=0) 62 | parser.add_argument('--down_sample_rate', type=float, default=1) 63 | parser.add_argument('--morph_trans_ratio', type=float, default=1) 64 | parser.add_argument('--aug_method', type=str, default="near_mix") 65 | parser.add_argument('--device', type=str, default=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 66 | parser.add_argument('--run_dir', type=str, default=os.getenv('WANDB_RUN_DIR')) 67 | parser.add_argument('--save_dir', type=str, default="result/") 68 | parser.add_argument('--plot', type=int, default=1) 69 | parser.add_argument('--var_plot', type=int, default=0) # plot vairous cluster numbers 70 | parser.add_argument('--plot_louvain', type=int, default=0) 71 | parser.add_argument('--plot_leiden', type=int, default=0) 72 | args = parser.parse_args() 73 | 74 | args.plot = bool(args.plot) 75 | args.var_plot = bool(args.var_plot) 76 | args.plot_louvain = bool(args.plot_louvain) 77 | args.plot_leiden = bool(args.plot_leiden) 78 | args.cluster_refinement = bool(args.cluster_refinement) 79 | args.preprocessed = bool(args.preprocessed) 80 | 81 | if args.run_dir is not None and not os.path.exists(args.run_dir): 82 | os.mkdir(args.run_dir) 83 | if args.save_dir is not None: 84 | args.save_dir += f"{args.dataset}/" 85 | os.makedirs(args.save_dir, exist_ok=True) 86 | os.makedirs(args.save_dir+'model/', exist_ok=True) 87 | with open(args.save_dir + 'setting.txt', 'w') as f: 88 | f.write(str(args)) 89 | 90 | if not os.path.exists(args.R_HOME): 91 | raise EnvironmentError("R_HOME misconfigured. Run `Rscript -e 'R.home(component=\"home\")' ` and pass the output as R_HOME.") 92 | os.environ['R_HOME'] = args.R_HOME 93 | 94 | wandb_agent = wandb.init( 95 | project="MuST_main", 96 | entity="liliangyu", 97 | config=args.__dict__, 98 | name='MuST_'.join(sys.argv[1:]), 99 | mode=args.wandb, 100 | save_code=True, 101 | dir=args.run_dir, 102 | ) 103 | 104 | visium = STData(name=args.dataset, crop_size=args.crop_size, bio_norm=False, sample=args.sample) # Reset sample to get better results. 105 | 106 | adata = visium.adata 107 | adata.uns["name"] = args.dataset 108 | 109 | n_clusters = visium.get_annotation_class() 110 | if n_clusters is not None: 111 | warnings.warn("n_cluster rewritten due to known label") 112 | else: 113 | n_clusters = args.n_clusters 114 | 115 | # define model 116 | use_morph = None if args.force_no_morph else visium.get_morph() 117 | 118 | model = MUST( 119 | adata, 120 | use_morph, 121 | n_top_genes=args.n_top_genes, 122 | max_value=args.max_value, 123 | device=args.device, 124 | random_seed=args.random_seed, 125 | learning_rate=args.learning_rate, 126 | weight_decay=args.weight_decay, 127 | epochs=args.epochs, 128 | dim_input=args.dim_input, 129 | dim_output=args.dim_output, 130 | alpha=args.alpha, 131 | beta=args.beta, 132 | v_latent=args.v_latent, 133 | theta=args.theta, 134 | aug_rate_0=args.aug_rate_0, 135 | aug_rate_1=args.aug_rate_1, 136 | n_encoder_layer=args.n_encoder_layer, 137 | n_fusion_layer=args.n_fusion_layer, 138 | bn_type=args.bn_type, 139 | self_loop=args.self_loop, 140 | morph_trans_ratio=args.morph_trans_ratio, 141 | graphwithpca=args.graphwithpca, 142 | uselabel=args.uselabel, 143 | K_m0=args.K_m0, 144 | K_m1=args.K_m1, 145 | aug_method=args.aug_method, 146 | unique_str=f"Crop{args.crop_size}", 147 | datatype=visium.platform, 148 | preprocessed=args.preprocessed, 149 | down_sample_rate=args.down_sample_rate, 150 | min_cells=args.min_cells, 151 | ) 152 | 153 | 154 | # train model 155 | print("INPUT GENE SHAPE: ", model.adata.shape, "HVG Selected: ", model.adata.var['highly_variable'].sum()) 156 | adata = model.train() 157 | 158 | wandb_logs = {} 159 | 160 | recon = adata.obsm['gene_rec'] 161 | emb = adata.obsm['emb'] 162 | cluster_emb = emb if args.cluster_using == 'emb' else recon 163 | if args.plot: 164 | if emb.shape[1] != 2: 165 | emb_2d = UMAP(n_components=2, random_state=args.random_seed).fit_transform(emb) 166 | else: 167 | emb_2d = emb 168 | 169 | if args.save_dir is not None: 170 | model.save(args.save_dir+'model/') 171 | # np.save(args.save_dir + 'recon.npy', adata.obsm['gene_rec']) 172 | np.save(args.save_dir + 'emb.npy', adata.obsm['emb']) 173 | # np.save(args.save_dir + 'trans_input.npy', adata.obsm['trans_input']) 174 | np.save(args.save_dir + 'emb_2d.npy', emb_2d) 175 | np.save(args.save_dir + 'hvg.npy', adata.var['highly_variable'].to_numpy()) 176 | # np.save(args.save_dir + 'trans_emb', model.trans_emb) 177 | # if model.morph_emb is not None: 178 | # np.save(args.save_dir + 'morph_emb', model.morph_emb) 179 | 180 | # kmeans_pred = cluster(cluster_emb, method="kmeans", n_clusters=n_clusters) 181 | if args.plot_louvain: 182 | louvain_pred = targeted_cluster(cluster_emb, method="louvain", target_n_clusters=n_clusters) 183 | if args.plot_leiden: 184 | leiden_pred = targeted_cluster(cluster_emb, method="leiden", target_n_clusters=n_clusters) 185 | # mclust_pred = cluster(cluster_emb, method="mclust", n_clusters=n_clusters, pca_dim=20) 186 | mclust_pred = stable_cluster(cluster_emb, method="mclust", n_clusters=n_clusters, pca_dim=20) 187 | if args.cluster_refinement: # Hexogonal Refinement 188 | print("REFINE!") 189 | # kmeans_pred = refine_label(kmeans_pred, corrds=visium.get_coords(), radius=args.radius) 190 | if args.plot_louvain: 191 | louvain_pred = refine_label(louvain_pred, corrds=visium.get_coords(), radius=args.radius) 192 | if args.plot_leiden: 193 | leiden_pred = refine_label(leiden_pred, corrds=visium.get_coords(), radius=args.radius) 194 | mclust_pred = refine_label(mclust_pred, corrds=visium.get_coords(), radius=args.radius) 195 | if visium.get_label() is not None: 196 | true = visium.get_label() 197 | # kmeans_pred = cluster_map(true, kmeans_pred, wildcard=999) 198 | if args.plot_louvain: 199 | louvain_pred = cluster_map(true, louvain_pred, wildcard=999) 200 | if args.plot_leiden: 201 | leiden_pred = cluster_map(true, leiden_pred, wildcard=999) 202 | mclust_pred = cluster_map(true, mclust_pred, wildcard=999) 203 | 204 | if args.save_dir is not None: 205 | np.save(args.save_dir + 'pred_main.npy', mclust_pred) 206 | if args.plot_louvain: 207 | np.save(args.save_dir + 'pred_louvain.npy', louvain_pred) 208 | if args.plot_leiden: 209 | np.save(args.save_dir + 'pred_leiden.npy', leiden_pred) 210 | 211 | ecb_e_trans = ecb.Eval(input=model.input_trans, latent=emb, label=mclust_pred, k=10) 212 | trans_mrre_zx, trans_mrre_xz = ecb_e_trans.E_mrre() 213 | trans_mrre = np.mean([trans_mrre_xz, trans_mrre_zx]) 214 | wandb_logs.update({ 215 | f"metrics/MRRE_trans_{10}": trans_mrre, 216 | f"metrics/MRRE_trans_xz_{10}": trans_mrre_xz, 217 | f"metrics/MRRE_trans_zx_{10}": trans_mrre_zx, 218 | f"metrics/mclust_sc": silhouette_score(emb, mclust_pred), 219 | f"metrics/mclust_db": davies_bouldin_score(emb, mclust_pred), 220 | }) 221 | if visium.get_morph() is not None: 222 | ecb_e_morph = ecb.Eval(input=visium.get_morph(), latent=emb, label=mclust_pred, k=10) 223 | 224 | morph_mrre_zx, morph_mrre_xz = ecb_e_morph.E_mrre() 225 | morph_mrre = np.mean([morph_mrre_xz, morph_mrre_zx]) 226 | mrre_xz = morph_mrre_xz + trans_mrre_xz 227 | mrre_zx = morph_mrre_zx + trans_mrre_zx 228 | 229 | wandb_logs.update({ 230 | f"metrics/MRRE_{10}": morph_mrre + trans_mrre, 231 | f"metrics/MRRE_morph_{10}": morph_mrre, 232 | f"metrics/MRRE_morph_xz_{10}": morph_mrre_xz, 233 | f"metrics/MRRE_xz_{10}": mrre_xz, 234 | f"metrics/MRRE_morph_zx_{10}": morph_mrre_zx, 235 | f"metrics/MRRE_zx_{10}": mrre_zx, 236 | }) 237 | 238 | if args.plot: 239 | # fig_spatial_plain_kmeans = visium.px_plot_spatial(kmeans_pred, background_image=False) 240 | if args.save_dir is None: 241 | fig_spatial_plain_mclust = visium.px_plot_spatial(mclust_pred, background_image=False) 242 | fig_emb_mclust = visium.px_plot_embedding(emb_2d, mclust_pred) 243 | if args.plot_louvain: 244 | fig_spatial_plain_louvain = visium.px_plot_spatial(louvain_pred, background_image=False) 245 | fig_emb_louvain = visium.px_plot_embedding(emb_2d, louvain_pred) 246 | if args.plot_leiden: 247 | fig_spatial_plain_leiden = visium.px_plot_spatial(leiden_pred, background_image=False) 248 | fig_emb_leiden = visium.px_plot_embedding(emb_2d, leiden_pred) 249 | else: 250 | fig_spatial_plain_mclust = visium.px_plot_spatial(mclust_pred, background_image=visium.platform == '10x', save_path=args.save_dir+'mclust.png') 251 | fig_emb_mclust = visium.px_plot_embedding(emb_2d, mclust_pred, save_path=args.save_dir+'mclust_emb.png') 252 | if args.plot_louvain: 253 | fig_spatial_plain_louvain = visium.px_plot_spatial(louvain_pred, background_image=visium.platform == '10x', save_path=args.save_dir+'louvain.png') 254 | fig_emb_louvain = visium.px_plot_embedding(emb_2d, louvain_pred, save_path=args.save_dir+'louvian_emb.png') 255 | if args.plot_leiden: 256 | fig_spatial_plain_leiden = visium.px_plot_spatial(leiden_pred, background_image=visium.platform == '10x', save_path=args.save_dir+'leiden.png') 257 | fig_emb_leiden = visium.px_plot_embedding(emb_2d, leiden_pred, save_path=args.save_dir+'leiden_emb.png') 258 | wandb_logs.update({ 259 | # "figs/plain Spatial Image - KMeans": fig_spatial_plain_kmeans, 260 | "figs/plain Spatial Image - mclust": fig_spatial_plain_mclust, 261 | 'figs/UMAP mclust': fig_emb_mclust, 262 | }) 263 | if args.plot_louvain: 264 | wandb_logs.update({ 265 | "figs/plain Spatial Image - Louvain": fig_spatial_plain_louvain, 266 | 'figs/UMAP louvain': fig_emb_louvain, 267 | }) 268 | if args.plot_leiden: 269 | wandb_logs.update({ 270 | "figs/plain Spatial Image - Leiden": fig_spatial_plain_leiden, 271 | 'figs/UMAP leiden': fig_emb_leiden, 272 | }) 273 | if args.var_plot: 274 | for var_cluster_num in [ 20 ]: 275 | var_mclust_pred = stable_cluster(emb, method="mclust", n_clusters=var_cluster_num, pca_dim=20) 276 | if args.save_dir is not None: 277 | np.save(args.save_dir + f'pred_{var_cluster_num}.npy', var_mclust_pred) 278 | fig_spatial_var_mclust = visium.px_plot_spatial(var_mclust_pred, background_image=True, save_path=args.save_dir+f'mclust_{var_cluster_num}.png') 279 | fig_embedding_var_mclust = visium.px_plot_embedding(emb_2d, var_mclust_pred, save_path=args.save_dir+f'mclust_{var_cluster_num}_emb.png') 280 | else: 281 | fig_spatial_var_mclust = visium.px_plot_spatial(var_mclust_pred, background_image=False) 282 | fig_embedding_var_mclust = visium.px_plot_embedding(emb_2d, var_mclust_pred) 283 | wandb_logs.update({ 284 | f"figs_variety/spatial_mclust_class{var_cluster_num}": fig_spatial_var_mclust, 285 | f"figs_variety/UMAP_mclust_class{var_cluster_num}": fig_embedding_var_mclust, 286 | }) 287 | 288 | if visium.get_label() is not None: 289 | mask = true != 999 290 | mclust_error = make_error_label(true, mclust_pred, wildcard=999) 291 | 292 | wandb_logs.update({ 293 | # f"metrics/k_means_acc": aligned_accuracy_score(true[mask], kmeans_pred[mask], wildcard=999), 294 | # f"metrics/k_means_ari": adjusted_rand_score(true[mask], kmeans_pred[mask]), 295 | # f"metrics/k_means_nmi": normalized_mutual_info_score(true[mask], kmeans_pred[mask]), 296 | # f"metrics/k_means_ami": adjusted_mutual_info_score(true[mask], kmeans_pred[mask]), 297 | # f"metrics/louvain_acc": aligned_accuracy_score(true[mask], louvain_pred[mask], wildcard=999), 298 | # f"metrics/louvain_ari": adjusted_rand_score(true[mask], louvain_pred[mask]), 299 | # f"metrics/louvain_nmi": normalized_mutual_info_score(true[mask], louvain_pred[mask]), 300 | # f"metrics/louvain_ami": adjusted_mutual_info_score(true[mask], louvain_pred[mask]), 301 | # f"metrics/mclust_acc": aligned_accuracy_score(true[mask], mclust_pred[mask], wildcard=999), 302 | f"metrics/mclust_ari": adjusted_rand_score(true[mask], mclust_pred[mask]), 303 | # f"metrics/mclust_nmi": normalized_mutual_info_score(true[mask], mclust_pred[mask]), 304 | # f"metrics/mclust_ami": adjusted_mutual_info_score(true[mask], mclust_pred[mask]), 305 | }) 306 | if args.plot: 307 | # fig_spatial_plain_mclust_error = visium.px_plot_spatial(mclust_error, background_image=False) 308 | # fig_spatial_plain_true = visium.px_plot_spatial(true, background_image=False) 309 | if args.save_dir is not None: 310 | fig_embedding_true = visium.px_plot_embedding(emb_2d, true, save_path=args.save_dir+'true_emb.png') 311 | else: 312 | fig_embedding_true = visium.px_plot_embedding(emb_2d, true) 313 | wandb_logs.update({ 314 | # f"figs/plain Spatial Image - True": fig_spatial_plain_true, 315 | # f"figs/plain Spatial Image - mclust Error": fig_spatial_plain_mclust_error, 316 | f'figs/UMAP true': fig_embedding_true, 317 | }) 318 | 319 | save_file_path = f'{wandb.run.dir[:-5]}flag.txt' 320 | wandb.log(wandb_logs) 321 | wandb.finish() 322 | 323 | with open(save_file_path, 'w') as f: 324 | f.write('finish run all') -------------------------------------------------------------------------------- /eval/eval_core_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.stats import spearmanr 3 | from sklearn.model_selection import RepeatedStratifiedKFold, cross_val_score 4 | from sklearn import metrics 5 | from sklearn.svm import SVC 6 | import numpy as np 7 | import random 8 | from sklearn.metrics import pairwise_distances 9 | 10 | import networkx as nx 11 | 12 | from sklearn.cluster import KMeans 13 | from munkres import Munkres 14 | 15 | from sklearn.metrics import confusion_matrix 16 | 17 | import sys 18 | sys.path.append("..") 19 | 20 | from utils import targeted_cluster 21 | 22 | 23 | def Curance_path_list(neighbour_input, distance_input, label): 24 | 25 | 26 | row = [] 27 | col = [] 28 | v = [] 29 | n_p, n_n = neighbour_input.shape 30 | for i in range(n_p): 31 | for j in range(n_n): 32 | row.append(i) 33 | col.append(neighbour_input[i,j]) 34 | v.append(distance_input[i,j]) 35 | 36 | G=nx.Graph() 37 | for i in range(0, n_p): 38 | G.add_node(i) 39 | for i in range(len(row)): 40 | G.add_weighted_edges_from([(row[i],col[i],v[i])]) 41 | 42 | # pos=nx.shell_layout(G) 43 | # nx.draw(G, pos,with_labels=True, node_color='white', edge_color='red', node_size=400, alpha=0.5 ) 44 | 45 | path_list = [] 46 | for i in range(5000): 47 | source = random.randint(a=0, b=n_p-1) 48 | source_label = label[source] 49 | list_with_same_label = np.array(range(n_p))[label==source_label] 50 | target = random.sample(list(list_with_same_label), 1) 51 | 52 | target = random.randint(a=0, b=n_p-1) 53 | try: 54 | path = nx.dijkstra_path(G, source=source, target=target) 55 | path_list.append(path) 56 | except: 57 | pass 58 | 59 | return path_list 60 | 61 | 62 | class Eval(): 63 | 64 | def __init__(self, input, latent, label, cuda=None, k=50) -> None: 65 | n = latent.shape[0] 66 | if n > 5000: 67 | random.seed(0) 68 | index = random.sample(range(n), 5000) 69 | self.k = k 70 | # self.k = int(k * 5000/n) 71 | # print('down sampling') 72 | else: 73 | index = range(n) 74 | self.k = k 75 | self.input = input.reshape(n, -1)[index] 76 | self.latent = latent.reshape(n, -1)[index] 77 | self.label = label[index] 78 | self.cuda = cuda 79 | # print('distance_input') 80 | self.distance_input = self._Distance_squared_CPU( 81 | self.input, self.input) 82 | # print('distance_latnet') 83 | self.distance_latnet = self._Distance_squared_CPU( 84 | self.latent, self.latent) 85 | # print('neighbour_input') 86 | self.neighbour_input, self.rank_input = self._neighbours_and_ranks(self.distance_input) 87 | # print('neighbour_latent') 88 | self.neighbour_latent, self.rank_latent = self._neighbours_and_ranks( 89 | self.distance_latnet) 90 | 91 | 92 | def _neighbours_and_ranks(self, distances): 93 | """ 94 | Inputs: 95 | - distances, distance matrix [n times n], 96 | - k, number of nearest neighbours to consider 97 | Returns: 98 | - neighbourhood, contains the sample indices (from 0 to n-1) of kth nearest neighbor of current sample [n times k] 99 | - ranks, contains the rank of each sample to each sample [n times n], whereas entry (i,j) gives the rank that sample j has to i (the how many 'closest' neighbour j is to i) 100 | """ 101 | k = self.k 102 | # Warning: this is only the ordering of neighbours that we need to 103 | # extract neighbourhoods below. The ranking comes later! 104 | indices = np.argsort(distances, axis=-1, kind="stable") 105 | 106 | # Extract neighbourhoods. 107 | neighbourhood = indices[:, 1 : k + 1] 108 | 109 | # Convert this into ranks (finally) 110 | ranks = indices.argsort(axis=-1, kind="stable") 111 | # print(ranks) 112 | 113 | return neighbourhood, ranks 114 | 115 | 116 | def _Distance_squared_GPU(self, x, y, cuda=7): 117 | 118 | x = torch.tensor(x).cuda() 119 | y = torch.tensor(y).cuda() 120 | m, n = x.size(0), y.size(0) 121 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 122 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 123 | dist = xx + yy 124 | dist = torch.addmm(dist, mat1=x, mat2=y.t(),beta=1, alpha=-2) 125 | 126 | d = dist.clamp(min=1e-36) 127 | return np.sqrt(d.detach().cpu().numpy()) 128 | 129 | def _Distance_squared_CPU(self, x, y): 130 | x = torch.tensor(x) 131 | y = torch.tensor(y) 132 | m, n = x.size(0), y.size(0) 133 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 134 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 135 | dist = xx + yy 136 | # dist.addmm_(1, -2, x, y.t()) 137 | dist = torch.addmm(dist, mat1=x, mat2=y.t(),beta=1, alpha=-2) 138 | d = dist.clamp(min=1e-36) 139 | return d.detach().cpu().numpy() 140 | 141 | def _trustworthiness(self, X_neighbourhood, X_ranks, Z_neighbourhood, Z_ranks, n, k): 142 | """ 143 | Calculates the trustworthiness measure between the data space `X` 144 | and the latent space `Z`, given a neighbourhood parameter `k` for 145 | defining the extent of neighbourhoods. 146 | """ 147 | 148 | result = 0.0 149 | 150 | # Calculate number of neighbours that are in the $k$-neighbourhood 151 | # of the latent space but not in the $k$-neighbourhood of the data 152 | # space. 153 | for row in range(X_ranks.shape[0]): 154 | missing_neighbours = np.setdiff1d( 155 | Z_neighbourhood[row], X_neighbourhood[row] 156 | ) 157 | 158 | for neighbour in missing_neighbours: 159 | result += X_ranks[row, neighbour] - k 160 | 161 | return 1 - 2 / (n * k * (2 * n - 3 * k - 1)) * result 162 | 163 | 164 | def E_Classifacation_SVC(self): 165 | 166 | from sklearn.preprocessing import StandardScaler 167 | 168 | method = SVC(kernel="linear", max_iter=90000) 169 | cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=1, random_state=1) 170 | # if 171 | n_scores = cross_val_score( 172 | method, 173 | StandardScaler().fit_transform(self.latent), 174 | self.label.astype(np.int32), 175 | scoring="accuracy", 176 | cv=cv, 177 | n_jobs=-1, 178 | ) 179 | 180 | return n_scores.mean() 181 | 182 | 183 | def E_Classifacation_rbfSVC(self): 184 | 185 | from sklearn.preprocessing import StandardScaler 186 | 187 | method = SVC() 188 | cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=1, random_state=1) 189 | # if 190 | n_scores = cross_val_score( 191 | method, 192 | StandardScaler().fit_transform(self.latent), 193 | self.label.astype(np.int32), 194 | scoring="accuracy", 195 | cv=cv, 196 | n_jobs=-1, 197 | ) 198 | 199 | return n_scores.mean() 200 | 201 | def E_Curance(self, use_all_data=False): 202 | # self.E_Curance_pre() 203 | 204 | # if self.label 205 | # label_ = self.label 206 | if use_all_data: 207 | label_ = np.array([0]*self.neighbour_input.shape[0]) 208 | else: 209 | label_ = self.label 210 | 211 | # print(label_) 212 | 213 | path_list = Curance_path_list(self.neighbour_input, self.distance_input, label_) 214 | 215 | # print(path_list) 216 | alpha_list = [] 217 | for path in path_list: 218 | if len(path)>3: 219 | # print(path) 220 | for i in range(len(path)-3): 221 | a_index = path[0] 222 | b_index = path[i+1] 223 | c_index = path[-1] 224 | # print([a_index,b_index,c_index]) 225 | 226 | v1 = self.latent[b_index] - self.latent[a_index] 227 | v2 = self.latent[c_index] - self.latent[b_index] 228 | cos_alpha = v1.dot(v2)/(np.linalg.norm(v1) * np.linalg.norm(v2)) 229 | alpha = np.arccos(cos_alpha) 230 | alpha_list.append(alpha) 231 | # print( alpha_list ) 232 | # print( alpha_list ) 233 | alpha_list = np.array(alpha_list) 234 | alpha_list = alpha_list[~np.isnan(alpha_list)] 235 | return np.mean( alpha_list ) 236 | 237 | 238 | def E_Curance_2(self, use_all_data=False): 239 | # self.E_Curance_pre() 240 | 241 | # if self.label 242 | # label_ = self.label 243 | if use_all_data: 244 | label_ = np.array([0]*self.neighbour_input.shape[0]) 245 | else: 246 | label_ = self.label 247 | 248 | # print(label_) 249 | 250 | path_list = Curance_path_list(self.neighbour_input, self.distance_input, label_) 251 | 252 | # print(path_list) 253 | alpha_list = [] 254 | for path in path_list: 255 | if len(path)>3: 256 | # print(path) 257 | for i in range(len(path)-3): 258 | a_index = path[i] 259 | b_index = path[i+1] 260 | c_index = path[i+2] 261 | 262 | v1 = self.latent[b_index] - self.latent[a_index] 263 | v2 = self.latent[c_index] - self.latent[b_index] 264 | cos_alpha = v1.dot(v2)/(np.linalg.norm(v1) * np.linalg.norm(v2)) 265 | alpha = np.arccos(cos_alpha) 266 | alpha_list.append(alpha) 267 | # print( alpha_list ) 268 | # print( alpha_list ) 269 | alpha_list = np.array(alpha_list) 270 | alpha_list = alpha_list[~np.isnan(alpha_list)] 271 | return np.mean( alpha_list ) 272 | 273 | def TestClassifacationKMeans(self, embedding, label, n_clusters=None): 274 | 275 | 276 | # l1 = list(set(label)) 277 | # numclass1 = len(l1) 278 | # predict_labels = KMeans(n_clusters=numclass1, random_state=0).fit_predict(embedding) 279 | 280 | # l2 = list(set(predict_labels)) 281 | # numclass2 = len(l2) 282 | 283 | # cost = np.zeros((numclass1, numclass2), dtype=int) 284 | # for i, c1 in enumerate(l1): 285 | # mps = [i1 for i1, e1 in enumerate(label) if e1 == c1] 286 | # for j, c2 in enumerate(l2): 287 | # mps_d = [i1 for i1 in mps if predict_labels[i1] == c2] 288 | # cost[i][j] = len(mps_d) 289 | 290 | # # match two clustering results by Munkres algorithm 291 | # m = Munkres() 292 | # cost = cost.__neg__().tolist() 293 | 294 | # indexes = m.compute(cost) 295 | 296 | # # get the match results 297 | # new_predict = np.zeros(len(predict_labels)) 298 | # for i, c in enumerate(l1): 299 | # # correponding label in l2: 300 | # c2 = l2[indexes[i][1]] 301 | 302 | # # ai is the index with label==c2 in the pred_label list 303 | # ai = [ind for ind, elm in enumerate(predict_labels) if elm == c2] 304 | # new_predict[ai] = int(c) 305 | 306 | # acc = metrics.accuracy_score(label, new_predict) 307 | 308 | true = label 309 | class_num = len(np.unique(true)) 310 | # pred = SpectralClustering(n_clusters=class_num, assign_labels='discretize').fit_predict(ins_emb) 311 | pred = KMeans(n_clusters=class_num if n_clusters is None else n_clusters, random_state=0).fit_predict(embedding) 312 | 313 | cnt = len(true) 314 | cm = cnt - confusion_matrix(pred, true) 315 | idxs = Munkres().compute(cm) 316 | idxs = dict(idxs) 317 | for i, num in enumerate(pred): 318 | pred[i] = idxs[num] 319 | 320 | self.k_means_pre = pred 321 | 322 | acc = metrics.accuracy_score(label, pred) 323 | 324 | return acc #, nmi, f1_macro, precision_macro, adjscore 325 | 326 | 327 | def E_Clasting_Kmeans(self, n_clusters=None): 328 | 329 | # from sklearn.preprocessing import StandardScaler 330 | 331 | # method = SVC(kernel="linear", max_iter=90000) 332 | # cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=1, random_state=1) 333 | # if 334 | # n_scores = cross_val_score( 335 | # method, 336 | # StandardScaler().fit_transform(self.latent), 337 | # self.label.astype(np.int32), 338 | # scoring="accuracy", 339 | # cv=cv, 340 | # n_jobs=-1 341 | # ) 342 | return self.TestClassifacationKMeans(self.latent, self.label.astype(np.int32), n_clusters=n_clusters) 343 | 344 | def E_Clasting_louvain(self, n_clusters=None): 345 | 346 | true_label = self.label.astype(np.int32) 347 | embedding = self.latent 348 | class_num = np.max(true_label) + 1 349 | # pred = SpectralClustering(n_clusters=class_num, assign_labels='discretize').fit_predict(ins_emb) 350 | # pred = KMeans(n_clusters=class_num if n_clusters is None else n_clusters, random_state=0).fit_predict(embedding) 351 | pred = targeted_cluster(embedding, target_n_clusters=11) 352 | 353 | cnt = len(true_label) 354 | cm = cnt - confusion_matrix(pred, true_label) 355 | idxs = Munkres().compute(cm) 356 | idxs = dict(idxs) 357 | for i, num in enumerate(pred): 358 | pred[i] = idxs[num] 359 | 360 | self.louvain_pre = pred 361 | 362 | acc = metrics.accuracy_score(true_label, pred) 363 | 364 | return acc #, nmi, f1_macro, precision_macro, adjscore 365 | 366 | def E_Classifacation_KNN(self): 367 | 368 | from sklearn.neighbors import KNeighborsClassifier 369 | method = KNeighborsClassifier(n_neighbors=3) 370 | cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=1, random_state=1) 371 | # if 372 | n_scores = cross_val_score( 373 | method, self.latent, self.label.astype(np.int32), scoring="accuracy", cv=cv, n_jobs=-1 374 | ) 375 | 376 | return n_scores.mean() 377 | 378 | def E_NNACC(self): 379 | 380 | indexNN = self.neighbour_latent[:, 0].reshape(-1) 381 | labelNN = self.label[indexNN] 382 | acc = (self.label == labelNN).sum() / self.label.shape[0] 383 | 384 | return acc 385 | 386 | def E_mrre(self, ): 387 | """ 388 | Calculates the mean relative rank error quality metric of the data 389 | space `X` with respect to the latent space `Z`, subject to its $k$ 390 | nearest neighbours. 391 | """ 392 | 393 | k=self.k 394 | 395 | X_neighbourhood, X_ranks = self.neighbour_input, self.rank_input 396 | Z_neighbourhood, Z_ranks = self.neighbour_latent, self.rank_latent 397 | 398 | n = self.distance_input.shape[0] 399 | 400 | # First component goes from the latent space to the data space, i.e. 401 | # the relative quality of neighbours in `Z`. 402 | 403 | mrre_ZX = 0.0 404 | for row in range(n): 405 | for neighbour in Z_neighbourhood[row]: 406 | rx = X_ranks[row, neighbour] 407 | rz = Z_ranks[row, neighbour] 408 | 409 | mrre_ZX += abs(rx - rz) / rz 410 | 411 | # Second component goes from the data space to the latent space, 412 | # i.e. the relative quality of neighbours in `X`. 413 | 414 | mrre_XZ = 0.0 415 | for row in range(n): 416 | # Note that this uses a different neighbourhood definition! 417 | for neighbour in X_neighbourhood[row]: 418 | rx = X_ranks[row, neighbour] 419 | rz = Z_ranks[row, neighbour] 420 | 421 | # Note that this uses a different normalisation factor 422 | mrre_XZ += abs(rx - rz) / rx 423 | 424 | # Normalisation constant 425 | C = n * sum([abs(2 * j - n - 1) / j for j in range(1, k + 1)]) 426 | return mrre_ZX / C, mrre_XZ / C 427 | 428 | def E_distanceAUC(self,): 429 | 430 | disZN = (self.distance_latnet-self.distance_latnet.min())/(self.distance_latnet.max()-self.distance_latnet.min()) 431 | LRepeat = self.label.reshape(1,-1).repeat(self.distance_latnet.shape[0], axis=0) 432 | L = (LRepeat==LRepeat.T).reshape(-1) 433 | auc = metrics.roc_auc_score(1-L, disZN.reshape(-1)) 434 | 435 | return auc 436 | 437 | def E_trustworthiness(self): 438 | X_neighbourhood, X_ranks = self.neighbour_input, self.rank_input 439 | Z_neighbourhood, Z_ranks = self.neighbour_latent, self.rank_latent 440 | n = self.distance_input.shape[0] 441 | return self._trustworthiness( 442 | X_neighbourhood, X_ranks, Z_neighbourhood, Z_ranks, n, self.k 443 | ) 444 | 445 | def E_continuity(self): 446 | """ 447 | Calculates the continuity measure between the data space `X` and the 448 | latent space `Z`, given a neighbourhood parameter `k` for setting up 449 | the extent of neighbourhoods. 450 | 451 | This is just the 'flipped' variant of the 'trustworthiness' measure. 452 | """ 453 | 454 | X_neighbourhood, X_ranks = self.neighbour_input, self.rank_input 455 | Z_neighbourhood, Z_ranks = self.neighbour_latent, self.rank_latent 456 | n = self.distance_input.shape[0] 457 | # Notice that the parameters have to be flipped here. 458 | return self._trustworthiness( 459 | Z_neighbourhood, Z_ranks, X_neighbourhood, X_ranks, n, self.k 460 | ) 461 | 462 | def E_Rscore(self): 463 | # n = self.distance_input.shape[0] 464 | import scipy 465 | r = scipy.stats.pearsonr(self.distance_input.reshape(-1), self.distance_latnet.reshape(-1)) 466 | # print(r) 467 | return r[0] 468 | 469 | def E_Dismatcher(self): 470 | emb, label = self.latent, self.label 471 | list_dis = [] 472 | for i in list(set(label)): 473 | p = emb[label==i] 474 | m = p.mean(axis=0)[None,:] 475 | list_dis.append(pairwise_distances(p, m).mean()) 476 | list_dis = np.array(list_dis) 477 | list_dis_norm=list_dis/list_dis.max() 478 | sort1 = np.argsort(list_dis_norm) 479 | # print('latent std:', list_dis_norm) 480 | # print('latent sort:', sort1) 481 | 482 | emb, label = self.input, self.label 483 | emb = emb.reshape(emb.shape[0],-1) 484 | list_dis = [] 485 | for i in list(set(label)): 486 | p = emb[label==i] 487 | m = p.mean(axis=0)[None,:] 488 | list_dis.append(pairwise_distances(p, m).mean()) 489 | list_dis = np.array(list_dis) 490 | list_dis_norm=list_dis/list_dis.max() 491 | sort2 = np.argsort(list_dis_norm) 492 | # print('latent std:', list_dis_norm) 493 | # print('latent sort:', sort2) 494 | 495 | 496 | v, s, t = 0, sort2.tolist(), sort1.tolist() 497 | for i in range(len(t)): 498 | if t[i] != s[i]: 499 | v = v + abs(t.index(s[i])-i) 500 | s_constant = (2.0/len(s)**2) 501 | 502 | return v * s_constant -------------------------------------------------------------------------------- /dataloader/stdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib 3 | import warnings 4 | import numpy as np 5 | import scanpy as sc 6 | from tqdm import tqdm, trange 7 | 8 | import PIL 9 | from PIL import Image 10 | 11 | from sklearn.preprocessing import MinMaxScaler 12 | from sklearn.decomposition import PCA 13 | 14 | 15 | class STData(): 16 | def __init__(self, name="V1_Adult_Mouse_Brain", sample='barcode', crop_size=None, use_quality="hires", use_adata=None, bio_norm=True, force_no_morph=False, 17 | gene_dim=1000, gene_decompn_method="HVG", seed=0, datapath="data/", tmppath="save_processed_data/", resize_as=None, true_wildcard=999): 18 | # Globals 19 | PIL.Image.MAX_IMAGE_PIXELS = 30000 * 30000 20 | self.seed = seed 21 | self.name = name 22 | self.force_no_morph = force_no_morph 23 | self.use_quality = use_quality 24 | self.true_wildcard = true_wildcard 25 | self.platform = '10x' 26 | 27 | self.datapath = datapath 28 | self.tmppath = tmppath 29 | os.makedirs(self.datapath, exist_ok=True) 30 | os.makedirs(self.tmppath, exist_ok=True) 31 | 32 | # pack dataset from existing adata 33 | if use_adata is None: 34 | if name in ["190921_19", "191007_07", "200127_15", "190926_06", "190926_02", "191204_01", 35 | "190921_21", "190926_01", "200306_03", "190926_03", "200115_08"]: 36 | self.platform = 'slide' 37 | self.adata = sc.read(f"{self.datapath}{name}/data.h5ad") 38 | elif name in ['MOB', 'MOB_3000', 'ME95', 'ME145', 'ME145_3000']: 39 | self.platform = 'stereo' 40 | self.adata = sc.read(f"{self.datapath}{name}/data.h5ad") 41 | else: 42 | try: 43 | self.adata = sc.read_visium(self.datapath + name) 44 | except FileNotFoundError: 45 | try: 46 | self.adata = sc.datasets.visium_sge(self.name, 47 | include_hires_tiff=True if use_quality == "fulres" else False) 48 | except urllib.error.HTTPError: 49 | raise NotImplementedError("Non-Implmented Dataset! You can put your own `self.adata` here") 50 | self.adata.var_names_make_unique() 51 | else: 52 | self.adata = use_adata 53 | 54 | # ditch non-mark spots if available 55 | if true_wildcard is not None and self.get_label() is not None: 56 | self.true = self.get_label() 57 | self.mask = self.true != true_wildcard 58 | if len(self.adata) != sum(self.mask): 59 | self.adata = self.adata[self.mask, :] 60 | self.true = self.true[self.mask] 61 | 62 | # sampling 63 | np.random.seed(self.seed) 64 | if isinstance(sample, int) and sample < len(self.adata): 65 | if sample == 15000: 66 | print("default sampling!") 67 | sc.pp.filter_cells(self.adata, min_genes=0) 68 | n_genes = self.adata.obs['n_genes'].to_numpy() 69 | idxs = np.argsort(n_genes) 70 | self.selected_idx = idxs[-sample:] 71 | self.adata = self.adata[self.selected_idx] 72 | elif isinstance(sample, str) and sample == 'barcode': 73 | try: 74 | selected = np.load(f"{self.datapath}{name}/selected_spots.npy") 75 | self.adata = self.adata[selected] 76 | except FileNotFoundError: 77 | warnings.warn('Try to use selected spots, but not found!') 78 | elif isinstance(sample, str) and sample.startswith('uni'): 79 | idxs = np.arange(len(self.adata)) 80 | np.random.shuffle(idxs) 81 | self.selected_idx = idxs[:eval(sample[3:])] 82 | self.adata = self.adata[self.selected_idx] 83 | 84 | # Gene-related 85 | self.gene_dim = gene_dim 86 | self.gene_decompn_method = gene_decompn_method 87 | 88 | if bio_norm: 89 | self.default_norm() 90 | 91 | # Image-related 92 | self.coords = np.asarray(self.adata.obsm["spatial"], dtype=np.float32) 93 | self.x_lim = (int(min(self.coords[:, 0])), int(max(self.coords[:, 0]))) 94 | self.y_lim = (int(max(self.coords[:, 1])), int(min(self.coords[:, 1]))) 95 | if self.platform == '10x': 96 | self.resize_as = resize_as 97 | self.crop_size = 224 if crop_size is None else crop_size 98 | 99 | if self.use_quality == "fulres": 100 | path = f"{self.datapath}/{self.name}/image.tif" 101 | self.image = np.asarray(Image.open(path), dtype=np.float32) 102 | self.image /= 255 103 | else: 104 | self.image = self.adata.uns["spatial"][name]["images"][use_quality] 105 | self.x_lim = (0, self.image.shape[1]) 106 | self.y_lim = (self.image.shape[0], 0) 107 | 108 | if use_quality != "fulres": 109 | self.coords *= self.adata.uns["spatial"][name]["scalefactors"][f"tissue_{use_quality}_scalef"] 110 | 111 | def default_norm(self): 112 | sc.pp.filter_genes(self.adata, min_cells=1) 113 | sc.pp.normalize_total(self.adata) 114 | sc.pp.log1p(self.adata) 115 | if self.gene_dim != -1 and self.gene_decompn_method == "HVG": 116 | sc.pp.highly_variable_genes(self.adata, n_top_genes=self.gene_dim) 117 | self.adata = self.adata[:, self.adata.var.highly_variable] 118 | 119 | def get_morph(self, use_saved_data=True): 120 | if self.platform in ['slide', 'stereo'] or self.force_no_morph: 121 | return None 122 | 123 | path = self.tmppath + f"morph_{self.name}_{self.use_quality}_{self.crop_size}_{self.resize_as}_{self.seed}_{self.true_wildcard}_{len(self.adata)}_ver230509.npy" 124 | if os.path.exists(path) and use_saved_data: 125 | return np.load(path, allow_pickle=True) 126 | 127 | import torch 128 | from torchvision.models import resnet50 129 | torch.manual_seed(self.seed) 130 | 131 | model = resnet50(pretrained=True) 132 | model.fc = torch.nn.Identity() 133 | model.cuda() 134 | model.eval() 135 | 136 | image = self.image 137 | morph_feature = [] 138 | if len(self.image.shape) == 2: 139 | image = self.image[:, :, np.newaxis] 140 | model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, 141 | bias=False).cuda() 142 | for coord in tqdm(self.coords, desc="Extracting Morphological Feature"): 143 | left = max(int(coord[0] - self.crop_size / 2), 0) 144 | up = max(int(coord[1] - self.crop_size / 2), 0) 145 | # Approximate Upperbound 146 | right = min(left + self.crop_size, self.y_lim[0]) 147 | down = min(up + self.crop_size, self.x_lim[1]) 148 | pic = image[up:down, left:right] 149 | if self.resize_as is not None: 150 | from PIL import Image 151 | pic = Image.fromarray(np.uint8(pic * 255)) 152 | pic = pic.resize((self.resize_as, self.resize_as)) 153 | pic = np.asarray(pic).astype(np.float32) / 255 154 | pic = np.transpose(pic, (2, 0, 1)) 155 | feature = model(torch.tensor(np.expand_dims(pic, axis=0)).cuda()) 156 | morph_feature.append(feature.detach().cpu().numpy().squeeze()) 157 | 158 | morph_feature = PCA(n_components=500, random_state=self.seed).fit_transform(morph_feature) 159 | 160 | if use_saved_data: 161 | np.save(path, morph_feature.astype(np.float32)) 162 | 163 | return morph_feature.astype(np.float32) 164 | 165 | def get_spot_index(self): 166 | return self.adata.obs[["array_col", "array_row"]].to_numpy() 167 | 168 | def get_trans(self): 169 | try: 170 | trans = np.array(self.adata.X.todense(), dtype=np.float32) 171 | except AttributeError: # raised by non-sparse matrix 172 | trans = np.asarray(self.adata.X, dtype=np.float32) 173 | 174 | if self.gene_decompn_method == "PCA": 175 | trans = PCA(n_components=self.gene_dim, random_state=self.seed).fit_transform(trans) 176 | 177 | return trans 178 | 179 | def get_image(self): 180 | return self.image 181 | 182 | def get_annotation_class(self, suffix="biofeature"): 183 | import pandas as pd 184 | 185 | data = pd.read_csv("annotation/dataset_setting.csv", index_col="dataset") 186 | data.index = data.index.astype(str) 187 | try: 188 | data = data[data["suffix"] == suffix].loc[self.name] 189 | except KeyError: 190 | return None 191 | 192 | return int(data["class_num"]) 193 | 194 | def get_label(self, suffix="biofeature", use_saved=True, refine=None): 195 | if hasattr(self, 'true'): 196 | return self.true 197 | 198 | path = f"annotation/{self.name}_{suffix}" 199 | if use_saved: 200 | try: 201 | return np.load(path + ".npy") 202 | except FileNotFoundError: 203 | warnings.warn("Try to use true label, but not found!") 204 | 205 | from sklearn.cluster import KMeans 206 | 207 | # make true label from annotation 208 | try: 209 | im = np.array(Image.open(path + ".png")) 210 | except: 211 | return None 212 | 213 | n_clusters = self.get_annotation_class(suffix=suffix) 214 | coords = self.adata.obsm["spatial"].astype(np.float32) 215 | coords *= self.adata.uns["spatial"][self.name]["scalefactors"]["tissue_hires_scalef"] 216 | labels = [] 217 | for coord in coords: 218 | val = im[int(coord[1]), int(coord[0])] 219 | val = val[0] + val[1] * 255 + val[2] * 255 * 255 + val[3] * 255 * 255 * 255 220 | labels.append([val]) 221 | labels = np.array(labels) 222 | labels = KMeans(n_clusters=n_clusters, random_state=self.seed).fit_predict(labels) 223 | 224 | if refine == "hexagon": 225 | from sklearn.neighbors import NearestNeighbors 226 | 227 | nbrs = NearestNeighbors(radius=30).fit(self.get_coords()) 228 | idxs = nbrs.radius_neighbors(self.get_coords(), return_distance=False) 229 | 230 | for i, vec in enumerate(idxs): 231 | vec_label = np.vectorize(labels.__getitem__)(vec) 232 | if np.sum(vec_label == labels[i]) <= 1: 233 | labels[i] = np.argmax(np.bincount(vec_label)) 234 | 235 | np.save(path + ".npy", labels) 236 | 237 | return np.array(labels) 238 | 239 | def get_coords(self): 240 | return np.array(self.coords) 241 | 242 | def filter_sect(self, x1, x2, y1, y2): 243 | xs = self.coords[:, 0] 244 | ys = self.coords[:, 1] 245 | m_x = (x1 < xs) & (xs < x2) 246 | m_y = (y1 < ys) & (ys < y2) 247 | mask = m_x & m_y 248 | self.coords = self.coords[mask] 249 | self.adata = self.adata[mask, :] 250 | 251 | return mask 252 | 253 | def filter_region(self, label, regions): 254 | mask = np.array([False] * len(self.adata)) 255 | for r in regions: 256 | mask |= (label == r) 257 | self.coords = self.coords[mask] 258 | self.adata = self.adata[mask, :] 259 | 260 | return mask 261 | 262 | def find_HVG(self, label:np.ndarray, num_gene=50, flavor='seurat_v3'): 263 | res = {} 264 | for target_cluster in np.unique(label): 265 | uni_cluster_adata = self.adata[label == target_cluster, :] 266 | sc.pp.highly_variable_genes(uni_cluster_adata, n_top_genes=num_gene, inplace=True, flavor=flavor) 267 | res[target_cluster] = np.asarray(uni_cluster_adata[:, uni_cluster_adata.var.highly_variable].var_names).astype(str) 268 | 269 | return res 270 | 271 | def get_related_gene(self, label, c, ret_val=False, norm=False): 272 | gene = self.get_trans() 273 | direct = np.ones(len(self.adata)) 274 | direct[label != c] *= -1 275 | val = np.sum(gene.T * direct, axis=1) 276 | if norm: 277 | val /= np.sum(gene, axis=0) 278 | idxs = np.flip(np.argsort(val)) 279 | names = np.asarray(self.adata.var_names).astype(str) 280 | genes = names[idxs] 281 | 282 | if ret_val: 283 | return genes, val[idxs] 284 | else: 285 | return genes 286 | 287 | def get_plot_bg_image(self): 288 | bg_path = f'background_image/{self.name}.png' 289 | if os.path.exists(bg_path): 290 | return np.asarray(Image.open(bg_path), dtype=np.float32) / 255 291 | else: 292 | return self.image 293 | 294 | def px_plot_spatial(self, label: np.ndarray, background_image=False, save_path=None, base=None, dpi=400, title=None, opacity=None): 295 | import plotly.express as px 296 | import plotly.graph_objects as go 297 | 298 | if self.platform != '10x' and background_image: 299 | warnings.warn("no bg image for non-10x platform") 300 | background_image = False 301 | 302 | if opacity is None: 303 | opacity = .3 if background_image else 1 304 | title = f"{len(np.unique(label))} classes" if title is None else title 305 | 306 | # Save Figure 307 | if save_path is not None: 308 | import matplotlib.pyplot as plt 309 | s = 4 if self.platform == '10x' else 1 310 | if self.platform == '10x': 311 | figsize = (5, 5) 312 | elif self.platform == 'stereo': 313 | ratio = np.abs(self.y_lim[0] - self.y_lim[1]) / np.abs(self.x_lim[0] - self.x_lim[1]) 314 | if base is None and self.name == 'ME145': 315 | base = 10 316 | elif self.name == 'ME95': 317 | base = 3 318 | elif self.name == 'MOB': 319 | base = 6 320 | figsize = (base, base * ratio) 321 | else: 322 | figsize = (7, 7) 323 | 324 | fig_plt, ax = plt.subplots(figsize=figsize, dpi=dpi) 325 | ax.set_xlim([self.x_lim[0], self.x_lim[1]]) 326 | ax.set_ylim([self.y_lim[0], self.y_lim[1]]) 327 | if background_image: 328 | ax.imshow(self.get_plot_bg_image()) 329 | ax.scatter(x=self.coords[:, 0], y=self.coords[:, 1], c=self._map_color(label), s=s, alpha=opacity) 330 | plt.axis('off') 331 | 332 | fig_plt.savefig(save_path, bbox_inches='tight', pad_inches=0) 333 | plt.close() 334 | 335 | if background_image: 336 | if self.use_quality == "fulres": 337 | warnings.warn("Plotting with fulres may lead to a figure taking huge memory.") 338 | fig_plotly = px.imshow(self.get_plot_bg_image()) 339 | else: 340 | fig_plotly = go.Figure() 341 | 342 | for i in np.unique(label): 343 | mask = label == i 344 | idxs = np.arange(len(label))[mask] 345 | fig_plotly.add_trace(go.Scatter( 346 | mode='markers', 347 | x=self.coords[mask, 0], 348 | y=self.coords[mask, 1], 349 | opacity=opacity, 350 | marker=dict( 351 | color=self._map_color(label[mask]), 352 | size= 10 if self.platform == '10x' else 2 353 | ), 354 | hovertemplate="%{text}", 355 | text=[f"Spot: {self.adata.obs.index[i]}
" 356 | f"x: {self.coords[i, 0]:.2f}
" 357 | f"y: {self.coords[i, 1]:.2f}
" 358 | f"index: {i}" 359 | for i in idxs], 360 | showlegend=True, 361 | name=str(i) 362 | )) 363 | fig_plotly.update_layout({ 364 | "plot_bgcolor": "rgba(0, 0, 0, 0)", 365 | "paper_bgcolor": "rgba(0, 0, 0, 0)", 366 | }) 367 | fig_plotly.update_layout( 368 | title=title, 369 | xaxis_range=[self.x_lim[0], self.x_lim[1]], 370 | yaxis_range=[self.y_lim[0], self.y_lim[1]] 371 | ) 372 | 373 | return fig_plotly 374 | 375 | 376 | def px_plot_spatial_gene(self, gene, background_image=False, save_path=None, style='sequential', title=None, dpi=400): 377 | import plotly.express as px 378 | import plotly.graph_objects as go 379 | 380 | opacity = .5 if background_image else 1 381 | if isinstance(gene, str): 382 | try: 383 | data = np.asarray(self.adata[:, gene].X.todense()).ravel().astype(np.float32) 384 | except AttributeError: 385 | data = np.asarray(self.adata[:, gene].X).ravel().astype(np.float32) 386 | data /= data.max() 387 | else: 388 | data = gene 389 | gene = 'user_defined' 390 | 391 | title = gene if title is None else title 392 | 393 | # Save Figure 394 | if save_path is not None: 395 | import matplotlib.pyplot as plt 396 | s = 4 if self.platform == '10x' else 1 397 | if self.platform == '10x': 398 | figsize = (5, 5) 399 | elif self.platform == 'stereo': 400 | ratio = np.abs(self.y_lim[0] - self.y_lim[1]) / np.abs(self.x_lim[0] - self.x_lim[1]) 401 | if self.name == 'ME145': 402 | base = 8 403 | elif self.name == 'ME95': 404 | base = 3 405 | elif self.name == 'MOB': 406 | base = 6 407 | figsize = (base, base * ratio) 408 | else: 409 | figsize = (7, 7) 410 | 411 | fig_plt, ax = plt.subplots(figsize=figsize, dpi=dpi) 412 | ax.set_xlim([self.x_lim[0], self.x_lim[1]]) 413 | ax.set_ylim([self.y_lim[0], self.y_lim[1]]) 414 | if background_image: 415 | ax.imshow(self.get_plot_bg_image()) 416 | ax.scatter(x=self.coords[:, 0], y=self.coords[:, 1], c=data, s=s, alpha=opacity, cmap='RdPu' if style=='sequential' else 'seismic') # seismic 417 | plt.axis('off') 418 | 419 | fig_plt.savefig(save_path, bbox_inches='tight', pad_inches=0) 420 | plt.close() 421 | 422 | if background_image: 423 | if self.use_quality == "fulres": 424 | warnings.warn("Plotting with fulres may lead to a figure taking huge memory.") 425 | fig_plotly = px.imshow(self.get_plot_bg_image()) 426 | else: 427 | fig_plotly = go.Figure() 428 | 429 | idxs = np.arange(len(data)) 430 | fig_plotly.add_trace(go.Scatter( 431 | mode='markers', 432 | x=self.coords[:, 0], 433 | y=self.coords[:, 1], 434 | opacity=opacity, 435 | marker=dict( 436 | color=data, 437 | colorscale="RdPu" if style=='sequential' else 'RdBu', # sequential RdPu diverging RdBu 438 | ), 439 | hovertemplate="%{text}", 440 | text=[f"Spot: {self.adata.obs.index[i]}
" 441 | f"Expr_ratio: {data[i]:.2f}
" 442 | f"x: {self.coords[i, 0]:.2f}
" 443 | f"y: {self.coords[i, 1]:.2f}
" 444 | f"index: {i}" 445 | for i in idxs], 446 | showlegend=True, 447 | name=gene, 448 | )) 449 | fig_plotly.update_layout({ 450 | "plot_bgcolor": "rgba(0, 0, 0, 0)", 451 | "paper_bgcolor": "rgba(0, 0, 0, 0)", 452 | }) 453 | fig_plotly.update_layout( 454 | title=title, 455 | xaxis_range=[0, self.x_lim], 456 | yaxis_range=[self.y_lim, 0] 457 | ) 458 | 459 | return fig_plotly 460 | 461 | 462 | def px_plot_spatial_massive_gene(self, genes, col_num=7, figsize=400, markersize=2, save_path=None, title=None, background_image=False, verbose=False): 463 | import plotly.graph_objects as go 464 | from plotly.subplots import make_subplots 465 | 466 | opacity = .5 if background_image else 1 467 | 468 | try: 469 | data = self.adata[:, genes].X.todense() 470 | except AttributeError: 471 | data = self.adata[:, genes].X 472 | data = np.asarray(data, dtype=np.float32) 473 | data /= data.max() 474 | 475 | row_num = np.ceil(len(genes)/col_num).astype(int).item() 476 | fig_plotly = make_subplots(rows=row_num, cols=col_num, 477 | subplot_titles=genes) 478 | 479 | if verbose: 480 | tr = trange(data.shape[1]) 481 | else: 482 | tr = range(data.shape[1]) 483 | for i in tr: 484 | fig_plotly.add_trace(go.Scatter( 485 | mode='markers', 486 | x=self.coords[:, 0], 487 | y=self.coords[:, 1], 488 | opacity=opacity, 489 | marker=dict( 490 | color=data[:, i], 491 | colorscale="Magenta", 492 | ), 493 | hovertemplate="%{text}", 494 | text=[f"Spot: {spot}
x: {coord[0]:.2f}
y: {coord[1]:.2f}" 495 | for spot, coord in zip(np.array(self.adata.obs.index), self.coords)], 496 | showlegend=True, 497 | name=genes[i], 498 | ), row=i // col_num + 1, col=i % col_num + 1) 499 | fig_plotly.update_layout({ 500 | "plot_bgcolor": "#ffffff", 501 | "paper_bgcolor": "#ffffff", 502 | }) 503 | fig_plotly.update_layout( 504 | width=figsize * col_num, 505 | height=figsize * row_num, 506 | title_text=title, 507 | ) 508 | fig_plotly.update_xaxes(visible=False) 509 | fig_plotly.update_yaxes(visible=False) 510 | 511 | if save_path is not None: 512 | if save_path.endswith('html'): 513 | fig_plotly.write_html(save_path) 514 | else: 515 | fig_plotly.write_image(save_path, scale=4) 516 | 517 | return fig_plotly 518 | 519 | 520 | def px_plot_graph_structure(self, edge_list: np.ndarray, background_image=False, title=None): 521 | import plotly.express as px 522 | import plotly.graph_objects as go 523 | 524 | lines = self.coords[edge_list] 525 | 526 | if background_image: 527 | if self.use_quality == "fulres": 528 | warnings.warn("Plotting with fulres may lead to a huge figure while taking huge memory.") 529 | fig_plotly = px.imshow(self.get_plot_bg_image()) 530 | else: 531 | fig_plotly = go.Figure() 532 | 533 | x = lines[:, :, 0].reshape(-1) 534 | y = lines[:, :, 1].reshape(-1) 535 | 536 | x = np.insert(x, np.arange(2, len(x), 2), np.nan) 537 | y = np.insert(y, np.arange(2, len(y), 2), np.nan) 538 | 539 | fig_plotly.add_trace(go.Scatter( 540 | mode='lines', 541 | x=x, 542 | y=y, 543 | line=dict( 544 | # color="black", 545 | ) 546 | )) 547 | fig_plotly.update_layout({ 548 | "plot_bgcolor": "rgba(0, 0, 0, 0)", 549 | "paper_bgcolor": "rgba(0, 0, 0, 0)", 550 | }) 551 | fig_plotly.update_layout( 552 | title=title, 553 | xaxis_range=[self.x_lim[0], self.x_lim[1]], 554 | yaxis_range=[self.y_lim[0], self.y_lim[1]] 555 | ) 556 | 557 | return fig_plotly 558 | 559 | 560 | def px_plot_embedding(self, latent, label, method="UMAP", save_path=None, title=None): 561 | import plotly.graph_objects as go 562 | 563 | title = f"{len(np.unique(label))} classes" if title is None else title 564 | 565 | if latent.shape[1] == 2: 566 | embedding = latent 567 | elif method == "UMAP": 568 | from umap import UMAP 569 | embedding = UMAP(n_components=2, random_state=self.seed).fit_transform(latent) 570 | elif method == "TSNE": 571 | from sklearn.manifold import TSNE 572 | embedding = TSNE(n_components=2, random_state=self.seed).fit_transform(latent) 573 | else: 574 | raise ValueError(f"{method} not valid!") 575 | 576 | # Save Figure 577 | if save_path is not None: 578 | import matplotlib.pyplot as plt 579 | figsize = (7, 7) if self.platform == '10x' else (10, 10) 580 | 581 | fig_plt, ax = plt.subplots(figsize=figsize, dpi=400) 582 | ax.scatter(x=embedding[:, 0], y=embedding[:, 1], c=self._map_color(label), s=1) 583 | plt.axis('off') 584 | 585 | fig_plt.savefig(save_path, bbox_inches='tight', pad_inches=0) 586 | 587 | fig_plotly = go.Figure() 588 | for i in np.unique(label): 589 | mask = label == i 590 | fig_plotly.add_trace(go.Scatter( 591 | mode='markers', 592 | x=embedding[mask, 0], 593 | y=embedding[mask, 1], 594 | marker=dict( 595 | color=self._map_color(label[mask]), 596 | size=3 597 | ), 598 | hovertemplate="%{text}", 599 | text=[f"Spot: {spot}
x: {coord[0]:.2f}
y: {coord[1]:.2f}" 600 | for spot, coord in zip(np.array(self.adata.obs.index)[mask], self.coords[mask])], 601 | showlegend=True, 602 | name=str(i) 603 | )) 604 | fig_plotly.update_layout({ 605 | "plot_bgcolor": "rgba(0, 0, 0, 0)", 606 | "paper_bgcolor": "rgba(0, 0, 0, 0)", 607 | }) 608 | fig_plotly.update_layout( 609 | title=title, 610 | ) 611 | 612 | return fig_plotly 613 | 614 | 615 | def px_plot_embedding_gene(self, latent, gene, method="UMAP", save_path=None, title=None): 616 | import plotly.graph_objects as go 617 | 618 | title = gene if title is None else title 619 | data = np.asarray(self.adata[:, gene].X.todense()).ravel().astype(np.float32) 620 | data /= data.max() 621 | 622 | if latent.shape[1] == 2: 623 | embedding = latent 624 | elif method == "UMAP": 625 | from umap import UMAP 626 | embedding = UMAP(n_components=2, random_state=self.seed).fit_transform(latent) 627 | elif method == "TSNE": 628 | from sklearn.manifold import TSNE 629 | embedding = TSNE(n_components=2, random_state=self.seed).fit_transform(latent) 630 | else: 631 | raise ValueError(f"{method} not valid!") 632 | 633 | # Save Figure 634 | if save_path is not None: 635 | import matplotlib.pyplot as plt 636 | 637 | fig_plt, ax = plt.subplots(figsize=(5, 5), dpi=400) 638 | ax.scatter(x=embedding[:, 0], y=embedding[:, 1], c=data, s=2) 639 | plt.axis('off') 640 | 641 | fig_plt.savefig(save_path, bbox_inches='tight', pad_inches=0) 642 | 643 | fig_plotly = go.Figure() 644 | fig_plotly.add_trace(go.Scatter( 645 | mode='markers', 646 | x=embedding[:, 0], 647 | y=embedding[:, 1], 648 | marker=dict( 649 | color=data, 650 | colorscale="PuRd", 651 | ), 652 | hovertemplate="%{text}", 653 | text=[f"Spot: {spot}
x: {coord[0]:.2f}
y: {coord[1]:.2f}" 654 | for spot, coord in zip(np.array(self.adata.obs.index), self.coords)], 655 | showlegend=True, 656 | name=gene 657 | )) 658 | fig_plotly.update_layout({ 659 | "plot_bgcolor": "rgba(0, 0, 0, 0)", 660 | "paper_bgcolor": "rgba(0, 0, 0, 0)", 661 | }) 662 | fig_plotly.update_layout( 663 | title=title, 664 | ) 665 | 666 | return fig_plotly 667 | 668 | def px_plot_embedding_massive_gene(self, latent, genes, col_num=7, figsize=400, markersize=2, method="TSNE", save_path=None, title=None): 669 | import plotly.graph_objects as go 670 | from plotly.subplots import make_subplots 671 | 672 | data = np.asarray(self.adata[:, genes].X.todense()).astype(np.float32) 673 | data /= data.max() 674 | 675 | if latent.shape[1] == 2: 676 | embedding = latent 677 | elif method == "UMAP": 678 | from umap import UMAP 679 | embedding = UMAP(n_components=2, random_state=self.seed).fit_transform(latent) 680 | elif method == "TSNE": 681 | from sklearn.manifold import TSNE 682 | embedding = TSNE(n_components=2, random_state=self.seed).fit_transform(latent) 683 | else: 684 | raise ValueError(f"{method} not valid!") 685 | 686 | row_num = np.ceil(len(genes)/col_num).astype(int).item() 687 | fig_plotly = make_subplots(rows=row_num, cols=col_num, 688 | subplot_titles=genes) 689 | for i in range(data.shape[1]): 690 | fig_plotly.add_trace(go.Scatter( 691 | mode='markers', 692 | x=embedding[:, 0], 693 | y=embedding[:, 1], 694 | marker=dict( 695 | color=data[:, i], 696 | colorscale="Magenta", 697 | size=markersize, 698 | ), 699 | hovertemplate="%{text}", 700 | text=[f"Spot: {spot}
x: {coord[0]:.2f}
y: {coord[1]:.2f}" 701 | for spot, coord in zip(np.array(self.adata.obs.index), self.coords)], 702 | showlegend=True, 703 | name=genes[i] 704 | ), row=i // col_num + 1, col=i % col_num + 1) 705 | fig_plotly.update_layout({ 706 | "plot_bgcolor": "#ffffff", 707 | "paper_bgcolor": "#ffffff", 708 | }) 709 | fig_plotly.update_layout( 710 | width=figsize * col_num, 711 | height=figsize * row_num, 712 | title_text=title, 713 | ) 714 | fig_plotly.update_xaxes(visible=False) 715 | fig_plotly.update_yaxes(visible=False) 716 | 717 | if save_path is not None: 718 | if save_path.endswith('html'): 719 | fig_plotly.write_html(save_path) 720 | else: 721 | fig_plotly.write_image(save_path, scale=4) 722 | 723 | return fig_plotly 724 | 725 | 726 | def _map_color(self, label: np.ndarray, style='p_d24') -> list: 727 | if style == 'bright': 728 | color_scheme = ['#FD3216', '#00FE35', '#6A76FC', '#FED4C4', 729 | '#FE00CE', '#0DF9FF', '#F6F926', '#FF9616', 730 | '#479B55', '#EEA6FB', '#DC587D', '#D626FF', 731 | '#6E899C', '#00B5F7', '#B68E00', '#C9FBE5', 732 | '#FF0092', '#22FFA7', '#E3EE9E', '#86CE00', 733 | '#BC7196', '#7E7DCD', '#FC6955', '#E48F72'] 734 | elif style == 'dark': 735 | color_scheme = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2", 736 | "#59a14f", "#edc949", "#af7aa1", "#ff9da7", 737 | "#9c755f", "#bab0ab", "#1b9e77", "#d95f02", 738 | "#7570b3", "#e7298a", "#66a61e", "#e6ab02", 739 | "#a6761d", "#666666"] 740 | elif style == 'p_d24': 741 | color_scheme = ['#2E91E5', '#E15F99', '#1CA71C', '#FB0D0D', 742 | '#DA16FF', '#222A2A', '#B68100', '#750D86', 743 | '#EB663B', '#511CFB', '#00A08B', '#FB00D1', 744 | '#FC0080', '#B2828D', '#6C7C32', '#778AAE', 745 | '#862A16', '#A777F1', '#620042', '#1616A7', 746 | '#DA60CA', '#6C4516', '#0D2A63', '#AF0038', 747 | '#D3D3D3'] 748 | 749 | return [color_scheme[i % len(color_scheme)] for i in label] 750 | 751 | def __len__(self): 752 | return len(self.adata) 753 | --------------------------------------------------------------------------------