├── data
└── example
├── records
└── log.txt
├── results
├── table.csv
└── png
│ ├── 6_PDR_comp.png
│ ├── 5_PDR_result.png
│ ├── 1_AI_DRP_model.png
│ ├── 2_P-MDL_dataset.png
│ ├── 3_P-DML_model_zoo.png
│ └── 4_P-MDL_performance.png
├── png
├── .DS_Store
├── 6_PDR_comp.png
├── 5_PDR_result.png
├── 1_AI_DRP_model.png
├── 2_P-MDL_dataset.png
├── 3_P-DML_model_zoo.png
└── 4_P-MDL_performance.png
├── code
├── __pycache__
│ ├── data.cpython-38.pyc
│ ├── fine_tuning.cpython-38.pyc
│ ├── encoder_decoder.cpython-38.pyc
│ └── evaluation_utils.cpython-38.pyc
├── model_train
│ ├── __pycache__
│ │ ├── ae.cpython-38.pyc
│ │ ├── mlp.cpython-38.pyc
│ │ ├── base_ae.cpython-38.pyc
│ │ ├── dsn_ae.cpython-38.pyc
│ │ ├── train_ae.cpython-38.pyc
│ │ ├── train_dsn.cpython-38.pyc
│ │ ├── train_ae_adv.cpython-38.pyc
│ │ ├── train_ae_mmd.cpython-38.pyc
│ │ ├── train_dsn_adv.cpython-38.pyc
│ │ ├── train_dsn_mmd.cpython-38.pyc
│ │ ├── train_dsrn_adv.cpython-38.pyc
│ │ ├── train_dsrn_mmd.cpython-38.pyc
│ │ └── loss_and_metrics.cpython-38.pyc
│ ├── base_ae.py
│ ├── mlp.py
│ ├── loss_and_metrics.py
│ ├── ae.py
│ ├── train_ae_mmd.py
│ ├── train_ae.py
│ ├── dsn_ae.py
│ ├── train_dsn.py
│ ├── train_dsn_mmd.py
│ ├── train_dsrn_mmd.py
│ ├── train_dsn_adv.py
│ ├── train_dsrn_adv.py
│ └── train_ae_adv.py
├── benchmark.txt
├── params
│ └── train_params.json
├── pdr_task.txt
├── encoder_decoder.py
├── run_pretrain_for_pdr.sh
├── run_pdr_task.sh
├── run_pretrain.sh
├── evaluation_utils.py
├── P_MDL.py
├── fine_tuning.py
├── PDR_task.py
└── data.py
└── README.md
/data/example:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/records/log.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/results/table.csv:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/png/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/png/.DS_Store
--------------------------------------------------------------------------------
/png/6_PDR_comp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/png/6_PDR_comp.png
--------------------------------------------------------------------------------
/png/5_PDR_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/png/5_PDR_result.png
--------------------------------------------------------------------------------
/png/1_AI_DRP_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/png/1_AI_DRP_model.png
--------------------------------------------------------------------------------
/png/2_P-MDL_dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/png/2_P-MDL_dataset.png
--------------------------------------------------------------------------------
/png/3_P-DML_model_zoo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/png/3_P-DML_model_zoo.png
--------------------------------------------------------------------------------
/results/png/6_PDR_comp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/results/png/6_PDR_comp.png
--------------------------------------------------------------------------------
/png/4_P-MDL_performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/png/4_P-MDL_performance.png
--------------------------------------------------------------------------------
/results/png/5_PDR_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/results/png/5_PDR_result.png
--------------------------------------------------------------------------------
/results/png/1_AI_DRP_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/results/png/1_AI_DRP_model.png
--------------------------------------------------------------------------------
/results/png/2_P-MDL_dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/results/png/2_P-MDL_dataset.png
--------------------------------------------------------------------------------
/results/png/3_P-DML_model_zoo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/results/png/3_P-DML_model_zoo.png
--------------------------------------------------------------------------------
/code/__pycache__/data.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/__pycache__/data.cpython-38.pyc
--------------------------------------------------------------------------------
/results/png/4_P-MDL_performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/results/png/4_P-MDL_performance.png
--------------------------------------------------------------------------------
/code/__pycache__/fine_tuning.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/__pycache__/fine_tuning.cpython-38.pyc
--------------------------------------------------------------------------------
/code/model_train/__pycache__/ae.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/ae.cpython-38.pyc
--------------------------------------------------------------------------------
/code/__pycache__/encoder_decoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/__pycache__/encoder_decoder.cpython-38.pyc
--------------------------------------------------------------------------------
/code/__pycache__/evaluation_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/__pycache__/evaluation_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/code/model_train/__pycache__/mlp.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/mlp.cpython-38.pyc
--------------------------------------------------------------------------------
/code/model_train/__pycache__/base_ae.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/base_ae.cpython-38.pyc
--------------------------------------------------------------------------------
/code/model_train/__pycache__/dsn_ae.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/dsn_ae.cpython-38.pyc
--------------------------------------------------------------------------------
/code/model_train/__pycache__/train_ae.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/train_ae.cpython-38.pyc
--------------------------------------------------------------------------------
/code/model_train/__pycache__/train_dsn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/train_dsn.cpython-38.pyc
--------------------------------------------------------------------------------
/code/model_train/__pycache__/train_ae_adv.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/train_ae_adv.cpython-38.pyc
--------------------------------------------------------------------------------
/code/model_train/__pycache__/train_ae_mmd.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/train_ae_mmd.cpython-38.pyc
--------------------------------------------------------------------------------
/code/model_train/__pycache__/train_dsn_adv.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/train_dsn_adv.cpython-38.pyc
--------------------------------------------------------------------------------
/code/model_train/__pycache__/train_dsn_mmd.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/train_dsn_mmd.cpython-38.pyc
--------------------------------------------------------------------------------
/code/model_train/__pycache__/train_dsrn_adv.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/train_dsrn_adv.cpython-38.pyc
--------------------------------------------------------------------------------
/code/model_train/__pycache__/train_dsrn_mmd.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/train_dsrn_mmd.cpython-38.pyc
--------------------------------------------------------------------------------
/code/model_train/__pycache__/loss_and_metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wuys13/Multi-Drug-Transfer-Learning/HEAD/code/model_train/__pycache__/loss_and_metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/code/benchmark.txt:
--------------------------------------------------------------------------------
1 | nohup: ignoring input
2 | Process < 25, finished last task!
3 | All-data pre-training (ADP) model runing!
4 | Start to run Test-pairwise pre-training (TPP) model!
5 | Processing next task: 8 20230718-053353
6 | Running task num: 26 20230718-054854
7 |
--------------------------------------------------------------------------------
/code/params/train_params.json:
--------------------------------------------------------------------------------
1 | {"unlabeled": {"batch_size": 64, "lr": 0.0001, "pretrain_num_epochs": 500,
2 | "train_num_epochs": 1000, "alpha": 1.0, "classifier_hidden_dims": [64, 32]},
3 | "labeled": {"classifier_hidden_dims": [64, 32], "batch_size": 256, "lr": 0.0001,
4 | "train_num_epochs":1500, "decay_coefficient": 0.1}, "encoder_hidden_dims": [512, 256],
5 | "latent_dim": 128, "dop": 0.1}
6 |
--------------------------------------------------------------------------------
/code/pdr_task.txt:
--------------------------------------------------------------------------------
1 | nohup: ignoring input
2 | Processed over last task
3 | mkdir: cannot create directory ‘../record/pdr_task’: No such file or directory
4 | run_pdr_task.sh: line 24: ../record/pdr_task/7_2.txt: No such file or directory
5 | run_pdr_task.sh: line 24: ../record/pdr_task/7_3.txt: No such file or directory
6 | run_pdr_task.sh: line 24: ../record/pdr_task/7_4.txt: No such file or directory
7 | Can choose to wait for 5 min ~~~
8 | run_pdr_task.sh: line 24: ../record/pdr_task/7_5.txt: No such file or directory
9 |
--------------------------------------------------------------------------------
/code/model_train/base_ae.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from abc import abstractmethod
3 | from typing import List, Any
4 |
5 | from typing import TypeVar
6 | Tensor = TypeVar('torch.tensor')
7 |
8 |
9 | class BaseAE(nn.Module):
10 | def __init__(self) -> None:
11 | super(BaseAE, self).__init__()
12 |
13 | def encode(self, input: Tensor) -> List[Tensor]:
14 | raise NotImplementedError
15 |
16 | def decode(self, input: Tensor) -> List[Tensor]:
17 | raise NotImplementedError
18 |
19 | def sample(self, batch_size: int, current_device: int, **kwargs) -> Tensor:
20 | raise RuntimeWarning()
21 |
22 | def generate(self, x: Tensor, **kwargs) -> Tensor:
23 | raise NotImplementedError
24 |
25 | @abstractmethod
26 | def forward(self, *inputs: Tensor) -> Tensor:
27 | pass
28 |
29 | @abstractmethod
30 | def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
31 | pass
32 |
--------------------------------------------------------------------------------
/code/encoder_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | # from molecules import Molecules
5 | # from neural_fingerprint import NeuralFingerprint
6 |
7 | from typing import TypeVar
8 | Tensor = TypeVar('torch.tensor')
9 |
10 |
11 | class EncoderDecoder(nn.Module):
12 |
13 | def __init__(self, encoder, decoder, normalize_flag=False):
14 | super(EncoderDecoder, self).__init__()
15 | self.encoder = encoder
16 | self.decoder = decoder
17 | self.normalize_flag = normalize_flag
18 |
19 |
20 | def forward(self, smiles ,input: Tensor) -> Tensor:
21 | encoded_cell = self.encode(input)
22 | if self.normalize_flag:
23 | encoded_cell = nn.functional.normalize(encoded_cell, p=2, dim=1)
24 | encoded_smiles = nn.functional.normalize(smiles, p=2, dim=1)
25 | output = self.decoder(torch.cat((encoded_cell, encoded_smiles), dim=-1))
26 |
27 | return output
28 |
29 | def encode(self, input: Tensor) -> Tensor:
30 | return self.encoder(input)
31 |
32 | def decode(self, z: Tensor) -> Tensor:
33 | return self.decoder(z)
34 |
35 |
--------------------------------------------------------------------------------
/code/run_pretrain_for_pdr.sh:
--------------------------------------------------------------------------------
1 | # nohup bash run_pretrain_for_pdr.sh 1>pdr_pretrain.txt 2>&1 &
2 |
3 | function next_task(){
4 | select_drug_method="all"
5 | # echo "now '$1'"
6 | method_num=$1
7 | # echo $method_num
8 | store_dir="pdr"
9 | mkdir ../results/${store_dir}
10 |
11 | aim=${store_dir}_TPP_${select_drug_method}
12 | mkdir ../record/${aim}
13 | gpu_param=3 #can be set depending on the number of GPUs, here we use 4 GPUs = 12 / 3
14 | for i in {0..13};
15 | do
16 | a=`expr $i - 1`
17 | b=`expr $a / ${gpu_param}`
18 | CUDA_VISIBLE_DEVICES=$b nohup python -u P_MDL.py --pretrain_num $i \
19 | --zero_shot_num 2 \
20 | --CCL_dataset gdsc1_rebalance \
21 | --select_drug_method $select_drug_method \
22 | --method_num ${method_num} \
23 | --store_dir $store_dir \
24 | 1> ../record/${aim}/${method_num}_${i}.txt 2>&1 &
25 | done
26 | }
27 |
28 |
29 | for method in {0..7};
30 | do
31 | while true
32 | do
33 | process=$(ps -ef | grep P_MDL | grep -v "grep" | awk "{print $2}" | wc -l)
34 | current_time=$(date "+%Y%m%d-%H%M%S")
35 | #nvidia-smi
36 | if [ $process -lt 20 ]; then
37 | echo "Process < 21, finished last task!"
38 | next_task $method ;
39 | echo "Processing next task: $method ${current_time}";
40 | break;
41 | else
42 | echo "Running task num: ${process} ${current_time}";
43 | sleep 12h;
44 | fi
45 | done
46 | done
47 |
48 | echo "Finish all !!!!"
49 |
--------------------------------------------------------------------------------
/code/run_pdr_task.sh:
--------------------------------------------------------------------------------
1 | #nohup bash run_pdr_task.sh 1>pdr_task.txt 2>&1 &
2 |
3 | function next_task(){
4 | select_drug_method="all"
5 | # echo "now '$1'"
6 | method_num=$1
7 | # echo $method_num
8 | store_dir="pdr_task"
9 | mkdir ../results/${store_dir}
10 |
11 | aim=${store_dir}
12 | mkdir ../records/${aim}
13 | gpu_num=5
14 | for i in {2..20};
15 | do
16 | if [ $i -gt 5 ]; then
17 | echo "Can choose to wait for 5 min ~~~"
18 | sleep 5m
19 | echo "Sleep over. Now for $i"
20 | fi
21 | a=`expr $i - 1`
22 | #echo $i;echo $a;
23 | b=`expr $a / ${gpu_num}`
24 | #b=3
25 | CUDA_VISIBLE_DEVICES=$b nohup python -u PDR_task.py --pretrain_num $i \
26 | --zero_shot_num 2 \
27 | --CCL_dataset gdsc1_rebalance \
28 | --select_drug_method $select_drug_method \
29 | --method_num ${method_num} \
30 | --store_dir $store_dir \
31 | 1> ../records/${aim}/${method_num}_${i}.txt 2>&1 &
32 | done
33 | }
34 |
35 |
36 | for method in {7..7};
37 | do
38 | while true
39 | do
40 | process=$(ps -ef | grep PDR_task | grep -v "grep" | awk "{print $2}" | wc -l)
41 | current_time=$(date "+%Y%m%d-%H%M%S")
42 | #nvidia-smi
43 | if [ $process -lt 1 ]; then
44 | echo "Processed over last task";
45 | next_task $method $CCL_type $tcga_construction $CCL_construction;
46 | echo "Processing next task: $method $CCL_type $tcga_construction $CCL_construction ${current_time}";
47 | sleep 2s;
48 | break;
49 | else
50 | echo "Last process running now ${current_time}";
51 | sleep 1h;
52 | fi
53 | done
54 | done
55 | done
56 | done
57 | done
58 |
59 | echo "Finish all !!!!"
60 |
61 |
--------------------------------------------------------------------------------
/code/run_pretrain.sh:
--------------------------------------------------------------------------------
1 | # nohup bash run_pretrain.sh 1>benchmark.txt 2>&1 &
2 |
3 | function next_task(){
4 | select_drug_method="overlap"
5 | method_num=$1
6 | # echo $method_num
7 |
8 | store_dir="benchmark"
9 | mkdir ../results/${store_dir}
10 |
11 | # For All-data pre-training (ADP) model
12 | aim=${store_dir}_ADP_${select_drug_method}
13 | mkdir ../records/${aim}
14 | gpu_param=13 #can be set depending on the number of GPUs, here we use 4 GPUs = 12 / gpu_param
15 | for i in {1..13};
16 | do
17 | a=`expr $i - 1`
18 | b=`expr $a / ${gpu_param}`
19 | CUDA_VISIBLE_DEVICES=$b nohup python -u P_MDL.py --pretrain_num 0 \
20 | --zero_shot_num $i \
21 | --CCL_dataset gdsc1_raw \
22 | --select_drug_method $select_drug_method \
23 | --method_num ${method_num} \
24 | --store_dir $store_dir \
25 | 1> ../records/${aim}/${method_num}_${i}.txt 2>&1 &
26 | done
27 |
28 | echo "All-data pre-training (ADP) model runing!"
29 | sleep 15m
30 | echo "Start to run Test-pairwise pre-training (TPP) model!"
31 |
32 | # For Test-pairwise pre-training (TPP) model
33 | aim=${store_dir}_TPP_${select_drug_method}
34 | mkdir ../records/${aim}
35 | gpu_param=13 #can be set depending on the number of GPUs, here we use 4 GPUs = 12 / gpu_param
36 | for i in {1..13};
37 | do
38 | a=`expr $i - 1`
39 | b=`expr $a / ${gpu_param}`
40 | CUDA_VISIBLE_DEVICES=$b nohup python -u P_MDL.py --pretrain_num $i \
41 | --zero_shot_num 2 \
42 | --CCL_dataset gdsc1_raw \
43 | --select_drug_method $select_drug_method \
44 | --method_num ${method_num} \
45 | --store_dir $store_dir \
46 | 1> ../records/${aim}/${method_num}_${i}.txt 2>&1 &
47 | done
48 |
49 | }
50 |
51 | for method in {8..1};
52 | do
53 | while true
54 | do
55 | process=$(ps -ef | grep P_MDL | grep -v "grep" | awk "{print $2}" | wc -l)
56 | current_time=$(date "+%Y%m%d-%H%M%S")
57 | #nvidia-smi
58 | if [ $process -lt 25 ]; then
59 | echo "Process < 25, finished last task!"
60 | echo "Processing next task: $method ${current_time}";
61 | next_task $method ;
62 | break;
63 | else
64 | echo "Running task num: ${process} ${current_time}";
65 | sleep 1h;
66 | fi
67 | done
68 | done
69 |
70 | echo "Finish all !!!!"
--------------------------------------------------------------------------------
/code/model_train/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch import nn
4 | from torch.autograd import Function
5 |
6 | from typing import List
7 | from typing import TypeVar
8 |
9 | Tensor = TypeVar('torch.tensor')
10 |
11 | class RevGrad(Function):
12 | @staticmethod
13 | def forward(ctx, input_, alpha_):
14 | ctx.save_for_backward(input_, alpha_)
15 | output = input_
16 | return output
17 |
18 | @staticmethod
19 | def backward(ctx, grad_output): # pragma: no cover
20 | grad_input = None
21 | _, alpha_ = ctx.saved_tensors
22 | if ctx.needs_input_grad[0]:
23 | grad_input = -grad_output * alpha_
24 | return grad_input, None
25 |
26 |
27 | revgrad = RevGrad.apply
28 |
29 |
30 | class RevGrad(torch.nn.Module):
31 | def __init__(self, alpha=1., *args, **kwargs):
32 | """
33 | A gradient reversal layer.
34 | This layer has no parameters, and simply reverses the gradient
35 | in the backward pass.
36 | """
37 | super().__init__(*args, **kwargs)
38 |
39 | self._alpha = torch.tensor(alpha, requires_grad=False)
40 |
41 | def forward(self, input_):
42 | return revgrad(input_, self._alpha)
43 |
44 |
45 | class MLP(nn.Module):
46 |
47 | def __init__(self, input_dim: int, output_dim: int, hidden_dims: List = None, dop: float = 0.1, act_fn=nn.SELU, out_fn=None, gr_flag=False, **kwargs) -> None:
48 | super(MLP, self).__init__()
49 | self.output_dim = output_dim
50 | self.dop = dop
51 |
52 | if hidden_dims is None:
53 | hidden_dims = [32, 64, 128, 256, 512]
54 |
55 | modules = []
56 | if gr_flag:
57 | modules.append(RevGrad())
58 |
59 | modules.append(
60 | nn.Sequential(
61 | nn.Linear(input_dim, hidden_dims[0], bias=True),
62 | #nn.BatchNorm1d(hidden_dims[0]),
63 | act_fn(),
64 | nn.Dropout(self.dop)
65 | )
66 | )
67 |
68 | for i in range(len(hidden_dims) - 1):
69 | modules.append(
70 | nn.Sequential(
71 | nn.Linear(hidden_dims[i], hidden_dims[i + 1], bias=True),
72 | #nn.BatchNorm1d(hidden_dims[i + 1]),
73 | act_fn(),
74 | nn.Dropout(self.dop)
75 | )
76 | )
77 |
78 | self.module = nn.Sequential(*modules)
79 |
80 | if out_fn is None:
81 | self.output_layer = nn.Sequential(
82 | nn.Linear(hidden_dims[-1], output_dim, bias=True)
83 | )
84 | else:
85 | self.output_layer = nn.Sequential(
86 | nn.Linear(hidden_dims[-1], output_dim, bias=True),
87 | out_fn()
88 | )
89 |
90 |
91 |
92 | def forward(self, input: Tensor) -> Tensor:
93 | embed = self.module(input)
94 | output = self.output_layer(embed)
95 |
96 | return output
97 |
--------------------------------------------------------------------------------
/code/model_train/loss_and_metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | from functools import partial
4 |
5 |
6 | def cov(m, rowvar=False):
7 | """Estimate a covariance matrix given data.
8 |
9 | Covariance indicates the level to which two variables vary together.
10 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
11 | then the covariance matrix element `C_{ij}` is the covariance of
12 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.
13 |
14 | Args:
15 | m: A 1-D or 2-D array containing multiple variables and observations.
16 | Each row of `m` represents a variable, and each column a single
17 | observation of all those variables.
18 | rowvar: If `rowvar` is True, then each row represents a
19 | variable, with observations in the columns. Otherwise, the
20 | relationship is transposed: each column represents a variable,
21 | while the rows contain observations.
22 |
23 | Returns:
24 | The covariance matrix of the variables.
25 | """
26 | if m.dim() > 2:
27 | raise ValueError('m has more than 2 dimensions')
28 | if m.dim() < 2:
29 | m = m.view(1, -1)
30 | if not rowvar and m.size(0) != 1:
31 | m = m.t()
32 | # m = m.type(torch.double) # uncomment this line if desired
33 | fact = 1.0 / (m.size(1) - 1)
34 | m -= torch.mean(m, dim=1, keepdim=True)
35 | mt = m.t() # if complex: mt = m.t().conj()
36 | return fact * m.matmul(mt).squeeze()
37 |
38 |
39 | def pairwise_distance(x, y):
40 | if not len(x.shape) == len(y.shape) == 2:
41 | raise ValueError('Both inputs should be matrices.')
42 |
43 | if x.shape[1] != y.shape[1]:
44 | raise ValueError('The number of features should be the same.')
45 |
46 | x = x.view(x.shape[0], x.shape[1], 1)
47 | y = torch.transpose(y, 0, 1)
48 | output = torch.sum((x - y) ** 2, 1)
49 | output = torch.transpose(output, 0, 1)
50 |
51 | return output
52 |
53 |
54 | def gaussian_kernel_matrix(x, y, sigmas):
55 | sigmas = sigmas.view(sigmas.shape[0], 1)
56 | beta = 1. / (2. * sigmas)
57 | dist = pairwise_distance(x, y).contiguous()
58 | dist_ = dist.view(1, -1)
59 | s = torch.matmul(beta, dist_)
60 |
61 | return torch.sum(torch.exp(-s), 0).view_as(dist)
62 |
63 |
64 | def maximum_mean_discrepancy(x, y, kernel=gaussian_kernel_matrix):
65 | cost = torch.mean(kernel(x, x))
66 | cost += torch.mean(kernel(y, y))
67 | cost -= 2 * torch.mean(kernel(x, y))
68 |
69 | return cost
70 |
71 |
72 | def mmd_loss(source_features, target_features, device):
73 | sigmas = [
74 | 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
75 | 1e3, 1e4, 1e5, 1e6
76 | ]
77 | gaussian_kernel = partial(
78 | gaussian_kernel_matrix, sigmas=Variable(torch.Tensor(sigmas), requires_grad=False).to(device)
79 | )
80 |
81 | loss_value = maximum_mean_discrepancy(source_features, target_features, kernel=gaussian_kernel)
82 | loss_value = loss_value
83 |
84 | return loss_value
85 |
--------------------------------------------------------------------------------
/code/model_train/ae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from base_ae import BaseAE
5 |
6 | from typing import List
7 | from typing import TypeVar
8 | Tensor = TypeVar('torch.tensor')
9 |
10 | class AE(BaseAE):
11 |
12 | def __init__(self, input_dim: int, latent_dim: int, hidden_dims: List = None, dop: float = 0.1, noise_flag: bool = False, **kwargs) -> None:
13 | super(AE, self).__init__()
14 | self.latent_dim = latent_dim
15 | self.noise_flag = noise_flag
16 | self.dop = dop
17 |
18 | if hidden_dims is None:
19 | hidden_dims = [32, 64, 128, 256, 512]
20 |
21 | # build encoder
22 | modules = []
23 |
24 | modules.append(
25 | nn.Sequential(
26 | nn.Linear(input_dim, hidden_dims[0], bias=True),
27 | #nn.BatchNorm1d(hidden_dims[0]),
28 | nn.ReLU(),
29 | nn.Dropout(self.dop)
30 | )
31 | )
32 |
33 | for i in range(len(hidden_dims) - 1):
34 | modules.append(
35 | nn.Sequential(
36 | nn.Linear(hidden_dims[i], hidden_dims[i + 1], bias=True),
37 | #nn.BatchNorm1d(hidden_dims[i + 1]),
38 | nn.ReLU(),
39 | nn.Dropout(self.dop)
40 | )
41 | )
42 | modules.append(nn.Dropout(self.dop))
43 | modules.append(nn.Linear(hidden_dims[-1], latent_dim, bias=True))
44 |
45 | self.encoder = nn.Sequential(*modules)
46 |
47 | # build decoder
48 | modules = []
49 |
50 | modules.append(
51 | nn.Sequential(
52 | nn.Linear(latent_dim, hidden_dims[-1], bias=True),
53 | #nn.BatchNorm1d(hidden_dims[-1]),
54 | nn.ReLU(),
55 | nn.Dropout(self.dop)
56 | )
57 | )
58 |
59 | hidden_dims.reverse()
60 |
61 | for i in range(len(hidden_dims) - 1):
62 | modules.append(
63 | nn.Sequential(
64 | nn.Linear(hidden_dims[i], hidden_dims[i + 1], bias=True),
65 | #nn.BatchNorm1d(hidden_dims[i + 1]),
66 | nn.ReLU(),
67 | nn.Dropout(self.dop)
68 | )
69 | )
70 | self.decoder = nn.Sequential(*modules)
71 |
72 | self.final_layer = nn.Sequential(
73 | nn.Linear(hidden_dims[-1], hidden_dims[-1], bias=True),
74 | #nn.BatchNorm1d(hidden_dims[-1]),
75 | nn.ReLU(),
76 | nn.Dropout(self.dop),
77 | nn.Linear(hidden_dims[-1], input_dim)
78 | )
79 |
80 |
81 | def encode(self, input: Tensor) -> Tensor:
82 | if self.noise_flag and self.training:
83 | latent_code = self.encoder(input+torch.randn_like(input, requires_grad=False) * 0.1)
84 | else:
85 | latent_code = self.encoder(input)
86 |
87 | return latent_code
88 |
89 | def decode(self, z: Tensor) -> Tensor:
90 | embed = self.decoder(z)
91 | outputs = self.final_layer(embed)
92 |
93 | return outputs
94 |
95 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
96 | z = self.encode(input)
97 | return [input, self.decode(z), z]
98 |
99 | def loss_function(self, *args, **kwargs) -> dict:
100 | input = args[0]
101 | recons = args[1]
102 |
103 | recons_loss = F.mse_loss(input, recons)
104 | loss = recons_loss
105 |
106 | return {'loss': loss, 'recons_loss': recons_loss}
107 |
108 | def sample(self, num_samples: int, current_device: int, **kwargs) -> Tensor:
109 | z = torch.randn(num_samples, self.latent_dim)
110 |
111 | z = z.to(current_device)
112 | samples = self.decode(z)
113 |
114 | return samples
115 |
116 | def generate(self, x: Tensor, **kwargs) -> Tensor:
117 | return self.forward(x)[1]
118 |
--------------------------------------------------------------------------------
/code/model_train/train_ae_mmd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from evaluation_utils import eval_ae_epoch, model_save_check
4 | from collections import defaultdict
5 | from ae import AE
6 | from mlp import MLP
7 | from loss_and_metrics import mmd_loss
8 | from encoder_decoder import EncoderDecoder
9 |
10 |
11 | def ae_train_step(ae, s_batch, t_batch, device, optimizer, history, scheduler=None):
12 | ae.zero_grad()
13 | ae.train()
14 |
15 | s_x = s_batch[0].to(device)
16 | t_x = t_batch[0].to(device)
17 |
18 | s_loss_dict = ae.loss_function(*ae(s_x))
19 | t_loss_dict = ae.loss_function(*ae(t_x))
20 |
21 | optimizer.zero_grad()
22 | m_loss = mmd_loss(source_features=ae.encode(s_x), target_features=ae.encode(t_x), device=device)
23 | loss = s_loss_dict['loss'] + t_loss_dict['loss'] + m_loss
24 | optimizer.zero_grad()
25 |
26 | loss.backward()
27 | optimizer.step()
28 | if scheduler is not None:
29 | scheduler.step()
30 |
31 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()}
32 |
33 | for k, v in loss_dict.items():
34 | history[k].append(v)
35 | history['mmd_loss'].append(m_loss.cpu().detach().item())
36 |
37 | return history
38 |
39 |
40 | def train_ae_mmd(s_dataloaders, t_dataloaders, **kwargs):
41 | """
42 |
43 | :param s_dataloaders:
44 | :param t_dataloaders:
45 | :param kwargs:
46 | :return:
47 | """
48 | s_train_dataloader = s_dataloaders[0]
49 | s_test_dataloader = s_dataloaders[1]
50 |
51 | t_train_dataloader = t_dataloaders[0]
52 | t_test_dataloader = t_dataloaders[1]
53 |
54 | autoencoder = AE(input_dim=kwargs['input_dim'],
55 | latent_dim=kwargs['latent_dim'],
56 | hidden_dims=kwargs['encoder_hidden_dims'],
57 | dop=kwargs['dop']).to(kwargs['device'])
58 | classifier = MLP(input_dim=kwargs['latent_dim'],
59 | output_dim=1,
60 | hidden_dims=kwargs['classifier_hidden_dims'],
61 | dop=kwargs['dop']).to(kwargs['device'])
62 | confounder_classifier = EncoderDecoder(encoder=autoencoder.encoder, decoder=classifier).to(kwargs['device'])
63 |
64 | ae_eval_train_history = defaultdict(list)
65 | ae_eval_val_history = defaultdict(list)
66 |
67 | if kwargs['retrain_flag']:
68 | ae_optimizer = torch.optim.AdamW(autoencoder.parameters(), lr=kwargs['lr'])
69 |
70 |
71 | # start autoencoder pretraining
72 | for epoch in range(int(kwargs['pretrain_num_epochs'])):
73 | if epoch % 50 == 0:
74 | print(f'----Autoencoder Pre-Training Epoch {epoch} ----')
75 | for step, s_batch in enumerate(s_train_dataloader):
76 | t_batch = next(iter(t_train_dataloader))
77 | ae_eval_train_history = ae_train_step(ae=autoencoder,
78 | s_batch=s_batch,
79 | t_batch=t_batch,
80 | device=kwargs['device'],
81 | optimizer=ae_optimizer,
82 | history=ae_eval_train_history)
83 |
84 | ae_eval_val_history = eval_ae_epoch(model=autoencoder,
85 | data_loader=s_test_dataloader,
86 | device=kwargs['device'],
87 | history=ae_eval_val_history
88 | )
89 | ae_eval_val_history = eval_ae_epoch(model=autoencoder,
90 | data_loader=t_test_dataloader,
91 | device=kwargs['device'],
92 | history=ae_eval_val_history
93 | )
94 | for k in ae_eval_val_history:
95 | if k != 'best_index':
96 | ae_eval_val_history[k][-2] += ae_eval_val_history[k][-1]
97 | ae_eval_val_history[k].pop()
98 | # print some loss/metric messages
99 | if kwargs['es_flag']:
100 | save_flag, stop_flag = model_save_check(history=ae_eval_val_history, metric_name='loss',
101 | tolerance_count=10)
102 | if save_flag:
103 | torch.save(autoencoder.state_dict(), os.path.join(kwargs['model_save_folder'], 'ae.pt'))
104 | if stop_flag:
105 | break
106 |
107 | if kwargs['es_flag']:
108 | autoencoder.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt')))
109 |
110 | else:
111 | try:
112 | autoencoder.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt')))
113 | except FileNotFoundError:
114 | raise Exception("No pre-trained encoder")
115 |
116 | return autoencoder.encoder, (ae_eval_train_history,
117 | ae_eval_val_history)
118 |
--------------------------------------------------------------------------------
/code/model_train/train_ae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from evaluation_utils import eval_ae_epoch, model_save_check
4 | from collections import defaultdict
5 | from ae import AE
6 | from collections import OrderedDict
7 |
8 |
9 | def ae_train_step(ae, s_batch, t_batch, device, optimizer, history, scheduler=None):
10 | ae.zero_grad()
11 | ae.train()
12 |
13 | s_x = s_batch[0].to(device)
14 | t_x = t_batch[0].to(device)
15 |
16 | s_loss_dict = ae.loss_function(*ae(s_x))
17 | t_loss_dict = ae.loss_function(*ae(t_x))
18 |
19 | optimizer.zero_grad()
20 | loss = s_loss_dict['loss'] + t_loss_dict['loss']
21 | optimizer.zero_grad()
22 |
23 | loss.backward()
24 | optimizer.step()
25 | if scheduler is not None:
26 | scheduler.step()
27 |
28 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()}
29 |
30 | for k, v in loss_dict.items():
31 | history[k].append(v)
32 |
33 | return history
34 |
35 |
36 | def train_ae(s_dataloaders, t_dataloaders, **kwargs):
37 | """
38 |
39 | :param s_dataloaders:
40 | :param t_dataloaders:
41 | :param kwargs:
42 | :return:
43 | """
44 | s_train_dataloader = s_dataloaders[0]
45 | s_test_dataloader = s_dataloaders[1]
46 |
47 | t_train_dataloader = t_dataloaders[0]
48 | t_test_dataloader = t_dataloaders[1]
49 | # print(kwargs)
50 |
51 | autoencoder = AE(input_dim=kwargs['input_dim'],
52 | latent_dim=kwargs['latent_dim'],
53 | hidden_dims=kwargs['encoder_hidden_dims'],
54 | noise_flag=False,
55 | dop=kwargs['dop']).to(kwargs['device'])
56 |
57 |
58 | ae_eval_train_history = defaultdict(list)
59 | ae_eval_val_history = defaultdict(list)
60 |
61 | if kwargs['retrain_flag']:
62 | ae_optimizer = torch.optim.AdamW(autoencoder.parameters(), lr=kwargs['lr'])
63 |
64 |
65 | # start autoencoder pretraining
66 | for epoch in range(int(kwargs['train_num_epochs'])):
67 | if epoch % 50 == 0:
68 | print(f'----Autoencoder Training Epoch {epoch} ----')
69 | for step, s_batch in enumerate(s_train_dataloader):
70 | t_batch = next(iter(t_train_dataloader))
71 | ae_eval_train_history = ae_train_step(ae=autoencoder,
72 | s_batch=s_batch,
73 | t_batch=t_batch,
74 | device=kwargs['device'],
75 | optimizer=ae_optimizer,
76 | history=ae_eval_train_history)
77 |
78 | ae_eval_val_history = eval_ae_epoch(model=autoencoder,
79 | data_loader=s_test_dataloader,
80 | device=kwargs['device'],
81 | history=ae_eval_val_history
82 | )
83 | ae_eval_val_history = eval_ae_epoch(model=autoencoder,
84 | data_loader=t_test_dataloader,
85 | device=kwargs['device'],
86 | history=ae_eval_val_history
87 | )
88 | for k in ae_eval_val_history:
89 | if k != 'best_index':
90 | ae_eval_val_history[k][-2] += ae_eval_val_history[k][-1]
91 | ae_eval_val_history[k].pop()
92 | # print some loss/metric messages
93 | if kwargs['es_flag']:
94 | save_flag, stop_flag = model_save_check(history=ae_eval_val_history, metric_name='loss',
95 | tolerance_count=50)
96 | if save_flag:
97 | torch.save(autoencoder.state_dict(), os.path.join(kwargs['model_save_folder'], 'ae.pt'))
98 | if stop_flag:
99 | break
100 |
101 | if kwargs['es_flag']:
102 | autoencoder.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt')))
103 | torch.save(autoencoder.state_dict(), os.path.join(kwargs['model_save_folder'], 'ae.pt'))
104 | else:
105 | try:
106 | autoencoder.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt')))
107 | # except:
108 | loaded_model = torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt'))
109 | # new_loaded_model = {key: val for key, val in loaded_model.items() if key in autoencoder.state_dict()}
110 | # new_loaded_model['shared_encoder.output_layer.0.weight'] = loaded_model[
111 | # 'shared_encoder.output_layer.3.weight']
112 | # new_loaded_model['shared_encoder.output_layer.0.bias'] = loaded_model[
113 | # 'shared_encoder.output_layer.3.bias']
114 | # new_loaded_model['decoder.output_layer.0.weight'] = loaded_model['decoder.output_layer.3.weight']
115 | # new_loaded_model['decoder.output_layer.0.bias'] = loaded_model['decoder.output_layer.3.bias']
116 |
117 | # corrected_model = OrderedDict({key: new_loaded_model[key] for key in autoencoder.state_dict()})
118 | # autoencoder.load_state_dict(corrected_model)
119 |
120 | except FileNotFoundError:
121 | raise Exception("No pre-trained encoder")
122 |
123 | return autoencoder.encoder, (ae_eval_train_history,
124 | ae_eval_val_history)
125 |
--------------------------------------------------------------------------------
/code/model_train/dsn_ae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from base_ae import BaseAE
5 | from typing import List
6 |
7 | from typing import TypeVar
8 | Tensor = TypeVar('torch.tensor')
9 |
10 |
11 | class DSNAE(BaseAE):
12 |
13 | def __init__(self, shared_encoder, decoder, input_dim: int, latent_dim: int, alpha: float = 1.0,
14 | hidden_dims: List = None, dop: float = 0.1, noise_flag: bool = False, norm_flag: bool = False,
15 | **kwargs) -> None:
16 | super(DSNAE, self).__init__()
17 | self.latent_dim = latent_dim
18 | self.alpha = alpha
19 | self.noise_flag = noise_flag
20 | self.dop = dop
21 | self.norm_flag = norm_flag
22 |
23 | if hidden_dims is None:
24 | hidden_dims = [32, 64, 128, 256, 512]
25 |
26 | self.shared_encoder = shared_encoder
27 | self.decoder = decoder
28 | # build encoder
29 | modules = []
30 |
31 | modules.append(
32 | nn.Sequential(
33 | nn.Linear(input_dim, hidden_dims[0], bias=True),
34 | # nn.BatchNorm1d(hidden_dims[0]),
35 | nn.ReLU(),
36 | nn.Dropout(self.dop)
37 | )
38 | )
39 |
40 | for i in range(len(hidden_dims) - 1):
41 | modules.append(
42 | nn.Sequential(
43 | nn.Linear(hidden_dims[i], hidden_dims[i + 1], bias=True),
44 | # nn.Dropout(0.1),
45 | # nn.BatchNorm1d(hidden_dims[i + 1]),
46 | nn.ReLU(),
47 | nn.Dropout(self.dop)
48 | )
49 | )
50 | modules.append(nn.Dropout(self.dop))
51 | modules.append(nn.Linear(hidden_dims[-1], latent_dim, bias=True))
52 | # modules.append(nn.LayerNorm(latent_dim, eps=1e-12, elementwise_affine=False))
53 |
54 | self.private_encoder = nn.Sequential(*modules)
55 |
56 | # build decoder
57 | # modules = []
58 | #
59 | # modules.append(
60 | # nn.Sequential(
61 | # nn.Linear(2 * latent_dim, hidden_dims[-1], bias=True),
62 | # # nn.Dropout(0.1),
63 | # nn.BatchNorm1d(hidden_dims[-1]),
64 | # nn.ReLU()
65 | # )
66 | # )
67 | #
68 | # hidden_dims.reverse()
69 | #
70 | # for i in range(len(hidden_dims) - 1):
71 | # modules.append(
72 | # nn.Sequential(
73 | # nn.Linear(hidden_dims[i], hidden_dims[i + 1], bias=True),
74 | # nn.BatchNorm1d(hidden_dims[i + 1]),
75 | # # nn.Dropout(0.1),
76 | # nn.ReLU()
77 | # )
78 | # )
79 | # self.decoder = nn.Sequential(*modules)
80 |
81 | # self.final_layer = nn.Sequential(
82 | # nn.Linear(hidden_dims[-1], hidden_dims[-1], bias=True),
83 | # nn.BatchNorm1d(hidden_dims[-1]),
84 | # nn.ReLU(),
85 | # nn.Dropout(0.1),
86 | # nn.Linear(hidden_dims[-1], input_dim)
87 | # )
88 |
89 | def p_encode(self, input: Tensor) -> Tensor:
90 | if self.noise_flag and self.training:
91 | latent_code = self.private_encoder(input + torch.randn_like(input, requires_grad=False) * 0.1)
92 | else:
93 | latent_code = self.private_encoder(input)
94 |
95 | if self.norm_flag:
96 | return F.normalize(latent_code, p=2, dim=1)
97 | else:
98 | return latent_code
99 |
100 | def s_encode(self, input: Tensor) -> Tensor:
101 | if self.noise_flag and self.training:
102 | latent_code = self.shared_encoder(input + torch.randn_like(input, requires_grad=False) * 0.1)
103 | else:
104 | latent_code = self.shared_encoder(input)
105 | if self.norm_flag:
106 | return F.normalize(latent_code, p=2, dim=1)
107 | else:
108 | return latent_code
109 |
110 | def encode(self, input: Tensor) -> Tensor:
111 | p_latent_code = self.p_encode(input)
112 | s_latent_code = self.s_encode(input)
113 |
114 | return torch.cat((p_latent_code, s_latent_code), dim=1)
115 |
116 | def decode(self, z: Tensor) -> Tensor:
117 | # embed = self.decoder(z)
118 | # outputs = self.final_layer(embed)
119 | outputs = self.decoder(z)
120 |
121 | return outputs
122 |
123 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
124 | z = self.encode(input)
125 | return [input, self.decode(z), z]
126 |
127 | def loss_function(self, *args, **kwargs) -> dict:
128 | input = args[0]
129 | recons = args[1]
130 | z = args[2]
131 |
132 | p_z = z[:, :z.shape[1] // 2]
133 | s_z = z[:, z.shape[1] // 2:]
134 |
135 | recons_loss = F.mse_loss(input, recons)
136 |
137 | s_l2_norm = torch.norm(s_z, p=2, dim=1, keepdim=True).detach()
138 | s_l2 = s_z.div(s_l2_norm.expand_as(s_z) + 1e-6)
139 |
140 | p_l2_norm = torch.norm(p_z, p=2, dim=1, keepdim=True).detach()
141 | p_l2 = p_z.div(p_l2_norm.expand_as(p_z) + 1e-6)
142 |
143 | ortho_loss = torch.mean((s_l2.t().mm(p_l2)).pow(2))
144 | # ortho_loss = torch.square(torch.norm(torch.matmul(s_z.t(), p_z), p='fro'))
145 | # ortho_loss = torch.mean(torch.square(torch.diagonal(torch.matmul(p_z, s_z.t()))))
146 | # if recons_loss > ortho_loss:
147 | # loss = recons_loss + self.alpha * 0.1 * ortho_loss
148 | # else:
149 | loss = recons_loss + self.alpha * ortho_loss
150 | return {'loss': loss, 'recons_loss': recons_loss, 'ortho_loss': ortho_loss}
151 |
152 | def sample(self, num_samples: int, current_device: int, **kwargs) -> Tensor:
153 | z = torch.randn(num_samples, self.latent_dim)
154 | z = z.to(current_device)
155 | samples = self.decode(z)
156 |
157 | return samples
158 |
159 | def generate(self, x: Tensor, **kwargs) -> Tensor:
160 | return self.forward(x)[1]
161 |
--------------------------------------------------------------------------------
/code/model_train/train_dsn.py:
--------------------------------------------------------------------------------
1 | import os
2 | from itertools import chain
3 |
4 | from dsn_ae import DSNAE
5 | from evaluation_utils import *
6 | from mlp import MLP
7 |
8 |
9 | def eval_dsnae_epoch(model, data_loader, device, history):
10 | """
11 |
12 | :param model:
13 | :param data_loader:
14 | :param device:
15 | :param history:
16 | :return:
17 | """
18 | model.eval()
19 | avg_loss_dict = defaultdict(float)
20 | for x_batch in data_loader:
21 | x_batch = x_batch[0].to(device)
22 | with torch.no_grad():
23 | loss_dict = model.loss_function(*(model(x_batch)))
24 | for k, v in loss_dict.items():
25 | avg_loss_dict[k] += v.cpu().detach().item() / len(data_loader)
26 |
27 | for k, v in avg_loss_dict.items():
28 | history[k].append(v)
29 | return history
30 |
31 |
32 | def dsn_ae_train_step(s_dsnae, t_dsnae, s_batch, t_batch, device, optimizer, history, scheduler=None):
33 | s_dsnae.zero_grad()
34 | t_dsnae.zero_grad()
35 | s_dsnae.train()
36 | t_dsnae.train()
37 |
38 | s_x = s_batch[0].to(device)
39 | t_x = t_batch[0].to(device)
40 |
41 | s_loss_dict = s_dsnae.loss_function(*s_dsnae(s_x))
42 | t_loss_dict = t_dsnae.loss_function(*t_dsnae(t_x))
43 |
44 | optimizer.zero_grad()
45 | loss = s_loss_dict['loss'] + t_loss_dict['loss']
46 | loss.backward()
47 |
48 | optimizer.step()
49 | if scheduler is not None:
50 | scheduler.step()
51 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()}
52 |
53 | for k, v in loss_dict.items():
54 | history[k].append(v)
55 |
56 | return history
57 |
58 |
59 | def train_dsn(s_dataloaders, t_dataloaders, **kwargs):
60 | """
61 |
62 | :param s_dataloaders:
63 | :param t_dataloaders:
64 | :param kwargs:
65 | :return:
66 | """
67 | s_train_dataloader = s_dataloaders[0]
68 | s_test_dataloader = s_dataloaders[1]
69 |
70 | t_train_dataloader = t_dataloaders[0]
71 | t_test_dataloader = t_dataloaders[1]
72 |
73 | shared_encoder = MLP(input_dim=kwargs['input_dim'],
74 | output_dim=kwargs['latent_dim'],
75 | hidden_dims=kwargs['encoder_hidden_dims'],
76 | dop=kwargs['dop']).to(kwargs['device'])
77 |
78 | shared_decoder = MLP(input_dim=2 * kwargs['latent_dim'],
79 | output_dim=kwargs['input_dim'],
80 | hidden_dims=kwargs['encoder_hidden_dims'][::-1],
81 | dop=kwargs['dop']).to(kwargs['device'])
82 |
83 | s_dsnae = DSNAE(shared_encoder=shared_encoder,
84 | decoder=shared_decoder,
85 | alpha=kwargs['alpha'],
86 | input_dim=kwargs['input_dim'],
87 | latent_dim=kwargs['latent_dim'],
88 | hidden_dims=kwargs['encoder_hidden_dims'],
89 | dop=kwargs['dop'],
90 | norm_flag=kwargs['norm_flag']).to(kwargs['device'])
91 |
92 | t_dsnae = DSNAE(shared_encoder=shared_encoder,
93 | decoder=shared_decoder,
94 | alpha=kwargs['alpha'],
95 | input_dim=kwargs['input_dim'],
96 | latent_dim=kwargs['latent_dim'],
97 | hidden_dims=kwargs['encoder_hidden_dims'],
98 | dop=kwargs['dop'],
99 | norm_flag=kwargs['norm_flag']).to(kwargs['device'])
100 |
101 | device = kwargs['device']
102 |
103 | dsnae_train_history = defaultdict(list)
104 | dsnae_val_history = defaultdict(list)
105 |
106 | if kwargs['retrain_flag']:
107 | ae_params = [t_dsnae.private_encoder.parameters(),
108 | s_dsnae.private_encoder.parameters(),
109 | shared_decoder.parameters(),
110 | shared_encoder.parameters()
111 | ]
112 |
113 | ae_optimizer = torch.optim.AdamW(chain(*ae_params), lr=kwargs['lr'])
114 |
115 |
116 | for epoch in range(int(kwargs['train_num_epochs'])):
117 | if epoch % 50 == 0:
118 | print(f'AE training epoch {epoch}')
119 | for step, s_batch in enumerate(s_train_dataloader):
120 | t_batch = next(iter(t_train_dataloader))
121 | dsnae_train_history = dsn_ae_train_step(s_dsnae=s_dsnae,
122 | t_dsnae=t_dsnae,
123 | s_batch=s_batch,
124 | t_batch=t_batch,
125 | device=device,
126 | optimizer=ae_optimizer,
127 | history=dsnae_train_history)
128 | dsnae_val_history = eval_dsnae_epoch(model=s_dsnae,
129 | data_loader=s_test_dataloader,
130 | device=device,
131 | history=dsnae_val_history
132 | )
133 | dsnae_val_history = eval_dsnae_epoch(model=t_dsnae,
134 | data_loader=t_test_dataloader,
135 | device=device,
136 | history=dsnae_val_history
137 | )
138 | for k in dsnae_val_history:
139 | if k != 'best_index':
140 | dsnae_val_history[k][-2] += dsnae_val_history[k][-1]
141 | dsnae_val_history[k].pop()
142 |
143 | if kwargs['es_flag']:
144 | save_flag, stop_flag = model_save_check(dsnae_val_history, metric_name='loss', tolerance_count=50)
145 | if save_flag:
146 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 's_dsnae.pt'))
147 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 't_dsnae.pt'))
148 | if stop_flag:
149 | break
150 |
151 | if kwargs['es_flag']:
152 | s_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 's_dsnae.pt')))
153 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 't_dsnae.pt')))
154 |
155 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 's_dsnae.pt'))
156 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 't_dsnae.pt'))
157 |
158 | else:
159 | try:
160 | loaded_model = torch.load(os.path.join(kwargs['model_save_folder'], 't_dsnae.pt'))
161 | print({key:val.shape for key,val in loaded_model.items()})
162 | print({key:val.shape for key,val in t_dsnae.state_dict().items()})
163 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 't_dsnae.pt')))
164 | except FileNotFoundError:
165 | raise Exception("No pre-trained encoder")
166 |
167 |
168 | return t_dsnae.shared_encoder, (dsnae_train_history, dsnae_val_history)
169 |
--------------------------------------------------------------------------------
/code/model_train/train_dsn_mmd.py:
--------------------------------------------------------------------------------
1 | import os
2 | from itertools import chain
3 |
4 | from dsn_ae import DSNAE
5 | from evaluation_utils import *
6 | from mlp import MLP
7 | from loss_and_metrics import mmd_loss
8 |
9 |
10 | def eval_dsnae_epoch(model, data_loader, device, history):
11 | """
12 |
13 | :param model:
14 | :param data_loader:
15 | :param device:
16 | :param history:
17 | :return:
18 | """
19 | model.eval()
20 | avg_loss_dict = defaultdict(float)
21 | for x_batch in data_loader:
22 | x_batch = x_batch[0].to(device)
23 | with torch.no_grad():
24 | loss_dict = model.loss_function(*(model(x_batch)))
25 | for k, v in loss_dict.items():
26 | avg_loss_dict[k] += v.cpu().detach().item() / len(data_loader)
27 |
28 | for k, v in avg_loss_dict.items():
29 | history[k].append(v)
30 | return history
31 |
32 |
33 | def dsn_ae_train_step(s_dsnae, t_dsnae, s_batch, t_batch, device, optimizer, history, scheduler=None):
34 | s_dsnae.zero_grad()
35 | t_dsnae.zero_grad()
36 | s_dsnae.train()
37 | t_dsnae.train()
38 |
39 | s_x = s_batch[0].to(device)
40 | t_x = t_batch[0].to(device)
41 |
42 | s_code = s_dsnae.s_encode(s_x)
43 | t_code = t_dsnae.s_encode(t_x)
44 |
45 | s_loss_dict = s_dsnae.loss_function(*s_dsnae(s_x))
46 | t_loss_dict = t_dsnae.loss_function(*t_dsnae(t_x))
47 |
48 | optimizer.zero_grad()
49 | m_loss = mmd_loss(source_features=s_code, target_features=t_code, device=device)
50 | loss = s_loss_dict['loss'] + t_loss_dict['loss'] + m_loss
51 | loss.backward()
52 |
53 | optimizer.step()
54 | if scheduler is not None:
55 | scheduler.step()
56 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()}
57 |
58 | for k, v in loss_dict.items():
59 | history[k].append(v)
60 |
61 | history['mmd_loss'].append(m_loss.cpu().detach().item())
62 |
63 | return history
64 |
65 |
66 | def train_dsn_mmd(s_dataloaders, t_dataloaders, **kwargs):
67 | """
68 |
69 | :param s_dataloaders:
70 | :param t_dataloaders:
71 | :param kwargs:
72 | :return:
73 | """
74 | s_train_dataloader = s_dataloaders[0]
75 | s_test_dataloader = s_dataloaders[1]
76 |
77 | t_train_dataloader = t_dataloaders[0]
78 | t_test_dataloader = t_dataloaders[1]
79 |
80 | shared_encoder = MLP(input_dim=kwargs['input_dim'],
81 | output_dim=kwargs['latent_dim'],
82 | hidden_dims=kwargs['encoder_hidden_dims'],
83 | dop=kwargs['dop']).to(kwargs['device'])
84 |
85 | shared_decoder = MLP(input_dim=2 * kwargs['latent_dim'],
86 | output_dim=kwargs['input_dim'],
87 | hidden_dims=kwargs['encoder_hidden_dims'][::-1],
88 | dop=kwargs['dop']).to(kwargs['device'])
89 |
90 | s_dsnae = DSNAE(shared_encoder=shared_encoder,
91 | decoder=shared_decoder,
92 | alpha=kwargs['alpha'],
93 | input_dim=kwargs['input_dim'],
94 | latent_dim=kwargs['latent_dim'],
95 | hidden_dims=kwargs['encoder_hidden_dims'],
96 | dop=kwargs['dop'],
97 | norm_flag=False).to(kwargs['device'])
98 |
99 | t_dsnae = DSNAE(shared_encoder=shared_encoder,
100 | decoder=shared_decoder,
101 | alpha=kwargs['alpha'],
102 | input_dim=kwargs['input_dim'],
103 | latent_dim=kwargs['latent_dim'],
104 | hidden_dims=kwargs['encoder_hidden_dims'],
105 | dop=kwargs['dop'],
106 | norm_flag=False).to(kwargs['device'])
107 |
108 |
109 | device = kwargs['device']
110 |
111 | dsnae_train_history = defaultdict(list)
112 | dsnae_val_history = defaultdict(list)
113 |
114 | if kwargs['retrain_flag']:
115 | ae_params = [t_dsnae.private_encoder.parameters(),
116 | s_dsnae.private_encoder.parameters(),
117 | shared_decoder.parameters(),
118 | shared_encoder.parameters()
119 | ]
120 |
121 | ae_optimizer = torch.optim.AdamW(chain(*ae_params), lr=kwargs['lr'])
122 |
123 |
124 | for epoch in range(int(kwargs['train_num_epochs'])):
125 | if epoch % 50 == 0:
126 | print(f'AE training epoch {epoch}')
127 | for step, s_batch in enumerate(s_train_dataloader):
128 | t_batch = next(iter(t_train_dataloader))
129 | dsnae_train_history = dsn_ae_train_step(s_dsnae=s_dsnae,
130 | t_dsnae=t_dsnae,
131 | s_batch=s_batch,
132 | t_batch=t_batch,
133 | device=device,
134 | optimizer=ae_optimizer,
135 | history=dsnae_train_history)
136 | dsnae_val_history = eval_dsnae_epoch(model=s_dsnae,
137 | data_loader=s_test_dataloader,
138 | device=device,
139 | history=dsnae_val_history
140 | )
141 | dsnae_val_history = eval_dsnae_epoch(model=t_dsnae,
142 | data_loader=t_test_dataloader,
143 | device=device,
144 | history=dsnae_val_history
145 | )
146 | for k in dsnae_val_history:
147 | if k != 'best_index':
148 | dsnae_val_history[k][-2] += dsnae_val_history[k][-1]
149 | dsnae_val_history[k].pop()
150 |
151 | if kwargs['es_flag']:
152 | save_flag, stop_flag = model_save_check(dsnae_val_history, metric_name='loss', tolerance_count=50)
153 | if save_flag:
154 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'm_s_dsnae.pt'))
155 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt'))
156 | if stop_flag:
157 | break
158 |
159 | if kwargs['es_flag']:
160 | s_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'm_s_dsnae.pt')))
161 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt')))
162 |
163 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'm_s_dsnae.pt'))
164 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt'))
165 |
166 | else:
167 | try:
168 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt')))
169 | except FileNotFoundError:
170 | raise Exception("No pre-trained encoder")
171 |
172 | return t_dsnae.shared_encoder, (dsnae_train_history, dsnae_val_history)
173 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multi-Drug-Transfer-Learning
2 | Pre-clinical drug discovery (PDD) faces the low efficiency dilemma. One of the reasons is the lack of cross-drug efficacy evaluation infrastructure at the patient level. Here we propose Patient Multi-Drug Learning(P-MDL) task, and construct the P-MDL dataset and model zoo. The best P-MDL model DSN-adv achieve the SOTA performance in all of the 13 tumor types compared with previous SOTA models.
3 |
4 | You can also check out our [podcast](https://www.ximalaya.com/sound/938380402) and watch our intro videos (in both [Chinese](https://www.bilibili.com/video/BV1ta2LBzEJd/?spm_id_from=333.1387.homepage.video_card.click&vd_source=3d478d07006b6eaf39bdb03ccbf5ec3a) and [English](https://youtu.be/G3_kJq5isG8)) for a quick overview.
5 |
6 | ## Features
7 |
8 | ### P-MDL task
9 | Artificial intelligence (AI) models used for drug response prediction (DRP) tasks are generally classified into Single-Drug learning (SDL) and Multi-Drug Learning (MDL) paradigms. SDL paradigms have been adapted to the patient level and evaluate within-drug response, disregarding tumor types. However, there exist substantial differences in treatment response and survival outcomes among different tumor types, indicating that tumor type is a crucial confounding factor that can not be overlooked when predicting drug response. Additionally, SDL paradigms fail to assess cross-drug response, while MDL paradigms are currently limited to the cell line level. Therefore, we propose the P-MDL approach, which aims to achieve a comprehensive view of drug response at the patient level.
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 | ### P-MDL dataset
18 | We constructed the first P-MDL dataset from publicly available data. Tumor types with relatively sufficient data were filtered out. Finally, 13 tumor types were selected for the P-MDL dataset.
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 | ### P-MDL model zoo
27 |
28 | P-MDL model zoo includes eight models employing different transfer learning methods:
29 | | **P-MDL models** | **Description** |
30 | | ------- | ------- |
31 | | **ae** | **Autoencoder** used for encoding both of the gene expression profiles (GEPs) of cell lines and patients. |
32 | | **ae-mmd** | **ae** model added with another **mmd-loss**. |
33 | | **ae-adv** | **ae** model added with another **adv-loss**. |
34 | | **dsn**| **Domain seperation network** has been successfully applied in computer vision. Here, it is used for the GEPs encoding of cell lines and patients. |
35 | | **dsn-adv**| **dsn** model added with another **adv-loss**. |
36 | | **dsrn** | An variant of **dsn** model.|
37 | | **dsrn-mmd** | **dsrn** model added with another **mmd-loss**. |
38 | | **dsrn-adv**| **dsrn** model added with another **adv-loss**. |
39 |
40 |
49 |
50 |
51 |
52 | ### Model evaluation
53 | One of the P-MDL models (DSN-adv) outperforms all of the P-SDL and C-MDL models across all tumor types.
54 |
55 |
56 |
57 |
58 |
59 | ### Patient Drug Response (PDR) prediction for PDD
60 | To further validate the P-MDL models and demonstrate their potential in PDD applications, the test-pairwise pre-trained DSN-adv model was used to screen 233 small molecules for patients of 13 tumor types.
61 | Take tumor type COAD as an example, most drugs were inefficient, but a few drugs showed potential efficacy for over half of COAD patients.
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 | ## Installation
70 |
71 | To support basic usage of P-MDL task, run the following commands:
72 |
73 | ```bash
74 | conda create -n P-MDL python=3.8
75 | conda activate P-MDL
76 | conda install -c conda-forge rdkit
77 |
78 | pip install torch
79 | pip install pandas
80 | pip install numpy
81 | pip install sklearn
82 | ```
83 |
84 | ## Quick Start
85 |
86 | Here, we provide the instruction for a quick start.
87 |
88 |
89 |
96 |
97 | ### Step 1: Data Preparation
98 |
99 | Download the data at Zenodo [here](https://zenodo.org/record/8021167). Move the download files to ``data/`` folder.
100 |
101 | ### Step 2: Training and Evaluation
102 |
103 | Run:
104 |
105 | ```bash
106 | cd code
107 | nohup bash run_pretrain.sh 1>benchmark.txt 2>&1 &
108 | less benchmark.txt
109 | ```
110 |
111 | The ``benchmark.txt`` will look like the following:
112 | ```bash
113 | Processing next task: 8 20230718-043332
114 | All-data pre-training (ADP) model runing!
115 | Start to run Test-pairwise pre-training (TPP) model!
116 | ...
117 | ```
118 |
119 | The bash file [run_pretrain.sh](./code/run_pretrain.sh) will run the script [P_MDL.py](./code/P_MDL.py) in a proper params setting. You can also find model Log output in ``records/`` folder and model evaluation results in ``results/`` folder.
120 |
121 | ### Step 3: Interfence the PDR scores for PDD application
122 |
123 | Run:
124 |
125 | ```bash
126 | cd code
127 | nohup bash run_pretrain_for_pdr.sh 1>pdr_pretrain.txt 2>&1 &
128 | less pdr_pretrain.txt
129 | ```
130 |
131 | The ``pdr_pretrain.txt`` will look like the following:
132 | ```bash
133 | Processing next task: 8 20230718-043332
134 | All-data pre-training (ADP) model runing!
135 | Start to run Test-pairwise pre-training (TPP) model!
136 | ...
137 | ```
138 |
139 | For PDD application, we need to predict the efficacy of all drugs to every patients. So here we set the params ``--select_drug_method all`` in the bash file [run_pretrain_for_pdr.sh](./code/run_pretrain_for_pdr.sh) to recover the model which can be used for all-drug response prediction.
140 |
141 | Then you can run the bash file [run_pdr_task.sh](./code/run_pdr_task.sh) by ```nohup bash run_pdr_task.sh 1>pdr_task.txt 2>&1 &```, which will call the script [PDR_task.py](./code/PDR_task.py) to predict the efficacy of all drugs to every patients.
142 |
143 |
144 |
145 | ## Contact Us
146 |
147 | As a pre-alpha version release, we are looking forward to user feedback to help us improve our framework. If you have any questions or suggestions, please open an issue or contact [wuyushuai0727@gmail.com].
148 |
149 |
150 | ## Cite Us
151 |
152 | If you find our open-sourced code & models helpful to your research, please consider giving this repo a star🌟 and citing📑 the following article. Thank you for your support!
153 | ```
154 | @misc{P_MDL_code,
155 | author={Yushuai Wu},
156 | title={Code of Multi-Drug-Transfer-Learning},
157 | year={2025},
158 | howpublished={\url{https://github.com/wuys13/Multi-Drug-Transfer-Learning.git}}
159 | }
160 | ```
161 |
162 | ## Contributing
163 |
164 | If you encounter problems, feel free to create an issue!
--------------------------------------------------------------------------------
/code/model_train/train_dsrn_mmd.py:
--------------------------------------------------------------------------------
1 | import os
2 | from itertools import chain
3 |
4 | from dsn_ae import DSNAE
5 | from evaluation_utils import *
6 | from mlp import MLP
7 | from loss_and_metrics import mmd_loss
8 | from collections import OrderedDict
9 |
10 | def eval_dsnae_epoch(model, data_loader, device, history):
11 | """
12 |
13 | :param model:
14 | :param data_loader:
15 | :param device:
16 | :param history:
17 | :return:
18 | """
19 | model.eval()
20 | avg_loss_dict = defaultdict(float)
21 | for x_batch in data_loader:
22 | x_batch = x_batch[0].to(device)
23 | with torch.no_grad():
24 | loss_dict = model.loss_function(*(model(x_batch)))
25 | for k, v in loss_dict.items():
26 | avg_loss_dict[k] += v.cpu().detach().item() / len(data_loader)
27 |
28 | for k, v in avg_loss_dict.items():
29 | history[k].append(v)
30 | return history
31 |
32 |
33 | def dsn_ae_train_step(s_dsnae, t_dsnae, s_batch, t_batch, device, optimizer, history, scheduler=None):
34 | s_dsnae.zero_grad()
35 | t_dsnae.zero_grad()
36 | s_dsnae.train()
37 | t_dsnae.train()
38 |
39 | s_x = s_batch[0].to(device)
40 | t_x = t_batch[0].to(device)
41 |
42 | s_code = s_dsnae.encode(s_x)
43 | t_code = t_dsnae.encode(t_x)
44 |
45 | s_loss_dict = s_dsnae.loss_function(*s_dsnae(s_x))
46 | t_loss_dict = t_dsnae.loss_function(*t_dsnae(t_x))
47 |
48 | optimizer.zero_grad()
49 | m_loss = mmd_loss(source_features=s_code, target_features=t_code, device=device)
50 | loss = s_loss_dict['loss'] + t_loss_dict['loss'] + m_loss
51 | loss.backward()
52 |
53 | optimizer.step()
54 | if scheduler is not None:
55 | scheduler.step()
56 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()}
57 |
58 | for k, v in loss_dict.items():
59 | history[k].append(v)
60 | history['mmd_loss'].append(m_loss.cpu().detach().item())
61 |
62 | return history
63 |
64 |
65 | def train_dsrn_mmd(s_dataloaders, t_dataloaders, **kwargs):
66 | """
67 |
68 | :param s_dataloaders:
69 | :param t_dataloaders:
70 | :param kwargs:
71 | :return:
72 | """
73 | s_train_dataloader = s_dataloaders[0]
74 | s_test_dataloader = s_dataloaders[1]
75 |
76 | t_train_dataloader = t_dataloaders[0]
77 | t_test_dataloader = t_dataloaders[1]
78 |
79 | shared_encoder = MLP(input_dim=kwargs['input_dim'],
80 | output_dim=kwargs['latent_dim'],
81 | hidden_dims=kwargs['encoder_hidden_dims'],
82 | dop=kwargs['dop']).to(kwargs['device'])
83 |
84 | shared_decoder = MLP(input_dim=2 * kwargs['latent_dim'],
85 | output_dim=kwargs['input_dim'],
86 | hidden_dims=kwargs['encoder_hidden_dims'][::-1],
87 | dop=kwargs['dop']).to(kwargs['device'])
88 |
89 | s_dsnae = DSNAE(shared_encoder=shared_encoder,
90 | decoder=shared_decoder,
91 | alpha=kwargs['alpha'],
92 | input_dim=kwargs['input_dim'],
93 | latent_dim=kwargs['latent_dim'],
94 | hidden_dims=kwargs['encoder_hidden_dims'],
95 | dop=kwargs['dop'],
96 | norm_flag=kwargs['norm_flag']).to(kwargs['device'])
97 |
98 | t_dsnae = DSNAE(shared_encoder=shared_encoder,
99 | decoder=shared_decoder,
100 | alpha=kwargs['alpha'],
101 | input_dim=kwargs['input_dim'],
102 | latent_dim=kwargs['latent_dim'],
103 | hidden_dims=kwargs['encoder_hidden_dims'],
104 | dop=kwargs['dop'],
105 | norm_flag=kwargs['norm_flag']).to(kwargs['device'])
106 |
107 | device = kwargs['device']
108 |
109 | dsnae_train_history = defaultdict(list)
110 | dsnae_val_history = defaultdict(list)
111 |
112 | if kwargs['retrain_flag']:
113 | ae_params = [t_dsnae.private_encoder.parameters(),
114 | s_dsnae.private_encoder.parameters(),
115 | shared_decoder.parameters(),
116 | shared_encoder.parameters()
117 | ]
118 |
119 | ae_optimizer = torch.optim.AdamW(chain(*ae_params), lr=kwargs['lr'])
120 |
121 |
122 | for epoch in range(int(kwargs['train_num_epochs'])):
123 | if epoch % 50 == 0:
124 | print(f'AE training epoch {epoch}')
125 | for step, s_batch in enumerate(s_train_dataloader):
126 | t_batch = next(iter(t_train_dataloader))
127 | dsnae_train_history = dsn_ae_train_step(s_dsnae=s_dsnae, #loss = s_loss_dict['loss'] + t_loss_dict['loss'] + m_loss
128 | t_dsnae=t_dsnae,
129 | s_batch=s_batch,
130 | t_batch=t_batch,
131 | device=device,
132 | optimizer=ae_optimizer,
133 | history=dsnae_train_history)
134 | dsnae_val_history = eval_dsnae_epoch(model=s_dsnae,
135 | data_loader=s_test_dataloader,
136 | device=device,
137 | history=dsnae_val_history
138 | )
139 | dsnae_val_history = eval_dsnae_epoch(model=t_dsnae,
140 | data_loader=t_test_dataloader,
141 | device=device,
142 | history=dsnae_val_history
143 | )
144 | for k in dsnae_val_history:
145 | if k != 'best_index':
146 | dsnae_val_history[k][-2] += dsnae_val_history[k][-1]
147 | dsnae_val_history[k].pop()
148 |
149 |
150 | if kwargs['es_flag']:
151 | save_flag, stop_flag = model_save_check(dsnae_val_history, metric_name='loss', tolerance_count=50)
152 | if save_flag:
153 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'm_s_dsnae.pt'))
154 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt'))
155 | if stop_flag:
156 | break
157 |
158 | if kwargs['es_flag']:
159 | s_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'm_s_dsnae.pt')))
160 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt')))
161 |
162 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'm_s_dsnae.pt'))
163 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt'))
164 |
165 | else:
166 | try:
167 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt')))
168 | # if kwargs['norm_flag']:
169 | # loaded_model = torch.load(os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt'))
170 | # new_loaded_model = {key: val for key, val in loaded_model.items() if key in t_dsnae.state_dict()}
171 | # new_loaded_model['shared_encoder.output_layer.0.weight'] = loaded_model[
172 | # 'shared_encoder.output_layer.3.weight']
173 | # new_loaded_model['shared_encoder.output_layer.0.bias'] = loaded_model[
174 | # 'shared_encoder.output_layer.3.bias']
175 | # new_loaded_model['decoder.output_layer.0.weight'] = loaded_model['decoder.output_layer.3.weight']
176 | # new_loaded_model['decoder.output_layer.0.bias'] = loaded_model['decoder.output_layer.3.bias']
177 |
178 | # corrected_model = OrderedDict({key: new_loaded_model[key] for key in t_dsnae.state_dict()})
179 | # t_dsnae.load_state_dict(corrected_model)
180 | # else:
181 | # t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'm_t_dsnae.pt')))
182 |
183 | except FileNotFoundError:
184 | raise Exception("No pre-trained encoder")
185 |
186 |
187 | return t_dsnae.shared_encoder, (dsnae_train_history, dsnae_val_history)
188 |
--------------------------------------------------------------------------------
/code/evaluation_utils.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import torch
4 |
5 | from collections import defaultdict
6 | from torch import nn
7 | from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, f1_score, \
8 | log_loss, auc, precision_recall_curve
9 | from scipy.stats import pearsonr #pearsonr(self.y_val, y_pred_val[:,0])[0]
10 |
11 |
12 |
13 | def auprc(y_true, y_score):
14 | lr_precision, lr_recall, _ = precision_recall_curve(y_true=y_true, probas_pred=y_score)
15 | return auc(lr_recall, lr_precision)
16 |
17 |
18 | def model_save_check(history, metric_name, tolerance_count=5, reset_count=1):
19 | save_flag = False
20 | stop_flag = False
21 | if 'best_index' not in history:
22 | history['best_index'] = 0
23 | if metric_name.endswith('loss'):
24 | if history[metric_name][-1] <= history[metric_name][history['best_index']]:
25 | save_flag = True
26 | history['best_index'] = len(history[metric_name]) - 1
27 | else:
28 | if history[metric_name][-1] >= history[metric_name][history['best_index']]:
29 | save_flag = True
30 | history['best_index'] = len(history[metric_name]) - 1
31 |
32 | # print(history['best_index'],history[metric_name][history['best_index']])
33 | if len(history[metric_name]) - history['best_index'] > tolerance_count * reset_count and history['best_index'] > 0:
34 | stop_flag = True
35 | a,b,c,d = history['best_index'],len(history[metric_name]) - 1,history[metric_name][0],history[metric_name][-1]
36 | print(f'The best epoch: {a} / {b}')
37 | print(f'Metric from first to stop: {c} to {d}')
38 |
39 | return save_flag, stop_flag
40 |
41 |
42 | def eval_ae_epoch(model, data_loader, device, history):
43 | model.eval()
44 | avg_loss_dict = defaultdict(float)
45 | for x_batch in data_loader:
46 | x_batch = x_batch[0].to(device)
47 | with torch.no_grad():
48 | loss_dict = model.loss_function(*(model(x_batch)))
49 | for k, v in loss_dict.items():
50 | avg_loss_dict[k] += v.cpu().detach().item() / len(data_loader)
51 |
52 | for k, v in avg_loss_dict.items():
53 | history[k].append(v)
54 | return history
55 |
56 | def evaluate_target_classification_epoch(classifier, dataloader, device, history,class_num,test_flag = False):
57 | y_truths = np.array([])
58 | classifier.eval()
59 |
60 | if class_num == 0:
61 | y_preds = np.array([])
62 | for x_gex ,x_smiles, y_batch in dataloader:
63 | x_gex = x_gex.to(device)
64 | x_smiles = x_smiles.to(device)
65 | y_batch = y_batch.to(device)
66 | # print("x_gex:",x_gex)
67 | # print("x_smiles",x_smiles)
68 | # exit()
69 | with torch.no_grad():
70 | y_truths = np.concatenate([y_truths, y_batch.cpu().detach().numpy().ravel()])
71 | # print("y_pred0",classifier(x_smiles,x_gex))
72 | y_pred = torch.sigmoid(classifier(x_smiles,x_gex)).detach()
73 | # print("y_pred",y_pred)
74 | # exit()
75 | y_preds = np.concatenate([y_preds, y_pred.cpu().detach().numpy().ravel()])
76 | elif class_num == 1:
77 | y_preds = []
78 | for x_gex ,x_smiles, y_batch in dataloader:
79 | x_gex = x_gex.to(device)
80 | x_smiles = x_smiles.to(device)
81 | y_batch = y_batch.to(device)
82 | with torch.no_grad():
83 | y_truths = np.concatenate([y_truths, y_batch.cpu().detach().numpy().ravel()])
84 | # y_pred = torch.sigmoid(classifier(x_smiles,x_gex)).detach()
85 | y_pred = classifier(x_smiles,x_gex).detach()
86 | y_preds.append( y_pred.cpu().detach().tolist() )
87 | y_preds = [token for st in y_preds for token in st]
88 | y_preds = np.array(y_preds)
89 |
90 | else :
91 | y_preds = []
92 | for x_gex ,x_smiles, y_batch in dataloader:
93 | x_gex = x_gex.to(device)
94 | x_smiles = x_smiles.to(device)
95 | y_batch = y_batch.to(device)
96 | with torch.no_grad():
97 | y_truths = np.concatenate([y_truths, y_batch.cpu().detach().numpy().ravel()])
98 | # y_pred = torch.sigmoid(classifier(x_smiles,x_gex)).detach()
99 | y_pred = nn.functional.softmax(classifier(x_smiles,x_gex),dim=1).detach()
100 | y_preds.append( y_pred.cpu().detach().tolist() )
101 | y_preds = [token for st in y_preds for token in st]
102 | y_preds = np.array(y_preds)
103 |
104 | if class_num == 2:
105 | # y_preds = y_preds.gather(-1,nn.init.ones_(torch.zeros(y_preds.size(0), 1))) #for torch
106 | # print(y_preds)
107 | y_preds = y_preds[:,1]
108 | # print(y_preds)
109 | # exit()
110 | elif class_num > 2 :
111 | if test_flag :#做zero shot
112 | a = torch.from_numpy(y_preds)
113 | _, max_class = torch.max(a, -1)
114 | y_preds = torch.sigmoid(max_class).numpy() #x->3->[0,1]
115 | #numpy
116 | # max_class = np.argmax(y_preds,-1)
117 | # y_preds = sigmoid(max_class)
118 | else:
119 | y_preds = y_preds
120 | else:
121 | raise Exception("class_num error")
122 |
123 | # print("y_truths: ",y_truths)
124 | # print("y_preds: ",y_preds)
125 | if class_num == 1:
126 | if test_flag :
127 | # history['auroc'].append(1-roc_auc_score(y_true=y_truths, y_score=sigmoid(y_preds))) #AUC越大越差
128 | history['auroc'].append(roc_auc_score(y_true=y_truths, y_score=1-y_preds)) #AUC本身就在01之间,加个负号就行
129 | else:
130 | history['auroc'].append(pearsonr(y_truths, y_preds)[0])
131 | else:
132 | history['auroc'].append(roc_auc_score(y_true=y_truths, y_score=y_preds,multi_class='ovo')) #暂时只有auroc,多分类下面可能会有问题
133 | history['auprc'].append(auprc(y_true=y_truths, y_score=y_preds))
134 | history['acc'].append(accuracy_score(y_true=y_truths, y_pred=(y_preds > 0.5).astype('int')))
135 | history['f1'].append(f1_score(y_true=y_truths, y_pred=(y_preds > 0.5).astype('int')))
136 | history['aps'].append(average_precision_score(y_true=y_truths, y_score=y_preds))
137 | history['ce'].append(log_loss(y_true=y_truths, y_pred=y_preds))
138 |
139 |
140 | return history
141 |
142 |
143 | def evaluate_adv_classification_epoch(classifier, s_dataloader, t_dataloader, device, history):
144 | y_truths = np.array([])
145 | y_preds = np.array([])
146 | classifier.eval()
147 |
148 | for s_batch in s_dataloader:
149 | s_x = s_batch[0].to(device)
150 | with torch.no_grad():
151 | y_truths = np.concatenate([y_truths, np.zeros(s_x.shape[0]).ravel()])
152 | s_y_pred = torch.sigmoid(classifier(s_x)).detach()
153 | y_preds = np.concatenate([y_preds, s_y_pred.cpu().detach().numpy().ravel()])
154 |
155 | for t_batch in t_dataloader:
156 | t_x = t_batch[0].to(device)
157 | with torch.no_grad():
158 | y_truths = np.concatenate([y_truths, np.ones(t_x.shape[0]).ravel()])
159 | t_y_pred = torch.sigmoid(classifier(t_x)).detach()
160 | y_preds = np.concatenate([y_preds, t_y_pred.cpu().detach().numpy().ravel()])
161 |
162 | history['acc'].append(accuracy_score(y_true=y_truths, y_pred=(y_preds > 0.5).astype('int')))
163 | history['auroc'].append(roc_auc_score(y_true=y_truths, y_score=y_preds))
164 | history['aps'].append(average_precision_score(y_true=y_truths, y_score=y_preds))
165 | history['f1'].append(f1_score(y_true=y_truths, y_pred=(y_preds > 0.5).astype('int')))
166 | history['bce'].append(log_loss(y_true=y_truths, y_pred=y_preds))
167 | history['auprc'].append(auprc(y_true=y_truths, y_score=y_preds))
168 |
169 | return history
170 |
171 |
172 | # patient drug score matrix
173 | def predict_pdr_score(classifier, pdr_dataloader ,device):
174 | test_dataloader,patient_index = pdr_dataloader
175 | y_preds = np.array([])
176 | classifier.eval()
177 | for x_gex ,x_smiles in test_dataloader:
178 | x_gex = x_gex.to(device)
179 | x_smiles = x_smiles.to(device)
180 | # print("x_gex:",x_gex)
181 | # print("x_smiles",x_smiles)
182 | # exit()
183 | with torch.no_grad():
184 | y_pred = torch.sigmoid(classifier(x_smiles,x_gex)).detach()
185 | # print("y_pred",y_pred)
186 | # exit()
187 | y_preds = np.concatenate([y_preds, y_pred.cpu().detach().numpy().ravel()])
188 |
189 | drug_list = pd.read_csv('../data/preprocessed_dat/drug_embedding/CCL_dataset/drug_smiles.csv')
190 | patient_num = len(patient_index)
191 | y_preds = y_preds.reshape(patient_num,-1)
192 | output_df = pd.DataFrame(y_preds,index=patient_index,columns=drug_list['Drug_name'])
193 |
194 | return output_df
195 |
196 | def sigmoid(x):
197 | return 1 / (1 + np.exp(-x))
198 |
--------------------------------------------------------------------------------
/code/P_MDL.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 |
4 | import torch
5 | import json
6 | import os
7 | import argparse
8 | import random
9 | import pickle
10 | import itertools
11 | import data
12 | import fine_tuning
13 |
14 | # add model_train path
15 | import sys
16 | sys.path.append('./model_train')
17 |
18 | import train_ae
19 | import train_ae_mmd
20 | import train_ae_adv
21 |
22 | import train_dsn
23 | import train_dsn_mmd
24 | import train_dsn_adv
25 |
26 | import train_dsrn_mmd
27 | import train_dsrn_adv
28 |
29 | from copy import deepcopy
30 | from collections import defaultdict
31 |
32 |
33 | def dict_to_str(d):
34 |
35 | return "_".join(["_".join([k, str(v)]) for k, v in d.items()])
36 |
37 | def wrap_training_params(training_params, type='unlabeled'):
38 | aux_dict = {k: v for k, v in training_params.items() if k not in ['unlabeled', 'labeled']}
39 | aux_dict.update(**training_params[type])
40 |
41 | return aux_dict
42 |
43 | def safe_make_dir(new_folder_name):
44 | if not os.path.exists(new_folder_name):
45 | os.makedirs(new_folder_name)
46 | else:
47 | print(new_folder_name, 'exists!')
48 |
49 |
50 |
51 | def main(args, update_params_dict):
52 | if args.method == 'ae':
53 | train_fn = train_ae.train_ae
54 | elif args.method == 'ae_mmd':
55 | train_fn = train_ae_mmd.train_ae_mmd
56 | elif args.method == 'ae_adv':
57 | train_fn = train_ae_adv.train_ae_adv
58 |
59 | elif args.method == 'dsn':
60 | train_fn = train_dsn.train_dsn
61 | elif args.method == 'dsn_mmd':
62 | train_fn = train_dsn_mmd.train_dsn_mmd
63 | elif args.method == 'dsn_adv':
64 | train_fn = train_dsn_adv.train_dsn_adv
65 |
66 | elif args.method == 'dsrn_mmd':
67 | train_fn = train_dsrn_mmd.train_dsrn_mmd
68 | elif args.method == 'dsrn_adv':
69 | train_fn = train_dsrn_adv.train_dsrn_adv
70 |
71 | else:
72 | raise NotImplementedError("Not true method supplied!")
73 |
74 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
75 |
76 |
77 | with open(os.path.join('params/train_params.json'), 'r') as f:
78 | training_params = json.load(f)
79 |
80 | training_params['unlabeled'].update(update_params_dict)
81 | param_str = dict_to_str(update_params_dict)
82 |
83 | source_dir = os.path.join(args.pretrain_dataset,
84 | tumor_type,
85 | args.select_drug_method)
86 |
87 | store_dir = os.path.join('../results',args.store_dir)
88 | method_save_folder = os.path.join(store_dir, f'{args.method}_norm',source_dir)
89 | task_save_folder = os.path.join(f'{method_save_folder}', args.measurement)
90 | safe_make_dir(task_save_folder) #unlabel result save dir
91 |
92 | training_params.update(
93 | {
94 | 'device': device,
95 | 'input_dim': gex_features_df.shape[-1],
96 | 'model_save_folder': os.path.join(method_save_folder, param_str),
97 | 'es_flag': False,
98 | 'retrain_flag': args.retrain_flag,
99 | 'norm_flag': args.norm_flag
100 | })
101 |
102 | safe_make_dir(training_params['model_save_folder'])
103 |
104 | random.seed(2020)
105 |
106 | s_dataloaders, t_dataloaders = data.get_unlabeled_dataloaders(
107 | gex_features_df=gex_features_df,
108 | seed=2020,
109 | batch_size=training_params['unlabeled']['batch_size'],
110 | ccle_only=False
111 | )
112 |
113 |
114 | encoder, historys = train_fn(s_dataloaders=s_dataloaders,
115 | t_dataloaders=t_dataloaders,
116 | **wrap_training_params(training_params, type='unlabeled'))
117 | print(' ')
118 | # print("Trained SE:",encoder)
119 |
120 |
121 | ft_evaluation_metrics = defaultdict(list)
122 |
123 | # Fine-tuning dataset
124 | labeled_dataloader_generator = data.get_finetune_dataloader_generator(
125 | gex_features_df = gex_features_df,
126 | # all_ccle_gex = all_ccle_gex,
127 | seed=2020,
128 | batch_size=training_params['labeled']['batch_size'],
129 | dataset = args.CCL_dataset, #finetune_dataset
130 | # sample_size = 0.006,
131 | ccle_measurement=args.measurement,
132 | n_splits=args.n,
133 | q=2,
134 | tumor_type = tumor_type,
135 | label_type = args.label_type,
136 | select_drug_method = args.select_drug_method)
137 |
138 | fold_count = 0
139 | for train_labeled_ccle_dataloader, test_labeled_ccle_dataloader, labeled_tcga_dataloader in labeled_dataloader_generator:
140 | ft_encoder = deepcopy(encoder)
141 | print(' ')
142 | print('Fold count = {}'.format(fold_count))
143 | if args.select_drug_method == "overlap" : # only one fine-tuned model save for benchmark
144 | save_folder = method_save_folder
145 | elif args.select_drug_method == "all" : # all fine-tuned model save for pdr
146 | save_folder = training_params['model_save_folder']
147 | target_classifier, ft_historys = fine_tuning.fine_tune_encoder_drug(
148 | encoder=ft_encoder,
149 | train_dataloader=train_labeled_ccle_dataloader,
150 | val_dataloader=test_labeled_ccle_dataloader,
151 | test_dataloader=labeled_tcga_dataloader,
152 | fold_count=fold_count,
153 | normalize_flag=args.norm_flag,
154 | metric_name=args.metric,
155 | task_save_folder = save_folder,
156 | drug_emb_dim=300,
157 | class_num = args.class_num,
158 | store_dir = method_save_folder,
159 | **wrap_training_params(training_params, type='labeled')
160 | )
161 | if fold_count == 0 :
162 | print(f"Save target_classifier {param_dict}.")
163 | torch.save(target_classifier.state_dict(),
164 | os.path.join(method_save_folder, f'save_classifier_{fold_count}.pt'))
165 | ft_evaluation_metrics['best_index'].append(ft_historys[-2]['best_index'])
166 | for metric in ['auroc']:
167 | ft_evaluation_metrics[metric].append(ft_historys[-1][metric][ft_historys[-2]['best_index']])
168 | fold_count += 1
169 | print(' ')
170 | print(' ')
171 |
172 |
173 | with open(os.path.join(task_save_folder, f'{param_str}_ft_evaluation_results.json'), 'w') as f:
174 | json.dump(ft_evaluation_metrics, f)
175 |
176 |
177 | if __name__ == '__main__':
178 | parser = argparse.ArgumentParser('P-MDL training and evaluation')
179 | parser.add_argument('--pretrain_num',default = None,type=int)
180 | parser.add_argument('--zero_shot_num',default = None,type=int)
181 | parser.add_argument('--method_num',default = None,type=int)
182 |
183 | parser.add_argument('--method', dest='method', nargs='?', default='dsn_adv',
184 | choices=['ae','ae_mmd','ae_adv',
185 | 'dsn','dsn_mmd','dsn_adv',
186 | 'dsrn_mmd','dsrn_adv'])
187 | parser.add_argument('--metric', dest='metric', nargs='?', default='auroc', choices=['auroc', 'auprc'])
188 |
189 | parser.add_argument('--measurement', dest='measurement', nargs='?', default='AUC', choices=['AUC', 'LN_IC50'])
190 |
191 | parser.add_argument('--n', dest='n', nargs='?', type=int, default=5)
192 |
193 | train_group = parser.add_mutually_exclusive_group(required=False)
194 | train_group.add_argument('--train', dest='retrain_flag', action='store_true')
195 | train_group.add_argument('--no-train', dest='retrain_flag', action='store_false')
196 | parser.set_defaults(retrain_flag=True)
197 | # parser.set_defaults(retrain_flag=False)
198 |
199 | norm_group = parser.add_mutually_exclusive_group(required=False)
200 | norm_group.add_argument('--norm', dest='norm_flag', action='store_true')
201 | norm_group.add_argument('--no-norm', dest='norm_flag', action='store_false')
202 | parser.set_defaults(norm_flag=True)
203 |
204 | parser.add_argument('--label_type', default = "PFS",choices=["PFS","Imaging"])
205 |
206 | parser.add_argument('--select_drug_method', default = "overlap",choices=["overlap","all","random"])
207 |
208 | parser.add_argument('--store_dir',default = "benchmark")
209 | parser.add_argument('--select_gene_method',default = "Percent_sd",choices=["Percent_sd","HVG"])
210 | parser.add_argument('--select_gene_num',default = 1000,type=int)
211 |
212 | parser.add_argument('--pretrain_dataset',default = "tcga",
213 | choices=["tcga", "brca", "cesc", "coad", "gbm", "hnsc", "kirc",
214 | "lgg", "luad", "lusc","paad", "read", "sarc", "skcm", "stad"
215 | ])
216 | parser.add_argument('--tumor_type',default = "BRCA",
217 | choices=['TCGA','GBM', 'LGG', 'HNSC','KIRC','SARC','BRCA','STAD','CESC','SKCM','LUSC','LUAD','READ','COAD'
218 | ])
219 | parser.add_argument('--CCL_dataset',default = 'gdsc1_raw',
220 | choices=['gdsc1_raw','gdsc1_rebalance'])
221 | parser.add_argument('--class_num',default = 0,type=int)
222 |
223 | args = parser.parse_args()
224 | if args.class_num == 1:
225 | # args.CCL_dataset = f"{args.CCL_dataset}_regression"
226 | print("Regression task. Use dataset:",args.CCL_dataset)
227 |
228 | params_grid = {
229 | "pretrain_num_epochs": [0, 100, 300],
230 | "train_num_epochs": [100, 200, 300, 500, 750, 1000, 1500, 2000, 2500, 3000],
231 | "dop": [0.0, 0.1]
232 | }
233 |
234 |
235 | Tumor_type_list = ["tcga", "brca", "cesc", "coad", "gbm", "hnsc", "kirc",
236 | "lgg", "luad", "lusc","paad", "read", "sarc", "skcm", "stad"
237 | ]
238 | if args.pretrain_num :
239 | args.pretrain_dataset = Tumor_type_list[args.pretrain_num]
240 | if args.zero_shot_num :
241 | args.tumor_type = [element.upper() for element in Tumor_type_list][args.zero_shot_num]
242 | # print(f'Tumor type: Select zero_shot_num: {Num}. Zero-shot dataset: {args.tumor_type}')
243 | if args.method_num :
244 | args.method = [
245 | 'no','ae','dsn',
246 | 'ae_mmd','dsrn_mmd','dsn_mmd',
247 | 'ae_adv','dsrn_adv','dsn_adv'][args.method_num]
248 |
249 | # Test-pairwise pre-training set pretrain_dataset and test_dataset to the same tumor type in default
250 | tumor_type = args.pretrain_dataset.upper()
251 | if tumor_type == "TCGA" :
252 | tumor_type = args.tumor_type
253 |
254 | if args.method not in ['dsrn_adv', 'dsn_adv', 'ae_adv']:
255 | params_grid.pop('pretrain_num_epochs')
256 |
257 | keys, values = zip(*params_grid.items())
258 | update_params_dict_list = [dict(zip(keys, v)) for v in itertools.product(*values)]
259 |
260 |
261 | gex_features_df = pd.read_csv(f'../data/pretrain_data/{args.pretrain_dataset}_pretrain_dataset.csv',index_col=0)
262 | CCL_tumor_type = 'all_CCL'
263 |
264 | print(f'Pretrain dataset: Patient({args.pretrain_dataset} {args.pretrain_num}) CCL( {CCL_tumor_type}). Select_gene_method: {args.select_gene_method}')
265 | print(f'Zero-shot dataset: {tumor_type}({args.zero_shot_num})')
266 | print(f'CCL_dataset: {args.CCL_dataset} Select_drug_method: {args.select_drug_method}')
267 | print(f'Store_dir: {args.store_dir} ')
268 |
269 | print(f'method: {args.method}({args.method_num}). label_type: {args.label_type}')
270 | param_num = 0
271 |
272 | # update_params_dict_list.reverse()
273 |
274 | for param_dict in update_params_dict_list:
275 | param_num = param_num + 1
276 | print(' ')
277 | print('##############################################################################')
278 | print(f'####### Param_num {param_num}/{len(update_params_dict_list)} #######')
279 | print('Param_dict: {}'.format(param_dict) )
280 | print('##############################################################################')
281 | main(args=args, update_params_dict=param_dict)
282 | print("Finsh All !!!!!!!!!!!!!!!")
283 |
--------------------------------------------------------------------------------
/code/fine_tuning.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('./model_train')
3 |
4 | from evaluation_utils import evaluate_target_classification_epoch,model_save_check,predict_pdr_score
5 | from collections import defaultdict
6 | from itertools import chain
7 | from mlp import MLP
8 | from encoder_decoder import EncoderDecoder
9 | import os
10 | import torch
11 | import torch.nn as nn
12 |
13 |
14 | def fine_tune_encoder_drug(encoder, train_dataloader, val_dataloader, fold_count,store_dir, task_save_folder,test_dataloader=None,
15 | metric_name='auroc',
16 | class_num = 2,
17 | normalize_flag=False,
18 | break_flag=False,
19 | test_df=None,
20 | drug_emb_dim=128,
21 | to_roughly_test = False,
22 | **kwargs):
23 | finetune_output_dim = class_num
24 | if finetune_output_dim == 0:
25 | finetune_output_dim = 1
26 | target_decoder = MLP(input_dim=kwargs['latent_dim'] + drug_emb_dim,
27 | output_dim=finetune_output_dim,
28 | hidden_dims=kwargs['classifier_hidden_dims']).to(kwargs['device'])
29 |
30 | target_classifier = EncoderDecoder(encoder=encoder, decoder=target_decoder,
31 | normalize_flag=normalize_flag).to(kwargs['device'])
32 | # target_decoder load to re-train faster
33 | target_classifier_file = os.path.join(store_dir, 'save_classifier_0.pt')
34 | if os.path.exists(target_classifier_file) and fold_count>0:
35 | print("Loading ",target_classifier_file)
36 | target_classifier.load_state_dict(
37 | torch.load(target_classifier_file))
38 | else:
39 | to_roughly_test = False
40 | print("No target_classifier_file stored for use. Generate first!")
41 | # print(' ')
42 |
43 | if class_num == 0:
44 | classification_loss = nn.BCEWithLogitsLoss()
45 | elif class_num == 1:
46 | classification_loss = nn.MSELoss() ## nn.MAELoss()
47 | else:
48 | classification_loss = nn.CrossEntropyLoss()
49 |
50 | target_classification_train_history = defaultdict(list)
51 | target_classification_eval_train_history = defaultdict(list)
52 | target_classification_eval_val_history = defaultdict(list)
53 | target_classification_eval_test_history = defaultdict(list)
54 |
55 | encoder_module_indices = [i for i in range(len(list(encoder.modules())))
56 | if str(list(encoder.modules())[i]).startswith('Linear')]
57 |
58 | reset_count = 1
59 | lr = kwargs['lr']
60 |
61 | # target_classification_params = [target_classifier.decoder.parameters(),target_classifier.smiles_encoder.parameters()] #原来只更新decoder的层
62 | target_classification_params = [target_classifier.decoder.parameters()] #原来只更新decoder的层
63 |
64 | target_classification_optimizer = torch.optim.AdamW(chain(*target_classification_params),
65 | lr=lr)
66 |
67 | stop_flag_num = 0
68 | for epoch in range(kwargs['train_num_epochs']):
69 | # if epoch % 50 == 0:
70 | # print(f'Fine tuning epoch {epoch}')
71 | for step, batch in enumerate(train_dataloader):
72 | target_classification_train_history = classification_train_step_drug(model=target_classifier,
73 | batch=batch,
74 | loss_fn=classification_loss,
75 | device=kwargs['device'],
76 | optimizer=target_classification_optimizer,
77 | history=target_classification_train_history,
78 | class_num = class_num)
79 | target_classification_eval_train_history = evaluate_target_classification_epoch(classifier=target_classifier,
80 | dataloader=train_dataloader,
81 | device=kwargs['device'],
82 | history=target_classification_eval_train_history,
83 | class_num = class_num)
84 | target_classification_eval_val_history = evaluate_target_classification_epoch(classifier=target_classifier,
85 | dataloader=val_dataloader,
86 | device=kwargs['device'],
87 | history=target_classification_eval_val_history,
88 | class_num = class_num)
89 |
90 | if test_dataloader is not None:
91 | target_classification_eval_test_history = evaluate_target_classification_epoch(classifier=target_classifier,
92 | dataloader=test_dataloader,
93 | device=kwargs['device'],
94 | history=target_classification_eval_test_history,
95 | class_num = class_num,
96 | test_flag=True)
97 | save_flag, stop_flag = model_save_check(history=target_classification_eval_val_history,
98 | metric_name=metric_name,
99 | tolerance_count=5, # stop_flag once 5 epochs not better
100 | reset_count=reset_count)
101 | test_metric = target_classification_eval_test_history[metric_name][target_classification_eval_val_history['best_index']]
102 | if epoch % 50 == 0:
103 | print(f'Fine tuning epoch {epoch}. stop_flag_num: {stop_flag_num}. {test_metric}')
104 | # print(save_flag, stop_flag,stop_flag_num)
105 | if save_flag:
106 | torch.save(target_classifier.state_dict(),
107 | os.path.join(task_save_folder, f'target_classifier_{fold_count}.pt')) # save model
108 | # print("Save model. ",test_metric,epoch)
109 | if to_roughly_test:
110 | print("To roughly test pass, just get the zero-shot metric at the begining for judge.")
111 | break
112 | pass
113 | if stop_flag:
114 | stop_flag_num = stop_flag_num+1
115 | print(' ')
116 | try:
117 | ind = encoder_module_indices.pop()
118 | print(f'Unfreezing Linear {ind} in the epoch {epoch}. {test_metric}')
119 | target_classifier.load_state_dict(
120 | torch.load(os.path.join(task_save_folder, f'target_classifier_{fold_count}.pt')))
121 |
122 | target_classification_params.append(list(target_classifier.encoder.modules())[ind].parameters())
123 | lr = lr * kwargs['decay_coefficient']
124 | target_classification_optimizer = torch.optim.AdamW(chain(*target_classification_params), lr=lr)
125 | reset_count += 1
126 | except Exception as e:
127 | print(e)
128 | print(test_metric)
129 | break
130 | # if stop_flag and not break_flag:
131 | # print(f'Unfreezing {epoch}')
132 | # target_classifier.load_state_dict(
133 | # torch.load(os.path.join(task_save_folder, f'target_classifier_{fold_count}.pt')))
134 | #
135 | # target_classification_params.append(target_classifier.encoder.shared_encoder.parameters())
136 | # target_classification_params.append(target_classifier.encoder.private_encoder.parameters())
137 | #
138 | # lr = lr * kwargs['decay_coefficient']
139 | # target_classification_optimizer = torch.optim.AdamW(chain(*target_classification_params), lr=lr)
140 | # break_flag = True
141 | # stop_flag = False
142 | # if stop_flag and break_flag:
143 | # break
144 |
145 | target_classifier.load_state_dict(
146 | torch.load(os.path.join(task_save_folder, f'target_classifier_{fold_count}.pt')))
147 |
148 |
149 | return target_classifier, (target_classification_train_history, target_classification_eval_train_history,
150 | target_classification_eval_val_history, target_classification_eval_test_history)#, prediction_df
151 |
152 |
153 | def classification_train_step_drug(model, batch, loss_fn, device, optimizer, history, class_num,scheduler=None, clip=None):
154 | model.zero_grad()
155 | model.train()
156 |
157 | x_smiles = batch[1].to(device)
158 | x_gex = batch[0].to(device)
159 | y = batch[2].to(device)
160 | # print("smiles:",x_smiles)
161 | # print("gex:",x_gex)
162 | # print("label:",y)
163 |
164 | if class_num == 0:
165 | loss = loss_fn(model(x_smiles,x_gex), y.double().unsqueeze(1)) #
166 | else :
167 | loss = loss_fn(model(x_smiles,x_gex), y.float().unsqueeze(1)) #
168 |
169 | optimizer.zero_grad()
170 | loss.backward()
171 | if clip is not None:
172 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
173 |
174 | optimizer.step()
175 | if scheduler is not None:
176 | scheduler.step()
177 |
178 | history['ce'].append(loss.cpu().detach().item())
179 | # history['bce'].append(loss.cpu().detach().item())
180 |
181 | return history
182 |
183 |
184 | #predict pdr result
185 | def predict_pdr(encoder, fold_count, task_save_folder, test_dataloader=None,
186 | metric_name='auroc',
187 | class_num = 2,
188 | normalize_flag=False,
189 | break_flag=False,
190 | pdr_dataloader=None,
191 | drug_emb_dim=128,
192 | **kwargs):
193 | finetune_output_dim = class_num
194 | if finetune_output_dim == 0:
195 | finetune_output_dim = 1
196 | target_decoder = MLP(input_dim=kwargs['latent_dim'] + drug_emb_dim,
197 | output_dim=finetune_output_dim,
198 | hidden_dims=kwargs['classifier_hidden_dims']).to(kwargs['device'])
199 |
200 | target_classifier = EncoderDecoder(encoder=encoder, decoder=target_decoder,
201 | normalize_flag=normalize_flag).to(kwargs['device'])
202 | target_classifier.load_state_dict(
203 | torch.load(os.path.join(task_save_folder, f'target_classifier_{fold_count}.pt')))
204 | print('Sucessfully loaded target_classifier_{}'.format(fold_count))
205 |
206 | target_classification_eval_test_history = defaultdict(list)
207 | if test_dataloader is not None:
208 | target_classification_eval_test_history = evaluate_target_classification_epoch_1(classifier=target_classifier,
209 | dataloader=test_dataloader,
210 | device=kwargs['device'],
211 | history=target_classification_eval_test_history,
212 | class_num = class_num,
213 | test_flag=True)
214 |
215 |
216 | prediction_df = None
217 | if pdr_dataloader is not None:
218 | prediction_df = predict_pdr_score(classifier=target_classifier, pdr_dataloader=pdr_dataloader,
219 | device=kwargs['device'])
220 |
221 | return target_classification_eval_test_history, prediction_df
--------------------------------------------------------------------------------
/code/PDR_task.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import torch
3 | import json
4 | import os
5 | import argparse
6 | import random
7 | import pickle
8 | import itertools
9 | import numpy as np
10 | import data
11 | import data_config
12 | import fine_tuning
13 |
14 | # add model_train path
15 | import sys
16 | sys.path.append('../model_train')
17 |
18 | import train_ae
19 | import train_ae_mmd
20 | import train_ae_adv
21 |
22 | import train_dsn
23 | import train_dsn_mmd
24 | import train_dsn_adv
25 |
26 | import train_dsrn_mmd
27 | import train_dsrn_adv
28 |
29 |
30 | from copy import deepcopy
31 | from collections import defaultdict
32 |
33 |
34 | def wrap_training_params(training_params, type='unlabeled'):
35 | aux_dict = {k: v for k, v in training_params.items() if k not in ['unlabeled', 'labeled']}
36 | aux_dict.update(**training_params[type])
37 |
38 | return aux_dict
39 |
40 | def safe_make_dir(new_folder_name):
41 | if not os.path.exists(new_folder_name):
42 | os.makedirs(new_folder_name)
43 | else:
44 | print(new_folder_name, 'exists!')
45 |
46 | def dict_to_str(d):
47 | return "_".join(["_".join([k, str(v)]) for k, v in d.items()])
48 |
49 |
50 |
51 | def main(args, update_params_dict):
52 | if args.method == 'ae':
53 | train_fn = train_ae.train_ae
54 | elif args.method == 'ae_mmd':
55 | train_fn = train_ae_mmd.train_ae_mmd
56 | elif args.method == 'ae_adv':
57 | train_fn = train_ae_adv.train_ae_adv
58 |
59 | elif args.method == 'dsn':
60 | train_fn = train_dsn.train_dsn
61 | elif args.method == 'dsn_mmd':
62 | train_fn = train_dsn_mmd.train_dsn_mmd
63 | elif args.method == 'dsn_adv':
64 | train_fn = train_dsn_adv.train_dsn_adv
65 |
66 | elif args.method == 'dsrn':
67 | train_fn = train_dsrn.train_dsrn
68 | elif args.method == 'dsrn_mmd':
69 | train_fn = train_dsrn_mmd.train_dsrn_mmd
70 | elif args.method == 'dsrn_adv':
71 | train_fn = train_dsrn_adv.train_dsrn_adv
72 |
73 | else:
74 | raise NotImplementedError("Not true method supplied!")
75 |
76 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
77 |
78 |
79 | with open(os.path.join('model_save/train_params.json'), 'r') as f:
80 | training_params = json.load(f)
81 |
82 | training_params['unlabeled'].update(update_params_dict) ##unlabeled一开始就更新成dict的队列了,labeled一直没有更新,都是json读进来的2000
83 | param_str = dict_to_str(update_params_dict) #给这一版的文件命名,从dict转为str
84 |
85 | source_dir = os.path.join(args.pretrain_dataset,args.tcga_construction,
86 | args.select_gene_method,
87 | CCL_tumor_type,args.CCL_construction,
88 | tumor_type,
89 | args.CCL_dataset,args.select_drug_method)
90 | if args.use_tcga_pretrain_model:
91 | source_dir = os.path.join("tcga",args.tcga_construction,
92 | args.select_gene_method,
93 | CCL_tumor_type,args.CCL_construction,
94 | tumor_type,
95 | args.CCL_dataset,args.select_drug_method)
96 |
97 | if not args.norm_flag:
98 | method_save_folder = os.path.join('../results',args.store_dir, args.method,source_dir)
99 | else:
100 | method_save_folder = os.path.join('../results',args.store_dir, f'{args.method}_norm',source_dir)
101 |
102 | training_params.update(
103 | {
104 | 'device': device,
105 | 'input_dim': gex_features_df.shape[-1],
106 | 'model_save_folder': os.path.join(method_save_folder, param_str),
107 | 'es_flag': False,
108 | 'retrain_flag': args.retrain_flag,
109 | 'norm_flag': args.norm_flag
110 | })
111 | if args.pdtc_flag:
112 | task_save_folder = os.path.join(f'{method_save_folder}', 'predict', 'pdtc')
113 | else:
114 | task_save_folder = os.path.join(f'{method_save_folder}', 'predict')
115 |
116 | # safe_make_dir(training_params['model_save_folder'])
117 | safe_make_dir(task_save_folder) #unlabel result save dir
118 |
119 | random.seed(2020)
120 |
121 | s_dataloaders, t_dataloaders = data.get_unlabeled_dataloaders(
122 | gex_features_df=gex_features_df,
123 | seed=2020,
124 | batch_size=training_params['unlabeled']['batch_size'],
125 | ccle_only=False
126 | )
127 |
128 | #为了tcga_pretrain,所有模型都存在BRCA下
129 | if args.use_tcga_pretrain_model:
130 | training_params.update(
131 | {
132 | 'model_save_folder': os.path.join(
133 | os.path.join('../results',args.store_dir, f'{args.method}_norm',
134 | os.path.join("tcga",args.tcga_construction,
135 | args.select_gene_method,
136 | CCL_tumor_type,args.CCL_construction,
137 | "BRCA", #替换这个
138 | args.CCL_dataset,args.select_drug_method)
139 | ),
140 | param_str), #, 一个模型就存一个unlabel model
141 | })
142 | print(f"model_save_folder:{training_params['model_save_folder']}")
143 |
144 | # start unlabeled training
145 | ##先pretrain code_ae_base,再用adv(GAN) train;pretrain可以没有,train都有
146 | encoder, historys = train_fn(s_dataloaders=s_dataloaders, #若no-train,就根据training_params中的save_folder找到对应Pretrained_model
147 | t_dataloaders=t_dataloaders, #[0]是train,[1]是test
148 | **wrap_training_params(training_params, type='unlabeled'))
149 | # print("Trained SE:",encoder)
150 | if args.use_tcga_pretrain_model:
151 | patient_sample_info = pd.read_csv("../data/preprocessed_dat/xena_sample_info.csv", index_col=0)
152 | patient_samples = gex_features_df.index.intersection(patient_sample_info.loc[patient_sample_info.tumor_type == tumor_type].index)
153 | gex_features_select = gex_features_df.loc[patient_samples]
154 | else:
155 | gex_features_select = gex_features_df
156 |
157 | ##再验证一下zero-shot的结果——数据准备
158 | tcga_ZS_dataloader,_ = data.get_tcga_ZS_dataloaders(gex_features_df=gex_features_select,
159 | label_type = args.label_type,
160 | batch_size=training_params['labeled']['batch_size'],
161 | q=2,
162 | tumor_type = tumor_type)
163 | print("Generating the pdr dataloader......")
164 | #预测pdr的数据——数据准备
165 | pdr_dataloader = data.get_pdr_data_dataloaders(gex_features_df=gex_features_select,
166 | batch_size=training_params['labeled']['batch_size'],
167 | tumor_type = tumor_type)
168 | ft_evaluation_metrics = defaultdict(list)
169 | # 5折开始
170 | for fold_count in range(5):
171 | ft_encoder = deepcopy(encoder)
172 | print(' ')
173 | print('Fold count = {}'.format(fold_count))
174 |
175 | ft_historys , prediction_df = fine_tuning1.predict_pdr(
176 | encoder=ft_encoder,
177 | pdr_dataloader = pdr_dataloader,
178 | test_dataloader=tcga_ZS_dataloader,
179 | fold_count=fold_count,
180 | normalize_flag=args.norm_flag,
181 | metric_name=args.metric,
182 | task_save_folder = training_params['model_save_folder'], #task_save_folder
183 | drug_emb_dim=300,
184 | class_num = args.class_num, #CCL_dataset几分类,需要和train_labeled_ccle_dataloader, test_labeled_ccle_dataloader匹配
185 | **wrap_training_params(training_params, type='labeled')
186 | )
187 | for metric in ['auroc', 'acc', 'aps', 'f1', 'auprc']:
188 | ft_evaluation_metrics[metric].append(ft_historys[metric])
189 | # if ft_historys['auroc'].mean() > :
190 | print(f'Saving results: {param_str} ; Fold count: {fold_count}')
191 | prediction_df.to_csv(os.path.join(task_save_folder, f'{param_str}__{fold_count}__predict.csv'))
192 |
193 | with open(os.path.join(task_save_folder, f'{param_str}_ft_evaluation_results.json'), 'w') as f:
194 | json.dump(ft_evaluation_metrics, f)
195 |
196 |
197 | if __name__ == '__main__':
198 | parser = argparse.ArgumentParser('PDR task')
199 | parser.add_argument('--pretrain_num',default = None,type=int)
200 | parser.add_argument('--zero_shot_num',default = None,type=int)
201 | parser.add_argument('--method_num',default = None,type=int)
202 |
203 | parser.add_argument('--method', dest='method', nargs='?', default='dsn_adv',
204 | choices=['ae','ae_mmd','ae_adv',
205 | 'dsn','dsn_mmd','dsn_adv',
206 | 'dsrn_mmd','dsrn_adv'])
207 | parser.add_argument('--metric', dest='metric', nargs='?', default='auroc', choices=['auroc', 'auprc'])
208 |
209 | parser.add_argument('--measurement', dest='measurement', nargs='?', default='AUC', choices=['AUC', 'LN_IC50'])
210 |
211 | parser.add_argument('--n', dest='n', nargs='?', type=int, default=5)
212 |
213 | train_group = parser.add_mutually_exclusive_group(required=False)
214 | train_group.add_argument('--train', dest='retrain_flag', action='store_true')
215 | train_group.add_argument('--no-train', dest='retrain_flag', action='store_false')
216 | parser.set_defaults(retrain_flag=True)
217 | # parser.set_defaults(retrain_flag=False)
218 |
219 |
220 | parser.add_argument('--label_type', default = "PFS",choices=["PFS","Imaging"])
221 |
222 | parser.add_argument('--select_drug_method', default = "overlap",choices=["overlap","all","random"])
223 |
224 | parser.add_argument('--store_dir',default = "benchmark")
225 | parser.add_argument('--select_gene_method',default = "Percent_sd",choices=["Percent_sd","HVG"])
226 | parser.add_argument('--select_gene_num',default = 1000,type=int)
227 |
228 | parser.add_argument('--pretrain_dataset',default = "tcga",
229 | choices=["tcga", "brca", "cesc", "coad", "gbm", "hnsc", "kirc",
230 | "lgg", "luad", "lusc","paad", "read", "sarc", "skcm", "stad"
231 | ])
232 | parser.add_argument('--tumor_type',default = "BRCA",
233 | choices=['TCGA','GBM', 'LGG', 'HNSC','KIRC','SARC','BRCA','STAD','CESC','SKCM','LUSC','LUAD','READ','COAD'
234 | ])
235 | parser.add_argument('--CCL_dataset',default = 'gdsc1_raw',
236 | choices=['gdsc1_raw','gdsc1_rebalance'])
237 | parser.add_argument('--class_num',default = 0,type=int)
238 |
239 | args = parser.parse_args()
240 |
241 | params_grid = {
242 | "pretrain_num_epochs": [0, 100, 300],
243 | "train_num_epochs": [100, 200, 300, 500, 750, 1000, 1500, 2000, 2500, 3000],
244 | "dop": [0.0, 0.1]
245 | }
246 |
247 | Tumor_type_list = ["tcga", "brca", "cesc", "coad", "gbm", "hnsc", "kirc",
248 | "lgg", "luad", "lusc","paad", "read", "sarc", "skcm", "stad"
249 | ]
250 |
251 | if args.pretrain_num :
252 | args.pretrain_dataset = Tumor_type_list[args.pretrain_num] #24
253 | if args.zero_shot_num :
254 | args.tumor_type = [element.upper() for element in Tumor_type_list][args.zero_shot_num]
255 | # print(f'Tumor type: Select zero_shot_num: {Num}. Zero-shot dataset: {args.tumor_type}')
256 | if args.method_num :
257 | args.method = [
258 | 'ae','dsn',
259 | 'ae_mmd','dsrn_mmd','dsn_mmd',
260 | 'ae_adv','dsrn_adv','dsn_adv'][args.method_num]
261 |
262 | #tcga pretrain需要根据args.tumor_type去指定finetune哪个癌种,其他的默认先是原癌种
263 | tumor_type = args.pretrain_dataset.upper() #需要交叉的时候注释掉这行,直接让【tumor_type = args.tumor_type】自由选择
264 | if tumor_type == "TCGA" :
265 | tumor_type = args.tumor_type
266 |
267 | if args.method not in ['dsrn_adv', 'dsn_adv', 'ae_adv']:
268 | params_grid.pop('pretrain_num_epochs')
269 |
270 | keys, values = zip(*params_grid.items())
271 | update_params_dict_list = [dict(zip(keys, v)) for v in itertools.product(*values)]
272 |
273 | #构建4级文件夹结构,并生成用于unlabel training需要的gex_features_df------
274 | if args.use_tcga_pretrain_model:
275 | patient_tumor_type = "tcga"
276 | else:
277 | patient_tumor_type = args.pretrain_dataset
278 | gex_features_df,CCL_tumor_type,all_ccle_gex,all_patient_gex = data.get_pretrain_dataset(
279 | patient_tumor_type = patient_tumor_type,
280 | CCL_type = args.CCL_type,
281 | tumor_type = tumor_type,
282 | tcga_construction = args.tcga_construction,
283 | CCL_construction = args.CCL_construction,
284 | gene_num = args.select_gene_num,select_gene_method = args.select_gene_method
285 | )
286 |
287 | print(f'Pretrain dataset: Patient({args.pretrain_dataset} {args.pretrain_num} {args.tcga_construction}) CCL({args.CCL_type} {CCL_tumor_type} {args.CCL_construction}). Select_gene_method: {args.select_gene_method}')
288 | print(f'Zero-shot dataset: {tumor_type}({args.zero_shot_num})')
289 | print(f'CCL_dataset: {args.CCL_dataset} Select_drug_method: {args.select_drug_method}')
290 | print(f'Store_dir: {args.store_dir} ')
291 |
292 | print(f'pdtc_flag: {args.pdtc_flag}. method: {args.method}({args.method_num}). label_type: {args.label_type}')
293 | param_num = 0
294 |
295 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1'
296 | # update_params_dict_list.reverse()
297 |
298 | for param_dict in update_params_dict_list:
299 | param_num = param_num + 1
300 | print(' ')
301 | print('##############################################################################')
302 | print(f'####### Param_num {param_num}/{len(update_params_dict_list)} #######')
303 | print('Param_dict: {}'.format(param_dict) )
304 | print('##############################################################################')
305 | try:
306 | main(args=args, update_params_dict=param_dict)
307 | except Exception as e:
308 | print(e)
309 | print(param_dict,param_num)
310 | print("Finsh All !!!!!!!!!!!!!!!")
311 |
--------------------------------------------------------------------------------
/code/model_train/train_dsn_adv.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.autograd as autograd
3 | from itertools import chain
4 | from dsn_ae import DSNAE
5 | from evaluation_utils import *
6 | from mlp import MLP
7 | from train_dsn import eval_dsnae_epoch, dsn_ae_train_step
8 | from collections import OrderedDict
9 |
10 | def compute_gradient_penalty(critic, real_samples, fake_samples, device):
11 | """Calculates the gradient penalty loss for WGAN GP"""
12 | # Random weight term for interpolation between real and fake samples
13 | alpha = torch.rand((real_samples.shape[0], 1)).to(device)
14 | # Get random interpolation between real and fake samples
15 | interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
16 | critic_interpolates = critic(interpolates)
17 | fakes = torch.ones((real_samples.shape[0], 1)).to(device)
18 | # Get gradient w.r.t. interpolates
19 | gradients = autograd.grad(
20 | outputs=critic_interpolates,
21 | inputs=interpolates,
22 | grad_outputs=fakes,
23 | create_graph=True,
24 | retain_graph=True,
25 | only_inputs=True,
26 | )[0]
27 | gradients = gradients.view(gradients.size(0), -1)
28 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
29 | return gradient_penalty
30 |
31 |
32 | def critic_dsn_train_step(critic, s_dsnae, t_dsnae, s_batch, t_batch, device, optimizer, history, scheduler=None,
33 | clip=None, gp=None):
34 | critic.zero_grad()
35 | s_dsnae.zero_grad()
36 | t_dsnae.zero_grad()
37 | s_dsnae.eval()
38 | t_dsnae.eval()
39 | critic.train()
40 |
41 | s_x = s_batch[0].to(device)
42 | t_x = t_batch[0].to(device)
43 |
44 | s_code = s_dsnae.s_encode(s_x)
45 | t_code = t_dsnae.s_encode(t_x)
46 |
47 | loss = torch.mean(critic(t_code)) - torch.mean(critic(s_code))
48 | if gp is not None:
49 | gradient_penalty = compute_gradient_penalty(critic,
50 | real_samples=s_code,
51 | fake_samples=t_code,
52 | device=device)
53 | loss = loss + gp * gradient_penalty
54 |
55 | optimizer.zero_grad()
56 | loss.backward()
57 | # if clip is not None:
58 | # torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
59 | optimizer.step()
60 |
61 | if clip is not None:
62 | for p in critic.parameters():
63 | p.data.clamp_(-clip, clip)
64 | if scheduler is not None:
65 | scheduler.step()
66 |
67 | history['critic_loss'].append(loss.cpu().detach().item())
68 |
69 | return history
70 |
71 |
72 | def gan_dsn_gen_train_step(critic, s_dsnae, t_dsnae, s_batch, t_batch, device, optimizer, alpha, history,
73 | scheduler=None):
74 | critic.zero_grad()
75 | s_dsnae.zero_grad()
76 | t_dsnae.zero_grad()
77 | critic.eval()
78 | s_dsnae.train()
79 | t_dsnae.train()
80 |
81 | s_x = s_batch[0].to(device)
82 | t_x = t_batch[0].to(device)
83 |
84 | t_code = t_dsnae.s_encode(t_x)
85 |
86 | optimizer.zero_grad()
87 | gen_loss = -torch.mean(critic(t_code))
88 | s_loss_dict = s_dsnae.loss_function(*s_dsnae(s_x))
89 | t_loss_dict = t_dsnae.loss_function(*t_dsnae(t_x))
90 | recons_loss = s_loss_dict['loss'] + t_loss_dict['loss']
91 | loss = recons_loss + alpha * gen_loss
92 | optimizer.zero_grad()
93 |
94 | loss.backward()
95 | optimizer.step()
96 | if scheduler is not None:
97 | scheduler.step()
98 |
99 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()}
100 |
101 | for k, v in loss_dict.items():
102 | history[k].append(v)
103 | history['gen_loss'].append(gen_loss.cpu().detach().item())
104 |
105 | return history
106 |
107 |
108 | def train_dsn_adv(s_dataloaders, t_dataloaders, **kwargs):
109 | """
110 |
111 | :param s_dataloaders:
112 | :param t_dataloaders:
113 | :param kwargs:
114 | :return:
115 | """
116 | s_train_dataloader = s_dataloaders[0]
117 | s_test_dataloader = s_dataloaders[1]
118 |
119 | t_train_dataloader = t_dataloaders[0]
120 | t_test_dataloader = t_dataloaders[1]
121 |
122 | shared_encoder = MLP(input_dim=kwargs['input_dim'],
123 | output_dim=kwargs['latent_dim'],
124 | hidden_dims=kwargs['encoder_hidden_dims'],
125 | dop=kwargs['dop']).to(kwargs['device'])
126 |
127 | shared_decoder = MLP(input_dim=2 * kwargs['latent_dim'],
128 | output_dim=kwargs['input_dim'],
129 | hidden_dims=kwargs['encoder_hidden_dims'][::-1],
130 | dop=kwargs['dop']).to(kwargs['device'])
131 |
132 | s_dsnae = DSNAE(shared_encoder=shared_encoder,
133 | decoder=shared_decoder,
134 | alpha=kwargs['alpha'],
135 | input_dim=kwargs['input_dim'],
136 | latent_dim=kwargs['latent_dim'],
137 | hidden_dims=kwargs['encoder_hidden_dims'],
138 | dop=kwargs['dop'],
139 | norm_flag=kwargs['norm_flag']).to(kwargs['device'])
140 |
141 | t_dsnae = DSNAE(shared_encoder=shared_encoder,
142 | decoder=shared_decoder,
143 | alpha=kwargs['alpha'],
144 | input_dim=kwargs['input_dim'],
145 | latent_dim=kwargs['latent_dim'],
146 | hidden_dims=kwargs['encoder_hidden_dims'],
147 | dop=kwargs['dop'],
148 | norm_flag=kwargs['norm_flag']).to(kwargs['device'])
149 |
150 | confounding_classifier = MLP(input_dim=kwargs['latent_dim'],
151 | output_dim=1,
152 | hidden_dims=kwargs['classifier_hidden_dims'],
153 | dop=kwargs['dop']).to(kwargs['device'])
154 |
155 |
156 | dsnae_train_history = defaultdict(list)
157 | dsnae_val_history = defaultdict(list)
158 | critic_train_history = defaultdict(list)
159 | gen_train_history = defaultdict(list)
160 | # classification_eval_test_history = defaultdict(list)
161 | # classification_eval_train_history = defaultdict(list)
162 |
163 | if kwargs['retrain_flag']:
164 | ae_params = [t_dsnae.private_encoder.parameters(),
165 | s_dsnae.private_encoder.parameters(),
166 | shared_decoder.parameters(),
167 | shared_encoder.parameters()
168 | ]
169 | t_ae_params = [t_dsnae.private_encoder.parameters(),
170 | s_dsnae.private_encoder.parameters(),
171 | shared_decoder.parameters(),
172 | shared_encoder.parameters()
173 | ]
174 |
175 | ae_optimizer = torch.optim.AdamW(chain(*ae_params), lr=kwargs['lr'])
176 | classifier_optimizer = torch.optim.RMSprop(confounding_classifier.parameters(), lr=kwargs['lr'])
177 | t_ae_optimizer = torch.optim.RMSprop(chain(*t_ae_params), lr=kwargs['lr'])
178 |
179 |
180 | # start dsnae pre-training
181 | for epoch in range(int(kwargs['pretrain_num_epochs'])):
182 | if epoch % 50 == 0:
183 | print(f'AE training epoch {epoch}')
184 | for step, s_batch in enumerate(s_train_dataloader):
185 | t_batch = next(iter(t_train_dataloader))
186 | dsnae_train_history = dsn_ae_train_step(s_dsnae=s_dsnae,
187 | t_dsnae=t_dsnae,
188 | s_batch=s_batch,
189 | t_batch=t_batch,
190 | device=kwargs['device'],
191 | optimizer=ae_optimizer,
192 | history=dsnae_train_history)
193 | dsnae_val_history = eval_dsnae_epoch(model=s_dsnae,
194 | data_loader=s_test_dataloader,
195 | device=kwargs['device'],
196 | history=dsnae_val_history
197 | )
198 | dsnae_val_history = eval_dsnae_epoch(model=t_dsnae,
199 | data_loader=t_test_dataloader,
200 | device=kwargs['device'],
201 | history=dsnae_val_history
202 | )
203 | for k in dsnae_val_history:
204 | if k != 'best_index':
205 | dsnae_val_history[k][-2] += dsnae_val_history[k][-1]
206 | dsnae_val_history[k].pop()
207 |
208 | save_flag, stop_flag = model_save_check(dsnae_val_history, metric_name='loss', tolerance_count=20)
209 | if kwargs['es_flag']:
210 | if save_flag:
211 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'a_s_dsnae.pt'))
212 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt'))
213 | if stop_flag:
214 | break
215 |
216 | if kwargs['es_flag']:
217 | s_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'a_s_dsnae.pt')))
218 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt')))
219 |
220 | # start critic pre-training
221 | # for epoch in range(100):
222 | # if epoch % 10 == 0:
223 | # print(f'confounder critic pre-training epoch {epoch}')
224 | # for step, t_batch in enumerate(s_train_dataloader):
225 | # s_batch = next(iter(t_train_dataloader))
226 | # critic_train_history = critic_dsn_train_step(critic=confounding_classifier,
227 | # s_dsnae=s_dsnae,
228 | # t_dsnae=t_dsnae,
229 | # s_batch=s_batch,
230 | # t_batch=t_batch,
231 | # device=kwargs['device'],
232 | # optimizer=classifier_optimizer,
233 | # history=critic_train_history,
234 | # clip=None,
235 | # gp=None)
236 | # start GAN training
237 | for epoch in range(int(kwargs['train_num_epochs'])):
238 | if epoch % 50 == 0:
239 | print(f'confounder wgan training epoch {epoch}')
240 | for step, s_batch in enumerate(s_train_dataloader):
241 | t_batch = next(iter(t_train_dataloader))
242 | critic_train_history = critic_dsn_train_step(critic=confounding_classifier,
243 | s_dsnae=s_dsnae,
244 | t_dsnae=t_dsnae,
245 | s_batch=s_batch,
246 | t_batch=t_batch,
247 | device=kwargs['device'],
248 | optimizer=classifier_optimizer,
249 | history=critic_train_history,
250 | # clip=0.1,
251 | gp=10.0)
252 | if (step + 1) % 5 == 0:
253 | gen_train_history = gan_dsn_gen_train_step(critic=confounding_classifier,
254 | s_dsnae=s_dsnae,
255 | t_dsnae=t_dsnae,
256 | s_batch=s_batch,
257 | t_batch=t_batch,
258 | device=kwargs['device'],
259 | optimizer=t_ae_optimizer,
260 | alpha=1.0,
261 | history=gen_train_history)
262 |
263 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'a_s_dsnae.pt'))
264 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt'))
265 | else:
266 | try:
267 | # loaded_model = torch.load(os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt'))
268 | # new_loaded_model = {key: val for key, val in loaded_model.items() if key in t_dsnae.state_dict()}
269 | # new_loaded_model['shared_encoder.output_layer.0.weight'] = loaded_model[
270 | # 'shared_encoder.output_layer.3.weight']
271 | # new_loaded_model['shared_encoder.output_layer.0.bias'] = loaded_model[
272 | # 'shared_encoder.output_layer.3.bias']
273 | # new_loaded_model['decoder.output_layer.0.weight'] = loaded_model['decoder.output_layer.3.weight']
274 | # new_loaded_model['decoder.output_layer.0.bias'] = loaded_model['decoder.output_layer.3.bias']
275 |
276 | # corrected_model = OrderedDict({key: new_loaded_model[key] for key in t_dsnae.state_dict()})
277 | # t_dsnae.load_state_dict(corrected_model)
278 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt')))
279 | except FileNotFoundError:
280 | raise Exception("No pre-trained encoder")
281 |
282 | return t_dsnae.shared_encoder, (dsnae_train_history, dsnae_val_history, critic_train_history, gen_train_history)
283 |
--------------------------------------------------------------------------------
/code/model_train/train_dsrn_adv.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.autograd as autograd
3 | from itertools import chain
4 | from dsn_ae import DSNAE
5 | from evaluation_utils import *
6 | from mlp import MLP
7 | from train_dsn import eval_dsnae_epoch, dsn_ae_train_step
8 | from collections import OrderedDict
9 |
10 | def compute_gradient_penalty(critic, real_samples, fake_samples, device):
11 | """Calculates the gradient penalty loss for WGAN GP"""
12 | # Random weight term for interpolation between real and fake samples
13 | alpha = torch.rand((real_samples.shape[0], 1)).to(device)
14 | # Get random interpolation between real and fake samples
15 | interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
16 | critic_interpolates = critic(interpolates)
17 | fakes = torch.ones((real_samples.shape[0], 1)).to(device)
18 | # Get gradient w.r.t. interpolates
19 | gradients = autograd.grad(
20 | outputs=critic_interpolates,
21 | inputs=interpolates,
22 | grad_outputs=fakes,
23 | create_graph=True,
24 | retain_graph=True,
25 | only_inputs=True,
26 | )[0]
27 | gradients = gradients.view(gradients.size(0), -1)
28 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
29 | return gradient_penalty
30 |
31 |
32 | def critic_dsn_train_step(critic, s_dsnae, t_dsnae, s_batch, t_batch, device, optimizer, history, scheduler=None,
33 | clip=None, gp=None):
34 | critic.zero_grad()
35 | s_dsnae.zero_grad()
36 | t_dsnae.zero_grad()
37 | s_dsnae.eval()
38 | t_dsnae.eval()
39 | critic.train()
40 |
41 | s_x = s_batch[0].to(device)
42 | t_x = t_batch[0].to(device)
43 |
44 | s_code = s_dsnae.encode(s_x)
45 | t_code = t_dsnae.encode(t_x)
46 |
47 | loss = torch.mean(critic(t_code)) - torch.mean(critic(s_code))
48 |
49 | if gp is not None:
50 | gradient_penalty = compute_gradient_penalty(critic,
51 | real_samples=s_code,
52 | fake_samples=t_code,
53 | device=device)
54 | loss = loss + gp * gradient_penalty
55 |
56 | optimizer.zero_grad()
57 | loss.backward()
58 | # if clip is not None:
59 | # torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
60 | optimizer.step()
61 |
62 | if clip is not None:
63 | for p in critic.parameters():
64 | p.data.clamp_(-clip, clip)
65 | if scheduler is not None:
66 | scheduler.step()
67 |
68 | history['critic_loss'].append(loss.cpu().detach().item())
69 |
70 | return history
71 |
72 |
73 | def gan_dsn_gen_train_step(critic, s_dsnae, t_dsnae, s_batch, t_batch, device, optimizer, alpha, history,
74 | scheduler=None):
75 | critic.zero_grad()
76 | s_dsnae.zero_grad()
77 | t_dsnae.zero_grad()
78 | critic.eval()
79 | s_dsnae.train()
80 | t_dsnae.train()
81 |
82 | s_x = s_batch[0].to(device)
83 | t_x = t_batch[0].to(device)
84 |
85 | t_code = t_dsnae.encode(t_x)
86 |
87 | optimizer.zero_grad()
88 | gen_loss = -torch.mean(critic(t_code)) #生成器的loss,为了让target预测尽可能大,loss就要尽可能小
89 | s_loss_dict = s_dsnae.loss_function(*s_dsnae(s_x))
90 | t_loss_dict = t_dsnae.loss_function(*t_dsnae(t_x))
91 | recons_loss = s_loss_dict['loss'] + t_loss_dict['loss'] #dict中取第一项loss,将source和target的总loss(L code−ae−base)相加
92 | loss = recons_loss + alpha * gen_loss
93 | optimizer.zero_grad()
94 |
95 | loss.backward()
96 | optimizer.step()
97 | if scheduler is not None:
98 | scheduler.step()
99 |
100 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()}
101 |
102 | for k, v in loss_dict.items():
103 | history[k].append(v)
104 | history['gen_loss'].append(gen_loss.cpu().detach().item())
105 |
106 | return history
107 |
108 |
109 | def train_dsrn_adv(s_dataloaders, t_dataloaders, **kwargs):
110 | """
111 |
112 | :param s_dataloaders:
113 | :param t_dataloaders:
114 | :param kwargs:
115 | :return:
116 | """
117 | s_train_dataloader = s_dataloaders[0]
118 | s_test_dataloader = s_dataloaders[1]
119 |
120 | t_train_dataloader = t_dataloaders[0]
121 | t_test_dataloader = t_dataloaders[1]
122 |
123 | shared_encoder = MLP(input_dim=kwargs['input_dim'],
124 | output_dim=kwargs['latent_dim'],
125 | hidden_dims=kwargs['encoder_hidden_dims'],
126 | dop=kwargs['dop']).to(kwargs['device'])
127 |
128 | shared_decoder = MLP(input_dim=2 * kwargs['latent_dim'],
129 | output_dim=kwargs['input_dim'],
130 | hidden_dims=kwargs['encoder_hidden_dims'][::-1],
131 | dop=kwargs['dop']).to(kwargs['device'])
132 |
133 | s_dsnae = DSNAE(shared_encoder=shared_encoder,
134 | decoder=shared_decoder,
135 | alpha=kwargs['alpha'],
136 | input_dim=kwargs['input_dim'],
137 | latent_dim=kwargs['latent_dim'],
138 | hidden_dims=kwargs['encoder_hidden_dims'],
139 | dop=kwargs['dop'],
140 | norm_flag=kwargs['norm_flag']).to(kwargs['device'])
141 |
142 | t_dsnae = DSNAE(shared_encoder=shared_encoder,
143 | decoder=shared_decoder,
144 | alpha=kwargs['alpha'],
145 | input_dim=kwargs['input_dim'],
146 | latent_dim=kwargs['latent_dim'],
147 | hidden_dims=kwargs['encoder_hidden_dims'],
148 | dop=kwargs['dop'],
149 | norm_flag=kwargs['norm_flag']).to(kwargs['device'])
150 |
151 | confounding_classifier = MLP(input_dim=kwargs['latent_dim'] * 2,
152 | output_dim=1,
153 | hidden_dims=kwargs['classifier_hidden_dims'],
154 | dop=kwargs['dop']).to(kwargs['device'])
155 |
156 | ae_params = [t_dsnae.private_encoder.parameters(),
157 | s_dsnae.private_encoder.parameters(),
158 | shared_decoder.parameters(),
159 | shared_encoder.parameters()
160 | ]
161 | t_ae_params = [t_dsnae.private_encoder.parameters(),
162 | s_dsnae.private_encoder.parameters(),
163 | shared_decoder.parameters(),
164 | shared_encoder.parameters()
165 | ]
166 |
167 | ae_optimizer = torch.optim.AdamW(chain(*ae_params), lr=kwargs['lr'])
168 | classifier_optimizer = torch.optim.RMSprop(confounding_classifier.parameters(), lr=kwargs['lr'])
169 | t_ae_optimizer = torch.optim.RMSprop(chain(*t_ae_params), lr=kwargs['lr'])
170 |
171 | dsnae_train_history = defaultdict(list)
172 | dsnae_val_history = defaultdict(list)
173 | critic_train_history = defaultdict(list)
174 | gen_train_history = defaultdict(list)
175 | # classification_eval_test_history = defaultdict(list)
176 | # classification_eval_train_history = defaultdict(list)
177 |
178 | # print(os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt'))
179 | if kwargs['retrain_flag']:
180 | # start dsnae pre-training
181 | for epoch in range(int(kwargs['pretrain_num_epochs'])):
182 | if epoch % 50 == 0:
183 | print(f'AE training epoch {epoch}')
184 | for step, s_batch in enumerate(s_train_dataloader):
185 | t_batch = next(iter(t_train_dataloader))
186 | dsnae_train_history = dsn_ae_train_step(s_dsnae=s_dsnae,
187 | t_dsnae=t_dsnae,
188 | s_batch=s_batch,
189 | t_batch=t_batch,
190 | device=kwargs['device'],
191 | optimizer=ae_optimizer,
192 | history=dsnae_train_history)
193 | dsnae_val_history = eval_dsnae_epoch(model=s_dsnae,
194 | data_loader=s_test_dataloader,
195 | device=kwargs['device'],
196 | history=dsnae_val_history
197 | )
198 | dsnae_val_history = eval_dsnae_epoch(model=t_dsnae,
199 | data_loader=t_test_dataloader,
200 | device=kwargs['device'],
201 | history=dsnae_val_history
202 | )
203 | for k in dsnae_val_history:
204 | if k != 'best_index':
205 | dsnae_val_history[k][-2] += dsnae_val_history[k][-1]
206 | dsnae_val_history[k].pop()
207 |
208 | if kwargs['es_flag']:
209 | save_flag, stop_flag = model_save_check(dsnae_val_history, metric_name='loss', tolerance_count=20)
210 | if save_flag:
211 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'a_s_dsnae.pt'))
212 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt'))
213 | if stop_flag:
214 | break
215 | if kwargs['es_flag']:
216 | s_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'a_s_dsnae.pt')))
217 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt')))
218 |
219 | # start critic pre-training
220 | # for epoch in range(100):
221 | # if epoch % 10 == 0:
222 | # print(f'confounder critic pre-training epoch {epoch}')
223 | # for step, t_batch in enumerate(s_train_dataloader):
224 | # s_batch = next(iter(t_train_dataloader))
225 | # critic_train_history = critic_dsn_train_step(critic=confounding_classifier,
226 | # s_dsnae=s_dsnae,
227 | # t_dsnae=t_dsnae,
228 | # s_batch=s_batch,
229 | # t_batch=t_batch,
230 | # device=kwargs['device'],
231 | # optimizer=classifier_optimizer,
232 | # history=critic_train_history,
233 | # clip=None,
234 | # gp=None)
235 | # start GAN training
236 | for epoch in range(int(kwargs['train_num_epochs'])):
237 | if epoch % 50 == 0:
238 | print(f'confounder wgan training epoch {epoch}')
239 | for step, s_batch in enumerate(s_train_dataloader):
240 | t_batch = next(iter(t_train_dataloader))
241 |
242 | critic_train_history = critic_dsn_train_step(critic=confounding_classifier, #让判别器学的更好target分高,source分低,再加上梯度限制:loss = torch.mean(critic(t_code)) - torch.mean(critic(s_code)) + gradient_penalty
243 | s_dsnae=s_dsnae,
244 | t_dsnae=t_dsnae,
245 | s_batch=s_batch,
246 | t_batch=t_batch,
247 | device=kwargs['device'],
248 | optimizer=classifier_optimizer,
249 | history=critic_train_history,
250 | # clip=0.1,
251 | gp=10.0)
252 | if (step + 1) % 5 == 0:
253 | gen_train_history = gan_dsn_gen_train_step(critic=confounding_classifier,#让生成器更好:输入target生成更像source的表示,即分越高,但loss越小越好,故加上负号
254 | s_dsnae=s_dsnae,
255 | t_dsnae=t_dsnae,
256 | s_batch=s_batch,
257 | t_batch=t_batch,
258 | device=kwargs['device'],
259 | optimizer=t_ae_optimizer,
260 | alpha=1.0,
261 | history=gen_train_history)
262 |
263 | torch.save(s_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'a_s_dsnae.pt'))
264 | torch.save(t_dsnae.state_dict(), os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt'))
265 |
266 | # s_batch = next(iter(s_train_dataloader))
267 | # t_batch = next(iter(t_train_dataloader))
268 | # s_x = s_batch[0].to('cuda')
269 | # t_x = t_batch[0].to('cuda')
270 | # print(f"s_x: {s_x}")
271 | # print(f"t_x: {t_x}")
272 | # print(f"Patient encoding: {t_dsnae(t_x)}")
273 | # print(f"Cell line encoding: {t_dsnae(s_x)}")
274 | # exit()
275 | else:
276 | try:
277 | # if kwargs['norm_flag']:
278 | # loaded_model = torch.load(os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt'))
279 | # new_loaded_model = {key: val for key, val in loaded_model.items() if key in t_dsnae.state_dict()}
280 | # new_loaded_model['shared_encoder.output_layer.0.weight'] = loaded_model[
281 | # 'shared_encoder.output_layer.3.weight']
282 | # new_loaded_model['shared_encoder.output_layer.0.bias'] = loaded_model[
283 | # 'shared_encoder.output_layer.3.bias']
284 | # new_loaded_model['decoder.output_layer.0.weight'] = loaded_model['decoder.output_layer.3.weight']
285 | # new_loaded_model['decoder.output_layer.0.bias'] = loaded_model['decoder.output_layer.3.bias']
286 |
287 | # corrected_model = OrderedDict({key: new_loaded_model[key] for key in t_dsnae.state_dict()})
288 | # t_dsnae.load_state_dict(corrected_model)
289 | # else:
290 | #t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt')))
291 | t_dsnae.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'a_t_dsnae.pt')))
292 | # t_dsnae.eval()
293 |
294 | except FileNotFoundError:
295 | raise Exception("No pre-trained encoder")
296 |
297 | return t_dsnae.shared_encoder, (dsnae_train_history, dsnae_val_history, critic_train_history, gen_train_history)
298 |
299 |
--------------------------------------------------------------------------------
/code/model_train/train_ae_adv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import os
5 | from evaluation_utils import eval_ae_epoch, evaluate_adv_classification_epoch, model_save_check
6 | from collections import defaultdict
7 | from ae import AE
8 | from mlp import MLP
9 | from encoder_decoder import EncoderDecoder
10 |
11 | def ae_train_step(ae, s_batch, t_batch, device, optimizer, history, scheduler=None):
12 | ae.zero_grad()
13 | ae.train()
14 |
15 | s_x = s_batch[0].to(device)
16 | t_x = t_batch[0].to(device)
17 |
18 | s_loss_dict = ae.loss_function(*ae(s_x))
19 | t_loss_dict = ae.loss_function(*ae(t_x))
20 |
21 | optimizer.zero_grad()
22 | loss = s_loss_dict['loss'] + t_loss_dict['loss']
23 | optimizer.zero_grad()
24 |
25 | loss.backward()
26 | optimizer.step()
27 | if scheduler is not None:
28 | scheduler.step()
29 |
30 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()}
31 |
32 | for k, v in loss_dict.items():
33 | history[k].append(v)
34 |
35 | return history
36 |
37 |
38 | def classification_train_step(classifier, s_batch, t_batch, loss_fn, device, optimizer, history, scheduler=None,
39 | clip=None):
40 | classifier.zero_grad()
41 | classifier.train()
42 |
43 | s_x = s_batch[0].to(device)
44 | t_x = t_batch[0].to(device)
45 | outputs = torch.cat((classifier(s_x), classifier(t_x)), dim=0)
46 | truths = torch.cat((torch.zeros(s_x.shape[0], 1), torch.ones(t_x.shape[0], 1)), dim=0).to(device)
47 | loss = loss_fn(outputs, truths)
48 |
49 | # valid = torch.ones((s_x.shape[0], 1)).to(device)
50 | # fake = torch.zeros((t_x.shape[0], 1)).to(device)
51 | #
52 | # real_loss = loss_fn((classifier(s_x)), valid)
53 | # fake_loss = loss_fn(classifier(t_x), fake)
54 | # loss = 0.5 * (real_loss + fake_loss)
55 |
56 | optimizer.zero_grad()
57 | loss.backward()
58 | # if clip is not None:
59 | # torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
60 | optimizer.step()
61 | if clip is not None:
62 | for p in classifier.decoder.parameters():
63 | p.data.clamp_(-clip, clip)
64 | if scheduler is not None:
65 | scheduler.step()
66 |
67 | history['bce'].append(loss.cpu().detach().item())
68 |
69 | return history
70 |
71 |
72 | def customized_ae_train_step(classifier, ae, s_batch, t_batch, loss_fn, alpha, device, optimizer, history,
73 | scheduler=None):
74 | classifier.zero_grad()
75 | ae.zero_grad()
76 | classifier.eval()
77 | ae.train()
78 |
79 | s_x = s_batch[0].to(device)
80 | t_x = t_batch[0].to(device)
81 |
82 | outputs = torch.cat((classifier(s_x), classifier(t_x)), dim=0)
83 | truths = torch.cat((torch.zeros(s_x.shape[0], 1), torch.ones(t_x.shape[0], 1)), dim=0).to(device)
84 | adv_loss = loss_fn(outputs, truths)
85 |
86 | # valid = torch.ones((s_x.shape[0], 1)).to(device)
87 | # fake = torch.zeros((t_x.shape[0], 1)).to(device)
88 | #
89 | # real_loss = loss_fn((classifier(s_x)), valid)
90 | # fake_loss = loss_fn(classifier(t_x), fake)
91 | # adv_loss = 0.5 * (real_loss + fake_loss)
92 |
93 | s_loss_dict = ae.loss_function(*ae(s_x))
94 | t_loss_dict = ae.loss_function(*ae(t_x))
95 | loss = s_loss_dict['loss'] + t_loss_dict['loss'] - alpha * adv_loss
96 |
97 | optimizer.zero_grad()
98 | loss.backward()
99 | optimizer.step()
100 |
101 | if scheduler is not None:
102 | scheduler.step()
103 |
104 | loss_dict = {k: v.cpu().detach().item() + t_loss_dict[k].cpu().detach().item() for k, v in s_loss_dict.items()}
105 | for k, v in loss_dict.items():
106 | history[k].append(v)
107 | # history['bce'].append(adv_loss.cpu().detach().item())
108 |
109 | return history
110 |
111 |
112 | def train_ae_adv(s_dataloaders, t_dataloaders, **kwargs):
113 | """
114 |
115 | :param s_dataloaders:
116 | :param t_dataloaders:
117 | :param kwargs:
118 | :return:
119 | """
120 | s_train_dataloader = s_dataloaders[0]
121 | s_test_dataloader = s_dataloaders[1]
122 |
123 | t_train_dataloader = t_dataloaders[0]
124 | t_test_dataloader = t_dataloaders[1]
125 |
126 | autoencoder = AE(input_dim=kwargs['input_dim'],
127 | latent_dim=kwargs['latent_dim'],
128 | hidden_dims=kwargs['encoder_hidden_dims'],
129 | dop=kwargs['dop']).to(kwargs['device'])
130 | classifier = MLP(input_dim=kwargs['latent_dim'],
131 | output_dim=1,
132 | hidden_dims=kwargs['classifier_hidden_dims'],
133 | dop=kwargs['dop']).to(kwargs['device'])
134 | confounder_classifier = EncoderDecoder(encoder=autoencoder.encoder, decoder=classifier).to(kwargs['device'])
135 |
136 | ae_eval_train_history = defaultdict(list)
137 | ae_eval_val_history = defaultdict(list)
138 | classifier_pretrain_history = defaultdict(list)
139 | classification_eval_test_history = defaultdict(list)
140 | classification_eval_train_history = defaultdict(list)
141 |
142 | if kwargs['retrain_flag']:
143 | confounded_loss = nn.BCEWithLogitsLoss()
144 | ae_optimizer = torch.optim.AdamW(autoencoder.parameters(), lr=kwargs['lr'])
145 | classifier_optimizer = torch.optim.AdamW(confounder_classifier.decoder.parameters(), lr=kwargs['lr'])
146 |
147 |
148 | # start autoencoder pretraining
149 | for epoch in range(int(kwargs['pretrain_num_epochs'])):
150 | if epoch % 50 == 0:
151 | print(f'----Autoencoder Pre-Training Epoch {epoch} ----')
152 |
153 | for step, s_batch in enumerate(s_train_dataloader):
154 | t_batch = next(iter(t_train_dataloader))
155 | ae_eval_train_history = ae_train_step(ae=autoencoder,
156 | s_batch=s_batch,
157 | t_batch=t_batch,
158 | device=kwargs['device'],
159 | optimizer=ae_optimizer,
160 | history=ae_eval_train_history)
161 |
162 | ae_eval_val_history = eval_ae_epoch(model=autoencoder,
163 | data_loader=s_test_dataloader,
164 | device=kwargs['device'],
165 | history=ae_eval_val_history
166 | )
167 | ae_eval_val_history = eval_ae_epoch(model=autoencoder,
168 | data_loader=t_test_dataloader,
169 | device=kwargs['device'],
170 | history=ae_eval_val_history
171 | )
172 | for k in ae_eval_val_history:
173 | if k != 'best_index':
174 | ae_eval_val_history[k][-2] += ae_eval_val_history[k][-1]
175 | ae_eval_val_history[k].pop()
176 | # print some loss/metric messages
177 | if kwargs['es_flag']:
178 | save_flag, stop_flag = model_save_check(history=ae_eval_val_history, metric_name='loss',
179 | tolerance_count=10)
180 | if save_flag:
181 | torch.save(autoencoder.state_dict(), os.path.join(kwargs['model_save_folder'], 'ae.pt'))
182 | if stop_flag:
183 | break
184 |
185 | if kwargs['es_flag']:
186 | autoencoder.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt')))
187 |
188 | # start adversarial classifier pre-training
189 | for epoch in range(int(kwargs['pretrain_num_epochs'])):
190 | if epoch % 50 == 0:
191 | print(f'Adversarial classifier pre-training epoch {epoch}')
192 | for step, s_batch in enumerate(s_train_dataloader):
193 | t_batch = next(iter(t_train_dataloader))
194 | classifier_pretrain_history = classification_train_step(classifier=confounder_classifier,
195 | s_batch=s_batch,
196 | t_batch=t_batch,
197 | loss_fn=confounded_loss,
198 | device=kwargs['device'],
199 | optimizer=classifier_optimizer,
200 | history=classifier_pretrain_history)
201 |
202 | classification_eval_test_history = evaluate_adv_classification_epoch(classifier=confounder_classifier,
203 | s_dataloader=s_test_dataloader,
204 | t_dataloader=t_test_dataloader,
205 | device=kwargs['device'],
206 | history=classification_eval_test_history)
207 | classification_eval_train_history = evaluate_adv_classification_epoch(classifier=confounder_classifier,
208 | s_dataloader=s_train_dataloader,
209 | t_dataloader=t_train_dataloader,
210 | device=kwargs['device'],
211 | history=classification_eval_train_history)
212 |
213 | save_flag, stop_flag = model_save_check(history=classification_eval_test_history, metric_name='acc',
214 | tolerance_count=50)
215 | if kwargs['es_flag']:
216 | if save_flag:
217 | torch.save(confounder_classifier.state_dict(),
218 | os.path.join(kwargs['model_save_folder'], 'adv_classifier.pt'))
219 | if stop_flag:
220 | break
221 |
222 | if kwargs['es_flag']:
223 | confounder_classifier.load_state_dict(
224 | torch.load(os.path.join(kwargs['model_save_folder'], 'adv_classifier.pt')))
225 |
226 | # start alternative training
227 | for epoch in range(int(kwargs['train_num_epochs'])):
228 | if epoch % 50 == 0:
229 | print(f'Alternative training epoch {epoch}')
230 | # start autoencoder training epoch
231 | for step, s_batch in enumerate(s_train_dataloader):
232 | t_batch = next(iter(t_train_dataloader))
233 | ae_eval_train_history = customized_ae_train_step(classifier=confounder_classifier,
234 | ae=autoencoder,
235 | s_batch=s_batch,
236 | t_batch=t_batch,
237 | loss_fn=confounded_loss,
238 | alpha=kwargs['alpha'],
239 | device=kwargs['device'],
240 | optimizer=ae_optimizer,
241 | history=ae_eval_train_history,
242 | scheduler=None)
243 |
244 | ae_eval_val_history = eval_ae_epoch(model=autoencoder,
245 | data_loader=s_test_dataloader,
246 | device=kwargs['device'],
247 | history=ae_eval_val_history
248 | )
249 | ae_eval_val_history = eval_ae_epoch(model=autoencoder,
250 | data_loader=t_test_dataloader,
251 | device=kwargs['device'],
252 | history=ae_eval_val_history
253 | )
254 | for k in ae_eval_val_history:
255 | if k != 'best_index':
256 | ae_eval_val_history[k][-2] += ae_eval_val_history[k][-1]
257 | ae_eval_val_history[k].pop()
258 |
259 | classification_eval_test_history = evaluate_adv_classification_epoch(classifier=confounder_classifier,
260 | s_dataloader=s_test_dataloader,
261 | t_dataloader=t_test_dataloader,
262 | device=kwargs['device'],
263 | history=classification_eval_test_history)
264 | classification_eval_train_history = evaluate_adv_classification_epoch(classifier=confounder_classifier,
265 | s_dataloader=s_train_dataloader,
266 | t_dataloader=t_train_dataloader,
267 | device=kwargs['device'],
268 | history=classification_eval_train_history)
269 |
270 | for step, s_batch in enumerate(s_train_dataloader):
271 | t_batch = next(iter(t_train_dataloader))
272 | classifier_pretrain_history = classification_train_step(classifier=confounder_classifier,
273 | s_batch=s_batch,
274 | t_batch=t_batch,
275 | loss_fn=confounded_loss,
276 | device=kwargs['device'],
277 | optimizer=classifier_optimizer,
278 | history=classifier_pretrain_history)
279 |
280 | classification_eval_test_history = evaluate_adv_classification_epoch(classifier=confounder_classifier,
281 | s_dataloader=s_test_dataloader,
282 | t_dataloader=t_test_dataloader,
283 | device=kwargs['device'],
284 | history=classification_eval_test_history)
285 | classification_eval_train_history = evaluate_adv_classification_epoch(classifier=confounder_classifier,
286 | s_dataloader=s_train_dataloader,
287 | t_dataloader=t_test_dataloader,
288 | device=kwargs['device'],
289 | history=classification_eval_train_history)
290 |
291 | torch.save(autoencoder.state_dict(), os.path.join(kwargs['model_save_folder'], 'ae.pt'))
292 |
293 | else:
294 | try:
295 | autoencoder.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt')))
296 | except FileNotFoundError:
297 | raise Exception("No pre-trained encoder")
298 |
299 | return autoencoder.encoder, (ae_eval_train_history,
300 | ae_eval_val_history,
301 | classification_eval_train_history,
302 | classification_eval_test_history,
303 | classifier_pretrain_history)
304 |
--------------------------------------------------------------------------------
/code/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | from sklearn import preprocessing
8 | from sklearn.model_selection import train_test_split, StratifiedKFold
9 | from torch.utils.data import TensorDataset, DataLoader,Dataset
10 | from rdkit import RDLogger
11 | import math
12 |
13 | def set_seed(seed):
14 | random.seed(seed)
15 | np.random.seed(seed)
16 | torch.manual_seed(seed)
17 | if torch.cuda.device_count() > 0:
18 | torch.cuda.manual_seed_all(seed)
19 |
20 | def get_unlabeled_dataloaders(gex_features_df, seed, batch_size, ccle_only=False):
21 | """
22 | Cancer cell lines as source domain, thus s_dataloaders
23 | Patients as target domain, thus t_dataloaders
24 | """
25 | set_seed(seed)
26 |
27 | ccle_sample_info_df = pd.read_csv('../data/supple_info/sample_info/ccle_sample_info.csv', index_col=0)
28 | ccle_sample_info_df = ccle_sample_info_df.reset_index().drop_duplicates(subset="Depmap_id",keep='first').set_index("Depmap_id")
29 |
30 | xena_sample_info_df = pd.read_csv('../data/supple_info/sample_info/xena_sample_info.csv', index_col=0)
31 | xena_samples = xena_sample_info_df.index.intersection(gex_features_df.index) #gex_features_df.loc[gex_features_df.index.str.startswith('TCGA')].index
32 | ccle_samples = ccle_sample_info_df.index.intersection(gex_features_df.index)
33 | xena_sample_info_df = xena_sample_info_df.loc[xena_samples]
34 | ccle_sample_info_df = ccle_sample_info_df.loc[ccle_samples.intersection(ccle_sample_info_df.index)]
35 |
36 | xena_df = gex_features_df.loc[xena_samples]
37 | ccle_df = gex_features_df.loc[ccle_samples]
38 |
39 | excluded_ccle_samples = []
40 | excluded_ccle_samples.extend(ccle_df.index.difference(ccle_sample_info_df.index))
41 | excluded_ccle_diseases = ccle_sample_info_df.primary_disease.value_counts()[
42 | ccle_sample_info_df.primary_disease.value_counts() < 2].index
43 | excluded_ccle_samples.extend(
44 | ccle_sample_info_df[ccle_sample_info_df.primary_disease.isin(excluded_ccle_diseases)].index)
45 |
46 | to_split_ccle_df = ccle_df[~ccle_df.index.isin(excluded_ccle_samples)]
47 | train_ccle_df, test_ccle_df = train_test_split(to_split_ccle_df, test_size=0.1,
48 | stratify=ccle_sample_info_df.loc[
49 | to_split_ccle_df.index].primary_disease)
50 | test_ccle_df = test_ccle_df.append(ccle_df.loc[excluded_ccle_samples])
51 | train_xena_df, test_xena_df = train_test_split(xena_df, test_size=0.1, #len(test_ccle_df) / len(xena_df),
52 | #stratify=xena_sample_info_df['_primary_disease'],
53 | random_state=seed)
54 | print(' ')
55 | print(f"Pretrain dataset: {xena_df.shape[0]}(TCGA) {to_split_ccle_df.shape[0]}(Cell line)")
56 | print(' ')
57 |
58 | xena_dataset = TensorDataset(
59 | torch.from_numpy(xena_df.values.astype('float32'))
60 | )
61 |
62 | ccle_dataset = TensorDataset(
63 | torch.from_numpy(ccle_df.values.astype('float32'))
64 | )
65 |
66 | train_xena_dateset = TensorDataset(
67 | torch.from_numpy(train_xena_df.values.astype('float32')))
68 | test_xena_dateset = TensorDataset(
69 | torch.from_numpy(test_xena_df.values.astype('float32')))
70 | train_ccle_dateset = TensorDataset(
71 | torch.from_numpy(train_ccle_df.values.astype('float32')))
72 | test_ccle_dateset = TensorDataset(
73 | torch.from_numpy(test_ccle_df.values.astype('float32')))
74 |
75 | xena_dataloader = DataLoader(xena_dataset,
76 | batch_size=batch_size,
77 | shuffle=True)
78 | train_xena_dataloader = DataLoader(train_xena_dateset,
79 | batch_size=batch_size,
80 | shuffle=True)
81 | test_xena_dataloader = DataLoader(test_xena_dateset,
82 | batch_size=batch_size,
83 | shuffle=True)
84 |
85 | ccle_data_loader = DataLoader(ccle_dataset,
86 | batch_size=batch_size,
87 | shuffle=True,
88 | drop_last=True
89 | )
90 |
91 | train_ccle_dataloader = DataLoader(train_ccle_dateset,
92 | batch_size=batch_size,
93 | shuffle=True, drop_last=True)
94 | test_ccle_dataloader = DataLoader(test_ccle_dateset,
95 | batch_size=batch_size,
96 | shuffle=True)
97 | if ccle_only:
98 | return (ccle_data_loader, test_ccle_dataloader), (ccle_data_loader, test_ccle_dataloader)
99 | else:
100 | return (ccle_data_loader, test_ccle_dataloader), (xena_dataloader, test_xena_dataloader)
101 |
102 |
103 |
104 | def get_finetune_dataloader_generator(gex_features_df,label_type, sample_size = 0.006,dataset = 'gdsc1_raw', seed = 2020 , batch_size = 64, ccle_measurement='AUC',
105 | n_splits=5,q=2,
106 | tumor_type = "TCGA",
107 | select_drug_method = True):
108 | """
109 | sensitive (responder): 1
110 | resistant (non-responder): 0
111 |
112 | """
113 | RDLogger.DisableLog('rdApp.*')
114 |
115 | # TCGA_id(index),smiles(column 0),gex,label(column -1)
116 | test_labeled_dataloaders,cid_list = get_tcga_ZS_dataloaders(gex_features_df=gex_features_df,
117 | label_type = label_type,
118 | batch_size=batch_size,
119 | q=2,
120 | tumor_type = tumor_type)
121 |
122 | ccle_labeled_dataloader_generator = get_ccl_labeled_dataloader_generator(gex_features_df=gex_features_df,
123 | tumor_type = tumor_type,
124 | seed=seed,
125 | sample_size = sample_size,
126 | dataset = dataset,
127 | batch_size=batch_size,
128 | measurement=ccle_measurement,
129 | n_splits=n_splits,q=q,
130 | cid_list = cid_list,
131 | select_drug_method = select_drug_method)
132 |
133 | for train_labeled_ccle_dataloader, test_labeled_ccle_dataloader in ccle_labeled_dataloader_generator:
134 | yield train_labeled_ccle_dataloader, test_labeled_ccle_dataloader, test_labeled_dataloaders
135 |
136 | def get_ccl_labeled_dataloader_generator(gex_features_df,tumor_type,cid_list,select_drug_method,sample_size,dataset = 'gdsc_raw', batch_size = 64, seed=2020,
137 | measurement='AUC', n_splits=5,q=2):
138 | # measurement = 'Z_score' 'AUC' 'IC50'
139 | # dataset = 'gdsc1_raw' 'gdsc1_rebalance.csv'
140 |
141 | Dataset_path = '../data/finetune_data/{}.csv'.format(dataset)
142 | # if select_drug_method == "all":
143 | # Dataset_path = '../data/finetune_data/gdsc1_rebalance.csv'
144 | sensitivity_df = pd.read_csv(Dataset_path,index_col=1)
145 | # sensitivity_df.dropna(inplace=True)
146 |
147 | # target_df = sensitivity_df.groupby(['Depmap_id', 'Drug_smiles']).mean()
148 | # target_df = target_df.reset_index()
149 | ccle_sample_with_gex = pd.read_csv("../data/supple_info/ccle_sample_with_gex.csv",index_col=0)
150 | sensitivity_df = sensitivity_df.loc[sensitivity_df.index.isin(ccle_sample_with_gex.index)]
151 | if select_drug_method == "all":
152 | target_df = sensitivity_df
153 | elif select_drug_method == "overlap":
154 | target_df = sensitivity_df.loc[sensitivity_df['cid'].isin(cid_list)]
155 | print("Drug num: target(TCGA) / source(GDSC) / overlap = {0} {1} / {2} / {3} {4}".format(
156 | pd.Series(cid_list).unique() , pd.Series(cid_list).nunique(),
157 | sensitivity_df['Drug_name'].nunique(),
158 | target_df['Drug_name'].unique() , target_df['Drug_name'].nunique()
159 | ))
160 | elif select_drug_method == "random":
161 | target_df = sensitivity_df.sample(n = round(sample_size*sensitivity_df.shape[0])) # sample some data randomly
162 | print(' ')
163 |
164 | print("Select {0} dataset {1} / {2}".format(dataset,
165 | # round(sample_size*sensitivity_df.shape[0]),#
166 | target_df.shape[0],
167 | sensitivity_df.shape[0]
168 | ))
169 | print(' ')
170 |
171 |
172 | # if ccl_match == "yes": #use match
173 | # ccl_gex_df = gex_features_df
174 | # elif ccl_match == "no": #use all
175 | # ccl_gex_df = all_ccle_gex
176 | # elif ccl_match == "match_zs": #
177 | # ccle_sample_info = pd.read_csv('../data/supple_info/sample_info/ccle_sample_info.csv', index_col=0)
178 | # select_ccl = ccle_sample_info.loc[ccle_sample_info.tumor_type == tumor_type].index
179 | # ccl_gex_df = all_ccle_gex.loc[select_ccl]
180 | # else:
181 | # raise NotImplementedError('Not true ccl_match supplied!')
182 | ccl_gex_df = gex_features_df
183 | keep_samples = target_df.index.isin(ccl_gex_df.index)
184 |
185 | ccle_labeled_feature_df = target_df.loc[keep_samples][["Drug_smile","label"]]
186 | ccle_labeled_feature_df.dropna(inplace=True)
187 | ccle_labeled_feature_df = ccle_labeled_feature_df.merge(ccl_gex_df,left_index=True,right_index=True)
188 |
189 | drug_emb = pd.read_csv("../data/supple_info/drug_embedding/drug_embedding_for_cell_line.csv",index_col=0)
190 | ccle_labeled_feature_df = ccle_labeled_feature_df.merge(drug_emb,left_on="Drug_smile",right_on="Drug_smile") #gex,label(1),drug(300)
191 | del ccle_labeled_feature_df['Drug_smile']
192 | ccle_labeled_feature_df.dropna(inplace=True)
193 |
194 | # ccle_labeled_feature_df.drop_duplicates(keep='first',inplace=True)
195 | # ccle_labeled_feature_df.to_csv("../r_plot/cesc_ccl_data.csv") #需要时把fintune筛选得到的ccl_feature输出来看看
196 | # if select_drug_method == "all":
197 | # ccle_labeled_feature_df['label'] = 0
198 | # ccle_labeled_feature_df.loc[ccle_labeled_feature_df[measurement]<0.55,'label'] = 1
199 | # print("Label distribution before: ",ccle_labeled_feature_df['label'].value_counts())
200 | # ccle_labeled_feature_df = Rebalance_dataset_for_every_drug(ccle_labeled_feature_df)
201 | # ccle_labeled_feature_df = ccle_labeled_feature_df['label'].astype('float32')
202 | # ccle_labeled_feature_df['label'] = ccle_labeled_feature_df['label'].astype('int')
203 | # print("Label distribution after: ",ccle_labeled_feature_df['label'].value_counts())
204 |
205 | ccle_labels = ccle_labeled_feature_df['label'].values
206 | # ccle_labeled_feature_df.to_csv("../see.csv")
207 | # del ccle_labeled_feature_df["Drug_name"]
208 | if max(ccle_labels) < 1 and min(ccle_labels) > 0:
209 | ccle_labels = (ccle_labels < np.median(ccle_labels)).astype('int')
210 |
211 |
212 | s_kfold = StratifiedKFold(n_splits=5, random_state=seed, shuffle=True)
213 | for train_index, test_index in s_kfold.split(ccle_labeled_feature_df.values, ccle_labels):
214 | train_labeled_ccle_df, test_labeled_ccle_df = ccle_labeled_feature_df.values[train_index], \
215 | ccle_labeled_feature_df.values[test_index]
216 | # train_ccle_labels, test_ccle_labels = ccle_labels.values[train_index], ccle_labels.values[test_index]
217 |
218 | train_labeled_ccle_dateset = TensorDataset(
219 | # torch.from_numpy(train_labeled_ccle_df[:,0:train_labeled_ccle_df.shape[1]-1].astype('float32')),
220 | torch.from_numpy(train_labeled_ccle_df[:,1:(train_labeled_ccle_df.shape[1]-300)].astype('float32')), #gex:length-301
221 | torch.from_numpy(train_labeled_ccle_df[:,(train_labeled_ccle_df.shape[1]-300):(train_labeled_ccle_df.shape[1])].astype('float32')), #drug:300
222 | torch.from_numpy(train_labeled_ccle_df[:,0])
223 | )
224 | test_labeled_ccle_dateset = TensorDataset(
225 | # torch.from_numpy(test_labeled_ccle_df[:,0:test_labeled_ccle_df.shape[1]-1].astype('float32')),
226 | torch.from_numpy(test_labeled_ccle_df[:,1:(test_labeled_ccle_df.shape[1]-300)].astype('float32')),
227 | torch.from_numpy(test_labeled_ccle_df[:,(test_labeled_ccle_df.shape[1]-300):(test_labeled_ccle_df.shape[1])].astype('float32')),
228 | torch.from_numpy(test_labeled_ccle_df[:,0])
229 | )
230 |
231 | train_labeled_ccle_dataloader = DataLoader(train_labeled_ccle_dateset,
232 | batch_size=batch_size,
233 | # num_workers = 8,
234 | # pin_memory=True,
235 | shuffle=True)
236 |
237 | test_labeled_ccle_dataloader = DataLoader(test_labeled_ccle_dateset,
238 | batch_size=batch_size,
239 | # num_workers = 8,
240 | # pin_memory=True,
241 | shuffle=True)
242 |
243 | yield train_labeled_ccle_dataloader, test_labeled_ccle_dataloader
244 |
245 |
246 | def get_tcga_ZS_dataloaders(gex_features_df, batch_size, label_type = "PFS",q=2,tumor_type = "TCGA"):
247 | '''
248 | gex_features_df: pd.DataFrame, index: patient_id, columns: gene_id
249 | label_type: str, "PFS"
250 | q: int
251 | '''
252 | tcga_gex_feature_df = gex_features_df.loc[gex_features_df.index.str.startswith('TCGA')]
253 | tcga_gex_feature_df.index = tcga_gex_feature_df.index.map(lambda x: x[:12])
254 | tcga_gex_feature_df = tcga_gex_feature_df.groupby(level=0).mean()
255 |
256 | response_df = pd.read_csv('../data/zeroshot_data/tcga698_single_drug_response_df.csv',index_col=1)
257 | response_type_df = pd.read_csv('../data/zeroshot_data/tcga667_single_drug_response_type_df.csv',index_col=1)
258 |
259 | if tumor_type != "TCGA":
260 | response_df = response_df.loc[response_df['tcga_project'] == tumor_type]
261 | response_type_df = response_type_df[response_type_df['tcga_project'] == tumor_type]
262 |
263 | response_df = response_df[['days_to_new_tumor_event_after_initial_treatment','smiles','cid']]
264 | response_type_df = response_type_df[['treatment_best_response','smiles','cid']]
265 |
266 | if label_type == "PFS":
267 | tcga_drug_gex = response_df.merge(tcga_gex_feature_df,left_index=True,right_index=True)
268 | elif label_type == "Imaging":
269 | tcga_drug_gex = response_type_df.merge(tcga_gex_feature_df,left_index=True,right_index=True)
270 |
271 | # Get the drug embedding of each drug for TCGA patients
272 | drug_emb = pd.read_csv("../data/supple_info/drug_embedding/drug_embedding_for_patient.csv",index_col=0)
273 |
274 | # print(tcga_drug_gex.shape)
275 | # print(tcga_drug_gex['smiles'].isin(drug_emb['Drug_smile']))
276 | tcga_drug_gex = tcga_drug_gex.merge(drug_emb,left_on="smiles",right_on="Drug_smile") #gex,label(1),drug(300)
277 | # print(tcga_drug_gex.shape)
278 |
279 | del tcga_drug_gex['smiles']
280 | del tcga_drug_gex['Drug_smile']
281 | cid_list = tcga_drug_gex.iloc[:,1].values
282 | # print("cid_list: {}".format(cid_list))
283 | del tcga_drug_gex['cid']
284 | # tcga_drug_gex.dropna(inplace=True)
285 |
286 | if label_type == "PFS":
287 | # tcga_drug_gex = response_df.merge(tcga_gex_feature_df,left_index=True,right_index=True)
288 | drug_label = pd.qcut(tcga_drug_gex['days_to_new_tumor_event_after_initial_treatment'],q,labels = range(0,q))
289 | del tcga_drug_gex['days_to_new_tumor_event_after_initial_treatment']
290 | elif label_type == "Imaging":
291 | # tcga_drug_gex = response_type_df.merge(tcga_gex_feature_df,left_index=True,right_index=True)
292 | drug_label = np.array(tcga_drug_gex['treatment_best_response'].apply(
293 | lambda s: s in ['Complete Response']), dtype='int32')
294 | del tcga_drug_gex['treatment_best_response']
295 |
296 | tcga_drug_gex['label'] = drug_label
297 |
298 | tcga_drug_gex = tcga_drug_gex.values
299 | # print(' ')
300 | print("Zero-shot to {0} num: {1}".format(tumor_type,tcga_drug_gex.shape[0]))
301 | # print(tcga_drug_gex.shape[1])
302 | # print(tcga_drug_gex)
303 |
304 | labeled_tcga_dateset = TensorDataset(
305 | torch.from_numpy(tcga_drug_gex[:,0:tcga_drug_gex.shape[1]-301].astype('float32')), #gex:length-301
306 | torch.from_numpy(tcga_drug_gex[:,tcga_drug_gex.shape[1]-301:tcga_drug_gex.shape[1]-1].astype('float32')), #drug:300
307 | torch.from_numpy(tcga_drug_gex[:,tcga_drug_gex.shape[1]-1])
308 | )
309 | # labeled_tcga_dateset = TCGADataset(tcga_drug_gex[:,0], #smiles
310 | # tcga_drug_gex[:,2:tcga_drug_gex.shape[1]-1], #gex
311 | # tcga_drug_gex[:,tcga_drug_gex.shape[1]-1] #label
312 | # )
313 |
314 | labeled_tcga_dataloader = DataLoader(labeled_tcga_dateset,
315 | batch_size=batch_size,
316 | shuffle=False) # set shuffle Flase
317 |
318 | return labeled_tcga_dataloader,cid_list
319 |
320 |
321 | # Generating Zero-shot dataset for specific tumor type
322 | def get_all_tcga_ZS_dataloaders(all_patient_gex, batch_size,Tumor_type_list, label_type = "PFS",q=2):
323 | patient_sample_info = pd.read_csv("../data/preprocessed_dat/xena_sample_info.csv", index_col=0)
324 |
325 | for TT in Tumor_type_list:
326 | print(f'Generating Zero-shot dataset: {TT}')
327 | patient_samples = all_patient_gex.index.intersection(patient_sample_info.loc[patient_sample_info.tumor_type == TT].index)
328 | gex_features_df = all_patient_gex.loc[patient_samples]
329 | # print(gex_features_df)
330 | TT_ZS_dataloaders,_ = get_tcga_ZS_dataloaders(gex_features_df,
331 | batch_size, label_type = label_type,q=q,tumor_type = TT)
332 | yield TT_ZS_dataloaders, TT
333 |
334 | # Generating PDR dataset for specific tumor type
335 | def get_pdr_data_dataloaders(gex_features_df, batch_size, tumor_type = "TCGA"):
336 | # Return: Dataloader: TCGA_id,smiles,gex,label
337 | tcga_gex_feature_df = gex_features_df.loc[gex_features_df.index.str.startswith('TCGA')]
338 | tcga_gex_feature_df.index = tcga_gex_feature_df.index.map(lambda x: x[:12])
339 | tcga_gex_feature_df = tcga_gex_feature_df.groupby(level=0).mean()
340 | # print(tcga_gex_feature_df.shape)
341 |
342 | patient_sample_info = pd.read_csv("../data/supple_info/sample_info/xena_sample_info.csv", index_col=0)
343 |
344 | # print(tcga_gex_feature_df.index)
345 | select_index = patient_sample_info.loc[patient_sample_info['tumor_type'] == tumor_type].index.intersection(tcga_gex_feature_df.index)
346 | # print(select_index)
347 | if tumor_type != "TCGA":
348 | tcga_gex_feature_df = tcga_gex_feature_df.loc[select_index]
349 |
350 | drug_emb = pd.read_csv("../data/supple_info/drug_embedding/drug_embedding_for_cell_line.csv",index_col=0)
351 | drug_emb['Drug_smile']="a"
352 |
353 | tcga_gex_feature_df['Drug_smile'] = "a"
354 | sample_id = tcga_gex_feature_df.index
355 | tcga_gex_feature_df.reset_index(inplace=True)
356 | # print(tcga_drug_gex.shape)
357 | # print(tcga_drug_gex['smiles'].isin(drug_emb['Drug_smile']))
358 | tcga_drug_gex = tcga_gex_feature_df.merge(drug_emb,left_on="Drug_smile",right_on="Drug_smile") #gex,label(1),drug(300)
359 | tcga_drug_gex.set_index("index",inplace=True)
360 | print(f"{tumor_type} Sample_num: {tcga_gex_feature_df.shape[0]}; PDR_test_num: {tcga_drug_gex.shape[0]}")
361 | # exit()
362 |
363 | del tcga_drug_gex['Drug_smile']
364 |
365 | pdr_data_dateset = TensorDataset(
366 | torch.from_numpy(tcga_drug_gex.loc[:,~tcga_drug_gex.columns.str.startswith('drug_embedding')].values.astype('float32')), #gex:length-301
367 | torch.from_numpy(tcga_drug_gex.loc[:,tcga_drug_gex.columns.str.startswith('drug_embedding')].values.astype('float32')) #drug:300
368 | )
369 |
370 | pdr_data_dataloader = DataLoader(pdr_data_dateset,
371 | batch_size=batch_size,
372 | shuffle=False) # not shuffle to save the order of patient
373 |
374 | return (pdr_data_dataloader,sample_id)
375 |
376 |
377 | def Rebalance_dataset_for_every_drug(Original_dataset):
378 | drug_list = Original_dataset.Drug_name.drop_duplicates().values
379 | all_drug = pd.DataFrame()
380 | # drug = "5-Fluorouracil"
381 | for drug in drug_list:
382 | # print(drug)
383 | drug_data = Original_dataset.loc[Original_dataset.Drug_name == drug]
384 | drug_data_0 = drug_data.loc[drug_data.label == 0]
385 | drug_data_1 = drug_data.loc[drug_data.label == 1]
386 | ratio = drug_data_1.shape[0]/drug_data_0.shape[0]
387 | if ratio > 1 :
388 | ratio = math.floor(ratio)
389 | drug_data_0 = pd.DataFrame(np.repeat(drug_data_0.values,ratio,axis=0),columns=drug_data.columns)
390 | elif ratio == 0:
391 | drug_data_0 = drug_data_0
392 | else:
393 | ratio = math.floor(1/ratio)
394 | drug_data_1 = pd.DataFrame(np.repeat(drug_data_1.values,ratio,axis=0),columns=drug_data.columns)
395 | drug_data = pd.concat([drug_data_0,drug_data_1])
396 | # drug_data.label.value_counts()
397 | all_drug = pd.concat([all_drug,drug_data])
398 | print(all_drug.iloc[1:5])
399 | # all_drug.groupby('Drug_name')['label'].value_counts()
400 |
401 | return all_drug
--------------------------------------------------------------------------------