├── Data ├── GeneratingExampleData.py └── example_config.yaml ├── DataProcess ├── DataProcess.py ├── geneList.pkl └── myUtil.py ├── LICENSE ├── MANIFEST.in ├── Output └── training.log ├── README.md ├── Supp_code ├── GI_experiment │ ├── GI_evaluation.py │ ├── Gears_results.pkl │ ├── STAMP_results.pkl │ ├── Science_data.csv │ └── Truth_results.pkl ├── Modified_CPA │ ├── _api.py │ ├── _data.py │ ├── _metrics.py │ ├── _model.py │ ├── _module.py │ ├── _plotting.py │ ├── _task.py │ └── _utils.py └── Modified_GEARS │ ├── data_utils.py │ ├── gears.py │ ├── inference.py │ ├── model.py │ ├── pertdata.py │ └── utils.py ├── Tutorial └── tutorial_for_training.py.ipynb ├── environment.yml ├── img └── framework.png ├── requirements.txt ├── setup.py └── stamp ├── DataSet.py ├── Modules.py ├── STAMP.py ├── __init__.py ├── utils.py └── version.py /Data/GeneratingExampleData.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scanpy as sc 3 | import anndata as ad 4 | import joblib 5 | import scipy 6 | import os 7 | 8 | np.random.seed(888) 9 | 10 | # generating the training example data 11 | train_data_template_x = (np.random.rand(1000,5000) > 0.5) * 1 12 | train_data_layer_1 = (np.random.rand(1000,5000) > 0.5) * 1 13 | train_data_layer_2 = (np.random.rand(1000,5000) > 0.5) * 1 14 | train_data_layer_3 = np.random.rand(1000,5000) 15 | train_data = ad.AnnData(X = scipy.sparse.csr_matrix(train_data_template_x)) 16 | train_data.layers['level1'] = scipy.sparse.csr_matrix(train_data_layer_1) 17 | train_data.layers['level2'] = scipy.sparse.csr_matrix(train_data_layer_2) 18 | train_data.layers['level3'] = scipy.sparse.csr_matrix(train_data_layer_3) 19 | train_data.obs.index = [f"Gene{i+1}" for i in range(1000)] 20 | train_data.var.index = [f"Gene{i+1}" for i in range(5000)] 21 | 22 | # generating the validation example data 23 | val_data_template_x = (np.random.rand(100,5000) > 0.5) * 1 24 | val_data_layer_1 = (np.random.rand(100,5000) > 0.5) * 1 25 | val_data_layer_2 = (np.random.rand(100,5000) > 0.5) * 1 26 | val_data_layer_3 = np.random.rand(100,5000) 27 | val_data = ad.AnnData(X = scipy.sparse.csr_matrix(val_data_template_x)) 28 | val_data.layers['level1'] = scipy.sparse.csr_matrix(val_data_layer_1) 29 | val_data.layers['level2'] = scipy.sparse.csr_matrix(val_data_layer_2) 30 | val_data.layers['level3'] = scipy.sparse.csr_matrix(val_data_layer_3) 31 | val_data.obs.index = [f"Gene{i+1}" for i in range(1000,1100)] 32 | val_data.var.index = [f"Gene{i+1}" for i in range(5000)] 33 | 34 | # generating the testing example data 35 | test_data_template_x = (np.random.rand(200,5000) > 0.5) * 1 36 | test_data_layer_1 = (np.random.rand(200,5000) > 0.5) * 1 37 | test_data_layer_2 = (np.random.rand(200,5000) > 0.5) * 1 38 | test_data_layer_3 = np.random.rand(200,5000) 39 | test_data = ad.AnnData(X = scipy.sparse.csr_matrix(test_data_template_x)) 40 | test_data.layers['level1'] = scipy.sparse.csr_matrix(test_data_layer_1) 41 | test_data.layers['level2'] = scipy.sparse.csr_matrix(test_data_layer_2) 42 | test_data.layers['level3'] = scipy.sparse.csr_matrix(test_data_layer_3) 43 | test_data.obs.index = [f"Gene{i+1},Gene{i+2}" for i in range(1100,1300)] 44 | test_data.var.index = [f"Gene{i+1}" for i in range(5000)] 45 | 46 | # generating the top 40 DEGs for testing example data 47 | test_data_top40 = ad.AnnData(X = scipy.sparse.csr_matrix(test_data_template_x)) 48 | test_data_layer_1_top40 = np.zeros_like(np.random.rand(200,5000)) 49 | for i in range(200): 50 | test_data_layer_1_top40[i][np.random.choice(5000,40,replace=False)]=1 51 | test_data_top40.layers['level1'] = scipy.sparse.csr_matrix(test_data_layer_1_top40) 52 | test_data_top40.layers['level2'] = scipy.sparse.csr_matrix(test_data_layer_2) 53 | test_data_top40.layers['level3'] = scipy.sparse.csr_matrix(test_data_layer_3) 54 | test_data_top40.obs.index = [f"Gene{i+1},Gene{i+2}" for i in range(1100,1300)] 55 | test_data_top40.var.index = [f"Gene{i+1}" for i in range(5000)] 56 | 57 | # generating the gene embedding matrix and the gene embedding orders must be consistent with the gene orders of data.var 58 | gene_embs = np.random.rand(5000, 512).astype('float32') 59 | 60 | if not os.path.exists("./Data"): 61 | os.makedirs("./Data") 62 | 63 | train_data.write("./Data/example_train.h5ad") 64 | val_data.write("./Data/example_val.h5ad") 65 | test_data.write("./Data/example_test.h5ad") 66 | test_data_top40.write("./Data/example_test_top40.h5ad") 67 | joblib.dump(gene_embs, "./Data/example_gene_embeddings.pkl") -------------------------------------------------------------------------------- /Data/example_config.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | Training_dataset: ./Data/example_train.h5ad 3 | Validation_dataseta: ./Data/example_val.h5ad 4 | Testing_dataset: ./Data/example_test.h5ad 5 | Gene_embedding: ./Data/example_gene_embeddings.pkl 6 | Train: 7 | Sampling: 8 | batch_size: 64 9 | sample_shuffle: True 10 | 11 | Model_Parameter: 12 | First_level: 13 | in_features: 512 14 | out_features: 5000 15 | 16 | Second_level: 17 | hid1_features_2: 128 18 | hid2_features_2: 64 19 | hid3_features_2: 32 20 | Third_level: 21 | in_feature_3: 32 22 | hid1_features_3: 16 23 | hid2_features_3: 8 24 | device: cuda 25 | 26 | Trainer_parameter: 27 | random_seed: 888 28 | epoch: 10 29 | learning_rate: 0.0001 30 | 31 | output_dir: ./Output 32 | 33 | Inference: 34 | Sampling: 35 | batch_size: 256 36 | sample_shuffle: False -------------------------------------------------------------------------------- /DataProcess/DataProcess.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/home//project/GW_PerturbSeq') 3 | from myUtil import * 4 | import scanpy as sc 5 | from tqdm import tqdm 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | 9 | ### 数据预处理 以及得到True的差异基因等 10 | def fun1(adata): 11 | adata = adata[~adata.obs['gene'].isin(['None', 'CTRL'])] 12 | tmp = [i.split(',') for i in adata.obs['gene']] 13 | tmp = [i for i in tmp if i != 'CTRL' and i != 'None'] 14 | tmp = np.unique([i for j in tmp for i in j]) 15 | return tmp 16 | 17 | 18 | def fun2(adata): 19 | mylist = [] 20 | for gene, cell in zip(adata.obs['gene'], adata.obs_names): 21 | if gene == 'CTRL': 22 | mylist.append(cell) 23 | else: 24 | genes = gene.split(',') 25 | tmp = [True if i in adata.var_names else False for i in genes] 26 | if np.all(tmp): mylist.append(cell) 27 | adata1 = adata[mylist, :] 28 | return adata1 29 | 30 | def fun3(x): ### 只保留最多两个组合扰动 31 | xs = x.split(',') 32 | if len(xs) >= 3: return False 33 | else: return True 34 | 35 | 36 | def fun4(adata): 37 | with open('/home/project/GW_PerturbSeq/geneEmbedding/geneList.pkl', 'rb') as fin: 38 | geneList = pickle.load(fin) 39 | tmp = adata.var_names.isin(geneList) ### 必须有geneEmbedding 40 | adata = adata[:, tmp] 41 | adata1 = fun2(adata) ### 扰动基因必须在表达谱中 且有geneEmbedding 42 | adata1.write_h5ad('raw.h5ad') 43 | 44 | 45 | ### 首先进行处理,使数据能够满足标准化处理, 全基因组文章三个数据集的处理方法 46 | def QCsample1(dirName, fileName='Raw.h5ad'): 47 | os.chdir(dirName) 48 | adata = sc.read_h5ad(fileName) 49 | adata = adata[:, ~adata.var['gene_name'].duplicated()] ## 更换index, 首先去除重复, 要不然后续报错 50 | adata.var['gene_id'] = adata.var.index 51 | adata.var.set_index('gene_name', inplace=True) 52 | adata.obs['gene'].replace({'non-targeting': 'CTRL'}, inplace=True) 53 | fun4(adata) 54 | 55 | ''' 确保都有embedding 并且 扰动基因在表达的基因中 56 | adata = sc.read_h5ad('raw.h5ad') 57 | [i for i in adata.obs['gene'].unique() if i not in geneList] 58 | [i for i in adata.var_names if i not in geneList] 59 | [i for i in adata.obs['gene'].unique() if i not in adata.var_names] 60 | ''' 61 | 62 | #### 张峰转录因子数据 63 | def QCsample2(dirName): 64 | os.chdir(dirName) 65 | adata = sc.read_h5ad('Raw.h5ad') 66 | adata.var['gene_id'] = adata.var.index 67 | adata.var['gene_name'] = adata.var.index 68 | adata.var.set_index('gene_name', inplace=True) 69 | fun4(adata) 70 | 71 | #### Perturb-CITE-seq 72 | def QCsample3(dirName): 73 | os.chdir(dirName) 74 | adata = sc.read_h5ad("raw1.h5ad") 75 | adata.var['gene_id'] = adata.var.index 76 | adata.var['gene_name'] = adata.var.index 77 | adata.var.set_index('gene_name', inplace=True) 78 | tmp = adata.obs['gene'].apply(lambda x: fun3(x)) ### 只保留两个扰动 79 | adata = adata[tmp] 80 | fun4(adata) 81 | 82 | ### 对数据集进行差异基因的计算 83 | ###数据预处理 84 | def f_preData(dirName): 85 | os.chdir(dirName) 86 | adata1 = sc.read_h5ad('raw.h5ad') 87 | 88 | filterNoneNums, filterCells, filterMT, filterMinNums, adata = preData(adata1, filterNone=True, minNums = 30, shuffle=False, filterCom=False, seed = 42, mtpercent = 10, min_genes = 200, domaxNums=500, doNor = True) 89 | print (filterNoneNums, filterCells, filterMT, filterMinNums) 90 | 91 | sc.pp.highly_variable_genes(adata, subset=False, n_top_genes=5000) ###和gears一致,保证5000个hvg 92 | hvgs = list(adata.var_names[adata.var['highly_variable']]) 93 | trainGene = fun1(adata) ### 获得扰动的基因列表, 除去CTRL 94 | trainGene = [i for i in trainGene if i in adata.var_names] ### 扰动数据必须保留在表达谱中 95 | keepGene = list(set(trainGene + hvgs)) 96 | adata = adata[:, keepGene] 97 | adata1 = adata1[adata.obs_names, keepGene] 98 | 99 | adata = fun2(adata) ### 再次保证扰动都在表达基因中, 不在的话过滤掉 100 | adata1 = fun2(adata1) ### 再次保证扰动都在表达基因中, 不在的话过滤掉 101 | adata.write_h5ad('filterNor.h5ad') 102 | adata1.write_h5ad('filterRaw.h5ad') 103 | 104 | 105 | 106 | 107 | dirNames = [ 108 | #'/home//project/GW_PerturbSeq/anndata/K562_GW', 109 | '/home//project/GW_PerturbSeq/anndata/K562_GW_subset', 110 | '/home//project/GW_PerturbSeq/anndata/K562_essential', 111 | '/home//project/GW_PerturbSeq/anndata/RPE1_essential', 112 | '/home//project/GW_PerturbSeq/anndata/TFatlas', 113 | 114 | '/home//project/GW_PerturbSeq/anndata_combination/PRJNA551220', 115 | '/home//project/GW_PerturbSeq/anndata_combination/Perturb-CITE-seq', 116 | '/home//project/GW_PerturbSeq/anndata_combination/PRJNA787633' 117 | ] 118 | 119 | #dirName = '/home//project/GW_PerturbSeq/anndata/K562_GW_subset'; domaxNums = True; minNums = 30; ###QCsample1(dirName) ### ### CRISPRi ### 每个扰动最多只保留50个细胞 120 | dirName = '/home//project/GW_PerturbSeq/anndata/K562_GW'; domaxNums = True; minNums = 30; ##QCsample1(dirName) ### ### CRISPRi 121 | #dirName = '/home//project/GW_PerturbSeq/anndata/K562_essential'; domaxNums = True; minNums = 30; #QCsample1(dirName) 122 | #dirName = '/home//project/GW_PerturbSeq/anndata/RPE1_essential'; domaxNums = True; minNums = 30; #QCsample1(dirName) 123 | #dirName = '/home//project/GW_PerturbSeq/anndata/TFatlas'; domaxNums = True; minNums = 30; #QCsample2(dirName) #Embryonic stem cells Activation 124 | 125 | #dirName = '/home//project/GW_PerturbSeq/anndata_combination/PRJNA551220'; domaxNums = True; minNums = 30; #QCsample3(dirName) ### 组合扰动数据 #### K562 Activation 126 | #dirName = '/home//project/GW_PerturbSeq/anndata_combination/Perturb-CITE-seq'; domaxNums = True; minNums = 5; #QCsample3(dirName) #### CRISPRKO melanoma cell 127 | #dirName = '/home//project/GW_PerturbSeq/anndata_combination/PRJNA787633'; domaxNums = True; minNums = 5; #QCsample3(dirName) ### T cell CRISPRa 128 | 129 | ### Ttest, wilcoxon, edgeR 130 | 131 | if __name__ == '__main__': 132 | print ('hello, world') 133 | f_preData(dirName) 134 | getDEG("/home//project/GW_PerturbSeq/anndata/K562_GW", method='wilcoxon') 135 | 136 | for dirName in dirNames: 137 | for method in tqdm(['Ttest', 'wilcoxon', 'edgeR']): 138 | f_getDEG1(dirName, topDeg=0, method=method); f_getDEG1(dirName, topDeg=20, method=method) 139 | f_getDEG2(dirName, topDeg=0, method=method); f_getDEG2(dirName, topDeg=20, method=method) 140 | f_getDEG3(dirName, topDeg=0, method=method); f_getDEG3(dirName, topDeg=20, method=method) 141 | mergeLevelData(dirName, topDeg=0, method=method); mergeLevelData(dirName, topDeg=20, method=method) -------------------------------------------------------------------------------- /DataProcess/geneList.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bm2-lab/STAMP/f40e2887c013647bb01570bd61bdfa01d6a05b00/DataProcess/geneList.pkl -------------------------------------------------------------------------------- /DataProcess/myUtil.py: -------------------------------------------------------------------------------- 1 | import numpy as np, pandas as pd, scanpy as sc 2 | import os, sys, pickle, joblib, re, torch, subprocess 3 | from itertools import combinations 4 | from collections import Counter 5 | import seaborn as sns 6 | from collections import defaultdict, OrderedDict 7 | import matplotlib.pyplot as plt 8 | from scipy import stats 9 | from sklearn.metrics.pairwise import cosine_similarity 10 | from scipy.stats import pearsonr, spearmanr 11 | from sklearn.metrics import precision_recall_curve 12 | from sklearn.metrics import average_precision_score 13 | from sklearn.metrics import roc_curve, roc_auc_score, mean_squared_error 14 | from scipy import sparse 15 | import scipy 16 | import anndata as ad 17 | from tqdm import tqdm 18 | from multiprocessing import Pool 19 | 20 | 21 | class multidict(dict): 22 | def __getitem__(self, item): 23 | try: 24 | return dict.__getitem__(self, item) 25 | except KeyError: 26 | value = self[item] = type(self)() 27 | return value 28 | 29 | 30 | def convertAnn2Matrix(adata, dirName='./'): 31 | a = pd.DataFrame(adata.var) 32 | a['index'] = a.index 33 | a = a[['gene_ids', 'index']] 34 | a.to_csv("{}/genes.tsv".format(dirName), sep = "\t", index = False, header=False) 35 | pd.DataFrame(adata.obs.index).to_csv("{}/barcodes.tsv".format(dirName), sep = "\t", index = False, header=False) 36 | if not sparse.issparse(adata.X): adata.X = sparse.csr_matrix(adata.X) ### 转换为稀疏矩阵 37 | adata.X = adata.X.astype(np.int32) ###转换为整数 38 | scipy.io.mmwrite("{}/matrix.mtx".format(dirName), adata.X.T) 39 | 40 | 41 | 42 | 43 | def getsubgroup(x = 'combo_seen0', seed = 1): 44 | tmp = '../../gears/data/train/splits/train_simulation_{}_0.8_subgroup.pkl'.format(seed) 45 | with open(tmp, 'rb') as fin: 46 | subgroup_split = pickle.load(fin) 47 | test_subgroup = subgroup_split['test_subgroup'][x] 48 | mylist = [] 49 | if 'combo' in x: 50 | for i in test_subgroup: 51 | mylist.append(','.join(sorted(i.split('+')))) 52 | else: 53 | for i in test_subgroup: 54 | mylist.append(i.split('+')[0]) 55 | return mylist 56 | 57 | 58 | def getsubgroup_single(seed = 1): ### 单个扰动 59 | tmp1 = '../../gears/data/train/splits/train_simulation_{}_0.8.pkl'.format(seed) 60 | tmp2 = '../../../gears/data/train/splits/train_simulation_{}_0.8.pkl'.format(seed) 61 | 62 | if os.path.isfile(tmp1): 63 | tmp = tmp1 64 | else: 65 | tmp = tmp2 66 | with open(tmp, 'rb') as fin: 67 | subgroup_split = pickle.load(fin) 68 | mylist = [] 69 | for i in subgroup_split['test']: 70 | mylist.append(i.split('+')[0]) 71 | return mylist 72 | 73 | 74 | 75 | 76 | def wilcoxonFun(gene): 77 | sc.tl.rank_genes_groups(adata, 'gene', groups=[gene], reference='CTRL', method= 'wilcoxon') 78 | result = adata.uns['rank_genes_groups'] 79 | groups = result['names'].dtype.names 80 | final_result = pd.DataFrame({group + '_' + key: result[key][group] for group in groups for key in ['names', 'pvals_adj', 'logfoldchanges', 'scores']}) 81 | for group in groups: 82 | tmp1 = group + '_' + 'foldchanges' 83 | tmp2 = group + '_' + 'logfoldchanges' 84 | final_result[tmp1] = 2 ** final_result[tmp2] ### logfoldchange 转换为 foldchange 85 | final_result.drop(labels=[tmp2], inplace=True, axis=1) 86 | return final_result 87 | 88 | def getDEG(dirName, method='Ttest'): 89 | os.chdir(dirName) 90 | global adata 91 | fileout = '{}_DEG.tsv'.format(method) 92 | #adata = sc.read_h5ad('filterNor.h5ad') 93 | adata = sc.read_h5ad('filterNor_subset.h5ad') 94 | if 'log1p' not in adata.uns: 95 | adata.uns['log1p'] = {} 96 | adata.uns['log1p']["base"] = None 97 | if scipy.sparse.issparse(adata.X): 98 | adata.X = adata.X.toarray() 99 | adata.X += .1 100 | genes = [i for i in set(adata.obs['gene']) if i != 'CTRL'] 101 | if method == 'Ttest': 102 | sc.tl.rank_genes_groups(adata, 'gene', groups=genes[:], reference='CTRL', method= 't-test') 103 | result = adata.uns['rank_genes_groups'] 104 | groups = result['names'].dtype.names 105 | final_result = pd.DataFrame({group + '_' + key: result[key][group] for group in groups for key in ['names', 'pvals_adj', 'logfoldchanges', 'scores']}) 106 | for group in groups: 107 | tmp1 = group + '_' + 'foldchanges' 108 | tmp2 = group + '_' + 'logfoldchanges' 109 | final_result[tmp1] = 2 ** final_result[tmp2] ### logfoldchange 转换为 foldchange 110 | final_result.drop(labels=[tmp2], inplace=True, axis=1) 111 | final_result.sort_index(axis=1, inplace=True) 112 | final_result.to_csv(fileout, sep='\t', index=False) 113 | elif method == 'wilcoxon': 114 | result = myPool(wilcoxonFun, genes, processes=7) 115 | final_result = pd.concat(result, axis=1) 116 | final_result.sort_index(axis=1, inplace=True) 117 | final_result.to_csv(fileout, sep='\t', index=False) 118 | 119 | 120 | 121 | 122 | ### 根据DEG得到二维矩阵, 列是表达的基因, 行是扰动的基因, 123 | ### 第一级任务。 ***************** 124 | def f_getDEG1(dirName, topDeg=0, pvalue = 0.01, method='Ttest'): 125 | os.chdir(dirName) 126 | if topDeg == 0: 127 | filein = '{}_DEG.tsv'.format(method); fileout = '{}_DEG_binary.tsv'.format(method) 128 | else: 129 | filein = '{}_DEG.tsv'.format(method); fileout = '{}_DEG_binary_topDeg{}.tsv'.format(method, topDeg) 130 | if os.path.isfile(fileout): 131 | tmp = pd.read_csv(fileout, sep='\t') 132 | if tmp.shape[0] >= 10: pass 133 | dat = pd.read_csv(filein, sep='\t') 134 | pertGene = list(set([i.split('_')[0] for i in dat.columns])) 135 | expGene = list(dat.iloc[:, 1]) 136 | binaryMat = pd.DataFrame(columns= expGene, index=pertGene, data=0) 137 | for pertGene1 in pertGene: 138 | tmp1 = '{}_names'.format(pertGene1); tmp2 = '{}_pvals_adj'.format(pertGene1); 139 | if topDeg == 0: 140 | expGene1 = [i for i, j in zip(dat[tmp1], dat[tmp2]) if j <= pvalue] 141 | else: 142 | expGene1 = list(dat[tmp1][:20]) + list(dat[tmp1][-20:]) ### 上下调各取20个 143 | binaryMat.loc[pertGene1, expGene1] = 1 144 | binaryMat.sort_index(axis=0, inplace=True) ### 排序 145 | binaryMat.sort_index(axis=1, inplace=True) ### 排序 146 | binaryMat.to_csv(fileout, index=True, sep='\t', header=True) 147 | 148 | 149 | ### 分类成上下调 ### 第二级任务。 ***************** 150 | def f_getDEG2(dirName, topDeg=0, method='Ttest'): 151 | os.chdir(dirName) 152 | if topDeg ==0: ### 其实没必要,因为预测值不会随着topDeg而改变 153 | filein = '{}_DEG.tsv'.format(method); fileout = '{}_DEG_UpDown.tsv'.format(method) 154 | else: 155 | filein = '{}_DEG.tsv'.format(method); fileout = '{}_DEG_UpDown_topDeg{}.tsv'.format(method, topDeg) 156 | if os.path.isfile(fileout): 157 | tmp = pd.read_csv(fileout, sep='\t') 158 | if tmp.shape[0] >= 10: pass 159 | dat = pd.read_csv(filein, sep='\t') 160 | pertGene = list(set([i.split('_')[0] for i in dat.columns])) 161 | expGene = list(dat.iloc[:, 1]) 162 | binaryMat = pd.DataFrame(columns= expGene, index=pertGene, data=0) 163 | for pertGene1 in pertGene: 164 | tmp1 = '{}_names'.format(pertGene1); tmp2 = '{}_foldchanges'.format(pertGene1) 165 | expGene1 = [i for i, j in zip(dat[tmp1], dat[tmp2]) if j >= 1] 166 | binaryMat.loc[pertGene1, expGene1] = 1 167 | binaryMat.sort_index(axis=0, inplace=True) ### 排序 168 | binaryMat.sort_index(axis=1, inplace=True) ### 排序 169 | binaryMat.to_csv(fileout, index=True, sep='\t', header=True) 170 | 171 | 172 | def f_getDEG3(dirName, topDeg = 0, method='Ttest'): 173 | os.chdir(dirName) 174 | if topDeg == 0: ### 其实没必要,因为预测值不会随着topDeg而改变 175 | filein = '{}_DEG.tsv'.format(method); fileout = '{}_DEG_foldchange.tsv'.format(method) 176 | else: 177 | filein = '{}_DEG.tsv'.format(method); fileout = '{}_DEG_foldchange_topDeg{}.tsv'.format(method, topDeg) 178 | if os.path.isfile(fileout): 179 | tmp = pd.read_csv(fileout, sep='\t') 180 | if tmp.shape[0] >= 10: pass 181 | dat = pd.read_csv(filein, sep='\t') 182 | pertGene = list(set([i.split('_')[0] for i in dat.columns])) 183 | expGene = list(dat.iloc[:, 1]) 184 | expGene = sorted(expGene) 185 | binaryMat = pd.DataFrame(columns= expGene, index=pertGene, data=0.0) 186 | for pertGene1 in pertGene: 187 | tmp1 = '{}_names'.format(pertGene1); tmp2 = '{}_foldchanges'.format(pertGene1) 188 | tmp3 = dat[[tmp1, tmp2]] 189 | expGene1 = list(tmp3.sort_values(tmp1)[tmp2]) 190 | binaryMat.loc[pertGene1, :] = expGene1 191 | binaryMat.sort_index(axis=0, inplace=True) 192 | binaryMat.sort_index(axis=1, inplace=True) 193 | binaryMat.to_csv(fileout, index=True, sep='\t', header=True) 194 | 195 | 196 | def mergeLevelData(dirName, topDeg=0, method='Ttest'): 197 | os.chdir(dirName) 198 | if topDeg == 0: 199 | filein1 = '{}_DEG_binary.tsv'.format(method) 200 | filein2 = '{}_DEG_UpDown.tsv'.format('Ttest') 201 | filein3 = '{}_DEG_foldchange.tsv'.format("Ttest") 202 | fileout = '{}_merge.h5ad'.format(method) 203 | else: 204 | filein1 = '{}_DEG_binary_topDeg{}.tsv'.format(method, topDeg) 205 | filein2 = '{}_DEG_UpDown_topDeg{}.tsv'.format("Ttest", topDeg) 206 | filein3 = '{}_DEG_foldchange_topDeg{}.tsv'.format("Ttest", topDeg) 207 | fileout = '{}_merge_topDeg{}.h5ad'.format(method, topDeg) 208 | 209 | 210 | dat1 = pd.read_csv(filein1, sep='\t', index_col=0) 211 | dat2 = pd.read_csv(filein2, sep='\t', index_col=0) 212 | dat2 = dat2.loc[dat1.index, dat1.columns] 213 | 214 | dat3 = pd.read_csv(filein3, sep='\t', index_col=0) 215 | dat3 = dat3.loc[dat1.index, dat1.columns] 216 | adata = ad.AnnData(X=sparse.csr_matrix(dat1.values), obs=pd.DataFrame(index=dat1.index), var=pd.DataFrame(index=dat1.columns)) 217 | adata.layers['level1'] = sparse.csr_matrix(dat1.values) 218 | adata.layers['level2'] = sparse.csr_matrix(dat2.values) 219 | adata.layers['level3'] = sparse.csr_matrix(dat3.values) 220 | adata.write_h5ad(fileout) 221 | 222 | 223 | 224 | def preData(adata, filterNone=True, minNums = 30, shuffle=True, filterCom=False, seed = 42, mtpercent = 10, min_genes = 200, domaxNums=False, doNor=True, min_cells=3): #### 为了测试聚类,最好不要进行排序 225 | if domaxNums: maxNums = 50 226 | adata.var_names.astype(str) 227 | adata.var_names_make_unique() 228 | adata = adata[~adata.obs.index.duplicated()] 229 | if filterCom: 230 | tmp = adata.obs['gene'].apply(lambda x: True if ',' not in x else False); adata = adata[tmp] 231 | if filterNone: 232 | adata = adata[adata.obs["gene"] != "None"] 233 | filterNoneNums = adata.shape[0] 234 | sc.pp.filter_cells(adata, min_genes=min_genes) 235 | sc.pp.filter_genes(adata, min_cells= min_cells) 236 | filterCells = adata.shape[0] 237 | 238 | if np.any([True if i.startswith('mt-') else False for i in adata.var_names]): 239 | adata.var['mt'] = adata.var_names.str.startswith('mt-') 240 | else: 241 | adata.var['mt'] = adata.var_names.str.startswith('MT-') 242 | sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True) 243 | if sum(adata.obs['pct_counts_mt'] < 10) / adata.shape[0] <=0.5: mtpercent = 15 244 | adata = adata[adata.obs.pct_counts_mt < mtpercent, :] 245 | filterMT = adata.shape[0] 246 | tmp = adata.obs['gene'].value_counts() 247 | tmp_bool = tmp >= minNums 248 | genes = list(tmp[tmp_bool].index) 249 | if 'CTRL' not in genes: genes += ['CTRL'] 250 | adata = adata[adata.obs['gene'].isin(genes), :] 251 | if domaxNums: 252 | adata1 = adata[adata.obs['gene'] == 'CTRL'] 253 | genes = adata.obs['gene'].unique() 254 | tmp = [adata[adata.obs['gene'] == i][:maxNums] for i in genes if i !='CTRL'] 255 | adata2 = ad.concat(tmp) 256 | adata = ad.concat([adata1, adata2]) 257 | 258 | filterMinNums = adata.shape[0] 259 | 260 | if doNor: 261 | sc.pp.normalize_total(adata, target_sum=1e4) 262 | sc.pp.log1p(adata) 263 | adata = adata[adata.obs.sort_values(by='gene').index,:] 264 | if shuffle: 265 | tmp = list(adata.obs.index) 266 | np.random.seed(seed); np.random.shuffle(tmp); adata = adata[tmp] 267 | return filterNoneNums, filterCells, filterMT, filterMinNums, adata 268 | 269 | 270 | 271 | def myPool(func, mylist, processes): 272 | with Pool(processes) as pool: 273 | results = list(tqdm(pool.imap(func, mylist), total=len(mylist))) 274 | return results -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune Data 2 | include README.md 3 | include requirements.txt 4 | include LICENSE -------------------------------------------------------------------------------- /Output/training.log: -------------------------------------------------------------------------------- 1 | [2024-01-08 14:47:37,272][STAMP.py][line:134][INFO] Epoch:[1/10] steps:16 loss1_train:0.79759 loss2_train:0.69352 loss3_train:0.32161 loss1_test:0.77941 loss2_test:0.69337 loss3_test:0.29287 time:4.355 2 | [2024-01-08 14:47:41,487][STAMP.py][line:134][INFO] Epoch:[2/10] steps:16 loss1_train:0.77924 loss2_train:0.69322 loss3_train:0.25942 loss1_test:0.77926 loss2_test:0.69336 loss3_test:0.21642 time:3.902 3 | [2024-01-08 14:47:45,568][STAMP.py][line:134][INFO] Epoch:[3/10] steps:16 loss1_train:0.77899 loss2_train:0.69377 loss3_train:0.17493 loss1_test:0.77886 loss2_test:0.69381 loss3_test:0.13191 time:3.771 4 | [2024-01-08 14:47:49,320][STAMP.py][line:134][INFO] Epoch:[4/10] steps:16 loss1_train:0.77779 loss2_train:0.69335 loss3_train:0.10727 loss1_test:0.77657 loss2_test:0.69318 loss3_test:0.08846 time:3.413 5 | [2024-01-08 14:47:53,320][STAMP.py][line:134][INFO] Epoch:[5/10] steps:16 loss1_train:0.77411 loss2_train:0.69317 loss3_train:0.08467 loss1_test:0.77420 loss2_test:0.69318 loss3_test:0.08353 time:3.686 6 | [2024-01-08 14:47:56,933][STAMP.py][line:134][INFO] Epoch:[6/10] steps:16 loss1_train:0.77264 loss2_train:0.69316 loss3_train:0.08380 loss1_test:0.77355 loss2_test:0.69318 loss3_test:0.08373 time:3.277 7 | [2024-01-08 14:48:00,993][STAMP.py][line:134][INFO] Epoch:[7/10] steps:16 loss1_train:0.77234 loss2_train:0.69316 loss3_train:0.08357 loss1_test:0.77358 loss2_test:0.69318 loss3_test:0.08341 time:3.747 8 | [2024-01-08 14:48:04,797][STAMP.py][line:134][INFO] Epoch:[8/10] steps:16 loss1_train:0.77223 loss2_train:0.69316 loss3_train:0.08344 loss1_test:0.77350 loss2_test:0.69318 loss3_test:0.08340 time:3.507 9 | [2024-01-08 14:48:08,271][STAMP.py][line:134][INFO] Epoch:[9/10] steps:16 loss1_train:0.77211 loss2_train:0.69316 loss3_train:0.08344 loss1_test:0.77351 loss2_test:0.69318 loss3_test:0.08340 time:3.156 10 | [2024-01-08 14:48:11,990][STAMP.py][line:134][INFO] Epoch:[10/10] steps:16 loss1_train:0.77197 loss2_train:0.69316 loss3_train:0.08344 loss1_test:0.77352 loss2_test:0.69318 loss3_test:0.08340 time:3.421 11 | [2024-01-08 14:48:11,996][STAMP.py][line:145][INFO] finish training! 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # STAMP: Toward subtask decomposition-based learning and benchmarking for genetic perturbation outcome prediction and beyond 2 | [![DOI image](https://zenodo.org/badge/DOI/10.5281/zenodo.12779567.svg)](https://zenodo.org/records/12779567) 3 | ## Introduction 4 | This repository hosts the official implementation of STAMP, a method that can predict perturbation outcomes using single-cell RNA-sequencing data from perturbational experimental screens using subtask decomposition learning. STAMP can be applied to three challenges in this area, i.e. (1) predict single genetic perturbation outcomes, (2) predict multiple genetic perturbation outcomes and (3) predict genetic perturbation outcomes across cell lines. 5 |

STAMP

6 | 7 | ## Installation 8 | Our experiments were conducted on python=3.9.7 and our CUDA version is 11.4. 9 | 10 | We recommend using Anaconda / Miniconda to create a conda environment for using STAMP. You can create a python environment using the following command: 11 | ```python 12 | conda create -n stamp python==3.9.7 13 | ``` 14 | 15 | Then, you can activate the environment using: 16 | ```python 17 | conda activate stamp 18 | ``` 19 | Installing Pytorch with following command: 20 | ```python 21 | conda install pytorch==1.10.2 torchvision==0.11.3 torchaudio==0.10.2 -c pytorch 22 | ``` 23 | Then 24 | ```python 25 | pip install . 26 | ``` 27 | or you can install it from [PyPI](https://pypi.org/project/cell-stamp/): 28 | ``` 29 | pip install cell-stamp 30 | ``` 31 | 32 | ## Example data 33 | We have made available the code necessary to generate example data, serving as a practical illustration for training and testing the STAMP model. Additionally, for guidance on configuring the training process of STAMP, we offer an example config file located at `./Data/example_config.yaml`. 34 | ```python 35 | python ./Data/GeneratingExampleData.py 36 | ``` 37 | The example *.h5ad data file has three distinct layers, namely 'level1', 'level2', and 'level3'. The 'level1' layer is a binary matrix, where '0' represents non-differentially expressed genes (non-DEGs) and '1' indicates differentially expressed genes (DEGs). Similarly, 'level2' is another binary matrix, denoting down-regulated genes with '0' and up-regulated genes with '1'. Lastly, the 'level3' layer is a matrix that quantifies the magnitude of gene expression changes. 38 | 39 | ## Real demo data 40 | We have uploaded all benchmark datasets to Zenodo, which can be obtained from [here](https://zenodo.org/records/12779567). Please download all these files into the `./Data` directory and refer `tutorial_for_training.py.ipynb` in `./Tutorial` directory. This tutorial uses the one fold of RPE1_essential dataset as an example to perform the model training, testing and check the loss curves during training. 41 | #### Note: Users are encouraged to change the path of each data in 'Config.yaml' based on their own machines. 42 | 43 | ## Core API interface for model training 44 | Using this API, you can train and test STAMP on your own perturbation datasets using a few lines of code. 45 | ```python 46 | from stamp import STAMP, load_config 47 | import scanpy as sc 48 | 49 | # load config file 50 | config = load_config("./Data/example_config.yaml") 51 | 52 | # set up and train a STAMP 53 | model = STAMP(config) 54 | model.train() 55 | 56 | # load trained model 57 | model.load_pretrained(f"{config['Train']['output_dir']}/trained_models") 58 | 59 | # use trained model to predict unseen perturbations 60 | model.prediction(config['dataset']['Testing_dataset'], combo_test = True) 61 | 62 | # use trained model to predict unseen perturbations; considering Top 40 DEGs 63 | # Top 40 DEGs consisting of Top 20 up-regulation genes and Top 20 down-regulation genes 64 | 65 | # load Top 40 test data 66 | top_40_data = sc.read_h5ad("./Data/example_test_top40.h5ad") 67 | 68 | # prediction 69 | model.prediction(top_40_data, combo_test = True) 70 | ``` 71 | ## Core API interface for model fine-tuning 72 | Using this API, you can fine-tune and test STAMP on your own perturbation datasets using a few lines of code. 73 | ```python 74 | from stamp import STAMP, load_config 75 | import scanpy as sc 76 | 77 | # load config file (we use the example config used for model training to illustrate this) 78 | config = load_config("./Data/example_config.yaml") 79 | 80 | # set up STAMP 81 | model = STAMP(config) 82 | 83 | # load pre-trained model 84 | model.load_pretrained(f"{config['Train']['output_dir']}/trained_models") 85 | 86 | # fine-tuning model 87 | model.finetuning() 88 | 89 | # use fine-tuned model to predict unseen perturbations 90 | model.prediction(config['dataset']['Testing_dataset'], combo_test = False) 91 | 92 | # use fine-tuned model to predict unseen perturbations; considering Top 40 DEGs 93 | # Top 40 DEGs consisting of Top 20 up-regulation genes and Top 20 down-regulation genes 94 | 95 | # load Top 40 test data 96 | top_40_data = sc.read_h5ad("./Data/example_test_top40.h5ad") 97 | 98 | # prediction 99 | model.prediction(top_40_data, combo_test = False) 100 | ``` 101 | ## Citation 102 | Yicheng Gao, Zhiting Wei, Qi Liu et al. *Toward subtask decomposition-based learning and benchmarking for genetic perturbation outcome prediction and beyond*, Nature Computational Science, 2024. 103 | ## Contacts 104 | bm2-lab@tongji.edu.cn 105 | -------------------------------------------------------------------------------- /Supp_code/GI_experiment/GI_evaluation.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | import numpy as np 3 | from sklearn import tree 4 | import matplotlib.pyplot as plt 5 | from sklearn.tree import plot_tree 6 | from scipy import stats 7 | import pandas as pd 8 | np.random.seed(000) 9 | GIs = { 10 | 'NEOMORPHIC': ['CBL+TGFBR2', 11 | 'KLF1+TGFBR2', 12 | 'MAP2K6+SPI1', 13 | 'SAMD1+TGFBR2', 14 | 'TGFBR2+C19orf26', 15 | 'TGFBR2+ETS2', 16 | 'CBL+UBASH3A', 17 | 'CEBPE+KLF1', 18 | 'DUSP9+MAPK1', 19 | 'FOSB+PTPN12', 20 | 'PLK4+STIL', 21 | 'PTPN12+OSR2', 22 | 'ZC3HAV1+CEBPE'], 23 | 'ADDITIVE': ['BPGM+SAMD1', 24 | 'CEBPB+MAPK1', 25 | 'CEBPB+OSR2', 26 | 'DUSP9+PRTG', 27 | 'FOSB+OSR2', 28 | 'IRF1+SET', 29 | 'MAP2K3+ELMSAN1', 30 | 'MAP2K6+ELMSAN1', 31 | 'POU3F2+FOXL2', 32 | 'RHOXF2BB+SET', 33 | 'SAMD1+PTPN12', 34 | 'SAMD1+UBASH3B', 35 | 'SAMD1+ZBTB1', 36 | 'SGK1+TBX2', 37 | 'TBX3+TBX2', 38 | 'ZBTB10+SNAI1'], 39 | 'EPISTASIS': ['AHR+KLF1', 40 | 'MAPK1+TGFBR2', 41 | 'TGFBR2+IGDCC3', 42 | 'TGFBR2+PRTG', 43 | 'UBASH3B+OSR2', 44 | 'DUSP9+ETS2', 45 | 'KLF1+CEBPA', 46 | 'MAP2K6+IKZF3', 47 | 'ZC3HAV1+CEBPA'], 48 | 'REDUNDANT': ['CDKN1C+CDKN1A', 49 | 'MAP2K3+MAP2K6', 50 | 'CEBPB+CEBPA', 51 | 'CEBPE+CEBPA', 52 | 'CEBPE+SPI1', 53 | 'ETS2+MAPK1', 54 | 'FOSB+CEBPE', 55 | 'FOXA3+FOXA1'], 56 | 'POTENTIATION': ['CNN1+UBASH3A', 57 | 'ETS2+MAP7D1', 58 | 'FEV+CBFA2T3', 59 | 'FEV+ISL2', 60 | 'FEV+MAP7D1', 61 | 'PTPN12+UBASH3A'], 62 | 'SYNERGY_SIMILAR_PHENO':['CBL+CNN1', 63 | 'CBL+PTPN12', 64 | 'CBL+PTPN9', 65 | 'CBL+UBASH3B', 66 | 'FOXA3+FOXL2', 67 | 'FOXA3+HOXB9', 68 | 'FOXL2+HOXB9', 69 | 'UBASH3B+CNN1', 70 | 'UBASH3B+PTPN12', 71 | 'UBASH3B+PTPN9', 72 | 'UBASH3B+ZBTB25'], 73 | 'SYNERGY_DISSIMILAR_PHENO': ['AHR+FEV', 74 | 'DUSP9+SNAI1', 75 | 'FOXA1+FOXF1', 76 | 'FOXA1+FOXL2', 77 | 'FOXA1+HOXB9', 78 | 'FOXF1+FOXL2', 79 | 'FOXF1+HOXB9', 80 | 'FOXL2+MEIS1', 81 | 'IGDCC3+ZBTB25', 82 | 'POU3F2+CBFA2T3', 83 | 'PTPN12+ZBTB25', 84 | 'SNAI1+DLX2', 85 | 'SNAI1+UBASH3B'], 86 | 'SUPPRESSOR': ['CEBPB+PTPN12', 87 | 'CEBPE+CNN1', 88 | 'CEBPE+PTPN12', 89 | 'CNN1+MAPK1', 90 | 'ETS2+CNN1', 91 | 'ETS2+IGDCC3', 92 | 'ETS2+PRTG', 93 | 'FOSB+UBASH3B', 94 | 'IGDCC3+MAPK1', 95 | 'LYL1+CEBPB', 96 | 'MAPK1+PRTG', 97 | 'PTPN12+SNAI1'] 98 | } 99 | GIs['SYNERGY'] = GIs['SYNERGY_DISSIMILAR_PHENO'] + GIs['SYNERGY_SIMILAR_PHENO'] + GIs['POTENTIATION'] 100 | 101 | all_results = joblib.load("./STAMP_results.pkl") 102 | all_results_truth = joblib.load("./Truth_results.pkl") 103 | all_results_gears = joblib.load("./Gears_results.pkl") 104 | science_data = pd.read_csv("./Science_data.csv", sep = ',') 105 | all_results_science = {} 106 | for idx, name in enumerate(science_data['name']): 107 | all_results_science[(','.join(name.split('_')))] = {} 108 | all_results_science[(','.join(name.split('_')))]['magnitude'] = science_data['ts_norm2'][idx] 109 | all_results_science[(','.join(name.split('_')))]['model_fit'] = science_data['ts_linear_dcor'][idx] 110 | all_results_science[(','.join(name.split('_')))]['equality_of_contribution'] = science_data['dcor_ratio'][idx] 111 | all_results_science[(','.join(name.split('_')))]['Similarity'] = science_data['dcor'][idx] 112 | 113 | 114 | def calculate_metric(all_results, top_k = 10): 115 | mags = [all_results[i]['magnitude'] for i in all_results] 116 | model_fits = [all_results[i]['model_fit'] for i in all_results] 117 | equality_of_contributions = [all_results[i]['equality_of_contribution'] for i in all_results] 118 | Similaritys = [all_results[i]['Similarity'] for i in all_results] 119 | top10_synergy = np.array(list(all_results.keys()))[np.argsort(mags)[::-1][:top_k]] 120 | top10_precision_synergy = len(set(top10_synergy).intersection(set([(',').join(i.split("+")) for i in GIs['SYNERGY']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['SYNERGY']])))) 121 | top10_precision_synergy /= top_k 122 | top10_suppressor = np.array(list(all_results.keys()))[np.argsort(mags)[:top_k]] 123 | top10_precision_suppressor = len(set(top10_suppressor).intersection(set([(',').join(i.split("+")) for i in GIs['SUPPRESSOR']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['SUPPRESSOR']])))) 124 | top10_precision_suppressor /= top_k 125 | top10_neomorphism = np.array(list(all_results.keys()))[np.argsort(model_fits)[:top_k]] 126 | top10_precision_neomorphism = len(set(top10_neomorphism).intersection(set([(',').join(i.split("+")) for i in GIs['NEOMORPHIC']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['NEOMORPHIC']])))) 127 | top10_precision_neomorphism /= top_k 128 | top10_redundant = np.array(list(all_results.keys()))[np.argsort(Similaritys)[::-1][:top_k]] 129 | top10_precision_redundant = len(set(top10_redundant).intersection(set([(',').join(i.split("+")) for i in GIs['REDUNDANT']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['REDUNDANT']])))) 130 | top10_precision_redundant /= 8 131 | top10_epistasis = np.array(list(all_results.keys()))[np.argsort(equality_of_contributions)[:top_k]] 132 | top10_precision_epistasis = len(set(top10_epistasis).intersection(set([(',').join(i.split("+")) for i in GIs['EPISTASIS']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['EPISTASIS']])))) 133 | top10_precision_epistasis /= 9 134 | return top10_precision_synergy, top10_precision_suppressor, top10_precision_neomorphism, top10_precision_redundant, top10_precision_epistasis 135 | 136 | print("Science_ori,Top10 precision:",calculate_metric(all_results_science)) 137 | print("Ground_truth,Top10 precision:",calculate_metric(all_results_truth)) 138 | print("STAMP,Top10 precision:",calculate_metric(all_results)) 139 | print("GEARs,Top10 precision:",calculate_metric(all_results_gears)) 140 | 141 | def calculate_metric_top10acc(all_results,all_results_truth, top_k = 10): 142 | mags = [all_results[i]['magnitude'] for i in all_results] 143 | model_fits = [all_results[i]['model_fit'] for i in all_results] 144 | equality_of_contributions = [all_results[i]['equality_of_contribution'] for i in all_results] 145 | Similaritys = [all_results[i]['Similarity'] for i in all_results] 146 | mags_truth = [all_results_truth[i]['magnitude'] for i in all_results_truth] 147 | model_fits_truth = [all_results_truth[i]['model_fit'] for i in all_results_truth] 148 | equality_of_contributions_truth = [all_results_truth[i]['equality_of_contribution'] for i in all_results_truth] 149 | Similaritys_truth = [all_results_truth[i]['Similarity'] for i in all_results_truth] 150 | 151 | top10_synergy = np.array(list(all_results.keys()))[np.argsort(mags)[::-1][:top_k]] 152 | top10_acc_synergy = set(top10_synergy).intersection(set([(',').join(i.split("+")) for i in GIs['SYNERGY']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['SYNERGY']]))) 153 | top10_synergy_truth = np.array(list(all_results_truth.keys()))[np.argsort(mags_truth)[::-1][:top_k]] 154 | top10_acc_synergy_truth = set(top10_synergy_truth).intersection(set([(',').join(i.split("+")) for i in GIs['SYNERGY']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['SYNERGY']]))) 155 | top10_acc_synergy = len(top10_acc_synergy.intersection(top10_acc_synergy_truth))/len(top10_acc_synergy_truth) 156 | 157 | top10_suppressor = np.array(list(all_results.keys()))[np.argsort(mags)[:top_k]] 158 | top10_acc_suppressor = set(top10_suppressor).intersection(set([(',').join(i.split("+")) for i in GIs['SUPPRESSOR']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['SUPPRESSOR']]))) 159 | top10_suppressor_truth = np.array(list(all_results_truth.keys()))[np.argsort(mags_truth)[:top_k]] 160 | top10_acc_suppressor_truth = set(top10_suppressor_truth).intersection(set([(',').join(i.split("+")) for i in GIs['SUPPRESSOR']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['SUPPRESSOR']]))) 161 | try: 162 | top10_acc_suppressor = len(top10_acc_suppressor.intersection(top10_acc_suppressor_truth))/len(top10_acc_suppressor_truth) 163 | except: 164 | top10_acc_suppressor=0 165 | 166 | top10_neomorphism = np.array(list(all_results.keys()))[np.argsort(model_fits)[:top_k]] 167 | top10_acc_neomorphism = set(top10_neomorphism).intersection(set([(',').join(i.split("+")) for i in GIs['NEOMORPHIC']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['NEOMORPHIC']]))) 168 | top10_neomorphism_truth = np.array(list(all_results_truth.keys()))[np.argsort(model_fits)[:top_k]] 169 | top10_acc_neomorphism_truth = set(top10_neomorphism_truth).intersection(set([(',').join(i.split("+")) for i in GIs['NEOMORPHIC']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['NEOMORPHIC']]))) 170 | top10_acc_neomorphism = len(top10_acc_neomorphism.intersection(top10_acc_neomorphism_truth))/len(top10_acc_neomorphism_truth) 171 | 172 | top10_redundant = np.array(list(all_results.keys()))[np.argsort(Similaritys)[::-1][:top_k]] 173 | top10_acc_redundant = set(top10_redundant).intersection(set([(',').join(i.split("+")) for i in GIs['REDUNDANT']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['REDUNDANT']]))) 174 | top10_redundant_truth = np.array(list(all_results_truth.keys()))[np.argsort(Similaritys)[::-1][:top_k]] 175 | top10_acc_redundant_truth = set(top10_redundant_truth).intersection(set([(',').join(i.split("+")) for i in GIs['REDUNDANT']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['REDUNDANT']]))) 176 | try: 177 | top10_acc_redundant = len(top10_acc_redundant.intersection(top10_acc_redundant_truth))/len(top10_acc_redundant_truth) 178 | except: 179 | top10_acc_redundant=0 180 | 181 | top10_epistasis = np.array(list(all_results.keys()))[np.argsort(equality_of_contributions)[:top_k]] 182 | top10_acc_epistasis = set(top10_epistasis).intersection(set([(',').join(i.split("+")) for i in GIs['EPISTASIS']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['EPISTASIS']]))) 183 | top10_epistasis_truth = np.array(list(all_results_truth.keys()))[np.argsort(equality_of_contributions)[:top_k]] 184 | top10_acc_epistasis_truth = set(top10_epistasis_truth).intersection(set([(',').join(i.split("+")) for i in GIs['EPISTASIS']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['EPISTASIS']]))) 185 | top10_acc_epistasis = len(top10_acc_epistasis.intersection(top10_acc_epistasis_truth))/len(top10_acc_epistasis_truth) 186 | 187 | return top10_acc_synergy, top10_acc_suppressor, top10_acc_neomorphism, top10_acc_redundant, top10_acc_epistasis 188 | 189 | # print("STAMP,Top10 acc:",calculate_metric_top10acc(all_results, all_results_truth)) 190 | # print("GEARs,Top10 acc:",calculate_metric_top10acc(all_results_gears, all_results_truth)) 191 | 192 | def array_generation(results): 193 | X = [] 194 | Y = [] 195 | for idx, combo_gene in enumerate(results): 196 | if ('+').join(combo_gene.split(",")) in GIs['SYNERGY'] or ('+').join(combo_gene.split(",")[::-1]) in GIs['SYNERGY']: 197 | Y.append(0) 198 | elif ('+').join(combo_gene.split(",")) in GIs['SUPPRESSOR'] or ('+').join(combo_gene.split(",")[::-1]) in GIs['SUPPRESSOR']: 199 | Y.append(1) 200 | elif ('+').join(combo_gene.split(",")) in GIs['NEOMORPHIC'] or ('+').join(combo_gene.split(",")[::-1]) in GIs['NEOMORPHIC']: 201 | Y.append(2) 202 | elif ('+').join(combo_gene.split(",")) in GIs['REDUNDANT'] or ('+').join(combo_gene.split(",")[::-1]) in GIs['REDUNDANT']: 203 | Y.append(3) 204 | elif ('+').join(combo_gene.split(",")) in GIs['EPISTASIS'] or ('+').join(combo_gene.split(",")[::-1]) in GIs['EPISTASIS']: 205 | Y.append(4) 206 | elif ('+').join(combo_gene.split(",")) in GIs['ADDITIVE'] or ('+').join(combo_gene.split(",")[::-1]) in GIs['ADDITIVE']: 207 | Y.append(5) 208 | else: 209 | continue 210 | X.append([results[combo_gene][feature] for feature in results[combo_gene]]) 211 | return np.array(X), np.array(Y) 212 | 213 | def acc_cal(X_STAMP,Y_STAMP): 214 | tmp = clf.predict_proba(X_STAMP) 215 | acc_SYNERGY = (Y_STAMP[Y_STAMP==0]==(tmp.argmax(axis=1)[Y_STAMP==0])).mean() 216 | acc_SUPPRESSOR = (Y_STAMP[Y_STAMP==1]==(tmp.argmax(axis=1)[Y_STAMP==1])).mean() 217 | acc_NEOMORPHIC = (Y_STAMP[Y_STAMP==2]==(tmp.argmax(axis=1)[Y_STAMP==2])).mean() 218 | acc_REDUNDANT = (Y_STAMP[Y_STAMP==3]==(tmp.argmax(axis=1)[Y_STAMP==3])).mean() 219 | acc_EPISTASIS = (Y_STAMP[Y_STAMP==4]==(tmp.argmax(axis=1)[Y_STAMP==4])).mean() 220 | acc_Additive = (Y_STAMP[Y_STAMP==5]==(tmp.argmax(axis=1)[Y_STAMP==5])).mean() 221 | return (acc_SYNERGY,acc_SUPPRESSOR, acc_NEOMORPHIC, acc_REDUNDANT, acc_EPISTASIS, acc_Additive) 222 | 223 | X_truth,Y_truth = array_generation(all_results_truth) 224 | 225 | X_truth = (X_truth-X_truth.mean(axis=0))/(X_truth.std(axis=0)) 226 | 227 | clf = tree.DecisionTreeClassifier(random_state=42, max_depth=6,min_samples_leaf=8, min_samples_split=8, max_leaf_nodes=6) 228 | clf = clf.fit(X_truth, Y_truth) 229 | plt.figure(figsize=(12, 8)) 230 | plot_tree(clf, filled=True, feature_names=["magnitude", "model_fit", "equality_of_contribution", "Similarity"], class_names=["SYNERGY", "SUPPRESSOR", "NEOMORPHIC", "REDUNDANT", "EPISTASIS","Additive"]) 231 | plt.savefig("./test_tree.png") 232 | 233 | print("Ground_truth, accuracy",acc_cal(X_truth, Y_truth)) 234 | 235 | X_STAMP,Y_STAMP = array_generation(all_results) 236 | X_STAMP = (X_STAMP-X_STAMP.mean(axis=0))/(X_STAMP.std(axis=0)) 237 | print("STAMP, accuracy",acc_cal(X_STAMP, Y_STAMP)) 238 | 239 | 240 | X_GEARs,Y_GEARs = array_generation(all_results_gears) 241 | X_GEARs = (X_GEARs-X_GEARs.mean(axis=0))/(X_GEARs.std(axis=0)) 242 | print("GEARs, accuracy",acc_cal(X_GEARs, Y_GEARs)) 243 | 244 | ##### Random test 245 | acc_SYNERGY = 0 246 | acc_SUPPRESSOR = 0 247 | acc_NEOMORPHIC = 0 248 | acc_REDUNDANT = 0 249 | acc_EPISTASIS = 0 250 | acc_Additive = 0 251 | for i in range(10): 252 | random_idx = list(range(len(Y_truth))) 253 | np.random.shuffle(random_idx) 254 | random_Y_truth = Y_truth[random_idx] 255 | acc_SYNERGY += (Y_truth[Y_truth==0]==(random_Y_truth[Y_truth==0])).mean() 256 | acc_SUPPRESSOR += (Y_truth[Y_truth==1]==(random_Y_truth[Y_truth==1])).mean() 257 | acc_NEOMORPHIC += (Y_truth[Y_truth==2]==(random_Y_truth[Y_truth==2])).mean() 258 | acc_REDUNDANT += (Y_truth[Y_truth==3]==(random_Y_truth[Y_truth==3])).mean() 259 | acc_EPISTASIS += (Y_truth[Y_truth==4]==(random_Y_truth[Y_truth==4])).mean() 260 | acc_Additive += (Y_truth[Y_truth==5]==(random_Y_truth[Y_truth==5])).mean() 261 | acc_SYNERGY /= i+1;acc_SUPPRESSOR /= i+1;acc_NEOMORPHIC/=i+1;acc_REDUNDANT/=i+1;acc_EPISTASIS/=i+1;acc_Additive/=i+1 262 | print("Random, accuracy",acc_SYNERGY,acc_SUPPRESSOR,acc_NEOMORPHIC,acc_REDUNDANT,acc_EPISTASIS,acc_Additive) 263 | 264 | -------------------------------------------------------------------------------- /Supp_code/GI_experiment/Gears_results.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bm2-lab/STAMP/f40e2887c013647bb01570bd61bdfa01d6a05b00/Supp_code/GI_experiment/Gears_results.pkl -------------------------------------------------------------------------------- /Supp_code/GI_experiment/STAMP_results.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bm2-lab/STAMP/f40e2887c013647bb01570bd61bdfa01d6a05b00/Supp_code/GI_experiment/STAMP_results.pkl -------------------------------------------------------------------------------- /Supp_code/GI_experiment/Science_data.csv: -------------------------------------------------------------------------------- 1 | name,UMI_double,UMI_first,UMI_second,emap,ts_coef_first,ts_coef_second,de_double,de_first,de_second,ts_norm2,abs_log_ts_ratio,ts_linear_dcor,ts_score,dcor,dcor_singles,dcor_ratio 2 | AHR_FEV,8636.844697,11797.92902,11839.87342,-5.027742939,1.324721494,0.897119607,3.410777233,3.069298012,3.110926242,1.599909568,0.169274234,0.848803361,0.291686949,0.802051082,0.423128473,0.949834231 3 | AHR_KLF1,15603.35437,11797.92902,16965.72671,8.108849223,0.287289094,0.961618784,2.556302501,3.069298012,2.626340367,1.003616315,0.524683799,0.77444383,0.643809829,0.564353295,0.228726987,0.325484782 4 | BPGM_SAMD1,11727.04583,13666.33842,12627.92697,0.133637871,0.881589627,0.562934492,3.030194785,2.45484486,2.346352974,1.045990206,0.194808612,0.826855941,0.824285302,0.776432808,0.428246225,0.996113432 5 | BPGM_ZBTB1,13104.41696,13666.33842,13298.50174,-0.165732221,0.572838581,0.648653854,2.737987326,2.45484486,2.652246341,0.865387694,0.053980743,0.828207669,0.851846555,0.786469111,0.378112583,0.846149338 6 | CBL_CNN1,10464.53472,13697.12082,13551.7975,-11.92419707,1.236388617,0.796307064,3.335457901,2.604226053,2.586587305,1.470633113,0.19107443,0.767189707,0.581215221,0.759021105,0.651102721,0.971246728 7 | CBL_PTPN12,11755.97665,13697.12082,13775.92073,-9.292766305,1.027412249,0.488572154,3.005180513,2.604226053,1.857332496,1.137663693,0.322816028,0.787531436,0.604664938,0.763691494,0.616919891,0.899828956 8 | CBL_PTPN9,11971.68803,13697.12082,14425.0804,-8.874028919,1.090314344,0.702850075,3.087781418,2.604226053,1.954242509,1.297221491,0.19068903,0.786436654,0.67898102,0.767071313,0.59465612,0.939267985 9 | CBL_TGFBR2,13510.5641,13697.12082,15409.96646,-7.848298921,0.987539067,0.318175444,2.021189299,2.604226053,0.84509804,1.037530251,0.491887627,0.694873968,0.550328636,0.673294068,0.165872108,0.340666729 10 | CBL_UBASH3A,11734.24,13697.12082,14870.69841,-12.5880076,1.219414513,0.39278773,1.397940009,2.604226053,0.698970004,1.281114341,0.491993446,0.565911049,0.319847012,0.55715236,0.498398153,0.646665863 11 | CBL_UBASH3B,10905.36196,13697.12082,14311.16764,-17.37367685,1.189134451,1.209160918,3.339252634,2.604226053,2.389166084,1.695910042,0.00725314,0.81169121,0.607281853,0.806653002,0.722777602,0.971507213 12 | CDKN1B_CDKN1A,16469.32653,15793.41045,16124.92366,1.298753859,0.811738956,0.552621113,2.626340367,2.691081492,2.589949601,0.981992988,0.166988915,0.858919528,0.857994563,0.852215814,0.779759604,0.986648095 13 | CDKN1C_CDKN1A,16215.125,16374.58182,16124.92366,1.043144218,0.619520525,0.546857729,2.480006943,2.632457292,2.589949601,0.826352865,0.054181345,0.846051153,0.778233899,0.844506245,0.807302633,0.985484102 14 | CDKN1C_CDKN1B,16601.72152,16374.58182,15320.97241,0.706050967,0.810262567,0.388682174,2.334453751,2.632457292,1.86923172,0.898665266,0.319031152,0.786032642,0.716122276,0.782493728,0.708938972,0.863180227 15 | CEBPB_CEBPA,7563.773333,11799,7722.624031,-1.742055369,0.197434739,0.902550261,3.262213705,3.046885191,3.461348434,0.923892553,0.660047827,0.899737549,0.950724936,0.897086826,0.77120447,0.816624294 16 | CEBPB_MAPK1,11963.90208,11799,13065.41995,0.043571624,0.443488455,0.718158591,3.103461622,3.046885191,2.790988475,0.844057919,0.209338042,0.888166659,0.868354231,0.850716029,0.533365168,0.97015814 17 | CEBPB_OSR2,11856.60638,11799,13988.30155,-1.254239492,0.818998884,0.644165689,2.972202838,3.046885191,2.541579244,1.041973419,0.104285722,0.878186153,0.822441522,0.847253341,0.461574835,0.812212918 18 | CEBPB_PTPN12,12005.34211,11799,13775.92073,5.129323636,0.549778509,0.379569796,2.890979597,3.046885191,1.857332496,0.668079067,0.160896113,0.797909156,0.672672467,0.756725077,0.228562525,0.577042676 19 | CEBPE_CEBPA,6785.266667,12149.16702,7722.624031,0.137461117,0.2941997,0.906974965,3.468051791,3.258397804,3.461348434,0.953497274,0.488953073,0.946981921,0.973038243,0.945129984,0.796302841,0.843371476 20 | CEBPE_CEBPB,8777.954955,12149.16702,11186.25275,-1.345570837,0.711181374,0.706802479,3.242789809,3.258397804,3.020775488,1.002670779,0.00268231,0.837563188,0.844607024,0.8383896,0.819268136,0.979346064 21 | CEBPE_CNN1,12539.13402,12149.16702,13551.7975,1.575725753,0.516728029,0.556403091,2.618048097,3.258397804,2.586587305,0.759336721,0.032127514,0.792517361,0.60565655,0.738101034,0.243477827,0.708260752 22 | CEBPE_KLF1,15994.29049,12149.16702,16965.72671,8.112668533,0.198732644,0.607721817,2.193124598,3.258397804,2.626340367,0.639390703,0.485435618,0.683678832,0.725060609,0.442425592,0.308756986,0.5598106 23 | CEBPE_PTPN12,12346.59937,12149.16702,13775.92073,4.805238489,0.486174479,0.427682646,2.84509804,3.258397804,1.857332496,0.64751685,0.055670529,0.794011277,0.666559708,0.737870109,0.208562737,0.668670689 24 | CEBPE_RUNX1T1,13164.67282,12149.16702,15136.08146,7.186707799,0.528024425,0.192067308,3.15715444,3.258397804,1,0.561871555,0.439200563,0.898611354,0.92359612,0.895592766,0.387950582,0.484840303 25 | CEBPE_SPI1,7469.125714,12149.16702,12152.25893,-0.647642772,1.296935668,0.438928258,3.434888121,3.258397804,3.271376872,1.369196896,0.470524893,0.871711071,0.824827936,0.854669,0.724255829,0.901574566 26 | CNN1_MAPK1,12634.65538,13734.63136,13065.41995,2.744813323,0.516854892,0.726869999,2.833784375,2.033423755,2.790988475,0.891896281,0.148088113,0.80036494,0.722336562,0.712052998,0.141684934,0.475171325 27 | CNN1_UBASH3A,10354.79167,13734.63136,14870.69841,-12.62252645,1.172934016,1.394486602,3.158060794,2.033423755,0.698970004,1.822187446,0.075140764,0.625393712,0.420910238,0.600296595,0.33481356,0.822338643 28 | DUSP9_ETS2,11765.28653,11420.06042,13646.32012,5.73928537,0.745848619,0.251165296,3.262688344,3.31868927,2.627365857,0.787003284,0.472691058,0.902021621,0.909803082,0.86340842,0.177495732,0.312562378 29 | DUSP9_IGDCC3,11007.87307,11420.06042,14162.75336,-2.855596189,0.872373698,0.558357698,3.245018871,3.31868927,2.480006943,1.035760198,0.193790055,0.892098107,0.887314206,0.853389245,0.408061986,0.730936563 30 | DUSP9_KLF1,12910.53438,11420.06042,16965.72671,-8.279382679,0.860066154,0.548829397,2.90579588,3.31868927,2.626340367,1.020258544,0.195094492,0.84462104,0.405028022,0.767668309,0.157677804,0.45161062 31 | DUSP9_MAPK1,13607.44939,11420.06042,13065.41995,10.97294961,0.390445231,0.22762557,2.037426498,3.31868927,2.790988475,0.451952297,0.234339077,0.719977164,0.632241426,0.641788904,0.160302921,0.610895593 32 | DUSP9_PRTG,11112.5,11420.06042,14082.03378,-0.274218251,0.741686659,0.768578988,3.142389466,3.31868927,2.539076099,1.068088367,0.01546804,0.880315779,0.888419445,0.840874145,0.431651033,0.863363575 33 | DUSP9_SNAI1,9051.970238,11420.06042,12942.9625,-5.106041145,0.920403475,1.210886763,3.297541668,3.31868927,2.305351369,1.520983007,0.119125282,0.814706915,0.67332464,0.744778431,0.276644636,0.885213216 34 | ETS2_CEBPE,12600.60625,13415.224,10803.07067,4.08893643,0.554168023,0.547054798,2.943988875,2.334453751,3.412124406,0.778698369,0.005610631,0.881355459,0.874423873,0.849920391,0.397292999,0.715649331 35 | ETS2_CNN1,13301.58599,13415.224,13551.7975,2.963442466,0.531611948,0.694234984,2.976808337,2.334453751,2.586587305,0.874398923,0.115911762,0.799242773,0.686112904,0.659071405,0.171328119,0.712042781 36 | ETS2_IGDCC3,13966.67123,13415.224,14162.75336,3.517946458,0.364204027,0.646645865,2.588831726,2.334453751,2.480006943,0.742155945,0.249321762,0.769855918,0.707286895,0.684466017,0.097831906,0.464417285 37 | ETS2_IKZF3,11333.86598,13415.224,10792.22488,3.213049525,0.433618386,0.783590974,3.315760491,2.334453751,3.197831693,0.895566703,0.256981736,0.91081426,0.958604252,0.895268569,0.285972009,0.454071094 38 | ETS2_MAP7D1,14029.75849,13415.224,16034.69553,-8.130058286,0.901764675,1.142253745,2.72427587,2.334453751,1.204119983,1.455308609,0.102669372,0.673798982,0.622982588,0.573413405,0.097997848,0.788321061 39 | ETS2_MAPK1,10452.27778,13415.224,13065.41995,-2.492788735,0.467608915,1.069031241,3.360025089,2.334453751,2.790988475,1.166827276,0.359107614,0.87005775,0.875464014,0.868083617,0.782637405,0.892759196 40 | ETS2_PRTG,13322.32486,13415.224,14082.03378,0.913348856,0.369143357,0.78031689,2.73239376,2.334453751,2.539076099,0.86322724,0.32507595,0.805695362,0.754610063,0.713911424,0.107060691,0.40320754 41 | FEV_CBFA2T3,7040.738562,12598.44509,13696.46914,-9.174221894,1.464067071,1.173544895,3.399327532,2.546542663,2.11058971,1.876352848,0.096061264,0.736163457,0.590764104,0.665060987,0.158151012,0.636881675 42 | FEV_ISL2,9070.136029,12598.44509,14545.00939,-6.76006398,1.074089588,0.938481999,3.363799945,2.546542663,0.903089987,1.426329873,0.05861456,0.760891753,0.549812377,0.732349249,0.202657432,0.532004074 43 | FEV_MAP7D1,11403.55285,12598.44509,16034.69553,-17.98743136,0.799132772,1.14711789,3.122543524,2.546542663,1.204119983,1.398031703,0.156989112,0.724961339,0.659692891,0.660944997,0.101899662,0.605546768 44 | FOSB_CEBPB,8741.094118,14237.66534,11186.25275,-0.838076988,0.642261329,1.181266678,3.167317335,2.357934847,3.020775488,1.344578216,0.26463618,0.826199929,0.826870444,0.826074518,0.684825928,0.801750062 45 | FOSB_CEBPE,9682.746479,14237.66534,10803.07067,-4.770079913,0.342593158,1.140074278,3.23350376,2.357934847,3.412124406,1.190436656,0.522154463,0.880114063,0.902458792,0.879180986,0.703304378,0.772134051 46 | FOSB_IKZF3,12068.60784,14237.66534,10792.22488,6.979696085,0.455352082,0.552411675,2.871572936,2.357934847,3.197831693,0.715893971,0.083915522,0.852862126,0.93357552,0.829234283,0.340114162,0.601480227 47 | FOSB_OSR2,12804.59877,14237.66534,13988.30155,-2.286219529,0.94285204,0.84158268,2.897077003,2.357934847,2.541579244,1.263816196,0.049346756,0.843472206,0.717954019,0.797591593,0.383793997,0.969500562 48 | FOSB_PTPN12,13597.76632,14237.66534,13775.92073,6.08485282,0.407665944,0.537729177,2.204119983,2.357934847,1.857332496,0.67479196,0.12025917,0.667084148,0.545325012,0.60012142,0.144863939,0.620875 49 | FOSB_UBASH3B,13442.14325,14237.66534,14311.16764,5.691658063,0.552437883,0.532223703,2.307496038,2.357934847,2.389166084,0.767104742,0.01618924,0.749868517,0.680892429,0.67896995,0.206272503,0.806498428 50 | FOXA1_FOXF1,11799.84388,13616.223,14341.4794,-3.404568293,1.096066242,0.689994223,2.959518377,2.320146286,1.204119983,1.295165331,0.200991348,0.736810827,0.627424694,0.720773123,0.525740726,0.821520243 51 | FOXA1_FOXL2,10116.68715,13616.223,11533.13278,-1.762576429,0.555985598,1.002491852,3.105169428,2.320146286,2.79518459,1.146346326,0.256017309,0.786560694,0.687887234,0.777181775,0.579837524,0.810366005 52 | FOXA1_HOXB9,9851.755639,13616.223,11882.16406,-3.677702592,0.835984464,1.084173293,3.322012439,2.320146286,2.666517981,1.369051406,0.112900499,0.812205276,0.703061854,0.796701741,0.539119633,0.82919151 53 | FOXA3_FOXA1,11293.43094,12535.74328,13077.36418,-0.895729844,1.227342903,0.507016092,3.027349608,2.980457892,2.301029996,1.327944245,0.383944172,0.865852363,0.828298989,0.858789426,0.725421573,0.861392296 54 | FOXA3_FOXF1,11358.5274,12535.74328,14341.4794,-1.690503757,1.096812933,0.279254637,2.758911892,2.980457892,1.204119983,1.131804649,0.59413217,0.787445666,0.671841143,0.781515531,0.557535841,0.689736249 55 | FOXA3_FOXL2,8891.40708,12535.74328,11533.13278,-0.947739512,1.093592622,0.908210321,3.181271772,2.980457892,2.79518459,1.421545219,0.080669139,0.786489903,0.668007277,0.780517965,0.694711141,0.98975022 56 | FOXA3_HOXB9,9656.863787,12535.74328,11882.16406,-2.790099564,1.137353739,0.797940651,3.389520466,2.980457892,2.666517981,1.389346181,0.153924969,0.843231765,0.718681802,0.833232915,0.680176106,0.960565661 57 | FOXF1_FOXL2,9703.034682,14135.87277,11533.13278,-2.28736857,0.786630895,1.233076742,3.224533063,1.949390007,2.79518459,1.462623129,0.195219107,0.782206312,0.682867138,0.778978448,0.602816606,0.811885278 58 | FOXF1_HOXB9,10817.24855,14135.87277,11882.16406,-2.890810768,0.876665612,0.930986539,2.996073654,1.949390007,2.666517981,1.278780095,0.02610943,0.791389107,0.690310743,0.776576923,0.508089544,0.812075189 59 | FOXL2NB_FOXL2,11012.59184,11064.30415,11533.13278,1.779331866,0.716285078,0.361353683,1.414973348,2.931966115,2.79518459,0.802272271,0.297153419,0.647031936,0.515983634,0.647738214,0.782128897,0.925137448 60 | FOXL2_HOXB9,8482.770115,11907.10891,11882.16406,-1.329540145,1.004384228,0.898307234,3.04060234,2.902002891,2.666517981,1.347495293,0.048474987,0.808048675,0.71022185,0.800533292,0.679832039,0.998586688 61 | FOXL2_MEIS1,10409.43503,11907.10891,13716.44872,-2.39977222,0.921020134,0.766142188,3.015778756,2.902002891,1.707570176,1.198020008,0.079959747,0.79626318,0.699111897,0.777135414,0.487731079,0.791939297 62 | IGDCC3_MAPK1,14003.10239,13874.75352,13065.41995,4.300250641,0.455597543,0.438698989,2.431363764,2.629409599,2.790988475,0.632476027,0.01641474,0.729522831,0.609531652,0.651376959,0.163937861,0.926295594 63 | IGDCC3_PRTG,12710.94444,13874.75352,14082.03378,1.608980686,0.586107352,0.634582591,2.346352974,2.629409599,2.539076099,0.863838465,0.034510984,0.782351469,0.807184237,0.781076535,0.772179546,0.991147957 64 | IGDCC3_ZBTB25,12853.12121,13874.75352,14683.2996,-6.576366626,1.110245165,0.783264841,2.733999287,2.629409599,2.136720567,1.358730341,0.151510258,0.806253481,0.761795215,0.787228576,0.524281876,0.833457845 65 | IRF1_SET,12962.80926,15671.99698,13421.56475,-6.988736995,0.697645225,0.795742337,3.195622944,2.764176132,2.908485019,1.058260236,0.057137839,0.880045613,0.838025145,0.825844081,0.325779767,0.881354135 66 | JUN_CEBPA,6258.328571,11627.30213,7722.624031,-0.911147039,0.362276276,0.986779584,3.314288661,2.86332286,3.461348434,1.05117936,0.435180261,0.906509709,0.944949647,0.90102418,0.437827672,0.527494198 67 | JUN_CEBPB,8086.211538,11627.30213,11186.25275,-1.561250115,0.663613429,0.896706075,2.86569606,2.86332286,3.020775488,1.115555722,0.130734946,0.775955397,0.71301994,0.760118685,0.511400601,0.806377417 68 | KLF1_CEBPA,10018.12987,16941.35607,7722.624031,10.07008471,0.191513545,0.723248907,3.411282913,2.833784375,3.461348434,0.748175393,0.57708829,0.91467045,0.946912391,0.909806086,0.156143605,0.132127382 69 | KLF1_CLDN6,13293.6186,16941.35607,12557.41065,-7.140357162,0.731443935,0.753947434,2.788875116,2.833784375,3.005180513,1.050450838,0.013160024,0.749898908,0.755937195,0.652890816,0.220085605,0.51273542 70 | KLF1_COL2A1,15956.68526,16941.35607,14324.93392,-7.527070526,0.625799336,0.803624774,2.763427994,2.833784375,2.669316881,1.018546801,0.108618218,0.769249034,0.623724153,0.703331008,0.095874437,0.464221231 71 | KLF1_FOXA1,14289.12074,16941.35607,13077.36418,-4.660041825,0.733820937,0.82342367,2.5774918,2.833784375,2.301029996,1.102959523,0.050033248,0.702063063,0.656328187,0.56665259,0.17001451,0.828928502 72 | KLF1_MAP2K6,18223.34234,16941.35607,16847.97619,-1.245393582,0.774651018,0.432913551,2.26245109,2.833784375,2.113943352,0.887411033,0.252704915,0.810674387,0.800950346,0.770662068,0.409285776,0.746892754 73 | KLF1_TGFBR2,16120.23194,16941.35607,15409.96646,-8.882663739,0.746960196,0.30598057,2.181843588,2.833784375,0.84509804,0.807201117,0.387603611,0.701582403,0.40229939,0.678914207,0.095663258,0.293754476 74 | LHX1_ELMSAN1,12547.65814,12104.16667,16978.7907,-6.676777988,0.841852226,0.441000502,2.988558957,2.705007959,1.851258349,0.950366568,0.28079678,0.74837991,0.638229829,0.666713304,0.328110716,0.154103481 75 | LYL1_CEBPB,12723.45912,13887.6,11186.25275,6.734497397,0.468963309,0.499630597,2.426511261,1.633468456,3.020775488,0.685242526,0.027510162,0.749679529,0.569130773,0.729852975,0.29445261,0.558705996 76 | LYL1_IER5L,10272.14991,13887.6,11958.24103,-5.971554097,0.301680001,0.985819281,3.421274791,1.633468456,2.681241237,1.030946399,0.514250787,0.813960909,0.771745767,0.809115129,0.489713229,0.609163059 77 | MAP2K3_ELMSAN1,20565.01446,19002.51092,16978.7907,-4.616741496,0.907377257,0.769829495,3.177824972,2.795880017,1.851258349,1.189945771,0.071393343,0.853404464,0.724937032,0.837294085,0.56677764,0.861671302 78 | MAP2K3_IKZF3,14750.99091,19002.51092,10792.22488,9.537746863,0.248636517,0.61311005,2.617000341,2.795880017,3.197831693,0.661607173,0.391973521,0.791968058,0.360803035,0.763459396,0.418720863,0.257814491 79 | MAP2K3_MAP2K6,19073.95633,19002.51092,16847.97619,-5.182336519,0.55500612,0.432476343,2.790988475,2.795880017,2.113943352,0.703610389,0.108335416,0.876210209,0.83625186,0.871670331,0.73734692,0.942076989 80 | MAP2K3_SLC38A2,18358.99327,19002.51092,15759.45357,-1.534254356,0.768911252,0.437098605,2.599883072,2.795880017,1.230448921,0.884465774,0.245296796,0.82213848,0.757333072,0.796772502,0.356687704,0.62343575 81 | MAP2K6_ELMSAN1,18832.58904,18377.71625,16978.7907,-3.780749567,0.856571607,0.55083486,2.754348336,2.324282455,1.851258349,1.018397742,0.191742257,0.821666947,0.748537531,0.799020209,0.500132649,0.828879281 82 | MAP2K6_IKZF3,13342.98805,18377.71625,10792.22488,8.961105885,0.251362864,0.665459856,2.84135947,2.324282455,3.197831693,0.711350905,0.422820746,0.846312187,0.704434855,0.828969486,0.352050343,0.240160044 83 | MAP2K6_SPI1,15388.49174,18377.71625,12152.25893,7.691201953,0.196648506,0.459466711,2.290034611,2.324282455,3.271376872,0.499780245,0.3685634,0.711306252,0.728633493,0.687616749,0.413254143,0.285389156 84 | MAPK1_IKZF3,11846.40152,13076.2509,10792.22488,3.415792474,0.415440051,0.600631396,3.105510185,2.59439255,3.197831693,0.730307134,0.160099667,0.861870425,0.927883881,0.831736857,0.341522917,0.643318459 85 | MAPK1_PRTG,14127.37368,13076.2509,14082.03378,1.350838838,0.410133394,0.562854605,2.664641976,2.59439255,2.539076099,0.696429973,0.137471091,0.742901276,0.589695206,0.646810737,0.131344164,0.884243716 86 | MAPK1_TGFBR2,13754.95887,13076.2509,15409.96646,3.95567623,0.746940705,0.136640751,2.636487896,2.59439255,0.84509804,0.759335968,0.737705889,0.806045323,0.801448703,0.796959163,0.233062177,0.206835257 87 | PLK4_STIL,12173.63636,14449.81443,12756.88652,-2.655264906,0.419465981,0.752439348,1.556302501,0,0.903089987,0.861461944,0.253774753,0.524852422,0.324729848,0.499191032,0.40318767,0.776153436 88 | POU3F2_CBFA2T3,10481.31188,12106.46154,13696.46914,-6.602857974,1.171567589,0.621729535,3.0923697,2.58546073,2.11058971,1.326317545,0.27516585,0.793951694,0.706853576,0.750946609,0.35377785,0.740505724 89 | POU3F2_FOXL2,8979.702797,12106.46154,11533.13278,-3.005394022,0.92270723,1.020388228,3.401400541,2.58546073,2.79518459,1.375711004,0.043701516,0.8408346,0.771167876,0.82431596,0.583569552,0.919215676 90 | PRDM1_CBFA2T3,7352.633028,10256.42857,13696.46914,-2.535841296,1.224280108,0.499653669,3.315340477,3.245759356,2.11058971,1.322314475,0.389211712,0.815008139,0.682092841,0.801459347,0.36569377,0.549864483 91 | PTPN12_OSR2,14426.18375,14014.87629,13988.30155,7.367622216,0.288280185,0.332547939,1.633468456,1.826074803,2.541579244,0.440106347,0.062039468,0.575267573,0.497319864,0.484016602,0.135709611,0.574755643 92 | PTPN12_PTPN9,11696.46102,14014.87629,14425.0804,-9.005757018,0.7383842,0.868063435,3.157456768,1.826074803,1.954242509,1.139625094,0.070269068,0.776336812,0.655984144,0.752076023,0.4861733,0.962740284 93 | PTPN12_SNAI1,12319.45902,14014.87629,12942.9625,10.07555502,0.600056214,0.573352956,2.701567985,1.826074803,2.305351369,0.829940403,0.019769882,0.757667528,0.656193036,0.699233463,0.226777577,0.782388504 94 | PTPN12_UBASH3A,11319.51765,14014.87629,14870.69841,-11.66040496,1.07642032,1.018797608,3.157758886,1.826074803,0.698970004,1.482102989,0.023893971,0.707980549,0.478099864,0.684474996,0.351565613,0.728704695 95 | PTPN12_ZBTB25,11876.8677,14014.87629,14683.2996,-8.056528667,0.751759062,0.983985647,3.067070856,1.826074803,2.136720567,1.238292954,0.116910092,0.794687457,0.6872563,0.768537561,0.494841647,0.942712472 96 | RHOXF2B_SET,10323.2397,10930.4069,13421.56475,7.164357199,0.705755666,0.609668871,3.300378065,3.207634367,2.908485019,0.932623822,0.063560353,0.872215805,0.826291702,0.840901018,0.490519694,0.953460818 97 | RHOXF2B_ZBTB25,11005.28293,10930.4069,14683.2996,6.847164341,0.763576583,0.802227653,3.127428778,3.207634367,2.136720567,1.107528061,0.021445027,0.851454028,0.716207176,0.819650761,0.361491225,0.682564476 98 | SAMD1_PTPN12,10546.71575,11822.51587,13775.92073,-6.398557523,0.909073745,0.551394305,3.285782274,2.913283902,1.857332496,1.063226577,0.217136839,0.828934693,0.704138241,0.798401282,0.431435321,0.759486087 99 | SAMD1_TGFBR2,11972.27119,11822.51587,15409.96646,-8.107851289,0.869725389,0.43445474,1.322219295,2.913283902,0.84509804,0.972200171,0.301437608,0.640896979,0.599548983,0.632411771,0.080093324,0.228092849 100 | SAMD1_UBASH3B,10741.57416,11822.51587,14311.16764,-7.528617025,0.893324788,0.917398582,3.211921084,2.913283902,2.389166084,1.280487929,0.011548679,0.84098202,0.778241998,0.825687121,0.545809017,0.814472575 101 | SAMD1_ZBTB1,10795.85366,11822.51587,13298.50174,0.372634559,0.749627701,0.827827763,3.046495164,2.913283902,2.652246341,1.116799219,0.043094361,0.849807895,0.796000237,0.816243375,0.473419466,0.959725238 102 | SET_CEBPE,13110.97543,14035.45117,10803.07067,3.821361642,0.624992472,0.203585476,3.104145551,3.083860801,3.412124406,0.657314716,0.487127994,0.855534698,0.784660789,0.773792775,0.396066676,0.735155459 103 | SET_KLF1,16486.39926,14035.45117,16965.72671,-4.790630364,0.666004631,0.724769436,2.924795996,3.083860801,2.626340367,0.984303258,0.036722621,0.836322527,0.697481086,0.686002014,0.236300229,0.580657286 104 | SGK1_TBX2,12025.19512,12336.787,13914.72881,1.310275508,0.804936225,0.731411338,3.14082218,3.10720997,2.434568904,1.087605108,0.041599784,0.868834244,0.865735206,0.842236927,0.51071007,0.959354312 105 | SGK1_TBX3,12739.70782,12336.787,14505,-0.389360046,0.784773698,0.647871406,2.894316063,3.10720997,2.133538908,1.017647835,0.083255627,0.890050364,0.736125208,0.923203425,0.370183246,0.682986951 106 | SNAI1_DLX2,9787,13686.24545,12835.45172,-4.38912823,1.14532526,0.882836892,3.098643726,2.285557309,2.591064607,1.446088147,0.113048366,0.744400927,0.626867608,0.721699525,0.499316035,0.873888934 107 | SNAI1_UBASH3B,12653.79297,13686.24545,14311.16764,6.9698056,0.736916087,0.891443369,2.855519156,2.285557309,2.389166084,1.156596991,0.082675721,0.818085084,0.767350742,0.773401166,0.351883827,0.822914157 108 | TBX3_TBX2,13807.30031,14376.44637,13914.72881,-0.952148354,0.735713499,0.64424836,3.217220656,2.117271296,2.434568904,0.977921419,0.057655402,0.900807612,0.892832212,0.88290534,0.35702695,0.86670143 109 | TGFBR2_CBARP,12145.62409,16065.11483,14362.77407,-7.115548769,0.090276718,0.850719808,2.742725131,0.301029996,1.716003344,0.855496393,0.974210783,0.537434279,0.398571272,0.503398657,0.10339338,0.209380757 110 | TGFBR2_ETS2,14768.47636,16065.11483,13646.32012,5.818769871,0.116972048,0.567208074,1.755874856,0.301029996,2.627365857,0.579143729,0.685660311,0.633545702,0.582562968,0.596019738,0.238684094,0.152343388 111 | TGFBR2_IGDCC3,13191.4542,16065.11483,14162.75336,0.485365858,0.122216034,0.760712652,2.549003262,0.301029996,2.480006943,0.770467714,0.794092454,0.756707278,0.724232417,0.742370505,0.172887454,0.262692258 112 | TGFBR2_PRTG,12900.80769,16065.11483,14082.03378,1.630267601,0.051310753,0.923956999,2.544068044,0.301029996,2.539076099,0.925380641,1.255443368,0.784791742,0.823568472,0.768317952,0.203975856,0.248194348 113 | UBASH3B_CNN1,10725.07324,14063.23404,13551.7975,-16.58980658,1.212246283,0.962860571,3.3586961,2.378397901,2.586587305,1.548109017,0.100027458,0.808383898,0.682219701,0.798315986,0.621193342,0.96783611 114 | UBASH3B_OSR2,14884.37075,14063.23404,13988.30155,6.65485367,-0.01838437,0.367629526,2.11058971,2.378397901,2.541579244,0.368088921,1.300961639,0.682917797,0.622222396,0.644576921,0.260723913,0.315819702 115 | UBASH3B_PTPN12,11332.46154,14063.23404,13775.92073,-10.76779039,1.021372069,0.595685121,3.122215878,2.378397901,1.857332496,1.18238812,0.234167225,0.764947361,0.591454917,0.741417699,0.587460103,0.948035612 116 | UBASH3B_PTPN9,12555.31915,14063.23404,14425.0804,-7.66800543,1.289197478,0.653431269,3.098643726,2.378397901,1.954242509,1.445338217,0.295119534,0.812567054,0.711899225,0.793137969,0.664629088,0.936463678 117 | UBASH3B_UBASH3A,14188.25352,14063.23404,14870.69841,2.623744933,0.913865214,0.410892151,2.683947131,2.378397901,0.698970004,1.001989017,0.347154301,0.78296664,0.726350132,0.760406899,0.477881348,0.709040053 118 | UBASH3B_ZBTB25,12408.12946,14063.23404,14683.2996,-8.06503663,1.02071028,0.687041565,2.895974732,2.378397901,2.136720567,1.230396517,0.171919477,0.790346045,0.59809173,0.780251964,0.663206049,0.981034076 119 | ZBTB10_DLX2,8448.364865,10531.65517,12835.45172,-3.74965872,1.066279561,0.846798705,3.13481437,2.988558957,2.591064607,1.361624085,0.1000909,0.792719301,0.683952394,0.781513116,0.400261218,0.617034869 120 | ZBTB10_ELMSAN1,10288.49068,10531.65517,16978.7907,-7.543480386,1.009903688,0.172350471,3.121231455,2.988558957,1.851258349,1.024504828,0.767867484,0.849333398,0.809231506,0.843521274,0.39078376,0.356840993 121 | ZBTB10_PTPN12,10124.58261,10531.65517,13775.92073,7.752122403,0.728830991,0.455160448,3.254064453,2.988558957,1.857332496,0.859282053,0.204462315,0.844756886,0.750872545,0.82377239,0.237415929,0.452185145 122 | ZBTB10_SNAI1,8498.470588,10531.65517,12942.9625,-1.716184098,0.935797128,0.63951727,3.14113609,2.988558957,2.305351369,1.133445457,0.165329431,0.828663341,0.681733092,0.817607998,0.549498791,0.790953551 123 | ZC3HAV1_CEBPA,7989.677419,16377.05505,7722.624031,-9.826651018,0.318894432,0.891369331,3.36078269,2.045322979,3.461348434,0.946695803,0.446410751,0.912331818,0.940605339,0.909385192,0.236326219,0.194054803 124 | ZC3HAV1_CEBPE,15794.74566,16377.05505,10803.07067,8.288154134,0.63493671,0.215502289,2.149219113,2.045322979,3.412124406,0.670511643,0.46927855,0.690145499,0.720944273,0.486556144,0.21748292,0.968025272 125 | ZC3HAV1_HOXC13,12174.1418,16377.05505,10738.57813,6.368610928,0.53804059,0.604288531,3.135450699,2.045322979,3.015778756,0.809105868,0.050429311,0.791498403,0.68866069,0.725703468,0.251890652,0.193133306 126 | ZNF318_FOXL2,12446.00508,15384.52865,11533.13278,-3.079212783,0.816195732,0.732563732,2.609594409,1.763427994,2.79518459,1.09673383,0.046948906,0.762681444,0.670256568,0.723900023,0.217552433,0.577758247 127 | -------------------------------------------------------------------------------- /Supp_code/GI_experiment/Truth_results.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bm2-lab/STAMP/f40e2887c013647bb01570bd61bdfa01d6a05b00/Supp_code/GI_experiment/Truth_results.pkl -------------------------------------------------------------------------------- /Supp_code/Modified_CPA/_data.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from scvi import settings 4 | from scvi.data import AnnDataManager 5 | from scvi.dataloaders import DataSplitter, AnnDataLoader 6 | from scvi.model._utils import parse_use_gpu_arg 7 | 8 | 9 | class AnnDataSplitter(DataSplitter): 10 | def __init__( 11 | self, 12 | adata_manager: AnnDataManager, 13 | train_indices, 14 | valid_indices, 15 | test_indices, 16 | use_gpu: bool = False, 17 | **kwargs, 18 | ): 19 | super().__init__(adata_manager) 20 | self.data_loader_kwargs = kwargs 21 | self.use_gpu = use_gpu 22 | self.train_idx = train_indices 23 | self.val_idx = valid_indices 24 | self.test_idx = test_indices 25 | 26 | def setup(self, stage: Optional[str] = None): 27 | accelerator, _, self.device = parse_use_gpu_arg( 28 | self.use_gpu, return_device=True 29 | ) 30 | self.pin_memory = ( 31 | True 32 | if (settings.dl_pin_memory_gpu_training and accelerator == "gpu") 33 | else False 34 | ) 35 | 36 | def train_dataloader(self): 37 | if len(self.train_idx) > 0: 38 | return AnnDataLoader( 39 | self.adata_manager, 40 | indices=self.train_idx, 41 | shuffle=True, 42 | pin_memory=self.pin_memory, 43 | **self.data_loader_kwargs, 44 | ) 45 | else: 46 | pass 47 | 48 | def val_dataloader(self): 49 | if len(self.val_idx) > 0: 50 | data_loader_kwargs = self.data_loader_kwargs.copy() 51 | # if len(self.valid_indices < 4096): 52 | # data_loader_kwargs.update({'batch_size': len(self.valid_indices)}) 53 | # else: 54 | # data_loader_kwargs.update({'batch_size': 2048}) 55 | return AnnDataLoader( 56 | self.adata_manager, 57 | indices=self.val_idx, 58 | shuffle=True, 59 | pin_memory=self.pin_memory, 60 | **data_loader_kwargs, 61 | ) 62 | else: 63 | pass 64 | 65 | def test_dataloader(self): 66 | if len(self.test_idx) > 0: 67 | return AnnDataLoader( 68 | self.adata_manager, 69 | indices=self.test_idx, 70 | shuffle=True, 71 | pin_memory=self.pin_memory, 72 | **self.data_loader_kwargs, 73 | ) 74 | else: 75 | pass 76 | -------------------------------------------------------------------------------- /Supp_code/Modified_CPA/_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import entropy 3 | from sklearn.neighbors import NearestNeighbors 4 | from sklearn.preprocessing import LabelEncoder 5 | 6 | 7 | def knn_purity(data, labels: np.ndarray, n_neighbors=30): 8 | """Computes KNN Purity for ``data`` given the labels. 9 | Parameters 10 | ---------- 11 | data: 12 | Numpy ndarray of data 13 | labels 14 | Numpy ndarray of labels 15 | n_neighbors: int 16 | Number of nearest neighbors. 17 | Returns 18 | ------- 19 | score: float 20 | KNN purity score. A float between 0 and 1. 21 | """ 22 | labels = LabelEncoder().fit_transform(labels.ravel()) 23 | 24 | nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(data) 25 | indices = nbrs.kneighbors(data, return_distance=False)[:, 1:] 26 | neighbors_labels = np.vectorize(lambda i: labels[i])(indices) 27 | 28 | # pre cell purity scores 29 | scores = ((neighbors_labels - labels.reshape(-1, 1)) == 0).mean(axis=1) 30 | res = [ 31 | np.mean(scores[labels == i]) for i in np.unique(labels) 32 | ] # per cell-type purity 33 | 34 | return np.mean(res) 35 | 36 | 37 | def entropy_batch_mixing(data, labels, 38 | n_neighbors=50, n_pools=50, n_samples_per_pool=100): 39 | """Computes Entory of Batch mixing metric for ``adata`` given the batch column name. 40 | Parameters 41 | ---------- 42 | data 43 | Numpy ndarray of data 44 | labels 45 | Numpy ndarray of labels 46 | n_neighbors: int 47 | Number of nearest neighbors. 48 | n_pools: int 49 | Number of EBM computation which will be averaged. 50 | n_samples_per_pool: int 51 | Number of samples to be used in each pool of execution. 52 | Returns 53 | ------- 54 | score: float 55 | EBM score. A float between zero and one. 56 | """ 57 | 58 | def __entropy_from_indices(indices, n_cat): 59 | return entropy(np.array(np.unique(indices, return_counts=True)[1].astype(np.int32)), base=n_cat) 60 | 61 | n_cat = len(np.unique(labels)) 62 | # print(f'Calculating EBM with n_cat = {n_cat}') 63 | 64 | neighbors = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(data) 65 | indices = neighbors.kneighbors(data, return_distance=False)[:, 1:] 66 | batch_indices = np.vectorize(lambda i: labels[i])(indices) 67 | 68 | entropies = np.apply_along_axis(__entropy_from_indices, axis=1, arr=batch_indices, n_cat=n_cat) 69 | 70 | # average n_pools entropy results where each result is an average of n_samples_per_pool random samples. 71 | if n_pools == 1: 72 | score = np.mean(entropies) 73 | else: 74 | score = np.mean([ 75 | np.mean(entropies[np.random.choice(len(entropies), size=n_samples_per_pool)]) 76 | for _ in range(n_pools) 77 | ]) 78 | 79 | return score 80 | -------------------------------------------------------------------------------- /Supp_code/Modified_CPA/_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scvi import settings 5 | from scvi.distributions import NegativeBinomial, ZeroInflatedNegativeBinomial 6 | from scvi.module import Classifier 7 | from scvi.module.base import BaseModuleClass, auto_move_data 8 | from scvi.nn import Encoder, DecoderSCVI 9 | from torch.distributions import Normal 10 | from torch.distributions.kl import kl_divergence as kl 11 | from torchmetrics.functional import accuracy, pearson_corrcoef, r2_score 12 | 13 | from ._metrics import knn_purity 14 | from ._utils import PerturbationNetwork, VanillaEncoder, CPA_REGISTRY_KEYS 15 | 16 | from typing import Optional 17 | 18 | 19 | class CPAModule(BaseModuleClass): 20 | """ 21 | CPA module using Gaussian/NegativeBinomial/Zero-InflatedNegativeBinomial Likelihood 22 | 23 | Parameters 24 | ---------- 25 | n_genes: int 26 | Number of input genes 27 | n_perts: int, 28 | Number of total unique perturbations 29 | covars_encoder: dict 30 | Dictionary of covariates with keys as each covariate name and values as 31 | unique values of the corresponding covariate 32 | n_latent: int 33 | dimensionality of the latent space 34 | recon_loss: str 35 | Autoencoder loss (either "gauss", "nb" or "zinb") 36 | doser_type: str 37 | Type of dosage network (either "logsigm", "sigm", or "linear") 38 | n_hidden_encoder: int 39 | Number of hidden units in encoder 40 | n_layers_encoder: int 41 | Number of layers in encoder 42 | n_hidden_decoder: int 43 | Number of hidden units in decoder 44 | n_layers_decoder: int 45 | Number of layers in decoder 46 | n_hidden_doser: int 47 | Number of hidden units in dosage network 48 | n_layers_doser: int 49 | Number of layers in dosage network 50 | use_batch_norm_encoder: bool 51 | Whether to use batch norm in encoder 52 | use_layer_norm_encoder: bool 53 | Whether to use layer norm in encoder 54 | use_batch_norm_decoder: bool 55 | Whether to use batch norm in decoder 56 | use_layer_norm_decoder: bool 57 | Whether to use layer norm in decoder 58 | dropout_rate_encoder: float 59 | Dropout rate in encoder 60 | dropout_rate_decoder: float 61 | Dropout rate in decoder 62 | variational: bool 63 | Whether to use variational inference 64 | seed: int 65 | Random seed 66 | """ 67 | 68 | def __init__(self, 69 | n_genes: int, 70 | n_perts: int, 71 | covars_encoder: dict, 72 | drug_embeddings: Optional[np.ndarray] = None, 73 | n_latent: int = 128, 74 | recon_loss: str = "nb", 75 | doser_type: str = "logsigm", 76 | n_hidden_encoder: int = 256, 77 | n_layers_encoder: int = 3, 78 | n_hidden_decoder: int = 256, 79 | n_layers_decoder: int = 3, 80 | n_hidden_doser: int = 128, 81 | n_layers_doser: int = 2, 82 | use_batch_norm_encoder: bool = True, 83 | use_layer_norm_encoder: bool = False, 84 | use_batch_norm_decoder: bool = True, 85 | use_layer_norm_decoder: bool = False, 86 | dropout_rate_encoder: float = 0.0, 87 | dropout_rate_decoder: float = 0.0, 88 | variational: bool = False, 89 | seed: int = 0, 90 | ): 91 | super().__init__() 92 | 93 | recon_loss = recon_loss.lower() 94 | assert recon_loss in ['gauss', 'zinb', 'nb'] 95 | 96 | torch.manual_seed(seed) 97 | np.random.seed(seed) 98 | settings.seed = seed 99 | 100 | self.n_genes = n_genes 101 | self.n_perts = n_perts 102 | self.n_latent = n_latent 103 | self.recon_loss = recon_loss 104 | self.doser_type = doser_type 105 | self.variational = variational 106 | 107 | self.covars_encoder = covars_encoder 108 | 109 | if variational: 110 | self.encoder = Encoder( 111 | n_genes, 112 | n_latent, 113 | var_activation=nn.Softplus(), 114 | n_hidden=n_hidden_encoder, 115 | n_layers=n_layers_encoder, 116 | use_batch_norm=use_batch_norm_encoder, 117 | use_layer_norm=use_layer_norm_encoder, 118 | dropout_rate=dropout_rate_encoder, 119 | activation_fn=nn.ReLU, 120 | return_dist=True, 121 | ) 122 | else: 123 | self.encoder = VanillaEncoder( 124 | n_input=n_genes, 125 | n_output=n_latent, 126 | n_cat_list=[], 127 | n_hidden=n_hidden_encoder, 128 | n_layers=n_layers_encoder, 129 | use_batch_norm=use_batch_norm_encoder, 130 | use_layer_norm=use_layer_norm_encoder, 131 | dropout_rate=dropout_rate_encoder, 132 | activation_fn=nn.ReLU, 133 | output_activation='linear', 134 | ) 135 | 136 | # Decoder components 137 | if self.recon_loss in ['zinb', 'nb']: 138 | # setup the parameters of your generative model, as well as your inference model 139 | self.px_r = torch.nn.Parameter(torch.randn(self.n_genes)) 140 | 141 | # decoder goes from n_latent-dimensional space to n_input-d data 142 | self.decoder = DecoderSCVI( 143 | n_input=n_latent, 144 | n_output=n_genes, 145 | n_layers=n_layers_decoder, 146 | n_hidden=n_hidden_decoder, 147 | use_batch_norm=use_batch_norm_decoder, 148 | use_layer_norm=use_layer_norm_decoder, 149 | ) 150 | 151 | elif recon_loss == "gauss": 152 | self.decoder = Encoder(n_input=n_latent, 153 | n_output=n_genes, 154 | n_layers=n_layers_decoder, 155 | n_hidden=n_hidden_decoder, 156 | dropout_rate=dropout_rate_decoder, 157 | use_batch_norm=use_batch_norm_decoder, 158 | use_layer_norm=use_layer_norm_decoder, 159 | var_activation=None, 160 | ) 161 | 162 | else: 163 | raise Exception('Invalid Loss function for Autoencoder') 164 | 165 | # Embeddings 166 | # 1. Drug Network 167 | self.pert_network = PerturbationNetwork(n_perts=n_perts, 168 | n_latent=n_latent, 169 | doser_type=doser_type, 170 | n_hidden=n_hidden_doser, 171 | n_layers=n_layers_doser, 172 | drug_embeddings=drug_embeddings, 173 | ) 174 | 175 | # 2. Covariates Embedding 176 | self.covars_embeddings = nn.ModuleDict( 177 | { 178 | key: torch.nn.Embedding(len(unique_covars), n_latent) 179 | for key, unique_covars in self.covars_encoder.items() 180 | } 181 | ) 182 | 183 | self.metrics = { 184 | 'pearson_r': pearson_corrcoef, 185 | 'r2_score': r2_score 186 | } 187 | 188 | def mixup_data(self, tensors, alpha: float = 0.0, opt=False): 189 | """ 190 | Returns mixed inputs, pairs of targets, and lambda 191 | """ 192 | alpha = max(0.0, alpha) 193 | 194 | if alpha == 0.0: 195 | mixup_lambda = 1.0 196 | else: 197 | mixup_lambda = np.random.beta(alpha, alpha) 198 | 199 | x = tensors[CPA_REGISTRY_KEYS.X_KEY] 200 | y_perturbations = tensors[CPA_REGISTRY_KEYS.PERTURBATION_KEY] 201 | perturbations = tensors[CPA_REGISTRY_KEYS.PERTURBATIONS] 202 | perturbations_dosages = tensors[CPA_REGISTRY_KEYS.PERTURBATIONS_DOSAGES] 203 | 204 | batch_size = x.size()[0] 205 | index = torch.randperm(batch_size).to(x.device) 206 | 207 | mixed_x = mixup_lambda * x + (1. - mixup_lambda) * x[index, :] 208 | 209 | tensors[CPA_REGISTRY_KEYS.X_KEY] = mixed_x 210 | tensors[CPA_REGISTRY_KEYS.X_KEY + '_true'] = x 211 | tensors[CPA_REGISTRY_KEYS.X_KEY + '_mixup'] = x[index] 212 | tensors[CPA_REGISTRY_KEYS.PERTURBATION_KEY + '_mixup'] = y_perturbations[index] 213 | tensors[CPA_REGISTRY_KEYS.PERTURBATIONS + '_mixup'] = perturbations[index] 214 | tensors[CPA_REGISTRY_KEYS.PERTURBATIONS_DOSAGES + '_mixup'] = perturbations_dosages[index] 215 | 216 | for covar, encoder in self.covars_encoder.items(): 217 | tensors[covar + '_mixup'] = tensors[covar][index] 218 | 219 | return tensors, mixup_lambda 220 | 221 | def _get_inference_input(self, tensors): 222 | x = tensors[CPA_REGISTRY_KEYS.X_KEY] # batch_size, n_genes 223 | perts = { 224 | 'true': tensors[CPA_REGISTRY_KEYS.PERTURBATIONS], 225 | 'mixup': tensors[CPA_REGISTRY_KEYS.PERTURBATIONS + '_mixup'] 226 | } 227 | perts_doses = { 228 | 'true': tensors[CPA_REGISTRY_KEYS.PERTURBATIONS_DOSAGES], 229 | 'mixup': tensors[CPA_REGISTRY_KEYS.PERTURBATIONS_DOSAGES + '_mixup'], 230 | } 231 | 232 | covars_dict = dict() 233 | for covar, unique_covars in self.covars_encoder.items(): 234 | encoded_covars = tensors[covar].view(-1, ) # (batch_size,) 235 | encoded_covars_mixup = tensors[covar + '_mixup'].view(-1, ) # (batch_size,) 236 | covars_dict[covar] = encoded_covars 237 | covars_dict[covar + '_mixup'] = encoded_covars_mixup 238 | 239 | return dict( 240 | x=x, 241 | perts=perts, 242 | perts_doses=perts_doses, 243 | covars_dict=covars_dict, 244 | ) 245 | 246 | @auto_move_data 247 | def inference( 248 | self, 249 | x, 250 | perts, 251 | perts_doses, 252 | covars_dict, 253 | mixup_lambda: float = 1.0, 254 | n_samples: int = 1, 255 | covars_to_add: Optional[list] = None, 256 | ): 257 | batch_size = x.shape[0] 258 | 259 | if self.recon_loss in ['nb', 'zinb']: 260 | # log the input to the variational distribution for numerical stability 261 | x_ = torch.log(1 + x) 262 | 263 | library = torch.log(x.sum(1)).unsqueeze(1) 264 | else: 265 | x_ = x 266 | library = None, None 267 | 268 | if self.variational: 269 | qz, z_basal = self.encoder(x_) 270 | else: 271 | qz, z_basal = None, self.encoder(x_) 272 | 273 | if self.variational and n_samples > 1: 274 | sampled_z = qz.sample((n_samples,)) 275 | z_basal = self.encoder.z_transformation(sampled_z) 276 | if self.recon_loss in ['nb', 'zinb']: 277 | library = library.unsqueeze(0).expand( 278 | (n_samples, library.size(0), library.size(1)) 279 | ) 280 | 281 | z_pert_true = self.pert_network(perts['true'], perts_doses['true']) 282 | if mixup_lambda < 1.0: 283 | z_pert_mixup = self.pert_network(perts['mixup'], perts_doses['mixup']) 284 | z_pert = mixup_lambda * z_pert_true + (1. - mixup_lambda) * z_pert_mixup 285 | else: 286 | z_pert = z_pert_true 287 | 288 | z_covs = torch.zeros_like(z_basal) # ([n_samples,] batch_size, n_latent) 289 | z_covs_wo_batch = torch.zeros_like(z_basal) # ([n_samples,] batch_size, n_latent) 290 | 291 | batch_key = CPA_REGISTRY_KEYS.BATCH_KEY 292 | 293 | if covars_to_add is None: 294 | covars_to_add = list(self.covars_encoder.keys()) 295 | 296 | for covar, encoder in self.covars_encoder.items(): 297 | if covar in covars_to_add: 298 | z_cov = self.covars_embeddings[covar](covars_dict[covar].long()) 299 | if len(encoder) > 1: 300 | z_cov_mixup = self.covars_embeddings[covar](covars_dict[covar + '_mixup'].long()) 301 | z_cov = mixup_lambda * z_cov + (1. - mixup_lambda) * z_cov_mixup 302 | z_cov = z_cov.view(batch_size, self.n_latent) # batch_size, n_latent 303 | z_covs += z_cov 304 | 305 | if covar != batch_key: 306 | z_covs_wo_batch += z_cov 307 | 308 | z = z_basal + z_pert + z_covs 309 | z_corrected = z_basal + z_pert + z_covs_wo_batch 310 | z_no_pert = z_basal + z_covs 311 | z_no_pert_corrected = z_basal + z_covs_wo_batch 312 | 313 | return dict( 314 | z=z, 315 | z_corrected=z_corrected, 316 | z_no_pert=z_no_pert, 317 | z_no_pert_corrected=z_no_pert_corrected, 318 | z_basal=z_basal, 319 | z_covs=z_covs, 320 | z_pert=z_pert.sum(dim=1), 321 | library=library, 322 | qz=qz, 323 | mixup_lambda=mixup_lambda, 324 | ) 325 | 326 | def _get_generative_input(self, tensors, inference_outputs, **kwargs): 327 | if 'latent' in kwargs.keys(): 328 | if kwargs['latent'] in inference_outputs.keys(): # z, z_corrected, z_no_pert, z_no_pert_corrected, z_basal 329 | z = inference_outputs[kwargs['latent']] 330 | else: 331 | raise Exception('Invalid latent space') 332 | else: 333 | z = inference_outputs["z"] 334 | library = inference_outputs['library'] 335 | 336 | return dict( 337 | z=z, 338 | library=library, 339 | ) 340 | 341 | @auto_move_data 342 | def generative( 343 | self, 344 | z, 345 | library=None, 346 | ): 347 | if self.recon_loss == 'nb': 348 | px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library) 349 | px_r = torch.exp(self.px_r) 350 | 351 | px = NegativeBinomial(mu=px_rate, theta=px_r) 352 | 353 | elif self.recon_loss == 'zinb': 354 | px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library) 355 | px_r = torch.exp(self.px_r) 356 | 357 | px = ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) 358 | 359 | else: 360 | px_mean, px_var, x_pred = self.decoder(z) 361 | 362 | px = Normal(loc=px_mean, scale=px_var.sqrt()) 363 | 364 | pz = Normal(torch.zeros_like(z), torch.ones_like(z)) 365 | return dict(px=px, pz=pz) 366 | 367 | def loss(self, tensors, inference_outputs, generative_outputs): 368 | """Computes the reconstruction loss (AE) or the ELBO (VAE)""" 369 | x = tensors[CPA_REGISTRY_KEYS.X_KEY] 370 | 371 | px = generative_outputs['px'] 372 | recon_loss = -px.log_prob(x).sum(dim=-1).mean() 373 | 374 | if self.variational: 375 | qz = inference_outputs["qz"] 376 | pz = generative_outputs['pz'] 377 | 378 | kl_divergence_z = kl(qz, pz).sum(dim=1) 379 | kl_loss = kl_divergence_z.mean() 380 | else: 381 | from scvi.model import SCVI 382 | kl_loss = torch.zeros_like(recon_loss) 383 | 384 | return recon_loss, kl_loss 385 | 386 | def r2_metric(self, tensors, inference_outputs, generative_outputs, mode: str = 'lfc'): 387 | mode = mode.lower() 388 | assert mode in ['direct'] 389 | 390 | x = tensors[CPA_REGISTRY_KEYS.X_KEY] # batch_size, n_genes 391 | indices = tensors[CPA_REGISTRY_KEYS.CATEGORY_KEY].view(-1,) 392 | 393 | unique_indices = indices.unique() 394 | 395 | r2_mean = 0.0 396 | r2_var = 0.0 397 | 398 | px = generative_outputs['px'] 399 | for ind in unique_indices: 400 | i_mask = indices == ind 401 | 402 | x_i = x[i_mask, :] 403 | if self.recon_loss == 'gauss': 404 | x_pred_mean = px.loc[i_mask, :] 405 | x_pred_var = px.scale[i_mask, :] ** 2 406 | 407 | if CPA_REGISTRY_KEYS.DEG_MASK_R2 in tensors.keys(): 408 | deg_mask = tensors[f'{CPA_REGISTRY_KEYS.DEG_MASK_R2}'][i_mask, :] 409 | 410 | x_i *= deg_mask 411 | x_pred_mean *= deg_mask 412 | x_pred_var *= deg_mask 413 | 414 | x_pred_mean = torch.nan_to_num(x_pred_mean, nan=0, posinf=1e3, neginf=-1e3) 415 | x_pred_var = torch.nan_to_num(x_pred_var, nan=0, posinf=1e3, neginf=-1e3) 416 | 417 | r2_mean += torch.nan_to_num(self.metrics['r2_score'](x_pred_mean.mean(0), x_i.mean(0)), 418 | nan=0.0).item() 419 | r2_var += torch.nan_to_num(self.metrics['r2_score'](x_pred_var.mean(0), x_i.var(0)), 420 | nan=0.0).item() 421 | 422 | elif self.recon_loss in ['nb', 'zinb']: 423 | x_i = torch.log(1 + x_i) 424 | x_pred = px.mu[i_mask, :] 425 | x_pred = torch.log(1 + x_pred) 426 | 427 | x_pred = torch.nan_to_num(x_pred, nan=0, posinf=1e3, neginf=-1e3) 428 | 429 | if CPA_REGISTRY_KEYS.DEG_MASK_R2 in tensors.keys(): 430 | deg_mask = tensors[f'{CPA_REGISTRY_KEYS.DEG_MASK_R2}'][i_mask, :] 431 | 432 | x_i *= deg_mask 433 | x_pred *= deg_mask 434 | 435 | r2_mean += torch.nan_to_num(self.metrics['r2_score'](x_pred.mean(0), x_i.mean(0)), 436 | nan=0.0).item() 437 | r2_var += torch.nan_to_num(self.metrics['r2_score'](x_pred.var(0), x_i.var(0)), 438 | nan=0.0).item() 439 | 440 | n_unique_indices = len(unique_indices) 441 | return r2_mean / n_unique_indices, r2_var / n_unique_indices 442 | 443 | def disentanglement(self, tensors, inference_outputs, generative_outputs, linear=True): 444 | z_basal = inference_outputs['z_basal'].detach().cpu().numpy() 445 | z = inference_outputs['z'].detach().cpu().numpy() 446 | 447 | perturbations = tensors[CPA_REGISTRY_KEYS.PERTURBATION_KEY].view(-1, ) 448 | perturbations_names = perturbations.detach().cpu().numpy() 449 | 450 | knn_basal = knn_purity(z_basal, perturbations_names.ravel(), 451 | n_neighbors=min(perturbations_names.shape[0] - 1, 30)) 452 | knn_after = knn_purity(z, perturbations_names.ravel(), 453 | n_neighbors=min(perturbations_names.shape[0] - 1, 30)) 454 | 455 | for covar, unique_covars in self.covars_encoder.items(): 456 | if len(unique_covars) > 1: 457 | target_covars = tensors[f'{covar}'].detach().cpu().numpy() 458 | 459 | knn_basal += knn_purity(z_basal, target_covars.ravel(), 460 | n_neighbors=min(target_covars.shape[0] - 1, 30)) 461 | 462 | knn_after += knn_purity(z, target_covars.ravel(), 463 | n_neighbors=min(target_covars.shape[0] - 1, 30)) 464 | 465 | return knn_basal, knn_after 466 | 467 | def get_expression(self, tensors, n_samples=1, covars_to_add=None, latent='z'): 468 | """Computes gene expression means and std. 469 | 470 | Only implemented for the gaussian likelihood. 471 | 472 | Parameters 473 | ---------- 474 | tensors : dict 475 | Considered inputs 476 | 477 | """ 478 | tensors, _ = self.mixup_data(tensors, alpha=0.0) 479 | 480 | inference_outputs, generative_outputs = self.forward( 481 | tensors, 482 | inference_kwargs={'n_samples': n_samples, 'covars_to_add': covars_to_add}, 483 | get_generative_input_kwargs={'latent': latent}, 484 | compute_loss=False, 485 | ) 486 | 487 | z = inference_outputs['z'] 488 | z_corrected = inference_outputs['z_corrected'] 489 | z_no_pert = inference_outputs['z_no_pert'] 490 | z_no_pert_corrected = inference_outputs['z_no_pert_corrected'] 491 | z_basal = inference_outputs['z_basal'] 492 | 493 | px = generative_outputs['px'] 494 | 495 | if self.recon_loss == 'gauss': 496 | output_key = 'loc' 497 | else: 498 | output_key = 'mu' 499 | 500 | reconstruction = getattr(px, output_key) 501 | 502 | return dict( 503 | px=reconstruction, 504 | z=z, 505 | z_corrected=z_corrected, 506 | z_no_pert=z_no_pert, 507 | z_no_pert_corrected=z_no_pert_corrected, 508 | z_basal=z_basal, 509 | ) 510 | 511 | def get_pert_embeddings(self, tensors, **inference_kwargs): 512 | inputs = self._get_inference_input(tensors) 513 | drugs = inputs['perts'] 514 | doses = inputs['perts_doses'] 515 | 516 | return self.pert_network(drugs, doses) 517 | -------------------------------------------------------------------------------- /Supp_code/Modified_CPA/_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from scvi.distributions import NegativeBinomial 7 | 8 | from scvi.nn import FCLayers 9 | from torch.distributions import Normal 10 | from typing import Optional 11 | 12 | 13 | class _REGISTRY_KEYS: 14 | X_KEY: str = "X" 15 | X_CTRL_KEY: str = None 16 | BATCH_KEY: str = None 17 | CATEGORY_KEY: str = "cpa_category" 18 | PERTURBATION_KEY: str = None 19 | PERTURBATION_DOSAGE_KEY: str = None 20 | PERTURBATIONS: str = "perts" 21 | PERTURBATIONS_DOSAGES: str = "perts_doses" 22 | SIZE_FACTOR_KEY: str = "size_factor" 23 | CAT_COV_KEYS: List[str] = [] 24 | MAX_COMB_LENGTH: int = 2 25 | CONTROL_KEY: str = None 26 | DEG_MASK: str = None 27 | DEG_MASK_R2: str = None 28 | PADDING_IDX: int = 0 29 | 30 | 31 | CPA_REGISTRY_KEYS = _REGISTRY_KEYS() 32 | 33 | 34 | class VanillaEncoder(nn.Module): 35 | def __init__( 36 | self, 37 | n_input, 38 | n_output, 39 | n_hidden, 40 | n_layers, 41 | n_cat_list, 42 | use_layer_norm=True, 43 | use_batch_norm=False, 44 | output_activation: str = 'linear', 45 | dropout_rate: float = 0.1, 46 | activation_fn=nn.ReLU, 47 | ): 48 | super().__init__() 49 | self.n_output = n_output 50 | self.output_activation = output_activation 51 | 52 | self.network = FCLayers( 53 | n_in=n_input, 54 | n_out=n_hidden, 55 | n_cat_list=n_cat_list, 56 | n_layers=n_layers, 57 | n_hidden=n_hidden, 58 | use_layer_norm=use_layer_norm, 59 | use_batch_norm=use_batch_norm, 60 | dropout_rate=dropout_rate, 61 | activation_fn=activation_fn, 62 | ) 63 | self.z = nn.Linear(n_hidden, n_output) 64 | 65 | def forward(self, inputs, *cat_list): 66 | if self.output_activation == 'linear': 67 | z = self.z(self.network(inputs, *cat_list)) 68 | elif self.output_activation == 'relu': 69 | z = F.relu(self.z(self.network(inputs, *cat_list))) 70 | else: 71 | raise ValueError(f'Unknown output activation: {self.output_activation}') 72 | return z 73 | 74 | 75 | class GeneralizedSigmoid(nn.Module): 76 | """ 77 | Sigmoid, log-sigmoid or linear functions for encoding dose-response for 78 | drug perurbations. 79 | """ 80 | 81 | def __init__(self, n_drugs, non_linearity='sigmoid'): 82 | """Sigmoid modeling of continuous variable. 83 | Params 84 | ------ 85 | nonlin : str (default: logsigm) 86 | One of logsigm, sigm. 87 | """ 88 | super(GeneralizedSigmoid, self).__init__() 89 | self.non_linearity = non_linearity 90 | self.n_drugs = n_drugs 91 | 92 | self.beta = torch.nn.Parameter( 93 | torch.ones(1, n_drugs), 94 | requires_grad=True 95 | ) 96 | self.bias = torch.nn.Parameter( 97 | torch.zeros(1, n_drugs), 98 | requires_grad=True 99 | ) 100 | 101 | self.vmap = None 102 | 103 | def forward(self, x, y): 104 | """ 105 | Parameters 106 | ---------- 107 | x: (batch_size, max_comb_len) 108 | y: (batch_size, max_comb_len) 109 | """ 110 | y = y.long() 111 | if self.non_linearity == 'logsigm': 112 | bias = self.bias[0][y] 113 | beta = self.beta[0][y] 114 | c0 = bias.sigmoid() 115 | return (torch.log1p(x) * beta + bias).sigmoid() - c0 116 | elif self.non_linearity == 'sigm': 117 | bias = self.bias[0][y] 118 | beta = self.beta[0][y] 119 | c0 = bias.sigmoid() 120 | return (x * beta + bias).sigmoid() - c0 121 | else: 122 | return x 123 | 124 | def one_drug(self, x, i): 125 | if self.non_linearity == 'logsigm': 126 | c0 = self.bias[0][i].sigmoid() 127 | return (torch.log1p(x) * self.beta[0][i] + self.bias[0][i]).sigmoid() - c0 128 | elif self.non_linearity == 'sigm': 129 | c0 = self.bias[0][i].sigmoid() 130 | return (x * self.beta[0][i] + self.bias[0][i]).sigmoid() - c0 131 | else: 132 | return x 133 | 134 | 135 | class PerturbationNetwork(nn.Module): 136 | def __init__(self, 137 | n_perts, 138 | n_latent, 139 | doser_type='logsigm', 140 | n_hidden=None, 141 | n_layers=None, 142 | dropout_rate: float = 0.0, 143 | drug_embeddings=None,): 144 | super().__init__() 145 | self.n_latent = n_latent 146 | 147 | if drug_embeddings is not None: 148 | self.pert_embedding = drug_embeddings 149 | self.pert_transformation = nn.Linear(drug_embeddings.embedding_dim, n_latent) 150 | self.use_rdkit = True 151 | else: 152 | self.use_rdkit = False 153 | self.pert_embedding = nn.Embedding(n_perts, n_latent, padding_idx=CPA_REGISTRY_KEYS.PADDING_IDX) 154 | 155 | self.doser_type = doser_type 156 | if self.doser_type == 'mlp': 157 | self.dosers = nn.ModuleList() 158 | for _ in range(n_perts): 159 | self.dosers.append( 160 | FCLayers( 161 | n_in=1, 162 | n_out=1, 163 | n_hidden=n_hidden, 164 | n_layers=n_layers, 165 | use_batch_norm=False, 166 | use_layer_norm=True, 167 | dropout_rate=dropout_rate 168 | ) 169 | ) 170 | else: 171 | self.dosers = GeneralizedSigmoid(n_perts, non_linearity=self.doser_type) 172 | 173 | def forward(self, perts, dosages): 174 | """ 175 | perts: (batch_size, max_comb_len) 176 | dosages: (batch_size, max_comb_len) 177 | """ 178 | bs, max_comb_len = perts.shape 179 | perts = perts.long() 180 | scaled_dosages = self.dosers(dosages, perts) # (batch_size, max_comb_len) 181 | 182 | drug_embeddings = self.pert_embedding(perts) # (batch_size, max_comb_len, n_drug_emb_dim) 183 | 184 | if self.use_rdkit: 185 | drug_embeddings = self.pert_transformation(drug_embeddings.view(bs * max_comb_len, -1)).view(bs, max_comb_len, -1) 186 | 187 | z_drugs = torch.einsum('bm,bme->bme', [scaled_dosages, drug_embeddings]) # (batch_size, n_latent) 188 | 189 | z_drugs = torch.einsum('bmn,bm->bmn', z_drugs, (perts != CPA_REGISTRY_KEYS.PADDING_IDX).int()).sum(dim=1) # mask single perts 190 | 191 | return z_drugs # (batch_size, n_latent) 192 | 193 | class FocalLoss(nn.Module): 194 | """ Inspired by https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py 195 | 196 | Focal Loss, as described in https://arxiv.org/abs/1708.02002. 197 | It is essentially an enhancement to cross entropy loss and is 198 | useful for classification tasks when there is a large class imbalance. 199 | x is expected to contain raw, unnormalized scores for each class. 200 | y is expected to contain class labels. 201 | Shape: 202 | - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0. 203 | - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0. 204 | """ 205 | 206 | def __init__(self, 207 | alpha: Optional[torch.Tensor] = None, 208 | gamma: float = 2., 209 | reduction: str = 'mean', 210 | ): 211 | """ 212 | Args: 213 | alpha (Tensor, optional): Weights for each class. Defaults to None. 214 | gamma (float, optional): A constant, as described in the paper. 215 | Defaults to 0. 216 | reduction (str, optional): 'mean', 'sum' or 'none'. 217 | Defaults to 'mean'. 218 | """ 219 | if reduction not in ('mean', 'sum', 'none'): 220 | raise ValueError( 221 | 'Reduction must be one of: "mean", "sum", "none".') 222 | 223 | super().__init__() 224 | self.alpha = alpha 225 | self.gamma = gamma 226 | self.reduction = reduction 227 | 228 | self.nll_loss = nn.NLLLoss( 229 | weight=alpha, reduction='none') 230 | 231 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 232 | if len(y_true) == 0: 233 | return torch.tensor(0.) 234 | 235 | # compute weighted cross entropy term: -alpha * log(pt) 236 | # (alpha is already part of self.nll_loss) 237 | log_p = F.log_softmax(y_pred, dim=-1) 238 | ce = self.nll_loss(log_p, y_true) 239 | 240 | # get true class column from each row 241 | all_rows = torch.arange(len(y_pred)) 242 | log_pt = log_p[all_rows, y_true] 243 | 244 | # compute focal term: (1 - pt)^gamma 245 | pt = log_pt.exp() 246 | focal_term = (1 - pt) ** self.gamma 247 | 248 | # the full loss: -alpha * ((1 - pt)^gamma) * log(pt) 249 | loss = focal_term * ce 250 | 251 | if self.reduction == 'mean': 252 | loss = loss.mean() 253 | elif self.reduction == 'sum': 254 | loss = loss.sum() 255 | 256 | return loss -------------------------------------------------------------------------------- /Supp_code/Modified_GEARS/data_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import scanpy as sc 4 | from random import shuffle 5 | sc.settings.verbosity = 0 6 | from tqdm import tqdm 7 | import requests 8 | import os, sys 9 | 10 | import warnings 11 | warnings.filterwarnings("ignore") 12 | 13 | from .utils import parse_single_pert, parse_combo_pert, parse_any_pert, print_sys 14 | 15 | def rank_genes_groups_by_cov( 16 | adata, 17 | groupby, 18 | control_group, 19 | covariate, 20 | pool_doses=False, 21 | n_genes=50, 22 | rankby_abs=True, 23 | key_added='rank_genes_groups_cov', 24 | return_dict=False, 25 | ): 26 | 27 | gene_dict = {} 28 | cov_categories = adata.obs[covariate].unique() 29 | for cov_cat in cov_categories: 30 | #name of the control group in the groupby obs column 31 | control_group_cov = '_'.join([cov_cat, control_group]) 32 | 33 | #subset adata to cells belonging to a covariate category 34 | adata_cov = adata[adata.obs[covariate]==cov_cat] 35 | 36 | #compute DEGs 37 | sc.tl.rank_genes_groups( 38 | adata_cov, 39 | groupby=groupby, 40 | reference=control_group_cov, 41 | rankby_abs=rankby_abs, 42 | n_genes=n_genes, 43 | use_raw=False 44 | ) 45 | 46 | #add entries to dictionary of gene sets 47 | de_genes = pd.DataFrame(adata_cov.uns['rank_genes_groups']['names']) 48 | for group in de_genes: 49 | gene_dict[group] = de_genes[group].tolist() 50 | 51 | adata.uns[key_added] = gene_dict 52 | 53 | if return_dict: 54 | return gene_dict 55 | 56 | 57 | def get_DE_genes(adata, skip_calc_de): 58 | adata.obs.loc[:, 'dose_val'] = adata.obs.condition.apply(lambda x: '1+1' if len(x.split('+')) == 2 else '1') 59 | adata.obs.loc[:, 'control'] = adata.obs.condition.apply(lambda x: 0 if len(x.split('+')) == 2 else 1) 60 | adata.obs.loc[:, 'condition_name'] = adata.obs.apply(lambda x: '_'.join([x.cell_type, x.condition, x.dose_val]), axis = 1) 61 | 62 | adata.obs = adata.obs.astype('category') 63 | if not skip_calc_de: 64 | rank_genes_groups_by_cov(adata, 65 | groupby='condition_name', 66 | covariate='cell_type', 67 | control_group='ctrl_1', 68 | n_genes=len(adata.var), 69 | key_added = 'rank_genes_groups_cov_all') 70 | return adata 71 | 72 | def get_dropout_non_zero_genes(adata): 73 | 74 | # calculate mean expression for each condition 75 | unique_conditions = adata.obs.condition.unique() 76 | conditions2index = {} 77 | for i in unique_conditions: 78 | conditions2index[i] = np.where(adata.obs.condition == i)[0] 79 | 80 | condition2mean_expression = {} 81 | for i, j in conditions2index.items(): 82 | condition2mean_expression[i] = np.mean(adata.X[j], axis = 0) 83 | pert_list = np.array(list(condition2mean_expression.keys())) 84 | mean_expression = np.array(list(condition2mean_expression.values())).reshape(len(adata.obs.condition.unique()), adata.X.toarray().shape[1]) 85 | ctrl = mean_expression[np.where(pert_list == 'ctrl')[0]] 86 | 87 | ## in silico modeling and upperbounding 88 | pert2pert_full_id = dict(adata.obs[['condition', 'condition_name']].values) 89 | pert_full_id2pert = dict(adata.obs[['condition_name', 'condition']].values) 90 | 91 | gene_id2idx = dict(zip(adata.var.index.values, range(len(adata.var)))) 92 | gene_idx2id = dict(zip(range(len(adata.var)), adata.var.index.values)) 93 | 94 | non_zeros_gene_idx = {} 95 | top_non_dropout_de_20 = {} 96 | top_non_zero_de_20 = {} 97 | non_dropout_gene_idx = {} 98 | 99 | for pert in adata.uns['rank_genes_groups_cov_all'].keys(): 100 | p = pert_full_id2pert[pert] 101 | X = np.mean(adata[adata.obs.condition == p].X, axis = 0) 102 | 103 | non_zero = np.where(np.array(X)[0] != 0)[0] 104 | zero = np.where(np.array(X)[0] == 0)[0] 105 | true_zeros = np.intersect1d(zero, np.where(np.array(ctrl)[0] == 0)[0]) 106 | non_dropouts = np.concatenate((non_zero, true_zeros)) 107 | 108 | top = adata.uns['rank_genes_groups_cov_all'][pert] 109 | gene_idx_top = [gene_id2idx[i] for i in top] 110 | 111 | non_dropout_20 = [i for i in gene_idx_top if i in non_dropouts][:20] 112 | non_dropout_20_gene_id = [gene_idx2id[i] for i in non_dropout_20] 113 | 114 | non_zero_20 = [i for i in gene_idx_top if i in non_zero][:20] 115 | non_zero_20_gene_id = [gene_idx2id[i] for i in non_zero_20] 116 | 117 | non_zeros_gene_idx[pert] = np.sort(non_zero) 118 | non_dropout_gene_idx[pert] = np.sort(non_dropouts) 119 | top_non_dropout_de_20[pert] = np.array(non_dropout_20_gene_id) 120 | top_non_zero_de_20[pert] = np.array(non_zero_20_gene_id) 121 | 122 | non_zero = np.where(np.array(X)[0] != 0)[0] 123 | zero = np.where(np.array(X)[0] == 0)[0] 124 | true_zeros = np.intersect1d(zero, np.where(np.array(ctrl)[0] == 0)[0]) 125 | non_dropouts = np.concatenate((non_zero, true_zeros)) 126 | 127 | adata.uns['top_non_dropout_de_20'] = top_non_dropout_de_20 128 | adata.uns['non_dropout_gene_idx'] = non_dropout_gene_idx 129 | adata.uns['non_zeros_gene_idx'] = non_zeros_gene_idx 130 | adata.uns['top_non_zero_de_20'] = top_non_zero_de_20 131 | 132 | return adata 133 | 134 | 135 | class DataSplitter(): 136 | """ 137 | Class for handling data splitting. This class is able to generate new 138 | data splits and assign them as a new attribute to the data file. 139 | """ 140 | def __init__(self, adata, split_type='single', seen=0): 141 | self.adata = adata 142 | self.split_type = split_type 143 | self.seen = seen 144 | 145 | def split_data(self, test_size=0.1, test_pert_genes=None, 146 | test_perts=None, split_name='split', seed=None, val_size = 0.1, 147 | train_gene_set_size = 0.75, combo_seen2_train_frac = 0.75, only_test_set_perts = False): 148 | """ 149 | Split dataset and adds split as a column to the dataframe 150 | Note: split categories are train, val, test 151 | """ 152 | np.random.seed(seed=seed) 153 | unique_perts = [p for p in self.adata.obs['condition'].unique() if 154 | p != 'ctrl'] 155 | 156 | if self.split_type == 'simulation': 157 | train, test, test_subgroup = self.get_simulation_split(unique_perts, 158 | train_gene_set_size, 159 | combo_seen2_train_frac, 160 | seed, test_perts, only_test_set_perts) 161 | train, val, val_subgroup = self.get_simulation_split(train, 162 | 0.9, 163 | 0.9, 164 | seed) 165 | ## adding back ctrl to train... 166 | train.append('ctrl') 167 | elif self.split_type == 'simulation_single': 168 | train, test, test_subgroup = self.get_simulation_split_single(unique_perts, 169 | train_gene_set_size, 170 | seed, test_perts, only_test_set_perts) 171 | train, val, val_subgroup = self.get_simulation_split_single(train, 172 | 0.9, 173 | seed) 174 | elif self.split_type == 'no_test': 175 | print('test_pert_genes',str(test_pert_genes)) 176 | print('test_perts',str(test_perts)) 177 | 178 | train, val = self.get_split_list(unique_perts, 179 | test_pert_genes=test_pert_genes, 180 | test_perts=test_perts, 181 | test_size=test_size) 182 | else: 183 | train, test = self.get_split_list(unique_perts, 184 | test_pert_genes=test_pert_genes, 185 | test_perts=test_perts, 186 | test_size=test_size) 187 | 188 | train, val = self.get_split_list(train, test_size=val_size) 189 | 190 | map_dict = {x: 'train' for x in train} 191 | map_dict.update({x: 'val' for x in val}) 192 | if self.split_type != 'no_test': 193 | map_dict.update({x: 'test' for x in test}) 194 | map_dict.update({'ctrl': 'train'}) 195 | 196 | self.adata.obs[split_name] = self.adata.obs['condition'].map(map_dict) 197 | 198 | if self.split_type == 'simulation': 199 | return self.adata, {'test_subgroup': test_subgroup, 200 | 'val_subgroup': val_subgroup 201 | } 202 | else: 203 | return self.adata 204 | 205 | def get_simulation_split_single(self, pert_list, train_gene_set_size = 0.85, seed = 1, test_set_perts = None, only_test_set_perts = False): 206 | unique_pert_genes = self.get_genes_from_perts(pert_list) 207 | 208 | pert_train = [] 209 | pert_test = [] 210 | np.random.seed(seed=seed) 211 | 212 | if only_test_set_perts and (test_set_perts is not None): 213 | ood_genes = np.array(test_set_perts) 214 | train_gene_candidates = np.setdiff1d(unique_pert_genes, ood_genes) 215 | else: 216 | ## a pre-specified list of genes 217 | train_gene_candidates = np.random.choice(unique_pert_genes, 218 | int(len(unique_pert_genes) * train_gene_set_size), replace = False) 219 | 220 | if test_set_perts is not None: 221 | num_overlap = len(np.intersect1d(train_gene_candidates, test_set_perts)) 222 | train_gene_candidates = train_gene_candidates[~np.isin(train_gene_candidates, test_set_perts)] 223 | ood_genes_exclude_test_set = np.setdiff1d(unique_pert_genes, np.union1d(train_gene_candidates, test_set_perts)) 224 | train_set_addition = np.random.choice(ood_genes_exclude_test_set, num_overlap, replace = False) 225 | train_gene_candidates = np.concatenate((train_gene_candidates, train_set_addition)) 226 | 227 | ## ood genes 228 | ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates) 229 | 230 | pert_single_train = self.get_perts_from_genes(train_gene_candidates, pert_list,'single') 231 | unseen_single = self.get_perts_from_genes(ood_genes, pert_list, 'single') 232 | assert len(unseen_single) + len(pert_single_train) == len(pert_list) 233 | 234 | return pert_single_train, unseen_single, {'unseen_single': unseen_single} 235 | 236 | def get_simulation_split(self, pert_list, train_gene_set_size = 0.85, combo_seen2_train_frac = 0.85, seed = 1, test_set_perts = None, only_test_set_perts = False): 237 | 238 | unique_pert_genes = self.get_genes_from_perts(pert_list) 239 | 240 | pert_train = [] 241 | pert_test = [] 242 | np.random.seed(seed=seed) 243 | 244 | if only_test_set_perts and (test_set_perts is not None): 245 | ood_genes = np.array(test_set_perts) 246 | train_gene_candidates = np.setdiff1d(unique_pert_genes, ood_genes) 247 | else: 248 | ## a pre-specified list of genes 249 | train_gene_candidates = np.random.choice(unique_pert_genes, 250 | int(len(unique_pert_genes) * train_gene_set_size), replace = False) 251 | 252 | if test_set_perts is not None: 253 | num_overlap = len(np.intersect1d(train_gene_candidates, test_set_perts)) 254 | train_gene_candidates = train_gene_candidates[~np.isin(train_gene_candidates, test_set_perts)] 255 | ood_genes_exclude_test_set = np.setdiff1d(unique_pert_genes, np.union1d(train_gene_candidates, test_set_perts)) 256 | train_set_addition = np.random.choice(ood_genes_exclude_test_set, num_overlap, replace = False) 257 | train_gene_candidates = np.concatenate((train_gene_candidates, train_set_addition)) 258 | 259 | ## ood genes 260 | ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates) 261 | 262 | pert_single_train = self.get_perts_from_genes(train_gene_candidates, pert_list,'single') 263 | pert_combo = self.get_perts_from_genes(train_gene_candidates, pert_list,'combo') 264 | pert_train.extend(pert_single_train) 265 | 266 | ## the combo set with one of them in OOD 267 | combo_seen1 = [x for x in pert_combo if len([t for t in x.split('+') if 268 | t in train_gene_candidates]) == 1] 269 | pert_test.extend(combo_seen1) 270 | 271 | pert_combo = np.setdiff1d(pert_combo, combo_seen1) 272 | ## randomly sample the combo seen 2 as a test set, the rest in training set 273 | np.random.seed(seed=seed) 274 | pert_combo_train = np.random.choice(pert_combo, int(len(pert_combo) * combo_seen2_train_frac), replace = False) 275 | 276 | combo_seen2 = np.setdiff1d(pert_combo, pert_combo_train).tolist() 277 | pert_test.extend(combo_seen2) 278 | pert_train.extend(pert_combo_train) 279 | 280 | ## unseen single 281 | unseen_single = self.get_perts_from_genes(ood_genes, pert_list, 'single') 282 | combo_ood = self.get_perts_from_genes(ood_genes, pert_list, 'combo') 283 | pert_test.extend(unseen_single) 284 | 285 | ## here only keeps the seen 0, since seen 1 is tackled above 286 | combo_seen0 = [x for x in combo_ood if len([t for t in x.split('+') if 287 | t in train_gene_candidates]) == 0] 288 | pert_test.extend(combo_seen0) 289 | assert len(combo_seen1) + len(combo_seen0) + len(unseen_single) + len(pert_train) + len(combo_seen2) == len(pert_list) 290 | 291 | return pert_train, pert_test, {'combo_seen0': combo_seen0, 292 | 'combo_seen1': combo_seen1, 293 | 'combo_seen2': combo_seen2, 294 | 'unseen_single': unseen_single} 295 | 296 | def get_split_list(self, pert_list, test_size=0.1, 297 | test_pert_genes=None, test_perts=None, 298 | hold_outs=True): 299 | """ 300 | Splits a given perturbation list into train and test with no shared 301 | perturbations 302 | """ 303 | 304 | single_perts = [p for p in pert_list if 'ctrl' in p and p != 'ctrl'] 305 | combo_perts = [p for p in pert_list if 'ctrl' not in p] 306 | unique_pert_genes = self.get_genes_from_perts(pert_list) 307 | hold_out = [] 308 | 309 | if test_pert_genes is None: 310 | test_pert_genes = np.random.choice(unique_pert_genes, 311 | int(len(single_perts) * test_size)) 312 | 313 | # Only single unseen genes (in test set) 314 | # Train contains both single and combos 315 | if self.split_type == 'single' or self.split_type == 'single_only': 316 | test_perts = self.get_perts_from_genes(test_pert_genes, pert_list, 317 | 'single') 318 | if self.split_type == 'single_only': 319 | # Discard all combos 320 | hold_out = combo_perts 321 | else: 322 | # Discard only those combos which contain test genes 323 | hold_out = self.get_perts_from_genes(test_pert_genes, pert_list, 324 | 'combo') 325 | 326 | elif self.split_type == 'combo': 327 | if self.seen == 0: 328 | # NOTE: This can reduce the dataset size! 329 | # To prevent this set 'holdouts' to False, this will cause 330 | # the test set to have some perturbations with 1 gene seen 331 | single_perts = self.get_perts_from_genes(test_pert_genes, 332 | pert_list, 'single') 333 | combo_perts = self.get_perts_from_genes(test_pert_genes, 334 | pert_list, 'combo') 335 | 336 | if hold_outs: 337 | # This just checks that none of the combos have 2 seen genes 338 | hold_out = [t for t in combo_perts if 339 | len([t for t in t.split('+') if 340 | t not in test_pert_genes]) > 0] 341 | combo_perts = [c for c in combo_perts if c not in hold_out] 342 | test_perts = single_perts + combo_perts 343 | 344 | elif self.seen == 1: 345 | # NOTE: This can reduce the dataset size! 346 | # To prevent this set 'holdouts' to False, this will cause 347 | # the test set to have some perturbations with 2 genes seen 348 | single_perts = self.get_perts_from_genes(test_pert_genes, 349 | pert_list, 'single') 350 | combo_perts = self.get_perts_from_genes(test_pert_genes, 351 | pert_list, 'combo') 352 | 353 | if hold_outs: 354 | # This just checks that none of the combos have 2 seen genes 355 | hold_out = [t for t in combo_perts if 356 | len([t for t in t.split('+') if 357 | t not in test_pert_genes]) > 1] 358 | combo_perts = [c for c in combo_perts if c not in hold_out] 359 | test_perts = single_perts + combo_perts 360 | 361 | elif self.seen == 2: 362 | if test_perts is None: 363 | test_perts = np.random.choice(combo_perts, 364 | int(len(combo_perts) * test_size)) 365 | else: 366 | test_perts = np.array(test_perts) 367 | else: 368 | if test_perts is None: 369 | test_perts = np.random.choice(combo_perts, 370 | int(len(combo_perts) * test_size)) 371 | 372 | train_perts = [p for p in pert_list if (p not in test_perts) 373 | and (p not in hold_out)] 374 | return train_perts, test_perts 375 | 376 | def get_perts_from_genes(self, genes, pert_list, type_='both'): 377 | """ 378 | Returns all single/combo/both perturbations that include a gene 379 | """ 380 | 381 | single_perts = [p for p in pert_list if ('ctrl' in p) and (p != 'ctrl')] 382 | combo_perts = [p for p in pert_list if 'ctrl' not in p] 383 | 384 | perts = [] 385 | 386 | if type_ == 'single': 387 | pert_candidate_list = single_perts 388 | elif type_ == 'combo': 389 | pert_candidate_list = combo_perts 390 | elif type_ == 'both': 391 | pert_candidate_list = pert_list 392 | 393 | for p in pert_candidate_list: 394 | for g in genes: 395 | if g in parse_any_pert(p): 396 | perts.append(p) 397 | break 398 | return perts 399 | 400 | def get_genes_from_perts(self, perts): 401 | """ 402 | Returns list of genes involved in a given perturbation list 403 | """ 404 | 405 | if type(perts) is str: 406 | perts = [perts] 407 | gene_list = [p.split('+') for p in np.unique(perts)] 408 | gene_list = [item for sublist in gene_list for item in sublist] 409 | gene_list = [g for g in gene_list if g != 'ctrl'] 410 | return np.unique(gene_list) -------------------------------------------------------------------------------- /Supp_code/Modified_GEARS/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Sequential, Linear, ReLU 5 | 6 | from torch_geometric.nn import SGConv 7 | 8 | class MLP(torch.nn.Module): 9 | 10 | def __init__(self, sizes, batch_norm=True, last_layer_act="linear"): 11 | """ 12 | Multi-layer perceptron 13 | :param sizes: list of sizes of the layers 14 | :param batch_norm: whether to use batch normalization 15 | :param last_layer_act: activation function of the last layer 16 | 17 | """ 18 | super(MLP, self).__init__() 19 | layers = [] 20 | for s in range(len(sizes) - 1): 21 | layers = layers + [ 22 | torch.nn.Linear(sizes[s], sizes[s + 1]), 23 | torch.nn.BatchNorm1d(sizes[s + 1]) 24 | if batch_norm and s < len(sizes) - 1 else None, 25 | torch.nn.ReLU() 26 | ] 27 | 28 | layers = [l for l in layers if l is not None][:-1] 29 | self.activation = last_layer_act 30 | self.network = torch.nn.Sequential(*layers) 31 | self.relu = torch.nn.ReLU() 32 | def forward(self, x): 33 | return self.network(x) 34 | 35 | 36 | class First_Sub_task_MLP(torch.nn.Module): 37 | 38 | def __init__(self, sizes, batch_norm=True, last_layer_act="linear"): 39 | """ 40 | Multi-layer perceptron 41 | :param sizes: list of sizes of the layers 42 | :param batch_norm: whether to use batch normalization 43 | :param last_layer_act: activation function of the last layer 44 | 45 | """ 46 | super().__init__() 47 | layers = [] 48 | for s in range(len(sizes) - 1): 49 | layers = layers + [ 50 | torch.nn.Linear(sizes[s], sizes[s + 1]), 51 | torch.nn.ReLU() 52 | ] 53 | layers = [l for l in layers if l is not None][:-1] 54 | self.activation = last_layer_act 55 | self.network = torch.nn.Sequential(*layers) 56 | def forward(self, x): 57 | return self.network(x) 58 | 59 | class GEARS_Model(torch.nn.Module): 60 | """ 61 | GEARS model 62 | 63 | """ 64 | 65 | def __init__(self, args): 66 | """ 67 | :param args: arguments dictionary 68 | """ 69 | 70 | super(GEARS_Model, self).__init__() 71 | self.args = args 72 | self.num_genes = args['num_genes'] 73 | self.num_perts = args['num_perts'] 74 | hidden_size = args['hidden_size'] 75 | self.uncertainty = args['uncertainty'] 76 | self.num_layers = args['num_go_gnn_layers'] 77 | self.indv_out_hidden_size = args['decoder_hidden_size'] 78 | self.num_layers_gene_pos = args['num_gene_gnn_layers'] 79 | self.no_perturb = args['no_perturb'] 80 | self.pert_emb_lambda = 0.2 81 | 82 | # perturbation positional embedding added only to the perturbed genes 83 | self.pert_w = nn.Linear(1, hidden_size) 84 | 85 | # gene/globel perturbation embedding dictionary lookup 86 | self.gene_emb = nn.Embedding(self.num_genes, hidden_size, max_norm=True) 87 | self.pert_emb = nn.Embedding(self.num_perts, hidden_size, max_norm=True) 88 | 89 | # transformation layer 90 | self.emb_trans = nn.ReLU() 91 | self.pert_base_trans = nn.ReLU() 92 | self.transform = nn.ReLU() 93 | self.emb_trans_v2 = MLP([hidden_size, hidden_size, hidden_size], last_layer_act='ReLU') 94 | self.pert_fuse = MLP([hidden_size, hidden_size, hidden_size], last_layer_act='ReLU') 95 | self.first_sub_task = First_Sub_task_MLP([hidden_size, hidden_size // 2, 1], last_layer_act='Sigmoid') 96 | 97 | 98 | # gene co-expression GNN 99 | self.G_coexpress = args['G_coexpress'].to(args['device']) 100 | self.G_coexpress_weight = args['G_coexpress_weight'].to(args['device']) 101 | 102 | self.emb_pos = nn.Embedding(self.num_genes, hidden_size, max_norm=True) 103 | self.layers_emb_pos = torch.nn.ModuleList() 104 | for i in range(1, self.num_layers_gene_pos + 1): 105 | self.layers_emb_pos.append(SGConv(hidden_size, hidden_size, 1)) 106 | 107 | ### perturbation gene ontology GNN 108 | self.G_sim = args['G_go'].to(args['device']) 109 | self.G_sim_weight = args['G_go_weight'].to(args['device']) 110 | 111 | self.sim_layers = torch.nn.ModuleList() 112 | for i in range(1, self.num_layers + 1): 113 | self.sim_layers.append(SGConv(hidden_size, hidden_size, 1)) 114 | 115 | # decoder shared MLP 116 | self.recovery_w = MLP([hidden_size, hidden_size*2, hidden_size], last_layer_act='linear') 117 | 118 | # gene specific decoder 119 | self.indv_w1 = nn.Parameter(torch.rand(self.num_genes, 120 | hidden_size, 1)) 121 | self.indv_b1 = nn.Parameter(torch.rand(self.num_genes, 1)) 122 | self.act = nn.ReLU() 123 | nn.init.xavier_normal_(self.indv_w1) 124 | nn.init.xavier_normal_(self.indv_b1) 125 | 126 | # Cross gene MLP 127 | self.cross_gene_state = MLP([self.num_genes, hidden_size, 128 | hidden_size]) 129 | # final gene specific decoder 130 | self.indv_w2 = nn.Parameter(torch.rand(1, self.num_genes, 131 | hidden_size+1)) 132 | self.indv_b2 = nn.Parameter(torch.rand(1, self.num_genes)) 133 | nn.init.xavier_normal_(self.indv_w2) 134 | nn.init.xavier_normal_(self.indv_b2) 135 | 136 | # batchnorms 137 | self.bn_emb = nn.BatchNorm1d(hidden_size) 138 | self.bn_pert_base = nn.BatchNorm1d(hidden_size) 139 | self.bn_pert_base_trans = nn.BatchNorm1d(hidden_size) 140 | 141 | # uncertainty mode 142 | if self.uncertainty: 143 | self.uncertainty_w = MLP([hidden_size, hidden_size*2, hidden_size, 1], last_layer_act='linear') 144 | 145 | def forward(self, data): 146 | """ 147 | Forward pass of the model 148 | """ 149 | x, pert_idx = data.x, data.pert_idx 150 | if self.no_perturb: 151 | out = x.reshape(-1,1) 152 | out = torch.split(torch.flatten(out), self.num_genes) 153 | return torch.stack(out) 154 | else: 155 | num_graphs = len(data.batch.unique()) 156 | 157 | ## get base gene embeddings 158 | emb = self.gene_emb(torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.args['device'])) 159 | emb = self.bn_emb(emb) 160 | base_emb = self.emb_trans(emb) 161 | 162 | pos_emb = self.emb_pos(torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.args['device'])) 163 | for idx, layer in enumerate(self.layers_emb_pos): 164 | pos_emb = layer(pos_emb, self.G_coexpress, self.G_coexpress_weight) 165 | if idx < len(self.layers_emb_pos) - 1: 166 | pos_emb = pos_emb.relu() 167 | 168 | base_emb = base_emb + 0.2 * pos_emb 169 | base_emb = self.emb_trans_v2(base_emb) 170 | 171 | ## get perturbation index and embeddings 172 | 173 | pert_index = [] 174 | for idx, i in enumerate(pert_idx): 175 | for j in i: 176 | if j != -1: 177 | pert_index.append([idx, j]) 178 | pert_index = torch.tensor(pert_index).T 179 | 180 | pert_global_emb = self.pert_emb(torch.LongTensor(list(range(self.num_perts))).to(self.args['device'])) 181 | 182 | ## augment global perturbation embedding with GNN 183 | for idx, layer in enumerate(self.sim_layers): 184 | pert_global_emb = layer(pert_global_emb, self.G_sim, self.G_sim_weight) 185 | if idx < self.num_layers - 1: 186 | pert_global_emb = pert_global_emb.relu() 187 | 188 | ## add global perturbation embedding to each gene in each cell in the batch 189 | base_emb = base_emb.reshape(num_graphs, self.num_genes, -1) 190 | 191 | first_sub_task_out = None 192 | if pert_index.shape[0] != 0: 193 | ### in case all samples in the batch are controls, then there is no indexing for pert_index. 194 | pert_track = {} 195 | for i, j in enumerate(pert_index[0]): 196 | if j.item() in pert_track: 197 | pert_track[j.item()] = pert_track[j.item()] + pert_global_emb[pert_index[1][i]] 198 | else: 199 | pert_track[j.item()] = pert_global_emb[pert_index[1][i]] 200 | 201 | if len(list(pert_track.values())) > 0: 202 | if len(list(pert_track.values())) == 1: 203 | # circumvent when batch size = 1 with single perturbation and cannot feed into MLP 204 | emb_total = self.pert_fuse(torch.stack(list(pert_track.values()) * 2)) 205 | else: 206 | emb_total = self.pert_fuse(torch.stack(list(pert_track.values()))) 207 | 208 | for idx, j in enumerate(pert_track.keys()): 209 | base_emb[j] = base_emb[j] + emb_total[idx] 210 | # print(emb_total.size()) || output shape: [batch-num_control, hidden_dim: 64(default)] 211 | 212 | #print(base_emb.size()) || output shape: [batch, gene_num, hidden_dim] 213 | base_emb = base_emb.reshape(num_graphs * self.num_genes, -1) 214 | # print(base_emb.size()) || output shape: [batch * gene_num, hidden_dim] 215 | base_emb = self.bn_pert_base(base_emb) 216 | 217 | first_sub_task_out = base_emb.reshape(num_graphs, self.num_genes, -1) 218 | first_sub_task_out = torch.nn.Sigmoid()(self.first_sub_task(first_sub_task_out)) 219 | 220 | ## apply the first MLP 221 | base_emb = self.transform(base_emb) 222 | out = self.recovery_w(base_emb) 223 | out = out.reshape(num_graphs, self.num_genes, -1) 224 | out = out.unsqueeze(-1) * self.indv_w1 225 | w = torch.sum(out, axis = 2) 226 | out = w + self.indv_b1 # print(out.size()) || output shape :[batch, gene_exp_dim, 1] 227 | 228 | ## second sub-task output 229 | second_sub_task_out = torch.nn.Sigmoid()(out) 230 | 231 | # Cross gene 232 | cross_gene_embed = self.cross_gene_state(out.reshape(num_graphs, self.num_genes, -1).squeeze(2)) 233 | cross_gene_embed = cross_gene_embed.repeat(1, self.num_genes) 234 | 235 | cross_gene_embed = cross_gene_embed.reshape([num_graphs,self.num_genes, -1]) 236 | cross_gene_out = torch.cat([out, cross_gene_embed], 2) 237 | 238 | # original code 239 | # cross_gene_out = cross_gene_out * self.indv_w2 240 | # cross_gene_out = torch.sum(cross_gene_out, axis=2) 241 | # out = cross_gene_out + self.indv_b2 242 | # out = out.reshape(num_graphs * self.num_genes, -1) + x.reshape(-1,1) 243 | # out = torch.split(torch.flatten(out), self.num_genes) 244 | 245 | # Modification code 246 | cross_gene_out = cross_gene_out * self.indv_w2 247 | cross_gene_out = torch.sum(cross_gene_out, axis=2) 248 | # out = torch.exp(cross_gene_out + self.indv_b2) 249 | # out = out.reshape(num_graphs * self.num_genes, -1) 250 | out = cross_gene_out + self.indv_b2 251 | # out = out.reshape(num_graphs * self.num_genes, -1) + x.reshape(-1,1) 252 | 253 | # fold_change_version 254 | # out = torch.exp(out) 255 | # our_fc = out.reshape(num_graphs * self.num_genes, -1) 256 | # our_fc = torch.split(torch.flatten(our_fc), self.num_genes) 257 | # our_fc = torch.stack(our_fc) 258 | 259 | 260 | out = out.reshape(num_graphs * self.num_genes, -1) * x.reshape(-1,1) 261 | out = torch.split(torch.flatten(out), self.num_genes) 262 | third_sub_task_out = torch.stack(out) 263 | 264 | # print(third_sub_task_out.size()) 265 | ## uncertainty head 266 | # if self.uncertainty: 267 | # out_logvar = self.uncertainty_w(base_emb) 268 | # out_logvar = torch.split(torch.flatten(out_logvar), self.num_genes) 269 | # return torch.stack(out), torch.stack(out_logvar) 270 | 271 | # original coe 272 | # return torch.stack(out) 273 | 274 | # print(first_sub_task_out.size(), second_sub_task_out.size(), third_sub_task_out.size())|| [batch_size, num_gene, 1], [batch_size, num_gene, 1], [batch_size, num_gene] 275 | return first_sub_task_out, second_sub_task_out, third_sub_task_out 276 | 277 | -------------------------------------------------------------------------------- /Supp_code/Modified_GEARS/pertdata.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Data 2 | import torch 3 | import numpy as np 4 | import pickle 5 | from torch_geometric.data import DataLoader 6 | import os 7 | import scanpy as sc 8 | from tqdm import tqdm 9 | 10 | import warnings 11 | warnings.filterwarnings("ignore") 12 | sc.settings.verbosity = 0 13 | 14 | from .data_utils import get_DE_genes, get_dropout_non_zero_genes, DataSplitter 15 | from .utils import print_sys, zip_data_download_wrapper, dataverse_download,\ 16 | filter_pert_in_go, get_genes_from_perts, tar_data_download_wrapper 17 | 18 | class PertData: 19 | """ 20 | Class for loading and processing perturbation data 21 | 22 | Attributes 23 | ---------- 24 | data_path: str 25 | Path to save/load data 26 | gene_set_path: str 27 | Path to gene set to use for perturbation graph 28 | default_pert_graph: bool 29 | Whether to use default perturbation graph or not 30 | dataset_name: str 31 | Name of dataset 32 | dataset_path: str 33 | Path to dataset 34 | adata: AnnData 35 | AnnData object containing dataset 36 | dataset_processed: bool 37 | Whether dataset has been processed or not 38 | ctrl_adata: AnnData 39 | AnnData object containing control samples 40 | gene_names: list 41 | List of gene names 42 | node_map: dict 43 | Dictionary mapping gene names to indices 44 | split: str 45 | Split type 46 | seed: int 47 | Seed for splitting 48 | subgroup: str 49 | Subgroup for splitting 50 | train_gene_set_size: int 51 | Number of genes to use for training 52 | 53 | """ 54 | 55 | def __init__(self, data_path, 56 | gene_set_path=None, 57 | default_pert_graph=True): 58 | """ 59 | Parameters 60 | ---------- 61 | 62 | data_path: str 63 | Path to save/load data 64 | gene_set_path: str 65 | Path to gene set to use for perturbation graph 66 | default_pert_graph: bool 67 | Whether to use default perturbation graph or not 68 | 69 | """ 70 | 71 | 72 | # Dataset/Dataloader attributes 73 | self.data_path = data_path 74 | self.default_pert_graph = default_pert_graph 75 | self.gene_set_path = gene_set_path 76 | self.dataset_name = None 77 | self.dataset_path = None 78 | self.adata = None 79 | self.dataset_processed = None 80 | self.ctrl_adata = None 81 | self.gene_names = [] 82 | self.node_map = {} 83 | 84 | # Split attributes 85 | self.split = None 86 | self.seed = None 87 | self.subgroup = None 88 | self.train_gene_set_size = None 89 | 90 | if not os.path.exists(self.data_path): 91 | os.mkdir(self.data_path) 92 | server_path = 'https://dataverse.harvard.edu/api/access/datafile/6153417' 93 | dataverse_download(server_path, 94 | os.path.join(self.data_path, 'gene2go_all.pkl')) 95 | with open(os.path.join(self.data_path, 'gene2go_all.pkl'), 'rb') as f: 96 | self.gene2go = pickle.load(f) 97 | 98 | def set_pert_genes(self): 99 | """ 100 | Set the list of genes that can be perturbed and are to be included in 101 | perturbation graph 102 | """ 103 | 104 | if self.gene_set_path is not None: 105 | # If gene set specified for perturbation graph, use that 106 | path_ = self.gene_set_path 107 | self.default_pert_graph = False 108 | with open(path_, 'rb') as f: 109 | essential_genes = pickle.load(f) 110 | 111 | elif self.default_pert_graph is False: 112 | # Use a smaller perturbation graph 113 | all_pert_genes = get_genes_from_perts(self.adata.obs['condition']) 114 | essential_genes = list(self.adata.var['gene_name'].values) 115 | essential_genes += all_pert_genes 116 | 117 | else: 118 | # Otherwise, use a large set of genes to create perturbation graph 119 | server_path = 'https://dataverse.harvard.edu/api/access/datafile/6934320' 120 | path_ = os.path.join(self.data_path, 121 | 'essential_all_data_pert_genes.pkl') 122 | dataverse_download(server_path, path_) 123 | with open(path_, 'rb') as f: 124 | essential_genes = pickle.load(f) 125 | 126 | gene2go = {i: self.gene2go[i] for i in essential_genes if i in self.gene2go} 127 | 128 | self.pert_names = np.unique(list(gene2go.keys())) 129 | self.node_map_pert = {x: it for it, x in enumerate(self.pert_names)} 130 | 131 | def load(self, data_name = None, data_path = None): 132 | """ 133 | Load existing dataloader 134 | Use data_name for loading 'norman', 'adamson', 'dixit' datasets 135 | For other datasets use data_path 136 | 137 | Parameters 138 | ---------- 139 | data_name: str 140 | Name of dataset 141 | data_path: str 142 | Path to dataset 143 | 144 | Returns 145 | ------- 146 | None 147 | 148 | """ 149 | 150 | if data_name in ['norman', 'adamson', 'dixit', 151 | 'replogle_k562_essential', 152 | 'replogle_rpe1_essential']: 153 | ## load from harvard dataverse 154 | if data_name == 'norman': 155 | url = 'https://dataverse.harvard.edu/api/access/datafile/6154020' 156 | elif data_name == 'adamson': 157 | url = 'https://dataverse.harvard.edu/api/access/datafile/6154417' 158 | elif data_name == 'dixit': 159 | url = 'https://dataverse.harvard.edu/api/access/datafile/6154416' 160 | elif data_name == 'replogle_k562_essential': 161 | ## Note: This is not the complete dataset and has been filtered 162 | url = 'https://dataverse.harvard.edu/api/access/datafile/7458695' 163 | elif data_name == 'replogle_rpe1_essential': 164 | ## Note: This is not the complete dataset and has been filtered 165 | url = 'https://dataverse.harvard.edu/api/access/datafile/7458694' 166 | data_path = os.path.join(self.data_path, data_name) 167 | zip_data_download_wrapper(url, data_path, self.data_path) 168 | self.dataset_name = data_path.split('/')[-1] 169 | self.dataset_path = data_path 170 | adata_path = os.path.join(data_path, 'perturb_processed.h5ad') 171 | self.adata = sc.read_h5ad(adata_path) 172 | 173 | elif os.path.exists(data_path): 174 | adata_path = os.path.join(data_path, 'perturb_processed.h5ad') 175 | self.adata = sc.read_h5ad(adata_path) 176 | self.dataset_name = data_path.split('/')[-1] 177 | self.dataset_path = data_path 178 | else: 179 | raise ValueError("data attribute is either norman, adamson, dixit " 180 | "replogle_k562 or replogle_rpe1 " 181 | "or a path to an h5ad file") 182 | 183 | self.set_pert_genes() 184 | print_sys('These perturbations are not in the GO graph and their ' 185 | 'perturbation can thus not be predicted') 186 | not_in_go_pert = np.array(self.adata.obs[ 187 | self.adata.obs.condition.apply( 188 | lambda x:not filter_pert_in_go(x, 189 | self.pert_names))].condition.unique()) 190 | print_sys(not_in_go_pert) 191 | 192 | filter_go = self.adata.obs[self.adata.obs.condition.apply( 193 | lambda x: filter_pert_in_go(x, self.pert_names))] 194 | self.adata = self.adata[filter_go.index.values, :] 195 | pyg_path = os.path.join(data_path, 'data_pyg') 196 | if not os.path.exists(pyg_path): 197 | os.mkdir(pyg_path) 198 | dataset_fname = os.path.join(pyg_path, 'cell_graphs.pkl') 199 | 200 | if os.path.isfile(dataset_fname): 201 | print_sys("Local copy of pyg dataset is detected. Loading...") 202 | self.dataset_processed = pickle.load(open(dataset_fname, "rb")) 203 | print_sys("Done!") 204 | else: 205 | self.ctrl_adata = self.adata[self.adata.obs['condition'] == 'ctrl'] 206 | self.gene_names = self.adata.var.gene_name 207 | 208 | 209 | print_sys("Creating pyg object for each cell in the data...") 210 | self.create_dataset_file() 211 | print_sys("Saving new dataset pyg object at " + dataset_fname) 212 | pickle.dump(self.dataset_processed, open(dataset_fname, "wb")) 213 | print_sys("Done!") 214 | 215 | def new_data_process(self, dataset_name, 216 | adata = None, 217 | skip_calc_de = False, sub_task_data = None): 218 | """ 219 | Process new dataset 220 | 221 | Parameters 222 | ---------- 223 | dataset_name: str 224 | Name of dataset 225 | adata: AnnData object 226 | AnnData object containing gene expression data 227 | skip_calc_de: bool 228 | If True, skip differential expression calculation 229 | 230 | Returns 231 | ------- 232 | None 233 | 234 | """ 235 | 236 | if 'condition' not in adata.obs.columns.values: 237 | raise ValueError("Please specify condition") 238 | if 'gene_name' not in adata.var.columns.values: 239 | raise ValueError("Please specify gene name") 240 | if 'cell_type' not in adata.obs.columns.values: 241 | raise ValueError("Please specify cell type") 242 | 243 | dataset_name = dataset_name.lower() 244 | self.dataset_name = dataset_name 245 | save_data_folder = os.path.join(self.data_path, dataset_name) 246 | 247 | if not os.path.exists(save_data_folder): 248 | os.mkdir(save_data_folder) 249 | self.dataset_path = save_data_folder 250 | self.adata = get_DE_genes(adata, skip_calc_de) 251 | if not skip_calc_de: 252 | self.adata = get_dropout_non_zero_genes(self.adata) 253 | self.adata.write_h5ad(os.path.join(save_data_folder, 'perturb_processed.h5ad')) 254 | 255 | self.set_pert_genes() 256 | self.ctrl_adata = self.adata[self.adata.obs['condition'] == 'ctrl'] 257 | self.gene_names = self.adata.var.gene_name 258 | pyg_path = os.path.join(save_data_folder, 'data_pyg') 259 | if not os.path.exists(pyg_path): 260 | os.mkdir(pyg_path) 261 | dataset_fname = os.path.join(pyg_path, 'cell_graphs.pkl') 262 | print_sys("Creating pyg object for each cell in the data...") 263 | self.create_dataset_file(sub_task_data) 264 | print_sys("Saving new dataset pyg object at " + dataset_fname) 265 | pickle.dump(self.dataset_processed, open(dataset_fname, "wb")) 266 | print_sys("Done!") 267 | 268 | def prepare_split(self, split = 'simulation', 269 | seed = 1, 270 | train_gene_set_size = 0.75, 271 | combo_seen2_train_frac = 0.75, 272 | combo_single_split_test_set_fraction = 0.1, 273 | test_perts = None, 274 | only_test_set_perts = False, 275 | test_pert_genes = None, 276 | split_dict_path=None): 277 | 278 | """ 279 | Prepare splits for training and testing 280 | 281 | Parameters 282 | ---------- 283 | split: str 284 | Type of split to use. Currently, we support 'simulation', 285 | 'simulation_single', 'combo_seen0', 'combo_seen1', 'combo_seen2', 286 | 'single', 'no_test', 'no_split', 'custom' 287 | seed: int 288 | Random seed 289 | train_gene_set_size: float 290 | Fraction of genes to use for training 291 | combo_seen2_train_frac: float 292 | Fraction of combo seen2 perturbations to use for training 293 | combo_single_split_test_set_fraction: float 294 | Fraction of combo single perturbations to use for testing 295 | test_perts: list 296 | List of perturbations to use for testing 297 | only_test_set_perts: bool 298 | If True, only use test set perturbations for testing 299 | test_pert_genes: list 300 | List of genes to use for testing 301 | split_dict_path: str 302 | Path to dictionary used for custom split. Sample format: 303 | {'train': [X, Y], 'val': [P, Q], 'test': [Z]} 304 | 305 | Returns 306 | ------- 307 | None 308 | 309 | """ 310 | available_splits = ['simulation', 'simulation_single', 'combo_seen0', 311 | 'combo_seen1', 'combo_seen2', 'single', 'no_test', 312 | 'no_split', 'custom'] 313 | if split not in available_splits: 314 | raise ValueError('currently, we only support ' + ','.join(available_splits)) 315 | self.split = split 316 | self.seed = seed 317 | self.subgroup = None 318 | self.train_gene_set_size = train_gene_set_size 319 | 320 | split_folder = os.path.join(self.dataset_path, 'splits') 321 | if not os.path.exists(split_folder): 322 | os.mkdir(split_folder) 323 | split_file = self.dataset_name + '_' + split + '_' + str(seed) + '_' \ 324 | + str(train_gene_set_size) + '.pkl' 325 | split_path = os.path.join(split_folder, split_file) 326 | 327 | if test_perts: 328 | split_path = split_path[:-4] + '_' + test_perts + '.pkl' 329 | 330 | if os.path.exists(split_path): 331 | print_sys("Local copy of split is detected. Loading...") 332 | set2conditions = pickle.load(open(split_path, "rb")) 333 | if split == 'simulation': 334 | subgroup_path = split_path[:-4] + '_subgroup.pkl' 335 | subgroup = pickle.load(open(subgroup_path, "rb")) 336 | self.subgroup = subgroup 337 | else: 338 | print_sys("Creating new splits....") 339 | if test_perts: 340 | test_perts = test_perts.split('_') 341 | 342 | if split in ['simulation', 'simulation_single']: 343 | # simulation split 344 | DS = DataSplitter(self.adata, split_type=split) 345 | 346 | adata, subgroup = DS.split_data(train_gene_set_size = train_gene_set_size, 347 | combo_seen2_train_frac = combo_seen2_train_frac, 348 | seed=seed, 349 | test_perts = test_perts, 350 | only_test_set_perts = only_test_set_perts 351 | ) 352 | subgroup_path = split_path[:-4] + '_subgroup.pkl' 353 | pickle.dump(subgroup, open(subgroup_path, "wb")) 354 | self.subgroup = subgroup 355 | 356 | elif split[:5] == 'combo': 357 | # combo perturbation 358 | split_type = 'combo' 359 | seen = int(split[-1]) 360 | 361 | if test_pert_genes: 362 | test_pert_genes = test_pert_genes.split('_') 363 | 364 | DS = DataSplitter(self.adata, split_type=split_type, seen=int(seen)) 365 | adata = DS.split_data(test_size=combo_single_split_test_set_fraction, 366 | test_perts=test_perts, 367 | test_pert_genes=test_pert_genes, 368 | seed=seed) 369 | 370 | elif split == 'single': 371 | # single perturbation 372 | DS = DataSplitter(self.adata, split_type=split) 373 | adata = DS.split_data(test_size=combo_single_split_test_set_fraction, 374 | seed=seed) 375 | 376 | elif split == 'no_test': 377 | # no test set 378 | DS = DataSplitter(self.adata, split_type=split) 379 | adata = DS.split_data(seed=seed) 380 | 381 | elif split == 'no_split': 382 | # no split 383 | adata = self.adata 384 | adata.obs['split'] = 'test' 385 | 386 | elif split == 'custom': 387 | adata = self.adata 388 | try: 389 | with open(split_dict_path, 'rb') as f: 390 | split_dict = pickle.load(f) 391 | except: 392 | raise ValueError('Please set split_dict_path for custom split') 393 | adata.obs['split'] = adata.obs['condition'].map(split_dict) 394 | 395 | 396 | 397 | set2conditions = dict(adata.obs.groupby('split').agg({'condition': 398 | lambda x: x}).condition) 399 | set2conditions = {i: j.unique().tolist() for i,j in set2conditions.items()} 400 | pickle.dump(set2conditions, open(split_path, "wb")) 401 | print_sys("Saving new splits at " + split_path) 402 | 403 | self.set2conditions = set2conditions 404 | 405 | if split == 'simulation': 406 | print_sys('Simulation split test composition:') 407 | for i,j in subgroup['test_subgroup'].items(): 408 | print_sys(i + ':' + str(len(j))) 409 | print_sys("Done!") 410 | 411 | def get_dataloader(self, batch_size, test_batch_size = None): 412 | """ 413 | Get dataloaders for training and testing 414 | 415 | Parameters 416 | ---------- 417 | batch_size: int 418 | Batch size for training 419 | test_batch_size: int 420 | Batch size for testing 421 | 422 | Returns 423 | ------- 424 | dict 425 | Dictionary of dataloaders 426 | 427 | """ 428 | if test_batch_size is None: 429 | test_batch_size = batch_size 430 | 431 | self.node_map = {x: it for it, x in enumerate(self.adata.var.gene_name)} 432 | self.gene_names = self.adata.var.gene_name 433 | 434 | # Create cell graphs 435 | cell_graphs = {} 436 | if self.split == 'no_split': 437 | i = 'test' 438 | cell_graphs[i] = [] 439 | for p in self.set2conditions[i]: 440 | if p != 'ctrl': 441 | cell_graphs[i].extend(self.dataset_processed[p]) 442 | 443 | print_sys("Creating dataloaders....") 444 | # Set up dataloaders 445 | test_loader = DataLoader(cell_graphs['test'], 446 | batch_size=batch_size, shuffle=False) 447 | 448 | print_sys("Dataloaders created...") 449 | return {'test_loader': test_loader} 450 | else: 451 | if self.split =='no_test': 452 | splits = ['train','val'] 453 | else: 454 | splits = ['train','val','test'] 455 | for i in splits: 456 | cell_graphs[i] = [] 457 | for p in self.set2conditions[i]: 458 | cell_graphs[i].extend(self.dataset_processed[p]) 459 | 460 | print_sys("Creating dataloaders....") 461 | 462 | # Set up dataloaders 463 | train_loader = DataLoader(cell_graphs['train'], 464 | batch_size=batch_size, shuffle=True, drop_last = True) 465 | val_loader = DataLoader(cell_graphs['val'], 466 | batch_size=batch_size, shuffle=True) 467 | 468 | if self.split !='no_test': 469 | test_loader = DataLoader(cell_graphs['test'], 470 | batch_size=batch_size, shuffle=False) 471 | self.dataloader = {'train_loader': train_loader, 472 | 'val_loader': val_loader, 473 | 'test_loader': test_loader} 474 | 475 | else: 476 | self.dataloader = {'train_loader': train_loader, 477 | 'val_loader': val_loader} 478 | print_sys("Done!") 479 | 480 | def get_pert_idx(self, pert_category): 481 | """ 482 | Get perturbation index for a given perturbation category 483 | 484 | Parameters 485 | ---------- 486 | pert_category: str 487 | Perturbation category 488 | 489 | Returns 490 | ------- 491 | list 492 | List of perturbation indices 493 | 494 | """ 495 | try: 496 | pert_idx = [np.where(p == self.pert_names)[0][0] 497 | for p in pert_category.split('+') 498 | if p != 'ctrl'] 499 | except: 500 | print(pert_category) 501 | pert_idx = None 502 | 503 | return pert_idx 504 | 505 | def create_cell_graph(self, X, y, de_idx, pert, pert_idx=None, y_1 = None, y_2 = None, y_3 = None): 506 | """ 507 | Create a cell graph from a given cell 508 | 509 | Parameters 510 | ---------- 511 | X: np.ndarray 512 | Gene expression matrix 513 | y: np.ndarray 514 | Label vector 515 | de_idx: np.ndarray 516 | DE gene indices 517 | pert: str 518 | Perturbation category 519 | pert_idx: list 520 | List of perturbation indices 521 | 522 | Returns 523 | ------- 524 | torch_geometric.data.Data 525 | Cell graph to be used in dataloader 526 | 527 | """ 528 | 529 | feature_mat = torch.Tensor(X).T 530 | if pert_idx is None: 531 | pert_idx = [-1] 532 | return Data(x=feature_mat, pert_idx=pert_idx, 533 | y=torch.Tensor(y), de_idx=de_idx, pert=pert, y_1 = torch.Tensor(y_1), y_2 = torch.Tensor(y_2), y_3 = torch.Tensor(y_3)) 534 | 535 | def create_cell_graph_dataset(self, split_adata, pert_category, 536 | num_samples=1, sub_tasks_data = None): 537 | """ 538 | Combine cell graphs to create a dataset of cell graphs 539 | 540 | Parameters 541 | ---------- 542 | split_adata: anndata.AnnData 543 | Annotated data matrix 544 | pert_category: str 545 | Perturbation category 546 | num_samples: int 547 | Number of samples to create per perturbed cell (i.e. number of 548 | control cells to map to each perturbed cell) 549 | 550 | Returns 551 | ------- 552 | list 553 | List of cell graphs 554 | 555 | """ 556 | # This is modification 557 | if sub_tasks_data is not None: 558 | sub_tasks_data = sc.read_h5ad(sub_tasks_data) 559 | gene_idx_dict = dict(zip(list(sub_tasks_data.obs.index), list(range(len(sub_tasks_data.obs.index))))) 560 | 561 | num_de_genes = 20 562 | adata_ = split_adata[split_adata.obs['condition'] == pert_category] 563 | if 'rank_genes_groups_cov_all' in adata_.uns: 564 | de_genes = adata_.uns['rank_genes_groups_cov_all'] 565 | de = True 566 | else: 567 | de = False 568 | num_de_genes = 1 569 | Xs = [] 570 | ys = [] 571 | 572 | # this is modification 573 | y1 = [] 574 | y2 = [] 575 | y3 = [] 576 | # When considering a non-control perturbation 577 | if pert_category != 'ctrl': 578 | # Get the indices of applied perturbation 579 | pert_idx = self.get_pert_idx(pert_category) 580 | 581 | y_sub_tasks = sub_tasks_data[gene_idx_dict[pert_category]] # this is modification 582 | 583 | # Store list of genes that are most differentially expressed for testing 584 | pert_de_category = adata_.obs['condition_name'][0] 585 | if de: 586 | de_idx = np.where(adata_.var_names.isin( 587 | np.array(de_genes[pert_de_category][:num_de_genes])))[0] 588 | else: 589 | de_idx = [-1] * num_de_genes 590 | for cell_z in adata_.X: 591 | # Use samples from control as basal expression 592 | ctrl_samples = self.ctrl_adata[np.random.randint(0, 593 | len(self.ctrl_adata), num_samples), :] 594 | for c in ctrl_samples.X: 595 | Xs.append(c) 596 | ys.append(cell_z) 597 | y1.append(y_sub_tasks.layers['level1'].astype('float32')) 598 | y2.append(y_sub_tasks.layers['level2'].astype('float32')) 599 | y3.append(y_sub_tasks.layers['level3'].astype('float32')) 600 | # When considering a control perturbation 601 | else: 602 | y_sub_tasks = sub_tasks_data[gene_idx_dict[pert_category]] 603 | pert_idx = None 604 | de_idx = [-1] * num_de_genes 605 | for cell_z in adata_.X: 606 | Xs.append(cell_z) 607 | ys.append(cell_z) 608 | y1.append(y_sub_tasks.layers['level1'].astype('float32')) 609 | y2.append(y_sub_tasks.layers['level2'].astype('float32')) 610 | y3.append(y_sub_tasks.layers['level3'].astype('float32')) 611 | # Create cell graphs 612 | cell_graphs = [] 613 | for X, y, y_1, y_2, y_3 in zip(Xs, ys, y1, y2, y3): 614 | cell_graphs.append(self.create_cell_graph(X.toarray(), 615 | y.toarray(), de_idx, pert_category, pert_idx, y_1.toarray(), y_2.toarray(), y_3.toarray())) 616 | 617 | return cell_graphs 618 | 619 | def create_dataset_file(self, sub_tasks_data = None): 620 | """ 621 | Create dataset file for each perturbation condition 622 | """ 623 | print_sys("Creating dataset file...") 624 | self.dataset_processed = {} 625 | for p in tqdm(self.adata.obs['condition'].unique()): 626 | self.dataset_processed[p] = self.create_cell_graph_dataset(self.adata, p, sub_tasks_data = sub_tasks_data) 627 | print_sys("Done!") 628 | -------------------------------------------------------------------------------- /Supp_code/Modified_GEARS/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | import networkx as nx 5 | from tqdm import tqdm 6 | import pickle 7 | import sys, os 8 | import requests 9 | from torch_geometric.data import Data 10 | from zipfile import ZipFile 11 | import tarfile 12 | from sklearn.linear_model import TheilSenRegressor 13 | from dcor import distance_correlation 14 | from multiprocessing import Pool 15 | 16 | def parse_single_pert(i): 17 | a = i.split('+')[0] 18 | b = i.split('+')[1] 19 | if a == 'ctrl': 20 | pert = b 21 | else: 22 | pert = a 23 | return pert 24 | 25 | def parse_combo_pert(i): 26 | return i.split('+')[0], i.split('+')[1] 27 | 28 | def combine_res(res_1, res_2): 29 | res_out = {} 30 | for key in res_1: 31 | res_out[key] = np.concatenate([res_1[key], res_2[key]]) 32 | return res_out 33 | 34 | def parse_any_pert(p): 35 | if ('ctrl' in p) and (p != 'ctrl'): 36 | return [parse_single_pert(p)] 37 | elif 'ctrl' not in p: 38 | out = parse_combo_pert(p) 39 | return [out[0], out[1]] 40 | 41 | def np_pearson_cor(x, y): 42 | xv = x - x.mean(axis=0) 43 | yv = y - y.mean(axis=0) 44 | xvss = (xv * xv).sum(axis=0) 45 | yvss = (yv * yv).sum(axis=0) 46 | result = np.matmul(xv.transpose(), yv) / np.sqrt(np.outer(xvss, yvss)) 47 | # bound the values to -1 to 1 in the event of precision issues 48 | return np.maximum(np.minimum(result, 1.0), -1.0) 49 | 50 | 51 | def dataverse_download(url, save_path): 52 | """ 53 | Dataverse download helper with progress bar 54 | 55 | Args: 56 | url (str): the url of the dataset 57 | path (str): the path to save the dataset 58 | """ 59 | 60 | if os.path.exists(save_path): 61 | print_sys('Found local copy...') 62 | else: 63 | print_sys("Downloading...") 64 | response = requests.get(url, stream=True) 65 | total_size_in_bytes= int(response.headers.get('content-length', 0)) 66 | block_size = 1024 67 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 68 | with open(save_path, 'wb') as file: 69 | for data in response.iter_content(block_size): 70 | progress_bar.update(len(data)) 71 | file.write(data) 72 | progress_bar.close() 73 | 74 | 75 | def zip_data_download_wrapper(url, save_path, data_path): 76 | """ 77 | Wrapper for zip file download 78 | 79 | Args: 80 | url (str): the url of the dataset 81 | save_path (str): the path where the file is donwloaded 82 | data_path (str): the path to save the extracted dataset 83 | """ 84 | 85 | if os.path.exists(save_path): 86 | print_sys('Found local copy...') 87 | else: 88 | dataverse_download(url, save_path + '.zip') 89 | print_sys('Extracting zip file...') 90 | with ZipFile((save_path + '.zip'), 'r') as zip: 91 | zip.extractall(path = data_path) 92 | print_sys("Done!") 93 | 94 | def tar_data_download_wrapper(url, save_path, data_path): 95 | """ 96 | Wrapper for tar file download 97 | 98 | Args: 99 | url (str): the url of the dataset 100 | save_path (str): the path where the file is donwloaded 101 | data_path (str): the path to save the extracted dataset 102 | 103 | """ 104 | 105 | if os.path.exists(save_path): 106 | print_sys('Found local copy...') 107 | else: 108 | dataverse_download(url, save_path + '.tar.gz') 109 | print_sys('Extracting tar file...') 110 | with tarfile.open(save_path + '.tar.gz') as tar: 111 | tar.extractall(path= data_path) 112 | print_sys("Done!") 113 | 114 | def get_go_auto(gene_list, data_path, data_name): 115 | """ 116 | Get gene ontology data 117 | 118 | Args: 119 | gene_list (list): list of gene names 120 | data_path (str): the path to save the extracted dataset 121 | data_name (str): the name of the dataset 122 | 123 | Returns: 124 | df_edge_list (pd.DataFrame): gene ontology edge list 125 | """ 126 | go_path = os.path.join(data_path, data_name, 'go.csv') 127 | 128 | if os.path.exists(go_path): 129 | return pd.read_csv(go_path) 130 | else: 131 | ## download gene2go.pkl 132 | if not os.path.exists(os.path.join(data_path, 'gene2go.pkl')): 133 | server_path = 'https://dataverse.harvard.edu/api/access/datafile/6153417' 134 | dataverse_download(server_path, os.path.join(data_path, 'gene2go.pkl')) 135 | with open(os.path.join(data_path, 'gene2go.pkl'), 'rb') as f: 136 | gene2go = pickle.load(f) 137 | 138 | gene2go = {i: list(gene2go[i]) for i in gene_list if i in gene2go} 139 | edge_list = [] 140 | for g1 in tqdm(gene2go.keys()): 141 | for g2 in gene2go.keys(): 142 | edge_list.append((g1, g2, len(np.intersect1d(gene2go[g1], 143 | gene2go[g2]))/len(np.union1d(gene2go[g1], gene2go[g2])))) 144 | 145 | edge_list_filter = [i for i in edge_list if i[2] > 0] 146 | further_filter = [i for i in edge_list if i[2] > 0.1] 147 | df_edge_list = pd.DataFrame(further_filter).rename(columns = {0: 'gene1', 148 | 1: 'gene2', 149 | 2: 'score'}) 150 | 151 | df_edge_list = df_edge_list.rename(columns = {'gene1': 'source', 152 | 'gene2': 'target', 153 | 'score': 'importance'}) 154 | df_edge_list.to_csv(go_path, index = False) 155 | return df_edge_list 156 | 157 | class GeneSimNetwork(): 158 | """ 159 | GeneSimNetwork class 160 | 161 | Args: 162 | edge_list (pd.DataFrame): edge list of the network 163 | gene_list (list): list of gene names 164 | node_map (dict): dictionary mapping gene names to node indices 165 | 166 | Attributes: 167 | edge_index (torch.Tensor): edge index of the network 168 | edge_weight (torch.Tensor): edge weight of the network 169 | G (nx.DiGraph): networkx graph object 170 | """ 171 | def __init__(self, edge_list, gene_list, node_map): 172 | """ 173 | Initialize GeneSimNetwork class 174 | """ 175 | 176 | self.edge_list = edge_list 177 | self.G = nx.from_pandas_edgelist(self.edge_list, source='source', 178 | target='target', edge_attr=['importance'], 179 | create_using=nx.DiGraph()) 180 | self.gene_list = gene_list 181 | for n in self.gene_list: 182 | if n not in self.G.nodes(): 183 | self.G.add_node(n) 184 | 185 | edge_index_ = [(node_map[e[0]], node_map[e[1]]) for e in 186 | self.G.edges] 187 | self.edge_index = torch.tensor(edge_index_, dtype=torch.long).T 188 | #self.edge_weight = torch.Tensor(self.edge_list['importance'].values) 189 | 190 | edge_attr = nx.get_edge_attributes(self.G, 'importance') 191 | importance = np.array([edge_attr[e] for e in self.G.edges]) 192 | self.edge_weight = torch.Tensor(importance) 193 | 194 | def get_GO_edge_list(args): 195 | """ 196 | Get gene ontology edge list 197 | """ 198 | g1, gene2go = args 199 | edge_list = [] 200 | for g2 in gene2go.keys(): 201 | score = len(gene2go[g1].intersection(gene2go[g2])) / len( 202 | gene2go[g1].union(gene2go[g2])) 203 | if score > 0.1: 204 | edge_list.append((g1, g2, score)) 205 | return edge_list 206 | 207 | def make_GO(data_path, pert_list, data_name, num_workers=25, save=True): 208 | """ 209 | Creates Gene Ontology graph from a custom set of genes 210 | """ 211 | 212 | fname = './data/go_essential_' + data_name + '.csv' 213 | if os.path.exists(fname): 214 | return pd.read_csv(fname) 215 | 216 | with open(os.path.join(data_path, 'gene2go_all.pkl'), 'rb') as f: 217 | gene2go = pickle.load(f) 218 | gene2go = {i: gene2go[i] for i in pert_list} 219 | 220 | print('Creating custom GO graph, this can take a few minutes') 221 | with Pool(num_workers) as p: 222 | all_edge_list = list( 223 | tqdm(p.imap(get_GO_edge_list, ((g, gene2go) for g in gene2go.keys())), 224 | total=len(gene2go.keys()))) 225 | edge_list = [] 226 | for i in all_edge_list: 227 | edge_list = edge_list + i 228 | 229 | df_edge_list = pd.DataFrame(edge_list).rename( 230 | columns={0: 'source', 1: 'target', 2: 'importance'}) 231 | 232 | if save: 233 | print('Saving edge_list to file') 234 | df_edge_list.to_csv(fname, index=False) 235 | 236 | return df_edge_list 237 | 238 | def get_similarity_network(network_type, adata, threshold, k, 239 | data_path, data_name, split, seed, train_gene_set_size, 240 | set2conditions, default_pert_graph=True, pert_list=None): 241 | 242 | if network_type == 'co-express': 243 | df_out = get_coexpression_network_from_train(adata, threshold, k, 244 | data_path, data_name, split, 245 | seed, train_gene_set_size, 246 | set2conditions) 247 | elif network_type == 'go': 248 | if default_pert_graph: 249 | server_path = 'https://dataverse.harvard.edu/api/access/datafile/6934319' 250 | tar_data_download_wrapper(server_path, 251 | os.path.join(data_path, 'go_essential_all'), 252 | data_path) 253 | df_jaccard = pd.read_csv(os.path.join(data_path, 254 | 'go_essential_all/go_essential_all.csv')) 255 | 256 | else: 257 | df_jaccard = make_GO(data_path, pert_list, data_name) 258 | 259 | df_out = df_jaccard.groupby('target').apply(lambda x: x.nlargest(k + 1, 260 | ['importance'])).reset_index(drop = True) 261 | 262 | return df_out 263 | 264 | def get_coexpression_network_from_train(adata, threshold, k, data_path, 265 | data_name, split, seed, train_gene_set_size, 266 | set2conditions): 267 | """ 268 | Infer co-expression network from training data 269 | 270 | Args: 271 | adata (anndata.AnnData): anndata object 272 | threshold (float): threshold for co-expression 273 | k (int): number of edges to keep 274 | data_path (str): path to data 275 | data_name (str): name of dataset 276 | split (str): split of dataset 277 | seed (int): seed for random number generator 278 | train_gene_set_size (int): size of training gene set 279 | set2conditions (dict): dictionary of perturbations to conditions 280 | """ 281 | 282 | fname = os.path.join(os.path.join(data_path, data_name), split + '_' + 283 | str(seed) + '_' + str(train_gene_set_size) + '_' + 284 | str(threshold) + '_' + str(k) + 285 | '_co_expression_network.csv') 286 | 287 | if os.path.exists(fname): 288 | return pd.read_csv(fname) 289 | else: 290 | gene_list = [f for f in adata.var.gene_name.values] 291 | idx2gene = dict(zip(range(len(gene_list)), gene_list)) 292 | X = adata.X 293 | train_perts = set2conditions['train'] 294 | X_tr = X[np.isin(adata.obs.condition, [i for i in train_perts if 'ctrl' in i])] 295 | gene_list = adata.var['gene_name'].values 296 | 297 | X_tr = X_tr.toarray() 298 | out = np_pearson_cor(X_tr, X_tr) 299 | out[np.isnan(out)] = 0 300 | out = np.abs(out) 301 | 302 | out_sort_idx = np.argsort(out)[:, -(k + 1):] 303 | out_sort_val = np.sort(out)[:, -(k + 1):] 304 | 305 | df_g = [] 306 | for i in range(out_sort_idx.shape[0]): 307 | target = idx2gene[i] 308 | for j in range(out_sort_idx.shape[1]): 309 | df_g.append((idx2gene[out_sort_idx[i, j]], target, out_sort_val[i, j])) 310 | 311 | df_g = [i for i in df_g if i[2] > threshold] 312 | df_co_expression = pd.DataFrame(df_g).rename(columns = {0: 'source', 313 | 1: 'target', 314 | 2: 'importance'}) 315 | df_co_expression.to_csv(fname, index = False) 316 | return df_co_expression 317 | 318 | def filter_pert_in_go(condition, pert_names): 319 | """ 320 | Filter perturbations in GO graph 321 | 322 | Args: 323 | condition (str): whether condition is 'ctrl' or not 324 | pert_names (list): list of perturbations 325 | """ 326 | 327 | if condition == 'ctrl': 328 | return True 329 | else: 330 | cond1 = condition.split('+')[0] 331 | cond2 = condition.split('+')[1] 332 | num_ctrl = (cond1 == 'ctrl') + (cond2 == 'ctrl') 333 | num_in_perts = (cond1 in pert_names) + (cond2 in pert_names) 334 | if num_ctrl + num_in_perts == 2: 335 | return True 336 | else: 337 | return False 338 | 339 | def uncertainty_loss_fct(pred, logvar, y, perts, reg = 0.1, ctrl = None, 340 | direction_lambda = 1e-3, dict_filter = None): 341 | """ 342 | Uncertainty loss function 343 | 344 | Args: 345 | pred (torch.tensor): predicted values 346 | logvar (torch.tensor): log variance 347 | y (torch.tensor): true values 348 | perts (list): list of perturbations 349 | reg (float): regularization parameter 350 | ctrl (str): control perturbation 351 | direction_lambda (float): direction loss weight hyperparameter 352 | dict_filter (dict): dictionary of perturbations to conditions 353 | 354 | """ 355 | gamma = 2 356 | perts = np.array(perts) 357 | losses = torch.tensor(0.0, requires_grad=True).to(pred.device) 358 | for p in set(perts): 359 | if p!= 'ctrl': 360 | retain_idx = dict_filter[p] 361 | pred_p = pred[np.where(perts==p)[0]][:, retain_idx] 362 | y_p = y[np.where(perts==p)[0]][:, retain_idx] 363 | logvar_p = logvar[np.where(perts==p)[0]][:, retain_idx] 364 | else: 365 | pred_p = pred[np.where(perts==p)[0]] 366 | y_p = y[np.where(perts==p)[0]] 367 | logvar_p = logvar[np.where(perts==p)[0]] 368 | 369 | # uncertainty based loss 370 | losses += torch.sum((pred_p - y_p)**(2 + gamma) + reg * torch.exp( 371 | -logvar_p) * (pred_p - y_p)**(2 + gamma))/pred_p.shape[0]/pred_p.shape[1] 372 | 373 | # direction loss 374 | if p!= 'ctrl': 375 | losses += torch.sum(direction_lambda * 376 | (torch.sign(y_p - ctrl[retain_idx]) - 377 | torch.sign(pred_p - ctrl[retain_idx]))**2)/\ 378 | pred_p.shape[0]/pred_p.shape[1] 379 | else: 380 | losses += torch.sum(direction_lambda * 381 | (torch.sign(y_p - ctrl) - 382 | torch.sign(pred_p - ctrl))**2)/\ 383 | pred_p.shape[0]/pred_p.shape[1] 384 | 385 | return losses/(len(set(perts))) 386 | 387 | 388 | def loss_fct(pred, y, perts, ctrl = None, direction_lambda = 1e-3, dict_filter = None): 389 | """ 390 | Main MSE Loss function, includes direction loss 391 | 392 | Args: 393 | pred (torch.tensor): predicted values 394 | y (torch.tensor): true values 395 | perts (list): list of perturbations 396 | ctrl (str): control perturbation 397 | direction_lambda (float): direction loss weight hyperparameter 398 | dict_filter (dict): dictionary of perturbations to conditions 399 | 400 | """ 401 | gamma = 2 402 | mse_p = torch.nn.MSELoss() 403 | perts = np.array(perts) 404 | losses = torch.tensor(0.0, requires_grad=True).to(pred.device) 405 | 406 | for p in set(perts): 407 | pert_idx = np.where(perts == p)[0] 408 | 409 | # during training, we remove the all zero genes into calculation of loss. 410 | # this gives a cleaner direction loss. empirically, the performance stays the same. 411 | if p!= 'ctrl': 412 | retain_idx = dict_filter[p] 413 | pred_p = pred[pert_idx][:, retain_idx] 414 | y_p = y[pert_idx][:, retain_idx] 415 | else: 416 | pred_p = pred[pert_idx] 417 | y_p = y[pert_idx] 418 | losses = losses + torch.sum((pred_p - y_p)**(2 + gamma))/pred_p.shape[0]/pred_p.shape[1] 419 | 420 | ## direction loss 421 | if (p!= 'ctrl'): 422 | losses = losses + torch.sum(direction_lambda * 423 | (torch.sign(y_p - ctrl[retain_idx]) - 424 | torch.sign(pred_p - ctrl[retain_idx]))**2)/\ 425 | pred_p.shape[0]/pred_p.shape[1] 426 | else: 427 | losses = losses + torch.sum(direction_lambda * (torch.sign(y_p - ctrl) - 428 | torch.sign(pred_p - ctrl))**2)/\ 429 | pred_p.shape[0]/pred_p.shape[1] 430 | return losses/(len(set(perts))) 431 | 432 | def loss_sub_task(pred, y_1, y_2, y_3): 433 | first_sub_task_out = pred[0].squeeze(-1) 434 | second_sub_task_out = pred[1].squeeze(-1) 435 | # third_sub_task_out = pred[2] 436 | 437 | level_1_output = y_1 438 | level_2_output = y_2 439 | # level_3_output = y_3 440 | 441 | # binary loss 442 | with torch.no_grad(): 443 | all_num = level_1_output.shape[0] * level_1_output.shape[1] 444 | DE_num = torch.sum(level_1_output.sum()) 445 | if DE_num == 0: 446 | DE_num = 1 447 | loss_weights = (all_num - DE_num) / DE_num 448 | loss_binary = torch.nn.BCELoss(weight = level_1_output * loss_weights / 4 + 1) 449 | loss_1 = loss_binary(first_sub_task_out.squeeze(-1), level_1_output) 450 | 451 | 452 | with torch.no_grad(): 453 | mask = level_1_output!=0 454 | with torch.no_grad(): 455 | up_num = torch.sum(mask * level_2_output) 456 | all_num = torch.sum(mask) 457 | if up_num == 0: 458 | up_num = 1 459 | weights = mask * level_2_output * all_num / up_num 460 | weights[weights > 0] -= 1 461 | loss_binary = torch.nn.BCELoss(weight = weights + mask, reduction = 'sum') 462 | loss_2 = loss_binary(second_sub_task_out, level_2_output) / torch.sum(weights + mask) 463 | 464 | # loss_3 = torch.sum((mask * (third_sub_task_out.squeeze(-1)-third_sub_task_out) ** 2)) / torch.sum(mask) 465 | 466 | # return loss_1, loss_2, loss_3, (first_sub_task_out, second_sub_task_out, third_sub_task_out), mask 467 | return loss_1, loss_2, (first_sub_task_out, second_sub_task_out), mask 468 | 469 | def print_sys(s): 470 | """system print 471 | 472 | Args: 473 | s (str): the string to print 474 | """ 475 | print(s, flush = True, file = sys.stderr) 476 | 477 | def create_cell_graph_for_prediction(X, pert_idx, pert_gene): 478 | """ 479 | Create a perturbation specific cell graph for inference 480 | 481 | Args: 482 | X (np.array): gene expression matrix 483 | pert_idx (list): list of perturbation indices 484 | pert_gene (list): list of perturbations 485 | 486 | """ 487 | 488 | if pert_idx is None: 489 | pert_idx = [-1] 490 | return Data(x=torch.Tensor(X).T, pert_idx = pert_idx, pert=pert_gene) 491 | 492 | 493 | def create_cell_graph_dataset_for_prediction(pert_gene, ctrl_adata, gene_names, 494 | device, num_samples = 300): 495 | """ 496 | Create a perturbation specific cell graph dataset for inference 497 | 498 | Args: 499 | pert_gene (list): list of perturbations 500 | ctrl_adata (anndata): control anndata 501 | gene_names (list): list of gene names 502 | device (torch.device): device to use 503 | num_samples (int): number of samples to use for inference (default: 300) 504 | 505 | """ 506 | 507 | # Get the indices (and signs) of applied perturbation 508 | pert_idx = [np.where(p == np.array(gene_names))[0][0] for p in pert_gene] 509 | 510 | Xs = ctrl_adata[np.random.randint(0, len(ctrl_adata), num_samples), :].X.toarray() 511 | # Create cell graphs 512 | cell_graphs = [create_cell_graph_for_prediction(X, pert_idx, pert_gene).to(device) for X in Xs] 513 | return cell_graphs 514 | 515 | ## 516 | ##GI related utils 517 | ## 518 | 519 | def get_coeffs(singles_expr, first_expr, second_expr, double_expr): 520 | """ 521 | Get coefficients for GI calculation 522 | 523 | Args: 524 | singles_expr (np.array): single perturbation expression 525 | first_expr (np.array): first perturbation expression 526 | second_expr (np.array): second perturbation expression 527 | double_expr (np.array): double perturbation expression 528 | 529 | """ 530 | results = {} 531 | results['ts'] = TheilSenRegressor(fit_intercept=False, 532 | max_subpopulation=1e5, 533 | max_iter=1000, 534 | random_state=1000) 535 | X = singles_expr 536 | y = double_expr 537 | results['ts'].fit(X, y.ravel()) 538 | Zts = results['ts'].predict(X) 539 | results['c1'] = results['ts'].coef_[0] 540 | results['c2'] = results['ts'].coef_[1] 541 | results['mag'] = np.sqrt((results['c1']**2 + results['c2']**2)) 542 | 543 | results['dcor'] = distance_correlation(singles_expr, double_expr) 544 | results['dcor_singles'] = distance_correlation(first_expr, second_expr) 545 | results['dcor_first'] = distance_correlation(first_expr, double_expr) 546 | results['dcor_second'] = distance_correlation(second_expr, double_expr) 547 | results['corr_fit'] = np.corrcoef(Zts.flatten(), double_expr.flatten())[0,1] 548 | results['dominance'] = np.abs(np.log10(results['c1']/results['c2'])) 549 | results['eq_contr'] = np.min([results['dcor_first'], results['dcor_second']])/\ 550 | np.max([results['dcor_first'], results['dcor_second']]) 551 | 552 | return results 553 | 554 | def get_GI_params(preds, combo): 555 | """ 556 | Get GI parameters 557 | 558 | Args: 559 | preds (dict): dictionary of predictions 560 | combo (list): list of perturbations 561 | 562 | """ 563 | singles_expr = np.array([preds[combo[0]], preds[combo[1]]]).T 564 | first_expr = np.array(preds[combo[0]]).T 565 | second_expr = np.array(preds[combo[1]]).T 566 | double_expr = np.array(preds[combo[0]+'_'+combo[1]]).T 567 | 568 | return get_coeffs(singles_expr, first_expr, second_expr, double_expr) 569 | 570 | def get_GI_genes_idx(adata, GI_gene_file): 571 | """ 572 | Optional: Reads a file containing a list of GI genes (usually those 573 | with high mean expression) 574 | 575 | Args: 576 | adata (anndata): anndata object 577 | GI_gene_file (str): file containing GI genes (generally corresponds 578 | to genes with high mean expression) 579 | """ 580 | # Genes used for linear model fitting 581 | GI_genes = np.load(GI_gene_file, allow_pickle=True) 582 | GI_genes_idx = np.where([g in GI_genes for g in adata.var.gene_name.values])[0] 583 | 584 | return GI_genes_idx 585 | 586 | def get_mean_control(adata): 587 | """ 588 | Get mean control expression 589 | """ 590 | mean_ctrl_exp = adata[adata.obs['condition'] == 'ctrl'].to_df().mean() 591 | return mean_ctrl_exp 592 | 593 | def get_genes_from_perts(perts): 594 | """ 595 | Returns list of genes involved in a given perturbation list 596 | """ 597 | 598 | if type(perts) is str: 599 | perts = [perts] 600 | gene_list = [p.split('+') for p in np.unique(perts)] 601 | gene_list = [item for sublist in gene_list for item in sublist] 602 | gene_list = [g for g in gene_list if g != 'ctrl'] 603 | return list(np.unique(gene_list)) 604 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: stamp 2 | channels: 3 | - pytorch 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=5.1=1_gnu 11 | - blas=1.0=mkl 12 | - bzip2=1.0.8=h5eee18b_6 13 | - ca-certificates=2024.3.11=h06a4308_0 14 | - cudatoolkit=11.3.1=h2bc3f7f_2 15 | - ffmpeg=4.3=hf484d3e_0 16 | - freetype=2.12.1=h4a9f257_0 17 | - gmp=6.2.1=h295c915_3 18 | - gnutls=3.6.15=he1e5248_0 19 | - intel-openmp=2021.4.0=h06a4308_3561 20 | - jpeg=9e=h5eee18b_1 21 | - lame=3.100=h7b6447c_0 22 | - lcms2=2.12=h3be6417_0 23 | - ld_impl_linux-64=2.38=h1181459_1 24 | - lerc=3.0=h295c915_0 25 | - libdeflate=1.17=h5eee18b_1 26 | - libffi=3.3=he6710b0_2 27 | - libgcc-ng=11.2.0=h1234567_1 28 | - libgomp=11.2.0=h1234567_1 29 | - libiconv=1.16=h5eee18b_3 30 | - libidn2=2.3.4=h5eee18b_0 31 | - libpng=1.6.39=h5eee18b_0 32 | - libstdcxx-ng=11.2.0=h1234567_1 33 | - libtasn1=4.19.0=h5eee18b_0 34 | - libtiff=4.5.1=h6a678d5_0 35 | - libunistring=0.9.10=h27cfd23_0 36 | - libuv=1.44.2=h5eee18b_0 37 | - libwebp-base=1.3.2=h5eee18b_0 38 | - lz4-c=1.9.4=h6a678d5_1 39 | - mkl=2021.4.0=h06a4308_640 40 | - mkl-service=2.4.0=py39h7f8727e_0 41 | - mkl_fft=1.3.1=py39hd3c417c_0 42 | - mkl_random=1.2.2=py39h51133e4_0 43 | - ncurses=6.4=h6a678d5_0 44 | - nettle=3.7.3=hbbd107a_1 45 | - numpy=1.24.3=py39h14f4228_0 46 | - numpy-base=1.24.3=py39h31eccc5_0 47 | - openh264=2.1.1=h4ff587b_0 48 | - openjpeg=2.4.0=h3ad879b_0 49 | - openssl=1.1.1w=h7f8727e_0 50 | - pillow=10.3.0=py39h5eee18b_0 51 | - pip=24.0=py39h06a4308_0 52 | - python=3.9.7=h12debd9_1 53 | - pytorch=1.10.2=py3.9_cuda11.3_cudnn8.2.0_0 54 | - pytorch-mutex=1.0=cuda 55 | - readline=8.2=h5eee18b_0 56 | - setuptools=69.5.1=py39h06a4308_0 57 | - six=1.16.0=pyhd3eb1b0_1 58 | - sqlite=3.45.3=h5eee18b_0 59 | - tk=8.6.14=h39e8969_0 60 | - torchaudio=0.10.2=py39_cu113 61 | - torchvision=0.11.3=py39_cu113 62 | - typing_extensions=4.11.0=py39h06a4308_0 63 | - wheel=0.43.0=py39h06a4308_0 64 | - xz=5.4.6=h5eee18b_1 65 | - zlib=1.2.13=h5eee18b_1 66 | - zstd=1.5.5=hc292b87_2 67 | - pip: 68 | - anndata==0.10.7 69 | - array-api-compat==1.7.1 70 | - asttokens==2.4.1 71 | - comm==0.2.2 72 | - contourpy==1.2.1 73 | - cycler==0.12.1 74 | - dcor==0.6 75 | - debugpy==1.8.1 76 | - decorator==5.1.1 77 | - exceptiongroup==1.2.1 78 | - executing==2.0.1 79 | - fonttools==4.53.0 80 | - get-annotations==0.1.2 81 | - h5py==3.11.0 82 | - importlib-metadata==7.2.0 83 | - importlib-resources==6.4.0 84 | - ipykernel==6.29.4 85 | - ipython==8.18.1 86 | - jedi==0.19.1 87 | - joblib==1.4.2 88 | - jupyter-client==8.6.2 89 | - jupyter-core==5.7.2 90 | - kiwisolver==1.4.5 91 | - legacy-api-wrap==1.4 92 | - llvmlite==0.42.0 93 | - matplotlib==3.9.0 94 | - matplotlib-inline==0.1.7 95 | - natsort==8.4.0 96 | - nest-asyncio==1.6.0 97 | - networkx==3.2.1 98 | - numba==0.59.1 99 | - packaging==24.1 100 | - pandas==2.2.2 101 | - parso==0.8.4 102 | - patsy==0.5.6 103 | - pexpect==4.9.0 104 | - platformdirs==4.2.2 105 | - prompt-toolkit==3.0.47 106 | - psutil==6.0.0 107 | - ptyprocess==0.7.0 108 | - pure-eval==0.2.2 109 | - pygments==2.18.0 110 | - pynndescent==0.5.12 111 | - pyparsing==3.1.2 112 | - python-dateutil==2.9.0.post0 113 | - pytz==2024.1 114 | - pyyaml==6.0.2rc1 115 | - pyzmq==26.0.3 116 | - scanpy==1.10.1 117 | - scikit-learn==1.5.0 118 | - scipy==1.13.1 119 | - seaborn==0.13.2 120 | - session-info==1.0.0 121 | - stack-data==0.6.3 122 | - stamp==0.1.0 123 | - statsmodels==0.14.2 124 | - stdlib-list==0.10.0 125 | - threadpoolctl==3.5.0 126 | - tornado==6.4.1 127 | - tqdm==4.66.4 128 | - traitlets==5.14.3 129 | - tzdata==2024.1 130 | - umap-learn==0.5.6 131 | - wcwidth==0.2.13 132 | - zipp==3.19.2 133 | prefix: /home/gaoyicheng/anaconda3/envs/stamp 134 | -------------------------------------------------------------------------------- /img/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bm2-lab/STAMP/f40e2887c013647bb01570bd61bdfa01d6a05b00/img/framework.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.3 2 | pandas==2.2.2 3 | tqdm==4.66.4 4 | scikit-learn==1.5.0 5 | torch 6 | scanpy==1.10.1 7 | dcor==0.6 8 | scipy==1.13.1 9 | joblib==1.4.2 10 | pyyaml 11 | setuptools==69.5.1 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from os import path 3 | from io import open 4 | 5 | # get __version__ from _version.py 6 | ver_file = path.join('stamp', 'version.py') 7 | with open(ver_file) as f: 8 | exec(f.read()) 9 | 10 | this_directory = path.abspath(path.dirname(__file__)) 11 | 12 | # read the contents of README.md 13 | def readme(): 14 | with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: 15 | return f.read() 16 | 17 | # read the contents of requirements.txt 18 | with open(path.join(this_directory, 'requirements.txt'), 19 | encoding='utf-8') as f: 20 | requirements = f.read().splitlines() 21 | 22 | setup( 23 | name='cell-stamp', 24 | version=__version__, 25 | description='STAMP for genetic perturbation prediction', 26 | long_description=readme(), 27 | long_description_content_type='text/markdown', 28 | url='https://github.com/bm2-lab/STAMP', 29 | author='Yicheng Gao, Zhiting Wei, Qi Liu', 30 | packages=find_packages(), 31 | zip_safe=False, 32 | include_package_data=True, 33 | install_requires=requirements, 34 | license='GPL-3.0 license' 35 | ) -------------------------------------------------------------------------------- /stamp/DataSet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | import joblib 5 | import scanpy as sc 6 | import numpy as np 7 | 8 | class PerturbDataSet(Dataset): 9 | def __init__(self, data_dir, gene_embedding_notebook): 10 | 11 | # read data 12 | try: 13 | data = sc.read_h5ad(data_dir) 14 | except: 15 | data = data_dir 16 | 17 | # obtain three levels data 18 | self.FL_data = data.layers['level1'] 19 | self.SL_data = data.layers['level2'] 20 | self.TL_data = data.layers['level3'] 21 | 22 | # obtain the gene names 23 | self.gene_name = list(data.var.index) 24 | 25 | # obtain perturb genes 26 | self.perturb_genes = list(data.obs.index) 27 | 28 | # obtain gene embedding matrix 29 | self.gene_embedding_notebook = gene_embedding_notebook 30 | 31 | def __getitem__(self, item): 32 | 33 | target_gene = self.perturb_genes[item] 34 | 35 | pertub_embeds = 0 36 | for idx, t in enumerate(target_gene.split(',')): 37 | target_gene_index = self.gene_name.index(t) 38 | pertub_embeds += self.gene_embedding_notebook[target_gene_index] 39 | pertub_embeds /= idx + 1 40 | 41 | FL_output = self.FL_data[item].toarray() 42 | 43 | SL_output = self.SL_data[item].toarray() 44 | 45 | TL_output = self.TL_data[item].toarray() 46 | 47 | return (target_gene, pertub_embeds), (FL_output, SL_output, TL_output) 48 | 49 | def __len__(self, ): 50 | return len(self.perturb_genes) 51 | 52 | def __main__(): 53 | dataset = PerturbDataSet("Path") 54 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = 64, shuffle = True) 55 | for idx, batch_x in enumerate(dataloader): 56 | print(idx, batch_x) -------------------------------------------------------------------------------- /stamp/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import math 5 | import matplotlib.pyplot as plt 6 | from torch.nn import functional as F 7 | 8 | class Bayes_first_level_layer(nn.Module): 9 | def __init__(self, in_features, out_features, bias = True): 10 | super().__init__() 11 | self.in_features = in_features 12 | self.out_features = out_features 13 | 14 | self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features)) 15 | self.weight_sigma_log = nn.Parameter(torch.Tensor(out_features, in_features)) 16 | self.bias = bias 17 | 18 | if self.bias: 19 | self.bias_mu = nn.Parameter(torch.Tensor(out_features)) 20 | self.bias_sigma_log = nn.Parameter(torch.Tensor(out_features)) 21 | 22 | self.activation = nn.Softmax(dim = 1) 23 | 24 | def forward(self, X): 25 | weight = self.weight_mu + torch.randn_like(self.weight_sigma_log) * torch.exp(self.weight_sigma_log) 26 | if self.bias: 27 | bias = self.bias + torch.randn_like(self.bias_sigma_log) * torch.exp(self.bias_sigma_log) 28 | else: 29 | bias = None 30 | hidden_states = F.linear(X, weight, bias) 31 | output = self.activation(hidden_states) 32 | return output 33 | 34 | 35 | class First_level_layer(nn.Module): 36 | def __init__(self, in_features, out_features, bias = True): 37 | super().__init__() 38 | self.in_features = in_features 39 | self.out_features = out_features 40 | 41 | self.mapping1 = nn.ModuleList([nn.Linear(self.in_features, 1024), nn.ReLU(), nn.Dropout(0.9), 42 | nn.Linear(1024, 2048), nn.ReLU(), nn.Dropout(0.9), 43 | nn.Linear(2048, self.out_features)]) 44 | 45 | self.activation = nn.Sigmoid() 46 | 47 | def forward(self, X): 48 | hidden_states = X 49 | for idx, h in enumerate(self.mapping1): 50 | hidden_states = h(hidden_states) 51 | output = self.activation(hidden_states) 52 | return output 53 | 54 | class First_level_layer_concate(nn.Module): 55 | def __init__(self, hid1_features = 128, hid2_features = 64, hid3_features = 32, gene_embedding_notebook = None): 56 | super().__init__() 57 | self.hid1_features = hid1_features 58 | self.hid2_features = hid2_features 59 | self.hid3_features = hid3_features 60 | # self.gene_embedding_notebook = nn.Parameter(gene_embedding_notebook) 61 | self.gene_embedding_notebook = gene_embedding_notebook 62 | 63 | self.mapping1 = nn.ModuleList([nn.Linear(self.gene_embedding_notebook.shape[1]*2, self.hid1_features), nn.ReLU(), 64 | nn.Linear(self.hid1_features, self.hid2_features), nn.ReLU(), 65 | nn.Linear(self.hid2_features, self.hid3_features), nn.ReLU()] 66 | ) 67 | self.mapping1_head = nn.Linear(self.hid3_features, 1) 68 | 69 | self.activation = nn.Sigmoid() 70 | 71 | def forward(self, pertub_genes_embeds): 72 | 73 | pertub_genes_embeds = pertub_genes_embeds.unsqueeze(1).expand(pertub_genes_embeds.shape[0], self.gene_embedding_notebook.shape[0], -1) 74 | expanded_notebook = self.gene_embedding_notebook.expand(pertub_genes_embeds.shape[0], self.gene_embedding_notebook.shape[0], -1) 75 | hids = torch.cat([pertub_genes_embeds, expanded_notebook.to(pertub_genes_embeds.device)], dim = -1) 76 | for idx, h in enumerate(self.mapping1): 77 | hids = h(hids) 78 | hids_head = self.mapping1_head(hids) 79 | output = self.activation(hids_head) 80 | 81 | return output 82 | 83 | class Second_level_layer(nn.Module): 84 | def __init__(self, hid1_features = 128, hid2_features = 64, hid3_features = 32, gene_embedding_notebook = None): 85 | super().__init__() 86 | self.hid1_features = hid1_features 87 | self.hid2_features = hid2_features 88 | self.hid3_features = hid3_features 89 | # self.gene_embedding_notebook = nn.Parameter(gene_embedding_notebook) 90 | self.gene_embedding_notebook = gene_embedding_notebook 91 | 92 | self.mapping2 = nn.ModuleList([nn.Linear(self.gene_embedding_notebook.shape[1] * 2, self.hid1_features), nn.ReLU(), 93 | nn.Linear(self.hid1_features, self.hid2_features), nn.ReLU(), 94 | nn.Linear(self.hid2_features, self.hid3_features), nn.ReLU()] 95 | ) 96 | self.mapping2_head = nn.Linear(self.hid3_features, 1) 97 | 98 | self.activation = nn.Sigmoid() 99 | 100 | def forward(self, X, pertub_genes_embeds): 101 | 102 | with torch.no_grad(): 103 | mask = X==0 104 | pertub_genes_embeds = pertub_genes_embeds.unsqueeze(1).expand(pertub_genes_embeds.shape[0], self.gene_embedding_notebook.shape[0], -1) 105 | expanded_notebook = self.gene_embedding_notebook.expand(pertub_genes_embeds.shape[0], self.gene_embedding_notebook.shape[0], -1) 106 | hids = torch.cat([pertub_genes_embeds, expanded_notebook.to(X.device)], dim = -1) 107 | for idx, h in enumerate(self.mapping2): 108 | hids = h(hids) 109 | hids_head = self.mapping2_head(hids) 110 | output_second_level = self.activation(hids_head) 111 | 112 | return output_second_level, ~mask, hids 113 | 114 | class Third_level_layer(nn.Module): 115 | def __init__(self, in_features = 32, hid1_features = 16, hid2_features = 8): 116 | super().__init__() 117 | self.in_features = in_features 118 | self.hid1_features = hid1_features 119 | self.hid2_features = hid2_features 120 | self.mapping3 = nn.ModuleList([nn.Linear(in_features, hid1_features), 121 | nn.Linear(hid1_features, hid2_features)]) 122 | self.mapping3_head = nn.Linear(hid2_features, 1) 123 | self.activation = nn.LeakyReLU() 124 | 125 | def forward(self, X, mask = None): 126 | hids = X 127 | for idx, h in enumerate(self.mapping3): 128 | hids = h(hids) 129 | hids_head = self.mapping3_head(hids) 130 | # output_third_level = self.activation(hids_head) 131 | output_third_level = torch.exp(hids_head) 132 | return output_third_level, mask 133 | 134 | class TaskCombineLayer_multi_task(nn.Module): 135 | def __init__(self, in_features, out_features, 136 | hid1_features_2 = 128, hid2_features_2 = 64, hid3_features_2 = 32, 137 | in_feature_3 = 32, hid1_features_3 = 16, hid2_features_3 = 8, 138 | gene_embedding_notebook = None ,bias = True): 139 | super().__init__() 140 | 141 | # you can set gene embeddings as a learnable parameters 142 | # self.gene_embedding_notebook = nn.Parameter(gene_embedding_notebook) 143 | self.gene_embedding_notebook = gene_embedding_notebook 144 | self.first_level_layer = First_level_layer(in_features, out_features) 145 | self.second_level_layer = Second_level_layer(hid1_features_2, hid2_features_2, hid3_features_2, gene_embedding_notebook = self.gene_embedding_notebook) 146 | self.third_level_layer = Third_level_layer(in_feature_3, hid1_features_3, hid2_features_3) 147 | 148 | def forward(self, X, 149 | level_1_output, 150 | level_2_output, 151 | level_3_output 152 | ) : 153 | output_1 = self.first_level_layer(X) 154 | # binary loss 155 | with torch.no_grad(): 156 | all_num = level_1_output.shape[0] * level_1_output.shape[1] 157 | DE_num = torch.sum(level_1_output.sum()) 158 | if DE_num == 0: 159 | DE_num = 1 160 | loss_weights = (all_num - DE_num) / DE_num 161 | loss_binary = nn.BCELoss(weight = level_1_output * loss_weights / 4 + 1) 162 | loss_1 = loss_binary(output_1.squeeze(-1), level_1_output) 163 | 164 | output_2, mask, hids = self.second_level_layer(level_1_output, X) 165 | with torch.no_grad(): 166 | up_num = torch.sum(mask * level_2_output) 167 | all_num = torch.sum(mask) 168 | if up_num == 0: 169 | up_num = 1 170 | weights = mask * level_2_output * all_num / up_num 171 | weights[weights > 0] -= 1 172 | new_weights = weights + mask 173 | if all_num <= up_num: 174 | all_num = up_num + 1 175 | new_weights[new_weights==1] = all_num /(all_num - up_num) 176 | 177 | loss_binary = nn.BCELoss(weight = new_weights, reduction = 'sum') 178 | loss_2 = loss_binary(output_2.squeeze(-1), level_2_output) / torch.sum(new_weights) 179 | 180 | output_3, mask = self.third_level_layer(hids, mask) 181 | loss_3 = torch.sum((mask * (output_3.squeeze(-1)-level_3_output) ** 2)) / torch.sum(mask) 182 | 183 | return loss_1, loss_2, loss_3, (output_1, output_2, output_3), mask 184 | 185 | class TaskCombineLayer_multi_task_concate(nn.Module): 186 | def __init__(self, hid1_features_1 = 128, hid2_features_1 = 64, hid3_features_1 = 32, 187 | hid1_features_2 = 128, hid2_features_2 = 64, hid3_features_2 = 32, 188 | in_feature_3 = 32, hid1_features_3 = 16, hid2_features_3 = 8, 189 | gene_embedding_notebook = None ,bias = True): 190 | super().__init__() 191 | 192 | # self.gene_embedding_notebook = nn.Parameter(gene_embedding_notebook) 193 | self.gene_embedding_notebook = gene_embedding_notebook 194 | self.first_level_layer = First_level_layer_concate(hid1_features_1, hid2_features_1, hid3_features_1, gene_embedding_notebook = self.gene_embedding_notebook) 195 | self.second_level_layer = Second_level_layer(hid1_features_2, hid2_features_2, hid3_features_2, gene_embedding_notebook = self.gene_embedding_notebook) 196 | self.third_level_layer = Third_level_layer(in_feature_3, hid1_features_3, hid2_features_3) 197 | 198 | def forward(self, X, 199 | level_1_output, 200 | level_2_output, 201 | level_3_output 202 | ) : 203 | output_1 = self.first_level_layer(X) 204 | # binary loss 205 | with torch.no_grad(): 206 | all_num = level_1_output.shape[0] * level_1_output.shape[1] 207 | DE_num = torch.sum(level_1_output.sum()) 208 | if DE_num == 0: 209 | DE_num = 1 210 | loss_weights = (all_num - DE_num) / DE_num 211 | loss_binary = nn.BCELoss(weight = level_1_output * loss_weights / 4 + 1) 212 | loss_1 = loss_binary(output_1.squeeze(-1), level_1_output) 213 | 214 | output_2, mask, hids = self.second_level_layer(level_1_output, X) 215 | with torch.no_grad(): 216 | up_num = torch.sum(mask * level_2_output) 217 | all_num = torch.sum(mask) 218 | if up_num == 0: 219 | up_num = 1 220 | weights = mask * level_2_output * all_num / up_num 221 | weights[weights > 0] -= 1 222 | new_weights = weights + mask 223 | if all_num <= up_num: 224 | all_num = up_num + 1 225 | new_weights[new_weights==1] = all_num /(all_num - up_num) 226 | 227 | loss_binary = nn.BCELoss(weight = new_weights, reduction = 'sum') 228 | loss_2 = loss_binary(output_2.squeeze(-1), level_2_output) / torch.sum(new_weights) 229 | 230 | output_3, mask = self.third_level_layer(hids, mask) 231 | loss_3 = torch.sum((mask * (output_3.squeeze(-1)-level_3_output) ** 2)) / torch.sum(mask) 232 | 233 | return loss_1, loss_2, loss_3, (output_1, output_2, output_3), mask -------------------------------------------------------------------------------- /stamp/__init__.py: -------------------------------------------------------------------------------- 1 | from .STAMP import STAMP 2 | from .DataSet import PerturbDataSet 3 | from .utils import load_config -------------------------------------------------------------------------------- /stamp/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | # load config file 3 | def load_config(config_file): 4 | with open(config_file) as file: 5 | config = yaml.safe_load(file) 6 | return config -------------------------------------------------------------------------------- /stamp/version.py: -------------------------------------------------------------------------------- 1 | """ 2 | STAMP version file 3 | """ 4 | __version__ = '0.1.2' --------------------------------------------------------------------------------