├── .DS_Store ├── .idea ├── .gitignore ├── DeePathNet.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── other.xml └── vcs.xml ├── R ├── mixOmics.R ├── mixOmicsTCGA.R └── moCluster.R ├── README.md ├── configs ├── .DS_Store ├── ccle_gdsc_intersection3 │ ├── .DS_Store │ └── mutation_cnv_rna │ │ ├── deepathnet_allgenes_mutation_cnv_rna.json │ │ ├── ec_en_allgenes_drug_mutation_cnv_rna.json │ │ ├── ec_rf_allgenes_drug_mutation_cnv_rna.json │ │ ├── moCluster_rf_allgenes_drug_mutation_cnv_rna.json │ │ ├── move_rf_allgenes_drug_mutation_cnv_rna.json │ │ ├── pca_rf_allgenes_drug_mutation_cnv_rna.json │ │ └── scvaeit_rf_allgenes_drug_mutation_cnv_rna.json ├── sanger_gdsc_intersection_noprot │ ├── .DS_Store │ └── mutation_cnv_rna │ │ ├── deepathnet_allgenes_mutation_cnv_rna.json │ │ ├── ec_en_allgenes_drug_mutation_cnv_rna.json │ │ ├── ec_rf_allgenes_drug_mutation_cnv_rna.json │ │ ├── moCluster_rf_allgenes_drug_mutation_cnv_rna.json │ │ ├── move_rf_allgenes_drug_mutation_cnv_rna.json │ │ ├── pca_rf_allgenes_drug_mutation_cnv_rna.json │ │ └── scvaeit_rf_allgenes_drug_mutation_cnv_rna.json ├── sanger_train_ccle_test_gdsc │ ├── .DS_Store │ ├── mutation_cnv_rna │ │ ├── deepathnet_mutation_cnv_rna.json │ │ ├── deepathnet_mutation_cnv_rna_ds_100.json │ │ ├── deepathnet_mutation_cnv_rna_ds_200.json │ │ ├── deepathnet_mutation_cnv_rna_ds_300.json │ │ ├── deepathnet_mutation_cnv_rna_ds_400.json │ │ ├── deepathnet_mutation_cnv_rna_random_control.json │ │ ├── deepathnet_mutation_cnv_rna_random_control_ds_100.json │ │ ├── deepathnet_mutation_cnv_rna_random_control_ds_200.json │ │ ├── deepathnet_mutation_cnv_rna_random_control_ds_300.json │ │ ├── deepathnet_mutation_cnv_rna_random_control_ds_400.json │ │ ├── ec_mlp_all_genes_mutation_cnv_rna.json │ │ └── ec_rf_all_genes_mutation_cnv_rna.json │ └── mutation_cnv_rna_prot │ │ ├── deepathnet_mutation_cnv_rna_prot.json │ │ ├── deepathnet_mutation_cnv_rna_prot_random_control.json │ │ └── ec_rf_all_genes_mutation_cnv_rna_prot.json ├── tcga_all_cancer_types │ ├── .DS_Store │ └── mutation_cnv_rna │ │ ├── deepathnet_mutation_cnv_rna.json │ │ ├── deepathnet_mutation_cnv_rna_ds_2k.json │ │ ├── deepathnet_mutation_cnv_rna_random_control.json │ │ └── deepathnet_mutation_cnv_rna_random_control_ds_2k.json ├── tcga_brca_subtypes │ ├── .DS_Store │ └── mutation_cnv_rna │ │ └── deepathnet_mutation_cnv_rna.json └── tcga_train_cptac_test_brca │ └── cnv_rna │ ├── deepathnet_cnv_rna.json │ ├── deepathnet_cnv_rna_random_control.json │ └── ec_rf_cnv_rna.json ├── figures ├── Figure1.pdf ├── Figure1.png ├── supp_fig1.pdf └── supp_fig1.png ├── models ├── cancer_type_pretrained.pth └── drug_response_pretrained.pth ├── requirements.txt └── scripts ├── .DS_Store ├── baseline_ec_cv.py ├── baseline_independent_test.py ├── cancer_type_baseline_23cancertypes.py ├── cancer_type_baseline_brca.py ├── cancer_type_baseline_brca_validation.py ├── deepathnet_cv.py ├── deepathnet_independent_test.py ├── model_transformer_lrp.py ├── models.py ├── transformer_explantion_cancer_type.py ├── transformer_explantion_drug_response.py ├── transformer_shap_cancer_type.py ├── transformer_shap_drug_response.py └── utils ├── __init__.py ├── layers_ours.py ├── lr_scheduler.py └── training_prepare.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/.DS_Store -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/DeePathNet.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 15 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /R/mixOmics.R: -------------------------------------------------------------------------------- 1 | library(mixOmics) 2 | library(Metrics) 3 | 4 | # input_data = "sanger" 5 | # target = "gdsc" 6 | 7 | # args = commandArgs(trailingOnly=TRUE) 8 | # 9 | # input_data = args[1] 10 | # target = args[2] 11 | # mode = args[3] 12 | 13 | input_data = "ccle" 14 | target = "ctd2" 15 | mode = "allgenes" 16 | 17 | if (input_data == "sanger"){ 18 | mutation = read.csv("./data/processed/omics/sanger_df_mutation_drug.csv.gz", row.names = 1) 19 | cnv = read.csv("./data/processed/omics/sanger_df_cnv_drug.csv.gz", row.names = 1) 20 | rna = read.csv("./data/processed/omics/sanger_df_rna_drug.csv.gz", row.names = 1) 21 | train_splits = read.csv("./data/meta/25splits_train_sanger.csv", stringsAsFactors = F) 22 | val_splits = read.csv("./data/meta/25splits_val_sanger.csv", stringsAsFactors = F) 23 | 24 | } else { 25 | if (target == "gdsc"){ 26 | train_splits = read.csv("./data/meta/25splits_train_ccle_gdsc.csv", stringsAsFactors = F) 27 | val_splits = read.csv("./data/meta/25splits_val_ccle_gdsc.csv", stringsAsFactors = F) 28 | mutation = read.csv("./data/processed/omics/ccle_df_mutation_drug.csv.gz", row.names = 1) 29 | cnv = read.csv("./data/processed/omics/ccle_df_cnv_drug.csv.gz", row.names = 1) 30 | rna = read.csv("./data/processed/omics/ccle_df_rna_drug.csv.gz", row.names = 1) 31 | } else { 32 | mutation = read.csv("./data/processed/omics/ccle_df_mutation_ctd2.csv.gz", row.names = 1) 33 | cnv = read.csv("./data/processed/omics/ccle_df_cnv_ctd2.csv.gz", row.names = 1) 34 | rna = read.csv("./data/processed/omics/ccle_df_rna_ctd2.csv.gz", row.names = 1) 35 | train_splits = read.csv("./data/meta/25splits_train_ccle_ctd2.csv", stringsAsFactors = F) 36 | val_splits = read.csv("./data/meta/25splits_val_ccle_ctd2.csv", stringsAsFactors = F) 37 | } 38 | } 39 | 40 | if (target == "gdsc"){ 41 | if (input_data == "sanger"){ 42 | target_data = read.csv("./data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", row.names = 1) 43 | } else { 44 | target_data = read.csv("./data/processed/drug/ccle_gdsc_intersection3_wide.csv.gz", row.names = 1) 45 | } 46 | } else { 47 | if (input_data == "sanger"){ 48 | target_data = read.csv("./data/processed/drug/sanger_ctd2_min400.csv.gz", row.names = 1) 49 | } else { 50 | target_data = read.csv("./data/processed/drug/ccle_ctd2_min600.csv.gz", row.names = 1) 51 | } 52 | 53 | } 54 | 55 | print(input_data) 56 | print(target) 57 | 58 | drug_samples = rownames(target_data) 59 | target_data = sapply(target_data, as.numeric) 60 | rownames(target_data) = drug_samples 61 | 62 | lc_cancer_genes = read.csv("./data/meta/lc_cancer_genes.csv", stringsAsFactors = F) 63 | 64 | 65 | 66 | common_samples = intersect(rownames(mutation), drug_samples) 67 | 68 | if (mode != "allgenes"){ 69 | mutation = mutation[common_samples, colnames(mutation) %in% lc_cancer_genes$cancer_genes] 70 | cnv = cnv[common_samples, colnames(cnv) %in% lc_cancer_genes$cancer_genes] 71 | rna = rna[common_samples, colnames(rna) %in% lc_cancer_genes$cancer_genes] 72 | } 73 | 74 | target_data = target_data[common_samples,] 75 | 76 | 77 | res_df <- data.frame(drug_id=character(), 78 | run=character(), 79 | corr=numeric(), 80 | r2=numeric(), 81 | mae=numeric(), 82 | rmse=numeric(), 83 | time=numeric(), 84 | stringsAsFactors=FALSE) 85 | 86 | for (i in 1:25){ 87 | train_samples = intersect(train_splits[, i], drug_samples) 88 | val_samples = intersect(val_splits[, i], drug_samples) 89 | 90 | mutation_train = mutation[rownames(mutation) %in% train_samples,] 91 | mutation_val = mutation[rownames(mutation) %in% val_samples,] 92 | 93 | cnv_train = cnv[rownames(cnv) %in% train_samples,] 94 | cnv_val = cnv[rownames(cnv) %in% val_samples,] 95 | 96 | rna_train = rna[rownames(rna) %in% train_samples,] 97 | rna_val = rna[rownames(rna) %in% val_samples,] 98 | 99 | train_data = list(mutation = mutation_train, 100 | CNV = cnv_train, 101 | RNA = rna_train) 102 | 103 | val_data = list(mutation = mutation_val, 104 | CNV = cnv_val, 105 | RNA = rna_val) 106 | 107 | start_time <- Sys.time() 108 | N_comp = 50 109 | model = block.spls(train_data, target_data[rownames(target_data) %in% train_samples, ], ncomp = N_comp) 110 | target_val = target_data[rownames(target_data) %in% val_samples, ] 111 | pred = predict(model, val_data) 112 | pred_values = pred$AveragedPredict[,,N_comp] 113 | end_time <- Sys.time() 114 | time_taken = as.numeric(end_time - start_time) 115 | 116 | for (drug_idx in 1:dim(target_data)[2]){ 117 | curr_pred = pred_values[,drug_idx] 118 | curr_target = target_val[,drug_idx] 119 | pred_mean = mean(na.omit(curr_pred)) 120 | curr_pred[is.na(curr_pred)] = pred_mean 121 | curr_pred = curr_pred[!is.na(curr_target)] 122 | curr_target = curr_target[!is.na(curr_target)] 123 | pcorr = cor(curr_pred, curr_target, method="pearson") 124 | r2 = 1 - sum((curr_target - curr_pred)^2) / sum((curr_target - mean(curr_target))^2) 125 | mae = mean(abs(curr_target - curr_pred)) 126 | rmse_drug = rmse(curr_target, curr_pred) 127 | 128 | next_row = c(colnames(target_data)[drug_idx], paste0("cv_", i-1), pcorr, r2, mae,rmse_drug, time_taken) 129 | res_df[nrow(res_df) + 1,] = next_row 130 | } 131 | 132 | print(paste0("run ", i, " completed")) 133 | flush.console() 134 | } 135 | 136 | filename = paste0("./work_dirs/mixOmics/", input_data, "_mutation_cnv_rna_", target, "_", mode, "_", N_comp, "_comp.csv") 137 | write.csv(res_df, filename, row.names = F) 138 | -------------------------------------------------------------------------------- /R/mixOmicsTCGA.R: -------------------------------------------------------------------------------- 1 | library(mixOmics) 2 | library(pROC) 3 | library(mltest) 4 | 5 | # input_data = "tcga_23_cancer_types" 6 | # input_data = "tcga_brca" 7 | 8 | # target = "all" 9 | # target = "brca_subtypes" 10 | 11 | args = commandArgs(trailingOnly=TRUE) 12 | 13 | input_data = args[1] 14 | 15 | mode = args[2] 16 | 17 | if (input_data == "tcga_23_cancer_types") { 18 | mutation = read.csv("./data/processed/omics/tcga_23_cancer_types_mutation.csv.gz", row.names = 1) 19 | cnv = read.csv("./data/processed/omics/tcga_23_cancer_types_cnv.csv.gz", row.names = 1) 20 | rna = read.csv("./data/processed/omics/tcga_23_cancer_types_rna.csv.gz", row.names = 1) 21 | target_data = read.csv("./data/processed/cancer_type/tcga_23_cancer_types_mutation_cnv_rna.csv", row.names = 1) 22 | train_splits = read.csv("./data/meta/25splits_train_tcga.csv", stringsAsFactors = F) 23 | val_splits = read.csv("./data/meta/25splits_val_tcga.csv", stringsAsFactors = F) 24 | } else { 25 | mutation = read.csv("./data/processed/omics/tcga_brca_mutation.csv.gz", row.names = 1) 26 | cnv = read.csv("./data/processed/omics/tcga_brca_cnv.csv.gz", row.names = 1) 27 | rna = read.csv("./data/processed/omics/tcga_brca_rna.csv.gz", row.names = 1) 28 | target_data = read.csv("./data/processed/cancer_type/tcga_brca_mutation_cnv_rna_subtypes.csv", row.names = 1) 29 | train_splits = read.csv("./data/meta/25splits_train_tcga_brca.csv", stringsAsFactors = F) 30 | val_splits = read.csv("./data/meta/25splits_val_tcga_brca.csv", stringsAsFactors = F) 31 | } 32 | 33 | mutation[is.na(mutation)] <- 0 34 | cnv[is.na(cnv)] <- 0 35 | rna[is.na(rna)] <- 0 36 | samples = rownames(target_data) 37 | target_data = sapply(target_data, as.numeric) 38 | rownames(target_data) = samples 39 | 40 | lc_cancer_genes = read.csv("./data/meta/lc_cancer_genes.csv", stringsAsFactors = F) 41 | 42 | 43 | 44 | common_samples = intersect(rownames(mutation), samples) 45 | 46 | # mutation = mutation[common_samples, colnames(mutation) %in% lc_cancer_genes$cancer_genes] 47 | # cnv = cnv[common_samples, colnames(cnv) %in% lc_cancer_genes$cancer_genes] 48 | # rna = rna[common_samples, colnames(rna) %in% lc_cancer_genes$cancer_genes] 49 | target_data = target_data[common_samples,] 50 | 51 | 52 | res_df <- data.frame(run=character(), 53 | top1_acc=numeric(), 54 | f1=numeric(), 55 | roc_auc=numeric(), 56 | time=numeric(), 57 | stringsAsFactors=FALSE) 58 | i=1 59 | for (i in 1:25){ 60 | train_samples = intersect(train_splits[, i], samples) 61 | val_samples = intersect(val_splits[, i], samples) 62 | 63 | mutation_train = mutation[rownames(mutation) %in% train_samples,] 64 | mutation_val = mutation[rownames(mutation) %in% val_samples,] 65 | 66 | cnv_train = cnv[rownames(cnv) %in% train_samples,] 67 | cnv_val = cnv[rownames(cnv) %in% val_samples,] 68 | 69 | rna_train = rna[rownames(rna) %in% train_samples,] 70 | rna_val = rna[rownames(rna) %in% val_samples,] 71 | 72 | train_data = list(mutation = mutation_train, 73 | CNV = cnv_train, 74 | RNA = rna_train) 75 | 76 | val_data = list(mutation = mutation_val, 77 | CNV = cnv_val, 78 | RNA = rna_val) 79 | 80 | start_time <- Sys.time() 81 | N_comp = 50 82 | model = block.splsda(train_data, target_data[names(target_data) %in% train_samples], ncomp = N_comp) 83 | target_val = target_data[names(target_data) %in% val_samples] 84 | pred = predict(model, val_data) 85 | pred_values = as.numeric(pred$AveragedPredict.class$max.dist[,N_comp]) 86 | end_time <- Sys.time() 87 | time_taken = as.numeric(end_time - start_time) 88 | 89 | auc = multiclass.roc(target_val,pred_values)$auc 90 | 91 | y_true = as.factor(target_val) 92 | y_pred = as.factor(pred_values) 93 | levels(y_pred) = levels(y_true) 94 | test_res = ml_test(y_pred, y_true, output.as.table = FALSE) 95 | f1_macro = mean(na.omit(test_res$F1)) 96 | accuracy = test_res$accuracy 97 | 98 | next_row = c(paste0("cv_", i-1), accuracy, f1_macro, auc, time_taken) 99 | res_df[nrow(res_df) + 1,] = next_row 100 | 101 | print(paste0("run ", i, " completed")) 102 | flush.console() 103 | } 104 | 105 | filename = paste0("./work_dirs/mixOmics/", input_data, "_mutation_cnv_rna_" , mode, "_", N_comp, "_comp.csv") 106 | write.csv(res_df, filename, row.names = F) 107 | -------------------------------------------------------------------------------- /R/moCluster.R: -------------------------------------------------------------------------------- 1 | library(mogsa) 2 | 3 | # mutation = read.csv("./data/processed/omics/ccle_df_mutation_drug.csv.gz", row.names = 1) 4 | # cnv = read.csv("./data/processed/omics/ccle_df_cnv_drug.csv.gz", row.names = 1) 5 | # rna = read.csv("./data/processed/omics/ccle_df_rna_drug.csv.gz", row.names = 1) 6 | 7 | # mutation = read.csv("./data/processed/omics/ccle_df_mutation_drug_ctd2.csv.gz", row.names = 1) 8 | # cnv = read.csv("./data/processed/omics/ccle_df_cnv_drug_ctd2.csv.gz", row.names = 1) 9 | # rna = read.csv("./data/processed/omics/ccle_df_rna_drug_ctd2.csv.gz", row.names = 1) 10 | 11 | mutation = read.csv("./data/processed/omics/tcga_23_cancer_types_mutation.csv.gz", row.names = 1) 12 | cnv = read.csv("./data/processed/omics/tcga_23_cancer_types_cnv.csv.gz", row.names = 1) 13 | rna = read.csv("./data/processed/omics/tcga_23_cancer_types_rna.csv.gz", row.names = 1) 14 | mutation[is.na(mutation)] <- 0 15 | cnv[is.na(cnv)] <- 0 16 | rna[is.na(rna)] <- 0 17 | 18 | lc_cancer_genes = read.csv("./data/meta/lc_cancer_genes.csv", stringsAsFactors = F) 19 | lc_cancer_genes$cancer_genes 20 | 21 | common_samples = rownames(mutation) 22 | 23 | # mutation = mutation[, colnames(mutation) %in% lc_cancer_genes$cancer_genes] 24 | # cnv = cnv[, colnames(cnv) %in% lc_cancer_genes$cancer_genes] 25 | # rna = rna[, colnames(rna) %in% lc_cancer_genes$cancer_genes] 26 | 27 | mutation <- sapply(mutation, as.numeric) 28 | cnv <- sapply(cnv, as.numeric) 29 | rna <- sapply(rna, as.numeric) 30 | 31 | mo.combined = list(t(mutation), t(cnv), t(rna)) 32 | # mo.combined = list(t(cnv), t(rna), t(protein)) 33 | # mo.combined = list(t(methy), t(rna), t(protein)) 34 | # mo.combined = list(t(rna), t(protein)) 35 | print(length(mo.combined)) 36 | 37 | moa <- mbpca(mo.combined, ncomp = 200, k = "all", method = "globalScore", 38 | option = "lambda1", center=TRUE, scale=FALSE, moa = TRUE, 39 | svd.solver = "fast", maxiter = 1000) 40 | 41 | res.df = as.data.frame(moa@fac.scr) 42 | res.df$Cell_line = common_samples 43 | res.df = res.df[, c(dim(res.df)[2], 1:(dim(res.df)[2]-1))] 44 | write.table(res.df, "./data/DR/moCluster/tcga_23_cancer_types_mutation_cnv_rna_allgenes.csv", sep = ",", quote = F, row.names = F) 45 | 46 | # tmp = read.csv("./data/DR/moCluster/ccle_mutation_cnv_rna_gdsc.csv", sep = ",") 47 | # tmp = tmp[, c(2:(dim(tmp)[2]-1), dim(tmp)[2])] 48 | # res.df = tmp 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeePathNet 2 | [![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0) 3 | 4 | ![Figure1](./figures/Figure1.png) 5 | 6 | Transformer-based deep learning integrates multi-omic data with cancer pathways. 7 | Cai, et al., 2023 8 | 9 | ## Overview 10 | 11 | DeePathNet is a transformer-based deep learning tool that integrates multi-omic data to improve predictions for cancer subtyping and drug response. It combines pathway-level information with deep learning to enhance precision in oncology research. 12 | 13 | ## Features 14 | - **Multi-omic data integration** using a transformer architecture 15 | - **Support for pathway-level feature importance analysis** 16 | - **Pre-trained models** for cancer type classification and drug response prediction 17 | - **Cross-validation and independent test scripts** for model evaluation 18 | 19 | ## Setting up the Coding Environment 20 | 21 | To ensure compatibility and avoid potential issues, it is recommended to use Python 3.8 and PyTorch 1.10. Below are detailed instructions to set up the coding environment: 22 | 23 | 1. **Install Anaconda** 24 | - Follow the [Anaconda installation guide](https://docs.anaconda.com/free/anaconda/install/index.html) for your operating system. 25 | 26 | 2. **Create a Virtual Environment** 27 | - Once Anaconda is installed, create and activate a virtual environment: 28 | ```bash 29 | conda create -n deepathnet_env python=3.8 anaconda 30 | conda activate deepathnet_env 31 | ``` 32 | 33 | 3. **Install PyTorch** 34 | - Install the appropriate version of PyTorch based on your hardware: 35 | - **For CUDA-enabled systems**: 36 | ```bash 37 | pip install torch==1.10.0 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 38 | ``` 39 | - **For CPU-only systems**: 40 | ```bash 41 | pip install torch==1.10.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu 42 | ``` 43 | 44 | 4. **Install Additional Dependencies** 45 | - DeePathNet requires several Python packages to run effectively. Create a `requirements.txt` file with the following contents: 46 | ```plaintext 47 | torch==1.10.0 48 | torchvision 49 | torchaudio 50 | numpy 51 | pandas 52 | scikit-learn 53 | matplotlib 54 | seaborn 55 | json5 56 | scipy 57 | tqdm 58 | ``` 59 | - Install all dependencies: 60 | ```bash 61 | pip install -r requirements.txt 62 | ``` 63 | 64 | ## Loading Pre-trained DeePathNet Model(s) with Test Data 65 | 66 | DeePathNet provides pre-trained models and test datasets to facilitate model evaluation: 67 | 68 | 1. **Download Pre-trained Models and Test Data** 69 | - Access pre-trained models and test data from the [Figshare repository](https://doi.org/10.6084/m9.figshare.24137619). 70 | - Save the files to local directories, such as `models/` for models and `data/` for test data. 71 | 72 | 2. **Configure Paths for Models and Test Data** 73 | - Update the paths in the configuration file, such as: 74 | ```json 75 | { 76 | "model": "DeePathNet", 77 | "pretrained_model_path": "models/deepathnet_pretrained.pth", 78 | "test_data_path": "data/test_data.csv", 79 | "output_dir": "results/", 80 | ... 81 | } 82 | ``` 83 | - Ensure that the paths in `pretrained_model_path` and `test_data_path` are correctly set to the local files. 84 | 85 | 3. **Load and Run Pre-trained DeePathNet Model** 86 | - DeePathNet can be run using the `deepathnet_independent_test.py` script, which loads a pre-trained model and performs inference: 87 | ```bash 88 | python scripts/deepathnet_independent_test.py configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna_prot/deepathnet_mutation_cnv_rna_prot.json 89 | ``` 90 | - This command will load the pre-trained model and run inference on the specified test dataset. The results will be saved to the designated output directory as specified in the configuration file. 91 | 92 | ## Running the Inference Step to Generate Predictions 93 | 94 | The following example demonstrates how to generate predictions using DeePathNet for various tasks: 95 | 96 | 1. **Predict Drug Response** 97 | - To predict drug response (IC50 values), run: 98 | ```bash 99 | python scripts/deepathnet_independent_test.py configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna_prot/deepathnet_mutation_cnv_rna_prot.json 100 | ``` 101 | - This script reads the configuration, loads the pre-trained model, and performs inference on the test dataset. 102 | 103 | 2. **Classify Cancer Types** 104 | - For cancer type classification: 105 | ```bash 106 | python scripts/deepathnet_cv.py configs/tcga_all_cancer_types/mutation_cnv_rna/deepathnet_mutation_cnv_rna.json 107 | ``` 108 | - This script performs cross-validation using a specified dataset and configuration file. 109 | 110 | 3. **Breast Cancer Subtyping** 111 | - For breast cancer subtyping: 112 | ```bash 113 | python scripts/deepathnet_independent_test.py configs/tcga_train_cptac_test_brca/cnv_rna/deepathnet_cnv_rna.json 114 | ``` 115 | 116 | ### Output Description 117 | 118 | - The predictions generated by DeePathNet (e.g., IC50 values or cancer subtypes) are saved in the output directory defined in the configuration file. 119 | - The output includes performance metrics, predictions, and optional feature importance scores. 120 | 121 | ## Running Baseline Comparisons 122 | 123 | To compare DeePathNet with baseline models like moCluster and mixOmics, use the provided scripts: 124 | 125 | 1. **moCluster Baseline Comparison** 126 | ```bash 127 | python scripts/baseline_ec_cv.py configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/moCluster_rf_allgenes_drug_mutation_cnv_rna.json 128 | ``` 129 | 2. **Cancer Type Baseline Comparison** 130 | ```bash 131 | python scripts/cancer_type_baseline_23cancertypes.py 132 | ``` 133 | 134 | ## Running Feature Importance Analysis 135 | 136 | DeePathNet supports pathway-level and gene-level feature importance analysis: 137 | 138 | 1. **Pathway-level Feature Importance** 139 | ```bash 140 | python scripts/transformer_explantion_cancer_type.py configs/tcga_brca_subtypes/mutation_cnv_rna/deepathnet_allgenes_mutation_cnv_rna.json 141 | ``` 142 | 143 | 2. **Gene-level Feature Importance** 144 | ```bash 145 | python scripts/transformer_shap_cancer_type.py configs/tcga_brca_subtypes/mutation_cnv_rna/deepathnet_allgenes_mutation_cnv_rna.json 146 | ``` 147 | 148 | ## Data Input 149 | 150 | The input files for DeePathNet should have samples as rows and features as columns. Features should be formatted with an underscore separating the gene name and the omic data type (e.g., `GeneA_RNA`). For example: 151 | 152 | | Sample | GeneA_RNA | GeneA_PROT | GeneB_RNA | GeneB_PROT | 153 | |------------|-----------|------------|-----------|------------| 154 | | Cell_lineA | 10 | 8 | 2 | 3 | 155 | | Cell_lineB | 15 | 12 | 1 | 2 | 156 | | Cell_lineC | 5 | 3 | 10 | 8 | 157 | 158 | ## Data Output 159 | 160 | The output includes predictions such as: 161 | - Drug response (IC50 values) 162 | - Cancer types/subtypes 163 | - Feature importance scores for interpretability 164 | 165 | ## Troubleshooting and Raising Issues 166 | 167 | We recommend using the specified Python and PyTorch versions for compatibility. If issues arise, please open a ticket in the Issues tab with details about your setup, the steps you followed, and error logs. 168 | 169 | ## Contact 170 | 171 | For more information, please contact the study authors via the associated publication. -------------------------------------------------------------------------------- /configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/configs/.DS_Store -------------------------------------------------------------------------------- /configs/ccle_gdsc_intersection3/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/configs/ccle_gdsc_intersection3/.DS_Store -------------------------------------------------------------------------------- /configs/ccle_gdsc_intersection3/mutation_cnv_rna/deepathnet_allgenes_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/ccle_df_intersection3_drug.csv.gz", 3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/ccle_gdsc_intersection3_wide.csv.gz", 4 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 5 | "model": "DeePathNet", 6 | "do_cv": true, 7 | "work_dir": "/home/scai/DeePathNet/work_dirs/ccle_gdsc_intersection3/DeePathNet/mutation_cnv_rna", 8 | "data_type": ["mutation", "cnv", "rna"], 9 | "task": "regression", 10 | "seed": 1, 11 | "num_repeat": 5, 12 | "batch_size": 100, 13 | "num_workers": 1, 14 | "log_freq": 20, 15 | "num_of_epochs": 400, 16 | "dim": 512, 17 | "mlp_ratio": 2, 18 | "out_mlp_ratio": 8, 19 | "heads": 16, 20 | "depth": 2, 21 | "dropout": 0, 22 | "emb_dropout": 0, 23 | "pathway_dropout": 0.5, 24 | "weight_decay": 1e-5, 25 | "cancer_only": false, 26 | "lr": 3e-5, 27 | "save_checkpoints": false, 28 | "save_scores": true, 29 | "drop_last": true, 30 | "suffix": "_allgenes", 31 | "saved_model": "" 32 | } -------------------------------------------------------------------------------- /configs/ccle_gdsc_intersection3/mutation_cnv_rna/ec_en_allgenes_drug_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/ccle_df_intersection3_drug.csv.gz", 3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/ccle_gdsc_intersection3_wide.csv.gz", 4 | "model": "en", 5 | "num_repeat": 5, 6 | "cv": 5, 7 | "seed": 1, 8 | "work_dir": "/home/scai/DeePathNet/work_dirs/ccle_gdsc_intersection3/ec_en/mutation_cnv_rna", 9 | "data_type": ["mutation","cnv","rna"], 10 | "task": "regression" 11 | } -------------------------------------------------------------------------------- /configs/ccle_gdsc_intersection3/mutation_cnv_rna/ec_rf_allgenes_drug_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/ccle_df_intersection3_drug.csv.gz", 3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/ccle_gdsc_intersection3_wide.csv.gz", 4 | "model": "rf", 5 | "num_repeat": 5, 6 | "cv": 5, 7 | "seed": 1, 8 | "params_grid": { 9 | "n_estimators": [400, 600, 800], 10 | "max_features": ["sqrt"], 11 | "n_jobs": [4] 12 | }, 13 | "work_dir": "/home/scai/DeePathNet/work_dirs/ccle_gdsc_intersection3/ec_rf/mutation_cnv_rna", 14 | "data_type": ["mutation","cnv","rna"], 15 | "task": "regression" 16 | } -------------------------------------------------------------------------------- /configs/ccle_gdsc_intersection3/mutation_cnv_rna/moCluster_rf_allgenes_drug_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/DR/moCluster/ccle_mutation_cnv_rna_gdsc_allgenes.csv", 3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/ccle_gdsc_intersection3_wide.csv.gz", 4 | "model": "rf", 5 | "num_repeat": 5, 6 | "cv": 5, 7 | "seed": 1, 8 | "work_dir": "/home/scai/DeePathNet/work_dirs/ccle_gdsc_intersection3/moCluster_rf/mutation_cnv_rna", 9 | "data_type": ["DR"], 10 | "task": "regression" 11 | } -------------------------------------------------------------------------------- /configs/ccle_gdsc_intersection3/mutation_cnv_rna/move_rf_allgenes_drug_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/DR/MOVE/ccle_mutation_cnv_rna_200factor.csv", 3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/ccle_gdsc_intersection3_wide.csv.gz", 4 | "model": "rf", 5 | "num_repeat": 5, 6 | "cv": 5, 7 | "seed": 1, 8 | "work_dir": "/home/scai/DeePathNet/work_dirs/ccle_gdsc_intersection3/move_rf/mutation_cnv_rna", 9 | "data_type": ["DR"], 10 | "task": "regression" 11 | } -------------------------------------------------------------------------------- /configs/ccle_gdsc_intersection3/mutation_cnv_rna/pca_rf_allgenes_drug_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/DR/pca/ccle_mutation_cnv_rna_gdsc_allgenes.csv", 3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/ccle_gdsc_intersection3_wide.csv.gz", 4 | "model": "rf", 5 | "num_repeat": 5, 6 | "cv": 5, 7 | "seed": 1, 8 | "work_dir": "/home/scai/DeePathNet/work_dirs/ccle_gdsc_intersection3/pca_rf/mutation_cnv_rna", 9 | "data_type": ["DR"], 10 | "task": "regression" 11 | } -------------------------------------------------------------------------------- /configs/ccle_gdsc_intersection3/mutation_cnv_rna/scvaeit_rf_allgenes_drug_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/DR/scVAEIT/ccle_scvaeit_latent_200factor.csv", 3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/ccle_gdsc_intersection3_wide.csv.gz", 4 | "model": "rf", 5 | "num_repeat": 5, 6 | "cv": 5, 7 | "seed": 1, 8 | "work_dir": "/home/scai/DeePathNet/work_dirs/ccle_gdsc_intersection3/scvaeit_rf/mutation_cnv_rna", 9 | "data_type": ["DR"], 10 | "task": "regression" 11 | } -------------------------------------------------------------------------------- /configs/sanger_gdsc_intersection_noprot/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/configs/sanger_gdsc_intersection_noprot/.DS_Store -------------------------------------------------------------------------------- /configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/deepathnet_allgenes_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv.gz", 3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 5 | "model": "DeePathNet", 6 | "do_cv": true, 7 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_gdsc_intersection_noprot/DeePathNet/mutation_cnv_rna", 8 | "data_type": ["mutation", "cnv", "rna"], 9 | "task": "regression", 10 | "seed": 1, 11 | "num_repeat": 5, 12 | "batch_size": 100, 13 | "num_workers": 1, 14 | "log_freq": 20, 15 | "num_of_epochs": 400, 16 | "dim": 512, 17 | "mlp_ratio": 2, 18 | "out_mlp_ratio": 8, 19 | "heads": 16, 20 | "depth": 2, 21 | "dropout": 0, 22 | "emb_dropout": 0, 23 | "pathway_dropout": 0.5, 24 | "weight_decay": 1e-5, 25 | "cancer_only": false, 26 | "lr": 1e-5, 27 | "save_checkpoints": false, 28 | "save_scores": true, 29 | "drop_last": true, 30 | "suffix": "_all_genes", 31 | "saved_model": "" 32 | } -------------------------------------------------------------------------------- /configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/ec_en_allgenes_drug_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv.gz", 3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "model": "en", 5 | "num_repeat": 5, 6 | "cv": 5, 7 | "seed": 1, 8 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_gdsc_intersection_noprot/ec_en/mutation_cnv_rna", 9 | "data_type": ["mutation","cnv","rna"], 10 | "task": "regression" 11 | } -------------------------------------------------------------------------------- /configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/ec_rf_allgenes_drug_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv.gz", 3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "model": "rf", 5 | "num_repeat": 5, 6 | "cv": 5, 7 | "seed": 1, 8 | "params_grid": { 9 | "n_estimators": [400, 600, 800], 10 | "max_features": ["sqrt"], 11 | "n_jobs": [4] 12 | }, 13 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_gdsc_intersection_noprot/ec_rf/mutation_cnv_rna", 14 | "data_type": ["mutation","cnv","rna"], 15 | "task": "regression" 16 | } -------------------------------------------------------------------------------- /configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/moCluster_rf_allgenes_drug_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/DR/moCluster/sanger_mutation_cnv_rna_allgenes.csv", 3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "model": "rf", 5 | "num_repeat": 5, 6 | "cv": 5, 7 | "seed": 1, 8 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_gdsc_intersection_noprot/moCluster_rf/mutation_cnv_rna", 9 | "data_type": ["DR"], 10 | "task": "regression" 11 | } -------------------------------------------------------------------------------- /configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/move_rf_allgenes_drug_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/DR/MOVE/sanger_mutation_cnv_rna_200factor.csv", 3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "model": "rf", 5 | "num_repeat": 5, 6 | "cv": 5, 7 | "seed": 1, 8 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_gdsc_intersection_noprot/move_rf/mutation_cnv_rna", 9 | "data_type": ["DR"], 10 | "task": "regression" 11 | } -------------------------------------------------------------------------------- /configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/pca_rf_allgenes_drug_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/DR/pca/sanger_mutation_cnv_rna_allgenes.csv", 3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "model": "rf", 5 | "num_repeat": 5, 6 | "cv": 5, 7 | "seed": 1, 8 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_gdsc_intersection_noprot/pca_rf/mutation_cnv_rna", 9 | "data_type": ["DR"], 10 | "task": "regression" 11 | } -------------------------------------------------------------------------------- /configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/scvaeit_rf_allgenes_drug_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/DR/scVAEIT/sanger_scvaeit_latent_200factor.csv", 3 | "target_file": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "model": "rf", 5 | "num_repeat": 5, 6 | "cv": 5, 7 | "seed": 1, 8 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_gdsc_intersection_noprot/scvaeit_rf/mutation_cnv_rna", 9 | "data_type": ["DR"], 10 | "task": "regression" 11 | } -------------------------------------------------------------------------------- /configs/sanger_train_ccle_test_gdsc/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/configs/sanger_train_ccle_test_gdsc/.DS_Store -------------------------------------------------------------------------------- /configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv", 6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 7 | "model": "DeePathNet", 8 | "do_cv": true, 9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna", 10 | "data_type": ["mutation", "cnv", "rna"], 11 | "task": "regression", 12 | "seed": 1, 13 | "num_repeat": 1, 14 | "batch_size": 100, 15 | "num_workers": 1, 16 | "log_freq": 20, 17 | "num_of_epochs": 400, 18 | "dim": 512, 19 | "mlp_ratio": 2, 20 | "out_mlp_ratio": 8, 21 | "heads": 16, 22 | "depth": 2, 23 | "dropout": 0, 24 | "emb_dropout": 0, 25 | "pathway_dropout": 0.5, 26 | "weight_decay": 1e-5, 27 | "cancer_only": true, 28 | "lr": 1e-5, 29 | "save_checkpoints": true, 30 | "save_scores": false, 31 | "drop_last": true, 32 | "suffix": "_DeePathNet", 33 | "saved_model": "" 34 | } -------------------------------------------------------------------------------- /configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_ds_100.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv", 6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 7 | "model": "DeePathNet", 8 | "do_cv": true, 9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna", 10 | "data_type": ["mutation", "cnv", "rna"], 11 | "task": "regression", 12 | "seed": 1, 13 | "num_repeat": 10, 14 | "batch_size": 100, 15 | "num_workers": 1, 16 | "log_freq": 20, 17 | "num_of_epochs": 400, 18 | "dim": 512, 19 | "mlp_ratio": 2, 20 | "out_mlp_ratio": 8, 21 | "heads": 16, 22 | "depth": 2, 23 | "dropout": 0, 24 | "emb_dropout": 0, 25 | "pathway_dropout": 0.5, 26 | "weight_decay": 1e-5, 27 | "cancer_only": true, 28 | "lr": 1e-5, 29 | "downsample": 100, 30 | "save_checkpoints": false, 31 | "save_scores": true, 32 | "drop_last": true, 33 | "suffix": "_DeePathNet_ds_100", 34 | "saved_model": "" 35 | } -------------------------------------------------------------------------------- /configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_ds_200.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv", 6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 7 | "model": "DeePathNet", 8 | "do_cv": true, 9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna", 10 | "data_type": ["mutation", "cnv", "rna"], 11 | "task": "regression", 12 | "seed": 1, 13 | "num_repeat": 10, 14 | "batch_size": 100, 15 | "num_workers": 1, 16 | "log_freq": 20, 17 | "num_of_epochs": 400, 18 | "dim": 512, 19 | "mlp_ratio": 2, 20 | "out_mlp_ratio": 8, 21 | "heads": 16, 22 | "depth": 2, 23 | "dropout": 0, 24 | "emb_dropout": 0, 25 | "pathway_dropout": 0.5, 26 | "weight_decay": 1e-5, 27 | "cancer_only": true, 28 | "lr": 1e-5, 29 | "downsample": 200, 30 | "save_checkpoints": false, 31 | "save_scores": true, 32 | "drop_last": true, 33 | "suffix": "_DeePathNet_ds_200", 34 | "saved_model": "" 35 | } -------------------------------------------------------------------------------- /configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_ds_300.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv", 6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 7 | "model": "DeePathNet", 8 | "do_cv": true, 9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna", 10 | "data_type": ["mutation", "cnv", "rna"], 11 | "task": "regression", 12 | "seed": 1, 13 | "num_repeat": 10, 14 | "batch_size": 100, 15 | "num_workers": 1, 16 | "log_freq": 20, 17 | "num_of_epochs": 400, 18 | "dim": 512, 19 | "mlp_ratio": 2, 20 | "out_mlp_ratio": 8, 21 | "heads": 16, 22 | "depth": 2, 23 | "dropout": 0, 24 | "emb_dropout": 0, 25 | "pathway_dropout": 0.5, 26 | "weight_decay": 1e-5, 27 | "cancer_only": true, 28 | "lr": 1e-5, 29 | "downsample": 300, 30 | "save_checkpoints": false, 31 | "save_scores": true, 32 | "drop_last": true, 33 | "suffix": "_DeePathNet_ds_300", 34 | "saved_model": "" 35 | } -------------------------------------------------------------------------------- /configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_ds_400.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv", 6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 7 | "model": "DeePathNet", 8 | "do_cv": true, 9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna", 10 | "data_type": ["mutation", "cnv", "rna"], 11 | "task": "regression", 12 | "seed": 1, 13 | "num_repeat": 10, 14 | "batch_size": 100, 15 | "num_workers": 1, 16 | "log_freq": 20, 17 | "num_of_epochs": 400, 18 | "dim": 512, 19 | "mlp_ratio": 2, 20 | "out_mlp_ratio": 8, 21 | "heads": 16, 22 | "depth": 2, 23 | "dropout": 0, 24 | "emb_dropout": 0, 25 | "pathway_dropout": 0.5, 26 | "weight_decay": 1e-5, 27 | "cancer_only": true, 28 | "lr": 1e-5, 29 | "downsample": 400, 30 | "save_checkpoints": false, 31 | "save_scores": true, 32 | "drop_last": true, 33 | "suffix": "_DeePathNet_ds_400", 34 | "saved_model": "" 35 | } -------------------------------------------------------------------------------- /configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_random_control.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv", 6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 7 | "model": "DeePathNet", 8 | "do_cv": true, 9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna", 10 | "data_type": ["mutation", "cnv", "rna"], 11 | "task": "regression", 12 | "seed": 1, 13 | "num_repeat": 10, 14 | "batch_size": 100, 15 | "num_workers": 1, 16 | "log_freq": 20, 17 | "num_of_epochs": 400, 18 | "dim": 512, 19 | "mlp_ratio": 2, 20 | "out_mlp_ratio": 8, 21 | "heads": 16, 22 | "depth": 2, 23 | "dropout": 0, 24 | "emb_dropout": 0, 25 | "pathway_dropout": 0.5, 26 | "weight_decay": 1e-5, 27 | "cancer_only": true, 28 | "lr": 1e-5, 29 | "save_checkpoints": false, 30 | "save_scores": true, 31 | "drop_last": true, 32 | "suffix": "_DeePathNet_random_control", 33 | "random_control": true, 34 | "saved_model": "" 35 | } -------------------------------------------------------------------------------- /configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_random_control_ds_100.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv", 6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 7 | "model": "DeePathNet", 8 | "do_cv": true, 9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna", 10 | "data_type": ["mutation", "cnv", "rna"], 11 | "task": "regression", 12 | "seed": 1, 13 | "num_repeat": 10, 14 | "batch_size": 100, 15 | "num_workers": 1, 16 | "log_freq": 20, 17 | "num_of_epochs": 400, 18 | "dim": 512, 19 | "mlp_ratio": 2, 20 | "out_mlp_ratio": 8, 21 | "heads": 16, 22 | "depth": 2, 23 | "dropout": 0, 24 | "emb_dropout": 0, 25 | "pathway_dropout": 0.5, 26 | "weight_decay": 1e-5, 27 | "cancer_only": true, 28 | "lr": 1e-5, 29 | "downsample": 100, 30 | "save_checkpoints": false, 31 | "save_scores": true, 32 | "drop_last": true, 33 | "suffix": "_DeePathNet_random_control_ds_100", 34 | "random_control": true, 35 | "saved_model": "" 36 | } -------------------------------------------------------------------------------- /configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_random_control_ds_200.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv", 6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 7 | "model": "DeePathNet", 8 | "do_cv": true, 9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna", 10 | "data_type": ["mutation", "cnv", "rna"], 11 | "task": "regression", 12 | "seed": 1, 13 | "num_repeat": 10, 14 | "batch_size": 100, 15 | "num_workers": 1, 16 | "log_freq": 20, 17 | "num_of_epochs": 400, 18 | "dim": 512, 19 | "mlp_ratio": 2, 20 | "out_mlp_ratio": 8, 21 | "heads": 16, 22 | "depth": 2, 23 | "dropout": 0, 24 | "emb_dropout": 0, 25 | "pathway_dropout": 0.5, 26 | "weight_decay": 1e-5, 27 | "cancer_only": true, 28 | "lr": 1e-5, 29 | "downsample": 200, 30 | "save_checkpoints": false, 31 | "save_scores": true, 32 | "drop_last": true, 33 | "suffix": "_DeePathNet_random_control_ds_200", 34 | "random_control": true, 35 | "saved_model": "" 36 | } -------------------------------------------------------------------------------- /configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_random_control_ds_300.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv", 6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 7 | "model": "DeePathNet", 8 | "do_cv": true, 9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna", 10 | "data_type": ["mutation", "cnv", "rna"], 11 | "task": "regression", 12 | "seed": 1, 13 | "num_repeat": 10, 14 | "batch_size": 100, 15 | "num_workers": 1, 16 | "log_freq": 20, 17 | "num_of_epochs": 400, 18 | "dim": 512, 19 | "mlp_ratio": 2, 20 | "out_mlp_ratio": 8, 21 | "heads": 16, 22 | "depth": 2, 23 | "dropout": 0, 24 | "emb_dropout": 0, 25 | "pathway_dropout": 0.5, 26 | "weight_decay": 1e-5, 27 | "cancer_only": true, 28 | "lr": 1e-5, 29 | "downsample": 300, 30 | "save_checkpoints": false, 31 | "save_scores": true, 32 | "drop_last": true, 33 | "suffix": "_DeePathNet_random_control_ds_300", 34 | "random_control": true, 35 | "saved_model": "" 36 | } -------------------------------------------------------------------------------- /configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/deepathnet_mutation_cnv_rna_random_control_ds_400.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv", 6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 7 | "model": "DeePathNet", 8 | "do_cv": true, 9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna", 10 | "data_type": ["mutation", "cnv", "rna"], 11 | "task": "regression", 12 | "seed": 1, 13 | "num_repeat": 10, 14 | "batch_size": 100, 15 | "num_workers": 1, 16 | "log_freq": 20, 17 | "num_of_epochs": 400, 18 | "dim": 512, 19 | "mlp_ratio": 2, 20 | "out_mlp_ratio": 8, 21 | "heads": 16, 22 | "depth": 2, 23 | "dropout": 0, 24 | "emb_dropout": 0, 25 | "pathway_dropout": 0.5, 26 | "weight_decay": 1e-5, 27 | "cancer_only": true, 28 | "lr": 1e-5, 29 | "downsample": 400, 30 | "save_checkpoints": false, 31 | "save_scores": true, 32 | "drop_last": true, 33 | "suffix": "_DeePathNet_random_control_ds_400", 34 | "random_control": true, 35 | "saved_model": "" 36 | } -------------------------------------------------------------------------------- /configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/ec_mlp_all_genes_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv.gz", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv", 6 | "pathway_file": "/home/scai/DeepOmicIntegrate/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 7 | "model": "mlp", 8 | "seed": 1, 9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/ec_mlp/mutation_cnv_rna", 10 | "data_type": [ 11 | "mutation", 12 | "cnv", 13 | "rna" 14 | ], 15 | "task": "regression" 16 | } -------------------------------------------------------------------------------- /configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna/ec_rf_all_genes_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_noprot_drug.csv.gz", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test.csv", 6 | "model": "rf", 7 | "seed": 1, 8 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/ec_rf/mutation_cnv_rna", 9 | "data_type": [ 10 | "mutation", 11 | "cnv", 12 | "rna" 13 | ], 14 | "task": "regression" 15 | } -------------------------------------------------------------------------------- /configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna_prot/deepathnet_mutation_cnv_rna_prot.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_mutation_cnv_rna_prot_drug.csv.gz", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test_mutation_cnv_rna_prot.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test_mutation_cnv_rna_prot.csv", 6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 7 | "model": "DeePathNet", 8 | "do_cv": true, 9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna_prot", 10 | "data_type": ["mutation", "cnv", "rna", "prot"], 11 | "task": "regression", 12 | "seed": 1, 13 | "num_repeat": 1, 14 | "batch_size": 100, 15 | "num_workers": 1, 16 | "log_freq": 20, 17 | "num_of_epochs": 500, 18 | "dim": 512, 19 | "mlp_ratio": 2, 20 | "out_mlp_ratio": 8, 21 | "heads": 16, 22 | "depth": 2, 23 | "dropout": 0, 24 | "emb_dropout": 0, 25 | "pathway_dropout": 0.5, 26 | "weight_decay": 1e-5, 27 | "cancer_only": true, 28 | "lr": 1e-5, 29 | "save_checkpoints": false, 30 | "save_scores": true, 31 | "drop_last": true, 32 | "suffix": "_DeePathNet", 33 | "saved_model": "" 34 | } -------------------------------------------------------------------------------- /configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna_prot/deepathnet_mutation_cnv_rna_prot_random_control.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_mutation_cnv_rna_prot_drug.csv.gz", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test_mutation_cnv_rna_prot.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test_mutation_cnv_rna_prot.csv", 6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 7 | "model": "DeePathNet", 8 | "do_cv": true, 9 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/DeePathNet/mutation_cnv_rna_prot", 10 | "data_type": ["mutation", "cnv", "rna", "prot"], 11 | "task": "regression", 12 | "seed": 1, 13 | "num_repeat": 10, 14 | "batch_size": 100, 15 | "num_workers": 1, 16 | "log_freq": 20, 17 | "num_of_epochs": 500, 18 | "dim": 512, 19 | "mlp_ratio": 2, 20 | "out_mlp_ratio": 8, 21 | "heads": 16, 22 | "depth": 2, 23 | "dropout": 0, 24 | "emb_dropout": 0, 25 | "pathway_dropout": 0.5, 26 | "weight_decay": 1e-5, 27 | "cancer_only": true, 28 | "lr": 1e-5, 29 | "save_checkpoints": false, 30 | "save_scores": true, 31 | "drop_last": true, 32 | "suffix": "_DeePathNet_random_control", 33 | "random_control": true, 34 | "saved_model": "" 35 | } -------------------------------------------------------------------------------- /configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna_prot/ec_rf_all_genes_mutation_cnv_rna_prot.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/sanger_df_intersection_mutation_cnv_rna_prot_drug.csv.gz", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/drug/sanger_gdsc_intersection_noprot_wide.csv.gz", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/ccle_unique_as_test_mutation_cnv_rna_prot.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/drug/ccle_unique_as_test_mutation_cnv_rna_prot.csv", 6 | "model": "rf", 7 | "seed": 1, 8 | "work_dir": "/home/scai/DeePathNet/work_dirs/sanger_train_ccle_test_gdsc/ec_rf/mutation_cnv_rna_prot", 9 | "data_type": [ 10 | "mutation", 11 | "cnv", 12 | "rna", 13 | "prot" 14 | ], 15 | "task": "regression" 16 | } -------------------------------------------------------------------------------- /configs/tcga_all_cancer_types/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/configs/tcga_all_cancer_types/.DS_Store -------------------------------------------------------------------------------- /configs/tcga_all_cancer_types/mutation_cnv_rna/deepathnet_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/tcga_23_cancer_types_mutation_cnv_rna_cancer_only.csv.gz", 3 | "target_file": "/home/scai/DeePathNet/data/processed/cancer_type/tcga_23_cancer_types_mutation_cnv_rna.csv", 4 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 5 | "model": "DeePathNet", 6 | "do_cv": true, 7 | "work_dir": "/home/scai/DeePathNet/work_dirs/tcga_all_cancer_types/DeePathNet/mutation_cnv_rna", 8 | "data_type": ["mutation", "cnv", "rna"], 9 | "task": "multiclass", 10 | "seed": 1, 11 | "num_repeat": 5, 12 | "batch_size": 100, 13 | "num_workers": 1, 14 | "log_freq": 20, 15 | "num_of_epochs": 600, 16 | "dim": 512, 17 | "mlp_ratio": 2, 18 | "out_mlp_ratio": 8, 19 | "heads": 16, 20 | "depth": 2, 21 | "dropout": 0, 22 | "emb_dropout": 0, 23 | "pathway_dropout": 0.5, 24 | "weight_decay": 1e-5, 25 | "cancer_only": true, 26 | "lr": 1e-5, 27 | "save_checkpoints": false, 28 | "save_scores": true, 29 | "drop_last": true, 30 | "suffix": "_23_types_cancer_genes", 31 | "saved_model": "" 32 | } -------------------------------------------------------------------------------- /configs/tcga_all_cancer_types/mutation_cnv_rna/deepathnet_mutation_cnv_rna_ds_2k.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/tcga_23_cancer_types_mutation_cnv_rna_cancer_only.csv.gz", 3 | "target_file": "/home/scai/DeePathNet/data/processed/cancer_type/tcga_23_cancer_types_mutation_cnv_rna.csv", 4 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 5 | "model": "DeePathNet", 6 | "do_cv": true, 7 | "work_dir": "/home/scai/DeePathNet/work_dirs/tcga_all_cancer_types/DeePathNet/mutation_cnv_rna", 8 | "data_type": ["mutation", "cnv", "rna"], 9 | "task": "multiclass", 10 | "seed": 1, 11 | "num_repeat": 2, 12 | "batch_size": 100, 13 | "num_workers": 1, 14 | "log_freq": 20, 15 | "num_of_epochs": 600, 16 | "dim": 512, 17 | "mlp_ratio": 2, 18 | "out_mlp_ratio": 8, 19 | "heads": 16, 20 | "depth": 2, 21 | "dropout": 0, 22 | "emb_dropout": 0, 23 | "pathway_dropout": 0.5, 24 | "weight_decay": 1e-5, 25 | "cancer_only": true, 26 | "lr": 1e-5, 27 | "save_checkpoints": false, 28 | "save_scores": true, 29 | "drop_last": true, 30 | "downsample": 2000, 31 | "suffix": "_23_types_cancer_genes_ds_2k", 32 | "saved_model": "" 33 | } -------------------------------------------------------------------------------- /configs/tcga_all_cancer_types/mutation_cnv_rna/deepathnet_mutation_cnv_rna_random_control.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/tcga_23_cancer_types_mutation_cnv_rna_cancer_only.csv.gz", 3 | "target_file": "/home/scai/DeePathNet/data/processed/cancer_type/tcga_23_cancer_types_mutation_cnv_rna.csv", 4 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 5 | "model": "DeePathNet", 6 | "do_cv": true, 7 | "work_dir": "/home/scai/DeePathNet/work_dirs/tcga_all_cancer_types/DeePathNet/mutation_cnv_rna", 8 | "data_type": ["mutation", "cnv", "rna"], 9 | "task": "multiclass", 10 | "seed": 1, 11 | "num_repeat": 5, 12 | "batch_size": 100, 13 | "num_workers": 1, 14 | "log_freq": 20, 15 | "num_of_epochs": 600, 16 | "dim": 512, 17 | "mlp_ratio": 2, 18 | "out_mlp_ratio": 8, 19 | "heads": 16, 20 | "depth": 2, 21 | "dropout": 0, 22 | "emb_dropout": 0, 23 | "pathway_dropout": 0.5, 24 | "weight_decay": 1e-5, 25 | "cancer_only": true, 26 | "lr": 1e-5, 27 | "save_checkpoints": false, 28 | "save_scores": true, 29 | "drop_last": true, 30 | "suffix": "_23_types_cancer_genes_random_control", 31 | "random_control": true, 32 | "saved_model": "" 33 | } -------------------------------------------------------------------------------- /configs/tcga_all_cancer_types/mutation_cnv_rna/deepathnet_mutation_cnv_rna_random_control_ds_2k.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/tcga_23_cancer_types_mutation_cnv_rna_cancer_only.csv.gz", 3 | "target_file": "/home/scai/DeePathNet/data/processed/cancer_type/tcga_23_cancer_types_mutation_cnv_rna.csv", 4 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 5 | "model": "DeePathNet", 6 | "do_cv": true, 7 | "work_dir": "/home/scai/DeePathNet/work_dirs/tcga_all_cancer_types/DeePathNet/mutation_cnv_rna", 8 | "data_type": ["mutation", "cnv", "rna"], 9 | "task": "multiclass", 10 | "seed": 1, 11 | "num_repeat": 2, 12 | "batch_size": 100, 13 | "num_workers": 1, 14 | "log_freq": 20, 15 | "num_of_epochs": 600, 16 | "dim": 512, 17 | "mlp_ratio": 2, 18 | "out_mlp_ratio": 8, 19 | "heads": 16, 20 | "depth": 2, 21 | "dropout": 0, 22 | "emb_dropout": 0, 23 | "pathway_dropout": 0.5, 24 | "weight_decay": 1e-5, 25 | "cancer_only": true, 26 | "lr": 1e-5, 27 | "save_checkpoints": false, 28 | "save_scores": true, 29 | "drop_last": true, 30 | "downsample": 2000, 31 | "suffix": "_23_types_cancer_genes_random_control_ds_2k", 32 | "random_control": true, 33 | "saved_model": "" 34 | } -------------------------------------------------------------------------------- /configs/tcga_brca_subtypes/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/configs/tcga_brca_subtypes/.DS_Store -------------------------------------------------------------------------------- /configs/tcga_brca_subtypes/mutation_cnv_rna/deepathnet_mutation_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "/home/scai/DeePathNet/data/processed/omics/tcga_brca_mutation_cnv_rna_log2.csv.gz", 3 | "target_file": "/home/scai/DeePathNet/data/processed/cancer_type/tcga_brca_mutation_cnv_rna_subtypes.csv", 4 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 5 | "model": "DeePathNet", 6 | "do_cv": true, 7 | "work_dir": "/home/scai/DeePathNet/work_dirs/tcga_brca_subtypes/DeePathNet/mutation_cnv_rna", 8 | "data_type": ["mutation", "cnv", "rna"], 9 | "task": "multiclass", 10 | "seed": 1, 11 | "num_repeat": 5, 12 | "batch_size": 100, 13 | "num_workers": 1, 14 | "log_freq": 20, 15 | "num_of_epochs": 300, 16 | "dim": 512, 17 | "mlp_ratio": 2, 18 | "out_mlp_ratio": 8, 19 | "heads": 16, 20 | "depth": 2, 21 | "dropout": 0, 22 | "emb_dropout": 0, 23 | "pathway_dropout": 0.5, 24 | "weight_decay": 1e-5, 25 | "cancer_only": true, 26 | "lr": 1e-5, 27 | "save_checkpoints": true, 28 | "save_scores": false, 29 | "drop_last": true, 30 | "suffix": "_DeePathNet", 31 | "saved_model": "202202081127_Cancer_type.pth" 32 | } -------------------------------------------------------------------------------- /configs/tcga_train_cptac_test_brca/cnv_rna/deepathnet_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/tcga_brca_as_validation.csv", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/cancer_type/tcga_brca_mutation_cnv_rna_subtypes.csv", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/cptac_as_validation.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/cancer_type/cptac_brca_cnv_rna_subtypes_independent.csv", 6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 7 | "model": "DeePathNet", 8 | "do_cv": true, 9 | "work_dir": "/home/scai/DeePathNet/work_dirs/tcga_train_cptac_test_brca/DeePathNet/cnv_rna", 10 | "data_type": ["cnv", "rna"], 11 | "task": "multiclass", 12 | "seed": 1, 13 | "num_repeat": 1, 14 | "batch_size": 100, 15 | "num_workers": 1, 16 | "log_freq": 20, 17 | "num_of_epochs": 1000, 18 | "dim": 512, 19 | "mlp_ratio": 2, 20 | "out_mlp_ratio": 8, 21 | "heads": 16, 22 | "depth": 2, 23 | "dropout": 0, 24 | "emb_dropout": 0, 25 | "pathway_dropout": 0.5, 26 | "weight_decay": 1e-5, 27 | "cancer_only": true, 28 | "lr": 1e-5, 29 | "save_checkpoints": false, 30 | "save_scores": true, 31 | "drop_last": true, 32 | "suffix": "_cancer_genes", 33 | "saved_model": "" 34 | } -------------------------------------------------------------------------------- /configs/tcga_train_cptac_test_brca/cnv_rna/deepathnet_cnv_rna_random_control.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/tcga_brca_as_validation.csv", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/cancer_type/tcga_brca_mutation_cnv_rna_subtypes.csv", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/cptac_as_validation.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/cancer_type/cptac_brca_cnv_rna_subtypes_independent.csv", 6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 7 | "model": "DeePathNet", 8 | "do_cv": true, 9 | "work_dir": "/home/scai/DeePathNet/work_dirs/tcga_train_cptac_test_brca/DeePathNet/cnv_rna", 10 | "data_type": ["cnv", "rna"], 11 | "task": "multiclass", 12 | "seed": 1, 13 | "num_repeat": 10, 14 | "batch_size": 100, 15 | "num_workers": 1, 16 | "log_freq": 20, 17 | "num_of_epochs": 1000, 18 | "dim": 512, 19 | "mlp_ratio": 2, 20 | "out_mlp_ratio": 8, 21 | "heads": 16, 22 | "depth": 2, 23 | "dropout": 0, 24 | "emb_dropout": 0, 25 | "pathway_dropout": 0.5, 26 | "weight_decay": 1e-5, 27 | "cancer_only": true, 28 | "lr": 1e-5, 29 | "save_checkpoints": false, 30 | "save_scores": true, 31 | "drop_last": true, 32 | "suffix": "_cancer_genes_random_control", 33 | "random_control": true, 34 | "saved_model": "" 35 | } -------------------------------------------------------------------------------- /configs/tcga_train_cptac_test_brca/cnv_rna/ec_rf_cnv_rna.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file_train": "/home/scai/DeePathNet/data/processed/omics/tcga_brca_as_validation.csv", 3 | "target_file_train": "/home/scai/DeePathNet/data/processed/cancer_type/tcga_brca_mutation_cnv_rna_subtypes.csv", 4 | "data_file_test": "/home/scai/DeePathNet/data/processed/omics/cptac_as_validation.csv", 5 | "target_file_test": "/home/scai/DeePathNet/data/processed/cancer_type/cptac_brca_cnv_rna_subtypes_independent.csv", 6 | "pathway_file": "/home/scai/DeePathNet/data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv", 7 | "model": "rf", 8 | "seed": 1, 9 | "work_dir": "/home/scai/DeePathNet/work_dirs/tcga_train_cptac_test_brca/ec_rf/cnv_rna", 10 | "data_type": [ 11 | "cnv", 12 | "rna" 13 | ], 14 | "task": "multiclass" 15 | } -------------------------------------------------------------------------------- /figures/Figure1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/figures/Figure1.pdf -------------------------------------------------------------------------------- /figures/Figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/figures/Figure1.png -------------------------------------------------------------------------------- /figures/supp_fig1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/figures/supp_fig1.pdf -------------------------------------------------------------------------------- /figures/supp_fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/figures/supp_fig1.png -------------------------------------------------------------------------------- /models/cancer_type_pretrained.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d328d6b7a29d640eb521dfe05aa4ec59c4ab14737ad0adb6afd55b44087245ff 3 | size 92638145 4 | -------------------------------------------------------------------------------- /models/drug_response_pretrained.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:083f2286bf1f9807d07a370bab3b515a18a84cffb0628d109cd128240ddfd4d8 3 | size 123718849 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | pandas 5 | scikit-learn 6 | matplotlib 7 | seaborn 8 | scipy 9 | tqdm 10 | -------------------------------------------------------------------------------- /scripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CMRI-ProCan/DeePathNet/7aa373c8ef4d8c8873bf719e2e63363eff05e6dc/scripts/.DS_Store -------------------------------------------------------------------------------- /scripts/baseline_ec_cv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to run the baseline model for the early concatenation methods using cross-validation for drug response prediction. 3 | E.g. python scripts/baseline_ec_cv.py configs/sanger_gdsc_intersection_noprot/mutation_cnv_rna/ec_rf_allgenes_drug_mutation_cnv_rna.json 4 | """ 5 | 6 | import json 7 | import logging 8 | import os 9 | import sys 10 | import warnings 11 | from datetime import datetime 12 | from time import time 13 | 14 | import numpy as np 15 | import pandas as pd 16 | from scipy.stats import pearsonr 17 | from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier 18 | from sklearn.linear_model import ElasticNet 19 | from sklearn.linear_model import LinearRegression, LogisticRegression 20 | from sklearn.metrics import ( 21 | mean_squared_error, 22 | r2_score, 23 | mean_absolute_error, 24 | roc_auc_score, 25 | accuracy_score, 26 | ) 27 | from sklearn.model_selection import KFold 28 | from sklearn.neural_network import MLPRegressor, MLPClassifier 29 | from sklearn.svm import SVR, SVC 30 | from tqdm import trange 31 | from xgboost import XGBClassifier, XGBRegressor 32 | 33 | warnings.filterwarnings(action="ignore", category=UserWarning) 34 | 35 | STAMP = datetime.today().strftime("%Y%m%d%H%M") 36 | OUTPUT_NA_NUM = -100 37 | 38 | config_file = sys.argv[1] 39 | 40 | # load model configs 41 | configs = json.load(open(config_file, "r")) 42 | 43 | log_suffix = f"{config_file.split('/')[-1].replace('.json', '')}" 44 | if not os.path.isdir(configs["work_dir"]): 45 | os.system(f"mkdir -p {configs['work_dir']}") 46 | 47 | data_file = configs["data_file"] 48 | target_file = configs["target_file"] 49 | data_type = configs["data_type"] 50 | 51 | log_file = f"{STAMP}_{log_suffix}.log" 52 | logger = logging.getLogger("baseline_ec") 53 | logger.setLevel(logging.DEBUG) 54 | fh = logging.FileHandler(os.path.join(configs["work_dir"], log_file)) 55 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 56 | fh.setFormatter(formatter) 57 | logger.addHandler(fh) 58 | 59 | logger.info(open(config_file, "r").read()) 60 | print(open(config_file, "r").read()) 61 | 62 | seed = configs["seed"] 63 | cv = KFold(n_splits=configs["cv"], shuffle=True, random_state=seed) 64 | 65 | if configs["task"].lower() == "classification": 66 | model_dict = { 67 | "lr": LogisticRegression(n_jobs=-1, solver="saga"), 68 | "rf": RandomForestClassifier(n_jobs=40, max_features="sqrt"), 69 | "svm": SVC(), 70 | "en": ElasticNet(), 71 | "svm-linear": SVC(kernel="linear"), 72 | "mlp": MLPClassifier(), 73 | "xgb": XGBClassifier(), 74 | } 75 | 76 | else: 77 | model_dict = { 78 | "lr": LinearRegression(), 79 | "rf": RandomForestRegressor(n_jobs=40, max_features="sqrt"), 80 | "svm": SVR(), 81 | "en": ElasticNet(), 82 | "svm-linear": SVR(kernel="linear"), 83 | "mlp": MLPRegressor(), 84 | "xgb": XGBRegressor(), 85 | } 86 | 87 | data_target = pd.read_csv(target_file, index_col=0) 88 | 89 | data_input = pd.read_csv(data_file, index_col=0) 90 | genes = np.unique(([x.split("_")[0] for x in data_input.columns])) 91 | if "pathway_file" in configs: 92 | pathway_dict = {} 93 | pathway_df = pd.read_csv(configs["pathway_file"]) 94 | if "min_cancer_publication" in configs: 95 | pathway_df = pathway_df[ 96 | pathway_df["Cancer_Publications"] > configs["min_cancer_publication"] 97 | ] 98 | logger.info( 99 | f"Filtering pathway with Cancer_Publications > {configs['min_cancer_publication']}" 100 | ) 101 | if "max_gene_num" in configs: 102 | pathway_df = pathway_df[pathway_df["GeneNumber"] < configs["max_gene_num"]] 103 | logger.info(f"Filtering pathway with GeneNumber < {configs['max_gene_num']}") 104 | if "min_gene_num" in configs: 105 | pathway_df = pathway_df[pathway_df["GeneNumber"] > configs["min_gene_num"]] 106 | logger.info(f"Filtering pathway with GeneNumber > {configs['min_gene_num']}") 107 | 108 | pathway_df["genes"] = pathway_df["genes"].map( 109 | lambda x: "|".join([gene for gene in x.split("|") if gene in genes]) 110 | ) 111 | 112 | for index, row in pathway_df.iterrows(): 113 | if row["genes"]: 114 | pathway_dict[row["name"]] = row["genes"].split("|") 115 | cancer_genes = set([y for x in pathway_df["genes"].values for y in x.split("|")]) 116 | data_input = data_input[ 117 | [ 118 | x 119 | for x in data_input.columns 120 | if (x.split("_")[0] in cancer_genes) or (x.split("_")[0] == "tissue") 121 | ] 122 | ] 123 | 124 | if data_type[0] != "DR": 125 | data_input = data_input[ 126 | [ 127 | x 128 | for x in data_input.columns 129 | if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type) 130 | ] 131 | ] 132 | 133 | logger.info(f"Input data shape: {data_input.shape}") 134 | logger.info(f"Target data shape: {data_target.shape}") 135 | 136 | data_input = data_input.fillna(0) 137 | data_target = data_target.fillna(OUTPUT_NA_NUM) 138 | 139 | merged_df = pd.merge(data_target, data_input, on="Cell_line") 140 | cell_lines_all = data_input.index.values 141 | num_targets = data_target.shape[1] 142 | 143 | count = 0 144 | num_repeat = 1 if "num_repeat" not in configs else configs["num_repeat"] 145 | feature_df_list = [] 146 | score_df_list = [] 147 | time_df_list = [] 148 | logger.info(f"Merged df shape: {merged_df.shape}") 149 | 150 | for n in range(num_repeat): 151 | cv = KFold(n_splits=5, shuffle=True, random_state=(seed + n)) 152 | for cell_lines_train_index, cell_lines_val_index in cv.split(cell_lines_all): 153 | start_time = time() 154 | for i in trange(num_targets): 155 | train_lines = np.array(cell_lines_all)[cell_lines_train_index] 156 | val_lines = np.array(cell_lines_all)[cell_lines_val_index] 157 | merged_df_train = merged_df[merged_df.index.isin(train_lines)] 158 | merged_df_val = merged_df[merged_df.index.isin(val_lines)] 159 | 160 | y_train = merged_df_train.iloc[:, i] 161 | X_train = merged_df_train.iloc[:, num_targets:] 162 | X_train = X_train[(y_train != OUTPUT_NA_NUM)] 163 | y_train = y_train[y_train != OUTPUT_NA_NUM] 164 | 165 | y_val = merged_df_val.iloc[:, i] 166 | X_val = merged_df_val.iloc[:, num_targets:] 167 | X_val = X_val[(y_val != OUTPUT_NA_NUM)] 168 | y_val = y_val[y_val != OUTPUT_NA_NUM] 169 | 170 | model = model_dict[configs["model"]] 171 | 172 | model.fit(X_train, y_train) 173 | y_pred = model.predict(X_val) 174 | sign = 1 if configs["task"].lower() == "classification" else -1 175 | seconds_elapsed = time() - start_time 176 | if configs["task"].lower() == "classification": 177 | y_confs = model.predict_proba(X_val) 178 | val_auc = roc_auc_score(y_val, y_confs[:, 1]) 179 | val_acc = accuracy_score(y_val, y_pred) 180 | score_dict = { 181 | "target": merged_df_train.columns[i], 182 | "run": f"cv_{count}", 183 | "auc": val_auc, 184 | "acc": val_acc, 185 | } 186 | else: 187 | val_mae = mean_absolute_error(y_val, y_pred) 188 | val_rmse = mean_squared_error(y_val, y_pred, squared=False) 189 | val_r2 = r2_score(y_val, y_pred) 190 | val_corr = pearsonr(y_val, y_pred)[0] 191 | score_dict = { 192 | "drug_id": merged_df_train.columns[i], 193 | "run": f"cv_{count}", 194 | "mae": val_mae, 195 | "rmse": val_rmse, 196 | "r2": val_r2, 197 | "corr": val_corr, 198 | } 199 | score_df_list.append(score_dict) 200 | 201 | # record feature importance if possible 202 | end_time = time() 203 | seconds_elapsed = end_time - start_time 204 | logger.info(f"cv_{count}: {seconds_elapsed} seconds") 205 | time_df_list.append({"run": f"cv_{count}", "time": seconds_elapsed}) 206 | count += 1 207 | time_df_list = pd.DataFrame(time_df_list) 208 | logger.info(f"All finished.") 209 | score_df = pd.DataFrame(score_df_list) 210 | # logger.info(score_df.median()) 211 | time_df_list.to_csv(f"{configs['work_dir']}/time_{STAMP}_{log_suffix}.csv", index=False) 212 | 213 | if "save_scores" not in configs or configs["save_scores"]: 214 | score_df.to_csv( 215 | f"{configs['work_dir']}/scores_{STAMP}_{log_suffix}.csv", index=False 216 | ) 217 | # Select only the numeric columns 218 | numeric_columns = score_df.select_dtypes(include=["number"]) 219 | # Compute the median of the numeric columns 220 | print(numeric_columns.median()) 221 | -------------------------------------------------------------------------------- /scripts/baseline_independent_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to run the baseline model for the early concatenation methods on the independent test set for drug response prediction. 3 | E.g. python scripts/baseline_independent_test.py configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna_prot/ec_rf_all_genes_mutation_cnv_rna_prot.json 4 | """ 5 | import json 6 | import logging 7 | import os 8 | import sys 9 | from datetime import datetime 10 | 11 | import numpy as np 12 | import pandas as pd 13 | from scipy.stats import pearsonr 14 | from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier 15 | from sklearn.linear_model import ElasticNet 16 | from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error, roc_auc_score, accuracy_score 17 | from sklearn.neural_network import MLPRegressor, MLPClassifier 18 | from sklearn.svm import SVR, SVC 19 | from tqdm import trange 20 | 21 | STAMP = datetime.today().strftime('%Y%m%d%H%M') 22 | OUTPUT_NA_NUM = -100 23 | 24 | config_file = sys.argv[1] 25 | 26 | # load model configs 27 | configs = json.load(open(config_file, 'r')) 28 | 29 | log_suffix = f"{config_file.split('/')[-1].replace('.json', '')}" 30 | if not os.path.isdir(configs['work_dir']): 31 | os.system(f"mkdir -p {configs['work_dir']}") 32 | 33 | log_file = f"{STAMP}_{log_suffix}.log" 34 | logger = logging.getLogger('baseline_ec') 35 | logger.setLevel(logging.DEBUG) 36 | fh = logging.FileHandler(os.path.join(configs['work_dir'], log_file)) 37 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 38 | fh.setFormatter(formatter) 39 | logger.addHandler(fh) 40 | 41 | logger.info(open(config_file, 'r').read()) 42 | print(open(config_file, 'r').read()) 43 | 44 | seed = 12345 45 | 46 | if configs['task'].lower() == 'classification': 47 | model_dict = {'rf': RandomForestClassifier(n_jobs=100), 48 | 'svm': SVC(), 49 | 'en': ElasticNet(), 50 | 'svm-linear': SVC(kernel='linear'), 51 | 'mlp': MLPClassifier(verbose=True)} 52 | 53 | else: 54 | model_dict = {'rf': RandomForestRegressor(n_jobs=100), 55 | 'svm': SVR(), 56 | 'en': ElasticNet(), 57 | 'svm-linear': SVR(kernel='linear'), 58 | 'mlp': MLPRegressor(verbose=True)} 59 | 60 | data_target_train = pd.read_csv(configs['target_file_train'], index_col=0) 61 | data_input_train = pd.read_csv(configs['data_file_train'], index_col=0) 62 | data_target_test = pd.read_csv(configs['target_file_test'], index_col=0) 63 | data_input_test = pd.read_csv(configs['data_file_test'], index_col=0) 64 | data_type = configs['data_type'] 65 | 66 | common_features = set(data_input_train.columns).intersection(data_input_test.columns) 67 | data_input_train = data_input_train[common_features] 68 | data_input_test = data_input_test[common_features] 69 | 70 | genes = np.unique(([x.split("_")[0] for x in data_input_train.columns])) 71 | if 'pathway_file' in configs: 72 | pathway_dict = {} 73 | pathway_df = pd.read_csv(configs['pathway_file']) 74 | if 'min_cancer_publication' in configs: 75 | pathway_df = pathway_df[pathway_df['Cancer_Publications'] > configs['min_cancer_publication']] 76 | logger.info(f"Filtering pathway with Cancer_Publications > {configs['min_cancer_publication']}") 77 | if 'max_gene_num' in configs: 78 | pathway_df = pathway_df[pathway_df['GeneNumber'] < configs['max_gene_num']] 79 | logger.info(f"Filtering pathway with GeneNumber < {configs['max_gene_num']}") 80 | if 'min_gene_num' in configs: 81 | pathway_df = pathway_df[pathway_df['GeneNumber'] > configs['min_gene_num']] 82 | logger.info(f"Filtering pathway with GeneNumber > {configs['min_gene_num']}") 83 | 84 | pathway_df['genes'] = pathway_df['genes'].map( 85 | lambda x: "|".join([gene for gene in x.split('|') if gene in genes])) 86 | 87 | for index, row in pathway_df.iterrows(): 88 | pathway_dict[row['name']] = row['genes'].split('|') 89 | cancer_genes = set([y for x in pathway_df['genes'].values for y in x.split("|")]) 90 | data_input_train = data_input_train[ 91 | [x for x in data_input_train.columns if (x.split("_")[0] in cancer_genes) or (x.split("_")[0] == 'tissue')]] 92 | data_input_test = data_input_test[ 93 | [x for x in data_input_test.columns if (x.split("_")[0] in cancer_genes) or (x.split("_")[0] == 'tissue')]] 94 | 95 | if data_type[0] != 'DR': 96 | data_input_train = data_input_train[ 97 | [x for x in data_input_train.columns if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)]] 98 | data_input_test = data_input_test[ 99 | [x for x in data_input_test.columns if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)]] 100 | 101 | logger.info(f"Input trainning data shape: {data_input_train.shape}") 102 | logger.info(f"Input trainning target shape: {data_target_train.shape}") 103 | logger.info(f"Input test data shape: {data_input_test.shape}") 104 | logger.info(f"Input test target shape: {data_target_test.shape}") 105 | 106 | data_input_train = data_input_train.fillna(0) 107 | data_target_train = data_target_train.fillna(OUTPUT_NA_NUM) 108 | data_input_test = data_input_test.fillna(0) 109 | data_target_test = data_target_test.fillna(OUTPUT_NA_NUM) 110 | 111 | train_df = pd.merge(data_target_train, data_input_train, on='Cell_line') 112 | test_df = pd.merge(data_target_test, data_input_test, on='Cell_line') 113 | num_targets = data_target_train.shape[1] 114 | 115 | feature_df_list = [] 116 | score_df_list = [] 117 | params_df_list = [] 118 | for i in trange(num_targets): 119 | y_train = train_df.iloc[:, i] 120 | X_train = train_df.iloc[:, num_targets:] 121 | X_train = X_train[(y_train != OUTPUT_NA_NUM)] 122 | y_train = y_train[y_train != OUTPUT_NA_NUM] 123 | 124 | y_test = test_df.iloc[:, i] 125 | X_test = test_df.iloc[:, num_targets:] 126 | X_test = X_test[(y_test != OUTPUT_NA_NUM)] 127 | y_test = y_test[y_test != OUTPUT_NA_NUM] 128 | logger.info(f"Running {train_df.columns[i]} Train:{X_train.shape} Test:{X_test.shape}") 129 | 130 | model = model_dict[configs['model']] 131 | 132 | model.fit(X_train, y_train) 133 | y_pred = model.predict(X_test) 134 | sign = 1 if configs['task'].lower() == 'classification' else -1 135 | if configs['task'].lower() == 'classification': 136 | y_confs = model.predict_proba(X_test) 137 | val_auc = roc_auc_score(y_test, y_confs[:, 1]) 138 | val_acc = accuracy_score(y_test, y_pred) 139 | score_dict = {'target': train_df.columns[i], 'auc': val_auc, 140 | 'acc': val_acc} 141 | else: 142 | val_mae = mean_absolute_error(y_test, y_pred) 143 | val_rmse = mean_squared_error(y_test, y_pred, squared=False) 144 | val_r2 = r2_score(y_test, y_pred) 145 | val_corr = pearsonr(y_test, y_pred)[0] 146 | score_dict = {'drug_id': train_df.columns[i], 147 | 'mae': val_mae, 'rmse': val_rmse, 148 | 'r2': val_r2, 'corr': val_corr} 149 | score_df_list.append(score_dict) 150 | 151 | logger.info(f"All finished.") 152 | score_df = pd.DataFrame(score_df_list) 153 | logger.info(score_df.median()) 154 | if 'save_scores' not in configs or configs['save_scores']: 155 | score_df.to_csv(f"{configs['work_dir']}/scores_{STAMP}_{log_suffix}.csv", index=False) 156 | 157 | -------------------------------------------------------------------------------- /scripts/cancer_type_baseline_23cancertypes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to run baseline models for cancer type classification using cross validation 3 | E.g. python cancer_type_baseline_23cancertypes.py 4 | """ 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from sklearn.ensemble import RandomForestClassifier 9 | from sklearn.linear_model import LogisticRegression, RidgeClassifier 10 | from sklearn.metrics import accuracy_score, f1_score, roc_auc_score 11 | from sklearn.model_selection import KFold 12 | from sklearn.naive_bayes import GaussianNB 13 | from sklearn.neighbors import KNeighborsClassifier 14 | from sklearn.svm import SVC 15 | from sklearn.tree import DecisionTreeClassifier 16 | from sklearn.neural_network import MLPClassifier 17 | from tqdm import tqdm 18 | from xgboost import XGBClassifier 19 | import os 20 | 21 | NUM_REPEAT = 2 22 | num_fold = 5 23 | 24 | seed = 1 25 | 26 | 27 | def run_model(input_df, clf_name, data_type=("mutation", "cnv", "rna")): 28 | count = 0 29 | clf_results_df = [] 30 | if data_type[0] != "DR": 31 | input_df = input_df[ 32 | [ 33 | x 34 | for x in input_df.columns 35 | if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type) 36 | ] 37 | ] 38 | num_of_features = input_df.shape[1] 39 | cell_lines_all = input_df.index.values 40 | for n in range(NUM_REPEAT): 41 | cv = KFold(n_splits=num_fold, shuffle=True, random_state=(seed + n)) 42 | 43 | for cell_lines_train_index, cell_lines_val_index in tqdm( 44 | cv.split(cell_lines_all), total=num_fold 45 | ): 46 | train_lines = np.array(cell_lines_all)[cell_lines_train_index] 47 | val_lines = np.array(cell_lines_all)[cell_lines_val_index] 48 | 49 | merged_df_train = pd.merge( 50 | input_df[input_df.index.isin(train_lines)], target_df, on=["Cell_line"] 51 | ) 52 | 53 | val_data = input_df[input_df.index.isin(val_lines)] 54 | 55 | merged_df_val = pd.merge(val_data, target_df, on=["Cell_line"]) 56 | if clf_name == "RF": 57 | clf = RandomForestClassifier(n_jobs=40, max_features="sqrt") 58 | elif clf_name == "XGB": 59 | clf = XGBClassifier() 60 | elif clf_name == "LR": 61 | clf = LogisticRegression(n_jobs=40, solver="saga") 62 | elif clf_name == "KNN": 63 | clf = KNeighborsClassifier() 64 | elif clf_name == "NB": 65 | clf = GaussianNB() 66 | elif clf_name == "MLP": 67 | clf = MLPClassifier(verbose=True) 68 | elif clf_name == "DT": 69 | clf = DecisionTreeClassifier() 70 | else: 71 | raise Exception 72 | clf.fit( 73 | merged_df_train.iloc[:, :num_of_features], 74 | merged_df_train.iloc[:, num_of_features:].values.flatten(), 75 | ) 76 | y_pred = clf.predict(merged_df_val.iloc[:, :num_of_features]) 77 | y_conf = clf.predict_proba(merged_df_val.iloc[:, :num_of_features]) 78 | y_true = merged_df_val.iloc[:, num_of_features:].values.flatten() 79 | acc = accuracy_score(y_true, y_pred) 80 | f1 = f1_score(y_true, y_pred, average="macro") 81 | try: 82 | auc = roc_auc_score(y_true, y_conf, multi_class="ovo") 83 | except: 84 | auc = np.nan 85 | 86 | clf_results_df.append( 87 | {"run": f"cv_{count}", "acc": acc, "f1": f1, "roc_auc": auc} 88 | ) 89 | 90 | count += 1 91 | clf_results_df = pd.DataFrame(clf_results_df) 92 | return clf_results_df 93 | 94 | 95 | # input_df = pd.read_csv( 96 | # "../data/processed/omics/tcga_23_cancer_types_mutation_cnv_rna_union.csv", 97 | # index_col=0, 98 | # ) 99 | 100 | # genes = np.unique( 101 | # ([x.split("_")[0] for x in input_df.columns if x.split("_")[0] != "tissue"]) 102 | # ) 103 | 104 | target_df = pd.read_csv( 105 | "../data/processed/cancer_type/tcga_23_cancer_types_mutation_cnv_rna.csv", 106 | index_col=0, 107 | ) 108 | 109 | # pathway_dict = {} 110 | # pathway_df = pd.read_csv( 111 | # "../data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv" 112 | # ) 113 | 114 | # pathway_df["genes"] = pathway_df["genes"].map( 115 | # lambda x: "|".join([gene for gene in x.split("|") if gene in genes]) 116 | # ) 117 | # # pathway_df = pathway_df[pathway_df['Cancer_Publications'] > 50] 118 | 119 | # for index, row in pathway_df.iterrows(): 120 | # if row["genes"]: 121 | # pathway_dict[row["name"]] = row["genes"].split("|") 122 | 123 | # cancer_genes = list(set([y for x in pathway_df["genes"].values for y in x.split("|")])) 124 | # non_cancer_genes = sorted(set(genes) - set(cancer_genes)) 125 | 126 | # class_name_to_id = dict( 127 | # zip( 128 | # sorted(target_df.iloc[:, 0].unique()), 129 | # list(range(target_df.iloc[:, 0].unique().size)), 130 | # ) 131 | # ) 132 | # id_to_class_name = dict( 133 | # zip( 134 | # list(range(target_df.iloc[:, 0].unique().size)), 135 | # sorted(target_df.iloc[:, 0].unique()), 136 | # ) 137 | # ) 138 | 139 | # input_df_cancergenes = input_df[ 140 | # [x for x in input_df.columns if (x.split("_")[0] in cancer_genes)] 141 | # ] 142 | 143 | # input_df_cancergenes = input_df_cancergenes.fillna(0) 144 | # input_df = input_df.fillna(0) 145 | 146 | dir_path = "../results/tcga_all_cancer_types/" 147 | 148 | if not os.path.exists(dir_path): 149 | os.makedirs(dir_path) 150 | 151 | # %% All Cancer types 152 | # print("Running LR") 153 | # lr_results_df = run_model(input_df, "LR") 154 | # lr_results_df.to_csv("../results/tcga_all_cancer_types/lr_results_allgenes.csv", 155 | # index=False) 156 | # print("Running MLP") 157 | # mlp_results_df = run_model(input_df_cancergenes, "MLP") 158 | # mlp_results_df.to_csv("../results/tcga_all_cancer_types/mlp_results.csv", 159 | # index=False) 160 | # print("Running RF") 161 | # rf_results_df = run_model(input_df, "RF") 162 | # rf_results_df.to_csv("../results/tcga_all_cancer_types/rf_mutation_cnv_rna_allgenes_results.csv", 163 | # index=False) 164 | # print("Running KNN") 165 | # knn_results_df = run_model(input_df, "KNN") 166 | # knn_results_df.to_csv("../results/tcga_all_cancer_types/knn_mutation_cnv_rna_allgenes.csv", 167 | # index=False) 168 | # print("Running DR") 169 | # dt_results_df = run_model(input_df, "DT") 170 | # dt_results_df.to_csv("../results/tcga_all_cancer_types/dt_mutation_cnv_rna_allgenes.csv", 171 | # index=False) 172 | # print("Running PCA") 173 | # pca_input_df = pd.read_csv( 174 | # "../data/DR/pca/tcga_23_cancer_types_mutation_cnv_rna_allgenes.csv", 175 | # index_col=0) 176 | # rf_pca_results_df = run_model(pca_input_df, "RF", data_type = ['DR']) 177 | # rf_pca_results_df.to_csv("../results/tcga_all_cancer_types/pca_rf_mutation_cnv_rna_allgenes.csv", 178 | # index=False) 179 | # print("Running moCluster") 180 | # moCluster_input_df = pd.read_csv( 181 | # "../data/DR/moCluster/tcga_23_cancer_types_mutation_cnv_rna_allgenes.csv", 182 | # index_col=0, 183 | # ) 184 | # moCluster_results_df = run_model(moCluster_input_df, "RF", data_type=["DR"]) 185 | # moCluster_results_df.to_csv( 186 | # "../results/tcga_all_cancer_types/moCluster_rf_mutation_cnv_rna_allgenes.csv", 187 | # index=False, 188 | # ) 189 | 190 | move_input_df = pd.read_csv( 191 | "../data/DR/MOVE/tcga_mutation_cnv_rna_200factor.csv", index_col=0 192 | ) 193 | move_results_df = run_model(move_input_df, "RF", data_type=["DR"]) 194 | move_results_df.to_csv( 195 | "../results/tcga_all_cancer_types/move_rf_mutation_cnv_rna_200factor.csv", 196 | index=False, 197 | ) 198 | 199 | scvaeit_input_df = pd.read_csv( 200 | "../data/DR/scVAEIT/tcga_scvaeit_latent_200factor.csv", index_col=0 201 | ) 202 | scvaeit_results_df = run_model(scvaeit_input_df, "RF", data_type=["DR"]) 203 | scvaeit_results_df.to_csv( 204 | "../results/tcga_all_cancer_types/scvaeit_rf_mutation_cnv_rna_200factor.csv", 205 | index=False, 206 | ) 207 | -------------------------------------------------------------------------------- /scripts/cancer_type_baseline_brca.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to run baseline models for breast cancer subtype classification using cross validation 3 | E.g. python cancer_type_baseline_brca.py 4 | """ 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from sklearn.ensemble import RandomForestClassifier 9 | from sklearn.linear_model import LogisticRegression, RidgeClassifier 10 | from sklearn.metrics import accuracy_score, f1_score, roc_auc_score 11 | from sklearn.model_selection import KFold 12 | from sklearn.naive_bayes import GaussianNB 13 | from sklearn.neighbors import KNeighborsClassifier 14 | from sklearn.svm import SVC 15 | from sklearn.tree import DecisionTreeClassifier 16 | from tqdm import tqdm 17 | from xgboost import XGBClassifier 18 | 19 | NUM_REPATE = 5 20 | num_fold = 5 21 | 22 | seed = 1 23 | 24 | 25 | def run_model(input_df, clf_name, data_type=("mutation", "cnv", "rna")): 26 | count = 0 27 | clf_results_df = [] 28 | if data_type[0] != "DR": 29 | input_df = input_df[ 30 | [ 31 | x 32 | for x in input_df.columns 33 | if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type) 34 | ] 35 | ] 36 | num_of_features = input_df.shape[1] 37 | cell_lines_all = input_df.index.values 38 | for n in range(NUM_REPATE): 39 | cv = KFold(n_splits=num_fold, shuffle=True, random_state=(seed + n)) 40 | 41 | for cell_lines_train_index, cell_lines_val_index in tqdm( 42 | cv.split(cell_lines_all), total=num_fold 43 | ): 44 | train_lines = np.array(cell_lines_all)[cell_lines_train_index] 45 | val_lines = np.array(cell_lines_all)[cell_lines_val_index] 46 | 47 | merged_df_train = pd.merge( 48 | input_df[input_df.index.isin(train_lines)], target_df, on=["Cell_line"] 49 | ) 50 | 51 | val_data = input_df[input_df.index.isin(val_lines)] 52 | 53 | merged_df_val = pd.merge(val_data, target_df, on=["Cell_line"]) 54 | if clf_name == "RF": 55 | clf = RandomForestClassifier(n_jobs=40, max_features="sqrt") 56 | elif clf_name == "XGB": 57 | clf = XGBClassifier() 58 | elif clf_name == "LR": 59 | clf = LogisticRegression(n_jobs=40, solver="saga") 60 | elif clf_name == "ridge": 61 | clf = RidgeClassifier() 62 | elif clf_name == "KNN": 63 | clf = KNeighborsClassifier() 64 | elif clf_name == "SVML": 65 | clf = SVC(probability=True, kernel="linear") 66 | elif clf_name == "SVM": 67 | clf = SVC(probability=True, kernel="linear") 68 | elif clf_name == "NB": 69 | clf = GaussianNB() 70 | elif clf_name == "DT": 71 | clf = DecisionTreeClassifier() 72 | else: 73 | raise Exception 74 | clf.fit( 75 | merged_df_train.iloc[:, :num_of_features], 76 | merged_df_train.iloc[:, num_of_features:].values.flatten(), 77 | ) 78 | y_pred = clf.predict(merged_df_val.iloc[:, :num_of_features]) 79 | y_conf = clf.predict_proba(merged_df_val.iloc[:, :num_of_features]) 80 | y_true = merged_df_val.iloc[:, num_of_features:].values.flatten() 81 | acc = accuracy_score(y_true, y_pred) 82 | f1 = f1_score(y_true, y_pred, average="macro") 83 | auc = roc_auc_score(y_true, y_conf, multi_class="ovo") 84 | clf_results_df.append( 85 | {"run": f"cv_{count}", "acc": acc, "f1": f1, "roc_auc": auc} 86 | ) 87 | 88 | count += 1 89 | clf_results_df = pd.DataFrame(clf_results_df) 90 | return clf_results_df 91 | 92 | 93 | input_df = pd.read_csv( 94 | "../data/processed/omics/tcga_brca_mutation_cnv_rna_log2.csv.gz", index_col=0 95 | ) 96 | 97 | genes = np.unique( 98 | ([x.split("_")[0] for x in input_df.columns if x.split("_")[0] != "tissue"]) 99 | ) 100 | 101 | # target_df = pd.read_csv( 102 | # "../data/processed/cancer_type/tcga_23_cancer_types_mutation_cnv_rna.csv", 103 | # index_col=0) 104 | target_df = pd.read_csv( 105 | "../data/processed/cancer_type/tcga_brca_mutation_cnv_rna_subtypes.csv", index_col=0 106 | ) 107 | 108 | 109 | pathway_dict = {} 110 | pathway_df = pd.read_csv( 111 | "../data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv" 112 | ) 113 | 114 | pathway_df["genes"] = pathway_df["genes"].map( 115 | lambda x: "|".join([gene for gene in x.split("|") if gene in genes]) 116 | ) 117 | # pathway_df = pathway_df[pathway_df['Cancer_Publications'] > 50] 118 | 119 | for index, row in pathway_df.iterrows(): 120 | if row["genes"]: 121 | pathway_dict[row["name"]] = row["genes"].split("|") 122 | 123 | cancer_genes = list(set([y for x in pathway_df["genes"].values for y in x.split("|")])) 124 | non_cancer_genes = sorted(set(genes) - set(cancer_genes)) 125 | 126 | class_name_to_id = dict( 127 | zip( 128 | sorted(target_df.iloc[:, 0].unique()), 129 | list(range(target_df.iloc[:, 0].unique().size)), 130 | ) 131 | ) 132 | id_to_class_name = dict( 133 | zip( 134 | list(range(target_df.iloc[:, 0].unique().size)), 135 | sorted(target_df.iloc[:, 0].unique()), 136 | ) 137 | ) 138 | 139 | input_df_cancergenes = input_df[ 140 | [x for x in input_df.columns if (x.split("_")[0] in cancer_genes)] 141 | ] 142 | 143 | input_df_cancergenes = input_df_cancergenes.fillna(0) 144 | input_df = input_df.fillna(0) 145 | num_fold = 5 146 | 147 | NUM_REPATE = 5 148 | 149 | seed = 1 150 | 151 | # %% BRCA 152 | # print("Running LR") 153 | # lr_results_df = run_model(input_df, "LR") 154 | # lr_results_df.to_csv( 155 | # "../results/tcga_brca_subtype/lr_mutation_cnv_rna_allgenes.csv", index=False 156 | # ) 157 | # print("Running RF") 158 | # rf_results_df = run_model(input_df, "RF") 159 | # rf_results_df.to_csv( 160 | # "../results/tcga_brca_subtype/rf_mutation_cnv_rna_allgenes.csv", index=False 161 | # ) 162 | # print("Running KNN") 163 | # knn_results_df = run_model(input_df, "KNN") 164 | # knn_results_df.to_csv( 165 | # "../results/tcga_brca_subtype/knn_mutation_cnv_rna_allgenes.csv", index=False 166 | # ) 167 | # print("Running DR") 168 | # dt_results_df = run_model(input_df, "DT") 169 | # dt_results_df.to_csv( 170 | # "../results/tcga_brca_subtype/dt_mutation_cnv_rna_allgenes.csv", index=False 171 | # ) 172 | # print("Running PCA") 173 | # pca_input_df = pd.read_csv( 174 | # "../data/DR/pca/tcga_brca_mutation_cnv_rna_allgenes.csv", index_col=0 175 | # ) 176 | # rf_pca_results_df = run_model(pca_input_df, "RF", data_type=["DR"]) 177 | # rf_pca_results_df.to_csv( 178 | # "../results/tcga_brca_subtype/pca_rf_mutation_cnv_rna_allgenes.csv", index=False 179 | # ) 180 | # print("Running moCluster") 181 | # moCluster_input_df = pd.read_csv( 182 | # "../data/DR/moCluster/tcga_brca_mutation_cnv_rna.csv", index_col=0 183 | # ) 184 | # moCluster_results_df = run_model(moCluster_input_df, "RF", data_type=["DR"]) 185 | # moCluster_results_df.to_csv( 186 | # "../results/tcga_brca_subtype/moCluster_rf_mutation_cnv_rna_allgenes.csv", 187 | # index=False, 188 | # ) 189 | 190 | print("Running MOVE") 191 | MOVE_input_df = pd.read_csv( 192 | "../data/DR/MOVE/tcga_brca_mutation_cnv_rna_200factor.csv", index_col=0 193 | ) 194 | MOVE_results_df = run_model(MOVE_input_df, "RF", data_type=["DR"]) 195 | MOVE_results_df.to_csv( 196 | "../results/tcga_brca_subtype/MOVE_rf_mutation_cnv_rna_allgenes.csv", index=False 197 | ) 198 | 199 | print("Running scVAEIT") 200 | scVAEIT_input_df = pd.read_csv( 201 | "../data/DR/scVAEIT/tcga_brca_scvaeit_latent_200factor.csv", index_col=0 202 | ) 203 | scVAEIT_results_df = run_model(scVAEIT_input_df, "RF", data_type=["DR"]) 204 | scVAEIT_results_df.to_csv( 205 | "../results/tcga_brca_subtype/scVAEIT_rf_mutation_cnv_rna_allgenes.csv", index=False 206 | ) 207 | -------------------------------------------------------------------------------- /scripts/cancer_type_baseline_brca_validation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to run baseline models for breast cancer subtype classification using independent test set 3 | E.g. python cancer_type_baseline_brca_validation.py 4 | """ 5 | import pandas as pd 6 | from sklearn.ensemble import RandomForestClassifier 7 | from sklearn.linear_model import LogisticRegression 8 | from sklearn.metrics import accuracy_score, f1_score, roc_auc_score 9 | from xgboost import XGBClassifier 10 | from sklearn.neural_network import MLPClassifier 11 | import numpy as np 12 | import os 13 | 14 | seed = 1 15 | 16 | def run_model(input_df_train, input_df_test, clf_name, data_type=('cnv', 'rna')): 17 | count = 0 18 | clf_results_df = [] 19 | if data_type[0] != 'DR': 20 | input_df_train = input_df_train[[ 21 | x for x in input_df_train.columns 22 | if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type) 23 | ]] 24 | num_of_features = input_df_train.shape[1] 25 | 26 | if clf_name == "RF": 27 | clf = RandomForestClassifier(n_jobs=40, max_features='sqrt') 28 | elif clf_name == "XGB": 29 | clf = XGBClassifier(n_jobs=60) 30 | elif clf_name == "LR": 31 | clf = LogisticRegression(n_jobs=40, solver='saga') 32 | elif clf_name == 'MLP': 33 | clf = MLPClassifier(verbose=True) 34 | else: 35 | raise Exception 36 | 37 | merged_df_train = pd.merge( 38 | input_df_train, 39 | target_df_train, 40 | on=['Cell_line']) 41 | 42 | merged_df_val = pd.merge(input_df_test, target_df_test, on=['Cell_line']) 43 | clf.fit(merged_df_train.iloc[:, :-1], 44 | merged_df_train.iloc[:, -1].values.flatten()) 45 | y_pred = clf.predict(merged_df_val.iloc[:, :-1]) 46 | y_conf = clf.predict_proba(merged_df_val.iloc[:, :-1]) 47 | y_true = merged_df_val.iloc[:, -1].values.flatten() 48 | acc = accuracy_score(y_true, y_pred) 49 | f1 = f1_score(y_true, y_pred, average='macro') 50 | auc = roc_auc_score(y_true, y_conf, multi_class='ovo') 51 | clf_results_df.append({ 52 | 'run': f'cv_{count}', 53 | 'acc': acc, 54 | 'f1': f1, 55 | 'roc_auc': auc 56 | }) 57 | val_res_perclass = {} 58 | val_res_perclass['y_pred'] = y_pred 59 | val_res_perclass['y_true'] = y_true 60 | for i in range(y_conf.shape[1]): 61 | val_res_perclass[f"feature_{i}"] = y_conf[:, i] 62 | 63 | clf_results_df = pd.DataFrame(clf_results_df) 64 | return clf_results_df, pd.DataFrame(val_res_perclass) 65 | 66 | 67 | input_df_train = pd.read_csv( 68 | "../data/processed/omics/tcga_brca_as_validation.csv", 69 | index_col=0) 70 | target_df_train = pd.read_csv( 71 | "../data/processed/cancer_type/tcga_brca_mutation_cnv_rna_subtypes.csv", 72 | index_col=0) 73 | input_df_test = pd.read_csv( 74 | "../data/processed/omics/cptac_as_validation.csv", 75 | index_col=0) 76 | target_df_test = pd.read_csv( 77 | "../data/processed/cancer_type/cptac_brca_cnv_rna_subtypes_independent.csv", 78 | index_col=0) 79 | # target_df_train = pd.read_csv( 80 | # "../data/processed/cancer_type/tcga_23_cancer_types_mutation_cnv_rna.csv", 81 | # index_col=0) 82 | genes = np.unique(([x.split("_")[0] for x in input_df_train.columns if x.split("_")[0] != 'tissue'])) 83 | pathway_dict = {} 84 | pathway_df = pd.read_csv( 85 | "../data/graph_predefined/LCPathways/41568_2020_240_MOESM4_ESM.csv" 86 | ) 87 | 88 | pathway_df['genes'] = pathway_df['genes'].map( 89 | lambda x: "|".join([gene for gene in x.split('|') if gene in genes])) 90 | # pathway_df = pathway_df[pathway_df['Cancer_Publications'] > 50] 91 | 92 | for index, row in pathway_df.iterrows(): 93 | if row['genes']: 94 | pathway_dict[row['name']] = row['genes'].split('|') 95 | 96 | cancer_genes = list(set([y for x in pathway_df['genes'].values for y in x.split("|")])) 97 | non_cancer_genes = sorted(set(genes) - set(cancer_genes)) 98 | input_df_train_cancergenes = input_df_train[ 99 | [x for x in input_df_train.columns if (x.split("_")[0] in cancer_genes)]] 100 | 101 | input_df_train_cancergenes = input_df_train_cancergenes.fillna(0) 102 | 103 | input_df_test_cancergenes = input_df_test[ 104 | [x for x in input_df_test.columns if (x.split("_")[0] in cancer_genes)]] 105 | input_df_test_cancergenes = input_df_test_cancergenes.fillna(0) 106 | 107 | class_name_to_id = dict( 108 | zip(sorted(target_df_train.iloc[:, 0].unique()), 109 | list(range(target_df_train.iloc[:, 0].unique().size)))) 110 | id_to_class_name = dict( 111 | zip(list(range(target_df_train.iloc[:, 0].unique().size)), sorted(target_df_train.iloc[:, 0].unique()))) 112 | 113 | input_df_train = input_df_train.fillna(0) 114 | input_df_test = input_df_test.fillna(0) 115 | 116 | dir_path = "../results/tcga_brca_subtype/" 117 | 118 | if not os.path.exists(dir_path): 119 | os.makedirs(dir_path) 120 | 121 | # %% BRCA 122 | # print("Running LR") 123 | # lr_results_df, all_val_df = run_model(input_df_train, input_df_test, "LR") 124 | # all_val_df.columns = [id_to_class_name[int(x.split("_")[-1])] if "feature_" in x else x for x in 125 | # all_val_df.columns] 126 | # lr_results_df.to_csv("../results/tcga_brca_subtype/lr_cnv_rna_allgenes_results_validation.csv", 127 | # index=False) 128 | # all_val_df.to_csv(f"../results/tcga_brca_subtype/lr_all_res_cnv_rna_allgenes_results_validation.csv.gz", index=False) 129 | 130 | print("Running MLP") 131 | mlp_results_df, all_val_df = run_model(input_df_train_cancergenes, input_df_test_cancergenes, "MLP") 132 | all_val_df.columns = [id_to_class_name[int(x.split("_")[-1])] if "feature_" in x else x for x in 133 | all_val_df.columns] 134 | mlp_results_df.to_csv("../results/tcga_brca_subtype/mlp_cnv_rna_results_validation.csv", 135 | index=False) 136 | all_val_df.to_csv(f"../results/tcga_brca_subtype/mlp_all_res_cnv_rna_results_validation.csv.gz", index=False) 137 | 138 | # print("Running RF") 139 | # rf_results_df, val_res_perclass = run_model(input_df_train, input_df_test, "RF") 140 | # all_val_df.columns = [id_to_class_name[int(x.split("_")[-1])] if "feature_" in x else x for x in 141 | # all_val_df.columns] 142 | # rf_results_df.to_csv("../results/tcga_brca_subtype/rf_cnv_rna_allgenes_results_validation.csv", 143 | # index=False) 144 | # all_val_df.to_csv(f"../results/tcga_brca_subtype/rf_all_res_cnv_rna_allgenes_results_validation.csv.gz", index=False) 145 | 146 | # print("Running XGB") 147 | # xgb_results_df, val_res_perclass = run_model(input_df_train, input_df_test, "XGB") 148 | # xgb_results_df.to_csv("../results/tcga_brca_subtype/xgb_cnv_rna_allgenes_validation.csv", 149 | # index=False) 150 | 151 | -------------------------------------------------------------------------------- /scripts/deepathnet_cv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to run DeePathNet with cross validation for any task. 3 | E.g. python scripts/deepathnet_cv.py configs/tcga_all_cancer_types/mutation_cnv_rna/deepathnet_allgenes_mutation_cnv_rna.json 4 | """ 5 | import json 6 | import sys 7 | from datetime import datetime 8 | 9 | import torch.optim 10 | from sklearn.model_selection import KFold, StratifiedKFold 11 | from torch.utils.data import DataLoader 12 | 13 | from model_transformer_lrp import DeePathNet 14 | from models import * 15 | from utils.training_prepare import prepare_data_cv 16 | 17 | STAMP = datetime.today().strftime("%Y%m%d%H%M") 18 | proj_dir = "/home/scai/DeePathNet" 19 | sys.path.extend([proj_dir]) 20 | 21 | config_file = sys.argv[1] 22 | # load model configs 23 | configs = json.load(open(config_file, "r")) 24 | 25 | log_suffix = "" 26 | if "suffix" in configs: 27 | log_suffix = configs["suffix"] 28 | 29 | seed = configs["seed"] 30 | torch.manual_seed(seed) 31 | 32 | BATCH_SIZE = configs["batch_size"] 33 | NUM_WORKERS = 0 34 | LOG_FREQ = configs["log_freq"] 35 | NUM_EPOCHS = configs["num_of_epochs"] 36 | device = "cuda" if torch.cuda.is_available() else "cpu" 37 | 38 | 39 | def get_setup(genes_to_id, id_to_genes, target_dim): 40 | def load_pathway(random_control=False): 41 | pathway_dict = {} 42 | pathway_df = pd.read_csv(configs["pathway_file"]) 43 | if "min_cancer_publication" in configs: 44 | pathway_df = pathway_df[ 45 | pathway_df["Cancer_Publications"] > configs["min_cancer_publication"] 46 | ] 47 | logger.info( 48 | f"Filtering pathway with Cancer_Publications > {configs['min_cancer_publication']}" 49 | ) 50 | if "max_gene_num" in configs: 51 | pathway_df = pathway_df[pathway_df["GeneNumber"] < configs["max_gene_num"]] 52 | logger.info( 53 | f"Filtering pathway with GeneNumber < {configs['max_gene_num']}" 54 | ) 55 | if "min_gene_num" in configs: 56 | pathway_df = pathway_df[pathway_df["GeneNumber"] > configs["min_gene_num"]] 57 | logger.info( 58 | f"Filtering pathway with GeneNumber > {configs['min_gene_num']}" 59 | ) 60 | 61 | pathway_df["genes"] = pathway_df["genes"].map( 62 | lambda x: "|".join([gene for gene in x.split("|") if gene in genes]) 63 | ) 64 | 65 | for index, row in pathway_df.iterrows(): 66 | if row["genes"]: 67 | pathway_dict[row["name"]] = row["genes"].split("|") 68 | cancer_genes = set( 69 | [y for x in pathway_df["genes"].values for y in x.split("|")] 70 | ) 71 | non_cancer_genes = set(genes) - set(cancer_genes) 72 | logger.info( 73 | f"Cancer genes:{len(cancer_genes)}\tNon-cancer genes:{len(non_cancer_genes)}" 74 | ) 75 | if random_control: 76 | logger.info("Randomly select genes for each pathway") 77 | for key in pathway_dict: 78 | pathway_dict[key] = np.random.choice(list(set(cancer_genes)), 79 | len(pathway_dict[key]), replace=False) 80 | return pathway_dict, non_cancer_genes 81 | 82 | random_control = False if "random_control" not in configs else configs["random_control"] 83 | pathway_dict, non_cancer_genes = load_pathway(random_control=random_control) 84 | model = DeePathNet( 85 | len(omics_types), 86 | target_dim, 87 | genes_to_id, 88 | id_to_genes, 89 | pathway_dict, 90 | non_cancer_genes, 91 | embed_dim=configs["dim"], 92 | depth=configs["depth"], 93 | mlp_ratio=configs["mlp_ratio"], 94 | out_mlp_ratio=configs["out_mlp_ratio"], 95 | num_heads=configs["heads"], 96 | pathway_drop_rate=configs["pathway_dropout"], 97 | only_cancer_genes=configs["cancer_only"], 98 | tissues=tissues, 99 | ) 100 | logger.info( 101 | open("/home/scai/DeePathNet/scripts/model_transformer_lrp.py", "r").read() 102 | ) 103 | 104 | logger.info(model) 105 | model = model.to(device) 106 | 107 | criterion = nn.MSELoss() 108 | 109 | optimizer = torch.optim.Adam( 110 | model.parameters(), lr=configs["lr"], weight_decay=configs["weight_decay"] 111 | ) 112 | 113 | logger.info(optimizer) 114 | 115 | lr_scheduler = None 116 | 117 | return model, criterion, optimizer, lr_scheduler 118 | 119 | 120 | def run_experiment( 121 | merged_df_train, merged_df_test, val_score_dict, run="test", class_name_to_id=None 122 | ): 123 | train_df = merged_df_train.iloc[:, :num_of_features] 124 | test_df = merged_df_test.iloc[:, :num_of_features] 125 | train_target = merged_df_train.iloc[:, num_of_features:] 126 | test_target = merged_df_test.iloc[:, num_of_features:] 127 | 128 | X_train = train_df 129 | X_test = test_df 130 | 131 | if configs["task"] == "multiclass": 132 | train_dataset = MultiOmicMulticlassDataset( 133 | X_train, 134 | train_target, 135 | mode="train", 136 | omics_types=omics_types, 137 | class_name_to_id=class_name_to_id, 138 | logger=logger, 139 | ) 140 | test_dataset = MultiOmicMulticlassDataset( 141 | X_test, 142 | test_target, 143 | mode="val", 144 | omics_types=omics_types, 145 | class_name_to_id=class_name_to_id, 146 | logger=logger, 147 | ) 148 | else: 149 | train_dataset = MultiOmicDataset( 150 | X_train, 151 | train_target, 152 | mode="train", 153 | omics_types=omics_types, 154 | logger=logger, 155 | with_tissue=with_tissue, 156 | ) 157 | test_dataset = MultiOmicDataset( 158 | X_test, 159 | test_target, 160 | mode="val", 161 | omics_types=omics_types, 162 | logger=logger, 163 | with_tissue=with_tissue, 164 | ) 165 | 166 | train_loader = DataLoader( 167 | train_dataset, 168 | batch_size=BATCH_SIZE, 169 | shuffle=True, 170 | drop_last=configs["drop_last"], 171 | num_workers=NUM_WORKERS, 172 | ) 173 | 174 | test_loader = DataLoader( 175 | test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS 176 | ) 177 | 178 | if configs["task"] == "multiclass": 179 | target_dim = len(class_name_to_id) 180 | else: 181 | target_dim = train_target.shape[1] 182 | model, criterion, optimizer, lr_scheduler = get_setup( 183 | train_dataset.genes_to_id, train_dataset.id_to_genes, target_dim 184 | ) 185 | 186 | val_drug_ids = merged_df_test.columns[num_of_features:] 187 | val_res = train_loop( 188 | NUM_EPOCHS, 189 | train_loader, 190 | test_loader, 191 | model, 192 | criterion, 193 | optimizer, 194 | logger, 195 | STAMP, 196 | configs, 197 | lr_scheduler, 198 | val_drug_ids, 199 | run=run, 200 | val_score_dict=val_score_dict, 201 | ) 202 | 203 | return val_res 204 | 205 | 206 | data_dict = prepare_data_cv(config_file, STAMP) 207 | data_input_all = data_dict["data_input_all"] 208 | data_target_all = data_dict["data_target_all"] 209 | val_score_dict = data_dict["val_score_dict"] 210 | num_of_features = data_dict["num_of_features"] 211 | genes = data_dict["genes"] 212 | omics_types = data_dict["omics_types"] 213 | with_tissue = data_dict["with_tissue"] 214 | tissues = data_dict["tissues"] 215 | logger = data_dict["logger"] 216 | cell_lines_all = data_input_all.index.values 217 | 218 | class_name_to_id = None 219 | id_to_class_name = None 220 | if configs["task"] == "multiclass": 221 | class_name_to_id = dict( 222 | zip( 223 | sorted(data_target_all.iloc[:, 0].unique()), 224 | list(range(data_target_all.iloc[:, 0].unique().size)), 225 | ) 226 | ) 227 | id_to_class_name = dict( 228 | zip( 229 | list(range(data_target_all.iloc[:, 0].unique().size)), 230 | sorted(data_target_all.iloc[:, 0].unique()), 231 | ) 232 | ) 233 | 234 | count = 0 235 | num_repeat = 1 if "num_repeat" not in configs else configs["num_repeat"] 236 | 237 | all_val_df = [] 238 | 239 | if configs["save_checkpoints"]: 240 | # only run once if the purpose is for model explanation 241 | cv = KFold(n_splits=5, shuffle=True, random_state=seed) 242 | cell_lines_train_index, cell_lines_val_index = next(cv.split(cell_lines_all)) 243 | train_lines = np.array(cell_lines_all)[cell_lines_train_index] 244 | val_lines = np.array(cell_lines_all)[cell_lines_val_index] 245 | 246 | merged_df_train = pd.merge( 247 | data_input_all[data_input_all.index.isin(train_lines)], 248 | data_target_all, 249 | on=["Cell_line"], 250 | ) 251 | 252 | val_data = data_input_all[data_input_all.index.isin(val_lines)] 253 | 254 | merged_df_val = pd.merge(val_data, data_target_all, on=["Cell_line"]) 255 | val_res = run_experiment( 256 | merged_df_train, 257 | merged_df_val, 258 | val_score_dict, 259 | run=f"cv_{count}", 260 | class_name_to_id=class_name_to_id, 261 | ) 262 | all_val_df.append(val_res) 263 | else: 264 | for n in range(num_repeat): 265 | if configs["task"] == "multiclass": 266 | cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=(seed + n)) 267 | folds = cv.split(cell_lines_all, data_target_all.iloc[:, 0]) 268 | else: 269 | cv = KFold(n_splits=5, shuffle=True, random_state=(seed + n)) 270 | folds = cv.split(cell_lines_all) 271 | for cell_lines_train_index, cell_lines_val_index in folds: 272 | train_lines = np.array(cell_lines_all)[cell_lines_train_index] 273 | val_lines = np.array(cell_lines_all)[cell_lines_val_index] 274 | 275 | merged_df_train = pd.merge( 276 | data_input_all[data_input_all.index.isin(train_lines)], 277 | data_target_all, 278 | on=["Cell_line"], 279 | ) 280 | 281 | val_data = data_input_all[data_input_all.index.isin(val_lines)] 282 | 283 | merged_df_val = pd.merge(val_data, data_target_all, on=["Cell_line"]) 284 | 285 | unique_train, unique_test = merged_df_train.iloc[:,-1].unique(), merged_df_val.iloc[:,-1].unique() 286 | if set(unique_train) != set(unique_test): 287 | logger.info("Missing class in validation fold") 288 | continue 289 | 290 | val_res = run_experiment( 291 | merged_df_train, 292 | merged_df_val, 293 | val_score_dict, 294 | run=f"cv_{count}", 295 | class_name_to_id=class_name_to_id, 296 | ) 297 | all_val_df.append(val_res) 298 | count += 1 299 | 300 | 301 | if "save_scores" not in configs or configs["save_scores"]: 302 | val_score_df = pd.DataFrame(val_score_dict) 303 | val_score_df.to_csv( 304 | f"{configs['work_dir']}/scores_{STAMP}{log_suffix}.csv.gz", index=False 305 | ) 306 | if configs["task"] == "multiclass": 307 | all_val_df = pd.concat(all_val_df) 308 | all_val_df["y_pred"] = all_val_df["y_pred"].map(id_to_class_name) 309 | all_val_df["y_true"] = all_val_df["y_true"].map(id_to_class_name) 310 | all_val_df.columns = [ 311 | id_to_class_name[int(x.split("_")[-1])] if "feature_" in x else x 312 | for x in all_val_df.columns 313 | ] 314 | all_val_df.to_csv( 315 | f"{configs['work_dir']}/all_val_res_{STAMP}{log_suffix}.csv.gz", index=False 316 | ) 317 | 318 | logger.info("Full training finished.") 319 | -------------------------------------------------------------------------------- /scripts/deepathnet_independent_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to run DeePathNet with independent test set for any task. 3 | E.g. python scripts/deepathnet_independent_test.py configs/sanger_train_ccle_test_gdsc/mutation_cnv_rna_prot/deepathnet_mutation_cnv_rna_prot.json 4 | """ 5 | import json 6 | import sys 7 | from datetime import datetime 8 | import numpy as np 9 | import torch.optim 10 | from torch.utils.data import DataLoader 11 | 12 | from utils.training_prepare import ( 13 | prepare_data_independent_test, 14 | get_logger, 15 | get_score_dict, 16 | ) 17 | 18 | from models import * 19 | from model_transformer_lrp import DeePathNet 20 | 21 | STAMP = datetime.today().strftime("%Y%m%d%H%M") 22 | 23 | config_file = sys.argv[1] 24 | # load model configs 25 | configs = json.load(open(config_file, "r")) 26 | 27 | log_suffix = "" 28 | if "suffix" in configs: 29 | log_suffix = configs["suffix"] 30 | 31 | seed = configs["seed"] 32 | torch.manual_seed(seed) 33 | np.random.seed(seed) 34 | 35 | BATCH_SIZE = configs["batch_size"] 36 | NUM_WORKERS = 0 37 | LOG_FREQ = configs["log_freq"] 38 | NUM_EPOCHS = configs["num_of_epochs"] 39 | device = "cuda" if torch.cuda.is_available() else "cpu" 40 | RANDOM_CONTROL = False if "random_control" not in configs else configs["random_control"] 41 | 42 | logger = get_logger(config_file, STAMP) 43 | 44 | 45 | def get_setup(genes_to_id, id_to_genes, target_dim, cv=0): 46 | def load_pathway(random_control=False): 47 | pathway_dict = {} 48 | pathway_df = pd.read_csv(configs["pathway_file"]) 49 | if "min_cancer_publication" in configs: 50 | pathway_df = pathway_df[ 51 | pathway_df["Cancer_Publications"] > configs["min_cancer_publication"] 52 | ] 53 | logger.info( 54 | f"Filtering pathway with Cancer_Publications > {configs['min_cancer_publication']}" 55 | ) 56 | if "max_gene_num" in configs: 57 | pathway_df = pathway_df[pathway_df["GeneNumber"] < configs["max_gene_num"]] 58 | logger.info( 59 | f"Filtering pathway with GeneNumber < {configs['max_gene_num']}" 60 | ) 61 | if "min_gene_num" in configs: 62 | pathway_df = pathway_df[pathway_df["GeneNumber"] > configs["min_gene_num"]] 63 | logger.info( 64 | f"Filtering pathway with GeneNumber > {configs['min_gene_num']}" 65 | ) 66 | 67 | pathway_df["genes"] = pathway_df["genes"].map( 68 | lambda x: "|".join([gene for gene in x.split("|") if gene in genes]) 69 | ) 70 | 71 | for index, row in pathway_df.iterrows(): 72 | if row["genes"]: 73 | pathway_dict[row["name"]] = row["genes"].split("|") 74 | cancer_genes = set( 75 | [y for x in pathway_df["genes"].values for y in x.split("|")] 76 | ) 77 | non_cancer_genes = set(genes) - set(cancer_genes) 78 | logger.info( 79 | f"Cancer genes:{len(cancer_genes)}\tNon-cancer genes:{len(non_cancer_genes)}" 80 | ) 81 | if random_control: 82 | logger.info("Randomly select genes for each pathway") 83 | for key in pathway_dict: 84 | pathway_dict[key] = list( 85 | np.random.choice( 86 | list(set(cancer_genes)), len(pathway_dict[key]), replace=False 87 | ) 88 | ) 89 | return pathway_dict, non_cancer_genes 90 | 91 | pathway_dict, non_cancer_genes = load_pathway(random_control=RANDOM_CONTROL) 92 | if RANDOM_CONTROL: 93 | logger.info("Saving random control genes") 94 | with open( 95 | f"{configs['work_dir']}/random_genes_cv{cv}_{STAMP}{log_suffix}.json", "w" 96 | ) as f: 97 | json.dump(pathway_dict, f) 98 | 99 | model = DeePathNet( 100 | len(omics_types), 101 | target_dim, 102 | genes_to_id, 103 | id_to_genes, 104 | pathway_dict, 105 | non_cancer_genes, 106 | embed_dim=configs["dim"], 107 | depth=configs["depth"], 108 | mlp_ratio=configs["mlp_ratio"], 109 | out_mlp_ratio=configs["out_mlp_ratio"], 110 | num_heads=configs["heads"], 111 | pathway_drop_rate=configs["pathway_dropout"], 112 | only_cancer_genes=configs["cancer_only"], 113 | tissues=tissues, 114 | ) 115 | logger.info( 116 | open("/home/scai/DeePathNet/scripts/model_transformer_lrp.py", "r").read() 117 | ) 118 | 119 | logger.info(model) 120 | model = model.to(device) 121 | 122 | criterion = nn.MSELoss() 123 | 124 | optimizer = torch.optim.Adam( 125 | model.parameters(), lr=configs["lr"], weight_decay=configs["weight_decay"] 126 | ) 127 | 128 | logger.info(optimizer) 129 | 130 | lr_scheduler = None 131 | 132 | return model, criterion, optimizer, lr_scheduler 133 | 134 | 135 | def run_experiment( 136 | merged_df_train, 137 | merged_df_test, 138 | val_score_dict, 139 | run="test", 140 | class_name_to_id=None, 141 | cv=0, 142 | ): 143 | train_df = merged_df_train.iloc[:, :num_of_features] 144 | test_df = merged_df_test.iloc[:, :num_of_features] 145 | train_target = merged_df_train.iloc[:, num_of_features:] 146 | test_target = merged_df_test.iloc[:, num_of_features:] 147 | 148 | X_train = train_df 149 | X_test = test_df 150 | 151 | if configs["task"] == "multiclass": 152 | train_dataset = MultiOmicMulticlassDataset( 153 | X_train, 154 | train_target, 155 | mode="train", 156 | omics_types=omics_types, 157 | class_name_to_id=class_name_to_id, 158 | logger=logger, 159 | ) 160 | test_dataset = MultiOmicMulticlassDataset( 161 | X_test, 162 | test_target, 163 | mode="val", 164 | omics_types=omics_types, 165 | class_name_to_id=class_name_to_id, 166 | logger=logger, 167 | ) 168 | else: 169 | train_dataset = MultiOmicDataset( 170 | X_train, 171 | train_target, 172 | mode="train", 173 | omics_types=omics_types, 174 | logger=logger, 175 | with_tissue=with_tissue, 176 | ) 177 | test_dataset = MultiOmicDataset( 178 | X_test, 179 | test_target, 180 | mode="val", 181 | omics_types=omics_types, 182 | logger=logger, 183 | with_tissue=with_tissue, 184 | ) 185 | 186 | train_loader = DataLoader( 187 | train_dataset, 188 | batch_size=BATCH_SIZE, 189 | shuffle=True, 190 | drop_last=configs["drop_last"], 191 | num_workers=NUM_WORKERS, 192 | ) 193 | 194 | test_loader = DataLoader( 195 | test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS 196 | ) 197 | 198 | if configs["task"] == "multiclass": 199 | target_dim = len(class_name_to_id) 200 | else: 201 | target_dim = train_target.shape[1] 202 | model, criterion, optimizer, lr_scheduler = get_setup( 203 | train_dataset.genes_to_id, train_dataset.id_to_genes, target_dim, cv=cv 204 | ) 205 | 206 | val_drug_ids = merged_df_test.columns[num_of_features:] 207 | val_res = train_loop( 208 | NUM_EPOCHS, 209 | train_loader, 210 | test_loader, 211 | model, 212 | criterion, 213 | optimizer, 214 | logger, 215 | STAMP, 216 | configs, 217 | lr_scheduler, 218 | val_drug_ids, 219 | run=run, 220 | val_score_dict=val_score_dict, 221 | ) 222 | 223 | return val_res 224 | 225 | 226 | num_repeat = 1 if "num_repeat" not in configs else configs["num_repeat"] 227 | count = 0 228 | all_val_df = [] 229 | val_score_dict = get_score_dict(config_file) 230 | for n in range(num_repeat): 231 | data_dict = prepare_data_independent_test(config_file, STAMP, seed=count) 232 | data_input_train = data_dict["data_input_train"] 233 | data_target_train = data_dict["data_target_train"] 234 | data_input_test = data_dict["data_input_test"] 235 | data_target_test = data_dict["data_target_test"] 236 | num_of_features = data_dict["num_of_features"] 237 | genes = data_dict["genes"] 238 | omics_types = data_dict["omics_types"] 239 | with_tissue = data_dict["with_tissue"] 240 | tissues = data_dict["tissues"] 241 | 242 | class_name_to_id = None 243 | id_to_class_name = None 244 | if configs["task"] == "multiclass": 245 | class_name_to_id = dict( 246 | zip( 247 | sorted(data_target_train.iloc[:, 0].unique()), 248 | list(range(data_target_train.iloc[:, 0].unique().size)), 249 | ) 250 | ) 251 | id_to_class_name = dict( 252 | zip( 253 | list(range(data_target_train.iloc[:, 0].unique().size)), 254 | sorted(data_target_train.iloc[:, 0].unique()), 255 | ) 256 | ) 257 | 258 | merged_df_train = pd.merge(data_input_train, data_target_train, on=["Cell_line"]) 259 | merged_df_test = pd.merge(data_input_test, data_target_test, on=["Cell_line"]) 260 | 261 | val_res = run_experiment( 262 | merged_df_train, 263 | merged_df_test, 264 | val_score_dict, 265 | run=f"cv_{count}", 266 | class_name_to_id=class_name_to_id, 267 | cv=count, 268 | ) 269 | all_val_df.append(val_res) 270 | count += 1 271 | 272 | if "save_scores" not in configs or configs["save_scores"]: 273 | val_score_df = pd.DataFrame(val_score_dict) 274 | val_score_df.to_csv( 275 | f"{configs['work_dir']}/scores_{STAMP}{log_suffix}.csv.gz", index=False 276 | ) 277 | if configs["task"] == "multiclass": 278 | all_val_df = pd.concat(all_val_df) 279 | all_val_df["y_pred"] = all_val_df["y_pred"].map(id_to_class_name) 280 | all_val_df["y_true"] = all_val_df["y_true"].map(id_to_class_name) 281 | all_val_df.columns = [ 282 | id_to_class_name[int(x.split("_")[-1])] if "feature_" in x else x 283 | for x in all_val_df.columns 284 | ] 285 | all_val_df.to_csv( 286 | f"{configs['work_dir']}/all_val_res_{STAMP}{log_suffix}.csv.gz", index=False 287 | ) 288 | 289 | logger.info("Full training finished.") 290 | -------------------------------------------------------------------------------- /scripts/model_transformer_lrp.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import warnings 6 | from torch import nn, einsum 7 | import torch.nn.functional as F 8 | from inspect import isfunction 9 | from einops import rearrange, repeat 10 | from einops.layers.torch import Rearrange 11 | from utils.layers_ours import * 12 | 13 | np.random.seed(12345) 14 | 15 | 16 | def compute_rollout_attention(all_layer_matrices, start_layer=0): 17 | # adding residual consideration 18 | num_tokens = all_layer_matrices[0].shape[1] 19 | batch_size = all_layer_matrices[0].shape[0] 20 | eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device) 21 | all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))] 22 | # all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True) 23 | # for i in range(len(all_layer_matrices))] 24 | joint_attention = all_layer_matrices[start_layer] 25 | for i in range(start_layer + 1, len(all_layer_matrices)): 26 | joint_attention = all_layer_matrices[i].bmm(joint_attention) 27 | return joint_attention 28 | 29 | 30 | class Mlp(nn.Module): 31 | def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., activation='GELU'): 32 | super().__init__() 33 | out_features = out_features or in_features 34 | hidden_features = hidden_features or in_features 35 | self.fc1 = Linear(in_features, hidden_features) 36 | if activation == 'GELU': 37 | self.act = GELU() 38 | elif activation == 'ReLU': 39 | self.act = ReLU() 40 | else: 41 | raise Exception 42 | self.fc2 = Linear(hidden_features, out_features) 43 | self.drop = Dropout(drop) 44 | 45 | def forward(self, x): 46 | x = self.fc1(x) 47 | x = self.act(x) 48 | x = self.drop(x) 49 | x = self.fc2(x) 50 | x = self.drop(x) 51 | return x 52 | 53 | def relprop(self, cam, **kwargs): 54 | cam = self.drop.relprop(cam, **kwargs) 55 | cam = self.fc2.relprop(cam, **kwargs) 56 | cam = self.act.relprop(cam, **kwargs) 57 | cam = self.fc1.relprop(cam, **kwargs) 58 | return cam 59 | 60 | 61 | class Attention(nn.Module): 62 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 63 | super().__init__() 64 | self.num_heads = num_heads 65 | head_dim = dim // num_heads 66 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 67 | self.scale = head_dim ** -0.5 68 | 69 | # A = Q*K^T 70 | self.matmul1 = einsum('bhid,bhjd->bhij') 71 | # attn = A*V 72 | self.matmul2 = einsum('bhij,bhjd->bhid') 73 | 74 | self.qkv = Linear(dim, dim * 3, bias=qkv_bias) 75 | self.attn_drop = Dropout(attn_drop) 76 | self.proj = Linear(dim, dim) 77 | self.proj_drop = Dropout(proj_drop) 78 | self.softmax = Softmax(dim=-1) 79 | 80 | self.attn_cam = None 81 | self.attn = None 82 | self.v = None 83 | self.v_cam = None 84 | self.attn_gradients = None 85 | 86 | def get_attn(self): 87 | return self.attn 88 | 89 | def save_attn(self, attn): 90 | self.attn = attn 91 | 92 | def save_attn_cam(self, cam): 93 | self.attn_cam = cam 94 | 95 | def get_attn_cam(self): 96 | return self.attn_cam 97 | 98 | def get_v(self): 99 | return self.v 100 | 101 | def save_v(self, v): 102 | self.v = v 103 | 104 | def save_v_cam(self, cam): 105 | self.v_cam = cam 106 | 107 | def get_v_cam(self): 108 | return self.v_cam 109 | 110 | def save_attn_gradients(self, attn_gradients): 111 | self.attn_gradients = attn_gradients 112 | 113 | def get_attn_gradients(self): 114 | return self.attn_gradients 115 | 116 | def forward(self, x): 117 | b, n, _, h = *x.shape, self.num_heads 118 | qkv = self.qkv(x) 119 | q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h) 120 | 121 | self.save_v(v) 122 | 123 | dots = self.matmul1([q, k]) * self.scale 124 | 125 | attn = self.softmax(dots) 126 | attn = self.attn_drop(attn) 127 | 128 | self.save_attn(attn) 129 | 130 | if x.requires_grad: 131 | attn.register_hook(self.save_attn_gradients) 132 | 133 | out = self.matmul2([attn, v]) 134 | out = rearrange(out, 'b h n d -> b n (h d)') 135 | 136 | out = self.proj(out) 137 | out = self.proj_drop(out) 138 | return out 139 | 140 | def relprop(self, cam, **kwargs): 141 | cam = self.proj_drop.relprop(cam, **kwargs) 142 | cam = self.proj.relprop(cam, **kwargs) 143 | cam = rearrange(cam, 'b n (h d) -> b h n d', h=self.num_heads) 144 | 145 | # attn = A*V 146 | (cam1, cam_v) = self.matmul2.relprop(cam, **kwargs) 147 | cam1 /= 2 148 | cam_v /= 2 149 | 150 | self.save_v_cam(cam_v) 151 | self.save_attn_cam(cam1) 152 | 153 | cam1 = self.attn_drop.relprop(cam1, **kwargs) 154 | cam1 = self.softmax.relprop(cam1, **kwargs) 155 | 156 | # A = Q*K^T 157 | (cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs) 158 | cam_q /= 2 159 | cam_k /= 2 160 | 161 | cam_qkv = rearrange([cam_q, cam_k, cam_v], 'qkv b h n d -> b n (qkv h d)', qkv=3, h=self.num_heads) 162 | 163 | return self.qkv.relprop(cam_qkv, **kwargs) 164 | 165 | 166 | class Block(nn.Module): 167 | 168 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.): 169 | super().__init__() 170 | self.norm1 = LayerNorm(dim, eps=1e-6) 171 | self.attn = Attention( 172 | dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 173 | self.norm2 = LayerNorm(dim, eps=1e-6) 174 | mlp_hidden_dim = int(dim * mlp_ratio) 175 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop) 176 | 177 | self.add1 = Add() 178 | self.add2 = Add() 179 | self.clone1 = Clone() 180 | self.clone2 = Clone() 181 | 182 | def forward(self, x): 183 | x1, x2 = self.clone1(x, 2) 184 | x = self.add1([x1, self.attn(self.norm1(x2))]) 185 | x1, x2 = self.clone2(x, 2) 186 | x = self.add2([x1, self.mlp(self.norm2(x2))]) 187 | return x 188 | 189 | def relprop(self, cam, **kwargs): 190 | (cam1, cam2) = self.add2.relprop(cam, **kwargs) 191 | cam2 = self.mlp.relprop(cam2, **kwargs) 192 | cam2 = self.norm2.relprop(cam2, **kwargs) 193 | cam = self.clone2.relprop((cam1, cam2), **kwargs) 194 | 195 | (cam1, cam2) = self.add1.relprop(cam, **kwargs) 196 | cam2 = self.attn.relprop(cam2, **kwargs) 197 | cam2 = self.norm1.relprop(cam2, **kwargs) 198 | cam = self.clone1.relprop((cam1, cam2), **kwargs) 199 | return cam 200 | 201 | 202 | class LRP: 203 | def __init__(self, model): 204 | self.model = model 205 | self.model.eval() 206 | 207 | def generate_LRP(self, data, index=None, method="transformer_attribution", is_ablation=False, start_layer=0): 208 | if len(data) == 2: 209 | (input, targets) = data 210 | output = self.model(input.float().cuda()) 211 | elif len(data) == 3: 212 | (input, tissue_x, targets) = data 213 | output = self.model(input.float().cuda(), tissue_x.float().cuda()) 214 | else: 215 | raise Exception 216 | # output = self.model(input) 217 | kwargs = {"alpha": 1} 218 | 219 | assert index is not None 220 | 221 | one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32) 222 | one_hot[0, index] = 1 223 | one_hot_vector = one_hot 224 | one_hot = torch.from_numpy(one_hot).requires_grad_(True) 225 | one_hot = torch.sum(one_hot.cuda() * output) 226 | 227 | self.model.zero_grad() 228 | one_hot.backward(retain_graph=True) 229 | 230 | return self.model.relprop(torch.tensor(one_hot_vector).cuda(), method=method, is_ablation=is_ablation, 231 | start_layer=start_layer, **kwargs) 232 | 233 | 234 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 235 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 236 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 237 | def norm_cdf(x): 238 | # Computes standard normal cumulative distribution function 239 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 240 | 241 | if (mean < a - 2 * std) or (mean > b + 2 * std): 242 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 243 | "The distribution of values may be incorrect.", 244 | stacklevel=2) 245 | 246 | with torch.no_grad(): 247 | # Values are generated by using a truncated uniform distribution and 248 | # then using the inverse CDF for the normal distribution. 249 | # Get upper and lower cdf values 250 | l = norm_cdf((a - mean) / std) 251 | u = norm_cdf((b - mean) / std) 252 | 253 | # Uniformly fill tensor with values from [l, u], then translate to 254 | # [2l-1, 2u-1]. 255 | tensor.uniform_(2 * l - 1, 2 * u - 1) 256 | 257 | # Use inverse cdf transform for normal distribution to get truncated 258 | # standard normal 259 | tensor.erfinv_() 260 | 261 | # Transform to proper mean, std 262 | tensor.mul_(std * math.sqrt(2.)) 263 | tensor.add_(mean) 264 | 265 | # Clamp to ensure it's in the proper range 266 | tensor.clamp_(min=a, max=b) 267 | return tensor 268 | 269 | 270 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 271 | # type: (Tensor, float, float, float, float) -> Tensor 272 | r"""Fills the input Tensor with values drawn from a truncated 273 | normal distribution. The values are effectively drawn from the 274 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 275 | with values outside :math:`[a, b]` redrawn until they are within 276 | the bounds. The method used for generating the random values works 277 | best when :math:`a \leq \text{mean} \leq b`. 278 | Args: 279 | tensor: an n-dimensional `torch.Tensor` 280 | mean: the mean of the normal distribution 281 | std: the standard deviation of the normal distribution 282 | a: the minimum cutoff value 283 | b: the maximum cutoff value 284 | Examples: 285 | >>> w = torch.empty(3, 5) 286 | >>> nn.init.trunc_normal_(w) 287 | """ 288 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 289 | 290 | 291 | class PathwayDrop(nn.Module): 292 | 293 | def __init__(self, p=0.5): 294 | super().__init__() 295 | self.p = p 296 | 297 | def forward(self, x): 298 | if self.training: 299 | num_tokens = x.shape[1] 300 | for idx in range(num_tokens): 301 | if np.random.binomial(1, self.p): 302 | x[:, idx, :] = 0 303 | 304 | return x 305 | 306 | 307 | class DeePathNet(nn.Module): 308 | 309 | def __init__(self, num_omics, out_dim, gene_to_id, id_to_gene, pathway_dict, non_cancer_genes, embed_dim=2048, 310 | depth=2, 311 | num_heads=1, mlp_ratio=2, out_mlp_ratio=4, qkv_bias=True, drop_rate=0., attn_drop_rate=0., 312 | pathway_drop_rate=0, 313 | only_cancer_genes=False, tissues=None): 314 | super().__init__() 315 | self.pathway_dict = pathway_dict 316 | self.pathway_layers = nn.ModuleDict() 317 | self.non_cancer_genes = sorted(non_cancer_genes) 318 | self.gene_to_id = gene_to_id 319 | self.id_to_gene = id_to_gene 320 | self.num_omics = num_omics 321 | self.only_cancer_genes = only_cancer_genes 322 | self.out_dim = out_dim 323 | self.tissues = tissues 324 | 325 | for key in self.pathway_dict: 326 | num_genes_in_pathway = len(self.pathway_dict[key]) 327 | pathway_width = embed_dim 328 | self.pathway_layers[key] = nn.Linear(num_genes_in_pathway * num_omics, pathway_width) 329 | 330 | self.non_cancer_layer = nn.Linear(len(non_cancer_genes) * num_omics, 331 | embed_dim) if not only_cancer_genes else None 332 | # self.non_cancer_layer = nn.Linear(len(non_cancer_genes) * num_omics, 333 | # embed_dim) 334 | 335 | pathway_embedding_num = 1 336 | if not only_cancer_genes: 337 | pathway_embedding_num += 1 338 | if self.tissues: 339 | pathway_embedding_num += 1 340 | 341 | self.pathway_embedding = nn.Parameter(torch.randn(1, len(pathway_dict) + pathway_embedding_num, embed_dim)) 342 | self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) 343 | 344 | self.tissue_embedding = nn.Linear(len(self.tissues), embed_dim) if self.tissues else None 345 | 346 | # trunc_normal_(self.cls_token, std=.02) 347 | # self.apply(self._init_weights) 348 | 349 | self.blocks = nn.ModuleList([ 350 | Block( 351 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, 352 | drop=drop_rate, attn_drop=attn_drop_rate) 353 | for i in range(depth)]) 354 | 355 | self.norm = LayerNorm(embed_dim) 356 | self.pathway_drop = PathwayDrop(pathway_drop_rate) 357 | 358 | self.head = Mlp(embed_dim, int(embed_dim * out_mlp_ratio), out_dim, 359 | activation="ReLU") if out_mlp_ratio > 1 else Linear(embed_dim, out_dim) 360 | 361 | self.pool = IndexSelect() 362 | self.add = Add() 363 | 364 | self.inp_grad = None 365 | 366 | def save_inp_grad(self, grad): 367 | self.inp_grad = grad 368 | 369 | def get_inp_grad(self): 370 | return self.inp_grad 371 | 372 | def _init_weights(self, m): 373 | if isinstance(m, nn.Linear): 374 | trunc_normal_(m.weight, std=.02) 375 | if isinstance(m, nn.Linear) and m.bias is not None: 376 | nn.init.constant_(m.bias, 0) 377 | elif isinstance(m, nn.LayerNorm): 378 | nn.init.constant_(m.bias, 0) 379 | nn.init.constant_(m.weight, 1.0) 380 | 381 | @property 382 | def no_weight_decay(self): 383 | return {'pathway_embedding', 'cls_token'} 384 | 385 | def forward(self, x, tissue_x=None, return_embedding=False): 386 | B = x.shape[0] 387 | pathway_x = [] 388 | for key in self.pathway_dict: 389 | gene_ids = [self.gene_to_id[x] for x in self.pathway_dict[key]] 390 | tmp = self.pathway_layers[key]( 391 | x[:, gene_ids, :].reshape(-1, len(gene_ids) * self.num_omics)) 392 | # tmp = tmp / torch.linalg.norm(tmp, dim=1, keepdim=True).expand_as(tmp) 393 | pathway_x.append(tmp) # shape:[b,g,3] -> [b,g] 394 | pathway_x = torch.stack(pathway_x, dim=1) # [b, p, p_embed] 395 | if self.only_cancer_genes: 396 | x = pathway_x 397 | else: 398 | non_cancer_gene_ids = sorted([self.gene_to_id[x] for x in self.non_cancer_genes]) 399 | non_cancer_gene_x = self.non_cancer_layer( 400 | x[:, non_cancer_gene_ids, :].reshape(-1, len(non_cancer_gene_ids) * self.num_omics)) 401 | non_cancer_gene_x = non_cancer_gene_x.unsqueeze(1) 402 | x = torch.cat([pathway_x, non_cancer_gene_x], dim=1) # [b, p+1, p_embed] 403 | 404 | if self.tissue_embedding: 405 | tissue_x = self.tissue_embedding(tissue_x) 406 | x = torch.cat([x, tissue_x.unsqueeze(1)], dim=1) 407 | x = self.pathway_drop(x) 408 | 409 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 410 | x = torch.cat((cls_tokens, x), dim=1) 411 | x = self.add([x, self.pathway_embedding]) 412 | 413 | if x.requires_grad: 414 | x.register_hook(self.save_inp_grad) 415 | 416 | for blk in self.blocks: 417 | x = blk(x) 418 | 419 | x = self.norm(x) 420 | if return_embedding: 421 | return x 422 | x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device)) 423 | x = x.squeeze(1) 424 | 425 | x = self.head(x) 426 | return x 427 | 428 | def relprop(self, cam=None, method="transformer_attribution", is_ablation=False, start_layer=0, **kwargs): 429 | # print(kwargs) 430 | # print("conservation 1", cam.sum()) 431 | cam = self.head.relprop(cam, **kwargs) 432 | cam = cam.unsqueeze(1) 433 | cam = self.pool.relprop(cam, **kwargs) 434 | cam = self.norm.relprop(cam, **kwargs) 435 | for blk in reversed(self.blocks): 436 | cam = blk.relprop(cam, **kwargs) 437 | 438 | # print("conservation 2", cam.sum()) 439 | # print("min", cam.min()) 440 | 441 | if method == "full": 442 | (cam, _) = self.add.relprop(cam, **kwargs) 443 | cam = cam[:, 1:] 444 | cam = self.patch_embed.relprop(cam, **kwargs) 445 | # sum on channels 446 | cam = cam.sum(dim=1) 447 | return cam 448 | 449 | elif method == "rollout": 450 | # cam rollout 451 | attn_cams = [] 452 | for blk in self.blocks: 453 | attn_heads = blk.attn.get_attn_cam().clamp(min=0) 454 | avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach() 455 | attn_cams.append(avg_heads) 456 | cam = compute_rollout_attention(attn_cams, start_layer=start_layer) 457 | cam = cam[:, 0, 1:] 458 | return cam 459 | 460 | # our method, method name grad is legacy 461 | elif method == "transformer_attribution" or method == "grad": 462 | cams = [] 463 | for blk in self.blocks: 464 | grad = blk.attn.get_attn_gradients() 465 | cam = blk.attn.get_attn_cam() 466 | cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1]) 467 | grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1]) 468 | cam = grad * cam 469 | cam = cam.clamp(min=0).mean(dim=0) 470 | cams.append(cam.unsqueeze(0)) 471 | rollout = compute_rollout_attention(cams, start_layer=start_layer) 472 | cam = rollout[:, 0, 1:] 473 | return cam 474 | 475 | elif method == "last_layer": 476 | cam = self.blocks[-1].attn.get_attn_cam() 477 | cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1]) 478 | if is_ablation: 479 | grad = self.blocks[-1].attn.get_attn_gradients() 480 | grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1]) 481 | cam = grad * cam 482 | cam = cam.clamp(min=0).mean(dim=0) 483 | cam = cam[0, 1:] 484 | return cam 485 | 486 | elif method == "last_layer_attn": 487 | cam = self.blocks[-1].attn.get_attn() 488 | cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1]) 489 | cam = cam.clamp(min=0).mean(dim=0) 490 | cam = cam[0, 1:] 491 | return cam 492 | 493 | elif method == "second_layer": 494 | cam = self.blocks[1].attn.get_attn_cam() 495 | cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1]) 496 | if is_ablation: 497 | grad = self.blocks[1].attn.get_attn_gradients() 498 | grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1]) 499 | cam = grad * cam 500 | cam = cam.clamp(min=0).mean(dim=0) 501 | cam = cam[0, 1:] 502 | return cam 503 | -------------------------------------------------------------------------------- /scripts/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch.nn import Linear, LayerNorm, ReLU 3 | import pandas as pd 4 | import math 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error, accuracy_score, roc_auc_score, f1_score 11 | 12 | from tqdm import trange 13 | from scipy.stats import pearsonr 14 | import time 15 | from scipy.special import softmax 16 | 17 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 18 | MISSING_NUM = -100 19 | 20 | 21 | def logistic(x): 22 | return 1 / (1 + torch.exp(-x)) 23 | 24 | 25 | def corr_loss(output, target): 26 | x = output 27 | y = target 28 | 29 | vx = x - torch.mean(x) 30 | vy = y - torch.mean(y) 31 | loss = 50 * (1 - torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)))) 32 | 33 | return loss 34 | 35 | 36 | class OmicLinearLayer(nn.Module): 37 | """ Custom Linear layer but mimics a standard linear layer """ 38 | 39 | def __init__(self, size_in, num_omics): 40 | super().__init__() 41 | self.size_in, self.num_omics = size_in, num_omics 42 | weights = torch.Tensor(num_omics, size_in) 43 | self.weights = nn.Parameter(weights) # nn.Parameter is a Tensor that's a module parameter. 44 | bias = torch.Tensor(size_in) 45 | self.bias = nn.Parameter(bias) 46 | 47 | # initialize weights and biases 48 | nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) # weight init 49 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights) 50 | bound = 1 / math.sqrt(fan_in) 51 | nn.init.uniform_(self.bias, -bound, bound) # bias init 52 | 53 | def forward(self, x): 54 | w_times_x = (x * self.weights.t()).sum(dim=2) 55 | return torch.add(w_times_x, self.bias) # w times x + b 56 | 57 | 58 | class SingleOmicDataset(Dataset): 59 | def __init__(self, data_df, purpose_data_df, mode, logger=None): 60 | assert mode in ['train', 'val', 'test'] 61 | 62 | self.df = np.nan_to_num(data_df, nan=0) 63 | self.purpose_data = np.nan_to_num(purpose_data_df, nan=MISSING_NUM) 64 | 65 | assert self.df.shape[0] == self.purpose_data.shape[0], f"{self.df.shape[0]}, {self.purpose_data.shape[0]}" 66 | self.mode = mode 67 | if logger: 68 | logger.info(f"mode: {mode}, df shape: {self.df.shape}, purpose_data shape: {self.purpose_data.shape}") 69 | 70 | def __getitem__(self, index): 71 | """ Returns: tuple (sample, target) """ 72 | data = self.df[index, :] 73 | if len(self.purpose_data.shape) > 1: 74 | target = self.purpose_data[index, :] # the first col is cell line name 75 | else: 76 | target = self.purpose_data[index] 77 | 78 | # no other preprocessing for now 79 | 80 | return data, target 81 | 82 | def __len__(self): 83 | return self.df.shape[0] 84 | 85 | 86 | def get_multiomic_df(df, omics_types): 87 | gene_columns = [x for x in df.columns if 'tissue' not in x] 88 | tissue_df = df[[x for x in df.columns if 'tissue' in x]].values 89 | genes = np.unique(([x.split("_")[0] for x in gene_columns])) 90 | 91 | not_covered = [f"{x}_{omic}" for x in genes for omic in omics_types if f"{x}_{omic}" not in gene_columns] 92 | 93 | df_zeros = pd.DataFrame(np.zeros((df.shape[0], len(not_covered))), columns=not_covered, index=df.index) 94 | df_combined = pd.concat([df, df_zeros], axis=1) 95 | df_multiomic = np.zeros((df.shape[0], len(genes), len([x for x in omics_types if 96 | x != 'tissue']))) 97 | for i in range(len(genes)): 98 | df_multiomic[:, i, :] = df_combined[[f"{genes[i]}_{omic}" for omic in omics_types]].values 99 | 100 | genes_to_id = dict(zip(genes, range(len(genes)))) 101 | id_to_genes = dict(zip(range(len(genes)), genes)) 102 | return df_multiomic, genes_to_id, id_to_genes, tissue_df 103 | 104 | 105 | class MultiOmicDataset(Dataset): 106 | def __init__(self, df, purpose_data_df, mode, omics_types, logger=None, with_tissue=False): 107 | assert mode in ['train', 'val', 'test'] 108 | self.df, self.genes_to_id, self.id_to_genes, self.tissue_df = get_multiomic_df(df, omics_types) 109 | self.omics_types = omics_types 110 | self.purpose_data = np.nan_to_num(purpose_data_df, nan=MISSING_NUM) 111 | self.with_tissue = with_tissue 112 | 113 | assert self.df.shape[0] == self.purpose_data.shape[0], f"{self.df.shape[0]}, {self.purpose_data.shape[0]}" 114 | self.mode = mode 115 | if logger: 116 | logger.info(f"mode: {mode}, df shape: {self.df.shape}, purpose_data shape: {self.purpose_data.shape}") 117 | 118 | def __getitem__(self, index): 119 | """ Returns: tuple (sample, target) """ 120 | data = self.df[index, :, :] # (b, genes, omics) 121 | tissue_data = self.tissue_df[index, :] 122 | target = self.purpose_data[index, :] 123 | if self.with_tissue: 124 | return data, tissue_data, target 125 | else: 126 | return data, target 127 | 128 | def __len__(self): 129 | return self.df.shape[0] 130 | 131 | 132 | class MultiOmicMulticlassDataset(Dataset): 133 | def __init__(self, df, purpose_data_df, mode, omics_types, class_name_to_id, logger=None): 134 | assert mode in ['train', 'val', 'test'] 135 | self.df, self.genes_to_id, self.id_to_genes, self.tissue_df = get_multiomic_df(df, omics_types) 136 | self.df = np.nan_to_num(self.df, nan=0) 137 | self.omics_types = omics_types 138 | self.purpose_data = np.nan_to_num(purpose_data_df, nan=MISSING_NUM) 139 | self.class_name_to_id = class_name_to_id 140 | 141 | assert self.df.shape[0] == self.purpose_data.shape[0], f"{self.df.shape[0]}, {self.purpose_data.shape[0]}" 142 | self.mode = mode 143 | if logger: 144 | logger.info(f"mode: {mode}, df shape: {self.df.shape}, purpose_data shape: {self.purpose_data.shape}") 145 | logger.info(self.class_name_to_id) 146 | 147 | def __getitem__(self, index): 148 | """ Returns: tuple (sample, target) """ 149 | data = self.df[index, :, :] # (b, genes, omics) 150 | target = self.purpose_data[index, :] 151 | target_id = np.array([self.class_name_to_id[x] for x in target]) 152 | return data, target_id 153 | 154 | def __len__(self): 155 | return self.df.shape[0] 156 | 157 | 158 | class AverageMeter: 159 | ''' Computes and stores the average and current value ''' 160 | 161 | def __init__(self): 162 | self.reset() 163 | 164 | def reset(self): 165 | self.val = 0.0 166 | self.avg = 0.0 167 | self.sum = 0.0 168 | self.count = 0 169 | 170 | def update(self, val, n=1): 171 | self.val = val 172 | self.sum += val * n 173 | self.count += n 174 | self.avg = self.sum / self.count 175 | 176 | 177 | def train(train_loader, model, criterion, optimizer, epoch, logger): 178 | batch_time = AverageMeter() 179 | losses = AverageMeter() 180 | avg_r2 = AverageMeter() 181 | avg_mae = AverageMeter() 182 | avg_rmse = AverageMeter() 183 | avg_corr = AverageMeter() 184 | 185 | model.train() 186 | 187 | end = time.time() 188 | lr_str = '' 189 | 190 | for i, data in enumerate(train_loader): 191 | 192 | if len(data) == 2: 193 | (input_, targets) = data 194 | output = model(input_.float().to(device)) 195 | elif len(data) == 3: 196 | (input_, tissue_x, targets) = data 197 | output = model(input_.float().to(device), tissue_x.float().to(device)) 198 | else: 199 | raise Exception 200 | 201 | output[targets == MISSING_NUM] = MISSING_NUM 202 | 203 | loss = criterion(output, targets.float().to(device)) 204 | targets = targets.cpu().numpy() 205 | 206 | confs = output.detach().cpu().numpy() 207 | if not np.isinf(confs).any() and not np.isnan(confs).any(): 208 | try: 209 | avg_r2.update(np.median( 210 | [r2_score(targets[targets[:, i] != MISSING_NUM, i], confs[targets[:, i] != MISSING_NUM, i]) 211 | for i in range(confs.shape[1])])) 212 | avg_mae.update(np.median( 213 | [mean_absolute_error(targets[targets[:, i] != MISSING_NUM, i], 214 | confs[targets[:, i] != MISSING_NUM, i]) 215 | for i in range(confs.shape[1])])) 216 | avg_rmse.update(np.median( 217 | [mean_squared_error(targets[targets[:, i] != MISSING_NUM, i], 218 | confs[targets[:, i] != MISSING_NUM, i], 219 | squared=True) 220 | for i in range(confs.shape[1])])) 221 | avg_corr.update(np.median( 222 | [pearsonr(targets[targets[:, i] != MISSING_NUM, i], confs[targets[:, i] != MISSING_NUM, i])[0] 223 | for i in range(confs.shape[1])][0])) 224 | except ValueError: 225 | logger.info("skipping training score") 226 | 227 | losses.update(loss.data.item(), input_.size(0)) 228 | optimizer.zero_grad() 229 | loss.backward() 230 | optimizer.step() 231 | 232 | batch_time.update(time.time() - end) 233 | end = time.time() 234 | 235 | logger.info(f'{epoch} \t' 236 | f'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 237 | f'loss {losses.val:.4f} ({losses.avg:.4f})\t' 238 | f'corr {avg_corr.val:.4f} ({avg_corr.avg:.4f})\t' 239 | f'R2 {avg_r2.val:.4f} ({avg_r2.avg:.4f})\t' 240 | f'MAE {avg_mae.val:.4f} ({avg_mae.avg:.4f})\t' 241 | f'RMSE {avg_rmse.val:.4f} ({avg_rmse.avg:.4f})\t' + lr_str) 242 | 243 | return avg_r2.avg 244 | 245 | 246 | def inference(data_loader, model): 247 | ''' Returns predictions and targets, if any. ''' 248 | model.eval() 249 | 250 | all_confs, all_targets = [], [] 251 | with torch.no_grad(): 252 | for i, data in enumerate(data_loader): 253 | if len(data) == 2: 254 | (input_, target) = data 255 | output = model(input_.float().to(device)) 256 | elif len(data) == 3: 257 | (input_, tissue_type, target) = data 258 | output = model(input_.float().to(device), tissue_type.float().to(device)) 259 | else: 260 | raise Exception 261 | 262 | # output = model(input_.float().to(device)) 263 | all_confs.append(output) 264 | 265 | if target is not None: 266 | all_targets.append(target) 267 | 268 | confs = torch.cat(all_confs) 269 | targets = torch.cat(all_targets) if len(all_targets) else None 270 | targets = targets.cpu().numpy() 271 | confs = confs.cpu().numpy() 272 | 273 | return confs, targets 274 | 275 | 276 | def validate(val_loader, model, val_drug_ids, run=None, epoch=None, val_score_dict=None): 277 | confs, targets = inference(val_loader, model) 278 | 279 | r2_avg, mae_avg, rmse_avg, corr_avg = None, None, None, None 280 | if not np.isinf(confs).any() and not np.isnan(confs).any(): 281 | if 'drug_id' in val_score_dict: 282 | val_score_dict['drug_id'].extend(val_drug_ids) 283 | elif 'Gene' in val_score_dict: 284 | val_score_dict['Gene'].extend(val_drug_ids) 285 | else: 286 | raise 287 | val_score_dict['run'].extend([run] * len(val_drug_ids)) 288 | val_score_dict['epoch'].extend([epoch] * len(val_drug_ids)) 289 | 290 | r2 = [r2_score(targets[targets[:, i] != MISSING_NUM, i], confs[targets[:, i] != MISSING_NUM, i]) 291 | for i in range(confs.shape[1])] 292 | r2_avg = np.median(r2) 293 | 294 | mae = [mean_absolute_error(targets[targets[:, i] != MISSING_NUM, i], confs[targets[:, i] != MISSING_NUM, i]) 295 | for i in range(confs.shape[1])] 296 | mae_avg = np.median(mae) 297 | 298 | rmse = [mean_squared_error(targets[targets[:, i] != MISSING_NUM, i], confs[targets[:, i] != MISSING_NUM, i], 299 | squared=False) 300 | for i in range(confs.shape[1])] 301 | rmse_avg = np.median(rmse) 302 | 303 | corr = [pearsonr(targets[targets[:, i] != MISSING_NUM, i], confs[targets[:, i] != MISSING_NUM, i])[0] 304 | for i in range(confs.shape[1])] 305 | corr_avg = np.median(corr) 306 | 307 | val_score_dict['mae'].extend(mae) 308 | val_score_dict['rmse'].extend(rmse) 309 | val_score_dict['corr'].extend(corr) 310 | val_score_dict['r2'].extend(r2) 311 | 312 | return r2_avg, mae_avg, rmse_avg, corr_avg 313 | 314 | 315 | def get_model_filename(drug_id): 316 | drug_name = drug_id.replace(';', '_') 317 | drug_name = drug_name.replace('/', '_') 318 | drug_name = drug_name.replace(' ', '') 319 | drug_name = drug_name.replace('(', '') 320 | drug_name = drug_name.replace(')', '') 321 | drug_name = drug_name.replace('+', '_') 322 | drug_name = drug_name.replace(',', '_') 323 | return drug_name 324 | 325 | 326 | def train_loop(epochs, train_loader, val_loader, model, criterion, optimizer, logger, stamp, 327 | configs, 328 | lr_scheduler=None, 329 | val_drug_ids=None, 330 | run=None, val_score_dict=None, id_to_class_name=None): 331 | train_res = [] 332 | val_res = [] 333 | if configs['task'] == 'regression': 334 | best_r2 = 0.05 335 | for epoch in trange(1, epochs + 1): 336 | if lr_scheduler: 337 | logger.info(f"learning rate: {lr_scheduler.get_last_lr()}") 338 | train_score = train(train_loader, 339 | model, 340 | criterion, 341 | optimizer, 342 | epoch, 343 | logger) 344 | 345 | train_res.append(train_score) 346 | if lr_scheduler: 347 | lr_scheduler.step() 348 | 349 | r2, mae, rmse, corr = validate(val_loader, model, val_drug_ids, run=run, epoch=epoch, 350 | val_score_dict=val_score_dict) 351 | 352 | if r2 and mae and rmse and corr: 353 | logger.info(f"Epoch {epoch} validation corr:{corr:4f}, R2:{r2:4f}, MAE:{mae:4f}, RMSE:{rmse:4f}") 354 | else: 355 | logger.info(f"Epoch {epoch} validation Inf") 356 | if configs['save_checkpoints'] and best_r2 < r2: 357 | best_r2 = max(best_r2, r2) 358 | if len(val_drug_ids) == 1: 359 | model_path = f"{configs['work_dir']}/{stamp}_{get_model_filename(val_drug_ids[0])}.pth" 360 | else: 361 | model_path = f"{configs['work_dir']}/{stamp}{configs['suffix']}.pth" 362 | torch.save(model.state_dict(), model_path) 363 | return None 364 | 365 | elif configs['task'] == 'classification': 366 | best_auc = 0.71 367 | criterion = nn.BCEWithLogitsLoss() 368 | for epoch in trange(1, epochs + 1): 369 | if lr_scheduler: 370 | logger.info(f"learning rate: {lr_scheduler.get_lr()}") 371 | train_score = train_cls(train_loader, 372 | model, 373 | criterion, 374 | optimizer, 375 | epoch, 376 | logger) 377 | 378 | train_res.append(train_score) 379 | if lr_scheduler: 380 | lr_scheduler.step() 381 | 382 | accuracy, auc = validate_cls(val_loader, model, val_drug_ids, run=run, epoch=epoch, 383 | val_score_dict=val_score_dict) 384 | if accuracy: 385 | logger.info(f"Epoch {epoch} validation accuracy:{accuracy:4f}, AUC:{auc:4f}") 386 | else: 387 | logger.info(f"Epoch {epoch} validation Inf") 388 | if configs['save_checkpoints'] and auc > best_auc: 389 | best_auc = max(best_auc, auc) 390 | if len(val_drug_ids) == 1: 391 | model_path = f"{configs['work_dir']}/{stamp}_{get_model_filename(val_drug_ids[0])}.pth" 392 | else: 393 | model_path = f"{configs['work_dir']}/{stamp}{configs['suffix']}.pth" 394 | torch.save(model.state_dict(), model_path) 395 | 396 | torch.cuda.empty_cache() 397 | return None 398 | 399 | elif configs['task'] == 'multiclass': 400 | best_avg_acc = 0.7 401 | criterion = nn.CrossEntropyLoss() 402 | all_val_res = [] 403 | for epoch in trange(1, epochs + 1): 404 | if lr_scheduler: 405 | logger.info(f"learning rate: {lr_scheduler.get_lr()}") 406 | train_score = train_cls_multiclass(train_loader, 407 | model, 408 | criterion, 409 | optimizer, 410 | epoch, 411 | logger) 412 | 413 | train_res.append(train_score) 414 | if lr_scheduler: 415 | lr_scheduler.step() 416 | 417 | top1_acc, top3_acc, f1, roc_auc, val_res_perclass = validate_cls_multiclass(val_loader, model, run=run, 418 | epoch=epoch, 419 | val_score_dict=val_score_dict) 420 | all_val_res.append(val_res_perclass) 421 | logger.info( 422 | f"Epoch {epoch} validation top1_acc:{top1_acc:4f}, f1:{f1:4f}, AUC:{roc_auc:4f}") 423 | avg_acc = np.mean([top1_acc, top3_acc]) 424 | if configs['save_checkpoints'] and avg_acc > best_avg_acc: 425 | best_avg_acc = max(best_avg_acc, avg_acc) 426 | if len(val_drug_ids) == 1: 427 | model_path = f"{configs['work_dir']}/{stamp}_{get_model_filename(val_drug_ids[0])}.pth" 428 | else: 429 | model_path = f"{configs['work_dir']}/{stamp}{configs['suffix']}.pth" 430 | torch.save(model.state_dict(), model_path) 431 | 432 | torch.cuda.empty_cache() 433 | return pd.concat(all_val_res) 434 | else: 435 | raise Exception 436 | 437 | 438 | def train_cls(train_loader, model, criterion, optimizer, epoch, logger): 439 | batch_time = AverageMeter() 440 | losses = AverageMeter() 441 | avg_accuracy = AverageMeter() 442 | avg_auc = AverageMeter() 443 | 444 | model.train() 445 | 446 | end = time.time() 447 | lr_str = '' 448 | 449 | for i, data in enumerate(train_loader): 450 | 451 | if len(data) == 2: 452 | (input_, targets) = data 453 | output = model(input_.float().to(device)) 454 | elif len(data) == 3: 455 | (input_, tissue_x, targets) = data 456 | output = model(input_.float().to(device), tissue_x.float().to(device)) 457 | else: 458 | raise Exception 459 | 460 | output[targets == MISSING_NUM] = MISSING_NUM 461 | 462 | loss = criterion(output, targets.float().to(device)) 463 | targets = targets.cpu().numpy() 464 | 465 | confs = torch.sigmoid(output).detach().cpu().numpy() 466 | predicts = (confs > 0.5).astype(int) 467 | 468 | avg_auc.update(np.median( 469 | [roc_auc_score(targets[targets[:, i] != MISSING_NUM, i], 470 | confs[targets[:, i] != MISSING_NUM, i]) 471 | for i in range(confs.shape[1])])) 472 | 473 | avg_accuracy.update(np.median( 474 | [accuracy_score(targets[targets[:, i] != MISSING_NUM, i], 475 | predicts[targets[:, i] != MISSING_NUM, i]) 476 | for i in range(predicts.shape[1])])) 477 | 478 | losses.update(loss.data.item(), input_.size(0)) 479 | optimizer.zero_grad() 480 | loss.backward() 481 | optimizer.step() 482 | 483 | batch_time.update(time.time() - end) 484 | end = time.time() 485 | 486 | logger.info(f'{epoch} \t' 487 | f'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 488 | f'loss {losses.val:.4f} ({losses.avg:.4f})\t' 489 | f'avg_accuracy {avg_accuracy.val:.4f} ({avg_accuracy.avg:.4f})\t' 490 | f'avg_auc {avg_auc.val:.4f} ({avg_auc.avg:.4f})\t' + lr_str) 491 | 492 | return avg_accuracy.avg 493 | 494 | 495 | def validate_cls(val_loader, model, val_drug_ids, run=None, epoch=None, val_score_dict=None): 496 | confs, targets = inference(val_loader, model) 497 | confs = torch.from_numpy(confs) 498 | confs = torch.sigmoid(confs).numpy() 499 | 500 | predicts = (confs > 0.5).astype(int) 501 | acc_avg, auc_avg = None, None 502 | if not np.isinf(confs).any() and not np.isnan(confs).any(): 503 | if 'drug_id' in val_score_dict: 504 | val_score_dict['drug_id'].extend(val_drug_ids) 505 | elif 'Gene' in val_score_dict: 506 | val_score_dict['Gene'].extend(val_drug_ids) 507 | else: 508 | raise Exception 509 | 510 | val_score_dict['run'].extend([run] * len(val_drug_ids)) 511 | val_score_dict['epoch'].extend([epoch] * len(val_drug_ids)) 512 | accuracy = [accuracy_score(targets[targets[:, i] != MISSING_NUM, i], predicts[targets[:, i] != MISSING_NUM, i]) 513 | for i in range(predicts.shape[1])] 514 | acc_avg = np.median(accuracy) 515 | 516 | auc = [roc_auc_score(targets[targets[:, i] != MISSING_NUM, i], confs[targets[:, i] != MISSING_NUM, i]) 517 | for i in range(confs.shape[1])] 518 | auc_avg = np.median(auc) 519 | 520 | val_score_dict['accuracy'].extend(accuracy) 521 | val_score_dict['auc'].extend(auc) 522 | 523 | return acc_avg, auc_avg 524 | 525 | 526 | def train_cls_multiclass(train_loader, model, criterion, optimizer, epoch, logger): 527 | batch_time = AverageMeter() 528 | losses = AverageMeter() 529 | avg_top1_acc = AverageMeter() 530 | avg_top3_acc = AverageMeter() 531 | avg_f1 = AverageMeter() 532 | avg_auc = AverageMeter() 533 | 534 | model.train() 535 | 536 | end = time.time() 537 | lr_str = '' 538 | 539 | for i, data in enumerate(train_loader): 540 | (input_, targets) = data 541 | output = model(input_.float().to(device)) 542 | 543 | loss = criterion(output, targets.flatten().long().to(device)) 544 | output = torch.softmax(output, dim=-1) 545 | targets = targets.cpu().numpy() 546 | 547 | confs = output.detach().cpu().numpy() 548 | predicts = np.argsort(-confs, axis=1) 549 | targets = targets.flatten() 550 | top1_acc = np.sum((predicts[:, 0] == targets)) / targets.shape[0] 551 | top3_acc = np.sum(np.any(predicts[:, :3] == np.expand_dims(targets, axis=1), axis=1)) / targets.shape[0] 552 | f1 = f1_score(targets, predicts[:, 0], average='macro') 553 | # roc_auc = roc_auc_score(targets, confs, multi_class='ovo') 554 | 555 | avg_top1_acc.update(top1_acc) 556 | avg_top3_acc.update(top3_acc) 557 | avg_f1.update(f1) 558 | # avg_auc.update(roc_auc) 559 | 560 | losses.update(loss.data.item(), input_.size(0)) 561 | optimizer.zero_grad() 562 | loss.backward() 563 | optimizer.step() 564 | 565 | batch_time.update(time.time() - end) 566 | end = time.time() 567 | 568 | logger.info(f'{epoch} \t' 569 | f'time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 570 | f'loss {losses.val:.4f} ({losses.avg:.4f})\t' 571 | f'avg_top1_acc {avg_top1_acc.val:.4f} ({avg_top1_acc.avg:.4f})\t' 572 | f'avg_f1 {avg_f1.val:.4f} ({avg_f1.avg:.4f})' + lr_str) 573 | 574 | return avg_top1_acc.avg 575 | 576 | 577 | def validate_cls_multiclass(val_loader, model, run=None, epoch=None, val_score_dict=None): 578 | confs, targets = inference(val_loader, model) 579 | predicts = np.argsort(-confs, axis=1) 580 | confs = softmax(confs, axis=-1) 581 | val_res_perclass = {} 582 | 583 | val_score_dict['run'].append(run) 584 | val_score_dict['epoch'].append(epoch) 585 | targets = targets.flatten() 586 | top1_acc = np.sum((predicts[:, 0] == targets)) / targets.shape[0] 587 | 588 | if np.unique(targets).size > 2: 589 | top3_acc = np.sum(np.any(predicts[:, :3] == np.expand_dims(targets, axis=1), axis=1)) / targets.shape[0] 590 | f1 = f1_score(targets, predicts[:, 0], average='macro') 591 | unique_train, unique_test = np.unique(predicts), np.unique(targets) 592 | if set(unique_train) == set(unique_test): 593 | roc_auc = roc_auc_score(targets, confs, multi_class='ovo') 594 | else: 595 | roc_auc = np.nan 596 | else: 597 | top3_acc = 1 598 | f1 = f1_score(targets, predicts[:, 0]) 599 | roc_auc = roc_auc_score(targets, confs[:, 1]) 600 | 601 | val_res_perclass['run'] = [run] * len(predicts) 602 | val_res_perclass['epoch'] = [epoch] * len(predicts) 603 | val_res_perclass['y_pred'] = predicts[:, 0] 604 | val_res_perclass['y_true'] = targets 605 | for i in range(confs.shape[1]): 606 | val_res_perclass[f"feature_{i}"] = confs[:, i] 607 | 608 | val_score_dict['top1_acc'].append(top1_acc) 609 | val_score_dict['top3_acc'].append(top3_acc) 610 | val_score_dict['f1'].append(f1) 611 | val_score_dict['roc_auc'].append(roc_auc) 612 | 613 | return top1_acc, top3_acc, f1, roc_auc, pd.DataFrame(val_res_perclass) 614 | -------------------------------------------------------------------------------- /scripts/transformer_explantion_cancer_type.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to run pathway-level model explanation for cancer type 3 | E.g. python scripts/transformer_explantion_cancer_type.py configs/tcga_brca_subtypes/mutation_cnv_rna/deepathnet_allgenes_mutation_cnv_rna.json 4 | """ 5 | import json 6 | import os 7 | import sys 8 | 9 | from sklearn.model_selection import KFold 10 | 11 | sys.path.append(os.getcwd() + '/..') 12 | from models import * 13 | from model_transformer_lrp import DeePathNet, LRP 14 | from torch.utils.data import DataLoader 15 | from tqdm import trange 16 | 17 | seed = 12345 18 | torch.manual_seed(seed) 19 | OUTPUT_NA_NUM = -100 20 | 21 | config_file = sys.argv[1] 22 | 23 | # load model configs 24 | configs = json.load(open(config_file, 'r')) 25 | data_file = configs['data_file'] 26 | data_type = configs['data_type'] 27 | 28 | BATCH_SIZE = configs['batch_size'] 29 | NUM_WORKERS = 0 30 | LOG_FREQ = configs['log_freq'] 31 | NUM_EPOCHS = configs['num_of_epochs'] 32 | # device = 'cuda' if torch.cuda.is_available() else 'cpu' 33 | device = 'cuda' 34 | 35 | data_target = pd.read_csv(configs['target_file'], low_memory=False, index_col=0) 36 | 37 | data_input = pd.read_csv(data_file, index_col=0) 38 | if data_type[0] != 'DR': 39 | data_input = data_input[ 40 | [x for x in data_input.columns if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)]] 41 | 42 | omics_types = [x for x in data_type if x != 'tissue'] 43 | 44 | genes = np.unique(([x.split("_")[0] for x in data_input.columns if x.split("_")[0] != 'tissue'])) 45 | class_name_to_id = dict( 46 | zip(sorted(data_target.iloc[:, 0].unique()), list(range(data_target.iloc[:, 0].unique().size)))) 47 | id_to_class_name = dict( 48 | zip(list(range(data_target.iloc[:, 0].unique().size)), sorted(data_target.iloc[:, 0].unique()))) 49 | 50 | 51 | num_of_features = data_input.shape[1] 52 | 53 | pathway_dict = {} 54 | pathway_df = pd.read_csv(configs['pathway_file']) 55 | 56 | pathway_df['genes'] = pathway_df['genes'].map( 57 | lambda x: "|".join([gene for gene in x.split('|') if gene in genes])) 58 | if 'min_cancer_publication' in configs: 59 | pathway_df = pathway_df[pathway_df['Cancer_Publications'] > configs['min_cancer_publication']] 60 | if 'max_gene_num' in configs: 61 | pathway_df = pathway_df[pathway_df['GeneNumber'] < configs['max_gene_num']] 62 | if 'min_gene_num' in configs: 63 | pathway_df = pathway_df[pathway_df['GeneNumber'] > configs['min_gene_num']] 64 | 65 | for index, row in pathway_df.iterrows(): 66 | pathway_dict[row['name']] = row['genes'].split('|') 67 | 68 | cancer_genes = set([y for x in pathway_df['genes'].values for y in x.split("|")]) 69 | non_cancer_genes = sorted(set(genes) - set(cancer_genes)) 70 | 71 | cell_lines_all = data_input.index.values 72 | cv = KFold(n_splits=5, shuffle=True, random_state=seed) 73 | cell_lines_train_index, cell_lines_val_index = next(cv.split(cell_lines_all)) 74 | cell_lines_train = np.array(cell_lines_all)[cell_lines_train_index] 75 | cell_lines_test = np.array(cell_lines_all)[cell_lines_val_index] 76 | 77 | data_input_train = data_input[data_input.index.isin(cell_lines_train)] 78 | data_input_test = data_input[data_input.index.isin(cell_lines_test)] 79 | data_target_train = data_target[ 80 | data_target.index.isin(cell_lines_train)] 81 | data_target_test = data_target[ 82 | data_target.index.isin(cell_lines_test)] 83 | 84 | def run_lrp_cancer_type(merged_df_train): 85 | train_df = merged_df_train.iloc[:, :num_of_features] 86 | train_target = merged_df_train.iloc[:, num_of_features:] 87 | 88 | X_train = train_df 89 | 90 | train_dataset = MultiOmicMulticlassDataset(X_train, train_target, mode='train', omics_types=omics_types, 91 | class_name_to_id=class_name_to_id, logger=None) 92 | 93 | pathway_dict = {} 94 | pathway_df = pd.read_csv(configs['pathway_file']) 95 | 96 | pathway_df['genes'] = pathway_df['genes'].map( 97 | lambda x: "|".join([gene for gene in x.split('|') if gene in genes])) 98 | if 'min_cancer_publication' in configs: 99 | pathway_df = pathway_df[pathway_df['Cancer_Publications'] > configs['min_cancer_publication']] 100 | if 'max_gene_num' in configs: 101 | pathway_df = pathway_df[pathway_df['GeneNumber'] < configs['max_gene_num']] 102 | if 'min_gene_num' in configs: 103 | pathway_df = pathway_df[pathway_df['GeneNumber'] > configs['min_gene_num']] 104 | 105 | for index, row in pathway_df.iterrows(): 106 | pathway_dict[row['name']] = row['genes'].split('|') 107 | 108 | cancer_genes = set([y for x in pathway_df['genes'].values for y in x.split("|")]) 109 | non_cancer_genes = sorted(set(genes) - set(cancer_genes)) 110 | model = DeePathNet(len(omics_types), len(class_name_to_id), train_dataset.genes_to_id, 111 | train_dataset.id_to_genes, 112 | pathway_dict, non_cancer_genes, embed_dim=configs['dim'], depth=configs['depth'], 113 | num_heads=configs['heads'], 114 | mlp_ratio=configs['mlp_ratio'], out_mlp_ratio=configs['out_mlp_ratio'], 115 | only_cancer_genes=configs['cancer_only']) 116 | model.load_state_dict(torch.load(f"{configs['work_dir']}/{configs['saved_model']}")) 117 | 118 | model.cuda() 119 | model.eval() 120 | 121 | attribution_generator = LRP(model) 122 | # index = train_target.columns.get_loc('1032;Afatinib;GDSC2') 123 | pathways = list(pathway_dict.keys()) 124 | if not configs['cancer_only']: 125 | pathways += ['non_cancer'] 126 | if 'tissue' in data_type: 127 | pathways += ['tissue'] 128 | 129 | res_df_all = [] 130 | for idx in trange(len(class_name_to_id)): 131 | transformer_attribution_all = [] 132 | cancer_type = id_to_class_name[idx] 133 | 134 | train_loader = DataLoader(train_dataset, 135 | batch_size=1, 136 | num_workers=NUM_WORKERS) 137 | for i, data in enumerate(train_loader): 138 | transformer_attribution = attribution_generator.generate_LRP(data, 139 | method="transformer_attribution", 140 | index=idx).detach().cpu().numpy() 141 | transformer_attribution_all.append(transformer_attribution[0, :]) 142 | transformer_attribution_sum = np.sum(transformer_attribution_all, axis=0) 143 | res_df = pd.DataFrame({'pathway': pathways, 'importance': transformer_attribution_sum}) 144 | res_df['cancer_type'] = cancer_type 145 | res_df_all.append(res_df) 146 | 147 | res_df_all = pd.concat(res_df_all) 148 | res_df_all = res_df_all[['cancer_type', 'pathway', 'importance']] 149 | return res_df_all 150 | 151 | 152 | merged_df_train = pd.merge(data_input_train, data_target_train, on=['Cell_line']) 153 | res_df_all = run_lrp_cancer_type(merged_df_train) 154 | res_df_all.to_csv(f"{configs['work_dir']}/explanation_{configs['saved_model'].replace('pth', 'csv')}", index=False) 155 | print("finished") 156 | -------------------------------------------------------------------------------- /scripts/transformer_explantion_drug_response.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from datetime import datetime 4 | import sys 5 | import logging 6 | import os 7 | 8 | import pandas as pd 9 | 10 | from models import * 11 | from model_transformer_lrp import DOIT_LRP, LRP 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm, trange 14 | 15 | seed = 12345 16 | torch.manual_seed(seed) 17 | OUTPUT_NA_NUM = -100 18 | 19 | config_file = sys.argv[1] 20 | # load model configs 21 | configs = json.load(open(config_file, 'r')) 22 | data_file = configs['data_file'] 23 | data_type = configs['data_type'] 24 | 25 | BATCH_SIZE = configs['batch_size'] 26 | NUM_WORKERS = 0 27 | LOG_FREQ = configs['log_freq'] 28 | NUM_EPOCHS = configs['num_of_epochs'] 29 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 30 | 31 | data_target = pd.read_csv(configs['target_file'], low_memory=False, index_col=0) 32 | drug_ids = data_target.columns 33 | if 'drug_id' in configs and configs['drug_id'] != "": 34 | data_target = data_target[configs['drug_id']] 35 | drug_ids = [configs['drug_id']] 36 | 37 | data_input = pd.read_csv(data_file, index_col=0) 38 | if data_type[0] != 'DR': 39 | data_input = data_input[ 40 | [x for x in data_input.columns if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)]] 41 | 42 | tissues = [x for x in data_input.columns if 'tissue' in x] if 'tissue' in data_type else None 43 | omics_types = [x for x in data_type if x != 'tissue'] 44 | with_tissue = 'tissue' in data_type 45 | 46 | with open(configs['train_cells']) as f: 47 | cell_lines_train = [line.rstrip() for line in f] 48 | with open(configs['test_cells']) as f: 49 | cell_lines_test = [line.rstrip() for line in f] 50 | 51 | genes = np.unique(([x.split("_")[0] for x in data_input.columns if x.split("_")[0] != 'tissue'])) 52 | 53 | data_target_train = data_target[ 54 | data_target.index.isin(cell_lines_train)] 55 | data_target_test = data_target[ 56 | data_target.index.isin(cell_lines_test)] 57 | 58 | num_of_features = data_input.shape[1] 59 | data_input_train = data_input[data_input.index.isin(cell_lines_train)] 60 | merged_df_train = pd.merge(data_input_train, data_target_train, on=['Cell_line']) 61 | 62 | 63 | # test_data = data_input_test 64 | # merged_df_test = pd.merge(test_data, data_target_test, on=['Cell_line']) 65 | 66 | def run_lrp(merged_df_train, drug_id=None): 67 | train_df = merged_df_train.iloc[:, :num_of_features] 68 | train_target = merged_df_train.iloc[:, num_of_features:] 69 | 70 | X_train = train_df 71 | 72 | train_dataset = MultiOmicDataset(X_train, train_target, mode='train', omics_types=omics_types, logger=None, 73 | with_tissue=with_tissue) 74 | 75 | pathway_dict = {} 76 | pathway_df = pd.read_csv(configs['pathway_file']) 77 | 78 | pathway_df['genes'] = pathway_df['genes'].map( 79 | lambda x: "|".join([gene for gene in x.split('|') if gene in genes])) 80 | if 'min_cancer_publication' in configs: 81 | pathway_df = pathway_df[pathway_df['Cancer_Publications'] > configs['min_cancer_publication']] 82 | if 'max_gene_num' in configs: 83 | pathway_df = pathway_df[pathway_df['GeneNumber'] < configs['max_gene_num']] 84 | if 'min_gene_num' in configs: 85 | pathway_df = pathway_df[pathway_df['GeneNumber'] > configs['min_gene_num']] 86 | 87 | for index, row in pathway_df.iterrows(): 88 | pathway_dict[row['name']] = row['genes'].split('|') 89 | 90 | cancer_genes = set([y for x in pathway_df['genes'].values for y in x.split("|")]) 91 | non_cancer_genes = sorted(set(genes) - set(cancer_genes)) 92 | model = DOIT_LRP(len(omics_types), train_target.shape[1], train_dataset.genes_to_id, 93 | train_dataset.id_to_genes, 94 | pathway_dict, non_cancer_genes, embed_dim=configs['dim'], depth=configs['depth'], 95 | num_heads=configs['heads'], 96 | mlp_ratio=configs['mlp_ratio'], out_mlp_ratio=configs['out_mlp_ratio'], 97 | only_cancer_genes=configs['cancer_only'], tissues=tissues) 98 | 99 | if drug_id: 100 | drug_file_name = get_model_filename(drug_id) 101 | drug_path = f"{configs['work_dir']}_{configs['saved_model']}/{configs['saved_model']}_{drug_file_name}.pth" 102 | if not os.path.exists(drug_path): 103 | return None 104 | model.load_state_dict( 105 | torch.load(drug_path)) 106 | else: 107 | model.load_state_dict(torch.load(f"{configs['work_dir']}/{configs['saved_model']}")) 108 | 109 | model.cuda() 110 | model.eval() 111 | 112 | attribution_generator = LRP(model) 113 | # index = train_target.columns.get_loc('1032;Afatinib;GDSC2') 114 | pathways = list(pathway_dict.keys()) 115 | if not configs['cancer_only']: 116 | pathways += ['non_cancer'] 117 | if 'tissue' in data_type: 118 | pathways += ['tissue'] 119 | 120 | res_df_all = [] 121 | for drug_idx in range(train_target.shape[1]): 122 | transformer_attribution_all = [] 123 | drug = train_target.columns[drug_idx] 124 | X_train_drug = X_train[X_train.index.isin(train_target[train_target[drug] > 0].index)] 125 | train_target_drug = train_target[train_target[drug] > 0] 126 | train_dataset = MultiOmicDataset(X_train_drug, train_target_drug, mode='train', omics_types=omics_types, 127 | logger=None, with_tissue=with_tissue) 128 | train_loader = DataLoader(train_dataset, 129 | batch_size=1, 130 | num_workers=NUM_WORKERS) 131 | for i, data in enumerate(train_loader): 132 | transformer_attribution = attribution_generator.generate_LRP(data, 133 | method="transformer_attribution", 134 | index=drug_idx).detach().cpu().numpy() 135 | transformer_attribution_all.append(transformer_attribution[0, :]) 136 | transformer_attribution_sum = np.sum(transformer_attribution_all, axis=0) 137 | res_df = pd.DataFrame({'pathway': pathways, 'importance': transformer_attribution_sum}) 138 | res_df['drug_id'] = train_target.columns[drug_idx] 139 | res_df_all.append(res_df) 140 | 141 | res_df_all = pd.concat(res_df_all) 142 | res_df_all = res_df_all[['drug_id', 'pathway', 'importance']] 143 | return res_df_all 144 | 145 | 146 | if 'all_single_mode' in configs and configs['all_single_mode']: 147 | res_df_all = [] 148 | drug_ids = pd.read_csv(configs['drug_list'], index_col=0).index.values 149 | for drug_id in tqdm(drug_ids): 150 | merged_df_train = pd.merge(data_input_train, data_target_train[drug_id], on=['Cell_line']) 151 | relevance = run_lrp(merged_df_train, drug_id=drug_id) 152 | if relevance is not None: 153 | res_df_all.append(relevance) 154 | res_df_all = pd.concat(res_df_all) 155 | res_df_all.to_csv( 156 | f"{configs['work_dir']}_{configs['saved_model']}/explanation_{configs['saved_model']}.csv.gz", 157 | index=False) 158 | else: 159 | data_input_train = data_input[data_input.index.isin(cell_lines_train)] 160 | merged_df_train = pd.merge(data_input_train, data_target_train, on=['Cell_line']) 161 | res_df_all = run_lrp(merged_df_train) 162 | res_df_all.to_csv(f"{configs['work_dir']}/explanation_{configs['saved_model'].replace('pth', 'csv')}", index=False) 163 | print("finished") 164 | -------------------------------------------------------------------------------- /scripts/transformer_shap_cancer_type.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to run gene-level model explanation for cancer type 3 | E.g. python scripts/transformer_shap_cancer_type.py configs/tcga_brca_subtypes/mutation_cnv_rna/deepathnet_allgenes_mutation_cnv_rna.json 4 | """ 5 | 6 | import json 7 | import os 8 | import sys 9 | 10 | from sklearn.model_selection import KFold 11 | 12 | sys.path.append(os.getcwd() + "/..") 13 | from models import * 14 | from model_transformer_lrp import DeePathNet 15 | from torch.utils.data import DataLoader 16 | 17 | import shap 18 | import time 19 | 20 | seed = 12345 21 | torch.manual_seed(seed) 22 | OUTPUT_NA_NUM = -100 23 | 24 | config_file = sys.argv[1] 25 | 26 | mode = "grad" if len(sys.argv) <= 2 else sys.argv[2] 27 | print(f"SHAP algo: {mode}") 28 | # load model configs 29 | configs = json.load(open(config_file, "r")) 30 | data_file = configs["data_file"] 31 | data_type = configs["data_type"] 32 | 33 | BATCH_SIZE = configs["batch_size"] 34 | NUM_WORKERS = 0 35 | LOG_FREQ = configs["log_freq"] 36 | NUM_EPOCHS = configs["num_of_epochs"] 37 | # device = 'cuda' if torch.cuda.is_available() else 'cpu' 38 | device = "cuda" 39 | 40 | data_target = pd.read_csv(configs["target_file"], low_memory=False, index_col=0) 41 | 42 | data_input = pd.read_csv(data_file, index_col=0) 43 | if data_type[0] != "DR": 44 | data_input = data_input[ 45 | [ 46 | x 47 | for x in data_input.columns 48 | if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type) 49 | ] 50 | ] 51 | 52 | omics_types = [x for x in data_type if x != "tissue"] 53 | 54 | 55 | genes = np.unique( 56 | ([x.split("_")[0] for x in data_input.columns if x.split("_")[0] != "tissue"]) 57 | ) 58 | class_name_to_id = dict( 59 | zip( 60 | sorted(data_target.iloc[:, 0].unique()), 61 | list(range(data_target.iloc[:, 0].unique().size)), 62 | ) 63 | ) 64 | id_to_class_name = dict( 65 | zip( 66 | list(range(data_target.iloc[:, 0].unique().size)), 67 | sorted(data_target.iloc[:, 0].unique()), 68 | ) 69 | ) 70 | 71 | num_of_features = data_input.shape[1] 72 | 73 | pathway_dict = {} 74 | pathway_df = pd.read_csv(configs["pathway_file"]) 75 | 76 | pathway_df["genes"] = pathway_df["genes"].map( 77 | lambda x: "|".join([gene for gene in x.split("|") if gene in genes]) 78 | ) 79 | if "min_cancer_publication" in configs: 80 | pathway_df = pathway_df[ 81 | pathway_df["Cancer_Publications"] > configs["min_cancer_publication"] 82 | ] 83 | if "max_gene_num" in configs: 84 | pathway_df = pathway_df[pathway_df["GeneNumber"] < configs["max_gene_num"]] 85 | if "min_gene_num" in configs: 86 | pathway_df = pathway_df[pathway_df["GeneNumber"] > configs["min_gene_num"]] 87 | 88 | for index, row in pathway_df.iterrows(): 89 | pathway_dict[row["name"]] = row["genes"].split("|") 90 | 91 | cancer_genes = set([y for x in pathway_df["genes"].values for y in x.split("|")]) 92 | non_cancer_genes = sorted(set(genes) - set(cancer_genes)) 93 | 94 | cell_lines_all = data_input.index.values 95 | cv = KFold(n_splits=5, shuffle=True, random_state=seed) 96 | cell_lines_train_index, cell_lines_val_index = next(cv.split(cell_lines_all)) 97 | cell_lines_train = np.array(cell_lines_all)[cell_lines_train_index] 98 | cell_lines_test = np.array(cell_lines_all)[cell_lines_val_index] 99 | 100 | data_input_train = data_input[data_input.index.isin(cell_lines_train)] 101 | data_input_test = data_input[data_input.index.isin(cell_lines_test)] 102 | data_target_train = data_target[data_target.index.isin(cell_lines_train)] 103 | data_target_test = data_target[data_target.index.isin(cell_lines_test)] 104 | 105 | 106 | def run_shap(merged_df_train, merged_df_test): 107 | train_df = merged_df_train.iloc[:, :num_of_features] 108 | test_df = merged_df_test.iloc[:, :num_of_features] 109 | train_target = merged_df_train.iloc[:, num_of_features:] 110 | test_target = merged_df_test.iloc[:, num_of_features:] 111 | 112 | X_train = train_df 113 | X_test = test_df 114 | 115 | train_dataset = MultiOmicMulticlassDataset( 116 | X_train, 117 | train_target, 118 | mode="train", 119 | omics_types=omics_types, 120 | class_name_to_id=class_name_to_id, 121 | logger=None, 122 | ) 123 | test_dataset = MultiOmicMulticlassDataset( 124 | X_test, 125 | test_target, 126 | mode="val", 127 | omics_types=omics_types, 128 | class_name_to_id=class_name_to_id, 129 | logger=None, 130 | ) 131 | 132 | model = DeePathNet( 133 | len(omics_types), 134 | len(class_name_to_id), 135 | train_dataset.genes_to_id, 136 | train_dataset.id_to_genes, 137 | pathway_dict, 138 | non_cancer_genes, 139 | embed_dim=configs["dim"], 140 | depth=configs["depth"], 141 | num_heads=configs["heads"], 142 | mlp_ratio=configs["mlp_ratio"], 143 | out_mlp_ratio=configs["out_mlp_ratio"], 144 | only_cancer_genes=configs["cancer_only"], 145 | ) 146 | model.load_state_dict(torch.load(f"{configs['work_dir']}/{configs['saved_model']}")) 147 | 148 | model.to(device) 149 | model.eval() 150 | train_loader = DataLoader( 151 | train_dataset, batch_size=600, shuffle=True, num_workers=NUM_WORKERS 152 | ) 153 | test_loader = DataLoader( 154 | test_dataset, 155 | batch_size=len(test_dataset), 156 | shuffle=True, 157 | num_workers=NUM_WORKERS, 158 | ) 159 | 160 | data = next(iter(train_loader)) 161 | test_data = next(iter(test_loader)) 162 | 163 | (input, targets) = data 164 | (test_input, test_targets) = test_data 165 | 166 | NUM_EXPLAINED = 400 167 | N_SAMPLES = 50 168 | 169 | start = time.time() 170 | background = input.float().to(device) 171 | explainer = shap.GradientExplainer(model, background) 172 | shap_values = explainer.shap_values( 173 | test_input[:NUM_EXPLAINED, :, :].float().to(device), nsamples=N_SAMPLES 174 | ) 175 | end = time.time() 176 | print(end - start) 177 | 178 | all_drug_gradients_summary = {"cancer_type": [], "gene": []} 179 | for target in omics_types: 180 | all_drug_gradients_summary[target] = [] 181 | 182 | for idx in range(len(class_name_to_id)): 183 | cancer_type = id_to_class_name[idx] 184 | omics_shap = shap_values[idx] 185 | 186 | omics_shap_mean = np.mean(np.abs(omics_shap), axis=0) 187 | 188 | all_drug_gradients_summary["cancer_type"].extend([cancer_type] * len(genes)) 189 | all_drug_gradients_summary["gene"].extend(genes) 190 | for i in range(len(omics_types)): 191 | all_drug_gradients_summary[omics_types[i]].extend(omics_shap_mean[:, i]) 192 | 193 | all_drug_gradients_summary = pd.DataFrame(all_drug_gradients_summary) 194 | all_drug_gradients_summary["sum"] = all_drug_gradients_summary.iloc[:, 2:].sum( 195 | axis=1 196 | ) 197 | 198 | del model, explainer, input, targets 199 | torch.cuda.empty_cache() 200 | return all_drug_gradients_summary 201 | 202 | 203 | merged_df_train = pd.merge(data_input_train, data_target_train, on=["Cell_line"]) 204 | merged_df_test = pd.merge(data_input_test, data_target_test, on=["Cell_line"]) 205 | all_drug_gradients_summary = run_shap(merged_df_train, merged_df_test) 206 | all_drug_gradients_summary.to_csv( 207 | f"{configs['work_dir']}/shap{mode}_genes_{configs['saved_model'].replace('pth', 'csv.gz')}", 208 | index=False, 209 | ) 210 | -------------------------------------------------------------------------------- /scripts/transformer_shap_drug_response.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from datetime import datetime 4 | import sys 5 | import logging 6 | import os 7 | import sys 8 | 9 | import pandas as pd 10 | 11 | sys.path.append(os.getcwd() + '/..') 12 | from models import * 13 | from model_transformer_lrp import DOIT_LRP, LRP 14 | from torch.utils.data import DataLoader 15 | from tqdm import tqdm, trange 16 | 17 | import shap 18 | import time 19 | 20 | seed = 12345 21 | torch.manual_seed(seed) 22 | OUTPUT_NA_NUM = -100 23 | 24 | config_file = sys.argv[1] 25 | 26 | mode = 'grad' if len(sys.argv) <= 2 else sys.argv[2] 27 | print(f"SHAP algo: {mode}") 28 | # load model configs 29 | configs = json.load(open(config_file, 'r')) 30 | data_file = configs['data_file'] 31 | data_type = configs['data_type'] 32 | 33 | BATCH_SIZE = configs['batch_size'] 34 | NUM_WORKERS = 0 35 | LOG_FREQ = configs['log_freq'] 36 | NUM_EPOCHS = configs['num_of_epochs'] 37 | # device = 'cuda' if torch.cuda.is_available() else 'cpu' 38 | device = 'cuda' 39 | 40 | data_target = pd.read_csv(configs['target_file'], low_memory=False, index_col=0) 41 | drug_ids = data_target.columns 42 | if 'drug_id' in configs and configs['drug_id'] != "": 43 | data_target = data_target[configs['drug_id']] 44 | drug_ids = [configs['drug_id']] 45 | 46 | data_input = pd.read_csv(data_file, index_col=0) 47 | if data_type[0] != 'DR': 48 | data_input = data_input[ 49 | [x for x in data_input.columns if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type)]] 50 | 51 | tissues = [x for x in data_input.columns if 'tissue' in x] if 'tissue' in data_type else None 52 | omics_types = [x for x in data_type if x != 'tissue'] 53 | with_tissue = 'tissue' in data_type 54 | 55 | with open(configs['train_cells']) as f: 56 | cell_lines_train = [line.rstrip() for line in f] 57 | with open(configs['test_cells']) as f: 58 | cell_lines_test = [line.rstrip() for line in f] 59 | 60 | genes = np.unique(([x.split("_")[0] for x in data_input.columns if x.split("_")[0] != 'tissue'])) 61 | 62 | data_target_train = data_target[ 63 | data_target.index.isin(cell_lines_train)] 64 | data_target_test = data_target[ 65 | data_target.index.isin(cell_lines_test)] 66 | 67 | num_of_features = data_input.shape[1] 68 | 69 | pathway_dict = {} 70 | pathway_df = pd.read_csv(configs['pathway_file']) 71 | 72 | pathway_df['genes'] = pathway_df['genes'].map( 73 | lambda x: "|".join([gene for gene in x.split('|') if gene in genes])) 74 | if 'min_cancer_publication' in configs: 75 | pathway_df = pathway_df[pathway_df['Cancer_Publications'] > configs['min_cancer_publication']] 76 | if 'max_gene_num' in configs: 77 | pathway_df = pathway_df[pathway_df['GeneNumber'] < configs['max_gene_num']] 78 | if 'min_gene_num' in configs: 79 | pathway_df = pathway_df[pathway_df['GeneNumber'] > configs['min_gene_num']] 80 | 81 | for index, row in pathway_df.iterrows(): 82 | pathway_dict[row['name']] = row['genes'].split('|') 83 | 84 | cancer_genes = set([y for x in pathway_df['genes'].values for y in x.split("|")]) 85 | non_cancer_genes = sorted(set(genes) - set(cancer_genes)) 86 | 87 | data_input_train = data_input[data_input.index.isin(cell_lines_train)] 88 | data_input_test = data_input[data_input.index.isin(cell_lines_test)] 89 | 90 | 91 | def run_shap(merged_df_train, merged_df_test, drug_ids=None): 92 | train_df = merged_df_train.iloc[:, :num_of_features] 93 | test_df = merged_df_test.iloc[:, :num_of_features] 94 | train_target = merged_df_train.iloc[:, num_of_features:] 95 | test_target = merged_df_test.iloc[:, num_of_features:] 96 | 97 | X_train = train_df 98 | X_test = test_df 99 | 100 | train_dataset = MultiOmicDataset(X_train, train_target, mode='train', omics_types=omics_types, logger=None, 101 | with_tissue=with_tissue) 102 | test_dataset = MultiOmicDataset(X_test, test_target, mode='val', omics_types=omics_types, logger=None, 103 | with_tissue=with_tissue) 104 | 105 | model = DOIT_LRP(len(omics_types), train_target.shape[1], train_dataset.genes_to_id, 106 | train_dataset.id_to_genes, 107 | pathway_dict, non_cancer_genes, embed_dim=configs['dim'], depth=configs['depth'], 108 | num_heads=configs['heads'], 109 | mlp_ratio=configs['mlp_ratio'], out_mlp_ratio=configs['out_mlp_ratio'], 110 | only_cancer_genes=configs['cancer_only'], tissues=tissues) 111 | if len(drug_ids) == 1: 112 | drug_file_name = get_model_filename(drug_ids[0]) 113 | drug_path = f"{configs['work_dir']}_{configs['saved_model']}/{configs['saved_model']}_{drug_file_name}.pth" 114 | if not os.path.exists(drug_path): 115 | return None, None 116 | model.load_state_dict(torch.load(drug_path)) 117 | else: 118 | model.load_state_dict(torch.load(f"{configs['work_dir']}/{configs['saved_model']}")) 119 | model.to(device) 120 | model.eval() 121 | 122 | train_loader = DataLoader(train_dataset, 123 | batch_size=len(train_dataset), shuffle=True, 124 | num_workers=NUM_WORKERS) 125 | test_loader = DataLoader(test_dataset, 126 | batch_size=len(test_dataset), shuffle=True, 127 | num_workers=NUM_WORKERS) 128 | 129 | data = next(iter(train_loader)) 130 | test_data = next(iter(test_loader)) 131 | 132 | tissue_x = None 133 | test_tissue_x = None 134 | if len(data) == 2: 135 | (input, targets) = data 136 | (test_input, test_targets) = test_data 137 | elif len(data) == 3: 138 | (input, tissue_x, targets) = data 139 | (test_input, test_tissue_x, test_targets) = test_data 140 | else: 141 | raise Exception 142 | 143 | NUM_EXPLAINED = 400 144 | N_SAMPLES = 50 145 | 146 | start = time.time() 147 | if mode == 'grad': 148 | if tissue_x is not None: 149 | background = [input.float().to(device), tissue_x.float().to(device)] 150 | explainer = shap.GradientExplainer(model, background) 151 | shap_values = explainer.shap_values( 152 | [test_input[:NUM_EXPLAINED, :, :].float().to(device), 153 | test_tissue_x[:NUM_EXPLAINED, :].float().to(device)], 154 | nsamples=N_SAMPLES) 155 | else: 156 | background = input.float().to(device) 157 | explainer = shap.GradientExplainer(model, background) 158 | shap_values = explainer.shap_values(test_input[:NUM_EXPLAINED, :, :].float().to(device), nsamples=N_SAMPLES) 159 | else: 160 | raise Exception 161 | end = time.time() 162 | print(end - start) 163 | 164 | all_drug_gradients_summary = {'drug_id': [], 'gene': []} 165 | all_drug_gradients_tissue_summary = {'drug_id': [], 'tissue': [], 'importance': []} 166 | for target in omics_types: 167 | all_drug_gradients_summary[target] = [] 168 | 169 | for drug_idx in range(len(drug_ids)): 170 | drug_id = drug_ids[drug_idx] 171 | if len(drug_ids) > 1: 172 | omics_shap = shap_values[drug_idx][0] # N x genes x num_omics 173 | tissue_shap = shap_values[drug_idx][1] # N x tissue 174 | else: 175 | omics_shap = shap_values[0] # N x genes x num_omics 176 | tissue_shap = shap_values[1] # N x tissue 177 | 178 | omics_shap_mean = np.mean(np.abs(omics_shap), axis=0) 179 | tissue_shap_mean = np.mean(np.abs(tissue_shap), axis=0) 180 | 181 | all_drug_gradients_summary['drug_id'].extend([drug_id] * len(genes)) 182 | all_drug_gradients_summary['gene'].extend(genes) 183 | for i in range(len(omics_types)): 184 | all_drug_gradients_summary[omics_types[i]].extend(omics_shap_mean[:, i]) 185 | 186 | all_drug_gradients_tissue_summary['drug_id'].extend([drug_id] * len(tissues)) 187 | all_drug_gradients_tissue_summary['tissue'].extend(tissues) 188 | all_drug_gradients_tissue_summary['importance'].extend(tissue_shap_mean) 189 | 190 | all_drug_gradients_summary = pd.DataFrame(all_drug_gradients_summary) 191 | all_drug_gradients_summary['sum'] = all_drug_gradients_summary.iloc[:, 2:].sum(axis=1) 192 | 193 | all_drug_gradients_tissue_summary = pd.DataFrame(all_drug_gradients_tissue_summary) 194 | 195 | del model, explainer, input, targets 196 | torch.cuda.empty_cache() 197 | return all_drug_gradients_summary, all_drug_gradients_tissue_summary 198 | 199 | 200 | if 'all_single_mode' in configs and configs['all_single_mode']: 201 | all_drug_gradients_summary = [] 202 | all_drug_gradients_tissue_summary = [] 203 | drug_ids = pd.read_csv(configs['drug_list'], index_col=0).index.values 204 | for drug_id in tqdm(drug_ids): 205 | merged_df_train = pd.merge(data_input_train, data_target_train[drug_id], on=['Cell_line']) 206 | merged_df_test = pd.merge(data_input_test, data_target_test[drug_id], on=['Cell_line']) 207 | omics, tissue = run_shap(merged_df_train, merged_df_test, drug_ids=[drug_id]) 208 | if omics is not None: 209 | all_drug_gradients_summary.append(omics) 210 | all_drug_gradients_tissue_summary.append(tissue) 211 | 212 | all_drug_gradients_summary = pd.concat(all_drug_gradients_summary) 213 | all_drug_gradients_tissue_summary = pd.concat(all_drug_gradients_tissue_summary) 214 | all_drug_gradients_summary.to_csv( 215 | f"{configs['work_dir']}_{configs['saved_model']}/shap{mode}_genes_{configs['saved_model']}.csv.gz", 216 | index=False) 217 | all_drug_gradients_tissue_summary.to_csv( 218 | f"{configs['work_dir']}_{configs['saved_model']}/shap{mode}_tissue_{configs['saved_model']}.csv.gz", 219 | index=False) 220 | else: 221 | merged_df_train = pd.merge(data_input_train, data_target_train, on=['Cell_line']) 222 | merged_df_test = pd.merge(data_input_test, data_target_test, on=['Cell_line']) 223 | all_drug_gradients_summary, all_drug_gradients_tissue_summary = run_shap(merged_df_train, merged_df_test, 224 | drug_ids=drug_ids) 225 | all_drug_gradients_summary.to_csv( 226 | f"{configs['work_dir']}/shap{mode}_genes_{configs['saved_model'].replace('pth', 'csv.gz')}", 227 | index=False) 228 | all_drug_gradients_tissue_summary.to_csv( 229 | f"{configs['work_dir']}/shap{mode}_tissue_{configs['saved_model'].replace('pth', 'csv.gz')}", index=False) 230 | -------------------------------------------------------------------------------- /scripts/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import lr_scheduler -------------------------------------------------------------------------------- /scripts/utils/layers_ours.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d', 6 | 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect', 7 | 'LayerNorm', 'AddEye', 'Identity'] 8 | 9 | 10 | def safe_divide(a, b): 11 | den = b.clamp(min=1e-9) + b.clamp(max=1e-9) 12 | den = den + den.eq(0).type(den.type()) * 1e-9 13 | return a / den * b.ne(0).type(b.type()) 14 | 15 | 16 | def forward_hook(self, input, output): 17 | if type(input[0]) in (list, tuple): 18 | self.X = [] 19 | for i in input[0]: 20 | x = i.detach() 21 | x.requires_grad = True 22 | self.X.append(x) 23 | else: 24 | self.X = input[0].detach() 25 | self.X.requires_grad = True 26 | 27 | self.Y = output 28 | 29 | 30 | def backward_hook(self, grad_input, grad_output): 31 | self.grad_input = grad_input 32 | self.grad_output = grad_output 33 | 34 | 35 | class RelProp(nn.Module): 36 | def __init__(self): 37 | super(RelProp, self).__init__() 38 | # if not self.training: 39 | self.register_forward_hook(forward_hook) 40 | 41 | def gradprop(self, Z, X, S): 42 | C = torch.autograd.grad(Z, X, S, retain_graph=True) 43 | return C 44 | 45 | def relprop(self, R, alpha): 46 | return R 47 | 48 | 49 | class RelPropSimple(RelProp): 50 | def relprop(self, R, alpha): 51 | Z = self.forward(self.X) 52 | S = safe_divide(R, Z) 53 | C = self.gradprop(Z, self.X, S) 54 | 55 | if torch.is_tensor(self.X) == False: 56 | outputs = [] 57 | outputs.append(self.X[0] * C[0]) 58 | outputs.append(self.X[1] * C[1]) 59 | else: 60 | outputs = self.X * (C[0]) 61 | return outputs 62 | 63 | 64 | class AddEye(RelPropSimple): 65 | # input of shape B, C, seq_len, seq_len 66 | def forward(self, input): 67 | return input + torch.eye(input.shape[2]).expand_as(input).to(input.device) 68 | 69 | 70 | class ReLU(nn.ReLU, RelProp): 71 | pass 72 | 73 | 74 | class GELU(nn.GELU, RelProp): 75 | pass 76 | 77 | 78 | class Identity(nn.Identity, RelProp): 79 | pass 80 | 81 | 82 | class Softmax(nn.Softmax, RelProp): 83 | pass 84 | 85 | 86 | class LayerNorm(nn.LayerNorm, RelProp): 87 | pass 88 | 89 | 90 | class Dropout(nn.Dropout, RelProp): 91 | pass 92 | 93 | 94 | class MaxPool2d(nn.MaxPool2d, RelPropSimple): 95 | pass 96 | 97 | 98 | class LayerNorm(nn.LayerNorm, RelProp): 99 | pass 100 | 101 | 102 | class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple): 103 | pass 104 | 105 | 106 | class AvgPool2d(nn.AvgPool2d, RelPropSimple): 107 | pass 108 | 109 | 110 | class Add(RelPropSimple): 111 | def forward(self, inputs): 112 | return torch.add(*inputs) 113 | 114 | def relprop(self, R, alpha): 115 | Z = self.forward(self.X) 116 | S = safe_divide(R, Z) 117 | C = self.gradprop(Z, self.X, S) 118 | 119 | a = self.X[0] * C[0] 120 | b = self.X[1] * C[1] 121 | 122 | a_sum = a.sum() 123 | b_sum = b.sum() 124 | 125 | a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum() 126 | b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum() 127 | 128 | a = a * safe_divide(a_fact, a.sum()) 129 | b = b * safe_divide(b_fact, b.sum()) 130 | 131 | outputs = [a, b] 132 | 133 | return outputs 134 | 135 | 136 | class einsum(RelPropSimple): 137 | def __init__(self, equation): 138 | super().__init__() 139 | self.equation = equation 140 | 141 | def forward(self, *operands): 142 | return torch.einsum(self.equation, *operands) 143 | 144 | 145 | class IndexSelect(RelProp): 146 | def forward(self, inputs, dim, indices): 147 | self.__setattr__('dim', dim) 148 | self.__setattr__('indices', indices) 149 | 150 | return torch.index_select(inputs, dim, indices) 151 | 152 | def relprop(self, R, alpha): 153 | Z = self.forward(self.X, self.dim, self.indices) 154 | S = safe_divide(R, Z) 155 | C = self.gradprop(Z, self.X, S) 156 | 157 | if torch.is_tensor(self.X) == False: 158 | outputs = [] 159 | outputs.append(self.X[0] * C[0]) 160 | outputs.append(self.X[1] * C[1]) 161 | else: 162 | outputs = self.X * (C[0]) 163 | return outputs 164 | 165 | 166 | class Clone(RelProp): 167 | def forward(self, input, num): 168 | self.__setattr__('num', num) 169 | outputs = [] 170 | for _ in range(num): 171 | outputs.append(input) 172 | 173 | return outputs 174 | 175 | def relprop(self, R, alpha): 176 | Z = [] 177 | for _ in range(self.num): 178 | Z.append(self.X) 179 | S = [safe_divide(r, z) for r, z in zip(R, Z)] 180 | C = self.gradprop(Z, self.X, S)[0] 181 | 182 | R = self.X * C 183 | 184 | return R 185 | 186 | 187 | class Cat(RelProp): 188 | def forward(self, inputs, dim): 189 | self.__setattr__('dim', dim) 190 | return torch.cat(inputs, dim) 191 | 192 | def relprop(self, R, alpha): 193 | Z = self.forward(self.X, self.dim) 194 | S = safe_divide(R, Z) 195 | C = self.gradprop(Z, self.X, S) 196 | 197 | outputs = [] 198 | for x, c in zip(self.X, C): 199 | outputs.append(x * c) 200 | 201 | return outputs 202 | 203 | 204 | class Sequential(nn.Sequential): 205 | def relprop(self, R, alpha): 206 | for m in reversed(self._modules.values()): 207 | R = m.relprop(R, alpha) 208 | return R 209 | 210 | 211 | class BatchNorm2d(nn.BatchNorm2d, RelProp): 212 | def relprop(self, R, alpha): 213 | X = self.X 214 | beta = 1 - alpha 215 | weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / ( 216 | (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5)) 217 | Z = X * weight + 1e-9 218 | S = R / Z 219 | Ca = S * weight 220 | R = self.X * (Ca) 221 | return R 222 | 223 | 224 | class Linear(nn.Linear, RelProp): 225 | def relprop(self, R, alpha): 226 | beta = alpha - 1 227 | pw = torch.clamp(self.weight, min=0) 228 | nw = torch.clamp(self.weight, max=0) 229 | px = torch.clamp(self.X, min=0) 230 | nx = torch.clamp(self.X, max=0) 231 | 232 | def f(w1, w2, x1, x2): 233 | Z1 = F.linear(x1, w1) 234 | Z2 = F.linear(x2, w2) 235 | S1 = safe_divide(R, Z1 + Z2) 236 | S2 = safe_divide(R, Z1 + Z2) 237 | C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0] 238 | C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0] 239 | 240 | return C1 + C2 241 | 242 | activator_relevances = f(pw, nw, px, nx) 243 | inhibitor_relevances = f(nw, pw, px, nx) 244 | 245 | R = alpha * activator_relevances - beta * inhibitor_relevances 246 | 247 | return R 248 | 249 | 250 | class Conv2d(nn.Conv2d, RelProp): 251 | def gradprop2(self, DY, weight): 252 | Z = self.forward(self.X) 253 | 254 | output_padding = self.X.size()[2] - ( 255 | (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0]) 256 | 257 | return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding) 258 | 259 | def relprop(self, R, alpha): 260 | if self.X.shape[1] == 3: 261 | pw = torch.clamp(self.weight, min=0) 262 | nw = torch.clamp(self.weight, max=0) 263 | X = self.X 264 | L = self.X * 0 + \ 265 | torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, 266 | keepdim=True)[0] 267 | H = self.X * 0 + \ 268 | torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3, 269 | keepdim=True)[0] 270 | Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \ 271 | torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \ 272 | torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9 273 | 274 | S = R / Za 275 | C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw) 276 | R = C 277 | else: 278 | beta = alpha - 1 279 | pw = torch.clamp(self.weight, min=0) 280 | nw = torch.clamp(self.weight, max=0) 281 | px = torch.clamp(self.X, min=0) 282 | nx = torch.clamp(self.X, max=0) 283 | 284 | def f(w1, w2, x1, x2): 285 | Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding) 286 | Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding) 287 | S1 = safe_divide(R, Z1) 288 | S2 = safe_divide(R, Z2) 289 | C1 = x1 * self.gradprop(Z1, x1, S1)[0] 290 | C2 = x2 * self.gradprop(Z2, x2, S2)[0] 291 | return C1 + C2 292 | 293 | activator_relevances = f(pw, nw, px, nx) 294 | inhibitor_relevances = f(nw, pw, px, nx) 295 | 296 | R = alpha * activator_relevances - beta * inhibitor_relevances 297 | return R 298 | -------------------------------------------------------------------------------- /scripts/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class FindLR(_LRScheduler): 5 | """ 6 | inspired by fast.ai @https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html 7 | """ 8 | def __init__(self, optimizer, max_steps, max_lr=10): 9 | self.max_steps = max_steps 10 | self.max_lr = max_lr 11 | super().__init__(optimizer) 12 | 13 | def get_lr(self): 14 | return [base_lr * ((self.max_lr / base_lr) ** (self.last_epoch / (self.max_steps - 1))) 15 | for base_lr in self.base_lrs] 16 | 17 | 18 | class NoamLR(_LRScheduler): 19 | """ 20 | Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate 21 | linearly for the first ``warmup_steps`` training steps, and decreasing it thereafter proportionally 22 | to the inverse square root of the step number, scaled by the inverse square root of the 23 | dimensionality of the model. Time will tell if this is just madness or it's actually important. 24 | Parameters 25 | ---------- 26 | warmup_steps: ``int``, required. 27 | The number of steps to linearly increase the learning rate. 28 | """ 29 | def __init__(self, optimizer, warmup_steps): 30 | self.warmup_steps = warmup_steps 31 | super().__init__(optimizer) 32 | 33 | def get_lr(self): 34 | last_epoch = max(1, self.last_epoch) 35 | scale = self.warmup_steps ** 0.5 * min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5)) 36 | return [base_lr * scale for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /scripts/utils/training_prepare.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from sklearn.model_selection import train_test_split 8 | 9 | 10 | def get_logger(config_file, STAMP): 11 | configs = json.load(open(config_file, "r")) 12 | log_suffix = "" 13 | if "suffix" in configs: 14 | log_suffix = configs["suffix"] 15 | log_file = f"{STAMP}{log_suffix}.log" 16 | logger = logging.getLogger("multi-drug") 17 | logger.setLevel(logging.DEBUG) 18 | fh = logging.FileHandler(os.path.join(configs["work_dir"], log_file)) 19 | ch = logging.StreamHandler() 20 | ch.setLevel(logging.INFO) 21 | formatter = logging.Formatter( 22 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 23 | ) 24 | ch.setFormatter(formatter) 25 | fh.setFormatter(formatter) 26 | # logger.addHandler(ch) 27 | logger.addHandler(fh) 28 | 29 | logger.info(open(config_file, "r").read()) 30 | print(open(config_file, "r").read()) 31 | return logger 32 | 33 | 34 | def prepare_data_cv(config_file, STAMP): 35 | configs = json.load(open(config_file, "r")) 36 | 37 | if not os.path.isdir(configs["work_dir"]): 38 | os.system(f"mkdir -p {configs['work_dir']}") 39 | 40 | data_file = configs["data_file"] 41 | data_type = configs["data_type"] 42 | 43 | data_target = pd.read_csv(configs["target_file"], low_memory=False, index_col=0) 44 | if "drug_id" in configs and configs["drug_id"] != "": 45 | data_target = data_target[configs["drug_id"]] 46 | 47 | data_input = pd.read_csv(data_file, index_col=0) 48 | if data_type[0] != "DR": 49 | data_input = data_input[ 50 | [ 51 | x 52 | for x in data_input.columns 53 | if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type) 54 | ] 55 | ] 56 | 57 | if "downsample" in configs: 58 | _, test_idx = train_test_split( 59 | data_input.index, 60 | test_size=configs["downsample"], 61 | random_state=42, 62 | stratify=data_target, 63 | ) 64 | data_input = data_input.loc[test_idx] 65 | data_target = data_target.loc[test_idx] 66 | # data_input = data_input.sample(n=configs["downsample"], random_state=42) 67 | # data_target = data_target.loc[data_input.index] 68 | 69 | tissues = ( 70 | [x for x in data_input.columns if "tissue" in x] 71 | if "tissue" in data_type 72 | else None 73 | ) 74 | omics_types = [x for x in data_type if x != "tissue"] 75 | with_tissue = "tissue" in data_type 76 | 77 | genes = np.unique( 78 | ([x.split("_")[0] for x in data_input.columns if x.split("_")[0] != "tissue"]) 79 | ) 80 | 81 | num_of_features = data_input.shape[1] 82 | 83 | if configs["task"] == "regression": 84 | val_score_dict = { 85 | "drug_id": [], 86 | "run": [], 87 | "epoch": [], 88 | "mae": [], 89 | "rmse": [], 90 | "corr": [], 91 | "r2": [], 92 | } 93 | elif configs["task"] == "multiclass": 94 | val_score_dict = { 95 | "run": [], 96 | "epoch": [], 97 | "top1_acc": [], 98 | "top3_acc": [], 99 | "f1": [], 100 | "roc_auc": [], 101 | } 102 | else: 103 | val_score_dict = { 104 | "drug_id": [], 105 | "run": [], 106 | "epoch": [], 107 | "accuracy": [], 108 | "auc": [], 109 | } 110 | 111 | ret_dict = { 112 | "data_input_all": data_input, 113 | "data_target_all": data_target, 114 | "val_score_dict": val_score_dict, 115 | "num_of_features": num_of_features, 116 | "genes": genes, 117 | "omics_types": omics_types, 118 | "with_tissue": with_tissue, 119 | "tissues": tissues, 120 | "logger": logger, 121 | } 122 | return ret_dict 123 | 124 | 125 | def prepare_data_independent_test(config_file, STAMP, seed=1): 126 | configs = json.load(open(config_file, "r")) 127 | 128 | if not os.path.isdir(configs["work_dir"]): 129 | os.system(f"mkdir -p {configs['work_dir']}") 130 | 131 | data_file_train = configs["data_file_train"] 132 | data_file_test = configs["data_file_test"] 133 | data_type = configs["data_type"] 134 | 135 | data_target_train = pd.read_csv( 136 | configs["target_file_train"], low_memory=False, index_col=0 137 | ) 138 | data_target_test = pd.read_csv( 139 | configs["target_file_test"], low_memory=False, index_col=0 140 | ) 141 | 142 | if "drug_id" in configs and configs["drug_id"] != "": 143 | data_target_train = data_target_train[configs["drug_id"]] 144 | 145 | data_input_train = pd.read_csv(data_file_train, index_col=0).fillna(0) 146 | data_input_test = pd.read_csv(data_file_test, index_col=0).fillna(0) 147 | common_features = list( 148 | set(data_input_train.columns).intersection(data_input_test.columns) 149 | ) 150 | common_cells_train = list( 151 | set(data_input_train.index).intersection(data_target_train.index) 152 | ) 153 | common_cells_test = list( 154 | set(data_input_test.index).intersection(data_target_test.index) 155 | ) 156 | data_input_train = data_input_train.loc[common_cells_train, common_features] 157 | data_target_train = data_target_train.loc[common_cells_train] 158 | data_input_test = data_input_test.loc[common_cells_test, common_features] 159 | data_target_test = data_target_test.loc[common_cells_test] 160 | 161 | if data_type[0] != "DR": 162 | data_input_train = data_input_train[ 163 | [ 164 | x 165 | for x in data_input_train.columns 166 | if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type) 167 | ] 168 | ] 169 | data_input_test = data_input_test[ 170 | [ 171 | x 172 | for x in data_input_test.columns 173 | if (x.split("_")[1] in data_type) or (x.split("_")[0] in data_type) 174 | ] 175 | ] 176 | 177 | if "downsample" in configs: 178 | _, test_idx = train_test_split( 179 | data_input_train.index, 180 | test_size=configs["downsample"], 181 | random_state=seed, 182 | ) 183 | data_input_train = data_input_train.loc[test_idx] 184 | data_target_train = data_target_train.loc[test_idx] 185 | 186 | tissues = ( 187 | [x for x in data_input_train.columns if "tissue" in x] 188 | if "tissue" in data_type 189 | else None 190 | ) 191 | omics_types = [x for x in data_type if x != "tissue"] 192 | with_tissue = "tissue" in data_type 193 | 194 | genes = np.unique( 195 | ( 196 | [ 197 | x.split("_")[0] 198 | for x in data_input_train.columns 199 | if x.split("_")[0] != "tissue" 200 | ] 201 | ) 202 | ) 203 | 204 | num_of_features = data_input_train.shape[1] 205 | 206 | ret_dict = { 207 | "data_input_train": data_input_train, 208 | "data_target_train": data_target_train, 209 | "data_input_test": data_input_test, 210 | "data_target_test": data_target_test, 211 | "num_of_features": num_of_features, 212 | "genes": genes, 213 | "omics_types": omics_types, 214 | "with_tissue": with_tissue, 215 | "tissues": tissues, 216 | } 217 | return ret_dict 218 | 219 | 220 | def get_score_dict(config_file): 221 | configs = json.load(open(config_file, "r")) 222 | if configs["task"] == "regression": 223 | val_score_dict = { 224 | "drug_id": [], 225 | "run": [], 226 | "epoch": [], 227 | "mae": [], 228 | "rmse": [], 229 | "corr": [], 230 | "r2": [], 231 | } 232 | elif configs["task"] == "multiclass": 233 | val_score_dict = { 234 | "run": [], 235 | "epoch": [], 236 | "top1_acc": [], 237 | "top3_acc": [], 238 | "f1": [], 239 | "roc_auc": [], 240 | } 241 | else: 242 | val_score_dict = { 243 | "drug_id": [], 244 | "run": [], 245 | "epoch": [], 246 | "accuracy": [], 247 | "auc": [], 248 | } 249 | return val_score_dict 250 | --------------------------------------------------------------------------------