├── Data
    ├── GeneratingExampleData.py
    └── example_config.yaml
├── DataProcess
    ├── DataProcess.py
    ├── geneList.pkl
    └── myUtil.py
├── LICENSE
├── MANIFEST.in
├── Output
    └── training.log
├── README.md
├── Supp_code
    ├── GI_experiment
    │   ├── GI_evaluation.py
    │   ├── Gears_results.pkl
    │   ├── STAMP_results.pkl
    │   ├── Science_data.csv
    │   └── Truth_results.pkl
    ├── Modified_CPA
    │   ├── _api.py
    │   ├── _data.py
    │   ├── _metrics.py
    │   ├── _model.py
    │   ├── _module.py
    │   ├── _plotting.py
    │   ├── _task.py
    │   └── _utils.py
    └── Modified_GEARS
    │   ├── data_utils.py
    │   ├── gears.py
    │   ├── inference.py
    │   ├── model.py
    │   ├── pertdata.py
    │   └── utils.py
├── Tutorial
    └── tutorial_for_training.py.ipynb
├── environment.yml
├── img
    └── framework.png
├── requirements.txt
├── setup.py
└── stamp
    ├── DataSet.py
    ├── Modules.py
    ├── STAMP.py
    ├── __init__.py
    ├── utils.py
    └── version.py
/Data/GeneratingExampleData.py:
--------------------------------------------------------------------------------
 1 | import numpy as np
 2 | import scanpy as sc
 3 | import anndata as ad
 4 | import joblib
 5 | import scipy
 6 | import os
 7 | 
 8 | np.random.seed(888)
 9 | 
10 | # generating the training example data
11 | train_data_template_x = (np.random.rand(1000,5000) > 0.5) * 1
12 | train_data_layer_1 = (np.random.rand(1000,5000) > 0.5) * 1
13 | train_data_layer_2 = (np.random.rand(1000,5000) > 0.5) * 1
14 | train_data_layer_3 = np.random.rand(1000,5000)
15 | train_data = ad.AnnData(X = scipy.sparse.csr_matrix(train_data_template_x))
16 | train_data.layers['level1'] = scipy.sparse.csr_matrix(train_data_layer_1)
17 | train_data.layers['level2'] = scipy.sparse.csr_matrix(train_data_layer_2)
18 | train_data.layers['level3'] = scipy.sparse.csr_matrix(train_data_layer_3)
19 | train_data.obs.index = [f"Gene{i+1}" for i in range(1000)]
20 | train_data.var.index = [f"Gene{i+1}" for i in range(5000)]
21 | 
22 | # generating the validation example data
23 | val_data_template_x = (np.random.rand(100,5000) > 0.5) * 1
24 | val_data_layer_1 = (np.random.rand(100,5000) > 0.5) * 1
25 | val_data_layer_2 = (np.random.rand(100,5000) > 0.5) * 1
26 | val_data_layer_3 = np.random.rand(100,5000)
27 | val_data = ad.AnnData(X = scipy.sparse.csr_matrix(val_data_template_x))
28 | val_data.layers['level1'] = scipy.sparse.csr_matrix(val_data_layer_1)
29 | val_data.layers['level2'] = scipy.sparse.csr_matrix(val_data_layer_2)
30 | val_data.layers['level3'] = scipy.sparse.csr_matrix(val_data_layer_3)
31 | val_data.obs.index = [f"Gene{i+1}" for i in range(1000,1100)]
32 | val_data.var.index = [f"Gene{i+1}" for i in range(5000)]
33 | 
34 | # generating the testing example data
35 | test_data_template_x = (np.random.rand(200,5000) > 0.5) * 1
36 | test_data_layer_1 = (np.random.rand(200,5000) > 0.5) * 1
37 | test_data_layer_2 = (np.random.rand(200,5000) > 0.5) * 1
38 | test_data_layer_3 = np.random.rand(200,5000)
39 | test_data = ad.AnnData(X = scipy.sparse.csr_matrix(test_data_template_x))
40 | test_data.layers['level1'] = scipy.sparse.csr_matrix(test_data_layer_1)
41 | test_data.layers['level2'] = scipy.sparse.csr_matrix(test_data_layer_2)
42 | test_data.layers['level3'] = scipy.sparse.csr_matrix(test_data_layer_3)
43 | test_data.obs.index = [f"Gene{i+1},Gene{i+2}" for i in range(1100,1300)]
44 | test_data.var.index = [f"Gene{i+1}" for i in range(5000)]
45 | 
46 | # generating the top 40 DEGs for testing example data
47 | test_data_top40 = ad.AnnData(X = scipy.sparse.csr_matrix(test_data_template_x))
48 | test_data_layer_1_top40 = np.zeros_like(np.random.rand(200,5000))
49 | for i in range(200):
50 |     test_data_layer_1_top40[i][np.random.choice(5000,40,replace=False)]=1
51 | test_data_top40.layers['level1'] = scipy.sparse.csr_matrix(test_data_layer_1_top40)
52 | test_data_top40.layers['level2'] = scipy.sparse.csr_matrix(test_data_layer_2)
53 | test_data_top40.layers['level3'] = scipy.sparse.csr_matrix(test_data_layer_3)
54 | test_data_top40.obs.index = [f"Gene{i+1},Gene{i+2}" for i in range(1100,1300)]
55 | test_data_top40.var.index = [f"Gene{i+1}" for i in range(5000)]
56 | 
57 | # generating the gene embedding matrix and the gene embedding orders must be consistent with the gene orders of data.var
58 | gene_embs = np.random.rand(5000, 512).astype('float32')
59 | 
60 | if not os.path.exists("./Data"):
61 |     os.makedirs("./Data")
62 | 
63 | train_data.write("./Data/example_train.h5ad")
64 | val_data.write("./Data/example_val.h5ad")
65 | test_data.write("./Data/example_test.h5ad")
66 | test_data_top40.write("./Data/example_test_top40.h5ad")
67 | joblib.dump(gene_embs, "./Data/example_gene_embeddings.pkl")
--------------------------------------------------------------------------------
/Data/example_config.yaml:
--------------------------------------------------------------------------------
 1 | dataset:
 2 |   Training_dataset: ./Data/example_train.h5ad
 3 |   Validation_dataseta: ./Data/example_val.h5ad
 4 |   Testing_dataset: ./Data/example_test.h5ad
 5 |   Gene_embedding: ./Data/example_gene_embeddings.pkl
 6 | Train:
 7 |   Sampling:
 8 |     batch_size: 64
 9 |     sample_shuffle: True
10 |   
11 |   Model_Parameter:
12 |     First_level:
13 |       in_features: 512
14 |       out_features: 5000
15 | 
16 |     Second_level:
17 |       hid1_features_2: 128
18 |       hid2_features_2: 64
19 |       hid3_features_2: 32
20 |     Third_level:
21 |       in_feature_3: 32
22 |       hid1_features_3: 16
23 |       hid2_features_3: 8
24 |     device: cuda
25 | 
26 |   Trainer_parameter:
27 |     random_seed: 888
28 |     epoch: 10
29 |     learning_rate: 0.0001
30 |   
31 |   output_dir: ./Output
32 | 
33 | Inference:
34 |   Sampling:
35 |     batch_size: 256
36 |     sample_shuffle: False
--------------------------------------------------------------------------------
/DataProcess/DataProcess.py:
--------------------------------------------------------------------------------
  1 | import sys
  2 | sys.path.append('/home//project/GW_PerturbSeq')
  3 | from myUtil import *
  4 | import scanpy as sc
  5 | from tqdm import tqdm
  6 | import warnings
  7 | warnings.filterwarnings('ignore')
  8 | 
  9 | ### 数据预处理 以及得到True的差异基因等
 10 | def fun1(adata):
 11 |     adata = adata[~adata.obs['gene'].isin(['None', 'CTRL'])]
 12 |     tmp = [i.split(',') for i in adata.obs['gene']]
 13 |     tmp = [i for i in tmp if i != 'CTRL' and i != 'None']
 14 |     tmp = np.unique([i for j in tmp for i in j])
 15 |     return tmp
 16 | 
 17 | 
 18 | def fun2(adata):
 19 |     mylist = []
 20 |     for gene, cell in zip(adata.obs['gene'], adata.obs_names):
 21 |         if gene == 'CTRL':
 22 |             mylist.append(cell)
 23 |         else:
 24 |             genes = gene.split(',')
 25 |             tmp = [True if i in adata.var_names else False for i in genes]
 26 |             if np.all(tmp): mylist.append(cell)
 27 |     adata1 = adata[mylist, :]
 28 |     return adata1
 29 | 
 30 | def fun3(x):   ### 只保留最多两个组合扰动
 31 |     xs = x.split(',')
 32 |     if len(xs) >= 3: return False
 33 |     else: return True
 34 | 
 35 | 
 36 | def fun4(adata):
 37 |     with open('/home/project/GW_PerturbSeq/geneEmbedding/geneList.pkl', 'rb') as fin:
 38 |         geneList = pickle.load(fin)
 39 |     tmp = adata.var_names.isin(geneList)  ### 必须有geneEmbedding
 40 |     adata = adata[:, tmp]
 41 |     adata1 = fun2(adata)   ### 扰动基因必须在表达谱中 且有geneEmbedding
 42 |     adata1.write_h5ad('raw.h5ad')
 43 | 
 44 | 
 45 | ### 首先进行处理,使数据能够满足标准化处理, 全基因组文章三个数据集的处理方法
 46 | def QCsample1(dirName, fileName='Raw.h5ad'):
 47 |     os.chdir(dirName)
 48 |     adata = sc.read_h5ad(fileName)
 49 |     adata = adata[:, ~adata.var['gene_name'].duplicated()]  ## 更换index, 首先去除重复, 要不然后续报错
 50 |     adata.var['gene_id'] = adata.var.index
 51 |     adata.var.set_index('gene_name', inplace=True)
 52 |     adata.obs['gene'].replace({'non-targeting': 'CTRL'}, inplace=True)
 53 |     fun4(adata)
 54 | 
 55 | '''  确保都有embedding 并且 扰动基因在表达的基因中
 56 | adata = sc.read_h5ad('raw.h5ad')
 57 | [i for i in adata.obs['gene'].unique() if i not in geneList]
 58 | [i for i in adata.var_names if i not in geneList]
 59 | [i for i in adata.obs['gene'].unique() if i not in adata.var_names]
 60 | '''
 61 | 
 62 | ####  张峰转录因子数据
 63 | def QCsample2(dirName):
 64 |     os.chdir(dirName)
 65 |     adata = sc.read_h5ad('Raw.h5ad')
 66 |     adata.var['gene_id'] = adata.var.index
 67 |     adata.var['gene_name'] = adata.var.index
 68 |     adata.var.set_index('gene_name', inplace=True)
 69 |     fun4(adata)
 70 | 
 71 | #### Perturb-CITE-seq
 72 | def QCsample3(dirName):
 73 |     os.chdir(dirName)
 74 |     adata = sc.read_h5ad("raw1.h5ad")
 75 |     adata.var['gene_id'] = adata.var.index
 76 |     adata.var['gene_name'] = adata.var.index
 77 |     adata.var.set_index('gene_name', inplace=True)
 78 |     tmp = adata.obs['gene'].apply(lambda x: fun3(x))   ### 只保留两个扰动
 79 |     adata = adata[tmp]
 80 |     fun4(adata)
 81 | 
 82 | ### 对数据集进行差异基因的计算
 83 | ###数据预处理
 84 | def f_preData(dirName):
 85 |     os.chdir(dirName)
 86 |     adata1 = sc.read_h5ad('raw.h5ad')
 87 | 
 88 |     filterNoneNums, filterCells, filterMT, filterMinNums, adata = preData(adata1, filterNone=True, minNums = 30, shuffle=False, filterCom=False,  seed = 42, mtpercent = 10,  min_genes = 200, domaxNums=500, doNor = True)
 89 |     print (filterNoneNums, filterCells, filterMT, filterMinNums)
 90 | 
 91 |     sc.pp.highly_variable_genes(adata, subset=False, n_top_genes=5000)  ###和gears一致,保证5000个hvg
 92 |     hvgs = list(adata.var_names[adata.var['highly_variable']])
 93 |     trainGene = fun1(adata)   ### 获得扰动的基因列表, 除去CTRL
 94 |     trainGene = [i for i in trainGene if i in adata.var_names]  ### 扰动数据必须保留在表达谱中
 95 |     keepGene = list(set(trainGene + hvgs))
 96 |     adata = adata[:, keepGene]
 97 |     adata1 = adata1[adata.obs_names, keepGene]
 98 |     
 99 |     adata = fun2(adata)   ### 再次保证扰动都在表达基因中, 不在的话过滤掉
100 |     adata1 = fun2(adata1) ### 再次保证扰动都在表达基因中, 不在的话过滤掉
101 |     adata.write_h5ad('filterNor.h5ad')
102 |     adata1.write_h5ad('filterRaw.h5ad')
103 | 
104 | 
105 | 
106 | 
107 | dirNames = [
108 |     #'/home//project/GW_PerturbSeq/anndata/K562_GW',
109 |     '/home//project/GW_PerturbSeq/anndata/K562_GW_subset',
110 |     '/home//project/GW_PerturbSeq/anndata/K562_essential',
111 |     '/home//project/GW_PerturbSeq/anndata/RPE1_essential',
112 |     '/home//project/GW_PerturbSeq/anndata/TFatlas',
113 | 
114 |     '/home//project/GW_PerturbSeq/anndata_combination/PRJNA551220',
115 |     '/home//project/GW_PerturbSeq/anndata_combination/Perturb-CITE-seq',
116 |     '/home//project/GW_PerturbSeq/anndata_combination/PRJNA787633'
117 | ]
118 | 
119 | #dirName = '/home//project/GW_PerturbSeq/anndata/K562_GW_subset'; domaxNums = True; minNums = 30; ###QCsample1(dirName)   ###   ###  CRISPRi ### 每个扰动最多只保留50个细胞
120 | dirName = '/home//project/GW_PerturbSeq/anndata/K562_GW'; domaxNums = True; minNums = 30; ##QCsample1(dirName)   ###   ###  CRISPRi
121 | #dirName = '/home//project/GW_PerturbSeq/anndata/K562_essential'; domaxNums = True; minNums = 30; #QCsample1(dirName)
122 | #dirName = '/home//project/GW_PerturbSeq/anndata/RPE1_essential'; domaxNums = True; minNums = 30; #QCsample1(dirName)
123 | #dirName = '/home//project/GW_PerturbSeq/anndata/TFatlas'; domaxNums = True; minNums = 30; #QCsample2(dirName)   #Embryonic stem cells    Activation
124 | 
125 | #dirName = '/home//project/GW_PerturbSeq/anndata_combination/PRJNA551220'; domaxNums = True; minNums = 30; #QCsample3(dirName)   ### 组合扰动数据   ####   K562   Activation
126 | #dirName = '/home//project/GW_PerturbSeq/anndata_combination/Perturb-CITE-seq'; domaxNums = True; minNums = 5; #QCsample3(dirName)  ####  CRISPRKO   melanoma cell
127 | #dirName = '/home//project/GW_PerturbSeq/anndata_combination/PRJNA787633'; domaxNums = True; minNums = 5; #QCsample3(dirName)  ###  T cell    CRISPRa
128 | 
129 | ### Ttest, wilcoxon, edgeR
130 | 
131 | if __name__ == '__main__':
132 |     print ('hello, world')
133 |     f_preData(dirName)
134 |     getDEG("/home//project/GW_PerturbSeq/anndata/K562_GW", method='wilcoxon')
135 |     
136 |     for dirName in dirNames:
137 |         for method in tqdm(['Ttest', 'wilcoxon', 'edgeR']):
138 |             f_getDEG1(dirName, topDeg=0, method=method); f_getDEG1(dirName, topDeg=20, method=method)
139 |             f_getDEG2(dirName, topDeg=0, method=method); f_getDEG2(dirName, topDeg=20, method=method)
140 |             f_getDEG3(dirName, topDeg=0, method=method); f_getDEG3(dirName, topDeg=20, method=method)
141 |             mergeLevelData(dirName, topDeg=0, method=method); mergeLevelData(dirName, topDeg=20, method=method)
--------------------------------------------------------------------------------
/DataProcess/geneList.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bm2-lab/STAMP/f40e2887c013647bb01570bd61bdfa01d6a05b00/DataProcess/geneList.pkl
--------------------------------------------------------------------------------
/DataProcess/myUtil.py:
--------------------------------------------------------------------------------
  1 | import numpy as np, pandas as pd, scanpy as sc
  2 | import os, sys, pickle, joblib, re, torch, subprocess
  3 | from itertools import combinations
  4 | from collections import Counter
  5 | import seaborn as sns
  6 | from collections import defaultdict, OrderedDict
  7 | import matplotlib.pyplot as plt
  8 | from scipy import stats
  9 | from sklearn.metrics.pairwise import cosine_similarity
 10 | from scipy.stats import pearsonr, spearmanr
 11 | from sklearn.metrics import precision_recall_curve
 12 | from sklearn.metrics import average_precision_score
 13 | from sklearn.metrics import roc_curve, roc_auc_score, mean_squared_error
 14 | from scipy import sparse
 15 | import scipy
 16 | import anndata as ad
 17 | from tqdm import tqdm
 18 | from multiprocessing import Pool
 19 | 
 20 | 
 21 | class multidict(dict):
 22 |     def __getitem__(self, item):
 23 |         try:
 24 |             return dict.__getitem__(self, item)
 25 |         except KeyError:
 26 |             value = self[item] = type(self)()
 27 |             return value
 28 | 
 29 | 
 30 | def convertAnn2Matrix(adata, dirName='./'):
 31 |     a  = pd.DataFrame(adata.var)
 32 |     a['index'] = a.index
 33 |     a = a[['gene_ids', 'index']]
 34 |     a.to_csv("{}/genes.tsv".format(dirName),  sep = "\t", index = False, header=False)
 35 |     pd.DataFrame(adata.obs.index).to_csv("{}/barcodes.tsv".format(dirName), sep = "\t", index = False, header=False)
 36 |     if not sparse.issparse(adata.X): adata.X = sparse.csr_matrix(adata.X) ### 转换为稀疏矩阵
 37 |     adata.X = adata.X.astype(np.int32)  ###转换为整数
 38 |     scipy.io.mmwrite("{}/matrix.mtx".format(dirName), adata.X.T)
 39 | 
 40 | 
 41 | 
 42 | 
 43 | def getsubgroup(x = 'combo_seen0', seed = 1):
 44 |     tmp = '../../gears/data/train/splits/train_simulation_{}_0.8_subgroup.pkl'.format(seed)
 45 |     with open(tmp, 'rb') as fin:
 46 |         subgroup_split = pickle.load(fin)
 47 |         test_subgroup = subgroup_split['test_subgroup'][x]
 48 |         mylist = []
 49 |         if 'combo' in x:
 50 |             for i in test_subgroup:
 51 |                 mylist.append(','.join(sorted(i.split('+'))))
 52 |         else:
 53 |             for i in test_subgroup:
 54 |                 mylist.append(i.split('+')[0])
 55 |     return mylist
 56 | 
 57 | 
 58 | def getsubgroup_single(seed = 1):   ### 单个扰动
 59 |     tmp1 = '../../gears/data/train/splits/train_simulation_{}_0.8.pkl'.format(seed)
 60 |     tmp2 = '../../../gears/data/train/splits/train_simulation_{}_0.8.pkl'.format(seed)
 61 | 
 62 |     if os.path.isfile(tmp1):
 63 |         tmp = tmp1
 64 |     else:
 65 |         tmp = tmp2
 66 |     with open(tmp, 'rb') as fin:
 67 |         subgroup_split = pickle.load(fin)
 68 |         mylist = []
 69 |         for i in subgroup_split['test']:
 70 |             mylist.append(i.split('+')[0])
 71 |     return mylist
 72 | 
 73 | 
 74 | 
 75 | 
 76 | def wilcoxonFun(gene):
 77 |     sc.tl.rank_genes_groups(adata, 'gene', groups=[gene], reference='CTRL', method= 'wilcoxon')
 78 |     result = adata.uns['rank_genes_groups']
 79 |     groups = result['names'].dtype.names
 80 |     final_result = pd.DataFrame({group + '_' + key: result[key][group] for group in groups for key in ['names', 'pvals_adj', 'logfoldchanges', 'scores']})
 81 |     for group in groups:
 82 |         tmp1 = group + '_' + 'foldchanges'
 83 |         tmp2 = group + '_' + 'logfoldchanges'
 84 |         final_result[tmp1] = 2 ** final_result[tmp2]  ### logfoldchange 转换为 foldchange
 85 |         final_result.drop(labels=[tmp2], inplace=True, axis=1)
 86 |     return final_result
 87 | 
 88 | def getDEG(dirName, method='Ttest'):
 89 |     os.chdir(dirName)
 90 |     global adata
 91 |     fileout = '{}_DEG.tsv'.format(method)
 92 |     #adata = sc.read_h5ad('filterNor.h5ad')
 93 |     adata = sc.read_h5ad('filterNor_subset.h5ad')
 94 |     if 'log1p' not in adata.uns:
 95 |         adata.uns['log1p'] = {}
 96 |     adata.uns['log1p']["base"] = None
 97 |     if scipy.sparse.issparse(adata.X):
 98 |         adata.X = adata.X.toarray()
 99 |     adata.X += .1
100 |     genes = [i for i in set(adata.obs['gene']) if i != 'CTRL']
101 |     if method == 'Ttest':
102 |         sc.tl.rank_genes_groups(adata, 'gene', groups=genes[:], reference='CTRL', method= 't-test')
103 |         result = adata.uns['rank_genes_groups']
104 |         groups = result['names'].dtype.names
105 |         final_result = pd.DataFrame({group + '_' + key: result[key][group] for group in groups for key in ['names', 'pvals_adj', 'logfoldchanges', 'scores']})
106 |         for group in groups:
107 |             tmp1 = group + '_' + 'foldchanges'
108 |             tmp2 = group + '_' + 'logfoldchanges'
109 |             final_result[tmp1] = 2 ** final_result[tmp2]  ### logfoldchange 转换为 foldchange
110 |             final_result.drop(labels=[tmp2], inplace=True, axis=1)
111 |         final_result.sort_index(axis=1, inplace=True)
112 |         final_result.to_csv(fileout, sep='\t', index=False)
113 |     elif method == 'wilcoxon':
114 |         result = myPool(wilcoxonFun, genes, processes=7)
115 |         final_result = pd.concat(result, axis=1)
116 |         final_result.sort_index(axis=1, inplace=True)
117 |         final_result.to_csv(fileout, sep='\t', index=False)
118 | 
119 | 
120 | 
121 | 
122 | ### 根据DEG得到二维矩阵, 列是表达的基因, 行是扰动的基因, 
123 | ###   第一级任务。 *****************
124 | def f_getDEG1(dirName, topDeg=0, pvalue = 0.01, method='Ttest'):
125 |     os.chdir(dirName)
126 |     if topDeg == 0:
127 |         filein = '{}_DEG.tsv'.format(method);  fileout = '{}_DEG_binary.tsv'.format(method)
128 |     else:
129 |         filein = '{}_DEG.tsv'.format(method);  fileout = '{}_DEG_binary_topDeg{}.tsv'.format(method, topDeg)
130 |     if os.path.isfile(fileout):
131 |         tmp = pd.read_csv(fileout, sep='\t')
132 |         if tmp.shape[0] >= 10: pass
133 |     dat = pd.read_csv(filein, sep='\t')
134 |     pertGene = list(set([i.split('_')[0] for i in dat.columns]))
135 |     expGene = list(dat.iloc[:, 1])
136 |     binaryMat = pd.DataFrame(columns= expGene, index=pertGene, data=0)
137 |     for pertGene1 in pertGene:
138 |         tmp1 = '{}_names'.format(pertGene1); tmp2 = '{}_pvals_adj'.format(pertGene1); 
139 |         if topDeg == 0:
140 |             expGene1 = [i for i, j in zip(dat[tmp1], dat[tmp2]) if j <= pvalue]
141 |         else:
142 |             expGene1 = list(dat[tmp1][:20]) + list(dat[tmp1][-20:])  ### 上下调各取20个
143 |         binaryMat.loc[pertGene1, expGene1] = 1
144 |     binaryMat.sort_index(axis=0, inplace=True)  ### 排序
145 |     binaryMat.sort_index(axis=1, inplace=True)  ### 排序
146 |     binaryMat.to_csv(fileout, index=True, sep='\t', header=True)
147 | 
148 | 
149 | ### 分类成上下调  ###   第二级任务。 *****************
150 | def f_getDEG2(dirName, topDeg=0, method='Ttest'):
151 |     os.chdir(dirName)
152 |     if topDeg ==0:   ### 其实没必要,因为预测值不会随着topDeg而改变
153 |         filein = '{}_DEG.tsv'.format(method); fileout = '{}_DEG_UpDown.tsv'.format(method)
154 |     else:
155 |         filein = '{}_DEG.tsv'.format(method);  fileout = '{}_DEG_UpDown_topDeg{}.tsv'.format(method, topDeg)
156 |     if os.path.isfile(fileout):
157 |         tmp = pd.read_csv(fileout, sep='\t')
158 |         if tmp.shape[0] >= 10: pass
159 |     dat = pd.read_csv(filein, sep='\t')
160 |     pertGene = list(set([i.split('_')[0] for i in dat.columns]))
161 |     expGene = list(dat.iloc[:, 1])
162 |     binaryMat = pd.DataFrame(columns= expGene, index=pertGene, data=0)
163 |     for pertGene1 in pertGene:
164 |         tmp1 = '{}_names'.format(pertGene1); tmp2 = '{}_foldchanges'.format(pertGene1)
165 |         expGene1 = [i for i, j in zip(dat[tmp1], dat[tmp2]) if j >= 1]
166 |         binaryMat.loc[pertGene1, expGene1] = 1
167 |     binaryMat.sort_index(axis=0, inplace=True)  ### 排序
168 |     binaryMat.sort_index(axis=1, inplace=True)  ### 排序
169 |     binaryMat.to_csv(fileout, index=True, sep='\t', header=True)
170 | 
171 | 
172 | def f_getDEG3(dirName, topDeg = 0, method='Ttest'):
173 |     os.chdir(dirName)
174 |     if topDeg == 0:   ### 其实没必要,因为预测值不会随着topDeg而改变
175 |         filein = '{}_DEG.tsv'.format(method); fileout = '{}_DEG_foldchange.tsv'.format(method)
176 |     else:
177 |         filein = '{}_DEG.tsv'.format(method);  fileout = '{}_DEG_foldchange_topDeg{}.tsv'.format(method, topDeg)
178 |     if os.path.isfile(fileout):
179 |         tmp = pd.read_csv(fileout, sep='\t')
180 |         if tmp.shape[0] >= 10: pass
181 |     dat = pd.read_csv(filein, sep='\t')
182 |     pertGene = list(set([i.split('_')[0] for i in dat.columns]))
183 |     expGene = list(dat.iloc[:, 1])
184 |     expGene = sorted(expGene)
185 |     binaryMat = pd.DataFrame(columns= expGene, index=pertGene, data=0.0)
186 |     for pertGene1 in pertGene:
187 |         tmp1 = '{}_names'.format(pertGene1); tmp2 = '{}_foldchanges'.format(pertGene1)
188 |         tmp3 = dat[[tmp1, tmp2]]
189 |         expGene1 = list(tmp3.sort_values(tmp1)[tmp2])
190 |         binaryMat.loc[pertGene1, :] = expGene1
191 |     binaryMat.sort_index(axis=0, inplace=True)
192 |     binaryMat.sort_index(axis=1, inplace=True)
193 |     binaryMat.to_csv(fileout, index=True, sep='\t', header=True)
194 | 
195 | 
196 | def mergeLevelData(dirName, topDeg=0, method='Ttest'):
197 |     os.chdir(dirName)
198 |     if topDeg == 0:
199 |         filein1 = '{}_DEG_binary.tsv'.format(method)
200 |         filein2 = '{}_DEG_UpDown.tsv'.format('Ttest')
201 |         filein3 = '{}_DEG_foldchange.tsv'.format("Ttest")
202 |         fileout = '{}_merge.h5ad'.format(method)
203 |     else:
204 |         filein1 = '{}_DEG_binary_topDeg{}.tsv'.format(method, topDeg)
205 |         filein2 = '{}_DEG_UpDown_topDeg{}.tsv'.format("Ttest", topDeg)
206 |         filein3 = '{}_DEG_foldchange_topDeg{}.tsv'.format("Ttest", topDeg)
207 |         fileout = '{}_merge_topDeg{}.h5ad'.format(method, topDeg)
208 | 
209 |     
210 |     dat1 = pd.read_csv(filein1, sep='\t', index_col=0)
211 |     dat2 = pd.read_csv(filein2, sep='\t', index_col=0)
212 |     dat2 = dat2.loc[dat1.index, dat1.columns]
213 |     
214 |     dat3 = pd.read_csv(filein3, sep='\t', index_col=0)
215 |     dat3 = dat3.loc[dat1.index, dat1.columns]
216 |     adata = ad.AnnData(X=sparse.csr_matrix(dat1.values), obs=pd.DataFrame(index=dat1.index), var=pd.DataFrame(index=dat1.columns))
217 |     adata.layers['level1'] = sparse.csr_matrix(dat1.values)
218 |     adata.layers['level2'] = sparse.csr_matrix(dat2.values)
219 |     adata.layers['level3'] = sparse.csr_matrix(dat3.values)
220 |     adata.write_h5ad(fileout)
221 | 
222 | 
223 | 
224 | def preData(adata, filterNone=True, minNums = 30, shuffle=True, filterCom=False,  seed = 42, mtpercent = 10,  min_genes = 200, domaxNums=False, doNor=True, min_cells=3):  #### 为了测试聚类,最好不要进行排序
225 |     if domaxNums: maxNums = 50
226 |     adata.var_names.astype(str)
227 |     adata.var_names_make_unique()
228 |     adata = adata[~adata.obs.index.duplicated()]
229 |     if filterCom:
230 |         tmp = adata.obs['gene'].apply(lambda x: True if ',' not in x else False);  adata = adata[tmp]
231 |     if filterNone:
232 |         adata = adata[adata.obs["gene"] != "None"]
233 |     filterNoneNums = adata.shape[0]
234 |     sc.pp.filter_cells(adata, min_genes=min_genes)
235 |     sc.pp.filter_genes(adata, min_cells= min_cells)
236 |     filterCells = adata.shape[0]
237 | 
238 |     if np.any([True if i.startswith('mt-') else False for i in adata.var_names]):
239 |         adata.var['mt'] = adata.var_names.str.startswith('mt-')
240 |     else:
241 |         adata.var['mt'] = adata.var_names.str.startswith('MT-')
242 |     sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)
243 |     if sum(adata.obs['pct_counts_mt'] < 10) / adata.shape[0] <=0.5: mtpercent = 15
244 |     adata = adata[adata.obs.pct_counts_mt < mtpercent, :]
245 |     filterMT = adata.shape[0]
246 |     tmp = adata.obs['gene'].value_counts()  
247 |     tmp_bool = tmp >= minNums
248 |     genes = list(tmp[tmp_bool].index)
249 |     if 'CTRL' not in genes: genes += ['CTRL']
250 |     adata = adata[adata.obs['gene'].isin(genes), :]
251 |     if domaxNums:
252 |         adata1 = adata[adata.obs['gene'] == 'CTRL']
253 |         genes = adata.obs['gene'].unique()
254 |         tmp = [adata[adata.obs['gene'] == i][:maxNums] for i in genes if i !='CTRL']
255 |         adata2 = ad.concat(tmp)
256 |         adata = ad.concat([adata1, adata2])
257 | 
258 |     filterMinNums = adata.shape[0]
259 |     
260 |     if doNor:
261 |         sc.pp.normalize_total(adata, target_sum=1e4)
262 |         sc.pp.log1p(adata)
263 |     adata = adata[adata.obs.sort_values(by='gene').index,:]
264 |     if shuffle:
265 |         tmp = list(adata.obs.index)
266 |         np.random.seed(seed); np.random.shuffle(tmp); adata = adata[tmp]
267 |     return filterNoneNums, filterCells, filterMT, filterMinNums, adata
268 | 
269 | 
270 | 
271 | def myPool(func, mylist, processes):
272 |     with Pool(processes) as pool:
273 |         results = list(tqdm(pool.imap(func, mylist), total=len(mylist)))
274 |     return results
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | prune Data
2 | include README.md
3 | include requirements.txt
4 | include LICENSE
--------------------------------------------------------------------------------
/Output/training.log:
--------------------------------------------------------------------------------
 1 | [2024-01-08 14:47:37,272][STAMP.py][line:134][INFO] Epoch:[1/10]	steps:16	loss1_train:0.79759	loss2_train:0.69352	loss3_train:0.32161	loss1_test:0.77941	loss2_test:0.69337	loss3_test:0.29287	time:4.355
 2 | [2024-01-08 14:47:41,487][STAMP.py][line:134][INFO] Epoch:[2/10]	steps:16	loss1_train:0.77924	loss2_train:0.69322	loss3_train:0.25942	loss1_test:0.77926	loss2_test:0.69336	loss3_test:0.21642	time:3.902
 3 | [2024-01-08 14:47:45,568][STAMP.py][line:134][INFO] Epoch:[3/10]	steps:16	loss1_train:0.77899	loss2_train:0.69377	loss3_train:0.17493	loss1_test:0.77886	loss2_test:0.69381	loss3_test:0.13191	time:3.771
 4 | [2024-01-08 14:47:49,320][STAMP.py][line:134][INFO] Epoch:[4/10]	steps:16	loss1_train:0.77779	loss2_train:0.69335	loss3_train:0.10727	loss1_test:0.77657	loss2_test:0.69318	loss3_test:0.08846	time:3.413
 5 | [2024-01-08 14:47:53,320][STAMP.py][line:134][INFO] Epoch:[5/10]	steps:16	loss1_train:0.77411	loss2_train:0.69317	loss3_train:0.08467	loss1_test:0.77420	loss2_test:0.69318	loss3_test:0.08353	time:3.686
 6 | [2024-01-08 14:47:56,933][STAMP.py][line:134][INFO] Epoch:[6/10]	steps:16	loss1_train:0.77264	loss2_train:0.69316	loss3_train:0.08380	loss1_test:0.77355	loss2_test:0.69318	loss3_test:0.08373	time:3.277
 7 | [2024-01-08 14:48:00,993][STAMP.py][line:134][INFO] Epoch:[7/10]	steps:16	loss1_train:0.77234	loss2_train:0.69316	loss3_train:0.08357	loss1_test:0.77358	loss2_test:0.69318	loss3_test:0.08341	time:3.747
 8 | [2024-01-08 14:48:04,797][STAMP.py][line:134][INFO] Epoch:[8/10]	steps:16	loss1_train:0.77223	loss2_train:0.69316	loss3_train:0.08344	loss1_test:0.77350	loss2_test:0.69318	loss3_test:0.08340	time:3.507
 9 | [2024-01-08 14:48:08,271][STAMP.py][line:134][INFO] Epoch:[9/10]	steps:16	loss1_train:0.77211	loss2_train:0.69316	loss3_train:0.08344	loss1_test:0.77351	loss2_test:0.69318	loss3_test:0.08340	time:3.156
10 | [2024-01-08 14:48:11,990][STAMP.py][line:134][INFO] Epoch:[10/10]	steps:16	loss1_train:0.77197	loss2_train:0.69316	loss3_train:0.08344	loss1_test:0.77352	loss2_test:0.69318	loss3_test:0.08340	time:3.421
11 | [2024-01-08 14:48:11,996][STAMP.py][line:145][INFO] finish training!
12 | 
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | # STAMP: Toward subtask decomposition-based learning and benchmarking for genetic perturbation outcome prediction and beyond
  2 | [](https://zenodo.org/records/12779567)
  3 | ## Introduction 
  4 | This repository hosts the official implementation of STAMP, a method that can predict perturbation outcomes using single-cell RNA-sequencing data from perturbational experimental screens using subtask decomposition learning. STAMP can be applied to three challenges in this area, i.e. (1) predict single genetic perturbation outcomes, (2) predict multiple genetic perturbation outcomes and (3) predict genetic perturbation outcomes across cell lines.
  5 | 

  
  6 | 
  7 | ## Installation
  8 | Our experiments were conducted on python=3.9.7 and our CUDA version is 11.4.
  9 | 
 10 | We recommend using Anaconda / Miniconda to create a conda environment for using STAMP. You can create a python environment using the following command:
 11 | ```python
 12 | conda  create -n stamp python==3.9.7
 13 | ```
 14 | 
 15 | Then, you can activate the environment using:
 16 | ```python
 17 | conda activate stamp
 18 | ```
 19 | Installing Pytorch with following command:
 20 | ```python
 21 | conda install pytorch==1.10.2 torchvision==0.11.3 torchaudio==0.10.2 -c pytorch
 22 | ```
 23 | Then
 24 | ```python
 25 | pip install .
 26 | ```
 27 | or you can install it from [PyPI](https://pypi.org/project/cell-stamp/):
 28 | ```
 29 | pip install cell-stamp
 30 | ```
 31 | 
 32 | ## Example data
 33 | We have made available the code necessary to generate example data, serving as a practical illustration for training and testing the STAMP model. Additionally, for guidance on configuring the training process of STAMP, we offer an example config file located at `./Data/example_config.yaml`.
 34 | ```python
 35 | python ./Data/GeneratingExampleData.py
 36 | ```
 37 | The example *.h5ad data file has three distinct layers, namely 'level1', 'level2', and 'level3'. The 'level1' layer is a binary matrix, where '0' represents non-differentially expressed genes (non-DEGs) and '1' indicates differentially expressed genes (DEGs). Similarly, 'level2' is another binary matrix, denoting down-regulated genes with '0' and up-regulated genes with '1'. Lastly, the 'level3' layer is a matrix that quantifies the magnitude of gene expression changes.
 38 | 
 39 | ## Real demo data
 40 | We have uploaded all benchmark datasets to Zenodo, which can be obtained from [here](https://zenodo.org/records/12779567). Please download all these files into the `./Data` directory and refer `tutorial_for_training.py.ipynb` in `./Tutorial` directory. This tutorial uses the one fold of RPE1_essential dataset as an example to perform the model training, testing and check the loss curves during training.
 41 | #### Note: Users are encouraged to change the path of each data in 'Config.yaml' based on their own machines.
 42 | 
 43 | ## Core API interface for model training
 44 | Using this API, you can train and test STAMP on your own perturbation datasets using a few lines of code. 
 45 | ```python
 46 | from stamp import STAMP, load_config
 47 | import scanpy as sc
 48 | 
 49 | # load config file
 50 | config = load_config("./Data/example_config.yaml")
 51 | 
 52 | # set up and train a STAMP
 53 | model = STAMP(config)
 54 | model.train()
 55 | 
 56 | # load trained model
 57 | model.load_pretrained(f"{config['Train']['output_dir']}/trained_models")
 58 | 
 59 | # use trained model to predict unseen perturbations
 60 | model.prediction(config['dataset']['Testing_dataset'], combo_test = True)
 61 | 
 62 | # use trained model to predict unseen perturbations; considering Top 40 DEGs
 63 | # Top 40 DEGs consisting of Top 20 up-regulation genes and Top 20 down-regulation genes
 64 | 
 65 | # load Top 40 test data
 66 | top_40_data = sc.read_h5ad("./Data/example_test_top40.h5ad")
 67 | 
 68 | # prediction
 69 | model.prediction(top_40_data, combo_test = True)
 70 | ```
 71 | ## Core API interface for model fine-tuning
 72 | Using this API, you can fine-tune and test STAMP on your own perturbation datasets using a few lines of code.
 73 | ```python
 74 | from stamp import STAMP, load_config
 75 | import scanpy as sc
 76 | 
 77 | # load config file (we use the example config used for model training to illustrate this)
 78 | config = load_config("./Data/example_config.yaml")
 79 | 
 80 | # set up STAMP
 81 | model = STAMP(config)
 82 | 
 83 | # load pre-trained model
 84 | model.load_pretrained(f"{config['Train']['output_dir']}/trained_models")
 85 | 
 86 | # fine-tuning model
 87 | model.finetuning()
 88 | 
 89 | # use fine-tuned model to predict unseen perturbations
 90 | model.prediction(config['dataset']['Testing_dataset'], combo_test = False)
 91 | 
 92 | # use fine-tuned model to predict unseen perturbations; considering Top 40 DEGs
 93 | # Top 40 DEGs consisting of Top 20 up-regulation genes and Top 20 down-regulation genes
 94 | 
 95 | # load Top 40 test data
 96 | top_40_data = sc.read_h5ad("./Data/example_test_top40.h5ad")
 97 | 
 98 | # prediction
 99 | model.prediction(top_40_data, combo_test = False)
100 | ```
101 | ## Citation
102 | Yicheng Gao, Zhiting Wei, Qi Liu et al. *Toward subtask decomposition-based learning and benchmarking for genetic perturbation outcome prediction and beyond*, Nature Computational Science, 2024.
103 | ## Contacts
104 | bm2-lab@tongji.edu.cn
105 | 
--------------------------------------------------------------------------------
/Supp_code/GI_experiment/GI_evaluation.py:
--------------------------------------------------------------------------------
  1 | import joblib
  2 | import numpy as np
  3 | from sklearn import tree
  4 | import matplotlib.pyplot as plt
  5 | from sklearn.tree import plot_tree
  6 | from scipy import stats
  7 | import pandas as pd
  8 | np.random.seed(000)
  9 | GIs = {
 10 |     'NEOMORPHIC': ['CBL+TGFBR2',
 11 |                   'KLF1+TGFBR2',
 12 |                   'MAP2K6+SPI1',
 13 |                   'SAMD1+TGFBR2',
 14 |                   'TGFBR2+C19orf26',
 15 |                   'TGFBR2+ETS2',
 16 |                   'CBL+UBASH3A',
 17 |                   'CEBPE+KLF1',
 18 |                   'DUSP9+MAPK1',
 19 |                   'FOSB+PTPN12',
 20 |                   'PLK4+STIL',
 21 |                   'PTPN12+OSR2',
 22 |                   'ZC3HAV1+CEBPE'],
 23 |     'ADDITIVE': ['BPGM+SAMD1',
 24 |                 'CEBPB+MAPK1',
 25 |                 'CEBPB+OSR2',
 26 |                 'DUSP9+PRTG',
 27 |                 'FOSB+OSR2',
 28 |                 'IRF1+SET',
 29 |                 'MAP2K3+ELMSAN1',
 30 |                 'MAP2K6+ELMSAN1',
 31 |                 'POU3F2+FOXL2',
 32 |                 'RHOXF2BB+SET',
 33 |                 'SAMD1+PTPN12',
 34 |                 'SAMD1+UBASH3B',
 35 |                 'SAMD1+ZBTB1',
 36 |                 'SGK1+TBX2',
 37 |                 'TBX3+TBX2',
 38 |                 'ZBTB10+SNAI1'],
 39 |     'EPISTASIS': ['AHR+KLF1',
 40 |                  'MAPK1+TGFBR2',
 41 |                  'TGFBR2+IGDCC3',
 42 |                  'TGFBR2+PRTG',
 43 |                  'UBASH3B+OSR2',
 44 |                  'DUSP9+ETS2',
 45 |                  'KLF1+CEBPA',
 46 |                  'MAP2K6+IKZF3',
 47 |                  'ZC3HAV1+CEBPA'],
 48 |     'REDUNDANT': ['CDKN1C+CDKN1A',
 49 |                  'MAP2K3+MAP2K6',
 50 |                  'CEBPB+CEBPA',
 51 |                  'CEBPE+CEBPA',
 52 |                  'CEBPE+SPI1',
 53 |                  'ETS2+MAPK1',
 54 |                  'FOSB+CEBPE',
 55 |                  'FOXA3+FOXA1'],
 56 |     'POTENTIATION': ['CNN1+UBASH3A',
 57 |                     'ETS2+MAP7D1',
 58 |                     'FEV+CBFA2T3',
 59 |                     'FEV+ISL2',
 60 |                     'FEV+MAP7D1',
 61 |                     'PTPN12+UBASH3A'],
 62 |     'SYNERGY_SIMILAR_PHENO':['CBL+CNN1',
 63 |                             'CBL+PTPN12',
 64 |                             'CBL+PTPN9',
 65 |                             'CBL+UBASH3B',
 66 |                             'FOXA3+FOXL2',
 67 |                             'FOXA3+HOXB9',
 68 |                             'FOXL2+HOXB9',
 69 |                             'UBASH3B+CNN1',
 70 |                             'UBASH3B+PTPN12',
 71 |                             'UBASH3B+PTPN9',
 72 |                             'UBASH3B+ZBTB25'],
 73 |     'SYNERGY_DISSIMILAR_PHENO': ['AHR+FEV',
 74 |                                 'DUSP9+SNAI1',
 75 |                                 'FOXA1+FOXF1',
 76 |                                 'FOXA1+FOXL2',
 77 |                                 'FOXA1+HOXB9',
 78 |                                 'FOXF1+FOXL2',
 79 |                                 'FOXF1+HOXB9',
 80 |                                 'FOXL2+MEIS1',
 81 |                                 'IGDCC3+ZBTB25',
 82 |                                 'POU3F2+CBFA2T3',
 83 |                                 'PTPN12+ZBTB25',
 84 |                                 'SNAI1+DLX2',
 85 |                                 'SNAI1+UBASH3B'],
 86 |     'SUPPRESSOR': ['CEBPB+PTPN12',
 87 |                   'CEBPE+CNN1',
 88 |                   'CEBPE+PTPN12',
 89 |                   'CNN1+MAPK1',
 90 |                   'ETS2+CNN1',
 91 |                   'ETS2+IGDCC3',
 92 |                   'ETS2+PRTG',
 93 |                   'FOSB+UBASH3B',
 94 |                   'IGDCC3+MAPK1',
 95 |                   'LYL1+CEBPB',
 96 |                   'MAPK1+PRTG',
 97 |                   'PTPN12+SNAI1']
 98 | }
 99 | GIs['SYNERGY'] = GIs['SYNERGY_DISSIMILAR_PHENO'] + GIs['SYNERGY_SIMILAR_PHENO'] + GIs['POTENTIATION']
100 | 
101 | all_results = joblib.load("./STAMP_results.pkl")
102 | all_results_truth = joblib.load("./Truth_results.pkl")
103 | all_results_gears = joblib.load("./Gears_results.pkl")
104 | science_data = pd.read_csv("./Science_data.csv", sep = ',')
105 | all_results_science = {}
106 | for idx, name in enumerate(science_data['name']):
107 |     all_results_science[(','.join(name.split('_')))] = {}
108 |     all_results_science[(','.join(name.split('_')))]['magnitude'] = science_data['ts_norm2'][idx]
109 |     all_results_science[(','.join(name.split('_')))]['model_fit'] = science_data['ts_linear_dcor'][idx]
110 |     all_results_science[(','.join(name.split('_')))]['equality_of_contribution'] = science_data['dcor_ratio'][idx]
111 |     all_results_science[(','.join(name.split('_')))]['Similarity'] = science_data['dcor'][idx]
112 |     
113 |     
114 | def calculate_metric(all_results, top_k = 10):
115 |     mags = [all_results[i]['magnitude'] for i in all_results]
116 |     model_fits = [all_results[i]['model_fit'] for i in all_results]
117 |     equality_of_contributions = [all_results[i]['equality_of_contribution'] for i in all_results]
118 |     Similaritys = [all_results[i]['Similarity'] for i in all_results]
119 |     top10_synergy = np.array(list(all_results.keys()))[np.argsort(mags)[::-1][:top_k]]
120 |     top10_precision_synergy = len(set(top10_synergy).intersection(set([(',').join(i.split("+")) for i in GIs['SYNERGY']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['SYNERGY']]))))
121 |     top10_precision_synergy /= top_k
122 |     top10_suppressor = np.array(list(all_results.keys()))[np.argsort(mags)[:top_k]]
123 |     top10_precision_suppressor = len(set(top10_suppressor).intersection(set([(',').join(i.split("+")) for i in GIs['SUPPRESSOR']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['SUPPRESSOR']]))))
124 |     top10_precision_suppressor /= top_k
125 |     top10_neomorphism = np.array(list(all_results.keys()))[np.argsort(model_fits)[:top_k]]
126 |     top10_precision_neomorphism = len(set(top10_neomorphism).intersection(set([(',').join(i.split("+")) for i in GIs['NEOMORPHIC']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['NEOMORPHIC']]))))
127 |     top10_precision_neomorphism /= top_k
128 |     top10_redundant = np.array(list(all_results.keys()))[np.argsort(Similaritys)[::-1][:top_k]]
129 |     top10_precision_redundant = len(set(top10_redundant).intersection(set([(',').join(i.split("+")) for i in GIs['REDUNDANT']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['REDUNDANT']]))))
130 |     top10_precision_redundant /= 8
131 |     top10_epistasis = np.array(list(all_results.keys()))[np.argsort(equality_of_contributions)[:top_k]]
132 |     top10_precision_epistasis = len(set(top10_epistasis).intersection(set([(',').join(i.split("+")) for i in GIs['EPISTASIS']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['EPISTASIS']]))))
133 |     top10_precision_epistasis /= 9
134 |     return top10_precision_synergy, top10_precision_suppressor, top10_precision_neomorphism, top10_precision_redundant, top10_precision_epistasis
135 | 
136 | print("Science_ori,Top10 precision:",calculate_metric(all_results_science))
137 | print("Ground_truth,Top10 precision:",calculate_metric(all_results_truth))
138 | print("STAMP,Top10 precision:",calculate_metric(all_results))
139 | print("GEARs,Top10 precision:",calculate_metric(all_results_gears))
140 | 
141 | def calculate_metric_top10acc(all_results,all_results_truth, top_k = 10):
142 |     mags = [all_results[i]['magnitude'] for i in all_results]
143 |     model_fits = [all_results[i]['model_fit'] for i in all_results]
144 |     equality_of_contributions = [all_results[i]['equality_of_contribution'] for i in all_results]
145 |     Similaritys = [all_results[i]['Similarity'] for i in all_results]
146 |     mags_truth = [all_results_truth[i]['magnitude'] for i in all_results_truth]
147 |     model_fits_truth = [all_results_truth[i]['model_fit'] for i in all_results_truth]
148 |     equality_of_contributions_truth = [all_results_truth[i]['equality_of_contribution'] for i in all_results_truth]
149 |     Similaritys_truth = [all_results_truth[i]['Similarity'] for i in all_results_truth]
150 |     
151 |     top10_synergy = np.array(list(all_results.keys()))[np.argsort(mags)[::-1][:top_k]]
152 |     top10_acc_synergy = set(top10_synergy).intersection(set([(',').join(i.split("+")) for i in GIs['SYNERGY']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['SYNERGY']])))
153 |     top10_synergy_truth = np.array(list(all_results_truth.keys()))[np.argsort(mags_truth)[::-1][:top_k]]
154 |     top10_acc_synergy_truth = set(top10_synergy_truth).intersection(set([(',').join(i.split("+")) for i in GIs['SYNERGY']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['SYNERGY']])))
155 |     top10_acc_synergy = len(top10_acc_synergy.intersection(top10_acc_synergy_truth))/len(top10_acc_synergy_truth)
156 |     
157 |     top10_suppressor = np.array(list(all_results.keys()))[np.argsort(mags)[:top_k]]
158 |     top10_acc_suppressor = set(top10_suppressor).intersection(set([(',').join(i.split("+")) for i in GIs['SUPPRESSOR']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['SUPPRESSOR']])))
159 |     top10_suppressor_truth = np.array(list(all_results_truth.keys()))[np.argsort(mags_truth)[:top_k]]
160 |     top10_acc_suppressor_truth = set(top10_suppressor_truth).intersection(set([(',').join(i.split("+")) for i in GIs['SUPPRESSOR']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['SUPPRESSOR']])))
161 |     try:
162 |         top10_acc_suppressor = len(top10_acc_suppressor.intersection(top10_acc_suppressor_truth))/len(top10_acc_suppressor_truth)
163 |     except:
164 |         top10_acc_suppressor=0
165 |     
166 |     top10_neomorphism = np.array(list(all_results.keys()))[np.argsort(model_fits)[:top_k]]
167 |     top10_acc_neomorphism = set(top10_neomorphism).intersection(set([(',').join(i.split("+")) for i in GIs['NEOMORPHIC']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['NEOMORPHIC']])))
168 |     top10_neomorphism_truth = np.array(list(all_results_truth.keys()))[np.argsort(model_fits)[:top_k]]
169 |     top10_acc_neomorphism_truth = set(top10_neomorphism_truth).intersection(set([(',').join(i.split("+")) for i in GIs['NEOMORPHIC']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['NEOMORPHIC']])))
170 |     top10_acc_neomorphism = len(top10_acc_neomorphism.intersection(top10_acc_neomorphism_truth))/len(top10_acc_neomorphism_truth)
171 |     
172 |     top10_redundant = np.array(list(all_results.keys()))[np.argsort(Similaritys)[::-1][:top_k]]
173 |     top10_acc_redundant = set(top10_redundant).intersection(set([(',').join(i.split("+")) for i in GIs['REDUNDANT']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['REDUNDANT']])))
174 |     top10_redundant_truth = np.array(list(all_results_truth.keys()))[np.argsort(Similaritys)[::-1][:top_k]]
175 |     top10_acc_redundant_truth = set(top10_redundant_truth).intersection(set([(',').join(i.split("+")) for i in GIs['REDUNDANT']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['REDUNDANT']])))
176 |     try:
177 |         top10_acc_redundant = len(top10_acc_redundant.intersection(top10_acc_redundant_truth))/len(top10_acc_redundant_truth)
178 |     except:
179 |         top10_acc_redundant=0
180 |     
181 |     top10_epistasis = np.array(list(all_results.keys()))[np.argsort(equality_of_contributions)[:top_k]]
182 |     top10_acc_epistasis = set(top10_epistasis).intersection(set([(',').join(i.split("+")) for i in GIs['EPISTASIS']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['EPISTASIS']])))
183 |     top10_epistasis_truth = np.array(list(all_results_truth.keys()))[np.argsort(equality_of_contributions)[:top_k]]
184 |     top10_acc_epistasis_truth = set(top10_epistasis_truth).intersection(set([(',').join(i.split("+")) for i in GIs['EPISTASIS']]).union(set([(',').join(i.split("+")[::-1]) for i in GIs['EPISTASIS']])))
185 |     top10_acc_epistasis = len(top10_acc_epistasis.intersection(top10_acc_epistasis_truth))/len(top10_acc_epistasis_truth)
186 |     
187 |     return top10_acc_synergy, top10_acc_suppressor, top10_acc_neomorphism, top10_acc_redundant, top10_acc_epistasis
188 | 
189 | # print("STAMP,Top10 acc:",calculate_metric_top10acc(all_results, all_results_truth))
190 | # print("GEARs,Top10 acc:",calculate_metric_top10acc(all_results_gears, all_results_truth))
191 | 
192 | def array_generation(results):
193 |     X = []
194 |     Y = []
195 |     for idx, combo_gene in enumerate(results):
196 |         if ('+').join(combo_gene.split(",")) in GIs['SYNERGY'] or ('+').join(combo_gene.split(",")[::-1]) in GIs['SYNERGY']:
197 |             Y.append(0)
198 |         elif ('+').join(combo_gene.split(",")) in GIs['SUPPRESSOR'] or ('+').join(combo_gene.split(",")[::-1]) in GIs['SUPPRESSOR']:
199 |             Y.append(1)
200 |         elif ('+').join(combo_gene.split(",")) in GIs['NEOMORPHIC'] or ('+').join(combo_gene.split(",")[::-1]) in GIs['NEOMORPHIC']:
201 |             Y.append(2)
202 |         elif ('+').join(combo_gene.split(",")) in GIs['REDUNDANT'] or ('+').join(combo_gene.split(",")[::-1]) in GIs['REDUNDANT']:
203 |             Y.append(3)
204 |         elif ('+').join(combo_gene.split(",")) in GIs['EPISTASIS'] or ('+').join(combo_gene.split(",")[::-1]) in GIs['EPISTASIS']:
205 |             Y.append(4)
206 |         elif ('+').join(combo_gene.split(",")) in GIs['ADDITIVE'] or ('+').join(combo_gene.split(",")[::-1]) in GIs['ADDITIVE']:
207 |             Y.append(5)
208 |         else:
209 |             continue
210 |         X.append([results[combo_gene][feature] for feature in results[combo_gene]])
211 |     return np.array(X), np.array(Y)     
212 | 
213 | def acc_cal(X_STAMP,Y_STAMP):
214 |     tmp = clf.predict_proba(X_STAMP)
215 |     acc_SYNERGY = (Y_STAMP[Y_STAMP==0]==(tmp.argmax(axis=1)[Y_STAMP==0])).mean()
216 |     acc_SUPPRESSOR = (Y_STAMP[Y_STAMP==1]==(tmp.argmax(axis=1)[Y_STAMP==1])).mean()
217 |     acc_NEOMORPHIC = (Y_STAMP[Y_STAMP==2]==(tmp.argmax(axis=1)[Y_STAMP==2])).mean()
218 |     acc_REDUNDANT = (Y_STAMP[Y_STAMP==3]==(tmp.argmax(axis=1)[Y_STAMP==3])).mean()
219 |     acc_EPISTASIS = (Y_STAMP[Y_STAMP==4]==(tmp.argmax(axis=1)[Y_STAMP==4])).mean()
220 |     acc_Additive = (Y_STAMP[Y_STAMP==5]==(tmp.argmax(axis=1)[Y_STAMP==5])).mean()
221 |     return (acc_SYNERGY,acc_SUPPRESSOR, acc_NEOMORPHIC, acc_REDUNDANT, acc_EPISTASIS, acc_Additive)
222 | 
223 | X_truth,Y_truth = array_generation(all_results_truth)
224 | 
225 | X_truth = (X_truth-X_truth.mean(axis=0))/(X_truth.std(axis=0))
226 | 
227 | clf = tree.DecisionTreeClassifier(random_state=42, max_depth=6,min_samples_leaf=8, min_samples_split=8, max_leaf_nodes=6)
228 | clf = clf.fit(X_truth, Y_truth)
229 | plt.figure(figsize=(12, 8))
230 | plot_tree(clf, filled=True, feature_names=["magnitude", "model_fit", "equality_of_contribution", "Similarity"], class_names=["SYNERGY", "SUPPRESSOR", "NEOMORPHIC", "REDUNDANT", "EPISTASIS","Additive"])
231 | plt.savefig("./test_tree.png")
232 | 
233 | print("Ground_truth, accuracy",acc_cal(X_truth, Y_truth))
234 | 
235 | X_STAMP,Y_STAMP = array_generation(all_results)
236 | X_STAMP = (X_STAMP-X_STAMP.mean(axis=0))/(X_STAMP.std(axis=0))
237 | print("STAMP, accuracy",acc_cal(X_STAMP, Y_STAMP))
238 | 
239 | 
240 | X_GEARs,Y_GEARs = array_generation(all_results_gears)
241 | X_GEARs = (X_GEARs-X_GEARs.mean(axis=0))/(X_GEARs.std(axis=0))
242 | print("GEARs, accuracy",acc_cal(X_GEARs, Y_GEARs))
243 | 
244 | ##### Random test
245 | acc_SYNERGY = 0
246 | acc_SUPPRESSOR = 0
247 | acc_NEOMORPHIC = 0
248 | acc_REDUNDANT = 0
249 | acc_EPISTASIS = 0
250 | acc_Additive = 0
251 | for i in range(10):
252 |     random_idx = list(range(len(Y_truth)))
253 |     np.random.shuffle(random_idx)
254 |     random_Y_truth = Y_truth[random_idx]
255 |     acc_SYNERGY += (Y_truth[Y_truth==0]==(random_Y_truth[Y_truth==0])).mean()
256 |     acc_SUPPRESSOR += (Y_truth[Y_truth==1]==(random_Y_truth[Y_truth==1])).mean()
257 |     acc_NEOMORPHIC += (Y_truth[Y_truth==2]==(random_Y_truth[Y_truth==2])).mean()
258 |     acc_REDUNDANT += (Y_truth[Y_truth==3]==(random_Y_truth[Y_truth==3])).mean()
259 |     acc_EPISTASIS += (Y_truth[Y_truth==4]==(random_Y_truth[Y_truth==4])).mean()
260 |     acc_Additive += (Y_truth[Y_truth==5]==(random_Y_truth[Y_truth==5])).mean()
261 | acc_SYNERGY /= i+1;acc_SUPPRESSOR /= i+1;acc_NEOMORPHIC/=i+1;acc_REDUNDANT/=i+1;acc_EPISTASIS/=i+1;acc_Additive/=i+1
262 | print("Random, accuracy",acc_SYNERGY,acc_SUPPRESSOR,acc_NEOMORPHIC,acc_REDUNDANT,acc_EPISTASIS,acc_Additive)
263 | 
264 | 
--------------------------------------------------------------------------------
/Supp_code/GI_experiment/Gears_results.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bm2-lab/STAMP/f40e2887c013647bb01570bd61bdfa01d6a05b00/Supp_code/GI_experiment/Gears_results.pkl
--------------------------------------------------------------------------------
/Supp_code/GI_experiment/STAMP_results.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bm2-lab/STAMP/f40e2887c013647bb01570bd61bdfa01d6a05b00/Supp_code/GI_experiment/STAMP_results.pkl
--------------------------------------------------------------------------------
/Supp_code/GI_experiment/Science_data.csv:
--------------------------------------------------------------------------------
  1 | name,UMI_double,UMI_first,UMI_second,emap,ts_coef_first,ts_coef_second,de_double,de_first,de_second,ts_norm2,abs_log_ts_ratio,ts_linear_dcor,ts_score,dcor,dcor_singles,dcor_ratio
  2 | AHR_FEV,8636.844697,11797.92902,11839.87342,-5.027742939,1.324721494,0.897119607,3.410777233,3.069298012,3.110926242,1.599909568,0.169274234,0.848803361,0.291686949,0.802051082,0.423128473,0.949834231
  3 | AHR_KLF1,15603.35437,11797.92902,16965.72671,8.108849223,0.287289094,0.961618784,2.556302501,3.069298012,2.626340367,1.003616315,0.524683799,0.77444383,0.643809829,0.564353295,0.228726987,0.325484782
  4 | BPGM_SAMD1,11727.04583,13666.33842,12627.92697,0.133637871,0.881589627,0.562934492,3.030194785,2.45484486,2.346352974,1.045990206,0.194808612,0.826855941,0.824285302,0.776432808,0.428246225,0.996113432
  5 | BPGM_ZBTB1,13104.41696,13666.33842,13298.50174,-0.165732221,0.572838581,0.648653854,2.737987326,2.45484486,2.652246341,0.865387694,0.053980743,0.828207669,0.851846555,0.786469111,0.378112583,0.846149338
  6 | CBL_CNN1,10464.53472,13697.12082,13551.7975,-11.92419707,1.236388617,0.796307064,3.335457901,2.604226053,2.586587305,1.470633113,0.19107443,0.767189707,0.581215221,0.759021105,0.651102721,0.971246728
  7 | CBL_PTPN12,11755.97665,13697.12082,13775.92073,-9.292766305,1.027412249,0.488572154,3.005180513,2.604226053,1.857332496,1.137663693,0.322816028,0.787531436,0.604664938,0.763691494,0.616919891,0.899828956
  8 | CBL_PTPN9,11971.68803,13697.12082,14425.0804,-8.874028919,1.090314344,0.702850075,3.087781418,2.604226053,1.954242509,1.297221491,0.19068903,0.786436654,0.67898102,0.767071313,0.59465612,0.939267985
  9 | CBL_TGFBR2,13510.5641,13697.12082,15409.96646,-7.848298921,0.987539067,0.318175444,2.021189299,2.604226053,0.84509804,1.037530251,0.491887627,0.694873968,0.550328636,0.673294068,0.165872108,0.340666729
 10 | CBL_UBASH3A,11734.24,13697.12082,14870.69841,-12.5880076,1.219414513,0.39278773,1.397940009,2.604226053,0.698970004,1.281114341,0.491993446,0.565911049,0.319847012,0.55715236,0.498398153,0.646665863
 11 | CBL_UBASH3B,10905.36196,13697.12082,14311.16764,-17.37367685,1.189134451,1.209160918,3.339252634,2.604226053,2.389166084,1.695910042,0.00725314,0.81169121,0.607281853,0.806653002,0.722777602,0.971507213
 12 | CDKN1B_CDKN1A,16469.32653,15793.41045,16124.92366,1.298753859,0.811738956,0.552621113,2.626340367,2.691081492,2.589949601,0.981992988,0.166988915,0.858919528,0.857994563,0.852215814,0.779759604,0.986648095
 13 | CDKN1C_CDKN1A,16215.125,16374.58182,16124.92366,1.043144218,0.619520525,0.546857729,2.480006943,2.632457292,2.589949601,0.826352865,0.054181345,0.846051153,0.778233899,0.844506245,0.807302633,0.985484102
 14 | CDKN1C_CDKN1B,16601.72152,16374.58182,15320.97241,0.706050967,0.810262567,0.388682174,2.334453751,2.632457292,1.86923172,0.898665266,0.319031152,0.786032642,0.716122276,0.782493728,0.708938972,0.863180227
 15 | CEBPB_CEBPA,7563.773333,11799,7722.624031,-1.742055369,0.197434739,0.902550261,3.262213705,3.046885191,3.461348434,0.923892553,0.660047827,0.899737549,0.950724936,0.897086826,0.77120447,0.816624294
 16 | CEBPB_MAPK1,11963.90208,11799,13065.41995,0.043571624,0.443488455,0.718158591,3.103461622,3.046885191,2.790988475,0.844057919,0.209338042,0.888166659,0.868354231,0.850716029,0.533365168,0.97015814
 17 | CEBPB_OSR2,11856.60638,11799,13988.30155,-1.254239492,0.818998884,0.644165689,2.972202838,3.046885191,2.541579244,1.041973419,0.104285722,0.878186153,0.822441522,0.847253341,0.461574835,0.812212918
 18 | CEBPB_PTPN12,12005.34211,11799,13775.92073,5.129323636,0.549778509,0.379569796,2.890979597,3.046885191,1.857332496,0.668079067,0.160896113,0.797909156,0.672672467,0.756725077,0.228562525,0.577042676
 19 | CEBPE_CEBPA,6785.266667,12149.16702,7722.624031,0.137461117,0.2941997,0.906974965,3.468051791,3.258397804,3.461348434,0.953497274,0.488953073,0.946981921,0.973038243,0.945129984,0.796302841,0.843371476
 20 | CEBPE_CEBPB,8777.954955,12149.16702,11186.25275,-1.345570837,0.711181374,0.706802479,3.242789809,3.258397804,3.020775488,1.002670779,0.00268231,0.837563188,0.844607024,0.8383896,0.819268136,0.979346064
 21 | CEBPE_CNN1,12539.13402,12149.16702,13551.7975,1.575725753,0.516728029,0.556403091,2.618048097,3.258397804,2.586587305,0.759336721,0.032127514,0.792517361,0.60565655,0.738101034,0.243477827,0.708260752
 22 | CEBPE_KLF1,15994.29049,12149.16702,16965.72671,8.112668533,0.198732644,0.607721817,2.193124598,3.258397804,2.626340367,0.639390703,0.485435618,0.683678832,0.725060609,0.442425592,0.308756986,0.5598106
 23 | CEBPE_PTPN12,12346.59937,12149.16702,13775.92073,4.805238489,0.486174479,0.427682646,2.84509804,3.258397804,1.857332496,0.64751685,0.055670529,0.794011277,0.666559708,0.737870109,0.208562737,0.668670689
 24 | CEBPE_RUNX1T1,13164.67282,12149.16702,15136.08146,7.186707799,0.528024425,0.192067308,3.15715444,3.258397804,1,0.561871555,0.439200563,0.898611354,0.92359612,0.895592766,0.387950582,0.484840303
 25 | CEBPE_SPI1,7469.125714,12149.16702,12152.25893,-0.647642772,1.296935668,0.438928258,3.434888121,3.258397804,3.271376872,1.369196896,0.470524893,0.871711071,0.824827936,0.854669,0.724255829,0.901574566
 26 | CNN1_MAPK1,12634.65538,13734.63136,13065.41995,2.744813323,0.516854892,0.726869999,2.833784375,2.033423755,2.790988475,0.891896281,0.148088113,0.80036494,0.722336562,0.712052998,0.141684934,0.475171325
 27 | CNN1_UBASH3A,10354.79167,13734.63136,14870.69841,-12.62252645,1.172934016,1.394486602,3.158060794,2.033423755,0.698970004,1.822187446,0.075140764,0.625393712,0.420910238,0.600296595,0.33481356,0.822338643
 28 | DUSP9_ETS2,11765.28653,11420.06042,13646.32012,5.73928537,0.745848619,0.251165296,3.262688344,3.31868927,2.627365857,0.787003284,0.472691058,0.902021621,0.909803082,0.86340842,0.177495732,0.312562378
 29 | DUSP9_IGDCC3,11007.87307,11420.06042,14162.75336,-2.855596189,0.872373698,0.558357698,3.245018871,3.31868927,2.480006943,1.035760198,0.193790055,0.892098107,0.887314206,0.853389245,0.408061986,0.730936563
 30 | DUSP9_KLF1,12910.53438,11420.06042,16965.72671,-8.279382679,0.860066154,0.548829397,2.90579588,3.31868927,2.626340367,1.020258544,0.195094492,0.84462104,0.405028022,0.767668309,0.157677804,0.45161062
 31 | DUSP9_MAPK1,13607.44939,11420.06042,13065.41995,10.97294961,0.390445231,0.22762557,2.037426498,3.31868927,2.790988475,0.451952297,0.234339077,0.719977164,0.632241426,0.641788904,0.160302921,0.610895593
 32 | DUSP9_PRTG,11112.5,11420.06042,14082.03378,-0.274218251,0.741686659,0.768578988,3.142389466,3.31868927,2.539076099,1.068088367,0.01546804,0.880315779,0.888419445,0.840874145,0.431651033,0.863363575
 33 | DUSP9_SNAI1,9051.970238,11420.06042,12942.9625,-5.106041145,0.920403475,1.210886763,3.297541668,3.31868927,2.305351369,1.520983007,0.119125282,0.814706915,0.67332464,0.744778431,0.276644636,0.885213216
 34 | ETS2_CEBPE,12600.60625,13415.224,10803.07067,4.08893643,0.554168023,0.547054798,2.943988875,2.334453751,3.412124406,0.778698369,0.005610631,0.881355459,0.874423873,0.849920391,0.397292999,0.715649331
 35 | ETS2_CNN1,13301.58599,13415.224,13551.7975,2.963442466,0.531611948,0.694234984,2.976808337,2.334453751,2.586587305,0.874398923,0.115911762,0.799242773,0.686112904,0.659071405,0.171328119,0.712042781
 36 | ETS2_IGDCC3,13966.67123,13415.224,14162.75336,3.517946458,0.364204027,0.646645865,2.588831726,2.334453751,2.480006943,0.742155945,0.249321762,0.769855918,0.707286895,0.684466017,0.097831906,0.464417285
 37 | ETS2_IKZF3,11333.86598,13415.224,10792.22488,3.213049525,0.433618386,0.783590974,3.315760491,2.334453751,3.197831693,0.895566703,0.256981736,0.91081426,0.958604252,0.895268569,0.285972009,0.454071094
 38 | ETS2_MAP7D1,14029.75849,13415.224,16034.69553,-8.130058286,0.901764675,1.142253745,2.72427587,2.334453751,1.204119983,1.455308609,0.102669372,0.673798982,0.622982588,0.573413405,0.097997848,0.788321061
 39 | ETS2_MAPK1,10452.27778,13415.224,13065.41995,-2.492788735,0.467608915,1.069031241,3.360025089,2.334453751,2.790988475,1.166827276,0.359107614,0.87005775,0.875464014,0.868083617,0.782637405,0.892759196
 40 | ETS2_PRTG,13322.32486,13415.224,14082.03378,0.913348856,0.369143357,0.78031689,2.73239376,2.334453751,2.539076099,0.86322724,0.32507595,0.805695362,0.754610063,0.713911424,0.107060691,0.40320754
 41 | FEV_CBFA2T3,7040.738562,12598.44509,13696.46914,-9.174221894,1.464067071,1.173544895,3.399327532,2.546542663,2.11058971,1.876352848,0.096061264,0.736163457,0.590764104,0.665060987,0.158151012,0.636881675
 42 | FEV_ISL2,9070.136029,12598.44509,14545.00939,-6.76006398,1.074089588,0.938481999,3.363799945,2.546542663,0.903089987,1.426329873,0.05861456,0.760891753,0.549812377,0.732349249,0.202657432,0.532004074
 43 | FEV_MAP7D1,11403.55285,12598.44509,16034.69553,-17.98743136,0.799132772,1.14711789,3.122543524,2.546542663,1.204119983,1.398031703,0.156989112,0.724961339,0.659692891,0.660944997,0.101899662,0.605546768
 44 | FOSB_CEBPB,8741.094118,14237.66534,11186.25275,-0.838076988,0.642261329,1.181266678,3.167317335,2.357934847,3.020775488,1.344578216,0.26463618,0.826199929,0.826870444,0.826074518,0.684825928,0.801750062
 45 | FOSB_CEBPE,9682.746479,14237.66534,10803.07067,-4.770079913,0.342593158,1.140074278,3.23350376,2.357934847,3.412124406,1.190436656,0.522154463,0.880114063,0.902458792,0.879180986,0.703304378,0.772134051
 46 | FOSB_IKZF3,12068.60784,14237.66534,10792.22488,6.979696085,0.455352082,0.552411675,2.871572936,2.357934847,3.197831693,0.715893971,0.083915522,0.852862126,0.93357552,0.829234283,0.340114162,0.601480227
 47 | FOSB_OSR2,12804.59877,14237.66534,13988.30155,-2.286219529,0.94285204,0.84158268,2.897077003,2.357934847,2.541579244,1.263816196,0.049346756,0.843472206,0.717954019,0.797591593,0.383793997,0.969500562
 48 | FOSB_PTPN12,13597.76632,14237.66534,13775.92073,6.08485282,0.407665944,0.537729177,2.204119983,2.357934847,1.857332496,0.67479196,0.12025917,0.667084148,0.545325012,0.60012142,0.144863939,0.620875
 49 | FOSB_UBASH3B,13442.14325,14237.66534,14311.16764,5.691658063,0.552437883,0.532223703,2.307496038,2.357934847,2.389166084,0.767104742,0.01618924,0.749868517,0.680892429,0.67896995,0.206272503,0.806498428
 50 | FOXA1_FOXF1,11799.84388,13616.223,14341.4794,-3.404568293,1.096066242,0.689994223,2.959518377,2.320146286,1.204119983,1.295165331,0.200991348,0.736810827,0.627424694,0.720773123,0.525740726,0.821520243
 51 | FOXA1_FOXL2,10116.68715,13616.223,11533.13278,-1.762576429,0.555985598,1.002491852,3.105169428,2.320146286,2.79518459,1.146346326,0.256017309,0.786560694,0.687887234,0.777181775,0.579837524,0.810366005
 52 | FOXA1_HOXB9,9851.755639,13616.223,11882.16406,-3.677702592,0.835984464,1.084173293,3.322012439,2.320146286,2.666517981,1.369051406,0.112900499,0.812205276,0.703061854,0.796701741,0.539119633,0.82919151
 53 | FOXA3_FOXA1,11293.43094,12535.74328,13077.36418,-0.895729844,1.227342903,0.507016092,3.027349608,2.980457892,2.301029996,1.327944245,0.383944172,0.865852363,0.828298989,0.858789426,0.725421573,0.861392296
 54 | FOXA3_FOXF1,11358.5274,12535.74328,14341.4794,-1.690503757,1.096812933,0.279254637,2.758911892,2.980457892,1.204119983,1.131804649,0.59413217,0.787445666,0.671841143,0.781515531,0.557535841,0.689736249
 55 | FOXA3_FOXL2,8891.40708,12535.74328,11533.13278,-0.947739512,1.093592622,0.908210321,3.181271772,2.980457892,2.79518459,1.421545219,0.080669139,0.786489903,0.668007277,0.780517965,0.694711141,0.98975022
 56 | FOXA3_HOXB9,9656.863787,12535.74328,11882.16406,-2.790099564,1.137353739,0.797940651,3.389520466,2.980457892,2.666517981,1.389346181,0.153924969,0.843231765,0.718681802,0.833232915,0.680176106,0.960565661
 57 | FOXF1_FOXL2,9703.034682,14135.87277,11533.13278,-2.28736857,0.786630895,1.233076742,3.224533063,1.949390007,2.79518459,1.462623129,0.195219107,0.782206312,0.682867138,0.778978448,0.602816606,0.811885278
 58 | FOXF1_HOXB9,10817.24855,14135.87277,11882.16406,-2.890810768,0.876665612,0.930986539,2.996073654,1.949390007,2.666517981,1.278780095,0.02610943,0.791389107,0.690310743,0.776576923,0.508089544,0.812075189
 59 | FOXL2NB_FOXL2,11012.59184,11064.30415,11533.13278,1.779331866,0.716285078,0.361353683,1.414973348,2.931966115,2.79518459,0.802272271,0.297153419,0.647031936,0.515983634,0.647738214,0.782128897,0.925137448
 60 | FOXL2_HOXB9,8482.770115,11907.10891,11882.16406,-1.329540145,1.004384228,0.898307234,3.04060234,2.902002891,2.666517981,1.347495293,0.048474987,0.808048675,0.71022185,0.800533292,0.679832039,0.998586688
 61 | FOXL2_MEIS1,10409.43503,11907.10891,13716.44872,-2.39977222,0.921020134,0.766142188,3.015778756,2.902002891,1.707570176,1.198020008,0.079959747,0.79626318,0.699111897,0.777135414,0.487731079,0.791939297
 62 | IGDCC3_MAPK1,14003.10239,13874.75352,13065.41995,4.300250641,0.455597543,0.438698989,2.431363764,2.629409599,2.790988475,0.632476027,0.01641474,0.729522831,0.609531652,0.651376959,0.163937861,0.926295594
 63 | IGDCC3_PRTG,12710.94444,13874.75352,14082.03378,1.608980686,0.586107352,0.634582591,2.346352974,2.629409599,2.539076099,0.863838465,0.034510984,0.782351469,0.807184237,0.781076535,0.772179546,0.991147957
 64 | IGDCC3_ZBTB25,12853.12121,13874.75352,14683.2996,-6.576366626,1.110245165,0.783264841,2.733999287,2.629409599,2.136720567,1.358730341,0.151510258,0.806253481,0.761795215,0.787228576,0.524281876,0.833457845
 65 | IRF1_SET,12962.80926,15671.99698,13421.56475,-6.988736995,0.697645225,0.795742337,3.195622944,2.764176132,2.908485019,1.058260236,0.057137839,0.880045613,0.838025145,0.825844081,0.325779767,0.881354135
 66 | JUN_CEBPA,6258.328571,11627.30213,7722.624031,-0.911147039,0.362276276,0.986779584,3.314288661,2.86332286,3.461348434,1.05117936,0.435180261,0.906509709,0.944949647,0.90102418,0.437827672,0.527494198
 67 | JUN_CEBPB,8086.211538,11627.30213,11186.25275,-1.561250115,0.663613429,0.896706075,2.86569606,2.86332286,3.020775488,1.115555722,0.130734946,0.775955397,0.71301994,0.760118685,0.511400601,0.806377417
 68 | KLF1_CEBPA,10018.12987,16941.35607,7722.624031,10.07008471,0.191513545,0.723248907,3.411282913,2.833784375,3.461348434,0.748175393,0.57708829,0.91467045,0.946912391,0.909806086,0.156143605,0.132127382
 69 | KLF1_CLDN6,13293.6186,16941.35607,12557.41065,-7.140357162,0.731443935,0.753947434,2.788875116,2.833784375,3.005180513,1.050450838,0.013160024,0.749898908,0.755937195,0.652890816,0.220085605,0.51273542
 70 | KLF1_COL2A1,15956.68526,16941.35607,14324.93392,-7.527070526,0.625799336,0.803624774,2.763427994,2.833784375,2.669316881,1.018546801,0.108618218,0.769249034,0.623724153,0.703331008,0.095874437,0.464221231
 71 | KLF1_FOXA1,14289.12074,16941.35607,13077.36418,-4.660041825,0.733820937,0.82342367,2.5774918,2.833784375,2.301029996,1.102959523,0.050033248,0.702063063,0.656328187,0.56665259,0.17001451,0.828928502
 72 | KLF1_MAP2K6,18223.34234,16941.35607,16847.97619,-1.245393582,0.774651018,0.432913551,2.26245109,2.833784375,2.113943352,0.887411033,0.252704915,0.810674387,0.800950346,0.770662068,0.409285776,0.746892754
 73 | KLF1_TGFBR2,16120.23194,16941.35607,15409.96646,-8.882663739,0.746960196,0.30598057,2.181843588,2.833784375,0.84509804,0.807201117,0.387603611,0.701582403,0.40229939,0.678914207,0.095663258,0.293754476
 74 | LHX1_ELMSAN1,12547.65814,12104.16667,16978.7907,-6.676777988,0.841852226,0.441000502,2.988558957,2.705007959,1.851258349,0.950366568,0.28079678,0.74837991,0.638229829,0.666713304,0.328110716,0.154103481
 75 | LYL1_CEBPB,12723.45912,13887.6,11186.25275,6.734497397,0.468963309,0.499630597,2.426511261,1.633468456,3.020775488,0.685242526,0.027510162,0.749679529,0.569130773,0.729852975,0.29445261,0.558705996
 76 | LYL1_IER5L,10272.14991,13887.6,11958.24103,-5.971554097,0.301680001,0.985819281,3.421274791,1.633468456,2.681241237,1.030946399,0.514250787,0.813960909,0.771745767,0.809115129,0.489713229,0.609163059
 77 | MAP2K3_ELMSAN1,20565.01446,19002.51092,16978.7907,-4.616741496,0.907377257,0.769829495,3.177824972,2.795880017,1.851258349,1.189945771,0.071393343,0.853404464,0.724937032,0.837294085,0.56677764,0.861671302
 78 | MAP2K3_IKZF3,14750.99091,19002.51092,10792.22488,9.537746863,0.248636517,0.61311005,2.617000341,2.795880017,3.197831693,0.661607173,0.391973521,0.791968058,0.360803035,0.763459396,0.418720863,0.257814491
 79 | MAP2K3_MAP2K6,19073.95633,19002.51092,16847.97619,-5.182336519,0.55500612,0.432476343,2.790988475,2.795880017,2.113943352,0.703610389,0.108335416,0.876210209,0.83625186,0.871670331,0.73734692,0.942076989
 80 | MAP2K3_SLC38A2,18358.99327,19002.51092,15759.45357,-1.534254356,0.768911252,0.437098605,2.599883072,2.795880017,1.230448921,0.884465774,0.245296796,0.82213848,0.757333072,0.796772502,0.356687704,0.62343575
 81 | MAP2K6_ELMSAN1,18832.58904,18377.71625,16978.7907,-3.780749567,0.856571607,0.55083486,2.754348336,2.324282455,1.851258349,1.018397742,0.191742257,0.821666947,0.748537531,0.799020209,0.500132649,0.828879281
 82 | MAP2K6_IKZF3,13342.98805,18377.71625,10792.22488,8.961105885,0.251362864,0.665459856,2.84135947,2.324282455,3.197831693,0.711350905,0.422820746,0.846312187,0.704434855,0.828969486,0.352050343,0.240160044
 83 | MAP2K6_SPI1,15388.49174,18377.71625,12152.25893,7.691201953,0.196648506,0.459466711,2.290034611,2.324282455,3.271376872,0.499780245,0.3685634,0.711306252,0.728633493,0.687616749,0.413254143,0.285389156
 84 | MAPK1_IKZF3,11846.40152,13076.2509,10792.22488,3.415792474,0.415440051,0.600631396,3.105510185,2.59439255,3.197831693,0.730307134,0.160099667,0.861870425,0.927883881,0.831736857,0.341522917,0.643318459
 85 | MAPK1_PRTG,14127.37368,13076.2509,14082.03378,1.350838838,0.410133394,0.562854605,2.664641976,2.59439255,2.539076099,0.696429973,0.137471091,0.742901276,0.589695206,0.646810737,0.131344164,0.884243716
 86 | MAPK1_TGFBR2,13754.95887,13076.2509,15409.96646,3.95567623,0.746940705,0.136640751,2.636487896,2.59439255,0.84509804,0.759335968,0.737705889,0.806045323,0.801448703,0.796959163,0.233062177,0.206835257
 87 | PLK4_STIL,12173.63636,14449.81443,12756.88652,-2.655264906,0.419465981,0.752439348,1.556302501,0,0.903089987,0.861461944,0.253774753,0.524852422,0.324729848,0.499191032,0.40318767,0.776153436
 88 | POU3F2_CBFA2T3,10481.31188,12106.46154,13696.46914,-6.602857974,1.171567589,0.621729535,3.0923697,2.58546073,2.11058971,1.326317545,0.27516585,0.793951694,0.706853576,0.750946609,0.35377785,0.740505724
 89 | POU3F2_FOXL2,8979.702797,12106.46154,11533.13278,-3.005394022,0.92270723,1.020388228,3.401400541,2.58546073,2.79518459,1.375711004,0.043701516,0.8408346,0.771167876,0.82431596,0.583569552,0.919215676
 90 | PRDM1_CBFA2T3,7352.633028,10256.42857,13696.46914,-2.535841296,1.224280108,0.499653669,3.315340477,3.245759356,2.11058971,1.322314475,0.389211712,0.815008139,0.682092841,0.801459347,0.36569377,0.549864483
 91 | PTPN12_OSR2,14426.18375,14014.87629,13988.30155,7.367622216,0.288280185,0.332547939,1.633468456,1.826074803,2.541579244,0.440106347,0.062039468,0.575267573,0.497319864,0.484016602,0.135709611,0.574755643
 92 | PTPN12_PTPN9,11696.46102,14014.87629,14425.0804,-9.005757018,0.7383842,0.868063435,3.157456768,1.826074803,1.954242509,1.139625094,0.070269068,0.776336812,0.655984144,0.752076023,0.4861733,0.962740284
 93 | PTPN12_SNAI1,12319.45902,14014.87629,12942.9625,10.07555502,0.600056214,0.573352956,2.701567985,1.826074803,2.305351369,0.829940403,0.019769882,0.757667528,0.656193036,0.699233463,0.226777577,0.782388504
 94 | PTPN12_UBASH3A,11319.51765,14014.87629,14870.69841,-11.66040496,1.07642032,1.018797608,3.157758886,1.826074803,0.698970004,1.482102989,0.023893971,0.707980549,0.478099864,0.684474996,0.351565613,0.728704695
 95 | PTPN12_ZBTB25,11876.8677,14014.87629,14683.2996,-8.056528667,0.751759062,0.983985647,3.067070856,1.826074803,2.136720567,1.238292954,0.116910092,0.794687457,0.6872563,0.768537561,0.494841647,0.942712472
 96 | RHOXF2B_SET,10323.2397,10930.4069,13421.56475,7.164357199,0.705755666,0.609668871,3.300378065,3.207634367,2.908485019,0.932623822,0.063560353,0.872215805,0.826291702,0.840901018,0.490519694,0.953460818
 97 | RHOXF2B_ZBTB25,11005.28293,10930.4069,14683.2996,6.847164341,0.763576583,0.802227653,3.127428778,3.207634367,2.136720567,1.107528061,0.021445027,0.851454028,0.716207176,0.819650761,0.361491225,0.682564476
 98 | SAMD1_PTPN12,10546.71575,11822.51587,13775.92073,-6.398557523,0.909073745,0.551394305,3.285782274,2.913283902,1.857332496,1.063226577,0.217136839,0.828934693,0.704138241,0.798401282,0.431435321,0.759486087
 99 | SAMD1_TGFBR2,11972.27119,11822.51587,15409.96646,-8.107851289,0.869725389,0.43445474,1.322219295,2.913283902,0.84509804,0.972200171,0.301437608,0.640896979,0.599548983,0.632411771,0.080093324,0.228092849
100 | SAMD1_UBASH3B,10741.57416,11822.51587,14311.16764,-7.528617025,0.893324788,0.917398582,3.211921084,2.913283902,2.389166084,1.280487929,0.011548679,0.84098202,0.778241998,0.825687121,0.545809017,0.814472575
101 | SAMD1_ZBTB1,10795.85366,11822.51587,13298.50174,0.372634559,0.749627701,0.827827763,3.046495164,2.913283902,2.652246341,1.116799219,0.043094361,0.849807895,0.796000237,0.816243375,0.473419466,0.959725238
102 | SET_CEBPE,13110.97543,14035.45117,10803.07067,3.821361642,0.624992472,0.203585476,3.104145551,3.083860801,3.412124406,0.657314716,0.487127994,0.855534698,0.784660789,0.773792775,0.396066676,0.735155459
103 | SET_KLF1,16486.39926,14035.45117,16965.72671,-4.790630364,0.666004631,0.724769436,2.924795996,3.083860801,2.626340367,0.984303258,0.036722621,0.836322527,0.697481086,0.686002014,0.236300229,0.580657286
104 | SGK1_TBX2,12025.19512,12336.787,13914.72881,1.310275508,0.804936225,0.731411338,3.14082218,3.10720997,2.434568904,1.087605108,0.041599784,0.868834244,0.865735206,0.842236927,0.51071007,0.959354312
105 | SGK1_TBX3,12739.70782,12336.787,14505,-0.389360046,0.784773698,0.647871406,2.894316063,3.10720997,2.133538908,1.017647835,0.083255627,0.890050364,0.736125208,0.923203425,0.370183246,0.682986951
106 | SNAI1_DLX2,9787,13686.24545,12835.45172,-4.38912823,1.14532526,0.882836892,3.098643726,2.285557309,2.591064607,1.446088147,0.113048366,0.744400927,0.626867608,0.721699525,0.499316035,0.873888934
107 | SNAI1_UBASH3B,12653.79297,13686.24545,14311.16764,6.9698056,0.736916087,0.891443369,2.855519156,2.285557309,2.389166084,1.156596991,0.082675721,0.818085084,0.767350742,0.773401166,0.351883827,0.822914157
108 | TBX3_TBX2,13807.30031,14376.44637,13914.72881,-0.952148354,0.735713499,0.64424836,3.217220656,2.117271296,2.434568904,0.977921419,0.057655402,0.900807612,0.892832212,0.88290534,0.35702695,0.86670143
109 | TGFBR2_CBARP,12145.62409,16065.11483,14362.77407,-7.115548769,0.090276718,0.850719808,2.742725131,0.301029996,1.716003344,0.855496393,0.974210783,0.537434279,0.398571272,0.503398657,0.10339338,0.209380757
110 | TGFBR2_ETS2,14768.47636,16065.11483,13646.32012,5.818769871,0.116972048,0.567208074,1.755874856,0.301029996,2.627365857,0.579143729,0.685660311,0.633545702,0.582562968,0.596019738,0.238684094,0.152343388
111 | TGFBR2_IGDCC3,13191.4542,16065.11483,14162.75336,0.485365858,0.122216034,0.760712652,2.549003262,0.301029996,2.480006943,0.770467714,0.794092454,0.756707278,0.724232417,0.742370505,0.172887454,0.262692258
112 | TGFBR2_PRTG,12900.80769,16065.11483,14082.03378,1.630267601,0.051310753,0.923956999,2.544068044,0.301029996,2.539076099,0.925380641,1.255443368,0.784791742,0.823568472,0.768317952,0.203975856,0.248194348
113 | UBASH3B_CNN1,10725.07324,14063.23404,13551.7975,-16.58980658,1.212246283,0.962860571,3.3586961,2.378397901,2.586587305,1.548109017,0.100027458,0.808383898,0.682219701,0.798315986,0.621193342,0.96783611
114 | UBASH3B_OSR2,14884.37075,14063.23404,13988.30155,6.65485367,-0.01838437,0.367629526,2.11058971,2.378397901,2.541579244,0.368088921,1.300961639,0.682917797,0.622222396,0.644576921,0.260723913,0.315819702
115 | UBASH3B_PTPN12,11332.46154,14063.23404,13775.92073,-10.76779039,1.021372069,0.595685121,3.122215878,2.378397901,1.857332496,1.18238812,0.234167225,0.764947361,0.591454917,0.741417699,0.587460103,0.948035612
116 | UBASH3B_PTPN9,12555.31915,14063.23404,14425.0804,-7.66800543,1.289197478,0.653431269,3.098643726,2.378397901,1.954242509,1.445338217,0.295119534,0.812567054,0.711899225,0.793137969,0.664629088,0.936463678
117 | UBASH3B_UBASH3A,14188.25352,14063.23404,14870.69841,2.623744933,0.913865214,0.410892151,2.683947131,2.378397901,0.698970004,1.001989017,0.347154301,0.78296664,0.726350132,0.760406899,0.477881348,0.709040053
118 | UBASH3B_ZBTB25,12408.12946,14063.23404,14683.2996,-8.06503663,1.02071028,0.687041565,2.895974732,2.378397901,2.136720567,1.230396517,0.171919477,0.790346045,0.59809173,0.780251964,0.663206049,0.981034076
119 | ZBTB10_DLX2,8448.364865,10531.65517,12835.45172,-3.74965872,1.066279561,0.846798705,3.13481437,2.988558957,2.591064607,1.361624085,0.1000909,0.792719301,0.683952394,0.781513116,0.400261218,0.617034869
120 | ZBTB10_ELMSAN1,10288.49068,10531.65517,16978.7907,-7.543480386,1.009903688,0.172350471,3.121231455,2.988558957,1.851258349,1.024504828,0.767867484,0.849333398,0.809231506,0.843521274,0.39078376,0.356840993
121 | ZBTB10_PTPN12,10124.58261,10531.65517,13775.92073,7.752122403,0.728830991,0.455160448,3.254064453,2.988558957,1.857332496,0.859282053,0.204462315,0.844756886,0.750872545,0.82377239,0.237415929,0.452185145
122 | ZBTB10_SNAI1,8498.470588,10531.65517,12942.9625,-1.716184098,0.935797128,0.63951727,3.14113609,2.988558957,2.305351369,1.133445457,0.165329431,0.828663341,0.681733092,0.817607998,0.549498791,0.790953551
123 | ZC3HAV1_CEBPA,7989.677419,16377.05505,7722.624031,-9.826651018,0.318894432,0.891369331,3.36078269,2.045322979,3.461348434,0.946695803,0.446410751,0.912331818,0.940605339,0.909385192,0.236326219,0.194054803
124 | ZC3HAV1_CEBPE,15794.74566,16377.05505,10803.07067,8.288154134,0.63493671,0.215502289,2.149219113,2.045322979,3.412124406,0.670511643,0.46927855,0.690145499,0.720944273,0.486556144,0.21748292,0.968025272
125 | ZC3HAV1_HOXC13,12174.1418,16377.05505,10738.57813,6.368610928,0.53804059,0.604288531,3.135450699,2.045322979,3.015778756,0.809105868,0.050429311,0.791498403,0.68866069,0.725703468,0.251890652,0.193133306
126 | ZNF318_FOXL2,12446.00508,15384.52865,11533.13278,-3.079212783,0.816195732,0.732563732,2.609594409,1.763427994,2.79518459,1.09673383,0.046948906,0.762681444,0.670256568,0.723900023,0.217552433,0.577758247
127 | 
--------------------------------------------------------------------------------
/Supp_code/GI_experiment/Truth_results.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bm2-lab/STAMP/f40e2887c013647bb01570bd61bdfa01d6a05b00/Supp_code/GI_experiment/Truth_results.pkl
--------------------------------------------------------------------------------
/Supp_code/Modified_CPA/_data.py:
--------------------------------------------------------------------------------
 1 | from typing import Optional
 2 | 
 3 | from scvi import settings
 4 | from scvi.data import AnnDataManager
 5 | from scvi.dataloaders import DataSplitter, AnnDataLoader
 6 | from scvi.model._utils import parse_use_gpu_arg
 7 | 
 8 | 
 9 | class AnnDataSplitter(DataSplitter):
10 |     def __init__(
11 |             self,
12 |             adata_manager: AnnDataManager,
13 |             train_indices,
14 |             valid_indices,
15 |             test_indices,
16 |             use_gpu: bool = False,
17 |             **kwargs,
18 |     ):
19 |         super().__init__(adata_manager)
20 |         self.data_loader_kwargs = kwargs
21 |         self.use_gpu = use_gpu
22 |         self.train_idx = train_indices
23 |         self.val_idx = valid_indices
24 |         self.test_idx = test_indices
25 | 
26 |     def setup(self, stage: Optional[str] = None):
27 |         accelerator, _, self.device = parse_use_gpu_arg(
28 |             self.use_gpu, return_device=True
29 |         )
30 |         self.pin_memory = (
31 |             True
32 |             if (settings.dl_pin_memory_gpu_training and accelerator == "gpu")
33 |             else False
34 |         )
35 | 
36 |     def train_dataloader(self):
37 |         if len(self.train_idx) > 0:
38 |             return AnnDataLoader(
39 |                 self.adata_manager,
40 |                 indices=self.train_idx,
41 |                 shuffle=True,
42 |                 pin_memory=self.pin_memory,
43 |                 **self.data_loader_kwargs,
44 |             )
45 |         else:
46 |             pass
47 | 
48 |     def val_dataloader(self):
49 |         if len(self.val_idx) > 0:
50 |             data_loader_kwargs = self.data_loader_kwargs.copy()
51 |             # if len(self.valid_indices < 4096):
52 |             #     data_loader_kwargs.update({'batch_size': len(self.valid_indices)})
53 |             # else:
54 |             #     data_loader_kwargs.update({'batch_size': 2048})
55 |             return AnnDataLoader(
56 |                 self.adata_manager,
57 |                 indices=self.val_idx,
58 |                 shuffle=True,
59 |                 pin_memory=self.pin_memory,
60 |                 **data_loader_kwargs,
61 |             )
62 |         else:
63 |             pass
64 | 
65 |     def test_dataloader(self):
66 |         if len(self.test_idx) > 0:
67 |             return AnnDataLoader(
68 |                 self.adata_manager,
69 |                 indices=self.test_idx,
70 |                 shuffle=True,
71 |                 pin_memory=self.pin_memory,
72 |                 **self.data_loader_kwargs,
73 |             )
74 |         else:
75 |             pass
76 | 
--------------------------------------------------------------------------------
/Supp_code/Modified_CPA/_metrics.py:
--------------------------------------------------------------------------------
 1 | import numpy as np
 2 | from scipy.stats import entropy
 3 | from sklearn.neighbors import NearestNeighbors
 4 | from sklearn.preprocessing import LabelEncoder
 5 | 
 6 | 
 7 | def knn_purity(data, labels: np.ndarray, n_neighbors=30):
 8 |     """Computes KNN Purity for ``data`` given the labels.
 9 |         Parameters
10 |         ----------
11 |         data:
12 |             Numpy ndarray of data
13 |         labels
14 |             Numpy ndarray of labels
15 |         n_neighbors: int
16 |             Number of nearest neighbors.
17 |         Returns
18 |         -------
19 |         score: float
20 |             KNN purity score. A float between 0 and 1.
21 |     """
22 |     labels = LabelEncoder().fit_transform(labels.ravel())
23 | 
24 |     nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(data)
25 |     indices = nbrs.kneighbors(data, return_distance=False)[:, 1:]
26 |     neighbors_labels = np.vectorize(lambda i: labels[i])(indices)
27 | 
28 |     # pre cell purity scores
29 |     scores = ((neighbors_labels - labels.reshape(-1, 1)) == 0).mean(axis=1)
30 |     res = [
31 |         np.mean(scores[labels == i]) for i in np.unique(labels)
32 |     ]  # per cell-type purity
33 | 
34 |     return np.mean(res)
35 | 
36 | 
37 | def entropy_batch_mixing(data, labels,
38 |                          n_neighbors=50, n_pools=50, n_samples_per_pool=100):
39 |     """Computes Entory of Batch mixing metric for ``adata`` given the batch column name.
40 |         Parameters
41 |         ----------
42 |         data
43 |             Numpy ndarray of data
44 |         labels
45 |             Numpy ndarray of labels
46 |         n_neighbors: int
47 |             Number of nearest neighbors.
48 |         n_pools: int
49 |             Number of EBM computation which will be averaged.
50 |         n_samples_per_pool: int
51 |             Number of samples to be used in each pool of execution.
52 |         Returns
53 |         -------
54 |         score: float
55 |             EBM score. A float between zero and one.
56 |     """
57 | 
58 |     def __entropy_from_indices(indices, n_cat):
59 |         return entropy(np.array(np.unique(indices, return_counts=True)[1].astype(np.int32)), base=n_cat)
60 | 
61 |     n_cat = len(np.unique(labels))
62 |     # print(f'Calculating EBM with n_cat = {n_cat}')
63 | 
64 |     neighbors = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(data)
65 |     indices = neighbors.kneighbors(data, return_distance=False)[:, 1:]
66 |     batch_indices = np.vectorize(lambda i: labels[i])(indices)
67 | 
68 |     entropies = np.apply_along_axis(__entropy_from_indices, axis=1, arr=batch_indices, n_cat=n_cat)
69 | 
70 |     # average n_pools entropy results where each result is an average of n_samples_per_pool random samples.
71 |     if n_pools == 1:
72 |         score = np.mean(entropies)
73 |     else:
74 |         score = np.mean([
75 |             np.mean(entropies[np.random.choice(len(entropies), size=n_samples_per_pool)])
76 |             for _ in range(n_pools)
77 |         ])
78 | 
79 |     return score
80 | 
--------------------------------------------------------------------------------
/Supp_code/Modified_CPA/_module.py:
--------------------------------------------------------------------------------
  1 | import numpy as np
  2 | import torch
  3 | import torch.nn as nn
  4 | from scvi import settings
  5 | from scvi.distributions import NegativeBinomial, ZeroInflatedNegativeBinomial
  6 | from scvi.module import Classifier
  7 | from scvi.module.base import BaseModuleClass, auto_move_data
  8 | from scvi.nn import Encoder, DecoderSCVI
  9 | from torch.distributions import Normal
 10 | from torch.distributions.kl import kl_divergence as kl
 11 | from torchmetrics.functional import accuracy, pearson_corrcoef, r2_score
 12 | 
 13 | from ._metrics import knn_purity
 14 | from ._utils import PerturbationNetwork, VanillaEncoder, CPA_REGISTRY_KEYS
 15 | 
 16 | from typing import Optional
 17 | 
 18 | 
 19 | class CPAModule(BaseModuleClass):
 20 |     """
 21 |     CPA module using Gaussian/NegativeBinomial/Zero-InflatedNegativeBinomial Likelihood
 22 | 
 23 |     Parameters
 24 |     ----------
 25 |         n_genes: int
 26 |             Number of input genes
 27 |         n_perts: int,
 28 |             Number of total unique perturbations
 29 |         covars_encoder: dict
 30 |             Dictionary of covariates with keys as each covariate name and values as
 31 |                 unique values of the corresponding covariate
 32 |         n_latent: int
 33 |             dimensionality of the latent space
 34 |         recon_loss: str
 35 |             Autoencoder loss (either "gauss", "nb" or "zinb")
 36 |         doser_type: str
 37 |             Type of dosage network (either "logsigm", "sigm", or "linear")
 38 |         n_hidden_encoder: int
 39 |             Number of hidden units in encoder
 40 |         n_layers_encoder: int
 41 |             Number of layers in encoder
 42 |         n_hidden_decoder: int
 43 |             Number of hidden units in decoder
 44 |         n_layers_decoder: int
 45 |             Number of layers in decoder
 46 |         n_hidden_doser: int
 47 |             Number of hidden units in dosage network
 48 |         n_layers_doser: int
 49 |             Number of layers in dosage network
 50 |         use_batch_norm_encoder: bool
 51 |             Whether to use batch norm in encoder
 52 |         use_layer_norm_encoder: bool
 53 |             Whether to use layer norm in encoder
 54 |         use_batch_norm_decoder: bool
 55 |             Whether to use batch norm in decoder
 56 |         use_layer_norm_decoder: bool
 57 |             Whether to use layer norm in decoder
 58 |         dropout_rate_encoder: float
 59 |             Dropout rate in encoder
 60 |         dropout_rate_decoder: float
 61 |             Dropout rate in decoder
 62 |         variational: bool
 63 |             Whether to use variational inference
 64 |         seed: int
 65 |             Random seed
 66 |     """
 67 | 
 68 |     def __init__(self,
 69 |                  n_genes: int,
 70 |                  n_perts: int,
 71 |                  covars_encoder: dict,
 72 |                  drug_embeddings: Optional[np.ndarray] = None,
 73 |                  n_latent: int = 128,
 74 |                  recon_loss: str = "nb",
 75 |                  doser_type: str = "logsigm",
 76 |                  n_hidden_encoder: int = 256,
 77 |                  n_layers_encoder: int = 3,
 78 |                  n_hidden_decoder: int = 256,
 79 |                  n_layers_decoder: int = 3,
 80 |                  n_hidden_doser: int = 128,
 81 |                  n_layers_doser: int = 2,
 82 |                  use_batch_norm_encoder: bool = True,
 83 |                  use_layer_norm_encoder: bool = False,
 84 |                  use_batch_norm_decoder: bool = True,
 85 |                  use_layer_norm_decoder: bool = False,
 86 |                  dropout_rate_encoder: float = 0.0,
 87 |                  dropout_rate_decoder: float = 0.0,
 88 |                  variational: bool = False,
 89 |                  seed: int = 0,
 90 |                  ):
 91 |         super().__init__()
 92 | 
 93 |         recon_loss = recon_loss.lower()
 94 |         assert recon_loss in ['gauss', 'zinb', 'nb']
 95 | 
 96 |         torch.manual_seed(seed)
 97 |         np.random.seed(seed)
 98 |         settings.seed = seed
 99 | 
100 |         self.n_genes = n_genes
101 |         self.n_perts = n_perts
102 |         self.n_latent = n_latent
103 |         self.recon_loss = recon_loss
104 |         self.doser_type = doser_type
105 |         self.variational = variational
106 | 
107 |         self.covars_encoder = covars_encoder
108 | 
109 |         if variational:
110 |             self.encoder = Encoder(
111 |                 n_genes,
112 |                 n_latent,
113 |                 var_activation=nn.Softplus(),
114 |                 n_hidden=n_hidden_encoder,
115 |                 n_layers=n_layers_encoder,
116 |                 use_batch_norm=use_batch_norm_encoder,
117 |                 use_layer_norm=use_layer_norm_encoder,
118 |                 dropout_rate=dropout_rate_encoder,
119 |                 activation_fn=nn.ReLU,
120 |                 return_dist=True,
121 |             )
122 |         else:
123 |             self.encoder = VanillaEncoder(
124 |                 n_input=n_genes,
125 |                 n_output=n_latent,
126 |                 n_cat_list=[],
127 |                 n_hidden=n_hidden_encoder,
128 |                 n_layers=n_layers_encoder,
129 |                 use_batch_norm=use_batch_norm_encoder,
130 |                 use_layer_norm=use_layer_norm_encoder,
131 |                 dropout_rate=dropout_rate_encoder,
132 |                 activation_fn=nn.ReLU,
133 |                 output_activation='linear',
134 |             )
135 | 
136 |         # Decoder components
137 |         if self.recon_loss in ['zinb', 'nb']:
138 |             # setup the parameters of your generative model, as well as your inference model
139 |             self.px_r = torch.nn.Parameter(torch.randn(self.n_genes))
140 | 
141 |             # decoder goes from n_latent-dimensional space to n_input-d data
142 |             self.decoder = DecoderSCVI(
143 |                 n_input=n_latent,
144 |                 n_output=n_genes,
145 |                 n_layers=n_layers_decoder,
146 |                 n_hidden=n_hidden_decoder,
147 |                 use_batch_norm=use_batch_norm_decoder,
148 |                 use_layer_norm=use_layer_norm_decoder,
149 |             )
150 | 
151 |         elif recon_loss == "gauss":
152 |             self.decoder = Encoder(n_input=n_latent,
153 |                                    n_output=n_genes,
154 |                                    n_layers=n_layers_decoder,
155 |                                    n_hidden=n_hidden_decoder,
156 |                                    dropout_rate=dropout_rate_decoder,
157 |                                    use_batch_norm=use_batch_norm_decoder,
158 |                                    use_layer_norm=use_layer_norm_decoder,
159 |                                    var_activation=None,
160 |                                    )
161 | 
162 |         else:
163 |             raise Exception('Invalid Loss function for Autoencoder')
164 | 
165 |         # Embeddings
166 |         # 1. Drug Network
167 |         self.pert_network = PerturbationNetwork(n_perts=n_perts,
168 |                                                 n_latent=n_latent,
169 |                                                 doser_type=doser_type,
170 |                                                 n_hidden=n_hidden_doser,
171 |                                                 n_layers=n_layers_doser,
172 |                                                 drug_embeddings=drug_embeddings,
173 |                                                 )
174 | 
175 |         # 2. Covariates Embedding
176 |         self.covars_embeddings = nn.ModuleDict(
177 |             {
178 |                 key: torch.nn.Embedding(len(unique_covars), n_latent)
179 |                 for key, unique_covars in self.covars_encoder.items()
180 |             }
181 |         )
182 | 
183 |         self.metrics = {
184 |             'pearson_r': pearson_corrcoef,
185 |             'r2_score': r2_score
186 |         }
187 | 
188 |     def mixup_data(self, tensors, alpha: float = 0.0, opt=False):
189 |         """
190 |             Returns mixed inputs, pairs of targets, and lambda
191 |         """
192 |         alpha = max(0.0, alpha)
193 | 
194 |         if alpha == 0.0:
195 |             mixup_lambda = 1.0
196 |         else:
197 |             mixup_lambda = np.random.beta(alpha, alpha)
198 | 
199 |         x = tensors[CPA_REGISTRY_KEYS.X_KEY]
200 |         y_perturbations = tensors[CPA_REGISTRY_KEYS.PERTURBATION_KEY]
201 |         perturbations = tensors[CPA_REGISTRY_KEYS.PERTURBATIONS]
202 |         perturbations_dosages = tensors[CPA_REGISTRY_KEYS.PERTURBATIONS_DOSAGES]
203 | 
204 |         batch_size = x.size()[0]
205 |         index = torch.randperm(batch_size).to(x.device)
206 | 
207 |         mixed_x = mixup_lambda * x + (1. - mixup_lambda) * x[index, :]
208 | 
209 |         tensors[CPA_REGISTRY_KEYS.X_KEY] = mixed_x
210 |         tensors[CPA_REGISTRY_KEYS.X_KEY + '_true'] = x
211 |         tensors[CPA_REGISTRY_KEYS.X_KEY + '_mixup'] = x[index]
212 |         tensors[CPA_REGISTRY_KEYS.PERTURBATION_KEY + '_mixup'] = y_perturbations[index]
213 |         tensors[CPA_REGISTRY_KEYS.PERTURBATIONS + '_mixup'] = perturbations[index]
214 |         tensors[CPA_REGISTRY_KEYS.PERTURBATIONS_DOSAGES + '_mixup'] = perturbations_dosages[index]
215 | 
216 |         for covar, encoder in self.covars_encoder.items():
217 |             tensors[covar + '_mixup'] = tensors[covar][index]
218 | 
219 |         return tensors, mixup_lambda
220 | 
221 |     def _get_inference_input(self, tensors):
222 |         x = tensors[CPA_REGISTRY_KEYS.X_KEY]  # batch_size, n_genes
223 |         perts = {
224 |             'true': tensors[CPA_REGISTRY_KEYS.PERTURBATIONS],
225 |             'mixup': tensors[CPA_REGISTRY_KEYS.PERTURBATIONS + '_mixup']
226 |         }
227 |         perts_doses = {
228 |             'true': tensors[CPA_REGISTRY_KEYS.PERTURBATIONS_DOSAGES],
229 |             'mixup': tensors[CPA_REGISTRY_KEYS.PERTURBATIONS_DOSAGES + '_mixup'],
230 |         }
231 | 
232 |         covars_dict = dict()
233 |         for covar, unique_covars in self.covars_encoder.items():
234 |             encoded_covars = tensors[covar].view(-1, )  # (batch_size,)
235 |             encoded_covars_mixup = tensors[covar + '_mixup'].view(-1, )  # (batch_size,)
236 |             covars_dict[covar] = encoded_covars
237 |             covars_dict[covar + '_mixup'] = encoded_covars_mixup
238 | 
239 |         return dict(
240 |             x=x,
241 |             perts=perts,
242 |             perts_doses=perts_doses,
243 |             covars_dict=covars_dict,
244 |         )
245 | 
246 |     @auto_move_data
247 |     def inference(
248 |             self,
249 |             x,
250 |             perts,
251 |             perts_doses,
252 |             covars_dict,
253 |             mixup_lambda: float = 1.0,
254 |             n_samples: int = 1,
255 |             covars_to_add: Optional[list] = None,
256 |     ):
257 |         batch_size = x.shape[0]
258 | 
259 |         if self.recon_loss in ['nb', 'zinb']:
260 |             # log the input to the variational distribution for numerical stability
261 |             x_ = torch.log(1 + x)
262 | 
263 |             library = torch.log(x.sum(1)).unsqueeze(1)
264 |         else:
265 |             x_ = x
266 |             library = None, None
267 | 
268 |         if self.variational:
269 |             qz, z_basal = self.encoder(x_)
270 |         else:
271 |             qz, z_basal = None, self.encoder(x_)
272 | 
273 |         if self.variational and n_samples > 1:
274 |             sampled_z = qz.sample((n_samples,))
275 |             z_basal = self.encoder.z_transformation(sampled_z)
276 |             if self.recon_loss in ['nb', 'zinb']:
277 |                 library = library.unsqueeze(0).expand(
278 |                     (n_samples, library.size(0), library.size(1))
279 |                 )
280 | 
281 |         z_pert_true = self.pert_network(perts['true'], perts_doses['true'])
282 |         if mixup_lambda < 1.0:
283 |             z_pert_mixup = self.pert_network(perts['mixup'], perts_doses['mixup'])
284 |             z_pert = mixup_lambda * z_pert_true + (1. - mixup_lambda) * z_pert_mixup
285 |         else:
286 |             z_pert = z_pert_true
287 | 
288 |         z_covs = torch.zeros_like(z_basal)  # ([n_samples,] batch_size, n_latent)
289 |         z_covs_wo_batch = torch.zeros_like(z_basal)  # ([n_samples,] batch_size, n_latent)
290 | 
291 |         batch_key = CPA_REGISTRY_KEYS.BATCH_KEY
292 |         
293 |         if covars_to_add is None:
294 |             covars_to_add = list(self.covars_encoder.keys())
295 |             
296 |         for covar, encoder in self.covars_encoder.items():
297 |             if covar in covars_to_add:
298 |                 z_cov = self.covars_embeddings[covar](covars_dict[covar].long())
299 |                 if len(encoder) > 1:
300 |                     z_cov_mixup = self.covars_embeddings[covar](covars_dict[covar + '_mixup'].long())
301 |                     z_cov = mixup_lambda * z_cov + (1. - mixup_lambda) * z_cov_mixup
302 |                 z_cov = z_cov.view(batch_size, self.n_latent)  # batch_size, n_latent
303 |                 z_covs += z_cov
304 |                 
305 |                 if covar != batch_key:
306 |                     z_covs_wo_batch += z_cov
307 | 
308 |         z = z_basal + z_pert + z_covs
309 |         z_corrected = z_basal + z_pert + z_covs_wo_batch
310 |         z_no_pert = z_basal + z_covs
311 |         z_no_pert_corrected = z_basal + z_covs_wo_batch
312 | 
313 |         return dict(
314 |             z=z,
315 |             z_corrected=z_corrected,
316 |             z_no_pert=z_no_pert,
317 |             z_no_pert_corrected=z_no_pert_corrected,
318 |             z_basal=z_basal,
319 |             z_covs=z_covs,
320 |             z_pert=z_pert.sum(dim=1),
321 |             library=library,
322 |             qz=qz,
323 |             mixup_lambda=mixup_lambda,
324 |         )
325 | 
326 |     def _get_generative_input(self, tensors, inference_outputs, **kwargs):
327 |         if 'latent' in kwargs.keys():
328 |             if kwargs['latent'] in inference_outputs.keys(): # z, z_corrected, z_no_pert, z_no_pert_corrected, z_basal
329 |                 z = inference_outputs[kwargs['latent']]
330 |             else:
331 |                 raise Exception('Invalid latent space')
332 |         else:
333 |             z = inference_outputs["z"]
334 |         library = inference_outputs['library']
335 | 
336 |         return dict(
337 |             z=z,
338 |             library=library,
339 |         )
340 | 
341 |     @auto_move_data
342 |     def generative(
343 |             self,
344 |             z,
345 |             library=None,
346 |     ):
347 |         if self.recon_loss == 'nb':
348 |             px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library)
349 |             px_r = torch.exp(self.px_r)
350 | 
351 |             px = NegativeBinomial(mu=px_rate, theta=px_r)
352 | 
353 |         elif self.recon_loss == 'zinb':
354 |             px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library)
355 |             px_r = torch.exp(self.px_r)
356 | 
357 |             px = ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout)
358 | 
359 |         else:
360 |             px_mean, px_var, x_pred = self.decoder(z)
361 | 
362 |             px = Normal(loc=px_mean, scale=px_var.sqrt())
363 | 
364 |         pz = Normal(torch.zeros_like(z), torch.ones_like(z))
365 |         return dict(px=px, pz=pz)
366 | 
367 |     def loss(self, tensors, inference_outputs, generative_outputs):
368 |         """Computes the reconstruction loss (AE) or the ELBO (VAE)"""
369 |         x = tensors[CPA_REGISTRY_KEYS.X_KEY]
370 | 
371 |         px = generative_outputs['px']
372 |         recon_loss = -px.log_prob(x).sum(dim=-1).mean()
373 | 
374 |         if self.variational:
375 |             qz = inference_outputs["qz"]
376 |             pz = generative_outputs['pz']
377 | 
378 |             kl_divergence_z = kl(qz, pz).sum(dim=1)
379 |             kl_loss = kl_divergence_z.mean()
380 |         else:
381 |             from scvi.model import SCVI
382 |             kl_loss = torch.zeros_like(recon_loss)
383 | 
384 |         return recon_loss, kl_loss
385 | 
386 |     def r2_metric(self, tensors, inference_outputs, generative_outputs, mode: str = 'lfc'):
387 |         mode = mode.lower()
388 |         assert mode in ['direct']
389 | 
390 |         x = tensors[CPA_REGISTRY_KEYS.X_KEY]  # batch_size, n_genes
391 |         indices = tensors[CPA_REGISTRY_KEYS.CATEGORY_KEY].view(-1,)
392 | 
393 |         unique_indices = indices.unique()
394 | 
395 |         r2_mean = 0.0
396 |         r2_var = 0.0
397 | 
398 |         px = generative_outputs['px']
399 |         for ind in unique_indices:
400 |             i_mask = indices == ind
401 | 
402 |             x_i = x[i_mask, :]
403 |             if self.recon_loss == 'gauss':
404 |                 x_pred_mean = px.loc[i_mask, :]
405 |                 x_pred_var = px.scale[i_mask, :] ** 2
406 | 
407 |                 if CPA_REGISTRY_KEYS.DEG_MASK_R2 in tensors.keys():
408 |                     deg_mask = tensors[f'{CPA_REGISTRY_KEYS.DEG_MASK_R2}'][i_mask, :]
409 | 
410 |                     x_i *= deg_mask
411 |                     x_pred_mean *= deg_mask
412 |                     x_pred_var *= deg_mask
413 | 
414 |                 x_pred_mean = torch.nan_to_num(x_pred_mean, nan=0, posinf=1e3, neginf=-1e3)
415 |                 x_pred_var = torch.nan_to_num(x_pred_var, nan=0, posinf=1e3, neginf=-1e3)
416 | 
417 |                 r2_mean += torch.nan_to_num(self.metrics['r2_score'](x_pred_mean.mean(0), x_i.mean(0)),
418 |                                         nan=0.0).item()
419 |                 r2_var += torch.nan_to_num(self.metrics['r2_score'](x_pred_var.mean(0), x_i.var(0)),
420 |                                         nan=0.0).item()
421 | 
422 |             elif self.recon_loss in ['nb', 'zinb']:
423 |                 x_i = torch.log(1 + x_i)
424 |                 x_pred = px.mu[i_mask, :]
425 |                 x_pred = torch.log(1 + x_pred)
426 | 
427 |                 x_pred = torch.nan_to_num(x_pred, nan=0, posinf=1e3, neginf=-1e3)
428 | 
429 |                 if CPA_REGISTRY_KEYS.DEG_MASK_R2 in tensors.keys():
430 |                     deg_mask = tensors[f'{CPA_REGISTRY_KEYS.DEG_MASK_R2}'][i_mask, :]
431 | 
432 |                     x_i *= deg_mask
433 |                     x_pred *= deg_mask
434 | 
435 |                 r2_mean += torch.nan_to_num(self.metrics['r2_score'](x_pred.mean(0), x_i.mean(0)),
436 |                                         nan=0.0).item()
437 |                 r2_var += torch.nan_to_num(self.metrics['r2_score'](x_pred.var(0), x_i.var(0)),
438 |                                         nan=0.0).item()
439 | 
440 |         n_unique_indices = len(unique_indices)
441 |         return r2_mean / n_unique_indices, r2_var / n_unique_indices
442 | 
443 |     def disentanglement(self, tensors, inference_outputs, generative_outputs, linear=True):
444 |         z_basal = inference_outputs['z_basal'].detach().cpu().numpy()
445 |         z = inference_outputs['z'].detach().cpu().numpy()
446 | 
447 |         perturbations = tensors[CPA_REGISTRY_KEYS.PERTURBATION_KEY].view(-1, )
448 |         perturbations_names = perturbations.detach().cpu().numpy()
449 | 
450 |         knn_basal = knn_purity(z_basal, perturbations_names.ravel(),
451 |                                n_neighbors=min(perturbations_names.shape[0] - 1, 30))
452 |         knn_after = knn_purity(z, perturbations_names.ravel(),
453 |                                n_neighbors=min(perturbations_names.shape[0] - 1, 30))
454 | 
455 |         for covar, unique_covars in self.covars_encoder.items():
456 |             if len(unique_covars) > 1:
457 |                 target_covars = tensors[f'{covar}'].detach().cpu().numpy()
458 | 
459 |                 knn_basal += knn_purity(z_basal, target_covars.ravel(),
460 |                                         n_neighbors=min(target_covars.shape[0] - 1, 30))
461 | 
462 |                 knn_after += knn_purity(z, target_covars.ravel(),
463 |                                         n_neighbors=min(target_covars.shape[0] - 1, 30))
464 | 
465 |         return knn_basal, knn_after
466 | 
467 |     def get_expression(self, tensors, n_samples=1, covars_to_add=None, latent='z'):
468 |         """Computes gene expression means and std.
469 | 
470 |         Only implemented for the gaussian likelihood.
471 | 
472 |         Parameters
473 |         ----------
474 |         tensors : dict
475 |             Considered inputs
476 | 
477 |         """
478 |         tensors, _ = self.mixup_data(tensors, alpha=0.0)
479 | 
480 |         inference_outputs, generative_outputs = self.forward(
481 |             tensors,
482 |             inference_kwargs={'n_samples': n_samples, 'covars_to_add': covars_to_add},
483 |             get_generative_input_kwargs={'latent': latent},
484 |             compute_loss=False,
485 |         )
486 | 
487 |         z = inference_outputs['z']
488 |         z_corrected = inference_outputs['z_corrected']
489 |         z_no_pert = inference_outputs['z_no_pert']
490 |         z_no_pert_corrected = inference_outputs['z_no_pert_corrected']
491 |         z_basal = inference_outputs['z_basal']
492 | 
493 |         px = generative_outputs['px']
494 | 
495 |         if self.recon_loss == 'gauss':
496 |             output_key = 'loc'
497 |         else:
498 |             output_key = 'mu'
499 | 
500 |         reconstruction = getattr(px, output_key)
501 | 
502 |         return dict(
503 |             px=reconstruction,
504 |             z=z,
505 |             z_corrected=z_corrected,
506 |             z_no_pert=z_no_pert,
507 |             z_no_pert_corrected=z_no_pert_corrected,
508 |             z_basal=z_basal,
509 |         )
510 | 
511 |     def get_pert_embeddings(self, tensors, **inference_kwargs):
512 |         inputs = self._get_inference_input(tensors)
513 |         drugs = inputs['perts']
514 |         doses = inputs['perts_doses']
515 | 
516 |         return self.pert_network(drugs, doses)
517 | 
--------------------------------------------------------------------------------
/Supp_code/Modified_CPA/_utils.py:
--------------------------------------------------------------------------------
  1 | from typing import List, Dict
  2 | 
  3 | import torch
  4 | import torch.nn as nn
  5 | import torch.nn.functional as F
  6 | from scvi.distributions import NegativeBinomial
  7 | 
  8 | from scvi.nn import FCLayers
  9 | from torch.distributions import Normal
 10 | from typing import Optional
 11 | 
 12 | 
 13 | class _REGISTRY_KEYS:
 14 |     X_KEY: str = "X"
 15 |     X_CTRL_KEY: str = None
 16 |     BATCH_KEY: str = None
 17 |     CATEGORY_KEY: str = "cpa_category"
 18 |     PERTURBATION_KEY: str = None
 19 |     PERTURBATION_DOSAGE_KEY: str = None
 20 |     PERTURBATIONS: str = "perts"
 21 |     PERTURBATIONS_DOSAGES: str = "perts_doses"
 22 |     SIZE_FACTOR_KEY: str = "size_factor"
 23 |     CAT_COV_KEYS: List[str] = []
 24 |     MAX_COMB_LENGTH: int = 2
 25 |     CONTROL_KEY: str = None
 26 |     DEG_MASK: str = None
 27 |     DEG_MASK_R2: str = None
 28 |     PADDING_IDX: int = 0
 29 | 
 30 | 
 31 | CPA_REGISTRY_KEYS = _REGISTRY_KEYS()
 32 | 
 33 | 
 34 | class VanillaEncoder(nn.Module):
 35 |     def __init__(
 36 |             self,
 37 |             n_input,
 38 |             n_output,
 39 |             n_hidden,
 40 |             n_layers,
 41 |             n_cat_list,
 42 |             use_layer_norm=True,
 43 |             use_batch_norm=False,
 44 |             output_activation: str = 'linear',
 45 |             dropout_rate: float = 0.1,
 46 |             activation_fn=nn.ReLU,
 47 |     ):
 48 |         super().__init__()
 49 |         self.n_output = n_output
 50 |         self.output_activation = output_activation
 51 | 
 52 |         self.network = FCLayers(
 53 |             n_in=n_input,
 54 |             n_out=n_hidden,
 55 |             n_cat_list=n_cat_list,
 56 |             n_layers=n_layers,
 57 |             n_hidden=n_hidden,
 58 |             use_layer_norm=use_layer_norm,
 59 |             use_batch_norm=use_batch_norm,
 60 |             dropout_rate=dropout_rate,
 61 |             activation_fn=activation_fn,
 62 |         )
 63 |         self.z = nn.Linear(n_hidden, n_output)
 64 | 
 65 |     def forward(self, inputs, *cat_list):
 66 |         if self.output_activation == 'linear':
 67 |             z = self.z(self.network(inputs, *cat_list))
 68 |         elif self.output_activation == 'relu':
 69 |             z = F.relu(self.z(self.network(inputs, *cat_list)))
 70 |         else:
 71 |             raise ValueError(f'Unknown output activation: {self.output_activation}')
 72 |         return z
 73 | 
 74 | 
 75 | class GeneralizedSigmoid(nn.Module):
 76 |     """
 77 |     Sigmoid, log-sigmoid or linear functions for encoding dose-response for
 78 |     drug perurbations.
 79 |     """
 80 | 
 81 |     def __init__(self, n_drugs, non_linearity='sigmoid'):
 82 |         """Sigmoid modeling of continuous variable.
 83 |         Params
 84 |         ------
 85 |         nonlin : str (default: logsigm)
 86 |             One of logsigm, sigm.
 87 |         """
 88 |         super(GeneralizedSigmoid, self).__init__()
 89 |         self.non_linearity = non_linearity
 90 |         self.n_drugs = n_drugs
 91 | 
 92 |         self.beta = torch.nn.Parameter(
 93 |             torch.ones(1, n_drugs),
 94 |             requires_grad=True
 95 |         )
 96 |         self.bias = torch.nn.Parameter(
 97 |             torch.zeros(1, n_drugs),
 98 |             requires_grad=True
 99 |         )
100 | 
101 |         self.vmap = None
102 | 
103 |     def forward(self, x, y):
104 |         """
105 |             Parameters
106 |             ----------
107 |             x: (batch_size, max_comb_len)
108 |             y: (batch_size, max_comb_len)
109 |         """
110 |         y = y.long()
111 |         if self.non_linearity == 'logsigm':
112 |             bias = self.bias[0][y]
113 |             beta = self.beta[0][y]
114 |             c0 = bias.sigmoid()
115 |             return (torch.log1p(x) * beta + bias).sigmoid() - c0
116 |         elif self.non_linearity == 'sigm':
117 |             bias = self.bias[0][y]
118 |             beta = self.beta[0][y]
119 |             c0 = bias.sigmoid()
120 |             return (x * beta + bias).sigmoid() - c0
121 |         else:
122 |             return x
123 | 
124 |     def one_drug(self, x, i):
125 |         if self.non_linearity == 'logsigm':
126 |             c0 = self.bias[0][i].sigmoid()
127 |             return (torch.log1p(x) * self.beta[0][i] + self.bias[0][i]).sigmoid() - c0
128 |         elif self.non_linearity == 'sigm':
129 |             c0 = self.bias[0][i].sigmoid()
130 |             return (x * self.beta[0][i] + self.bias[0][i]).sigmoid() - c0
131 |         else:
132 |             return x
133 | 
134 | 
135 | class PerturbationNetwork(nn.Module):
136 |     def __init__(self,
137 |                  n_perts,
138 |                  n_latent,
139 |                  doser_type='logsigm',
140 |                  n_hidden=None,
141 |                  n_layers=None,
142 |                  dropout_rate: float = 0.0,
143 |                  drug_embeddings=None,):
144 |         super().__init__()
145 |         self.n_latent = n_latent
146 |         
147 |         if drug_embeddings is not None:
148 |             self.pert_embedding = drug_embeddings
149 |             self.pert_transformation = nn.Linear(drug_embeddings.embedding_dim, n_latent)
150 |             self.use_rdkit = True
151 |         else:
152 |             self.use_rdkit = False
153 |             self.pert_embedding = nn.Embedding(n_perts, n_latent, padding_idx=CPA_REGISTRY_KEYS.PADDING_IDX)
154 |             
155 |         self.doser_type = doser_type
156 |         if self.doser_type == 'mlp':
157 |             self.dosers = nn.ModuleList()
158 |             for _ in range(n_perts):
159 |                 self.dosers.append(
160 |                     FCLayers(
161 |                         n_in=1,
162 |                         n_out=1,
163 |                         n_hidden=n_hidden,
164 |                         n_layers=n_layers,
165 |                         use_batch_norm=False,
166 |                         use_layer_norm=True,
167 |                         dropout_rate=dropout_rate
168 |                     )
169 |                 )
170 |         else:
171 |             self.dosers = GeneralizedSigmoid(n_perts, non_linearity=self.doser_type)
172 | 
173 |     def forward(self, perts, dosages):
174 |         """
175 |             perts: (batch_size, max_comb_len)
176 |             dosages: (batch_size, max_comb_len)
177 |         """
178 |         bs, max_comb_len = perts.shape
179 |         perts = perts.long()
180 |         scaled_dosages = self.dosers(dosages, perts)  # (batch_size, max_comb_len)
181 | 
182 |         drug_embeddings = self.pert_embedding(perts)  # (batch_size, max_comb_len, n_drug_emb_dim)
183 | 
184 |         if self.use_rdkit:
185 |             drug_embeddings = self.pert_transformation(drug_embeddings.view(bs * max_comb_len, -1)).view(bs, max_comb_len, -1)
186 | 
187 |         z_drugs = torch.einsum('bm,bme->bme', [scaled_dosages, drug_embeddings])  # (batch_size, n_latent)
188 | 
189 |         z_drugs = torch.einsum('bmn,bm->bmn', z_drugs, (perts != CPA_REGISTRY_KEYS.PADDING_IDX).int()).sum(dim=1)  # mask single perts
190 | 
191 |         return z_drugs # (batch_size, n_latent)
192 | 
193 | class FocalLoss(nn.Module):
194 |     """ Inspired by https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py
195 | 
196 |     Focal Loss, as described in https://arxiv.org/abs/1708.02002.
197 |     It is essentially an enhancement to cross entropy loss and is
198 |     useful for classification tasks when there is a large class imbalance.
199 |     x is expected to contain raw, unnormalized scores for each class.
200 |     y is expected to contain class labels.
201 |     Shape:
202 |         - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0.
203 |         - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0.
204 |     """
205 | 
206 |     def __init__(self,
207 |                  alpha: Optional[torch.Tensor] = None,
208 |                  gamma: float = 2.,
209 |                  reduction: str = 'mean',
210 |                  ):
211 |         """
212 |         Args:
213 |             alpha (Tensor, optional): Weights for each class. Defaults to None.
214 |             gamma (float, optional): A constant, as described in the paper.
215 |                 Defaults to 0.
216 |             reduction (str, optional): 'mean', 'sum' or 'none'.
217 |                 Defaults to 'mean'.
218 |         """
219 |         if reduction not in ('mean', 'sum', 'none'):
220 |             raise ValueError(
221 |                 'Reduction must be one of: "mean", "sum", "none".')
222 | 
223 |         super().__init__()
224 |         self.alpha = alpha
225 |         self.gamma = gamma
226 |         self.reduction = reduction
227 | 
228 |         self.nll_loss = nn.NLLLoss(
229 |             weight=alpha, reduction='none')
230 | 
231 |     def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
232 |         if len(y_true) == 0:
233 |             return torch.tensor(0.)
234 | 
235 |         # compute weighted cross entropy term: -alpha * log(pt)
236 |         # (alpha is already part of self.nll_loss)
237 |         log_p = F.log_softmax(y_pred, dim=-1)
238 |         ce = self.nll_loss(log_p, y_true)
239 | 
240 |         # get true class column from each row
241 |         all_rows = torch.arange(len(y_pred))
242 |         log_pt = log_p[all_rows, y_true]
243 | 
244 |         # compute focal term: (1 - pt)^gamma
245 |         pt = log_pt.exp()
246 |         focal_term = (1 - pt) ** self.gamma
247 | 
248 |         # the full loss: -alpha * ((1 - pt)^gamma) * log(pt)
249 |         loss = focal_term * ce
250 | 
251 |         if self.reduction == 'mean':
252 |             loss = loss.mean()
253 |         elif self.reduction == 'sum':
254 |             loss = loss.sum()
255 | 
256 |         return loss
--------------------------------------------------------------------------------
/Supp_code/Modified_GEARS/data_utils.py:
--------------------------------------------------------------------------------
  1 | import pandas as pd
  2 | import numpy as np
  3 | import scanpy as sc
  4 | from random import shuffle
  5 | sc.settings.verbosity = 0
  6 | from tqdm import tqdm
  7 | import requests
  8 | import os, sys
  9 | 
 10 | import warnings
 11 | warnings.filterwarnings("ignore")
 12 | 
 13 | from .utils import parse_single_pert, parse_combo_pert, parse_any_pert, print_sys
 14 | 
 15 | def rank_genes_groups_by_cov(
 16 |     adata,
 17 |     groupby,
 18 |     control_group,
 19 |     covariate,
 20 |     pool_doses=False,
 21 |     n_genes=50,
 22 |     rankby_abs=True,
 23 |     key_added='rank_genes_groups_cov',
 24 |     return_dict=False,
 25 | ):
 26 | 
 27 |     gene_dict = {}
 28 |     cov_categories = adata.obs[covariate].unique()
 29 |     for cov_cat in cov_categories:
 30 |         #name of the control group in the groupby obs column
 31 |         control_group_cov = '_'.join([cov_cat, control_group])
 32 | 
 33 |         #subset adata to cells belonging to a covariate category
 34 |         adata_cov = adata[adata.obs[covariate]==cov_cat]
 35 | 
 36 |         #compute DEGs
 37 |         sc.tl.rank_genes_groups(
 38 |             adata_cov,
 39 |             groupby=groupby,
 40 |             reference=control_group_cov,
 41 |             rankby_abs=rankby_abs,
 42 |             n_genes=n_genes,
 43 |             use_raw=False
 44 |         )
 45 | 
 46 |         #add entries to dictionary of gene sets
 47 |         de_genes = pd.DataFrame(adata_cov.uns['rank_genes_groups']['names'])
 48 |         for group in de_genes:
 49 |             gene_dict[group] = de_genes[group].tolist()
 50 | 
 51 |     adata.uns[key_added] = gene_dict
 52 | 
 53 |     if return_dict:
 54 |         return gene_dict
 55 | 
 56 |     
 57 | def get_DE_genes(adata, skip_calc_de):
 58 |     adata.obs.loc[:, 'dose_val'] = adata.obs.condition.apply(lambda x: '1+1' if len(x.split('+')) == 2 else '1')
 59 |     adata.obs.loc[:, 'control'] = adata.obs.condition.apply(lambda x: 0 if len(x.split('+')) == 2 else 1)
 60 |     adata.obs.loc[:, 'condition_name'] =  adata.obs.apply(lambda x: '_'.join([x.cell_type, x.condition, x.dose_val]), axis = 1) 
 61 |     
 62 |     adata.obs = adata.obs.astype('category')
 63 |     if not skip_calc_de:
 64 |         rank_genes_groups_by_cov(adata, 
 65 |                          groupby='condition_name', 
 66 |                          covariate='cell_type', 
 67 |                          control_group='ctrl_1', 
 68 |                          n_genes=len(adata.var),
 69 |                          key_added = 'rank_genes_groups_cov_all')
 70 |     return adata
 71 | 
 72 | def get_dropout_non_zero_genes(adata):
 73 |     
 74 |     # calculate mean expression for each condition
 75 |     unique_conditions = adata.obs.condition.unique()
 76 |     conditions2index = {}
 77 |     for i in unique_conditions:
 78 |         conditions2index[i] = np.where(adata.obs.condition == i)[0]
 79 | 
 80 |     condition2mean_expression = {}
 81 |     for i, j in conditions2index.items():
 82 |         condition2mean_expression[i] = np.mean(adata.X[j], axis = 0)
 83 |     pert_list = np.array(list(condition2mean_expression.keys()))
 84 |     mean_expression = np.array(list(condition2mean_expression.values())).reshape(len(adata.obs.condition.unique()), adata.X.toarray().shape[1])
 85 |     ctrl = mean_expression[np.where(pert_list == 'ctrl')[0]]
 86 |     
 87 |     ## in silico modeling and upperbounding
 88 |     pert2pert_full_id = dict(adata.obs[['condition', 'condition_name']].values)
 89 |     pert_full_id2pert = dict(adata.obs[['condition_name', 'condition']].values)
 90 | 
 91 |     gene_id2idx = dict(zip(adata.var.index.values, range(len(adata.var))))
 92 |     gene_idx2id = dict(zip(range(len(adata.var)), adata.var.index.values))
 93 | 
 94 |     non_zeros_gene_idx = {}
 95 |     top_non_dropout_de_20 = {}
 96 |     top_non_zero_de_20 = {}
 97 |     non_dropout_gene_idx = {}
 98 | 
 99 |     for pert in adata.uns['rank_genes_groups_cov_all'].keys():
100 |         p = pert_full_id2pert[pert]
101 |         X = np.mean(adata[adata.obs.condition == p].X, axis = 0)
102 | 
103 |         non_zero = np.where(np.array(X)[0] != 0)[0]
104 |         zero = np.where(np.array(X)[0] == 0)[0]
105 |         true_zeros = np.intersect1d(zero, np.where(np.array(ctrl)[0] == 0)[0])
106 |         non_dropouts = np.concatenate((non_zero, true_zeros))
107 | 
108 |         top = adata.uns['rank_genes_groups_cov_all'][pert]
109 |         gene_idx_top = [gene_id2idx[i] for i in top]
110 | 
111 |         non_dropout_20 = [i for i in gene_idx_top if i in non_dropouts][:20]
112 |         non_dropout_20_gene_id = [gene_idx2id[i] for i in non_dropout_20]
113 | 
114 |         non_zero_20 = [i for i in gene_idx_top if i in non_zero][:20]
115 |         non_zero_20_gene_id = [gene_idx2id[i] for i in non_zero_20]
116 | 
117 |         non_zeros_gene_idx[pert] = np.sort(non_zero)
118 |         non_dropout_gene_idx[pert] = np.sort(non_dropouts)
119 |         top_non_dropout_de_20[pert] = np.array(non_dropout_20_gene_id)
120 |         top_non_zero_de_20[pert] = np.array(non_zero_20_gene_id)
121 |         
122 |     non_zero = np.where(np.array(X)[0] != 0)[0]
123 |     zero = np.where(np.array(X)[0] == 0)[0]
124 |     true_zeros = np.intersect1d(zero, np.where(np.array(ctrl)[0] == 0)[0])
125 |     non_dropouts = np.concatenate((non_zero, true_zeros))
126 |     
127 |     adata.uns['top_non_dropout_de_20'] = top_non_dropout_de_20
128 |     adata.uns['non_dropout_gene_idx'] = non_dropout_gene_idx
129 |     adata.uns['non_zeros_gene_idx'] = non_zeros_gene_idx
130 |     adata.uns['top_non_zero_de_20'] = top_non_zero_de_20
131 |     
132 |     return adata
133 | 
134 | 
135 | class DataSplitter():
136 |     """
137 |     Class for handling data splitting. This class is able to generate new
138 |     data splits and assign them as a new attribute to the data file.
139 |     """
140 |     def __init__(self, adata, split_type='single', seen=0):
141 |         self.adata = adata
142 |         self.split_type = split_type
143 |         self.seen = seen
144 | 
145 |     def split_data(self, test_size=0.1, test_pert_genes=None,
146 |                    test_perts=None, split_name='split', seed=None, val_size = 0.1,
147 |                    train_gene_set_size = 0.75, combo_seen2_train_frac = 0.75, only_test_set_perts = False):
148 |         """
149 |         Split dataset and adds split as a column to the dataframe
150 |         Note: split categories are train, val, test
151 |         """
152 |         np.random.seed(seed=seed)
153 |         unique_perts = [p for p in self.adata.obs['condition'].unique() if
154 |                         p != 'ctrl']
155 |         
156 |         if self.split_type == 'simulation':
157 |             train, test, test_subgroup = self.get_simulation_split(unique_perts,
158 |                                                                   train_gene_set_size,
159 |                                                                   combo_seen2_train_frac, 
160 |                                                                   seed, test_perts, only_test_set_perts)
161 |             train, val, val_subgroup = self.get_simulation_split(train,
162 |                                                                   0.9,
163 |                                                                   0.9,
164 |                                                                   seed)
165 |             ## adding back ctrl to train...
166 |             train.append('ctrl')
167 |         elif self.split_type == 'simulation_single':
168 |             train, test, test_subgroup = self.get_simulation_split_single(unique_perts,
169 |                                                                   train_gene_set_size,
170 |                                                                   seed, test_perts, only_test_set_perts)
171 |             train, val, val_subgroup = self.get_simulation_split_single(train,
172 |                                                                   0.9,
173 |                                                                   seed)
174 |         elif self.split_type == 'no_test':
175 |             print('test_pert_genes',str(test_pert_genes))
176 |             print('test_perts',str(test_perts))
177 |             
178 |             train, val = self.get_split_list(unique_perts,
179 |                                           test_pert_genes=test_pert_genes,
180 |                                           test_perts=test_perts,
181 |                                           test_size=test_size)      
182 |         else:
183 |             train, test = self.get_split_list(unique_perts,
184 |                                           test_pert_genes=test_pert_genes,
185 |                                           test_perts=test_perts,
186 |                                           test_size=test_size)
187 |             
188 |             train, val = self.get_split_list(train, test_size=val_size)
189 | 
190 |         map_dict = {x: 'train' for x in train}
191 |         map_dict.update({x: 'val' for x in val})
192 |         if self.split_type != 'no_test':
193 |             map_dict.update({x: 'test' for x in test})
194 |         map_dict.update({'ctrl': 'train'})
195 | 
196 |         self.adata.obs[split_name] = self.adata.obs['condition'].map(map_dict)
197 | 
198 |         if self.split_type == 'simulation':
199 |             return self.adata, {'test_subgroup': test_subgroup, 
200 |                                 'val_subgroup': val_subgroup
201 |                                }
202 |         else:
203 |             return self.adata
204 |     
205 |     def get_simulation_split_single(self, pert_list, train_gene_set_size = 0.85, seed = 1, test_set_perts = None, only_test_set_perts = False):
206 |         unique_pert_genes = self.get_genes_from_perts(pert_list)
207 |         
208 |         pert_train = []
209 |         pert_test = []
210 |         np.random.seed(seed=seed)
211 |         
212 |         if only_test_set_perts and (test_set_perts is not None):
213 |             ood_genes = np.array(test_set_perts)
214 |             train_gene_candidates = np.setdiff1d(unique_pert_genes, ood_genes)
215 |         else:
216 |             ## a pre-specified list of genes
217 |             train_gene_candidates = np.random.choice(unique_pert_genes,
218 |                                                     int(len(unique_pert_genes) * train_gene_set_size), replace = False)
219 | 
220 |             if test_set_perts is not None:
221 |                 num_overlap = len(np.intersect1d(train_gene_candidates, test_set_perts))
222 |                 train_gene_candidates = train_gene_candidates[~np.isin(train_gene_candidates, test_set_perts)]
223 |                 ood_genes_exclude_test_set = np.setdiff1d(unique_pert_genes, np.union1d(train_gene_candidates, test_set_perts))
224 |                 train_set_addition = np.random.choice(ood_genes_exclude_test_set, num_overlap, replace = False)
225 |                 train_gene_candidates = np.concatenate((train_gene_candidates, train_set_addition))
226 |                 
227 |             ## ood genes
228 |             ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates)  
229 |         
230 |         pert_single_train = self.get_perts_from_genes(train_gene_candidates, pert_list,'single')
231 |         unseen_single = self.get_perts_from_genes(ood_genes, pert_list, 'single')
232 |         assert len(unseen_single) + len(pert_single_train) == len(pert_list)
233 |         
234 |         return pert_single_train, unseen_single, {'unseen_single': unseen_single}
235 |     
236 |     def get_simulation_split(self, pert_list, train_gene_set_size = 0.85, combo_seen2_train_frac = 0.85, seed = 1, test_set_perts = None, only_test_set_perts = False):
237 |         
238 |         unique_pert_genes = self.get_genes_from_perts(pert_list)
239 |         
240 |         pert_train = []
241 |         pert_test = []
242 |         np.random.seed(seed=seed)
243 |         
244 |         if only_test_set_perts and (test_set_perts is not None):
245 |             ood_genes = np.array(test_set_perts)
246 |             train_gene_candidates = np.setdiff1d(unique_pert_genes, ood_genes)
247 |         else:
248 |             ## a pre-specified list of genes
249 |             train_gene_candidates = np.random.choice(unique_pert_genes,
250 |                                                     int(len(unique_pert_genes) * train_gene_set_size), replace = False)
251 | 
252 |             if test_set_perts is not None:
253 |                 num_overlap = len(np.intersect1d(train_gene_candidates, test_set_perts))
254 |                 train_gene_candidates = train_gene_candidates[~np.isin(train_gene_candidates, test_set_perts)]
255 |                 ood_genes_exclude_test_set = np.setdiff1d(unique_pert_genes, np.union1d(train_gene_candidates, test_set_perts))
256 |                 train_set_addition = np.random.choice(ood_genes_exclude_test_set, num_overlap, replace = False)
257 |                 train_gene_candidates = np.concatenate((train_gene_candidates, train_set_addition))
258 |                 
259 |             ## ood genes
260 |             ood_genes = np.setdiff1d(unique_pert_genes, train_gene_candidates)                
261 |         
262 |         pert_single_train = self.get_perts_from_genes(train_gene_candidates, pert_list,'single')
263 |         pert_combo = self.get_perts_from_genes(train_gene_candidates, pert_list,'combo')
264 |         pert_train.extend(pert_single_train)
265 |         
266 |         ## the combo set with one of them in OOD
267 |         combo_seen1 = [x for x in pert_combo if len([t for t in x.split('+') if
268 |                                      t in train_gene_candidates]) == 1]
269 |         pert_test.extend(combo_seen1)
270 |         
271 |         pert_combo = np.setdiff1d(pert_combo, combo_seen1)
272 |         ## randomly sample the combo seen 2 as a test set, the rest in training set
273 |         np.random.seed(seed=seed)
274 |         pert_combo_train = np.random.choice(pert_combo, int(len(pert_combo) * combo_seen2_train_frac), replace = False)
275 |        
276 |         combo_seen2 = np.setdiff1d(pert_combo, pert_combo_train).tolist()
277 |         pert_test.extend(combo_seen2)
278 |         pert_train.extend(pert_combo_train)
279 |         
280 |         ## unseen single
281 |         unseen_single = self.get_perts_from_genes(ood_genes, pert_list, 'single')
282 |         combo_ood = self.get_perts_from_genes(ood_genes, pert_list, 'combo')
283 |         pert_test.extend(unseen_single)
284 |         
285 |         ## here only keeps the seen 0, since seen 1 is tackled above
286 |         combo_seen0 = [x for x in combo_ood if len([t for t in x.split('+') if
287 |                                      t in train_gene_candidates]) == 0]
288 |         pert_test.extend(combo_seen0)
289 |         assert len(combo_seen1) + len(combo_seen0) + len(unseen_single) + len(pert_train) + len(combo_seen2) == len(pert_list)
290 | 
291 |         return pert_train, pert_test, {'combo_seen0': combo_seen0,
292 |                                        'combo_seen1': combo_seen1,
293 |                                        'combo_seen2': combo_seen2,
294 |                                        'unseen_single': unseen_single}
295 |         
296 |     def get_split_list(self, pert_list, test_size=0.1,
297 |                        test_pert_genes=None, test_perts=None,
298 |                        hold_outs=True):
299 |         """
300 |         Splits a given perturbation list into train and test with no shared
301 |         perturbations
302 |         """
303 | 
304 |         single_perts = [p for p in pert_list if 'ctrl' in p and p != 'ctrl']
305 |         combo_perts = [p for p in pert_list if 'ctrl' not in p]
306 |         unique_pert_genes = self.get_genes_from_perts(pert_list)
307 |         hold_out = []
308 | 
309 |         if test_pert_genes is None:
310 |             test_pert_genes = np.random.choice(unique_pert_genes,
311 |                                         int(len(single_perts) * test_size))
312 | 
313 |         # Only single unseen genes (in test set)
314 |         # Train contains both single and combos
315 |         if self.split_type == 'single' or self.split_type == 'single_only':
316 |             test_perts = self.get_perts_from_genes(test_pert_genes, pert_list,
317 |                                                    'single')
318 |             if self.split_type == 'single_only':
319 |                 # Discard all combos
320 |                 hold_out = combo_perts
321 |             else:
322 |                 # Discard only those combos which contain test genes
323 |                 hold_out = self.get_perts_from_genes(test_pert_genes, pert_list,
324 |                                                      'combo')
325 | 
326 |         elif self.split_type == 'combo':
327 |             if self.seen == 0:
328 |                 # NOTE: This can reduce the dataset size!
329 |                 # To prevent this set 'holdouts' to False, this will cause
330 |                 # the test set to have some perturbations with 1 gene seen
331 |                 single_perts = self.get_perts_from_genes(test_pert_genes,
332 |                                                          pert_list, 'single')
333 |                 combo_perts = self.get_perts_from_genes(test_pert_genes,
334 |                                                         pert_list, 'combo')
335 | 
336 |                 if hold_outs:
337 |                     # This just checks that none of the combos have 2 seen genes
338 |                     hold_out = [t for t in combo_perts if
339 |                                 len([t for t in t.split('+') if
340 |                                      t not in test_pert_genes]) > 0]
341 |                 combo_perts = [c for c in combo_perts if c not in hold_out]
342 |                 test_perts = single_perts + combo_perts
343 | 
344 |             elif self.seen == 1:
345 |                 # NOTE: This can reduce the dataset size!
346 |                 # To prevent this set 'holdouts' to False, this will cause
347 |                 # the test set to have some perturbations with 2 genes seen
348 |                 single_perts = self.get_perts_from_genes(test_pert_genes,
349 |                                                          pert_list, 'single')
350 |                 combo_perts = self.get_perts_from_genes(test_pert_genes,
351 |                                                         pert_list, 'combo')
352 | 
353 |                 if hold_outs:
354 |                     # This just checks that none of the combos have 2 seen genes
355 |                     hold_out = [t for t in combo_perts if
356 |                                 len([t for t in t.split('+') if
357 |                                      t not in test_pert_genes]) > 1]
358 |                 combo_perts = [c for c in combo_perts if c not in hold_out]
359 |                 test_perts = single_perts + combo_perts
360 | 
361 |             elif self.seen == 2:
362 |                 if test_perts is None:
363 |                     test_perts = np.random.choice(combo_perts,
364 |                                      int(len(combo_perts) * test_size))       
365 |                 else:
366 |                     test_perts = np.array(test_perts)
367 |         else:
368 |             if test_perts is None:
369 |                 test_perts = np.random.choice(combo_perts,
370 |                                     int(len(combo_perts) * test_size))
371 |         
372 |         train_perts = [p for p in pert_list if (p not in test_perts)
373 |                                         and (p not in hold_out)]
374 |         return train_perts, test_perts
375 | 
376 |     def get_perts_from_genes(self, genes, pert_list, type_='both'):
377 |         """
378 |         Returns all single/combo/both perturbations that include a gene
379 |         """
380 | 
381 |         single_perts = [p for p in pert_list if ('ctrl' in p) and (p != 'ctrl')]
382 |         combo_perts = [p for p in pert_list if 'ctrl' not in p]
383 |         
384 |         perts = []
385 |         
386 |         if type_ == 'single':
387 |             pert_candidate_list = single_perts
388 |         elif type_ == 'combo':
389 |             pert_candidate_list = combo_perts
390 |         elif type_ == 'both':
391 |             pert_candidate_list = pert_list
392 |             
393 |         for p in pert_candidate_list:
394 |             for g in genes:
395 |                 if g in parse_any_pert(p):
396 |                     perts.append(p)
397 |                     break
398 |         return perts
399 | 
400 |     def get_genes_from_perts(self, perts):
401 |         """
402 |         Returns list of genes involved in a given perturbation list
403 |         """
404 | 
405 |         if type(perts) is str:
406 |             perts = [perts]
407 |         gene_list = [p.split('+') for p in np.unique(perts)]
408 |         gene_list = [item for sublist in gene_list for item in sublist]
409 |         gene_list = [g for g in gene_list if g != 'ctrl']
410 |         return np.unique(gene_list)
--------------------------------------------------------------------------------
/Supp_code/Modified_GEARS/model.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | import torch.nn as nn
  3 | import torch.nn.functional as F
  4 | from torch.nn import Sequential, Linear, ReLU
  5 | 
  6 | from torch_geometric.nn import SGConv
  7 | 
  8 | class MLP(torch.nn.Module):
  9 | 
 10 |     def __init__(self, sizes, batch_norm=True, last_layer_act="linear"):
 11 |         """
 12 |         Multi-layer perceptron
 13 |         :param sizes: list of sizes of the layers
 14 |         :param batch_norm: whether to use batch normalization
 15 |         :param last_layer_act: activation function of the last layer
 16 | 
 17 |         """
 18 |         super(MLP, self).__init__()
 19 |         layers = []
 20 |         for s in range(len(sizes) - 1):
 21 |             layers = layers + [
 22 |                 torch.nn.Linear(sizes[s], sizes[s + 1]),
 23 |                 torch.nn.BatchNorm1d(sizes[s + 1])
 24 |                 if batch_norm and s < len(sizes) - 1 else None,
 25 |                 torch.nn.ReLU()
 26 |             ]
 27 | 
 28 |         layers = [l for l in layers if l is not None][:-1]
 29 |         self.activation = last_layer_act
 30 |         self.network = torch.nn.Sequential(*layers)
 31 |         self.relu = torch.nn.ReLU()
 32 |     def forward(self, x):
 33 |         return self.network(x)
 34 | 
 35 | 
 36 | class First_Sub_task_MLP(torch.nn.Module):
 37 | 
 38 |     def __init__(self, sizes, batch_norm=True, last_layer_act="linear"):
 39 |         """
 40 |         Multi-layer perceptron
 41 |         :param sizes: list of sizes of the layers
 42 |         :param batch_norm: whether to use batch normalization
 43 |         :param last_layer_act: activation function of the last layer
 44 | 
 45 |         """
 46 |         super().__init__()
 47 |         layers = []
 48 |         for s in range(len(sizes) - 1):
 49 |             layers = layers + [
 50 |                 torch.nn.Linear(sizes[s], sizes[s + 1]),
 51 |                 torch.nn.ReLU()
 52 |             ]
 53 |         layers = [l for l in layers if l is not None][:-1]
 54 |         self.activation = last_layer_act
 55 |         self.network = torch.nn.Sequential(*layers)
 56 |     def forward(self, x):
 57 |         return self.network(x)
 58 | 
 59 | class GEARS_Model(torch.nn.Module):
 60 |     """
 61 |     GEARS model
 62 | 
 63 |     """
 64 | 
 65 |     def __init__(self, args):
 66 |         """
 67 |         :param args: arguments dictionary
 68 |         """
 69 | 
 70 |         super(GEARS_Model, self).__init__()
 71 |         self.args = args       
 72 |         self.num_genes = args['num_genes']
 73 |         self.num_perts = args['num_perts']
 74 |         hidden_size = args['hidden_size']
 75 |         self.uncertainty = args['uncertainty']
 76 |         self.num_layers = args['num_go_gnn_layers']
 77 |         self.indv_out_hidden_size = args['decoder_hidden_size']
 78 |         self.num_layers_gene_pos = args['num_gene_gnn_layers']
 79 |         self.no_perturb = args['no_perturb']
 80 |         self.pert_emb_lambda = 0.2
 81 |         
 82 |         # perturbation positional embedding added only to the perturbed genes
 83 |         self.pert_w = nn.Linear(1, hidden_size)
 84 |            
 85 |         # gene/globel perturbation embedding dictionary lookup            
 86 |         self.gene_emb = nn.Embedding(self.num_genes, hidden_size, max_norm=True)
 87 |         self.pert_emb = nn.Embedding(self.num_perts, hidden_size, max_norm=True)
 88 |         
 89 |         # transformation layer
 90 |         self.emb_trans = nn.ReLU()
 91 |         self.pert_base_trans = nn.ReLU()
 92 |         self.transform = nn.ReLU()
 93 |         self.emb_trans_v2 = MLP([hidden_size, hidden_size, hidden_size], last_layer_act='ReLU')
 94 |         self.pert_fuse = MLP([hidden_size, hidden_size, hidden_size], last_layer_act='ReLU')
 95 |         self.first_sub_task = First_Sub_task_MLP([hidden_size, hidden_size // 2, 1], last_layer_act='Sigmoid')
 96 |         
 97 |         
 98 |         # gene co-expression GNN
 99 |         self.G_coexpress = args['G_coexpress'].to(args['device'])
100 |         self.G_coexpress_weight = args['G_coexpress_weight'].to(args['device'])
101 | 
102 |         self.emb_pos = nn.Embedding(self.num_genes, hidden_size, max_norm=True)
103 |         self.layers_emb_pos = torch.nn.ModuleList()
104 |         for i in range(1, self.num_layers_gene_pos + 1):
105 |             self.layers_emb_pos.append(SGConv(hidden_size, hidden_size, 1))
106 |         
107 |         ### perturbation gene ontology GNN
108 |         self.G_sim = args['G_go'].to(args['device'])
109 |         self.G_sim_weight = args['G_go_weight'].to(args['device'])
110 | 
111 |         self.sim_layers = torch.nn.ModuleList()
112 |         for i in range(1, self.num_layers + 1):
113 |             self.sim_layers.append(SGConv(hidden_size, hidden_size, 1))
114 |         
115 |         # decoder shared MLP
116 |         self.recovery_w = MLP([hidden_size, hidden_size*2, hidden_size], last_layer_act='linear')
117 |         
118 |         # gene specific decoder
119 |         self.indv_w1 = nn.Parameter(torch.rand(self.num_genes,
120 |                                                hidden_size, 1))
121 |         self.indv_b1 = nn.Parameter(torch.rand(self.num_genes, 1))
122 |         self.act = nn.ReLU()
123 |         nn.init.xavier_normal_(self.indv_w1)
124 |         nn.init.xavier_normal_(self.indv_b1)
125 |         
126 |         # Cross gene MLP
127 |         self.cross_gene_state = MLP([self.num_genes, hidden_size,
128 |                                      hidden_size])
129 |         # final gene specific decoder
130 |         self.indv_w2 = nn.Parameter(torch.rand(1, self.num_genes,
131 |                                            hidden_size+1))
132 |         self.indv_b2 = nn.Parameter(torch.rand(1, self.num_genes))
133 |         nn.init.xavier_normal_(self.indv_w2)
134 |         nn.init.xavier_normal_(self.indv_b2)
135 |         
136 |         # batchnorms
137 |         self.bn_emb = nn.BatchNorm1d(hidden_size)
138 |         self.bn_pert_base = nn.BatchNorm1d(hidden_size)
139 |         self.bn_pert_base_trans = nn.BatchNorm1d(hidden_size)
140 |         
141 |         # uncertainty mode
142 |         if self.uncertainty:
143 |             self.uncertainty_w = MLP([hidden_size, hidden_size*2, hidden_size, 1], last_layer_act='linear')
144 |         
145 |     def forward(self, data):
146 |         """
147 |         Forward pass of the model
148 |         """
149 |         x, pert_idx = data.x, data.pert_idx
150 |         if self.no_perturb:
151 |             out = x.reshape(-1,1)
152 |             out = torch.split(torch.flatten(out), self.num_genes)           
153 |             return torch.stack(out)
154 |         else:
155 |             num_graphs = len(data.batch.unique())
156 | 
157 |             ## get base gene embeddings
158 |             emb = self.gene_emb(torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.args['device']))        
159 |             emb = self.bn_emb(emb)
160 |             base_emb = self.emb_trans(emb)        
161 | 
162 |             pos_emb = self.emb_pos(torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.args['device']))
163 |             for idx, layer in enumerate(self.layers_emb_pos):
164 |                 pos_emb = layer(pos_emb, self.G_coexpress, self.G_coexpress_weight)
165 |                 if idx < len(self.layers_emb_pos) - 1:
166 |                     pos_emb = pos_emb.relu()
167 | 
168 |             base_emb = base_emb + 0.2 * pos_emb
169 |             base_emb = self.emb_trans_v2(base_emb)
170 | 
171 |             ## get perturbation index and embeddings
172 | 
173 |             pert_index = []
174 |             for idx, i in enumerate(pert_idx):
175 |                 for j in i:
176 |                     if j != -1:
177 |                         pert_index.append([idx, j])
178 |             pert_index = torch.tensor(pert_index).T
179 | 
180 |             pert_global_emb = self.pert_emb(torch.LongTensor(list(range(self.num_perts))).to(self.args['device']))        
181 | 
182 |             ## augment global perturbation embedding with GNN
183 |             for idx, layer in enumerate(self.sim_layers):
184 |                 pert_global_emb = layer(pert_global_emb, self.G_sim, self.G_sim_weight)
185 |                 if idx < self.num_layers - 1:
186 |                     pert_global_emb = pert_global_emb.relu()
187 | 
188 |             ## add global perturbation embedding to each gene in each cell in the batch
189 |             base_emb = base_emb.reshape(num_graphs, self.num_genes, -1)
190 |             
191 |             first_sub_task_out = None
192 |             if pert_index.shape[0] != 0:
193 |                 ### in case all samples in the batch are controls, then there is no indexing for pert_index.
194 |                 pert_track = {}
195 |                 for i, j in enumerate(pert_index[0]):
196 |                     if j.item() in pert_track:
197 |                         pert_track[j.item()] = pert_track[j.item()] + pert_global_emb[pert_index[1][i]]
198 |                     else:
199 |                         pert_track[j.item()] = pert_global_emb[pert_index[1][i]]
200 | 
201 |                 if len(list(pert_track.values())) > 0:
202 |                     if len(list(pert_track.values())) == 1:
203 |                         # circumvent when batch size = 1 with single perturbation and cannot feed into MLP
204 |                         emb_total = self.pert_fuse(torch.stack(list(pert_track.values()) * 2))
205 |                     else:
206 |                         emb_total = self.pert_fuse(torch.stack(list(pert_track.values())))
207 | 
208 |                     for idx, j in enumerate(pert_track.keys()):
209 |                         base_emb[j] = base_emb[j] + emb_total[idx]
210 |             # print(emb_total.size()) || output shape: [batch-num_control, hidden_dim: 64(default)]
211 |                 
212 |             #print(base_emb.size()) || output shape: [batch, gene_num, hidden_dim]
213 |             base_emb = base_emb.reshape(num_graphs * self.num_genes, -1)
214 |             # print(base_emb.size()) || output shape: [batch * gene_num, hidden_dim]
215 |             base_emb = self.bn_pert_base(base_emb)
216 | 
217 |             first_sub_task_out = base_emb.reshape(num_graphs, self.num_genes, -1)
218 |             first_sub_task_out = torch.nn.Sigmoid()(self.first_sub_task(first_sub_task_out))
219 |             
220 |             ## apply the first MLP
221 |             base_emb = self.transform(base_emb)        
222 |             out = self.recovery_w(base_emb)
223 |             out = out.reshape(num_graphs, self.num_genes, -1)
224 |             out = out.unsqueeze(-1) * self.indv_w1
225 |             w = torch.sum(out, axis = 2)
226 |             out = w + self.indv_b1 # print(out.size()) || output shape :[batch, gene_exp_dim, 1]
227 |             
228 |             ## second sub-task output
229 |             second_sub_task_out = torch.nn.Sigmoid()(out)
230 |             
231 |             # Cross gene
232 |             cross_gene_embed = self.cross_gene_state(out.reshape(num_graphs, self.num_genes, -1).squeeze(2))
233 |             cross_gene_embed = cross_gene_embed.repeat(1, self.num_genes)
234 | 
235 |             cross_gene_embed = cross_gene_embed.reshape([num_graphs,self.num_genes, -1])
236 |             cross_gene_out = torch.cat([out, cross_gene_embed], 2)
237 | 
238 |             # original code
239 |             # cross_gene_out = cross_gene_out * self.indv_w2
240 |             # cross_gene_out = torch.sum(cross_gene_out, axis=2)
241 |             # out = cross_gene_out + self.indv_b2        
242 |             # out = out.reshape(num_graphs * self.num_genes, -1) + x.reshape(-1,1)
243 |             # out = torch.split(torch.flatten(out), self.num_genes)
244 |             
245 |             # Modification code
246 |             cross_gene_out = cross_gene_out * self.indv_w2
247 |             cross_gene_out = torch.sum(cross_gene_out, axis=2)
248 |             # out = torch.exp(cross_gene_out + self.indv_b2)
249 |             # out = out.reshape(num_graphs * self.num_genes, -1)
250 |             out = cross_gene_out + self.indv_b2
251 |             # out = out.reshape(num_graphs * self.num_genes, -1) + x.reshape(-1,1)
252 |            
253 |             # fold_change_version
254 |             # out = torch.exp(out)
255 |             # our_fc = out.reshape(num_graphs * self.num_genes, -1)
256 |             # our_fc = torch.split(torch.flatten(our_fc), self.num_genes)
257 |             # our_fc = torch.stack(our_fc)   
258 |             
259 |                    
260 |             out = out.reshape(num_graphs * self.num_genes, -1) * x.reshape(-1,1)
261 |             out = torch.split(torch.flatten(out), self.num_genes)
262 |             third_sub_task_out = torch.stack(out) 
263 |                        
264 |             # print(third_sub_task_out.size())
265 |             ## uncertainty head
266 |             # if self.uncertainty:
267 |             #     out_logvar = self.uncertainty_w(base_emb)
268 |             #     out_logvar = torch.split(torch.flatten(out_logvar), self.num_genes)
269 |             #     return torch.stack(out), torch.stack(out_logvar)
270 |             
271 |             # original coe
272 |             # return torch.stack(out)
273 |             
274 |             # print(first_sub_task_out.size(), second_sub_task_out.size(), third_sub_task_out.size())|| [batch_size, num_gene, 1], [batch_size, num_gene, 1], [batch_size, num_gene]
275 |             return first_sub_task_out, second_sub_task_out, third_sub_task_out
276 |         
277 | 
--------------------------------------------------------------------------------
/Supp_code/Modified_GEARS/pertdata.py:
--------------------------------------------------------------------------------
  1 | from torch_geometric.data import Data
  2 | import torch
  3 | import numpy as np
  4 | import pickle
  5 | from torch_geometric.data import DataLoader
  6 | import os
  7 | import scanpy as sc
  8 | from tqdm import tqdm
  9 | 
 10 | import warnings
 11 | warnings.filterwarnings("ignore")
 12 | sc.settings.verbosity = 0
 13 | 
 14 | from .data_utils import get_DE_genes, get_dropout_non_zero_genes, DataSplitter
 15 | from .utils import print_sys, zip_data_download_wrapper, dataverse_download,\
 16 |                   filter_pert_in_go, get_genes_from_perts, tar_data_download_wrapper
 17 | 
 18 | class PertData:
 19 |     """
 20 |     Class for loading and processing perturbation data
 21 | 
 22 |     Attributes
 23 |     ----------
 24 |     data_path: str
 25 |         Path to save/load data
 26 |     gene_set_path: str
 27 |         Path to gene set to use for perturbation graph
 28 |     default_pert_graph: bool
 29 |         Whether to use default perturbation graph or not
 30 |     dataset_name: str
 31 |         Name of dataset
 32 |     dataset_path: str
 33 |         Path to dataset
 34 |     adata: AnnData
 35 |         AnnData object containing dataset
 36 |     dataset_processed: bool
 37 |         Whether dataset has been processed or not
 38 |     ctrl_adata: AnnData
 39 |         AnnData object containing control samples
 40 |     gene_names: list
 41 |         List of gene names
 42 |     node_map: dict
 43 |         Dictionary mapping gene names to indices
 44 |     split: str
 45 |         Split type
 46 |     seed: int
 47 |         Seed for splitting
 48 |     subgroup: str
 49 |         Subgroup for splitting
 50 |     train_gene_set_size: int
 51 |         Number of genes to use for training
 52 | 
 53 |     """
 54 |     
 55 |     def __init__(self, data_path, 
 56 |                  gene_set_path=None, 
 57 |                  default_pert_graph=True):
 58 |         """
 59 |         Parameters
 60 |         ----------
 61 | 
 62 |         data_path: str
 63 |             Path to save/load data
 64 |         gene_set_path: str
 65 |             Path to gene set to use for perturbation graph
 66 |         default_pert_graph: bool
 67 |             Whether to use default perturbation graph or not
 68 | 
 69 |         """
 70 | 
 71 |         
 72 |         # Dataset/Dataloader attributes
 73 |         self.data_path = data_path
 74 |         self.default_pert_graph = default_pert_graph
 75 |         self.gene_set_path = gene_set_path
 76 |         self.dataset_name = None
 77 |         self.dataset_path = None
 78 |         self.adata = None
 79 |         self.dataset_processed = None
 80 |         self.ctrl_adata = None
 81 |         self.gene_names = []
 82 |         self.node_map = {}
 83 | 
 84 |         # Split attributes
 85 |         self.split = None
 86 |         self.seed = None
 87 |         self.subgroup = None
 88 |         self.train_gene_set_size = None
 89 |         
 90 |         if not os.path.exists(self.data_path):
 91 |             os.mkdir(self.data_path)
 92 |         server_path = 'https://dataverse.harvard.edu/api/access/datafile/6153417'
 93 |         dataverse_download(server_path,
 94 |                            os.path.join(self.data_path, 'gene2go_all.pkl'))
 95 |         with open(os.path.join(self.data_path, 'gene2go_all.pkl'), 'rb') as f:
 96 |             self.gene2go = pickle.load(f)
 97 |     
 98 |     def set_pert_genes(self):
 99 |         """
100 |         Set the list of genes that can be perturbed and are to be included in 
101 |         perturbation graph
102 |         """
103 |         
104 |         if self.gene_set_path is not None:
105 |             # If gene set specified for perturbation graph, use that
106 |             path_ = self.gene_set_path
107 |             self.default_pert_graph = False
108 |             with open(path_, 'rb') as f:
109 |                 essential_genes = pickle.load(f)
110 |             
111 |         elif self.default_pert_graph is False:
112 |             # Use a smaller perturbation graph 
113 |             all_pert_genes = get_genes_from_perts(self.adata.obs['condition'])
114 |             essential_genes = list(self.adata.var['gene_name'].values)
115 |             essential_genes += all_pert_genes
116 |             
117 |         else:
118 |             # Otherwise, use a large set of genes to create perturbation graph
119 |             server_path = 'https://dataverse.harvard.edu/api/access/datafile/6934320'
120 |             path_ = os.path.join(self.data_path,
121 |                                      'essential_all_data_pert_genes.pkl')
122 |             dataverse_download(server_path, path_)
123 |             with open(path_, 'rb') as f:
124 |                 essential_genes = pickle.load(f)
125 |     
126 |         gene2go = {i: self.gene2go[i] for i in essential_genes if i in self.gene2go}
127 | 
128 |         self.pert_names = np.unique(list(gene2go.keys()))
129 |         self.node_map_pert = {x: it for it, x in enumerate(self.pert_names)}
130 |             
131 |     def load(self, data_name = None, data_path = None):
132 |         """
133 |         Load existing dataloader
134 |         Use data_name for loading 'norman', 'adamson', 'dixit' datasets
135 |         For other datasets use data_path
136 | 
137 |         Parameters
138 |         ----------
139 |         data_name: str
140 |             Name of dataset
141 |         data_path: str
142 |             Path to dataset
143 | 
144 |         Returns
145 |         -------
146 |         None
147 | 
148 |         """
149 |         
150 |         if data_name in ['norman', 'adamson', 'dixit', 
151 |                          'replogle_k562_essential', 
152 |                          'replogle_rpe1_essential']:
153 |             ## load from harvard dataverse
154 |             if data_name == 'norman':
155 |                 url = 'https://dataverse.harvard.edu/api/access/datafile/6154020'
156 |             elif data_name == 'adamson':
157 |                 url = 'https://dataverse.harvard.edu/api/access/datafile/6154417'
158 |             elif data_name == 'dixit':
159 |                 url = 'https://dataverse.harvard.edu/api/access/datafile/6154416'
160 |             elif data_name == 'replogle_k562_essential':
161 |                 ## Note: This is not the complete dataset and has been filtered
162 |                 url = 'https://dataverse.harvard.edu/api/access/datafile/7458695'
163 |             elif data_name == 'replogle_rpe1_essential':
164 |                 ## Note: This is not the complete dataset and has been filtered
165 |                 url = 'https://dataverse.harvard.edu/api/access/datafile/7458694'
166 |             data_path = os.path.join(self.data_path, data_name)
167 |             zip_data_download_wrapper(url, data_path, self.data_path)
168 |             self.dataset_name = data_path.split('/')[-1]
169 |             self.dataset_path = data_path
170 |             adata_path = os.path.join(data_path, 'perturb_processed.h5ad')
171 |             self.adata = sc.read_h5ad(adata_path)
172 | 
173 |         elif os.path.exists(data_path):
174 |             adata_path = os.path.join(data_path, 'perturb_processed.h5ad')
175 |             self.adata = sc.read_h5ad(adata_path)
176 |             self.dataset_name = data_path.split('/')[-1]
177 |             self.dataset_path = data_path
178 |         else:
179 |             raise ValueError("data attribute is either norman, adamson, dixit "
180 |                              "replogle_k562 or replogle_rpe1 "
181 |                              "or a path to an h5ad file")
182 |         
183 |         self.set_pert_genes()
184 |         print_sys('These perturbations are not in the GO graph and their '
185 |                   'perturbation can thus not be predicted')
186 |         not_in_go_pert = np.array(self.adata.obs[
187 |                                   self.adata.obs.condition.apply(
188 |                                   lambda x:not filter_pert_in_go(x,
189 |                                         self.pert_names))].condition.unique())
190 |         print_sys(not_in_go_pert)
191 |         
192 |         filter_go = self.adata.obs[self.adata.obs.condition.apply(
193 |                               lambda x: filter_pert_in_go(x, self.pert_names))]
194 |         self.adata = self.adata[filter_go.index.values, :]
195 |         pyg_path = os.path.join(data_path, 'data_pyg')
196 |         if not os.path.exists(pyg_path):
197 |             os.mkdir(pyg_path)
198 |         dataset_fname = os.path.join(pyg_path, 'cell_graphs.pkl')
199 |                 
200 |         if os.path.isfile(dataset_fname):
201 |             print_sys("Local copy of pyg dataset is detected. Loading...")
202 |             self.dataset_processed = pickle.load(open(dataset_fname, "rb"))        
203 |             print_sys("Done!")
204 |         else:
205 |             self.ctrl_adata = self.adata[self.adata.obs['condition'] == 'ctrl']
206 |             self.gene_names = self.adata.var.gene_name
207 |             
208 |             
209 |             print_sys("Creating pyg object for each cell in the data...")
210 |             self.create_dataset_file()
211 |             print_sys("Saving new dataset pyg object at " + dataset_fname) 
212 |             pickle.dump(self.dataset_processed, open(dataset_fname, "wb"))    
213 |             print_sys("Done!")
214 |             
215 |     def new_data_process(self, dataset_name,
216 |                          adata = None,
217 |                          skip_calc_de = False, sub_task_data = None):
218 |         """
219 |         Process new dataset
220 | 
221 |         Parameters
222 |         ----------
223 |         dataset_name: str
224 |             Name of dataset
225 |         adata: AnnData object
226 |             AnnData object containing gene expression data
227 |         skip_calc_de: bool
228 |             If True, skip differential expression calculation
229 | 
230 |         Returns
231 |         -------
232 |         None
233 | 
234 |         """
235 |         
236 |         if 'condition' not in adata.obs.columns.values:
237 |             raise ValueError("Please specify condition")
238 |         if 'gene_name' not in adata.var.columns.values:
239 |             raise ValueError("Please specify gene name")
240 |         if 'cell_type' not in adata.obs.columns.values:
241 |             raise ValueError("Please specify cell type")
242 |         
243 |         dataset_name = dataset_name.lower()
244 |         self.dataset_name = dataset_name
245 |         save_data_folder = os.path.join(self.data_path, dataset_name)
246 |         
247 |         if not os.path.exists(save_data_folder):
248 |             os.mkdir(save_data_folder)
249 |         self.dataset_path = save_data_folder
250 |         self.adata = get_DE_genes(adata, skip_calc_de)
251 |         if not skip_calc_de:
252 |             self.adata = get_dropout_non_zero_genes(self.adata)
253 |         self.adata.write_h5ad(os.path.join(save_data_folder, 'perturb_processed.h5ad'))
254 |         
255 |         self.set_pert_genes()
256 |         self.ctrl_adata = self.adata[self.adata.obs['condition'] == 'ctrl']
257 |         self.gene_names = self.adata.var.gene_name
258 |         pyg_path = os.path.join(save_data_folder, 'data_pyg')
259 |         if not os.path.exists(pyg_path):
260 |             os.mkdir(pyg_path)
261 |         dataset_fname = os.path.join(pyg_path, 'cell_graphs.pkl')
262 |         print_sys("Creating pyg object for each cell in the data...")
263 |         self.create_dataset_file(sub_task_data)
264 |         print_sys("Saving new dataset pyg object at " + dataset_fname) 
265 |         pickle.dump(self.dataset_processed, open(dataset_fname, "wb"))    
266 |         print_sys("Done!")
267 |         
268 |     def prepare_split(self, split = 'simulation', 
269 |                       seed = 1, 
270 |                       train_gene_set_size = 0.75,
271 |                       combo_seen2_train_frac = 0.75,
272 |                       combo_single_split_test_set_fraction = 0.1,
273 |                       test_perts = None,
274 |                       only_test_set_perts = False,
275 |                       test_pert_genes = None,
276 |                       split_dict_path=None):
277 | 
278 |         """
279 |         Prepare splits for training and testing
280 | 
281 |         Parameters
282 |         ----------
283 |         split: str
284 |             Type of split to use. Currently, we support 'simulation',
285 |             'simulation_single', 'combo_seen0', 'combo_seen1', 'combo_seen2',
286 |             'single', 'no_test', 'no_split', 'custom'
287 |         seed: int
288 |             Random seed
289 |         train_gene_set_size: float
290 |             Fraction of genes to use for training
291 |         combo_seen2_train_frac: float
292 |             Fraction of combo seen2 perturbations to use for training
293 |         combo_single_split_test_set_fraction: float
294 |             Fraction of combo single perturbations to use for testing
295 |         test_perts: list
296 |             List of perturbations to use for testing
297 |         only_test_set_perts: bool
298 |             If True, only use test set perturbations for testing
299 |         test_pert_genes: list
300 |             List of genes to use for testing
301 |         split_dict_path: str
302 |             Path to dictionary used for custom split. Sample format:
303 |                 {'train': [X, Y], 'val': [P, Q], 'test': [Z]}
304 | 
305 |         Returns
306 |         -------
307 |         None
308 | 
309 |         """
310 |         available_splits = ['simulation', 'simulation_single', 'combo_seen0',
311 |                             'combo_seen1', 'combo_seen2', 'single', 'no_test',
312 |                             'no_split', 'custom']
313 |         if split not in available_splits:
314 |             raise ValueError('currently, we only support ' + ','.join(available_splits))
315 |         self.split = split
316 |         self.seed = seed
317 |         self.subgroup = None
318 |         self.train_gene_set_size = train_gene_set_size
319 |         
320 |         split_folder = os.path.join(self.dataset_path, 'splits')
321 |         if not os.path.exists(split_folder):
322 |             os.mkdir(split_folder)
323 |         split_file = self.dataset_name + '_' + split + '_' + str(seed) + '_' \
324 |                                        +  str(train_gene_set_size) + '.pkl'
325 |         split_path = os.path.join(split_folder, split_file)
326 |         
327 |         if test_perts:
328 |             split_path = split_path[:-4] + '_' + test_perts + '.pkl'
329 |         
330 |         if os.path.exists(split_path):
331 |             print_sys("Local copy of split is detected. Loading...")
332 |             set2conditions = pickle.load(open(split_path, "rb"))
333 |             if split == 'simulation':
334 |                 subgroup_path = split_path[:-4] + '_subgroup.pkl'
335 |                 subgroup = pickle.load(open(subgroup_path, "rb"))
336 |                 self.subgroup = subgroup
337 |         else:
338 |             print_sys("Creating new splits....")
339 |             if test_perts:
340 |                 test_perts = test_perts.split('_')
341 |                     
342 |             if split in ['simulation', 'simulation_single']:
343 |                 # simulation split
344 |                 DS = DataSplitter(self.adata, split_type=split)
345 |                 
346 |                 adata, subgroup = DS.split_data(train_gene_set_size = train_gene_set_size, 
347 |                                                 combo_seen2_train_frac = combo_seen2_train_frac,
348 |                                                 seed=seed,
349 |                                                 test_perts = test_perts,
350 |                                                 only_test_set_perts = only_test_set_perts
351 |                                                )
352 |                 subgroup_path = split_path[:-4] + '_subgroup.pkl'
353 |                 pickle.dump(subgroup, open(subgroup_path, "wb"))
354 |                 self.subgroup = subgroup
355 |                 
356 |             elif split[:5] == 'combo':
357 |                 # combo perturbation
358 |                 split_type = 'combo'
359 |                 seen = int(split[-1])
360 | 
361 |                 if test_pert_genes:
362 |                     test_pert_genes = test_pert_genes.split('_')
363 |                 
364 |                 DS = DataSplitter(self.adata, split_type=split_type, seen=int(seen))
365 |                 adata = DS.split_data(test_size=combo_single_split_test_set_fraction,
366 |                                       test_perts=test_perts,
367 |                                       test_pert_genes=test_pert_genes,
368 |                                       seed=seed)
369 | 
370 |             elif split == 'single':
371 |                 # single perturbation
372 |                 DS = DataSplitter(self.adata, split_type=split)
373 |                 adata = DS.split_data(test_size=combo_single_split_test_set_fraction,
374 |                                       seed=seed)
375 | 
376 |             elif split == 'no_test':
377 |                 # no test set
378 |                 DS = DataSplitter(self.adata, split_type=split)
379 |                 adata = DS.split_data(seed=seed)
380 |             
381 |             elif split == 'no_split':
382 |                 # no split
383 |                 adata = self.adata
384 |                 adata.obs['split'] = 'test'
385 |                 
386 |             elif split == 'custom':
387 |                 adata = self.adata
388 |                 try:
389 |                     with open(split_dict_path, 'rb') as f:
390 |                         split_dict = pickle.load(f)
391 |                 except:
392 |                     raise ValueError('Please set split_dict_path for custom split')
393 |                 adata.obs['split'] = adata.obs['condition'].map(split_dict)
394 |                 
395 |                 
396 |             
397 |             set2conditions = dict(adata.obs.groupby('split').agg({'condition':
398 |                                                         lambda x: x}).condition)
399 |             set2conditions = {i: j.unique().tolist() for i,j in set2conditions.items()} 
400 |             pickle.dump(set2conditions, open(split_path, "wb"))
401 |             print_sys("Saving new splits at " + split_path)
402 |             
403 |         self.set2conditions = set2conditions
404 | 
405 |         if split == 'simulation':
406 |             print_sys('Simulation split test composition:')
407 |             for i,j in subgroup['test_subgroup'].items():
408 |                 print_sys(i + ':' + str(len(j)))
409 |         print_sys("Done!")
410 |         
411 |     def get_dataloader(self, batch_size, test_batch_size = None):
412 |         """
413 |         Get dataloaders for training and testing
414 | 
415 |         Parameters
416 |         ----------
417 |         batch_size: int
418 |             Batch size for training
419 |         test_batch_size: int
420 |             Batch size for testing
421 | 
422 |         Returns
423 |         -------
424 |         dict
425 |             Dictionary of dataloaders
426 | 
427 |         """
428 |         if test_batch_size is None:
429 |             test_batch_size = batch_size
430 |             
431 |         self.node_map = {x: it for it, x in enumerate(self.adata.var.gene_name)}
432 |         self.gene_names = self.adata.var.gene_name
433 |        
434 |         # Create cell graphs
435 |         cell_graphs = {}
436 |         if self.split == 'no_split':
437 |             i = 'test'
438 |             cell_graphs[i] = []
439 |             for p in self.set2conditions[i]:
440 |                 if p != 'ctrl':
441 |                     cell_graphs[i].extend(self.dataset_processed[p])
442 |                 
443 |             print_sys("Creating dataloaders....")
444 |             # Set up dataloaders
445 |             test_loader = DataLoader(cell_graphs['test'],
446 |                                 batch_size=batch_size, shuffle=False)
447 | 
448 |             print_sys("Dataloaders created...")
449 |             return {'test_loader': test_loader}
450 |         else:
451 |             if self.split =='no_test':
452 |                 splits = ['train','val']
453 |             else:
454 |                 splits = ['train','val','test']
455 |             for i in splits:
456 |                 cell_graphs[i] = []
457 |                 for p in self.set2conditions[i]:
458 |                     cell_graphs[i].extend(self.dataset_processed[p])
459 | 
460 |             print_sys("Creating dataloaders....")
461 |             
462 |             # Set up dataloaders
463 |             train_loader = DataLoader(cell_graphs['train'],
464 |                                 batch_size=batch_size, shuffle=True, drop_last = True)
465 |             val_loader = DataLoader(cell_graphs['val'],
466 |                                 batch_size=batch_size, shuffle=True)
467 |             
468 |             if self.split !='no_test':
469 |                 test_loader = DataLoader(cell_graphs['test'],
470 |                                 batch_size=batch_size, shuffle=False)
471 |                 self.dataloader =  {'train_loader': train_loader,
472 |                                     'val_loader': val_loader,
473 |                                     'test_loader': test_loader}
474 | 
475 |             else: 
476 |                 self.dataloader =  {'train_loader': train_loader,
477 |                                     'val_loader': val_loader}
478 |             print_sys("Done!")
479 | 
480 |     def get_pert_idx(self, pert_category):
481 |         """
482 |         Get perturbation index for a given perturbation category
483 | 
484 |         Parameters
485 |         ----------
486 |         pert_category: str
487 |             Perturbation category
488 | 
489 |         Returns
490 |         -------
491 |         list
492 |             List of perturbation indices
493 | 
494 |         """
495 |         try:
496 |             pert_idx = [np.where(p == self.pert_names)[0][0]
497 |                     for p in pert_category.split('+')
498 |                     if p != 'ctrl']
499 |         except:
500 |             print(pert_category)
501 |             pert_idx = None
502 |             
503 |         return pert_idx
504 | 
505 |     def create_cell_graph(self, X, y, de_idx, pert, pert_idx=None, y_1 = None, y_2 = None, y_3 = None):
506 |         """
507 |         Create a cell graph from a given cell
508 | 
509 |         Parameters
510 |         ----------
511 |         X: np.ndarray
512 |             Gene expression matrix
513 |         y: np.ndarray
514 |             Label vector
515 |         de_idx: np.ndarray
516 |             DE gene indices
517 |         pert: str
518 |             Perturbation category
519 |         pert_idx: list
520 |             List of perturbation indices
521 | 
522 |         Returns
523 |         -------
524 |         torch_geometric.data.Data
525 |             Cell graph to be used in dataloader
526 | 
527 |         """
528 |             
529 |         feature_mat = torch.Tensor(X).T
530 |         if pert_idx is None:
531 |             pert_idx = [-1]
532 |         return Data(x=feature_mat, pert_idx=pert_idx,
533 |                     y=torch.Tensor(y), de_idx=de_idx, pert=pert, y_1 = torch.Tensor(y_1), y_2 = torch.Tensor(y_2), y_3 = torch.Tensor(y_3))
534 | 
535 |     def create_cell_graph_dataset(self, split_adata, pert_category,
536 |                                   num_samples=1, sub_tasks_data = None):
537 |         """
538 |         Combine cell graphs to create a dataset of cell graphs
539 | 
540 |         Parameters
541 |         ----------
542 |         split_adata: anndata.AnnData
543 |             Annotated data matrix
544 |         pert_category: str
545 |             Perturbation category
546 |         num_samples: int
547 |             Number of samples to create per perturbed cell (i.e. number of
548 |             control cells to map to each perturbed cell)
549 | 
550 |         Returns
551 |         -------
552 |         list
553 |             List of cell graphs
554 | 
555 |         """
556 |         # This is modification
557 |         if sub_tasks_data is not None:
558 |             sub_tasks_data = sc.read_h5ad(sub_tasks_data)
559 |             gene_idx_dict = dict(zip(list(sub_tasks_data.obs.index), list(range(len(sub_tasks_data.obs.index)))))
560 |         
561 |         num_de_genes = 20        
562 |         adata_ = split_adata[split_adata.obs['condition'] == pert_category]
563 |         if 'rank_genes_groups_cov_all' in adata_.uns:
564 |             de_genes = adata_.uns['rank_genes_groups_cov_all']
565 |             de = True
566 |         else:
567 |             de = False
568 |             num_de_genes = 1
569 |         Xs = []
570 |         ys = []
571 |         
572 |         # this is modification
573 |         y1 = []
574 |         y2 = []
575 |         y3 = []
576 |         # When considering a non-control perturbation
577 |         if pert_category != 'ctrl':
578 |             # Get the indices of applied perturbation
579 |             pert_idx = self.get_pert_idx(pert_category)
580 |             
581 |             y_sub_tasks = sub_tasks_data[gene_idx_dict[pert_category]] # this is modification
582 |             
583 |             # Store list of genes that are most differentially expressed for testing
584 |             pert_de_category = adata_.obs['condition_name'][0]
585 |             if de:
586 |                 de_idx = np.where(adata_.var_names.isin(
587 |                 np.array(de_genes[pert_de_category][:num_de_genes])))[0]
588 |             else:
589 |                 de_idx = [-1] * num_de_genes
590 |             for cell_z in adata_.X:
591 |                 # Use samples from control as basal expression
592 |                 ctrl_samples = self.ctrl_adata[np.random.randint(0,
593 |                                         len(self.ctrl_adata), num_samples), :]
594 |                 for c in ctrl_samples.X:
595 |                     Xs.append(c)
596 |                     ys.append(cell_z)
597 |                     y1.append(y_sub_tasks.layers['level1'].astype('float32'))
598 |                     y2.append(y_sub_tasks.layers['level2'].astype('float32'))
599 |                     y3.append(y_sub_tasks.layers['level3'].astype('float32'))
600 |         # When considering a control perturbation
601 |         else:
602 |             y_sub_tasks = sub_tasks_data[gene_idx_dict[pert_category]]
603 |             pert_idx = None
604 |             de_idx = [-1] * num_de_genes
605 |             for cell_z in adata_.X:
606 |                 Xs.append(cell_z)
607 |                 ys.append(cell_z)
608 |                 y1.append(y_sub_tasks.layers['level1'].astype('float32'))
609 |                 y2.append(y_sub_tasks.layers['level2'].astype('float32'))
610 |                 y3.append(y_sub_tasks.layers['level3'].astype('float32'))
611 |         # Create cell graphs
612 |         cell_graphs = []
613 |         for X, y, y_1, y_2, y_3 in zip(Xs, ys, y1, y2, y3):
614 |             cell_graphs.append(self.create_cell_graph(X.toarray(),
615 |                                 y.toarray(), de_idx, pert_category, pert_idx, y_1.toarray(), y_2.toarray(), y_3.toarray()))
616 | 
617 |         return cell_graphs
618 | 
619 |     def create_dataset_file(self, sub_tasks_data = None):
620 |         """
621 |         Create dataset file for each perturbation condition
622 |         """
623 |         print_sys("Creating dataset file...")
624 |         self.dataset_processed = {}
625 |         for p in tqdm(self.adata.obs['condition'].unique()):
626 |             self.dataset_processed[p] = self.create_cell_graph_dataset(self.adata, p, sub_tasks_data = sub_tasks_data)
627 |         print_sys("Done!")
628 | 
--------------------------------------------------------------------------------
/Supp_code/Modified_GEARS/utils.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | import numpy as np
  3 | import pandas as pd
  4 | import networkx as nx
  5 | from tqdm import tqdm
  6 | import pickle
  7 | import sys, os
  8 | import requests
  9 | from torch_geometric.data import Data
 10 | from zipfile import ZipFile
 11 | import tarfile
 12 | from sklearn.linear_model import TheilSenRegressor
 13 | from dcor import distance_correlation
 14 | from multiprocessing import Pool
 15 | 
 16 | def parse_single_pert(i):
 17 |     a = i.split('+')[0]
 18 |     b = i.split('+')[1]
 19 |     if a == 'ctrl':
 20 |         pert = b
 21 |     else:
 22 |         pert = a
 23 |     return pert
 24 | 
 25 | def parse_combo_pert(i):
 26 |     return i.split('+')[0], i.split('+')[1]
 27 | 
 28 | def combine_res(res_1, res_2):
 29 |     res_out = {}
 30 |     for key in res_1:
 31 |         res_out[key] = np.concatenate([res_1[key], res_2[key]])
 32 |     return res_out
 33 | 
 34 | def parse_any_pert(p):
 35 |     if ('ctrl' in p) and (p != 'ctrl'):
 36 |         return [parse_single_pert(p)]
 37 |     elif 'ctrl' not in p:
 38 |         out = parse_combo_pert(p)
 39 |         return [out[0], out[1]]
 40 | 
 41 | def np_pearson_cor(x, y):
 42 |     xv = x - x.mean(axis=0)
 43 |     yv = y - y.mean(axis=0)
 44 |     xvss = (xv * xv).sum(axis=0)
 45 |     yvss = (yv * yv).sum(axis=0)
 46 |     result = np.matmul(xv.transpose(), yv) / np.sqrt(np.outer(xvss, yvss))
 47 |     # bound the values to -1 to 1 in the event of precision issues
 48 |     return np.maximum(np.minimum(result, 1.0), -1.0)
 49 | 
 50 | 
 51 | def dataverse_download(url, save_path):
 52 |     """
 53 |     Dataverse download helper with progress bar
 54 | 
 55 |     Args:
 56 |         url (str): the url of the dataset
 57 |         path (str): the path to save the dataset
 58 |     """
 59 |     
 60 |     if os.path.exists(save_path):
 61 |         print_sys('Found local copy...')
 62 |     else:
 63 |         print_sys("Downloading...")
 64 |         response = requests.get(url, stream=True)
 65 |         total_size_in_bytes= int(response.headers.get('content-length', 0))
 66 |         block_size = 1024
 67 |         progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
 68 |         with open(save_path, 'wb') as file:
 69 |             for data in response.iter_content(block_size):
 70 |                 progress_bar.update(len(data))
 71 |                 file.write(data)
 72 |         progress_bar.close()
 73 | 
 74 |         
 75 | def zip_data_download_wrapper(url, save_path, data_path):
 76 |     """
 77 |     Wrapper for zip file download
 78 | 
 79 |     Args:
 80 |         url (str): the url of the dataset
 81 |         save_path (str): the path where the file is donwloaded
 82 |         data_path (str): the path to save the extracted dataset
 83 |     """
 84 | 
 85 |     if os.path.exists(save_path):
 86 |         print_sys('Found local copy...')
 87 |     else:
 88 |         dataverse_download(url, save_path + '.zip')
 89 |         print_sys('Extracting zip file...')
 90 |         with ZipFile((save_path + '.zip'), 'r') as zip:
 91 |             zip.extractall(path = data_path)
 92 |         print_sys("Done!")  
 93 |         
 94 | def tar_data_download_wrapper(url, save_path, data_path):
 95 |     """
 96 |     Wrapper for tar file download
 97 | 
 98 |     Args:
 99 |         url (str): the url of the dataset
100 |         save_path (str): the path where the file is donwloaded
101 |         data_path (str): the path to save the extracted dataset
102 | 
103 |     """
104 | 
105 |     if os.path.exists(save_path):
106 |         print_sys('Found local copy...')
107 |     else:
108 |         dataverse_download(url, save_path + '.tar.gz')
109 |         print_sys('Extracting tar file...')
110 |         with tarfile.open(save_path  + '.tar.gz') as tar:
111 |             tar.extractall(path= data_path)
112 |         print_sys("Done!")  
113 |         
114 | def get_go_auto(gene_list, data_path, data_name):
115 |     """
116 |     Get gene ontology data
117 | 
118 |     Args:
119 |         gene_list (list): list of gene names
120 |         data_path (str): the path to save the extracted dataset
121 |         data_name (str): the name of the dataset
122 | 
123 |     Returns:
124 |         df_edge_list (pd.DataFrame): gene ontology edge list
125 |     """
126 |     go_path = os.path.join(data_path, data_name, 'go.csv')
127 |     
128 |     if os.path.exists(go_path):
129 |         return pd.read_csv(go_path)
130 |     else:
131 |         ## download gene2go.pkl
132 |         if not os.path.exists(os.path.join(data_path, 'gene2go.pkl')):
133 |             server_path = 'https://dataverse.harvard.edu/api/access/datafile/6153417'
134 |             dataverse_download(server_path, os.path.join(data_path, 'gene2go.pkl'))
135 |         with open(os.path.join(data_path, 'gene2go.pkl'), 'rb') as f:
136 |             gene2go = pickle.load(f)
137 | 
138 |         gene2go = {i: list(gene2go[i]) for i in gene_list if i in gene2go}
139 |         edge_list = []
140 |         for g1 in tqdm(gene2go.keys()):
141 |             for g2 in gene2go.keys():
142 |                 edge_list.append((g1, g2, len(np.intersect1d(gene2go[g1],
143 |                    gene2go[g2]))/len(np.union1d(gene2go[g1], gene2go[g2]))))
144 | 
145 |         edge_list_filter = [i for i in edge_list if i[2] > 0]
146 |         further_filter = [i for i in edge_list if i[2] > 0.1]
147 |         df_edge_list = pd.DataFrame(further_filter).rename(columns = {0: 'gene1',
148 |                                                                       1: 'gene2',
149 |                                                                       2: 'score'})
150 | 
151 |         df_edge_list = df_edge_list.rename(columns = {'gene1': 'source',
152 |                                                       'gene2': 'target',
153 |                                                       'score': 'importance'})
154 |         df_edge_list.to_csv(go_path, index = False)        
155 |         return df_edge_list
156 | 
157 | class GeneSimNetwork():
158 |     """
159 |     GeneSimNetwork class
160 | 
161 |     Args:
162 |         edge_list (pd.DataFrame): edge list of the network
163 |         gene_list (list): list of gene names
164 |         node_map (dict): dictionary mapping gene names to node indices
165 | 
166 |     Attributes:
167 |         edge_index (torch.Tensor): edge index of the network
168 |         edge_weight (torch.Tensor): edge weight of the network
169 |         G (nx.DiGraph): networkx graph object
170 |     """
171 |     def __init__(self, edge_list, gene_list, node_map):
172 |         """
173 |         Initialize GeneSimNetwork class
174 |         """
175 | 
176 |         self.edge_list = edge_list
177 |         self.G = nx.from_pandas_edgelist(self.edge_list, source='source',
178 |                         target='target', edge_attr=['importance'],
179 |                         create_using=nx.DiGraph())    
180 |         self.gene_list = gene_list
181 |         for n in self.gene_list:
182 |             if n not in self.G.nodes():
183 |                 self.G.add_node(n)
184 |         
185 |         edge_index_ = [(node_map[e[0]], node_map[e[1]]) for e in
186 |                       self.G.edges]
187 |         self.edge_index = torch.tensor(edge_index_, dtype=torch.long).T
188 |         #self.edge_weight = torch.Tensor(self.edge_list['importance'].values)
189 |         
190 |         edge_attr = nx.get_edge_attributes(self.G, 'importance') 
191 |         importance = np.array([edge_attr[e] for e in self.G.edges])
192 |         self.edge_weight = torch.Tensor(importance)
193 | 
194 | def get_GO_edge_list(args):
195 |     """
196 |     Get gene ontology edge list
197 |     """
198 |     g1, gene2go = args
199 |     edge_list = []
200 |     for g2 in gene2go.keys():
201 |         score = len(gene2go[g1].intersection(gene2go[g2])) / len(
202 |             gene2go[g1].union(gene2go[g2]))
203 |         if score > 0.1:
204 |             edge_list.append((g1, g2, score))
205 |     return edge_list
206 |         
207 | def make_GO(data_path, pert_list, data_name, num_workers=25, save=True):
208 |     """
209 |     Creates Gene Ontology graph from a custom set of genes
210 |     """
211 | 
212 |     fname = './data/go_essential_' + data_name + '.csv'
213 |     if os.path.exists(fname):
214 |         return pd.read_csv(fname)
215 | 
216 |     with open(os.path.join(data_path, 'gene2go_all.pkl'), 'rb') as f:
217 |         gene2go = pickle.load(f)
218 |     gene2go = {i: gene2go[i] for i in pert_list}
219 | 
220 |     print('Creating custom GO graph, this can take a few minutes')
221 |     with Pool(num_workers) as p:
222 |         all_edge_list = list(
223 |             tqdm(p.imap(get_GO_edge_list, ((g, gene2go) for g in gene2go.keys())),
224 |                       total=len(gene2go.keys())))
225 |     edge_list = []
226 |     for i in all_edge_list:
227 |         edge_list = edge_list + i
228 | 
229 |     df_edge_list = pd.DataFrame(edge_list).rename(
230 |         columns={0: 'source', 1: 'target', 2: 'importance'})
231 |     
232 |     if save:
233 |         print('Saving edge_list to file')
234 |         df_edge_list.to_csv(fname, index=False)
235 | 
236 |     return df_edge_list
237 | 
238 | def get_similarity_network(network_type, adata, threshold, k,
239 |                            data_path, data_name, split, seed, train_gene_set_size,
240 |                            set2conditions, default_pert_graph=True, pert_list=None):
241 |     
242 |     if network_type == 'co-express':
243 |         df_out = get_coexpression_network_from_train(adata, threshold, k,
244 |                                                      data_path, data_name, split,
245 |                                                      seed, train_gene_set_size,
246 |                                                      set2conditions)
247 |     elif network_type == 'go':
248 |         if default_pert_graph:
249 |             server_path = 'https://dataverse.harvard.edu/api/access/datafile/6934319'
250 |             tar_data_download_wrapper(server_path, 
251 |                                      os.path.join(data_path, 'go_essential_all'),
252 |                                      data_path)
253 |             df_jaccard = pd.read_csv(os.path.join(data_path, 
254 |                                      'go_essential_all/go_essential_all.csv'))
255 | 
256 |         else:
257 |             df_jaccard = make_GO(data_path, pert_list, data_name)
258 | 
259 |         df_out = df_jaccard.groupby('target').apply(lambda x: x.nlargest(k + 1,
260 |                                     ['importance'])).reset_index(drop = True)
261 | 
262 |     return df_out
263 | 
264 | def get_coexpression_network_from_train(adata, threshold, k, data_path,
265 |                                         data_name, split, seed, train_gene_set_size,
266 |                                         set2conditions):
267 |     """
268 |     Infer co-expression network from training data
269 | 
270 |     Args:
271 |         adata (anndata.AnnData): anndata object
272 |         threshold (float): threshold for co-expression
273 |         k (int): number of edges to keep
274 |         data_path (str): path to data
275 |         data_name (str): name of dataset
276 |         split (str): split of dataset
277 |         seed (int): seed for random number generator
278 |         train_gene_set_size (int): size of training gene set
279 |         set2conditions (dict): dictionary of perturbations to conditions
280 |     """
281 |     
282 |     fname = os.path.join(os.path.join(data_path, data_name), split + '_'  +
283 |                          str(seed) + '_' + str(train_gene_set_size) + '_' +
284 |                          str(threshold) + '_' + str(k) +
285 |                          '_co_expression_network.csv')
286 |     
287 |     if os.path.exists(fname):
288 |         return pd.read_csv(fname)
289 |     else:
290 |         gene_list = [f for f in adata.var.gene_name.values]
291 |         idx2gene = dict(zip(range(len(gene_list)), gene_list)) 
292 |         X = adata.X
293 |         train_perts = set2conditions['train']
294 |         X_tr = X[np.isin(adata.obs.condition, [i for i in train_perts if 'ctrl' in i])]
295 |         gene_list = adata.var['gene_name'].values
296 | 
297 |         X_tr = X_tr.toarray()
298 |         out = np_pearson_cor(X_tr, X_tr)
299 |         out[np.isnan(out)] = 0
300 |         out = np.abs(out)
301 | 
302 |         out_sort_idx = np.argsort(out)[:, -(k + 1):]
303 |         out_sort_val = np.sort(out)[:, -(k + 1):]
304 | 
305 |         df_g = []
306 |         for i in range(out_sort_idx.shape[0]):
307 |             target = idx2gene[i]
308 |             for j in range(out_sort_idx.shape[1]):
309 |                 df_g.append((idx2gene[out_sort_idx[i, j]], target, out_sort_val[i, j]))
310 | 
311 |         df_g = [i for i in df_g if i[2] > threshold]
312 |         df_co_expression = pd.DataFrame(df_g).rename(columns = {0: 'source',
313 |                                                                 1: 'target',
314 |                                                                 2: 'importance'})
315 |         df_co_expression.to_csv(fname, index = False)
316 |         return df_co_expression
317 |     
318 | def filter_pert_in_go(condition, pert_names):
319 |     """
320 |     Filter perturbations in GO graph
321 | 
322 |     Args:
323 |         condition (str): whether condition is 'ctrl' or not
324 |         pert_names (list): list of perturbations
325 |     """
326 | 
327 |     if condition == 'ctrl':
328 |         return True
329 |     else:
330 |         cond1 = condition.split('+')[0]
331 |         cond2 = condition.split('+')[1]
332 |         num_ctrl = (cond1 == 'ctrl') + (cond2 == 'ctrl')
333 |         num_in_perts = (cond1 in pert_names) + (cond2 in pert_names)
334 |         if num_ctrl + num_in_perts == 2:
335 |             return True
336 |         else:
337 |             return False
338 |         
339 | def uncertainty_loss_fct(pred, logvar, y, perts, reg = 0.1, ctrl = None,
340 |                          direction_lambda = 1e-3, dict_filter = None):
341 |     """
342 |     Uncertainty loss function
343 | 
344 |     Args:
345 |         pred (torch.tensor): predicted values
346 |         logvar (torch.tensor): log variance
347 |         y (torch.tensor): true values
348 |         perts (list): list of perturbations
349 |         reg (float): regularization parameter
350 |         ctrl (str): control perturbation
351 |         direction_lambda (float): direction loss weight hyperparameter
352 |         dict_filter (dict): dictionary of perturbations to conditions
353 | 
354 |     """
355 |     gamma = 2                     
356 |     perts = np.array(perts)
357 |     losses = torch.tensor(0.0, requires_grad=True).to(pred.device)
358 |     for p in set(perts):
359 |         if p!= 'ctrl':
360 |             retain_idx = dict_filter[p]
361 |             pred_p = pred[np.where(perts==p)[0]][:, retain_idx]
362 |             y_p = y[np.where(perts==p)[0]][:, retain_idx]
363 |             logvar_p = logvar[np.where(perts==p)[0]][:, retain_idx]
364 |         else:
365 |             pred_p = pred[np.where(perts==p)[0]]
366 |             y_p = y[np.where(perts==p)[0]]
367 |             logvar_p = logvar[np.where(perts==p)[0]]
368 |                          
369 |         # uncertainty based loss
370 |         losses += torch.sum((pred_p - y_p)**(2 + gamma) + reg * torch.exp(
371 |             -logvar_p)  * (pred_p - y_p)**(2 + gamma))/pred_p.shape[0]/pred_p.shape[1]
372 |                          
373 |         # direction loss                 
374 |         if p!= 'ctrl':
375 |             losses += torch.sum(direction_lambda *
376 |                                 (torch.sign(y_p - ctrl[retain_idx]) -
377 |                                  torch.sign(pred_p - ctrl[retain_idx]))**2)/\
378 |                                  pred_p.shape[0]/pred_p.shape[1]
379 |         else:
380 |             losses += torch.sum(direction_lambda *
381 |                                 (torch.sign(y_p - ctrl) -
382 |                                  torch.sign(pred_p - ctrl))**2)/\
383 |                                  pred_p.shape[0]/pred_p.shape[1]
384 |             
385 |     return losses/(len(set(perts)))
386 | 
387 | 
388 | def loss_fct(pred, y, perts, ctrl = None, direction_lambda = 1e-3, dict_filter = None):
389 |     """
390 |     Main MSE Loss function, includes direction loss
391 | 
392 |     Args:
393 |         pred (torch.tensor): predicted values
394 |         y (torch.tensor): true values
395 |         perts (list): list of perturbations
396 |         ctrl (str): control perturbation
397 |         direction_lambda (float): direction loss weight hyperparameter
398 |         dict_filter (dict): dictionary of perturbations to conditions
399 | 
400 |     """
401 |     gamma = 2
402 |     mse_p = torch.nn.MSELoss()
403 |     perts = np.array(perts)
404 |     losses = torch.tensor(0.0, requires_grad=True).to(pred.device)
405 | 
406 |     for p in set(perts):
407 |         pert_idx = np.where(perts == p)[0]
408 |         
409 |         # during training, we remove the all zero genes into calculation of loss.
410 |         # this gives a cleaner direction loss. empirically, the performance stays the same.
411 |         if p!= 'ctrl':
412 |             retain_idx = dict_filter[p]
413 |             pred_p = pred[pert_idx][:, retain_idx]
414 |             y_p = y[pert_idx][:, retain_idx]
415 |         else:
416 |             pred_p = pred[pert_idx]
417 |             y_p = y[pert_idx]
418 |         losses = losses + torch.sum((pred_p - y_p)**(2 + gamma))/pred_p.shape[0]/pred_p.shape[1]
419 |                          
420 |         ## direction loss
421 |         if (p!= 'ctrl'):
422 |             losses = losses + torch.sum(direction_lambda *
423 |                                 (torch.sign(y_p - ctrl[retain_idx]) -
424 |                                  torch.sign(pred_p - ctrl[retain_idx]))**2)/\
425 |                                  pred_p.shape[0]/pred_p.shape[1]
426 |         else:
427 |             losses = losses + torch.sum(direction_lambda * (torch.sign(y_p - ctrl) -
428 |                                                 torch.sign(pred_p - ctrl))**2)/\
429 |                                                 pred_p.shape[0]/pred_p.shape[1]
430 |     return losses/(len(set(perts)))
431 | 
432 | def loss_sub_task(pred, y_1, y_2, y_3):
433 |     first_sub_task_out = pred[0].squeeze(-1)
434 |     second_sub_task_out = pred[1].squeeze(-1)
435 |     # third_sub_task_out = pred[2]
436 |     
437 |     level_1_output = y_1
438 |     level_2_output = y_2
439 |     # level_3_output = y_3
440 |     
441 |     # binary loss
442 |     with torch.no_grad():
443 |         all_num = level_1_output.shape[0] * level_1_output.shape[1]
444 |         DE_num = torch.sum(level_1_output.sum())
445 |         if DE_num == 0:
446 |             DE_num = 1
447 |         loss_weights = (all_num - DE_num) / DE_num
448 |     loss_binary = torch.nn.BCELoss(weight = level_1_output * loss_weights / 4 + 1)
449 |     loss_1 = loss_binary(first_sub_task_out.squeeze(-1), level_1_output)
450 |     
451 |     
452 |     with torch.no_grad():
453 |         mask = level_1_output!=0
454 |     with torch.no_grad():
455 |         up_num = torch.sum(mask * level_2_output)
456 |         all_num = torch.sum(mask)
457 |         if up_num == 0:
458 |             up_num = 1
459 |         weights = mask * level_2_output * all_num / up_num
460 |         weights[weights > 0] -= 1
461 |     loss_binary = torch.nn.BCELoss(weight = weights + mask, reduction = 'sum')
462 |     loss_2 = loss_binary(second_sub_task_out, level_2_output) / torch.sum(weights + mask)
463 | 
464 |     # loss_3 = torch.sum((mask * (third_sub_task_out.squeeze(-1)-third_sub_task_out) ** 2)) / torch.sum(mask)
465 |     
466 |     # return loss_1, loss_2, loss_3, (first_sub_task_out, second_sub_task_out, third_sub_task_out), mask
467 |     return loss_1, loss_2, (first_sub_task_out, second_sub_task_out), mask
468 | 
469 | def print_sys(s):
470 |     """system print
471 | 
472 |     Args:
473 |         s (str): the string to print
474 |     """
475 |     print(s, flush = True, file = sys.stderr)
476 |     
477 | def create_cell_graph_for_prediction(X, pert_idx, pert_gene):
478 |     """
479 |     Create a perturbation specific cell graph for inference
480 | 
481 |     Args:
482 |         X (np.array): gene expression matrix
483 |         pert_idx (list): list of perturbation indices
484 |         pert_gene (list): list of perturbations
485 | 
486 |     """
487 | 
488 |     if pert_idx is None:
489 |         pert_idx = [-1]
490 |     return Data(x=torch.Tensor(X).T, pert_idx = pert_idx, pert=pert_gene)
491 |     
492 | 
493 | def create_cell_graph_dataset_for_prediction(pert_gene, ctrl_adata, gene_names,
494 |                                              device, num_samples = 300):
495 |     """
496 |     Create a perturbation specific cell graph dataset for inference
497 | 
498 |     Args:
499 |         pert_gene (list): list of perturbations
500 |         ctrl_adata (anndata): control anndata
501 |         gene_names (list): list of gene names
502 |         device (torch.device): device to use
503 |         num_samples (int): number of samples to use for inference (default: 300)
504 | 
505 |     """
506 | 
507 |     # Get the indices (and signs) of applied perturbation
508 |     pert_idx = [np.where(p == np.array(gene_names))[0][0] for p in pert_gene]
509 | 
510 |     Xs = ctrl_adata[np.random.randint(0, len(ctrl_adata), num_samples), :].X.toarray()
511 |     # Create cell graphs
512 |     cell_graphs = [create_cell_graph_for_prediction(X, pert_idx, pert_gene).to(device) for X in Xs]
513 |     return cell_graphs
514 | 
515 | ##
516 | ##GI related utils
517 | ##
518 | 
519 | def get_coeffs(singles_expr, first_expr, second_expr, double_expr):
520 |     """
521 |     Get coefficients for GI calculation
522 | 
523 |     Args:
524 |         singles_expr (np.array): single perturbation expression
525 |         first_expr (np.array): first perturbation expression
526 |         second_expr (np.array): second perturbation expression
527 |         double_expr (np.array): double perturbation expression
528 | 
529 |     """
530 |     results = {}
531 |     results['ts'] = TheilSenRegressor(fit_intercept=False,
532 |                           max_subpopulation=1e5,
533 |                           max_iter=1000,
534 |                           random_state=1000)   
535 |     X = singles_expr
536 |     y = double_expr
537 |     results['ts'].fit(X, y.ravel())
538 |     Zts = results['ts'].predict(X)
539 |     results['c1'] = results['ts'].coef_[0]
540 |     results['c2'] = results['ts'].coef_[1]
541 |     results['mag'] = np.sqrt((results['c1']**2 + results['c2']**2))
542 |     
543 |     results['dcor'] = distance_correlation(singles_expr, double_expr)
544 |     results['dcor_singles'] = distance_correlation(first_expr, second_expr)
545 |     results['dcor_first'] = distance_correlation(first_expr, double_expr)
546 |     results['dcor_second'] = distance_correlation(second_expr, double_expr)
547 |     results['corr_fit'] = np.corrcoef(Zts.flatten(), double_expr.flatten())[0,1]
548 |     results['dominance'] = np.abs(np.log10(results['c1']/results['c2']))
549 |     results['eq_contr'] = np.min([results['dcor_first'], results['dcor_second']])/\
550 |                         np.max([results['dcor_first'], results['dcor_second']])
551 |     
552 |     return results
553 | 
554 | def get_GI_params(preds, combo):
555 |     """
556 |     Get GI parameters
557 | 
558 |     Args:
559 |         preds (dict): dictionary of predictions
560 |         combo (list): list of perturbations
561 | 
562 |     """
563 |     singles_expr = np.array([preds[combo[0]], preds[combo[1]]]).T
564 |     first_expr = np.array(preds[combo[0]]).T
565 |     second_expr = np.array(preds[combo[1]]).T
566 |     double_expr = np.array(preds[combo[0]+'_'+combo[1]]).T
567 |     
568 |     return get_coeffs(singles_expr, first_expr, second_expr, double_expr)
569 | 
570 | def get_GI_genes_idx(adata, GI_gene_file):
571 |     """
572 |     Optional: Reads a file containing a list of GI genes (usually those
573 |     with high mean expression)
574 | 
575 |     Args:
576 |         adata (anndata): anndata object
577 |         GI_gene_file (str): file containing GI genes (generally corresponds
578 |         to genes with high mean expression)
579 |     """
580 |     # Genes used for linear model fitting
581 |     GI_genes = np.load(GI_gene_file, allow_pickle=True)
582 |     GI_genes_idx = np.where([g in GI_genes for g in adata.var.gene_name.values])[0]
583 |     
584 |     return GI_genes_idx
585 | 
586 | def get_mean_control(adata):
587 |     """
588 |     Get mean control expression
589 |     """
590 |     mean_ctrl_exp = adata[adata.obs['condition'] == 'ctrl'].to_df().mean()
591 |     return mean_ctrl_exp
592 | 
593 | def get_genes_from_perts(perts):
594 |     """
595 |     Returns list of genes involved in a given perturbation list
596 |     """
597 | 
598 |     if type(perts) is str:
599 |         perts = [perts]
600 |     gene_list = [p.split('+') for p in np.unique(perts)]
601 |     gene_list = [item for sublist in gene_list for item in sublist]
602 |     gene_list = [g for g in gene_list if g != 'ctrl']
603 |     return list(np.unique(gene_list))
604 | 
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
  1 | name: stamp
  2 | channels:
  3 |   - pytorch
  4 |   - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
  5 |   - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
  6 |   - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
  7 |   - defaults
  8 | dependencies:
  9 |   - _libgcc_mutex=0.1=main
 10 |   - _openmp_mutex=5.1=1_gnu
 11 |   - blas=1.0=mkl
 12 |   - bzip2=1.0.8=h5eee18b_6
 13 |   - ca-certificates=2024.3.11=h06a4308_0
 14 |   - cudatoolkit=11.3.1=h2bc3f7f_2
 15 |   - ffmpeg=4.3=hf484d3e_0
 16 |   - freetype=2.12.1=h4a9f257_0
 17 |   - gmp=6.2.1=h295c915_3
 18 |   - gnutls=3.6.15=he1e5248_0
 19 |   - intel-openmp=2021.4.0=h06a4308_3561
 20 |   - jpeg=9e=h5eee18b_1
 21 |   - lame=3.100=h7b6447c_0
 22 |   - lcms2=2.12=h3be6417_0
 23 |   - ld_impl_linux-64=2.38=h1181459_1
 24 |   - lerc=3.0=h295c915_0
 25 |   - libdeflate=1.17=h5eee18b_1
 26 |   - libffi=3.3=he6710b0_2
 27 |   - libgcc-ng=11.2.0=h1234567_1
 28 |   - libgomp=11.2.0=h1234567_1
 29 |   - libiconv=1.16=h5eee18b_3
 30 |   - libidn2=2.3.4=h5eee18b_0
 31 |   - libpng=1.6.39=h5eee18b_0
 32 |   - libstdcxx-ng=11.2.0=h1234567_1
 33 |   - libtasn1=4.19.0=h5eee18b_0
 34 |   - libtiff=4.5.1=h6a678d5_0
 35 |   - libunistring=0.9.10=h27cfd23_0
 36 |   - libuv=1.44.2=h5eee18b_0
 37 |   - libwebp-base=1.3.2=h5eee18b_0
 38 |   - lz4-c=1.9.4=h6a678d5_1
 39 |   - mkl=2021.4.0=h06a4308_640
 40 |   - mkl-service=2.4.0=py39h7f8727e_0
 41 |   - mkl_fft=1.3.1=py39hd3c417c_0
 42 |   - mkl_random=1.2.2=py39h51133e4_0
 43 |   - ncurses=6.4=h6a678d5_0
 44 |   - nettle=3.7.3=hbbd107a_1
 45 |   - numpy=1.24.3=py39h14f4228_0
 46 |   - numpy-base=1.24.3=py39h31eccc5_0
 47 |   - openh264=2.1.1=h4ff587b_0
 48 |   - openjpeg=2.4.0=h3ad879b_0
 49 |   - openssl=1.1.1w=h7f8727e_0
 50 |   - pillow=10.3.0=py39h5eee18b_0
 51 |   - pip=24.0=py39h06a4308_0
 52 |   - python=3.9.7=h12debd9_1
 53 |   - pytorch=1.10.2=py3.9_cuda11.3_cudnn8.2.0_0
 54 |   - pytorch-mutex=1.0=cuda
 55 |   - readline=8.2=h5eee18b_0
 56 |   - setuptools=69.5.1=py39h06a4308_0
 57 |   - six=1.16.0=pyhd3eb1b0_1
 58 |   - sqlite=3.45.3=h5eee18b_0
 59 |   - tk=8.6.14=h39e8969_0
 60 |   - torchaudio=0.10.2=py39_cu113
 61 |   - torchvision=0.11.3=py39_cu113
 62 |   - typing_extensions=4.11.0=py39h06a4308_0
 63 |   - wheel=0.43.0=py39h06a4308_0
 64 |   - xz=5.4.6=h5eee18b_1
 65 |   - zlib=1.2.13=h5eee18b_1
 66 |   - zstd=1.5.5=hc292b87_2
 67 |   - pip:
 68 |     - anndata==0.10.7
 69 |     - array-api-compat==1.7.1
 70 |     - asttokens==2.4.1
 71 |     - comm==0.2.2
 72 |     - contourpy==1.2.1
 73 |     - cycler==0.12.1
 74 |     - dcor==0.6
 75 |     - debugpy==1.8.1
 76 |     - decorator==5.1.1
 77 |     - exceptiongroup==1.2.1
 78 |     - executing==2.0.1
 79 |     - fonttools==4.53.0
 80 |     - get-annotations==0.1.2
 81 |     - h5py==3.11.0
 82 |     - importlib-metadata==7.2.0
 83 |     - importlib-resources==6.4.0
 84 |     - ipykernel==6.29.4
 85 |     - ipython==8.18.1
 86 |     - jedi==0.19.1
 87 |     - joblib==1.4.2
 88 |     - jupyter-client==8.6.2
 89 |     - jupyter-core==5.7.2
 90 |     - kiwisolver==1.4.5
 91 |     - legacy-api-wrap==1.4
 92 |     - llvmlite==0.42.0
 93 |     - matplotlib==3.9.0
 94 |     - matplotlib-inline==0.1.7
 95 |     - natsort==8.4.0
 96 |     - nest-asyncio==1.6.0
 97 |     - networkx==3.2.1
 98 |     - numba==0.59.1
 99 |     - packaging==24.1
100 |     - pandas==2.2.2
101 |     - parso==0.8.4
102 |     - patsy==0.5.6
103 |     - pexpect==4.9.0
104 |     - platformdirs==4.2.2
105 |     - prompt-toolkit==3.0.47
106 |     - psutil==6.0.0
107 |     - ptyprocess==0.7.0
108 |     - pure-eval==0.2.2
109 |     - pygments==2.18.0
110 |     - pynndescent==0.5.12
111 |     - pyparsing==3.1.2
112 |     - python-dateutil==2.9.0.post0
113 |     - pytz==2024.1
114 |     - pyyaml==6.0.2rc1
115 |     - pyzmq==26.0.3
116 |     - scanpy==1.10.1
117 |     - scikit-learn==1.5.0
118 |     - scipy==1.13.1
119 |     - seaborn==0.13.2
120 |     - session-info==1.0.0
121 |     - stack-data==0.6.3
122 |     - stamp==0.1.0
123 |     - statsmodels==0.14.2
124 |     - stdlib-list==0.10.0
125 |     - threadpoolctl==3.5.0
126 |     - tornado==6.4.1
127 |     - tqdm==4.66.4
128 |     - traitlets==5.14.3
129 |     - tzdata==2024.1
130 |     - umap-learn==0.5.6
131 |     - wcwidth==0.2.13
132 |     - zipp==3.19.2
133 | prefix: /home/gaoyicheng/anaconda3/envs/stamp
134 | 
--------------------------------------------------------------------------------
/img/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bm2-lab/STAMP/f40e2887c013647bb01570bd61bdfa01d6a05b00/img/framework.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
 1 | numpy==1.24.3
 2 | pandas==2.2.2
 3 | tqdm==4.66.4
 4 | scikit-learn==1.5.0
 5 | torch
 6 | scanpy==1.10.1
 7 | dcor==0.6
 8 | scipy==1.13.1
 9 | joblib==1.4.2
10 | pyyaml
11 | setuptools==69.5.1
12 | 
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
 1 | from setuptools import setup, find_packages
 2 | from os import path
 3 | from io import open
 4 | 
 5 | # get __version__ from _version.py
 6 | ver_file = path.join('stamp', 'version.py')
 7 | with open(ver_file) as f:
 8 |     exec(f.read())
 9 | 
10 | this_directory = path.abspath(path.dirname(__file__))
11 | 
12 | # read the contents of README.md
13 | def readme():
14 |     with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
15 |         return f.read()
16 | 
17 | # read the contents of requirements.txt
18 | with open(path.join(this_directory, 'requirements.txt'),
19 |           encoding='utf-8') as f:
20 |     requirements = f.read().splitlines()
21 | 
22 | setup(
23 |     name='cell-stamp',
24 |     version=__version__,
25 |     description='STAMP for genetic perturbation prediction',
26 |     long_description=readme(), 
27 |     long_description_content_type='text/markdown',
28 |     url='https://github.com/bm2-lab/STAMP',
29 |     author='Yicheng Gao, Zhiting Wei, Qi Liu',
30 |     packages=find_packages(),
31 |     zip_safe=False,
32 |     include_package_data=True,
33 |     install_requires=requirements,
34 |     license='GPL-3.0 license'
35 | )
--------------------------------------------------------------------------------
/stamp/DataSet.py:
--------------------------------------------------------------------------------
 1 | import numpy as np
 2 | import torch
 3 | from torch.utils.data import Dataset
 4 | import joblib
 5 | import scanpy as sc
 6 | import numpy as np
 7 | 
 8 | class PerturbDataSet(Dataset):
 9 |     def __init__(self, data_dir, gene_embedding_notebook):
10 |         
11 |         # read data
12 |         try:
13 |             data = sc.read_h5ad(data_dir)
14 |         except:
15 |             data = data_dir
16 |         
17 |         # obtain three levels data
18 |         self.FL_data = data.layers['level1']
19 |         self.SL_data = data.layers['level2']
20 |         self.TL_data = data.layers['level3']
21 |         
22 |         # obtain the gene names
23 |         self.gene_name = list(data.var.index)
24 |         
25 |         # obtain perturb genes
26 |         self.perturb_genes = list(data.obs.index)
27 |         
28 |         # obtain gene embedding matrix
29 |         self.gene_embedding_notebook = gene_embedding_notebook
30 |         
31 |     def __getitem__(self, item):
32 |         
33 |         target_gene = self.perturb_genes[item]
34 | 
35 |         pertub_embeds = 0
36 |         for idx, t in enumerate(target_gene.split(',')):
37 |             target_gene_index = self.gene_name.index(t)
38 |             pertub_embeds += self.gene_embedding_notebook[target_gene_index]
39 |         pertub_embeds /= idx + 1
40 |         
41 |         FL_output = self.FL_data[item].toarray()
42 |         
43 |         SL_output = self.SL_data[item].toarray()
44 |         
45 |         TL_output = self.TL_data[item].toarray()
46 |         
47 |         return (target_gene, pertub_embeds), (FL_output, SL_output, TL_output)
48 |     
49 |     def __len__(self, ):
50 |         return len(self.perturb_genes)
51 | 
52 | def __main__():
53 |     dataset = PerturbDataSet("Path")
54 |     dataloader = torch.utils.data.DataLoader(dataset, batch_size = 64, shuffle = True)
55 |     for idx, batch_x in enumerate(dataloader):
56 |         print(idx, batch_x)
--------------------------------------------------------------------------------
/stamp/Modules.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | from torch import nn
  3 | import numpy as np
  4 | import math
  5 | import matplotlib.pyplot as plt
  6 | from torch.nn import functional as F
  7 | 
  8 | class Bayes_first_level_layer(nn.Module):
  9 |     def __init__(self, in_features, out_features, bias = True):
 10 |         super().__init__()
 11 |         self.in_features = in_features
 12 |         self.out_features = out_features
 13 |         
 14 |         self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features))
 15 |         self.weight_sigma_log = nn.Parameter(torch.Tensor(out_features, in_features))
 16 |         self.bias = bias
 17 |         
 18 |         if self.bias:
 19 |             self.bias_mu = nn.Parameter(torch.Tensor(out_features))
 20 |             self.bias_sigma_log = nn.Parameter(torch.Tensor(out_features))
 21 |             
 22 |         self.activation = nn.Softmax(dim = 1)
 23 |         
 24 |     def forward(self, X):
 25 |         weight = self.weight_mu + torch.randn_like(self.weight_sigma_log) * torch.exp(self.weight_sigma_log)
 26 |         if self.bias:
 27 |             bias = self.bias + torch.randn_like(self.bias_sigma_log) * torch.exp(self.bias_sigma_log)
 28 |         else:
 29 |             bias = None
 30 |         hidden_states =  F.linear(X, weight, bias)
 31 |         output = self.activation(hidden_states)
 32 |         return output
 33 | 
 34 | 
 35 | class First_level_layer(nn.Module):
 36 |     def __init__(self, in_features, out_features, bias = True):
 37 |         super().__init__()
 38 |         self.in_features = in_features
 39 |         self.out_features = out_features
 40 |         
 41 |         self.mapping1 = nn.ModuleList([nn.Linear(self.in_features, 1024), nn.ReLU(), nn.Dropout(0.9),
 42 |                                        nn.Linear(1024, 2048), nn.ReLU(), nn.Dropout(0.9),
 43 |                                        nn.Linear(2048, self.out_features)])
 44 |             
 45 |         self.activation = nn.Sigmoid()
 46 | 
 47 |     def forward(self, X):
 48 |         hidden_states = X
 49 |         for idx, h in enumerate(self.mapping1):
 50 |             hidden_states =  h(hidden_states)
 51 |         output = self.activation(hidden_states)
 52 |         return output
 53 |     
 54 | class First_level_layer_concate(nn.Module):
 55 |     def __init__(self, hid1_features = 128, hid2_features = 64, hid3_features = 32, gene_embedding_notebook = None):
 56 |         super().__init__()
 57 |         self.hid1_features = hid1_features
 58 |         self.hid2_features = hid2_features
 59 |         self.hid3_features = hid3_features
 60 |         # self.gene_embedding_notebook = nn.Parameter(gene_embedding_notebook)
 61 |         self.gene_embedding_notebook = gene_embedding_notebook
 62 |         
 63 |         self.mapping1 = nn.ModuleList([nn.Linear(self.gene_embedding_notebook.shape[1]*2, self.hid1_features), nn.ReLU(), 
 64 |                                        nn.Linear(self.hid1_features, self.hid2_features), nn.ReLU(), 
 65 |                                        nn.Linear(self.hid2_features, self.hid3_features), nn.ReLU()] 
 66 |                                       )
 67 |         self.mapping1_head = nn.Linear(self.hid3_features, 1)
 68 |         
 69 |         self.activation = nn.Sigmoid()
 70 |         
 71 |     def forward(self, pertub_genes_embeds):
 72 |         
 73 |         pertub_genes_embeds = pertub_genes_embeds.unsqueeze(1).expand(pertub_genes_embeds.shape[0],  self.gene_embedding_notebook.shape[0], -1)
 74 |         expanded_notebook = self.gene_embedding_notebook.expand(pertub_genes_embeds.shape[0], self.gene_embedding_notebook.shape[0], -1)
 75 |         hids = torch.cat([pertub_genes_embeds, expanded_notebook.to(pertub_genes_embeds.device)], dim = -1)
 76 |         for idx, h in enumerate(self.mapping1):
 77 |             hids = h(hids)
 78 |         hids_head = self.mapping1_head(hids)
 79 |         output = self.activation(hids_head)
 80 |         
 81 |         return output
 82 | 
 83 | class Second_level_layer(nn.Module):
 84 |     def __init__(self, hid1_features = 128, hid2_features = 64, hid3_features = 32, gene_embedding_notebook = None):
 85 |         super().__init__()
 86 |         self.hid1_features = hid1_features
 87 |         self.hid2_features = hid2_features
 88 |         self.hid3_features = hid3_features
 89 |         # self.gene_embedding_notebook = nn.Parameter(gene_embedding_notebook)
 90 |         self.gene_embedding_notebook = gene_embedding_notebook
 91 |         
 92 |         self.mapping2 = nn.ModuleList([nn.Linear(self.gene_embedding_notebook.shape[1] * 2, self.hid1_features), nn.ReLU(), 
 93 |                                        nn.Linear(self.hid1_features, self.hid2_features), nn.ReLU(), 
 94 |                                        nn.Linear(self.hid2_features, self.hid3_features), nn.ReLU()] 
 95 |                                       )
 96 |         self.mapping2_head = nn.Linear(self.hid3_features, 1)
 97 |         
 98 |         self.activation = nn.Sigmoid()
 99 |         
100 |     def forward(self, X, pertub_genes_embeds):
101 |         
102 |         with torch.no_grad():
103 |             mask = X==0
104 |         pertub_genes_embeds = pertub_genes_embeds.unsqueeze(1).expand(pertub_genes_embeds.shape[0],  self.gene_embedding_notebook.shape[0], -1)
105 |         expanded_notebook = self.gene_embedding_notebook.expand(pertub_genes_embeds.shape[0], self.gene_embedding_notebook.shape[0], -1)
106 |         hids = torch.cat([pertub_genes_embeds, expanded_notebook.to(X.device)], dim = -1)
107 |         for idx, h in enumerate(self.mapping2):
108 |             hids = h(hids)
109 |         hids_head = self.mapping2_head(hids)
110 |         output_second_level = self.activation(hids_head)
111 |         
112 |         return output_second_level, ~mask, hids
113 | 
114 | class Third_level_layer(nn.Module):
115 |     def __init__(self, in_features = 32, hid1_features = 16, hid2_features = 8):
116 |         super().__init__()
117 |         self.in_features = in_features
118 |         self.hid1_features = hid1_features
119 |         self.hid2_features = hid2_features
120 |         self.mapping3 = nn.ModuleList([nn.Linear(in_features, hid1_features), 
121 |                                        nn.Linear(hid1_features, hid2_features)])
122 |         self.mapping3_head = nn.Linear(hid2_features, 1)
123 |         self.activation = nn.LeakyReLU()
124 |     
125 |     def forward(self, X, mask = None):
126 |         hids = X
127 |         for idx, h in enumerate(self.mapping3):
128 |             hids = h(hids)
129 |         hids_head = self.mapping3_head(hids)
130 |         # output_third_level = self.activation(hids_head)
131 |         output_third_level = torch.exp(hids_head)
132 |         return output_third_level, mask
133 | 
134 | class TaskCombineLayer_multi_task(nn.Module):
135 |     def __init__(self, in_features, out_features, 
136 |                  hid1_features_2 = 128, hid2_features_2 = 64, hid3_features_2 = 32,
137 |                  in_feature_3 = 32, hid1_features_3 = 16, hid2_features_3 = 8, 
138 |                  gene_embedding_notebook = None ,bias = True):
139 |         super().__init__()
140 |         
141 |         # you can set gene embeddings as a learnable parameters
142 |         # self.gene_embedding_notebook = nn.Parameter(gene_embedding_notebook)
143 |         self.gene_embedding_notebook = gene_embedding_notebook
144 |         self.first_level_layer = First_level_layer(in_features, out_features)
145 |         self.second_level_layer = Second_level_layer(hid1_features_2, hid2_features_2, hid3_features_2, gene_embedding_notebook = self.gene_embedding_notebook)
146 |         self.third_level_layer = Third_level_layer(in_feature_3, hid1_features_3, hid2_features_3)
147 |         
148 |     def forward(self, X, 
149 |                 level_1_output, 
150 |                 level_2_output, 
151 |                 level_3_output
152 |                 ) :
153 |         output_1 = self.first_level_layer(X)
154 |         # binary loss
155 |         with torch.no_grad():
156 |             all_num = level_1_output.shape[0] * level_1_output.shape[1]
157 |             DE_num = torch.sum(level_1_output.sum())
158 |             if DE_num == 0:
159 |                 DE_num = 1
160 |             loss_weights = (all_num - DE_num) / DE_num
161 |         loss_binary = nn.BCELoss(weight = level_1_output * loss_weights / 4 + 1)
162 |         loss_1 = loss_binary(output_1.squeeze(-1), level_1_output)
163 | 
164 |         output_2, mask, hids = self.second_level_layer(level_1_output, X)
165 |         with torch.no_grad():
166 |             up_num = torch.sum(mask * level_2_output)
167 |             all_num = torch.sum(mask)
168 |             if up_num == 0:
169 |                 up_num = 1
170 |             weights = mask * level_2_output * all_num / up_num
171 |             weights[weights > 0] -= 1
172 |             new_weights = weights + mask
173 |             if all_num <= up_num:
174 |                 all_num = up_num + 1
175 |             new_weights[new_weights==1] = all_num /(all_num - up_num)
176 | 
177 |         loss_binary = nn.BCELoss(weight = new_weights, reduction = 'sum')
178 |         loss_2 = loss_binary(output_2.squeeze(-1), level_2_output) / torch.sum(new_weights)
179 | 
180 |         output_3, mask = self.third_level_layer(hids, mask)
181 |         loss_3 = torch.sum((mask * (output_3.squeeze(-1)-level_3_output) ** 2)) / torch.sum(mask)
182 |         
183 |         return loss_1, loss_2, loss_3, (output_1, output_2, output_3), mask
184 | 
185 | class TaskCombineLayer_multi_task_concate(nn.Module):
186 |     def __init__(self, hid1_features_1 = 128, hid2_features_1 = 64, hid3_features_1 = 32,
187 |                  hid1_features_2 = 128, hid2_features_2 = 64, hid3_features_2 = 32,
188 |                  in_feature_3 = 32, hid1_features_3 = 16, hid2_features_3 = 8, 
189 |                  gene_embedding_notebook = None ,bias = True):
190 |         super().__init__()
191 |         
192 |         # self.gene_embedding_notebook = nn.Parameter(gene_embedding_notebook)
193 |         self.gene_embedding_notebook = gene_embedding_notebook
194 |         self.first_level_layer = First_level_layer_concate(hid1_features_1, hid2_features_1, hid3_features_1, gene_embedding_notebook = self.gene_embedding_notebook)
195 |         self.second_level_layer = Second_level_layer(hid1_features_2, hid2_features_2, hid3_features_2, gene_embedding_notebook = self.gene_embedding_notebook)
196 |         self.third_level_layer = Third_level_layer(in_feature_3, hid1_features_3, hid2_features_3)
197 |         
198 |     def forward(self, X, 
199 |                 level_1_output, 
200 |                 level_2_output, 
201 |                 level_3_output
202 |                 ) :
203 |         output_1 = self.first_level_layer(X)
204 |         # binary loss
205 |         with torch.no_grad():
206 |             all_num = level_1_output.shape[0] * level_1_output.shape[1]
207 |             DE_num = torch.sum(level_1_output.sum())
208 |             if DE_num == 0:
209 |                 DE_num = 1
210 |             loss_weights = (all_num - DE_num) / DE_num
211 |         loss_binary = nn.BCELoss(weight = level_1_output * loss_weights / 4 + 1)
212 |         loss_1 = loss_binary(output_1.squeeze(-1), level_1_output)
213 |         
214 |         output_2, mask, hids = self.second_level_layer(level_1_output, X)
215 |         with torch.no_grad():
216 |             up_num = torch.sum(mask * level_2_output)
217 |             all_num = torch.sum(mask)
218 |             if up_num == 0:
219 |                 up_num = 1
220 |             weights = mask * level_2_output * all_num / up_num
221 |             weights[weights > 0] -= 1
222 |             new_weights = weights + mask
223 |             if all_num <= up_num:
224 |                 all_num = up_num + 1
225 |             new_weights[new_weights==1] = all_num /(all_num - up_num)
226 | 
227 |         loss_binary = nn.BCELoss(weight = new_weights, reduction = 'sum')
228 |         loss_2 = loss_binary(output_2.squeeze(-1), level_2_output) / torch.sum(new_weights)
229 | 
230 |         output_3, mask = self.third_level_layer(hids, mask)
231 |         loss_3 = torch.sum((mask * (output_3.squeeze(-1)-level_3_output) ** 2)) / torch.sum(mask)
232 |         
233 |         return loss_1, loss_2, loss_3, (output_1, output_2, output_3), mask
--------------------------------------------------------------------------------
/stamp/__init__.py:
--------------------------------------------------------------------------------
1 | from .STAMP import STAMP
2 | from .DataSet import PerturbDataSet
3 | from .utils import load_config
--------------------------------------------------------------------------------
/stamp/utils.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | # load config file
3 | def load_config(config_file):
4 |     with open(config_file) as file:
5 |         config = yaml.safe_load(file)
6 |     return config
--------------------------------------------------------------------------------
/stamp/version.py:
--------------------------------------------------------------------------------
1 | """
2 | STAMP version file
3 | """
4 | __version__ = '0.1.2'
--------------------------------------------------------------------------------