├── Img └── architecture.png ├── README.md ├── argument.py ├── embedder.py ├── main.py ├── misc ├── graph_construction.py └── utils.py ├── models ├── __init__.py └── scFP.py └── run.sh /Img/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junseok0207/scFP/b982d91b389f74b803d1300c3328d193e958f9fc/Img/architecture.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Single-cell RNA-seq data imputation using Feature Propagation 2 | 3 |
4 |
5 |
6 |
7 |
8 | The official source code for "Single-cell RNA-seq data imputation using Feature Propagation", accepted at 2023 ICML Workshop on Computational Biology (Contributed talk, Best Paper).
9 |
10 | ## Overview
11 |
12 | While single-cell RNA sequencing provides an understanding of the transcriptome of individual cells, its high sparsity, often termed dropout, hampers the capture of significant cell-cell relationships. Here, we propose scFP (single-cell Feature Propagation), which directly propagates features, i.e., gene expression, especially in raw feature space, via cell-cell graph. Specifically, it first obtains a warmed-up cell-gene matrix via Hard Feature Propagation which fully utilizes known gene transcripts. Then, we refine the k-Nearest Neighbor (kNN) of the cell-cell graph with a warmed up cell-gene matrix, followed by Soft Feature Propagation which now allows known gene transcripts to be further denoised through their neighbors. Through extensive experiments on imputation with cell clustering tasks, we demonstrate our proposed model, scFP, outperforms various recent imputation and clustering methods
13 |
14 |
15 |
16 | ## Requirements
17 | - Python version : 3.9.16
18 | - Pytorch version : 1.10.0
19 | - scanpy : 1.9.3
20 |
21 | ## Download data
22 |
23 | Create the directory to save dataset.
24 | ```
25 | mkdir dataset
26 | ```
27 |
28 | You can download preprocessed data [here](https://www.dropbox.com/sh/eaujyhthxjs0d5g/AADzvVv-h2yYWaoOfs1sybKea?dl=0)
29 |
30 | ## How to Run
31 |
32 | You can simply reproduce the result with following codes
33 | ```
34 | git clone https://github.com/Junseok0207/scFP.git
35 | cd scFP
36 | sh run.sh
37 | ```
38 |
39 | ## Hyperparameters
40 |
41 | `--name:`
42 | Name of the dataset.
43 | usage example :`--dataset baron_mouse`
44 |
45 | `--k:`
46 | Number of neighbors in cell-cell graph
47 | usage example :`--k 5`
48 |
49 | `--iter:`
50 | Number of iterations in feature propagation
51 | usage example :`--iter 40`
52 |
53 | Using above hyper-parmeters, you can run our model with following codes
54 |
55 | ```
56 | python main.py --name baron_mouse --k 15 --iter 40
57 | ```
58 |
59 |
--------------------------------------------------------------------------------
/argument.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def str2bool(s):
4 | if s not in {'False', 'True', 'false', 'true'}:
5 | raise ValueError('Not a valid boolean string')
6 | return (s == 'True') or (s == 'true')
7 |
8 | def parse_args():
9 | parser = argparse.ArgumentParser()
10 |
11 | parser.add_argument('--name', type=str, default='baron_mouse', help='baron_mouse, mouse_es, mouse_bladder, zeisel, baron_human')
12 | parser.add_argument('--drop_rate', type=float, default=0.0)
13 | parser.add_argument('--n_runs', type=int, default=3)
14 |
15 | # FP
16 | parser.add_argument('--k', type=int, default=15)
17 | parser.add_argument('--iter', type=int, default=40)
18 | parser.add_argument('--alpha', type=float, default=0.99)
19 |
20 | # setting
21 | parser.add_argument('--device', type=int, default=4)
22 | parser.add_argument('--seed', type=int, default=0)
23 |
24 | # Preprocessing
25 | parser.add_argument('--HVG', type=int, default=2000)
26 | parser.add_argument('--sf', action='store_true', default=True)
27 | parser.add_argument('--log', action='store_true', default=True)
28 | parser.add_argument('--normal', action='store_true', default=False)
29 |
30 | return parser.parse_known_args()
31 |
32 | def enumerateConfig(args):
33 | args_names = []
34 | args_vals = []
35 | for arg in vars(args):
36 | args_names.append(arg)
37 | args_vals.append(getattr(args, arg))
38 |
39 | return args_names, args_vals
40 |
41 | def printConfig(args):
42 | args_names, args_vals = enumerateConfig(args)
43 | st = ''
44 | for name, val in zip(args_names, args_vals):
45 | if val == False:
46 | continue
47 | st_ = "{} <- {} / ".format(name, val)
48 | st += st_
49 |
50 | return st[:-1]
51 |
52 | def config2string(args):
53 | args_names, args_vals = enumerateConfig(args)
54 | st = ''
55 | for name, val in zip(args_names, args_vals):
56 | if val == False:
57 | continue
58 | if name not in ['device']:
59 | st_ = "{}_{}_".format(name, val)
60 | st += st_
61 |
62 | return st[:-1]
--------------------------------------------------------------------------------
/embedder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import anndata
4 | import scanpy as sc
5 | import numpy as np
6 | import pandas as pd
7 | # from metrics import mse, poisson
8 | from argument import printConfig, config2string
9 | from misc.utils import drop_data
10 | # from data.utils import write_h5ad
11 | from sklearn.metrics import mean_squared_error
12 | import scipy
13 |
14 | from sklearn.cluster import KMeans
15 | from misc.utils import imputation_error, cluster_acc
16 | from sklearn.metrics.cluster import silhouette_score, adjusted_rand_score, normalized_mutual_info_score
17 |
18 |
19 | from sklearn.preprocessing import LabelEncoder
20 |
21 | class embedder:
22 | def __init__(self, args):
23 | self.args = args
24 | printConfig(args)
25 | self.config_str = config2string(args)
26 | self.device = f'cuda:{args.device}' if torch.cuda.is_available() else "cpu"
27 |
28 | self.data_path = f'dataset/{self.args.name}.h5ad'
29 | os.makedirs(os.path.dirname(self.data_path), exist_ok=True)
30 |
31 | self.result_path = f'result/{self.args.name}.txt'
32 | os.makedirs(os.path.dirname(self.result_path), exist_ok=True)
33 |
34 | self._init_dataset()
35 |
36 | def _init_dataset(self):
37 |
38 | self.adata = sc.read(self.data_path)
39 | if self.adata.obs['celltype'].dtype != int:
40 | self.label_encoding()
41 |
42 | self.preprocess(HVG=self.args.HVG, size_factors=self.args.sf, logtrans_input=self.args.log, normalize_input=self.args.normal)
43 | self.adata = drop_data(self.adata, rate=self.args.drop_rate)
44 |
45 | def label_encoding(self):
46 | label_encoder = LabelEncoder()
47 | celltype = self.adata.obs['celltype']
48 | celltype = label_encoder.fit_transform(celltype)
49 | self.adata.obs['celltype'] = celltype
50 |
51 | def preprocess(self, HVG=2000, size_factors=True, logtrans_input=True, normalize_input=False):
52 |
53 | sc.pp.filter_cells(self.adata, min_counts=1)
54 | sc.pp.filter_genes(self.adata, min_counts=1)
55 |
56 | variance = np.array(self.adata.X.todense().var(axis=0))[0]
57 | hvg_gene_idx = np.argsort(variance)[-int(HVG):]
58 | self.adata = self.adata[:,hvg_gene_idx]
59 |
60 | self.adata.raw = self.adata.copy()
61 |
62 | if size_factors:
63 | sc.pp.normalize_per_cell(self.adata)
64 | self.adata.obs['size_factors'] = self.adata.obs.n_counts / np.median(self.adata.obs.n_counts)
65 | else:
66 | self.adata.obs['size_factors'] = 1.0
67 |
68 | if logtrans_input:
69 | sc.pp.log1p(self.adata)
70 |
71 | if normalize_input:
72 | sc.pp.scale(self.adata)
73 |
74 |
75 | def evaluate(self):
76 |
77 | X_imputed = self.adata.obsm['denoised']
78 | if self.args.drop_rate != 0.0:
79 | X_test = self.adata.obsm["test"]
80 | drop_index = self.adata.uns['drop_index']
81 |
82 | rmse, median_l1_distance, cosine_similarity = imputation_error(X_imputed, X_test, drop_index)
83 |
84 | # clustering
85 | celltype = self.adata.obs['celltype'].values
86 | n_cluster = np.unique(celltype).shape[0]
87 |
88 | ### Imputed
89 | kmeans = KMeans(n_cluster, n_init=20, random_state=self.args.seed)
90 | y_pred = kmeans.fit_predict(X_imputed)
91 |
92 | imputed_silhouette = silhouette_score(X_imputed, y_pred)
93 | imputed_ari = adjusted_rand_score(celltype, y_pred)
94 | imputed_nmi = normalized_mutual_info_score(celltype, y_pred)
95 | imputed_ca, imputed_ma_f1, imputed_mi_f1 = cluster_acc(celltype, y_pred)
96 |
97 | ### Reduced
98 | reduced = self.adata.obsm['reduced']
99 | kmeans = KMeans(n_cluster, n_init=20, random_state=self.args.seed)
100 | y_pred = kmeans.fit_predict(reduced)
101 |
102 | reduced_silhouette = silhouette_score(reduced, y_pred)
103 | reduced_ari = adjusted_rand_score(celltype, y_pred)
104 | reduced_nmi = normalized_mutual_info_score(celltype, y_pred)
105 | reduced_ca, reduced_ma_f1, reduced_mi_f1 = cluster_acc(celltype, y_pred)
106 |
107 |
108 | print(f"Dataset: {self.args.name}, Drop: {self.args.drop_rate}, Alpha: {self.args.alpha}")
109 | print()
110 | if self.args.drop_rate != 0.0:
111 | print("RMSE : {:.4f} / Median L1 Dist : {:.4f} / Cos-Sim : {:.4f}".format(rmse, median_l1_distance, cosine_similarity))
112 |
113 | print("Imputed --> ARI : {:.4f} / NMI : {:.4f} / ca : {:.4f}\n".format(imputed_ari, imputed_nmi, imputed_ca))
114 | print("Reduced --> ARI : {:.4f} / NMI : {:.4f} / ca : {:.4f}\n".format(reduced_ari, reduced_nmi, reduced_ca))
115 |
116 | with open(self.result_path, 'a+') as f:
117 | f.write("{}\n".format(self.config_str))
118 | if self.args.drop_rate != 0.0:
119 | f.write("Rate {} -> RMSE : {:.4f} / Median L1 Dist : {:.4f} / Cos-Sim : {:.4f}\n".format(self.args.drop_rate, rmse, median_l1_distance, cosine_similarity))
120 | f.write("(Imputed) Rate {} -> ARI : {:.4f} / NMI : {:.4f} / Silhouette : {:.4f} / ca : {:.4f} / ma-f1 : {:.4f} / mi-f1 : {:.4f}\n".format(self.args.drop_rate, imputed_ari, imputed_nmi, imputed_silhouette, imputed_ca, imputed_ma_f1, imputed_mi_f1))
121 | f.write("(Reduced) Rate {} -> ARI : {:.4f} / NMI : {:.4f} / Silhouette : {:.4f} / ca : {:.4f} / ma-f1 : {:.4f} / mi-f1 : {:.4f}\n".format(self.args.drop_rate, reduced_ari, reduced_nmi, reduced_silhouette, reduced_ca, reduced_ma_f1, reduced_mi_f1))
122 | f.write("\n")
123 |
124 | if self.args.drop_rate != 0.0:
125 | return [rmse, median_l1_distance, cosine_similarity, imputed_ari, imputed_nmi, imputed_ca], [reduced_ari, reduced_nmi, reduced_ca]
126 | else:
127 | return [imputed_ari, imputed_nmi, imputed_ca], [reduced_ari, reduced_nmi, reduced_ca]
128 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from misc.utils import set_seed, set_filename, setup_logger
3 | from argument import parse_args
4 | import numpy as np
5 | import datetime
6 |
7 | def main():
8 | args, _ = parse_args()
9 | torch.set_num_threads(3)
10 |
11 | rmse_list, median_l1_distance_list, cosine_similarity_list = [], [], []
12 | imputed_ari_list, imputed_nmi_list, imputed_ca_list = [], [], []
13 | reduced_ari_list, reduced_nmi_list, reduced_ca_list = [], [], []
14 |
15 | file = set_filename(args)
16 | logger = setup_logger('./', '-', file)
17 | for seed in range(0, args.n_runs):
18 | print(f'Seed: {seed}, Filename: {file}')
19 | set_seed(seed)
20 | args.seed = seed
21 |
22 | from models import scFP_Trainer
23 | embedder = scFP_Trainer(args)
24 |
25 | if (args.drop_rate != 0.0):
26 | [rmse, median_l1_distance, cosine_similarity, imputed_ari, imputed_nmi, imputed_ca], [reduced_ari, reduced_nmi, reduced_ca] = embedder.train()
27 |
28 | rmse_list.append(rmse)
29 | median_l1_distance_list.append(median_l1_distance)
30 | cosine_similarity_list.append(cosine_similarity)
31 |
32 | else:
33 | [imputed_ari, imputed_nmi, imputed_ca], [reduced_ari, reduced_nmi, reduced_ca] = embedder.train()
34 |
35 | imputed_ari_list.append(imputed_ari)
36 | imputed_nmi_list.append(imputed_nmi)
37 | imputed_ca_list.append(imputed_ca)
38 |
39 | reduced_ari_list.append(reduced_ari)
40 | reduced_nmi_list.append(reduced_nmi)
41 | reduced_ca_list.append(reduced_ca)
42 |
43 | logger.info('')
44 | logger.info(datetime.datetime.now())
45 | logger.info(file)
46 | if args.drop_rate > 0.0:
47 | logger.info(f'-------------------- Drop Rate: {args.drop_rate} --------------------')
48 | logger.info('[Averaged result] RMSE Median_L1 Cosine_Similarity')
49 | logger.info('{:.4f}+{:.4f} {:.4f}+{:.4f} {:.4f}+{:.4f}'.format(np.mean(rmse_list), np.std(rmse_list), np.mean(median_l1_distance_list), np.std(median_l1_distance_list), np.mean(cosine_similarity_list), np.std(cosine_similarity_list)))
50 | logger.info('[Averaged result] (Imputed) ARI NMI CA')
51 | logger.info('{:.4f}+{:.4f} {:.4f}+{:.4f} {:.4f}+{:.4f}'.format(np.mean(imputed_ari_list), np.std(imputed_ari_list), np.mean(imputed_nmi_list), np.std(imputed_nmi_list), np.mean(imputed_ca_list), np.std(imputed_ca_list)))
52 | logger.info('[Averaged result] (Reduced) ARI NMI CA')
53 | logger.info('{:.4f}+{:.4f} {:.4f}+{:.4f} {:.4f}+{:.4f}'.format(np.mean(reduced_ari_list), np.std(reduced_ari_list), np.mean(reduced_nmi_list), np.std(reduced_nmi_list), np.mean(reduced_ca_list), np.std(reduced_ca_list)))
54 | logger.info('')
55 | logger.info(args)
56 | logger.info(f'=================================')
57 |
58 | if __name__ == "__main__":
59 | main()
--------------------------------------------------------------------------------
/misc/graph_construction.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def top_k(raw_graph, K):
6 | values, indices = raw_graph.topk(k=int(K), dim=-1)
7 | assert torch.max(indices) < raw_graph.shape[1]
8 | mask = torch.zeros(raw_graph.shape).to(raw_graph.device)
9 | mask[torch.arange(raw_graph.shape[0]).view(-1, 1), indices] = 1.
10 |
11 | mask.requires_grad = False
12 | sparse_graph = raw_graph * mask
13 | return sparse_graph
14 |
15 | def post_processing(cur_raw_adj, add_self_loop=True, sym=True, gcn_norm=False):
16 |
17 | if add_self_loop:
18 | num_nodes = cur_raw_adj.size(0)
19 | cur_raw_adj = cur_raw_adj + torch.diag(torch.ones(num_nodes)).to(cur_raw_adj.device)
20 |
21 | if sym:
22 | cur_raw_adj = cur_raw_adj + cur_raw_adj.t()
23 | cur_raw_adj /= 2
24 |
25 | deg = cur_raw_adj.sum(1)
26 |
27 | if gcn_norm:
28 |
29 | deg_inv_sqrt = deg.pow_(-0.5)
30 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
31 | deg_inv_sqrt = torch.diag(deg_inv_sqrt)
32 |
33 | cur_adj = torch.mm(deg_inv_sqrt, cur_raw_adj)
34 | cur_adj = torch.mm(cur_adj, deg_inv_sqrt)
35 |
36 | else:
37 |
38 | deg_inv_sqrt = deg.pow_(-1)
39 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
40 | deg_inv_sqrt = torch.diag(deg_inv_sqrt)
41 |
42 | cur_adj = torch.mm(deg_inv_sqrt, cur_raw_adj)
43 |
44 | return cur_adj
45 |
46 | def knn_graph(embeddings, k, gcn_norm=False, sym=True):
47 |
48 | device = embeddings.device
49 | embeddings = F.normalize(embeddings, dim=1, p=2)
50 | similarity_graph = torch.mm(embeddings, embeddings.t())
51 |
52 | X = top_k(similarity_graph.to(device), k + 1)
53 | similarity_graph = F.relu(X)
54 |
55 | cur_adj = post_processing(similarity_graph, gcn_norm=gcn_norm, sym=sym)
56 |
57 | sparse_adj = cur_adj.to_sparse()
58 | edge_index = sparse_adj.indices().detach()
59 | edge_weight = sparse_adj.values()
60 |
61 | return edge_index, edge_weight
62 |
--------------------------------------------------------------------------------
/misc/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import torch
4 | import random
5 | import numpy as np
6 | import scipy.sparse
7 | import scanpy as sc
8 | from sklearn import metrics
9 | from munkres import Munkres
10 | import logging
11 | import sys
12 | import torch.nn.functional as F
13 |
14 | def set_seed(seed=0):
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed_all(seed)
17 | torch.backends.cudnn.deterministic = True
18 | torch.backends.cudnn.benchmark = False
19 | random.seed(seed)
20 | np.random.seed(seed)
21 |
22 | def filter_genes_cells(adata):
23 | """Remove empty cells and genes."""
24 |
25 | if "var_names_all" not in adata.uns:
26 | # fill in original var names before filtering
27 | adata.uns["var_names_all"] = adata.var.index.to_numpy()
28 | sc.pp.filter_genes(adata, min_cells=1)
29 | sc.pp.filter_cells(adata, min_counts=2)
30 |
31 |
32 | def drop_data(adata, rate, datatype='real'):
33 |
34 | X = adata.X
35 |
36 | if scipy.sparse.issparse(X):
37 | X = np.array(X.todense())
38 |
39 | if datatype == 'real':
40 | X_train = np.copy(X)
41 | i, j = np.nonzero(X)
42 |
43 | ix = np.random.choice(range(len(i)), int(
44 | np.floor(rate * len(i))), replace=False)
45 | X_train[i[ix], j[ix]] = 0.0
46 |
47 | drop_index = {'i':i, 'j':j, 'ix':ix}
48 | adata.uns['drop_index'] = drop_index
49 | adata.obsm["train"] = X_train
50 | adata.obsm["test"] = X
51 |
52 | # for training
53 | adata.raw.X[i[ix],j[ix]] = 0.0
54 |
55 | elif datatype == 'simul':
56 | adata.obsm["train"] = X
57 |
58 | return adata
59 |
60 |
61 | def cosine_similarity(x,y):
62 | x = F.normalize(x, dim=1, p=2)
63 | y = F.normalize(y, dim=1, p=2)
64 | cos_sim = torch.sum(torch.mul(x,y),1)
65 | return cos_sim
66 |
67 | def cos_sim(x,y):
68 | sim = np.dot(x,y)/(np.linalg.norm(x)*np.linalg.norm(y))
69 | return sim
70 |
71 | def imputation_error(X_hat, X, drop_index):
72 |
73 | i, j, ix = drop_index['i'], drop_index['j'], drop_index['ix']
74 |
75 | all_index = i[ix], j[ix]
76 | x, y = X_hat[all_index], X[all_index]
77 |
78 | squared_error = (x-y)**2
79 | absolute_error = np.abs(x - y)
80 |
81 | rmse = np.mean(np.sqrt(squared_error))
82 | median_l1_distance = np.median(absolute_error)
83 | cosine_similarity = cos_sim(x,y)
84 |
85 | return rmse, median_l1_distance, cosine_similarity
86 |
87 |
88 | def reset(value):
89 | if hasattr(value, 'reset_parameters'):
90 | value.reset_parameters()
91 | else:
92 | for child in value.children() if hasattr(value, 'children') else []:
93 | reset(child)
94 |
95 |
96 | def cluster_acc(y_true, y_pred):
97 |
98 | #######
99 | y_true = y_true.astype(int)
100 | #######
101 |
102 | y_true = y_true - np.min(y_true)
103 | l1 = list(set(y_true))
104 | numclass1 = len(l1)
105 | l2 = list(set(y_pred))
106 | numclass2 = len(l2)
107 |
108 | ind = 0
109 | if numclass1 != numclass2:
110 | for i in l1:
111 | if i in l2:
112 | pass
113 | else:
114 | y_pred[ind] = i
115 | ind += 1
116 |
117 | l2 = list(set(y_pred))
118 | numclass2 = len(l2)
119 |
120 | if numclass1 != numclass2:
121 | print('n_cluster is not valid')
122 | return
123 |
124 | cost = np.zeros((numclass1, numclass2), dtype=int)
125 | for i, c1 in enumerate(l1):
126 | mps = [i1 for i1, e1 in enumerate(y_true) if e1 == c1]
127 | for j, c2 in enumerate(l2):
128 | mps_d = [i1 for i1 in mps if y_pred[i1] == c2]
129 | cost[i][j] = len(mps_d)
130 |
131 | m = Munkres()
132 | cost = cost.__neg__().tolist()
133 | indexes = m.compute(cost)
134 |
135 | new_predict = np.zeros(len(y_pred))
136 | for i, c in enumerate(l1):
137 | c2 = l2[indexes[i][1]]
138 | ai = [ind for ind, elm in enumerate(y_pred) if elm == c2]
139 | new_predict[ai] = c
140 |
141 | acc = metrics.accuracy_score(y_true, new_predict)
142 | # y_true:Like 1d array or label indicator array/sparse matrix (correct) label
143 | # y_pred:Like a one-dimensional array or label indicator array/sparse matrix predicted labels, returned by the classifier
144 |
145 | f1_macro = metrics.f1_score(y_true, new_predict, average='macro')
146 | f1_micro = metrics.f1_score(y_true, new_predict, average='micro')
147 |
148 | return acc, f1_macro, f1_micro
149 |
150 |
151 |
152 | def setup_logger(save_dir, text, filename = 'log.txt'):
153 | os.makedirs(save_dir, exist_ok=True)
154 | logger = logging.getLogger(text)
155 | # for each in logger.handlers:
156 | # logger.removeHandler(each)
157 | logger.setLevel(4)
158 | ch = logging.StreamHandler(stream=sys.stdout)
159 | ch.setLevel(logging.DEBUG)
160 | formatter = logging.Formatter("%(message)s")
161 | ch.setFormatter(formatter)
162 | logger.addHandler(ch)
163 | if save_dir:
164 | fh = logging.FileHandler(os.path.join(save_dir, filename))
165 | fh.setLevel(logging.DEBUG)
166 | fh.setFormatter(formatter)
167 | logger.addHandler(fh)
168 | logger.info("======================================================================================")
169 |
170 | return logger
171 |
172 | def set_filename(args):
173 | # runs = '_n_runs_10' if args.n_runs == 10 else ''
174 | runs = f'n_runs_{args.n_runs}'
175 | if args.drop_rate > 0.0:
176 | logs_path = f'logs_{runs}/imputation/{args.name}'
177 | else:
178 | logs_path = f'logs_{runs}/clustering/{args.name}'
179 |
180 | os.makedirs(logs_path, exist_ok=True)
181 |
182 | file = f'{logs_path}/scFP.txt'
183 |
184 | return file
185 |
186 | def get_gene(x):
187 | if 'symbol' in x.keys():
188 | return x['symbol']
189 | return ''
190 |
191 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .scFP import scFP_Trainer
--------------------------------------------------------------------------------
/models/scFP.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import copy
3 | from embedder import embedder
4 | from misc.graph_construction import knn_graph
5 | import numpy as np
6 | from sklearn.decomposition import PCA
7 |
8 |
9 | class scFP_Trainer(embedder):
10 | def __init__(self, args):
11 | embedder.__init__(self, args)
12 | self.args = args
13 | self.args.n_nodes = self.adata.obsm["train"].shape[0]
14 | self.args.n_feat = self.adata.obsm["train"].shape[1]
15 |
16 | def train(self):
17 | cell_data = torch.Tensor(self.adata.obsm["train"]).to(self.args.device)
18 |
19 | # Hard FP
20 | print('Start Hard Feature Propagation ...!')
21 | edge_index, edge_weight = knn_graph(cell_data, self.args.k)
22 | self.model = FeaturePropagation(num_iterations=self.args.iter, mask=True, alpha=0.0)
23 | self.model = self.model.to(self.device)
24 |
25 | denoised_matrix = self.model(cell_data, edge_index, edge_weight)
26 |
27 | # Soft FP
28 | print('Start Soft Feature Propagation ...!')
29 | edge_index_new, edge_weight_new = knn_graph(denoised_matrix, self.args.k)
30 | self.model = FeaturePropagation(num_iterations=self.args.iter, mask=False, alpha=self.args.alpha)
31 | self.model = self.model.to(self.device)
32 |
33 | denoised_matrix = self.model(denoised_matrix, edge_index_new, edge_weight_new)
34 |
35 | # reduced
36 | pca = PCA(n_components = 32)
37 | denoised_matrix = denoised_matrix.detach().cpu().numpy()
38 | reduced = pca.fit_transform(denoised_matrix)
39 |
40 | self.adata.obsm['denoised'] = denoised_matrix
41 | self.adata.obsm['reduced'] = reduced
42 |
43 | return self.evaluate()
44 |
45 | class FeaturePropagation(torch.nn.Module):
46 | def __init__(self, num_iterations, mask, alpha=0.0):
47 | super(FeaturePropagation, self).__init__()
48 | self.num_iterations = num_iterations
49 | self.mask = mask
50 | self.alpha = alpha
51 |
52 | def forward(self, x, edge_index, edge_weight):
53 | original_x = copy.copy(x)
54 | nonzero_idx = torch.nonzero(x)
55 | nonzero_i, nonzero_j = nonzero_idx.t()
56 |
57 | out = x
58 | n_nodes = x.shape[0]
59 | adj = torch.sparse.FloatTensor(edge_index, values=edge_weight, size=(n_nodes, n_nodes)).to(edge_index.device)
60 | adj = adj.float()
61 |
62 |
63 | res = (1-self.alpha) * out
64 | for _ in range(self.num_iterations):
65 | out = torch.sparse.mm(adj, out)
66 | if self.mask:
67 | out[nonzero_i, nonzero_j] = original_x[nonzero_i, nonzero_j]
68 | else:
69 | out.mul_(self.alpha).add_(res)
70 |
71 | return out
72 |
73 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Imputation (table 1)
4 | for drop_rate in 0.2 0.4 0.8
5 | do
6 | python main.py --name baron_mouse --drop_rate $drop_rate
7 | python main.py --name mouse_es --drop_rate $drop_rate
8 | python main.py --name mouse_bladder --drop_rate $drop_rate
9 | python main.py --name zeisel --drop_rate $drop_rate
10 | python main.py --name baron_human --drop_rate $drop_rate
11 | done
12 |
13 |
14 | # Clustering (table 2)
15 | python main.py --name baron_mouse
16 | python main.py --name mouse_es
17 | python main.py --name mouse_bladder
18 | python main.py --name zeisel
19 | python main.py --name baron_human
20 |
21 | # appendix
22 | python main.py --name shekhar
--------------------------------------------------------------------------------