├── .DS_Store
├── .idea
├── .gitignore
├── DeePathNet.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── other.xml
└── vcs.xml
├── R
├── mixOmics.R
├── mixOmicsTCGA.R
└── moCluster.R
├── README.md
├── configs
├── .DS_Store
├── ccle_gdsc_intersection3
│ ├── .DS_Store
│ └── mutation_cnv_rna
│ │ ├── deepathnet_allgenes_mutation_cnv_rna.json
│ │ ├── ec_en_allgenes_drug_mutation_cnv_rna.json
│ │ ├── ec_rf_allgenes_drug_mutation_cnv_rna.json
│ │ ├── moCluster_rf_allgenes_drug_mutation_cnv_rna.json
│ │ ├── move_rf_allgenes_drug_mutation_cnv_rna.json
│ │ ├── pca_rf_allgenes_drug_mutation_cnv_rna.json
│ │ └── scvaeit_rf_allgenes_drug_mutation_cnv_rna.json
├── sanger_gdsc_intersection_noprot
│ ├── .DS_Store
│ └── mutation_cnv_rna
│ │ ├── deepathnet_allgenes_mutation_cnv_rna.json
│ │ ├── ec_en_allgenes_drug_mutation_cnv_rna.json
│ │ ├── ec_rf_allgenes_drug_mutation_cnv_rna.json
│ │ ├── moCluster_rf_allgenes_drug_mutation_cnv_rna.json
│ │ ├── move_rf_allgenes_drug_mutation_cnv_rna.json
│ │ ├── pca_rf_allgenes_drug_mutation_cnv_rna.json
│ │ └── scvaeit_rf_allgenes_drug_mutation_cnv_rna.json
├── sanger_train_ccle_test_gdsc
│ ├── .DS_Store
│ ├── mutation_cnv_rna
│ │ ├── deepathnet_mutation_cnv_rna.json
│ │ ├── deepathnet_mutation_cnv_rna_ds_100.json
│ │ ├── deepathnet_mutation_cnv_rna_ds_200.json
│ │ ├── deepathnet_mutation_cnv_rna_ds_300.json
│ │ ├── deepathnet_mutation_cnv_rna_ds_400.json
│ │ ├── deepathnet_mutation_cnv_rna_random_control.json
│ │ ├── deepathnet_mutation_cnv_rna_random_control_ds_100.json
│ │ ├── deepathnet_mutation_cnv_rna_random_control_ds_200.json
│ │ ├── deepathnet_mutation_cnv_rna_random_control_ds_300.json
│ │ ├── deepathnet_mutation_cnv_rna_random_control_ds_400.json
│ │ ├── ec_mlp_all_genes_mutation_cnv_rna.json
│ │ └── ec_rf_all_genes_mutation_cnv_rna.json
│ └── mutation_cnv_rna_prot
│ │ ├── deepathnet_mutation_cnv_rna_prot.json
│ │ ├── deepathnet_mutation_cnv_rna_prot_random_control.json
│ │ └── ec_rf_all_genes_mutation_cnv_rna_prot.json
├── tcga_all_cancer_types
│ ├── .DS_Store
│ └── mutation_cnv_rna
│ │ ├── deepathnet_mutation_cnv_rna.json
│ │ ├── deepathnet_mutation_cnv_rna_ds_2k.json
│ │ ├── deepathnet_mutation_cnv_rna_random_control.json
│ │ └── deepathnet_mutation_cnv_rna_random_control_ds_2k.json
├── tcga_brca_subtypes
│ ├── .DS_Store
│ └── mutation_cnv_rna
│ │ └── deepathnet_mutation_cnv_rna.json
└── tcga_train_cptac_test_brca
│ └── cnv_rna
│ ├── deepathnet_cnv_rna.json
│ ├── deepathnet_cnv_rna_random_control.json
│ └── ec_rf_cnv_rna.json
├── figures
├── Figure1.pdf
├── Figure1.png
├── supp_fig1.pdf
└── supp_fig1.png
├── models
├── cancer_type_pretrained.pth
└── drug_response_pretrained.pth
├── requirements.txt
└── scripts
├── .DS_Store
├── baseline_ec_cv.py
├── baseline_independent_test.py
├── cancer_type_baseline_23cancertypes.py
├── cancer_type_baseline_brca.py
├── cancer_type_baseline_brca_validation.py
├── deepathnet_cv.py
├── deepathnet_independent_test.py
├── model_transformer_lrp.py
├── models.py
├── transformer_explantion_cancer_type.py
├── transformer_explantion_drug_response.py
├── transformer_shap_cancer_type.py
├── transformer_shap_drug_response.py
└── utils
├── __init__.py
├── layers_ours.py
├── lr_scheduler.py
└── training_prepare.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/.DS_Store
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/DeePathNet.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/R/mixOmics.R:
--------------------------------------------------------------------------------
1 | library(mixOmics)
2 | library(Metrics)
3 |
4 | # input_data = "sanger"
5 | # target = "gdsc"
6 |
7 | # args = commandArgs(trailingOnly=TRUE)
8 | #
9 | # input_data = args[1]
10 | # target = args[2]
11 | # mode = args[3]
12 |
13 | input_data = "ccle"
14 | target = "ctd2"
15 | mode = "allgenes"
16 |
17 | if (input_data == "sanger"){
18 | mutation = read.csv("./data/processed/omics/sanger_df_mutation_drug.csv.gz", row.names = 1)
19 | cnv = read.csv("./data/processed/omics/sanger_df_cnv_drug.csv.gz", row.names = 1)
20 | rna = read.csv("./data/processed/omics/sanger_df_rna_drug.csv.gz", row.names = 1)
21 | train_splits = read.csv("./data/meta/25splits_train_sanger.csv", stringsAsFactors = F)
22 | val_splits = read.csv("./data/meta/25splits_val_sanger.csv", stringsAsFactors = F)
23 |
24 | } else {
25 | if (target == "gdsc"){
26 | train_splits = read.csv("./data/meta/25splits_train_ccle_gdsc.csv", stringsAsFactors = F)
27 | val_splits = read.csv("./data/meta/25splits_val_ccle_gdsc.csv", stringsAsFactors = F)
28 | mutation = read.csv("./data/processed/omics/ccle_df_mutation_drug.csv.gz", row.names = 1)
29 | cnv = read.csv("./data/processed/omics/ccle_df_cnv_drug.csv.gz", row.names = 1)
30 | rna = read.csv("./data/processed/omics/ccle_df_rna_drug.csv.gz", row.names = 1)
31 | } else {
32 | mutation = read.csv("./data/processed/omics/ccle_df_mutation_ctd2.csv.gz", row.names = 1)
33 | cnv = read.csv("./data/processed/omics/ccle_df_cnv_ctd2.csv.gz", row.names = 1)
34 | rna = read.csv("./data/processed/omics/ccle_df_rna_ctd2.csv.gz", row.names = 1)
35 | train_splits = read.csv("./data/meta/25splits_train_ccle_ctd2.csv", stringsAsFactors = F)
36 | val_splits = read.csv("./data/meta/25splits_val_ccle_ctd2.csv", stringsAsFactors = F)
37 | }
38 | }
39 |
40 | if (target == "gdsc"){
41 | if (input_data == "sanger"){
42 | target_data = read.csv("./data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", row.names = 1)
43 | } else {
44 | target_data = read.csv("./data/processed/drug/ccle_gdsc_intersection3_wide.csv.gz", row.names = 1)
45 | }
46 | } else {
47 | if (input_data == "sanger"){
48 | target_data = read.csv("./data/processed/drug/sanger_ctd2_min400.csv.gz", row.names = 1)
49 | } else {
50 | target_data = read.csv("./data/processed/drug/ccle_ctd2_min600.csv.gz", row.names = 1)
51 | }
52 |
53 | }
54 |
55 | print(input_data)
56 | print(target)
57 |
58 | drug_samples = rownames(target_data)
59 | target_data = sapply(target_data, as.numeric)
60 | rownames(target_data) = drug_samples
61 |
62 | lc_cancer_genes = read.csv("./data/meta/lc_cancer_genes.csv", stringsAsFactors = F)
63 |
64 |
65 |
66 | common_samples = intersect(rownames(mutation), drug_samples)
67 |
68 | if (mode != "allgenes"){
69 | mutation = mutation[common_samples, colnames(mutation) %in% lc_cancer_genes$cancer_genes]
70 | cnv = cnv[common_samples, colnames(cnv) %in% lc_cancer_genes$cancer_genes]
71 | rna = rna[common_samples, colnames(rna) %in% lc_cancer_genes$cancer_genes]
72 | }
73 |
74 | target_data = target_data[common_samples,]
75 |
76 |
77 | res_df <- data.frame(drug_id=character(),
78 | run=character(),
79 | corr=numeric(),
80 | r2=numeric(),
81 | mae=numeric(),
82 | rmse=numeric(),
83 | time=numeric(),
84 | stringsAsFactors=FALSE)
85 |
86 | for (i in 1:25){
87 | train_samples = intersect(train_splits[, i], drug_samples)
88 | val_samples = intersect(val_splits[, i], drug_samples)
89 |
90 | mutation_train = mutation[rownames(mutation) %in% train_samples,]
91 | mutation_val = mutation[rownames(mutation) %in% val_samples,]
92 |
93 | cnv_train = cnv[rownames(cnv) %in% train_samples,]
94 | cnv_val = cnv[rownames(cnv) %in% val_samples,]
95 |
96 | rna_train = rna[rownames(rna) %in% train_samples,]
97 | rna_val = rna[rownames(rna) %in% val_samples,]
98 |
99 | train_data = list(mutation = mutation_train,
100 | CNV = cnv_train,
101 | RNA = rna_train)
102 |
103 | val_data = list(mutation = mutation_val,
104 | CNV = cnv_val,
105 | RNA = rna_val)
106 |
107 | start_time <- Sys.time()
108 | N_comp = 50
109 | model = block.spls(train_data, target_data[rownames(target_data) %in% train_samples, ], ncomp = N_comp)
110 | target_val = target_data[rownames(target_data) %in% val_samples, ]
111 | pred = predict(model, val_data)
112 | pred_values = pred$AveragedPredict[,,N_comp]
113 | end_time <- Sys.time()
114 | time_taken = as.numeric(end_time - start_time)
115 |
116 | for (drug_idx in 1:dim(target_data)[2]){
117 | curr_pred = pred_values[,drug_idx]
118 | curr_target = target_val[,drug_idx]
119 | pred_mean = mean(na.omit(curr_pred))
120 | curr_pred[is.na(curr_pred)] = pred_mean
121 | curr_pred = curr_pred[!is.na(curr_target)]
122 | curr_target = curr_target[!is.na(curr_target)]
123 | pcorr = cor(curr_pred, curr_target, method="pearson")
124 | r2 = 1 - sum((curr_target - curr_pred)^2) / sum((curr_target - mean(curr_target))^2)
125 | mae = mean(abs(curr_target - curr_pred))
126 | rmse_drug = rmse(curr_target, curr_pred)
127 |
128 | next_row = c(colnames(target_data)[drug_idx], paste0("cv_", i-1), pcorr, r2, mae,rmse_drug, time_taken)
129 | res_df[nrow(res_df) + 1,] = next_row
130 | }
131 |
132 | print(paste0("run ", i, " completed"))
133 | flush.console()
134 | }
135 |
136 | filename = paste0("./work_dirs/mixOmics/", input_data, "_mutation_cnv_rna_", target, "_", mode, "_", N_comp, "_comp.csv")
137 | write.csv(res_df, filename, row.names = F)
138 |
--------------------------------------------------------------------------------
/R/mixOmicsTCGA.R:
--------------------------------------------------------------------------------
1 | library(mixOmics)
2 | library(pROC)
3 | library(mltest)
4 |
5 | # input_data = "tcga_23_cancer_types"
6 | # input_data = "tcga_brca"
7 |
8 | # target = "all"
9 | # target = "brca_subtypes"
10 |
11 | args = commandArgs(trailingOnly=TRUE)
12 |
13 | input_data = args[1]
14 |
15 | mode = args[2]
16 |
17 | if (input_data == "tcga_23_cancer_types") {
18 | mutation = read.csv("./data/processed/omics/tcga_23_cancer_types_mutation.csv.gz", row.names = 1)
19 | cnv = read.csv("./data/processed/omics/tcga_23_cancer_types_cnv.csv.gz", row.names = 1)
20 | rna = read.csv("./data/processed/omics/tcga_23_cancer_types_rna.csv.gz", row.names = 1)
21 | target_data = read.csv("./data/processed/cancer_type/tcga_23_cancer_types_mutation_cnv_rna.csv", row.names = 1)
22 | train_splits = read.csv("./data/meta/25splits_train_tcga.csv", stringsAsFactors = F)
23 | val_splits = read.csv("./data/meta/25splits_val_tcga.csv", stringsAsFactors = F)
24 | } else {
25 | mutation = read.csv("./data/processed/omics/tcga_brca_mutation.csv.gz", row.names = 1)
26 | cnv = read.csv("./data/processed/omics/tcga_brca_cnv.csv.gz", row.names = 1)
27 | rna = read.csv("./data/processed/omics/tcga_brca_rna.csv.gz", row.names = 1)
28 | target_data = read.csv("./data/processed/cancer_type/tcga_brca_mutation_cnv_rna_subtypes.csv", row.names = 1)
29 | train_splits = read.csv("./data/meta/25splits_train_tcga_brca.csv", stringsAsFactors = F)
30 | val_splits = read.csv("./data/meta/25splits_val_tcga_brca.csv", stringsAsFactors = F)
31 | }
32 |
33 | mutation[is.na(mutation)] <- 0
34 | cnv[is.na(cnv)] <- 0
35 | rna[is.na(rna)] <- 0
36 | samples = rownames(target_data)
37 | target_data = sapply(target_data, as.numeric)
38 | rownames(target_data) = samples
39 |
40 | lc_cancer_genes = read.csv("./data/meta/lc_cancer_genes.csv", stringsAsFactors = F)
41 |
42 |
43 |
44 | common_samples = intersect(rownames(mutation), samples)
45 |
46 | # mutation = mutation[common_samples, colnames(mutation) %in% lc_cancer_genes$cancer_genes]
47 | # cnv = cnv[common_samples, colnames(cnv) %in% lc_cancer_genes$cancer_genes]
48 | # rna = rna[common_samples, colnames(rna) %in% lc_cancer_genes$cancer_genes]
49 | target_data = target_data[common_samples,]
50 |
51 |
52 | res_df <- data.frame(run=character(),
53 | top1_acc=numeric(),
54 | f1=numeric(),
55 | roc_auc=numeric(),
56 | time=numeric(),
57 | stringsAsFactors=FALSE)
58 | i=1
59 | for (i in 1:25){
60 | train_samples = intersect(train_splits[, i], samples)
61 | val_samples = intersect(val_splits[, i], samples)
62 |
63 | mutation_train = mutation[rownames(mutation) %in% train_samples,]
64 | mutation_val = mutation[rownames(mutation) %in% val_samples,]
65 |
66 | cnv_train = cnv[rownames(cnv) %in% train_samples,]
67 | cnv_val = cnv[rownames(cnv) %in% val_samples,]
68 |
69 | rna_train = rna[rownames(rna) %in% train_samples,]
70 | rna_val = rna[rownames(rna) %in% val_samples,]
71 |
72 | train_data = list(mutation = mutation_train,
73 | CNV = cnv_train,
74 | RNA = rna_train)
75 |
76 | val_data = list(mutation = mutation_val,
77 | CNV = cnv_val,
78 | RNA = rna_val)
79 |
80 | start_time <- Sys.time()
81 | N_comp = 50
82 | model = block.splsda(train_data, target_data[names(target_data) %in% train_samples], ncomp = N_comp)
83 | target_val = target_data[names(target_data) %in% val_samples]
84 | pred = predict(model, val_data)
85 | pred_values = as.numeric(pred$AveragedPredict.class$max.dist[,N_comp])
86 | end_time <- Sys.time()
87 | time_taken = as.numeric(end_time - start_time)
88 |
89 | auc = multiclass.roc(target_val,pred_values)$auc
90 |
91 | y_true = as.factor(target_val)
92 | y_pred = as.factor(pred_values)
93 | levels(y_pred) = levels(y_true)
94 | test_res = ml_test(y_pred, y_true, output.as.table = FALSE)
95 | f1_macro = mean(na.omit(test_res$F1))
96 | accuracy = test_res$accuracy
97 |
98 | next_row = c(paste0("cv_", i-1), accuracy, f1_macro, auc, time_taken)
99 | res_df[nrow(res_df) + 1,] = next_row
100 |
101 | print(paste0("run ", i, " completed"))
102 | flush.console()
103 | }
104 |
105 | filename = paste0("./work_dirs/mixOmics/", input_data, "_mutation_cnv_rna_" , mode, "_", N_comp, "_comp.csv")
106 | write.csv(res_df, filename, row.names = F)
107 |
--------------------------------------------------------------------------------
/R/moCluster.R:
--------------------------------------------------------------------------------
1 | library(mogsa)
2 |
3 | # mutation = read.csv("./data/processed/omics/ccle_df_mutation_drug.csv.gz", row.names = 1)
4 | # cnv = read.csv("./data/processed/omics/ccle_df_cnv_drug.csv.gz", row.names = 1)
5 | # rna = read.csv("./data/processed/omics/ccle_df_rna_drug.csv.gz", row.names = 1)
6 |
7 | # mutation = read.csv("./data/processed/omics/ccle_df_mutation_drug_ctd2.csv.gz", row.names = 1)
8 | # cnv = read.csv("./data/processed/omics/ccle_df_cnv_drug_ctd2.csv.gz", row.names = 1)
9 | # rna = read.csv("./data/processed/omics/ccle_df_rna_drug_ctd2.csv.gz", row.names = 1)
10 |
11 | mutation = read.csv("./data/processed/omics/tcga_23_cancer_types_mutation.csv.gz", row.names = 1)
12 | cnv = read.csv("./data/processed/omics/tcga_23_cancer_types_cnv.csv.gz", row.names = 1)
13 | rna = read.csv("./data/processed/omics/tcga_23_cancer_types_rna.csv.gz", row.names = 1)
14 | mutation[is.na(mutation)] <- 0
15 | cnv[is.na(cnv)] <- 0
16 | rna[is.na(rna)] <- 0
17 |
18 | lc_cancer_genes = read.csv("./data/meta/lc_cancer_genes.csv", stringsAsFactors = F)
19 | lc_cancer_genes$cancer_genes
20 |
21 | common_samples = rownames(mutation)
22 |
23 | # mutation = mutation[, colnames(mutation) %in% lc_cancer_genes$cancer_genes]
24 | # cnv = cnv[, colnames(cnv) %in% lc_cancer_genes$cancer_genes]
25 | # rna = rna[, colnames(rna) %in% lc_cancer_genes$cancer_genes]
26 |
27 | mutation <- sapply(mutation, as.numeric)
28 | cnv <- sapply(cnv, as.numeric)
29 | rna <- sapply(rna, as.numeric)
30 |
31 | mo.combined = list(t(mutation), t(cnv), t(rna))
32 | # mo.combined = list(t(cnv), t(rna), t(protein))
33 | # mo.combined = list(t(methy), t(rna), t(protein))
34 | # mo.combined = list(t(rna), t(protein))
35 | print(length(mo.combined))
36 |
37 | moa <- mbpca(mo.combined, ncomp = 200, k = "all", method = "globalScore",
38 | option = "lambda1", center=TRUE, scale=FALSE, moa = TRUE,
39 | svd.solver = "fast", maxiter = 1000)
40 |
41 | res.df = as.data.frame(moa@fac.scr)
42 | res.df$Cell_line = common_samples
43 | res.df = res.df[, c(dim(res.df)[2], 1:(dim(res.df)[2]-1))]
44 | write.table(res.df, "./data/DR/moCluster/tcga_23_cancer_types_mutation_cnv_rna_allgenes.csv", sep = ",", quote = F, row.names = F)
45 |
46 | # tmp = read.csv("./data/DR/moCluster/ccle_mutation_cnv_rna_gdsc.csv", sep = ",")
47 | # tmp = tmp[, c(2:(dim(tmp)[2]-1), dim(tmp)[2])]
48 | # res.df = tmp
49 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeePathNet
2 | [](https://www.gnu.org/licenses/gpl-3.0)
3 |
4 | 
5 |
6 | Transformer-based deep learning integrates multi-omic data with cancer pathways.
7 | Cai, et al., 2023
8 |
9 | ## Overview
10 |
11 | DeePathNet is a transformer-based deep learning tool that integrates multi-omic data to improve predictions for cancer subtyping and drug response. It combines pathway-level information with deep learning to enhance precision in oncology research.
12 |
13 | ## Features
14 | - **Multi-omic data integration** using a transformer architecture
15 | - **Support for pathway-level feature importance analysis**
16 | - **Pre-trained models** for cancer type classification and drug response prediction
17 | - **Cross-validation and independent test scripts** for model evaluation
18 |
19 | ## Setting up the Coding Environment
20 |
21 | To ensure compatibility and avoid potential issues, it is recommended to use Python 3.8 and PyTorch 1.10. Below are detailed instructions to set up the coding environment:
22 |
23 | 1. **Install Anaconda**
24 | - Follow the [Anaconda installation guide](https://docs.anaconda.com/free/anaconda/install/index.html) for your operating system.
25 |
26 | 2. **Create a Virtual Environment**
27 | - Once Anaconda is installed, create and activate a virtual environment:
28 | ```bash
29 | conda create -n deepathnet_env python=3.8 anaconda
30 | conda activate deepathnet_env
31 | ```
32 |
33 | 3. **Install PyTorch**
34 | - Install the appropriate version of PyTorch based on your hardware:
35 | - **For CUDA-enabled systems**:
36 | ```bash
37 | pip install torch==1.10.0 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
38 | ```
39 | - **For CPU-only systems**:
40 | ```bash
41 | pip install torch==1.10.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
42 | ```
43 |
44 | 4. **Install Additional Dependencies**
45 | - DeePathNet requires several Python packages to run effectively. Create a `requirements.txt` file with the following contents:
46 | ```plaintext
47 | torch==1.10.0
48 | torchvision
49 | torchaudio
50 | numpy
51 | pandas
52 | scikit-learn
53 | matplotlib
54 | seaborn
55 | json5
56 | scipy
57 | tqdm
58 | ```
59 | - Install all dependencies:
60 | ```bash
61 | pip install -r requirements.txt
62 | ```
63 |
64 | ## Loading Pre-trained DeePathNet Model(s) with Test Data
65 |
66 | DeePathNet provides pre-trained models and test datasets to facilitate model evaluation:
67 |
68 | 1. **Download Pre-trained Models and Test Data**
69 | - Access pre-trained models and test data from the [Figshare repository](https://doi.org/10.6084/m9.figshare.24137619).
70 | - Save the files to local directories, such as `models/` for models and `data/` for test data.
71 |
72 | 2. **Configure Paths for Models and Test Data**
73 | - Update the paths in the configuration file, such as:
74 | ```json
75 | {
76 | "model": "DeePathNet",
77 | "pretrained_model_path": "models/deepathnet_pretrained.pth",
78 | "test_data_path": "data/test_data.csv",
79 | "output_dir": "results/",
80 | ...
81 | }
82 | ```
83 | - Ensure that the paths in `pretrained_model_path` and `test_data_path` are correctly set to the local files.
84 |
85 | 3. **Load and Run Pre-trained DeePathNet Model**
86 | - DeePathNet can be run using the `deepathnet_independent_test.py` script, which loads a pre-trained model and performs inference:
87 | ```bash
88 | python scripts/deepathnet_independent_test.py configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna_prot/deepathnet_mutation_cnv_rna_prot.json
89 | ```
90 | - This command will load the pre-trained model and run inference on the specified test dataset. The results will be saved to the designated output directory as specified in the configuration file.
91 |
92 | ## Running the Inference Step to Generate Predictions
93 |
94 | The following example demonstrates how to generate predictions using DeePathNet for various tasks:
95 |
96 | 1. **Predict Drug Response**
97 | - To predict drug response (IC50 values), run:
98 | ```bash
99 | python scripts/deepathnet_independent_test.py configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna_prot/deepathnet_mutation_cnv_rna_prot.json
100 | ```
101 | - This script reads the configuration, loads the pre-trained model, and performs inference on the test dataset.
102 |
103 | 2. **Classify Cancer Types**
104 | - For cancer type classification:
105 | ```bash
106 | python scripts/deepathnet_cv.py configs/tcga_all_cancer_types/mutation_cnv_rna/deepathnet_mutation_cnv_rna.json
107 | ```
108 | - This script performs cross-validation using a specified dataset and configuration file.
109 |
110 | 3. **Breast Cancer Subtyping**
111 | - For breast cancer subtyping:
112 | ```bash
113 | python scripts/deepathnet_independent_test.py configs/tcga_train_cptac_test_brca/cnv_rna/deepathnet_cnv_rna.json
114 | ```
115 |
116 | ### Output Description
117 |
118 | - The predictions generated by DeePathNet (e.g., IC50 values or cancer subtypes) are saved in the output directory defined in the configuration file.
119 | - The output includes performance metrics, predictions, and optional feature importance scores.
120 |
121 | ## Running Baseline Comparisons
122 |
123 | To compare DeePathNet with baseline models like moCluster and mixOmics, use the provided scripts:
124 |
125 | 1. **moCluster Baseline Comparison**
126 | ```bash
127 | python scripts/baseline_ec_cv.py configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/moCluster_rf_allgenes_drug_mutation_cnv_rna.json
128 | ```
129 | 2. **Cancer Type Baseline Comparison**
130 | ```bash
131 | python scripts/cancer_type_baseline_23cancertypes.py
132 | ```
133 |
134 | ## Running Feature Importance Analysis
135 |
136 | DeePathNet supports pathway-level and gene-level feature importance analysis:
137 |
138 | 1. **Pathway-level Feature Importance**
139 | ```bash
140 | python scripts/transformer_explantion_cancer_type.py configs/tcga_brca_subtypes/mutation_cnv_rna/deepathnet_allgenes_mutation_cnv_rna.json
141 | ```
142 |
143 | 2. **Gene-level Feature Importance**
144 | ```bash
145 | python scripts/transformer_shap_cancer_type.py configs/tcga_brca_subtypes/mutation_cnv_rna/deepathnet_allgenes_mutation_cnv_rna.json
146 | ```
147 |
148 | ## Data Input
149 |
150 | The input files for DeePathNet should have samples as rows and features as columns. Features should be formatted with an underscore separating the gene name and the omic data type (e.g., `GeneA_RNA`). For example:
151 |
152 | | Sample | GeneA_RNA | GeneA_PROT | GeneB_RNA | GeneB_PROT |
153 | |------------|-----------|------------|-----------|------------|
154 | | Cell_lineA | 10 | 8 | 2 | 3 |
155 | | Cell_lineB | 15 | 12 | 1 | 2 |
156 | | Cell_lineC | 5 | 3 | 10 | 8 |
157 |
158 | ## Data Output
159 |
160 | The output includes predictions such as:
161 | - Drug response (IC50 values)
162 | - Cancer types/subtypes
163 | - Feature importance scores for interpretability
164 |
165 | ## Troubleshooting and Raising Issues
166 |
167 | We recommend using the specified Python and PyTorch versions for compatibility. If issues arise, please open a ticket in the Issues tab with details about your setup, the steps you followed, and error logs.
168 |
169 | ## Contact
170 |
171 | For more information, please contact the study authors via the associated publication.
--------------------------------------------------------------------------------
/configs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/configs/.DS_Store
--------------------------------------------------------------------------------
/configs/ccle_gdsc_intersection3/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/configs/ccle_gdsc_intersection3/.DS_Store
--------------------------------------------------------------------------------
/configs/ccle_gdsc_intersection3/mutation_cnv_rna/deepathnet_allgenes_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/ccle_df_intersection3_drug.csv.gz",
3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/ccle_gdsc_intersection3_wide.csv.gz",
4 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
5 | "model": "DeePathNet",
6 | "do_cv": true,
7 | "work_dir": "/home/scai/DeePathNet/work_dirs/ccle_gdsc_intersection3/DeePathNet/mutation_cnv_rna",
8 | "data_type": ["mutation", "cnv", "rna"],
9 | "task": "regression",
10 | "seed": 1,
11 | "num_repeat": 5,
12 | "batch_size": 100,
13 | "num_workers": 1,
14 | "log_freq": 20,
15 | "num_of_epochs": 400,
16 | "dim": 512,
17 | "mlp_ratio": 2,
18 | "out_mlp_ratio": 8,
19 | "heads": 16,
20 | "depth": 2,
21 | "dropout": 0,
22 | "emb_dropout": 0,
23 | "pathway_dropout": 0.5,
24 | "weight_decay": 1e-5,
25 | "cancer_only": false,
26 | "lr": 3e-5,
27 | "save_checkpoints": false,
28 | "save_scores": true,
29 | "drop_last": true,
30 | "suffix": "_allgenes",
31 | "saved_model": ""
32 | }
--------------------------------------------------------------------------------
/configs/ccle_gdsc_intersection3/mutation_cnv_rna/ec_en_allgenes_drug_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/ccle_df_intersection3_drug.csv.gz",
3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/ccle_gdsc_intersection3_wide.csv.gz",
4 | "model": "en",
5 | "num_repeat": 5,
6 | "cv": 5,
7 | "seed": 1,
8 | "work_dir": "/home/scai/DeePathNet/work_dirs/ccle_gdsc_intersection3/ec_en/mutation_cnv_rna",
9 | "data_type": ["mutation","cnv","rna"],
10 | "task": "regression"
11 | }
--------------------------------------------------------------------------------
/configs/ccle_gdsc_intersection3/mutation_cnv_rna/ec_rf_allgenes_drug_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/ccle_df_intersection3_drug.csv.gz",
3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/ccle_gdsc_intersection3_wide.csv.gz",
4 | "model": "rf",
5 | "num_repeat": 5,
6 | "cv": 5,
7 | "seed": 1,
8 | "params_grid": {
9 | "n_estimators": [400, 600, 800],
10 | "max_features": ["sqrt"],
11 | "n_jobs": [4]
12 | },
13 | "work_dir": "/home/scai/DeePathNet/work_dirs/ccle_gdsc_intersection3/ec_rf/mutation_cnv_rna",
14 | "data_type": ["mutation","cnv","rna"],
15 | "task": "regression"
16 | }
--------------------------------------------------------------------------------
/configs/ccle_gdsc_intersection3/mutation_cnv_rna/moCluster_rf_allgenes_drug_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/DR/moCluster/ccle_mutation_cnv_rna_gdsc_allgenes.csv",
3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/ccle_gdsc_intersection3_wide.csv.gz",
4 | "model": "rf",
5 | "num_repeat": 5,
6 | "cv": 5,
7 | "seed": 1,
8 | "work_dir": "/home/scai/DeePathNet/work_dirs/ccle_gdsc_intersection3/moCluster_rf/mutation_cnv_rna",
9 | "data_type": ["DR"],
10 | "task": "regression"
11 | }
--------------------------------------------------------------------------------
/configs/ccle_gdsc_intersection3/mutation_cnv_rna/move_rf_allgenes_drug_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/DR/MOVE/ccle_mutation_cnv_rna_200factor.csv",
3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/ccle_gdsc_intersection3_wide.csv.gz",
4 | "model": "rf",
5 | "num_repeat": 5,
6 | "cv": 5,
7 | "seed": 1,
8 | "work_dir": "/home/scai/DeePathNet/work_dirs/ccle_gdsc_intersection3/move_rf/mutation_cnv_rna",
9 | "data_type": ["DR"],
10 | "task": "regression"
11 | }
--------------------------------------------------------------------------------
/configs/ccle_gdsc_intersection3/mutation_cnv_rna/pca_rf_allgenes_drug_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/DR/pca/ccle_mutation_cnv_rna_gdsc_allgenes.csv",
3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/ccle_gdsc_intersection3_wide.csv.gz",
4 | "model": "rf",
5 | "num_repeat": 5,
6 | "cv": 5,
7 | "seed": 1,
8 | "work_dir": "/home/scai/DeePathNet/work_dirs/ccle_gdsc_intersection3/pca_rf/mutation_cnv_rna",
9 | "data_type": ["DR"],
10 | "task": "regression"
11 | }
--------------------------------------------------------------------------------
/configs/ccle_gdsc_intersection3/mutation_cnv_rna/scvaeit_rf_allgenes_drug_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/DR/scVAEIT/ccle_scvaeit_latent_200factor.csv",
3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/ccle_gdsc_intersection3_wide.csv.gz",
4 | "model": "rf",
5 | "num_repeat": 5,
6 | "cv": 5,
7 | "seed": 1,
8 | "work_dir": "/home/scai/DeePathNet/work_dirs/ccle_gdsc_intersection3/scvaeit_rf/mutation_cnv_rna",
9 | "data_type": ["DR"],
10 | "task": "regression"
11 | }
--------------------------------------------------------------------------------
/configs/sanger_gdsc_intersection_noprot/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/configs/sanger_gdsc_intersection_noprot/.DS_Store
--------------------------------------------------------------------------------
/configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/deepathnet_allgenes_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv.gz",
3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
5 | "model": "DeePathNet",
6 | "do_cv": true,
7 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_gdsc_intersection_noprot/DeePathNet/mutation_cnv_rna",
8 | "data_type": ["mutation", "cnv", "rna"],
9 | "task": "regression",
10 | "seed": 1,
11 | "num_repeat": 5,
12 | "batch_size": 100,
13 | "num_workers": 1,
14 | "log_freq": 20,
15 | "num_of_epochs": 400,
16 | "dim": 512,
17 | "mlp_ratio": 2,
18 | "out_mlp_ratio": 8,
19 | "heads": 16,
20 | "depth": 2,
21 | "dropout": 0,
22 | "emb_dropout": 0,
23 | "pathway_dropout": 0.5,
24 | "weight_decay": 1e-5,
25 | "cancer_only": false,
26 | "lr": 1e-5,
27 | "save_checkpoints": false,
28 | "save_scores": true,
29 | "drop_last": true,
30 | "suffix": "_all_genes",
31 | "saved_model": ""
32 | }
--------------------------------------------------------------------------------
/configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/ec_en_allgenes_drug_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv.gz",
3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "model": "en",
5 | "num_repeat": 5,
6 | "cv": 5,
7 | "seed": 1,
8 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_gdsc_intersection_noprot/ec_en/mutation_cnv_rna",
9 | "data_type": ["mutation","cnv","rna"],
10 | "task": "regression"
11 | }
--------------------------------------------------------------------------------
/configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/ec_rf_allgenes_drug_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv.gz",
3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "model": "rf",
5 | "num_repeat": 5,
6 | "cv": 5,
7 | "seed": 1,
8 | "params_grid": {
9 | "n_estimators": [400, 600, 800],
10 | "max_features": ["sqrt"],
11 | "n_jobs": [4]
12 | },
13 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_gdsc_intersection_noprot/ec_rf/mutation_cnv_rna",
14 | "data_type": ["mutation","cnv","rna"],
15 | "task": "regression"
16 | }
--------------------------------------------------------------------------------
/configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/moCluster_rf_allgenes_drug_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/DR/moCluster/sanger_mutation_cnv_rna_allgenes.csv",
3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "model": "rf",
5 | "num_repeat": 5,
6 | "cv": 5,
7 | "seed": 1,
8 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_gdsc_intersection_noprot/moCluster_rf/mutation_cnv_rna",
9 | "data_type": ["DR"],
10 | "task": "regression"
11 | }
--------------------------------------------------------------------------------
/configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/move_rf_allgenes_drug_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/DR/MOVE/sanger_mutation_cnv_rna_200factor.csv",
3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "model": "rf",
5 | "num_repeat": 5,
6 | "cv": 5,
7 | "seed": 1,
8 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_gdsc_intersection_noprot/move_rf/mutation_cnv_rna",
9 | "data_type": ["DR"],
10 | "task": "regression"
11 | }
--------------------------------------------------------------------------------
/configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/pca_rf_allgenes_drug_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/DR/pca/sanger_mutation_cnv_rna_allgenes.csv",
3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "model": "rf",
5 | "num_repeat": 5,
6 | "cv": 5,
7 | "seed": 1,
8 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_gdsc_intersection_noprot/pca_rf/mutation_cnv_rna",
9 | "data_type": ["DR"],
10 | "task": "regression"
11 | }
--------------------------------------------------------------------------------
/configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/scvaeit_rf_allgenes_drug_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/DR/scVAEIT/sanger_scvaeit_latent_200factor.csv",
3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "model": "rf",
5 | "num_repeat": 5,
6 | "cv": 5,
7 | "seed": 1,
8 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_gdsc_intersection_noprot/scvaeit_rf/mutation_cnv_rna",
9 | "data_type": ["DR"],
10 | "task": "regression"
11 | }
--------------------------------------------------------------------------------
/configs/sanger_train_ccle_test_gdsc/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/configs/sanger_train_ccle_test_gdsc/.DS_Store
--------------------------------------------------------------------------------
/configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv",
6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
7 | "model": "DeePathNet",
8 | "do_cv": true,
9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna",
10 | "data_type": ["mutation", "cnv", "rna"],
11 | "task": "regression",
12 | "seed": 1,
13 | "num_repeat": 1,
14 | "batch_size": 100,
15 | "num_workers": 1,
16 | "log_freq": 20,
17 | "num_of_epochs": 400,
18 | "dim": 512,
19 | "mlp_ratio": 2,
20 | "out_mlp_ratio": 8,
21 | "heads": 16,
22 | "depth": 2,
23 | "dropout": 0,
24 | "emb_dropout": 0,
25 | "pathway_dropout": 0.5,
26 | "weight_decay": 1e-5,
27 | "cancer_only": true,
28 | "lr": 1e-5,
29 | "save_checkpoints": true,
30 | "save_scores": false,
31 | "drop_last": true,
32 | "suffix": "_DeePathNet",
33 | "saved_model": ""
34 | }
--------------------------------------------------------------------------------
/configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_ds_100.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv",
6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
7 | "model": "DeePathNet",
8 | "do_cv": true,
9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna",
10 | "data_type": ["mutation", "cnv", "rna"],
11 | "task": "regression",
12 | "seed": 1,
13 | "num_repeat": 10,
14 | "batch_size": 100,
15 | "num_workers": 1,
16 | "log_freq": 20,
17 | "num_of_epochs": 400,
18 | "dim": 512,
19 | "mlp_ratio": 2,
20 | "out_mlp_ratio": 8,
21 | "heads": 16,
22 | "depth": 2,
23 | "dropout": 0,
24 | "emb_dropout": 0,
25 | "pathway_dropout": 0.5,
26 | "weight_decay": 1e-5,
27 | "cancer_only": true,
28 | "lr": 1e-5,
29 | "downsample": 100,
30 | "save_checkpoints": false,
31 | "save_scores": true,
32 | "drop_last": true,
33 | "suffix": "_DeePathNet_ds_100",
34 | "saved_model": ""
35 | }
--------------------------------------------------------------------------------
/configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_ds_200.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv",
6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
7 | "model": "DeePathNet",
8 | "do_cv": true,
9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna",
10 | "data_type": ["mutation", "cnv", "rna"],
11 | "task": "regression",
12 | "seed": 1,
13 | "num_repeat": 10,
14 | "batch_size": 100,
15 | "num_workers": 1,
16 | "log_freq": 20,
17 | "num_of_epochs": 400,
18 | "dim": 512,
19 | "mlp_ratio": 2,
20 | "out_mlp_ratio": 8,
21 | "heads": 16,
22 | "depth": 2,
23 | "dropout": 0,
24 | "emb_dropout": 0,
25 | "pathway_dropout": 0.5,
26 | "weight_decay": 1e-5,
27 | "cancer_only": true,
28 | "lr": 1e-5,
29 | "downsample": 200,
30 | "save_checkpoints": false,
31 | "save_scores": true,
32 | "drop_last": true,
33 | "suffix": "_DeePathNet_ds_200",
34 | "saved_model": ""
35 | }
--------------------------------------------------------------------------------
/configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_ds_300.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv",
6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
7 | "model": "DeePathNet",
8 | "do_cv": true,
9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna",
10 | "data_type": ["mutation", "cnv", "rna"],
11 | "task": "regression",
12 | "seed": 1,
13 | "num_repeat": 10,
14 | "batch_size": 100,
15 | "num_workers": 1,
16 | "log_freq": 20,
17 | "num_of_epochs": 400,
18 | "dim": 512,
19 | "mlp_ratio": 2,
20 | "out_mlp_ratio": 8,
21 | "heads": 16,
22 | "depth": 2,
23 | "dropout": 0,
24 | "emb_dropout": 0,
25 | "pathway_dropout": 0.5,
26 | "weight_decay": 1e-5,
27 | "cancer_only": true,
28 | "lr": 1e-5,
29 | "downsample": 300,
30 | "save_checkpoints": false,
31 | "save_scores": true,
32 | "drop_last": true,
33 | "suffix": "_DeePathNet_ds_300",
34 | "saved_model": ""
35 | }
--------------------------------------------------------------------------------
/configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_ds_400.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv",
6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
7 | "model": "DeePathNet",
8 | "do_cv": true,
9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna",
10 | "data_type": ["mutation", "cnv", "rna"],
11 | "task": "regression",
12 | "seed": 1,
13 | "num_repeat": 10,
14 | "batch_size": 100,
15 | "num_workers": 1,
16 | "log_freq": 20,
17 | "num_of_epochs": 400,
18 | "dim": 512,
19 | "mlp_ratio": 2,
20 | "out_mlp_ratio": 8,
21 | "heads": 16,
22 | "depth": 2,
23 | "dropout": 0,
24 | "emb_dropout": 0,
25 | "pathway_dropout": 0.5,
26 | "weight_decay": 1e-5,
27 | "cancer_only": true,
28 | "lr": 1e-5,
29 | "downsample": 400,
30 | "save_checkpoints": false,
31 | "save_scores": true,
32 | "drop_last": true,
33 | "suffix": "_DeePathNet_ds_400",
34 | "saved_model": ""
35 | }
--------------------------------------------------------------------------------
/configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_random_control.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv",
6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
7 | "model": "DeePathNet",
8 | "do_cv": true,
9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna",
10 | "data_type": ["mutation", "cnv", "rna"],
11 | "task": "regression",
12 | "seed": 1,
13 | "num_repeat": 10,
14 | "batch_size": 100,
15 | "num_workers": 1,
16 | "log_freq": 20,
17 | "num_of_epochs": 400,
18 | "dim": 512,
19 | "mlp_ratio": 2,
20 | "out_mlp_ratio": 8,
21 | "heads": 16,
22 | "depth": 2,
23 | "dropout": 0,
24 | "emb_dropout": 0,
25 | "pathway_dropout": 0.5,
26 | "weight_decay": 1e-5,
27 | "cancer_only": true,
28 | "lr": 1e-5,
29 | "save_checkpoints": false,
30 | "save_scores": true,
31 | "drop_last": true,
32 | "suffix": "_DeePathNet_random_control",
33 | "random_control": true,
34 | "saved_model": ""
35 | }
--------------------------------------------------------------------------------
/configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_random_control_ds_100.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv",
6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
7 | "model": "DeePathNet",
8 | "do_cv": true,
9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna",
10 | "data_type": ["mutation", "cnv", "rna"],
11 | "task": "regression",
12 | "seed": 1,
13 | "num_repeat": 10,
14 | "batch_size": 100,
15 | "num_workers": 1,
16 | "log_freq": 20,
17 | "num_of_epochs": 400,
18 | "dim": 512,
19 | "mlp_ratio": 2,
20 | "out_mlp_ratio": 8,
21 | "heads": 16,
22 | "depth": 2,
23 | "dropout": 0,
24 | "emb_dropout": 0,
25 | "pathway_dropout": 0.5,
26 | "weight_decay": 1e-5,
27 | "cancer_only": true,
28 | "lr": 1e-5,
29 | "downsample": 100,
30 | "save_checkpoints": false,
31 | "save_scores": true,
32 | "drop_last": true,
33 | "suffix": "_DeePathNet_random_control_ds_100",
34 | "random_control": true,
35 | "saved_model": ""
36 | }
--------------------------------------------------------------------------------
/configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_random_control_ds_200.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv",
6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
7 | "model": "DeePathNet",
8 | "do_cv": true,
9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna",
10 | "data_type": ["mutation", "cnv", "rna"],
11 | "task": "regression",
12 | "seed": 1,
13 | "num_repeat": 10,
14 | "batch_size": 100,
15 | "num_workers": 1,
16 | "log_freq": 20,
17 | "num_of_epochs": 400,
18 | "dim": 512,
19 | "mlp_ratio": 2,
20 | "out_mlp_ratio": 8,
21 | "heads": 16,
22 | "depth": 2,
23 | "dropout": 0,
24 | "emb_dropout": 0,
25 | "pathway_dropout": 0.5,
26 | "weight_decay": 1e-5,
27 | "cancer_only": true,
28 | "lr": 1e-5,
29 | "downsample": 200,
30 | "save_checkpoints": false,
31 | "save_scores": true,
32 | "drop_last": true,
33 | "suffix": "_DeePathNet_random_control_ds_200",
34 | "random_control": true,
35 | "saved_model": ""
36 | }
--------------------------------------------------------------------------------
/configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_random_control_ds_300.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv",
6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
7 | "model": "DeePathNet",
8 | "do_cv": true,
9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna",
10 | "data_type": ["mutation", "cnv", "rna"],
11 | "task": "regression",
12 | "seed": 1,
13 | "num_repeat": 10,
14 | "batch_size": 100,
15 | "num_workers": 1,
16 | "log_freq": 20,
17 | "num_of_epochs": 400,
18 | "dim": 512,
19 | "mlp_ratio": 2,
20 | "out_mlp_ratio": 8,
21 | "heads": 16,
22 | "depth": 2,
23 | "dropout": 0,
24 | "emb_dropout": 0,
25 | "pathway_dropout": 0.5,
26 | "weight_decay": 1e-5,
27 | "cancer_only": true,
28 | "lr": 1e-5,
29 | "downsample": 300,
30 | "save_checkpoints": false,
31 | "save_scores": true,
32 | "drop_last": true,
33 | "suffix": "_DeePathNet_random_control_ds_300",
34 | "random_control": true,
35 | "saved_model": ""
36 | }
--------------------------------------------------------------------------------
/configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_random_control_ds_400.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv",
6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
7 | "model": "DeePathNet",
8 | "do_cv": true,
9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna",
10 | "data_type": ["mutation", "cnv", "rna"],
11 | "task": "regression",
12 | "seed": 1,
13 | "num_repeat": 10,
14 | "batch_size": 100,
15 | "num_workers": 1,
16 | "log_freq": 20,
17 | "num_of_epochs": 400,
18 | "dim": 512,
19 | "mlp_ratio": 2,
20 | "out_mlp_ratio": 8,
21 | "heads": 16,
22 | "depth": 2,
23 | "dropout": 0,
24 | "emb_dropout": 0,
25 | "pathway_dropout": 0.5,
26 | "weight_decay": 1e-5,
27 | "cancer_only": true,
28 | "lr": 1e-5,
29 | "downsample": 400,
30 | "save_checkpoints": false,
31 | "save_scores": true,
32 | "drop_last": true,
33 | "suffix": "_DeePathNet_random_control_ds_400",
34 | "random_control": true,
35 | "saved_model": ""
36 | }
--------------------------------------------------------------------------------
/configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/ec_mlp_all_genes_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv.gz",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv",
6 | "pathway_file": "/home/scai/DeepOmicIntegrate/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
7 | "model": "mlp",
8 | "seed": 1,
9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/ec_mlp/mutation_cnv_rna",
10 | "data_type": [
11 | "mutation",
12 | "cnv",
13 | "rna"
14 | ],
15 | "task": "regression"
16 | }
--------------------------------------------------------------------------------
/configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/ec_rf_all_genes_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv.gz",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv",
6 | "model": "rf",
7 | "seed": 1,
8 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/ec_rf/mutation_cnv_rna",
9 | "data_type": [
10 | "mutation",
11 | "cnv",
12 | "rna"
13 | ],
14 | "task": "regression"
15 | }
--------------------------------------------------------------------------------
/configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna_prot/deepathnet_mutation_cnv_rna_prot.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_mutation_cnv_rna_prot_drug.csv.gz",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test_mutation_cnv_rna_prot.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test_mutation_cnv_rna_prot.csv",
6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
7 | "model": "DeePathNet",
8 | "do_cv": true,
9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna_prot",
10 | "data_type": ["mutation", "cnv", "rna", "prot"],
11 | "task": "regression",
12 | "seed": 1,
13 | "num_repeat": 1,
14 | "batch_size": 100,
15 | "num_workers": 1,
16 | "log_freq": 20,
17 | "num_of_epochs": 500,
18 | "dim": 512,
19 | "mlp_ratio": 2,
20 | "out_mlp_ratio": 8,
21 | "heads": 16,
22 | "depth": 2,
23 | "dropout": 0,
24 | "emb_dropout": 0,
25 | "pathway_dropout": 0.5,
26 | "weight_decay": 1e-5,
27 | "cancer_only": true,
28 | "lr": 1e-5,
29 | "save_checkpoints": false,
30 | "save_scores": true,
31 | "drop_last": true,
32 | "suffix": "_DeePathNet",
33 | "saved_model": ""
34 | }
--------------------------------------------------------------------------------
/configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna_prot/deepathnet_mutation_cnv_rna_prot_random_control.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_mutation_cnv_rna_prot_drug.csv.gz",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test_mutation_cnv_rna_prot.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test_mutation_cnv_rna_prot.csv",
6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
7 | "model": "DeePathNet",
8 | "do_cv": true,
9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna_prot",
10 | "data_type": ["mutation", "cnv", "rna", "prot"],
11 | "task": "regression",
12 | "seed": 1,
13 | "num_repeat": 10,
14 | "batch_size": 100,
15 | "num_workers": 1,
16 | "log_freq": 20,
17 | "num_of_epochs": 500,
18 | "dim": 512,
19 | "mlp_ratio": 2,
20 | "out_mlp_ratio": 8,
21 | "heads": 16,
22 | "depth": 2,
23 | "dropout": 0,
24 | "emb_dropout": 0,
25 | "pathway_dropout": 0.5,
26 | "weight_decay": 1e-5,
27 | "cancer_only": true,
28 | "lr": 1e-5,
29 | "save_checkpoints": false,
30 | "save_scores": true,
31 | "drop_last": true,
32 | "suffix": "_DeePathNet_random_control",
33 | "random_control": true,
34 | "saved_model": ""
35 | }
--------------------------------------------------------------------------------
/configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna_prot/ec_rf_all_genes_mutation_cnv_rna_prot.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_mutation_cnv_rna_prot_drug.csv.gz",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test_mutation_cnv_rna_prot.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test_mutation_cnv_rna_prot.csv",
6 | "model": "rf",
7 | "seed": 1,
8 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/ec_rf/mutation_cnv_rna_prot",
9 | "data_type": [
10 | "mutation",
11 | "cnv",
12 | "rna",
13 | "prot"
14 | ],
15 | "task": "regression"
16 | }
--------------------------------------------------------------------------------
/configs/tcga_all_cancer_types/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/configs/tcga_all_cancer_types/.DS_Store
--------------------------------------------------------------------------------
/configs/tcga_all_cancer_types/mutation_cnv_rna/deepathnet_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/tcga_23_cancer_types_mutation_cnv_rna_cancer_only.csv.gz",
3 | "target_file": "/home/scai/DeePathNet/data/processed/cancer_type/tcga_23_cancer_types_mutation_cnv_rna.csv",
4 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
5 | "model": "DeePathNet",
6 | "do_cv": true,
7 | "work_dir": "/home/scai/DeePathNet/work_dirs/tcga_all_cancer_types/DeePathNet/mutation_cnv_rna",
8 | "data_type": ["mutation", "cnv", "rna"],
9 | "task": "multiclass",
10 | "seed": 1,
11 | "num_repeat": 5,
12 | "batch_size": 100,
13 | "num_workers": 1,
14 | "log_freq": 20,
15 | "num_of_epochs": 600,
16 | "dim": 512,
17 | "mlp_ratio": 2,
18 | "out_mlp_ratio": 8,
19 | "heads": 16,
20 | "depth": 2,
21 | "dropout": 0,
22 | "emb_dropout": 0,
23 | "pathway_dropout": 0.5,
24 | "weight_decay": 1e-5,
25 | "cancer_only": true,
26 | "lr": 1e-5,
27 | "save_checkpoints": false,
28 | "save_scores": true,
29 | "drop_last": true,
30 | "suffix": "_23_types_cancer_genes",
31 | "saved_model": ""
32 | }
--------------------------------------------------------------------------------
/configs/tcga_all_cancer_types/mutation_cnv_rna/deepathnet_mutation_cnv_rna_ds_2k.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/tcga_23_cancer_types_mutation_cnv_rna_cancer_only.csv.gz",
3 | "target_file": "/home/scai/DeePathNet/data/processed/cancer_type/tcga_23_cancer_types_mutation_cnv_rna.csv",
4 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
5 | "model": "DeePathNet",
6 | "do_cv": true,
7 | "work_dir": "/home/scai/DeePathNet/work_dirs/tcga_all_cancer_types/DeePathNet/mutation_cnv_rna",
8 | "data_type": ["mutation", "cnv", "rna"],
9 | "task": "multiclass",
10 | "seed": 1,
11 | "num_repeat": 2,
12 | "batch_size": 100,
13 | "num_workers": 1,
14 | "log_freq": 20,
15 | "num_of_epochs": 600,
16 | "dim": 512,
17 | "mlp_ratio": 2,
18 | "out_mlp_ratio": 8,
19 | "heads": 16,
20 | "depth": 2,
21 | "dropout": 0,
22 | "emb_dropout": 0,
23 | "pathway_dropout": 0.5,
24 | "weight_decay": 1e-5,
25 | "cancer_only": true,
26 | "lr": 1e-5,
27 | "save_checkpoints": false,
28 | "save_scores": true,
29 | "drop_last": true,
30 | "downsample": 2000,
31 | "suffix": "_23_types_cancer_genes_ds_2k",
32 | "saved_model": ""
33 | }
--------------------------------------------------------------------------------
/configs/tcga_all_cancer_types/mutation_cnv_rna/deepathnet_mutation_cnv_rna_random_control.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/tcga_23_cancer_types_mutation_cnv_rna_cancer_only.csv.gz",
3 | "target_file": "/home/scai/DeePathNet/data/processed/cancer_type/tcga_23_cancer_types_mutation_cnv_rna.csv",
4 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
5 | "model": "DeePathNet",
6 | "do_cv": true,
7 | "work_dir": "/home/scai/DeePathNet/work_dirs/tcga_all_cancer_types/DeePathNet/mutation_cnv_rna",
8 | "data_type": ["mutation", "cnv", "rna"],
9 | "task": "multiclass",
10 | "seed": 1,
11 | "num_repeat": 5,
12 | "batch_size": 100,
13 | "num_workers": 1,
14 | "log_freq": 20,
15 | "num_of_epochs": 600,
16 | "dim": 512,
17 | "mlp_ratio": 2,
18 | "out_mlp_ratio": 8,
19 | "heads": 16,
20 | "depth": 2,
21 | "dropout": 0,
22 | "emb_dropout": 0,
23 | "pathway_dropout": 0.5,
24 | "weight_decay": 1e-5,
25 | "cancer_only": true,
26 | "lr": 1e-5,
27 | "save_checkpoints": false,
28 | "save_scores": true,
29 | "drop_last": true,
30 | "suffix": "_23_types_cancer_genes_random_control",
31 | "random_control": true,
32 | "saved_model": ""
33 | }
--------------------------------------------------------------------------------
/configs/tcga_all_cancer_types/mutation_cnv_rna/deepathnet_mutation_cnv_rna_random_control_ds_2k.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/tcga_23_cancer_types_mutation_cnv_rna_cancer_only.csv.gz",
3 | "target_file": "/home/scai/DeePathNet/data/processed/cancer_type/tcga_23_cancer_types_mutation_cnv_rna.csv",
4 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
5 | "model": "DeePathNet",
6 | "do_cv": true,
7 | "work_dir": "/home/scai/DeePathNet/work_dirs/tcga_all_cancer_types/DeePathNet/mutation_cnv_rna",
8 | "data_type": ["mutation", "cnv", "rna"],
9 | "task": "multiclass",
10 | "seed": 1,
11 | "num_repeat": 2,
12 | "batch_size": 100,
13 | "num_workers": 1,
14 | "log_freq": 20,
15 | "num_of_epochs": 600,
16 | "dim": 512,
17 | "mlp_ratio": 2,
18 | "out_mlp_ratio": 8,
19 | "heads": 16,
20 | "depth": 2,
21 | "dropout": 0,
22 | "emb_dropout": 0,
23 | "pathway_dropout": 0.5,
24 | "weight_decay": 1e-5,
25 | "cancer_only": true,
26 | "lr": 1e-5,
27 | "save_checkpoints": false,
28 | "save_scores": true,
29 | "drop_last": true,
30 | "downsample": 2000,
31 | "suffix": "_23_types_cancer_genes_random_control_ds_2k",
32 | "random_control": true,
33 | "saved_model": ""
34 | }
--------------------------------------------------------------------------------
/configs/tcga_brca_subtypes/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/configs/tcga_brca_subtypes/.DS_Store
--------------------------------------------------------------------------------
/configs/tcga_brca_subtypes/mutation_cnv_rna/deepathnet_mutation_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/tcga_brca_mutation_cnv_rna_log2.csv.gz",
3 | "target_file": "/home/scai/DeePathNet/data/processed/cancer_type/tcga_brca_mutation_cnv_rna_subtypes.csv",
4 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
5 | "model": "DeePathNet",
6 | "do_cv": true,
7 | "work_dir": "/home/scai/DeePathNet/work_dirs/tcga_brca_subtypes/DeePathNet/mutation_cnv_rna",
8 | "data_type": ["mutation", "cnv", "rna"],
9 | "task": "multiclass",
10 | "seed": 1,
11 | "num_repeat": 5,
12 | "batch_size": 100,
13 | "num_workers": 1,
14 | "log_freq": 20,
15 | "num_of_epochs": 300,
16 | "dim": 512,
17 | "mlp_ratio": 2,
18 | "out_mlp_ratio": 8,
19 | "heads": 16,
20 | "depth": 2,
21 | "dropout": 0,
22 | "emb_dropout": 0,
23 | "pathway_dropout": 0.5,
24 | "weight_decay": 1e-5,
25 | "cancer_only": true,
26 | "lr": 1e-5,
27 | "save_checkpoints": true,
28 | "save_scores": false,
29 | "drop_last": true,
30 | "suffix": "_DeePathNet",
31 | "saved_model": "202202081127_Cancer_type.pth"
32 | }
--------------------------------------------------------------------------------
/configs/tcga_train_cptac_test_brca/cnv_rna/deepathnet_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/tcga_brca_as_validation.csv",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/cancer_type/tcga_brca_mutation_cnv_rna_subtypes.csv",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/cptac_as_validation.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/cancer_type/cptac_brca_cnv_rna_subtypes_independent.csv",
6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
7 | "model": "DeePathNet",
8 | "do_cv": true,
9 | "work_dir": "/home/scai/DeePathNet/work_dirs/tcga_train_cptac_test_brca/DeePathNet/cnv_rna",
10 | "data_type": ["cnv", "rna"],
11 | "task": "multiclass",
12 | "seed": 1,
13 | "num_repeat": 1,
14 | "batch_size": 100,
15 | "num_workers": 1,
16 | "log_freq": 20,
17 | "num_of_epochs": 1000,
18 | "dim": 512,
19 | "mlp_ratio": 2,
20 | "out_mlp_ratio": 8,
21 | "heads": 16,
22 | "depth": 2,
23 | "dropout": 0,
24 | "emb_dropout": 0,
25 | "pathway_dropout": 0.5,
26 | "weight_decay": 1e-5,
27 | "cancer_only": true,
28 | "lr": 1e-5,
29 | "save_checkpoints": false,
30 | "save_scores": true,
31 | "drop_last": true,
32 | "suffix": "_cancer_genes",
33 | "saved_model": ""
34 | }
--------------------------------------------------------------------------------
/configs/tcga_train_cptac_test_brca/cnv_rna/deepathnet_cnv_rna_random_control.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/tcga_brca_as_validation.csv",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/cancer_type/tcga_brca_mutation_cnv_rna_subtypes.csv",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/cptac_as_validation.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/cancer_type/cptac_brca_cnv_rna_subtypes_independent.csv",
6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
7 | "model": "DeePathNet",
8 | "do_cv": true,
9 | "work_dir": "/home/scai/DeePathNet/work_dirs/tcga_train_cptac_test_brca/DeePathNet/cnv_rna",
10 | "data_type": ["cnv", "rna"],
11 | "task": "multiclass",
12 | "seed": 1,
13 | "num_repeat": 10,
14 | "batch_size": 100,
15 | "num_workers": 1,
16 | "log_freq": 20,
17 | "num_of_epochs": 1000,
18 | "dim": 512,
19 | "mlp_ratio": 2,
20 | "out_mlp_ratio": 8,
21 | "heads": 16,
22 | "depth": 2,
23 | "dropout": 0,
24 | "emb_dropout": 0,
25 | "pathway_dropout": 0.5,
26 | "weight_decay": 1e-5,
27 | "cancer_only": true,
28 | "lr": 1e-5,
29 | "save_checkpoints": false,
30 | "save_scores": true,
31 | "drop_last": true,
32 | "suffix": "_cancer_genes_random_control",
33 | "random_control": true,
34 | "saved_model": ""
35 | }
--------------------------------------------------------------------------------
/configs/tcga_train_cptac_test_brca/cnv_rna/ec_rf_cnv_rna.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/tcga_brca_as_validation.csv",
3 | "target_file_train": "/home/scai/DeePathNet/data/processed/cancer_type/tcga_brca_mutation_cnv_rna_subtypes.csv",
4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/cptac_as_validation.csv",
5 | "target_file_test": "/home/scai/DeePathNet/data/processed/cancer_type/cptac_brca_cnv_rna_subtypes_independent.csv",
6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv",
7 | "model": "rf",
8 | "seed": 1,
9 | "work_dir": "/home/scai/DeePathNet/work_dirs/tcga_train_cptac_test_brca/ec_rf/cnv_rna",
10 | "data_type": [
11 | "cnv",
12 | "rna"
13 | ],
14 | "task": "multiclass"
15 | }
--------------------------------------------------------------------------------
/figures/Figure1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/figures/Figure1.pdf
--------------------------------------------------------------------------------
/figures/Figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/figures/Figure1.png
--------------------------------------------------------------------------------
/figures/supp_fig1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/figures/supp_fig1.pdf
--------------------------------------------------------------------------------
/figures/supp_fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/figures/supp_fig1.png
--------------------------------------------------------------------------------
/models/cancer_type_pretrained.pth:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:d328d6b7a29d640eb521dfe05aa4ec59c4ab14737ad0adb6afd55b44087245ff
3 | size 92638145
4 |
--------------------------------------------------------------------------------
/models/drug_response_pretrained.pth:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:083f2286bf1f9807d07a370bab3b515a18a84cffb0628d109cd128240ddfd4d8
3 | size 123718849
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | numpy
4 | pandas
5 | scikit-learn
6 | matplotlib
7 | seaborn
8 | scipy
9 | tqdm
10 |
--------------------------------------------------------------------------------
/scripts/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/scripts/.DS_Store
--------------------------------------------------------------------------------
/scripts/baseline_ec_cv.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to run the baseline model for the early concatenation methods using cross-validation for drug response prediction.
3 | E.g. python scripts/baseline_ec_cv.py configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/ec_rf_allgenes_drug_mutation_cnv_rna.json
4 | """
5 |
6 | import json
7 | import logging
8 | import os
9 | import sys
10 | import warnings
11 | from datetime import datetime
12 | from time import time
13 |
14 | import numpy as np
15 | import pandas as pd
16 | from scipy.stats import pearsonr
17 | from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
18 | from sklearn.linear_model import ElasticNet
19 | from sklearn.linear_model import LinearRegression, LogisticRegression
20 | from sklearn.metrics import (
21 | mean_squared_error,
22 | r2_score,
23 | mean_absolute_error,
24 | roc_auc_score,
25 | accuracy_score,
26 | )
27 | from sklearn.model_selection import KFold
28 | from sklearn.neural_network import MLPRegressor, MLPClassifier
29 | from sklearn.svm import SVR, SVC
30 | from tqdm import trange
31 | from xgboost import XGBClassifier, XGBRegressor
32 |
33 | warnings.filterwarnings(action="ignore", category=UserWarning)
34 |
35 | STAMP = datetime.today().strftime("%Y%m%d%H%M")
36 | OUTPUT_NA_NUM = -100
37 |
38 | config_file = sys.argv[1]
39 |
40 | # load model configs
41 | configs = json.load(open(config_file, "r"))
42 |
43 | log_suffix = f"{config_file.split('/')[-1].replace('.json', '')}"
44 | if not os.path.isdir(configs["work_dir"]):
45 | os.system(f"mkdir -p {configs['work_dir']}")
46 |
47 | data_file = configs["data_file"]
48 | target_file = configs["target_file"]
49 | data_type = configs["data_type"]
50 |
51 | log_file = f"{STAMP}_{log_suffix}.log"
52 | logger = logging.getLogger("baseline_ec")
53 | logger.setLevel(logging.DEBUG)
54 | fh = logging.FileHandler(os.path.join(configs["work_dir"], log_file))
55 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
56 | fh.setFormatter(formatter)
57 | logger.addHandler(fh)
58 |
59 | logger.info(open(config_file, "r").read())
60 | print(open(config_file, "r").read())
61 |
62 | seed = configs["seed"]
63 | cv = KFold(n_splits=configs["cv"], shuffle=True, random_state=seed)
64 |
65 | if configs["task"].lower() == "classification":
66 | model_dict = {
67 | "lr": LogisticRegression(n_jobs=-1, solver="saga"),
68 | "rf": RandomForestClassifier(n_jobs=40, max_features="sqrt"),
69 | "svm": SVC(),
70 | "en": ElasticNet(),
71 | "svm-linear": SVC(kernel="linear"),
72 | "mlp": MLPClassifier(),
73 | "xgb": XGBClassifier(),
74 | }
75 |
76 | else:
77 | model_dict = {
78 | "lr": LinearRegression(),
79 | "rf": RandomForestRegressor(n_jobs=40, max_features="sqrt"),
80 | "svm": SVR(),
81 | "en": ElasticNet(),
82 | "svm-linear": SVR(kernel="linear"),
83 | "mlp": MLPRegressor(),
84 | "xgb": XGBRegressor(),
85 | }
86 |
87 | data_target = pd.read_csv(target_file, index_col=0)
88 |
89 | data_input = pd.read_csv(data_file, index_col=0)
90 | genes = np.unique(([x.split("_")[0] for x in data_input.columns]))
91 | if "pathway_file" in configs:
92 | pathway_dict = {}
93 | pathway_df = pd.read_csv(configs["pathway_file"])
94 | if "min_cancer_publication" in configs:
95 | pathway_df = pathway_df[
96 | pathway_df["Cancer_Publications"] > configs["min_cancer_publication"]
97 | ]
98 | logger.info(
99 | f"Filtering pathway with Cancer_Publications > {configs['min_cancer_publication']}"
100 | )
101 | if "max_gene_num" in configs:
102 | pathway_df = pathway_df[pathway_df["GeneNumber"] < configs["max_gene_num"]]
103 | logger.info(f"Filtering pathway with GeneNumber < {configs['max_gene_num']}")
104 | if "min_gene_num" in configs:
105 | pathway_df = pathway_df[pathway_df["GeneNumber"] > configs["min_gene_num"]]
106 | logger.info(f"Filtering pathway with GeneNumber > {configs['min_gene_num']}")
107 |
108 | pathway_df["genes"] = pathway_df["genes"].map(
109 | lambda x: "|".join([gene for gene in x.split("|") if gene in genes])
110 | )
111 |
112 | for index, row in pathway_df.iterrows():
113 | if row["genes"]:
114 | pathway_dict[row["name"]] = row["genes"].split("|")
115 | cancer_genes = set([y for x in pathway_df["genes"].values for y in x.split("|")])
116 | data_input = data_input[
117 | [
118 | x
119 | for x in data_input.columns
120 | if (x.split("_")[0] in cancer_genes) or (x.split("_")[0] == "tissue")
121 | ]
122 | ]
123 |
124 | if data_type[0] != "DR":
125 | data_input = data_input[
126 | [
127 | x
128 | for x in data_input.columns
129 | if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)
130 | ]
131 | ]
132 |
133 | logger.info(f"Input data shape: {data_input.shape}")
134 | logger.info(f"Target data shape: {data_target.shape}")
135 |
136 | data_input = data_input.fillna(0)
137 | data_target = data_target.fillna(OUTPUT_NA_NUM)
138 |
139 | merged_df = pd.merge(data_target, data_input, on="Cell_line")
140 | cell_lines_all = data_input.index.values
141 | num_targets = data_target.shape[1]
142 |
143 | count = 0
144 | num_repeat = 1 if "num_repeat" not in configs else configs["num_repeat"]
145 | feature_df_list = []
146 | score_df_list = []
147 | time_df_list = []
148 | logger.info(f"Merged df shape: {merged_df.shape}")
149 |
150 | for n in range(num_repeat):
151 | cv = KFold(n_splits=5, shuffle=True, random_state=(seed + n))
152 | for cell_lines_train_index, cell_lines_val_index in cv.split(cell_lines_all):
153 | start_time = time()
154 | for i in trange(num_targets):
155 | train_lines = np.array(cell_lines_all)[cell_lines_train_index]
156 | val_lines = np.array(cell_lines_all)[cell_lines_val_index]
157 | merged_df_train = merged_df[merged_df.index.isin(train_lines)]
158 | merged_df_val = merged_df[merged_df.index.isin(val_lines)]
159 |
160 | y_train = merged_df_train.iloc[:, i]
161 | X_train = merged_df_train.iloc[:, num_targets:]
162 | X_train = X_train[(y_train != OUTPUT_NA_NUM)]
163 | y_train = y_train[y_train != OUTPUT_NA_NUM]
164 |
165 | y_val = merged_df_val.iloc[:, i]
166 | X_val = merged_df_val.iloc[:, num_targets:]
167 | X_val = X_val[(y_val != OUTPUT_NA_NUM)]
168 | y_val = y_val[y_val != OUTPUT_NA_NUM]
169 |
170 | model = model_dict[configs["model"]]
171 |
172 | model.fit(X_train, y_train)
173 | y_pred = model.predict(X_val)
174 | sign = 1 if configs["task"].lower() == "classification" else -1
175 | seconds_elapsed = time() - start_time
176 | if configs["task"].lower() == "classification":
177 | y_confs = model.predict_proba(X_val)
178 | val_auc = roc_auc_score(y_val, y_confs[:, 1])
179 | val_acc = accuracy_score(y_val, y_pred)
180 | score_dict = {
181 | "target": merged_df_train.columns[i],
182 | "run": f"cv_{count}",
183 | "auc": val_auc,
184 | "acc": val_acc,
185 | }
186 | else:
187 | val_mae = mean_absolute_error(y_val, y_pred)
188 | val_rmse = mean_squared_error(y_val, y_pred, squared=False)
189 | val_r2 = r2_score(y_val, y_pred)
190 | val_corr = pearsonr(y_val, y_pred)[0]
191 | score_dict = {
192 | "drug_id": merged_df_train.columns[i],
193 | "run": f"cv_{count}",
194 | "mae": val_mae,
195 | "rmse": val_rmse,
196 | "r2": val_r2,
197 | "corr": val_corr,
198 | }
199 | score_df_list.append(score_dict)
200 |
201 | # record feature importance if possible
202 | end_time = time()
203 | seconds_elapsed = end_time - start_time
204 | logger.info(f"cv_{count}: {seconds_elapsed} seconds")
205 | time_df_list.append({"run": f"cv_{count}", "time": seconds_elapsed})
206 | count += 1
207 | time_df_list = pd.DataFrame(time_df_list)
208 | logger.info(f"All finished.")
209 | score_df = pd.DataFrame(score_df_list)
210 | # logger.info(score_df.median())
211 | time_df_list.to_csv(f"{configs['work_dir']}/time_{STAMP}_{log_suffix}.csv", index=False)
212 |
213 | if "save_scores" not in configs or configs["save_scores"]:
214 | score_df.to_csv(
215 | f"{configs['work_dir']}/scores_{STAMP}_{log_suffix}.csv", index=False
216 | )
217 | # Select only the numeric columns
218 | numeric_columns = score_df.select_dtypes(include=["number"])
219 | # Compute the median of the numeric columns
220 | print(numeric_columns.median())
221 |
--------------------------------------------------------------------------------
/scripts/baseline_independent_test.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to run the baseline model for the early concatenation methods on the independent test set for drug response prediction.
3 | E.g. python scripts/baseline_independent_test.py configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna_prot/ec_rf_all_genes_mutation_cnv_rna_prot.json
4 | """
5 | import json
6 | import logging
7 | import os
8 | import sys
9 | from datetime import datetime
10 |
11 | import numpy as np
12 | import pandas as pd
13 | from scipy.stats import pearsonr
14 | from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
15 | from sklearn.linear_model import ElasticNet
16 | from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error, roc_auc_score, accuracy_score
17 | from sklearn.neural_network import MLPRegressor, MLPClassifier
18 | from sklearn.svm import SVR, SVC
19 | from tqdm import trange
20 |
21 | STAMP = datetime.today().strftime('%Y%m%d%H%M')
22 | OUTPUT_NA_NUM = -100
23 |
24 | config_file = sys.argv[1]
25 |
26 | # load model configs
27 | configs = json.load(open(config_file, 'r'))
28 |
29 | log_suffix = f"{config_file.split('/')[-1].replace('.json', '')}"
30 | if not os.path.isdir(configs['work_dir']):
31 | os.system(f"mkdir -p {configs['work_dir']}")
32 |
33 | log_file = f"{STAMP}_{log_suffix}.log"
34 | logger = logging.getLogger('baseline_ec')
35 | logger.setLevel(logging.DEBUG)
36 | fh = logging.FileHandler(os.path.join(configs['work_dir'], log_file))
37 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
38 | fh.setFormatter(formatter)
39 | logger.addHandler(fh)
40 |
41 | logger.info(open(config_file, 'r').read())
42 | print(open(config_file, 'r').read())
43 |
44 | seed = 12345
45 |
46 | if configs['task'].lower() == 'classification':
47 | model_dict = {'rf': RandomForestClassifier(n_jobs=100),
48 | 'svm': SVC(),
49 | 'en': ElasticNet(),
50 | 'svm-linear': SVC(kernel='linear'),
51 | 'mlp': MLPClassifier(verbose=True)}
52 |
53 | else:
54 | model_dict = {'rf': RandomForestRegressor(n_jobs=100),
55 | 'svm': SVR(),
56 | 'en': ElasticNet(),
57 | 'svm-linear': SVR(kernel='linear'),
58 | 'mlp': MLPRegressor(verbose=True)}
59 |
60 | data_target_train = pd.read_csv(configs['target_file_train'], index_col=0)
61 | data_input_train = pd.read_csv(configs['data_file_train'], index_col=0)
62 | data_target_test = pd.read_csv(configs['target_file_test'], index_col=0)
63 | data_input_test = pd.read_csv(configs['data_file_test'], index_col=0)
64 | data_type = configs['data_type']
65 |
66 | common_features = set(data_input_train.columns).intersection(data_input_test.columns)
67 | data_input_train = data_input_train[common_features]
68 | data_input_test = data_input_test[common_features]
69 |
70 | genes = np.unique(([x.split("_")[0] for x in data_input_train.columns]))
71 | if 'pathway_file' in configs:
72 | pathway_dict = {}
73 | pathway_df = pd.read_csv(configs['pathway_file'])
74 | if 'min_cancer_publication' in configs:
75 | pathway_df = pathway_df[pathway_df['Cancer_Publications'] > configs['min_cancer_publication']]
76 | logger.info(f"Filtering pathway with Cancer_Publications > {configs['min_cancer_publication']}")
77 | if 'max_gene_num' in configs:
78 | pathway_df = pathway_df[pathway_df['GeneNumber'] < configs['max_gene_num']]
79 | logger.info(f"Filtering pathway with GeneNumber < {configs['max_gene_num']}")
80 | if 'min_gene_num' in configs:
81 | pathway_df = pathway_df[pathway_df['GeneNumber'] > configs['min_gene_num']]
82 | logger.info(f"Filtering pathway with GeneNumber > {configs['min_gene_num']}")
83 |
84 | pathway_df['genes'] = pathway_df['genes'].map(
85 | lambda x: "|".join([gene for gene in x.split('|') if gene in genes]))
86 |
87 | for index, row in pathway_df.iterrows():
88 | pathway_dict[row['name']] = row['genes'].split('|')
89 | cancer_genes = set([y for x in pathway_df['genes'].values for y in x.split("|")])
90 | data_input_train = data_input_train[
91 | [x for x in data_input_train.columns if (x.split("_")[0] in cancer_genes) or (x.split("_")[0] == 'tissue')]]
92 | data_input_test = data_input_test[
93 | [x for x in data_input_test.columns if (x.split("_")[0] in cancer_genes) or (x.split("_")[0] == 'tissue')]]
94 |
95 | if data_type[0] != 'DR':
96 | data_input_train = data_input_train[
97 | [x for x in data_input_train.columns if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)]]
98 | data_input_test = data_input_test[
99 | [x for x in data_input_test.columns if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)]]
100 |
101 | logger.info(f"Input trainning data shape: {data_input_train.shape}")
102 | logger.info(f"Input trainning target shape: {data_target_train.shape}")
103 | logger.info(f"Input test data shape: {data_input_test.shape}")
104 | logger.info(f"Input test target shape: {data_target_test.shape}")
105 |
106 | data_input_train = data_input_train.fillna(0)
107 | data_target_train = data_target_train.fillna(OUTPUT_NA_NUM)
108 | data_input_test = data_input_test.fillna(0)
109 | data_target_test = data_target_test.fillna(OUTPUT_NA_NUM)
110 |
111 | train_df = pd.merge(data_target_train, data_input_train, on='Cell_line')
112 | test_df = pd.merge(data_target_test, data_input_test, on='Cell_line')
113 | num_targets = data_target_train.shape[1]
114 |
115 | feature_df_list = []
116 | score_df_list = []
117 | params_df_list = []
118 | for i in trange(num_targets):
119 | y_train = train_df.iloc[:, i]
120 | X_train = train_df.iloc[:, num_targets:]
121 | X_train = X_train[(y_train != OUTPUT_NA_NUM)]
122 | y_train = y_train[y_train != OUTPUT_NA_NUM]
123 |
124 | y_test = test_df.iloc[:, i]
125 | X_test = test_df.iloc[:, num_targets:]
126 | X_test = X_test[(y_test != OUTPUT_NA_NUM)]
127 | y_test = y_test[y_test != OUTPUT_NA_NUM]
128 | logger.info(f"Running {train_df.columns[i]} Train:{X_train.shape} Test:{X_test.shape}")
129 |
130 | model = model_dict[configs['model']]
131 |
132 | model.fit(X_train, y_train)
133 | y_pred = model.predict(X_test)
134 | sign = 1 if configs['task'].lower() == 'classification' else -1
135 | if configs['task'].lower() == 'classification':
136 | y_confs = model.predict_proba(X_test)
137 | val_auc = roc_auc_score(y_test, y_confs[:, 1])
138 | val_acc = accuracy_score(y_test, y_pred)
139 | score_dict = {'target': train_df.columns[i], 'auc': val_auc,
140 | 'acc': val_acc}
141 | else:
142 | val_mae = mean_absolute_error(y_test, y_pred)
143 | val_rmse = mean_squared_error(y_test, y_pred, squared=False)
144 | val_r2 = r2_score(y_test, y_pred)
145 | val_corr = pearsonr(y_test, y_pred)[0]
146 | score_dict = {'drug_id': train_df.columns[i],
147 | 'mae': val_mae, 'rmse': val_rmse,
148 | 'r2': val_r2, 'corr': val_corr}
149 | score_df_list.append(score_dict)
150 |
151 | logger.info(f"All finished.")
152 | score_df = pd.DataFrame(score_df_list)
153 | logger.info(score_df.median())
154 | if 'save_scores' not in configs or configs['save_scores']:
155 | score_df.to_csv(f"{configs['work_dir']}/scores_{STAMP}_{log_suffix}.csv", index=False)
156 |
157 |
--------------------------------------------------------------------------------
/scripts/cancer_type_baseline_23cancertypes.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to run baseline models for cancer type classification using cross validation
3 | E.g. python cancer_type_baseline_23cancertypes.py
4 | """
5 |
6 | import numpy as np
7 | import pandas as pd
8 | from sklearn.ensemble import RandomForestClassifier
9 | from sklearn.linear_model import LogisticRegression, RidgeClassifier
10 | from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
11 | from sklearn.model_selection import KFold
12 | from sklearn.naive_bayes import GaussianNB
13 | from sklearn.neighbors import KNeighborsClassifier
14 | from sklearn.svm import SVC
15 | from sklearn.tree import DecisionTreeClassifier
16 | from sklearn.neural_network import MLPClassifier
17 | from tqdm import tqdm
18 | from xgboost import XGBClassifier
19 | import os
20 |
21 | NUM_REPEAT = 2
22 | num_fold = 5
23 |
24 | seed = 1
25 |
26 |
27 | def run_model(input_df, clf_name, data_type=("mutation", "cnv", "rna")):
28 | count = 0
29 | clf_results_df = []
30 | if data_type[0] != "DR":
31 | input_df = input_df[
32 | [
33 | x
34 | for x in input_df.columns
35 | if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)
36 | ]
37 | ]
38 | num_of_features = input_df.shape[1]
39 | cell_lines_all = input_df.index.values
40 | for n in range(NUM_REPEAT):
41 | cv = KFold(n_splits=num_fold, shuffle=True, random_state=(seed + n))
42 |
43 | for cell_lines_train_index, cell_lines_val_index in tqdm(
44 | cv.split(cell_lines_all), total=num_fold
45 | ):
46 | train_lines = np.array(cell_lines_all)[cell_lines_train_index]
47 | val_lines = np.array(cell_lines_all)[cell_lines_val_index]
48 |
49 | merged_df_train = pd.merge(
50 | input_df[input_df.index.isin(train_lines)], target_df, on=["Cell_line"]
51 | )
52 |
53 | val_data = input_df[input_df.index.isin(val_lines)]
54 |
55 | merged_df_val = pd.merge(val_data, target_df, on=["Cell_line"])
56 | if clf_name == "RF":
57 | clf = RandomForestClassifier(n_jobs=40, max_features="sqrt")
58 | elif clf_name == "XGB":
59 | clf = XGBClassifier()
60 | elif clf_name == "LR":
61 | clf = LogisticRegression(n_jobs=40, solver="saga")
62 | elif clf_name == "KNN":
63 | clf = KNeighborsClassifier()
64 | elif clf_name == "NB":
65 | clf = GaussianNB()
66 | elif clf_name == "MLP":
67 | clf = MLPClassifier(verbose=True)
68 | elif clf_name == "DT":
69 | clf = DecisionTreeClassifier()
70 | else:
71 | raise Exception
72 | clf.fit(
73 | merged_df_train.iloc[:, :num_of_features],
74 | merged_df_train.iloc[:, num_of_features:].values.flatten(),
75 | )
76 | y_pred = clf.predict(merged_df_val.iloc[:, :num_of_features])
77 | y_conf = clf.predict_proba(merged_df_val.iloc[:, :num_of_features])
78 | y_true = merged_df_val.iloc[:, num_of_features:].values.flatten()
79 | acc = accuracy_score(y_true, y_pred)
80 | f1 = f1_score(y_true, y_pred, average="macro")
81 | try:
82 | auc = roc_auc_score(y_true, y_conf, multi_class="ovo")
83 | except:
84 | auc = np.nan
85 |
86 | clf_results_df.append(
87 | {"run": f"cv_{count}", "acc": acc, "f1": f1, "roc_auc": auc}
88 | )
89 |
90 | count += 1
91 | clf_results_df = pd.DataFrame(clf_results_df)
92 | return clf_results_df
93 |
94 |
95 | # input_df = pd.read_csv(
96 | # "../data/processed/omics/tcga_23_cancer_types_mutation_cnv_rna_union.csv",
97 | # index_col=0,
98 | # )
99 |
100 | # genes = np.unique(
101 | # ([x.split("_")[0] for x in input_df.columns if x.split("_")[0] != "tissue"])
102 | # )
103 |
104 | target_df = pd.read_csv(
105 | "../data/processed/cancer_type/tcga_23_cancer_types_mutation_cnv_rna.csv",
106 | index_col=0,
107 | )
108 |
109 | # pathway_dict = {}
110 | # pathway_df = pd.read_csv(
111 | # "../data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv"
112 | # )
113 |
114 | # pathway_df["genes"] = pathway_df["genes"].map(
115 | # lambda x: "|".join([gene for gene in x.split("|") if gene in genes])
116 | # )
117 | # # pathway_df = pathway_df[pathway_df['Cancer_Publications'] > 50]
118 |
119 | # for index, row in pathway_df.iterrows():
120 | # if row["genes"]:
121 | # pathway_dict[row["name"]] = row["genes"].split("|")
122 |
123 | # cancer_genes = list(set([y for x in pathway_df["genes"].values for y in x.split("|")]))
124 | # non_cancer_genes = sorted(set(genes) - set(cancer_genes))
125 |
126 | # class_name_to_id = dict(
127 | # zip(
128 | # sorted(target_df.iloc[:, 0].unique()),
129 | # list(range(target_df.iloc[:, 0].unique().size)),
130 | # )
131 | # )
132 | # id_to_class_name = dict(
133 | # zip(
134 | # list(range(target_df.iloc[:, 0].unique().size)),
135 | # sorted(target_df.iloc[:, 0].unique()),
136 | # )
137 | # )
138 |
139 | # input_df_cancergenes = input_df[
140 | # [x for x in input_df.columns if (x.split("_")[0] in cancer_genes)]
141 | # ]
142 |
143 | # input_df_cancergenes = input_df_cancergenes.fillna(0)
144 | # input_df = input_df.fillna(0)
145 |
146 | dir_path = "../results/tcga_all_cancer_types/"
147 |
148 | if not os.path.exists(dir_path):
149 | os.makedirs(dir_path)
150 |
151 | # %% All Cancer types
152 | # print("Running LR")
153 | # lr_results_df = run_model(input_df, "LR")
154 | # lr_results_df.to_csv("../results/tcga_all_cancer_types/lr_results_allgenes.csv",
155 | # index=False)
156 | # print("Running MLP")
157 | # mlp_results_df = run_model(input_df_cancergenes, "MLP")
158 | # mlp_results_df.to_csv("../results/tcga_all_cancer_types/mlp_results.csv",
159 | # index=False)
160 | # print("Running RF")
161 | # rf_results_df = run_model(input_df, "RF")
162 | # rf_results_df.to_csv("../results/tcga_all_cancer_types/rf_mutation_cnv_rna_allgenes_results.csv",
163 | # index=False)
164 | # print("Running KNN")
165 | # knn_results_df = run_model(input_df, "KNN")
166 | # knn_results_df.to_csv("../results/tcga_all_cancer_types/knn_mutation_cnv_rna_allgenes.csv",
167 | # index=False)
168 | # print("Running DR")
169 | # dt_results_df = run_model(input_df, "DT")
170 | # dt_results_df.to_csv("../results/tcga_all_cancer_types/dt_mutation_cnv_rna_allgenes.csv",
171 | # index=False)
172 | # print("Running PCA")
173 | # pca_input_df = pd.read_csv(
174 | # "../data/DR/pca/tcga_23_cancer_types_mutation_cnv_rna_allgenes.csv",
175 | # index_col=0)
176 | # rf_pca_results_df = run_model(pca_input_df, "RF", data_type = ['DR'])
177 | # rf_pca_results_df.to_csv("../results/tcga_all_cancer_types/pca_rf_mutation_cnv_rna_allgenes.csv",
178 | # index=False)
179 | # print("Running moCluster")
180 | # moCluster_input_df = pd.read_csv(
181 | # "../data/DR/moCluster/tcga_23_cancer_types_mutation_cnv_rna_allgenes.csv",
182 | # index_col=0,
183 | # )
184 | # moCluster_results_df = run_model(moCluster_input_df, "RF", data_type=["DR"])
185 | # moCluster_results_df.to_csv(
186 | # "../results/tcga_all_cancer_types/moCluster_rf_mutation_cnv_rna_allgenes.csv",
187 | # index=False,
188 | # )
189 |
190 | move_input_df = pd.read_csv(
191 | "../data/DR/MOVE/tcga_mutation_cnv_rna_200factor.csv", index_col=0
192 | )
193 | move_results_df = run_model(move_input_df, "RF", data_type=["DR"])
194 | move_results_df.to_csv(
195 | "../results/tcga_all_cancer_types/move_rf_mutation_cnv_rna_200factor.csv",
196 | index=False,
197 | )
198 |
199 | scvaeit_input_df = pd.read_csv(
200 | "../data/DR/scVAEIT/tcga_scvaeit_latent_200factor.csv", index_col=0
201 | )
202 | scvaeit_results_df = run_model(scvaeit_input_df, "RF", data_type=["DR"])
203 | scvaeit_results_df.to_csv(
204 | "../results/tcga_all_cancer_types/scvaeit_rf_mutation_cnv_rna_200factor.csv",
205 | index=False,
206 | )
207 |
--------------------------------------------------------------------------------
/scripts/cancer_type_baseline_brca.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to run baseline models for breast cancer subtype classification using cross validation
3 | E.g. python cancer_type_baseline_brca.py
4 | """
5 |
6 | import numpy as np
7 | import pandas as pd
8 | from sklearn.ensemble import RandomForestClassifier
9 | from sklearn.linear_model import LogisticRegression, RidgeClassifier
10 | from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
11 | from sklearn.model_selection import KFold
12 | from sklearn.naive_bayes import GaussianNB
13 | from sklearn.neighbors import KNeighborsClassifier
14 | from sklearn.svm import SVC
15 | from sklearn.tree import DecisionTreeClassifier
16 | from tqdm import tqdm
17 | from xgboost import XGBClassifier
18 |
19 | NUM_REPATE = 5
20 | num_fold = 5
21 |
22 | seed = 1
23 |
24 |
25 | def run_model(input_df, clf_name, data_type=("mutation", "cnv", "rna")):
26 | count = 0
27 | clf_results_df = []
28 | if data_type[0] != "DR":
29 | input_df = input_df[
30 | [
31 | x
32 | for x in input_df.columns
33 | if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)
34 | ]
35 | ]
36 | num_of_features = input_df.shape[1]
37 | cell_lines_all = input_df.index.values
38 | for n in range(NUM_REPATE):
39 | cv = KFold(n_splits=num_fold, shuffle=True, random_state=(seed + n))
40 |
41 | for cell_lines_train_index, cell_lines_val_index in tqdm(
42 | cv.split(cell_lines_all), total=num_fold
43 | ):
44 | train_lines = np.array(cell_lines_all)[cell_lines_train_index]
45 | val_lines = np.array(cell_lines_all)[cell_lines_val_index]
46 |
47 | merged_df_train = pd.merge(
48 | input_df[input_df.index.isin(train_lines)], target_df, on=["Cell_line"]
49 | )
50 |
51 | val_data = input_df[input_df.index.isin(val_lines)]
52 |
53 | merged_df_val = pd.merge(val_data, target_df, on=["Cell_line"])
54 | if clf_name == "RF":
55 | clf = RandomForestClassifier(n_jobs=40, max_features="sqrt")
56 | elif clf_name == "XGB":
57 | clf = XGBClassifier()
58 | elif clf_name == "LR":
59 | clf = LogisticRegression(n_jobs=40, solver="saga")
60 | elif clf_name == "ridge":
61 | clf = RidgeClassifier()
62 | elif clf_name == "KNN":
63 | clf = KNeighborsClassifier()
64 | elif clf_name == "SVML":
65 | clf = SVC(probability=True, kernel="linear")
66 | elif clf_name == "SVM":
67 | clf = SVC(probability=True, kernel="linear")
68 | elif clf_name == "NB":
69 | clf = GaussianNB()
70 | elif clf_name == "DT":
71 | clf = DecisionTreeClassifier()
72 | else:
73 | raise Exception
74 | clf.fit(
75 | merged_df_train.iloc[:, :num_of_features],
76 | merged_df_train.iloc[:, num_of_features:].values.flatten(),
77 | )
78 | y_pred = clf.predict(merged_df_val.iloc[:, :num_of_features])
79 | y_conf = clf.predict_proba(merged_df_val.iloc[:, :num_of_features])
80 | y_true = merged_df_val.iloc[:, num_of_features:].values.flatten()
81 | acc = accuracy_score(y_true, y_pred)
82 | f1 = f1_score(y_true, y_pred, average="macro")
83 | auc = roc_auc_score(y_true, y_conf, multi_class="ovo")
84 | clf_results_df.append(
85 | {"run": f"cv_{count}", "acc": acc, "f1": f1, "roc_auc": auc}
86 | )
87 |
88 | count += 1
89 | clf_results_df = pd.DataFrame(clf_results_df)
90 | return clf_results_df
91 |
92 |
93 | input_df = pd.read_csv(
94 | "../data/processed/omics/tcga_brca_mutation_cnv_rna_log2.csv.gz", index_col=0
95 | )
96 |
97 | genes = np.unique(
98 | ([x.split("_")[0] for x in input_df.columns if x.split("_")[0] != "tissue"])
99 | )
100 |
101 | # target_df = pd.read_csv(
102 | # "../data/processed/cancer_type/tcga_23_cancer_types_mutation_cnv_rna.csv",
103 | # index_col=0)
104 | target_df = pd.read_csv(
105 | "../data/processed/cancer_type/tcga_brca_mutation_cnv_rna_subtypes.csv", index_col=0
106 | )
107 |
108 |
109 | pathway_dict = {}
110 | pathway_df = pd.read_csv(
111 | "../data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv"
112 | )
113 |
114 | pathway_df["genes"] = pathway_df["genes"].map(
115 | lambda x: "|".join([gene for gene in x.split("|") if gene in genes])
116 | )
117 | # pathway_df = pathway_df[pathway_df['Cancer_Publications'] > 50]
118 |
119 | for index, row in pathway_df.iterrows():
120 | if row["genes"]:
121 | pathway_dict[row["name"]] = row["genes"].split("|")
122 |
123 | cancer_genes = list(set([y for x in pathway_df["genes"].values for y in x.split("|")]))
124 | non_cancer_genes = sorted(set(genes) - set(cancer_genes))
125 |
126 | class_name_to_id = dict(
127 | zip(
128 | sorted(target_df.iloc[:, 0].unique()),
129 | list(range(target_df.iloc[:, 0].unique().size)),
130 | )
131 | )
132 | id_to_class_name = dict(
133 | zip(
134 | list(range(target_df.iloc[:, 0].unique().size)),
135 | sorted(target_df.iloc[:, 0].unique()),
136 | )
137 | )
138 |
139 | input_df_cancergenes = input_df[
140 | [x for x in input_df.columns if (x.split("_")[0] in cancer_genes)]
141 | ]
142 |
143 | input_df_cancergenes = input_df_cancergenes.fillna(0)
144 | input_df = input_df.fillna(0)
145 | num_fold = 5
146 |
147 | NUM_REPATE = 5
148 |
149 | seed = 1
150 |
151 | # %% BRCA
152 | # print("Running LR")
153 | # lr_results_df = run_model(input_df, "LR")
154 | # lr_results_df.to_csv(
155 | # "../results/tcga_brca_subtype/lr_mutation_cnv_rna_allgenes.csv", index=False
156 | # )
157 | # print("Running RF")
158 | # rf_results_df = run_model(input_df, "RF")
159 | # rf_results_df.to_csv(
160 | # "../results/tcga_brca_subtype/rf_mutation_cnv_rna_allgenes.csv", index=False
161 | # )
162 | # print("Running KNN")
163 | # knn_results_df = run_model(input_df, "KNN")
164 | # knn_results_df.to_csv(
165 | # "../results/tcga_brca_subtype/knn_mutation_cnv_rna_allgenes.csv", index=False
166 | # )
167 | # print("Running DR")
168 | # dt_results_df = run_model(input_df, "DT")
169 | # dt_results_df.to_csv(
170 | # "../results/tcga_brca_subtype/dt_mutation_cnv_rna_allgenes.csv", index=False
171 | # )
172 | # print("Running PCA")
173 | # pca_input_df = pd.read_csv(
174 | # "../data/DR/pca/tcga_brca_mutation_cnv_rna_allgenes.csv", index_col=0
175 | # )
176 | # rf_pca_results_df = run_model(pca_input_df, "RF", data_type=["DR"])
177 | # rf_pca_results_df.to_csv(
178 | # "../results/tcga_brca_subtype/pca_rf_mutation_cnv_rna_allgenes.csv", index=False
179 | # )
180 | # print("Running moCluster")
181 | # moCluster_input_df = pd.read_csv(
182 | # "../data/DR/moCluster/tcga_brca_mutation_cnv_rna.csv", index_col=0
183 | # )
184 | # moCluster_results_df = run_model(moCluster_input_df, "RF", data_type=["DR"])
185 | # moCluster_results_df.to_csv(
186 | # "../results/tcga_brca_subtype/moCluster_rf_mutation_cnv_rna_allgenes.csv",
187 | # index=False,
188 | # )
189 |
190 | print("Running MOVE")
191 | MOVE_input_df = pd.read_csv(
192 | "../data/DR/MOVE/tcga_brca_mutation_cnv_rna_200factor.csv", index_col=0
193 | )
194 | MOVE_results_df = run_model(MOVE_input_df, "RF", data_type=["DR"])
195 | MOVE_results_df.to_csv(
196 | "../results/tcga_brca_subtype/MOVE_rf_mutation_cnv_rna_allgenes.csv", index=False
197 | )
198 |
199 | print("Running scVAEIT")
200 | scVAEIT_input_df = pd.read_csv(
201 | "../data/DR/scVAEIT/tcga_brca_scvaeit_latent_200factor.csv", index_col=0
202 | )
203 | scVAEIT_results_df = run_model(scVAEIT_input_df, "RF", data_type=["DR"])
204 | scVAEIT_results_df.to_csv(
205 | "../results/tcga_brca_subtype/scVAEIT_rf_mutation_cnv_rna_allgenes.csv", index=False
206 | )
207 |
--------------------------------------------------------------------------------
/scripts/cancer_type_baseline_brca_validation.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to run baseline models for breast cancer subtype classification using independent test set
3 | E.g. python cancer_type_baseline_brca_validation.py
4 | """
5 | import pandas as pd
6 | from sklearn.ensemble import RandomForestClassifier
7 | from sklearn.linear_model import LogisticRegression
8 | from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
9 | from xgboost import XGBClassifier
10 | from sklearn.neural_network import MLPClassifier
11 | import numpy as np
12 | import os
13 |
14 | seed = 1
15 |
16 | def run_model(input_df_train, input_df_test, clf_name, data_type=('cnv', 'rna')):
17 | count = 0
18 | clf_results_df = []
19 | if data_type[0] != 'DR':
20 | input_df_train = input_df_train[[
21 | x for x in input_df_train.columns
22 | if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)
23 | ]]
24 | num_of_features = input_df_train.shape[1]
25 |
26 | if clf_name == "RF":
27 | clf = RandomForestClassifier(n_jobs=40, max_features='sqrt')
28 | elif clf_name == "XGB":
29 | clf = XGBClassifier(n_jobs=60)
30 | elif clf_name == "LR":
31 | clf = LogisticRegression(n_jobs=40, solver='saga')
32 | elif clf_name == 'MLP':
33 | clf = MLPClassifier(verbose=True)
34 | else:
35 | raise Exception
36 |
37 | merged_df_train = pd.merge(
38 | input_df_train,
39 | target_df_train,
40 | on=['Cell_line'])
41 |
42 | merged_df_val = pd.merge(input_df_test, target_df_test, on=['Cell_line'])
43 | clf.fit(merged_df_train.iloc[:, :-1],
44 | merged_df_train.iloc[:, -1].values.flatten())
45 | y_pred = clf.predict(merged_df_val.iloc[:, :-1])
46 | y_conf = clf.predict_proba(merged_df_val.iloc[:, :-1])
47 | y_true = merged_df_val.iloc[:, -1].values.flatten()
48 | acc = accuracy_score(y_true, y_pred)
49 | f1 = f1_score(y_true, y_pred, average='macro')
50 | auc = roc_auc_score(y_true, y_conf, multi_class='ovo')
51 | clf_results_df.append({
52 | 'run': f'cv_{count}',
53 | 'acc': acc,
54 | 'f1': f1,
55 | 'roc_auc': auc
56 | })
57 | val_res_perclass = {}
58 | val_res_perclass['y_pred'] = y_pred
59 | val_res_perclass['y_true'] = y_true
60 | for i in range(y_conf.shape[1]):
61 | val_res_perclass[f"feature_{i}"] = y_conf[:, i]
62 |
63 | clf_results_df = pd.DataFrame(clf_results_df)
64 | return clf_results_df, pd.DataFrame(val_res_perclass)
65 |
66 |
67 | input_df_train = pd.read_csv(
68 | "../data/processed/omics/tcga_brca_as_validation.csv",
69 | index_col=0)
70 | target_df_train = pd.read_csv(
71 | "../data/processed/cancer_type/tcga_brca_mutation_cnv_rna_subtypes.csv",
72 | index_col=0)
73 | input_df_test = pd.read_csv(
74 | "../data/processed/omics/cptac_as_validation.csv",
75 | index_col=0)
76 | target_df_test = pd.read_csv(
77 | "../data/processed/cancer_type/cptac_brca_cnv_rna_subtypes_independent.csv",
78 | index_col=0)
79 | # target_df_train = pd.read_csv(
80 | # "../data/processed/cancer_type/tcga_23_cancer_types_mutation_cnv_rna.csv",
81 | # index_col=0)
82 | genes = np.unique(([x.split("_")[0] for x in input_df_train.columns if x.split("_")[0] != 'tissue']))
83 | pathway_dict = {}
84 | pathway_df = pd.read_csv(
85 | "../data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv"
86 | )
87 |
88 | pathway_df['genes'] = pathway_df['genes'].map(
89 | lambda x: "|".join([gene for gene in x.split('|') if gene in genes]))
90 | # pathway_df = pathway_df[pathway_df['Cancer_Publications'] > 50]
91 |
92 | for index, row in pathway_df.iterrows():
93 | if row['genes']:
94 | pathway_dict[row['name']] = row['genes'].split('|')
95 |
96 | cancer_genes = list(set([y for x in pathway_df['genes'].values for y in x.split("|")]))
97 | non_cancer_genes = sorted(set(genes) - set(cancer_genes))
98 | input_df_train_cancergenes = input_df_train[
99 | [x for x in input_df_train.columns if (x.split("_")[0] in cancer_genes)]]
100 |
101 | input_df_train_cancergenes = input_df_train_cancergenes.fillna(0)
102 |
103 | input_df_test_cancergenes = input_df_test[
104 | [x for x in input_df_test.columns if (x.split("_")[0] in cancer_genes)]]
105 | input_df_test_cancergenes = input_df_test_cancergenes.fillna(0)
106 |
107 | class_name_to_id = dict(
108 | zip(sorted(target_df_train.iloc[:, 0].unique()),
109 | list(range(target_df_train.iloc[:, 0].unique().size))))
110 | id_to_class_name = dict(
111 | zip(list(range(target_df_train.iloc[:, 0].unique().size)), sorted(target_df_train.iloc[:, 0].unique())))
112 |
113 | input_df_train = input_df_train.fillna(0)
114 | input_df_test = input_df_test.fillna(0)
115 |
116 | dir_path = "../results/tcga_brca_subtype/"
117 |
118 | if not os.path.exists(dir_path):
119 | os.makedirs(dir_path)
120 |
121 | # %% BRCA
122 | # print("Running LR")
123 | # lr_results_df, all_val_df = run_model(input_df_train, input_df_test, "LR")
124 | # all_val_df.columns = [id_to_class_name[int(x.split("_")[-1])] if "feature_" in x else x for x in
125 | # all_val_df.columns]
126 | # lr_results_df.to_csv("../results/tcga_brca_subtype/lr_cnv_rna_allgenes_results_validation.csv",
127 | # index=False)
128 | # all_val_df.to_csv(f"../results/tcga_brca_subtype/lr_all_res_cnv_rna_allgenes_results_validation.csv.gz", index=False)
129 |
130 | print("Running MLP")
131 | mlp_results_df, all_val_df = run_model(input_df_train_cancergenes, input_df_test_cancergenes, "MLP")
132 | all_val_df.columns = [id_to_class_name[int(x.split("_")[-1])] if "feature_" in x else x for x in
133 | all_val_df.columns]
134 | mlp_results_df.to_csv("../results/tcga_brca_subtype/mlp_cnv_rna_results_validation.csv",
135 | index=False)
136 | all_val_df.to_csv(f"../results/tcga_brca_subtype/mlp_all_res_cnv_rna_results_validation.csv.gz", index=False)
137 |
138 | # print("Running RF")
139 | # rf_results_df, val_res_perclass = run_model(input_df_train, input_df_test, "RF")
140 | # all_val_df.columns = [id_to_class_name[int(x.split("_")[-1])] if "feature_" in x else x for x in
141 | # all_val_df.columns]
142 | # rf_results_df.to_csv("../results/tcga_brca_subtype/rf_cnv_rna_allgenes_results_validation.csv",
143 | # index=False)
144 | # all_val_df.to_csv(f"../results/tcga_brca_subtype/rf_all_res_cnv_rna_allgenes_results_validation.csv.gz", index=False)
145 |
146 | # print("Running XGB")
147 | # xgb_results_df, val_res_perclass = run_model(input_df_train, input_df_test, "XGB")
148 | # xgb_results_df.to_csv("../results/tcga_brca_subtype/xgb_cnv_rna_allgenes_validation.csv",
149 | # index=False)
150 |
151 |
--------------------------------------------------------------------------------
/scripts/deepathnet_cv.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to run DeePathNet with cross validation for any task.
3 | E.g. python scripts/deepathnet_cv.py configs/tcga_all_cancer_types/mutation_cnv_rna/deepathnet_allgenes_mutation_cnv_rna.json
4 | """
5 | import json
6 | import sys
7 | from datetime import datetime
8 |
9 | import torch.optim
10 | from sklearn.model_selection import KFold, StratifiedKFold
11 | from torch.utils.data import DataLoader
12 |
13 | from model_transformer_lrp import DeePathNet
14 | from models import *
15 | from utils.training_prepare import prepare_data_cv
16 |
17 | STAMP = datetime.today().strftime("%Y%m%d%H%M")
18 | proj_dir = "/home/scai/DeePathNet"
19 | sys.path.extend([proj_dir])
20 |
21 | config_file = sys.argv[1]
22 | # load model configs
23 | configs = json.load(open(config_file, "r"))
24 |
25 | log_suffix = ""
26 | if "suffix" in configs:
27 | log_suffix = configs["suffix"]
28 |
29 | seed = configs["seed"]
30 | torch.manual_seed(seed)
31 |
32 | BATCH_SIZE = configs["batch_size"]
33 | NUM_WORKERS = 0
34 | LOG_FREQ = configs["log_freq"]
35 | NUM_EPOCHS = configs["num_of_epochs"]
36 | device = "cuda" if torch.cuda.is_available() else "cpu"
37 |
38 |
39 | def get_setup(genes_to_id, id_to_genes, target_dim):
40 | def load_pathway(random_control=False):
41 | pathway_dict = {}
42 | pathway_df = pd.read_csv(configs["pathway_file"])
43 | if "min_cancer_publication" in configs:
44 | pathway_df = pathway_df[
45 | pathway_df["Cancer_Publications"] > configs["min_cancer_publication"]
46 | ]
47 | logger.info(
48 | f"Filtering pathway with Cancer_Publications > {configs['min_cancer_publication']}"
49 | )
50 | if "max_gene_num" in configs:
51 | pathway_df = pathway_df[pathway_df["GeneNumber"] < configs["max_gene_num"]]
52 | logger.info(
53 | f"Filtering pathway with GeneNumber < {configs['max_gene_num']}"
54 | )
55 | if "min_gene_num" in configs:
56 | pathway_df = pathway_df[pathway_df["GeneNumber"] > configs["min_gene_num"]]
57 | logger.info(
58 | f"Filtering pathway with GeneNumber > {configs['min_gene_num']}"
59 | )
60 |
61 | pathway_df["genes"] = pathway_df["genes"].map(
62 | lambda x: "|".join([gene for gene in x.split("|") if gene in genes])
63 | )
64 |
65 | for index, row in pathway_df.iterrows():
66 | if row["genes"]:
67 | pathway_dict[row["name"]] = row["genes"].split("|")
68 | cancer_genes = set(
69 | [y for x in pathway_df["genes"].values for y in x.split("|")]
70 | )
71 | non_cancer_genes = set(genes) - set(cancer_genes)
72 | logger.info(
73 | f"Cancer genes:{len(cancer_genes)}\tNon-cancer genes:{len(non_cancer_genes)}"
74 | )
75 | if random_control:
76 | logger.info("Randomly select genes for each pathway")
77 | for key in pathway_dict:
78 | pathway_dict[key] = np.random.choice(list(set(cancer_genes)),
79 | len(pathway_dict[key]), replace=False)
80 | return pathway_dict, non_cancer_genes
81 |
82 | random_control = False if "random_control" not in configs else configs["random_control"]
83 | pathway_dict, non_cancer_genes = load_pathway(random_control=random_control)
84 | model = DeePathNet(
85 | len(omics_types),
86 | target_dim,
87 | genes_to_id,
88 | id_to_genes,
89 | pathway_dict,
90 | non_cancer_genes,
91 | embed_dim=configs["dim"],
92 | depth=configs["depth"],
93 | mlp_ratio=configs["mlp_ratio"],
94 | out_mlp_ratio=configs["out_mlp_ratio"],
95 | num_heads=configs["heads"],
96 | pathway_drop_rate=configs["pathway_dropout"],
97 | only_cancer_genes=configs["cancer_only"],
98 | tissues=tissues,
99 | )
100 | logger.info(
101 | open("/home/scai/DeePathNet/scripts/model_transformer_lrp.py", "r").read()
102 | )
103 |
104 | logger.info(model)
105 | model = model.to(device)
106 |
107 | criterion = nn.MSELoss()
108 |
109 | optimizer = torch.optim.Adam(
110 | model.parameters(), lr=configs["lr"], weight_decay=configs["weight_decay"]
111 | )
112 |
113 | logger.info(optimizer)
114 |
115 | lr_scheduler = None
116 |
117 | return model, criterion, optimizer, lr_scheduler
118 |
119 |
120 | def run_experiment(
121 | merged_df_train, merged_df_test, val_score_dict, run="test", class_name_to_id=None
122 | ):
123 | train_df = merged_df_train.iloc[:, :num_of_features]
124 | test_df = merged_df_test.iloc[:, :num_of_features]
125 | train_target = merged_df_train.iloc[:, num_of_features:]
126 | test_target = merged_df_test.iloc[:, num_of_features:]
127 |
128 | X_train = train_df
129 | X_test = test_df
130 |
131 | if configs["task"] == "multiclass":
132 | train_dataset = MultiOmicMulticlassDataset(
133 | X_train,
134 | train_target,
135 | mode="train",
136 | omics_types=omics_types,
137 | class_name_to_id=class_name_to_id,
138 | logger=logger,
139 | )
140 | test_dataset = MultiOmicMulticlassDataset(
141 | X_test,
142 | test_target,
143 | mode="val",
144 | omics_types=omics_types,
145 | class_name_to_id=class_name_to_id,
146 | logger=logger,
147 | )
148 | else:
149 | train_dataset = MultiOmicDataset(
150 | X_train,
151 | train_target,
152 | mode="train",
153 | omics_types=omics_types,
154 | logger=logger,
155 | with_tissue=with_tissue,
156 | )
157 | test_dataset = MultiOmicDataset(
158 | X_test,
159 | test_target,
160 | mode="val",
161 | omics_types=omics_types,
162 | logger=logger,
163 | with_tissue=with_tissue,
164 | )
165 |
166 | train_loader = DataLoader(
167 | train_dataset,
168 | batch_size=BATCH_SIZE,
169 | shuffle=True,
170 | drop_last=configs["drop_last"],
171 | num_workers=NUM_WORKERS,
172 | )
173 |
174 | test_loader = DataLoader(
175 | test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
176 | )
177 |
178 | if configs["task"] == "multiclass":
179 | target_dim = len(class_name_to_id)
180 | else:
181 | target_dim = train_target.shape[1]
182 | model, criterion, optimizer, lr_scheduler = get_setup(
183 | train_dataset.genes_to_id, train_dataset.id_to_genes, target_dim
184 | )
185 |
186 | val_drug_ids = merged_df_test.columns[num_of_features:]
187 | val_res = train_loop(
188 | NUM_EPOCHS,
189 | train_loader,
190 | test_loader,
191 | model,
192 | criterion,
193 | optimizer,
194 | logger,
195 | STAMP,
196 | configs,
197 | lr_scheduler,
198 | val_drug_ids,
199 | run=run,
200 | val_score_dict=val_score_dict,
201 | )
202 |
203 | return val_res
204 |
205 |
206 | data_dict = prepare_data_cv(config_file, STAMP)
207 | data_input_all = data_dict["data_input_all"]
208 | data_target_all = data_dict["data_target_all"]
209 | val_score_dict = data_dict["val_score_dict"]
210 | num_of_features = data_dict["num_of_features"]
211 | genes = data_dict["genes"]
212 | omics_types = data_dict["omics_types"]
213 | with_tissue = data_dict["with_tissue"]
214 | tissues = data_dict["tissues"]
215 | logger = data_dict["logger"]
216 | cell_lines_all = data_input_all.index.values
217 |
218 | class_name_to_id = None
219 | id_to_class_name = None
220 | if configs["task"] == "multiclass":
221 | class_name_to_id = dict(
222 | zip(
223 | sorted(data_target_all.iloc[:, 0].unique()),
224 | list(range(data_target_all.iloc[:, 0].unique().size)),
225 | )
226 | )
227 | id_to_class_name = dict(
228 | zip(
229 | list(range(data_target_all.iloc[:, 0].unique().size)),
230 | sorted(data_target_all.iloc[:, 0].unique()),
231 | )
232 | )
233 |
234 | count = 0
235 | num_repeat = 1 if "num_repeat" not in configs else configs["num_repeat"]
236 |
237 | all_val_df = []
238 |
239 | if configs["save_checkpoints"]:
240 | # only run once if the purpose is for model explanation
241 | cv = KFold(n_splits=5, shuffle=True, random_state=seed)
242 | cell_lines_train_index, cell_lines_val_index = next(cv.split(cell_lines_all))
243 | train_lines = np.array(cell_lines_all)[cell_lines_train_index]
244 | val_lines = np.array(cell_lines_all)[cell_lines_val_index]
245 |
246 | merged_df_train = pd.merge(
247 | data_input_all[data_input_all.index.isin(train_lines)],
248 | data_target_all,
249 | on=["Cell_line"],
250 | )
251 |
252 | val_data = data_input_all[data_input_all.index.isin(val_lines)]
253 |
254 | merged_df_val = pd.merge(val_data, data_target_all, on=["Cell_line"])
255 | val_res = run_experiment(
256 | merged_df_train,
257 | merged_df_val,
258 | val_score_dict,
259 | run=f"cv_{count}",
260 | class_name_to_id=class_name_to_id,
261 | )
262 | all_val_df.append(val_res)
263 | else:
264 | for n in range(num_repeat):
265 | if configs["task"] == "multiclass":
266 | cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=(seed + n))
267 | folds = cv.split(cell_lines_all, data_target_all.iloc[:, 0])
268 | else:
269 | cv = KFold(n_splits=5, shuffle=True, random_state=(seed + n))
270 | folds = cv.split(cell_lines_all)
271 | for cell_lines_train_index, cell_lines_val_index in folds:
272 | train_lines = np.array(cell_lines_all)[cell_lines_train_index]
273 | val_lines = np.array(cell_lines_all)[cell_lines_val_index]
274 |
275 | merged_df_train = pd.merge(
276 | data_input_all[data_input_all.index.isin(train_lines)],
277 | data_target_all,
278 | on=["Cell_line"],
279 | )
280 |
281 | val_data = data_input_all[data_input_all.index.isin(val_lines)]
282 |
283 | merged_df_val = pd.merge(val_data, data_target_all, on=["Cell_line"])
284 |
285 | unique_train, unique_test = merged_df_train.iloc[:,-1].unique(), merged_df_val.iloc[:,-1].unique()
286 | if set(unique_train) != set(unique_test):
287 | logger.info("Missing class in validation fold")
288 | continue
289 |
290 | val_res = run_experiment(
291 | merged_df_train,
292 | merged_df_val,
293 | val_score_dict,
294 | run=f"cv_{count}",
295 | class_name_to_id=class_name_to_id,
296 | )
297 | all_val_df.append(val_res)
298 | count += 1
299 |
300 |
301 | if "save_scores" not in configs or configs["save_scores"]:
302 | val_score_df = pd.DataFrame(val_score_dict)
303 | val_score_df.to_csv(
304 | f"{configs['work_dir']}/scores_{STAMP}{log_suffix}.csv.gz", index=False
305 | )
306 | if configs["task"] == "multiclass":
307 | all_val_df = pd.concat(all_val_df)
308 | all_val_df["y_pred"] = all_val_df["y_pred"].map(id_to_class_name)
309 | all_val_df["y_true"] = all_val_df["y_true"].map(id_to_class_name)
310 | all_val_df.columns = [
311 | id_to_class_name[int(x.split("_")[-1])] if "feature_" in x else x
312 | for x in all_val_df.columns
313 | ]
314 | all_val_df.to_csv(
315 | f"{configs['work_dir']}/all_val_res_{STAMP}{log_suffix}.csv.gz", index=False
316 | )
317 |
318 | logger.info("Full training finished.")
319 |
--------------------------------------------------------------------------------
/scripts/deepathnet_independent_test.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to run DeePathNet with independent test set for any task.
3 | E.g. python scripts/deepathnet_independent_test.py configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna_prot/deepathnet_mutation_cnv_rna_prot.json
4 | """
5 | import json
6 | import sys
7 | from datetime import datetime
8 | import numpy as np
9 | import torch.optim
10 | from torch.utils.data import DataLoader
11 |
12 | from utils.training_prepare import (
13 | prepare_data_independent_test,
14 | get_logger,
15 | get_score_dict,
16 | )
17 |
18 | from models import *
19 | from model_transformer_lrp import DeePathNet
20 |
21 | STAMP = datetime.today().strftime("%Y%m%d%H%M")
22 |
23 | config_file = sys.argv[1]
24 | # load model configs
25 | configs = json.load(open(config_file, "r"))
26 |
27 | log_suffix = ""
28 | if "suffix" in configs:
29 | log_suffix = configs["suffix"]
30 |
31 | seed = configs["seed"]
32 | torch.manual_seed(seed)
33 | np.random.seed(seed)
34 |
35 | BATCH_SIZE = configs["batch_size"]
36 | NUM_WORKERS = 0
37 | LOG_FREQ = configs["log_freq"]
38 | NUM_EPOCHS = configs["num_of_epochs"]
39 | device = "cuda" if torch.cuda.is_available() else "cpu"
40 | RANDOM_CONTROL = False if "random_control" not in configs else configs["random_control"]
41 |
42 | logger = get_logger(config_file, STAMP)
43 |
44 |
45 | def get_setup(genes_to_id, id_to_genes, target_dim, cv=0):
46 | def load_pathway(random_control=False):
47 | pathway_dict = {}
48 | pathway_df = pd.read_csv(configs["pathway_file"])
49 | if "min_cancer_publication" in configs:
50 | pathway_df = pathway_df[
51 | pathway_df["Cancer_Publications"] > configs["min_cancer_publication"]
52 | ]
53 | logger.info(
54 | f"Filtering pathway with Cancer_Publications > {configs['min_cancer_publication']}"
55 | )
56 | if "max_gene_num" in configs:
57 | pathway_df = pathway_df[pathway_df["GeneNumber"] < configs["max_gene_num"]]
58 | logger.info(
59 | f"Filtering pathway with GeneNumber < {configs['max_gene_num']}"
60 | )
61 | if "min_gene_num" in configs:
62 | pathway_df = pathway_df[pathway_df["GeneNumber"] > configs["min_gene_num"]]
63 | logger.info(
64 | f"Filtering pathway with GeneNumber > {configs['min_gene_num']}"
65 | )
66 |
67 | pathway_df["genes"] = pathway_df["genes"].map(
68 | lambda x: "|".join([gene for gene in x.split("|") if gene in genes])
69 | )
70 |
71 | for index, row in pathway_df.iterrows():
72 | if row["genes"]:
73 | pathway_dict[row["name"]] = row["genes"].split("|")
74 | cancer_genes = set(
75 | [y for x in pathway_df["genes"].values for y in x.split("|")]
76 | )
77 | non_cancer_genes = set(genes) - set(cancer_genes)
78 | logger.info(
79 | f"Cancer genes:{len(cancer_genes)}\tNon-cancer genes:{len(non_cancer_genes)}"
80 | )
81 | if random_control:
82 | logger.info("Randomly select genes for each pathway")
83 | for key in pathway_dict:
84 | pathway_dict[key] = list(
85 | np.random.choice(
86 | list(set(cancer_genes)), len(pathway_dict[key]), replace=False
87 | )
88 | )
89 | return pathway_dict, non_cancer_genes
90 |
91 | pathway_dict, non_cancer_genes = load_pathway(random_control=RANDOM_CONTROL)
92 | if RANDOM_CONTROL:
93 | logger.info("Saving random control genes")
94 | with open(
95 | f"{configs['work_dir']}/random_genes_cv{cv}_{STAMP}{log_suffix}.json", "w"
96 | ) as f:
97 | json.dump(pathway_dict, f)
98 |
99 | model = DeePathNet(
100 | len(omics_types),
101 | target_dim,
102 | genes_to_id,
103 | id_to_genes,
104 | pathway_dict,
105 | non_cancer_genes,
106 | embed_dim=configs["dim"],
107 | depth=configs["depth"],
108 | mlp_ratio=configs["mlp_ratio"],
109 | out_mlp_ratio=configs["out_mlp_ratio"],
110 | num_heads=configs["heads"],
111 | pathway_drop_rate=configs["pathway_dropout"],
112 | only_cancer_genes=configs["cancer_only"],
113 | tissues=tissues,
114 | )
115 | logger.info(
116 | open("/home/scai/DeePathNet/scripts/model_transformer_lrp.py", "r").read()
117 | )
118 |
119 | logger.info(model)
120 | model = model.to(device)
121 |
122 | criterion = nn.MSELoss()
123 |
124 | optimizer = torch.optim.Adam(
125 | model.parameters(), lr=configs["lr"], weight_decay=configs["weight_decay"]
126 | )
127 |
128 | logger.info(optimizer)
129 |
130 | lr_scheduler = None
131 |
132 | return model, criterion, optimizer, lr_scheduler
133 |
134 |
135 | def run_experiment(
136 | merged_df_train,
137 | merged_df_test,
138 | val_score_dict,
139 | run="test",
140 | class_name_to_id=None,
141 | cv=0,
142 | ):
143 | train_df = merged_df_train.iloc[:, :num_of_features]
144 | test_df = merged_df_test.iloc[:, :num_of_features]
145 | train_target = merged_df_train.iloc[:, num_of_features:]
146 | test_target = merged_df_test.iloc[:, num_of_features:]
147 |
148 | X_train = train_df
149 | X_test = test_df
150 |
151 | if configs["task"] == "multiclass":
152 | train_dataset = MultiOmicMulticlassDataset(
153 | X_train,
154 | train_target,
155 | mode="train",
156 | omics_types=omics_types,
157 | class_name_to_id=class_name_to_id,
158 | logger=logger,
159 | )
160 | test_dataset = MultiOmicMulticlassDataset(
161 | X_test,
162 | test_target,
163 | mode="val",
164 | omics_types=omics_types,
165 | class_name_to_id=class_name_to_id,
166 | logger=logger,
167 | )
168 | else:
169 | train_dataset = MultiOmicDataset(
170 | X_train,
171 | train_target,
172 | mode="train",
173 | omics_types=omics_types,
174 | logger=logger,
175 | with_tissue=with_tissue,
176 | )
177 | test_dataset = MultiOmicDataset(
178 | X_test,
179 | test_target,
180 | mode="val",
181 | omics_types=omics_types,
182 | logger=logger,
183 | with_tissue=with_tissue,
184 | )
185 |
186 | train_loader = DataLoader(
187 | train_dataset,
188 | batch_size=BATCH_SIZE,
189 | shuffle=True,
190 | drop_last=configs["drop_last"],
191 | num_workers=NUM_WORKERS,
192 | )
193 |
194 | test_loader = DataLoader(
195 | test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
196 | )
197 |
198 | if configs["task"] == "multiclass":
199 | target_dim = len(class_name_to_id)
200 | else:
201 | target_dim = train_target.shape[1]
202 | model, criterion, optimizer, lr_scheduler = get_setup(
203 | train_dataset.genes_to_id, train_dataset.id_to_genes, target_dim, cv=cv
204 | )
205 |
206 | val_drug_ids = merged_df_test.columns[num_of_features:]
207 | val_res = train_loop(
208 | NUM_EPOCHS,
209 | train_loader,
210 | test_loader,
211 | model,
212 | criterion,
213 | optimizer,
214 | logger,
215 | STAMP,
216 | configs,
217 | lr_scheduler,
218 | val_drug_ids,
219 | run=run,
220 | val_score_dict=val_score_dict,
221 | )
222 |
223 | return val_res
224 |
225 |
226 | num_repeat = 1 if "num_repeat" not in configs else configs["num_repeat"]
227 | count = 0
228 | all_val_df = []
229 | val_score_dict = get_score_dict(config_file)
230 | for n in range(num_repeat):
231 | data_dict = prepare_data_independent_test(config_file, STAMP, seed=count)
232 | data_input_train = data_dict["data_input_train"]
233 | data_target_train = data_dict["data_target_train"]
234 | data_input_test = data_dict["data_input_test"]
235 | data_target_test = data_dict["data_target_test"]
236 | num_of_features = data_dict["num_of_features"]
237 | genes = data_dict["genes"]
238 | omics_types = data_dict["omics_types"]
239 | with_tissue = data_dict["with_tissue"]
240 | tissues = data_dict["tissues"]
241 |
242 | class_name_to_id = None
243 | id_to_class_name = None
244 | if configs["task"] == "multiclass":
245 | class_name_to_id = dict(
246 | zip(
247 | sorted(data_target_train.iloc[:, 0].unique()),
248 | list(range(data_target_train.iloc[:, 0].unique().size)),
249 | )
250 | )
251 | id_to_class_name = dict(
252 | zip(
253 | list(range(data_target_train.iloc[:, 0].unique().size)),
254 | sorted(data_target_train.iloc[:, 0].unique()),
255 | )
256 | )
257 |
258 | merged_df_train = pd.merge(data_input_train, data_target_train, on=["Cell_line"])
259 | merged_df_test = pd.merge(data_input_test, data_target_test, on=["Cell_line"])
260 |
261 | val_res = run_experiment(
262 | merged_df_train,
263 | merged_df_test,
264 | val_score_dict,
265 | run=f"cv_{count}",
266 | class_name_to_id=class_name_to_id,
267 | cv=count,
268 | )
269 | all_val_df.append(val_res)
270 | count += 1
271 |
272 | if "save_scores" not in configs or configs["save_scores"]:
273 | val_score_df = pd.DataFrame(val_score_dict)
274 | val_score_df.to_csv(
275 | f"{configs['work_dir']}/scores_{STAMP}{log_suffix}.csv.gz", index=False
276 | )
277 | if configs["task"] == "multiclass":
278 | all_val_df = pd.concat(all_val_df)
279 | all_val_df["y_pred"] = all_val_df["y_pred"].map(id_to_class_name)
280 | all_val_df["y_true"] = all_val_df["y_true"].map(id_to_class_name)
281 | all_val_df.columns = [
282 | id_to_class_name[int(x.split("_")[-1])] if "feature_" in x else x
283 | for x in all_val_df.columns
284 | ]
285 | all_val_df.to_csv(
286 | f"{configs['work_dir']}/all_val_res_{STAMP}{log_suffix}.csv.gz", index=False
287 | )
288 |
289 | logger.info("Full training finished.")
290 |
--------------------------------------------------------------------------------
/scripts/model_transformer_lrp.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 | import warnings
6 | from torch import nn, einsum
7 | import torch.nn.functional as F
8 | from inspect import isfunction
9 | from einops import rearrange, repeat
10 | from einops.layers.torch import Rearrange
11 | from utils.layers_ours import *
12 |
13 | np.random.seed(12345)
14 |
15 |
16 | def compute_rollout_attention(all_layer_matrices, start_layer=0):
17 | # adding residual consideration
18 | num_tokens = all_layer_matrices[0].shape[1]
19 | batch_size = all_layer_matrices[0].shape[0]
20 | eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
21 | all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
22 | # all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
23 | # for i in range(len(all_layer_matrices))]
24 | joint_attention = all_layer_matrices[start_layer]
25 | for i in range(start_layer + 1, len(all_layer_matrices)):
26 | joint_attention = all_layer_matrices[i].bmm(joint_attention)
27 | return joint_attention
28 |
29 |
30 | class Mlp(nn.Module):
31 | def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., activation='GELU'):
32 | super().__init__()
33 | out_features = out_features or in_features
34 | hidden_features = hidden_features or in_features
35 | self.fc1 = Linear(in_features, hidden_features)
36 | if activation == 'GELU':
37 | self.act = GELU()
38 | elif activation == 'ReLU':
39 | self.act = ReLU()
40 | else:
41 | raise Exception
42 | self.fc2 = Linear(hidden_features, out_features)
43 | self.drop = Dropout(drop)
44 |
45 | def forward(self, x):
46 | x = self.fc1(x)
47 | x = self.act(x)
48 | x = self.drop(x)
49 | x = self.fc2(x)
50 | x = self.drop(x)
51 | return x
52 |
53 | def relprop(self, cam, **kwargs):
54 | cam = self.drop.relprop(cam, **kwargs)
55 | cam = self.fc2.relprop(cam, **kwargs)
56 | cam = self.act.relprop(cam, **kwargs)
57 | cam = self.fc1.relprop(cam, **kwargs)
58 | return cam
59 |
60 |
61 | class Attention(nn.Module):
62 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
63 | super().__init__()
64 | self.num_heads = num_heads
65 | head_dim = dim // num_heads
66 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
67 | self.scale = head_dim ** -0.5
68 |
69 | # A = Q*K^T
70 | self.matmul1 = einsum('bhid,bhjd->bhij')
71 | # attn = A*V
72 | self.matmul2 = einsum('bhij,bhjd->bhid')
73 |
74 | self.qkv = Linear(dim, dim * 3, bias=qkv_bias)
75 | self.attn_drop = Dropout(attn_drop)
76 | self.proj = Linear(dim, dim)
77 | self.proj_drop = Dropout(proj_drop)
78 | self.softmax = Softmax(dim=-1)
79 |
80 | self.attn_cam = None
81 | self.attn = None
82 | self.v = None
83 | self.v_cam = None
84 | self.attn_gradients = None
85 |
86 | def get_attn(self):
87 | return self.attn
88 |
89 | def save_attn(self, attn):
90 | self.attn = attn
91 |
92 | def save_attn_cam(self, cam):
93 | self.attn_cam = cam
94 |
95 | def get_attn_cam(self):
96 | return self.attn_cam
97 |
98 | def get_v(self):
99 | return self.v
100 |
101 | def save_v(self, v):
102 | self.v = v
103 |
104 | def save_v_cam(self, cam):
105 | self.v_cam = cam
106 |
107 | def get_v_cam(self):
108 | return self.v_cam
109 |
110 | def save_attn_gradients(self, attn_gradients):
111 | self.attn_gradients = attn_gradients
112 |
113 | def get_attn_gradients(self):
114 | return self.attn_gradients
115 |
116 | def forward(self, x):
117 | b, n, _, h = *x.shape, self.num_heads
118 | qkv = self.qkv(x)
119 | q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)
120 |
121 | self.save_v(v)
122 |
123 | dots = self.matmul1([q, k]) * self.scale
124 |
125 | attn = self.softmax(dots)
126 | attn = self.attn_drop(attn)
127 |
128 | self.save_attn(attn)
129 |
130 | if x.requires_grad:
131 | attn.register_hook(self.save_attn_gradients)
132 |
133 | out = self.matmul2([attn, v])
134 | out = rearrange(out, 'b h n d -> b n (h d)')
135 |
136 | out = self.proj(out)
137 | out = self.proj_drop(out)
138 | return out
139 |
140 | def relprop(self, cam, **kwargs):
141 | cam = self.proj_drop.relprop(cam, **kwargs)
142 | cam = self.proj.relprop(cam, **kwargs)
143 | cam = rearrange(cam, 'b n (h d) -> b h n d', h=self.num_heads)
144 |
145 | # attn = A*V
146 | (cam1, cam_v) = self.matmul2.relprop(cam, **kwargs)
147 | cam1 /= 2
148 | cam_v /= 2
149 |
150 | self.save_v_cam(cam_v)
151 | self.save_attn_cam(cam1)
152 |
153 | cam1 = self.attn_drop.relprop(cam1, **kwargs)
154 | cam1 = self.softmax.relprop(cam1, **kwargs)
155 |
156 | # A = Q*K^T
157 | (cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs)
158 | cam_q /= 2
159 | cam_k /= 2
160 |
161 | cam_qkv = rearrange([cam_q, cam_k, cam_v], 'qkv b h n d -> b n (qkv h d)', qkv=3, h=self.num_heads)
162 |
163 | return self.qkv.relprop(cam_qkv, **kwargs)
164 |
165 |
166 | class Block(nn.Module):
167 |
168 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
169 | super().__init__()
170 | self.norm1 = LayerNorm(dim, eps=1e-6)
171 | self.attn = Attention(
172 | dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
173 | self.norm2 = LayerNorm(dim, eps=1e-6)
174 | mlp_hidden_dim = int(dim * mlp_ratio)
175 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
176 |
177 | self.add1 = Add()
178 | self.add2 = Add()
179 | self.clone1 = Clone()
180 | self.clone2 = Clone()
181 |
182 | def forward(self, x):
183 | x1, x2 = self.clone1(x, 2)
184 | x = self.add1([x1, self.attn(self.norm1(x2))])
185 | x1, x2 = self.clone2(x, 2)
186 | x = self.add2([x1, self.mlp(self.norm2(x2))])
187 | return x
188 |
189 | def relprop(self, cam, **kwargs):
190 | (cam1, cam2) = self.add2.relprop(cam, **kwargs)
191 | cam2 = self.mlp.relprop(cam2, **kwargs)
192 | cam2 = self.norm2.relprop(cam2, **kwargs)
193 | cam = self.clone2.relprop((cam1, cam2), **kwargs)
194 |
195 | (cam1, cam2) = self.add1.relprop(cam, **kwargs)
196 | cam2 = self.attn.relprop(cam2, **kwargs)
197 | cam2 = self.norm1.relprop(cam2, **kwargs)
198 | cam = self.clone1.relprop((cam1, cam2), **kwargs)
199 | return cam
200 |
201 |
202 | class LRP:
203 | def __init__(self, model):
204 | self.model = model
205 | self.model.eval()
206 |
207 | def generate_LRP(self, data, index=None, method="transformer_attribution", is_ablation=False, start_layer=0):
208 | if len(data) == 2:
209 | (input, targets) = data
210 | output = self.model(input.float().cuda())
211 | elif len(data) == 3:
212 | (input, tissue_x, targets) = data
213 | output = self.model(input.float().cuda(), tissue_x.float().cuda())
214 | else:
215 | raise Exception
216 | # output = self.model(input)
217 | kwargs = {"alpha": 1}
218 |
219 | assert index is not None
220 |
221 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
222 | one_hot[0, index] = 1
223 | one_hot_vector = one_hot
224 | one_hot = torch.from_numpy(one_hot).requires_grad_(True)
225 | one_hot = torch.sum(one_hot.cuda() * output)
226 |
227 | self.model.zero_grad()
228 | one_hot.backward(retain_graph=True)
229 |
230 | return self.model.relprop(torch.tensor(one_hot_vector).cuda(), method=method, is_ablation=is_ablation,
231 | start_layer=start_layer, **kwargs)
232 |
233 |
234 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
235 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
236 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
237 | def norm_cdf(x):
238 | # Computes standard normal cumulative distribution function
239 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
240 |
241 | if (mean < a - 2 * std) or (mean > b + 2 * std):
242 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
243 | "The distribution of values may be incorrect.",
244 | stacklevel=2)
245 |
246 | with torch.no_grad():
247 | # Values are generated by using a truncated uniform distribution and
248 | # then using the inverse CDF for the normal distribution.
249 | # Get upper and lower cdf values
250 | l = norm_cdf((a - mean) / std)
251 | u = norm_cdf((b - mean) / std)
252 |
253 | # Uniformly fill tensor with values from [l, u], then translate to
254 | # [2l-1, 2u-1].
255 | tensor.uniform_(2 * l - 1, 2 * u - 1)
256 |
257 | # Use inverse cdf transform for normal distribution to get truncated
258 | # standard normal
259 | tensor.erfinv_()
260 |
261 | # Transform to proper mean, std
262 | tensor.mul_(std * math.sqrt(2.))
263 | tensor.add_(mean)
264 |
265 | # Clamp to ensure it's in the proper range
266 | tensor.clamp_(min=a, max=b)
267 | return tensor
268 |
269 |
270 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
271 | # type: (Tensor, float, float, float, float) -> Tensor
272 | r"""Fills the input Tensor with values drawn from a truncated
273 | normal distribution. The values are effectively drawn from the
274 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
275 | with values outside :math:`[a, b]` redrawn until they are within
276 | the bounds. The method used for generating the random values works
277 | best when :math:`a \leq \text{mean} \leq b`.
278 | Args:
279 | tensor: an n-dimensional `torch.Tensor`
280 | mean: the mean of the normal distribution
281 | std: the standard deviation of the normal distribution
282 | a: the minimum cutoff value
283 | b: the maximum cutoff value
284 | Examples:
285 | >>> w = torch.empty(3, 5)
286 | >>> nn.init.trunc_normal_(w)
287 | """
288 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
289 |
290 |
291 | class PathwayDrop(nn.Module):
292 |
293 | def __init__(self, p=0.5):
294 | super().__init__()
295 | self.p = p
296 |
297 | def forward(self, x):
298 | if self.training:
299 | num_tokens = x.shape[1]
300 | for idx in range(num_tokens):
301 | if np.random.binomial(1, self.p):
302 | x[:, idx, :] = 0
303 |
304 | return x
305 |
306 |
307 | class DeePathNet(nn.Module):
308 |
309 | def __init__(self, num_omics, out_dim, gene_to_id, id_to_gene, pathway_dict, non_cancer_genes, embed_dim=2048,
310 | depth=2,
311 | num_heads=1, mlp_ratio=2, out_mlp_ratio=4, qkv_bias=True, drop_rate=0., attn_drop_rate=0.,
312 | pathway_drop_rate=0,
313 | only_cancer_genes=False, tissues=None):
314 | super().__init__()
315 | self.pathway_dict = pathway_dict
316 | self.pathway_layers = nn.ModuleDict()
317 | self.non_cancer_genes = sorted(non_cancer_genes)
318 | self.gene_to_id = gene_to_id
319 | self.id_to_gene = id_to_gene
320 | self.num_omics = num_omics
321 | self.only_cancer_genes = only_cancer_genes
322 | self.out_dim = out_dim
323 | self.tissues = tissues
324 |
325 | for key in self.pathway_dict:
326 | num_genes_in_pathway = len(self.pathway_dict[key])
327 | pathway_width = embed_dim
328 | self.pathway_layers[key] = nn.Linear(num_genes_in_pathway * num_omics, pathway_width)
329 |
330 | self.non_cancer_layer = nn.Linear(len(non_cancer_genes) * num_omics,
331 | embed_dim) if not only_cancer_genes else None
332 | # self.non_cancer_layer = nn.Linear(len(non_cancer_genes) * num_omics,
333 | # embed_dim)
334 |
335 | pathway_embedding_num = 1
336 | if not only_cancer_genes:
337 | pathway_embedding_num += 1
338 | if self.tissues:
339 | pathway_embedding_num += 1
340 |
341 | self.pathway_embedding = nn.Parameter(torch.randn(1, len(pathway_dict) + pathway_embedding_num, embed_dim))
342 | self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
343 |
344 | self.tissue_embedding = nn.Linear(len(self.tissues), embed_dim) if self.tissues else None
345 |
346 | # trunc_normal_(self.cls_token, std=.02)
347 | # self.apply(self._init_weights)
348 |
349 | self.blocks = nn.ModuleList([
350 | Block(
351 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
352 | drop=drop_rate, attn_drop=attn_drop_rate)
353 | for i in range(depth)])
354 |
355 | self.norm = LayerNorm(embed_dim)
356 | self.pathway_drop = PathwayDrop(pathway_drop_rate)
357 |
358 | self.head = Mlp(embed_dim, int(embed_dim * out_mlp_ratio), out_dim,
359 | activation="ReLU") if out_mlp_ratio > 1 else Linear(embed_dim, out_dim)
360 |
361 | self.pool = IndexSelect()
362 | self.add = Add()
363 |
364 | self.inp_grad = None
365 |
366 | def save_inp_grad(self, grad):
367 | self.inp_grad = grad
368 |
369 | def get_inp_grad(self):
370 | return self.inp_grad
371 |
372 | def _init_weights(self, m):
373 | if isinstance(m, nn.Linear):
374 | trunc_normal_(m.weight, std=.02)
375 | if isinstance(m, nn.Linear) and m.bias is not None:
376 | nn.init.constant_(m.bias, 0)
377 | elif isinstance(m, nn.LayerNorm):
378 | nn.init.constant_(m.bias, 0)
379 | nn.init.constant_(m.weight, 1.0)
380 |
381 | @property
382 | def no_weight_decay(self):
383 | return {'pathway_embedding', 'cls_token'}
384 |
385 | def forward(self, x, tissue_x=None, return_embedding=False):
386 | B = x.shape[0]
387 | pathway_x = []
388 | for key in self.pathway_dict:
389 | gene_ids = [self.gene_to_id[x] for x in self.pathway_dict[key]]
390 | tmp = self.pathway_layers[key](
391 | x[:, gene_ids, :].reshape(-1, len(gene_ids) * self.num_omics))
392 | # tmp = tmp / torch.linalg.norm(tmp, dim=1, keepdim=True).expand_as(tmp)
393 | pathway_x.append(tmp) # shape:[b,g,3] -> [b,g]
394 | pathway_x = torch.stack(pathway_x, dim=1) # [b, p, p_embed]
395 | if self.only_cancer_genes:
396 | x = pathway_x
397 | else:
398 | non_cancer_gene_ids = sorted([self.gene_to_id[x] for x in self.non_cancer_genes])
399 | non_cancer_gene_x = self.non_cancer_layer(
400 | x[:, non_cancer_gene_ids, :].reshape(-1, len(non_cancer_gene_ids) * self.num_omics))
401 | non_cancer_gene_x = non_cancer_gene_x.unsqueeze(1)
402 | x = torch.cat([pathway_x, non_cancer_gene_x], dim=1) # [b, p+1, p_embed]
403 |
404 | if self.tissue_embedding:
405 | tissue_x = self.tissue_embedding(tissue_x)
406 | x = torch.cat([x, tissue_x.unsqueeze(1)], dim=1)
407 | x = self.pathway_drop(x)
408 |
409 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
410 | x = torch.cat((cls_tokens, x), dim=1)
411 | x = self.add([x, self.pathway_embedding])
412 |
413 | if x.requires_grad:
414 | x.register_hook(self.save_inp_grad)
415 |
416 | for blk in self.blocks:
417 | x = blk(x)
418 |
419 | x = self.norm(x)
420 | if return_embedding:
421 | return x
422 | x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device))
423 | x = x.squeeze(1)
424 |
425 | x = self.head(x)
426 | return x
427 |
428 | def relprop(self, cam=None, method="transformer_attribution", is_ablation=False, start_layer=0, **kwargs):
429 | # print(kwargs)
430 | # print("conservation 1", cam.sum())
431 | cam = self.head.relprop(cam, **kwargs)
432 | cam = cam.unsqueeze(1)
433 | cam = self.pool.relprop(cam, **kwargs)
434 | cam = self.norm.relprop(cam, **kwargs)
435 | for blk in reversed(self.blocks):
436 | cam = blk.relprop(cam, **kwargs)
437 |
438 | # print("conservation 2", cam.sum())
439 | # print("min", cam.min())
440 |
441 | if method == "full":
442 | (cam, _) = self.add.relprop(cam, **kwargs)
443 | cam = cam[:, 1:]
444 | cam = self.patch_embed.relprop(cam, **kwargs)
445 | # sum on channels
446 | cam = cam.sum(dim=1)
447 | return cam
448 |
449 | elif method == "rollout":
450 | # cam rollout
451 | attn_cams = []
452 | for blk in self.blocks:
453 | attn_heads = blk.attn.get_attn_cam().clamp(min=0)
454 | avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
455 | attn_cams.append(avg_heads)
456 | cam = compute_rollout_attention(attn_cams, start_layer=start_layer)
457 | cam = cam[:, 0, 1:]
458 | return cam
459 |
460 | # our method, method name grad is legacy
461 | elif method == "transformer_attribution" or method == "grad":
462 | cams = []
463 | for blk in self.blocks:
464 | grad = blk.attn.get_attn_gradients()
465 | cam = blk.attn.get_attn_cam()
466 | cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
467 | grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
468 | cam = grad * cam
469 | cam = cam.clamp(min=0).mean(dim=0)
470 | cams.append(cam.unsqueeze(0))
471 | rollout = compute_rollout_attention(cams, start_layer=start_layer)
472 | cam = rollout[:, 0, 1:]
473 | return cam
474 |
475 | elif method == "last_layer":
476 | cam = self.blocks[-1].attn.get_attn_cam()
477 | cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
478 | if is_ablation:
479 | grad = self.blocks[-1].attn.get_attn_gradients()
480 | grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
481 | cam = grad * cam
482 | cam = cam.clamp(min=0).mean(dim=0)
483 | cam = cam[0, 1:]
484 | return cam
485 |
486 | elif method == "last_layer_attn":
487 | cam = self.blocks[-1].attn.get_attn()
488 | cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
489 | cam = cam.clamp(min=0).mean(dim=0)
490 | cam = cam[0, 1:]
491 | return cam
492 |
493 | elif method == "second_layer":
494 | cam = self.blocks[1].attn.get_attn_cam()
495 | cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
496 | if is_ablation:
497 | grad = self.blocks[1].attn.get_attn_gradients()
498 | grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
499 | cam = grad * cam
500 | cam = cam.clamp(min=0).mean(dim=0)
501 | cam = cam[0, 1:]
502 | return cam
503 |
--------------------------------------------------------------------------------
/scripts/models.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from torch.nn import Linear, LayerNorm, ReLU
3 | import pandas as pd
4 | import math
5 | import torch
6 | from torch.utils.data import Dataset
7 | import torch.nn as nn
8 | import numpy as np
9 |
10 | from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error, accuracy_score, roc_auc_score, f1_score
11 |
12 | from tqdm import trange
13 | from scipy.stats import pearsonr
14 | import time
15 | from scipy.special import softmax
16 |
17 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
18 | MISSING_NUM = -100
19 |
20 |
21 | def logistic(x):
22 | return 1 / (1 + torch.exp(-x))
23 |
24 |
25 | def corr_loss(output, target):
26 | x = output
27 | y = target
28 |
29 | vx = x - torch.mean(x)
30 | vy = y - torch.mean(y)
31 | loss = 50 * (1 - torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2))))
32 |
33 | return loss
34 |
35 |
36 | class OmicLinearLayer(nn.Module):
37 | """ Custom Linear layer but mimics a standard linear layer """
38 |
39 | def __init__(self, size_in, num_omics):
40 | super().__init__()
41 | self.size_in, self.num_omics = size_in, num_omics
42 | weights = torch.Tensor(num_omics, size_in)
43 | self.weights = nn.Parameter(weights) # nn.Parameter is a Tensor that's a module parameter.
44 | bias = torch.Tensor(size_in)
45 | self.bias = nn.Parameter(bias)
46 |
47 | # initialize weights and biases
48 | nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) # weight init
49 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights)
50 | bound = 1 / math.sqrt(fan_in)
51 | nn.init.uniform_(self.bias, -bound, bound) # bias init
52 |
53 | def forward(self, x):
54 | w_times_x = (x * self.weights.t()).sum(dim=2)
55 | return torch.add(w_times_x, self.bias) # w times x + b
56 |
57 |
58 | class SingleOmicDataset(Dataset):
59 | def __init__(self, data_df, purpose_data_df, mode, logger=None):
60 | assert mode in ['train', 'val', 'test']
61 |
62 | self.df = np.nan_to_num(data_df, nan=0)
63 | self.purpose_data = np.nan_to_num(purpose_data_df, nan=MISSING_NUM)
64 |
65 | assert self.df.shape[0] == self.purpose_data.shape[0], f"{self.df.shape[0]}, {self.purpose_data.shape[0]}"
66 | self.mode = mode
67 | if logger:
68 | logger.info(f"mode: {mode}, df shape: {self.df.shape}, purpose_data shape: {self.purpose_data.shape}")
69 |
70 | def __getitem__(self, index):
71 | """ Returns: tuple (sample, target) """
72 | data = self.df[index, :]
73 | if len(self.purpose_data.shape) > 1:
74 | target = self.purpose_data[index, :] # the first col is cell line name
75 | else:
76 | target = self.purpose_data[index]
77 |
78 | # no other preprocessing for now
79 |
80 | return data, target
81 |
82 | def __len__(self):
83 | return self.df.shape[0]
84 |
85 |
86 | def get_multiomic_df(df, omics_types):
87 | gene_columns = [x for x in df.columns if 'tissue' not in x]
88 | tissue_df = df[[x for x in df.columns if 'tissue' in x]].values
89 | genes = np.unique(([x.split("_")[0] for x in gene_columns]))
90 |
91 | not_covered = [f"{x}_{omic}" for x in genes for omic in omics_types if f"{x}_{omic}" not in gene_columns]
92 |
93 | df_zeros = pd.DataFrame(np.zeros((df.shape[0], len(not_covered))), columns=not_covered, index=df.index)
94 | df_combined = pd.concat([df, df_zeros], axis=1)
95 | df_multiomic = np.zeros((df.shape[0], len(genes), len([x for x in omics_types if
96 | x != 'tissue'])))
97 | for i in range(len(genes)):
98 | df_multiomic[:, i, :] = df_combined[[f"{genes[i]}_{omic}" for omic in omics_types]].values
99 |
100 | genes_to_id = dict(zip(genes, range(len(genes))))
101 | id_to_genes = dict(zip(range(len(genes)), genes))
102 | return df_multiomic, genes_to_id, id_to_genes, tissue_df
103 |
104 |
105 | class MultiOmicDataset(Dataset):
106 | def __init__(self, df, purpose_data_df, mode, omics_types, logger=None, with_tissue=False):
107 | assert mode in ['train', 'val', 'test']
108 | self.df, self.genes_to_id, self.id_to_genes, self.tissue_df = get_multiomic_df(df, omics_types)
109 | self.omics_types = omics_types
110 | self.purpose_data = np.nan_to_num(purpose_data_df, nan=MISSING_NUM)
111 | self.with_tissue = with_tissue
112 |
113 | assert self.df.shape[0] == self.purpose_data.shape[0], f"{self.df.shape[0]}, {self.purpose_data.shape[0]}"
114 | self.mode = mode
115 | if logger:
116 | logger.info(f"mode: {mode}, df shape: {self.df.shape}, purpose_data shape: {self.purpose_data.shape}")
117 |
118 | def __getitem__(self, index):
119 | """ Returns: tuple (sample, target) """
120 | data = self.df[index, :, :] # (b, genes, omics)
121 | tissue_data = self.tissue_df[index, :]
122 | target = self.purpose_data[index, :]
123 | if self.with_tissue:
124 | return data, tissue_data, target
125 | else:
126 | return data, target
127 |
128 | def __len__(self):
129 | return self.df.shape[0]
130 |
131 |
132 | class MultiOmicMulticlassDataset(Dataset):
133 | def __init__(self, df, purpose_data_df, mode, omics_types, class_name_to_id, logger=None):
134 | assert mode in ['train', 'val', 'test']
135 | self.df, self.genes_to_id, self.id_to_genes, self.tissue_df = get_multiomic_df(df, omics_types)
136 | self.df = np.nan_to_num(self.df, nan=0)
137 | self.omics_types = omics_types
138 | self.purpose_data = np.nan_to_num(purpose_data_df, nan=MISSING_NUM)
139 | self.class_name_to_id = class_name_to_id
140 |
141 | assert self.df.shape[0] == self.purpose_data.shape[0], f"{self.df.shape[0]}, {self.purpose_data.shape[0]}"
142 | self.mode = mode
143 | if logger:
144 | logger.info(f"mode: {mode}, df shape: {self.df.shape}, purpose_data shape: {self.purpose_data.shape}")
145 | logger.info(self.class_name_to_id)
146 |
147 | def __getitem__(self, index):
148 | """ Returns: tuple (sample, target) """
149 | data = self.df[index, :, :] # (b, genes, omics)
150 | target = self.purpose_data[index, :]
151 | target_id = np.array([self.class_name_to_id[x] for x in target])
152 | return data, target_id
153 |
154 | def __len__(self):
155 | return self.df.shape[0]
156 |
157 |
158 | class AverageMeter:
159 | ''' Computes and stores the average and current value '''
160 |
161 | def __init__(self):
162 | self.reset()
163 |
164 | def reset(self):
165 | self.val = 0.0
166 | self.avg = 0.0
167 | self.sum = 0.0
168 | self.count = 0
169 |
170 | def update(self, val, n=1):
171 | self.val = val
172 | self.sum += val * n
173 | self.count += n
174 | self.avg = self.sum / self.count
175 |
176 |
177 | def train(train_loader, model, criterion, optimizer, epoch, logger):
178 | batch_time = AverageMeter()
179 | losses = AverageMeter()
180 | avg_r2 = AverageMeter()
181 | avg_mae = AverageMeter()
182 | avg_rmse = AverageMeter()
183 | avg_corr = AverageMeter()
184 |
185 | model.train()
186 |
187 | end = time.time()
188 | lr_str = ''
189 |
190 | for i, data in enumerate(train_loader):
191 |
192 | if len(data) == 2:
193 | (input_, targets) = data
194 | output = model(input_.float().to(device))
195 | elif len(data) == 3:
196 | (input_, tissue_x, targets) = data
197 | output = model(input_.float().to(device), tissue_x.float().to(device))
198 | else:
199 | raise Exception
200 |
201 | output[targets == MISSING_NUM] = MISSING_NUM
202 |
203 | loss = criterion(output, targets.float().to(device))
204 | targets = targets.cpu().numpy()
205 |
206 | confs = output.detach().cpu().numpy()
207 | if not np.isinf(confs).any() and not np.isnan(confs).any():
208 | try:
209 | avg_r2.update(np.median(
210 | [r2_score(targets[targets[:, i] != MISSING_NUM, i], confs[targets[:, i] != MISSING_NUM, i])
211 | for i in range(confs.shape[1])]))
212 | avg_mae.update(np.median(
213 | [mean_absolute_error(targets[targets[:, i] != MISSING_NUM, i],
214 | confs[targets[:, i] != MISSING_NUM, i])
215 | for i in range(confs.shape[1])]))
216 | avg_rmse.update(np.median(
217 | [mean_squared_error(targets[targets[:, i] != MISSING_NUM, i],
218 | confs[targets[:, i] != MISSING_NUM, i],
219 | squared=True)
220 | for i in range(confs.shape[1])]))
221 | avg_corr.update(np.median(
222 | [pearsonr(targets[targets[:, i] != MISSING_NUM, i], confs[targets[:, i] != MISSING_NUM, i])[0]
223 | for i in range(confs.shape[1])][0]))
224 | except ValueError:
225 | logger.info("skipping training score")
226 |
227 | losses.update(loss.data.item(), input_.size(0))
228 | optimizer.zero_grad()
229 | loss.backward()
230 | optimizer.step()
231 |
232 | batch_time.update(time.time() - end)
233 | end = time.time()
234 |
235 | logger.info(f'{epoch} \t'
236 | f'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
237 | f'loss {losses.val:.4f} ({losses.avg:.4f})\t'
238 | f'corr {avg_corr.val:.4f} ({avg_corr.avg:.4f})\t'
239 | f'R2 {avg_r2.val:.4f} ({avg_r2.avg:.4f})\t'
240 | f'MAE {avg_mae.val:.4f} ({avg_mae.avg:.4f})\t'
241 | f'RMSE {avg_rmse.val:.4f} ({avg_rmse.avg:.4f})\t' + lr_str)
242 |
243 | return avg_r2.avg
244 |
245 |
246 | def inference(data_loader, model):
247 | ''' Returns predictions and targets, if any. '''
248 | model.eval()
249 |
250 | all_confs, all_targets = [], []
251 | with torch.no_grad():
252 | for i, data in enumerate(data_loader):
253 | if len(data) == 2:
254 | (input_, target) = data
255 | output = model(input_.float().to(device))
256 | elif len(data) == 3:
257 | (input_, tissue_type, target) = data
258 | output = model(input_.float().to(device), tissue_type.float().to(device))
259 | else:
260 | raise Exception
261 |
262 | # output = model(input_.float().to(device))
263 | all_confs.append(output)
264 |
265 | if target is not None:
266 | all_targets.append(target)
267 |
268 | confs = torch.cat(all_confs)
269 | targets = torch.cat(all_targets) if len(all_targets) else None
270 | targets = targets.cpu().numpy()
271 | confs = confs.cpu().numpy()
272 |
273 | return confs, targets
274 |
275 |
276 | def validate(val_loader, model, val_drug_ids, run=None, epoch=None, val_score_dict=None):
277 | confs, targets = inference(val_loader, model)
278 |
279 | r2_avg, mae_avg, rmse_avg, corr_avg = None, None, None, None
280 | if not np.isinf(confs).any() and not np.isnan(confs).any():
281 | if 'drug_id' in val_score_dict:
282 | val_score_dict['drug_id'].extend(val_drug_ids)
283 | elif 'Gene' in val_score_dict:
284 | val_score_dict['Gene'].extend(val_drug_ids)
285 | else:
286 | raise
287 | val_score_dict['run'].extend([run] * len(val_drug_ids))
288 | val_score_dict['epoch'].extend([epoch] * len(val_drug_ids))
289 |
290 | r2 = [r2_score(targets[targets[:, i] != MISSING_NUM, i], confs[targets[:, i] != MISSING_NUM, i])
291 | for i in range(confs.shape[1])]
292 | r2_avg = np.median(r2)
293 |
294 | mae = [mean_absolute_error(targets[targets[:, i] != MISSING_NUM, i], confs[targets[:, i] != MISSING_NUM, i])
295 | for i in range(confs.shape[1])]
296 | mae_avg = np.median(mae)
297 |
298 | rmse = [mean_squared_error(targets[targets[:, i] != MISSING_NUM, i], confs[targets[:, i] != MISSING_NUM, i],
299 | squared=False)
300 | for i in range(confs.shape[1])]
301 | rmse_avg = np.median(rmse)
302 |
303 | corr = [pearsonr(targets[targets[:, i] != MISSING_NUM, i], confs[targets[:, i] != MISSING_NUM, i])[0]
304 | for i in range(confs.shape[1])]
305 | corr_avg = np.median(corr)
306 |
307 | val_score_dict['mae'].extend(mae)
308 | val_score_dict['rmse'].extend(rmse)
309 | val_score_dict['corr'].extend(corr)
310 | val_score_dict['r2'].extend(r2)
311 |
312 | return r2_avg, mae_avg, rmse_avg, corr_avg
313 |
314 |
315 | def get_model_filename(drug_id):
316 | drug_name = drug_id.replace(';', '_')
317 | drug_name = drug_name.replace('/', '_')
318 | drug_name = drug_name.replace(' ', '')
319 | drug_name = drug_name.replace('(', '')
320 | drug_name = drug_name.replace(')', '')
321 | drug_name = drug_name.replace('+', '_')
322 | drug_name = drug_name.replace(',', '_')
323 | return drug_name
324 |
325 |
326 | def train_loop(epochs, train_loader, val_loader, model, criterion, optimizer, logger, stamp,
327 | configs,
328 | lr_scheduler=None,
329 | val_drug_ids=None,
330 | run=None, val_score_dict=None, id_to_class_name=None):
331 | train_res = []
332 | val_res = []
333 | if configs['task'] == 'regression':
334 | best_r2 = 0.05
335 | for epoch in trange(1, epochs + 1):
336 | if lr_scheduler:
337 | logger.info(f"learning rate: {lr_scheduler.get_last_lr()}")
338 | train_score = train(train_loader,
339 | model,
340 | criterion,
341 | optimizer,
342 | epoch,
343 | logger)
344 |
345 | train_res.append(train_score)
346 | if lr_scheduler:
347 | lr_scheduler.step()
348 |
349 | r2, mae, rmse, corr = validate(val_loader, model, val_drug_ids, run=run, epoch=epoch,
350 | val_score_dict=val_score_dict)
351 |
352 | if r2 and mae and rmse and corr:
353 | logger.info(f"Epoch {epoch} validation corr:{corr:4f}, R2:{r2:4f}, MAE:{mae:4f}, RMSE:{rmse:4f}")
354 | else:
355 | logger.info(f"Epoch {epoch} validation Inf")
356 | if configs['save_checkpoints'] and best_r2 < r2:
357 | best_r2 = max(best_r2, r2)
358 | if len(val_drug_ids) == 1:
359 | model_path = f"{configs['work_dir']}/{stamp}_{get_model_filename(val_drug_ids[0])}.pth"
360 | else:
361 | model_path = f"{configs['work_dir']}/{stamp}{configs['suffix']}.pth"
362 | torch.save(model.state_dict(), model_path)
363 | return None
364 |
365 | elif configs['task'] == 'classification':
366 | best_auc = 0.71
367 | criterion = nn.BCEWithLogitsLoss()
368 | for epoch in trange(1, epochs + 1):
369 | if lr_scheduler:
370 | logger.info(f"learning rate: {lr_scheduler.get_lr()}")
371 | train_score = train_cls(train_loader,
372 | model,
373 | criterion,
374 | optimizer,
375 | epoch,
376 | logger)
377 |
378 | train_res.append(train_score)
379 | if lr_scheduler:
380 | lr_scheduler.step()
381 |
382 | accuracy, auc = validate_cls(val_loader, model, val_drug_ids, run=run, epoch=epoch,
383 | val_score_dict=val_score_dict)
384 | if accuracy:
385 | logger.info(f"Epoch {epoch} validation accuracy:{accuracy:4f}, AUC:{auc:4f}")
386 | else:
387 | logger.info(f"Epoch {epoch} validation Inf")
388 | if configs['save_checkpoints'] and auc > best_auc:
389 | best_auc = max(best_auc, auc)
390 | if len(val_drug_ids) == 1:
391 | model_path = f"{configs['work_dir']}/{stamp}_{get_model_filename(val_drug_ids[0])}.pth"
392 | else:
393 | model_path = f"{configs['work_dir']}/{stamp}{configs['suffix']}.pth"
394 | torch.save(model.state_dict(), model_path)
395 |
396 | torch.cuda.empty_cache()
397 | return None
398 |
399 | elif configs['task'] == 'multiclass':
400 | best_avg_acc = 0.7
401 | criterion = nn.CrossEntropyLoss()
402 | all_val_res = []
403 | for epoch in trange(1, epochs + 1):
404 | if lr_scheduler:
405 | logger.info(f"learning rate: {lr_scheduler.get_lr()}")
406 | train_score = train_cls_multiclass(train_loader,
407 | model,
408 | criterion,
409 | optimizer,
410 | epoch,
411 | logger)
412 |
413 | train_res.append(train_score)
414 | if lr_scheduler:
415 | lr_scheduler.step()
416 |
417 | top1_acc, top3_acc, f1, roc_auc, val_res_perclass = validate_cls_multiclass(val_loader, model, run=run,
418 | epoch=epoch,
419 | val_score_dict=val_score_dict)
420 | all_val_res.append(val_res_perclass)
421 | logger.info(
422 | f"Epoch {epoch} validation top1_acc:{top1_acc:4f}, f1:{f1:4f}, AUC:{roc_auc:4f}")
423 | avg_acc = np.mean([top1_acc, top3_acc])
424 | if configs['save_checkpoints'] and avg_acc > best_avg_acc:
425 | best_avg_acc = max(best_avg_acc, avg_acc)
426 | if len(val_drug_ids) == 1:
427 | model_path = f"{configs['work_dir']}/{stamp}_{get_model_filename(val_drug_ids[0])}.pth"
428 | else:
429 | model_path = f"{configs['work_dir']}/{stamp}{configs['suffix']}.pth"
430 | torch.save(model.state_dict(), model_path)
431 |
432 | torch.cuda.empty_cache()
433 | return pd.concat(all_val_res)
434 | else:
435 | raise Exception
436 |
437 |
438 | def train_cls(train_loader, model, criterion, optimizer, epoch, logger):
439 | batch_time = AverageMeter()
440 | losses = AverageMeter()
441 | avg_accuracy = AverageMeter()
442 | avg_auc = AverageMeter()
443 |
444 | model.train()
445 |
446 | end = time.time()
447 | lr_str = ''
448 |
449 | for i, data in enumerate(train_loader):
450 |
451 | if len(data) == 2:
452 | (input_, targets) = data
453 | output = model(input_.float().to(device))
454 | elif len(data) == 3:
455 | (input_, tissue_x, targets) = data
456 | output = model(input_.float().to(device), tissue_x.float().to(device))
457 | else:
458 | raise Exception
459 |
460 | output[targets == MISSING_NUM] = MISSING_NUM
461 |
462 | loss = criterion(output, targets.float().to(device))
463 | targets = targets.cpu().numpy()
464 |
465 | confs = torch.sigmoid(output).detach().cpu().numpy()
466 | predicts = (confs > 0.5).astype(int)
467 |
468 | avg_auc.update(np.median(
469 | [roc_auc_score(targets[targets[:, i] != MISSING_NUM, i],
470 | confs[targets[:, i] != MISSING_NUM, i])
471 | for i in range(confs.shape[1])]))
472 |
473 | avg_accuracy.update(np.median(
474 | [accuracy_score(targets[targets[:, i] != MISSING_NUM, i],
475 | predicts[targets[:, i] != MISSING_NUM, i])
476 | for i in range(predicts.shape[1])]))
477 |
478 | losses.update(loss.data.item(), input_.size(0))
479 | optimizer.zero_grad()
480 | loss.backward()
481 | optimizer.step()
482 |
483 | batch_time.update(time.time() - end)
484 | end = time.time()
485 |
486 | logger.info(f'{epoch} \t'
487 | f'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
488 | f'loss {losses.val:.4f} ({losses.avg:.4f})\t'
489 | f'avg_accuracy {avg_accuracy.val:.4f} ({avg_accuracy.avg:.4f})\t'
490 | f'avg_auc {avg_auc.val:.4f} ({avg_auc.avg:.4f})\t' + lr_str)
491 |
492 | return avg_accuracy.avg
493 |
494 |
495 | def validate_cls(val_loader, model, val_drug_ids, run=None, epoch=None, val_score_dict=None):
496 | confs, targets = inference(val_loader, model)
497 | confs = torch.from_numpy(confs)
498 | confs = torch.sigmoid(confs).numpy()
499 |
500 | predicts = (confs > 0.5).astype(int)
501 | acc_avg, auc_avg = None, None
502 | if not np.isinf(confs).any() and not np.isnan(confs).any():
503 | if 'drug_id' in val_score_dict:
504 | val_score_dict['drug_id'].extend(val_drug_ids)
505 | elif 'Gene' in val_score_dict:
506 | val_score_dict['Gene'].extend(val_drug_ids)
507 | else:
508 | raise Exception
509 |
510 | val_score_dict['run'].extend([run] * len(val_drug_ids))
511 | val_score_dict['epoch'].extend([epoch] * len(val_drug_ids))
512 | accuracy = [accuracy_score(targets[targets[:, i] != MISSING_NUM, i], predicts[targets[:, i] != MISSING_NUM, i])
513 | for i in range(predicts.shape[1])]
514 | acc_avg = np.median(accuracy)
515 |
516 | auc = [roc_auc_score(targets[targets[:, i] != MISSING_NUM, i], confs[targets[:, i] != MISSING_NUM, i])
517 | for i in range(confs.shape[1])]
518 | auc_avg = np.median(auc)
519 |
520 | val_score_dict['accuracy'].extend(accuracy)
521 | val_score_dict['auc'].extend(auc)
522 |
523 | return acc_avg, auc_avg
524 |
525 |
526 | def train_cls_multiclass(train_loader, model, criterion, optimizer, epoch, logger):
527 | batch_time = AverageMeter()
528 | losses = AverageMeter()
529 | avg_top1_acc = AverageMeter()
530 | avg_top3_acc = AverageMeter()
531 | avg_f1 = AverageMeter()
532 | avg_auc = AverageMeter()
533 |
534 | model.train()
535 |
536 | end = time.time()
537 | lr_str = ''
538 |
539 | for i, data in enumerate(train_loader):
540 | (input_, targets) = data
541 | output = model(input_.float().to(device))
542 |
543 | loss = criterion(output, targets.flatten().long().to(device))
544 | output = torch.softmax(output, dim=-1)
545 | targets = targets.cpu().numpy()
546 |
547 | confs = output.detach().cpu().numpy()
548 | predicts = np.argsort(-confs, axis=1)
549 | targets = targets.flatten()
550 | top1_acc = np.sum((predicts[:, 0] == targets)) / targets.shape[0]
551 | top3_acc = np.sum(np.any(predicts[:, :3] == np.expand_dims(targets, axis=1), axis=1)) / targets.shape[0]
552 | f1 = f1_score(targets, predicts[:, 0], average='macro')
553 | # roc_auc = roc_auc_score(targets, confs, multi_class='ovo')
554 |
555 | avg_top1_acc.update(top1_acc)
556 | avg_top3_acc.update(top3_acc)
557 | avg_f1.update(f1)
558 | # avg_auc.update(roc_auc)
559 |
560 | losses.update(loss.data.item(), input_.size(0))
561 | optimizer.zero_grad()
562 | loss.backward()
563 | optimizer.step()
564 |
565 | batch_time.update(time.time() - end)
566 | end = time.time()
567 |
568 | logger.info(f'{epoch} \t'
569 | f'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
570 | f'loss {losses.val:.4f} ({losses.avg:.4f})\t'
571 | f'avg_top1_acc {avg_top1_acc.val:.4f} ({avg_top1_acc.avg:.4f})\t'
572 | f'avg_f1 {avg_f1.val:.4f} ({avg_f1.avg:.4f})' + lr_str)
573 |
574 | return avg_top1_acc.avg
575 |
576 |
577 | def validate_cls_multiclass(val_loader, model, run=None, epoch=None, val_score_dict=None):
578 | confs, targets = inference(val_loader, model)
579 | predicts = np.argsort(-confs, axis=1)
580 | confs = softmax(confs, axis=-1)
581 | val_res_perclass = {}
582 |
583 | val_score_dict['run'].append(run)
584 | val_score_dict['epoch'].append(epoch)
585 | targets = targets.flatten()
586 | top1_acc = np.sum((predicts[:, 0] == targets)) / targets.shape[0]
587 |
588 | if np.unique(targets).size > 2:
589 | top3_acc = np.sum(np.any(predicts[:, :3] == np.expand_dims(targets, axis=1), axis=1)) / targets.shape[0]
590 | f1 = f1_score(targets, predicts[:, 0], average='macro')
591 | unique_train, unique_test = np.unique(predicts), np.unique(targets)
592 | if set(unique_train) == set(unique_test):
593 | roc_auc = roc_auc_score(targets, confs, multi_class='ovo')
594 | else:
595 | roc_auc = np.nan
596 | else:
597 | top3_acc = 1
598 | f1 = f1_score(targets, predicts[:, 0])
599 | roc_auc = roc_auc_score(targets, confs[:, 1])
600 |
601 | val_res_perclass['run'] = [run] * len(predicts)
602 | val_res_perclass['epoch'] = [epoch] * len(predicts)
603 | val_res_perclass['y_pred'] = predicts[:, 0]
604 | val_res_perclass['y_true'] = targets
605 | for i in range(confs.shape[1]):
606 | val_res_perclass[f"feature_{i}"] = confs[:, i]
607 |
608 | val_score_dict['top1_acc'].append(top1_acc)
609 | val_score_dict['top3_acc'].append(top3_acc)
610 | val_score_dict['f1'].append(f1)
611 | val_score_dict['roc_auc'].append(roc_auc)
612 |
613 | return top1_acc, top3_acc, f1, roc_auc, pd.DataFrame(val_res_perclass)
614 |
--------------------------------------------------------------------------------
/scripts/transformer_explantion_cancer_type.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to run pathway-level model explanation for cancer type
3 | E.g. python scripts/transformer_explantion_cancer_type.py configs/tcga_brca_subtypes/mutation_cnv_rna/deepathnet_allgenes_mutation_cnv_rna.json
4 | """
5 | import json
6 | import os
7 | import sys
8 |
9 | from sklearn.model_selection import KFold
10 |
11 | sys.path.append(os.getcwd() + '/..')
12 | from models import *
13 | from model_transformer_lrp import DeePathNet, LRP
14 | from torch.utils.data import DataLoader
15 | from tqdm import trange
16 |
17 | seed = 12345
18 | torch.manual_seed(seed)
19 | OUTPUT_NA_NUM = -100
20 |
21 | config_file = sys.argv[1]
22 |
23 | # load model configs
24 | configs = json.load(open(config_file, 'r'))
25 | data_file = configs['data_file']
26 | data_type = configs['data_type']
27 |
28 | BATCH_SIZE = configs['batch_size']
29 | NUM_WORKERS = 0
30 | LOG_FREQ = configs['log_freq']
31 | NUM_EPOCHS = configs['num_of_epochs']
32 | # device = 'cuda' if torch.cuda.is_available() else 'cpu'
33 | device = 'cuda'
34 |
35 | data_target = pd.read_csv(configs['target_file'], low_memory=False, index_col=0)
36 |
37 | data_input = pd.read_csv(data_file, index_col=0)
38 | if data_type[0] != 'DR':
39 | data_input = data_input[
40 | [x for x in data_input.columns if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)]]
41 |
42 | omics_types = [x for x in data_type if x != 'tissue']
43 |
44 | genes = np.unique(([x.split("_")[0] for x in data_input.columns if x.split("_")[0] != 'tissue']))
45 | class_name_to_id = dict(
46 | zip(sorted(data_target.iloc[:, 0].unique()), list(range(data_target.iloc[:, 0].unique().size))))
47 | id_to_class_name = dict(
48 | zip(list(range(data_target.iloc[:, 0].unique().size)), sorted(data_target.iloc[:, 0].unique())))
49 |
50 |
51 | num_of_features = data_input.shape[1]
52 |
53 | pathway_dict = {}
54 | pathway_df = pd.read_csv(configs['pathway_file'])
55 |
56 | pathway_df['genes'] = pathway_df['genes'].map(
57 | lambda x: "|".join([gene for gene in x.split('|') if gene in genes]))
58 | if 'min_cancer_publication' in configs:
59 | pathway_df = pathway_df[pathway_df['Cancer_Publications'] > configs['min_cancer_publication']]
60 | if 'max_gene_num' in configs:
61 | pathway_df = pathway_df[pathway_df['GeneNumber'] < configs['max_gene_num']]
62 | if 'min_gene_num' in configs:
63 | pathway_df = pathway_df[pathway_df['GeneNumber'] > configs['min_gene_num']]
64 |
65 | for index, row in pathway_df.iterrows():
66 | pathway_dict[row['name']] = row['genes'].split('|')
67 |
68 | cancer_genes = set([y for x in pathway_df['genes'].values for y in x.split("|")])
69 | non_cancer_genes = sorted(set(genes) - set(cancer_genes))
70 |
71 | cell_lines_all = data_input.index.values
72 | cv = KFold(n_splits=5, shuffle=True, random_state=seed)
73 | cell_lines_train_index, cell_lines_val_index = next(cv.split(cell_lines_all))
74 | cell_lines_train = np.array(cell_lines_all)[cell_lines_train_index]
75 | cell_lines_test = np.array(cell_lines_all)[cell_lines_val_index]
76 |
77 | data_input_train = data_input[data_input.index.isin(cell_lines_train)]
78 | data_input_test = data_input[data_input.index.isin(cell_lines_test)]
79 | data_target_train = data_target[
80 | data_target.index.isin(cell_lines_train)]
81 | data_target_test = data_target[
82 | data_target.index.isin(cell_lines_test)]
83 |
84 | def run_lrp_cancer_type(merged_df_train):
85 | train_df = merged_df_train.iloc[:, :num_of_features]
86 | train_target = merged_df_train.iloc[:, num_of_features:]
87 |
88 | X_train = train_df
89 |
90 | train_dataset = MultiOmicMulticlassDataset(X_train, train_target, mode='train', omics_types=omics_types,
91 | class_name_to_id=class_name_to_id, logger=None)
92 |
93 | pathway_dict = {}
94 | pathway_df = pd.read_csv(configs['pathway_file'])
95 |
96 | pathway_df['genes'] = pathway_df['genes'].map(
97 | lambda x: "|".join([gene for gene in x.split('|') if gene in genes]))
98 | if 'min_cancer_publication' in configs:
99 | pathway_df = pathway_df[pathway_df['Cancer_Publications'] > configs['min_cancer_publication']]
100 | if 'max_gene_num' in configs:
101 | pathway_df = pathway_df[pathway_df['GeneNumber'] < configs['max_gene_num']]
102 | if 'min_gene_num' in configs:
103 | pathway_df = pathway_df[pathway_df['GeneNumber'] > configs['min_gene_num']]
104 |
105 | for index, row in pathway_df.iterrows():
106 | pathway_dict[row['name']] = row['genes'].split('|')
107 |
108 | cancer_genes = set([y for x in pathway_df['genes'].values for y in x.split("|")])
109 | non_cancer_genes = sorted(set(genes) - set(cancer_genes))
110 | model = DeePathNet(len(omics_types), len(class_name_to_id), train_dataset.genes_to_id,
111 | train_dataset.id_to_genes,
112 | pathway_dict, non_cancer_genes, embed_dim=configs['dim'], depth=configs['depth'],
113 | num_heads=configs['heads'],
114 | mlp_ratio=configs['mlp_ratio'], out_mlp_ratio=configs['out_mlp_ratio'],
115 | only_cancer_genes=configs['cancer_only'])
116 | model.load_state_dict(torch.load(f"{configs['work_dir']}/{configs['saved_model']}"))
117 |
118 | model.cuda()
119 | model.eval()
120 |
121 | attribution_generator = LRP(model)
122 | # index = train_target.columns.get_loc('1032;Afatinib;GDSC2')
123 | pathways = list(pathway_dict.keys())
124 | if not configs['cancer_only']:
125 | pathways += ['non_cancer']
126 | if 'tissue' in data_type:
127 | pathways += ['tissue']
128 |
129 | res_df_all = []
130 | for idx in trange(len(class_name_to_id)):
131 | transformer_attribution_all = []
132 | cancer_type = id_to_class_name[idx]
133 |
134 | train_loader = DataLoader(train_dataset,
135 | batch_size=1,
136 | num_workers=NUM_WORKERS)
137 | for i, data in enumerate(train_loader):
138 | transformer_attribution = attribution_generator.generate_LRP(data,
139 | method="transformer_attribution",
140 | index=idx).detach().cpu().numpy()
141 | transformer_attribution_all.append(transformer_attribution[0, :])
142 | transformer_attribution_sum = np.sum(transformer_attribution_all, axis=0)
143 | res_df = pd.DataFrame({'pathway': pathways, 'importance': transformer_attribution_sum})
144 | res_df['cancer_type'] = cancer_type
145 | res_df_all.append(res_df)
146 |
147 | res_df_all = pd.concat(res_df_all)
148 | res_df_all = res_df_all[['cancer_type', 'pathway', 'importance']]
149 | return res_df_all
150 |
151 |
152 | merged_df_train = pd.merge(data_input_train, data_target_train, on=['Cell_line'])
153 | res_df_all = run_lrp_cancer_type(merged_df_train)
154 | res_df_all.to_csv(f"{configs['work_dir']}/explanation_{configs['saved_model'].replace('pth', 'csv')}", index=False)
155 | print("finished")
156 |
--------------------------------------------------------------------------------
/scripts/transformer_explantion_drug_response.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from datetime import datetime
4 | import sys
5 | import logging
6 | import os
7 |
8 | import pandas as pd
9 |
10 | from models import *
11 | from model_transformer_lrp import DOIT_LRP, LRP
12 | from torch.utils.data import DataLoader
13 | from tqdm import tqdm, trange
14 |
15 | seed = 12345
16 | torch.manual_seed(seed)
17 | OUTPUT_NA_NUM = -100
18 |
19 | config_file = sys.argv[1]
20 | # load model configs
21 | configs = json.load(open(config_file, 'r'))
22 | data_file = configs['data_file']
23 | data_type = configs['data_type']
24 |
25 | BATCH_SIZE = configs['batch_size']
26 | NUM_WORKERS = 0
27 | LOG_FREQ = configs['log_freq']
28 | NUM_EPOCHS = configs['num_of_epochs']
29 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
30 |
31 | data_target = pd.read_csv(configs['target_file'], low_memory=False, index_col=0)
32 | drug_ids = data_target.columns
33 | if 'drug_id' in configs and configs['drug_id'] != "":
34 | data_target = data_target[configs['drug_id']]
35 | drug_ids = [configs['drug_id']]
36 |
37 | data_input = pd.read_csv(data_file, index_col=0)
38 | if data_type[0] != 'DR':
39 | data_input = data_input[
40 | [x for x in data_input.columns if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)]]
41 |
42 | tissues = [x for x in data_input.columns if 'tissue' in x] if 'tissue' in data_type else None
43 | omics_types = [x for x in data_type if x != 'tissue']
44 | with_tissue = 'tissue' in data_type
45 |
46 | with open(configs['train_cells']) as f:
47 | cell_lines_train = [line.rstrip() for line in f]
48 | with open(configs['test_cells']) as f:
49 | cell_lines_test = [line.rstrip() for line in f]
50 |
51 | genes = np.unique(([x.split("_")[0] for x in data_input.columns if x.split("_")[0] != 'tissue']))
52 |
53 | data_target_train = data_target[
54 | data_target.index.isin(cell_lines_train)]
55 | data_target_test = data_target[
56 | data_target.index.isin(cell_lines_test)]
57 |
58 | num_of_features = data_input.shape[1]
59 | data_input_train = data_input[data_input.index.isin(cell_lines_train)]
60 | merged_df_train = pd.merge(data_input_train, data_target_train, on=['Cell_line'])
61 |
62 |
63 | # test_data = data_input_test
64 | # merged_df_test = pd.merge(test_data, data_target_test, on=['Cell_line'])
65 |
66 | def run_lrp(merged_df_train, drug_id=None):
67 | train_df = merged_df_train.iloc[:, :num_of_features]
68 | train_target = merged_df_train.iloc[:, num_of_features:]
69 |
70 | X_train = train_df
71 |
72 | train_dataset = MultiOmicDataset(X_train, train_target, mode='train', omics_types=omics_types, logger=None,
73 | with_tissue=with_tissue)
74 |
75 | pathway_dict = {}
76 | pathway_df = pd.read_csv(configs['pathway_file'])
77 |
78 | pathway_df['genes'] = pathway_df['genes'].map(
79 | lambda x: "|".join([gene for gene in x.split('|') if gene in genes]))
80 | if 'min_cancer_publication' in configs:
81 | pathway_df = pathway_df[pathway_df['Cancer_Publications'] > configs['min_cancer_publication']]
82 | if 'max_gene_num' in configs:
83 | pathway_df = pathway_df[pathway_df['GeneNumber'] < configs['max_gene_num']]
84 | if 'min_gene_num' in configs:
85 | pathway_df = pathway_df[pathway_df['GeneNumber'] > configs['min_gene_num']]
86 |
87 | for index, row in pathway_df.iterrows():
88 | pathway_dict[row['name']] = row['genes'].split('|')
89 |
90 | cancer_genes = set([y for x in pathway_df['genes'].values for y in x.split("|")])
91 | non_cancer_genes = sorted(set(genes) - set(cancer_genes))
92 | model = DOIT_LRP(len(omics_types), train_target.shape[1], train_dataset.genes_to_id,
93 | train_dataset.id_to_genes,
94 | pathway_dict, non_cancer_genes, embed_dim=configs['dim'], depth=configs['depth'],
95 | num_heads=configs['heads'],
96 | mlp_ratio=configs['mlp_ratio'], out_mlp_ratio=configs['out_mlp_ratio'],
97 | only_cancer_genes=configs['cancer_only'], tissues=tissues)
98 |
99 | if drug_id:
100 | drug_file_name = get_model_filename(drug_id)
101 | drug_path = f"{configs['work_dir']}_{configs['saved_model']}/{configs['saved_model']}_{drug_file_name}.pth"
102 | if not os.path.exists(drug_path):
103 | return None
104 | model.load_state_dict(
105 | torch.load(drug_path))
106 | else:
107 | model.load_state_dict(torch.load(f"{configs['work_dir']}/{configs['saved_model']}"))
108 |
109 | model.cuda()
110 | model.eval()
111 |
112 | attribution_generator = LRP(model)
113 | # index = train_target.columns.get_loc('1032;Afatinib;GDSC2')
114 | pathways = list(pathway_dict.keys())
115 | if not configs['cancer_only']:
116 | pathways += ['non_cancer']
117 | if 'tissue' in data_type:
118 | pathways += ['tissue']
119 |
120 | res_df_all = []
121 | for drug_idx in range(train_target.shape[1]):
122 | transformer_attribution_all = []
123 | drug = train_target.columns[drug_idx]
124 | X_train_drug = X_train[X_train.index.isin(train_target[train_target[drug] > 0].index)]
125 | train_target_drug = train_target[train_target[drug] > 0]
126 | train_dataset = MultiOmicDataset(X_train_drug, train_target_drug, mode='train', omics_types=omics_types,
127 | logger=None, with_tissue=with_tissue)
128 | train_loader = DataLoader(train_dataset,
129 | batch_size=1,
130 | num_workers=NUM_WORKERS)
131 | for i, data in enumerate(train_loader):
132 | transformer_attribution = attribution_generator.generate_LRP(data,
133 | method="transformer_attribution",
134 | index=drug_idx).detach().cpu().numpy()
135 | transformer_attribution_all.append(transformer_attribution[0, :])
136 | transformer_attribution_sum = np.sum(transformer_attribution_all, axis=0)
137 | res_df = pd.DataFrame({'pathway': pathways, 'importance': transformer_attribution_sum})
138 | res_df['drug_id'] = train_target.columns[drug_idx]
139 | res_df_all.append(res_df)
140 |
141 | res_df_all = pd.concat(res_df_all)
142 | res_df_all = res_df_all[['drug_id', 'pathway', 'importance']]
143 | return res_df_all
144 |
145 |
146 | if 'all_single_mode' in configs and configs['all_single_mode']:
147 | res_df_all = []
148 | drug_ids = pd.read_csv(configs['drug_list'], index_col=0).index.values
149 | for drug_id in tqdm(drug_ids):
150 | merged_df_train = pd.merge(data_input_train, data_target_train[drug_id], on=['Cell_line'])
151 | relevance = run_lrp(merged_df_train, drug_id=drug_id)
152 | if relevance is not None:
153 | res_df_all.append(relevance)
154 | res_df_all = pd.concat(res_df_all)
155 | res_df_all.to_csv(
156 | f"{configs['work_dir']}_{configs['saved_model']}/explanation_{configs['saved_model']}.csv.gz",
157 | index=False)
158 | else:
159 | data_input_train = data_input[data_input.index.isin(cell_lines_train)]
160 | merged_df_train = pd.merge(data_input_train, data_target_train, on=['Cell_line'])
161 | res_df_all = run_lrp(merged_df_train)
162 | res_df_all.to_csv(f"{configs['work_dir']}/explanation_{configs['saved_model'].replace('pth', 'csv')}", index=False)
163 | print("finished")
164 |
--------------------------------------------------------------------------------
/scripts/transformer_shap_cancer_type.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to run gene-level model explanation for cancer type
3 | E.g. python scripts/transformer_shap_cancer_type.py configs/tcga_brca_subtypes/mutation_cnv_rna/deepathnet_allgenes_mutation_cnv_rna.json
4 | """
5 |
6 | import json
7 | import os
8 | import sys
9 |
10 | from sklearn.model_selection import KFold
11 |
12 | sys.path.append(os.getcwd() + "/..")
13 | from models import *
14 | from model_transformer_lrp import DeePathNet
15 | from torch.utils.data import DataLoader
16 |
17 | import shap
18 | import time
19 |
20 | seed = 12345
21 | torch.manual_seed(seed)
22 | OUTPUT_NA_NUM = -100
23 |
24 | config_file = sys.argv[1]
25 |
26 | mode = "grad" if len(sys.argv) <= 2 else sys.argv[2]
27 | print(f"SHAP algo: {mode}")
28 | # load model configs
29 | configs = json.load(open(config_file, "r"))
30 | data_file = configs["data_file"]
31 | data_type = configs["data_type"]
32 |
33 | BATCH_SIZE = configs["batch_size"]
34 | NUM_WORKERS = 0
35 | LOG_FREQ = configs["log_freq"]
36 | NUM_EPOCHS = configs["num_of_epochs"]
37 | # device = 'cuda' if torch.cuda.is_available() else 'cpu'
38 | device = "cuda"
39 |
40 | data_target = pd.read_csv(configs["target_file"], low_memory=False, index_col=0)
41 |
42 | data_input = pd.read_csv(data_file, index_col=0)
43 | if data_type[0] != "DR":
44 | data_input = data_input[
45 | [
46 | x
47 | for x in data_input.columns
48 | if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)
49 | ]
50 | ]
51 |
52 | omics_types = [x for x in data_type if x != "tissue"]
53 |
54 |
55 | genes = np.unique(
56 | ([x.split("_")[0] for x in data_input.columns if x.split("_")[0] != "tissue"])
57 | )
58 | class_name_to_id = dict(
59 | zip(
60 | sorted(data_target.iloc[:, 0].unique()),
61 | list(range(data_target.iloc[:, 0].unique().size)),
62 | )
63 | )
64 | id_to_class_name = dict(
65 | zip(
66 | list(range(data_target.iloc[:, 0].unique().size)),
67 | sorted(data_target.iloc[:, 0].unique()),
68 | )
69 | )
70 |
71 | num_of_features = data_input.shape[1]
72 |
73 | pathway_dict = {}
74 | pathway_df = pd.read_csv(configs["pathway_file"])
75 |
76 | pathway_df["genes"] = pathway_df["genes"].map(
77 | lambda x: "|".join([gene for gene in x.split("|") if gene in genes])
78 | )
79 | if "min_cancer_publication" in configs:
80 | pathway_df = pathway_df[
81 | pathway_df["Cancer_Publications"] > configs["min_cancer_publication"]
82 | ]
83 | if "max_gene_num" in configs:
84 | pathway_df = pathway_df[pathway_df["GeneNumber"] < configs["max_gene_num"]]
85 | if "min_gene_num" in configs:
86 | pathway_df = pathway_df[pathway_df["GeneNumber"] > configs["min_gene_num"]]
87 |
88 | for index, row in pathway_df.iterrows():
89 | pathway_dict[row["name"]] = row["genes"].split("|")
90 |
91 | cancer_genes = set([y for x in pathway_df["genes"].values for y in x.split("|")])
92 | non_cancer_genes = sorted(set(genes) - set(cancer_genes))
93 |
94 | cell_lines_all = data_input.index.values
95 | cv = KFold(n_splits=5, shuffle=True, random_state=seed)
96 | cell_lines_train_index, cell_lines_val_index = next(cv.split(cell_lines_all))
97 | cell_lines_train = np.array(cell_lines_all)[cell_lines_train_index]
98 | cell_lines_test = np.array(cell_lines_all)[cell_lines_val_index]
99 |
100 | data_input_train = data_input[data_input.index.isin(cell_lines_train)]
101 | data_input_test = data_input[data_input.index.isin(cell_lines_test)]
102 | data_target_train = data_target[data_target.index.isin(cell_lines_train)]
103 | data_target_test = data_target[data_target.index.isin(cell_lines_test)]
104 |
105 |
106 | def run_shap(merged_df_train, merged_df_test):
107 | train_df = merged_df_train.iloc[:, :num_of_features]
108 | test_df = merged_df_test.iloc[:, :num_of_features]
109 | train_target = merged_df_train.iloc[:, num_of_features:]
110 | test_target = merged_df_test.iloc[:, num_of_features:]
111 |
112 | X_train = train_df
113 | X_test = test_df
114 |
115 | train_dataset = MultiOmicMulticlassDataset(
116 | X_train,
117 | train_target,
118 | mode="train",
119 | omics_types=omics_types,
120 | class_name_to_id=class_name_to_id,
121 | logger=None,
122 | )
123 | test_dataset = MultiOmicMulticlassDataset(
124 | X_test,
125 | test_target,
126 | mode="val",
127 | omics_types=omics_types,
128 | class_name_to_id=class_name_to_id,
129 | logger=None,
130 | )
131 |
132 | model = DeePathNet(
133 | len(omics_types),
134 | len(class_name_to_id),
135 | train_dataset.genes_to_id,
136 | train_dataset.id_to_genes,
137 | pathway_dict,
138 | non_cancer_genes,
139 | embed_dim=configs["dim"],
140 | depth=configs["depth"],
141 | num_heads=configs["heads"],
142 | mlp_ratio=configs["mlp_ratio"],
143 | out_mlp_ratio=configs["out_mlp_ratio"],
144 | only_cancer_genes=configs["cancer_only"],
145 | )
146 | model.load_state_dict(torch.load(f"{configs['work_dir']}/{configs['saved_model']}"))
147 |
148 | model.to(device)
149 | model.eval()
150 | train_loader = DataLoader(
151 | train_dataset, batch_size=600, shuffle=True, num_workers=NUM_WORKERS
152 | )
153 | test_loader = DataLoader(
154 | test_dataset,
155 | batch_size=len(test_dataset),
156 | shuffle=True,
157 | num_workers=NUM_WORKERS,
158 | )
159 |
160 | data = next(iter(train_loader))
161 | test_data = next(iter(test_loader))
162 |
163 | (input, targets) = data
164 | (test_input, test_targets) = test_data
165 |
166 | NUM_EXPLAINED = 400
167 | N_SAMPLES = 50
168 |
169 | start = time.time()
170 | background = input.float().to(device)
171 | explainer = shap.GradientExplainer(model, background)
172 | shap_values = explainer.shap_values(
173 | test_input[:NUM_EXPLAINED, :, :].float().to(device), nsamples=N_SAMPLES
174 | )
175 | end = time.time()
176 | print(end - start)
177 |
178 | all_drug_gradients_summary = {"cancer_type": [], "gene": []}
179 | for target in omics_types:
180 | all_drug_gradients_summary[target] = []
181 |
182 | for idx in range(len(class_name_to_id)):
183 | cancer_type = id_to_class_name[idx]
184 | omics_shap = shap_values[idx]
185 |
186 | omics_shap_mean = np.mean(np.abs(omics_shap), axis=0)
187 |
188 | all_drug_gradients_summary["cancer_type"].extend([cancer_type] * len(genes))
189 | all_drug_gradients_summary["gene"].extend(genes)
190 | for i in range(len(omics_types)):
191 | all_drug_gradients_summary[omics_types[i]].extend(omics_shap_mean[:, i])
192 |
193 | all_drug_gradients_summary = pd.DataFrame(all_drug_gradients_summary)
194 | all_drug_gradients_summary["sum"] = all_drug_gradients_summary.iloc[:, 2:].sum(
195 | axis=1
196 | )
197 |
198 | del model, explainer, input, targets
199 | torch.cuda.empty_cache()
200 | return all_drug_gradients_summary
201 |
202 |
203 | merged_df_train = pd.merge(data_input_train, data_target_train, on=["Cell_line"])
204 | merged_df_test = pd.merge(data_input_test, data_target_test, on=["Cell_line"])
205 | all_drug_gradients_summary = run_shap(merged_df_train, merged_df_test)
206 | all_drug_gradients_summary.to_csv(
207 | f"{configs['work_dir']}/shap{mode}_genes_{configs['saved_model'].replace('pth', 'csv.gz')}",
208 | index=False,
209 | )
210 |
--------------------------------------------------------------------------------
/scripts/transformer_shap_drug_response.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from datetime import datetime
4 | import sys
5 | import logging
6 | import os
7 | import sys
8 |
9 | import pandas as pd
10 |
11 | sys.path.append(os.getcwd() + '/..')
12 | from models import *
13 | from model_transformer_lrp import DOIT_LRP, LRP
14 | from torch.utils.data import DataLoader
15 | from tqdm import tqdm, trange
16 |
17 | import shap
18 | import time
19 |
20 | seed = 12345
21 | torch.manual_seed(seed)
22 | OUTPUT_NA_NUM = -100
23 |
24 | config_file = sys.argv[1]
25 |
26 | mode = 'grad' if len(sys.argv) <= 2 else sys.argv[2]
27 | print(f"SHAP algo: {mode}")
28 | # load model configs
29 | configs = json.load(open(config_file, 'r'))
30 | data_file = configs['data_file']
31 | data_type = configs['data_type']
32 |
33 | BATCH_SIZE = configs['batch_size']
34 | NUM_WORKERS = 0
35 | LOG_FREQ = configs['log_freq']
36 | NUM_EPOCHS = configs['num_of_epochs']
37 | # device = 'cuda' if torch.cuda.is_available() else 'cpu'
38 | device = 'cuda'
39 |
40 | data_target = pd.read_csv(configs['target_file'], low_memory=False, index_col=0)
41 | drug_ids = data_target.columns
42 | if 'drug_id' in configs and configs['drug_id'] != "":
43 | data_target = data_target[configs['drug_id']]
44 | drug_ids = [configs['drug_id']]
45 |
46 | data_input = pd.read_csv(data_file, index_col=0)
47 | if data_type[0] != 'DR':
48 | data_input = data_input[
49 | [x for x in data_input.columns if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)]]
50 |
51 | tissues = [x for x in data_input.columns if 'tissue' in x] if 'tissue' in data_type else None
52 | omics_types = [x for x in data_type if x != 'tissue']
53 | with_tissue = 'tissue' in data_type
54 |
55 | with open(configs['train_cells']) as f:
56 | cell_lines_train = [line.rstrip() for line in f]
57 | with open(configs['test_cells']) as f:
58 | cell_lines_test = [line.rstrip() for line in f]
59 |
60 | genes = np.unique(([x.split("_")[0] for x in data_input.columns if x.split("_")[0] != 'tissue']))
61 |
62 | data_target_train = data_target[
63 | data_target.index.isin(cell_lines_train)]
64 | data_target_test = data_target[
65 | data_target.index.isin(cell_lines_test)]
66 |
67 | num_of_features = data_input.shape[1]
68 |
69 | pathway_dict = {}
70 | pathway_df = pd.read_csv(configs['pathway_file'])
71 |
72 | pathway_df['genes'] = pathway_df['genes'].map(
73 | lambda x: "|".join([gene for gene in x.split('|') if gene in genes]))
74 | if 'min_cancer_publication' in configs:
75 | pathway_df = pathway_df[pathway_df['Cancer_Publications'] > configs['min_cancer_publication']]
76 | if 'max_gene_num' in configs:
77 | pathway_df = pathway_df[pathway_df['GeneNumber'] < configs['max_gene_num']]
78 | if 'min_gene_num' in configs:
79 | pathway_df = pathway_df[pathway_df['GeneNumber'] > configs['min_gene_num']]
80 |
81 | for index, row in pathway_df.iterrows():
82 | pathway_dict[row['name']] = row['genes'].split('|')
83 |
84 | cancer_genes = set([y for x in pathway_df['genes'].values for y in x.split("|")])
85 | non_cancer_genes = sorted(set(genes) - set(cancer_genes))
86 |
87 | data_input_train = data_input[data_input.index.isin(cell_lines_train)]
88 | data_input_test = data_input[data_input.index.isin(cell_lines_test)]
89 |
90 |
91 | def run_shap(merged_df_train, merged_df_test, drug_ids=None):
92 | train_df = merged_df_train.iloc[:, :num_of_features]
93 | test_df = merged_df_test.iloc[:, :num_of_features]
94 | train_target = merged_df_train.iloc[:, num_of_features:]
95 | test_target = merged_df_test.iloc[:, num_of_features:]
96 |
97 | X_train = train_df
98 | X_test = test_df
99 |
100 | train_dataset = MultiOmicDataset(X_train, train_target, mode='train', omics_types=omics_types, logger=None,
101 | with_tissue=with_tissue)
102 | test_dataset = MultiOmicDataset(X_test, test_target, mode='val', omics_types=omics_types, logger=None,
103 | with_tissue=with_tissue)
104 |
105 | model = DOIT_LRP(len(omics_types), train_target.shape[1], train_dataset.genes_to_id,
106 | train_dataset.id_to_genes,
107 | pathway_dict, non_cancer_genes, embed_dim=configs['dim'], depth=configs['depth'],
108 | num_heads=configs['heads'],
109 | mlp_ratio=configs['mlp_ratio'], out_mlp_ratio=configs['out_mlp_ratio'],
110 | only_cancer_genes=configs['cancer_only'], tissues=tissues)
111 | if len(drug_ids) == 1:
112 | drug_file_name = get_model_filename(drug_ids[0])
113 | drug_path = f"{configs['work_dir']}_{configs['saved_model']}/{configs['saved_model']}_{drug_file_name}.pth"
114 | if not os.path.exists(drug_path):
115 | return None, None
116 | model.load_state_dict(torch.load(drug_path))
117 | else:
118 | model.load_state_dict(torch.load(f"{configs['work_dir']}/{configs['saved_model']}"))
119 | model.to(device)
120 | model.eval()
121 |
122 | train_loader = DataLoader(train_dataset,
123 | batch_size=len(train_dataset), shuffle=True,
124 | num_workers=NUM_WORKERS)
125 | test_loader = DataLoader(test_dataset,
126 | batch_size=len(test_dataset), shuffle=True,
127 | num_workers=NUM_WORKERS)
128 |
129 | data = next(iter(train_loader))
130 | test_data = next(iter(test_loader))
131 |
132 | tissue_x = None
133 | test_tissue_x = None
134 | if len(data) == 2:
135 | (input, targets) = data
136 | (test_input, test_targets) = test_data
137 | elif len(data) == 3:
138 | (input, tissue_x, targets) = data
139 | (test_input, test_tissue_x, test_targets) = test_data
140 | else:
141 | raise Exception
142 |
143 | NUM_EXPLAINED = 400
144 | N_SAMPLES = 50
145 |
146 | start = time.time()
147 | if mode == 'grad':
148 | if tissue_x is not None:
149 | background = [input.float().to(device), tissue_x.float().to(device)]
150 | explainer = shap.GradientExplainer(model, background)
151 | shap_values = explainer.shap_values(
152 | [test_input[:NUM_EXPLAINED, :, :].float().to(device),
153 | test_tissue_x[:NUM_EXPLAINED, :].float().to(device)],
154 | nsamples=N_SAMPLES)
155 | else:
156 | background = input.float().to(device)
157 | explainer = shap.GradientExplainer(model, background)
158 | shap_values = explainer.shap_values(test_input[:NUM_EXPLAINED, :, :].float().to(device), nsamples=N_SAMPLES)
159 | else:
160 | raise Exception
161 | end = time.time()
162 | print(end - start)
163 |
164 | all_drug_gradients_summary = {'drug_id': [], 'gene': []}
165 | all_drug_gradients_tissue_summary = {'drug_id': [], 'tissue': [], 'importance': []}
166 | for target in omics_types:
167 | all_drug_gradients_summary[target] = []
168 |
169 | for drug_idx in range(len(drug_ids)):
170 | drug_id = drug_ids[drug_idx]
171 | if len(drug_ids) > 1:
172 | omics_shap = shap_values[drug_idx][0] # N x genes x num_omics
173 | tissue_shap = shap_values[drug_idx][1] # N x tissue
174 | else:
175 | omics_shap = shap_values[0] # N x genes x num_omics
176 | tissue_shap = shap_values[1] # N x tissue
177 |
178 | omics_shap_mean = np.mean(np.abs(omics_shap), axis=0)
179 | tissue_shap_mean = np.mean(np.abs(tissue_shap), axis=0)
180 |
181 | all_drug_gradients_summary['drug_id'].extend([drug_id] * len(genes))
182 | all_drug_gradients_summary['gene'].extend(genes)
183 | for i in range(len(omics_types)):
184 | all_drug_gradients_summary[omics_types[i]].extend(omics_shap_mean[:, i])
185 |
186 | all_drug_gradients_tissue_summary['drug_id'].extend([drug_id] * len(tissues))
187 | all_drug_gradients_tissue_summary['tissue'].extend(tissues)
188 | all_drug_gradients_tissue_summary['importance'].extend(tissue_shap_mean)
189 |
190 | all_drug_gradients_summary = pd.DataFrame(all_drug_gradients_summary)
191 | all_drug_gradients_summary['sum'] = all_drug_gradients_summary.iloc[:, 2:].sum(axis=1)
192 |
193 | all_drug_gradients_tissue_summary = pd.DataFrame(all_drug_gradients_tissue_summary)
194 |
195 | del model, explainer, input, targets
196 | torch.cuda.empty_cache()
197 | return all_drug_gradients_summary, all_drug_gradients_tissue_summary
198 |
199 |
200 | if 'all_single_mode' in configs and configs['all_single_mode']:
201 | all_drug_gradients_summary = []
202 | all_drug_gradients_tissue_summary = []
203 | drug_ids = pd.read_csv(configs['drug_list'], index_col=0).index.values
204 | for drug_id in tqdm(drug_ids):
205 | merged_df_train = pd.merge(data_input_train, data_target_train[drug_id], on=['Cell_line'])
206 | merged_df_test = pd.merge(data_input_test, data_target_test[drug_id], on=['Cell_line'])
207 | omics, tissue = run_shap(merged_df_train, merged_df_test, drug_ids=[drug_id])
208 | if omics is not None:
209 | all_drug_gradients_summary.append(omics)
210 | all_drug_gradients_tissue_summary.append(tissue)
211 |
212 | all_drug_gradients_summary = pd.concat(all_drug_gradients_summary)
213 | all_drug_gradients_tissue_summary = pd.concat(all_drug_gradients_tissue_summary)
214 | all_drug_gradients_summary.to_csv(
215 | f"{configs['work_dir']}_{configs['saved_model']}/shap{mode}_genes_{configs['saved_model']}.csv.gz",
216 | index=False)
217 | all_drug_gradients_tissue_summary.to_csv(
218 | f"{configs['work_dir']}_{configs['saved_model']}/shap{mode}_tissue_{configs['saved_model']}.csv.gz",
219 | index=False)
220 | else:
221 | merged_df_train = pd.merge(data_input_train, data_target_train, on=['Cell_line'])
222 | merged_df_test = pd.merge(data_input_test, data_target_test, on=['Cell_line'])
223 | all_drug_gradients_summary, all_drug_gradients_tissue_summary = run_shap(merged_df_train, merged_df_test,
224 | drug_ids=drug_ids)
225 | all_drug_gradients_summary.to_csv(
226 | f"{configs['work_dir']}/shap{mode}_genes_{configs['saved_model'].replace('pth', 'csv.gz')}",
227 | index=False)
228 | all_drug_gradients_tissue_summary.to_csv(
229 | f"{configs['work_dir']}/shap{mode}_tissue_{configs['saved_model'].replace('pth', 'csv.gz')}", index=False)
230 |
--------------------------------------------------------------------------------
/scripts/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from . import lr_scheduler
--------------------------------------------------------------------------------
/scripts/utils/layers_ours.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
6 | 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
7 | 'LayerNorm', 'AddEye', 'Identity']
8 |
9 |
10 | def safe_divide(a, b):
11 | den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
12 | den = den + den.eq(0).type(den.type()) * 1e-9
13 | return a / den * b.ne(0).type(b.type())
14 |
15 |
16 | def forward_hook(self, input, output):
17 | if type(input[0]) in (list, tuple):
18 | self.X = []
19 | for i in input[0]:
20 | x = i.detach()
21 | x.requires_grad = True
22 | self.X.append(x)
23 | else:
24 | self.X = input[0].detach()
25 | self.X.requires_grad = True
26 |
27 | self.Y = output
28 |
29 |
30 | def backward_hook(self, grad_input, grad_output):
31 | self.grad_input = grad_input
32 | self.grad_output = grad_output
33 |
34 |
35 | class RelProp(nn.Module):
36 | def __init__(self):
37 | super(RelProp, self).__init__()
38 | # if not self.training:
39 | self.register_forward_hook(forward_hook)
40 |
41 | def gradprop(self, Z, X, S):
42 | C = torch.autograd.grad(Z, X, S, retain_graph=True)
43 | return C
44 |
45 | def relprop(self, R, alpha):
46 | return R
47 |
48 |
49 | class RelPropSimple(RelProp):
50 | def relprop(self, R, alpha):
51 | Z = self.forward(self.X)
52 | S = safe_divide(R, Z)
53 | C = self.gradprop(Z, self.X, S)
54 |
55 | if torch.is_tensor(self.X) == False:
56 | outputs = []
57 | outputs.append(self.X[0] * C[0])
58 | outputs.append(self.X[1] * C[1])
59 | else:
60 | outputs = self.X * (C[0])
61 | return outputs
62 |
63 |
64 | class AddEye(RelPropSimple):
65 | # input of shape B, C, seq_len, seq_len
66 | def forward(self, input):
67 | return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
68 |
69 |
70 | class ReLU(nn.ReLU, RelProp):
71 | pass
72 |
73 |
74 | class GELU(nn.GELU, RelProp):
75 | pass
76 |
77 |
78 | class Identity(nn.Identity, RelProp):
79 | pass
80 |
81 |
82 | class Softmax(nn.Softmax, RelProp):
83 | pass
84 |
85 |
86 | class LayerNorm(nn.LayerNorm, RelProp):
87 | pass
88 |
89 |
90 | class Dropout(nn.Dropout, RelProp):
91 | pass
92 |
93 |
94 | class MaxPool2d(nn.MaxPool2d, RelPropSimple):
95 | pass
96 |
97 |
98 | class LayerNorm(nn.LayerNorm, RelProp):
99 | pass
100 |
101 |
102 | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
103 | pass
104 |
105 |
106 | class AvgPool2d(nn.AvgPool2d, RelPropSimple):
107 | pass
108 |
109 |
110 | class Add(RelPropSimple):
111 | def forward(self, inputs):
112 | return torch.add(*inputs)
113 |
114 | def relprop(self, R, alpha):
115 | Z = self.forward(self.X)
116 | S = safe_divide(R, Z)
117 | C = self.gradprop(Z, self.X, S)
118 |
119 | a = self.X[0] * C[0]
120 | b = self.X[1] * C[1]
121 |
122 | a_sum = a.sum()
123 | b_sum = b.sum()
124 |
125 | a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
126 | b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
127 |
128 | a = a * safe_divide(a_fact, a.sum())
129 | b = b * safe_divide(b_fact, b.sum())
130 |
131 | outputs = [a, b]
132 |
133 | return outputs
134 |
135 |
136 | class einsum(RelPropSimple):
137 | def __init__(self, equation):
138 | super().__init__()
139 | self.equation = equation
140 |
141 | def forward(self, *operands):
142 | return torch.einsum(self.equation, *operands)
143 |
144 |
145 | class IndexSelect(RelProp):
146 | def forward(self, inputs, dim, indices):
147 | self.__setattr__('dim', dim)
148 | self.__setattr__('indices', indices)
149 |
150 | return torch.index_select(inputs, dim, indices)
151 |
152 | def relprop(self, R, alpha):
153 | Z = self.forward(self.X, self.dim, self.indices)
154 | S = safe_divide(R, Z)
155 | C = self.gradprop(Z, self.X, S)
156 |
157 | if torch.is_tensor(self.X) == False:
158 | outputs = []
159 | outputs.append(self.X[0] * C[0])
160 | outputs.append(self.X[1] * C[1])
161 | else:
162 | outputs = self.X * (C[0])
163 | return outputs
164 |
165 |
166 | class Clone(RelProp):
167 | def forward(self, input, num):
168 | self.__setattr__('num', num)
169 | outputs = []
170 | for _ in range(num):
171 | outputs.append(input)
172 |
173 | return outputs
174 |
175 | def relprop(self, R, alpha):
176 | Z = []
177 | for _ in range(self.num):
178 | Z.append(self.X)
179 | S = [safe_divide(r, z) for r, z in zip(R, Z)]
180 | C = self.gradprop(Z, self.X, S)[0]
181 |
182 | R = self.X * C
183 |
184 | return R
185 |
186 |
187 | class Cat(RelProp):
188 | def forward(self, inputs, dim):
189 | self.__setattr__('dim', dim)
190 | return torch.cat(inputs, dim)
191 |
192 | def relprop(self, R, alpha):
193 | Z = self.forward(self.X, self.dim)
194 | S = safe_divide(R, Z)
195 | C = self.gradprop(Z, self.X, S)
196 |
197 | outputs = []
198 | for x, c in zip(self.X, C):
199 | outputs.append(x * c)
200 |
201 | return outputs
202 |
203 |
204 | class Sequential(nn.Sequential):
205 | def relprop(self, R, alpha):
206 | for m in reversed(self._modules.values()):
207 | R = m.relprop(R, alpha)
208 | return R
209 |
210 |
211 | class BatchNorm2d(nn.BatchNorm2d, RelProp):
212 | def relprop(self, R, alpha):
213 | X = self.X
214 | beta = 1 - alpha
215 | weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
216 | (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
217 | Z = X * weight + 1e-9
218 | S = R / Z
219 | Ca = S * weight
220 | R = self.X * (Ca)
221 | return R
222 |
223 |
224 | class Linear(nn.Linear, RelProp):
225 | def relprop(self, R, alpha):
226 | beta = alpha - 1
227 | pw = torch.clamp(self.weight, min=0)
228 | nw = torch.clamp(self.weight, max=0)
229 | px = torch.clamp(self.X, min=0)
230 | nx = torch.clamp(self.X, max=0)
231 |
232 | def f(w1, w2, x1, x2):
233 | Z1 = F.linear(x1, w1)
234 | Z2 = F.linear(x2, w2)
235 | S1 = safe_divide(R, Z1 + Z2)
236 | S2 = safe_divide(R, Z1 + Z2)
237 | C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
238 | C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
239 |
240 | return C1 + C2
241 |
242 | activator_relevances = f(pw, nw, px, nx)
243 | inhibitor_relevances = f(nw, pw, px, nx)
244 |
245 | R = alpha * activator_relevances - beta * inhibitor_relevances
246 |
247 | return R
248 |
249 |
250 | class Conv2d(nn.Conv2d, RelProp):
251 | def gradprop2(self, DY, weight):
252 | Z = self.forward(self.X)
253 |
254 | output_padding = self.X.size()[2] - (
255 | (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
256 |
257 | return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
258 |
259 | def relprop(self, R, alpha):
260 | if self.X.shape[1] == 3:
261 | pw = torch.clamp(self.weight, min=0)
262 | nw = torch.clamp(self.weight, max=0)
263 | X = self.X
264 | L = self.X * 0 + \
265 | torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
266 | keepdim=True)[0]
267 | H = self.X * 0 + \
268 | torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
269 | keepdim=True)[0]
270 | Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
271 | torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
272 | torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
273 |
274 | S = R / Za
275 | C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
276 | R = C
277 | else:
278 | beta = alpha - 1
279 | pw = torch.clamp(self.weight, min=0)
280 | nw = torch.clamp(self.weight, max=0)
281 | px = torch.clamp(self.X, min=0)
282 | nx = torch.clamp(self.X, max=0)
283 |
284 | def f(w1, w2, x1, x2):
285 | Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
286 | Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
287 | S1 = safe_divide(R, Z1)
288 | S2 = safe_divide(R, Z2)
289 | C1 = x1 * self.gradprop(Z1, x1, S1)[0]
290 | C2 = x2 * self.gradprop(Z2, x2, S2)[0]
291 | return C1 + C2
292 |
293 | activator_relevances = f(pw, nw, px, nx)
294 | inhibitor_relevances = f(nw, pw, px, nx)
295 |
296 | R = alpha * activator_relevances - beta * inhibitor_relevances
297 | return R
298 |
--------------------------------------------------------------------------------
/scripts/utils/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 |
3 |
4 | class FindLR(_LRScheduler):
5 | """
6 | inspired by fast.ai @https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
7 | """
8 | def __init__(self, optimizer, max_steps, max_lr=10):
9 | self.max_steps = max_steps
10 | self.max_lr = max_lr
11 | super().__init__(optimizer)
12 |
13 | def get_lr(self):
14 | return [base_lr * ((self.max_lr / base_lr) ** (self.last_epoch / (self.max_steps - 1)))
15 | for base_lr in self.base_lrs]
16 |
17 |
18 | class NoamLR(_LRScheduler):
19 | """
20 | Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate
21 | linearly for the first ``warmup_steps`` training steps, and decreasing it thereafter proportionally
22 | to the inverse square root of the step number, scaled by the inverse square root of the
23 | dimensionality of the model. Time will tell if this is just madness or it's actually important.
24 | Parameters
25 | ----------
26 | warmup_steps: ``int``, required.
27 | The number of steps to linearly increase the learning rate.
28 | """
29 | def __init__(self, optimizer, warmup_steps):
30 | self.warmup_steps = warmup_steps
31 | super().__init__(optimizer)
32 |
33 | def get_lr(self):
34 | last_epoch = max(1, self.last_epoch)
35 | scale = self.warmup_steps ** 0.5 * min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5))
36 | return [base_lr * scale for base_lr in self.base_lrs]
--------------------------------------------------------------------------------
/scripts/utils/training_prepare.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 |
5 | import numpy as np
6 | import pandas as pd
7 | from sklearn.model_selection import train_test_split
8 |
9 |
10 | def get_logger(config_file, STAMP):
11 | configs = json.load(open(config_file, "r"))
12 | log_suffix = ""
13 | if "suffix" in configs:
14 | log_suffix = configs["suffix"]
15 | log_file = f"{STAMP}{log_suffix}.log"
16 | logger = logging.getLogger("multi-drug")
17 | logger.setLevel(logging.DEBUG)
18 | fh = logging.FileHandler(os.path.join(configs["work_dir"], log_file))
19 | ch = logging.StreamHandler()
20 | ch.setLevel(logging.INFO)
21 | formatter = logging.Formatter(
22 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
23 | )
24 | ch.setFormatter(formatter)
25 | fh.setFormatter(formatter)
26 | # logger.addHandler(ch)
27 | logger.addHandler(fh)
28 |
29 | logger.info(open(config_file, "r").read())
30 | print(open(config_file, "r").read())
31 | return logger
32 |
33 |
34 | def prepare_data_cv(config_file, STAMP):
35 | configs = json.load(open(config_file, "r"))
36 |
37 | if not os.path.isdir(configs["work_dir"]):
38 | os.system(f"mkdir -p {configs['work_dir']}")
39 |
40 | data_file = configs["data_file"]
41 | data_type = configs["data_type"]
42 |
43 | data_target = pd.read_csv(configs["target_file"], low_memory=False, index_col=0)
44 | if "drug_id" in configs and configs["drug_id"] != "":
45 | data_target = data_target[configs["drug_id"]]
46 |
47 | data_input = pd.read_csv(data_file, index_col=0)
48 | if data_type[0] != "DR":
49 | data_input = data_input[
50 | [
51 | x
52 | for x in data_input.columns
53 | if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)
54 | ]
55 | ]
56 |
57 | if "downsample" in configs:
58 | _, test_idx = train_test_split(
59 | data_input.index,
60 | test_size=configs["downsample"],
61 | random_state=42,
62 | stratify=data_target,
63 | )
64 | data_input = data_input.loc[test_idx]
65 | data_target = data_target.loc[test_idx]
66 | # data_input = data_input.sample(n=configs["downsample"], random_state=42)
67 | # data_target = data_target.loc[data_input.index]
68 |
69 | tissues = (
70 | [x for x in data_input.columns if "tissue" in x]
71 | if "tissue" in data_type
72 | else None
73 | )
74 | omics_types = [x for x in data_type if x != "tissue"]
75 | with_tissue = "tissue" in data_type
76 |
77 | genes = np.unique(
78 | ([x.split("_")[0] for x in data_input.columns if x.split("_")[0] != "tissue"])
79 | )
80 |
81 | num_of_features = data_input.shape[1]
82 |
83 | if configs["task"] == "regression":
84 | val_score_dict = {
85 | "drug_id": [],
86 | "run": [],
87 | "epoch": [],
88 | "mae": [],
89 | "rmse": [],
90 | "corr": [],
91 | "r2": [],
92 | }
93 | elif configs["task"] == "multiclass":
94 | val_score_dict = {
95 | "run": [],
96 | "epoch": [],
97 | "top1_acc": [],
98 | "top3_acc": [],
99 | "f1": [],
100 | "roc_auc": [],
101 | }
102 | else:
103 | val_score_dict = {
104 | "drug_id": [],
105 | "run": [],
106 | "epoch": [],
107 | "accuracy": [],
108 | "auc": [],
109 | }
110 |
111 | ret_dict = {
112 | "data_input_all": data_input,
113 | "data_target_all": data_target,
114 | "val_score_dict": val_score_dict,
115 | "num_of_features": num_of_features,
116 | "genes": genes,
117 | "omics_types": omics_types,
118 | "with_tissue": with_tissue,
119 | "tissues": tissues,
120 | "logger": logger,
121 | }
122 | return ret_dict
123 |
124 |
125 | def prepare_data_independent_test(config_file, STAMP, seed=1):
126 | configs = json.load(open(config_file, "r"))
127 |
128 | if not os.path.isdir(configs["work_dir"]):
129 | os.system(f"mkdir -p {configs['work_dir']}")
130 |
131 | data_file_train = configs["data_file_train"]
132 | data_file_test = configs["data_file_test"]
133 | data_type = configs["data_type"]
134 |
135 | data_target_train = pd.read_csv(
136 | configs["target_file_train"], low_memory=False, index_col=0
137 | )
138 | data_target_test = pd.read_csv(
139 | configs["target_file_test"], low_memory=False, index_col=0
140 | )
141 |
142 | if "drug_id" in configs and configs["drug_id"] != "":
143 | data_target_train = data_target_train[configs["drug_id"]]
144 |
145 | data_input_train = pd.read_csv(data_file_train, index_col=0).fillna(0)
146 | data_input_test = pd.read_csv(data_file_test, index_col=0).fillna(0)
147 | common_features = list(
148 | set(data_input_train.columns).intersection(data_input_test.columns)
149 | )
150 | common_cells_train = list(
151 | set(data_input_train.index).intersection(data_target_train.index)
152 | )
153 | common_cells_test = list(
154 | set(data_input_test.index).intersection(data_target_test.index)
155 | )
156 | data_input_train = data_input_train.loc[common_cells_train, common_features]
157 | data_target_train = data_target_train.loc[common_cells_train]
158 | data_input_test = data_input_test.loc[common_cells_test, common_features]
159 | data_target_test = data_target_test.loc[common_cells_test]
160 |
161 | if data_type[0] != "DR":
162 | data_input_train = data_input_train[
163 | [
164 | x
165 | for x in data_input_train.columns
166 | if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)
167 | ]
168 | ]
169 | data_input_test = data_input_test[
170 | [
171 | x
172 | for x in data_input_test.columns
173 | if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)
174 | ]
175 | ]
176 |
177 | if "downsample" in configs:
178 | _, test_idx = train_test_split(
179 | data_input_train.index,
180 | test_size=configs["downsample"],
181 | random_state=seed,
182 | )
183 | data_input_train = data_input_train.loc[test_idx]
184 | data_target_train = data_target_train.loc[test_idx]
185 |
186 | tissues = (
187 | [x for x in data_input_train.columns if "tissue" in x]
188 | if "tissue" in data_type
189 | else None
190 | )
191 | omics_types = [x for x in data_type if x != "tissue"]
192 | with_tissue = "tissue" in data_type
193 |
194 | genes = np.unique(
195 | (
196 | [
197 | x.split("_")[0]
198 | for x in data_input_train.columns
199 | if x.split("_")[0] != "tissue"
200 | ]
201 | )
202 | )
203 |
204 | num_of_features = data_input_train.shape[1]
205 |
206 | ret_dict = {
207 | "data_input_train": data_input_train,
208 | "data_target_train": data_target_train,
209 | "data_input_test": data_input_test,
210 | "data_target_test": data_target_test,
211 | "num_of_features": num_of_features,
212 | "genes": genes,
213 | "omics_types": omics_types,
214 | "with_tissue": with_tissue,
215 | "tissues": tissues,
216 | }
217 | return ret_dict
218 |
219 |
220 | def get_score_dict(config_file):
221 | configs = json.load(open(config_file, "r"))
222 | if configs["task"] == "regression":
223 | val_score_dict = {
224 | "drug_id": [],
225 | "run": [],
226 | "epoch": [],
227 | "mae": [],
228 | "rmse": [],
229 | "corr": [],
230 | "r2": [],
231 | }
232 | elif configs["task"] == "multiclass":
233 | val_score_dict = {
234 | "run": [],
235 | "epoch": [],
236 | "top1_acc": [],
237 | "top3_acc": [],
238 | "f1": [],
239 | "roc_auc": [],
240 | }
241 | else:
242 | val_score_dict = {
243 | "drug_id": [],
244 | "run": [],
245 | "epoch": [],
246 | "accuracy": [],
247 | "auc": [],
248 | }
249 | return val_score_dict
250 |
--------------------------------------------------------------------------------