├── data └── example ├── records └── log.txt ├── results ├── table.csv └── png │ ├── 6_PDR_comp.png │ ├── 5_PDR_result.png │ ├── 1_AI_DRP_model.png │ ├── 2_P-MDL_dataset.png │ ├── 3_P-DML_model_zoo.png │ └── 4_P-MDL_performance.png ├── png ├── .DS_Store ├── 6_PDR_comp.png ├── 5_PDR_result.png ├── 1_AI_DRP_model.png ├── 2_P-MDL_dataset.png ├── 3_P-DML_model_zoo.png └── 4_P-MDL_performance.png ├── code ├── __pycache__ │ ├── data.cpython-38.pyc │ ├── fine_tuning.cpython-38.pyc │ ├── encoder_decoder.cpython-38.pyc │ └── evaluation_utils.cpython-38.pyc ├── model_train │ ├── __pycache__ │ │ ├── ae.cpython-38.pyc │ │ ├── mlp.cpython-38.pyc │ │ ├── base_ae.cpython-38.pyc │ │ ├── dsn_ae.cpython-38.pyc │ │ ├── train_ae.cpython-38.pyc │ │ ├── train_dsn.cpython-38.pyc │ │ ├── train_ae_adv.cpython-38.pyc │ │ ├── train_ae_mmd.cpython-38.pyc │ │ ├── train_dsn_adv.cpython-38.pyc │ │ ├── train_dsn_mmd.cpython-38.pyc │ │ ├── train_dsrn_adv.cpython-38.pyc │ │ ├── train_dsrn_mmd.cpython-38.pyc │ │ └── loss_and_metrics.cpython-38.pyc │ ├── base_ae.py │ ├── mlp.py │ ├── loss_and_metrics.py │ ├── ae.py │ ├── train_ae_mmd.py │ ├── train_ae.py │ ├── dsn_ae.py │ ├── train_dsn.py │ ├── train_dsn_mmd.py │ ├── train_dsrn_mmd.py │ ├── train_dsn_adv.py │ ├── train_dsrn_adv.py │ └── train_ae_adv.py ├── benchmark.txt ├── params │ └── train_params.json ├── pdr_task.txt ├── encoder_decoder.py ├── run_pretrain_for_pdr.sh ├── run_pdr_task.sh ├── run_pretrain.sh ├── evaluation_utils.py ├── P_MDL.py ├── fine_tuning.py ├── PDR_task.py └── data.py └── README.md /data/example: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /records/log.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /results/table.csv: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /png/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/png/.DS_Store -------------------------------------------------------------------------------- /png/6_PDR_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/png/6_PDR_comp.png -------------------------------------------------------------------------------- /png/5_PDR_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/png/5_PDR_result.png -------------------------------------------------------------------------------- /png/1_AI_DRP_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/png/1_AI_DRP_model.png -------------------------------------------------------------------------------- /png/2_P-MDL_dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/png/2_P-MDL_dataset.png -------------------------------------------------------------------------------- /png/3_P-DML_model_zoo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/png/3_P-DML_model_zoo.png -------------------------------------------------------------------------------- /results/png/6_PDR_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/results/png/6_PDR_comp.png -------------------------------------------------------------------------------- /png/4_P-MDL_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/png/4_P-MDL_performance.png -------------------------------------------------------------------------------- /results/png/5_PDR_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/results/png/5_PDR_result.png -------------------------------------------------------------------------------- /results/png/1_AI_DRP_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/results/png/1_AI_DRP_model.png -------------------------------------------------------------------------------- /results/png/2_P-MDL_dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/results/png/2_P-MDL_dataset.png -------------------------------------------------------------------------------- /results/png/3_P-DML_model_zoo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/results/png/3_P-DML_model_zoo.png -------------------------------------------------------------------------------- /code/__pycache__/data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/__pycache__/data.cpython-38.pyc -------------------------------------------------------------------------------- /results/png/4_P-MDL_performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/results/png/4_P-MDL_performance.png -------------------------------------------------------------------------------- /code/__pycache__/fine_tuning.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/__pycache__/fine_tuning.cpython-38.pyc -------------------------------------------------------------------------------- /code/model_train/__pycache__/ae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/ae.cpython-38.pyc -------------------------------------------------------------------------------- /code/__pycache__/encoder_decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/__pycache__/encoder_decoder.cpython-38.pyc -------------------------------------------------------------------------------- /code/__pycache__/evaluation_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/__pycache__/evaluation_utils.cpython-38.pyc -------------------------------------------------------------------------------- /code/model_train/__pycache__/mlp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/mlp.cpython-38.pyc -------------------------------------------------------------------------------- /code/model_train/__pycache__/base_ae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/base_ae.cpython-38.pyc -------------------------------------------------------------------------------- /code/model_train/__pycache__/dsn_ae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/dsn_ae.cpython-38.pyc -------------------------------------------------------------------------------- /code/model_train/__pycache__/train_ae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/train_ae.cpython-38.pyc -------------------------------------------------------------------------------- /code/model_train/__pycache__/train_dsn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/train_dsn.cpython-38.pyc -------------------------------------------------------------------------------- /code/model_train/__pycache__/train_ae_adv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/train_ae_adv.cpython-38.pyc -------------------------------------------------------------------------------- /code/model_train/__pycache__/train_ae_mmd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/train_ae_mmd.cpython-38.pyc -------------------------------------------------------------------------------- /code/model_train/__pycache__/train_dsn_adv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/train_dsn_adv.cpython-38.pyc -------------------------------------------------------------------------------- /code/model_train/__pycache__/train_dsn_mmd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/train_dsn_mmd.cpython-38.pyc -------------------------------------------------------------------------------- /code/model_train/__pycache__/train_dsrn_adv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/train_dsrn_adv.cpython-38.pyc -------------------------------------------------------------------------------- /code/model_train/__pycache__/train_dsrn_mmd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/train_dsrn_mmd.cpython-38.pyc -------------------------------------------------------------------------------- /code/model_train/__pycache__/loss_and_metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/loss_and_metrics.cpython-38.pyc -------------------------------------------------------------------------------- /code/benchmark.txt: -------------------------------------------------------------------------------- 1 | nohup: ignoring input 2 | Process < 25, finished last task! 3 | All-data pre-training (ADP) model runing! 4 | Start to run Test-pairwise pre-training (TPP) model! 5 | Processing next task: 8 20230718-053353 6 | Running task num: 26 20230718-054854 7 | -------------------------------------------------------------------------------- /code/params/train_params.json: -------------------------------------------------------------------------------- 1 | {"unlabeled": {"batch_size": 64, "lr": 0.0001, "pretrain_num_epochs": 500, 2 | "train_num_epochs": 1000, "alpha": 1.0, "classifier_hidden_dims": [64, 32]}, 3 | "labeled": {"classifier_hidden_dims": [64, 32], "batch_size": 256, "lr": 0.0001, 4 | "train_num_epochs":1500, "decay_coefficient": 0.1}, "encoder_hidden_dims": [512, 256], 5 | "latent_dim": 128, "dop": 0.1} 6 | -------------------------------------------------------------------------------- /code/pdr_task.txt: -------------------------------------------------------------------------------- 1 | nohup: ignoring input 2 | Processed over last task 3 | mkdir: cannot create directory ‘../record/pdr_task’: No such file or directory 4 | run_pdr_task.sh: line 24: ../record/pdr_task/7_2.txt: No such file or directory 5 | run_pdr_task.sh: line 24: ../record/pdr_task/7_3.txt: No such file or directory 6 | run_pdr_task.sh: line 24: ../record/pdr_task/7_4.txt: No such file or directory 7 | Can choose to wait for 5 min ~~~ 8 | run_pdr_task.sh: line 24: ../record/pdr_task/7_5.txt: No such file or directory 9 | -------------------------------------------------------------------------------- /code/model_train/base_ae.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from abc import abstractmethod 3 | from typing import List, Any 4 | 5 | from typing import TypeVar 6 | Tensor = TypeVar('torch.tensor') 7 | 8 | 9 | class BaseAE(nn.Module): 10 | def __init__(self) -> None: 11 | super(BaseAE, self).__init__() 12 | 13 | def encode(self, input: Tensor) -> List[Tensor]: 14 | raise NotImplementedError 15 | 16 | def decode(self, input: Tensor) -> List[Tensor]: 17 | raise NotImplementedError 18 | 19 | def sample(self, batch_size: int, current_device: int, **kwargs) -> Tensor: 20 | raise RuntimeWarning() 21 | 22 | def generate(self, x: Tensor, **kwargs) -> Tensor: 23 | raise NotImplementedError 24 | 25 | @abstractmethod 26 | def forward(self, *inputs: Tensor) -> Tensor: 27 | pass 28 | 29 | @abstractmethod 30 | def loss_function(self, *inputs: Any, **kwargs) -> Tensor: 31 | pass 32 | -------------------------------------------------------------------------------- /code/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # from molecules import Molecules 5 | # from neural_fingerprint import NeuralFingerprint 6 | 7 | from typing import TypeVar 8 | Tensor = TypeVar('torch.tensor') 9 | 10 | 11 | class EncoderDecoder(nn.Module): 12 | 13 | def __init__(self, encoder, decoder, normalize_flag=False): 14 | super(EncoderDecoder, self).__init__() 15 | self.encoder = encoder 16 | self.decoder = decoder 17 | self.normalize_flag = normalize_flag 18 | 19 | 20 | def forward(self, smiles ,input: Tensor) -> Tensor: 21 | encoded_cell = self.encode(input) 22 | if self.normalize_flag: 23 | encoded_cell = nn.functional.normalize(encoded_cell, p=2, dim=1) 24 | encoded_smiles = nn.functional.normalize(smiles, p=2, dim=1) 25 | output = self.decoder(torch.cat((encoded_cell, encoded_smiles), dim=-1)) 26 | 27 | return output 28 | 29 | def encode(self, input: Tensor) -> Tensor: 30 | return self.encoder(input) 31 | 32 | def decode(self, z: Tensor) -> Tensor: 33 | return self.decoder(z) 34 | 35 | -------------------------------------------------------------------------------- /code/run_pretrain_for_pdr.sh: -------------------------------------------------------------------------------- 1 | # nohup bash run_pretrain_for_pdr.sh 1>pdr_pretrain.txt 2>&1 & 2 | 3 | function next_task(){ 4 | select_drug_method="all" 5 | # echo "now '$1'" 6 | method_num=$1 7 | # echo $method_num 8 | store_dir="pdr" 9 | mkdir ../results/${store_dir} 10 | 11 | aim=${store_dir}_TPP_${select_drug_method} 12 | mkdir ../record/${aim} 13 | gpu_param=3 #can be set depending on the number of GPUs, here we use 4 GPUs = 12 / 3 14 | for i in {0..13}; 15 | do 16 | a=`expr $i - 1` 17 | b=`expr $a / ${gpu_param}` 18 | CUDA_VISIBLE_DEVICES=$b nohup python -u P_MDL.py --pretrain_num $i \ 19 | --zero_shot_num 2 \ 20 | --CCL_dataset gdsc1_rebalance \ 21 | --select_drug_method $select_drug_method \ 22 | --method_num ${method_num} \ 23 | --store_dir $store_dir \ 24 | 1> ../record/${aim}/${method_num}_${i}.txt 2>&1 & 25 | done 26 | } 27 | 28 | 29 | for method in {0..7}; 30 | do 31 | while true 32 | do 33 | process=$(ps -ef | grep P_MDL | grep -v "grep" | awk "{print $2}" | wc -l) 34 | current_time=$(date "+%Y%m%d-%H%M%S") 35 | #nvidia-smi 36 | if [ $process -lt 20 ]; then 37 | echo "Process < 21, finished last task!" 38 | next_task $method ; 39 | echo "Processing next task: $method ${current_time}"; 40 | break; 41 | else 42 | echo "Running task num: ${process} ${current_time}"; 43 | sleep 12h; 44 | fi 45 | done 46 | done 47 | 48 | echo "Finish all !!!!" 49 | -------------------------------------------------------------------------------- /code/run_pdr_task.sh: -------------------------------------------------------------------------------- 1 | #nohup bash run_pdr_task.sh 1>pdr_task.txt 2>&1 & 2 | 3 | function next_task(){ 4 | select_drug_method="all" 5 | # echo "now '$1'" 6 | method_num=$1 7 | # echo $method_num 8 | store_dir="pdr_task" 9 | mkdir ../results/${store_dir} 10 | 11 | aim=${store_dir} 12 | mkdir ../records/${aim} 13 | gpu_num=5 14 | for i in {2..20}; 15 | do 16 | if [ $i -gt 5 ]; then 17 | echo "Can choose to wait for 5 min ~~~" 18 | sleep 5m 19 | echo "Sleep over. Now for $i" 20 | fi 21 | a=`expr $i - 1` 22 | #echo $i;echo $a; 23 | b=`expr $a / ${gpu_num}` 24 | #b=3 25 | CUDA_VISIBLE_DEVICES=$b nohup python -u PDR_task.py --pretrain_num $i \ 26 | --zero_shot_num 2 \ 27 | --CCL_dataset gdsc1_rebalance \ 28 | --select_drug_method $select_drug_method \ 29 | --method_num ${method_num} \ 30 | --store_dir $store_dir \ 31 | 1> ../records/${aim}/${method_num}_${i}.txt 2>&1 & 32 | done 33 | } 34 | 35 | 36 | for method in {7..7}; 37 | do 38 | while true 39 | do 40 | process=$(ps -ef | grep PDR_task | grep -v "grep" | awk "{print $2}" | wc -l) 41 | current_time=$(date "+%Y%m%d-%H%M%S") 42 | #nvidia-smi 43 | if [ $process -lt 1 ]; then 44 | echo "Processed over last task"; 45 | next_task $method $CCL_type $tcga_construction $CCL_construction; 46 | echo "Processing next task: $method $CCL_type $tcga_construction $CCL_construction ${current_time}"; 47 | sleep 2s; 48 | break; 49 | else 50 | echo "Last process running now ${current_time}"; 51 | sleep 1h; 52 | fi 53 | done 54 | done 55 | done 56 | done 57 | done 58 | 59 | echo "Finish all !!!!" 60 | 61 | -------------------------------------------------------------------------------- /code/run_pretrain.sh: -------------------------------------------------------------------------------- 1 | # nohup bash run_pretrain.sh 1>benchmark.txt 2>&1 & 2 | 3 | function next_task(){ 4 | select_drug_method="overlap" 5 | method_num=$1 6 | # echo $method_num 7 | 8 | store_dir="benchmark" 9 | mkdir ../results/${store_dir} 10 | 11 | # For All-data pre-training (ADP) model 12 | aim=${store_dir}_ADP_${select_drug_method} 13 | mkdir ../records/${aim} 14 | gpu_param=13 #can be set depending on the number of GPUs, here we use 4 GPUs = 12 / gpu_param 15 | for i in {1..13}; 16 | do 17 | a=`expr $i - 1` 18 | b=`expr $a / ${gpu_param}` 19 | CUDA_VISIBLE_DEVICES=$b nohup python -u P_MDL.py --pretrain_num 0 \ 20 | --zero_shot_num $i \ 21 | --CCL_dataset gdsc1_raw \ 22 | --select_drug_method $select_drug_method \ 23 | --method_num ${method_num} \ 24 | --store_dir $store_dir \ 25 | 1> ../records/${aim}/${method_num}_${i}.txt 2>&1 & 26 | done 27 | 28 | echo "All-data pre-training (ADP) model runing!" 29 | sleep 15m 30 | echo "Start to run Test-pairwise pre-training (TPP) model!" 31 | 32 | # For Test-pairwise pre-training (TPP) model 33 | aim=${store_dir}_TPP_${select_drug_method} 34 | mkdir ../records/${aim} 35 | gpu_param=13 #can be set depending on the number of GPUs, here we use 4 GPUs = 12 / gpu_param 36 | for i in {1..13}; 37 | do 38 | a=`expr $i - 1` 39 | b=`expr $a / ${gpu_param}` 40 | CUDA_VISIBLE_DEVICES=$b nohup python -u P_MDL.py --pretrain_num $i \ 41 | --zero_shot_num 2 \ 42 | --CCL_dataset gdsc1_raw \ 43 | --select_drug_method $select_drug_method \ 44 | --method_num ${method_num} \ 45 | --store_dir $store_dir \ 46 | 1> ../records/${aim}/${method_num}_${i}.txt 2>&1 & 47 | done 48 | 49 | } 50 | 51 | for method in {8..1}; 52 | do 53 | while true 54 | do 55 | process=$(ps -ef | grep P_MDL | grep -v "grep" | awk "{print $2}" | wc -l) 56 | current_time=$(date "+%Y%m%d-%H%M%S") 57 | #nvidia-smi 58 | if [ $process -lt 25 ]; then 59 | echo "Process < 25, finished last task!" 60 | echo "Processing next task: $method ${current_time}"; 61 | next_task $method ; 62 | break; 63 | else 64 | echo "Running task num: ${process} ${current_time}"; 65 | sleep 1h; 66 | fi 67 | done 68 | done 69 | 70 | echo "Finish all !!!!" -------------------------------------------------------------------------------- /code/model_train/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from torch.autograd import Function 5 | 6 | from typing import List 7 | from typing import TypeVar 8 | 9 | Tensor = TypeVar('torch.tensor') 10 | 11 | class RevGrad(Function): 12 | @staticmethod 13 | def forward(ctx, input_, alpha_): 14 | ctx.save_for_backward(input_, alpha_) 15 | output = input_ 16 | return output 17 | 18 | @staticmethod 19 | def backward(ctx, grad_output): # pragma: no cover 20 | grad_input = None 21 | _, alpha_ = ctx.saved_tensors 22 | if ctx.needs_input_grad[0]: 23 | grad_input = -grad_output * alpha_ 24 | return grad_input, None 25 | 26 | 27 | revgrad = RevGrad.apply 28 | 29 | 30 | class RevGrad(torch.nn.Module): 31 | def __init__(self, alpha=1., *args, **kwargs): 32 | """ 33 | A gradient reversal layer. 34 | This layer has no parameters, and simply reverses the gradient 35 | in the backward pass. 36 | """ 37 | super().__init__(*args, **kwargs) 38 | 39 | self._alpha = torch.tensor(alpha, requires_grad=False) 40 | 41 | def forward(self, input_): 42 | return revgrad(input_, self._alpha) 43 | 44 | 45 | class MLP(nn.Module): 46 | 47 | def __init__(self, input_dim: int, output_dim: int, hidden_dims: List = None, dop: float = 0.1, act_fn=nn.SELU, out_fn=None, gr_flag=False, **kwargs) -> None: 48 | super(MLP, self).__init__() 49 | self.output_dim = output_dim 50 | self.dop = dop 51 | 52 | if hidden_dims is None: 53 | hidden_dims = [32, 64, 128, 256, 512] 54 | 55 | modules = [] 56 | if gr_flag: 57 | modules.append(RevGrad()) 58 | 59 | modules.append( 60 | nn.Sequential( 61 | nn.Linear(input_dim, hidden_dims[0], bias=True), 62 | #nn.BatchNorm1d(hidden_dims[0]), 63 | act_fn(), 64 | nn.Dropout(self.dop) 65 | ) 66 | ) 67 | 68 | for i in range(len(hidden_dims) - 1): 69 | modules.append( 70 | nn.Sequential( 71 | nn.Linear(hidden_dims[i], hidden_dims[i + 1], bias=True), 72 | #nn.BatchNorm1d(hidden_dims[i + 1]), 73 | act_fn(), 74 | nn.Dropout(self.dop) 75 | ) 76 | ) 77 | 78 | self.module = nn.Sequential(*modules) 79 | 80 | if out_fn is None: 81 | self.output_layer = nn.Sequential( 82 | nn.Linear(hidden_dims[-1], output_dim, bias=True) 83 | ) 84 | else: 85 | self.output_layer = nn.Sequential( 86 | nn.Linear(hidden_dims[-1], output_dim, bias=True), 87 | out_fn() 88 | ) 89 | 90 | 91 | 92 | def forward(self, input: Tensor) -> Tensor: 93 | embed = self.module(input) 94 | output = self.output_layer(embed) 95 | 96 | return output 97 | -------------------------------------------------------------------------------- /code/model_train/loss_and_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from functools import partial 4 | 5 | 6 | def cov(m, rowvar=False): 7 | """Estimate a covariance matrix given data. 8 | 9 | Covariance indicates the level to which two variables vary together. 10 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 11 | then the covariance matrix element `C_{ij}` is the covariance of 12 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 13 | 14 | Args: 15 | m: A 1-D or 2-D array containing multiple variables and observations. 16 | Each row of `m` represents a variable, and each column a single 17 | observation of all those variables. 18 | rowvar: If `rowvar` is True, then each row represents a 19 | variable, with observations in the columns. Otherwise, the 20 | relationship is transposed: each column represents a variable, 21 | while the rows contain observations. 22 | 23 | Returns: 24 | The covariance matrix of the variables. 25 | """ 26 | if m.dim() > 2: 27 | raise ValueError('m has more than 2 dimensions') 28 | if m.dim() < 2: 29 | m = m.view(1, -1) 30 | if not rowvar and m.size(0) != 1: 31 | m = m.t() 32 | # m = m.type(torch.double) # uncomment this line if desired 33 | fact = 1.0 / (m.size(1) - 1) 34 | m -= torch.mean(m, dim=1, keepdim=True) 35 | mt = m.t() # if complex: mt = m.t().conj() 36 | return fact * m.matmul(mt).squeeze() 37 | 38 | 39 | def pairwise_distance(x, y): 40 | if not len(x.shape) == len(y.shape) == 2: 41 | raise ValueError('Both inputs should be matrices.') 42 | 43 | if x.shape[1] != y.shape[1]: 44 | raise ValueError('The number of features should be the same.') 45 | 46 | x = x.view(x.shape[0], x.shape[1], 1) 47 | y = torch.transpose(y, 0, 1) 48 | output = torch.sum((x - y) ** 2, 1) 49 | output = torch.transpose(output, 0, 1) 50 | 51 | return output 52 | 53 | 54 | def gaussian_kernel_matrix(x, y, sigmas): 55 | sigmas = sigmas.view(sigmas.shape[0], 1) 56 | beta = 1. / (2. * sigmas) 57 | dist = pairwise_distance(x, y).contiguous() 58 | dist_ = dist.view(1, -1) 59 | s = torch.matmul(beta, dist_) 60 | 61 | return torch.sum(torch.exp(-s), 0).view_as(dist) 62 | 63 | 64 | def maximum_mean_discrepancy(x, y, kernel=gaussian_kernel_matrix): 65 | cost = torch.mean(kernel(x, x)) 66 | cost += torch.mean(kernel(y, y)) 67 | cost -= 2 * torch.mean(kernel(x, y)) 68 | 69 | return cost 70 | 71 | 72 | def mmd_loss(source_features, target_features, device): 73 | sigmas = [ 74 | 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100, 75 | 1e3, 1e4, 1e5, 1e6 76 | ] 77 | gaussian_kernel = partial( 78 | gaussian_kernel_matrix, sigmas=Variable(torch.Tensor(sigmas), requires_grad=False).to(device) 79 | ) 80 | 81 | loss_value = maximum_mean_discrepancy(source_features, target_features, kernel=gaussian_kernel) 82 | loss_value = loss_value 83 | 84 | return loss_value 85 | -------------------------------------------------------------------------------- /code/model_train/ae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from base_ae import BaseAE 5 | 6 | from typing import List 7 | from typing import TypeVar 8 | Tensor = TypeVar('torch.tensor') 9 | 10 | class AE(BaseAE): 11 | 12 | def __init__(self, input_dim: int, latent_dim: int, hidden_dims: List = None, dop: float = 0.1, noise_flag: bool = False, **kwargs) -> None: 13 | super(AE, self).__init__() 14 | self.latent_dim = latent_dim 15 | self.noise_flag = noise_flag 16 | self.dop = dop 17 | 18 | if hidden_dims is None: 19 | hidden_dims = [32, 64, 128, 256, 512] 20 | 21 | # build encoder 22 | modules = [] 23 | 24 | modules.append( 25 | nn.Sequential( 26 | nn.Linear(input_dim, hidden_dims[0], bias=True), 27 | #nn.BatchNorm1d(hidden_dims[0]), 28 | nn.ReLU(), 29 | nn.Dropout(self.dop) 30 | ) 31 | ) 32 | 33 | for i in range(len(hidden_dims) - 1): 34 | modules.append( 35 | nn.Sequential( 36 | nn.Linear(hidden_dims[i], hidden_dims[i + 1], bias=True), 37 | #nn.BatchNorm1d(hidden_dims[i + 1]), 38 | nn.ReLU(), 39 | nn.Dropout(self.dop) 40 | ) 41 | ) 42 | modules.append(nn.Dropout(self.dop)) 43 | modules.append(nn.Linear(hidden_dims[-1], latent_dim, bias=True)) 44 | 45 | self.encoder = nn.Sequential(*modules) 46 | 47 | # build decoder 48 | modules = [] 49 | 50 | modules.append( 51 | nn.Sequential( 52 | nn.Linear(latent_dim, hidden_dims[-1], bias=True), 53 | #nn.BatchNorm1d(hidden_dims[-1]), 54 | nn.ReLU(), 55 | nn.Dropout(self.dop) 56 | ) 57 | ) 58 | 59 | hidden_dims.reverse() 60 | 61 | for i in range(len(hidden_dims) - 1): 62 | modules.append( 63 | nn.Sequential( 64 | nn.Linear(hidden_dims[i], hidden_dims[i + 1], bias=True), 65 | #nn.BatchNorm1d(hidden_dims[i + 1]), 66 | nn.ReLU(), 67 | nn.Dropout(self.dop) 68 | ) 69 | ) 70 | self.decoder = nn.Sequential(*modules) 71 | 72 | self.final_layer = nn.Sequential( 73 | nn.Linear(hidden_dims[-1], hidden_dims[-1], bias=True), 74 | #nn.BatchNorm1d(hidden_dims[-1]), 75 | nn.ReLU(), 76 | nn.Dropout(self.dop), 77 | nn.Linear(hidden_dims[-1], input_dim) 78 | ) 79 | 80 | 81 | def encode(self, input: Tensor) -> Tensor: 82 | if self.noise_flag and self.training: 83 | latent_code = self.encoder(input+torch.randn_like(input, requires_grad=False) * 0.1) 84 | else: 85 | latent_code = self.encoder(input) 86 | 87 | return latent_code 88 | 89 | def decode(self, z: Tensor) -> Tensor: 90 | embed = self.decoder(z) 91 | outputs = self.final_layer(embed) 92 | 93 | return outputs 94 | 95 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 96 | z = self.encode(input) 97 | return [input, self.decode(z), z] 98 | 99 | def loss_function(self, *args, **kwargs) -> dict: 100 | input = args[0] 101 | recons = args[1] 102 | 103 | recons_loss = F.mse_loss(input, recons) 104 | loss = recons_loss 105 | 106 | return {'loss': loss, 'recons_loss': recons_loss} 107 | 108 | def sample(self, num_samples: int, current_device: int, **kwargs) -> Tensor: 109 | z = torch.randn(num_samples, self.latent_dim) 110 | 111 | z = z.to(current_device) 112 | samples = self.decode(z) 113 | 114 | return samples 115 | 116 | def generate(self, x: Tensor, **kwargs) -> Tensor: 117 | return self.forward(x)[1] 118 | -------------------------------------------------------------------------------- /code/model_train/train_ae_mmd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from evaluation_utils import eval_ae_epoch, model_save_check 4 | from collections import defaultdict 5 | from ae import AE 6 | from mlp import MLP 7 | from loss_and_metrics import mmd_loss 8 | from encoder_decoder import EncoderDecoder 9 | 10 | 11 | def ae_train_step(ae, s_batch, t_batch, device, optimizer, history, scheduler=None): 12 | ae.zero_grad() 13 | ae.train() 14 | 15 | s_x = s_batch[0].to(device) 16 | t_x = t_batch[0].to(device) 17 | 18 | s_loss_dict = ae.loss_function(*ae(s_x)) 19 | t_loss_dict = ae.loss_function(*ae(t_x)) 20 | 21 | optimizer.zero_grad() 22 | m_loss = mmd_loss(source_features=ae.encode(s_x), target_features=ae.encode(t_x), device=device) 23 | loss = s_loss_dict['loss'] + t_loss_dict['loss'] + m_loss 24 | optimizer.zero_grad() 25 | 26 | loss.backward() 27 | optimizer.step() 28 | if scheduler is not None: 29 | scheduler.step() 30 | 31 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()} 32 | 33 | for k, v in loss_dict.items(): 34 | history[k].append(v) 35 | history['mmd_loss'].append(m_loss.cpu().detach().item()) 36 | 37 | return history 38 | 39 | 40 | def train_ae_mmd(s_dataloaders, t_dataloaders, **kwargs): 41 | """ 42 | 43 | :param s_dataloaders: 44 | :param t_dataloaders: 45 | :param kwargs: 46 | :return: 47 | """ 48 | s_train_dataloader = s_dataloaders[0] 49 | s_test_dataloader = s_dataloaders[1] 50 | 51 | t_train_dataloader = t_dataloaders[0] 52 | t_test_dataloader = t_dataloaders[1] 53 | 54 | autoencoder = AE(input_dim=kwargs['input_dim'], 55 | latent_dim=kwargs['latent_dim'], 56 | hidden_dims=kwargs['encoder_hidden_dims'], 57 | dop=kwargs['dop']).to(kwargs['device']) 58 | classifier = MLP(input_dim=kwargs['latent_dim'], 59 | output_dim=1, 60 | hidden_dims=kwargs['classifier_hidden_dims'], 61 | dop=kwargs['dop']).to(kwargs['device']) 62 | confounder_classifier = EncoderDecoder(encoder=autoencoder.encoder, decoder=classifier).to(kwargs['device']) 63 | 64 | ae_eval_train_history = defaultdict(list) 65 | ae_eval_val_history = defaultdict(list) 66 | 67 | if kwargs['retrain_flag']: 68 | ae_optimizer = torch.optim.AdamW(autoencoder.parameters(), lr=kwargs['lr']) 69 | 70 | 71 | # start autoencoder pretraining 72 | for epoch in range(int(kwargs['pretrain_num_epochs'])): 73 | if epoch % 50 == 0: 74 | print(f'----Autoencoder Pre-Training Epoch {epoch} ----') 75 | for step, s_batch in enumerate(s_train_dataloader): 76 | t_batch = next(iter(t_train_dataloader)) 77 | ae_eval_train_history = ae_train_step(ae=autoencoder, 78 | s_batch=s_batch, 79 | t_batch=t_batch, 80 | device=kwargs['device'], 81 | optimizer=ae_optimizer, 82 | history=ae_eval_train_history) 83 | 84 | ae_eval_val_history = eval_ae_epoch(model=autoencoder, 85 | data_loader=s_test_dataloader, 86 | device=kwargs['device'], 87 | history=ae_eval_val_history 88 | ) 89 | ae_eval_val_history = eval_ae_epoch(model=autoencoder, 90 | data_loader=t_test_dataloader, 91 | device=kwargs['device'], 92 | history=ae_eval_val_history 93 | ) 94 | for k in ae_eval_val_history: 95 | if k != 'best_index': 96 | ae_eval_val_history[k][-2] += ae_eval_val_history[k][-1] 97 | ae_eval_val_history[k].pop() 98 | # print some loss/metric messages 99 | if kwargs['es_flag']: 100 | save_flag, stop_flag = model_save_check(history=ae_eval_val_history, metric_name='loss', 101 | tolerance_count=10) 102 | if save_flag: 103 | torch.save(autoencoder.state_dict(), os.path.join(kwargs['model_save_folder'], 'ae.pt')) 104 | if stop_flag: 105 | break 106 | 107 | if kwargs['es_flag']: 108 | autoencoder.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt'))) 109 | 110 | else: 111 | try: 112 | autoencoder.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt'))) 113 | except FileNotFoundError: 114 | raise Exception("No pre-trained encoder") 115 | 116 | return autoencoder.encoder, (ae_eval_train_history, 117 | ae_eval_val_history) 118 | -------------------------------------------------------------------------------- /code/model_train/train_ae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from evaluation_utils import eval_ae_epoch, model_save_check 4 | from collections import defaultdict 5 | from ae import AE 6 | from collections import OrderedDict 7 | 8 | 9 | def ae_train_step(ae, s_batch, t_batch, device, optimizer, history, scheduler=None): 10 | ae.zero_grad() 11 | ae.train() 12 | 13 | s_x = s_batch[0].to(device) 14 | t_x = t_batch[0].to(device) 15 | 16 | s_loss_dict = ae.loss_function(*ae(s_x)) 17 | t_loss_dict = ae.loss_function(*ae(t_x)) 18 | 19 | optimizer.zero_grad() 20 | loss = s_loss_dict['loss'] + t_loss_dict['loss'] 21 | optimizer.zero_grad() 22 | 23 | loss.backward() 24 | optimizer.step() 25 | if scheduler is not None: 26 | scheduler.step() 27 | 28 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()} 29 | 30 | for k, v in loss_dict.items(): 31 | history[k].append(v) 32 | 33 | return history 34 | 35 | 36 | def train_ae(s_dataloaders, t_dataloaders, **kwargs): 37 | """ 38 | 39 | :param s_dataloaders: 40 | :param t_dataloaders: 41 | :param kwargs: 42 | :return: 43 | """ 44 | s_train_dataloader = s_dataloaders[0] 45 | s_test_dataloader = s_dataloaders[1] 46 | 47 | t_train_dataloader = t_dataloaders[0] 48 | t_test_dataloader = t_dataloaders[1] 49 | # print(kwargs) 50 | 51 | autoencoder = AE(input_dim=kwargs['input_dim'], 52 | latent_dim=kwargs['latent_dim'], 53 | hidden_dims=kwargs['encoder_hidden_dims'], 54 | noise_flag=False, 55 | dop=kwargs['dop']).to(kwargs['device']) 56 | 57 | 58 | ae_eval_train_history = defaultdict(list) 59 | ae_eval_val_history = defaultdict(list) 60 | 61 | if kwargs['retrain_flag']: 62 | ae_optimizer = torch.optim.AdamW(autoencoder.parameters(), lr=kwargs['lr']) 63 | 64 | 65 | # start autoencoder pretraining 66 | for epoch in range(int(kwargs['train_num_epochs'])): 67 | if epoch % 50 == 0: 68 | print(f'----Autoencoder Training Epoch {epoch} ----') 69 | for step, s_batch in enumerate(s_train_dataloader): 70 | t_batch = next(iter(t_train_dataloader)) 71 | ae_eval_train_history = ae_train_step(ae=autoencoder, 72 | s_batch=s_batch, 73 | t_batch=t_batch, 74 | device=kwargs['device'], 75 | optimizer=ae_optimizer, 76 | history=ae_eval_train_history) 77 | 78 | ae_eval_val_history = eval_ae_epoch(model=autoencoder, 79 | data_loader=s_test_dataloader, 80 | device=kwargs['device'], 81 | history=ae_eval_val_history 82 | ) 83 | ae_eval_val_history = eval_ae_epoch(model=autoencoder, 84 | data_loader=t_test_dataloader, 85 | device=kwargs['device'], 86 | history=ae_eval_val_history 87 | ) 88 | for k in ae_eval_val_history: 89 | if k != 'best_index': 90 | ae_eval_val_history[k][-2] += ae_eval_val_history[k][-1] 91 | ae_eval_val_history[k].pop() 92 | # print some loss/metric messages 93 | if kwargs['es_flag']: 94 | save_flag, stop_flag = model_save_check(history=ae_eval_val_history, metric_name='loss', 95 | tolerance_count=50) 96 | if save_flag: 97 | torch.save(autoencoder.state_dict(), os.path.join(kwargs['model_save_folder'], 'ae.pt')) 98 | if stop_flag: 99 | break 100 | 101 | if kwargs['es_flag']: 102 | autoencoder.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt'))) 103 | torch.save(autoencoder.state_dict(), os.path.join(kwargs['model_save_folder'], 'ae.pt')) 104 | else: 105 | try: 106 | autoencoder.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt'))) 107 | # except: 108 | loaded_model = torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt')) 109 | # new_loaded_model = {key: val for key, val in loaded_model.items() if key in autoencoder.state_dict()} 110 | # new_loaded_model['shared_encoder.output_layer.0.weight'] = loaded_model[ 111 | # 'shared_encoder.output_layer.3.weight'] 112 | # new_loaded_model['shared_encoder.output_layer.0.bias'] = loaded_model[ 113 | # 'shared_encoder.output_layer.3.bias'] 114 | # new_loaded_model['decoder.output_layer.0.weight'] = loaded_model['decoder.output_layer.3.weight'] 115 | # new_loaded_model['decoder.output_layer.0.bias'] = loaded_model['decoder.output_layer.3.bias'] 116 | 117 | # corrected_model = OrderedDict({key: new_loaded_model[key] for key in autoencoder.state_dict()}) 118 | # autoencoder.load_state_dict(corrected_model) 119 | 120 | except FileNotFoundError: 121 | raise Exception("No pre-trained encoder") 122 | 123 | return autoencoder.encoder, (ae_eval_train_history, 124 | ae_eval_val_history) 125 | -------------------------------------------------------------------------------- /code/model_train/dsn_ae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from base_ae import BaseAE 5 | from typing import List 6 | 7 | from typing import TypeVar 8 | Tensor = TypeVar('torch.tensor') 9 | 10 | 11 | class DSNAE(BaseAE): 12 | 13 | def __init__(self, shared_encoder, decoder, input_dim: int, latent_dim: int, alpha: float = 1.0, 14 | hidden_dims: List = None, dop: float = 0.1, noise_flag: bool = False, norm_flag: bool = False, 15 | **kwargs) -> None: 16 | super(DSNAE, self).__init__() 17 | self.latent_dim = latent_dim 18 | self.alpha = alpha 19 | self.noise_flag = noise_flag 20 | self.dop = dop 21 | self.norm_flag = norm_flag 22 | 23 | if hidden_dims is None: 24 | hidden_dims = [32, 64, 128, 256, 512] 25 | 26 | self.shared_encoder = shared_encoder 27 | self.decoder = decoder 28 | # build encoder 29 | modules = [] 30 | 31 | modules.append( 32 | nn.Sequential( 33 | nn.Linear(input_dim, hidden_dims[0], bias=True), 34 | # nn.BatchNorm1d(hidden_dims[0]), 35 | nn.ReLU(), 36 | nn.Dropout(self.dop) 37 | ) 38 | ) 39 | 40 | for i in range(len(hidden_dims) - 1): 41 | modules.append( 42 | nn.Sequential( 43 | nn.Linear(hidden_dims[i], hidden_dims[i + 1], bias=True), 44 | # nn.Dropout(0.1), 45 | # nn.BatchNorm1d(hidden_dims[i + 1]), 46 | nn.ReLU(), 47 | nn.Dropout(self.dop) 48 | ) 49 | ) 50 | modules.append(nn.Dropout(self.dop)) 51 | modules.append(nn.Linear(hidden_dims[-1], latent_dim, bias=True)) 52 | # modules.append(nn.LayerNorm(latent_dim, eps=1e-12, elementwise_affine=False)) 53 | 54 | self.private_encoder = nn.Sequential(*modules) 55 | 56 | # build decoder 57 | # modules = [] 58 | # 59 | # modules.append( 60 | # nn.Sequential( 61 | # nn.Linear(2 * latent_dim, hidden_dims[-1], bias=True), 62 | # # nn.Dropout(0.1), 63 | # nn.BatchNorm1d(hidden_dims[-1]), 64 | # nn.ReLU() 65 | # ) 66 | # ) 67 | # 68 | # hidden_dims.reverse() 69 | # 70 | # for i in range(len(hidden_dims) - 1): 71 | # modules.append( 72 | # nn.Sequential( 73 | # nn.Linear(hidden_dims[i], hidden_dims[i + 1], bias=True), 74 | # nn.BatchNorm1d(hidden_dims[i + 1]), 75 | # # nn.Dropout(0.1), 76 | # nn.ReLU() 77 | # ) 78 | # ) 79 | # self.decoder = nn.Sequential(*modules) 80 | 81 | # self.final_layer = nn.Sequential( 82 | # nn.Linear(hidden_dims[-1], hidden_dims[-1], bias=True), 83 | # nn.BatchNorm1d(hidden_dims[-1]), 84 | # nn.ReLU(), 85 | # nn.Dropout(0.1), 86 | # nn.Linear(hidden_dims[-1], input_dim) 87 | # ) 88 | 89 | def p_encode(self, input: Tensor) -> Tensor: 90 | if self.noise_flag and self.training: 91 | latent_code = self.private_encoder(input + torch.randn_like(input, requires_grad=False) * 0.1) 92 | else: 93 | latent_code = self.private_encoder(input) 94 | 95 | if self.norm_flag: 96 | return F.normalize(latent_code, p=2, dim=1) 97 | else: 98 | return latent_code 99 | 100 | def s_encode(self, input: Tensor) -> Tensor: 101 | if self.noise_flag and self.training: 102 | latent_code = self.shared_encoder(input + torch.randn_like(input, requires_grad=False) * 0.1) 103 | else: 104 | latent_code = self.shared_encoder(input) 105 | if self.norm_flag: 106 | return F.normalize(latent_code, p=2, dim=1) 107 | else: 108 | return latent_code 109 | 110 | def encode(self, input: Tensor) -> Tensor: 111 | p_latent_code = self.p_encode(input) 112 | s_latent_code = self.s_encode(input) 113 | 114 | return torch.cat((p_latent_code, s_latent_code), dim=1) 115 | 116 | def decode(self, z: Tensor) -> Tensor: 117 | # embed = self.decoder(z) 118 | # outputs = self.final_layer(embed) 119 | outputs = self.decoder(z) 120 | 121 | return outputs 122 | 123 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 124 | z = self.encode(input) 125 | return [input, self.decode(z), z] 126 | 127 | def loss_function(self, *args, **kwargs) -> dict: 128 | input = args[0] 129 | recons = args[1] 130 | z = args[2] 131 | 132 | p_z = z[:, :z.shape[1] // 2] 133 | s_z = z[:, z.shape[1] // 2:] 134 | 135 | recons_loss = F.mse_loss(input, recons) 136 | 137 | s_l2_norm = torch.norm(s_z, p=2, dim=1, keepdim=True).detach() 138 | s_l2 = s_z.div(s_l2_norm.expand_as(s_z) + 1e-6) 139 | 140 | p_l2_norm = torch.norm(p_z, p=2, dim=1, keepdim=True).detach() 141 | p_l2 = p_z.div(p_l2_norm.expand_as(p_z) + 1e-6) 142 | 143 | ortho_loss = torch.mean((s_l2.t().mm(p_l2)).pow(2)) 144 | # ortho_loss = torch.square(torch.norm(torch.matmul(s_z.t(), p_z), p='fro')) 145 | # ortho_loss = torch.mean(torch.square(torch.diagonal(torch.matmul(p_z, s_z.t())))) 146 | # if recons_loss > ortho_loss: 147 | # loss = recons_loss + self.alpha * 0.1 * ortho_loss 148 | # else: 149 | loss = recons_loss + self.alpha * ortho_loss 150 | return {'loss': loss, 'recons_loss': recons_loss, 'ortho_loss': ortho_loss} 151 | 152 | def sample(self, num_samples: int, current_device: int, **kwargs) -> Tensor: 153 | z = torch.randn(num_samples, self.latent_dim) 154 | z = z.to(current_device) 155 | samples = self.decode(z) 156 | 157 | return samples 158 | 159 | def generate(self, x: Tensor, **kwargs) -> Tensor: 160 | return self.forward(x)[1] 161 | -------------------------------------------------------------------------------- /code/model_train/train_dsn.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import chain 3 | 4 | from dsn_ae import DSNAE 5 | from evaluation_utils import * 6 | from mlp import MLP 7 | 8 | 9 | def eval_dsnae_epoch(model, data_loader, device, history): 10 | """ 11 | 12 | :param model: 13 | :param data_loader: 14 | :param device: 15 | :param history: 16 | :return: 17 | """ 18 | model.eval() 19 | avg_loss_dict = defaultdict(float) 20 | for x_batch in data_loader: 21 | x_batch = x_batch[0].to(device) 22 | with torch.no_grad(): 23 | loss_dict = model.loss_function(*(model(x_batch))) 24 | for k, v in loss_dict.items(): 25 | avg_loss_dict[k] += v.cpu().detach().item() / len(data_loader) 26 | 27 | for k, v in avg_loss_dict.items(): 28 | history[k].append(v) 29 | return history 30 | 31 | 32 | def dsn_ae_train_step(s_dsnae, t_dsnae, s_batch, t_batch, device, optimizer, history, scheduler=None): 33 | s_dsnae.zero_grad() 34 | t_dsnae.zero_grad() 35 | s_dsnae.train() 36 | t_dsnae.train() 37 | 38 | s_x = s_batch[0].to(device) 39 | t_x = t_batch[0].to(device) 40 | 41 | s_loss_dict = s_dsnae.loss_function(*s_dsnae(s_x)) 42 | t_loss_dict = t_dsnae.loss_function(*t_dsnae(t_x)) 43 | 44 | optimizer.zero_grad() 45 | loss = s_loss_dict['loss'] + t_loss_dict['loss'] 46 | loss.backward() 47 | 48 | optimizer.step() 49 | if scheduler is not None: 50 | scheduler.step() 51 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()} 52 | 53 | for k, v in loss_dict.items(): 54 | history[k].append(v) 55 | 56 | return history 57 | 58 | 59 | def train_dsn(s_dataloaders, t_dataloaders, **kwargs): 60 | """ 61 | 62 | :param s_dataloaders: 63 | :param t_dataloaders: 64 | :param kwargs: 65 | :return: 66 | """ 67 | s_train_dataloader = s_dataloaders[0] 68 | s_test_dataloader = s_dataloaders[1] 69 | 70 | t_train_dataloader = t_dataloaders[0] 71 | t_test_dataloader = t_dataloaders[1] 72 | 73 | shared_encoder = MLP(input_dim=kwargs['input_dim'], 74 | output_dim=kwargs['latent_dim'], 75 | hidden_dims=kwargs['encoder_hidden_dims'], 76 | dop=kwargs['dop']).to(kwargs['device']) 77 | 78 | shared_decoder = MLP(input_dim=2 * kwargs['latent_dim'], 79 | output_dim=kwargs['input_dim'], 80 | hidden_dims=kwargs['encoder_hidden_dims'][::-1], 81 | dop=kwargs['dop']).to(kwargs['device']) 82 | 83 | s_dsnae = DSNAE(shared_encoder=shared_encoder, 84 | decoder=shared_decoder, 85 | alpha=kwargs['alpha'], 86 | input_dim=kwargs['input_dim'], 87 | latent_dim=kwargs['latent_dim'], 88 | hidden_dims=kwargs['encoder_hidden_dims'], 89 | dop=kwargs['dop'], 90 | norm_flag=kwargs['norm_flag']).to(kwargs['device']) 91 | 92 | t_dsnae = DSNAE(shared_encoder=shared_encoder, 93 | decoder=shared_decoder, 94 | alpha=kwargs['alpha'], 95 | input_dim=kwargs['input_dim'], 96 | latent_dim=kwargs['latent_dim'], 97 | hidden_dims=kwargs['encoder_hidden_dims'], 98 | dop=kwargs['dop'], 99 | norm_flag=kwargs['norm_flag']).to(kwargs['device']) 100 | 101 | device = kwargs['device'] 102 | 103 | dsnae_train_history = defaultdict(list) 104 | dsnae_val_history = defaultdict(list) 105 | 106 | if kwargs['retrain_flag']: 107 | ae_params = [t_dsnae.private_encoder.parameters(), 108 | s_dsnae.private_encoder.parameters(), 109 | shared_decoder.parameters(), 110 | shared_encoder.parameters() 111 | ] 112 | 113 | ae_optimizer = torch.optim.AdamW(chain(*ae_params), lr=kwargs['lr']) 114 | 115 | 116 | for epoch in range(int(kwargs['train_num_epochs'])): 117 | if epoch % 50 == 0: 118 | print(f'AE training epoch {epoch}') 119 | for step, s_batch in enumerate(s_train_dataloader): 120 | t_batch = next(iter(t_train_dataloader)) 121 | dsnae_train_history = dsn_ae_train_step(s_dsnae=s_dsnae, 122 | t_dsnae=t_dsnae, 123 | s_batch=s_batch, 124 | t_batch=t_batch, 125 | device=device, 126 | optimizer=ae_optimizer, 127 | history=dsnae_train_history) 128 | dsnae_val_history = eval_dsnae_epoch(model=s_dsnae, 129 | data_loader=s_test_dataloader, 130 | device=device, 131 | history=dsnae_val_history 132 | ) 133 | dsnae_val_history = eval_dsnae_epoch(model=t_dsnae, 134 | data_loader=t_test_dataloader, 135 | device=device, 136 | history=dsnae_val_history 137 | ) 138 | for k in dsnae_val_history: 139 | if k != 'best_index': 140 | dsnae_val_history[k][-2] += dsnae_val_history[k][-1] 141 | dsnae_val_history[k].pop() 142 | 143 | if kwargs['es_flag']: 144 | save_flag, stop_flag = model_save_check(dsnae_val_history, metric_name='loss', tolerance_count=50) 145 | if save_flag: 146 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 's_dsnae.pt')) 147 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 't_dsnae.pt')) 148 | if stop_flag: 149 | break 150 | 151 | if kwargs['es_flag']: 152 | s_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 's_dsnae.pt'))) 153 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 't_dsnae.pt'))) 154 | 155 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 's_dsnae.pt')) 156 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 't_dsnae.pt')) 157 | 158 | else: 159 | try: 160 | loaded_model = torch.load(os.path.join(kwargs['model_save_folder'], 't_dsnae.pt')) 161 | print({key:val.shape for key,val in loaded_model.items()}) 162 | print({key:val.shape for key,val in t_dsnae.state_dict().items()}) 163 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 't_dsnae.pt'))) 164 | except FileNotFoundError: 165 | raise Exception("No pre-trained encoder") 166 | 167 | 168 | return t_dsnae.shared_encoder, (dsnae_train_history, dsnae_val_history) 169 | -------------------------------------------------------------------------------- /code/model_train/train_dsn_mmd.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import chain 3 | 4 | from dsn_ae import DSNAE 5 | from evaluation_utils import * 6 | from mlp import MLP 7 | from loss_and_metrics import mmd_loss 8 | 9 | 10 | def eval_dsnae_epoch(model, data_loader, device, history): 11 | """ 12 | 13 | :param model: 14 | :param data_loader: 15 | :param device: 16 | :param history: 17 | :return: 18 | """ 19 | model.eval() 20 | avg_loss_dict = defaultdict(float) 21 | for x_batch in data_loader: 22 | x_batch = x_batch[0].to(device) 23 | with torch.no_grad(): 24 | loss_dict = model.loss_function(*(model(x_batch))) 25 | for k, v in loss_dict.items(): 26 | avg_loss_dict[k] += v.cpu().detach().item() / len(data_loader) 27 | 28 | for k, v in avg_loss_dict.items(): 29 | history[k].append(v) 30 | return history 31 | 32 | 33 | def dsn_ae_train_step(s_dsnae, t_dsnae, s_batch, t_batch, device, optimizer, history, scheduler=None): 34 | s_dsnae.zero_grad() 35 | t_dsnae.zero_grad() 36 | s_dsnae.train() 37 | t_dsnae.train() 38 | 39 | s_x = s_batch[0].to(device) 40 | t_x = t_batch[0].to(device) 41 | 42 | s_code = s_dsnae.s_encode(s_x) 43 | t_code = t_dsnae.s_encode(t_x) 44 | 45 | s_loss_dict = s_dsnae.loss_function(*s_dsnae(s_x)) 46 | t_loss_dict = t_dsnae.loss_function(*t_dsnae(t_x)) 47 | 48 | optimizer.zero_grad() 49 | m_loss = mmd_loss(source_features=s_code, target_features=t_code, device=device) 50 | loss = s_loss_dict['loss'] + t_loss_dict['loss'] + m_loss 51 | loss.backward() 52 | 53 | optimizer.step() 54 | if scheduler is not None: 55 | scheduler.step() 56 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()} 57 | 58 | for k, v in loss_dict.items(): 59 | history[k].append(v) 60 | 61 | history['mmd_loss'].append(m_loss.cpu().detach().item()) 62 | 63 | return history 64 | 65 | 66 | def train_dsn_mmd(s_dataloaders, t_dataloaders, **kwargs): 67 | """ 68 | 69 | :param s_dataloaders: 70 | :param t_dataloaders: 71 | :param kwargs: 72 | :return: 73 | """ 74 | s_train_dataloader = s_dataloaders[0] 75 | s_test_dataloader = s_dataloaders[1] 76 | 77 | t_train_dataloader = t_dataloaders[0] 78 | t_test_dataloader = t_dataloaders[1] 79 | 80 | shared_encoder = MLP(input_dim=kwargs['input_dim'], 81 | output_dim=kwargs['latent_dim'], 82 | hidden_dims=kwargs['encoder_hidden_dims'], 83 | dop=kwargs['dop']).to(kwargs['device']) 84 | 85 | shared_decoder = MLP(input_dim=2 * kwargs['latent_dim'], 86 | output_dim=kwargs['input_dim'], 87 | hidden_dims=kwargs['encoder_hidden_dims'][::-1], 88 | dop=kwargs['dop']).to(kwargs['device']) 89 | 90 | s_dsnae = DSNAE(shared_encoder=shared_encoder, 91 | decoder=shared_decoder, 92 | alpha=kwargs['alpha'], 93 | input_dim=kwargs['input_dim'], 94 | latent_dim=kwargs['latent_dim'], 95 | hidden_dims=kwargs['encoder_hidden_dims'], 96 | dop=kwargs['dop'], 97 | norm_flag=False).to(kwargs['device']) 98 | 99 | t_dsnae = DSNAE(shared_encoder=shared_encoder, 100 | decoder=shared_decoder, 101 | alpha=kwargs['alpha'], 102 | input_dim=kwargs['input_dim'], 103 | latent_dim=kwargs['latent_dim'], 104 | hidden_dims=kwargs['encoder_hidden_dims'], 105 | dop=kwargs['dop'], 106 | norm_flag=False).to(kwargs['device']) 107 | 108 | 109 | device = kwargs['device'] 110 | 111 | dsnae_train_history = defaultdict(list) 112 | dsnae_val_history = defaultdict(list) 113 | 114 | if kwargs['retrain_flag']: 115 | ae_params = [t_dsnae.private_encoder.parameters(), 116 | s_dsnae.private_encoder.parameters(), 117 | shared_decoder.parameters(), 118 | shared_encoder.parameters() 119 | ] 120 | 121 | ae_optimizer = torch.optim.AdamW(chain(*ae_params), lr=kwargs['lr']) 122 | 123 | 124 | for epoch in range(int(kwargs['train_num_epochs'])): 125 | if epoch % 50 == 0: 126 | print(f'AE training epoch {epoch}') 127 | for step, s_batch in enumerate(s_train_dataloader): 128 | t_batch = next(iter(t_train_dataloader)) 129 | dsnae_train_history = dsn_ae_train_step(s_dsnae=s_dsnae, 130 | t_dsnae=t_dsnae, 131 | s_batch=s_batch, 132 | t_batch=t_batch, 133 | device=device, 134 | optimizer=ae_optimizer, 135 | history=dsnae_train_history) 136 | dsnae_val_history = eval_dsnae_epoch(model=s_dsnae, 137 | data_loader=s_test_dataloader, 138 | device=device, 139 | history=dsnae_val_history 140 | ) 141 | dsnae_val_history = eval_dsnae_epoch(model=t_dsnae, 142 | data_loader=t_test_dataloader, 143 | device=device, 144 | history=dsnae_val_history 145 | ) 146 | for k in dsnae_val_history: 147 | if k != 'best_index': 148 | dsnae_val_history[k][-2] += dsnae_val_history[k][-1] 149 | dsnae_val_history[k].pop() 150 | 151 | if kwargs['es_flag']: 152 | save_flag, stop_flag = model_save_check(dsnae_val_history, metric_name='loss', tolerance_count=50) 153 | if save_flag: 154 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'm_s_dsnae.pt')) 155 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt')) 156 | if stop_flag: 157 | break 158 | 159 | if kwargs['es_flag']: 160 | s_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'm_s_dsnae.pt'))) 161 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt'))) 162 | 163 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'm_s_dsnae.pt')) 164 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt')) 165 | 166 | else: 167 | try: 168 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt'))) 169 | except FileNotFoundError: 170 | raise Exception("No pre-trained encoder") 171 | 172 | return t_dsnae.shared_encoder, (dsnae_train_history, dsnae_val_history) 173 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Drug-Transfer-Learning 2 | Pre-clinical drug discovery (PDD) faces the low efficiency dilemma. One of the reasons is the lack of cross-drug efficacy evaluation infrastructure at the patient level. Here we propose Patient Multi-Drug Learning(P-MDL) task, and construct the P-MDL dataset and model zoo. The best P-MDL model DSN-adv achieve the SOTA performance in all of the 13 tumor types compared with previous SOTA models. 3 | 4 | You can also check out our [podcast](https://www.ximalaya.com/sound/938380402) and watch our intro videos (in both [Chinese](https://www.bilibili.com/video/BV1ta2LBzEJd/?spm_id_from=333.1387.homepage.video_card.click&vd_source=3d478d07006b6eaf39bdb03ccbf5ec3a) and [English](https://youtu.be/G3_kJq5isG8)) for a quick overview. 5 | 6 | ## Features 7 | 8 | ### P-MDL task 9 | Artificial intelligence (AI) models used for drug response prediction (DRP) tasks are generally classified into Single-Drug learning (SDL) and Multi-Drug Learning (MDL) paradigms. SDL paradigms have been adapted to the patient level and evaluate within-drug response, disregarding tumor types. However, there exist substantial differences in treatment response and survival outcomes among different tumor types, indicating that tumor type is a crucial confounding factor that can not be overlooked when predicting drug response. Additionally, SDL paradigms fail to assess cross-drug response, while MDL paradigms are currently limited to the cell line level. Therefore, we propose the P-MDL approach, which aims to achieve a comprehensive view of drug response at the patient level. 10 | 11 |

12 | DRP paradigms and limitations 13 |

14 | 15 | 16 | 17 | ### P-MDL dataset 18 | We constructed the first P-MDL dataset from publicly available data. Tumor types with relatively sufficient data were filtered out. Finally, 13 tumor types were selected for the P-MDL dataset. 19 | 20 | 21 |

22 | P-MDL dataset spanning 13 tumor types 23 |

24 | 25 | 26 | ### P-MDL model zoo 27 | 28 | P-MDL model zoo includes eight models employing different transfer learning methods: 29 | | **P-MDL models** | **Description** | 30 | | ------- | ------- | 31 | | **ae** | **Autoencoder** used for encoding both of the gene expression profiles (GEPs) of cell lines and patients. | 32 | | **ae-mmd** | **ae** model added with another **mmd-loss**. | 33 | | **ae-adv** | **ae** model added with another **adv-loss**. | 34 | | **dsn**| **Domain seperation network** has been successfully applied in computer vision. Here, it is used for the GEPs encoding of cell lines and patients. | 35 | | **dsn-adv**| **dsn** model added with another **adv-loss**. | 36 | | **dsrn** | An variant of **dsn** model.| 37 | | **dsrn-mmd** | **dsrn** model added with another **mmd-loss**. | 38 | | **dsrn-adv**| **dsrn** model added with another **adv-loss**. | 39 | 40 | 49 | 50 | 51 | 52 | ### Model evaluation 53 | One of the P-MDL models (DSN-adv) outperforms all of the P-SDL and C-MDL models across all tumor types. 54 |

55 | P-MDL performance in 13 tumor types 56 |

57 | 58 | 59 | ### Patient Drug Response (PDR) prediction for PDD 60 | To further validate the P-MDL models and demonstrate their potential in PDD applications, the test-pairwise pre-trained DSN-adv model was used to screen 233 small molecules for patients of 13 tumor types. 61 | Take tumor type COAD as an example, most drugs were inefficient, but a few drugs showed potential efficacy for over half of COAD patients. 62 | 63 | 64 |

65 | PDR score and analysis 66 |

67 | 68 | 69 | ## Installation 70 | 71 | To support basic usage of P-MDL task, run the following commands: 72 | 73 | ```bash 74 | conda create -n P-MDL python=3.8 75 | conda activate P-MDL 76 | conda install -c conda-forge rdkit 77 | 78 | pip install torch 79 | pip install pandas 80 | pip install numpy 81 | pip install sklearn 82 | ``` 83 | 84 | ## Quick Start 85 | 86 | Here, we provide the instruction for a quick start. 87 | 88 | 89 | 96 | 97 | ### Step 1: Data Preparation 98 | 99 | Download the data at Zenodo [here](https://zenodo.org/record/8021167). Move the download files to ``data/`` folder. 100 | 101 | ### Step 2: Training and Evaluation 102 | 103 | Run: 104 | 105 | ```bash 106 | cd code 107 | nohup bash run_pretrain.sh 1>benchmark.txt 2>&1 & 108 | less benchmark.txt 109 | ``` 110 | 111 | The ``benchmark.txt`` will look like the following: 112 | ```bash 113 | Processing next task: 8 20230718-043332 114 | All-data pre-training (ADP) model runing! 115 | Start to run Test-pairwise pre-training (TPP) model! 116 | ... 117 | ``` 118 | 119 | The bash file [run_pretrain.sh](./code/run_pretrain.sh) will run the script [P_MDL.py](./code/P_MDL.py) in a proper params setting. You can also find model Log output in ``records/`` folder and model evaluation results in ``results/`` folder. 120 | 121 | ### Step 3: Interfence the PDR scores for PDD application 122 | 123 | Run: 124 | 125 | ```bash 126 | cd code 127 | nohup bash run_pretrain_for_pdr.sh 1>pdr_pretrain.txt 2>&1 & 128 | less pdr_pretrain.txt 129 | ``` 130 | 131 | The ``pdr_pretrain.txt`` will look like the following: 132 | ```bash 133 | Processing next task: 8 20230718-043332 134 | All-data pre-training (ADP) model runing! 135 | Start to run Test-pairwise pre-training (TPP) model! 136 | ... 137 | ``` 138 | 139 | For PDD application, we need to predict the efficacy of all drugs to every patients. So here we set the params ``--select_drug_method all`` in the bash file [run_pretrain_for_pdr.sh](./code/run_pretrain_for_pdr.sh) to recover the model which can be used for all-drug response prediction. 140 | 141 | Then you can run the bash file [run_pdr_task.sh](./code/run_pdr_task.sh) by ```nohup bash run_pdr_task.sh 1>pdr_task.txt 2>&1 &```, which will call the script [PDR_task.py](./code/PDR_task.py) to predict the efficacy of all drugs to every patients. 142 | 143 | 144 | 145 | ## Contact Us 146 | 147 | As a pre-alpha version release, we are looking forward to user feedback to help us improve our framework. If you have any questions or suggestions, please open an issue or contact [wuyushuai0727@gmail.com]. 148 | 149 | 150 | ## Cite Us 151 | 152 | If you find our open-sourced code & models helpful to your research, please consider giving this repo a star🌟 and citing📑 the following article. Thank you for your support! 153 | ``` 154 | @misc{P_MDL_code, 155 | author={Yushuai Wu}, 156 | title={Code of Multi-Drug-Transfer-Learning}, 157 | year={2025}, 158 | howpublished={\url{https://github.com/wuys13/Multi-Drug-Transfer-Learning.git}} 159 | } 160 | ``` 161 | 162 | ## Contributing 163 | 164 | If you encounter problems, feel free to create an issue! -------------------------------------------------------------------------------- /code/model_train/train_dsrn_mmd.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import chain 3 | 4 | from dsn_ae import DSNAE 5 | from evaluation_utils import * 6 | from mlp import MLP 7 | from loss_and_metrics import mmd_loss 8 | from collections import OrderedDict 9 | 10 | def eval_dsnae_epoch(model, data_loader, device, history): 11 | """ 12 | 13 | :param model: 14 | :param data_loader: 15 | :param device: 16 | :param history: 17 | :return: 18 | """ 19 | model.eval() 20 | avg_loss_dict = defaultdict(float) 21 | for x_batch in data_loader: 22 | x_batch = x_batch[0].to(device) 23 | with torch.no_grad(): 24 | loss_dict = model.loss_function(*(model(x_batch))) 25 | for k, v in loss_dict.items(): 26 | avg_loss_dict[k] += v.cpu().detach().item() / len(data_loader) 27 | 28 | for k, v in avg_loss_dict.items(): 29 | history[k].append(v) 30 | return history 31 | 32 | 33 | def dsn_ae_train_step(s_dsnae, t_dsnae, s_batch, t_batch, device, optimizer, history, scheduler=None): 34 | s_dsnae.zero_grad() 35 | t_dsnae.zero_grad() 36 | s_dsnae.train() 37 | t_dsnae.train() 38 | 39 | s_x = s_batch[0].to(device) 40 | t_x = t_batch[0].to(device) 41 | 42 | s_code = s_dsnae.encode(s_x) 43 | t_code = t_dsnae.encode(t_x) 44 | 45 | s_loss_dict = s_dsnae.loss_function(*s_dsnae(s_x)) 46 | t_loss_dict = t_dsnae.loss_function(*t_dsnae(t_x)) 47 | 48 | optimizer.zero_grad() 49 | m_loss = mmd_loss(source_features=s_code, target_features=t_code, device=device) 50 | loss = s_loss_dict['loss'] + t_loss_dict['loss'] + m_loss 51 | loss.backward() 52 | 53 | optimizer.step() 54 | if scheduler is not None: 55 | scheduler.step() 56 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()} 57 | 58 | for k, v in loss_dict.items(): 59 | history[k].append(v) 60 | history['mmd_loss'].append(m_loss.cpu().detach().item()) 61 | 62 | return history 63 | 64 | 65 | def train_dsrn_mmd(s_dataloaders, t_dataloaders, **kwargs): 66 | """ 67 | 68 | :param s_dataloaders: 69 | :param t_dataloaders: 70 | :param kwargs: 71 | :return: 72 | """ 73 | s_train_dataloader = s_dataloaders[0] 74 | s_test_dataloader = s_dataloaders[1] 75 | 76 | t_train_dataloader = t_dataloaders[0] 77 | t_test_dataloader = t_dataloaders[1] 78 | 79 | shared_encoder = MLP(input_dim=kwargs['input_dim'], 80 | output_dim=kwargs['latent_dim'], 81 | hidden_dims=kwargs['encoder_hidden_dims'], 82 | dop=kwargs['dop']).to(kwargs['device']) 83 | 84 | shared_decoder = MLP(input_dim=2 * kwargs['latent_dim'], 85 | output_dim=kwargs['input_dim'], 86 | hidden_dims=kwargs['encoder_hidden_dims'][::-1], 87 | dop=kwargs['dop']).to(kwargs['device']) 88 | 89 | s_dsnae = DSNAE(shared_encoder=shared_encoder, 90 | decoder=shared_decoder, 91 | alpha=kwargs['alpha'], 92 | input_dim=kwargs['input_dim'], 93 | latent_dim=kwargs['latent_dim'], 94 | hidden_dims=kwargs['encoder_hidden_dims'], 95 | dop=kwargs['dop'], 96 | norm_flag=kwargs['norm_flag']).to(kwargs['device']) 97 | 98 | t_dsnae = DSNAE(shared_encoder=shared_encoder, 99 | decoder=shared_decoder, 100 | alpha=kwargs['alpha'], 101 | input_dim=kwargs['input_dim'], 102 | latent_dim=kwargs['latent_dim'], 103 | hidden_dims=kwargs['encoder_hidden_dims'], 104 | dop=kwargs['dop'], 105 | norm_flag=kwargs['norm_flag']).to(kwargs['device']) 106 | 107 | device = kwargs['device'] 108 | 109 | dsnae_train_history = defaultdict(list) 110 | dsnae_val_history = defaultdict(list) 111 | 112 | if kwargs['retrain_flag']: 113 | ae_params = [t_dsnae.private_encoder.parameters(), 114 | s_dsnae.private_encoder.parameters(), 115 | shared_decoder.parameters(), 116 | shared_encoder.parameters() 117 | ] 118 | 119 | ae_optimizer = torch.optim.AdamW(chain(*ae_params), lr=kwargs['lr']) 120 | 121 | 122 | for epoch in range(int(kwargs['train_num_epochs'])): 123 | if epoch % 50 == 0: 124 | print(f'AE training epoch {epoch}') 125 | for step, s_batch in enumerate(s_train_dataloader): 126 | t_batch = next(iter(t_train_dataloader)) 127 | dsnae_train_history = dsn_ae_train_step(s_dsnae=s_dsnae, #loss = s_loss_dict['loss'] + t_loss_dict['loss'] + m_loss 128 | t_dsnae=t_dsnae, 129 | s_batch=s_batch, 130 | t_batch=t_batch, 131 | device=device, 132 | optimizer=ae_optimizer, 133 | history=dsnae_train_history) 134 | dsnae_val_history = eval_dsnae_epoch(model=s_dsnae, 135 | data_loader=s_test_dataloader, 136 | device=device, 137 | history=dsnae_val_history 138 | ) 139 | dsnae_val_history = eval_dsnae_epoch(model=t_dsnae, 140 | data_loader=t_test_dataloader, 141 | device=device, 142 | history=dsnae_val_history 143 | ) 144 | for k in dsnae_val_history: 145 | if k != 'best_index': 146 | dsnae_val_history[k][-2] += dsnae_val_history[k][-1] 147 | dsnae_val_history[k].pop() 148 | 149 | 150 | if kwargs['es_flag']: 151 | save_flag, stop_flag = model_save_check(dsnae_val_history, metric_name='loss', tolerance_count=50) 152 | if save_flag: 153 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'm_s_dsnae.pt')) 154 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt')) 155 | if stop_flag: 156 | break 157 | 158 | if kwargs['es_flag']: 159 | s_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'm_s_dsnae.pt'))) 160 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt'))) 161 | 162 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'm_s_dsnae.pt')) 163 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt')) 164 | 165 | else: 166 | try: 167 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt'))) 168 | # if kwargs['norm_flag']: 169 | # loaded_model = torch.load(os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt')) 170 | # new_loaded_model = {key: val for key, val in loaded_model.items() if key in t_dsnae.state_dict()} 171 | # new_loaded_model['shared_encoder.output_layer.0.weight'] = loaded_model[ 172 | # 'shared_encoder.output_layer.3.weight'] 173 | # new_loaded_model['shared_encoder.output_layer.0.bias'] = loaded_model[ 174 | # 'shared_encoder.output_layer.3.bias'] 175 | # new_loaded_model['decoder.output_layer.0.weight'] = loaded_model['decoder.output_layer.3.weight'] 176 | # new_loaded_model['decoder.output_layer.0.bias'] = loaded_model['decoder.output_layer.3.bias'] 177 | 178 | # corrected_model = OrderedDict({key: new_loaded_model[key] for key in t_dsnae.state_dict()}) 179 | # t_dsnae.load_state_dict(corrected_model) 180 | # else: 181 | # t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt'))) 182 | 183 | except FileNotFoundError: 184 | raise Exception("No pre-trained encoder") 185 | 186 | 187 | return t_dsnae.shared_encoder, (dsnae_train_history, dsnae_val_history) 188 | -------------------------------------------------------------------------------- /code/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import torch 4 | 5 | from collections import defaultdict 6 | from torch import nn 7 | from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, f1_score, \ 8 | log_loss, auc, precision_recall_curve 9 | from scipy.stats import pearsonr #pearsonr(self.y_val, y_pred_val[:,0])[0] 10 | 11 | 12 | 13 | def auprc(y_true, y_score): 14 | lr_precision, lr_recall, _ = precision_recall_curve(y_true=y_true, probas_pred=y_score) 15 | return auc(lr_recall, lr_precision) 16 | 17 | 18 | def model_save_check(history, metric_name, tolerance_count=5, reset_count=1): 19 | save_flag = False 20 | stop_flag = False 21 | if 'best_index' not in history: 22 | history['best_index'] = 0 23 | if metric_name.endswith('loss'): 24 | if history[metric_name][-1] <= history[metric_name][history['best_index']]: 25 | save_flag = True 26 | history['best_index'] = len(history[metric_name]) - 1 27 | else: 28 | if history[metric_name][-1] >= history[metric_name][history['best_index']]: 29 | save_flag = True 30 | history['best_index'] = len(history[metric_name]) - 1 31 | 32 | # print(history['best_index'],history[metric_name][history['best_index']]) 33 | if len(history[metric_name]) - history['best_index'] > tolerance_count * reset_count and history['best_index'] > 0: 34 | stop_flag = True 35 | a,b,c,d = history['best_index'],len(history[metric_name]) - 1,history[metric_name][0],history[metric_name][-1] 36 | print(f'The best epoch: {a} / {b}') 37 | print(f'Metric from first to stop: {c} to {d}') 38 | 39 | return save_flag, stop_flag 40 | 41 | 42 | def eval_ae_epoch(model, data_loader, device, history): 43 | model.eval() 44 | avg_loss_dict = defaultdict(float) 45 | for x_batch in data_loader: 46 | x_batch = x_batch[0].to(device) 47 | with torch.no_grad(): 48 | loss_dict = model.loss_function(*(model(x_batch))) 49 | for k, v in loss_dict.items(): 50 | avg_loss_dict[k] += v.cpu().detach().item() / len(data_loader) 51 | 52 | for k, v in avg_loss_dict.items(): 53 | history[k].append(v) 54 | return history 55 | 56 | def evaluate_target_classification_epoch(classifier, dataloader, device, history,class_num,test_flag = False): 57 | y_truths = np.array([]) 58 | classifier.eval() 59 | 60 | if class_num == 0: 61 | y_preds = np.array([]) 62 | for x_gex ,x_smiles, y_batch in dataloader: 63 | x_gex = x_gex.to(device) 64 | x_smiles = x_smiles.to(device) 65 | y_batch = y_batch.to(device) 66 | # print("x_gex:",x_gex) 67 | # print("x_smiles",x_smiles) 68 | # exit() 69 | with torch.no_grad(): 70 | y_truths = np.concatenate([y_truths, y_batch.cpu().detach().numpy().ravel()]) 71 | # print("y_pred0",classifier(x_smiles,x_gex)) 72 | y_pred = torch.sigmoid(classifier(x_smiles,x_gex)).detach() 73 | # print("y_pred",y_pred) 74 | # exit() 75 | y_preds = np.concatenate([y_preds, y_pred.cpu().detach().numpy().ravel()]) 76 | elif class_num == 1: 77 | y_preds = [] 78 | for x_gex ,x_smiles, y_batch in dataloader: 79 | x_gex = x_gex.to(device) 80 | x_smiles = x_smiles.to(device) 81 | y_batch = y_batch.to(device) 82 | with torch.no_grad(): 83 | y_truths = np.concatenate([y_truths, y_batch.cpu().detach().numpy().ravel()]) 84 | # y_pred = torch.sigmoid(classifier(x_smiles,x_gex)).detach() 85 | y_pred = classifier(x_smiles,x_gex).detach() 86 | y_preds.append( y_pred.cpu().detach().tolist() ) 87 | y_preds = [token for st in y_preds for token in st] 88 | y_preds = np.array(y_preds) 89 | 90 | else : 91 | y_preds = [] 92 | for x_gex ,x_smiles, y_batch in dataloader: 93 | x_gex = x_gex.to(device) 94 | x_smiles = x_smiles.to(device) 95 | y_batch = y_batch.to(device) 96 | with torch.no_grad(): 97 | y_truths = np.concatenate([y_truths, y_batch.cpu().detach().numpy().ravel()]) 98 | # y_pred = torch.sigmoid(classifier(x_smiles,x_gex)).detach() 99 | y_pred = nn.functional.softmax(classifier(x_smiles,x_gex),dim=1).detach() 100 | y_preds.append( y_pred.cpu().detach().tolist() ) 101 | y_preds = [token for st in y_preds for token in st] 102 | y_preds = np.array(y_preds) 103 | 104 | if class_num == 2: 105 | # y_preds = y_preds.gather(-1,nn.init.ones_(torch.zeros(y_preds.size(0), 1))) #for torch 106 | # print(y_preds) 107 | y_preds = y_preds[:,1] 108 | # print(y_preds) 109 | # exit() 110 | elif class_num > 2 : 111 | if test_flag :#做zero shot 112 | a = torch.from_numpy(y_preds) 113 | _, max_class = torch.max(a, -1) 114 | y_preds = torch.sigmoid(max_class).numpy() #x->3->[0,1] 115 | #numpy 116 | # max_class = np.argmax(y_preds,-1) 117 | # y_preds = sigmoid(max_class) 118 | else: 119 | y_preds = y_preds 120 | else: 121 | raise Exception("class_num error") 122 | 123 | # print("y_truths: ",y_truths) 124 | # print("y_preds: ",y_preds) 125 | if class_num == 1: 126 | if test_flag : 127 | # history['auroc'].append(1-roc_auc_score(y_true=y_truths, y_score=sigmoid(y_preds))) #AUC越大越差 128 | history['auroc'].append(roc_auc_score(y_true=y_truths, y_score=1-y_preds)) #AUC本身就在01之间,加个负号就行 129 | else: 130 | history['auroc'].append(pearsonr(y_truths, y_preds)[0]) 131 | else: 132 | history['auroc'].append(roc_auc_score(y_true=y_truths, y_score=y_preds,multi_class='ovo')) #暂时只有auroc,多分类下面可能会有问题 133 | history['auprc'].append(auprc(y_true=y_truths, y_score=y_preds)) 134 | history['acc'].append(accuracy_score(y_true=y_truths, y_pred=(y_preds > 0.5).astype('int'))) 135 | history['f1'].append(f1_score(y_true=y_truths, y_pred=(y_preds > 0.5).astype('int'))) 136 | history['aps'].append(average_precision_score(y_true=y_truths, y_score=y_preds)) 137 | history['ce'].append(log_loss(y_true=y_truths, y_pred=y_preds)) 138 | 139 | 140 | return history 141 | 142 | 143 | def evaluate_adv_classification_epoch(classifier, s_dataloader, t_dataloader, device, history): 144 | y_truths = np.array([]) 145 | y_preds = np.array([]) 146 | classifier.eval() 147 | 148 | for s_batch in s_dataloader: 149 | s_x = s_batch[0].to(device) 150 | with torch.no_grad(): 151 | y_truths = np.concatenate([y_truths, np.zeros(s_x.shape[0]).ravel()]) 152 | s_y_pred = torch.sigmoid(classifier(s_x)).detach() 153 | y_preds = np.concatenate([y_preds, s_y_pred.cpu().detach().numpy().ravel()]) 154 | 155 | for t_batch in t_dataloader: 156 | t_x = t_batch[0].to(device) 157 | with torch.no_grad(): 158 | y_truths = np.concatenate([y_truths, np.ones(t_x.shape[0]).ravel()]) 159 | t_y_pred = torch.sigmoid(classifier(t_x)).detach() 160 | y_preds = np.concatenate([y_preds, t_y_pred.cpu().detach().numpy().ravel()]) 161 | 162 | history['acc'].append(accuracy_score(y_true=y_truths, y_pred=(y_preds > 0.5).astype('int'))) 163 | history['auroc'].append(roc_auc_score(y_true=y_truths, y_score=y_preds)) 164 | history['aps'].append(average_precision_score(y_true=y_truths, y_score=y_preds)) 165 | history['f1'].append(f1_score(y_true=y_truths, y_pred=(y_preds > 0.5).astype('int'))) 166 | history['bce'].append(log_loss(y_true=y_truths, y_pred=y_preds)) 167 | history['auprc'].append(auprc(y_true=y_truths, y_score=y_preds)) 168 | 169 | return history 170 | 171 | 172 | # patient drug score matrix 173 | def predict_pdr_score(classifier, pdr_dataloader ,device): 174 | test_dataloader,patient_index = pdr_dataloader 175 | y_preds = np.array([]) 176 | classifier.eval() 177 | for x_gex ,x_smiles in test_dataloader: 178 | x_gex = x_gex.to(device) 179 | x_smiles = x_smiles.to(device) 180 | # print("x_gex:",x_gex) 181 | # print("x_smiles",x_smiles) 182 | # exit() 183 | with torch.no_grad(): 184 | y_pred = torch.sigmoid(classifier(x_smiles,x_gex)).detach() 185 | # print("y_pred",y_pred) 186 | # exit() 187 | y_preds = np.concatenate([y_preds, y_pred.cpu().detach().numpy().ravel()]) 188 | 189 | drug_list = pd.read_csv('../data/preprocessed_dat/drug_embedding/CCL_dataset/drug_smiles.csv') 190 | patient_num = len(patient_index) 191 | y_preds = y_preds.reshape(patient_num,-1) 192 | output_df = pd.DataFrame(y_preds,index=patient_index,columns=drug_list['Drug_name']) 193 | 194 | return output_df 195 | 196 | def sigmoid(x): 197 | return 1 / (1 + np.exp(-x)) 198 | -------------------------------------------------------------------------------- /code/P_MDL.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | import torch 5 | import json 6 | import os 7 | import argparse 8 | import random 9 | import pickle 10 | import itertools 11 | import data 12 | import fine_tuning 13 | 14 | # add model_train path 15 | import sys 16 | sys.path.append('./model_train') 17 | 18 | import train_ae 19 | import train_ae_mmd 20 | import train_ae_adv 21 | 22 | import train_dsn 23 | import train_dsn_mmd 24 | import train_dsn_adv 25 | 26 | import train_dsrn_mmd 27 | import train_dsrn_adv 28 | 29 | from copy import deepcopy 30 | from collections import defaultdict 31 | 32 | 33 | def dict_to_str(d): 34 | 35 | return "_".join(["_".join([k, str(v)]) for k, v in d.items()]) 36 | 37 | def wrap_training_params(training_params, type='unlabeled'): 38 | aux_dict = {k: v for k, v in training_params.items() if k not in ['unlabeled', 'labeled']} 39 | aux_dict.update(**training_params[type]) 40 | 41 | return aux_dict 42 | 43 | def safe_make_dir(new_folder_name): 44 | if not os.path.exists(new_folder_name): 45 | os.makedirs(new_folder_name) 46 | else: 47 | print(new_folder_name, 'exists!') 48 | 49 | 50 | 51 | def main(args, update_params_dict): 52 | if args.method == 'ae': 53 | train_fn = train_ae.train_ae 54 | elif args.method == 'ae_mmd': 55 | train_fn = train_ae_mmd.train_ae_mmd 56 | elif args.method == 'ae_adv': 57 | train_fn = train_ae_adv.train_ae_adv 58 | 59 | elif args.method == 'dsn': 60 | train_fn = train_dsn.train_dsn 61 | elif args.method == 'dsn_mmd': 62 | train_fn = train_dsn_mmd.train_dsn_mmd 63 | elif args.method == 'dsn_adv': 64 | train_fn = train_dsn_adv.train_dsn_adv 65 | 66 | elif args.method == 'dsrn_mmd': 67 | train_fn = train_dsrn_mmd.train_dsrn_mmd 68 | elif args.method == 'dsrn_adv': 69 | train_fn = train_dsrn_adv.train_dsrn_adv 70 | 71 | else: 72 | raise NotImplementedError("Not true method supplied!") 73 | 74 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 75 | 76 | 77 | with open(os.path.join('params/train_params.json'), 'r') as f: 78 | training_params = json.load(f) 79 | 80 | training_params['unlabeled'].update(update_params_dict) 81 | param_str = dict_to_str(update_params_dict) 82 | 83 | source_dir = os.path.join(args.pretrain_dataset, 84 | tumor_type, 85 | args.select_drug_method) 86 | 87 | store_dir = os.path.join('../results',args.store_dir) 88 | method_save_folder = os.path.join(store_dir, f'{args.method}_norm',source_dir) 89 | task_save_folder = os.path.join(f'{method_save_folder}', args.measurement) 90 | safe_make_dir(task_save_folder) #unlabel result save dir 91 | 92 | training_params.update( 93 | { 94 | 'device': device, 95 | 'input_dim': gex_features_df.shape[-1], 96 | 'model_save_folder': os.path.join(method_save_folder, param_str), 97 | 'es_flag': False, 98 | 'retrain_flag': args.retrain_flag, 99 | 'norm_flag': args.norm_flag 100 | }) 101 | 102 | safe_make_dir(training_params['model_save_folder']) 103 | 104 | random.seed(2020) 105 | 106 | s_dataloaders, t_dataloaders = data.get_unlabeled_dataloaders( 107 | gex_features_df=gex_features_df, 108 | seed=2020, 109 | batch_size=training_params['unlabeled']['batch_size'], 110 | ccle_only=False 111 | ) 112 | 113 | 114 | encoder, historys = train_fn(s_dataloaders=s_dataloaders, 115 | t_dataloaders=t_dataloaders, 116 | **wrap_training_params(training_params, type='unlabeled')) 117 | print(' ') 118 | # print("Trained SE:",encoder) 119 | 120 | 121 | ft_evaluation_metrics = defaultdict(list) 122 | 123 | # Fine-tuning dataset 124 | labeled_dataloader_generator = data.get_finetune_dataloader_generator( 125 | gex_features_df = gex_features_df, 126 | # all_ccle_gex = all_ccle_gex, 127 | seed=2020, 128 | batch_size=training_params['labeled']['batch_size'], 129 | dataset = args.CCL_dataset, #finetune_dataset 130 | # sample_size = 0.006, 131 | ccle_measurement=args.measurement, 132 | n_splits=args.n, 133 | q=2, 134 | tumor_type = tumor_type, 135 | label_type = args.label_type, 136 | select_drug_method = args.select_drug_method) 137 | 138 | fold_count = 0 139 | for train_labeled_ccle_dataloader, test_labeled_ccle_dataloader, labeled_tcga_dataloader in labeled_dataloader_generator: 140 | ft_encoder = deepcopy(encoder) 141 | print(' ') 142 | print('Fold count = {}'.format(fold_count)) 143 | if args.select_drug_method == "overlap" : # only one fine-tuned model save for benchmark 144 | save_folder = method_save_folder 145 | elif args.select_drug_method == "all" : # all fine-tuned model save for pdr 146 | save_folder = training_params['model_save_folder'] 147 | target_classifier, ft_historys = fine_tuning.fine_tune_encoder_drug( 148 | encoder=ft_encoder, 149 | train_dataloader=train_labeled_ccle_dataloader, 150 | val_dataloader=test_labeled_ccle_dataloader, 151 | test_dataloader=labeled_tcga_dataloader, 152 | fold_count=fold_count, 153 | normalize_flag=args.norm_flag, 154 | metric_name=args.metric, 155 | task_save_folder = save_folder, 156 | drug_emb_dim=300, 157 | class_num = args.class_num, 158 | store_dir = method_save_folder, 159 | **wrap_training_params(training_params, type='labeled') 160 | ) 161 | if fold_count == 0 : 162 | print(f"Save target_classifier {param_dict}.") 163 | torch.save(target_classifier.state_dict(), 164 | os.path.join(method_save_folder, f'save_classifier_{fold_count}.pt')) 165 | ft_evaluation_metrics['best_index'].append(ft_historys[-2]['best_index']) 166 | for metric in ['auroc']: 167 | ft_evaluation_metrics[metric].append(ft_historys[-1][metric][ft_historys[-2]['best_index']]) 168 | fold_count += 1 169 | print(' ') 170 | print(' ') 171 | 172 | 173 | with open(os.path.join(task_save_folder, f'{param_str}_ft_evaluation_results.json'), 'w') as f: 174 | json.dump(ft_evaluation_metrics, f) 175 | 176 | 177 | if __name__ == '__main__': 178 | parser = argparse.ArgumentParser('P-MDL training and evaluation') 179 | parser.add_argument('--pretrain_num',default = None,type=int) 180 | parser.add_argument('--zero_shot_num',default = None,type=int) 181 | parser.add_argument('--method_num',default = None,type=int) 182 | 183 | parser.add_argument('--method', dest='method', nargs='?', default='dsn_adv', 184 | choices=['ae','ae_mmd','ae_adv', 185 | 'dsn','dsn_mmd','dsn_adv', 186 | 'dsrn_mmd','dsrn_adv']) 187 | parser.add_argument('--metric', dest='metric', nargs='?', default='auroc', choices=['auroc', 'auprc']) 188 | 189 | parser.add_argument('--measurement', dest='measurement', nargs='?', default='AUC', choices=['AUC', 'LN_IC50']) 190 | 191 | parser.add_argument('--n', dest='n', nargs='?', type=int, default=5) 192 | 193 | train_group = parser.add_mutually_exclusive_group(required=False) 194 | train_group.add_argument('--train', dest='retrain_flag', action='store_true') 195 | train_group.add_argument('--no-train', dest='retrain_flag', action='store_false') 196 | parser.set_defaults(retrain_flag=True) 197 | # parser.set_defaults(retrain_flag=False) 198 | 199 | norm_group = parser.add_mutually_exclusive_group(required=False) 200 | norm_group.add_argument('--norm', dest='norm_flag', action='store_true') 201 | norm_group.add_argument('--no-norm', dest='norm_flag', action='store_false') 202 | parser.set_defaults(norm_flag=True) 203 | 204 | parser.add_argument('--label_type', default = "PFS",choices=["PFS","Imaging"]) 205 | 206 | parser.add_argument('--select_drug_method', default = "overlap",choices=["overlap","all","random"]) 207 | 208 | parser.add_argument('--store_dir',default = "benchmark") 209 | parser.add_argument('--select_gene_method',default = "Percent_sd",choices=["Percent_sd","HVG"]) 210 | parser.add_argument('--select_gene_num',default = 1000,type=int) 211 | 212 | parser.add_argument('--pretrain_dataset',default = "tcga", 213 | choices=["tcga", "brca", "cesc", "coad", "gbm", "hnsc", "kirc", 214 | "lgg", "luad", "lusc","paad", "read", "sarc", "skcm", "stad" 215 | ]) 216 | parser.add_argument('--tumor_type',default = "BRCA", 217 | choices=['TCGA','GBM', 'LGG', 'HNSC','KIRC','SARC','BRCA','STAD','CESC','SKCM','LUSC','LUAD','READ','COAD' 218 | ]) 219 | parser.add_argument('--CCL_dataset',default = 'gdsc1_raw', 220 | choices=['gdsc1_raw','gdsc1_rebalance']) 221 | parser.add_argument('--class_num',default = 0,type=int) 222 | 223 | args = parser.parse_args() 224 | if args.class_num == 1: 225 | # args.CCL_dataset = f"{args.CCL_dataset}_regression" 226 | print("Regression task. Use dataset:",args.CCL_dataset) 227 | 228 | params_grid = { 229 | "pretrain_num_epochs": [0, 100, 300], 230 | "train_num_epochs": [100, 200, 300, 500, 750, 1000, 1500, 2000, 2500, 3000], 231 | "dop": [0.0, 0.1] 232 | } 233 | 234 | 235 | Tumor_type_list = ["tcga", "brca", "cesc", "coad", "gbm", "hnsc", "kirc", 236 | "lgg", "luad", "lusc","paad", "read", "sarc", "skcm", "stad" 237 | ] 238 | if args.pretrain_num : 239 | args.pretrain_dataset = Tumor_type_list[args.pretrain_num] 240 | if args.zero_shot_num : 241 | args.tumor_type = [element.upper() for element in Tumor_type_list][args.zero_shot_num] 242 | # print(f'Tumor type: Select zero_shot_num: {Num}. Zero-shot dataset: {args.tumor_type}') 243 | if args.method_num : 244 | args.method = [ 245 | 'no','ae','dsn', 246 | 'ae_mmd','dsrn_mmd','dsn_mmd', 247 | 'ae_adv','dsrn_adv','dsn_adv'][args.method_num] 248 | 249 | # Test-pairwise pre-training set pretrain_dataset and test_dataset to the same tumor type in default 250 | tumor_type = args.pretrain_dataset.upper() 251 | if tumor_type == "TCGA" : 252 | tumor_type = args.tumor_type 253 | 254 | if args.method not in ['dsrn_adv', 'dsn_adv', 'ae_adv']: 255 | params_grid.pop('pretrain_num_epochs') 256 | 257 | keys, values = zip(*params_grid.items()) 258 | update_params_dict_list = [dict(zip(keys, v)) for v in itertools.product(*values)] 259 | 260 | 261 | gex_features_df = pd.read_csv(f'../data/pretrain_data/{args.pretrain_dataset}_pretrain_dataset.csv',index_col=0) 262 | CCL_tumor_type = 'all_CCL' 263 | 264 | print(f'Pretrain dataset: Patient({args.pretrain_dataset} {args.pretrain_num}) CCL( {CCL_tumor_type}). Select_gene_method: {args.select_gene_method}') 265 | print(f'Zero-shot dataset: {tumor_type}({args.zero_shot_num})') 266 | print(f'CCL_dataset: {args.CCL_dataset} Select_drug_method: {args.select_drug_method}') 267 | print(f'Store_dir: {args.store_dir} ') 268 | 269 | print(f'method: {args.method}({args.method_num}). label_type: {args.label_type}') 270 | param_num = 0 271 | 272 | # update_params_dict_list.reverse() 273 | 274 | for param_dict in update_params_dict_list: 275 | param_num = param_num + 1 276 | print(' ') 277 | print('##############################################################################') 278 | print(f'####### Param_num {param_num}/{len(update_params_dict_list)} #######') 279 | print('Param_dict: {}'.format(param_dict) ) 280 | print('##############################################################################') 281 | main(args=args, update_params_dict=param_dict) 282 | print("Finsh All !!!!!!!!!!!!!!!") 283 | -------------------------------------------------------------------------------- /code/fine_tuning.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./model_train') 3 | 4 | from evaluation_utils import evaluate_target_classification_epoch,model_save_check,predict_pdr_score 5 | from collections import defaultdict 6 | from itertools import chain 7 | from mlp import MLP 8 | from encoder_decoder import EncoderDecoder 9 | import os 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | def fine_tune_encoder_drug(encoder, train_dataloader, val_dataloader, fold_count,store_dir, task_save_folder,test_dataloader=None, 15 | metric_name='auroc', 16 | class_num = 2, 17 | normalize_flag=False, 18 | break_flag=False, 19 | test_df=None, 20 | drug_emb_dim=128, 21 | to_roughly_test = False, 22 | **kwargs): 23 | finetune_output_dim = class_num 24 | if finetune_output_dim == 0: 25 | finetune_output_dim = 1 26 | target_decoder = MLP(input_dim=kwargs['latent_dim'] + drug_emb_dim, 27 | output_dim=finetune_output_dim, 28 | hidden_dims=kwargs['classifier_hidden_dims']).to(kwargs['device']) 29 | 30 | target_classifier = EncoderDecoder(encoder=encoder, decoder=target_decoder, 31 | normalize_flag=normalize_flag).to(kwargs['device']) 32 | # target_decoder load to re-train faster 33 | target_classifier_file = os.path.join(store_dir, 'save_classifier_0.pt') 34 | if os.path.exists(target_classifier_file) and fold_count>0: 35 | print("Loading ",target_classifier_file) 36 | target_classifier.load_state_dict( 37 | torch.load(target_classifier_file)) 38 | else: 39 | to_roughly_test = False 40 | print("No target_classifier_file stored for use. Generate first!") 41 | # print(' ') 42 | 43 | if class_num == 0: 44 | classification_loss = nn.BCEWithLogitsLoss() 45 | elif class_num == 1: 46 | classification_loss = nn.MSELoss() ## nn.MAELoss() 47 | else: 48 | classification_loss = nn.CrossEntropyLoss() 49 | 50 | target_classification_train_history = defaultdict(list) 51 | target_classification_eval_train_history = defaultdict(list) 52 | target_classification_eval_val_history = defaultdict(list) 53 | target_classification_eval_test_history = defaultdict(list) 54 | 55 | encoder_module_indices = [i for i in range(len(list(encoder.modules()))) 56 | if str(list(encoder.modules())[i]).startswith('Linear')] 57 | 58 | reset_count = 1 59 | lr = kwargs['lr'] 60 | 61 | # target_classification_params = [target_classifier.decoder.parameters(),target_classifier.smiles_encoder.parameters()] #原来只更新decoder的层 62 | target_classification_params = [target_classifier.decoder.parameters()] #原来只更新decoder的层 63 | 64 | target_classification_optimizer = torch.optim.AdamW(chain(*target_classification_params), 65 | lr=lr) 66 | 67 | stop_flag_num = 0 68 | for epoch in range(kwargs['train_num_epochs']): 69 | # if epoch % 50 == 0: 70 | # print(f'Fine tuning epoch {epoch}') 71 | for step, batch in enumerate(train_dataloader): 72 | target_classification_train_history = classification_train_step_drug(model=target_classifier, 73 | batch=batch, 74 | loss_fn=classification_loss, 75 | device=kwargs['device'], 76 | optimizer=target_classification_optimizer, 77 | history=target_classification_train_history, 78 | class_num = class_num) 79 | target_classification_eval_train_history = evaluate_target_classification_epoch(classifier=target_classifier, 80 | dataloader=train_dataloader, 81 | device=kwargs['device'], 82 | history=target_classification_eval_train_history, 83 | class_num = class_num) 84 | target_classification_eval_val_history = evaluate_target_classification_epoch(classifier=target_classifier, 85 | dataloader=val_dataloader, 86 | device=kwargs['device'], 87 | history=target_classification_eval_val_history, 88 | class_num = class_num) 89 | 90 | if test_dataloader is not None: 91 | target_classification_eval_test_history = evaluate_target_classification_epoch(classifier=target_classifier, 92 | dataloader=test_dataloader, 93 | device=kwargs['device'], 94 | history=target_classification_eval_test_history, 95 | class_num = class_num, 96 | test_flag=True) 97 | save_flag, stop_flag = model_save_check(history=target_classification_eval_val_history, 98 | metric_name=metric_name, 99 | tolerance_count=5, # stop_flag once 5 epochs not better 100 | reset_count=reset_count) 101 | test_metric = target_classification_eval_test_history[metric_name][target_classification_eval_val_history['best_index']] 102 | if epoch % 50 == 0: 103 | print(f'Fine tuning epoch {epoch}. stop_flag_num: {stop_flag_num}. {test_metric}') 104 | # print(save_flag, stop_flag,stop_flag_num) 105 | if save_flag: 106 | torch.save(target_classifier.state_dict(), 107 | os.path.join(task_save_folder, f'target_classifier_{fold_count}.pt')) # save model 108 | # print("Save model. ",test_metric,epoch) 109 | if to_roughly_test: 110 | print("To roughly test pass, just get the zero-shot metric at the begining for judge.") 111 | break 112 | pass 113 | if stop_flag: 114 | stop_flag_num = stop_flag_num+1 115 | print(' ') 116 | try: 117 | ind = encoder_module_indices.pop() 118 | print(f'Unfreezing Linear {ind} in the epoch {epoch}. {test_metric}') 119 | target_classifier.load_state_dict( 120 | torch.load(os.path.join(task_save_folder, f'target_classifier_{fold_count}.pt'))) 121 | 122 | target_classification_params.append(list(target_classifier.encoder.modules())[ind].parameters()) 123 | lr = lr * kwargs['decay_coefficient'] 124 | target_classification_optimizer = torch.optim.AdamW(chain(*target_classification_params), lr=lr) 125 | reset_count += 1 126 | except Exception as e: 127 | print(e) 128 | print(test_metric) 129 | break 130 | # if stop_flag and not break_flag: 131 | # print(f'Unfreezing {epoch}') 132 | # target_classifier.load_state_dict( 133 | # torch.load(os.path.join(task_save_folder, f'target_classifier_{fold_count}.pt'))) 134 | # 135 | # target_classification_params.append(target_classifier.encoder.shared_encoder.parameters()) 136 | # target_classification_params.append(target_classifier.encoder.private_encoder.parameters()) 137 | # 138 | # lr = lr * kwargs['decay_coefficient'] 139 | # target_classification_optimizer = torch.optim.AdamW(chain(*target_classification_params), lr=lr) 140 | # break_flag = True 141 | # stop_flag = False 142 | # if stop_flag and break_flag: 143 | # break 144 | 145 | target_classifier.load_state_dict( 146 | torch.load(os.path.join(task_save_folder, f'target_classifier_{fold_count}.pt'))) 147 | 148 | 149 | return target_classifier, (target_classification_train_history, target_classification_eval_train_history, 150 | target_classification_eval_val_history, target_classification_eval_test_history)#, prediction_df 151 | 152 | 153 | def classification_train_step_drug(model, batch, loss_fn, device, optimizer, history, class_num,scheduler=None, clip=None): 154 | model.zero_grad() 155 | model.train() 156 | 157 | x_smiles = batch[1].to(device) 158 | x_gex = batch[0].to(device) 159 | y = batch[2].to(device) 160 | # print("smiles:",x_smiles) 161 | # print("gex:",x_gex) 162 | # print("label:",y) 163 | 164 | if class_num == 0: 165 | loss = loss_fn(model(x_smiles,x_gex), y.double().unsqueeze(1)) # 166 | else : 167 | loss = loss_fn(model(x_smiles,x_gex), y.float().unsqueeze(1)) # 168 | 169 | optimizer.zero_grad() 170 | loss.backward() 171 | if clip is not None: 172 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip) 173 | 174 | optimizer.step() 175 | if scheduler is not None: 176 | scheduler.step() 177 | 178 | history['ce'].append(loss.cpu().detach().item()) 179 | # history['bce'].append(loss.cpu().detach().item()) 180 | 181 | return history 182 | 183 | 184 | #predict pdr result 185 | def predict_pdr(encoder, fold_count, task_save_folder, test_dataloader=None, 186 | metric_name='auroc', 187 | class_num = 2, 188 | normalize_flag=False, 189 | break_flag=False, 190 | pdr_dataloader=None, 191 | drug_emb_dim=128, 192 | **kwargs): 193 | finetune_output_dim = class_num 194 | if finetune_output_dim == 0: 195 | finetune_output_dim = 1 196 | target_decoder = MLP(input_dim=kwargs['latent_dim'] + drug_emb_dim, 197 | output_dim=finetune_output_dim, 198 | hidden_dims=kwargs['classifier_hidden_dims']).to(kwargs['device']) 199 | 200 | target_classifier = EncoderDecoder(encoder=encoder, decoder=target_decoder, 201 | normalize_flag=normalize_flag).to(kwargs['device']) 202 | target_classifier.load_state_dict( 203 | torch.load(os.path.join(task_save_folder, f'target_classifier_{fold_count}.pt'))) 204 | print('Sucessfully loaded target_classifier_{}'.format(fold_count)) 205 | 206 | target_classification_eval_test_history = defaultdict(list) 207 | if test_dataloader is not None: 208 | target_classification_eval_test_history = evaluate_target_classification_epoch_1(classifier=target_classifier, 209 | dataloader=test_dataloader, 210 | device=kwargs['device'], 211 | history=target_classification_eval_test_history, 212 | class_num = class_num, 213 | test_flag=True) 214 | 215 | 216 | prediction_df = None 217 | if pdr_dataloader is not None: 218 | prediction_df = predict_pdr_score(classifier=target_classifier, pdr_dataloader=pdr_dataloader, 219 | device=kwargs['device']) 220 | 221 | return target_classification_eval_test_history, prediction_df -------------------------------------------------------------------------------- /code/PDR_task.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | import json 4 | import os 5 | import argparse 6 | import random 7 | import pickle 8 | import itertools 9 | import numpy as np 10 | import data 11 | import data_config 12 | import fine_tuning 13 | 14 | # add model_train path 15 | import sys 16 | sys.path.append('../model_train') 17 | 18 | import train_ae 19 | import train_ae_mmd 20 | import train_ae_adv 21 | 22 | import train_dsn 23 | import train_dsn_mmd 24 | import train_dsn_adv 25 | 26 | import train_dsrn_mmd 27 | import train_dsrn_adv 28 | 29 | 30 | from copy import deepcopy 31 | from collections import defaultdict 32 | 33 | 34 | def wrap_training_params(training_params, type='unlabeled'): 35 | aux_dict = {k: v for k, v in training_params.items() if k not in ['unlabeled', 'labeled']} 36 | aux_dict.update(**training_params[type]) 37 | 38 | return aux_dict 39 | 40 | def safe_make_dir(new_folder_name): 41 | if not os.path.exists(new_folder_name): 42 | os.makedirs(new_folder_name) 43 | else: 44 | print(new_folder_name, 'exists!') 45 | 46 | def dict_to_str(d): 47 | return "_".join(["_".join([k, str(v)]) for k, v in d.items()]) 48 | 49 | 50 | 51 | def main(args, update_params_dict): 52 | if args.method == 'ae': 53 | train_fn = train_ae.train_ae 54 | elif args.method == 'ae_mmd': 55 | train_fn = train_ae_mmd.train_ae_mmd 56 | elif args.method == 'ae_adv': 57 | train_fn = train_ae_adv.train_ae_adv 58 | 59 | elif args.method == 'dsn': 60 | train_fn = train_dsn.train_dsn 61 | elif args.method == 'dsn_mmd': 62 | train_fn = train_dsn_mmd.train_dsn_mmd 63 | elif args.method == 'dsn_adv': 64 | train_fn = train_dsn_adv.train_dsn_adv 65 | 66 | elif args.method == 'dsrn': 67 | train_fn = train_dsrn.train_dsrn 68 | elif args.method == 'dsrn_mmd': 69 | train_fn = train_dsrn_mmd.train_dsrn_mmd 70 | elif args.method == 'dsrn_adv': 71 | train_fn = train_dsrn_adv.train_dsrn_adv 72 | 73 | else: 74 | raise NotImplementedError("Not true method supplied!") 75 | 76 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 77 | 78 | 79 | with open(os.path.join('model_save/train_params.json'), 'r') as f: 80 | training_params = json.load(f) 81 | 82 | training_params['unlabeled'].update(update_params_dict) ##unlabeled一开始就更新成dict的队列了,labeled一直没有更新,都是json读进来的2000 83 | param_str = dict_to_str(update_params_dict) #给这一版的文件命名,从dict转为str 84 | 85 | source_dir = os.path.join(args.pretrain_dataset,args.tcga_construction, 86 | args.select_gene_method, 87 | CCL_tumor_type,args.CCL_construction, 88 | tumor_type, 89 | args.CCL_dataset,args.select_drug_method) 90 | if args.use_tcga_pretrain_model: 91 | source_dir = os.path.join("tcga",args.tcga_construction, 92 | args.select_gene_method, 93 | CCL_tumor_type,args.CCL_construction, 94 | tumor_type, 95 | args.CCL_dataset,args.select_drug_method) 96 | 97 | if not args.norm_flag: 98 | method_save_folder = os.path.join('../results',args.store_dir, args.method,source_dir) 99 | else: 100 | method_save_folder = os.path.join('../results',args.store_dir, f'{args.method}_norm',source_dir) 101 | 102 | training_params.update( 103 | { 104 | 'device': device, 105 | 'input_dim': gex_features_df.shape[-1], 106 | 'model_save_folder': os.path.join(method_save_folder, param_str), 107 | 'es_flag': False, 108 | 'retrain_flag': args.retrain_flag, 109 | 'norm_flag': args.norm_flag 110 | }) 111 | if args.pdtc_flag: 112 | task_save_folder = os.path.join(f'{method_save_folder}', 'predict', 'pdtc') 113 | else: 114 | task_save_folder = os.path.join(f'{method_save_folder}', 'predict') 115 | 116 | # safe_make_dir(training_params['model_save_folder']) 117 | safe_make_dir(task_save_folder) #unlabel result save dir 118 | 119 | random.seed(2020) 120 | 121 | s_dataloaders, t_dataloaders = data.get_unlabeled_dataloaders( 122 | gex_features_df=gex_features_df, 123 | seed=2020, 124 | batch_size=training_params['unlabeled']['batch_size'], 125 | ccle_only=False 126 | ) 127 | 128 | #为了tcga_pretrain,所有模型都存在BRCA下 129 | if args.use_tcga_pretrain_model: 130 | training_params.update( 131 | { 132 | 'model_save_folder': os.path.join( 133 | os.path.join('../results',args.store_dir, f'{args.method}_norm', 134 | os.path.join("tcga",args.tcga_construction, 135 | args.select_gene_method, 136 | CCL_tumor_type,args.CCL_construction, 137 | "BRCA", #替换这个 138 | args.CCL_dataset,args.select_drug_method) 139 | ), 140 | param_str), #, 一个模型就存一个unlabel model 141 | }) 142 | print(f"model_save_folder:{training_params['model_save_folder']}") 143 | 144 | # start unlabeled training 145 | ##先pretrain code_ae_base,再用adv(GAN) train;pretrain可以没有,train都有 146 | encoder, historys = train_fn(s_dataloaders=s_dataloaders, #若no-train,就根据training_params中的save_folder找到对应Pretrained_model 147 | t_dataloaders=t_dataloaders, #[0]是train,[1]是test 148 | **wrap_training_params(training_params, type='unlabeled')) 149 | # print("Trained SE:",encoder) 150 | if args.use_tcga_pretrain_model: 151 | patient_sample_info = pd.read_csv("../data/preprocessed_dat/xena_sample_info.csv", index_col=0) 152 | patient_samples = gex_features_df.index.intersection(patient_sample_info.loc[patient_sample_info.tumor_type == tumor_type].index) 153 | gex_features_select = gex_features_df.loc[patient_samples] 154 | else: 155 | gex_features_select = gex_features_df 156 | 157 | ##再验证一下zero-shot的结果——数据准备 158 | tcga_ZS_dataloader,_ = data.get_tcga_ZS_dataloaders(gex_features_df=gex_features_select, 159 | label_type = args.label_type, 160 | batch_size=training_params['labeled']['batch_size'], 161 | q=2, 162 | tumor_type = tumor_type) 163 | print("Generating the pdr dataloader......") 164 | #预测pdr的数据——数据准备 165 | pdr_dataloader = data.get_pdr_data_dataloaders(gex_features_df=gex_features_select, 166 | batch_size=training_params['labeled']['batch_size'], 167 | tumor_type = tumor_type) 168 | ft_evaluation_metrics = defaultdict(list) 169 | # 5折开始 170 | for fold_count in range(5): 171 | ft_encoder = deepcopy(encoder) 172 | print(' ') 173 | print('Fold count = {}'.format(fold_count)) 174 | 175 | ft_historys , prediction_df = fine_tuning1.predict_pdr( 176 | encoder=ft_encoder, 177 | pdr_dataloader = pdr_dataloader, 178 | test_dataloader=tcga_ZS_dataloader, 179 | fold_count=fold_count, 180 | normalize_flag=args.norm_flag, 181 | metric_name=args.metric, 182 | task_save_folder = training_params['model_save_folder'], #task_save_folder 183 | drug_emb_dim=300, 184 | class_num = args.class_num, #CCL_dataset几分类,需要和train_labeled_ccle_dataloader, test_labeled_ccle_dataloader匹配 185 | **wrap_training_params(training_params, type='labeled') 186 | ) 187 | for metric in ['auroc', 'acc', 'aps', 'f1', 'auprc']: 188 | ft_evaluation_metrics[metric].append(ft_historys[metric]) 189 | # if ft_historys['auroc'].mean() > : 190 | print(f'Saving results: {param_str} ; Fold count: {fold_count}') 191 | prediction_df.to_csv(os.path.join(task_save_folder, f'{param_str}__{fold_count}__predict.csv')) 192 | 193 | with open(os.path.join(task_save_folder, f'{param_str}_ft_evaluation_results.json'), 'w') as f: 194 | json.dump(ft_evaluation_metrics, f) 195 | 196 | 197 | if __name__ == '__main__': 198 | parser = argparse.ArgumentParser('PDR task') 199 | parser.add_argument('--pretrain_num',default = None,type=int) 200 | parser.add_argument('--zero_shot_num',default = None,type=int) 201 | parser.add_argument('--method_num',default = None,type=int) 202 | 203 | parser.add_argument('--method', dest='method', nargs='?', default='dsn_adv', 204 | choices=['ae','ae_mmd','ae_adv', 205 | 'dsn','dsn_mmd','dsn_adv', 206 | 'dsrn_mmd','dsrn_adv']) 207 | parser.add_argument('--metric', dest='metric', nargs='?', default='auroc', choices=['auroc', 'auprc']) 208 | 209 | parser.add_argument('--measurement', dest='measurement', nargs='?', default='AUC', choices=['AUC', 'LN_IC50']) 210 | 211 | parser.add_argument('--n', dest='n', nargs='?', type=int, default=5) 212 | 213 | train_group = parser.add_mutually_exclusive_group(required=False) 214 | train_group.add_argument('--train', dest='retrain_flag', action='store_true') 215 | train_group.add_argument('--no-train', dest='retrain_flag', action='store_false') 216 | parser.set_defaults(retrain_flag=True) 217 | # parser.set_defaults(retrain_flag=False) 218 | 219 | 220 | parser.add_argument('--label_type', default = "PFS",choices=["PFS","Imaging"]) 221 | 222 | parser.add_argument('--select_drug_method', default = "overlap",choices=["overlap","all","random"]) 223 | 224 | parser.add_argument('--store_dir',default = "benchmark") 225 | parser.add_argument('--select_gene_method',default = "Percent_sd",choices=["Percent_sd","HVG"]) 226 | parser.add_argument('--select_gene_num',default = 1000,type=int) 227 | 228 | parser.add_argument('--pretrain_dataset',default = "tcga", 229 | choices=["tcga", "brca", "cesc", "coad", "gbm", "hnsc", "kirc", 230 | "lgg", "luad", "lusc","paad", "read", "sarc", "skcm", "stad" 231 | ]) 232 | parser.add_argument('--tumor_type',default = "BRCA", 233 | choices=['TCGA','GBM', 'LGG', 'HNSC','KIRC','SARC','BRCA','STAD','CESC','SKCM','LUSC','LUAD','READ','COAD' 234 | ]) 235 | parser.add_argument('--CCL_dataset',default = 'gdsc1_raw', 236 | choices=['gdsc1_raw','gdsc1_rebalance']) 237 | parser.add_argument('--class_num',default = 0,type=int) 238 | 239 | args = parser.parse_args() 240 | 241 | params_grid = { 242 | "pretrain_num_epochs": [0, 100, 300], 243 | "train_num_epochs": [100, 200, 300, 500, 750, 1000, 1500, 2000, 2500, 3000], 244 | "dop": [0.0, 0.1] 245 | } 246 | 247 | Tumor_type_list = ["tcga", "brca", "cesc", "coad", "gbm", "hnsc", "kirc", 248 | "lgg", "luad", "lusc","paad", "read", "sarc", "skcm", "stad" 249 | ] 250 | 251 | if args.pretrain_num : 252 | args.pretrain_dataset = Tumor_type_list[args.pretrain_num] #24 253 | if args.zero_shot_num : 254 | args.tumor_type = [element.upper() for element in Tumor_type_list][args.zero_shot_num] 255 | # print(f'Tumor type: Select zero_shot_num: {Num}. Zero-shot dataset: {args.tumor_type}') 256 | if args.method_num : 257 | args.method = [ 258 | 'ae','dsn', 259 | 'ae_mmd','dsrn_mmd','dsn_mmd', 260 | 'ae_adv','dsrn_adv','dsn_adv'][args.method_num] 261 | 262 | #tcga pretrain需要根据args.tumor_type去指定finetune哪个癌种,其他的默认先是原癌种 263 | tumor_type = args.pretrain_dataset.upper() #需要交叉的时候注释掉这行,直接让【tumor_type = args.tumor_type】自由选择 264 | if tumor_type == "TCGA" : 265 | tumor_type = args.tumor_type 266 | 267 | if args.method not in ['dsrn_adv', 'dsn_adv', 'ae_adv']: 268 | params_grid.pop('pretrain_num_epochs') 269 | 270 | keys, values = zip(*params_grid.items()) 271 | update_params_dict_list = [dict(zip(keys, v)) for v in itertools.product(*values)] 272 | 273 | #构建4级文件夹结构,并生成用于unlabel training需要的gex_features_df------ 274 | if args.use_tcga_pretrain_model: 275 | patient_tumor_type = "tcga" 276 | else: 277 | patient_tumor_type = args.pretrain_dataset 278 | gex_features_df,CCL_tumor_type,all_ccle_gex,all_patient_gex = data.get_pretrain_dataset( 279 | patient_tumor_type = patient_tumor_type, 280 | CCL_type = args.CCL_type, 281 | tumor_type = tumor_type, 282 | tcga_construction = args.tcga_construction, 283 | CCL_construction = args.CCL_construction, 284 | gene_num = args.select_gene_num,select_gene_method = args.select_gene_method 285 | ) 286 | 287 | print(f'Pretrain dataset: Patient({args.pretrain_dataset} {args.pretrain_num} {args.tcga_construction}) CCL({args.CCL_type} {CCL_tumor_type} {args.CCL_construction}). Select_gene_method: {args.select_gene_method}') 288 | print(f'Zero-shot dataset: {tumor_type}({args.zero_shot_num})') 289 | print(f'CCL_dataset: {args.CCL_dataset} Select_drug_method: {args.select_drug_method}') 290 | print(f'Store_dir: {args.store_dir} ') 291 | 292 | print(f'pdtc_flag: {args.pdtc_flag}. method: {args.method}({args.method_num}). label_type: {args.label_type}') 293 | param_num = 0 294 | 295 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' 296 | # update_params_dict_list.reverse() 297 | 298 | for param_dict in update_params_dict_list: 299 | param_num = param_num + 1 300 | print(' ') 301 | print('##############################################################################') 302 | print(f'####### Param_num {param_num}/{len(update_params_dict_list)} #######') 303 | print('Param_dict: {}'.format(param_dict) ) 304 | print('##############################################################################') 305 | try: 306 | main(args=args, update_params_dict=param_dict) 307 | except Exception as e: 308 | print(e) 309 | print(param_dict,param_num) 310 | print("Finsh All !!!!!!!!!!!!!!!") 311 | -------------------------------------------------------------------------------- /code/model_train/train_dsn_adv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.autograd as autograd 3 | from itertools import chain 4 | from dsn_ae import DSNAE 5 | from evaluation_utils import * 6 | from mlp import MLP 7 | from train_dsn import eval_dsnae_epoch, dsn_ae_train_step 8 | from collections import OrderedDict 9 | 10 | def compute_gradient_penalty(critic, real_samples, fake_samples, device): 11 | """Calculates the gradient penalty loss for WGAN GP""" 12 | # Random weight term for interpolation between real and fake samples 13 | alpha = torch.rand((real_samples.shape[0], 1)).to(device) 14 | # Get random interpolation between real and fake samples 15 | interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) 16 | critic_interpolates = critic(interpolates) 17 | fakes = torch.ones((real_samples.shape[0], 1)).to(device) 18 | # Get gradient w.r.t. interpolates 19 | gradients = autograd.grad( 20 | outputs=critic_interpolates, 21 | inputs=interpolates, 22 | grad_outputs=fakes, 23 | create_graph=True, 24 | retain_graph=True, 25 | only_inputs=True, 26 | )[0] 27 | gradients = gradients.view(gradients.size(0), -1) 28 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() 29 | return gradient_penalty 30 | 31 | 32 | def critic_dsn_train_step(critic, s_dsnae, t_dsnae, s_batch, t_batch, device, optimizer, history, scheduler=None, 33 | clip=None, gp=None): 34 | critic.zero_grad() 35 | s_dsnae.zero_grad() 36 | t_dsnae.zero_grad() 37 | s_dsnae.eval() 38 | t_dsnae.eval() 39 | critic.train() 40 | 41 | s_x = s_batch[0].to(device) 42 | t_x = t_batch[0].to(device) 43 | 44 | s_code = s_dsnae.s_encode(s_x) 45 | t_code = t_dsnae.s_encode(t_x) 46 | 47 | loss = torch.mean(critic(t_code)) - torch.mean(critic(s_code)) 48 | if gp is not None: 49 | gradient_penalty = compute_gradient_penalty(critic, 50 | real_samples=s_code, 51 | fake_samples=t_code, 52 | device=device) 53 | loss = loss + gp * gradient_penalty 54 | 55 | optimizer.zero_grad() 56 | loss.backward() 57 | # if clip is not None: 58 | # torch.nn.utils.clip_grad_norm_(model.parameters(), clip) 59 | optimizer.step() 60 | 61 | if clip is not None: 62 | for p in critic.parameters(): 63 | p.data.clamp_(-clip, clip) 64 | if scheduler is not None: 65 | scheduler.step() 66 | 67 | history['critic_loss'].append(loss.cpu().detach().item()) 68 | 69 | return history 70 | 71 | 72 | def gan_dsn_gen_train_step(critic, s_dsnae, t_dsnae, s_batch, t_batch, device, optimizer, alpha, history, 73 | scheduler=None): 74 | critic.zero_grad() 75 | s_dsnae.zero_grad() 76 | t_dsnae.zero_grad() 77 | critic.eval() 78 | s_dsnae.train() 79 | t_dsnae.train() 80 | 81 | s_x = s_batch[0].to(device) 82 | t_x = t_batch[0].to(device) 83 | 84 | t_code = t_dsnae.s_encode(t_x) 85 | 86 | optimizer.zero_grad() 87 | gen_loss = -torch.mean(critic(t_code)) 88 | s_loss_dict = s_dsnae.loss_function(*s_dsnae(s_x)) 89 | t_loss_dict = t_dsnae.loss_function(*t_dsnae(t_x)) 90 | recons_loss = s_loss_dict['loss'] + t_loss_dict['loss'] 91 | loss = recons_loss + alpha * gen_loss 92 | optimizer.zero_grad() 93 | 94 | loss.backward() 95 | optimizer.step() 96 | if scheduler is not None: 97 | scheduler.step() 98 | 99 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()} 100 | 101 | for k, v in loss_dict.items(): 102 | history[k].append(v) 103 | history['gen_loss'].append(gen_loss.cpu().detach().item()) 104 | 105 | return history 106 | 107 | 108 | def train_dsn_adv(s_dataloaders, t_dataloaders, **kwargs): 109 | """ 110 | 111 | :param s_dataloaders: 112 | :param t_dataloaders: 113 | :param kwargs: 114 | :return: 115 | """ 116 | s_train_dataloader = s_dataloaders[0] 117 | s_test_dataloader = s_dataloaders[1] 118 | 119 | t_train_dataloader = t_dataloaders[0] 120 | t_test_dataloader = t_dataloaders[1] 121 | 122 | shared_encoder = MLP(input_dim=kwargs['input_dim'], 123 | output_dim=kwargs['latent_dim'], 124 | hidden_dims=kwargs['encoder_hidden_dims'], 125 | dop=kwargs['dop']).to(kwargs['device']) 126 | 127 | shared_decoder = MLP(input_dim=2 * kwargs['latent_dim'], 128 | output_dim=kwargs['input_dim'], 129 | hidden_dims=kwargs['encoder_hidden_dims'][::-1], 130 | dop=kwargs['dop']).to(kwargs['device']) 131 | 132 | s_dsnae = DSNAE(shared_encoder=shared_encoder, 133 | decoder=shared_decoder, 134 | alpha=kwargs['alpha'], 135 | input_dim=kwargs['input_dim'], 136 | latent_dim=kwargs['latent_dim'], 137 | hidden_dims=kwargs['encoder_hidden_dims'], 138 | dop=kwargs['dop'], 139 | norm_flag=kwargs['norm_flag']).to(kwargs['device']) 140 | 141 | t_dsnae = DSNAE(shared_encoder=shared_encoder, 142 | decoder=shared_decoder, 143 | alpha=kwargs['alpha'], 144 | input_dim=kwargs['input_dim'], 145 | latent_dim=kwargs['latent_dim'], 146 | hidden_dims=kwargs['encoder_hidden_dims'], 147 | dop=kwargs['dop'], 148 | norm_flag=kwargs['norm_flag']).to(kwargs['device']) 149 | 150 | confounding_classifier = MLP(input_dim=kwargs['latent_dim'], 151 | output_dim=1, 152 | hidden_dims=kwargs['classifier_hidden_dims'], 153 | dop=kwargs['dop']).to(kwargs['device']) 154 | 155 | 156 | dsnae_train_history = defaultdict(list) 157 | dsnae_val_history = defaultdict(list) 158 | critic_train_history = defaultdict(list) 159 | gen_train_history = defaultdict(list) 160 | # classification_eval_test_history = defaultdict(list) 161 | # classification_eval_train_history = defaultdict(list) 162 | 163 | if kwargs['retrain_flag']: 164 | ae_params = [t_dsnae.private_encoder.parameters(), 165 | s_dsnae.private_encoder.parameters(), 166 | shared_decoder.parameters(), 167 | shared_encoder.parameters() 168 | ] 169 | t_ae_params = [t_dsnae.private_encoder.parameters(), 170 | s_dsnae.private_encoder.parameters(), 171 | shared_decoder.parameters(), 172 | shared_encoder.parameters() 173 | ] 174 | 175 | ae_optimizer = torch.optim.AdamW(chain(*ae_params), lr=kwargs['lr']) 176 | classifier_optimizer = torch.optim.RMSprop(confounding_classifier.parameters(), lr=kwargs['lr']) 177 | t_ae_optimizer = torch.optim.RMSprop(chain(*t_ae_params), lr=kwargs['lr']) 178 | 179 | 180 | # start dsnae pre-training 181 | for epoch in range(int(kwargs['pretrain_num_epochs'])): 182 | if epoch % 50 == 0: 183 | print(f'AE training epoch {epoch}') 184 | for step, s_batch in enumerate(s_train_dataloader): 185 | t_batch = next(iter(t_train_dataloader)) 186 | dsnae_train_history = dsn_ae_train_step(s_dsnae=s_dsnae, 187 | t_dsnae=t_dsnae, 188 | s_batch=s_batch, 189 | t_batch=t_batch, 190 | device=kwargs['device'], 191 | optimizer=ae_optimizer, 192 | history=dsnae_train_history) 193 | dsnae_val_history = eval_dsnae_epoch(model=s_dsnae, 194 | data_loader=s_test_dataloader, 195 | device=kwargs['device'], 196 | history=dsnae_val_history 197 | ) 198 | dsnae_val_history = eval_dsnae_epoch(model=t_dsnae, 199 | data_loader=t_test_dataloader, 200 | device=kwargs['device'], 201 | history=dsnae_val_history 202 | ) 203 | for k in dsnae_val_history: 204 | if k != 'best_index': 205 | dsnae_val_history[k][-2] += dsnae_val_history[k][-1] 206 | dsnae_val_history[k].pop() 207 | 208 | save_flag, stop_flag = model_save_check(dsnae_val_history, metric_name='loss', tolerance_count=20) 209 | if kwargs['es_flag']: 210 | if save_flag: 211 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'a_s_dsnae.pt')) 212 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt')) 213 | if stop_flag: 214 | break 215 | 216 | if kwargs['es_flag']: 217 | s_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'a_s_dsnae.pt'))) 218 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt'))) 219 | 220 | # start critic pre-training 221 | # for epoch in range(100): 222 | # if epoch % 10 == 0: 223 | # print(f'confounder critic pre-training epoch {epoch}') 224 | # for step, t_batch in enumerate(s_train_dataloader): 225 | # s_batch = next(iter(t_train_dataloader)) 226 | # critic_train_history = critic_dsn_train_step(critic=confounding_classifier, 227 | # s_dsnae=s_dsnae, 228 | # t_dsnae=t_dsnae, 229 | # s_batch=s_batch, 230 | # t_batch=t_batch, 231 | # device=kwargs['device'], 232 | # optimizer=classifier_optimizer, 233 | # history=critic_train_history, 234 | # clip=None, 235 | # gp=None) 236 | # start GAN training 237 | for epoch in range(int(kwargs['train_num_epochs'])): 238 | if epoch % 50 == 0: 239 | print(f'confounder wgan training epoch {epoch}') 240 | for step, s_batch in enumerate(s_train_dataloader): 241 | t_batch = next(iter(t_train_dataloader)) 242 | critic_train_history = critic_dsn_train_step(critic=confounding_classifier, 243 | s_dsnae=s_dsnae, 244 | t_dsnae=t_dsnae, 245 | s_batch=s_batch, 246 | t_batch=t_batch, 247 | device=kwargs['device'], 248 | optimizer=classifier_optimizer, 249 | history=critic_train_history, 250 | # clip=0.1, 251 | gp=10.0) 252 | if (step + 1) % 5 == 0: 253 | gen_train_history = gan_dsn_gen_train_step(critic=confounding_classifier, 254 | s_dsnae=s_dsnae, 255 | t_dsnae=t_dsnae, 256 | s_batch=s_batch, 257 | t_batch=t_batch, 258 | device=kwargs['device'], 259 | optimizer=t_ae_optimizer, 260 | alpha=1.0, 261 | history=gen_train_history) 262 | 263 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'a_s_dsnae.pt')) 264 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt')) 265 | else: 266 | try: 267 | # loaded_model = torch.load(os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt')) 268 | # new_loaded_model = {key: val for key, val in loaded_model.items() if key in t_dsnae.state_dict()} 269 | # new_loaded_model['shared_encoder.output_layer.0.weight'] = loaded_model[ 270 | # 'shared_encoder.output_layer.3.weight'] 271 | # new_loaded_model['shared_encoder.output_layer.0.bias'] = loaded_model[ 272 | # 'shared_encoder.output_layer.3.bias'] 273 | # new_loaded_model['decoder.output_layer.0.weight'] = loaded_model['decoder.output_layer.3.weight'] 274 | # new_loaded_model['decoder.output_layer.0.bias'] = loaded_model['decoder.output_layer.3.bias'] 275 | 276 | # corrected_model = OrderedDict({key: new_loaded_model[key] for key in t_dsnae.state_dict()}) 277 | # t_dsnae.load_state_dict(corrected_model) 278 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt'))) 279 | except FileNotFoundError: 280 | raise Exception("No pre-trained encoder") 281 | 282 | return t_dsnae.shared_encoder, (dsnae_train_history, dsnae_val_history, critic_train_history, gen_train_history) 283 | -------------------------------------------------------------------------------- /code/model_train/train_dsrn_adv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.autograd as autograd 3 | from itertools import chain 4 | from dsn_ae import DSNAE 5 | from evaluation_utils import * 6 | from mlp import MLP 7 | from train_dsn import eval_dsnae_epoch, dsn_ae_train_step 8 | from collections import OrderedDict 9 | 10 | def compute_gradient_penalty(critic, real_samples, fake_samples, device): 11 | """Calculates the gradient penalty loss for WGAN GP""" 12 | # Random weight term for interpolation between real and fake samples 13 | alpha = torch.rand((real_samples.shape[0], 1)).to(device) 14 | # Get random interpolation between real and fake samples 15 | interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) 16 | critic_interpolates = critic(interpolates) 17 | fakes = torch.ones((real_samples.shape[0], 1)).to(device) 18 | # Get gradient w.r.t. interpolates 19 | gradients = autograd.grad( 20 | outputs=critic_interpolates, 21 | inputs=interpolates, 22 | grad_outputs=fakes, 23 | create_graph=True, 24 | retain_graph=True, 25 | only_inputs=True, 26 | )[0] 27 | gradients = gradients.view(gradients.size(0), -1) 28 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() 29 | return gradient_penalty 30 | 31 | 32 | def critic_dsn_train_step(critic, s_dsnae, t_dsnae, s_batch, t_batch, device, optimizer, history, scheduler=None, 33 | clip=None, gp=None): 34 | critic.zero_grad() 35 | s_dsnae.zero_grad() 36 | t_dsnae.zero_grad() 37 | s_dsnae.eval() 38 | t_dsnae.eval() 39 | critic.train() 40 | 41 | s_x = s_batch[0].to(device) 42 | t_x = t_batch[0].to(device) 43 | 44 | s_code = s_dsnae.encode(s_x) 45 | t_code = t_dsnae.encode(t_x) 46 | 47 | loss = torch.mean(critic(t_code)) - torch.mean(critic(s_code)) 48 | 49 | if gp is not None: 50 | gradient_penalty = compute_gradient_penalty(critic, 51 | real_samples=s_code, 52 | fake_samples=t_code, 53 | device=device) 54 | loss = loss + gp * gradient_penalty 55 | 56 | optimizer.zero_grad() 57 | loss.backward() 58 | # if clip is not None: 59 | # torch.nn.utils.clip_grad_norm_(model.parameters(), clip) 60 | optimizer.step() 61 | 62 | if clip is not None: 63 | for p in critic.parameters(): 64 | p.data.clamp_(-clip, clip) 65 | if scheduler is not None: 66 | scheduler.step() 67 | 68 | history['critic_loss'].append(loss.cpu().detach().item()) 69 | 70 | return history 71 | 72 | 73 | def gan_dsn_gen_train_step(critic, s_dsnae, t_dsnae, s_batch, t_batch, device, optimizer, alpha, history, 74 | scheduler=None): 75 | critic.zero_grad() 76 | s_dsnae.zero_grad() 77 | t_dsnae.zero_grad() 78 | critic.eval() 79 | s_dsnae.train() 80 | t_dsnae.train() 81 | 82 | s_x = s_batch[0].to(device) 83 | t_x = t_batch[0].to(device) 84 | 85 | t_code = t_dsnae.encode(t_x) 86 | 87 | optimizer.zero_grad() 88 | gen_loss = -torch.mean(critic(t_code)) #生成器的loss,为了让target预测尽可能大,loss就要尽可能小 89 | s_loss_dict = s_dsnae.loss_function(*s_dsnae(s_x)) 90 | t_loss_dict = t_dsnae.loss_function(*t_dsnae(t_x)) 91 | recons_loss = s_loss_dict['loss'] + t_loss_dict['loss'] #dict中取第一项loss,将source和target的总loss(L code−ae−base)相加 92 | loss = recons_loss + alpha * gen_loss 93 | optimizer.zero_grad() 94 | 95 | loss.backward() 96 | optimizer.step() 97 | if scheduler is not None: 98 | scheduler.step() 99 | 100 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()} 101 | 102 | for k, v in loss_dict.items(): 103 | history[k].append(v) 104 | history['gen_loss'].append(gen_loss.cpu().detach().item()) 105 | 106 | return history 107 | 108 | 109 | def train_dsrn_adv(s_dataloaders, t_dataloaders, **kwargs): 110 | """ 111 | 112 | :param s_dataloaders: 113 | :param t_dataloaders: 114 | :param kwargs: 115 | :return: 116 | """ 117 | s_train_dataloader = s_dataloaders[0] 118 | s_test_dataloader = s_dataloaders[1] 119 | 120 | t_train_dataloader = t_dataloaders[0] 121 | t_test_dataloader = t_dataloaders[1] 122 | 123 | shared_encoder = MLP(input_dim=kwargs['input_dim'], 124 | output_dim=kwargs['latent_dim'], 125 | hidden_dims=kwargs['encoder_hidden_dims'], 126 | dop=kwargs['dop']).to(kwargs['device']) 127 | 128 | shared_decoder = MLP(input_dim=2 * kwargs['latent_dim'], 129 | output_dim=kwargs['input_dim'], 130 | hidden_dims=kwargs['encoder_hidden_dims'][::-1], 131 | dop=kwargs['dop']).to(kwargs['device']) 132 | 133 | s_dsnae = DSNAE(shared_encoder=shared_encoder, 134 | decoder=shared_decoder, 135 | alpha=kwargs['alpha'], 136 | input_dim=kwargs['input_dim'], 137 | latent_dim=kwargs['latent_dim'], 138 | hidden_dims=kwargs['encoder_hidden_dims'], 139 | dop=kwargs['dop'], 140 | norm_flag=kwargs['norm_flag']).to(kwargs['device']) 141 | 142 | t_dsnae = DSNAE(shared_encoder=shared_encoder, 143 | decoder=shared_decoder, 144 | alpha=kwargs['alpha'], 145 | input_dim=kwargs['input_dim'], 146 | latent_dim=kwargs['latent_dim'], 147 | hidden_dims=kwargs['encoder_hidden_dims'], 148 | dop=kwargs['dop'], 149 | norm_flag=kwargs['norm_flag']).to(kwargs['device']) 150 | 151 | confounding_classifier = MLP(input_dim=kwargs['latent_dim'] * 2, 152 | output_dim=1, 153 | hidden_dims=kwargs['classifier_hidden_dims'], 154 | dop=kwargs['dop']).to(kwargs['device']) 155 | 156 | ae_params = [t_dsnae.private_encoder.parameters(), 157 | s_dsnae.private_encoder.parameters(), 158 | shared_decoder.parameters(), 159 | shared_encoder.parameters() 160 | ] 161 | t_ae_params = [t_dsnae.private_encoder.parameters(), 162 | s_dsnae.private_encoder.parameters(), 163 | shared_decoder.parameters(), 164 | shared_encoder.parameters() 165 | ] 166 | 167 | ae_optimizer = torch.optim.AdamW(chain(*ae_params), lr=kwargs['lr']) 168 | classifier_optimizer = torch.optim.RMSprop(confounding_classifier.parameters(), lr=kwargs['lr']) 169 | t_ae_optimizer = torch.optim.RMSprop(chain(*t_ae_params), lr=kwargs['lr']) 170 | 171 | dsnae_train_history = defaultdict(list) 172 | dsnae_val_history = defaultdict(list) 173 | critic_train_history = defaultdict(list) 174 | gen_train_history = defaultdict(list) 175 | # classification_eval_test_history = defaultdict(list) 176 | # classification_eval_train_history = defaultdict(list) 177 | 178 | # print(os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt')) 179 | if kwargs['retrain_flag']: 180 | # start dsnae pre-training 181 | for epoch in range(int(kwargs['pretrain_num_epochs'])): 182 | if epoch % 50 == 0: 183 | print(f'AE training epoch {epoch}') 184 | for step, s_batch in enumerate(s_train_dataloader): 185 | t_batch = next(iter(t_train_dataloader)) 186 | dsnae_train_history = dsn_ae_train_step(s_dsnae=s_dsnae, 187 | t_dsnae=t_dsnae, 188 | s_batch=s_batch, 189 | t_batch=t_batch, 190 | device=kwargs['device'], 191 | optimizer=ae_optimizer, 192 | history=dsnae_train_history) 193 | dsnae_val_history = eval_dsnae_epoch(model=s_dsnae, 194 | data_loader=s_test_dataloader, 195 | device=kwargs['device'], 196 | history=dsnae_val_history 197 | ) 198 | dsnae_val_history = eval_dsnae_epoch(model=t_dsnae, 199 | data_loader=t_test_dataloader, 200 | device=kwargs['device'], 201 | history=dsnae_val_history 202 | ) 203 | for k in dsnae_val_history: 204 | if k != 'best_index': 205 | dsnae_val_history[k][-2] += dsnae_val_history[k][-1] 206 | dsnae_val_history[k].pop() 207 | 208 | if kwargs['es_flag']: 209 | save_flag, stop_flag = model_save_check(dsnae_val_history, metric_name='loss', tolerance_count=20) 210 | if save_flag: 211 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'a_s_dsnae.pt')) 212 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt')) 213 | if stop_flag: 214 | break 215 | if kwargs['es_flag']: 216 | s_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'a_s_dsnae.pt'))) 217 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt'))) 218 | 219 | # start critic pre-training 220 | # for epoch in range(100): 221 | # if epoch % 10 == 0: 222 | # print(f'confounder critic pre-training epoch {epoch}') 223 | # for step, t_batch in enumerate(s_train_dataloader): 224 | # s_batch = next(iter(t_train_dataloader)) 225 | # critic_train_history = critic_dsn_train_step(critic=confounding_classifier, 226 | # s_dsnae=s_dsnae, 227 | # t_dsnae=t_dsnae, 228 | # s_batch=s_batch, 229 | # t_batch=t_batch, 230 | # device=kwargs['device'], 231 | # optimizer=classifier_optimizer, 232 | # history=critic_train_history, 233 | # clip=None, 234 | # gp=None) 235 | # start GAN training 236 | for epoch in range(int(kwargs['train_num_epochs'])): 237 | if epoch % 50 == 0: 238 | print(f'confounder wgan training epoch {epoch}') 239 | for step, s_batch in enumerate(s_train_dataloader): 240 | t_batch = next(iter(t_train_dataloader)) 241 | 242 | critic_train_history = critic_dsn_train_step(critic=confounding_classifier, #让判别器学的更好target分高,source分低,再加上梯度限制:loss = torch.mean(critic(t_code)) - torch.mean(critic(s_code)) + gradient_penalty 243 | s_dsnae=s_dsnae, 244 | t_dsnae=t_dsnae, 245 | s_batch=s_batch, 246 | t_batch=t_batch, 247 | device=kwargs['device'], 248 | optimizer=classifier_optimizer, 249 | history=critic_train_history, 250 | # clip=0.1, 251 | gp=10.0) 252 | if (step + 1) % 5 == 0: 253 | gen_train_history = gan_dsn_gen_train_step(critic=confounding_classifier,#让生成器更好:输入target生成更像source的表示,即分越高,但loss越小越好,故加上负号 254 | s_dsnae=s_dsnae, 255 | t_dsnae=t_dsnae, 256 | s_batch=s_batch, 257 | t_batch=t_batch, 258 | device=kwargs['device'], 259 | optimizer=t_ae_optimizer, 260 | alpha=1.0, 261 | history=gen_train_history) 262 | 263 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'a_s_dsnae.pt')) 264 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt')) 265 | 266 | # s_batch = next(iter(s_train_dataloader)) 267 | # t_batch = next(iter(t_train_dataloader)) 268 | # s_x = s_batch[0].to('cuda') 269 | # t_x = t_batch[0].to('cuda') 270 | # print(f"s_x: {s_x}") 271 | # print(f"t_x: {t_x}") 272 | # print(f"Patient encoding: {t_dsnae(t_x)}") 273 | # print(f"Cell line encoding: {t_dsnae(s_x)}") 274 | # exit() 275 | else: 276 | try: 277 | # if kwargs['norm_flag']: 278 | # loaded_model = torch.load(os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt')) 279 | # new_loaded_model = {key: val for key, val in loaded_model.items() if key in t_dsnae.state_dict()} 280 | # new_loaded_model['shared_encoder.output_layer.0.weight'] = loaded_model[ 281 | # 'shared_encoder.output_layer.3.weight'] 282 | # new_loaded_model['shared_encoder.output_layer.0.bias'] = loaded_model[ 283 | # 'shared_encoder.output_layer.3.bias'] 284 | # new_loaded_model['decoder.output_layer.0.weight'] = loaded_model['decoder.output_layer.3.weight'] 285 | # new_loaded_model['decoder.output_layer.0.bias'] = loaded_model['decoder.output_layer.3.bias'] 286 | 287 | # corrected_model = OrderedDict({key: new_loaded_model[key] for key in t_dsnae.state_dict()}) 288 | # t_dsnae.load_state_dict(corrected_model) 289 | # else: 290 | #t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt'))) 291 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt'))) 292 | # t_dsnae.eval() 293 | 294 | except FileNotFoundError: 295 | raise Exception("No pre-trained encoder") 296 | 297 | return t_dsnae.shared_encoder, (dsnae_train_history, dsnae_val_history, critic_train_history, gen_train_history) 298 | 299 | -------------------------------------------------------------------------------- /code/model_train/train_ae_adv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import os 5 | from evaluation_utils import eval_ae_epoch, evaluate_adv_classification_epoch, model_save_check 6 | from collections import defaultdict 7 | from ae import AE 8 | from mlp import MLP 9 | from encoder_decoder import EncoderDecoder 10 | 11 | def ae_train_step(ae, s_batch, t_batch, device, optimizer, history, scheduler=None): 12 | ae.zero_grad() 13 | ae.train() 14 | 15 | s_x = s_batch[0].to(device) 16 | t_x = t_batch[0].to(device) 17 | 18 | s_loss_dict = ae.loss_function(*ae(s_x)) 19 | t_loss_dict = ae.loss_function(*ae(t_x)) 20 | 21 | optimizer.zero_grad() 22 | loss = s_loss_dict['loss'] + t_loss_dict['loss'] 23 | optimizer.zero_grad() 24 | 25 | loss.backward() 26 | optimizer.step() 27 | if scheduler is not None: 28 | scheduler.step() 29 | 30 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()} 31 | 32 | for k, v in loss_dict.items(): 33 | history[k].append(v) 34 | 35 | return history 36 | 37 | 38 | def classification_train_step(classifier, s_batch, t_batch, loss_fn, device, optimizer, history, scheduler=None, 39 | clip=None): 40 | classifier.zero_grad() 41 | classifier.train() 42 | 43 | s_x = s_batch[0].to(device) 44 | t_x = t_batch[0].to(device) 45 | outputs = torch.cat((classifier(s_x), classifier(t_x)), dim=0) 46 | truths = torch.cat((torch.zeros(s_x.shape[0], 1), torch.ones(t_x.shape[0], 1)), dim=0).to(device) 47 | loss = loss_fn(outputs, truths) 48 | 49 | # valid = torch.ones((s_x.shape[0], 1)).to(device) 50 | # fake = torch.zeros((t_x.shape[0], 1)).to(device) 51 | # 52 | # real_loss = loss_fn((classifier(s_x)), valid) 53 | # fake_loss = loss_fn(classifier(t_x), fake) 54 | # loss = 0.5 * (real_loss + fake_loss) 55 | 56 | optimizer.zero_grad() 57 | loss.backward() 58 | # if clip is not None: 59 | # torch.nn.utils.clip_grad_norm_(model.parameters(), clip) 60 | optimizer.step() 61 | if clip is not None: 62 | for p in classifier.decoder.parameters(): 63 | p.data.clamp_(-clip, clip) 64 | if scheduler is not None: 65 | scheduler.step() 66 | 67 | history['bce'].append(loss.cpu().detach().item()) 68 | 69 | return history 70 | 71 | 72 | def customized_ae_train_step(classifier, ae, s_batch, t_batch, loss_fn, alpha, device, optimizer, history, 73 | scheduler=None): 74 | classifier.zero_grad() 75 | ae.zero_grad() 76 | classifier.eval() 77 | ae.train() 78 | 79 | s_x = s_batch[0].to(device) 80 | t_x = t_batch[0].to(device) 81 | 82 | outputs = torch.cat((classifier(s_x), classifier(t_x)), dim=0) 83 | truths = torch.cat((torch.zeros(s_x.shape[0], 1), torch.ones(t_x.shape[0], 1)), dim=0).to(device) 84 | adv_loss = loss_fn(outputs, truths) 85 | 86 | # valid = torch.ones((s_x.shape[0], 1)).to(device) 87 | # fake = torch.zeros((t_x.shape[0], 1)).to(device) 88 | # 89 | # real_loss = loss_fn((classifier(s_x)), valid) 90 | # fake_loss = loss_fn(classifier(t_x), fake) 91 | # adv_loss = 0.5 * (real_loss + fake_loss) 92 | 93 | s_loss_dict = ae.loss_function(*ae(s_x)) 94 | t_loss_dict = ae.loss_function(*ae(t_x)) 95 | loss = s_loss_dict['loss'] + t_loss_dict['loss'] - alpha * adv_loss 96 | 97 | optimizer.zero_grad() 98 | loss.backward() 99 | optimizer.step() 100 | 101 | if scheduler is not None: 102 | scheduler.step() 103 | 104 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()} 105 | for k, v in loss_dict.items(): 106 | history[k].append(v) 107 | # history['bce'].append(adv_loss.cpu().detach().item()) 108 | 109 | return history 110 | 111 | 112 | def train_ae_adv(s_dataloaders, t_dataloaders, **kwargs): 113 | """ 114 | 115 | :param s_dataloaders: 116 | :param t_dataloaders: 117 | :param kwargs: 118 | :return: 119 | """ 120 | s_train_dataloader = s_dataloaders[0] 121 | s_test_dataloader = s_dataloaders[1] 122 | 123 | t_train_dataloader = t_dataloaders[0] 124 | t_test_dataloader = t_dataloaders[1] 125 | 126 | autoencoder = AE(input_dim=kwargs['input_dim'], 127 | latent_dim=kwargs['latent_dim'], 128 | hidden_dims=kwargs['encoder_hidden_dims'], 129 | dop=kwargs['dop']).to(kwargs['device']) 130 | classifier = MLP(input_dim=kwargs['latent_dim'], 131 | output_dim=1, 132 | hidden_dims=kwargs['classifier_hidden_dims'], 133 | dop=kwargs['dop']).to(kwargs['device']) 134 | confounder_classifier = EncoderDecoder(encoder=autoencoder.encoder, decoder=classifier).to(kwargs['device']) 135 | 136 | ae_eval_train_history = defaultdict(list) 137 | ae_eval_val_history = defaultdict(list) 138 | classifier_pretrain_history = defaultdict(list) 139 | classification_eval_test_history = defaultdict(list) 140 | classification_eval_train_history = defaultdict(list) 141 | 142 | if kwargs['retrain_flag']: 143 | confounded_loss = nn.BCEWithLogitsLoss() 144 | ae_optimizer = torch.optim.AdamW(autoencoder.parameters(), lr=kwargs['lr']) 145 | classifier_optimizer = torch.optim.AdamW(confounder_classifier.decoder.parameters(), lr=kwargs['lr']) 146 | 147 | 148 | # start autoencoder pretraining 149 | for epoch in range(int(kwargs['pretrain_num_epochs'])): 150 | if epoch % 50 == 0: 151 | print(f'----Autoencoder Pre-Training Epoch {epoch} ----') 152 | 153 | for step, s_batch in enumerate(s_train_dataloader): 154 | t_batch = next(iter(t_train_dataloader)) 155 | ae_eval_train_history = ae_train_step(ae=autoencoder, 156 | s_batch=s_batch, 157 | t_batch=t_batch, 158 | device=kwargs['device'], 159 | optimizer=ae_optimizer, 160 | history=ae_eval_train_history) 161 | 162 | ae_eval_val_history = eval_ae_epoch(model=autoencoder, 163 | data_loader=s_test_dataloader, 164 | device=kwargs['device'], 165 | history=ae_eval_val_history 166 | ) 167 | ae_eval_val_history = eval_ae_epoch(model=autoencoder, 168 | data_loader=t_test_dataloader, 169 | device=kwargs['device'], 170 | history=ae_eval_val_history 171 | ) 172 | for k in ae_eval_val_history: 173 | if k != 'best_index': 174 | ae_eval_val_history[k][-2] += ae_eval_val_history[k][-1] 175 | ae_eval_val_history[k].pop() 176 | # print some loss/metric messages 177 | if kwargs['es_flag']: 178 | save_flag, stop_flag = model_save_check(history=ae_eval_val_history, metric_name='loss', 179 | tolerance_count=10) 180 | if save_flag: 181 | torch.save(autoencoder.state_dict(), os.path.join(kwargs['model_save_folder'], 'ae.pt')) 182 | if stop_flag: 183 | break 184 | 185 | if kwargs['es_flag']: 186 | autoencoder.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt'))) 187 | 188 | # start adversarial classifier pre-training 189 | for epoch in range(int(kwargs['pretrain_num_epochs'])): 190 | if epoch % 50 == 0: 191 | print(f'Adversarial classifier pre-training epoch {epoch}') 192 | for step, s_batch in enumerate(s_train_dataloader): 193 | t_batch = next(iter(t_train_dataloader)) 194 | classifier_pretrain_history = classification_train_step(classifier=confounder_classifier, 195 | s_batch=s_batch, 196 | t_batch=t_batch, 197 | loss_fn=confounded_loss, 198 | device=kwargs['device'], 199 | optimizer=classifier_optimizer, 200 | history=classifier_pretrain_history) 201 | 202 | classification_eval_test_history = evaluate_adv_classification_epoch(classifier=confounder_classifier, 203 | s_dataloader=s_test_dataloader, 204 | t_dataloader=t_test_dataloader, 205 | device=kwargs['device'], 206 | history=classification_eval_test_history) 207 | classification_eval_train_history = evaluate_adv_classification_epoch(classifier=confounder_classifier, 208 | s_dataloader=s_train_dataloader, 209 | t_dataloader=t_train_dataloader, 210 | device=kwargs['device'], 211 | history=classification_eval_train_history) 212 | 213 | save_flag, stop_flag = model_save_check(history=classification_eval_test_history, metric_name='acc', 214 | tolerance_count=50) 215 | if kwargs['es_flag']: 216 | if save_flag: 217 | torch.save(confounder_classifier.state_dict(), 218 | os.path.join(kwargs['model_save_folder'], 'adv_classifier.pt')) 219 | if stop_flag: 220 | break 221 | 222 | if kwargs['es_flag']: 223 | confounder_classifier.load_state_dict( 224 | torch.load(os.path.join(kwargs['model_save_folder'], 'adv_classifier.pt'))) 225 | 226 | # start alternative training 227 | for epoch in range(int(kwargs['train_num_epochs'])): 228 | if epoch % 50 == 0: 229 | print(f'Alternative training epoch {epoch}') 230 | # start autoencoder training epoch 231 | for step, s_batch in enumerate(s_train_dataloader): 232 | t_batch = next(iter(t_train_dataloader)) 233 | ae_eval_train_history = customized_ae_train_step(classifier=confounder_classifier, 234 | ae=autoencoder, 235 | s_batch=s_batch, 236 | t_batch=t_batch, 237 | loss_fn=confounded_loss, 238 | alpha=kwargs['alpha'], 239 | device=kwargs['device'], 240 | optimizer=ae_optimizer, 241 | history=ae_eval_train_history, 242 | scheduler=None) 243 | 244 | ae_eval_val_history = eval_ae_epoch(model=autoencoder, 245 | data_loader=s_test_dataloader, 246 | device=kwargs['device'], 247 | history=ae_eval_val_history 248 | ) 249 | ae_eval_val_history = eval_ae_epoch(model=autoencoder, 250 | data_loader=t_test_dataloader, 251 | device=kwargs['device'], 252 | history=ae_eval_val_history 253 | ) 254 | for k in ae_eval_val_history: 255 | if k != 'best_index': 256 | ae_eval_val_history[k][-2] += ae_eval_val_history[k][-1] 257 | ae_eval_val_history[k].pop() 258 | 259 | classification_eval_test_history = evaluate_adv_classification_epoch(classifier=confounder_classifier, 260 | s_dataloader=s_test_dataloader, 261 | t_dataloader=t_test_dataloader, 262 | device=kwargs['device'], 263 | history=classification_eval_test_history) 264 | classification_eval_train_history = evaluate_adv_classification_epoch(classifier=confounder_classifier, 265 | s_dataloader=s_train_dataloader, 266 | t_dataloader=t_train_dataloader, 267 | device=kwargs['device'], 268 | history=classification_eval_train_history) 269 | 270 | for step, s_batch in enumerate(s_train_dataloader): 271 | t_batch = next(iter(t_train_dataloader)) 272 | classifier_pretrain_history = classification_train_step(classifier=confounder_classifier, 273 | s_batch=s_batch, 274 | t_batch=t_batch, 275 | loss_fn=confounded_loss, 276 | device=kwargs['device'], 277 | optimizer=classifier_optimizer, 278 | history=classifier_pretrain_history) 279 | 280 | classification_eval_test_history = evaluate_adv_classification_epoch(classifier=confounder_classifier, 281 | s_dataloader=s_test_dataloader, 282 | t_dataloader=t_test_dataloader, 283 | device=kwargs['device'], 284 | history=classification_eval_test_history) 285 | classification_eval_train_history = evaluate_adv_classification_epoch(classifier=confounder_classifier, 286 | s_dataloader=s_train_dataloader, 287 | t_dataloader=t_test_dataloader, 288 | device=kwargs['device'], 289 | history=classification_eval_train_history) 290 | 291 | torch.save(autoencoder.state_dict(), os.path.join(kwargs['model_save_folder'], 'ae.pt')) 292 | 293 | else: 294 | try: 295 | autoencoder.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt'))) 296 | except FileNotFoundError: 297 | raise Exception("No pre-trained encoder") 298 | 299 | return autoencoder.encoder, (ae_eval_train_history, 300 | ae_eval_val_history, 301 | classification_eval_train_history, 302 | classification_eval_test_history, 303 | classifier_pretrain_history) 304 | -------------------------------------------------------------------------------- /code/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from sklearn import preprocessing 8 | from sklearn.model_selection import train_test_split, StratifiedKFold 9 | from torch.utils.data import TensorDataset, DataLoader,Dataset 10 | from rdkit import RDLogger 11 | import math 12 | 13 | def set_seed(seed): 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | if torch.cuda.device_count() > 0: 18 | torch.cuda.manual_seed_all(seed) 19 | 20 | def get_unlabeled_dataloaders(gex_features_df, seed, batch_size, ccle_only=False): 21 | """ 22 | Cancer cell lines as source domain, thus s_dataloaders 23 | Patients as target domain, thus t_dataloaders 24 | """ 25 | set_seed(seed) 26 | 27 | ccle_sample_info_df = pd.read_csv('../data/supple_info/sample_info/ccle_sample_info.csv', index_col=0) 28 | ccle_sample_info_df = ccle_sample_info_df.reset_index().drop_duplicates(subset="Depmap_id",keep='first').set_index("Depmap_id") 29 | 30 | xena_sample_info_df = pd.read_csv('../data/supple_info/sample_info/xena_sample_info.csv', index_col=0) 31 | xena_samples = xena_sample_info_df.index.intersection(gex_features_df.index) #gex_features_df.loc[gex_features_df.index.str.startswith('TCGA')].index 32 | ccle_samples = ccle_sample_info_df.index.intersection(gex_features_df.index) 33 | xena_sample_info_df = xena_sample_info_df.loc[xena_samples] 34 | ccle_sample_info_df = ccle_sample_info_df.loc[ccle_samples.intersection(ccle_sample_info_df.index)] 35 | 36 | xena_df = gex_features_df.loc[xena_samples] 37 | ccle_df = gex_features_df.loc[ccle_samples] 38 | 39 | excluded_ccle_samples = [] 40 | excluded_ccle_samples.extend(ccle_df.index.difference(ccle_sample_info_df.index)) 41 | excluded_ccle_diseases = ccle_sample_info_df.primary_disease.value_counts()[ 42 | ccle_sample_info_df.primary_disease.value_counts() < 2].index 43 | excluded_ccle_samples.extend( 44 | ccle_sample_info_df[ccle_sample_info_df.primary_disease.isin(excluded_ccle_diseases)].index) 45 | 46 | to_split_ccle_df = ccle_df[~ccle_df.index.isin(excluded_ccle_samples)] 47 | train_ccle_df, test_ccle_df = train_test_split(to_split_ccle_df, test_size=0.1, 48 | stratify=ccle_sample_info_df.loc[ 49 | to_split_ccle_df.index].primary_disease) 50 | test_ccle_df = test_ccle_df.append(ccle_df.loc[excluded_ccle_samples]) 51 | train_xena_df, test_xena_df = train_test_split(xena_df, test_size=0.1, #len(test_ccle_df) / len(xena_df), 52 | #stratify=xena_sample_info_df['_primary_disease'], 53 | random_state=seed) 54 | print(' ') 55 | print(f"Pretrain dataset: {xena_df.shape[0]}(TCGA) {to_split_ccle_df.shape[0]}(Cell line)") 56 | print(' ') 57 | 58 | xena_dataset = TensorDataset( 59 | torch.from_numpy(xena_df.values.astype('float32')) 60 | ) 61 | 62 | ccle_dataset = TensorDataset( 63 | torch.from_numpy(ccle_df.values.astype('float32')) 64 | ) 65 | 66 | train_xena_dateset = TensorDataset( 67 | torch.from_numpy(train_xena_df.values.astype('float32'))) 68 | test_xena_dateset = TensorDataset( 69 | torch.from_numpy(test_xena_df.values.astype('float32'))) 70 | train_ccle_dateset = TensorDataset( 71 | torch.from_numpy(train_ccle_df.values.astype('float32'))) 72 | test_ccle_dateset = TensorDataset( 73 | torch.from_numpy(test_ccle_df.values.astype('float32'))) 74 | 75 | xena_dataloader = DataLoader(xena_dataset, 76 | batch_size=batch_size, 77 | shuffle=True) 78 | train_xena_dataloader = DataLoader(train_xena_dateset, 79 | batch_size=batch_size, 80 | shuffle=True) 81 | test_xena_dataloader = DataLoader(test_xena_dateset, 82 | batch_size=batch_size, 83 | shuffle=True) 84 | 85 | ccle_data_loader = DataLoader(ccle_dataset, 86 | batch_size=batch_size, 87 | shuffle=True, 88 | drop_last=True 89 | ) 90 | 91 | train_ccle_dataloader = DataLoader(train_ccle_dateset, 92 | batch_size=batch_size, 93 | shuffle=True, drop_last=True) 94 | test_ccle_dataloader = DataLoader(test_ccle_dateset, 95 | batch_size=batch_size, 96 | shuffle=True) 97 | if ccle_only: 98 | return (ccle_data_loader, test_ccle_dataloader), (ccle_data_loader, test_ccle_dataloader) 99 | else: 100 | return (ccle_data_loader, test_ccle_dataloader), (xena_dataloader, test_xena_dataloader) 101 | 102 | 103 | 104 | def get_finetune_dataloader_generator(gex_features_df,label_type, sample_size = 0.006,dataset = 'gdsc1_raw', seed = 2020 , batch_size = 64, ccle_measurement='AUC', 105 | n_splits=5,q=2, 106 | tumor_type = "TCGA", 107 | select_drug_method = True): 108 | """ 109 | sensitive (responder): 1 110 | resistant (non-responder): 0 111 | 112 | """ 113 | RDLogger.DisableLog('rdApp.*') 114 | 115 | # TCGA_id(index),smiles(column 0),gex,label(column -1) 116 | test_labeled_dataloaders,cid_list = get_tcga_ZS_dataloaders(gex_features_df=gex_features_df, 117 | label_type = label_type, 118 | batch_size=batch_size, 119 | q=2, 120 | tumor_type = tumor_type) 121 | 122 | ccle_labeled_dataloader_generator = get_ccl_labeled_dataloader_generator(gex_features_df=gex_features_df, 123 | tumor_type = tumor_type, 124 | seed=seed, 125 | sample_size = sample_size, 126 | dataset = dataset, 127 | batch_size=batch_size, 128 | measurement=ccle_measurement, 129 | n_splits=n_splits,q=q, 130 | cid_list = cid_list, 131 | select_drug_method = select_drug_method) 132 | 133 | for train_labeled_ccle_dataloader, test_labeled_ccle_dataloader in ccle_labeled_dataloader_generator: 134 | yield train_labeled_ccle_dataloader, test_labeled_ccle_dataloader, test_labeled_dataloaders 135 | 136 | def get_ccl_labeled_dataloader_generator(gex_features_df,tumor_type,cid_list,select_drug_method,sample_size,dataset = 'gdsc_raw', batch_size = 64, seed=2020, 137 | measurement='AUC', n_splits=5,q=2): 138 | # measurement = 'Z_score' 'AUC' 'IC50' 139 | # dataset = 'gdsc1_raw' 'gdsc1_rebalance.csv' 140 | 141 | Dataset_path = '../data/finetune_data/{}.csv'.format(dataset) 142 | # if select_drug_method == "all": 143 | # Dataset_path = '../data/finetune_data/gdsc1_rebalance.csv' 144 | sensitivity_df = pd.read_csv(Dataset_path,index_col=1) 145 | # sensitivity_df.dropna(inplace=True) 146 | 147 | # target_df = sensitivity_df.groupby(['Depmap_id', 'Drug_smiles']).mean() 148 | # target_df = target_df.reset_index() 149 | ccle_sample_with_gex = pd.read_csv("../data/supple_info/ccle_sample_with_gex.csv",index_col=0) 150 | sensitivity_df = sensitivity_df.loc[sensitivity_df.index.isin(ccle_sample_with_gex.index)] 151 | if select_drug_method == "all": 152 | target_df = sensitivity_df 153 | elif select_drug_method == "overlap": 154 | target_df = sensitivity_df.loc[sensitivity_df['cid'].isin(cid_list)] 155 | print("Drug num: target(TCGA) / source(GDSC) / overlap = {0} {1} / {2} / {3} {4}".format( 156 | pd.Series(cid_list).unique() , pd.Series(cid_list).nunique(), 157 | sensitivity_df['Drug_name'].nunique(), 158 | target_df['Drug_name'].unique() , target_df['Drug_name'].nunique() 159 | )) 160 | elif select_drug_method == "random": 161 | target_df = sensitivity_df.sample(n = round(sample_size*sensitivity_df.shape[0])) # sample some data randomly 162 | print(' ') 163 | 164 | print("Select {0} dataset {1} / {2}".format(dataset, 165 | # round(sample_size*sensitivity_df.shape[0]),# 166 | target_df.shape[0], 167 | sensitivity_df.shape[0] 168 | )) 169 | print(' ') 170 | 171 | 172 | # if ccl_match == "yes": #use match 173 | # ccl_gex_df = gex_features_df 174 | # elif ccl_match == "no": #use all 175 | # ccl_gex_df = all_ccle_gex 176 | # elif ccl_match == "match_zs": # 177 | # ccle_sample_info = pd.read_csv('../data/supple_info/sample_info/ccle_sample_info.csv', index_col=0) 178 | # select_ccl = ccle_sample_info.loc[ccle_sample_info.tumor_type == tumor_type].index 179 | # ccl_gex_df = all_ccle_gex.loc[select_ccl] 180 | # else: 181 | # raise NotImplementedError('Not true ccl_match supplied!') 182 | ccl_gex_df = gex_features_df 183 | keep_samples = target_df.index.isin(ccl_gex_df.index) 184 | 185 | ccle_labeled_feature_df = target_df.loc[keep_samples][["Drug_smile","label"]] 186 | ccle_labeled_feature_df.dropna(inplace=True) 187 | ccle_labeled_feature_df = ccle_labeled_feature_df.merge(ccl_gex_df,left_index=True,right_index=True) 188 | 189 | drug_emb = pd.read_csv("../data/supple_info/drug_embedding/drug_embedding_for_cell_line.csv",index_col=0) 190 | ccle_labeled_feature_df = ccle_labeled_feature_df.merge(drug_emb,left_on="Drug_smile",right_on="Drug_smile") #gex,label(1),drug(300) 191 | del ccle_labeled_feature_df['Drug_smile'] 192 | ccle_labeled_feature_df.dropna(inplace=True) 193 | 194 | # ccle_labeled_feature_df.drop_duplicates(keep='first',inplace=True) 195 | # ccle_labeled_feature_df.to_csv("../r_plot/cesc_ccl_data.csv") #需要时把fintune筛选得到的ccl_feature输出来看看 196 | # if select_drug_method == "all": 197 | # ccle_labeled_feature_df['label'] = 0 198 | # ccle_labeled_feature_df.loc[ccle_labeled_feature_df[measurement]<0.55,'label'] = 1 199 | # print("Label distribution before: ",ccle_labeled_feature_df['label'].value_counts()) 200 | # ccle_labeled_feature_df = Rebalance_dataset_for_every_drug(ccle_labeled_feature_df) 201 | # ccle_labeled_feature_df = ccle_labeled_feature_df['label'].astype('float32') 202 | # ccle_labeled_feature_df['label'] = ccle_labeled_feature_df['label'].astype('int') 203 | # print("Label distribution after: ",ccle_labeled_feature_df['label'].value_counts()) 204 | 205 | ccle_labels = ccle_labeled_feature_df['label'].values 206 | # ccle_labeled_feature_df.to_csv("../see.csv") 207 | # del ccle_labeled_feature_df["Drug_name"] 208 | if max(ccle_labels) < 1 and min(ccle_labels) > 0: 209 | ccle_labels = (ccle_labels < np.median(ccle_labels)).astype('int') 210 | 211 | 212 | s_kfold = StratifiedKFold(n_splits=5, random_state=seed, shuffle=True) 213 | for train_index, test_index in s_kfold.split(ccle_labeled_feature_df.values, ccle_labels): 214 | train_labeled_ccle_df, test_labeled_ccle_df = ccle_labeled_feature_df.values[train_index], \ 215 | ccle_labeled_feature_df.values[test_index] 216 | # train_ccle_labels, test_ccle_labels = ccle_labels.values[train_index], ccle_labels.values[test_index] 217 | 218 | train_labeled_ccle_dateset = TensorDataset( 219 | # torch.from_numpy(train_labeled_ccle_df[:,0:train_labeled_ccle_df.shape[1]-1].astype('float32')), 220 | torch.from_numpy(train_labeled_ccle_df[:,1:(train_labeled_ccle_df.shape[1]-300)].astype('float32')), #gex:length-301 221 | torch.from_numpy(train_labeled_ccle_df[:,(train_labeled_ccle_df.shape[1]-300):(train_labeled_ccle_df.shape[1])].astype('float32')), #drug:300 222 | torch.from_numpy(train_labeled_ccle_df[:,0]) 223 | ) 224 | test_labeled_ccle_dateset = TensorDataset( 225 | # torch.from_numpy(test_labeled_ccle_df[:,0:test_labeled_ccle_df.shape[1]-1].astype('float32')), 226 | torch.from_numpy(test_labeled_ccle_df[:,1:(test_labeled_ccle_df.shape[1]-300)].astype('float32')), 227 | torch.from_numpy(test_labeled_ccle_df[:,(test_labeled_ccle_df.shape[1]-300):(test_labeled_ccle_df.shape[1])].astype('float32')), 228 | torch.from_numpy(test_labeled_ccle_df[:,0]) 229 | ) 230 | 231 | train_labeled_ccle_dataloader = DataLoader(train_labeled_ccle_dateset, 232 | batch_size=batch_size, 233 | # num_workers = 8, 234 | # pin_memory=True, 235 | shuffle=True) 236 | 237 | test_labeled_ccle_dataloader = DataLoader(test_labeled_ccle_dateset, 238 | batch_size=batch_size, 239 | # num_workers = 8, 240 | # pin_memory=True, 241 | shuffle=True) 242 | 243 | yield train_labeled_ccle_dataloader, test_labeled_ccle_dataloader 244 | 245 | 246 | def get_tcga_ZS_dataloaders(gex_features_df, batch_size, label_type = "PFS",q=2,tumor_type = "TCGA"): 247 | ''' 248 | gex_features_df: pd.DataFrame, index: patient_id, columns: gene_id 249 | label_type: str, "PFS" 250 | q: int 251 | ''' 252 | tcga_gex_feature_df = gex_features_df.loc[gex_features_df.index.str.startswith('TCGA')] 253 | tcga_gex_feature_df.index = tcga_gex_feature_df.index.map(lambda x: x[:12]) 254 | tcga_gex_feature_df = tcga_gex_feature_df.groupby(level=0).mean() 255 | 256 | response_df = pd.read_csv('../data/zeroshot_data/tcga698_single_drug_response_df.csv',index_col=1) 257 | response_type_df = pd.read_csv('../data/zeroshot_data/tcga667_single_drug_response_type_df.csv',index_col=1) 258 | 259 | if tumor_type != "TCGA": 260 | response_df = response_df.loc[response_df['tcga_project'] == tumor_type] 261 | response_type_df = response_type_df[response_type_df['tcga_project'] == tumor_type] 262 | 263 | response_df = response_df[['days_to_new_tumor_event_after_initial_treatment','smiles','cid']] 264 | response_type_df = response_type_df[['treatment_best_response','smiles','cid']] 265 | 266 | if label_type == "PFS": 267 | tcga_drug_gex = response_df.merge(tcga_gex_feature_df,left_index=True,right_index=True) 268 | elif label_type == "Imaging": 269 | tcga_drug_gex = response_type_df.merge(tcga_gex_feature_df,left_index=True,right_index=True) 270 | 271 | # Get the drug embedding of each drug for TCGA patients 272 | drug_emb = pd.read_csv("../data/supple_info/drug_embedding/drug_embedding_for_patient.csv",index_col=0) 273 | 274 | # print(tcga_drug_gex.shape) 275 | # print(tcga_drug_gex['smiles'].isin(drug_emb['Drug_smile'])) 276 | tcga_drug_gex = tcga_drug_gex.merge(drug_emb,left_on="smiles",right_on="Drug_smile") #gex,label(1),drug(300) 277 | # print(tcga_drug_gex.shape) 278 | 279 | del tcga_drug_gex['smiles'] 280 | del tcga_drug_gex['Drug_smile'] 281 | cid_list = tcga_drug_gex.iloc[:,1].values 282 | # print("cid_list: {}".format(cid_list)) 283 | del tcga_drug_gex['cid'] 284 | # tcga_drug_gex.dropna(inplace=True) 285 | 286 | if label_type == "PFS": 287 | # tcga_drug_gex = response_df.merge(tcga_gex_feature_df,left_index=True,right_index=True) 288 | drug_label = pd.qcut(tcga_drug_gex['days_to_new_tumor_event_after_initial_treatment'],q,labels = range(0,q)) 289 | del tcga_drug_gex['days_to_new_tumor_event_after_initial_treatment'] 290 | elif label_type == "Imaging": 291 | # tcga_drug_gex = response_type_df.merge(tcga_gex_feature_df,left_index=True,right_index=True) 292 | drug_label = np.array(tcga_drug_gex['treatment_best_response'].apply( 293 | lambda s: s in ['Complete Response']), dtype='int32') 294 | del tcga_drug_gex['treatment_best_response'] 295 | 296 | tcga_drug_gex['label'] = drug_label 297 | 298 | tcga_drug_gex = tcga_drug_gex.values 299 | # print(' ') 300 | print("Zero-shot to {0} num: {1}".format(tumor_type,tcga_drug_gex.shape[0])) 301 | # print(tcga_drug_gex.shape[1]) 302 | # print(tcga_drug_gex) 303 | 304 | labeled_tcga_dateset = TensorDataset( 305 | torch.from_numpy(tcga_drug_gex[:,0:tcga_drug_gex.shape[1]-301].astype('float32')), #gex:length-301 306 | torch.from_numpy(tcga_drug_gex[:,tcga_drug_gex.shape[1]-301:tcga_drug_gex.shape[1]-1].astype('float32')), #drug:300 307 | torch.from_numpy(tcga_drug_gex[:,tcga_drug_gex.shape[1]-1]) 308 | ) 309 | # labeled_tcga_dateset = TCGADataset(tcga_drug_gex[:,0], #smiles 310 | # tcga_drug_gex[:,2:tcga_drug_gex.shape[1]-1], #gex 311 | # tcga_drug_gex[:,tcga_drug_gex.shape[1]-1] #label 312 | # ) 313 | 314 | labeled_tcga_dataloader = DataLoader(labeled_tcga_dateset, 315 | batch_size=batch_size, 316 | shuffle=False) # set shuffle Flase 317 | 318 | return labeled_tcga_dataloader,cid_list 319 | 320 | 321 | # Generating Zero-shot dataset for specific tumor type 322 | def get_all_tcga_ZS_dataloaders(all_patient_gex, batch_size,Tumor_type_list, label_type = "PFS",q=2): 323 | patient_sample_info = pd.read_csv("../data/preprocessed_dat/xena_sample_info.csv", index_col=0) 324 | 325 | for TT in Tumor_type_list: 326 | print(f'Generating Zero-shot dataset: {TT}') 327 | patient_samples = all_patient_gex.index.intersection(patient_sample_info.loc[patient_sample_info.tumor_type == TT].index) 328 | gex_features_df = all_patient_gex.loc[patient_samples] 329 | # print(gex_features_df) 330 | TT_ZS_dataloaders,_ = get_tcga_ZS_dataloaders(gex_features_df, 331 | batch_size, label_type = label_type,q=q,tumor_type = TT) 332 | yield TT_ZS_dataloaders, TT 333 | 334 | # Generating PDR dataset for specific tumor type 335 | def get_pdr_data_dataloaders(gex_features_df, batch_size, tumor_type = "TCGA"): 336 | # Return: Dataloader: TCGA_id,smiles,gex,label 337 | tcga_gex_feature_df = gex_features_df.loc[gex_features_df.index.str.startswith('TCGA')] 338 | tcga_gex_feature_df.index = tcga_gex_feature_df.index.map(lambda x: x[:12]) 339 | tcga_gex_feature_df = tcga_gex_feature_df.groupby(level=0).mean() 340 | # print(tcga_gex_feature_df.shape) 341 | 342 | patient_sample_info = pd.read_csv("../data/supple_info/sample_info/xena_sample_info.csv", index_col=0) 343 | 344 | # print(tcga_gex_feature_df.index) 345 | select_index = patient_sample_info.loc[patient_sample_info['tumor_type'] == tumor_type].index.intersection(tcga_gex_feature_df.index) 346 | # print(select_index) 347 | if tumor_type != "TCGA": 348 | tcga_gex_feature_df = tcga_gex_feature_df.loc[select_index] 349 | 350 | drug_emb = pd.read_csv("../data/supple_info/drug_embedding/drug_embedding_for_cell_line.csv",index_col=0) 351 | drug_emb['Drug_smile']="a" 352 | 353 | tcga_gex_feature_df['Drug_smile'] = "a" 354 | sample_id = tcga_gex_feature_df.index 355 | tcga_gex_feature_df.reset_index(inplace=True) 356 | # print(tcga_drug_gex.shape) 357 | # print(tcga_drug_gex['smiles'].isin(drug_emb['Drug_smile'])) 358 | tcga_drug_gex = tcga_gex_feature_df.merge(drug_emb,left_on="Drug_smile",right_on="Drug_smile") #gex,label(1),drug(300) 359 | tcga_drug_gex.set_index("index",inplace=True) 360 | print(f"{tumor_type} Sample_num: {tcga_gex_feature_df.shape[0]}; PDR_test_num: {tcga_drug_gex.shape[0]}") 361 | # exit() 362 | 363 | del tcga_drug_gex['Drug_smile'] 364 | 365 | pdr_data_dateset = TensorDataset( 366 | torch.from_numpy(tcga_drug_gex.loc[:,~tcga_drug_gex.columns.str.startswith('drug_embedding')].values.astype('float32')), #gex:length-301 367 | torch.from_numpy(tcga_drug_gex.loc[:,tcga_drug_gex.columns.str.startswith('drug_embedding')].values.astype('float32')) #drug:300 368 | ) 369 | 370 | pdr_data_dataloader = DataLoader(pdr_data_dateset, 371 | batch_size=batch_size, 372 | shuffle=False) # not shuffle to save the order of patient 373 | 374 | return (pdr_data_dataloader,sample_id) 375 | 376 | 377 | def Rebalance_dataset_for_every_drug(Original_dataset): 378 | drug_list = Original_dataset.Drug_name.drop_duplicates().values 379 | all_drug = pd.DataFrame() 380 | # drug = "5-Fluorouracil" 381 | for drug in drug_list: 382 | # print(drug) 383 | drug_data = Original_dataset.loc[Original_dataset.Drug_name == drug] 384 | drug_data_0 = drug_data.loc[drug_data.label == 0] 385 | drug_data_1 = drug_data.loc[drug_data.label == 1] 386 | ratio = drug_data_1.shape[0]/drug_data_0.shape[0] 387 | if ratio > 1 : 388 | ratio = math.floor(ratio) 389 | drug_data_0 = pd.DataFrame(np.repeat(drug_data_0.values,ratio,axis=0),columns=drug_data.columns) 390 | elif ratio == 0: 391 | drug_data_0 = drug_data_0 392 | else: 393 | ratio = math.floor(1/ratio) 394 | drug_data_1 = pd.DataFrame(np.repeat(drug_data_1.values,ratio,axis=0),columns=drug_data.columns) 395 | drug_data = pd.concat([drug_data_0,drug_data_1]) 396 | # drug_data.label.value_counts() 397 | all_drug = pd.concat([all_drug,drug_data]) 398 | print(all_drug.iloc[1:5]) 399 | # all_drug.groupby('Drug_name')['label'].value_counts() 400 | 401 | return all_drug --------------------------------------------------------------------------------