├── Img └── architecture.png ├── README.md ├── argument.py ├── embedder.py ├── main.py ├── misc ├── graph_construction.py └── utils.py ├── models ├── __init__.py └── scFP.py └── run.sh /Img/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junseok0207/scFP/b982d91b389f74b803d1300c3328d193e958f9fc/Img/architecture.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Single-cell RNA-seq data imputation using Feature Propagation 2 | 3 |

4 | 5 | 6 | 7 | 8 | The official source code for "Single-cell RNA-seq data imputation using Feature Propagation", accepted at 2023 ICML Workshop on Computational Biology (Contributed talk, Best Paper). 9 | 10 | ## Overview 11 | 12 | While single-cell RNA sequencing provides an understanding of the transcriptome of individual cells, its high sparsity, often termed dropout, hampers the capture of significant cell-cell relationships. Here, we propose scFP (single-cell Feature Propagation), which directly propagates features, i.e., gene expression, especially in raw feature space, via cell-cell graph. Specifically, it first obtains a warmed-up cell-gene matrix via Hard Feature Propagation which fully utilizes known gene transcripts. Then, we refine the k-Nearest Neighbor (kNN) of the cell-cell graph with a warmed up cell-gene matrix, followed by Soft Feature Propagation which now allows known gene transcripts to be further denoised through their neighbors. Through extensive experiments on imputation with cell clustering tasks, we demonstrate our proposed model, scFP, outperforms various recent imputation and clustering methods 13 | 14 | 15 | 16 | ## Requirements 17 | - Python version : 3.9.16 18 | - Pytorch version : 1.10.0 19 | - scanpy : 1.9.3 20 | 21 | ## Download data 22 | 23 | Create the directory to save dataset. 24 | ``` 25 | mkdir dataset 26 | ``` 27 | 28 | You can download preprocessed data [here](https://www.dropbox.com/sh/eaujyhthxjs0d5g/AADzvVv-h2yYWaoOfs1sybKea?dl=0) 29 | 30 | ## How to Run 31 | 32 | You can simply reproduce the result with following codes 33 | ``` 34 | git clone https://github.com/Junseok0207/scFP.git 35 | cd scFP 36 | sh run.sh 37 | ``` 38 | 39 | ## Hyperparameters 40 | 41 | `--name:` 42 | Name of the dataset. 43 | usage example :`--dataset baron_mouse` 44 | 45 | `--k:` 46 | Number of neighbors in cell-cell graph 47 | usage example :`--k 5` 48 | 49 | `--iter:` 50 | Number of iterations in feature propagation 51 | usage example :`--iter 40` 52 | 53 | Using above hyper-parmeters, you can run our model with following codes 54 | 55 | ``` 56 | python main.py --name baron_mouse --k 15 --iter 40 57 | ``` 58 | 59 | -------------------------------------------------------------------------------- /argument.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def str2bool(s): 4 | if s not in {'False', 'True', 'false', 'true'}: 5 | raise ValueError('Not a valid boolean string') 6 | return (s == 'True') or (s == 'true') 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument('--name', type=str, default='baron_mouse', help='baron_mouse, mouse_es, mouse_bladder, zeisel, baron_human') 12 | parser.add_argument('--drop_rate', type=float, default=0.0) 13 | parser.add_argument('--n_runs', type=int, default=3) 14 | 15 | # FP 16 | parser.add_argument('--k', type=int, default=15) 17 | parser.add_argument('--iter', type=int, default=40) 18 | parser.add_argument('--alpha', type=float, default=0.99) 19 | 20 | # setting 21 | parser.add_argument('--device', type=int, default=4) 22 | parser.add_argument('--seed', type=int, default=0) 23 | 24 | # Preprocessing 25 | parser.add_argument('--HVG', type=int, default=2000) 26 | parser.add_argument('--sf', action='store_true', default=True) 27 | parser.add_argument('--log', action='store_true', default=True) 28 | parser.add_argument('--normal', action='store_true', default=False) 29 | 30 | return parser.parse_known_args() 31 | 32 | def enumerateConfig(args): 33 | args_names = [] 34 | args_vals = [] 35 | for arg in vars(args): 36 | args_names.append(arg) 37 | args_vals.append(getattr(args, arg)) 38 | 39 | return args_names, args_vals 40 | 41 | def printConfig(args): 42 | args_names, args_vals = enumerateConfig(args) 43 | st = '' 44 | for name, val in zip(args_names, args_vals): 45 | if val == False: 46 | continue 47 | st_ = "{} <- {} / ".format(name, val) 48 | st += st_ 49 | 50 | return st[:-1] 51 | 52 | def config2string(args): 53 | args_names, args_vals = enumerateConfig(args) 54 | st = '' 55 | for name, val in zip(args_names, args_vals): 56 | if val == False: 57 | continue 58 | if name not in ['device']: 59 | st_ = "{}_{}_".format(name, val) 60 | st += st_ 61 | 62 | return st[:-1] -------------------------------------------------------------------------------- /embedder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import anndata 4 | import scanpy as sc 5 | import numpy as np 6 | import pandas as pd 7 | # from metrics import mse, poisson 8 | from argument import printConfig, config2string 9 | from misc.utils import drop_data 10 | # from data.utils import write_h5ad 11 | from sklearn.metrics import mean_squared_error 12 | import scipy 13 | 14 | from sklearn.cluster import KMeans 15 | from misc.utils import imputation_error, cluster_acc 16 | from sklearn.metrics.cluster import silhouette_score, adjusted_rand_score, normalized_mutual_info_score 17 | 18 | 19 | from sklearn.preprocessing import LabelEncoder 20 | 21 | class embedder: 22 | def __init__(self, args): 23 | self.args = args 24 | printConfig(args) 25 | self.config_str = config2string(args) 26 | self.device = f'cuda:{args.device}' if torch.cuda.is_available() else "cpu" 27 | 28 | self.data_path = f'dataset/{self.args.name}.h5ad' 29 | os.makedirs(os.path.dirname(self.data_path), exist_ok=True) 30 | 31 | self.result_path = f'result/{self.args.name}.txt' 32 | os.makedirs(os.path.dirname(self.result_path), exist_ok=True) 33 | 34 | self._init_dataset() 35 | 36 | def _init_dataset(self): 37 | 38 | self.adata = sc.read(self.data_path) 39 | if self.adata.obs['celltype'].dtype != int: 40 | self.label_encoding() 41 | 42 | self.preprocess(HVG=self.args.HVG, size_factors=self.args.sf, logtrans_input=self.args.log, normalize_input=self.args.normal) 43 | self.adata = drop_data(self.adata, rate=self.args.drop_rate) 44 | 45 | def label_encoding(self): 46 | label_encoder = LabelEncoder() 47 | celltype = self.adata.obs['celltype'] 48 | celltype = label_encoder.fit_transform(celltype) 49 | self.adata.obs['celltype'] = celltype 50 | 51 | def preprocess(self, HVG=2000, size_factors=True, logtrans_input=True, normalize_input=False): 52 | 53 | sc.pp.filter_cells(self.adata, min_counts=1) 54 | sc.pp.filter_genes(self.adata, min_counts=1) 55 | 56 | variance = np.array(self.adata.X.todense().var(axis=0))[0] 57 | hvg_gene_idx = np.argsort(variance)[-int(HVG):] 58 | self.adata = self.adata[:,hvg_gene_idx] 59 | 60 | self.adata.raw = self.adata.copy() 61 | 62 | if size_factors: 63 | sc.pp.normalize_per_cell(self.adata) 64 | self.adata.obs['size_factors'] = self.adata.obs.n_counts / np.median(self.adata.obs.n_counts) 65 | else: 66 | self.adata.obs['size_factors'] = 1.0 67 | 68 | if logtrans_input: 69 | sc.pp.log1p(self.adata) 70 | 71 | if normalize_input: 72 | sc.pp.scale(self.adata) 73 | 74 | 75 | def evaluate(self): 76 | 77 | X_imputed = self.adata.obsm['denoised'] 78 | if self.args.drop_rate != 0.0: 79 | X_test = self.adata.obsm["test"] 80 | drop_index = self.adata.uns['drop_index'] 81 | 82 | rmse, median_l1_distance, cosine_similarity = imputation_error(X_imputed, X_test, drop_index) 83 | 84 | # clustering 85 | celltype = self.adata.obs['celltype'].values 86 | n_cluster = np.unique(celltype).shape[0] 87 | 88 | ### Imputed 89 | kmeans = KMeans(n_cluster, n_init=20, random_state=self.args.seed) 90 | y_pred = kmeans.fit_predict(X_imputed) 91 | 92 | imputed_silhouette = silhouette_score(X_imputed, y_pred) 93 | imputed_ari = adjusted_rand_score(celltype, y_pred) 94 | imputed_nmi = normalized_mutual_info_score(celltype, y_pred) 95 | imputed_ca, imputed_ma_f1, imputed_mi_f1 = cluster_acc(celltype, y_pred) 96 | 97 | ### Reduced 98 | reduced = self.adata.obsm['reduced'] 99 | kmeans = KMeans(n_cluster, n_init=20, random_state=self.args.seed) 100 | y_pred = kmeans.fit_predict(reduced) 101 | 102 | reduced_silhouette = silhouette_score(reduced, y_pred) 103 | reduced_ari = adjusted_rand_score(celltype, y_pred) 104 | reduced_nmi = normalized_mutual_info_score(celltype, y_pred) 105 | reduced_ca, reduced_ma_f1, reduced_mi_f1 = cluster_acc(celltype, y_pred) 106 | 107 | 108 | print(f"Dataset: {self.args.name}, Drop: {self.args.drop_rate}, Alpha: {self.args.alpha}") 109 | print() 110 | if self.args.drop_rate != 0.0: 111 | print("RMSE : {:.4f} / Median L1 Dist : {:.4f} / Cos-Sim : {:.4f}".format(rmse, median_l1_distance, cosine_similarity)) 112 | 113 | print("Imputed --> ARI : {:.4f} / NMI : {:.4f} / ca : {:.4f}\n".format(imputed_ari, imputed_nmi, imputed_ca)) 114 | print("Reduced --> ARI : {:.4f} / NMI : {:.4f} / ca : {:.4f}\n".format(reduced_ari, reduced_nmi, reduced_ca)) 115 | 116 | with open(self.result_path, 'a+') as f: 117 | f.write("{}\n".format(self.config_str)) 118 | if self.args.drop_rate != 0.0: 119 | f.write("Rate {} -> RMSE : {:.4f} / Median L1 Dist : {:.4f} / Cos-Sim : {:.4f}\n".format(self.args.drop_rate, rmse, median_l1_distance, cosine_similarity)) 120 | f.write("(Imputed) Rate {} -> ARI : {:.4f} / NMI : {:.4f} / Silhouette : {:.4f} / ca : {:.4f} / ma-f1 : {:.4f} / mi-f1 : {:.4f}\n".format(self.args.drop_rate, imputed_ari, imputed_nmi, imputed_silhouette, imputed_ca, imputed_ma_f1, imputed_mi_f1)) 121 | f.write("(Reduced) Rate {} -> ARI : {:.4f} / NMI : {:.4f} / Silhouette : {:.4f} / ca : {:.4f} / ma-f1 : {:.4f} / mi-f1 : {:.4f}\n".format(self.args.drop_rate, reduced_ari, reduced_nmi, reduced_silhouette, reduced_ca, reduced_ma_f1, reduced_mi_f1)) 122 | f.write("\n") 123 | 124 | if self.args.drop_rate != 0.0: 125 | return [rmse, median_l1_distance, cosine_similarity, imputed_ari, imputed_nmi, imputed_ca], [reduced_ari, reduced_nmi, reduced_ca] 126 | else: 127 | return [imputed_ari, imputed_nmi, imputed_ca], [reduced_ari, reduced_nmi, reduced_ca] 128 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from misc.utils import set_seed, set_filename, setup_logger 3 | from argument import parse_args 4 | import numpy as np 5 | import datetime 6 | 7 | def main(): 8 | args, _ = parse_args() 9 | torch.set_num_threads(3) 10 | 11 | rmse_list, median_l1_distance_list, cosine_similarity_list = [], [], [] 12 | imputed_ari_list, imputed_nmi_list, imputed_ca_list = [], [], [] 13 | reduced_ari_list, reduced_nmi_list, reduced_ca_list = [], [], [] 14 | 15 | file = set_filename(args) 16 | logger = setup_logger('./', '-', file) 17 | for seed in range(0, args.n_runs): 18 | print(f'Seed: {seed}, Filename: {file}') 19 | set_seed(seed) 20 | args.seed = seed 21 | 22 | from models import scFP_Trainer 23 | embedder = scFP_Trainer(args) 24 | 25 | if (args.drop_rate != 0.0): 26 | [rmse, median_l1_distance, cosine_similarity, imputed_ari, imputed_nmi, imputed_ca], [reduced_ari, reduced_nmi, reduced_ca] = embedder.train() 27 | 28 | rmse_list.append(rmse) 29 | median_l1_distance_list.append(median_l1_distance) 30 | cosine_similarity_list.append(cosine_similarity) 31 | 32 | else: 33 | [imputed_ari, imputed_nmi, imputed_ca], [reduced_ari, reduced_nmi, reduced_ca] = embedder.train() 34 | 35 | imputed_ari_list.append(imputed_ari) 36 | imputed_nmi_list.append(imputed_nmi) 37 | imputed_ca_list.append(imputed_ca) 38 | 39 | reduced_ari_list.append(reduced_ari) 40 | reduced_nmi_list.append(reduced_nmi) 41 | reduced_ca_list.append(reduced_ca) 42 | 43 | logger.info('') 44 | logger.info(datetime.datetime.now()) 45 | logger.info(file) 46 | if args.drop_rate > 0.0: 47 | logger.info(f'-------------------- Drop Rate: {args.drop_rate} --------------------') 48 | logger.info('[Averaged result] RMSE Median_L1 Cosine_Similarity') 49 | logger.info('{:.4f}+{:.4f} {:.4f}+{:.4f} {:.4f}+{:.4f}'.format(np.mean(rmse_list), np.std(rmse_list), np.mean(median_l1_distance_list), np.std(median_l1_distance_list), np.mean(cosine_similarity_list), np.std(cosine_similarity_list))) 50 | logger.info('[Averaged result] (Imputed) ARI NMI CA') 51 | logger.info('{:.4f}+{:.4f} {:.4f}+{:.4f} {:.4f}+{:.4f}'.format(np.mean(imputed_ari_list), np.std(imputed_ari_list), np.mean(imputed_nmi_list), np.std(imputed_nmi_list), np.mean(imputed_ca_list), np.std(imputed_ca_list))) 52 | logger.info('[Averaged result] (Reduced) ARI NMI CA') 53 | logger.info('{:.4f}+{:.4f} {:.4f}+{:.4f} {:.4f}+{:.4f}'.format(np.mean(reduced_ari_list), np.std(reduced_ari_list), np.mean(reduced_nmi_list), np.std(reduced_nmi_list), np.mean(reduced_ca_list), np.std(reduced_ca_list))) 54 | logger.info('') 55 | logger.info(args) 56 | logger.info(f'=================================') 57 | 58 | if __name__ == "__main__": 59 | main() -------------------------------------------------------------------------------- /misc/graph_construction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def top_k(raw_graph, K): 6 | values, indices = raw_graph.topk(k=int(K), dim=-1) 7 | assert torch.max(indices) < raw_graph.shape[1] 8 | mask = torch.zeros(raw_graph.shape).to(raw_graph.device) 9 | mask[torch.arange(raw_graph.shape[0]).view(-1, 1), indices] = 1. 10 | 11 | mask.requires_grad = False 12 | sparse_graph = raw_graph * mask 13 | return sparse_graph 14 | 15 | def post_processing(cur_raw_adj, add_self_loop=True, sym=True, gcn_norm=False): 16 | 17 | if add_self_loop: 18 | num_nodes = cur_raw_adj.size(0) 19 | cur_raw_adj = cur_raw_adj + torch.diag(torch.ones(num_nodes)).to(cur_raw_adj.device) 20 | 21 | if sym: 22 | cur_raw_adj = cur_raw_adj + cur_raw_adj.t() 23 | cur_raw_adj /= 2 24 | 25 | deg = cur_raw_adj.sum(1) 26 | 27 | if gcn_norm: 28 | 29 | deg_inv_sqrt = deg.pow_(-0.5) 30 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.) 31 | deg_inv_sqrt = torch.diag(deg_inv_sqrt) 32 | 33 | cur_adj = torch.mm(deg_inv_sqrt, cur_raw_adj) 34 | cur_adj = torch.mm(cur_adj, deg_inv_sqrt) 35 | 36 | else: 37 | 38 | deg_inv_sqrt = deg.pow_(-1) 39 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.) 40 | deg_inv_sqrt = torch.diag(deg_inv_sqrt) 41 | 42 | cur_adj = torch.mm(deg_inv_sqrt, cur_raw_adj) 43 | 44 | return cur_adj 45 | 46 | def knn_graph(embeddings, k, gcn_norm=False, sym=True): 47 | 48 | device = embeddings.device 49 | embeddings = F.normalize(embeddings, dim=1, p=2) 50 | similarity_graph = torch.mm(embeddings, embeddings.t()) 51 | 52 | X = top_k(similarity_graph.to(device), k + 1) 53 | similarity_graph = F.relu(X) 54 | 55 | cur_adj = post_processing(similarity_graph, gcn_norm=gcn_norm, sym=sym) 56 | 57 | sparse_adj = cur_adj.to_sparse() 58 | edge_index = sparse_adj.indices().detach() 59 | edge_weight = sparse_adj.values() 60 | 61 | return edge_index, edge_weight 62 | -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import random 5 | import numpy as np 6 | import scipy.sparse 7 | import scanpy as sc 8 | from sklearn import metrics 9 | from munkres import Munkres 10 | import logging 11 | import sys 12 | import torch.nn.functional as F 13 | 14 | def set_seed(seed=0): 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | torch.backends.cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = False 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | 22 | def filter_genes_cells(adata): 23 | """Remove empty cells and genes.""" 24 | 25 | if "var_names_all" not in adata.uns: 26 | # fill in original var names before filtering 27 | adata.uns["var_names_all"] = adata.var.index.to_numpy() 28 | sc.pp.filter_genes(adata, min_cells=1) 29 | sc.pp.filter_cells(adata, min_counts=2) 30 | 31 | 32 | def drop_data(adata, rate, datatype='real'): 33 | 34 | X = adata.X 35 | 36 | if scipy.sparse.issparse(X): 37 | X = np.array(X.todense()) 38 | 39 | if datatype == 'real': 40 | X_train = np.copy(X) 41 | i, j = np.nonzero(X) 42 | 43 | ix = np.random.choice(range(len(i)), int( 44 | np.floor(rate * len(i))), replace=False) 45 | X_train[i[ix], j[ix]] = 0.0 46 | 47 | drop_index = {'i':i, 'j':j, 'ix':ix} 48 | adata.uns['drop_index'] = drop_index 49 | adata.obsm["train"] = X_train 50 | adata.obsm["test"] = X 51 | 52 | # for training 53 | adata.raw.X[i[ix],j[ix]] = 0.0 54 | 55 | elif datatype == 'simul': 56 | adata.obsm["train"] = X 57 | 58 | return adata 59 | 60 | 61 | def cosine_similarity(x,y): 62 | x = F.normalize(x, dim=1, p=2) 63 | y = F.normalize(y, dim=1, p=2) 64 | cos_sim = torch.sum(torch.mul(x,y),1) 65 | return cos_sim 66 | 67 | def cos_sim(x,y): 68 | sim = np.dot(x,y)/(np.linalg.norm(x)*np.linalg.norm(y)) 69 | return sim 70 | 71 | def imputation_error(X_hat, X, drop_index): 72 | 73 | i, j, ix = drop_index['i'], drop_index['j'], drop_index['ix'] 74 | 75 | all_index = i[ix], j[ix] 76 | x, y = X_hat[all_index], X[all_index] 77 | 78 | squared_error = (x-y)**2 79 | absolute_error = np.abs(x - y) 80 | 81 | rmse = np.mean(np.sqrt(squared_error)) 82 | median_l1_distance = np.median(absolute_error) 83 | cosine_similarity = cos_sim(x,y) 84 | 85 | return rmse, median_l1_distance, cosine_similarity 86 | 87 | 88 | def reset(value): 89 | if hasattr(value, 'reset_parameters'): 90 | value.reset_parameters() 91 | else: 92 | for child in value.children() if hasattr(value, 'children') else []: 93 | reset(child) 94 | 95 | 96 | def cluster_acc(y_true, y_pred): 97 | 98 | ####### 99 | y_true = y_true.astype(int) 100 | ####### 101 | 102 | y_true = y_true - np.min(y_true) 103 | l1 = list(set(y_true)) 104 | numclass1 = len(l1) 105 | l2 = list(set(y_pred)) 106 | numclass2 = len(l2) 107 | 108 | ind = 0 109 | if numclass1 != numclass2: 110 | for i in l1: 111 | if i in l2: 112 | pass 113 | else: 114 | y_pred[ind] = i 115 | ind += 1 116 | 117 | l2 = list(set(y_pred)) 118 | numclass2 = len(l2) 119 | 120 | if numclass1 != numclass2: 121 | print('n_cluster is not valid') 122 | return 123 | 124 | cost = np.zeros((numclass1, numclass2), dtype=int) 125 | for i, c1 in enumerate(l1): 126 | mps = [i1 for i1, e1 in enumerate(y_true) if e1 == c1] 127 | for j, c2 in enumerate(l2): 128 | mps_d = [i1 for i1 in mps if y_pred[i1] == c2] 129 | cost[i][j] = len(mps_d) 130 | 131 | m = Munkres() 132 | cost = cost.__neg__().tolist() 133 | indexes = m.compute(cost) 134 | 135 | new_predict = np.zeros(len(y_pred)) 136 | for i, c in enumerate(l1): 137 | c2 = l2[indexes[i][1]] 138 | ai = [ind for ind, elm in enumerate(y_pred) if elm == c2] 139 | new_predict[ai] = c 140 | 141 | acc = metrics.accuracy_score(y_true, new_predict) 142 | # y_true:Like 1d array or label indicator array/sparse matrix (correct) label 143 | # y_pred:Like a one-dimensional array or label indicator array/sparse matrix predicted labels, returned by the classifier 144 | 145 | f1_macro = metrics.f1_score(y_true, new_predict, average='macro') 146 | f1_micro = metrics.f1_score(y_true, new_predict, average='micro') 147 | 148 | return acc, f1_macro, f1_micro 149 | 150 | 151 | 152 | def setup_logger(save_dir, text, filename = 'log.txt'): 153 | os.makedirs(save_dir, exist_ok=True) 154 | logger = logging.getLogger(text) 155 | # for each in logger.handlers: 156 | # logger.removeHandler(each) 157 | logger.setLevel(4) 158 | ch = logging.StreamHandler(stream=sys.stdout) 159 | ch.setLevel(logging.DEBUG) 160 | formatter = logging.Formatter("%(message)s") 161 | ch.setFormatter(formatter) 162 | logger.addHandler(ch) 163 | if save_dir: 164 | fh = logging.FileHandler(os.path.join(save_dir, filename)) 165 | fh.setLevel(logging.DEBUG) 166 | fh.setFormatter(formatter) 167 | logger.addHandler(fh) 168 | logger.info("======================================================================================") 169 | 170 | return logger 171 | 172 | def set_filename(args): 173 | # runs = '_n_runs_10' if args.n_runs == 10 else '' 174 | runs = f'n_runs_{args.n_runs}' 175 | if args.drop_rate > 0.0: 176 | logs_path = f'logs_{runs}/imputation/{args.name}' 177 | else: 178 | logs_path = f'logs_{runs}/clustering/{args.name}' 179 | 180 | os.makedirs(logs_path, exist_ok=True) 181 | 182 | file = f'{logs_path}/scFP.txt' 183 | 184 | return file 185 | 186 | def get_gene(x): 187 | if 'symbol' in x.keys(): 188 | return x['symbol'] 189 | return '' 190 | 191 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .scFP import scFP_Trainer -------------------------------------------------------------------------------- /models/scFP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | from embedder import embedder 4 | from misc.graph_construction import knn_graph 5 | import numpy as np 6 | from sklearn.decomposition import PCA 7 | 8 | 9 | class scFP_Trainer(embedder): 10 | def __init__(self, args): 11 | embedder.__init__(self, args) 12 | self.args = args 13 | self.args.n_nodes = self.adata.obsm["train"].shape[0] 14 | self.args.n_feat = self.adata.obsm["train"].shape[1] 15 | 16 | def train(self): 17 | cell_data = torch.Tensor(self.adata.obsm["train"]).to(self.args.device) 18 | 19 | # Hard FP 20 | print('Start Hard Feature Propagation ...!') 21 | edge_index, edge_weight = knn_graph(cell_data, self.args.k) 22 | self.model = FeaturePropagation(num_iterations=self.args.iter, mask=True, alpha=0.0) 23 | self.model = self.model.to(self.device) 24 | 25 | denoised_matrix = self.model(cell_data, edge_index, edge_weight) 26 | 27 | # Soft FP 28 | print('Start Soft Feature Propagation ...!') 29 | edge_index_new, edge_weight_new = knn_graph(denoised_matrix, self.args.k) 30 | self.model = FeaturePropagation(num_iterations=self.args.iter, mask=False, alpha=self.args.alpha) 31 | self.model = self.model.to(self.device) 32 | 33 | denoised_matrix = self.model(denoised_matrix, edge_index_new, edge_weight_new) 34 | 35 | # reduced 36 | pca = PCA(n_components = 32) 37 | denoised_matrix = denoised_matrix.detach().cpu().numpy() 38 | reduced = pca.fit_transform(denoised_matrix) 39 | 40 | self.adata.obsm['denoised'] = denoised_matrix 41 | self.adata.obsm['reduced'] = reduced 42 | 43 | return self.evaluate() 44 | 45 | class FeaturePropagation(torch.nn.Module): 46 | def __init__(self, num_iterations, mask, alpha=0.0): 47 | super(FeaturePropagation, self).__init__() 48 | self.num_iterations = num_iterations 49 | self.mask = mask 50 | self.alpha = alpha 51 | 52 | def forward(self, x, edge_index, edge_weight): 53 | original_x = copy.copy(x) 54 | nonzero_idx = torch.nonzero(x) 55 | nonzero_i, nonzero_j = nonzero_idx.t() 56 | 57 | out = x 58 | n_nodes = x.shape[0] 59 | adj = torch.sparse.FloatTensor(edge_index, values=edge_weight, size=(n_nodes, n_nodes)).to(edge_index.device) 60 | adj = adj.float() 61 | 62 | 63 | res = (1-self.alpha) * out 64 | for _ in range(self.num_iterations): 65 | out = torch.sparse.mm(adj, out) 66 | if self.mask: 67 | out[nonzero_i, nonzero_j] = original_x[nonzero_i, nonzero_j] 68 | else: 69 | out.mul_(self.alpha).add_(res) 70 | 71 | return out 72 | 73 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Imputation (table 1) 4 | for drop_rate in 0.2 0.4 0.8 5 | do 6 | python main.py --name baron_mouse --drop_rate $drop_rate 7 | python main.py --name mouse_es --drop_rate $drop_rate 8 | python main.py --name mouse_bladder --drop_rate $drop_rate 9 | python main.py --name zeisel --drop_rate $drop_rate 10 | python main.py --name baron_human --drop_rate $drop_rate 11 | done 12 | 13 | 14 | # Clustering (table 2) 15 | python main.py --name baron_mouse 16 | python main.py --name mouse_es 17 | python main.py --name mouse_bladder 18 | python main.py --name zeisel 19 | python main.py --name baron_human 20 | 21 | # appendix 22 | python main.py --name shekhar --------------------------------------------------------------------------------